[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]
vigra/random_forest/rf_visitors.hxx | ![]() |
00001 /************************************************************************/ 00002 /* */ 00003 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */ 00004 /* */ 00005 /* This file is part of the VIGRA computer vision library. */ 00006 /* The VIGRA Website is */ 00007 /* http://hci.iwr.uni-heidelberg.de/vigra/ */ 00008 /* Please direct questions, bug reports, and contributions to */ 00009 /* ullrich.koethe@iwr.uni-heidelberg.de or */ 00010 /* vigra@informatik.uni-hamburg.de */ 00011 /* */ 00012 /* Permission is hereby granted, free of charge, to any person */ 00013 /* obtaining a copy of this software and associated documentation */ 00014 /* files (the "Software"), to deal in the Software without */ 00015 /* restriction, including without limitation the rights to use, */ 00016 /* copy, modify, merge, publish, distribute, sublicense, and/or */ 00017 /* sell copies of the Software, and to permit persons to whom the */ 00018 /* Software is furnished to do so, subject to the following */ 00019 /* conditions: */ 00020 /* */ 00021 /* The above copyright notice and this permission notice shall be */ 00022 /* included in all copies or substantial portions of the */ 00023 /* Software. */ 00024 /* */ 00025 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */ 00026 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */ 00027 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */ 00028 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */ 00029 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */ 00030 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */ 00031 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */ 00032 /* OTHER DEALINGS IN THE SOFTWARE. */ 00033 /* */ 00034 /************************************************************************/ 00035 #ifndef RF_VISITORS_HXX 00036 #define RF_VISITORS_HXX 00037 00038 #ifdef HasHDF5 00039 # include "vigra/hdf5impex.hxx" 00040 #endif // HasHDF5 00041 #include <vigra/windows.h> 00042 #include <iostream> 00043 #include <iomanip> 00044 00045 #include <vigra/multi_pointoperators.hxx> 00046 #include <vigra/timing.hxx> 00047 00048 namespace vigra 00049 { 00050 namespace rf 00051 { 00052 /** \addtogroup MachineLearning Machine Learning 00053 **/ 00054 //@{ 00055 00056 /** 00057 This namespace contains all classes and methods related to extracting information during 00058 learning of the random forest. All Visitors share the same interface defined in 00059 visitors::VisitorBase. The member methods are invoked at certain points of the main code in 00060 the order they were supplied. 00061 00062 For the Random Forest the Visitor concept is implemented as a statically linked list 00063 (Using templates). Each Visitor object is encapsulated in a detail::VisitorNode object. The 00064 VisitorNode object calls the Next Visitor after one of its visit() methods have terminated. 00065 00066 To simplify usage create_visitor() factory methods are supplied. 00067 Use the create_visitor() method to supply visitor objects to the RandomForest::learn() method. 00068 It is possible to supply more than one visitor. They will then be invoked in serial order. 00069 00070 The calculated information are stored as public data members of the class. - see documentation 00071 of the individual visitors 00072 00073 While creating a new visitor the new class should therefore publicly inherit from this class 00074 (i.e.: see visitors::OOB_Error). 00075 00076 \code 00077 00078 typedef xxx feature_t \\ replace xxx with whichever type 00079 typedef yyy label_t \\ meme chose. 00080 MultiArrayView<2, feature_t> f = get_some_features(); 00081 MultiArrayView<2, label_t> l = get_some_labels(); 00082 RandomForest<> rf() 00083 00084 //calculate OOB Error 00085 visitors::OOB_Error oob_v; 00086 //calculate Variable Importance 00087 visitors::VariableImportanceVisitor varimp_v; 00088 00089 double oob_error = rf.learn(f, l, visitors::create_visitor(oob_v, varimp_v); 00090 //the data can be found in the attributes of oob_v and varimp_v now 00091 00092 \endcode 00093 */ 00094 namespace visitors 00095 { 00096 00097 00098 /** Base Class from which all Visitors derive. Can be used as a template to create new 00099 * Visitors. 00100 */ 00101 class VisitorBase 00102 { 00103 public: 00104 bool active_; 00105 bool is_active() 00106 { 00107 return active_; 00108 } 00109 00110 bool has_value() 00111 { 00112 return false; 00113 } 00114 00115 VisitorBase() 00116 : active_(true) 00117 {} 00118 00119 void deactivate() 00120 { 00121 active_ = false; 00122 } 00123 void activate() 00124 { 00125 active_ = true; 00126 } 00127 00128 /** do something after the the Split has decided how to process the Region 00129 * (Stack entry) 00130 * 00131 * \param tree reference to the tree that is currently being learned 00132 * \param split reference to the split object 00133 * \param parent current stack entry which was used to decide the split 00134 * \param leftChild left stack entry that will be pushed 00135 * \param rightChild 00136 * right stack entry that will be pushed. 00137 * \param features features matrix 00138 * \param labels label matrix 00139 * \sa RF_Traits::StackEntry_t 00140 */ 00141 template<class Tree, class Split, class Region, class Feature_t, class Label_t> 00142 void visit_after_split( Tree & tree, 00143 Split & split, 00144 Region & parent, 00145 Region & leftChild, 00146 Region & rightChild, 00147 Feature_t & features, 00148 Label_t & labels) 00149 {} 00150 00151 /** do something after each tree has been learned 00152 * 00153 * \param rf reference to the random forest object that called this 00154 * visitor 00155 * \param pr reference to the preprocessor that processed the input 00156 * \param sm reference to the sampler object 00157 * \param st reference to the first stack entry 00158 * \param index index of current tree 00159 */ 00160 template<class RF, class PR, class SM, class ST> 00161 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index) 00162 {} 00163 00164 /** do something after all trees have been learned 00165 * 00166 * \param rf reference to the random forest object that called this 00167 * visitor 00168 * \param pr reference to the preprocessor that processed the input 00169 */ 00170 template<class RF, class PR> 00171 void visit_at_end(RF const & rf, PR const & pr) 00172 {} 00173 00174 /** do something before learning starts 00175 * 00176 * \param rf reference to the random forest object that called this 00177 * visitor 00178 * \param pr reference to the Processor class used. 00179 */ 00180 template<class RF, class PR> 00181 void visit_at_beginning(RF const & rf, PR const & pr) 00182 {} 00183 /** do some thing while traversing tree after it has been learned 00184 * (external nodes) 00185 * 00186 * \param tr reference to the tree object that called this visitor 00187 * \param index index in the topology_ array we currently are at 00188 * \param node_t type of node we have (will be e_.... - ) 00189 * \param features feature matrix 00190 * \sa NodeTags; 00191 * 00192 * you can create the node by using a switch on node_tag and using the 00193 * corresponding Node objects. Or - if you do not care about the type 00194 * use the NodeBase class. 00195 */ 00196 template<class TR, class IntT, class TopT,class Feat> 00197 void visit_external_node(TR & tr, IntT index, TopT node_t,Feat & features) 00198 {} 00199 00200 /** do something when visiting a internal node after it has been learned 00201 * 00202 * \sa visit_external_node 00203 */ 00204 template<class TR, class IntT, class TopT,class Feat> 00205 void visit_internal_node(TR & tr, IntT index, TopT node_t,Feat & features) 00206 {} 00207 00208 /** return a double value. The value of the first 00209 * visitor encountered that has a return value is returned with the 00210 * RandomForest::learn() method - or -1.0 if no return value visitor 00211 * existed. This functionality basically only exists so that the 00212 * OOB - visitor can return the oob error rate like in the old version 00213 * of the random forest. 00214 */ 00215 double return_val() 00216 { 00217 return -1.0; 00218 } 00219 }; 00220 00221 00222 /** Last Visitor that should be called to stop the recursion. 00223 */ 00224 class StopVisiting: public VisitorBase 00225 { 00226 public: 00227 bool has_value() 00228 { 00229 return true; 00230 } 00231 double return_val() 00232 { 00233 return -1.0; 00234 } 00235 }; 00236 namespace detail 00237 { 00238 /** Container elements of the statically linked Visitor list. 00239 * 00240 * use the create_visitor() factory functions to create visitors up to size 10; 00241 * 00242 */ 00243 template <class Visitor, class Next = StopVisiting> 00244 class VisitorNode 00245 { 00246 public: 00247 00248 StopVisiting stop_; 00249 Next next_; 00250 Visitor & visitor_; 00251 VisitorNode(Visitor & visitor, Next & next) 00252 : 00253 next_(next), visitor_(visitor) 00254 {} 00255 00256 VisitorNode(Visitor & visitor) 00257 : 00258 next_(stop_), visitor_(visitor) 00259 {} 00260 00261 template<class Tree, class Split, class Region, class Feature_t, class Label_t> 00262 void visit_after_split( Tree & tree, 00263 Split & split, 00264 Region & parent, 00265 Region & leftChild, 00266 Region & rightChild, 00267 Feature_t & features, 00268 Label_t & labels) 00269 { 00270 if(visitor_.is_active()) 00271 visitor_.visit_after_split(tree, split, 00272 parent, leftChild, rightChild, 00273 features, labels); 00274 next_.visit_after_split(tree, split, parent, leftChild, rightChild, 00275 features, labels); 00276 } 00277 00278 template<class RF, class PR, class SM, class ST> 00279 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index) 00280 { 00281 if(visitor_.is_active()) 00282 visitor_.visit_after_tree(rf, pr, sm, st, index); 00283 next_.visit_after_tree(rf, pr, sm, st, index); 00284 } 00285 00286 template<class RF, class PR> 00287 void visit_at_beginning(RF & rf, PR & pr) 00288 { 00289 if(visitor_.is_active()) 00290 visitor_.visit_at_beginning(rf, pr); 00291 next_.visit_at_beginning(rf, pr); 00292 } 00293 template<class RF, class PR> 00294 void visit_at_end(RF & rf, PR & pr) 00295 { 00296 if(visitor_.is_active()) 00297 visitor_.visit_at_end(rf, pr); 00298 next_.visit_at_end(rf, pr); 00299 } 00300 00301 template<class TR, class IntT, class TopT,class Feat> 00302 void visit_external_node(TR & tr, IntT & index, TopT & node_t,Feat & features) 00303 { 00304 if(visitor_.is_active()) 00305 visitor_.visit_external_node(tr, index, node_t,features); 00306 next_.visit_external_node(tr, index, node_t,features); 00307 } 00308 template<class TR, class IntT, class TopT,class Feat> 00309 void visit_internal_node(TR & tr, IntT & index, TopT & node_t,Feat & features) 00310 { 00311 if(visitor_.is_active()) 00312 visitor_.visit_internal_node(tr, index, node_t,features); 00313 next_.visit_internal_node(tr, index, node_t,features); 00314 } 00315 00316 double return_val() 00317 { 00318 if(visitor_.is_active() && visitor_.has_value()) 00319 return visitor_.return_val(); 00320 return next_.return_val(); 00321 } 00322 }; 00323 00324 } //namespace detail 00325 00326 ////////////////////////////////////////////////////////////////////////////// 00327 // Visitor Factory function up to 10 visitors // 00328 ////////////////////////////////////////////////////////////////////////////// 00329 00330 /** factory method to to be used with RandomForest::learn() 00331 */ 00332 template<class A> 00333 detail::VisitorNode<A> 00334 create_visitor(A & a) 00335 { 00336 typedef detail::VisitorNode<A> _0_t; 00337 _0_t _0(a); 00338 return _0; 00339 } 00340 00341 00342 /** factory method to to be used with RandomForest::learn() 00343 */ 00344 template<class A, class B> 00345 detail::VisitorNode<A, detail::VisitorNode<B> > 00346 create_visitor(A & a, B & b) 00347 { 00348 typedef detail::VisitorNode<B> _1_t; 00349 _1_t _1(b); 00350 typedef detail::VisitorNode<A, _1_t> _0_t; 00351 _0_t _0(a, _1); 00352 return _0; 00353 } 00354 00355 00356 /** factory method to to be used with RandomForest::learn() 00357 */ 00358 template<class A, class B, class C> 00359 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C> > > 00360 create_visitor(A & a, B & b, C & c) 00361 { 00362 typedef detail::VisitorNode<C> _2_t; 00363 _2_t _2(c); 00364 typedef detail::VisitorNode<B, _2_t> _1_t; 00365 _1_t _1(b, _2); 00366 typedef detail::VisitorNode<A, _1_t> _0_t; 00367 _0_t _0(a, _1); 00368 return _0; 00369 } 00370 00371 00372 /** factory method to to be used with RandomForest::learn() 00373 */ 00374 template<class A, class B, class C, class D> 00375 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 00376 detail::VisitorNode<D> > > > 00377 create_visitor(A & a, B & b, C & c, D & d) 00378 { 00379 typedef detail::VisitorNode<D> _3_t; 00380 _3_t _3(d); 00381 typedef detail::VisitorNode<C, _3_t> _2_t; 00382 _2_t _2(c, _3); 00383 typedef detail::VisitorNode<B, _2_t> _1_t; 00384 _1_t _1(b, _2); 00385 typedef detail::VisitorNode<A, _1_t> _0_t; 00386 _0_t _0(a, _1); 00387 return _0; 00388 } 00389 00390 00391 /** factory method to to be used with RandomForest::learn() 00392 */ 00393 template<class A, class B, class C, class D, class E> 00394 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 00395 detail::VisitorNode<D, detail::VisitorNode<E> > > > > 00396 create_visitor(A & a, B & b, C & c, 00397 D & d, E & e) 00398 { 00399 typedef detail::VisitorNode<E> _4_t; 00400 _4_t _4(e); 00401 typedef detail::VisitorNode<D, _4_t> _3_t; 00402 _3_t _3(d, _4); 00403 typedef detail::VisitorNode<C, _3_t> _2_t; 00404 _2_t _2(c, _3); 00405 typedef detail::VisitorNode<B, _2_t> _1_t; 00406 _1_t _1(b, _2); 00407 typedef detail::VisitorNode<A, _1_t> _0_t; 00408 _0_t _0(a, _1); 00409 return _0; 00410 } 00411 00412 00413 /** factory method to to be used with RandomForest::learn() 00414 */ 00415 template<class A, class B, class C, class D, class E, 00416 class F> 00417 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 00418 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F> > > > > > 00419 create_visitor(A & a, B & b, C & c, 00420 D & d, E & e, F & f) 00421 { 00422 typedef detail::VisitorNode<F> _5_t; 00423 _5_t _5(f); 00424 typedef detail::VisitorNode<E, _5_t> _4_t; 00425 _4_t _4(e, _5); 00426 typedef detail::VisitorNode<D, _4_t> _3_t; 00427 _3_t _3(d, _4); 00428 typedef detail::VisitorNode<C, _3_t> _2_t; 00429 _2_t _2(c, _3); 00430 typedef detail::VisitorNode<B, _2_t> _1_t; 00431 _1_t _1(b, _2); 00432 typedef detail::VisitorNode<A, _1_t> _0_t; 00433 _0_t _0(a, _1); 00434 return _0; 00435 } 00436 00437 00438 /** factory method to to be used with RandomForest::learn() 00439 */ 00440 template<class A, class B, class C, class D, class E, 00441 class F, class G> 00442 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 00443 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F, 00444 detail::VisitorNode<G> > > > > > > 00445 create_visitor(A & a, B & b, C & c, 00446 D & d, E & e, F & f, G & g) 00447 { 00448 typedef detail::VisitorNode<G> _6_t; 00449 _6_t _6(g); 00450 typedef detail::VisitorNode<F, _6_t> _5_t; 00451 _5_t _5(f, _6); 00452 typedef detail::VisitorNode<E, _5_t> _4_t; 00453 _4_t _4(e, _5); 00454 typedef detail::VisitorNode<D, _4_t> _3_t; 00455 _3_t _3(d, _4); 00456 typedef detail::VisitorNode<C, _3_t> _2_t; 00457 _2_t _2(c, _3); 00458 typedef detail::VisitorNode<B, _2_t> _1_t; 00459 _1_t _1(b, _2); 00460 typedef detail::VisitorNode<A, _1_t> _0_t; 00461 _0_t _0(a, _1); 00462 return _0; 00463 } 00464 00465 00466 /** factory method to to be used with RandomForest::learn() 00467 */ 00468 template<class A, class B, class C, class D, class E, 00469 class F, class G, class H> 00470 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 00471 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F, 00472 detail::VisitorNode<G, detail::VisitorNode<H> > > > > > > > 00473 create_visitor(A & a, B & b, C & c, 00474 D & d, E & e, F & f, 00475 G & g, H & h) 00476 { 00477 typedef detail::VisitorNode<H> _7_t; 00478 _7_t _7(h); 00479 typedef detail::VisitorNode<G, _7_t> _6_t; 00480 _6_t _6(g, _7); 00481 typedef detail::VisitorNode<F, _6_t> _5_t; 00482 _5_t _5(f, _6); 00483 typedef detail::VisitorNode<E, _5_t> _4_t; 00484 _4_t _4(e, _5); 00485 typedef detail::VisitorNode<D, _4_t> _3_t; 00486 _3_t _3(d, _4); 00487 typedef detail::VisitorNode<C, _3_t> _2_t; 00488 _2_t _2(c, _3); 00489 typedef detail::VisitorNode<B, _2_t> _1_t; 00490 _1_t _1(b, _2); 00491 typedef detail::VisitorNode<A, _1_t> _0_t; 00492 _0_t _0(a, _1); 00493 return _0; 00494 } 00495 00496 00497 /** factory method to to be used with RandomForest::learn() 00498 */ 00499 template<class A, class B, class C, class D, class E, 00500 class F, class G, class H, class I> 00501 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 00502 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F, 00503 detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I> > > > > > > > > 00504 create_visitor(A & a, B & b, C & c, 00505 D & d, E & e, F & f, 00506 G & g, H & h, I & i) 00507 { 00508 typedef detail::VisitorNode<I> _8_t; 00509 _8_t _8(i); 00510 typedef detail::VisitorNode<H, _8_t> _7_t; 00511 _7_t _7(h, _8); 00512 typedef detail::VisitorNode<G, _7_t> _6_t; 00513 _6_t _6(g, _7); 00514 typedef detail::VisitorNode<F, _6_t> _5_t; 00515 _5_t _5(f, _6); 00516 typedef detail::VisitorNode<E, _5_t> _4_t; 00517 _4_t _4(e, _5); 00518 typedef detail::VisitorNode<D, _4_t> _3_t; 00519 _3_t _3(d, _4); 00520 typedef detail::VisitorNode<C, _3_t> _2_t; 00521 _2_t _2(c, _3); 00522 typedef detail::VisitorNode<B, _2_t> _1_t; 00523 _1_t _1(b, _2); 00524 typedef detail::VisitorNode<A, _1_t> _0_t; 00525 _0_t _0(a, _1); 00526 return _0; 00527 } 00528 00529 /** factory method to to be used with RandomForest::learn() 00530 */ 00531 template<class A, class B, class C, class D, class E, 00532 class F, class G, class H, class I, class J> 00533 detail::VisitorNode<A, detail::VisitorNode<B, detail::VisitorNode<C, 00534 detail::VisitorNode<D, detail::VisitorNode<E, detail::VisitorNode<F, 00535 detail::VisitorNode<G, detail::VisitorNode<H, detail::VisitorNode<I, 00536 detail::VisitorNode<J> > > > > > > > > > 00537 create_visitor(A & a, B & b, C & c, 00538 D & d, E & e, F & f, 00539 G & g, H & h, I & i, 00540 J & j) 00541 { 00542 typedef detail::VisitorNode<J> _9_t; 00543 _9_t _9(j); 00544 typedef detail::VisitorNode<I, _9_t> _8_t; 00545 _8_t _8(i, _9); 00546 typedef detail::VisitorNode<H, _8_t> _7_t; 00547 _7_t _7(h, _8); 00548 typedef detail::VisitorNode<G, _7_t> _6_t; 00549 _6_t _6(g, _7); 00550 typedef detail::VisitorNode<F, _6_t> _5_t; 00551 _5_t _5(f, _6); 00552 typedef detail::VisitorNode<E, _5_t> _4_t; 00553 _4_t _4(e, _5); 00554 typedef detail::VisitorNode<D, _4_t> _3_t; 00555 _3_t _3(d, _4); 00556 typedef detail::VisitorNode<C, _3_t> _2_t; 00557 _2_t _2(c, _3); 00558 typedef detail::VisitorNode<B, _2_t> _1_t; 00559 _1_t _1(b, _2); 00560 typedef detail::VisitorNode<A, _1_t> _0_t; 00561 _0_t _0(a, _1); 00562 return _0; 00563 } 00564 00565 ////////////////////////////////////////////////////////////////////////////// 00566 // Visitors of communal interest. // 00567 ////////////////////////////////////////////////////////////////////////////// 00568 00569 00570 /** Visitor to gain information, later needed for online learning. 00571 */ 00572 00573 class OnlineLearnVisitor: public VisitorBase 00574 { 00575 public: 00576 //Set if we adjust thresholds 00577 bool adjust_thresholds; 00578 //Current tree id 00579 int tree_id; 00580 //Last node id for finding parent 00581 int last_node_id; 00582 //Need to now the label for interior node visiting 00583 vigra::Int32 current_label; 00584 //marginal distribution for interior nodes 00585 // 00586 OnlineLearnVisitor(): 00587 adjust_thresholds(false), tree_id(0), last_node_id(0), current_label(0) 00588 {} 00589 struct MarginalDistribution 00590 { 00591 ArrayVector<Int32> leftCounts; 00592 Int32 leftTotalCounts; 00593 ArrayVector<Int32> rightCounts; 00594 Int32 rightTotalCounts; 00595 double gap_left; 00596 double gap_right; 00597 }; 00598 typedef ArrayVector<vigra::Int32> IndexList; 00599 00600 //All information for one tree 00601 struct TreeOnlineInformation 00602 { 00603 std::vector<MarginalDistribution> mag_distributions; 00604 std::vector<IndexList> index_lists; 00605 //map for linear index of mag_distributions 00606 std::map<int,int> interior_to_index; 00607 //map for linear index of index_lists 00608 std::map<int,int> exterior_to_index; 00609 }; 00610 00611 //All trees 00612 std::vector<TreeOnlineInformation> trees_online_information; 00613 00614 /** Initialize, set the number of trees 00615 */ 00616 template<class RF,class PR> 00617 void visit_at_beginning(RF & rf,const PR & pr) 00618 { 00619 tree_id=0; 00620 trees_online_information.resize(rf.options_.tree_count_); 00621 } 00622 00623 /** Reset a tree 00624 */ 00625 void reset_tree(int tree_id) 00626 { 00627 trees_online_information[tree_id].mag_distributions.clear(); 00628 trees_online_information[tree_id].index_lists.clear(); 00629 trees_online_information[tree_id].interior_to_index.clear(); 00630 trees_online_information[tree_id].exterior_to_index.clear(); 00631 } 00632 00633 /** simply increase the tree count 00634 */ 00635 template<class RF, class PR, class SM, class ST> 00636 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index) 00637 { 00638 tree_id++; 00639 } 00640 00641 template<class Tree, class Split, class Region, class Feature_t, class Label_t> 00642 void visit_after_split( Tree & tree, 00643 Split & split, 00644 Region & parent, 00645 Region & leftChild, 00646 Region & rightChild, 00647 Feature_t & features, 00648 Label_t & labels) 00649 { 00650 int linear_index; 00651 int addr=tree.topology_.size(); 00652 if(split.createNode().typeID() == i_ThresholdNode) 00653 { 00654 if(adjust_thresholds) 00655 { 00656 //Store marginal distribution 00657 linear_index=trees_online_information[tree_id].mag_distributions.size(); 00658 trees_online_information[tree_id].interior_to_index[addr]=linear_index; 00659 trees_online_information[tree_id].mag_distributions.push_back(MarginalDistribution()); 00660 00661 trees_online_information[tree_id].mag_distributions.back().leftCounts=leftChild.classCounts_; 00662 trees_online_information[tree_id].mag_distributions.back().rightCounts=rightChild.classCounts_; 00663 00664 trees_online_information[tree_id].mag_distributions.back().leftTotalCounts=leftChild.size_; 00665 trees_online_information[tree_id].mag_distributions.back().rightTotalCounts=rightChild.size_; 00666 //Store the gap 00667 double gap_left,gap_right; 00668 int i; 00669 gap_left=features(leftChild[0],split.bestSplitColumn()); 00670 for(i=1;i<leftChild.size();++i) 00671 if(features(leftChild[i],split.bestSplitColumn())>gap_left) 00672 gap_left=features(leftChild[i],split.bestSplitColumn()); 00673 gap_right=features(rightChild[0],split.bestSplitColumn()); 00674 for(i=1;i<rightChild.size();++i) 00675 if(features(rightChild[i],split.bestSplitColumn())<gap_right) 00676 gap_right=features(rightChild[i],split.bestSplitColumn()); 00677 trees_online_information[tree_id].mag_distributions.back().gap_left=gap_left; 00678 trees_online_information[tree_id].mag_distributions.back().gap_right=gap_right; 00679 } 00680 } 00681 else 00682 { 00683 //Store index list 00684 linear_index=trees_online_information[tree_id].index_lists.size(); 00685 trees_online_information[tree_id].exterior_to_index[addr]=linear_index; 00686 00687 trees_online_information[tree_id].index_lists.push_back(IndexList()); 00688 00689 trees_online_information[tree_id].index_lists.back().resize(parent.size_,0); 00690 std::copy(parent.begin_,parent.end_,trees_online_information[tree_id].index_lists.back().begin()); 00691 } 00692 } 00693 void add_to_index_list(int tree,int node,int index) 00694 { 00695 if(!this->active_) 00696 return; 00697 TreeOnlineInformation &ti=trees_online_information[tree]; 00698 ti.index_lists[ti.exterior_to_index[node]].push_back(index); 00699 } 00700 void move_exterior_node(int src_tree,int src_index,int dst_tree,int dst_index) 00701 { 00702 if(!this->active_) 00703 return; 00704 trees_online_information[dst_tree].exterior_to_index[dst_index]=trees_online_information[src_tree].exterior_to_index[src_index]; 00705 trees_online_information[src_tree].exterior_to_index.erase(src_index); 00706 } 00707 /** do something when visiting a internal node during getToLeaf 00708 * 00709 * remember as last node id, for finding the parent of the last external node 00710 * also: adjust class counts and borders 00711 */ 00712 template<class TR, class IntT, class TopT,class Feat> 00713 void visit_internal_node(TR & tr, IntT index, TopT node_t,Feat & features) 00714 { 00715 last_node_id=index; 00716 if(adjust_thresholds) 00717 { 00718 vigra_assert(node_t==i_ThresholdNode,"We can only visit threshold nodes"); 00719 //Check if we are in the gap 00720 double value=features(0, Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).column()); 00721 TreeOnlineInformation &ti=trees_online_information[tree_id]; 00722 MarginalDistribution &m=ti.mag_distributions[ti.interior_to_index[index]]; 00723 if(value>m.gap_left && value<m.gap_right) 00724 { 00725 //Check which site we want to go 00726 if(m.leftCounts[current_label]/double(m.leftTotalCounts)>m.rightCounts[current_label]/double(m.rightTotalCounts)) 00727 { 00728 //We want to go left 00729 m.gap_left=value; 00730 } 00731 else 00732 { 00733 //We want to go right 00734 m.gap_right=value; 00735 } 00736 Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold()=(m.gap_right+m.gap_left)/2.0; 00737 } 00738 //Adjust class counts 00739 if(value>Node<i_ThresholdNode>(tr.topology_,tr.parameters_,index).threshold()) 00740 { 00741 ++m.rightTotalCounts; 00742 ++m.rightCounts[current_label]; 00743 } 00744 else 00745 { 00746 ++m.leftTotalCounts; 00747 ++m.rightCounts[current_label]; 00748 } 00749 } 00750 } 00751 /** do something when visiting a extern node during getToLeaf 00752 * 00753 * Store the new index! 00754 */ 00755 }; 00756 00757 ////////////////////////////////////////////////////////////////////////////// 00758 // Out of Bag Error estimates // 00759 ////////////////////////////////////////////////////////////////////////////// 00760 00761 00762 /** Visitor that calculates the oob error of each individual randomized 00763 * decision tree. 00764 * 00765 * After training a tree, all those samples that are OOB for this particular tree 00766 * are put down the tree and the error estimated. 00767 * the per tree oob error is the average of the individual error estimates. 00768 * (oobError = average error of one randomized tree) 00769 * Note: This is Not the OOB - Error estimate suggested by Breiman (See OOB_Error 00770 * visitor) 00771 */ 00772 class OOB_PerTreeError:public VisitorBase 00773 { 00774 public: 00775 /** Average error of one randomized decision tree 00776 */ 00777 double oobError; 00778 00779 int totalOobCount; 00780 ArrayVector<int> oobCount,oobErrorCount; 00781 00782 OOB_PerTreeError() 00783 : oobError(0.0), 00784 totalOobCount(0) 00785 {} 00786 00787 00788 bool has_value() 00789 { 00790 return true; 00791 } 00792 00793 00794 /** does the basic calculation per tree*/ 00795 template<class RF, class PR, class SM, class ST> 00796 void visit_after_tree( RF& rf, PR & pr, SM & sm, ST & st, int index) 00797 { 00798 //do the first time called. 00799 if(int(oobCount.size()) != rf.ext_param_.row_count_) 00800 { 00801 oobCount.resize(rf.ext_param_.row_count_, 0); 00802 oobErrorCount.resize(rf.ext_param_.row_count_, 0); 00803 } 00804 // go through the samples 00805 for(int l = 0; l < rf.ext_param_.row_count_; ++l) 00806 { 00807 // if the lth sample is oob... 00808 if(!sm.is_used()[l]) 00809 { 00810 ++oobCount[l]; 00811 if( rf.tree(index) 00812 .predictLabel(rowVector(pr.features(), l)) 00813 != pr.response()(l,0)) 00814 { 00815 ++oobErrorCount[l]; 00816 } 00817 } 00818 00819 } 00820 } 00821 00822 /** Does the normalisation 00823 */ 00824 template<class RF, class PR> 00825 void visit_at_end(RF & rf, PR & pr) 00826 { 00827 // do some normalisation 00828 for(int l=0; l < (int)rf.ext_param_.row_count_; ++l) 00829 { 00830 if(oobCount[l]) 00831 { 00832 oobError += double(oobErrorCount[l]) / oobCount[l]; 00833 ++totalOobCount; 00834 } 00835 } 00836 oobError/=totalOobCount; 00837 } 00838 00839 }; 00840 00841 /** Visitor that calculates the oob error of the ensemble 00842 * This rate should be used to estimate the crossvalidation 00843 * error rate. 00844 * Here each sample is put down those trees, for which this sample 00845 * is OOB i.e. if sample #1 is OOB for trees 1, 3 and 5 we calculate 00846 * the output using the ensemble consisting only of trees 1 3 and 5. 00847 * 00848 * Using normal bagged sampling each sample is OOB for approx. 33% of trees 00849 * The error rate obtained as such therefore corresponds to crossvalidation 00850 * rate obtained using a ensemble containing 33% of the trees. 00851 */ 00852 class OOB_Error : public VisitorBase 00853 { 00854 typedef MultiArrayShape<2>::type Shp; 00855 int class_count; 00856 bool is_weighted; 00857 MultiArray<2,double> tmp_prob; 00858 public: 00859 00860 MultiArray<2, double> prob_oob; 00861 /** Ensemble oob error rate 00862 */ 00863 double oob_breiman; 00864 00865 MultiArray<2, double> oobCount; 00866 ArrayVector< int> indices; 00867 OOB_Error() : VisitorBase(), oob_breiman(0.0) {} 00868 #ifdef HasHDF5 00869 void save(std::string filen, std::string pathn) 00870 { 00871 if(*(pathn.end()-1) != '/') 00872 pathn += "/"; 00873 const char* filename = filen.c_str(); 00874 MultiArray<2, double> temp(Shp(1,1), 0.0); 00875 temp[0] = oob_breiman; 00876 writeHDF5(filename, (pathn + "breiman_error").c_str(), temp); 00877 } 00878 #endif 00879 // negative value if sample was ib, number indicates how often. 00880 // value >=0 if sample was oob, 0 means fail 1, correct 00881 00882 template<class RF, class PR> 00883 void visit_at_beginning(RF & rf, PR & pr) 00884 { 00885 class_count = rf.class_count(); 00886 tmp_prob.reshape(Shp(1, class_count), 0); 00887 prob_oob.reshape(Shp(rf.ext_param().row_count_,class_count), 0); 00888 is_weighted = rf.options().predict_weighted_; 00889 indices.resize(rf.ext_param().row_count_); 00890 if(int(oobCount.size()) != rf.ext_param_.row_count_) 00891 { 00892 oobCount.reshape(Shp(rf.ext_param_.row_count_, 1), 0); 00893 } 00894 for(int ii = 0; ii < rf.ext_param().row_count_; ++ii) 00895 { 00896 indices[ii] = ii; 00897 } 00898 } 00899 00900 template<class RF, class PR, class SM, class ST> 00901 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index) 00902 { 00903 // go through the samples 00904 int total_oob =0; 00905 // FIXME: magic number 10000: invoke special treatment when when msample << sample_count 00906 // (i.e. the OOB sample ist very large) 00907 // 40000: use at most 40000 OOB samples per class for OOB error estimate 00908 if(rf.ext_param_.actual_msample_ < pr.features().shape(0) - 10000) 00909 { 00910 ArrayVector<int> oob_indices; 00911 ArrayVector<int> cts(class_count, 0); 00912 std::random_shuffle(indices.begin(), indices.end()); 00913 for(int ii = 0; ii < rf.ext_param_.row_count_; ++ii) 00914 { 00915 if(!sm.is_used()[indices[ii]] && cts[pr.response()(indices[ii], 0)] < 40000) 00916 { 00917 oob_indices.push_back(indices[ii]); 00918 ++cts[pr.response()(indices[ii], 0)]; 00919 } 00920 } 00921 for(unsigned int ll = 0; ll < oob_indices.size(); ++ll) 00922 { 00923 // update number of trees in which current sample is oob 00924 ++oobCount[oob_indices[ll]]; 00925 00926 // update number of oob samples in this tree. 00927 ++total_oob; 00928 // get the predicted votes ---> tmp_prob; 00929 int pos = rf.tree(index).getToLeaf(rowVector(pr.features(),oob_indices[ll])); 00930 Node<e_ConstProbNode> node ( rf.tree(index).topology_, 00931 rf.tree(index).parameters_, 00932 pos); 00933 tmp_prob.init(0); 00934 for(int ii = 0; ii < class_count; ++ii) 00935 { 00936 tmp_prob[ii] = node.prob_begin()[ii]; 00937 } 00938 if(is_weighted) 00939 { 00940 for(int ii = 0; ii < class_count; ++ii) 00941 tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1)); 00942 } 00943 rowVector(prob_oob, oob_indices[ll]) += tmp_prob; 00944 00945 } 00946 }else 00947 { 00948 for(int ll = 0; ll < rf.ext_param_.row_count_; ++ll) 00949 { 00950 // if the lth sample is oob... 00951 if(!sm.is_used()[ll]) 00952 { 00953 // update number of trees in which current sample is oob 00954 ++oobCount[ll]; 00955 00956 // update number of oob samples in this tree. 00957 ++total_oob; 00958 // get the predicted votes ---> tmp_prob; 00959 int pos = rf.tree(index).getToLeaf(rowVector(pr.features(),ll)); 00960 Node<e_ConstProbNode> node ( rf.tree(index).topology_, 00961 rf.tree(index).parameters_, 00962 pos); 00963 tmp_prob.init(0); 00964 for(int ii = 0; ii < class_count; ++ii) 00965 { 00966 tmp_prob[ii] = node.prob_begin()[ii]; 00967 } 00968 if(is_weighted) 00969 { 00970 for(int ii = 0; ii < class_count; ++ii) 00971 tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1)); 00972 } 00973 rowVector(prob_oob, ll) += tmp_prob; 00974 } 00975 } 00976 } 00977 // go through the ib samples; 00978 } 00979 00980 /** Normalise variable importance after the number of trees is known. 00981 */ 00982 template<class RF, class PR> 00983 void visit_at_end(RF & rf, PR & pr) 00984 { 00985 // ullis original metric and breiman style stuff 00986 int totalOobCount =0; 00987 int breimanstyle = 0; 00988 for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll) 00989 { 00990 if(oobCount[ll]) 00991 { 00992 if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0)) 00993 ++breimanstyle; 00994 ++totalOobCount; 00995 } 00996 } 00997 oob_breiman = double(breimanstyle)/totalOobCount; 00998 } 00999 }; 01000 01001 01002 /** Visitor that calculates different OOB error statistics 01003 */ 01004 class CompleteOOBInfo : public VisitorBase 01005 { 01006 typedef MultiArrayShape<2>::type Shp; 01007 int class_count; 01008 bool is_weighted; 01009 MultiArray<2,double> tmp_prob; 01010 public: 01011 01012 /** OOB Error rate of each individual tree 01013 */ 01014 MultiArray<2, double> oob_per_tree; 01015 /** Mean of oob_per_tree 01016 */ 01017 double oob_mean; 01018 /**Standard deviation of oob_per_tree 01019 */ 01020 double oob_std; 01021 01022 MultiArray<2, double> prob_oob; 01023 /** Ensemble OOB error 01024 * 01025 * \sa OOB_Error 01026 */ 01027 double oob_breiman; 01028 01029 MultiArray<2, double> oobCount; 01030 MultiArray<2, double> oobErrorCount; 01031 /** Per Tree OOB error calculated as in OOB_PerTreeError 01032 * (Ulli's version) 01033 */ 01034 double oob_per_tree2; 01035 01036 /**Column containing the development of the Ensemble 01037 * error rate with increasing number of trees 01038 */ 01039 MultiArray<2, double> breiman_per_tree; 01040 /** 4 dimensional array containing the development of confusion matrices 01041 * with number of trees - can be used to estimate ROC curves etc. 01042 * 01043 * oobroc_per_tree(ii,jj,kk,ll) 01044 * corresponds true label = ii 01045 * predicted label = jj 01046 * confusion matrix after ll trees 01047 * 01048 * explanation of third index: 01049 * 01050 * Two class case: 01051 * kk = 0 - (treeCount-1) 01052 * Threshold is on Probability for class 0 is kk/(treeCount-1); 01053 * More classes: 01054 * kk = 0. Threshold on probability set by argMax of the probability array. 01055 */ 01056 MultiArray<4, double> oobroc_per_tree; 01057 01058 CompleteOOBInfo() : VisitorBase(), oob_mean(0), oob_std(0), oob_per_tree2(0) {} 01059 01060 #ifdef HasHDF5 01061 /** save to HDF5 file 01062 */ 01063 void save(std::string filen, std::string pathn) 01064 { 01065 if(*(pathn.end()-1) != '/') 01066 pathn += "/"; 01067 const char* filename = filen.c_str(); 01068 MultiArray<2, double> temp(Shp(1,1), 0.0); 01069 writeHDF5(filename, (pathn + "oob_per_tree").c_str(), oob_per_tree); 01070 writeHDF5(filename, (pathn + "oobroc_per_tree").c_str(), oobroc_per_tree); 01071 writeHDF5(filename, (pathn + "breiman_per_tree").c_str(), breiman_per_tree); 01072 temp[0] = oob_mean; 01073 writeHDF5(filename, (pathn + "per_tree_error").c_str(), temp); 01074 temp[0] = oob_std; 01075 writeHDF5(filename, (pathn + "per_tree_error_std").c_str(), temp); 01076 temp[0] = oob_breiman; 01077 writeHDF5(filename, (pathn + "breiman_error").c_str(), temp); 01078 temp[0] = oob_per_tree2; 01079 writeHDF5(filename, (pathn + "ulli_error").c_str(), temp); 01080 } 01081 #endif 01082 // negative value if sample was ib, number indicates how often. 01083 // value >=0 if sample was oob, 0 means fail 1, correct 01084 01085 template<class RF, class PR> 01086 void visit_at_beginning(RF & rf, PR & pr) 01087 { 01088 class_count = rf.class_count(); 01089 if(class_count == 2) 01090 oobroc_per_tree.reshape(MultiArrayShape<4>::type(2,2,rf.tree_count(), rf.tree_count())); 01091 else 01092 oobroc_per_tree.reshape(MultiArrayShape<4>::type(rf.class_count(),rf.class_count(),1, rf.tree_count())); 01093 tmp_prob.reshape(Shp(1, class_count), 0); 01094 prob_oob.reshape(Shp(rf.ext_param().row_count_,class_count), 0); 01095 is_weighted = rf.options().predict_weighted_; 01096 oob_per_tree.reshape(Shp(1, rf.tree_count()), 0); 01097 breiman_per_tree.reshape(Shp(1, rf.tree_count()), 0); 01098 //do the first time called. 01099 if(int(oobCount.size()) != rf.ext_param_.row_count_) 01100 { 01101 oobCount.reshape(Shp(rf.ext_param_.row_count_, 1), 0); 01102 oobErrorCount.reshape(Shp(rf.ext_param_.row_count_,1), 0); 01103 } 01104 } 01105 01106 template<class RF, class PR, class SM, class ST> 01107 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index) 01108 { 01109 // go through the samples 01110 int total_oob =0; 01111 int wrong_oob =0; 01112 for(int ll = 0; ll < rf.ext_param_.row_count_; ++ll) 01113 { 01114 // if the lth sample is oob... 01115 if(!sm.is_used()[ll]) 01116 { 01117 // update number of trees in which current sample is oob 01118 ++oobCount[ll]; 01119 01120 // update number of oob samples in this tree. 01121 ++total_oob; 01122 // get the predicted votes ---> tmp_prob; 01123 int pos = rf.tree(index).getToLeaf(rowVector(pr.features(),ll)); 01124 Node<e_ConstProbNode> node ( rf.tree(index).topology_, 01125 rf.tree(index).parameters_, 01126 pos); 01127 tmp_prob.init(0); 01128 for(int ii = 0; ii < class_count; ++ii) 01129 { 01130 tmp_prob[ii] = node.prob_begin()[ii]; 01131 } 01132 if(is_weighted) 01133 { 01134 for(int ii = 0; ii < class_count; ++ii) 01135 tmp_prob[ii] = tmp_prob[ii] * (*(node.prob_begin()-1)); 01136 } 01137 rowVector(prob_oob, ll) += tmp_prob; 01138 int label = argMax(tmp_prob); 01139 01140 if(label != pr.response()(ll, 0)) 01141 { 01142 // update number of wrong oob samples in this tree. 01143 ++wrong_oob; 01144 // update number of trees in which current sample is wrong oob 01145 ++oobErrorCount[ll]; 01146 } 01147 } 01148 } 01149 int breimanstyle = 0; 01150 int totalOobCount = 0; 01151 for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll) 01152 { 01153 if(oobCount[ll]) 01154 { 01155 if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0)) 01156 ++breimanstyle; 01157 ++totalOobCount; 01158 if(oobroc_per_tree.shape(2) == 1) 01159 { 01160 oobroc_per_tree(pr.response()(ll,0), argMax(rowVector(prob_oob, ll)),0 ,index)++; 01161 } 01162 } 01163 } 01164 if(oobroc_per_tree.shape(2) == 1) 01165 oobroc_per_tree.bindOuter(index)/=totalOobCount; 01166 if(oobroc_per_tree.shape(2) > 1) 01167 { 01168 MultiArrayView<3, double> current_roc 01169 = oobroc_per_tree.bindOuter(index); 01170 for(int gg = 0; gg < current_roc.shape(2); ++gg) 01171 { 01172 for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll) 01173 { 01174 if(oobCount[ll]) 01175 { 01176 int pred = prob_oob(ll, 1) > (double(gg)/double(current_roc.shape(2)))? 01177 1 : 0; 01178 current_roc(pr.response()(ll, 0), pred, gg)+= 1; 01179 } 01180 } 01181 current_roc.bindOuter(gg)/= totalOobCount; 01182 } 01183 } 01184 breiman_per_tree[index] = double(breimanstyle)/double(totalOobCount); 01185 oob_per_tree[index] = double(wrong_oob)/double(total_oob); 01186 // go through the ib samples; 01187 } 01188 01189 /** Normalise variable importance after the number of trees is known. 01190 */ 01191 template<class RF, class PR> 01192 void visit_at_end(RF & rf, PR & pr) 01193 { 01194 // ullis original metric and breiman style stuff 01195 oob_per_tree2 = 0; 01196 int totalOobCount =0; 01197 int breimanstyle = 0; 01198 for(int ll=0; ll < (int)rf.ext_param_.row_count_; ++ll) 01199 { 01200 if(oobCount[ll]) 01201 { 01202 if(argMax(rowVector(prob_oob, ll)) != pr.response()(ll, 0)) 01203 ++breimanstyle; 01204 oob_per_tree2 += double(oobErrorCount[ll]) / oobCount[ll]; 01205 ++totalOobCount; 01206 } 01207 } 01208 oob_per_tree2 /= totalOobCount; 01209 oob_breiman = double(breimanstyle)/totalOobCount; 01210 // mean error of each tree 01211 MultiArrayView<2, double> mean(Shp(1,1), &oob_mean); 01212 MultiArrayView<2, double> stdDev(Shp(1,1), &oob_std); 01213 rowStatistics(oob_per_tree, mean, stdDev); 01214 } 01215 }; 01216 01217 /** calculate variable importance while learning. 01218 */ 01219 class VariableImportanceVisitor : public VisitorBase 01220 { 01221 public: 01222 01223 /** This Array has the same entries as the R - random forest variable 01224 * importance. 01225 * Matrix is featureCount by (classCount +2) 01226 * variable_importance_(ii,jj) is the variable importance measure of 01227 * the ii-th variable according to: 01228 * jj = 0 - (classCount-1) 01229 * classwise permutation importance 01230 * jj = rowCount(variable_importance_) -2 01231 * permutation importance 01232 * jj = rowCount(variable_importance_) -1 01233 * gini decrease importance. 01234 * 01235 * permutation importance: 01236 * The difference between the fraction of OOB samples classified correctly 01237 * before and after permuting (randomizing) the ii-th column is calculated. 01238 * The ii-th column is permuted rep_cnt times. 01239 * 01240 * class wise permutation importance: 01241 * same as permutation importance. We only look at those OOB samples whose 01242 * response corresponds to class jj. 01243 * 01244 * gini decrease importance: 01245 * row ii corresponds to the sum of all gini decreases induced by variable ii 01246 * in each node of the random forest. 01247 */ 01248 MultiArray<2, double> variable_importance_; 01249 int repetition_count_; 01250 bool in_place_; 01251 01252 #ifdef HasHDF5 01253 void save(std::string filename, std::string prefix) 01254 { 01255 prefix = "variable_importance_" + prefix; 01256 writeHDF5(filename.c_str(), 01257 prefix.c_str(), 01258 variable_importance_); 01259 } 01260 #endif 01261 01262 /** Constructor 01263 * \param rep_cnt (defautl: 10) how often should 01264 * the permutation take place. Set to 1 to make calculation faster (but 01265 * possibly more instable) 01266 */ 01267 VariableImportanceVisitor(int rep_cnt = 10) 01268 : repetition_count_(rep_cnt) 01269 01270 {} 01271 01272 /** calculates impurity decrease based variable importance after every 01273 * split. 01274 */ 01275 template<class Tree, class Split, class Region, class Feature_t, class Label_t> 01276 void visit_after_split( Tree & tree, 01277 Split & split, 01278 Region & parent, 01279 Region & leftChild, 01280 Region & rightChild, 01281 Feature_t & features, 01282 Label_t & labels) 01283 { 01284 //resize to right size when called the first time 01285 01286 Int32 const class_count = tree.ext_param_.class_count_; 01287 Int32 const column_count = tree.ext_param_.column_count_; 01288 if(variable_importance_.size() == 0) 01289 { 01290 01291 variable_importance_ 01292 .reshape(MultiArrayShape<2>::type(column_count, 01293 class_count+2)); 01294 } 01295 01296 if(split.createNode().typeID() == i_ThresholdNode) 01297 { 01298 Node<i_ThresholdNode> node(split.createNode()); 01299 variable_importance_(node.column(),class_count+1) 01300 += split.region_gini_ - split.minGini(); 01301 } 01302 } 01303 01304 /**compute permutation based var imp. 01305 * (Only an Array of size oob_sample_count x 1 is created. 01306 * - apposed to oob_sample_count x feature_count in the other method. 01307 * 01308 * \sa FieldProxy 01309 */ 01310 template<class RF, class PR, class SM, class ST> 01311 void after_tree_ip_impl(RF& rf, PR & pr, SM & sm, ST & st, int index) 01312 { 01313 typedef MultiArrayShape<2>::type Shp_t; 01314 Int32 column_count = rf.ext_param_.column_count_; 01315 Int32 class_count = rf.ext_param_.class_count_; 01316 01317 /* This solution saves memory uptake but not multithreading 01318 * compatible 01319 */ 01320 // remove the const cast on the features (yep , I know what I am 01321 // doing here.) data is not destroyed. 01322 //typename PR::Feature_t & features 01323 // = const_cast<typename PR::Feature_t &>(pr.features()); 01324 01325 typedef typename PR::FeatureWithMemory_t FeatureArray; 01326 typedef typename FeatureArray::value_type FeatureValue; 01327 01328 FeatureArray features = pr.features(); 01329 01330 //find the oob indices of current tree. 01331 ArrayVector<Int32> oob_indices; 01332 ArrayVector<Int32>::iterator 01333 iter; 01334 for(int ii = 0; ii < rf.ext_param_.row_count_; ++ii) 01335 if(!sm.is_used()[ii]) 01336 oob_indices.push_back(ii); 01337 01338 //create space to back up a column 01339 ArrayVector<FeatureValue> backup_column; 01340 01341 // Random foo 01342 #ifdef CLASSIFIER_TEST 01343 RandomMT19937 random(1); 01344 #else 01345 RandomMT19937 random(RandomSeed); 01346 #endif 01347 UniformIntRandomFunctor<RandomMT19937> 01348 randint(random); 01349 01350 01351 //make some space for the results 01352 MultiArray<2, double> 01353 oob_right(Shp_t(1, class_count + 1)); 01354 MultiArray<2, double> 01355 perm_oob_right (Shp_t(1, class_count + 1)); 01356 01357 01358 // get the oob success rate with the original samples 01359 for(iter = oob_indices.begin(); 01360 iter != oob_indices.end(); 01361 ++iter) 01362 { 01363 if(rf.tree(index) 01364 .predictLabel(rowVector(features, *iter)) 01365 == pr.response()(*iter, 0)) 01366 { 01367 //per class 01368 ++oob_right[pr.response()(*iter,0)]; 01369 //total 01370 ++oob_right[class_count]; 01371 } 01372 } 01373 //get the oob rate after permuting the ii'th dimension. 01374 for(int ii = 0; ii < column_count; ++ii) 01375 { 01376 perm_oob_right.init(0.0); 01377 //make backup of original column 01378 backup_column.clear(); 01379 for(iter = oob_indices.begin(); 01380 iter != oob_indices.end(); 01381 ++iter) 01382 { 01383 backup_column.push_back(features(*iter,ii)); 01384 } 01385 01386 //get the oob rate after permuting the ii'th dimension. 01387 for(int rr = 0; rr < repetition_count_; ++rr) 01388 { 01389 //permute dimension. 01390 int n = oob_indices.size(); 01391 for(int jj = 1; jj < n; ++jj) 01392 std::swap(features(oob_indices[jj], ii), 01393 features(oob_indices[randint(jj+1)], ii)); 01394 01395 //get the oob success rate after permuting 01396 for(iter = oob_indices.begin(); 01397 iter != oob_indices.end(); 01398 ++iter) 01399 { 01400 if(rf.tree(index) 01401 .predictLabel(rowVector(features, *iter)) 01402 == pr.response()(*iter, 0)) 01403 { 01404 //per class 01405 ++perm_oob_right[pr.response()(*iter, 0)]; 01406 //total 01407 ++perm_oob_right[class_count]; 01408 } 01409 } 01410 } 01411 01412 01413 //normalise and add to the variable_importance array. 01414 perm_oob_right /= repetition_count_; 01415 perm_oob_right -=oob_right; 01416 perm_oob_right *= -1; 01417 perm_oob_right /= oob_indices.size(); 01418 variable_importance_ 01419 .subarray(Shp_t(ii,0), 01420 Shp_t(ii+1,class_count+1)) += perm_oob_right; 01421 //copy back permuted dimension 01422 for(int jj = 0; jj < int(oob_indices.size()); ++jj) 01423 features(oob_indices[jj], ii) = backup_column[jj]; 01424 } 01425 } 01426 01427 /** calculate permutation based impurity after every tree has been 01428 * learned default behaviour is that this happens out of place. 01429 * If you have very big data sets and want to avoid copying of data 01430 * set the in_place_ flag to true. 01431 */ 01432 template<class RF, class PR, class SM, class ST> 01433 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index) 01434 { 01435 after_tree_ip_impl(rf, pr, sm, st, index); 01436 } 01437 01438 /** Normalise variable importance after the number of trees is known. 01439 */ 01440 template<class RF, class PR> 01441 void visit_at_end(RF & rf, PR & pr) 01442 { 01443 variable_importance_ /= rf.trees_.size(); 01444 } 01445 }; 01446 01447 /** Verbose output 01448 */ 01449 class RandomForestProgressVisitor : public VisitorBase { 01450 public: 01451 RandomForestProgressVisitor() : VisitorBase() {} 01452 01453 template<class RF, class PR, class SM, class ST> 01454 void visit_after_tree(RF& rf, PR & pr, SM & sm, ST & st, int index){ 01455 if(index != rf.options().tree_count_-1) { 01456 std::cout << "\r[" << std::setw(10) << (index+1)/static_cast<double>(rf.options().tree_count_)*100 << "%]" 01457 << " (" << index+1 << " of " << rf.options().tree_count_ << ") done" << std::flush; 01458 } 01459 else { 01460 std::cout << "\r[" << std::setw(10) << 100.0 << "%]" << std::endl; 01461 } 01462 } 01463 01464 template<class RF, class PR> 01465 void visit_at_end(RF const & rf, PR const & pr) { 01466 std::string a = TOCS; 01467 std::cout << "all " << rf.options().tree_count_ << " trees have been learned in " << a << std::endl; 01468 } 01469 01470 template<class RF, class PR> 01471 void visit_at_beginning(RF const & rf, PR const & pr) { 01472 TIC; 01473 std::cout << "growing random forest, which will have " << rf.options().tree_count_ << " trees" << std::endl; 01474 } 01475 01476 private: 01477 USETICTOC; 01478 }; 01479 01480 01481 /** Computes Correlation/Similarity Matrix of features while learning 01482 * random forest. 01483 */ 01484 class CorrelationVisitor : public VisitorBase 01485 { 01486 public: 01487 /** gini_missc(ii, jj) describes how well variable jj can describe a partition 01488 * created on variable ii(when variable ii was chosen) 01489 */ 01490 MultiArray<2, double> gini_missc; 01491 MultiArray<2, int> tmp_labels; 01492 /** additional noise features. 01493 */ 01494 MultiArray<2, double> noise; 01495 MultiArray<2, double> noise_l; 01496 /** how well can a noise column describe a partition created on variable ii. 01497 */ 01498 MultiArray<2, double> corr_noise; 01499 MultiArray<2, double> corr_l; 01500 01501 /** Similarity Matrix 01502 * 01503 * (numberOfFeatures + 1) by (number Of Features + 1) Matrix 01504 * gini_missc 01505 * - row normalized by the number of times the column was chosen 01506 * - mean of corr_noise subtracted 01507 * - and symmetrised. 01508 * 01509 */ 01510 MultiArray<2, double> similarity; 01511 /** Distance Matrix 1-similarity 01512 */ 01513 MultiArray<2, double> distance; 01514 ArrayVector<int> tmp_cc; 01515 01516 /** How often was variable ii chosen 01517 */ 01518 ArrayVector<int> numChoices; 01519 typedef BestGiniOfColumn<GiniCriterion> ColumnDecisionFunctor; 01520 BestGiniOfColumn<GiniCriterion> bgfunc; 01521 void save(std::string file, std::string prefix) 01522 { 01523 /* 01524 std::string tmp; 01525 #define VAR_WRITE(NAME) \ 01526 tmp = #NAME;\ 01527 tmp += "_";\ 01528 tmp += prefix;\ 01529 vigra::writeToHDF5File(file.c_str(), tmp.c_str(), NAME); 01530 VAR_WRITE(gini_missc); 01531 VAR_WRITE(corr_noise); 01532 VAR_WRITE(distance); 01533 VAR_WRITE(similarity); 01534 vigra::writeToHDF5File(file.c_str(), "nChoices", MultiArrayView<2, int>(MultiArrayShape<2>::type(numChoices.size(),1), numChoices.data())); 01535 #undef VAR_WRITE 01536 */ 01537 } 01538 template<class RF, class PR> 01539 void visit_at_beginning(RF const & rf, PR & pr) 01540 { 01541 typedef MultiArrayShape<2>::type Shp; 01542 int n = rf.ext_param_.column_count_; 01543 gini_missc.reshape(Shp(n +1,n+ 1)); 01544 corr_noise.reshape(Shp(n + 1, 10)); 01545 corr_l.reshape(Shp(n +1, 10)); 01546 01547 noise.reshape(Shp(pr.features().shape(0), 10)); 01548 noise_l.reshape(Shp(pr.features().shape(0), 10)); 01549 RandomMT19937 random(RandomSeed); 01550 for(int ii = 0; ii < noise.size(); ++ii) 01551 { 01552 noise[ii] = random.uniform53(); 01553 noise_l[ii] = random.uniform53() > 0.5; 01554 } 01555 bgfunc = ColumnDecisionFunctor( rf.ext_param_); 01556 tmp_labels.reshape(pr.response().shape()); 01557 tmp_cc.resize(2); 01558 numChoices.resize(n+1); 01559 // look at all axes 01560 } 01561 template<class RF, class PR> 01562 void visit_at_end(RF const & rf, PR const & pr) 01563 { 01564 typedef MultiArrayShape<2>::type Shp; 01565 similarity.reshape(gini_missc.shape()); 01566 similarity = gini_missc;; 01567 MultiArray<2, double> mean_noise(Shp(corr_noise.shape(0), 1)); 01568 rowStatistics(corr_noise, mean_noise); 01569 mean_noise/= MultiArrayView<2, int>(mean_noise.shape(), numChoices.data()); 01570 int rC = similarity.shape(0); 01571 for(int jj = 0; jj < rC-1; ++jj) 01572 { 01573 rowVector(similarity, jj) /= numChoices[jj]; 01574 rowVector(similarity, jj) -= mean_noise(jj, 0); 01575 } 01576 for(int jj = 0; jj < rC; ++jj) 01577 { 01578 similarity(rC -1, jj) /= numChoices[jj]; 01579 } 01580 rowVector(similarity, rC - 1) -= mean_noise(rC-1, 0); 01581 similarity = abs(similarity); 01582 FindMinMax<double> minmax; 01583 inspectMultiArray(srcMultiArrayRange(similarity), minmax); 01584 01585 for(int jj = 0; jj < rC; ++jj) 01586 similarity(jj, jj) = minmax.max; 01587 01588 similarity.subarray(Shp(0,0), Shp(rC-1, rC-1)) 01589 += similarity.subarray(Shp(0,0), Shp(rC-1, rC-1)).transpose(); 01590 similarity.subarray(Shp(0,0), Shp(rC-1, rC-1))/= 2; 01591 columnVector(similarity, rC-1) = rowVector(similarity, rC-1).transpose(); 01592 for(int jj = 0; jj < rC; ++jj) 01593 similarity(jj, jj) = 0; 01594 01595 FindMinMax<double> minmax2; 01596 inspectMultiArray(srcMultiArrayRange(similarity), minmax2); 01597 for(int jj = 0; jj < rC; ++jj) 01598 similarity(jj, jj) = minmax2.max; 01599 distance.reshape(gini_missc.shape(), minmax2.max); 01600 distance -= similarity; 01601 } 01602 01603 template<class Tree, class Split, class Region, class Feature_t, class Label_t> 01604 void visit_after_split( Tree & tree, 01605 Split & split, 01606 Region & parent, 01607 Region & leftChild, 01608 Region & rightChild, 01609 Feature_t & features, 01610 Label_t & labels) 01611 { 01612 if(split.createNode().typeID() == i_ThresholdNode) 01613 { 01614 double wgini; 01615 tmp_cc.init(0); 01616 for(int ii = 0; ii < parent.size(); ++ii) 01617 { 01618 tmp_labels[parent[ii]] 01619 = (features(parent[ii], split.bestSplitColumn()) < split.bestSplitThreshold()); 01620 ++tmp_cc[tmp_labels[parent[ii]]]; 01621 } 01622 double region_gini = bgfunc.loss_of_region(tmp_labels, 01623 parent.begin(), 01624 parent.end(), 01625 tmp_cc); 01626 01627 int n = split.bestSplitColumn(); 01628 ++numChoices[n]; 01629 ++(*(numChoices.end()-1)); 01630 //this functor does all the work 01631 for(int k = 0; k < features.shape(1); ++k) 01632 { 01633 bgfunc(columnVector(features, k), 01634 tmp_labels, 01635 parent.begin(), parent.end(), 01636 tmp_cc); 01637 wgini = (region_gini - bgfunc.min_gini_); 01638 gini_missc(n, k) 01639 += wgini; 01640 } 01641 for(int k = 0; k < 10; ++k) 01642 { 01643 bgfunc(columnVector(noise, k), 01644 tmp_labels, 01645 parent.begin(), parent.end(), 01646 tmp_cc); 01647 wgini = (region_gini - bgfunc.min_gini_); 01648 corr_noise(n, k) 01649 += wgini; 01650 } 01651 01652 for(int k = 0; k < 10; ++k) 01653 { 01654 bgfunc(columnVector(noise_l, k), 01655 tmp_labels, 01656 parent.begin(), parent.end(), 01657 tmp_cc); 01658 wgini = (region_gini - bgfunc.min_gini_); 01659 corr_l(n, k) 01660 += wgini; 01661 } 01662 bgfunc(labels, tmp_labels, parent.begin(), parent.end(),tmp_cc); 01663 wgini = (region_gini - bgfunc.min_gini_); 01664 gini_missc(n, columnCount(gini_missc)-1) 01665 += wgini; 01666 01667 region_gini = split.region_gini_; 01668 #if 1 01669 Node<i_ThresholdNode> node(split.createNode()); 01670 gini_missc(rowCount(gini_missc)-1, 01671 node.column()) 01672 +=split.region_gini_ - split.minGini(); 01673 #endif 01674 for(int k = 0; k < 10; ++k) 01675 { 01676 split.bgfunc(columnVector(noise, k), 01677 labels, 01678 parent.begin(), parent.end(), 01679 parent.classCounts()); 01680 corr_noise(rowCount(gini_missc)-1, 01681 k) 01682 += wgini; 01683 } 01684 #if 0 01685 for(int k = 0; k < tree.ext_param_.actual_mtry_; ++k) 01686 { 01687 wgini = region_gini - split.min_gini_[k]; 01688 01689 gini_missc(rowCount(gini_missc)-1, 01690 split.splitColumns[k]) 01691 += wgini; 01692 } 01693 01694 for(int k=tree.ext_param_.actual_mtry_; k<features.shape(1); ++k) 01695 { 01696 split.bgfunc(columnVector(features, split.splitColumns[k]), 01697 labels, 01698 parent.begin(), parent.end(), 01699 parent.classCounts()); 01700 wgini = region_gini - split.bgfunc.min_gini_; 01701 gini_missc(rowCount(gini_missc)-1, 01702 split.splitColumns[k]) += wgini; 01703 } 01704 #endif 01705 // remember to partition the data according to the best. 01706 gini_missc(rowCount(gini_missc)-1, 01707 columnCount(gini_missc)-1) 01708 += region_gini; 01709 SortSamplesByDimensions<Feature_t> 01710 sorter(features, split.bestSplitColumn(), split.bestSplitThreshold()); 01711 std::partition(parent.begin(), parent.end(), sorter); 01712 } 01713 } 01714 }; 01715 01716 01717 } // namespace visitors 01718 } // namespace rf 01719 } // namespace vigra 01720 01721 //@} 01722 #endif // RF_VISITORS_HXX
© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de) |
html generated using doxygen and Python
|