learn_fast_tree.cc File Reference


Detailed Description

Main file for the learn_fast_tree executable.

learn_fast_tree [--weight.x weight ] ... < infile > outfile

learn_fast_tree used ID3 to learn a ternary decision tree for corner detection. The data is read from the standard input, and the tree is written to the standard output. This is designed to learn FAST feature detectors, and does not allow for the possibility ambbiguity in the input data.

Input data

The input data has the following format:


5
[-1 -1] [1 1] [3 4] [5 6] [-3 4]
bbbbb 1     0
bsdsb 1000  1
.
.
.

The first row is the number of features. The second row is the the list of offsets assosciated with each feature. This list has no effect on the learning of the tree, but it is passed through to the outpur for convinience.

The remaining rows contain the data. The first field is the ternary feature vector. The three characters "b", "d" and "s" are the correspond to brighter, darker and similar respectively, with the first feature being stored in the first character and so on.

The next field is the number of instances of the particular feature. The third field is the class, with 1 for corner, and 0 for background.

Generating input data

Ideally, input data will be generated from some sample images. The program FIXME can be used to do this.

Additionally, a the program fast_N_features can be used to generate all possible feature combinations for FAST-N features. When run without arguments, it generates data for FAST-9 features, otherwise the argument can be used to specify N.

Output data

The program does not generate source code directly, rather it generates an easily parsabel representation of a decision tree which can be turned in to source code.

The structure of the tree is described in detail in print_tree.

Definition in file learn_fast_tree.cc.

Go to the source code of this file.

Classes

struct  datapoint< FEATURE_SIZE >
 This structure represents a datapoint. More...
struct  tree
 This class represents a decision tree. More...

Defines

#define fatal(E, S,...)   vfatal((E), (S), (tag::Fmt,## __VA_ARGS__))

Enumerations

enum  Ternary { Brighter = 'b', Darker = 'd', Similar = 's' }

Functions

template<class C>
void vfatal (int err, const string &s, const C &list)
template<int S>
V_tuple< shared_ptr
< vector< datapoint
< S > > >, uint64_t >
::type 
load_features (unsigned int nfeats)
double entropy (uint64_t n, uint64_t c1)
template<int S>
int find_best_split (const vector< datapoint< S > > &fs, const vector< double > &weights, unsigned int nfeats)
template<int S>
shared_ptr< treebuild_tree (vector< datapoint< S > > &corners, const vector< double > &weights, int nfeats)
void print_tree (const tree *node, ostream &o, const string &i="")
template<int S>
V_tuple< shared_ptr
< tree >, uint64_t >
::type 
load_and_build_tree (unsigned int num_features, const vector< double > &weights)
int main (int argc, char **argv)


Enumeration Type Documentation

enum Ternary

Representations of ternary digits.

Enumerator:
Brighter 
Darker 
Similar 

Definition at line 114 of file learn_fast_tree.cc.

00115 {
00116     Brighter='b',
00117     Darker  ='d',
00118     Similar ='s'
00119 };


Function Documentation

template<int S>
V_tuple<shared_ptr<vector<datapoint<S> > >, uint64_t >::type load_features ( unsigned int  nfeats  )  [inline]

This function loads as many datapoints from the standard input as possible.

Datapoints consist of a feature vector (a string containing the characters "b", "d" and "s"), a number of instances and a class.

See datapoint::pack_trits for a more complete description of the feature vector.

The tokens are whitespace separated.

Parameters:
nfeats Number of features in a feature vector. Used to spot errors.
Returns:
Loaded datapoints and total number of instances.

Definition at line 246 of file learn_fast_tree.cc.

References fatal.

