TooN 2.1
optimization/conjugate_gradient.h
00001 #include <TooN/optimization/brent.h>
00002 #include <utility>
00003 #include <cmath>
00004 #include <cassert>
00005 #include <cstdlib>
00006 
00007 namespace TooN{
00008     namespace Internal{
00009 
00010 
00011     ///Turn a multidimensional function in to a 1D function by specifying a
00012     ///point and direction. A nre function is defined:
00013     ////\f[
00014     /// g(a) = \Vec{s} + a \Vec{d}
00015     ///\f]
00016     ///@ingroup gOptimize
00017     template<int Size, typename Precision, typename Func> struct LineSearch
00018     {
00019         const Vector<Size, Precision>& start; ///< \f$\Vec{s}\f$
00020         const Vector<Size, Precision>& direction;///< \f$\Vec{d}\f$
00021 
00022         const Func& f;///< \f$f(\cdotp)\f$
00023 
00024         ///Set up the line search class.
00025         ///@param s Start point, \f$\Vec{s}\f$.
00026         ///@param d direction, \f$\Vec{d}\f$.
00027         ///@param func Function, \f$f(\cdotp)\f$.
00028         LineSearch(const Vector<Size, Precision>& s, const Vector<Size, Precision>& d, const Func& func)
00029         :start(s),direction(d),f(func)
00030         {}
00031 
00032         ///@param x Position to evaluate function
00033         ///@return \f$f(\vec{s} + x\vec{d})\f$
00034         Precision operator()(Precision x) const
00035         {
00036             return f(start + x * direction);
00037         }
00038     };
00039 
00040     ///Bracket a 1D function by searching forward from zero. The assumption
00041     ///is that a minima exists in \f$f(x),\ x>0\f$, and this function searches
00042     ///for a bracket using exponentially growning or shrinking steps.
00043     ///@param a_val The value of the function at zero.
00044     ///@param func Function to bracket
00045     ///@param initial_lambda Initial stepsize
00046     ///@param zeps Minimum bracket size.
00047     ///@return <code>m[i][0]</code> contains the values of \f$x\f$ for the bracket, in increasing order,
00048     ///        and <code>m[i][1]</code> contains the corresponding values of \f$f(x)\f$. If the bracket 
00049     ///        drops below the minimum bracket size, all zeros are returned.
00050     ///@ingroup gOptimize
00051     template<typename Precision, typename Func> Matrix<3,2,Precision> bracket_minimum_forward(Precision a_val, const Func& func, Precision initial_lambda, Precision zeps)
00052     {
00053         //Get a, b, c to  bracket a minimum along a line
00054         Precision a, b, c, b_val, c_val;
00055 
00056         a=0;
00057 
00058         //Search forward in steps of lambda
00059         Precision lambda=initial_lambda;
00060         b = lambda;
00061         b_val = func(b);
00062 
00063         while(std::isnan(b_val))
00064         {
00065             //We've probably gone in to an invalid region. This can happen even 
00066             //if following the gradient would never get us there.
00067             //try backing off lambda
00068             lambda*=.5;
00069             b = lambda;
00070             b_val = func(b);
00071 
00072         }
00073 
00074 
00075         if(b_val < a_val) //We've gone downhill, so keep searching until we go back up
00076         {
00077             double last_good_lambda = lambda;
00078             
00079             for(;;)
00080             {
00081                 lambda *= 2;
00082                 c = lambda;
00083                 c_val = func(c);
00084 
00085                 if(std::isnan(c_val))
00086                     break;
00087                 last_good_lambda = lambda;
00088                 if(c_val >  b_val) // we have a bracket
00089                     break;
00090                 else
00091                 {
00092                     a = b;
00093                     a_val = b_val;
00094                     b=c;
00095                     b_val=c_val;
00096 
00097                 }
00098             }
00099 
00100             //We took a step too far.
00101             //Back up: this will not attempt to ensure a bracket
00102             if(std::isnan(c_val))
00103             {
00104                 double bad_lambda=lambda;
00105                 double l=1;
00106 
00107                 for(;;)
00108                 {
00109                     l*=.5;
00110                     c = last_good_lambda + (bad_lambda - last_good_lambda)*l;
00111                     c_val = func(c);
00112 
00113                     if(!std::isnan(c_val))
00114                         break;
00115                 }
00116 
00117 
00118             }
00119 
00120         }
00121         else //We've overshot the minimum, so back up
00122         {
00123             c = b;
00124             c_val = b_val;
00125             //Here, c_val > a_val
00126 
00127             for(;;)
00128             {
00129                 lambda *= .5;
00130                 b = lambda;
00131                 b_val = func(b);
00132 
00133                 if(b_val < a_val)// we have a bracket
00134                     break;
00135                 else if(lambda < zeps)
00136                     return Zeros;
00137                 else //Contract the bracket
00138                 {
00139                     c = b;
00140                     c_val = b_val;
00141                 }
00142             }
00143         }
00144 
00145         Matrix<3,2> ret;
00146         ret[0] = makeVector(a, a_val);
00147         ret[1] = makeVector(b, b_val);
00148         ret[2] = makeVector(c, c_val);
00149 
00150         return ret;
00151     }
00152 
00153 }
00154 
00155 
00156 /** This class provides a nonlinear conjugate-gradient optimizer. The following
00157 code snippet will perform an optimization on the Rosenbrock Bananna function in
00158 two dimensions:
00159 
00160 @code
00161 double Rosenbrock(const Vector<2>& v)
00162 {
00163         return sq(1 - v[0]) + 100 * sq(v[1] - sq(v[0]));
00164 }
00165 
00166 Vector<2> RosenbrockDerivatives(const Vector<2>& v)
00167 {
00168     double x = v[0];
00169     double y = v[1];
00170 
00171     Vector<2> ret;
00172     ret[0] = -2+2*x-400*(y-sq(x))*x;
00173     ret[1] = 200*y-200*sq(x);
00174 
00175     return ret;
00176 }
00177 
00178 int main()
00179 {
00180     ConjugateGradient<2> cg(makeVector(0,0), Rosenbrock, RosenbrockDerivatives);
00181 
00182     while(cg.iterate(Rosenbrock, RosenbrockDerivatives))
00183         cout << "y_" << iteration << " = " << cg.y << endl;
00184 
00185     cout << "Optimal value: " << cg.y << endl;
00186 }
00187 @endcode
00188 
00189 The chances are that you will want to read the documentation for
00190 ConjugateGradient::ConjugateGradient and ConjugateGradient::iterate.
00191 
00192 Linesearch is currently performed using golden-section search and conjugate
00193 vector updates are performed using the Polak-Ribiere equations.  There many
00194 tunable parameters, and the internals are readily accessible, so alternative
00195 termination conditions etc can easily be substituted. However, ususally these
00196 will not be necessary.
00197 
00198 @ingroup gOptimize
00199 */
00200 template<int Size=Dynamic, class Precision=double> struct ConjugateGradient
00201 {
00202     const int size;      ///< Dimensionality of the space.
00203     Vector<Size> g;      ///< Gradient vector used by the next call to iterate()
00204     Vector<Size> h;      ///< Conjugate vector to be searched along in the next call to iterate()
00205     Vector<Size> minus_h;///< negative of h as this is required to be passed into a function which uses references (so can't be temporary)
00206     Vector<Size> old_g;  ///< Gradient vector used to compute $h$ in the last call to iterate()
00207     Vector<Size> old_h;  ///< Conjugate vector searched along in the last call to iterate()
00208     Vector<Size> x;      ///< Current position (best known point)
00209     Vector<Size> old_x;  ///< Previous best known point (not set at construction)
00210     Precision y;         ///< Function at \f$x\f$
00211     Precision old_y;     ///< Function at  old_x
00212 
00213     Precision tolerance; ///< Tolerance used to determine if the optimization is complete. Defaults to square root of machine precision.
00214     Precision epsilon;   ///< Additive term in tolerance to prevent excessive iterations if \f$x_\mathrm{optimal} = 0\f$. Known as \c ZEPS in numerical recipies. Defaults to 1e-20
00215     int       max_iterations; ///< Maximum number of iterations. Defaults to \c size\f$*100\f$
00216 
00217     Precision bracket_initial_lambda;///< Initial stepsize used in bracketing the minimum for the line search. Defaults to 1.
00218     Precision linesearch_tolerance; ///< Tolerance used to determine if the linesearch is complete. Defaults to square root of machine precision.
00219     Precision linesearch_epsilon; ///< Additive term in tolerance to prevent excessive iterations if \f$x_\mathrm{optimal} = 0\f$. Known as \c ZEPS in numerical recipies. Defaults to 1e-20
00220     int linesearch_max_iterations;  ///< Maximum number of iterations in the linesearch. Defaults to 100.
00221 
00222     Precision bracket_epsilon; ///<Minimum size for initial minima bracketing. Below this, it is assumed that the system has converged. Defaults to 1e-20.
00223 
00224     int iterations; ///< Number of iterations performed
00225 
00226     ///Initialize the ConjugateGradient class with sensible values.
00227     ///@param start Starting point, \e x
00228     ///@param func  Function \e f  to compute \f$f(x)\f$
00229     ///@param deriv  Function to compute \f$\nabla f(x)\f$
00230     template<class Func, class Deriv> ConjugateGradient(const Vector<Size>& start, const Func& func, const Deriv& deriv)
00231     : size(start.size()),
00232       g(size),h(size),minus_h(size),old_g(size),old_h(size),x(start),old_x(size)
00233     {
00234         init(start, func(start), deriv(start));
00235     }   
00236 
00237     ///Initialize the ConjugateGradient class with sensible values.
00238     ///@param start Starting point, \e x
00239     ///@param func  Function \e f  to compute \f$f(x)\f$
00240     ///@param deriv  \f$\nabla f(x)\f$
00241     template<class Func> ConjugateGradient(const Vector<Size>& start, const Func& func, const Vector<Size>& deriv)
00242     : size(start.size()),
00243       g(size),h(size),minus_h(size),old_g(size),old_h(size),x(start),old_x(size)
00244     {
00245         init(start, func(start), deriv);
00246     }   
00247 
00248     ///Initialize the ConjugateGradient class with sensible values. Used internally.
00249     ///@param start Starting point, \e x
00250     ///@param func  \f$f(x)\f$
00251     ///@param deriv  \f$\nabla f(x)\f$
00252     void init(const Vector<Size>& start, const Precision& func, const Vector<Size>& deriv)
00253     {
00254 
00255         using std::numeric_limits;
00256         using std::sqrt;
00257         x = start;
00258 
00259         //Start with the conjugate direction aligned with
00260         //the gradient
00261         g = deriv;
00262         h = g;
00263         minus_h=-h;
00264 
00265         y = func;
00266         old_y = y;
00267 
00268         tolerance = sqrt(numeric_limits<Precision>::epsilon());
00269         epsilon = 1e-20;
00270         max_iterations = size * 100;
00271 
00272         bracket_initial_lambda = 1;
00273 
00274         linesearch_tolerance =  sqrt(numeric_limits<Precision>::epsilon());
00275         linesearch_epsilon = 1e-20;
00276         linesearch_max_iterations=100;
00277 
00278         bracket_epsilon=1e-20;
00279 
00280         iterations=0;
00281     }
00282 
00283 
00284     ///Perform a linesearch from the current point (x) along the current
00285     ///conjugate vector (h).  The linesearch does not make use of derivatives.
00286     ///You probably do not want to use this function. See iterate() instead.
00287     ///This function updates:
00288     /// - x
00289     /// - old_c
00290     /// - y
00291     /// - old_y
00292     /// - iterations
00293     /// Note that the conjugate direction and gradient are not updated.
00294     /// If bracket_minimum_forward detects a local maximum, then essentially a zero
00295     /// sized step is taken.
00296     /// @param func Functor returning the function value at a given point.
00297     template<class Func> void find_next_point(const Func& func)
00298     {
00299         Internal::LineSearch<Size, Precision, Func> line(x, minus_h, func);
00300 
00301         //Always search in the conjugate direction (h)
00302         //First bracket a minimum.
00303         Matrix<3,2,Precision> bracket = Internal::bracket_minimum_forward(y, line, bracket_initial_lambda, bracket_epsilon);
00304         
00305         double a = bracket[0][0];
00306         double b = bracket[1][0];
00307         double c = bracket[2][0];
00308 
00309         double a_val = bracket[0][1];
00310         double b_val = bracket[1][1];
00311         double c_val = bracket[2][1];
00312 
00313         old_y = y;
00314         old_x = x;
00315         iterations++;
00316         
00317         //Local maximum achieved!
00318         if(a==0 && b== 0 && c == 0)
00319             return;
00320 
00321         //We should have a bracket here
00322 
00323         if(c < b)
00324         {
00325             //Failed to bracket due to NaN, so c is the best known point.
00326             //Simply go there.
00327             x-=h * c;
00328             y=c_val;
00329 
00330         }
00331         else
00332         {
00333             assert(a < b && b < c);
00334             assert(a_val > b_val && b_val < c_val);
00335 
00336             //Find the real minimum
00337             Vector<2, Precision>  m = brent_line_search(a, b, c, b_val, line, linesearch_max_iterations, linesearch_tolerance, linesearch_epsilon);
00338 
00339             assert(m[0] >= a && m[0] <= c);
00340             assert(m[1] <= b_val);
00341 
00342             //Update the current position and value
00343             x -= m[0] * h;
00344             y = m[1];
00345         }
00346     }
00347 
00348     ///Check to see it iteration should stop. You probably do not want to use
00349     ///this function. See iterate() instead. This function updates nothing.
00350     bool finished()
00351     {
00352         using std::abs;
00353         return iterations > max_iterations || 2*abs(y - old_y) <= tolerance * (abs(y) + abs(old_y) + epsilon);
00354     }
00355 
00356     ///After an iteration, update the gradient and conjugate using the
00357     ///Polak-Ribiere equations.
00358     ///This function updates:
00359     ///- g
00360     ///- old_g
00361     ///- h
00362     ///- old_h
00363     ///@param grad The derivatives of the function at \e x
00364     void update_vectors_PR(const Vector<Size>& grad)
00365     {
00366         //Update the position, gradient and conjugate directions
00367         old_g = g;
00368         old_h = h;
00369 
00370         g = grad;
00371         //Precision gamma = (g * g - oldg*g)/(oldg * oldg);
00372         Precision gamma = (g * g - old_g*g)/(old_g * old_g);
00373         h = g + gamma * old_h;
00374         minus_h=-h;
00375     }
00376 
00377     ///Use this function to iterate over the optimization. Note that after
00378     ///iterate returns false, g, h, old_g and old_h will not have been
00379     ///updated.
00380     ///This function updates:
00381     /// - x
00382     /// - old_c
00383     /// - y
00384     /// - old_y
00385     /// - iterations
00386     /// - g*
00387     /// - old_g*
00388     /// - h*
00389     /// - old_h*
00390     /// *'d variables not updated on the last iteration.
00391     ///@param func Functor returning the function value at a given point.
00392     ///@param deriv Functor to compute derivatives at the specified point.
00393     ///@return Whether to continue.
00394     template<class Func, class Deriv> bool iterate(const Func& func, const Deriv& deriv)
00395     {
00396         find_next_point(func);
00397 
00398         if(!finished())
00399         {
00400             update_vectors_PR(deriv(x));
00401             return 1;
00402         }
00403         else
00404             return 0;
00405     }
00406 };
00407 
00408 }