[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

vigra/random_forest/rf_visitors.hxx VIGRA

00001 /************************************************************************/
00002 /*                                                                      */
00003 /*        Copyright 2008-2009 by  Ullrich Koethe and Rahul Nair         */
00004 /*                                                                      */
00005 /*    This file is part of the VIGRA computer vision library.           */
00006 /*    The VIGRA Website is                                              */
00007 /*        http://hci.iwr.uni-heidelberg.de/vigra/                       */
00008 /*    Please direct questions, bug reports, and contributions to        */
00009 /*        ullrich.koethe@iwr.uni-heidelberg.de    or                    */
00010 /*        vigra@informatik.uni-hamburg.de                               */
00011 /*                                                                      */
00012 /*    Permission is hereby granted, free of charge, to any person       */
00013 /*    obtaining a copy of this software and associated documentation    */
00014 /*    files (the "Software"), to deal in the Software without           */
00015 /*    restriction, including without limitation the rights to use,      */
00016 /*    copy, modify, merge, publish, distribute, sublicense, and/or      */
00017 /*    sell copies of the Software, and to permit persons to whom the    */
00018 /*    Software is furnished to do so, subject to the following          */
00019 /*    conditions:                                                       */
00020 /*                                                                      */
00021 /*    The above copyright notice and this permission notice shall be    */
00022 /*    included in all copies or substantial portions of the             */
00023 /*    Software.                                                         */
00024 /*                                                                      */
00025 /*    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND    */
00026 /*    EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES   */
00027 /*    OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND          */
00028 /*    NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT       */
00029 /*    HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,      */
00030 /*    WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING      */
00031 /*    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR     */
00032 /*    OTHER DEALINGS IN THE SOFTWARE.                                   */
00033 /*                                                                      */
00034 /************************************************************************/
00035 #ifndef RF_VISITORS_HXX
00036 #define RF_VISITORS_HXX
00037 
00038 #ifdef HasHDF5
00039 # include "vigra/hdf5impex.hxx"
00040 #endif // HasHDF5
00041 #include <vigra/windows.h>
00042 #include <iostream>
00043 #include <iomanip>
00044 
00045 #include <vigra/multi_pointoperators.hxx>
00046 #include <vigra/timing.hxx>
00047 
00048 namespace vigra
00049 {
00050 namespace rf
00051 {
00052 /** \addtogroup MachineLearning Machine Learning
00053 **/
00054 //@{
00055 
00056 /**
00057     This namespace contains all classes and methods related to extracting information during 
00058     learning of the random forest. All Visitors share the same interface defined in 
00059     visitors::VisitorBase. The member methods are invoked at certain points of the main code in 
00060     the order they were supplied.
00061     
00062     For the Random Forest the  Visitor concept is implemented as a statically linked list 
00063     (Using templates). Each Visitor object is encapsulated in a detail::VisitorNode object. The 
00064     VisitorNode object calls the Next Visitor after one of its visit() methods have terminated.
00065     
00066     To simplify usage create_visitor() factory methods are supplied.
00067     Use the create_visitor() method to supply visitor objects to the RandomForest::learn() method.
00068     It is possible to supply more than one visitor. They will then be invoked in serial order.
00069 
00070     The calculated information are stored as public data members of the class. - see documentation
00071     of the individual visitors
00072     
00073     While creating a new visitor the new class should therefore publicly inherit from this class 
00074     (i.e.: see visitors::OOB_Error).
00075 
00076     \code
00077 
00078       typedef xxx feature_t \\ replace xxx with whichever type
00079       typedef yyy label_t   \\ meme chose. 
00080       MultiArrayView<2, feature_t> f = get_some_features();
00081       MultiArrayView<2, label_t>   l = get_some_labels();
00082       RandomForest<> rf()
00083     
00084       //calculate OOB Error
00085       visitors::OOB_Error oob_v;
00086       //calculate Variable Importance
00087       visitors::VariableImportanceVisitor varimp_v;
00088 
00089       double oob_error = rf.learn(f, l, visitors::create_visitor(oob_v, varimp_v);
00090       //the data can be found in the attributes of oob_v and varimp_v now
00091       
00092     \endcode
00093 */
00094 namespace visitors
00095 {
00096     
00097     
00098 /** Base Class from which all Visitors derive. Can be used as a template to create new 
00099  * Visitors.
00100  */
00101 class VisitorBase
00102 {
00103     public:
00104     bool active_;   
00105     bool is_active()
00106     {
00107         return active_;
00108     }
00109 
00110     bool has_value()
00111     {
00112         return false;
00113     }
00114 
00115     VisitorBase()
00116         : active_(true)
00117     {}
00118 
00119     void deactivate()
00120     {
00121         active_ = false;
00122     }
00123     void activate()
00124     {
00125         active_ = true;
00126     }
00127     
00128     /** do something after the the Split has decided how to process the Region
00129      * (Stack entry)
00130      *
00131      * \param tree      reference to the tree that is currently being learned
00132      * \param split     reference to the split object
00133      * \param parent    current stack entry  which was used to decide the split
00134      * \param leftChild left stack entry that will be pushed
00135      * \param rightChild
00136      *                  right stack entry that will be pushed.
00137      * \param features  features matrix
00138      * \param labels    label matrix
00139      * \sa RF_Traits::StackEntry_t
00140      */
00141     template<class Tree, class Split, class Region, class Feature_t, class Label_t>
00142     void visit_after_split( Tree          & tree, 
00143                             Split         & split,
00144                             Region        & parent,
00145                             Region        & leftChild,
00146                             Region        & rightChild,
00147                             Feature_t     & features,
00148                             Label_t       & labels)
00149     {}
00150     
00151     /** do something after each tree has been learned
00152      *
00153      * \param rf        reference to the random forest object that called this
00154      *                  visitor
00155      * \param pr        reference to the preprocessor that processed the input
00156      * \param sm        reference to the sampler object
00157      * \param st        reference to the first stack entry
00158      * \param index     index of current tree
00159      */
00160     template<class RF, class PR, class SM, class ST>
00161     void visit_after_tree(RF& rf, PR & pr,  SM & sm, ST & st, int index)
00162     {}
00163     
00164     /** do something after all trees have been learned
00165      *
00166      * \param rf        reference to the random forest object that called this
00167      *                  visitor
00168      * \param pr        reference to the preprocessor that processed the input
00169      */
00170     template<class RF, class PR>
00171     void visit_at_end(RF const & rf, PR const & pr)
00172     {}
00173     
00174     /** do something before learning starts 
00175      *
00176      * \param rf        reference to the random forest object that called this
00177      *                  visitor
00178      * \param pr        reference to the Processor class used.
00179      */
00180     template<class RF, class PR>
00181     void visit_at_beginning(RF const & rf, PR const & pr)
00182     {}
00183     /** do some thing while traversing tree after it has been learned 
00184      *  (external nodes)
00185      *
00186      * \param tr        reference to the tree object that called this visitor
00187      * \param index     index in the topology_ array we currently are at
00188      * \param node_t    type of node we have (will be e_.... - )
00189      * \param features  feature matrix
00190      * \sa  NodeTags;
00191      *
00192      * you can create the node by using a switch on node_tag and using the 
00193      * corresponding Node objects. Or - if you do not care about the type 
00194      * use the NodeBase class.
00195      */
00196     template<class TR, class IntT, class TopT,class Feat>
00197     void visit_external_node(TR & tr, IntT index, TopT node_t,Feat & features)
00198     {}
00199     
00200     /** do something when visiting a internal node after it has been learned
00201      *
00202      * \sa visit_external_node
00203      */
00204     template<class TR, class IntT, class TopT,class Feat>
00205     void visit_internal_node(TR & tr, IntT index, TopT node_t,Feat & features)
00206     {}
00207 
00208     /** return a double value.  The value of the first 
00209      * visitor encountered that has a return value is returned with the
00210      * RandomForest::learn() method - or -1.0 if no return value visitor
00211      * existed. This functionality basically only exists so that the 
00212      * OOB - visitor can return the oob error rate like in the old version 
00213      * of the random forest.
00214      */
00215     double return_val()
00216     {
00217         return -1.0;
00218     }
00219 };
00220 
00221 
00222 /** Last Visitor that should be called to stop the recursion.
00223  */
00224 class StopVisiting: public VisitorBase
00225 {
00226     public:
00227     bool has_value()
00228     {
00229         return true;
00230     }
00231     double return_val()
00232     {
00233         return -1.0;
00234     }
00235 };
00236 namespace detail
00237 {
00238 /** Container elements of the statically linked Visitor list.
00239  *
00240  * use the create_visitor() factory functions to create visitors up to size 10;
00241  *
00242  */
00243 template <class Visitor, class Next = StopVisiting>
00244 class VisitorNode
00245 {
00246     public:
00247     
00248     StopVisiting    stop_;
00249     Next            next_;
00250     Visitor &       visitor_;   
00251     VisitorNode(Visitor & visitor, Next & next) 
00252     : 
00253         next_(next), visitor_(visitor)
00254     {}
00255 
00256     VisitorNode(Visitor &  visitor) 
00257     : 
00258         next_(stop_), visitor_(visitor)
00259     {}
00260 
00261     template<class Tree, class Split, class Region, class Feature_t, class Label_t>
00262     void visit_after_split( Tree          & tree, 
00263                             Split         & split,
00264                             Region        & parent,
00265                             Region        & leftChild,
00266                             Region        & rightChild,
00267                             Feature_t     & features,
00268                             Label_t       & labels)
00269     {
00270         if(visitor_.is_active())
00271             visitor_.visit_after_split(tree, split, 
00272                                        parent, leftChild, rightChild,
00273                                        features, labels);
00274         next_.visit_after_split(tree, split, parent, leftChild, rightChild,
00275                                 features, labels);
00276     }
00277 
00278     template<class RF, class PR, class SM, class ST>
00279     void visit_after_tree(RF& rf, PR & pr,  SM & sm, ST & st, int index)
00280     {
00281         if(visitor_.is_active())
00282             visitor_.visit_after_tree(rf, pr, sm, st, index);
00283         next_.visit_after_tree(rf, pr, sm, st, index);
00284     }
00285 
00286     template<class RF, class PR>
00287     void visit_at_beginning(RF & rf, PR & pr)
00288     {
00289         if(visitor_.is_active())
00290             visitor_.visit_at_beginning(rf, pr);
00291         next_.visit_at_beginning(rf, pr);
00292     }
00293     template<class RF, class PR>
00294     void visit_at_end(RF & rf, PR & pr)
00295     {
00296         if(visitor_.is_active())
00297             visitor_.visit_at_end(rf, pr);
00298         next_.visit_at_end(rf, pr);
00299     }
00300     
00301     template<class TR, class IntT, class TopT,class Feat>
00302     void visit_external_node(TR & tr, IntT & index, TopT & node_t,Feat & features)
00303     {
00304         if(visitor_.is_active())
00305             visitor_.visit_external_node(tr, index, node_t,features);
00306         next_.visit_external_node(tr, index, node_t,features);
00307     }
00308     template<class TR, class IntT, class TopT,class Feat>
00309     void visit_internal_node(TR & tr, IntT & index, TopT & node_t,Feat & features)
00310     {
00311         if(visitor_.is_active())
00312             visitor_.visit_internal_node(tr, index, node_t,features);
00313         next_.visit_internal_node(tr, index, node_t,features);
00314     }
00315 
00316     double return_val()
00317     {
00318         if(visitor_.is_active() && visitor_.has_value())
00319             return visitor_.return_val();
00320         return next_.return_val();
00321     }
00322 };
00323 
00324 } //namespace detail
00325 
00326 //////////////////////////////////////////////////////////////////////////////
00327 //  Visitor Factory function up to 10 visitors                              //
00328 //////////////////////////////////////////////////////////////////////////////
00329 
00330 /** factory method to to be used with RandomForest::learn()
00331  */
00332 template<class A>
00333 detail::VisitorNode<A>
00334 create_visitor(A & a)
00335 {
00336    typedef detail::VisitorNode<A> _0_t;
00337    _0_t _0(a);
00338    return _0;
00339 }
00340 
00341 
00342 /** factory method to to be used with RandomForest::learn()
00343  */
00344 template<class A, class B>
00345 detail::VisitorNode<A, detail::VisitorNode<B> >
00346 create_visitor(A & a, B & b)
00347 {
00348    typedef detail::VisitorNode<B> _1_t;
00349    _1_t _1(b);
00350    typedef detail::VisitorNode<A, _1_t> _0_t;
00351    _0_t _0(a, _1);
00352    return _0;
00353 }
00354 
00355 
00356 /** factory method to to be used with RandomForest::learn()
00357  */
00358 template<class A, class B, class C>
00359 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C> > >
00360 create_visitor(A & a, B & b, C & c)
00361 {
00362    typedef detail::VisitorNode<C> _2_t;
00363    _2_t _2(c);
00364    typedef detail::VisitorNode<B, _2_t> _1_t;
00365    _1_t _1(b, _2);
00366    typedef detail::VisitorNode<A, _1_t> _0_t;
00367    _0_t _0(a, _1);
00368    return _0;
00369 }
00370 
00371 
00372 /** factory method to to be used with RandomForest::learn()
00373  */
00374 template<class A, class B, class C, class D>
00375 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 
00376     detail::VisitorNode<D> > > >
00377 create_visitor(A & a, B & b, C & c, D & d)
00378 {
00379    typedef detail::VisitorNode<D> _3_t;
00380    _3_t _3(d);
00381    typedef detail::VisitorNode<C, _3_t> _2_t;
00382    _2_t _2(c, _3);
00383    typedef detail::VisitorNode<B, _2_t> _1_t;
00384    _1_t _1(b, _2);
00385    typedef detail::VisitorNode<A, _1_t> _0_t;
00386    _0_t _0(a, _1);
00387    return _0;
00388 }
00389 
00390 
00391 /** factory method to to be used with RandomForest::learn()
00392  */
00393 template<class A, class B, class C, class D, class E>
00394 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 
00395     detail::VisitorNode<D, detail::VisitorNode<E> > > > >
00396 create_visitor(A & a, B & b, C & c, 
00397                D & d, E & e)
00398 {
00399    typedef detail::VisitorNode<E> _4_t;
00400    _4_t _4(e);
00401    typedef detail::VisitorNode<D, _4_t> _3_t;
00402    _3_t _3(d, _4);
00403    typedef detail::VisitorNode<C, _3_t> _2_t;
00404    _2_t _2(c, _3);
00405    typedef detail::VisitorNode<B, _2_t> _1_t;
00406    _1_t _1(b, _2);
00407    typedef detail::VisitorNode<A, _1_t> _0_t;
00408    _0_t _0(a, _1);
00409    return _0;
00410 }
00411 
00412 
00413 /** factory method to to be used with RandomForest::learn()
00414  */
00415 template<class A, class B, class C, class D, class E,
00416          class F>
00417 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 
00418     detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F> > > > > >
00419 create_visitor(A & a, B & b, C & c, 
00420                D & d, E & e, F & f)
00421 {
00422    typedef detail::VisitorNode<F> _5_t;
00423    _5_t _5(f);
00424    typedef detail::VisitorNode<E, _5_t> _4_t;
00425    _4_t _4(e, _5);
00426    typedef detail::VisitorNode<D, _4_t> _3_t;
00427    _3_t _3(d, _4);
00428    typedef detail::VisitorNode<C, _3_t> _2_t;
00429    _2_t _2(c, _3);
00430    typedef detail::VisitorNode<B, _2_t> _1_t;
00431    _1_t _1(b, _2);
00432    typedef detail::VisitorNode<A, _1_t> _0_t;
00433    _0_t _0(a, _1);
00434    return _0;
00435 }
00436 
00437 
00438 /** factory method to to be used with RandomForest::learn()
00439  */
00440 template<class A, class B, class C, class D, class E,
00441          class F, class G>
00442 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 
00443     detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F, 
00444     detail::VisitorNode<G> > > > > > >
00445 create_visitor(A & a, B & b, C & c, 
00446                D & d, E & e, F & f, G & g)
00447 {
00448    typedef detail::VisitorNode<G> _6_t;
00449    _6_t _6(g);
00450    typedef detail::VisitorNode<F, _6_t> _5_t;
00451    _5_t _5(f, _6);
00452    typedef detail::VisitorNode<E, _5_t> _4_t;
00453    _4_t _4(e, _5);
00454    typedef detail::VisitorNode<D, _4_t> _3_t;
00455    _3_t _3(d, _4);
00456    typedef detail::VisitorNode<C, _3_t> _2_t;
00457    _2_t _2(c, _3);
00458    typedef detail::VisitorNode<B, _2_t> _1_t;
00459    _1_t _1(b, _2);
00460    typedef detail::VisitorNode<A, _1_t> _0_t;
00461    _0_t _0(a, _1);
00462    return _0;
00463 }
00464 
00465 
00466 /** factory method to to be used with RandomForest::learn()
00467  */
00468 template<class A, class B, class C, class D, class E,
00469          class F, class G, class H>
00470 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 
00471     detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F, 
00472     detail::VisitorNode<G, detail::VisitorNode<H> > > > > > > >
00473 create_visitor(A & a, B & b, C & c, 
00474                D & d, E & e, F & f, 
00475                G & g, H & h)
00476 {
00477    typedef detail::VisitorNode<H> _7_t;
00478    _7_t _7(h);
00479    typedef detail::VisitorNode<G, _7_t> _6_t;
00480    _6_t _6(g, _7);
00481    typedef detail::VisitorNode<F, _6_t> _5_t;
00482    _5_t _5(f, _6);
00483    typedef detail::VisitorNode<E, _5_t> _4_t;
00484    _4_t _4(e, _5);
00485    typedef detail::VisitorNode<D, _4_t> _3_t;
00486    _3_t _3(d, _4);
00487    typedef detail::VisitorNode<C, _3_t> _2_t;
00488    _2_t _2(c, _3);
00489    typedef detail::VisitorNode<B, _2_t> _1_t;
00490    _1_t _1(b, _2);
00491    typedef detail::VisitorNode<A, _1_t> _0_t;
00492    _0_t _0(a, _1);
00493    return _0;
00494 }
00495 
00496 
00497 /** factory method to to be used with RandomForest::learn()
00498  */
00499 template<class A, class B, class C, class D, class E,
00500          class F, class G, class H, class I>
00501 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 
00502     detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F, 
00503     detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I> > > > > > > > >
00504 create_visitor(A & a, B & b, C & c, 
00505                D & d, E & e, F & f, 
00506                G & g, H & h, I & i)
00507 {
00508    typedef detail::VisitorNode<I> _8_t;
00509    _8_t _8(i);
00510    typedef detail::VisitorNode<H, _8_t> _7_t;
00511    _7_t _7(h, _8);
00512    typedef detail::VisitorNode<G, _7_t> _6_t;
00513    _6_t _6(g, _7);
00514    typedef detail::VisitorNode<F, _6_t> _5_t;
00515    _5_t _5(f, _6);
00516    typedef detail::VisitorNode<E, _5_t> _4_t;
00517    _4_t _4(e, _5);
00518    typedef detail::VisitorNode<D, _4_t> _3_t;
00519    _3_t _3(d, _4);
00520    typedef detail::VisitorNode<C, _3_t> _2_t;
00521    _2_t _2(c, _3);
00522    typedef detail::VisitorNode<B, _2_t> _1_t;
00523    _1_t _1(b, _2);
00524    typedef detail::VisitorNode<A, _1_t> _0_t;
00525    _0_t _0(a, _1);
00526    return _0;
00527 }
00528 
00529 /** factory method to to be used with RandomForest::learn()
00530  */
00531 template<class A, class B, class C, class D, class E,
00532          class F, class G, class H, class I, class J>
00533 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 
00534     detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F, 
00535     detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I,
00536     detail::VisitorNode<J> > > > > > > > > >
00537 create_visitor(A & a, B & b, C & c, 
00538                D & d, E & e, F & f, 
00539                G & g, H & h, I & i,
00540                J & j)
00541 {
00542    typedef detail::VisitorNode<J> _9_t;
00543    _9_t _9(j);
00544    typedef detail::VisitorNode<I, _9_t> _8_t;
00545    _8_t _8(i, _9);
00546    typedef detail::VisitorNode<H, _8_t> _7_t;
00547    _7_t _7(h, _8);
00548    typedef detail::VisitorNode<G, _7_t> _6_t;
00549    _6_t _6(g, _7);
00550    typedef detail::VisitorNode<F, _6_t> _5_t;
00551    _5_t _5(f, _6);
00552    typedef detail::VisitorNode<E, _5_t> _4_t;
00553    _4_t _4(e, _5);
00554    typedef detail::VisitorNode<D, _4_t> _3_t;
00555    _3_t _3(d, _4);
00556    typedef detail::VisitorNode<C, _3_t> _2_t;
00557    _2_t _2(c, _3);
00558    typedef detail::VisitorNode<B, _2_t> _1_t;
00559    _1_t _1(b, _2);
00560    typedef detail::VisitorNode<A, _1_t> _0_t;
00561    _0_t _0(a, _1);
00562    return _0;
00563 }
00564 
00565 //////////////////////////////////////////////////////////////////////////////
00566 // Visitors of communal interest.                                           //
00567 //////////////////////////////////////////////////////////////////////////////
00568 
00569 
00570 /** Visitor to gain information, later needed for online learning.
00571  */
00572 
00573 class OnlineLearnVisitor: public VisitorBase
00574 {
00575 public:
00576     //Set if we adjust thresholds
00577     bool adjust_thresholds;
00578     //Current tree id
00579     int tree_id;
00580     //Last node id for finding parent
00581     int last_node_id;
00582     //Need to now the label for interior node visiting
00583     vigra::Int32 current_label;
00584     //marginal distribution for interior nodes
00585     //
00586     OnlineLearnVisitor():
00587         adjust_thresholds(false), tree_id(0), last_node_id(0), current_label(0)
00588     {}
00589     struct MarginalDistribution
00590     {
00591         ArrayVector<Int32> leftCounts;
00592         Int32 leftTotalCounts;
00593         ArrayVector<Int32> rightCounts;
00594         Int32 rightTotalCounts;
00595         double gap_left;
00596         double gap_right;
00597     };
00598     typedef ArrayVector<vigra::Int32> IndexList;
00599 
00600     //All information for one tree
00601     struct TreeOnlineInformation
00602     {
00603         std::vector<MarginalDistribution> mag_distributions;
00604         std::vector<IndexList> index_lists;
00605         //map for linear index of mag_distributions
00606         std::map<int,int> interior_to_index;
00607         //map for linear index of index_lists
00608         std::map<int,int> exterior_to_index;
00609     };
00610 
00611     //All trees
00612     std::vector<TreeOnlineInformation> trees_online_information;
00613 
00614     /** Initialize, set the number of trees
00615      */
00616     template<class RF,class PR>
00617     void visit_at_beginning(RF & rf,const PR & pr)
00618     {
00619         tree_id=0;
00620         trees_online_information.resize(rf.options_.tree_count_);
00621     }
00622 
00623     /** Reset a tree
00624      */
00625     void reset_tree(int tree_id)
00626     {
00627         trees_online_information[tree_id].mag_distributions.clear();
00628         trees_online_information[tree_id].index_lists.clear();
00629         trees_online_information[tree_id].interior_to_index.clear();
00630         trees_online_information[tree_id].exterior_to_index.clear();
00631     }
00632 
00633     /** simply increase the tree count
00634     */
00635     template<class RF, class PR, class SM, class ST>
00636     void visit_after_tree(RF& rf, PR & pr,  SM & sm, ST & st, int index)
00637     {
00638         tree_id++;
00639     }
00640     
00641     template<class Tree, class Split, class Region, class Feature_t, class Label_t>
00642     void visit_after_split( Tree          & tree, 
00643                 Split         & split,
00644                             Region       & parent,
00645                             Region        & leftChild,
00646                             Region        & rightChild,
00647                             Feature_t     & features,
00648                             Label_t       & labels)
00649     {
00650         int linear_index;
00651         int addr=tree.topology_.size();
00652         if(split.createNode().typeID() == i_ThresholdNode)
00653         {
00654             if(adjust_thresholds)
00655             {
00656                 //Store marginal distribution
00657                 linear_index=trees_online_information[tree_id].mag_distributions.size();
00658                 trees_online_information[tree_id].interior_to_index[addr]=linear_index;
00659                 trees_online_information[tree_id].mag_distributions.push_back(MarginalDistribution());
00660 
00661                 trees_online_information[tree_id].mag_distributions.back().leftCounts=leftChild.classCounts_;
00662                 trees_online_information[tree_id].mag_distributions.back().rightCounts=rightChild.classCounts_;
00663 
00664                 trees_online_information[tree_id].mag_distributions.back().leftTotalCounts=leftChild.size_;
00665                 trees_online_information[tree_id].mag_distributions.back().rightTotalCounts=rightChild.size_;
00666                 //Store the gap
00667                 double gap_left,gap_right;
00668                 int i;
00669                 gap_left=features(leftChild[0],split.bestSplitColumn());
00670                 for(i=1;i<leftChild.size();++i)
00671                     if(features(leftChild[i],split.bestSplitColumn())>gap_left)
00672                         gap_left=features(leftChild[i],split.bestSplitColumn());
00673                 gap_right=features(rightChild[0],split.bestSplitColumn());
00674                 for(i=1;i<rightChild.size();++i)
00675                     if(features(rightChild[i],split.bestSplitColumn())<gap_right)
00676                         gap_right=features(rightChild[i],split.bestSplitColumn());
00677                 trees_online_information[tree_id].mag_distributions.back().gap_left=gap_left;
00678                 trees_online_information[tree_id].mag_distributions.back().gap_right=gap_right;
00679             }
00680         }
00681         else
00682         {
00683             //Store index list
00684             linear_index=trees_online_information[tree_id].index_lists.size();
00685             trees_online_information[tree_id].exterior_to_index[addr]=linear_index;
00686 
00687             trees_online_information[tree_id].index_lists.push_back(IndexList());
00688 
00689             trees_online_information[tree_id].index_lists.back().resize(parent.size_,0);
00690             std::copy(parent.begin_,parent.end_,trees_online_information[tree_id].index_lists.back().begin());
00691         }
00692     }
00693     void add_to_index_list(int tree,int node,int index)
00694     {
00695         if(!this->active_)
00696             return;
00697         TreeOnlineInformation &ti=trees_online_information[tree];
00698         ti.index_lists[ti.exterior_to_index[node]].push_back(index);
00699     }
00700     void move_exterior_node(int src_tree,int src_index,int dst_tree,int dst_index)
00701     {
00702         if(!this->active_)
00703             return;
00704         trees_online_information[dst_tree].exterior_to_index[dst_index]=trees_online_information[src_tree].exterior_to_index[src_index];
00705         trees_online_information[src_tree].exterior_to_index.erase(src_index);
00706     }
00707     /** do something when visiting a internal node during getToLeaf
00708      *
00709      * remember as last node id, for finding the parent of the last external node
00710      * also: adjust class counts and borders
00711      */
00712     template<class TR, class IntT, class TopT,class Feat>
00713         void visit_internal_node(TR & tr, IntT index, TopT node_t,Feat & features)
00714         {
00715             last_node_id=index;
00716             if(adjust_thresholds)
00717             {
00718                 vigra_assert(node_t==i_ThresholdNode,"We can only visit threshold nodes");
00719                 //Check if we are in the gap
00720                 double value=features(0, Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).column());
00721                 TreeOnlineInformation &ti=trees_online_information[tree_id];
00722                 MarginalDistribution &m=ti.mag_distributions[ti.interior_to_index[index]];
00723                 if(value>m.gap_left && value<m.gap_right)
00724                 {
00725                     //Check which site we want to go
00726                     if(m.leftCounts[current_label]/double(m.leftTotalCounts)>m.rightCounts[current_label]/double(m.rightTotalCounts))
00727                     {
00728                         //We want to go left
00729                         m.gap_left=value;
00730                     }
00731                     else
00732                     {
00733                         //We want to go right
00734                         m.gap_right=value;
00735                     }
00736                     Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold()=(m.gap_right+m.gap_left)/2.0;
00737                 }
00738                 //Adjust class counts
00739                 if(value>Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold())
00740                 {
00741                     ++m.rightTotalCounts;
00742                     ++m.rightCounts[current_label];
00743                 }
00744                 else
00745                 {
00746                     ++m.leftTotalCounts;
00747                     ++m.rightCounts[current_label];
00748                 }
00749             }
00750         }
00751     /** do something when visiting a extern node during getToLeaf
00752      * 
00753      * Store the new index!
00754      */
00755 };
00756 
00757 //////////////////////////////////////////////////////////////////////////////
00758 // Out of Bag Error estimates                                               //
00759 //////////////////////////////////////////////////////////////////////////////
00760 
00761 
00762 /** Visitor that calculates the oob error of each individual randomized
00763  * decision tree. 
00764  *
00765  * After training a tree, all those samples that are OOB for this particular tree
00766  * are put down the tree and the error estimated. 
00767  * the per tree oob error is the average of the individual error estimates. 
00768  * (oobError = average error of one randomized tree)
00769  * Note: This is Not the OOB - Error estimate suggested by Breiman (See OOB_Error 
00770  * visitor)
00771  */
00772 class OOB_PerTreeError:public VisitorBase
00773 {
00774 public:
00775     /** Average error of one randomized decision tree
00776      */
00777     double oobError;
00778 
00779     int totalOobCount;
00780     ArrayVector<int> oobCount,oobErrorCount;
00781 
00782     OOB_PerTreeError()
00783     : oobError(0.0),
00784       totalOobCount(0)
00785     {}
00786 
00787 
00788     bool has_value()
00789     {
00790         return true;
00791     }
00792 
00793 
00794     /** does the basic calculation per tree*/
00795     template<class RF, class PR, class SM, class ST>
00796     void visit_after_tree(    RF& rf, PR & pr,  SM & sm, ST & st, int index)
00797     {
00798         //do the first time called.
00799         if(int(oobCount.size()) != rf.ext_param_.row_count_)
00800         {
00801             oobCount.resize(rf.ext_param_.row_count_, 0);
00802             oobErrorCount.resize(rf.ext_param_.row_count_, 0);
00803         }
00804         // go through the samples
00805         for(int l = 0; l < rf.ext_param_.row_count_; ++l)
00806         {
00807             // if the lth sample is oob...
00808             if(!sm.is_used()[l])
00809             {
00810                 ++oobCount[l];
00811                 if(     rf.tree(index)
00812                             .predictLabel(rowVector(pr.features(), l)) 
00813                     !=  pr.response()(l,0))
00814                 {
00815                     ++oobErrorCount[l];
00816                 }
00817             }
00818 
00819         }
00820     }
00821 
00822     /** Does the normalisation
00823      */
00824     template<class RF, class PR>
00825     void visit_at_end(RF & rf, PR & pr)
00826     {
00827         // do some normalisation
00828         for(int l=0; l < (int)rf.ext_param_.row_count_; ++l)
00829         {
00830             if(oobCount[l])
00831             {
00832                 oobError += double(oobErrorCount[l]) / oobCount[l];
00833                 ++totalOobCount;
00834             }
00835         } 
00836         oobError/=totalOobCount;
00837     }
00838     
00839 };
00840 
00841 /** Visitor that calculates the oob error of the ensemble
00842  *  This rate should be used to estimate the crossvalidation 
00843  *  error rate.
00844  *  Here each sample is put down those trees, for which this sample
00845  *  is OOB i.e. if sample #1 is  OOB for trees 1, 3 and 5 we calculate
00846  *  the output using the ensemble consisting only of trees 1 3 and 5. 
00847  *
00848  *  Using normal bagged sampling each sample is OOB for approx. 33% of trees
00849  *  The error rate obtained as such therefore corresponds to crossvalidation
00850  *  rate obtained using a ensemble containing 33% of the trees.
00851  */
00852 class OOB_Error : public VisitorBase
00853 {
00854     typedef MultiArrayShape<2>::type Shp;
00855     int class_count;
00856     bool is_weighted;
00857     MultiArray<2,double> tmp_prob;
00858     public:
00859 
00860     MultiArray<2, double>       prob_oob; 
00861     /** Ensemble oob error rate
00862      */
00863     double                      oob_breiman;
00864 
00865     MultiArray<2, double>       oobCount;
00866     ArrayVector< int>           indices; 
00867     OOB_Error() : VisitorBase(), oob_breiman(0.0) {}
00868 #ifdef HasHDF5
00869     void save(std::string filen, std::string pathn)
00870     {
00871         if(*(pathn.end()-1) != '/')
00872             pathn += "/";
00873         const char* filename = filen.c_str();
00874         MultiArray<2, double> temp(Shp(1,1), 0.0); 
00875         temp[0] = oob_breiman;
00876         writeHDF5(filename, (pathn + "breiman_error").c_str(), temp);
00877     }
00878 #endif
00879     // negative value if sample was ib, number indicates how often.
00880     //  value >=0  if sample was oob, 0 means fail 1, correct
00881 
00882     template<class RF, class PR>
00883     void visit_at_beginning(RF & rf, PR & pr)
00884     {
00885         class_count = rf.class_count();
00886         tmp_prob.reshape(Shp(1, class_count), 0); 
00887         prob_oob.reshape(Shp(rf.ext_param().row_count_,class_count), 0);
00888         is_weighted = rf.options().predict_weighted_;
00889         indices.resize(rf.ext_param().row_count_);
00890         if(int(oobCount.size()) != rf.ext_param_.row_count_)
00891         {
00892             oobCount.reshape(Shp(rf.ext_param_.row_count_, 1), 0);
00893         }
00894         for(int ii = 0; ii < rf.ext_param().row_count_; ++ii)
00895         {
00896             indices[ii] = ii;
00897         }
00898     }
00899 
00900     template<class RF, class PR, class SM, class ST>
00901     void visit_after_tree(RF& rf, PR & pr,  SM & sm, ST & st, int index)
00902     {
00903         // go through the samples
00904         int total_oob =0;
00905         // FIXME: magic number 10000: invoke special treatment when when msample << sample_count
00906         //                            (i.e. the OOB sample ist very large)
00907         //                     40000: use at most 40000 OOB samples per class for OOB error estimate 
00908         if(rf.ext_param_.actual_msample_ < pr.features().shape(0) - 10000)
00909         {
00910             ArrayVector<int> oob_indices;
00911             ArrayVector<int> cts(class_count, 0);
00912             std::random_shuffle(indices.begin(), indices.end());
00913             for(int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
00914             {
00915                 if(!sm.is_used()[indices[ii]] && cts[pr.response()(indices[ii], 0)] < 40000)
00916                 {
00917                     oob_indices.push_back(indices[ii]);
00918                     ++cts[pr.response()(indices[ii], 0)];
00919                 }
00920             }
00921             for(unsigned int ll = 0; ll < oob_indices.size(); ++ll)
00922             {
00923                 // update number of trees in which current sample is oob
00924                 ++oobCount[oob_indices[ll]];
00925 
00926                 // update number of oob samples in this tree.
00927                 ++total_oob; 
00928                 // get the predicted votes ---> tmp_prob;
00929                 int pos =  rf.tree(index).getToLeaf(rowVector(pr.features(),oob_indices[ll]));
00930                 Node<e_ConstProbNode> node ( rf.tree(index).topology_, 
00931                                                     rf.tree(index).parameters_,
00932                                                     pos);
00933                 tmp_prob.init(0); 
00934                 for(int ii = 0; ii < class_count; ++ii)
00935                 {
00936                     tmp_prob[ii] = node.prob_begin()[ii];
00937                 }
00938                 if(is_weighted)
00939                 {
00940                     for(int ii = 0; ii < class_count; ++ii)
00941                         tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
00942                 }
00943                 rowVector(prob_oob, oob_indices[ll]) += tmp_prob;
00944                 
00945             }
00946         }else
00947         {
00948             for(int ll = 0; ll < rf.ext_param_.row_count_; ++ll)
00949             {
00950                 // if the lth sample is oob...
00951                 if(!sm.is_used()[ll])
00952                 {
00953                     // update number of trees in which current sample is oob
00954                     ++oobCount[ll];
00955 
00956                     // update number of oob samples in this tree.
00957                     ++total_oob; 
00958                     // get the predicted votes ---> tmp_prob;
00959                     int pos =  rf.tree(index).getToLeaf(rowVector(pr.features(),ll));
00960                     Node<e_ConstProbNode> node ( rf.tree(index).topology_, 
00961                                                         rf.tree(index).parameters_,
00962                                                         pos);
00963                     tmp_prob.init(0); 
00964                     for(int ii = 0; ii < class_count; ++ii)
00965                     {
00966                         tmp_prob[ii] = node.prob_begin()[ii];
00967                     }
00968                     if(is_weighted)
00969                     {
00970                         for(int ii = 0; ii < class_count; ++ii)
00971                             tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
00972                     }
00973                     rowVector(prob_oob, ll) += tmp_prob;
00974                 }
00975             }
00976         }
00977         // go through the ib samples; 
00978     }
00979 
00980     /** Normalise variable importance after the number of trees is known.
00981      */
00982     template<class RF, class PR>
00983     void visit_at_end(RF & rf, PR & pr)
00984     {
00985         // ullis original metric and breiman style stuff
00986         int totalOobCount =0;
00987         int breimanstyle = 0;
00988         for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
00989         {
00990             if(oobCount[ll])
00991             {
00992                 if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0))
00993                    ++breimanstyle;
00994                 ++totalOobCount;
00995             }
00996         }
00997         oob_breiman = double(breimanstyle)/totalOobCount; 
00998     }
00999 };
01000 
01001 
01002 /** Visitor that calculates different OOB error statistics
01003  */
01004 class CompleteOOBInfo : public VisitorBase
01005 {
01006     typedef MultiArrayShape<2>::type Shp;
01007     int class_count;
01008     bool is_weighted;
01009     MultiArray<2,double> tmp_prob;
01010     public:
01011 
01012     /** OOB Error rate of each individual tree
01013      */
01014     MultiArray<2, double>       oob_per_tree;
01015     /** Mean of oob_per_tree
01016      */
01017     double                      oob_mean;
01018     /**Standard deviation of oob_per_tree
01019      */
01020     double                      oob_std;
01021     
01022     MultiArray<2, double>       prob_oob; 
01023     /** Ensemble OOB error
01024      *
01025      * \sa OOB_Error
01026      */
01027     double                      oob_breiman;
01028 
01029     MultiArray<2, double>       oobCount;
01030     MultiArray<2, double>       oobErrorCount;
01031     /** Per Tree OOB error calculated as in OOB_PerTreeError
01032      * (Ulli's version)
01033      */
01034     double                      oob_per_tree2;
01035 
01036     /**Column containing the development of the Ensemble
01037      * error rate with increasing number of trees
01038      */
01039     MultiArray<2, double>       breiman_per_tree;
01040     /** 4 dimensional array containing the development of confusion matrices 
01041      * with number of trees - can be used to estimate ROC curves etc.
01042      *
01043      * oobroc_per_tree(ii,jj,kk,ll) 
01044      * corresponds true label = ii 
01045      * predicted label = jj
01046      * confusion matrix after ll trees
01047      *
01048      * explanation of third index:
01049      *
01050      * Two class case:
01051      * kk = 0 - (treeCount-1)
01052      *         Threshold is on Probability for class 0  is kk/(treeCount-1);
01053      * More classes:
01054      * kk = 0. Threshold on probability set by argMax of the probability array.
01055      */
01056     MultiArray<4, double>       oobroc_per_tree;
01057     
01058     CompleteOOBInfo() : VisitorBase(), oob_mean(0), oob_std(0), oob_per_tree2(0)  {}
01059 
01060 #ifdef HasHDF5
01061     /** save to HDF5 file
01062      */
01063     void save(std::string filen, std::string pathn)
01064     {
01065         if(*(pathn.end()-1) != '/')
01066             pathn += "/";
01067         const char* filename = filen.c_str();
01068         MultiArray<2, double> temp(Shp(1,1), 0.0); 
01069         writeHDF5(filename, (pathn + "oob_per_tree").c_str(), oob_per_tree);
01070         writeHDF5(filename, (pathn + "oobroc_per_tree").c_str(), oobroc_per_tree);
01071         writeHDF5(filename, (pathn + "breiman_per_tree").c_str(), breiman_per_tree);
01072         temp[0] = oob_mean;
01073         writeHDF5(filename, (pathn + "per_tree_error").c_str(), temp);
01074         temp[0] = oob_std;
01075         writeHDF5(filename, (pathn + "per_tree_error_std").c_str(), temp);
01076         temp[0] = oob_breiman;
01077         writeHDF5(filename, (pathn + "breiman_error").c_str(), temp);
01078         temp[0] = oob_per_tree2;
01079         writeHDF5(filename, (pathn + "ulli_error").c_str(), temp);
01080     }
01081 #endif
01082     // negative value if sample was ib, number indicates how often.
01083     //  value >=0  if sample was oob, 0 means fail 1, correct
01084 
01085     template<class RF, class PR>
01086     void visit_at_beginning(RF & rf, PR & pr)
01087     {
01088         class_count = rf.class_count();
01089         if(class_count == 2)
01090             oobroc_per_tree.reshape(MultiArrayShape<4>::type(2,2,rf.tree_count(), rf.tree_count()));
01091         else
01092             oobroc_per_tree.reshape(MultiArrayShape<4>::type(rf.class_count(),rf.class_count(),1, rf.tree_count()));
01093         tmp_prob.reshape(Shp(1, class_count), 0); 
01094         prob_oob.reshape(Shp(rf.ext_param().row_count_,class_count), 0);
01095         is_weighted = rf.options().predict_weighted_;
01096         oob_per_tree.reshape(Shp(1, rf.tree_count()), 0);
01097         breiman_per_tree.reshape(Shp(1, rf.tree_count()), 0);
01098         //do the first time called.
01099         if(int(oobCount.size()) != rf.ext_param_.row_count_)
01100         {
01101             oobCount.reshape(Shp(rf.ext_param_.row_count_, 1), 0);
01102             oobErrorCount.reshape(Shp(rf.ext_param_.row_count_,1), 0);
01103         }
01104     }
01105 
01106     template<class RF, class PR, class SM, class ST>
01107     void visit_after_tree(RF& rf, PR & pr,  SM & sm, ST & st, int index)
01108     {
01109         // go through the samples
01110         int total_oob =0;
01111         int wrong_oob =0;
01112         for(int ll = 0; ll < rf.ext_param_.row_count_; ++ll)
01113         {
01114             // if the lth sample is oob...
01115             if(!sm.is_used()[ll])
01116             {
01117                 // update number of trees in which current sample is oob
01118                 ++oobCount[ll];
01119 
01120                 // update number of oob samples in this tree.
01121                 ++total_oob; 
01122                 // get the predicted votes ---> tmp_prob;
01123                 int pos =  rf.tree(index).getToLeaf(rowVector(pr.features(),ll));
01124                 Node<e_ConstProbNode> node ( rf.tree(index).topology_, 
01125                                                     rf.tree(index).parameters_,
01126                                                     pos);
01127                 tmp_prob.init(0); 
01128                 for(int ii = 0; ii < class_count; ++ii)
01129                 {
01130                     tmp_prob[ii] = node.prob_begin()[ii];
01131                 }
01132                 if(is_weighted)
01133                 {
01134                     for(int ii = 0; ii < class_count; ++ii)
01135                         tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1));
01136                 }
01137                 rowVector(prob_oob, ll) += tmp_prob;
01138                 int label = argMax(tmp_prob); 
01139                 
01140                 if(label != pr.response()(ll, 0))
01141                 {
01142                     // update number of wrong oob samples in this tree.
01143                     ++wrong_oob;
01144                     // update number of trees in which current sample is wrong oob
01145                     ++oobErrorCount[ll];
01146                 }
01147             }
01148         }
01149         int breimanstyle = 0;
01150         int totalOobCount = 0;
01151         for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
01152         {
01153             if(oobCount[ll])
01154             {
01155                 if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0))
01156                    ++breimanstyle;
01157                 ++totalOobCount;
01158                 if(oobroc_per_tree.shape(2) == 1)
01159                 {
01160                     oobroc_per_tree(pr.response()(ll,0), argMax(rowVector(prob_oob, ll)),0 ,index)++;
01161                 }
01162             }
01163         }
01164         if(oobroc_per_tree.shape(2) == 1)
01165             oobroc_per_tree.bindOuter(index)/=totalOobCount;
01166         if(oobroc_per_tree.shape(2) > 1)
01167         {
01168             MultiArrayView<3, double> current_roc 
01169                     = oobroc_per_tree.bindOuter(index);
01170             for(int gg = 0; gg < current_roc.shape(2); ++gg)
01171             {
01172                 for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
01173                 {
01174                     if(oobCount[ll])
01175                     {
01176                         int pred = prob_oob(ll, 1) > (double(gg)/double(current_roc.shape(2)))?
01177                                         1 : 0; 
01178                         current_roc(pr.response()(ll, 0), pred, gg)+= 1; 
01179                     }
01180                 }
01181                 current_roc.bindOuter(gg)/= totalOobCount;
01182             }
01183         }
01184         breiman_per_tree[index] = double(breimanstyle)/double(totalOobCount);
01185         oob_per_tree[index] = double(wrong_oob)/double(total_oob);
01186         // go through the ib samples; 
01187     }
01188 
01189     /** Normalise variable importance after the number of trees is known.
01190      */
01191     template<class RF, class PR>
01192     void visit_at_end(RF & rf, PR & pr)
01193     {
01194         // ullis original metric and breiman style stuff
01195         oob_per_tree2 = 0; 
01196         int totalOobCount =0;
01197         int breimanstyle = 0;
01198         for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll)
01199         {
01200             if(oobCount[ll])
01201             {
01202                 if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0))
01203                    ++breimanstyle;
01204                 oob_per_tree2 += double(oobErrorCount[ll]) / oobCount[ll];
01205                 ++totalOobCount;
01206             }
01207         }
01208         oob_per_tree2 /= totalOobCount; 
01209         oob_breiman = double(breimanstyle)/totalOobCount; 
01210         // mean error of each tree
01211         MultiArrayView<2, double> mean(Shp(1,1), &oob_mean);
01212         MultiArrayView<2, double> stdDev(Shp(1,1), &oob_std);
01213         rowStatistics(oob_per_tree, mean, stdDev);
01214     }
01215 };
01216 
01217 /** calculate variable importance while learning.
01218  */
01219 class VariableImportanceVisitor : public VisitorBase
01220 {
01221     public:
01222 
01223     /** This Array has the same entries as the R - random forest variable
01224      *  importance.
01225      *  Matrix is   featureCount by (classCount +2)
01226      *  variable_importance_(ii,jj) is the variable importance measure of 
01227      *  the ii-th variable according to:
01228      *  jj = 0 - (classCount-1)
01229      *      classwise permutation importance 
01230      *  jj = rowCount(variable_importance_) -2
01231      *      permutation importance
01232      *  jj = rowCount(variable_importance_) -1
01233      *      gini decrease importance.
01234      *
01235      *  permutation importance:
01236      *  The difference between the fraction of OOB samples classified correctly
01237      *  before and after permuting (randomizing) the ii-th column is calculated.
01238      *  The ii-th column is permuted rep_cnt times.
01239      *
01240      *  class wise permutation importance:
01241      *  same as permutation importance. We only look at those OOB samples whose 
01242      *  response corresponds to class jj.
01243      *
01244      *  gini decrease importance:
01245      *  row ii corresponds to the sum of all gini decreases induced by variable ii 
01246      *  in each node of the random forest.
01247      */
01248     MultiArray<2, double>       variable_importance_;
01249     int                         repetition_count_;
01250     bool                        in_place_;
01251 
01252 #ifdef HasHDF5
01253     void save(std::string filename, std::string prefix)
01254     {
01255         prefix = "variable_importance_" + prefix;
01256         writeHDF5(filename.c_str(), 
01257                         prefix.c_str(), 
01258                         variable_importance_);
01259     }
01260 #endif
01261 
01262     /** Constructor
01263      * \param rep_cnt (defautl: 10) how often should 
01264      * the permutation take place. Set to 1 to make calculation faster (but
01265      * possibly more instable)
01266      */
01267     VariableImportanceVisitor(int rep_cnt = 10) 
01268     :   repetition_count_(rep_cnt)
01269 
01270     {}
01271 
01272     /** calculates impurity decrease based variable importance after every
01273      * split.  
01274      */
01275     template<class Tree, class Split, class Region, class Feature_t, class Label_t>
01276     void visit_after_split( Tree          & tree, 
01277                             Split         & split,
01278                             Region        & parent,
01279                             Region        & leftChild,
01280                             Region        & rightChild,
01281                             Feature_t     & features,
01282                             Label_t       & labels)
01283     {
01284         //resize to right size when called the first time
01285         
01286         Int32 const  class_count = tree.ext_param_.class_count_;
01287         Int32 const  column_count = tree.ext_param_.column_count_;
01288         if(variable_importance_.size() == 0)
01289         {
01290             
01291             variable_importance_
01292                 .reshape(MultiArrayShape<2>::type(column_count, 
01293                                                  class_count+2));
01294         }
01295 
01296         if(split.createNode().typeID() == i_ThresholdNode)
01297         {
01298             Node<i_ThresholdNode> node(split.createNode());
01299             variable_importance_(node.column(),class_count+1) 
01300                 += split.region_gini_ - split.minGini();
01301         }
01302     }
01303 
01304     /**compute permutation based var imp. 
01305      * (Only an Array of size oob_sample_count x 1 is created.
01306      *  - apposed to oob_sample_count x feature_count in the other method.
01307      * 
01308      * \sa FieldProxy
01309      */
01310     template<class RF, class PR, class SM, class ST>
01311     void after_tree_ip_impl(RF& rf, PR & pr,  SM & sm, ST & st, int index)
01312     {
01313         typedef MultiArrayShape<2>::type Shp_t;
01314         Int32                   column_count = rf.ext_param_.column_count_;
01315         Int32                   class_count  = rf.ext_param_.class_count_;  
01316         
01317         /* This solution saves memory uptake but not multithreading
01318          * compatible
01319          */
01320         // remove the const cast on the features (yep , I know what I am 
01321         // doing here.) data is not destroyed.
01322         //typename PR::Feature_t & features 
01323         //    = const_cast<typename PR::Feature_t &>(pr.features());
01324 
01325         typedef typename PR::FeatureWithMemory_t FeatureArray;
01326         typedef typename FeatureArray::value_type FeatureValue;
01327 
01328         FeatureArray features = pr.features();
01329 
01330         //find the oob indices of current tree. 
01331         ArrayVector<Int32>      oob_indices;
01332         ArrayVector<Int32>::iterator
01333                                 iter;
01334         for(int ii = 0; ii < rf.ext_param_.row_count_; ++ii)
01335             if(!sm.is_used()[ii])
01336                 oob_indices.push_back(ii);
01337 
01338         //create space to back up a column      
01339         ArrayVector<FeatureValue>     backup_column;
01340 
01341         // Random foo
01342 #ifdef CLASSIFIER_TEST
01343         RandomMT19937           random(1);
01344 #else 
01345         RandomMT19937           random(RandomSeed);
01346 #endif
01347         UniformIntRandomFunctor<RandomMT19937>  
01348                                 randint(random);
01349 
01350 
01351         //make some space for the results
01352         MultiArray<2, double>
01353                     oob_right(Shp_t(1, class_count + 1)); 
01354         MultiArray<2, double>
01355                     perm_oob_right (Shp_t(1, class_count + 1)); 
01356             
01357         
01358         // get the oob success rate with the original samples
01359         for(iter = oob_indices.begin(); 
01360             iter != oob_indices.end(); 
01361             ++iter)
01362         {
01363             if(rf.tree(index)
01364                     .predictLabel(rowVector(features, *iter)) 
01365                 ==  pr.response()(*iter, 0))
01366             {
01367                 //per class
01368                 ++oob_right[pr.response()(*iter,0)];
01369                 //total
01370                 ++oob_right[class_count];
01371             }
01372         }
01373         //get the oob rate after permuting the ii'th dimension.
01374         for(int ii = 0; ii < column_count; ++ii)
01375         {
01376             perm_oob_right.init(0.0); 
01377             //make backup of original column
01378             backup_column.clear();
01379             for(iter = oob_indices.begin(); 
01380                 iter != oob_indices.end(); 
01381                 ++iter)
01382             {
01383                 backup_column.push_back(features(*iter,ii));
01384             }
01385             
01386             //get the oob rate after permuting the ii'th dimension.
01387             for(int rr = 0; rr < repetition_count_; ++rr)
01388             {               
01389                 //permute dimension. 
01390                 int n = oob_indices.size();
01391                 for(int jj = 1; jj < n; ++jj)
01392                     std::swap(features(oob_indices[jj], ii), 
01393                               features(oob_indices[randint(jj+1)], ii));
01394 
01395                 //get the oob success rate after permuting
01396                 for(iter = oob_indices.begin(); 
01397                     iter != oob_indices.end(); 
01398                     ++iter)
01399                 {
01400                     if(rf.tree(index)
01401                             .predictLabel(rowVector(features, *iter)) 
01402                         ==  pr.response()(*iter, 0))
01403                     {
01404                         //per class
01405                         ++perm_oob_right[pr.response()(*iter, 0)];
01406                         //total
01407                         ++perm_oob_right[class_count];
01408                     }
01409                 }
01410             }
01411             
01412             
01413             //normalise and add to the variable_importance array.
01414             perm_oob_right  /=  repetition_count_;
01415             perm_oob_right -=oob_right;
01416             perm_oob_right *= -1;
01417             perm_oob_right      /=  oob_indices.size();
01418             variable_importance_
01419                 .subarray(Shp_t(ii,0), 
01420                           Shp_t(ii+1,class_count+1)) += perm_oob_right;
01421             //copy back permuted dimension
01422             for(int jj = 0; jj < int(oob_indices.size()); ++jj)
01423                 features(oob_indices[jj], ii) = backup_column[jj];
01424         }
01425     }
01426 
01427     /** calculate permutation based impurity after every tree has been 
01428      * learned  default behaviour is that this happens out of place.
01429      * If you have very big data sets and want to avoid copying of data 
01430      * set the in_place_ flag to true. 
01431      */
01432     template<class RF, class PR, class SM, class ST>
01433     void visit_after_tree(RF& rf, PR & pr,  SM & sm, ST & st, int index)
01434     {
01435             after_tree_ip_impl(rf, pr, sm, st, index);
01436     }
01437 
01438     /** Normalise variable importance after the number of trees is known.
01439      */
01440     template<class RF, class PR>
01441     void visit_at_end(RF & rf, PR & pr)
01442     {
01443         variable_importance_ /= rf.trees_.size();
01444     }
01445 };
01446 
01447 /** Verbose output
01448  */
01449 class RandomForestProgressVisitor : public VisitorBase {
01450     public:
01451     RandomForestProgressVisitor() : VisitorBase() {}
01452 
01453     template<class RF, class PR, class SM, class ST>
01454     void visit_after_tree(RF& rf, PR & pr,  SM & sm, ST & st, int index){
01455         if(index != rf.options().tree_count_-1) {
01456             std::cout << "\r[" << std::setw(10) << (index+1)/static_cast<double>(rf.options().tree_count_)*100 << "%]"
01457                       << " (" << index+1 << " of " << rf.options().tree_count_ << ") done" << std::flush;
01458         }
01459         else {
01460             std::cout << "\r[" << std::setw(10) << 100.0 << "%]" << std::endl;
01461         }
01462     }
01463     
01464     template<class RF, class PR>
01465     void visit_at_end(RF const & rf, PR const & pr) {
01466         std::string a = TOCS;
01467         std::cout << "all " << rf.options().tree_count_ << " trees have been learned in " << a  << std::endl;
01468     }
01469     
01470     template<class RF, class PR>
01471     void visit_at_beginning(RF const & rf, PR const & pr) {
01472         TIC;
01473         std::cout << "growing random forest, which will have " << rf.options().tree_count_ << " trees" << std::endl;
01474     }
01475     
01476     private:
01477     USETICTOC;
01478 };
01479 
01480 
01481 /** Computes Correlation/Similarity Matrix of features while learning
01482  * random forest.
01483  */
01484 class CorrelationVisitor : public VisitorBase
01485 {
01486     public:
01487     /** gini_missc(ii, jj) describes how well variable jj can describe a partition
01488      * created on variable ii(when variable ii was chosen)
01489      */ 
01490     MultiArray<2, double>   gini_missc;
01491     MultiArray<2, int>      tmp_labels;
01492     /** additional noise features. 
01493      */
01494     MultiArray<2, double>   noise;
01495     MultiArray<2, double>   noise_l;
01496     /** how well can a noise column describe a partition created on variable ii.
01497      */
01498     MultiArray<2, double>   corr_noise;
01499     MultiArray<2, double>   corr_l;
01500 
01501     /** Similarity Matrix
01502      * 
01503      * (numberOfFeatures + 1) by (number Of Features + 1) Matrix
01504      * gini_missc 
01505      *  - row normalized by the number of times the column was chosen
01506      *  - mean of corr_noise subtracted
01507      *  - and symmetrised. 
01508      *          
01509      */
01510     MultiArray<2, double>   similarity;
01511     /** Distance Matrix 1-similarity
01512      */
01513     MultiArray<2, double>   distance;
01514     ArrayVector<int>        tmp_cc;
01515     
01516     /** How often was variable ii chosen
01517      */
01518     ArrayVector<int>        numChoices;
01519     typedef BestGiniOfColumn<GiniCriterion> ColumnDecisionFunctor;
01520     BestGiniOfColumn<GiniCriterion>         bgfunc;
01521     void save(std::string file, std::string prefix)
01522     {
01523         /*
01524         std::string tmp;
01525 #define VAR_WRITE(NAME) \
01526         tmp = #NAME;\
01527         tmp += "_";\
01528         tmp += prefix;\
01529         vigra::writeToHDF5File(file.c_str(), tmp.c_str(), NAME);
01530         VAR_WRITE(gini_missc);
01531         VAR_WRITE(corr_noise);
01532         VAR_WRITE(distance);
01533         VAR_WRITE(similarity);
01534         vigra::writeToHDF5File(file.c_str(), "nChoices", MultiArrayView<2, int>(MultiArrayShape<2>::type(numChoices.size(),1), numChoices.data()));
01535 #undef VAR_WRITE
01536 */
01537     }
01538     template<class RF, class PR>
01539     void visit_at_beginning(RF const & rf, PR  & pr)
01540     {
01541         typedef MultiArrayShape<2>::type Shp;
01542         int n = rf.ext_param_.column_count_;
01543         gini_missc.reshape(Shp(n +1,n+ 1));
01544         corr_noise.reshape(Shp(n + 1, 10));
01545         corr_l.reshape(Shp(n +1, 10));
01546 
01547         noise.reshape(Shp(pr.features().shape(0), 10));
01548         noise_l.reshape(Shp(pr.features().shape(0), 10));
01549         RandomMT19937 random(RandomSeed);
01550         for(int ii = 0; ii < noise.size(); ++ii)
01551         {
01552             noise[ii]   = random.uniform53();
01553             noise_l[ii] = random.uniform53()  > 0.5;
01554         }
01555         bgfunc = ColumnDecisionFunctor( rf.ext_param_);
01556         tmp_labels.reshape(pr.response().shape()); 
01557         tmp_cc.resize(2);
01558         numChoices.resize(n+1);
01559         // look at all axes
01560     }
01561     template<class RF, class PR>
01562     void visit_at_end(RF const & rf, PR const & pr)
01563     {
01564         typedef MultiArrayShape<2>::type Shp;
01565         similarity.reshape(gini_missc.shape());
01566         similarity = gini_missc;;
01567         MultiArray<2, double> mean_noise(Shp(corr_noise.shape(0), 1));
01568         rowStatistics(corr_noise, mean_noise);
01569         mean_noise/= MultiArrayView<2, int>(mean_noise.shape(), numChoices.data());        
01570         int rC = similarity.shape(0);
01571         for(int jj = 0; jj < rC-1; ++jj)
01572         {
01573             rowVector(similarity, jj) /= numChoices[jj];
01574             rowVector(similarity, jj) -= mean_noise(jj, 0);
01575         }
01576         for(int jj = 0; jj < rC; ++jj)
01577         {
01578             similarity(rC -1, jj) /= numChoices[jj];
01579         }
01580         rowVector(similarity, rC -  1) -= mean_noise(rC-1, 0);
01581         similarity = abs(similarity);
01582         FindMinMax<double> minmax;
01583         inspectMultiArray(srcMultiArrayRange(similarity), minmax);
01584         
01585         for(int jj = 0; jj < rC; ++jj)
01586             similarity(jj, jj) = minmax.max;
01587         
01588         similarity.subarray(Shp(0,0), Shp(rC-1, rC-1)) 
01589             += similarity.subarray(Shp(0,0), Shp(rC-1, rC-1)).transpose();
01590         similarity.subarray(Shp(0,0), Shp(rC-1, rC-1))/= 2;  
01591         columnVector(similarity, rC-1) = rowVector(similarity, rC-1).transpose();
01592         for(int jj = 0; jj < rC; ++jj)
01593             similarity(jj, jj) = 0;
01594         
01595         FindMinMax<double> minmax2;
01596         inspectMultiArray(srcMultiArrayRange(similarity), minmax2);
01597         for(int jj = 0; jj < rC; ++jj)
01598             similarity(jj, jj) = minmax2.max;
01599         distance.reshape(gini_missc.shape(), minmax2.max);
01600         distance -= similarity; 
01601     }
01602 
01603     template<class Tree, class Split, class Region, class Feature_t, class Label_t>
01604     void visit_after_split( Tree          & tree, 
01605                             Split         & split,
01606                             Region        & parent,
01607                             Region        & leftChild,
01608                             Region        & rightChild,
01609                             Feature_t     & features,
01610                             Label_t       & labels)
01611     {
01612         if(split.createNode().typeID() == i_ThresholdNode)
01613         {
01614             double wgini;
01615             tmp_cc.init(0); 
01616             for(int ii = 0; ii < parent.size(); ++ii)
01617             {
01618                 tmp_labels[parent[ii]] 
01619                     = (features(parent[ii], split.bestSplitColumn()) < split.bestSplitThreshold());
01620                 ++tmp_cc[tmp_labels[parent[ii]]];
01621             }
01622             double region_gini = bgfunc.loss_of_region(tmp_labels, 
01623                                                        parent.begin(),
01624                                                        parent.end(),
01625                                                        tmp_cc);
01626 
01627             int n = split.bestSplitColumn(); 
01628             ++numChoices[n];
01629             ++(*(numChoices.end()-1));
01630             //this functor does all the work
01631             for(int k = 0; k < features.shape(1); ++k)
01632             {
01633                 bgfunc(columnVector(features, k),
01634                        tmp_labels, 
01635                        parent.begin(), parent.end(), 
01636                        tmp_cc);
01637                 wgini = (region_gini - bgfunc.min_gini_);
01638                 gini_missc(n, k) 
01639                     += wgini;
01640             }
01641             for(int k = 0; k < 10; ++k)
01642             {
01643                 bgfunc(columnVector(noise, k),
01644                        tmp_labels, 
01645                        parent.begin(), parent.end(), 
01646                        tmp_cc);
01647                 wgini = (region_gini - bgfunc.min_gini_);
01648                 corr_noise(n, k) 
01649                     += wgini;
01650             }
01651             
01652             for(int k = 0; k < 10; ++k)
01653             {
01654                 bgfunc(columnVector(noise_l, k),
01655                        tmp_labels, 
01656                        parent.begin(), parent.end(), 
01657                        tmp_cc);
01658                 wgini = (region_gini - bgfunc.min_gini_);
01659                 corr_l(n, k) 
01660                     += wgini;
01661             }
01662             bgfunc(labels, tmp_labels, parent.begin(), parent.end(),tmp_cc);
01663             wgini = (region_gini - bgfunc.min_gini_);
01664             gini_missc(n, columnCount(gini_missc)-1) 
01665                 += wgini;
01666             
01667             region_gini = split.region_gini_;
01668 #if 1 
01669             Node<i_ThresholdNode> node(split.createNode());
01670             gini_missc(rowCount(gini_missc)-1, 
01671                                   node.column()) 
01672                  +=split.region_gini_ - split.minGini();
01673 #endif
01674             for(int k = 0; k < 10; ++k)
01675             {
01676                 split.bgfunc(columnVector(noise, k),
01677                              labels, 
01678                              parent.begin(), parent.end(), 
01679                              parent.classCounts());
01680                 corr_noise(rowCount(gini_missc)-1, 
01681                            k) 
01682                      += wgini;
01683             }
01684 #if 0
01685             for(int k = 0; k < tree.ext_param_.actual_mtry_; ++k)
01686             {
01687                 wgini = region_gini - split.min_gini_[k];
01688                 
01689                 gini_missc(rowCount(gini_missc)-1, 
01690                                       split.splitColumns[k]) 
01691                      += wgini;
01692             }
01693             
01694             for(int k=tree.ext_param_.actual_mtry_; k<features.shape(1); ++k)
01695             {
01696                 split.bgfunc(columnVector(features, split.splitColumns[k]),
01697                              labels, 
01698                              parent.begin(), parent.end(), 
01699                              parent.classCounts());
01700                 wgini = region_gini - split.bgfunc.min_gini_;
01701                 gini_missc(rowCount(gini_missc)-1, 
01702                                       split.splitColumns[k]) += wgini;
01703             }
01704 #endif
01705             // remember to partition the data according to the best.
01706                 gini_missc(rowCount(gini_missc)-1, 
01707                            columnCount(gini_missc)-1) 
01708                      += region_gini;
01709                 SortSamplesByDimensions<Feature_t> 
01710                 sorter(features, split.bestSplitColumn(), split.bestSplitThreshold());
01711             std::partition(parent.begin(), parent.end(), sorter);
01712         }
01713     }
01714 };
01715 
01716 
01717 } // namespace visitors
01718 } // namespace rf
01719 } // namespace vigra
01720 
01721 //@}
01722 #endif // RF_VISITORS_HXX

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.8.0 (20 Sep 2011)