00247 {
00248     shared_ptr<vector<datapoint<S> > > ret(new vector<datapoint<S> >);
00249 
00250 
00251     string unpacked_feature;
00252     
00253     uint64_t total_num = 0;
00254 
00255     uint64_t line_num=2;
00256 
00257     for(;;)
00258     {
00259         uint64_t count;
00260         bool is;
00261 
00262         cin >> unpacked_feature >> count >> is;
00263 
00264         if(!cin)
00265             break;
00266 
00267         line_num++;
00268 
00269         if(unpacked_feature.size() != nfeats)
00270             fatal(1, "Feature string length is %i, not %i on line %i", unpacked_feature.size(), nfeats, line_num);
00271 
00272         if(count == 0)
00273             fatal(4, "Zero count is invalid");
00274 
00275         ret->push_back(datapoint<S>(unpacked_feature, count, is));
00276 
00277         total_num += count;
00278     }
00279 
00280     cerr << "Num features: " << total_num << endl
00281          << "Num distinct: " << ret->size() << endl;
00282 
00283     return make_vtuple(ret, total_num);
00284 }

double entropy ( uint64_t  n,
uint64_t  c1 
)

Compute the entropy of a set with binary annotations.

Parameters:
n Number of elements in the set
c1 Number of elements in class 1
Returns:
The set entropy.

Definition at line 291 of file learn_fast_tree.cc.

Referenced by find_best_split().

00292 {
00293     assert(c1 <= n);
00294     //n is total number, c1 in num in class 1
00295     if(n == 0)
00296         return 0;
00297     else if(c1 == 0 || c1 == n)
00298         return 0;
00299     else
00300     {
00301         double p1 = (double)c1 / n;
00302         double p2 = 1-p1;
00303 
00304         return -(double)n*(p1*log(p1) + p2*log(p2)) / log(2.f);
00305     }
00306 }

template<int S>
int find_best_split ( const vector< datapoint< S > > &  fs,
const vector< double > &  weights,
unsigned int  nfeats 
) [inline]

Find the feature that has the highest weighted entropy change.

Parameters:
fs datapoints to split in to three subsets.
weights weights on features
nfeats Number of features in use.
Returns:
best feature.

Definition at line 313 of file learn_fast_tree.cc.

References Brighter, Darker, entropy(), fatal, and Similar.

00314 {
00315     assert(nfeats == weights.size());
00316     uint64_t num_total = 0, num_corners=0;
00317 
00318     for(typename vector<datapoint<S> >::const_iterator i=fs.begin(); i != fs.end(); i++)
00319     {
00320         num_total += i->count;
00321         if(i->is_a_corner)
00322             num_corners += i->count;
00323     }
00324 
00325     double total_entropy = entropy(num_total, num_corners);
00326     
00327     double biggest_delta = 0;
00328     int   feature_num = -1;
00329 
00330     for(unsigned int i=0; i < nfeats; i++)
00331     {
00332         uint64_t num_bri = 0, num_dar = 0, num_sim = 0;
00333         uint64_t cor_bri = 0, cor_dar = 0, cor_sim = 0;
00334 
00335         for(typename vector<datapoint<S> >::const_iterator f=fs.begin(); f != fs.end(); f++)
00336         {
00337             switch(f->get_trit(i))
00338             {
00339                 case Brighter:
00340                     num_bri += f->count;
00341                     if(f->is_a_corner)
00342                         cor_bri += f->count;
00343                     break;
00344 
00345                 case Darker:
00346                     num_dar += f->count;
00347                     if(f->is_a_corner)
00348                         cor_dar += f->count;
00349                     break;
00350 
00351                 case Similar:
00352                     num_sim += f->count;
00353                     if(f->is_a_corner)
00354                         cor_sim += f->count;
00355                     break;
00356             }
00357         }
00358 
00359         double delta_e = total_entropy - (entropy(num_bri, cor_bri) + entropy(num_dar, cor_dar) + entropy(num_sim, cor_sim));
00360 
00361         delta_e *= weights[i];
00362 
00363         if(delta_e > biggest_delta)
00364         {       
00365             biggest_delta = delta_e;
00366             feature_num = i;
00367         }   
00368     }
00369 
00370     if(feature_num == -1)
00371         fatal(3, "Couldn't find a split.");
00372 
00373     return feature_num;
00374 }

