[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

vigra/random_forest/rf_decisionTree.hxx VIGRA

00001 /************************************************************************/
00002 /*                                                                      */
00003 /*        Copyright 2008-2009 by  Ullrich Koethe and Rahul Nair         */
00004 /*                                                                      */
00005 /*    This file is part of the VIGRA computer vision library.           */
00006 /*    The VIGRA Website is                                              */
00007 /*        http://hci.iwr.uni-heidelberg.de/vigra/                       */
00008 /*    Please direct questions, bug reports, and contributions to        */
00009 /*        ullrich.koethe@iwr.uni-heidelberg.de    or                    */
00010 /*        vigra@informatik.uni-hamburg.de                               */
00011 /*                                                                      */
00012 /*    Permission is hereby granted, free of charge, to any person       */
00013 /*    obtaining a copy of this software and associated documentation    */
00014 /*    files (the "Software"), to deal in the Software without           */
00015 /*    restriction, including without limitation the rights to use,      */
00016 /*    copy, modify, merge, publish, distribute, sublicense, and/or      */
00017 /*    sell copies of the Software, and to permit persons to whom the    */
00018 /*    Software is furnished to do so, subject to the following          */
00019 /*    conditions:                                                       */
00020 /*                                                                      */
00021 /*    The above copyright notice and this permission notice shall be    */
00022 /*    included in all copies or substantial portions of the             */
00023 /*    Software.                                                         */
00024 /*                                                                      */
00025 /*    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND    */
00026 /*    EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES   */
00027 /*    OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND          */
00028 /*    NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT       */
00029 /*    HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,      */
00030 /*    WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING      */
00031 /*    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR     */
00032 /*    OTHER DEALINGS IN THE SOFTWARE.                                   */
00033 /*                                                                      */
00034 /************************************************************************/
00035 
00036 #ifndef VIGRA_RANDOM_FOREST_DT_HXX
00037 #define VIGRA_RANDOM_FOREST_DT_HXX
00038 
00039 #include <algorithm>
00040 #include <map>
00041 #include <numeric>
00042 #include "vigra/multi_array.hxx"
00043 #include "vigra/mathutil.hxx"
00044 #include "vigra/array_vector.hxx"
00045 #include "vigra/sized_int.hxx"
00046 #include "vigra/matrix.hxx"
00047 #include "vigra/random.hxx"
00048 #include "vigra/functorexpression.hxx"
00049 #include <vector>
00050 
00051 #include "rf_common.hxx"
00052 #include "rf_visitors.hxx"
00053 #include "rf_nodeproxy.hxx"
00054 namespace vigra
00055 {
00056 
00057 namespace detail
00058 {
00059  // todo FINALLY DECIDE TO USE CAMEL CASE OR UNDERSCORES !!!!!!
00060 /* decisiontree classifier. 
00061  *
00062  * This class is actually meant to be used in conjunction with the 
00063  * Random Forest Classifier 
00064  * - My suggestion would be to use the RandomForest classifier with 
00065  *   following parameters instead of directly using this 
00066  *   class (Preprocessing default values etc is handled in there):
00067  *
00068  * \code
00069  *      RandomForest decisionTree(RF_Traits::Options_t()
00070  *                                  .features_per_node(RF_ALL)
00071  *                                  .tree_count(1)            );
00072  * \endcode
00073  * 
00074  * \todo remove the classCount and featurecount from the topology
00075  *       array. Pass ext_param_ to the nodes!
00076  * \todo Use relative addressing of nodes?
00077  */
00078 class DecisionTree
00079 {
00080     /* \todo make private?*/
00081   public:
00082     
00083     /* value type of container array. use whenever referencing it
00084      */
00085     typedef Int32 TreeInt;
00086 
00087     ArrayVector<TreeInt>  topology_;
00088     ArrayVector<double>   parameters_;
00089 
00090     ProblemSpec<> ext_param_;
00091     unsigned int classCount_;
00092 
00093 
00094   public:
00095     /* \brief Create tree with parameters */
00096     template<class T>
00097     DecisionTree(ProblemSpec<T> ext_param)
00098     :
00099         ext_param_(ext_param),
00100         classCount_(ext_param.class_count_)
00101     {}
00102 
00103     /* clears all memory used.
00104      */
00105     void reset(unsigned int classCount = 0)
00106     {
00107         if(classCount)
00108             classCount_ = classCount;
00109         topology_.clear();
00110         parameters_.clear();
00111     }
00112 
00113 
00114     /* learn a Tree
00115      *
00116      * \tparam  StackEntry_t The Stackentry containing Node/StackEntry_t 
00117      *          Information used during learning. Each Split functor has a 
00118      *          Stack entry associated with it (Split_t::StackEntry_t)
00119      * \sa RandomForest::learn()
00120      */
00121     template <  class U, class C,
00122                 class U2, class C2,
00123                 class StackEntry_t,
00124                 class Stop_t,
00125                 class Split_t,
00126                 class Visitor_t,
00127                 class Random_t >
00128     void learn(     MultiArrayView<2, U, C> const      & features,
00129                     MultiArrayView<2, U2, C2> const    & labels,
00130                     StackEntry_t const &                 stack_entry,
00131                     Split_t                              split,
00132                     Stop_t                               stop,
00133                     Visitor_t &                          visitor,
00134                     Random_t &                           randint);
00135     template <  class U, class C,
00136              class U2, class C2,
00137              class StackEntry_t,
00138              class Stop_t,
00139              class Split_t,
00140              class Visitor_t,
00141              class Random_t>
00142     void continueLearn(   MultiArrayView<2, U, C> const       & features,
00143                           MultiArrayView<2, U2, C2> const     & labels,
00144                           StackEntry_t const &                  stack_entry,
00145                           Split_t                               split,
00146                           Stop_t                                stop,
00147                           Visitor_t &                           visitor,
00148                           Random_t &                            randint,
00149                           //an index to which the last created exterior node will be moved (because it is not used anymore)
00150                           int                                   garbaged_child=-1);
00151 
00152     /* is a node a Leaf Node? */
00153     inline bool isLeafNode(TreeInt in) const
00154     {
00155         return (in & LeafNodeTag) == LeafNodeTag;
00156     }
00157 
00158     /* data driven traversal from root to leaf
00159      *
00160      * traverse through tree with data given in features. Use Visitors to 
00161      * collect statistics along the way. 
00162      */
00163     template<class U, class C, class Visitor_t>
00164     TreeInt getToLeaf(MultiArrayView<2, U, C> const & features, 
00165                       Visitor_t  & visitor) const
00166     {
00167         TreeInt index = 2;
00168         while(!isLeafNode(topology_[index]))
00169         {
00170             visitor.visit_internal_node(*this, index, topology_[index],features);
00171             switch(topology_[index])
00172             {
00173                 case i_ThresholdNode:
00174                 {
00175                     Node<i_ThresholdNode> 
00176                                 node(topology_, parameters_, index);
00177                     index = node.next(features);
00178                     break;
00179                 }
00180                 case i_HyperplaneNode:
00181                 {
00182                     Node<i_HyperplaneNode> 
00183                                 node(topology_, parameters_, index);
00184                     index = node.next(features);
00185                     break;
00186                 }
00187                 case i_HypersphereNode:
00188                 {
00189                     Node<i_HypersphereNode> 
00190                                 node(topology_, parameters_, index);
00191                     index = node.next(features);
00192                     break;
00193                 }
00194 #if 0 
00195                 // for quick prototyping! has to be implemented.
00196                 case i_VirtualNode:
00197                 {
00198                     Node<i_VirtualNode> 
00199                                 node(topology_, parameters, index);
00200                     index = node.next(features);
00201                 }
00202 #endif
00203                 default:
00204                     vigra_fail("DecisionTree::getToLeaf():"
00205                                "encountered unknown internal Node Type");
00206             }
00207         }
00208         visitor.visit_external_node(*this, index, topology_[index],features);
00209         return index;
00210     }
00211     /* traverse tree to get statistics
00212      *
00213      * Tree is traversed in order the Nodes are in memory (i.e. if no 
00214      * relearning//pruning scheme is utilized this will be pre order)
00215      */
00216     template<class Visitor_t>
00217     void traverse_mem_order(Visitor_t visitor) const
00218     {
00219         TreeInt index = 2;
00220         Int32 ii = 0;
00221         while(index < topology_.size())
00222         {
00223             if(isLeafNode(topology_[index]))
00224             {
00225                 visitor
00226                     .visit_external_node(*this, index, topology_[index]);
00227             }
00228             else
00229             {
00230                 visitor
00231                     ._internal_node(*this, index, topology_[index]);
00232             }
00233         }
00234     }
00235 
00236     template<class Visitor_t>
00237     void traverse_post_order(Visitor_t visitor,  TreeInt start = 2) const
00238     {
00239         typedef TinyVector<double, 2> Entry; 
00240         std::vector<Entry > stack;
00241         std::vector<double> result_stack;
00242         stack.push_back(Entry(2, 0));
00243         int addr; 
00244         while(!stack.empty())
00245         {
00246             addr = stack.back()[0];
00247             NodeBase node(topology_, parameters_, stack.back()[0]);
00248             if(stack.back()[1] == 1)
00249             {
00250                 stack.pop_back();
00251                 double leftRes = result_stack.back();
00252                 double rightRes = result_stack.back();
00253                 result_stack.pop_back();
00254                 result_stack.pop_back();
00255                 result_stack.push_back(rightRes+ leftRes);
00256                 visitor.visit_internal_node(*this, 
00257                                             addr, 
00258                                             node.typeID(), 
00259                                             rightRes+leftRes);
00260             }
00261             else
00262             {
00263                 if(isLeafNode(node.typeID()))
00264                 {
00265                     visitor.visit_external_node(*this, 
00266                                                 addr, 
00267                                                 node.typeID(), 
00268                                                 node.weights());
00269                     stack.pop_back();
00270                     result_stack.push_back(node.weights());
00271                 }
00272                 else
00273                 {
00274                     stack.back()[1] = 1; 
00275                     stack.push_back(Entry(node.child(0), 0));
00276                     stack.push_back(Entry(node.child(1), 0));
00277                 }
00278                     
00279             }
00280         }
00281     }
00282 
00283     /* same thing as above, without any visitors */
00284     template<class U, class C>
00285     TreeInt getToLeaf(MultiArrayView<2, U, C> const & features) const
00286     {
00287         ::vigra::rf::visitors::StopVisiting stop;
00288         return getToLeaf(features, stop);
00289     }
00290 
00291 
00292     template <class U, class C>
00293     ArrayVector<double>::iterator
00294     predict(MultiArrayView<2, U, C> const & features) const
00295     {
00296         TreeInt nodeindex = getToLeaf(features);
00297         switch(topology_[nodeindex])
00298         {
00299             case e_ConstProbNode:
00300                 return Node<e_ConstProbNode>(topology_, 
00301                                              parameters_,
00302                                              nodeindex).prob_begin();
00303                 break;
00304 #if 0 
00305             //first make the Logistic regression stuff...
00306             case e_LogRegProbNode:
00307                 return Node<e_LogRegProbNode>(topology_, 
00308                                               parameters_,
00309                                               nodeindex).prob_begin();
00310 #endif            
00311             default:
00312                 vigra_fail("DecisionTree::predict() :"
00313                            " encountered unknown external Node Type");
00314         }
00315         return ArrayVector<double>::iterator();
00316     }
00317 
00318 
00319 
00320     template <class U, class C>
00321     Int32 predictLabel(MultiArrayView<2, U, C> const & features) const
00322     {
00323         ArrayVector<double>::const_iterator weights = predict(features);
00324         return argMax(weights, weights+classCount_) - weights;
00325     }
00326 
00327 };
00328 
00329 
00330 template <  class U, class C,
00331             class U2, class C2,
00332             class StackEntry_t,
00333             class Stop_t,
00334             class Split_t,
00335             class Visitor_t,
00336             class Random_t>
00337 void DecisionTree::learn(   MultiArrayView<2, U, C> const       & features,
00338                             MultiArrayView<2, U2, C2> const     & labels,
00339                             StackEntry_t const &                  stack_entry,
00340                             Split_t                               split,
00341                             Stop_t                                stop,
00342                             Visitor_t &                           visitor,
00343                             Random_t &                            randint)
00344 {
00345     this->reset();
00346     topology_.reserve(256);
00347     parameters_.reserve(256);
00348     topology_.push_back(features.shape(1));
00349     topology_.push_back(classCount_);
00350     continueLearn(features,labels,stack_entry,split,stop,visitor,randint);
00351 }
00352 
00353 template <  class U, class C,
00354             class U2, class C2,
00355             class StackEntry_t,
00356             class Stop_t,
00357             class Split_t,
00358             class Visitor_t,
00359             class Random_t>
00360 void DecisionTree::continueLearn(   MultiArrayView<2, U, C> const       & features,
00361                             MultiArrayView<2, U2, C2> const     & labels,
00362                             StackEntry_t const &                  stack_entry,
00363                             Split_t                               split,
00364                             Stop_t                                stop,
00365                             Visitor_t &                           visitor,
00366                             Random_t &                            randint,
00367                             //an index to which the last created exterior node will be moved (because it is not used anymore)
00368                             int                                   garbaged_child)
00369 {
00370     std::vector<StackEntry_t> stack;
00371     stack.reserve(128);
00372     ArrayVector<StackEntry_t> child_stack_entry(2, stack_entry);
00373     stack.push_back(stack_entry);
00374     size_t last_node_pos = 0;
00375     StackEntry_t top=stack.back();
00376 
00377     while(!stack.empty())
00378     {
00379 
00380         // Take an element of the stack. Obvious ain't it?
00381         top = stack.back();
00382         stack.pop_back();
00383 
00384         // Make sure no data from the last round has remained in Pipeline;
00385         child_stack_entry[0].reset();
00386         child_stack_entry[1].reset();
00387         split.reset();
00388 
00389 
00390         //Either the Stopping criterion decides that the split should 
00391         //produce a Terminal Node or the Split itself decides what 
00392         //kind of node to make
00393         TreeInt NodeID;
00394         
00395         if(stop(top))
00396             NodeID = split.makeTerminalNode(features, 
00397                                             labels, 
00398                                             top, 
00399                                             randint);
00400         else
00401         {
00402             //TIC;
00403             NodeID = split.findBestSplit(features, 
00404                                          labels, 
00405                                          top, 
00406                                          child_stack_entry, 
00407                                          randint);
00408             //std::cerr << TOC <<" " << NodeID << ";" <<std::endl;
00409         }
00410 
00411         // do some visiting yawn - just added this comment as eye candy
00412         // (looks odd otherwise with my syntax highlighting....
00413         visitor.visit_after_split(*this, split, top, 
00414                                   child_stack_entry[0], 
00415                                   child_stack_entry[1],
00416                                   features, 
00417                                   labels);
00418 
00419 
00420         // Update the Child entries of the parent
00421         // Using InteriorNodeBase because exact parameter form not needed.
00422         // look at the Node base before getting scared.
00423         last_node_pos = topology_.size();
00424         if(top.leftParent != StackEntry_t::DecisionTreeNoParent)
00425         {
00426             NodeBase(topology_, 
00427                      parameters_, 
00428                      top.leftParent).child(0) = last_node_pos;
00429         }
00430         else if(top.rightParent != StackEntry_t::DecisionTreeNoParent)
00431         {
00432             NodeBase(topology_, 
00433                      parameters_, 
00434                      top.rightParent).child(1) = last_node_pos;
00435         }
00436 
00437 
00438         // Supply the split functor with the Node type it requires.
00439         // set the address to which the children of this node should point 
00440         // to and push back children onto stack
00441         if(!isLeafNode(NodeID))
00442         {
00443             child_stack_entry[0].leftParent = topology_.size();
00444             child_stack_entry[1].rightParent = topology_.size();    
00445             child_stack_entry[0].rightParent = -1;
00446             child_stack_entry[1].leftParent = -1;
00447             stack.push_back(child_stack_entry[0]);
00448             stack.push_back(child_stack_entry[1]);
00449         }
00450 
00451         //copy the newly created node form the split functor to the
00452         //decision tree.
00453         NodeBase(split.createNode(), topology_, parameters_ );
00454     }
00455     if(garbaged_child!=-1)
00456     {
00457         Node<e_ConstProbNode>(topology_,parameters_,garbaged_child).copy(Node<e_ConstProbNode>(topology_,parameters_,last_node_pos));
00458 
00459         int last_parameter_size = Node<e_ConstProbNode>(topology_,parameters_,garbaged_child).parameters_size();
00460         topology_.resize(last_node_pos);
00461         parameters_.resize(parameters_.size() - last_parameter_size);
00462     
00463         if(top.leftParent != StackEntry_t::DecisionTreeNoParent)
00464             NodeBase(topology_, 
00465                      parameters_, 
00466                      top.leftParent).child(0) = garbaged_child;
00467         else if(top.rightParent != StackEntry_t::DecisionTreeNoParent)
00468             NodeBase(topology_, 
00469                      parameters_, 
00470                      top.rightParent).child(1) = garbaged_child;
00471     }
00472 }
00473 
00474 } //namespace detail
00475 
00476 } //namespace vigra
00477 
00478 #endif //VIGRA_RANDOM_FOREST_DT_HXX

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.8.0 (20 Sep 2011)