optimization/conjugate_gradient.h

00001 #include <TooN/optimization/brent.h>
00002 #include <utility>
00003 #include <cmath>
00004 #include <cassert>
00005 #include <cstdlib>
00006 
00007 namespace TooN{
00008 
00009     namespace Internal{
00010     
00011 
00018     template<int Size, typename Precision, typename Func> struct LineSearch
00019     {
00020         const Vector<Size, Precision>& start; 
00021         const Vector<Size, Precision>& direction;
00022 
00023         const Func& f;
00024 
00029         LineSearch(const Vector<Size, Precision>& s, const Vector<Size, Precision>& d, const Func& func)
00030         :start(s),direction(d),f(func)
00031         {}
00032         
00035         Precision operator()(Precision x) const
00036         {
00037             return f(start + x * direction); 
00038         }
00039     };
00040     
00050     template<typename Precision, typename Func> Matrix<3,2,Precision> bracket_minimum_forward(Precision a_val, const Func& func, Precision initial_lambda=1)
00051     {
00052         //Get a, b, c to  bracket a minimum along a line
00053         Precision a, b, c, b_val, c_val;
00054 
00055         a=0;
00056 
00057         //Search forward in steps of lambda
00058         Precision lambda=initial_lambda;
00059         b = lambda;
00060         b_val = func(b);
00061 
00062         if(b_val < a_val) //We've gone downhill, so keep searching until we go back up
00063         {
00064             for(;;)
00065             {
00066                 lambda *= 2;
00067                 c = lambda;
00068                 c_val = func(c);
00069 
00070                 if(c_val >  b_val) // we have a bracket
00071                     break;
00072                 else
00073                 {
00074                     a = b;
00075                     a_val = b_val;
00076                     b=c;
00077                     b_val=c_val;
00078 
00079                 }
00080             }
00081         }
00082         else //We've overshot the minimum, so back up
00083         {
00084             c = b;
00085             c_val = b_val;
00086             //Here, c_val > a_val
00087 
00088             for(;;)
00089             {
00090                 lambda *= .5;
00091                 b = lambda;
00092                 b_val = func(b);
00093 
00094                 if(b_val < a_val)// we have a bracket
00095                     break;
00096                 else //Contract the bracket
00097                 {
00098                     c = b;
00099                     c_val = b_val;
00100                 }
00101             }
00102         }
00103         
00104         Matrix<3,2> ret;
00105         ret[0] = makeVector(a, a_val);
00106         ret[1] = makeVector(b, b_val);
00107         ret[2] = makeVector(c, c_val);
00108 
00109         return ret;
00110     }
00111 
00112 }
00113 
00114 
00159 template<int Size, class Precision=double> struct ConjugateGradient
00160 {
00161     const int size;      
00162     Vector<Size> g;      
00163     Vector<Size> h;      
00164     Vector<Size> old_g;  
00165     Vector<Size> old_h;  
00166     Vector<Size> x;      
00167     Vector<Size> old_x;  
00168     Precision y;         
00169     Precision old_y;     
00170 
00171     Precision tolerance; 
00172     Precision epsilon;   
00173     int       max_iterations; 
00174 
00175     Precision bracket_initial_lambda;
00176     Precision linesearch_tolerance; 
00177     Precision linesearch_epsilon; 
00178     int linesearch_max_iterations;  
00179 
00180     int iterations; 
00181     
00186     template<class Func, class Deriv> ConjugateGradient(const Vector<Size>& start, const Func& func, const Deriv& deriv)
00187     : size(start.size()),
00188       g(size),h(size),old_g(size),old_h(size),x(start),old_x(size)
00189     {
00190         using std::numeric_limits;
00191 
00192         x = start;
00193         
00194         //Start with the conjugate direction aligned with
00195         //the gradient
00196         g = deriv(x);
00197         h = g;
00198 
00199         y = func(x);
00200         old_y = y;
00201 
00202         tolerance = sqrt(numeric_limits<Precision>::epsilon());
00203         epsilon = 1e-20;
00204         max_iterations = size * 100;
00205 
00206         bracket_initial_lambda = 1;
00207 
00208         linesearch_tolerance =  sqrt(numeric_limits<Precision>::epsilon());
00209         linesearch_epsilon = 1e-20;
00210         linesearch_max_iterations=100;
00211 
00212         iterations=0;
00213     }
00214     
00215 
00227     template<class Func> void find_next_point(const Func& func)
00228     {
00229         Internal::LineSearch<Size, Precision, Func> line(x, -h, func);
00230 
00231         //Always search in the conjugate direction (h)
00232 
00233         //First bracket a minimum. 
00234         Matrix<3,2,Precision> bracket = Internal::bracket_minimum_forward(y, line, bracket_initial_lambda);
00235 
00236         double a = bracket[0][0];
00237         double b = bracket[1][0];
00238         double c = bracket[2][0];
00239 
00240         double a_val = bracket[0][1];
00241         double b_val = bracket[1][1];
00242         double c_val = bracket[2][1];
00243 
00244         //We should have a bracket here
00245         assert(a < b && b < c);
00246         assert(a_val > b_val && b_val < c_val);
00247         
00248         //Find the real minimum
00249         Vector<2, Precision>  m = brent_line_search(a, b, c, b_val, line, linesearch_max_iterations, linesearch_tolerance, linesearch_epsilon); 
00250     
00251         assert(m[0] >= a && m[0] <= c);
00252         assert(m[1] <= b_val);
00253         
00254         //Update the current position and value
00255         old_y = y;
00256         old_x = x;
00257 
00258         x -= m[0] * h;
00259         y = m[1];
00260         
00261         iterations++;
00262     }
00263     
00266     bool finished()
00267     {
00268         using std::abs;
00269         return iterations > max_iterations || 2*abs(y - old_y) <= tolerance * (abs(y) + abs(old_y) + epsilon);
00270     }
00271     
00279     template<class Deriv> void update_vectors_PR(const Deriv& deriv)
00280     {
00281         //Update the position, gradient and conjugate directions
00282         old_g = g;
00283         old_h = h;
00284 
00285         g = deriv(x);
00286         //Precision gamma = (g * g - oldg*g)/(oldg * oldg);
00287         Precision gamma = (g * g - old_g*g)/(old_g * old_g);
00288         h = g + gamma * old_h;
00289     }
00290     
00308     template<class Func, class Deriv> bool iterate(const Func& func, const Deriv& deriv)
00309     {
00310         find_next_point(func);
00311 
00312         if(!finished())
00313         {
00314             update_vectors_PR(deriv);
00315             return 1;
00316         }
00317         else
00318             return 0;
00319     }
00320 };
00321 
00322 }

Generated on Thu May 7 20:28:40 2009 for TooN by  doxygen 1.5.3