template<int S>
shared_ptr<tree> build_tree ( vector< datapoint< S > > &  corners,
const vector< double > &  weights,
int  nfeats 
) [inline]

This function uses ID3 to construct a decision tree.

The entropy changes are weighted by the list of weights, to allow bias towards certain features. This function assumes that the class is an exact function of the data. If there datapoints with different classes share the same feature vector, the program will crash with error code 3.

Parameters:
corners Datapoints in this part of the subtree to classify
weights Weights on the features
nfeats Number of features actually used
Returns:
The tree required to classify corners

Definition at line 468 of file learn_fast_tree.cc.

References Brighter, tree::CornerLeaf(), Darker, tree::NonCornerLeaf(), and Similar.

00469 {
00470     //Find the split
00471     int f = find_best_split<S>(corners, weights, nfeats);
00472 
00473     //Split corners in to the three chunks, based on the result of find_best_split.
00474     //Also, count how many of each class ends up in each of the three bins.
00475     //It may apper to be inefficient to use a vector here instead of a list, in terms
00476     //of memory, but the per-element storage overhead of the list is such that it uses
00477     //considerably more memory and is much slower.
00478     vector<datapoint<S> > brighter, darker, similar;
00479     uint64_t num_bri=0, cor_bri=0, num_dar=0, cor_dar=0, num_sim=0, cor_sim=0;
00480 
00481     for(size_t i=0; i < corners.size(); i++)
00482     {
00483         switch(corners[i].get_trit(f))
00484         {
00485             case Brighter:
00486                 brighter.push_back(corners[i]);
00487                 num_bri += corners[i].count;
00488                 if(corners[i].is_a_corner)
00489                     cor_bri += corners[i].count;
00490                 break;
00491 
00492             case Darker:
00493                 darker.push_back(corners[i]);
00494                 num_dar += corners[i].count;
00495                 if(corners[i].is_a_corner)
00496                     cor_dar += corners[i].count;
00497                 break;
00498 
00499             case Similar:
00500                 similar.push_back(corners[i]);
00501                 num_sim += corners[i].count;
00502                 if(corners[i].is_a_corner)
00503                     cor_sim += corners[i].count;
00504                 break;
00505         }
00506     }
00507     
00508     //Deallocate the memory now it's no longer needed.
00509     corners.clear();
00510     
00511     //This is not the same as corners.size(), since the corners (datapoints)
00512     //have a count assosciated with them.
00513     uint64_t num_tests =  num_bri + num_dar + num_sim;
00514 
00515     
00516     //Build the subtrees
00517     shared_ptr<tree> b_tree, d_tree, s_tree;
00518 
00519     
00520     //If the sublist contains a single class, then instantiate a leaf,
00521     //otherwise recursively build the tree.
00522     if(cor_bri == 0)
00523         b_tree = tree::NonCornerLeaf(num_bri);
00524     else if(cor_bri == num_bri)
00525         b_tree = tree::CornerLeaf(num_bri);
00526     else
00527         b_tree = build_tree<S>(brighter, weights, nfeats);
00528     
00529 
00530     if(cor_dar == 0)
00531         d_tree = tree::NonCornerLeaf(num_dar);
00532     else if(cor_dar == num_dar)
00533         d_tree = tree::CornerLeaf(num_dar);
00534     else
00535         d_tree = build_tree<S>(darker, weights, nfeats);
00536 
00537 
00538     if(cor_sim == 0)
00539         s_tree = tree::NonCornerLeaf(num_sim);
00540     else if(cor_sim == num_sim)
00541         s_tree = tree::CornerLeaf(num_sim);
00542     else
00543         s_tree = build_tree<S>(similar, weights, nfeats);
00544     
00545     return shared_ptr<tree>(new tree(b_tree, d_tree, s_tree, f, num_tests));
00546 }

