TooN 2.1
optimization/downhill_simplex.h
00001 #ifndef TOON_DOWNHILL_SIMPLEX_H
00002 #define TOON_DOWNHILL_SIMPLEX_H
00003 #include <TooN/TooN.h>
00004 #include <TooN/helpers.h>
00005 #include <algorithm>
00006 #include <cstdlib>
00007 
00008 namespace TooN
00009 {
00010 
00011 /** This is an implementation of the Downhill Simplex (Nelder & Mead, 1965)
00012     algorithm. This particular instance will minimize a given function.
00013     
00014     The function maintains \f$N+1\f$ points for a $N$ dimensional function, \f$f\f$
00015     
00016     At each iteration, the following algorithm is performed:
00017     - Find the worst (largest) point, \f$x_w\f$.
00018     - Find the centroid of the remaining points, \f$x_0\f$.
00019     - Let \f$v = x_0 - x_w\f$
00020     - Compute a reflected point, \f$ x_r = x_0 + \alpha v\f$
00021     - If \f$f(x_r)\f$ is better than the best point
00022       - Expand the simplex by extending the reflection to \f$x_e = x_0 + \rho \alpha v \f$
00023       - Replace \f$x_w\f$ with the best point of \f$x_e\f$,  and \f$x_r\f$.
00024     - Else, if  \f$f(x_r)\f$ is between the best and second-worst point
00025       - Replace \f$x_w\f$ with \f$x_r\f$.
00026     - Else, if  \f$f(x_r)\f$ is better than \f$x_w\f$
00027       - Contract the simplex by computing \f$x_c = x_0 + \gamma v\f$
00028       - If \f$f(x_c) < f(x_r)\f$
00029         - Replace \f$x_w\f$ with \f$x_c\f$.
00030     - If \f$x_w\f$ has not been replaced, then shrink the simplex by a factor of \f$\sigma\f$ around the best point.
00031 
00032     This implementation uses:
00033     - \f$\alpha = 1\f$
00034     - \f$\rho = 2\f$
00035     - \f$\gamma = 1/2\f$
00036     - \f$\sigma = 1/2\f$
00037     
00038     Example usage:
00039     @code
00040 #include <TooN/optimization/downhill_simplex.h>
00041 using namespace std;
00042 using namespace TooN;
00043 
00044 double sq(double x)
00045 {
00046     return x*x;
00047 }
00048 
00049 double Rosenbrock(const Vector<2>& v)
00050 {
00051         return sq(1 - v[0]) + 100 * sq(v[1] - sq(v[0]));
00052 }
00053 
00054 int main()
00055 {
00056         Vector<2> starting_point = makeVector( -1, 1);
00057 
00058         DownhillSimplex<2> dh_fixed(Rosenbrock, starting_point, 1);
00059 
00060         while(dh_fixed.iterate(Rosenbrock))
00061         {
00062             cout << dh.get_values()[dh.get_best()] << endl;
00063         }
00064         
00065         cout << dh_fixed.get_simplex()[dh_fixed.get_best()] << endl;
00066 }
00067 
00068     @endcode
00069 
00070 
00071     @ingroup gOptimize
00072     @param   N The dimension of the function to optimize. As usual, the default value of <i>N</i> (-1) indicates
00073              that the class is sized at run-time.
00074 
00075 
00076 **/
00077 template<int N=-1, typename Precision=double> class DownhillSimplex
00078 {
00079     static const int Vertices = (N==-1?-1:N+1);
00080     typedef Matrix<Vertices, N, Precision> Simplex;
00081     typedef Vector<Vertices, Precision> Values;
00082 
00083     public:
00084         /// Initialize the DownhillSimplex class. The simplex is automatically
00085         /// generated. One point is at <i>c</i>, the remaining points are made by moving
00086         /// <i>c</i> by <i>spread</i> along each axis aligned unit vector.
00087         ///
00088         ///@param func       Functor to minimize.
00089         ///@param c          Origin of the initial simplex. The dimension of this vector
00090         ///                  is used to determine the dimension of the run-time sized version.
00091         ///@param spread     Size of the initial simplex.
00092         template<class Function> DownhillSimplex(const Function& func, const Vector<N>& c, Precision spread=1)
00093         :simplex(c.size()+1, c.size()),values(c.size()+1)
00094         {
00095             alpha = 1.0;
00096             rho = 2.0;
00097             gamma = 0.5;
00098             sigma = 0.5;
00099 
00100             using std::sqrt;
00101             epsilon = sqrt(numeric_limits<Precision>::epsilon());
00102             zero_epsilon = 1e-20;
00103 
00104             restart(func, c, spread);
00105         }
00106         
00107         /// This function sets up the simplex around, with one point at \e c and the remaining
00108         /// points are made by moving by \e spread along each axis aligned unit vector.
00109         ///
00110         ///@param func       Functor to minimize.
00111         ///@param c          \e c corner point of the simplex
00112         ///@param spread     \e spread simplex size
00113         template<class Function> void restart(const Function& func, const Vector<N>& c, Precision spread)
00114         {
00115             for(int i=0; i < simplex.num_rows(); i++)
00116                 simplex[i] = c;
00117 
00118             for(int i=0; i < simplex.num_cols(); i++)
00119                 simplex[i][i] += spread;
00120 
00121             for(int i=0; i < values.size(); i++)
00122                 values[i] = func(simplex[i]);
00123         }
00124         
00125         ///Check to see it iteration should stop. You probably do not want to use
00126         ///this function. See iterate() instead. This function updates nothing.
00127         ///The termination criterion is that the simplex span (distancve between
00128         ///the best and worst vertices) is small compared to the scale or 
00129         ///small overall.
00130         bool finished()
00131         {
00132             Precision span =  norm(simplex[get_best()] - simplex[get_worst()]);
00133             Precision scale = norm(simplex[get_best()]);
00134 
00135             if(span/scale < epsilon || span < zero_epsilon)
00136                 return 1;
00137             else 
00138                 return 0;
00139         }
00140         
00141         /// This function resets the simplex around the best current point.
00142         ///
00143         ///@param func       Functor to minimize.
00144         ///@param spread     simplex size
00145         template<class Function> void restart(const Function& func, Precision spread)
00146         {
00147             restart(func, simplex[get_best()], spread);
00148         }
00149 
00150         ///Return the simplex
00151         const Simplex& get_simplex() const
00152         {
00153             return simplex;
00154         }
00155         
00156         ///Return the score at the vertices
00157         const Values& get_values() const
00158         {
00159             return values;
00160         }
00161         
00162         ///Get the index of the best vertex
00163         int get_best() const 
00164         {
00165             return std::min_element(&values[0], &values[0] + values.size()) - &values[0];
00166         }
00167         
00168         ///Get the index of the worst vertex
00169         int get_worst() const 
00170         {
00171             return std::max_element(&values[0], &values[0] + values.size()) - &values[0];
00172         }
00173 
00174         ///Perform one iteration of the downhill Simplex algorithm
00175         ///@param func Functor to minimize
00176         template<class Function> void find_next_point(const Function& func)
00177         {
00178             //Find various things:
00179             // - The worst point
00180             // - The second worst point
00181             // - The best point
00182             // - The centroid of all the points but the worst
00183             int worst = get_worst();
00184             Precision second_worst_val=-HUGE_VAL, bestval = HUGE_VAL, worst_val = values[worst];
00185             int best=0;
00186             Vector<N> x0 = Zeros(simplex.num_cols());
00187 
00188 
00189             for(int i=0; i < simplex.num_rows(); i++)
00190             {
00191                 if(values[i] < bestval)
00192                 {
00193                     bestval = values[i];
00194                     best = i;
00195                 }
00196 
00197                 if(i != worst)
00198                 {
00199                     if(values[i] > second_worst_val)
00200                         second_worst_val = values[i];
00201 
00202                     //Compute the centroid of the non-worst points;
00203                     x0 += simplex[i];
00204                 }
00205             }
00206             x0 *= 1.0 / simplex.num_cols();
00207 
00208 
00209             //Reflect the worst point about the centroid.
00210             Vector<N> xr = (1 + alpha) * x0 - alpha * simplex[worst];
00211             Precision fr = func(xr);
00212 
00213             if(fr < bestval)
00214             {
00215                 //If the new point is better than the smallest, then try expanding the simplex.
00216                 Vector<N> xe = rho * xr + (1-rho) * x0;
00217                 Precision fe = func(xe);
00218 
00219                 //Keep whichever is best
00220                 if(fe < fr)
00221                 {
00222                     simplex[worst] = xe;
00223                     values[worst] = fe;
00224                 }
00225                 else
00226                 {
00227                     simplex[worst] = xr;
00228                     values[worst] = fr;
00229                 }
00230 
00231                 return;
00232             }
00233 
00234             //Otherwise, if the new point lies between the other points
00235             //then keep it and move on to the next iteration.
00236             if(fr < second_worst_val)
00237             {
00238                 simplex[worst] = xr;
00239                 values[worst] = fr;
00240                 return;
00241             }
00242 
00243 
00244             //Otherwise, if the new point is a bit better than the worst point,
00245             //(ie, it's got just a little bit better) then contract the simplex
00246             //a bit.
00247             if(fr < worst_val)
00248             {
00249                 Vector<N> xc = (1 + gamma) * x0 - gamma * simplex[worst];
00250                 Precision fc = func(xc);
00251 
00252                 //If this helped, use it
00253                 if(fc <= fr)
00254                 {
00255                     simplex[worst] = xc;
00256                     values[worst] = fc;
00257                     return;
00258                 }
00259             }
00260             
00261             //Otherwise, fr is worse than the worst point, or the fc was worse
00262             //than fr. So shrink the whole simplex around the best point.
00263             for(int i=0; i < simplex.num_rows(); i++)
00264                 if(i != best)
00265                 {
00266                     simplex[i] = simplex[best] + sigma * (simplex[i] - simplex[best]);
00267                     values[i] = func(simplex[i]);
00268                 }
00269         }
00270 
00271         ///Perform one iteration of the downhill Simplex algorithm, and return the result
00272         ///of not DownhillSimplex::finished.
00273         ///@param func Functor to minimize
00274         template<class Function> bool iterate(const Function& func)
00275         {
00276             find_next_point(func);
00277             return !finished();
00278         }
00279 
00280         Precision alpha; ///< Reflected size. Defaults to 1.
00281         Precision rho;   ///< Expansion ratio. Defaults to 2.
00282         Precision gamma; ///< Contraction ratio. Defaults to .5.
00283         Precision sigma; ///< Shrink ratio. Defaults to .5.
00284         Precision epsilon;  ///< Tolerance used to determine if the optimization is complete. Defaults to square root of machine precision.
00285         Precision zero_epsilon; ///< Additive term in tolerance to prevent excessive iterations if \f$x_\mathrm{optimal} = 0\f$. Known as \c ZEPS in numerical recipies. Defaults to 1e-20
00286 
00287     private:
00288 
00289         //Each row is a simplex vertex
00290         Simplex simplex;
00291 
00292         //Function values for each vertex
00293         Values values;
00294 
00295 
00296 };
00297 }
00298 #endif