00001 #include <TooN/optimization/brent.h>
00002 #include <utility>
00003 #include <cmath>
00004 #include <cassert>
00005 #include <cstdlib>
00006
00007 namespace TooN{
00008
00009 namespace Internal{
00010
00011
00018 template<int Size, typename Precision, typename Func> struct LineSearch
00019 {
00020 const Vector<Size, Precision>& start;
00021 const Vector<Size, Precision>& direction;
00022
00023 const Func& f;
00024
00029 LineSearch(const Vector<Size, Precision>& s, const Vector<Size, Precision>& d, const Func& func)
00030 :start(s),direction(d),f(func)
00031 {}
00032
00035 Precision operator()(Precision x) const
00036 {
00037 return f(start + x * direction);
00038 }
00039 };
00040
00050 template<typename Precision, typename Func> Matrix<3,2,Precision> bracket_minimum_forward(Precision a_val, const Func& func, Precision initial_lambda=1)
00051 {
00052
00053 Precision a, b, c, b_val, c_val;
00054
00055 a=0;
00056
00057
00058 Precision lambda=initial_lambda;
00059 b = lambda;
00060 b_val = func(b);
00061
00062 if(b_val < a_val)
00063 {
00064 for(;;)
00065 {
00066 lambda *= 2;
00067 c = lambda;
00068 c_val = func(c);
00069
00070 if(c_val > b_val)
00071 break;
00072 else
00073 {
00074 a = b;
00075 a_val = b_val;
00076 b=c;
00077 b_val=c_val;
00078
00079 }
00080 }
00081 }
00082 else
00083 {
00084 c = b;
00085 c_val = b_val;
00086
00087
00088 for(;;)
00089 {
00090 lambda *= .5;
00091 b = lambda;
00092 b_val = func(b);
00093
00094 if(b_val < a_val)
00095 break;
00096 else
00097 {
00098 c = b;
00099 c_val = b_val;
00100 }
00101 }
00102 }
00103
00104 Matrix<3,2> ret;
00105 ret[0] = makeVector(a, a_val);
00106 ret[1] = makeVector(b, b_val);
00107 ret[2] = makeVector(c, c_val);
00108
00109 return ret;
00110 }
00111
00112 }
00113
00114
00159 template<int Size, class Precision=double> struct ConjugateGradient
00160 {
00161 const int size;
00162 Vector<Size> g;
00163 Vector<Size> h;
00164 Vector<Size> old_g;
00165 Vector<Size> old_h;
00166 Vector<Size> x;
00167 Vector<Size> old_x;
00168 Precision y;
00169 Precision old_y;
00170
00171 Precision tolerance;
00172 Precision epsilon;
00173 int max_iterations;
00174
00175 Precision bracket_initial_lambda;
00176 Precision linesearch_tolerance;
00177 Precision linesearch_epsilon;
00178 int linesearch_max_iterations;
00179
00180 int iterations;
00181
00186 template<class Func, class Deriv> ConjugateGradient(const Vector<Size>& start, const Func& func, const Deriv& deriv)
00187 : size(start.size()),
00188 g(size),h(size),old_g(size),old_h(size),x(start),old_x(size)
00189 {
00190 using std::numeric_limits;
00191
00192 x = start;
00193
00194
00195
00196 g = deriv(x);
00197 h = g;
00198
00199 y = func(x);
00200 old_y = y;
00201
00202 tolerance = sqrt(numeric_limits<Precision>::epsilon());
00203 epsilon = 1e-20;
00204 max_iterations = size * 100;
00205
00206 bracket_initial_lambda = 1;
00207
00208 linesearch_tolerance = sqrt(numeric_limits<Precision>::epsilon());
00209 linesearch_epsilon = 1e-20;
00210 linesearch_max_iterations=100;
00211
00212 iterations=0;
00213 }
00214
00215
00227 template<class Func> void find_next_point(const Func& func)
00228 {
00229 Internal::LineSearch<Size, Precision, Func> line(x, -h, func);
00230
00231
00232
00233
00234 Matrix<3,2,Precision> bracket = Internal::bracket_minimum_forward(y, line, bracket_initial_lambda);
00235
00236 double a = bracket[0][0];
00237 double b = bracket[1][0];
00238 double c = bracket[2][0];
00239
00240 double a_val = bracket[0][1];
00241 double b_val = bracket[1][1];
00242 double c_val = bracket[2][1];
00243
00244
00245 assert(a < b && b < c);
00246 assert(a_val > b_val && b_val < c_val);
00247
00248
00249 Vector<2, Precision> m = brent_line_search(a, b, c, b_val, line, linesearch_max_iterations, linesearch_tolerance, linesearch_epsilon);
00250
00251 assert(m[0] >= a && m[0] <= c);
00252 assert(m[1] <= b_val);
00253
00254
00255 old_y = y;
00256 old_x = x;
00257
00258 x -= m[0] * h;
00259 y = m[1];
00260
00261 iterations++;
00262 }
00263
00266 bool finished()
00267 {
00268 using std::abs;
00269 return iterations > max_iterations || 2*abs(y - old_y) <= tolerance * (abs(y) + abs(old_y) + epsilon);
00270 }
00271
00279 template<class Deriv> void update_vectors_PR(const Deriv& deriv)
00280 {
00281
00282 old_g = g;
00283 old_h = h;
00284
00285 g = deriv(x);
00286
00287 Precision gamma = (g * g - old_g*g)/(old_g * old_g);
00288 h = g + gamma * old_h;
00289 }
00290
00308 template<class Func, class Deriv> bool iterate(const Func& func, const Deriv& deriv)
00309 {
00310 find_next_point(func);
00311
00312 if(!finished())
00313 {
00314 update_vectors_PR(deriv);
00315 return 1;
00316 }
00317 else
00318 return 0;
00319 }
00320 };
00321
00322 }