void print_tree ( const tree node,
ostream &  o,
const string &  i = "" 
)

This function traverses the tree and produces a textual representation of it.

Additionally, if any of the subtrees are the same, then a single subtree is produced and the test is removed.

A subtree has the following format:

    subtree= lead | node;
    
    leaf = "corner" | "background" ;

    node = node2 | node3;

    node3 = "if_brighter" feature_number n1 n2 n3
                subtree
            "elsf_darker" feature_number
                subtree
            "else"
                subtree
            "end";

     node2= if_statement feature_number n1 n2
                subtree
            "else"
                subtree
            "end";

    if_statement = "if_brighter" | "if_darker" | "if_either";
    feature_number ==integer;
    n1 = integer;
    n2 = integer;
    n3 = integer;

feature_number refers to the index of the feature that the test is performed on.

In node3, a 3 way test is performed. n1, n2 and n3 refer to the number of training examples landing in the if block, the elfs block and the else block respectivly.

In a node2 node, one of the tests has been removed. n1 and n2refer to the number of training examples landing in the if block and the else block respectivly.

Although not mentioned in the grammar, the indenting is kept very strict.

This representation has been designed to be parsed very easily with simple regular expressions, hence the use if "elsf" as opposed to "elif" or "elseif".

Parameters:
node (sub)tree to serialize
o Stream to serialize to.
i Indent to print before each line of the serialized tree.

Definition at line 601 of file learn_fast_tree.cc.

References tree::brighter, tree::Corner, tree::darker, tree::feature_to_test, tree::is_a_corner, tree::NonCorner, tree::num_datapoints, and tree::similar.

Referenced by main().

00602 {
00603     if(node->is_a_corner == tree::Corner)
00604         o << i << "corner" << endl;
00605     else if(node->is_a_corner == tree::NonCorner)
00606         o << i << "background" << endl;
00607     else
00608     {
00609         string b = node->brighter->stringify();
00610         string d = node->darker->stringify();
00611         string s = node->similar->stringify();
00612 
00613         const tree * bt = node->brighter.get();
00614         const tree * dt = node->darker.get();
00615         const tree * st = node->similar.get();
00616         string ii = i + " ";
00617 
00618         int f = node->feature_to_test;
00619     
00620         if(b == d && d == s) //All the same
00621         {
00622             //o << i << "if " << f << " is whatever\n";
00623             print_tree(st, o, i);
00624         }
00625         else if(d == s)  //Bright is different
00626         {
00627             o << i << "if_brighter " << f << " " << bt->num_datapoints << " " << dt->num_datapoints+st->num_datapoints << endl;
00628                 print_tree(bt, o, ii);
00629             o << i << "else" << endl;
00630                 print_tree(st, o, ii);
00631             o << i << "end" << endl;
00632 
00633         }
00634         else if(b == s) //Dark is different
00635         {   
00636             o << i << "if_darker " << f << " " << dt->num_datapoints << " " << bt->num_datapoints + st->num_datapoints << endl;
00637                 print_tree(dt, o, ii);
00638             o << i << "else" << endl;
00639                 print_tree(st, o, ii);
00640             o << i << "end" << endl;
00641         }
00642         else if(b == d) //Similar is different
00643         {
00644             o << i << "if_either " << f << " " <<  bt->num_datapoints + dt->num_datapoints  << " " << st->num_datapoints << endl;
00645                 print_tree(bt, o, ii);
00646             o << i << "else" << endl;
00647                 print_tree(st, o, ii);
00648             o << i << "end" << endl;
00649         }
00650         else //All different
00651         {
00652             o << i << "if_brighter " << f << " "  <<  bt->num_datapoints << " " << dt->num_datapoints  << " " << st->num_datapoints << endl;
00653                 print_tree(bt, o, ii);
00654             o << i << "elsf_darker " << f << endl;
00655                 print_tree(dt, o, ii);
00656             o << i << "else" << endl;
00657                 print_tree(st, o, ii);
00658             o << i << "end" << endl;
00659         }
00660     }
00661 }

