optimization/downhill_simplex.h

00001 #ifndef TOON_DOWNHILL_SIMPLEX_H
00002 #define TOON_DOWNHILL_SIMPLEX_H
00003 #include <TooN/TooN.h>
00004 #include <TooN/helpers.h>
00005 #include <algorithm>
00006 #include <vector>
00007 #include <math.h>
00008 
00009 namespace TooN
00010 {
00011 
00012 template<int D> struct DSBase
00013 {
00014     typedef Vector<D> Vec;
00015     typedef Vector<D+1> Values;
00016     typedef std::vector<Vector<D> > Simplex;
00017 
00018     static const int Dim = D;
00019 
00020 
00021     DSBase(int) { };
00022 
00023     void resize_simplex(Simplex&s) {
00024         s.resize(Dim+1);
00025     }
00026     void resize_values(Values&) {}
00027     void resize_vector(Vec&) {}
00028 
00029 };
00030 
00031 template<> struct DSBase<-1>
00032 {
00033     typedef Vector<> Vec;
00034     typedef Vector<> Values;
00035     typedef std::vector<Vector<> > Simplex;
00036     int Dim;
00037 
00038     DSBase(int d)
00039     {
00040         Dim = d;
00041     };
00042 
00043     void resize_simplex(Simplex& s)
00044     {
00045         s.resize(Dim+1, Vector<>(Dim));
00046     }
00047 
00048     void resize_values(Values& v)
00049     {
00050         v.resize(Dim+1);
00051     }
00052 
00053     void resize_vector(Vec& v)
00054     {
00055         v.resize(Dim);
00056     }
00057 };
00058 
00127 template<int N=-1> class DownhillSimplex: public DSBase<N>
00128 {
00129     typedef typename DSBase<N>::Vec Vec;
00130     typedef typename DSBase<N>::Values Values;
00131     typedef typename DSBase<N>::Simplex Simplex;
00132 
00133     using DSBase<N>::Dim;
00134 
00135     public:
00144         template<class Function> DownhillSimplex(const Function& func, const Vec& c, double spread=1)
00145         :DSBase<N>(c.size())
00146         {
00147             resize_simplex(simplex);
00148             resize_values(values);
00149 
00150             for(int i=0; i < Dim+1; i++)
00151                 simplex[i] = c;
00152 
00153             for(int i=0; i < Dim; i++)
00154                 simplex[i][i] += spread;
00155 
00156             alpha = 1.0;
00157             rho = 2.0;
00158             gamma = 0.5;
00159             sigma = 0.5;
00160 
00161             for(int i=0; i < Dim+1; i++)
00162                 values[i] = func(simplex[i]);
00163         }
00164 
00166         const Simplex& get_simplex() const
00167         {
00168             return simplex;
00169         }
00170         
00172         const Values& get_values() const
00173         {
00174             return values;
00175         }
00176         
00178         int get_best() const 
00179         {
00180             return min_element(values.begin(), values.end()) - values.begin();
00181         }
00182         
00184         int get_worst() const 
00185         {
00186             return max_element(values.begin(), values.end()) - values.begin();
00187         }
00188 
00191         template<class Function> void iterate(const Function& func)
00192         {
00193             //Find various things:
00194             // - The worst point
00195             // - The second worst point
00196             // - The best point
00197             // - The centroid of all the points but the worst
00198             int worst = get_worst();
00199             double second_worst_val=-HUGE_VAL, bestval = HUGE_VAL, worst_val = values[worst];
00200             int best=0;
00201             Vec x0;
00202             resize_vector(x0);
00203             Zero(x0);
00204 
00205 
00206             for(int i=0; i < Dim+1; i++)
00207             {
00208                 if(values[i] < bestval)
00209                 {
00210                     bestval = values[i];
00211                     best = i;
00212                 }
00213 
00214                 if(i != worst)
00215                 {
00216                     if(values[i] > second_worst_val)
00217                         second_worst_val = values[i];
00218 
00219                     //Compute the centroid of the non-worst points;
00220                     x0 += simplex[i];
00221                 }
00222             }
00223             x0 *= 1.0 / Dim;
00224 
00225 
00226             //Reflect the worst point about the centroid.
00227             Vec xr = (1 + alpha) * x0 - alpha * simplex[worst];
00228             double fr = func(xr);
00229 
00230             if(fr < bestval)
00231             {
00232                 //If the new point is better than the smallest, then try expanding the simplex.
00233                 Vec xe = rho * xr + (1-rho) * x0;
00234                 double fe = func(xe);
00235 
00236                 //Keep whichever is best
00237                 if(fe < fr)
00238                 {
00239                     simplex[worst] = xe;
00240                     values[worst] = fe;
00241                 }
00242                 else
00243                 {
00244                     simplex[worst] = xr;
00245                     values[worst] = fr;
00246                 }
00247 
00248                 return;
00249             }
00250 
00251             //Otherwise, if the new point lies between the other points
00252             //then keep it and move on to the next iteration.
00253             if(fr < second_worst_val)
00254             {
00255                 simplex[worst] = xr;
00256                 values[worst] = fr;
00257                 return;
00258             }
00259 
00260 
00261             //Otherwise, if the new point is a bit better than the worst point,
00262             //(ie, it's got just a little bit better) then contract the simplex
00263             //a bit.
00264             if(fr < worst_val)
00265             {
00266                 Vec xc = (1 + gamma) * x0 - gamma * simplex[worst];
00267                 double fc = func(xc);
00268 
00269                 //If this helped, use it
00270                 if(fc <= fr)
00271                 {
00272                     simplex[worst] = xc;
00273                     values[worst] = fc;
00274                     return;
00275                 }
00276             }
00277             
00278             //Otherwise, fr is worse than the worst point, or the fc was worse
00279             //than fr. So shrink the whole simplex around the best point.
00280             for(int i=0; i < Dim+1; i++)
00281                 if(i != best)
00282                 {
00283                     simplex[i] = simplex[best] + sigma * (simplex[i] - simplex[best]);
00284                     values[i] = func(simplex[i]);
00285                 }
00286         }
00287 
00288 
00289 
00290 
00291 
00292     private:
00293         float alpha, rho, gamma, sigma;
00294 
00295         //Each row is a simplex vertex
00296         Simplex simplex;
00297 
00298         //Function values for each vertex
00299         Values  values;
00300 
00301 
00302 };
00303 }
00304 #endif

Generated on Thu May 7 20:28:40 2009 for TooN by  doxygen 1.5.3