00001 #ifndef TOON_DOWNHILL_SIMPLEX_H
00002 #define TOON_DOWNHILL_SIMPLEX_H
00003 #include <TooN/TooN.h>
00004 #include <TooN/helpers.h>
00005 #include <algorithm>
00006 #include <vector>
00007 #include <math.h>
00008
00009 namespace TooN
00010 {
00011
00012 template<int D> struct DSBase
00013 {
00014 typedef Vector<D> Vec;
00015 typedef Vector<D+1> Values;
00016 typedef std::vector<Vector<D> > Simplex;
00017
00018 static const int Dim = D;
00019
00020
00021 DSBase(int) { };
00022
00023 void resize_simplex(Simplex&s) {
00024 s.resize(Dim+1);
00025 }
00026 void resize_values(Values&) {}
00027 void resize_vector(Vec&) {}
00028
00029 };
00030
00031 template<> struct DSBase<-1>
00032 {
00033 typedef Vector<> Vec;
00034 typedef Vector<> Values;
00035 typedef std::vector<Vector<> > Simplex;
00036 int Dim;
00037
00038 DSBase(int d)
00039 {
00040 Dim = d;
00041 };
00042
00043 void resize_simplex(Simplex& s)
00044 {
00045 s.resize(Dim+1, Vector<>(Dim));
00046 }
00047
00048 void resize_values(Values& v)
00049 {
00050 v.resize(Dim+1);
00051 }
00052
00053 void resize_vector(Vec& v)
00054 {
00055 v.resize(Dim);
00056 }
00057 };
00058
00127 template<int N=-1> class DownhillSimplex: public DSBase<N>
00128 {
00129 typedef typename DSBase<N>::Vec Vec;
00130 typedef typename DSBase<N>::Values Values;
00131 typedef typename DSBase<N>::Simplex Simplex;
00132
00133 using DSBase<N>::Dim;
00134
00135 public:
00144 template<class Function> DownhillSimplex(const Function& func, const Vec& c, double spread=1)
00145 :DSBase<N>(c.size())
00146 {
00147 resize_simplex(simplex);
00148 resize_values(values);
00149
00150 for(int i=0; i < Dim+1; i++)
00151 simplex[i] = c;
00152
00153 for(int i=0; i < Dim; i++)
00154 simplex[i][i] += spread;
00155
00156 alpha = 1.0;
00157 rho = 2.0;
00158 gamma = 0.5;
00159 sigma = 0.5;
00160
00161 for(int i=0; i < Dim+1; i++)
00162 values[i] = func(simplex[i]);
00163 }
00164
00166 const Simplex& get_simplex() const
00167 {
00168 return simplex;
00169 }
00170
00172 const Values& get_values() const
00173 {
00174 return values;
00175 }
00176
00178 int get_best() const
00179 {
00180 return min_element(values.begin(), values.end()) - values.begin();
00181 }
00182
00184 int get_worst() const
00185 {
00186 return max_element(values.begin(), values.end()) - values.begin();
00187 }
00188
00191 template<class Function> void iterate(const Function& func)
00192 {
00193
00194
00195
00196
00197
00198 int worst = get_worst();
00199 double second_worst_val=-HUGE_VAL, bestval = HUGE_VAL, worst_val = values[worst];
00200 int best=0;
00201 Vec x0;
00202 resize_vector(x0);
00203 Zero(x0);
00204
00205
00206 for(int i=0; i < Dim+1; i++)
00207 {
00208 if(values[i] < bestval)
00209 {
00210 bestval = values[i];
00211 best = i;
00212 }
00213
00214 if(i != worst)
00215 {
00216 if(values[i] > second_worst_val)
00217 second_worst_val = values[i];
00218
00219
00220 x0 += simplex[i];
00221 }
00222 }
00223 x0 *= 1.0 / Dim;
00224
00225
00226
00227 Vec xr = (1 + alpha) * x0 - alpha * simplex[worst];
00228 double fr = func(xr);
00229
00230 if(fr < bestval)
00231 {
00232
00233 Vec xe = rho * xr + (1-rho) * x0;
00234 double fe = func(xe);
00235
00236
00237 if(fe < fr)
00238 {
00239 simplex[worst] = xe;
00240 values[worst] = fe;
00241 }
00242 else
00243 {
00244 simplex[worst] = xr;
00245 values[worst] = fr;
00246 }
00247
00248 return;
00249 }
00250
00251
00252
00253 if(fr < second_worst_val)
00254 {
00255 simplex[worst] = xr;
00256 values[worst] = fr;
00257 return;
00258 }
00259
00260
00261
00262
00263
00264 if(fr < worst_val)
00265 {
00266 Vec xc = (1 + gamma) * x0 - gamma * simplex[worst];
00267 double fc = func(xc);
00268
00269
00270 if(fc <= fr)
00271 {
00272 simplex[worst] = xc;
00273 values[worst] = fc;
00274 return;
00275 }
00276 }
00277
00278
00279
00280 for(int i=0; i < Dim+1; i++)
00281 if(i != best)
00282 {
00283 simplex[i] = simplex[best] + sigma * (simplex[i] - simplex[best]);
00284 values[i] = func(simplex[i]);
00285 }
00286 }
00287
00288
00289
00290
00291
00292 private:
00293 float alpha, rho, gamma, sigma;
00294
00295
00296 Simplex simplex;
00297
00298
00299 Values values;
00300
00301
00302 };
00303 }
00304 #endif