template<int S>
V_tuple<shared_ptr<tree>, uint64_t>::type load_and_build_tree ( unsigned int  num_features,
const vector< double > &  weights 
) [inline]

This function loads data and builds a tree.

It is templated because datapoint is templated, for reasons of memory efficiency.

Parameters:
num_features Number of features used
weights Weights on each feature.
Returns:
The learned tree, and number of datapoints.

Definition at line 668 of file learn_fast_tree.cc.

00669 {
00670     assert(weights.size() == num_features);
00671 
00672     shared_ptr<vector<datapoint<S> > > l;
00673     uint64_t num_datapoints;
00674     
00675     //Load the data
00676     make_rtuple(l, num_datapoints) = load_features<S>(num_features);
00677     
00678     cerr << "Loaded.\n";
00679     
00680     //Build the tree
00681     shared_ptr<tree> tree;
00682     tree  = build_tree<S>(*l, weights, num_features);
00683 
00684     return make_vtuple(tree, num_datapoints);
00685 }

int main ( int  argc,
char **  argv 
)

The main program.

Parameters:
argc Number of commandline arguments
argv Commandline arguments

Each feature takes up 2 bits. Since GCC doesn't pack any finer then 32 bits for hetrogenous structs, there is no point in having granularity finer than 16 features.

Definition at line 692 of file learn_fast_tree.cc.

References fatal, offsets, and print_tree().

00693 {
00694     //Set up default arguments
00695     GUI.parseArguments(argc, argv);
00696 
00697     cin.sync_with_stdio(false);
00698     cout.sync_with_stdio(false);
00699 
00700     
00701     ///////////////////
00702     //read file
00703     
00704     //Read number of features
00705     unsigned int num_features;
00706     cin >> num_features;
00707     if(!cin.good() || cin.eof())
00708         fatal(6, "Error reading number of features.");
00709 
00710     //Read offset list
00711     vector<ImageRef> offsets(num_features);
00712     for(unsigned int i=0; i < num_features; i++)
00713         cin >> offsets[i];
00714     if(!cin.good() || cin.eof())
00715         fatal(7, "Error reading offset list.");
00716 
00717     //Read weights for the various offsets
00718     vector<double> weights(offsets.size());
00719     for(unsigned int i=0; i < weights.size(); i++)
00720         weights[i] = GV3::get<double>(sPrintf("weights.%i", i), 1, 1);
00721 
00722 
00723     shared_ptr<tree> tree;
00724     uint64_t num_datapoints;
00725 
00726     ///Each feature takes up 2 bits. Since GCC doesn't pack any finer
00727     ///then 32 bits for hetrogenous structs, there is no point in having
00728     ///granularity finer than 16 features.
00729     if(num_features <= 16)
00730         make_rtuple(tree, num_datapoints) = load_and_build_tree<16>(num_features, weights);
00731     else if(num_features <= 32)
00732         make_rtuple(tree, num_datapoints) = load_and_build_tree<32>(num_features, weights);
00733     else if(num_features <= 48)
00734         make_rtuple(tree, num_datapoints) = load_and_build_tree<48>(num_features, weights);
00735     else if(num_features <= 64)
00736         make_rtuple(tree, num_datapoints) = load_and_build_tree<64>(num_features, weights);
00737     else
00738         fatal(8, "Too many feratures (%i). To learn from this, see %s, line %i.", num_features, __FILE__, __LINE__);
00739 
00740     
00741     cout << num_features << endl;
00742     copy(offsets.begin(), offsets.end(), ostream_iterator<ImageRef>(cout, " "));
00743     cout << endl;
00744     print_tree(tree.get(), cout);
00745 }


Generated on Mon Mar 2 12:47:12 2009 for FAST-ER by  doxygen 1.5.3