00001 #include <TooN/optimization/brent.h> 00002 #include <utility> 00003 #include <cmath> 00004 #include <cassert> 00005 #include <cstdlib> 00006 00007 namespace TooN{ 00008 namespace Internal{ 00009 00010 00011 ///Turn a multidimensional function in to a 1D function by specifying a 00012 ///point and direction. A nre function is defined: 00013 ////\f[ 00014 /// g(a) = \Vec{s} + a \Vec{d} 00015 ///\f] 00016 ///@ingroup gOptimize 00017 template<int Size, typename Precision, typename Func> struct LineSearch 00018 { 00019 const Vector<Size, Precision>& start; ///< \f$\Vec{s}\f$ 00020 const Vector<Size, Precision>& direction;///< \f$\Vec{d}\f$ 00021 00022 const Func& f;///< \f$f(\cdotp)\f$ 00023 00024 ///Set up the line search class. 00025 ///@param s Start point, \f$\Vec{s}\f$. 00026 ///@param d direction, \f$\Vec{d}\f$. 00027 ///@param func Function, \f$f(\cdotp)\f$. 00028 LineSearch(const Vector<Size, Precision>& s, const Vector<Size, Precision>& d, const Func& func) 00029 :start(s),direction(d),f(func) 00030 {} 00031 00032 ///@param x Position to evaluate function 00033 ///@return \f$f(\vec{s} + x\vec{d})\f$ 00034 Precision operator()(Precision x) const 00035 { 00036 return f(start + x * direction); 00037 } 00038 }; 00039 00040 ///Bracket a 1D function by searching forward from zero. The assumption 00041 ///is that a minima exists in \f$f(x),\ x>0\f$, and this function searches 00042 ///for a bracket using exponentially growning or shrinking steps. 00043 ///@param a_val The value of the function at zero. 00044 ///@param func Function to bracket 00045 ///@param initial_lambda Initial stepsize 00046 ///@param zeps Minimum bracket size. 00047 ///@return <code>m[i][0]</code> contains the values of \f$x\f$ for the bracket, in increasing order, 00048 /// and <code>m[i][1]</code> contains the corresponding values of \f$f(x)\f$. If the bracket 00049 /// drops below the minimum bracket size, all zeros are returned. 00050 ///@ingroup gOptimize 00051 template<typename Precision, typename Func> Matrix<3,2,Precision> bracket_minimum_forward(Precision a_val, const Func& func, Precision initial_lambda, Precision zeps) 00052 { 00053 //Get a, b, c to bracket a minimum along a line 00054 Precision a, b, c, b_val, c_val; 00055 00056 a=0; 00057 00058 //Search forward in steps of lambda 00059 Precision lambda=initial_lambda; 00060 b = lambda; 00061 b_val = func(b); 00062 00063 if(b_val < a_val) //We've gone downhill, so keep searching until we go back up 00064 { 00065 for(;;) 00066 { 00067 lambda *= 2; 00068 c = lambda; 00069 c_val = func(c); 00070 00071 if(c_val > b_val) // we have a bracket 00072 break; 00073 else 00074 { 00075 a = b; 00076 a_val = b_val; 00077 b=c; 00078 b_val=c_val; 00079 00080 } 00081 } 00082 } 00083 else //We've overshot the minimum, so back up 00084 { 00085 c = b; 00086 c_val = b_val; 00087 //Here, c_val > a_val 00088 00089 for(;;) 00090 { 00091 lambda *= .5; 00092 b = lambda; 00093 b_val = func(b); 00094 00095 if(b_val < a_val)// we have a bracket 00096 break; 00097 else if(lambda < zeps) 00098 return Zeros; 00099 else //Contract the bracket 00100 { 00101 c = b; 00102 c_val = b_val; 00103 } 00104 } 00105 } 00106 00107 Matrix<3,2> ret; 00108 ret[0] = makeVector(a, a_val); 00109 ret[1] = makeVector(b, b_val); 00110 ret[2] = makeVector(c, c_val); 00111 00112 return ret; 00113 } 00114 00115 } 00116 00117 00118 /** This class provides a nonlinear conjugate-gradient optimizer. The following 00119 code snippet will perform an optimization on the Rosenbrock Bananna function in 00120 two dimensions: 00121 00122 @code 00123 double Rosenbrock(const Vector<2>& v) 00124 { 00125 return sq(1 - v[0]) + 100 * sq(v[1] - sq(v[0])); 00126 } 00127 00128 Vector<2> RosenbrockDerivatives(const Vector<2>& v) 00129 { 00130 double x = v[0]; 00131 double y = v[1]; 00132 00133 Vector<2> ret; 00134 ret[0] = -2+2*x-400*(y-sq(x))*x; 00135 ret[1] = 200*y-200*sq(x); 00136 00137 return ret; 00138 } 00139 00140 int main() 00141 { 00142 ConjugateGradient<2> cg(makeVector(0,0), Rosenbrock, RosenbrockDerivatives); 00143 00144 while(cg.iterate(Rosenbrock, RosenbrockDerivatives)) 00145 cout << "y_" << iteration << " = " << cg.y << endl; 00146 00147 cout << "Optimal value: " << cg.y << endl; 00148 } 00149 @endcode 00150 00151 The chances are that you will want to read the documentation for 00152 ConjugateGradient::ConjugateGradient and ConjugateGradient::iterate. 00153 00154 Linesearch is currently performed using golden-section search and conjugate 00155 vector updates are performed using the Polak-Ribiere equations. There many 00156 tunable parameters, and the internals are readily accessible, so alternative 00157 termination conditions etc can easily be substituted. However, ususally these 00158 will not be necessary. 00159 00160 @ingroup gOptimize 00161 */ 00162 template<int Size, class Precision=double> struct ConjugateGradient 00163 { 00164 const int size; ///< Dimensionality of the space. 00165 Vector<Size> g; ///< Gradient vector used by the next call to iterate() 00166 Vector<Size> h; ///< Conjugate vector to be searched along in the next call to iterate() 00167 Vector<Size> old_g; ///< Gradient vector used to compute $h$ in the last call to iterate() 00168 Vector<Size> old_h; ///< Conjugate vector searched along in the last call to iterate() 00169 Vector<Size> x; ///< Current position (best known point) 00170 Vector<Size> old_x; ///< Previous best known point (not set at construction) 00171 Precision y; ///< Function at \f$x\f$ 00172 Precision old_y; ///< Function at old_x 00173 00174 Precision tolerance; ///< Tolerance used to determine if the optimization is complete. Defaults to square root of machine precision. 00175 Precision epsilon; ///< Additive term in tolerance to prevent excessive iterations if \f$x_\mathrm{optimal} = 0\f$. Known as \c ZEPS in numerical recipies. Defaults to 1e-20 00176 int max_iterations; ///< Maximum number of iterations. Defaults to \c size\f$*100\f$ 00177 00178 Precision bracket_initial_lambda;///< Initial stepsize used in bracketing the minimum for the line search. Defaults to 1. 00179 Precision linesearch_tolerance; ///< Tolerance used to determine if the linesearch is complete. Defaults to square root of machine precision. 00180 Precision linesearch_epsilon; ///< Additive term in tolerance to prevent excessive iterations if \f$x_\mathrm{optimal} = 0\f$. Known as \c ZEPS in numerical recipies. Defaults to 1e-20 00181 int linesearch_max_iterations; ///< Maximum number of iterations in the linesearch. Defaults to 100. 00182 00183 Precision bracket_epsilon; ///<Minimum size for initial minima bracketing. Below this, it is assumed that the system has converged. Defaults to 1e-20. 00184 00185 int iterations; ///< Number of iterations performed 00186 00187 ///Initialize the ConjugateGradient class with sensible values. 00188 ///@param start Starting point, \e x 00189 ///@param func Function \e f to compute \f$f(x)\f$ 00190 ///@param deriv Function to compute \f$\nabla f(x)\f$ 00191 template<class Func, class Deriv> ConjugateGradient(const Vector<Size>& start, const Func& func, const Deriv& deriv) 00192 : size(start.size()), 00193 g(size),h(size),old_g(size),old_h(size),x(start),old_x(size) 00194 { 00195 init(start, func(start), deriv(start)); 00196 } 00197 00198 ///Initialize the ConjugateGradient class with sensible values. 00199 ///@param start Starting point, \e x 00200 ///@param func Function \e f to compute \f$f(x)\f$ 00201 ///@param deriv \f$\nabla f(x)\f$ 00202 template<class Func> ConjugateGradient(const Vector<Size>& start, const Func& func, const Vector<Size>& deriv) 00203 : size(start.size()), 00204 g(size),h(size),old_g(size),old_h(size),x(start),old_x(size) 00205 { 00206 init(start, func(start), deriv); 00207 } 00208 00209 ///Initialize the ConjugateGradient class with sensible values. Used internally. 00210 ///@param start Starting point, \e x 00211 ///@param func \f$f(x)\f$ 00212 ///@param deriv \f$\nabla f(x)\f$ 00213 void init(const Vector<Size>& start, const Precision& func, const Vector<Size>& deriv) 00214 { 00215 00216 using std::numeric_limits; 00217 x = start; 00218 00219 //Start with the conjugate direction aligned with 00220 //the gradient 00221 g = deriv; 00222 h = g; 00223 00224 y = func; 00225 old_y = y; 00226 00227 tolerance = sqrt(numeric_limits<Precision>::epsilon()); 00228 epsilon = 1e-20; 00229 max_iterations = size * 100; 00230 00231 bracket_initial_lambda = 1; 00232 00233 linesearch_tolerance = sqrt(numeric_limits<Precision>::epsilon()); 00234 linesearch_epsilon = 1e-20; 00235 linesearch_max_iterations=100; 00236 00237 bracket_epsilon=1e-20; 00238 00239 iterations=0; 00240 } 00241 00242 00243 ///Perform a linesearch from the current point (x) along the current 00244 ///conjugate vector (h). The linesearch does not make use of derivatives. 00245 ///You probably do not want to use this function. See iterate() instead. 00246 ///This function updates: 00247 /// - x 00248 /// - old_c 00249 /// - y 00250 /// - old_y 00251 /// - iterations 00252 /// Note that the conjugate direction and gradient are not updated. 00253 /// If bracket_minimum_forward detects a local maximum, then essentially a zero 00254 /// sized step is taken. 00255 /// @param func Functor returning the function value at a given point. 00256 template<class Func> void find_next_point(const Func& func) 00257 { 00258 Internal::LineSearch<Size, Precision, Func> line(x, -h, func); 00259 00260 //Always search in the conjugate direction (h) 00261 //First bracket a minimum. 00262 Matrix<3,2,Precision> bracket = Internal::bracket_minimum_forward(y, line, bracket_initial_lambda, bracket_epsilon); 00263 00264 double a = bracket[0][0]; 00265 double b = bracket[1][0]; 00266 double c = bracket[2][0]; 00267 00268 double a_val = bracket[0][1]; 00269 double b_val = bracket[1][1]; 00270 double c_val = bracket[2][1]; 00271 00272 old_y = y; 00273 old_x = x; 00274 iterations++; 00275 00276 //Local maximum achieved! 00277 if(a==0 && b== 0 && c == 0) 00278 return; 00279 00280 //We should have a bracket here 00281 assert(a < b && b < c); 00282 assert(a_val > b_val && b_val < c_val); 00283 00284 //Find the real minimum 00285 Vector<2, Precision> m = brent_line_search(a, b, c, b_val, line, linesearch_max_iterations, linesearch_tolerance, linesearch_epsilon); 00286 00287 assert(m[0] >= a && m[0] <= c); 00288 assert(m[1] <= b_val); 00289 00290 //Update the current position and value 00291 x -= m[0] * h; 00292 y = m[1]; 00293 } 00294 00295 ///Check to see it iteration should stop. You probably do not want to use 00296 ///this function. See iterate() instead. This function updates nothing. 00297 bool finished() 00298 { 00299 using std::abs; 00300 return iterations > max_iterations || 2*abs(y - old_y) <= tolerance * (abs(y) + abs(old_y) + epsilon); 00301 } 00302 00303 ///After an iteration, update the gradient and conjugate using the 00304 ///Polak-Ribiere equations. 00305 ///This function updates: 00306 ///- g 00307 ///- old_g 00308 ///- h 00309 ///- old_h 00310 ///@param grad The derivatives of the function at \e x 00311 void update_vectors_PR(const Vector<Size>& grad) 00312 { 00313 //Update the position, gradient and conjugate directions 00314 old_g = g; 00315 old_h = h; 00316 00317 g = grad; 00318 //Precision gamma = (g * g - oldg*g)/(oldg * oldg); 00319 Precision gamma = (g * g - old_g*g)/(old_g * old_g); 00320 h = g + gamma * old_h; 00321 } 00322 00323 ///Use this function to iterate over the optimization. Note that after 00324 ///iterate returns false, g, h, old_g and old_h will not have been 00325 ///updated. 00326 ///This function updates: 00327 /// - x 00328 /// - old_c 00329 /// - y 00330 /// - old_y 00331 /// - iterations 00332 /// - g* 00333 /// - old_g* 00334 /// - h* 00335 /// - old_h* 00336 /// *'d variables not updated on the last iteration. 00337 ///@param func Functor returning the function value at a given point. 00338 ///@param deriv Functor to compute derivatives at the specified point. 00339 ///@return Whether to continue. 00340 template<class Func, class Deriv> bool iterate(const Func& func, const Deriv& deriv) 00341 { 00342 find_next_point(func); 00343 00344 if(!finished()) 00345 { 00346 update_vectors_PR(deriv(x)); 00347 return 1; 00348 } 00349 else 00350 return 0; 00351 } 00352 }; 00353 00354 }