[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]
vigra/random_forest/rf_split.hxx | ![]() |
00001 /************************************************************************/ 00002 /* */ 00003 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */ 00004 /* */ 00005 /* This file is part of the VIGRA computer vision library. */ 00006 /* The VIGRA Website is */ 00007 /* http://hci.iwr.uni-heidelberg.de/vigra/ */ 00008 /* Please direct questions, bug reports, and contributions to */ 00009 /* ullrich.koethe@iwr.uni-heidelberg.de or */ 00010 /* vigra@informatik.uni-hamburg.de */ 00011 /* */ 00012 /* Permission is hereby granted, free of charge, to any person */ 00013 /* obtaining a copy of this software and associated documentation */ 00014 /* files (the "Software"), to deal in the Software without */ 00015 /* restriction, including without limitation the rights to use, */ 00016 /* copy, modify, merge, publish, distribute, sublicense, and/or */ 00017 /* sell copies of the Software, and to permit persons to whom the */ 00018 /* Software is furnished to do so, subject to the following */ 00019 /* conditions: */ 00020 /* */ 00021 /* The above copyright notice and this permission notice shall be */ 00022 /* included in all copies or substantial portions of the */ 00023 /* Software. */ 00024 /* */ 00025 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */ 00026 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */ 00027 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */ 00028 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */ 00029 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */ 00030 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */ 00031 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */ 00032 /* OTHER DEALINGS IN THE SOFTWARE. */ 00033 /* */ 00034 /************************************************************************/ 00035 #ifndef VIGRA_RANDOM_FOREST_SPLIT_HXX 00036 #define VIGRA_RANDOM_FOREST_SPLIT_HXX 00037 #include <algorithm> 00038 #include <cstddef> 00039 #include <map> 00040 #include <numeric> 00041 #include <math.h> 00042 #include "../mathutil.hxx" 00043 #include "../array_vector.hxx" 00044 #include "../sized_int.hxx" 00045 #include "../matrix.hxx" 00046 #include "../random.hxx" 00047 #include "../functorexpression.hxx" 00048 #include "rf_nodeproxy.hxx" 00049 //#include "rf_sampling.hxx" 00050 #include "rf_region.hxx" 00051 //#include "../hokashyap.hxx" 00052 //#include "vigra/rf_helpers.hxx" 00053 00054 namespace vigra 00055 { 00056 00057 // Incomplete Class to ensure that findBestSplit is always implemented in 00058 // the derived classes of SplitBase 00059 class CompileTimeError; 00060 00061 00062 namespace detail 00063 { 00064 template<class Tag> 00065 class Normalise 00066 { 00067 public: 00068 template<class Iter> 00069 static void exec(Iter begin, Iter end) 00070 {} 00071 }; 00072 00073 template<> 00074 class Normalise<ClassificationTag> 00075 { 00076 public: 00077 template<class Iter> 00078 static void exec (Iter begin, Iter end) 00079 { 00080 double bla = std::accumulate(begin, end, 0.0); 00081 for(int ii = 0; ii < end - begin; ++ii) 00082 begin[ii] = begin[ii]/bla ; 00083 } 00084 }; 00085 } 00086 00087 00088 /** Base Class for all SplitFunctors used with the \ref RandomForest class 00089 defines the interface used while learning a tree. 00090 **/ 00091 template<class Tag> 00092 class SplitBase 00093 { 00094 public: 00095 00096 typedef Tag RF_Tag; 00097 typedef DT_StackEntry<ArrayVectorView<Int32>::iterator> 00098 StackEntry_t; 00099 00100 ProblemSpec<> ext_param_; 00101 00102 NodeBase::T_Container_type t_data; 00103 NodeBase::P_Container_type p_data; 00104 00105 NodeBase node_; 00106 00107 /** returns the DecisionTree Node created by 00108 \ref SplitBase::findBestSplit() or \ref SplitBase::makeTerminalNode(). 00109 **/ 00110 00111 template<class T> 00112 void set_external_parameters(ProblemSpec<T> const & in) 00113 { 00114 ext_param_ = in; 00115 t_data.push_back(in.column_count_); 00116 t_data.push_back(in.class_count_); 00117 } 00118 00119 NodeBase & createNode() 00120 { 00121 return node_; 00122 } 00123 00124 int classCount() const 00125 { 00126 return int(t_data[1]); 00127 } 00128 00129 int featureCount() const 00130 { 00131 return int(t_data[0]); 00132 } 00133 00134 /** resets internal data. Should always be called before 00135 calling findBestSplit or makeTerminalNode 00136 **/ 00137 void reset() 00138 { 00139 t_data.resize(2); 00140 p_data.resize(0); 00141 } 00142 00143 00144 /** findBestSplit has to be re-implemented in derived split functor. 00145 The defaut implementation only insures that a CompileTime error is issued 00146 if no such method was defined. 00147 **/ 00148 00149 template<class T, class C, class T2, class C2, class Region, class Random> 00150 int findBestSplit(MultiArrayView<2, T, C> features, 00151 MultiArrayView<2, T2, C2> labels, 00152 Region region, 00153 ArrayVector<Region> childs, 00154 Random randint) 00155 { 00156 #ifndef __clang__ 00157 // FIXME: This compile-time checking trick does not work for clang. 00158 CompileTimeError SplitFunctor__findBestSplit_member_was_not_defined; 00159 #endif 00160 return 0; 00161 } 00162 00163 /** Default action for creating a terminal Node. 00164 sets the Class probability of the remaining region according to 00165 the class histogram 00166 **/ 00167 template<class T, class C, class T2,class C2, class Region, class Random> 00168 int makeTerminalNode(MultiArrayView<2, T, C> features, 00169 MultiArrayView<2, T2, C2> labels, 00170 Region & region, 00171 Random randint) 00172 { 00173 Node<e_ConstProbNode> ret(t_data, p_data); 00174 node_ = ret; 00175 if(ext_param_.class_weights_.size() != region.classCounts().size()) 00176 { 00177 std::copy(region.classCounts().begin(), 00178 region.classCounts().end(), 00179 ret.prob_begin()); 00180 } 00181 else 00182 { 00183 std::transform(region.classCounts().begin(), 00184 region.classCounts().end(), 00185 ext_param_.class_weights_.begin(), 00186 ret.prob_begin(), std::multiplies<double>()); 00187 } 00188 detail::Normalise<RF_Tag>::exec(ret.prob_begin(), ret.prob_end()); 00189 // std::copy(ret.prob_begin(), ret.prob_end(), std::ostream_iterator<double>(std::cerr, ", " )); 00190 // std::cerr << std::endl; 00191 ret.weights() = region.size(); 00192 return e_ConstProbNode; 00193 } 00194 00195 00196 }; 00197 00198 /** Functor to sort the indices of a feature Matrix by a certain dimension 00199 **/ 00200 template<class DataMatrix> 00201 class SortSamplesByDimensions 00202 { 00203 DataMatrix const & data_; 00204 MultiArrayIndex sortColumn_; 00205 double thresVal_; 00206 public: 00207 00208 SortSamplesByDimensions(DataMatrix const & data, 00209 MultiArrayIndex sortColumn, 00210 double thresVal = 0.0) 00211 : data_(data), 00212 sortColumn_(sortColumn), 00213 thresVal_(thresVal) 00214 {} 00215 00216 void setColumn(MultiArrayIndex sortColumn) 00217 { 00218 sortColumn_ = sortColumn; 00219 } 00220 void setThreshold(double value) 00221 { 00222 thresVal_ = value; 00223 } 00224 00225 bool operator()(MultiArrayIndex l, MultiArrayIndex r) const 00226 { 00227 return data_(l, sortColumn_) < data_(r, sortColumn_); 00228 } 00229 bool operator()(MultiArrayIndex l) const 00230 { 00231 return data_(l, sortColumn_) < thresVal_; 00232 } 00233 }; 00234 00235 template<class DataMatrix> 00236 class DimensionNotEqual 00237 { 00238 DataMatrix const & data_; 00239 MultiArrayIndex sortColumn_; 00240 00241 public: 00242 00243 DimensionNotEqual(DataMatrix const & data, 00244 MultiArrayIndex sortColumn) 00245 : data_(data), 00246 sortColumn_(sortColumn) 00247 {} 00248 00249 void setColumn(MultiArrayIndex sortColumn) 00250 { 00251 sortColumn_ = sortColumn; 00252 } 00253 00254 bool operator()(MultiArrayIndex l, MultiArrayIndex r) const 00255 { 00256 return data_(l, sortColumn_) != data_(r, sortColumn_); 00257 } 00258 }; 00259 00260 template<class DataMatrix> 00261 class SortSamplesByHyperplane 00262 { 00263 DataMatrix const & data_; 00264 Node<i_HyperplaneNode> const & node_; 00265 00266 public: 00267 00268 SortSamplesByHyperplane(DataMatrix const & data, 00269 Node<i_HyperplaneNode> const & node) 00270 : 00271 data_(data), 00272 node_(node) 00273 {} 00274 00275 /** calculate the distance of a sample point to a hyperplane 00276 */ 00277 double operator[](MultiArrayIndex l) const 00278 { 00279 double result_l = -1 * node_.intercept(); 00280 for(int ii = 0; ii < node_.columns_size(); ++ii) 00281 { 00282 result_l += rowVector(data_, l)[node_.columns_begin()[ii]] 00283 * node_.weights()[ii]; 00284 } 00285 return result_l; 00286 } 00287 00288 bool operator()(MultiArrayIndex l, MultiArrayIndex r) const 00289 { 00290 return (*this)[l] < (*this)[r]; 00291 } 00292 00293 }; 00294 00295 /** makes a Class Histogram given indices in a labels_ array 00296 * usage: 00297 * MultiArrayView<2, T2, C2> labels = makeSomeLabels() 00298 * ArrayVector<int> hist(numberOfLabels(labels), 0); 00299 * RandomForestClassCounter<T2, C2, ArrayVector> counter(labels, hist); 00300 * 00301 * Container<int> indices = getSomeIndices() 00302 * std::for_each(indices, counter); 00303 */ 00304 template <class DataSource, class CountArray> 00305 class RandomForestClassCounter 00306 { 00307 DataSource const & labels_; 00308 CountArray & counts_; 00309 00310 public: 00311 00312 RandomForestClassCounter(DataSource const & labels, 00313 CountArray & counts) 00314 : labels_(labels), 00315 counts_(counts) 00316 { 00317 reset(); 00318 } 00319 00320 void reset() 00321 { 00322 counts_.init(0); 00323 } 00324 00325 void operator()(MultiArrayIndex l) const 00326 { 00327 counts_[labels_[l]] +=1; 00328 } 00329 }; 00330 00331 00332 /** Functor To Calculate the Best possible Split Based on the Gini Index 00333 given Labels and Features along a given Axis 00334 */ 00335 00336 namespace detail 00337 { 00338 template<int N> 00339 class ConstArr 00340 { 00341 public: 00342 double operator[](size_t) const 00343 { 00344 return (double)N; 00345 } 00346 }; 00347 00348 00349 } 00350 00351 00352 00353 00354 /** Functor to calculate the entropy based impurity 00355 */ 00356 class EntropyCriterion 00357 { 00358 public: 00359 /**calculate the weighted gini impurity based on class histogram 00360 * and class weights 00361 */ 00362 template<class Array, class Array2> 00363 double operator() (Array const & hist, 00364 Array2 const & weights, 00365 double total = 1.0) const 00366 { 00367 return impurity(hist, weights, total); 00368 } 00369 00370 /** calculate the gini based impurity based on class histogram 00371 */ 00372 template<class Array> 00373 double operator()(Array const & hist, double total = 1.0) const 00374 { 00375 return impurity(hist, total); 00376 } 00377 00378 /** static version of operator(hist total) 00379 */ 00380 template<class Array> 00381 static double impurity(Array const & hist, double total) 00382 { 00383 return impurity(hist, detail::ConstArr<1>(), total); 00384 } 00385 00386 /** static version of operator(hist, weights, total) 00387 */ 00388 template<class Array, class Array2> 00389 static double impurity (Array const & hist, 00390 Array2 const & weights, 00391 double total) 00392 { 00393 00394 int class_count = hist.size(); 00395 double entropy = 0.0; 00396 if(class_count == 2) 00397 { 00398 double p0 = (hist[0]/total); 00399 double p1 = (hist[1]/total); 00400 entropy = 0 - weights[0]*p0*std::log(p0) - weights[1]*p1*std::log(p1); 00401 } 00402 else 00403 { 00404 for(int ii = 0; ii < class_count; ++ii) 00405 { 00406 double w = weights[ii]; 00407 double pii = hist[ii]/total; 00408 entropy -= w*( pii*std::log(pii)); 00409 } 00410 } 00411 entropy = total * entropy; 00412 return entropy; 00413 } 00414 }; 00415 00416 /** Functor to calculate the gini impurity 00417 */ 00418 class GiniCriterion 00419 { 00420 public: 00421 /**calculate the weighted gini impurity based on class histogram 00422 * and class weights 00423 */ 00424 template<class Array, class Array2> 00425 double operator() (Array const & hist, 00426 Array2 const & weights, 00427 double total = 1.0) const 00428 { 00429 return impurity(hist, weights, total); 00430 } 00431 00432 /** calculate the gini based impurity based on class histogram 00433 */ 00434 template<class Array> 00435 double operator()(Array const & hist, double total = 1.0) const 00436 { 00437 return impurity(hist, total); 00438 } 00439 00440 /** static version of operator(hist total) 00441 */ 00442 template<class Array> 00443 static double impurity(Array const & hist, double total) 00444 { 00445 return impurity(hist, detail::ConstArr<1>(), total); 00446 } 00447 00448 /** static version of operator(hist, weights, total) 00449 */ 00450 template<class Array, class Array2> 00451 static double impurity (Array const & hist, 00452 Array2 const & weights, 00453 double total) 00454 { 00455 00456 int class_count = hist.size(); 00457 double gini = 0.0; 00458 if(class_count == 2) 00459 { 00460 double w = weights[0] * weights[1]; 00461 gini = w * (hist[0] * hist[1] / total); 00462 } 00463 else 00464 { 00465 for(int ii = 0; ii < class_count; ++ii) 00466 { 00467 double w = weights[ii]; 00468 gini += w*( hist[ii]*( 1.0 - w * hist[ii]/total ) ); 00469 } 00470 } 00471 return gini; 00472 } 00473 }; 00474 00475 00476 template <class DataSource, class Impurity= GiniCriterion> 00477 class ImpurityLoss 00478 { 00479 00480 DataSource const & labels_; 00481 ArrayVector<double> counts_; 00482 ArrayVector<double> const class_weights_; 00483 double total_counts_; 00484 Impurity impurity_; 00485 00486 public: 00487 00488 template<class T> 00489 ImpurityLoss(DataSource const & labels, 00490 ProblemSpec<T> const & ext_) 00491 : labels_(labels), 00492 counts_(ext_.class_count_, 0.0), 00493 class_weights_(ext_.class_weights_), 00494 total_counts_(0.0) 00495 {} 00496 00497 void reset() 00498 { 00499 counts_.init(0); 00500 total_counts_ = 0.0; 00501 } 00502 00503 template<class Counts> 00504 double increment_histogram(Counts const & counts) 00505 { 00506 std::transform(counts.begin(), counts.end(), 00507 counts_.begin(), counts_.begin(), 00508 std::plus<double>()); 00509 total_counts_ = std::accumulate( counts_.begin(), 00510 counts_.end(), 00511 0.0); 00512 return impurity_(counts_, class_weights_, total_counts_); 00513 } 00514 00515 template<class Counts> 00516 double decrement_histogram(Counts const & counts) 00517 { 00518 std::transform(counts.begin(), counts.end(), 00519 counts_.begin(), counts_.begin(), 00520 std::minus<double>()); 00521 total_counts_ = std::accumulate( counts_.begin(), 00522 counts_.end(), 00523 0.0); 00524 return impurity_(counts_, class_weights_, total_counts_); 00525 } 00526 00527 template<class Iter> 00528 double increment(Iter begin, Iter end) 00529 { 00530 for(Iter iter = begin; iter != end; ++iter) 00531 { 00532 counts_[labels_(*iter, 0)] +=1.0; 00533 total_counts_ +=1.0; 00534 } 00535 return impurity_(counts_, class_weights_, total_counts_); 00536 } 00537 00538 template<class Iter> 00539 double decrement(Iter const & begin, Iter const & end) 00540 { 00541 for(Iter iter = begin; iter != end; ++iter) 00542 { 00543 counts_[labels_(*iter,0)] -=1.0; 00544 total_counts_ -=1.0; 00545 } 00546 return impurity_(counts_, class_weights_, total_counts_); 00547 } 00548 00549 template<class Iter, class Resp_t> 00550 double init (Iter begin, Iter end, Resp_t resp) 00551 { 00552 reset(); 00553 std::copy(resp.begin(), resp.end(), counts_.begin()); 00554 total_counts_ = std::accumulate(counts_.begin(), counts_.end(), 0.0); 00555 return impurity_(counts_,class_weights_, total_counts_); 00556 } 00557 00558 ArrayVector<double> const & response() 00559 { 00560 return counts_; 00561 } 00562 }; 00563 00564 00565 00566 template <class DataSource> 00567 class RegressionForestCounter 00568 { 00569 public: 00570 typedef MultiArrayShape<2>::type Shp; 00571 DataSource const & labels_; 00572 ArrayVector <double> mean_; 00573 ArrayVector <double> variance_; 00574 ArrayVector <double> tmp_; 00575 size_t count_; 00576 int* end_; 00577 00578 template<class T> 00579 RegressionForestCounter(DataSource const & labels, 00580 ProblemSpec<T> const & ext_) 00581 : 00582 labels_(labels), 00583 mean_(ext_.response_size_, 0.0), 00584 variance_(ext_.response_size_, 0.0), 00585 tmp_(ext_.response_size_), 00586 count_(0) 00587 {} 00588 00589 template<class Iter> 00590 double increment (Iter begin, Iter end) 00591 { 00592 for(Iter iter = begin; iter != end; ++iter) 00593 { 00594 ++count_; 00595 for(unsigned int ii = 0; ii < mean_.size(); ++ii) 00596 tmp_[ii] = labels_(*iter, ii) - mean_[ii]; 00597 double f = 1.0 / count_, 00598 f1 = 1.0 - f; 00599 for(unsigned int ii = 0; ii < mean_.size(); ++ii) 00600 mean_[ii] += f*tmp_[ii]; 00601 for(unsigned int ii = 0; ii < mean_.size(); ++ii) 00602 variance_[ii] += f1*sq(tmp_[ii]); 00603 } 00604 double res = std::accumulate(variance_.begin(), 00605 variance_.end(), 00606 0.0, 00607 std::plus<double>()); 00608 //std::cerr << res << " ) = "; 00609 return res; 00610 } 00611 00612 template<class Iter> //This is BROKEN 00613 double decrement (Iter begin, Iter end) 00614 { 00615 for(Iter iter = begin; iter != end; ++iter) 00616 { 00617 --count_; 00618 } 00619 00620 begin = end; 00621 end = end + count_; 00622 00623 00624 for(unsigned int ii = 0; ii < mean_.size(); ++ii) 00625 { 00626 mean_[ii] = 0; 00627 for(Iter iter = begin; iter != end; ++iter) 00628 { 00629 mean_[ii] += labels_(*iter, ii); 00630 } 00631 mean_[ii] /= count_; 00632 variance_[ii] = 0; 00633 for(Iter iter = begin; iter != end; ++iter) 00634 { 00635 variance_[ii] += (labels_(*iter, ii) - mean_[ii])*(labels_(*iter, ii) - mean_[ii]); 00636 } 00637 } 00638 double res = std::accumulate(variance_.begin(), 00639 variance_.end(), 00640 0.0, 00641 std::plus<double>()); 00642 //std::cerr << res << " ) = "; 00643 return res; 00644 } 00645 00646 00647 template<class Iter, class Resp_t> 00648 double init (Iter begin, Iter end, Resp_t resp) 00649 { 00650 reset(); 00651 return this->increment(begin, end); 00652 00653 } 00654 00655 00656 ArrayVector<double> const & response() 00657 { 00658 return mean_; 00659 } 00660 00661 void reset() 00662 { 00663 mean_.init(0.0); 00664 variance_.init(0.0); 00665 count_ = 0; 00666 } 00667 }; 00668 00669 00670 template <class DataSource> 00671 class RegressionForestCounter2 00672 { 00673 public: 00674 typedef MultiArrayShape<2>::type Shp; 00675 DataSource const & labels_; 00676 ArrayVector <double> mean_; 00677 ArrayVector <double> variance_; 00678 ArrayVector <double> tmp_; 00679 size_t count_; 00680 00681 template<class T> 00682 RegressionForestCounter2(DataSource const & labels, 00683 ProblemSpec<T> const & ext_) 00684 : 00685 labels_(labels), 00686 mean_(ext_.response_size_, 0.0), 00687 variance_(ext_.response_size_, 0.0), 00688 tmp_(ext_.response_size_), 00689 count_(0) 00690 {} 00691 00692 template<class Iter> 00693 double increment (Iter begin, Iter end) 00694 { 00695 for(Iter iter = begin; iter != end; ++iter) 00696 { 00697 ++count_; 00698 for(int ii = 0; ii < mean_.size(); ++ii) 00699 tmp_[ii] = labels_(*iter, ii) - mean_[ii]; 00700 double f = 1.0 / count_, 00701 f1 = 1.0 - f; 00702 for(int ii = 0; ii < mean_.size(); ++ii) 00703 mean_[ii] += f*tmp_[ii]; 00704 for(int ii = 0; ii < mean_.size(); ++ii) 00705 variance_[ii] += f1*sq(tmp_[ii]); 00706 } 00707 double res = std::accumulate(variance_.begin(), 00708 variance_.end(), 00709 0.0, 00710 std::plus<double>()) 00711 /((count_ == 1)? 1:(count_ -1)); 00712 //std::cerr << res << " ) = "; 00713 return res; 00714 } 00715 00716 template<class Iter> //This is BROKEN 00717 double decrement (Iter begin, Iter end) 00718 { 00719 for(Iter iter = begin; iter != end; ++iter) 00720 { 00721 double f = 1.0 / count_, 00722 f1 = 1.0 - f; 00723 for(int ii = 0; ii < mean_.size(); ++ii) 00724 mean_[ii] = (mean_[ii] - f*labels_(*iter,ii))/(1-f); 00725 for(int ii = 0; ii < mean_.size(); ++ii) 00726 variance_[ii] -= f1*sq(labels_(*iter,ii) - mean_[ii]); 00727 --count_; 00728 } 00729 double res = std::accumulate(variance_.begin(), 00730 variance_.end(), 00731 0.0, 00732 std::plus<double>()) 00733 /((count_ == 1)? 1:(count_ -1)); 00734 //std::cerr << "( " << res << " + "; 00735 return res; 00736 } 00737 /* west's algorithm for incremental variance 00738 // calculation 00739 template<class Iter> 00740 double increment (Iter begin, Iter end) 00741 { 00742 for(Iter iter = begin; iter != end; ++iter) 00743 { 00744 ++count_; 00745 for(int ii = 0; ii < mean_.size(); ++ii) 00746 tmp_[ii] = labels_(*iter, ii) - mean_[ii]; 00747 double f = 1.0 / count_, 00748 f1 = 1.0 - f; 00749 for(int ii = 0; ii < mean_.size(); ++ii) 00750 mean_[ii] += f*tmp_[ii]; 00751 for(int ii = 0; ii < mean_.size(); ++ii) 00752 variance_[ii] += f1*sq(tmp_[ii]); 00753 } 00754 return std::accumulate(variance_.begin(), 00755 variance_.end(), 00756 0.0, 00757 std::plus<double>()) 00758 /(count_ -1); 00759 } 00760 00761 template<class Iter> 00762 double decrement (Iter begin, Iter end) 00763 { 00764 for(Iter iter = begin; iter != end; ++iter) 00765 { 00766 --count_; 00767 for(int ii = 0; ii < mean_.size(); ++ii) 00768 tmp_[ii] = labels_(*iter, ii) - mean_[ii]; 00769 double f = 1.0 / count_, 00770 f1 = 1.0 + f; 00771 for(int ii = 0; ii < mean_.size(); ++ii) 00772 mean_[ii] -= f*tmp_[ii]; 00773 for(int ii = 0; ii < mean_.size(); ++ii) 00774 variance_[ii] -= f1*sq(tmp_[ii]); 00775 } 00776 return std::accumulate(variance_.begin(), 00777 variance_.end(), 00778 0.0, 00779 std::plus<double>()) 00780 /(count_ -1); 00781 }*/ 00782 00783 template<class Iter, class Resp_t> 00784 double init (Iter begin, Iter end, Resp_t resp) 00785 { 00786 reset(); 00787 return this->increment(begin, end, resp); 00788 } 00789 00790 00791 ArrayVector<double> const & response() 00792 { 00793 return mean_; 00794 } 00795 00796 void reset() 00797 { 00798 mean_.init(0.0); 00799 variance_.init(0.0); 00800 count_ = 0; 00801 } 00802 }; 00803 00804 template<class Tag, class Datatyp> 00805 struct LossTraits; 00806 00807 struct LSQLoss 00808 {}; 00809 00810 template<class Datatype> 00811 struct LossTraits<GiniCriterion, Datatype> 00812 { 00813 typedef ImpurityLoss<Datatype, GiniCriterion> type; 00814 }; 00815 00816 template<class Datatype> 00817 struct LossTraits<EntropyCriterion, Datatype> 00818 { 00819 typedef ImpurityLoss<Datatype, EntropyCriterion> type; 00820 }; 00821 00822 template<class Datatype> 00823 struct LossTraits<LSQLoss, Datatype> 00824 { 00825 typedef RegressionForestCounter<Datatype> type; 00826 }; 00827 00828 /** Given a column, choose a split that minimizes some loss 00829 */ 00830 template<class LineSearchLossTag> 00831 class BestGiniOfColumn 00832 { 00833 public: 00834 ArrayVector<double> class_weights_; 00835 ArrayVector<double> bestCurrentCounts[2]; 00836 double min_gini_; 00837 std::ptrdiff_t min_index_; 00838 double min_threshold_; 00839 ProblemSpec<> ext_param_; 00840 00841 BestGiniOfColumn() 00842 {} 00843 00844 template<class T> 00845 BestGiniOfColumn(ProblemSpec<T> const & ext) 00846 : 00847 class_weights_(ext.class_weights_), 00848 ext_param_(ext) 00849 { 00850 bestCurrentCounts[0].resize(ext.class_count_); 00851 bestCurrentCounts[1].resize(ext.class_count_); 00852 } 00853 template<class T> 00854 void set_external_parameters(ProblemSpec<T> const & ext) 00855 { 00856 class_weights_ = ext.class_weights_; 00857 ext_param_ = ext; 00858 bestCurrentCounts[0].resize(ext.class_count_); 00859 bestCurrentCounts[1].resize(ext.class_count_); 00860 } 00861 /** calculate the best gini split along a Feature Column 00862 * \param column the feature vector - has to support the [] operator 00863 * \param labels the label vector 00864 * \param begin 00865 * \param end (in and out) 00866 * begin and end iterators to the indices of the 00867 * samples in the current region. 00868 * the range begin - end is sorted by the column supplied 00869 * during function execution. 00870 * \param region_response 00871 * ??? 00872 * class histogram of the range. 00873 * 00874 * precondition: begin, end valid range, 00875 * class_counts positive integer valued array with the 00876 * class counts in the current range. 00877 * labels.size() >= max(begin, end); 00878 * postcondition: 00879 * begin, end sorted by column given. 00880 * min_gini_ contains the minimum gini found or 00881 * NumericTraits<double>::max if no split was found. 00882 * min_index_ contains the splitting index in the range 00883 * or invalid data if no split was found. 00884 * BestCirremtcounts[0] and [1] contain the 00885 * class histogram of the left and right region of 00886 * the left and right regions. 00887 */ 00888 template< class DataSourceF_t, 00889 class DataSource_t, 00890 class I_Iter, 00891 class Array> 00892 void operator()(DataSourceF_t const & column, 00893 DataSource_t const & labels, 00894 I_Iter & begin, 00895 I_Iter & end, 00896 Array const & region_response) 00897 { 00898 std::sort(begin, end, 00899 SortSamplesByDimensions<DataSourceF_t>(column, 0)); 00900 typedef typename 00901 LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss; 00902 LineSearchLoss left(labels, ext_param_); //initialize left and right region 00903 LineSearchLoss right(labels, ext_param_); 00904 00905 00906 00907 min_gini_ = right.init(begin, end, region_response); 00908 min_threshold_ = *begin; 00909 min_index_ = 0; //the starting point where to split 00910 DimensionNotEqual<DataSourceF_t> comp(column, 0); 00911 00912 I_Iter iter = begin; 00913 I_Iter next = std::adjacent_find(iter, end, comp); 00914 //std::cerr << std::distance(begin, end) << std::endl; 00915 while( next != end) 00916 { 00917 double lr = right.decrement(iter, next + 1); 00918 double ll = left.increment(iter , next + 1); 00919 double loss = lr +ll; 00920 //std::cerr <<lr << " + "<< ll << " " << loss << " "; 00921 #ifdef CLASSIFIER_TEST 00922 if(loss < min_gini_ && !closeAtTolerance(loss, min_gini_)) 00923 #else 00924 if(loss < min_gini_ ) 00925 #endif 00926 { 00927 bestCurrentCounts[0] = left.response(); 00928 bestCurrentCounts[1] = right.response(); 00929 #ifdef CLASSIFIER_TEST 00930 min_gini_ = loss < min_gini_? loss : min_gini_; 00931 #else 00932 min_gini_ = loss; 00933 #endif 00934 min_index_ = next - begin +1 ; 00935 min_threshold_ = (double(column(*next,0)) + double(column(*(next +1), 0)))/2.0; 00936 } 00937 iter = next +1 ; 00938 next = std::adjacent_find(iter, end, comp); 00939 } 00940 //std::cerr << std::endl << " 000 " << std::endl; 00941 //int in; 00942 //std::cin >> in; 00943 } 00944 00945 template<class DataSource_t, class Iter, class Array> 00946 double loss_of_region(DataSource_t const & labels, 00947 Iter & begin, 00948 Iter & end, 00949 Array const & region_response) const 00950 { 00951 typedef typename 00952 LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss; 00953 LineSearchLoss region_loss(labels, ext_param_); 00954 return 00955 region_loss.init(begin, end, region_response); 00956 } 00957 00958 }; 00959 00960 namespace detail 00961 { 00962 template<class T> 00963 struct Correction 00964 { 00965 template<class Region, class LabelT> 00966 static void exec(Region & in, LabelT & labels) 00967 {} 00968 }; 00969 00970 template<> 00971 struct Correction<ClassificationTag> 00972 { 00973 template<class Region, class LabelT> 00974 static void exec(Region & region, LabelT & labels) 00975 { 00976 if(std::accumulate(region.classCounts().begin(), 00977 region.classCounts().end(), 0.0) != region.size()) 00978 { 00979 RandomForestClassCounter< LabelT, 00980 ArrayVector<double> > 00981 counter(labels, region.classCounts()); 00982 std::for_each( region.begin(), region.end(), counter); 00983 region.classCountsIsValid = true; 00984 } 00985 } 00986 }; 00987 } 00988 00989 /** Chooses mtry columns and applies ColumnDecisionFunctor to each of the 00990 * columns. Then Chooses the column that is best 00991 */ 00992 template<class ColumnDecisionFunctor, class Tag = ClassificationTag> 00993 class ThresholdSplit: public SplitBase<Tag> 00994 { 00995 public: 00996 00997 00998 typedef SplitBase<Tag> SB; 00999 01000 ArrayVector<Int32> splitColumns; 01001 ColumnDecisionFunctor bgfunc; 01002 01003 double region_gini_; 01004 ArrayVector<double> min_gini_; 01005 ArrayVector<std::ptrdiff_t> min_indices_; 01006 ArrayVector<double> min_thresholds_; 01007 01008 int bestSplitIndex; 01009 01010 double minGini() const 01011 { 01012 return min_gini_[bestSplitIndex]; 01013 } 01014 int bestSplitColumn() const 01015 { 01016 return splitColumns[bestSplitIndex]; 01017 } 01018 double bestSplitThreshold() const 01019 { 01020 return min_thresholds_[bestSplitIndex]; 01021 } 01022 01023 template<class T> 01024 void set_external_parameters(ProblemSpec<T> const & in) 01025 { 01026 SB::set_external_parameters(in); 01027 bgfunc.set_external_parameters( SB::ext_param_); 01028 int featureCount_ = SB::ext_param_.column_count_; 01029 splitColumns.resize(featureCount_); 01030 for(int k=0; k<featureCount_; ++k) 01031 splitColumns[k] = k; 01032 min_gini_.resize(featureCount_); 01033 min_indices_.resize(featureCount_); 01034 min_thresholds_.resize(featureCount_); 01035 } 01036 01037 01038 template<class T, class C, class T2, class C2, class Region, class Random> 01039 int findBestSplit(MultiArrayView<2, T, C> features, 01040 MultiArrayView<2, T2, C2> labels, 01041 Region & region, 01042 ArrayVector<Region>& childRegions, 01043 Random & randint) 01044 { 01045 01046 typedef typename Region::IndexIterator IndexIterator; 01047 if(region.size() == 0) 01048 { 01049 std::cerr << "SplitFunctor::findBestSplit(): stackentry with 0 examples encountered\n" 01050 "continuing learning process...."; 01051 } 01052 // calculate things that haven't been calculated yet. 01053 detail::Correction<Tag>::exec(region, labels); 01054 01055 01056 // Is the region pure already? 01057 region_gini_ = bgfunc.loss_of_region(labels, 01058 region.begin(), 01059 region.end(), 01060 region.classCounts()); 01061 if(region_gini_ <= SB::ext_param_.precision_) 01062 return this->makeTerminalNode(features, labels, region, randint); 01063 01064 // select columns to be tried. 01065 for(int ii = 0; ii < SB::ext_param_.actual_mtry_; ++ii) 01066 std::swap(splitColumns[ii], 01067 splitColumns[ii+ randint(features.shape(1) - ii)]); 01068 01069 // find the best gini index 01070 bestSplitIndex = 0; 01071 double current_min_gini = region_gini_; 01072 int num2try = features.shape(1); 01073 for(int k=0; k<num2try; ++k) 01074 { 01075 //this functor does all the work 01076 bgfunc(columnVector(features, splitColumns[k]), 01077 labels, 01078 region.begin(), region.end(), 01079 region.classCounts()); 01080 min_gini_[k] = bgfunc.min_gini_; 01081 min_indices_[k] = bgfunc.min_index_; 01082 min_thresholds_[k] = bgfunc.min_threshold_; 01083 #ifdef CLASSIFIER_TEST 01084 if( bgfunc.min_gini_ < current_min_gini 01085 && !closeAtTolerance(bgfunc.min_gini_, current_min_gini)) 01086 #else 01087 if(bgfunc.min_gini_ < current_min_gini) 01088 #endif 01089 { 01090 current_min_gini = bgfunc.min_gini_; 01091 childRegions[0].classCounts() = bgfunc.bestCurrentCounts[0]; 01092 childRegions[1].classCounts() = bgfunc.bestCurrentCounts[1]; 01093 childRegions[0].classCountsIsValid = true; 01094 childRegions[1].classCountsIsValid = true; 01095 01096 bestSplitIndex = k; 01097 num2try = SB::ext_param_.actual_mtry_; 01098 } 01099 } 01100 //std::cerr << current_min_gini << "curr " << region_gini_ << std::endl; 01101 // did not find any suitable split 01102 if(closeAtTolerance(current_min_gini, region_gini_)) 01103 return this->makeTerminalNode(features, labels, region, randint); 01104 01105 //create a Node for output 01106 Node<i_ThresholdNode> node(SB::t_data, SB::p_data); 01107 SB::node_ = node; 01108 node.threshold() = min_thresholds_[bestSplitIndex]; 01109 node.column() = splitColumns[bestSplitIndex]; 01110 01111 // partition the range according to the best dimension 01112 SortSamplesByDimensions<MultiArrayView<2, T, C> > 01113 sorter(features, node.column(), node.threshold()); 01114 IndexIterator bestSplit = 01115 std::partition(region.begin(), region.end(), sorter); 01116 // Save the ranges of the child stack entries. 01117 childRegions[0].setRange( region.begin() , bestSplit ); 01118 childRegions[0].rule = region.rule; 01119 childRegions[0].rule.push_back(std::make_pair(1, 1.0)); 01120 childRegions[1].setRange( bestSplit , region.end() ); 01121 childRegions[1].rule = region.rule; 01122 childRegions[1].rule.push_back(std::make_pair(1, 1.0)); 01123 01124 return i_ThresholdNode; 01125 } 01126 }; 01127 01128 typedef ThresholdSplit<BestGiniOfColumn<GiniCriterion> > GiniSplit; 01129 typedef ThresholdSplit<BestGiniOfColumn<EntropyCriterion> > EntropySplit; 01130 typedef ThresholdSplit<BestGiniOfColumn<LSQLoss>, RegressionTag> RegressionSplit; 01131 01132 namespace rf 01133 { 01134 01135 /** This namespace contains additional Splitfunctors. 01136 * 01137 * The Split functor classes are designed in a modular fashion because new split functors may 01138 * share a lot of code with existing ones. 01139 * 01140 * ThresholdSplit implements the functionality needed for any split functor, that makes its 01141 * decision via one dimensional axis-parallel cuts. The Template parameter defines how the split 01142 * along one dimension is chosen. 01143 * 01144 * The BestGiniOfColumn class chooses a split that minimizes one of the Loss functions supplied 01145 * (GiniCriterion for classification and LSQLoss for regression). Median chooses the Split in a 01146 * kD tree fashion. 01147 * 01148 * 01149 * Currently defined typedefs: 01150 * \code 01151 * typedef ThresholdSplit<BestGiniOfColumn<GiniCriterion> > GiniSplit; 01152 * typedef ThresholdSplit<BestGiniOfColumn<LSQLoss>, RegressionTag> RegressionSplit; 01153 * typedef ThresholdSplit<Median> MedianSplit; 01154 * \endcode 01155 */ 01156 namespace split 01157 { 01158 01159 /** This Functor chooses the median value of a column 01160 */ 01161 class Median 01162 { 01163 public: 01164 01165 typedef GiniCriterion LineSearchLossTag; 01166 ArrayVector<double> class_weights_; 01167 ArrayVector<double> bestCurrentCounts[2]; 01168 double min_gini_; 01169 std::ptrdiff_t min_index_; 01170 double min_threshold_; 01171 ProblemSpec<> ext_param_; 01172 01173 Median() 01174 {} 01175 01176 template<class T> 01177 Median(ProblemSpec<T> const & ext) 01178 : 01179 class_weights_(ext.class_weights_), 01180 ext_param_(ext) 01181 { 01182 bestCurrentCounts[0].resize(ext.class_count_); 01183 bestCurrentCounts[1].resize(ext.class_count_); 01184 } 01185 01186 template<class T> 01187 void set_external_parameters(ProblemSpec<T> const & ext) 01188 { 01189 class_weights_ = ext.class_weights_; 01190 ext_param_ = ext; 01191 bestCurrentCounts[0].resize(ext.class_count_); 01192 bestCurrentCounts[1].resize(ext.class_count_); 01193 } 01194 01195 template< class DataSourceF_t, 01196 class DataSource_t, 01197 class I_Iter, 01198 class Array> 01199 void operator()(DataSourceF_t const & column, 01200 DataSource_t const & labels, 01201 I_Iter & begin, 01202 I_Iter & end, 01203 Array const & region_response) 01204 { 01205 std::sort(begin, end, 01206 SortSamplesByDimensions<DataSourceF_t>(column, 0)); 01207 typedef typename 01208 LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss; 01209 LineSearchLoss left(labels, ext_param_); 01210 LineSearchLoss right(labels, ext_param_); 01211 right.init(begin, end, region_response); 01212 01213 min_gini_ = NumericTraits<double>::max(); 01214 min_index_ = floor(double(end - begin)/2.0); 01215 min_threshold_ = column[*(begin + min_index_)]; 01216 SortSamplesByDimensions<DataSourceF_t> 01217 sorter(column, 0, min_threshold_); 01218 I_Iter part = std::partition(begin, end, sorter); 01219 DimensionNotEqual<DataSourceF_t> comp(column, 0); 01220 if(part == begin) 01221 { 01222 part= std::adjacent_find(part, end, comp)+1; 01223 01224 } 01225 if(part >= end) 01226 { 01227 return; 01228 } 01229 else 01230 { 01231 min_threshold_ = column[*part]; 01232 } 01233 min_gini_ = right.decrement(begin, part) 01234 + left.increment(begin , part); 01235 01236 bestCurrentCounts[0] = left.response(); 01237 bestCurrentCounts[1] = right.response(); 01238 01239 min_index_ = part - begin; 01240 } 01241 01242 template<class DataSource_t, class Iter, class Array> 01243 double loss_of_region(DataSource_t const & labels, 01244 Iter & begin, 01245 Iter & end, 01246 Array const & region_response) const 01247 { 01248 typedef typename 01249 LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss; 01250 LineSearchLoss region_loss(labels, ext_param_); 01251 return 01252 region_loss.init(begin, end, region_response); 01253 } 01254 01255 }; 01256 01257 typedef ThresholdSplit<Median> MedianSplit; 01258 01259 01260 /** This Functor chooses a random value of a column 01261 */ 01262 class RandomSplitOfColumn 01263 { 01264 public: 01265 01266 typedef GiniCriterion LineSearchLossTag; 01267 ArrayVector<double> class_weights_; 01268 ArrayVector<double> bestCurrentCounts[2]; 01269 double min_gini_; 01270 std::ptrdiff_t min_index_; 01271 double min_threshold_; 01272 ProblemSpec<> ext_param_; 01273 typedef RandomMT19937 Random_t; 01274 Random_t random; 01275 01276 RandomSplitOfColumn() 01277 {} 01278 01279 template<class T> 01280 RandomSplitOfColumn(ProblemSpec<T> const & ext) 01281 : 01282 class_weights_(ext.class_weights_), 01283 ext_param_(ext), 01284 random(RandomSeed) 01285 { 01286 bestCurrentCounts[0].resize(ext.class_count_); 01287 bestCurrentCounts[1].resize(ext.class_count_); 01288 } 01289 01290 template<class T> 01291 RandomSplitOfColumn(ProblemSpec<T> const & ext, Random_t & random_) 01292 : 01293 class_weights_(ext.class_weights_), 01294 ext_param_(ext), 01295 random(random_) 01296 { 01297 bestCurrentCounts[0].resize(ext.class_count_); 01298 bestCurrentCounts[1].resize(ext.class_count_); 01299 } 01300 01301 template<class T> 01302 void set_external_parameters(ProblemSpec<T> const & ext) 01303 { 01304 class_weights_ = ext.class_weights_; 01305 ext_param_ = ext; 01306 bestCurrentCounts[0].resize(ext.class_count_); 01307 bestCurrentCounts[1].resize(ext.class_count_); 01308 } 01309 01310 template< class DataSourceF_t, 01311 class DataSource_t, 01312 class I_Iter, 01313 class Array> 01314 void operator()(DataSourceF_t const & column, 01315 DataSource_t const & labels, 01316 I_Iter & begin, 01317 I_Iter & end, 01318 Array const & region_response) 01319 { 01320 std::sort(begin, end, 01321 SortSamplesByDimensions<DataSourceF_t>(column, 0)); 01322 typedef typename 01323 LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss; 01324 LineSearchLoss left(labels, ext_param_); 01325 LineSearchLoss right(labels, ext_param_); 01326 right.init(begin, end, region_response); 01327 01328 01329 min_gini_ = NumericTraits<double>::max(); 01330 int tmp_pt = random.uniformInt(std::distance(begin, end)); 01331 min_index_ = tmp_pt; 01332 min_threshold_ = column[*(begin + min_index_)]; 01333 SortSamplesByDimensions<DataSourceF_t> 01334 sorter(column, 0, min_threshold_); 01335 I_Iter part = std::partition(begin, end, sorter); 01336 DimensionNotEqual<DataSourceF_t> comp(column, 0); 01337 if(part == begin) 01338 { 01339 part= std::adjacent_find(part, end, comp)+1; 01340 01341 } 01342 if(part >= end) 01343 { 01344 return; 01345 } 01346 else 01347 { 01348 min_threshold_ = column[*part]; 01349 } 01350 min_gini_ = right.decrement(begin, part) 01351 + left.increment(begin , part); 01352 01353 bestCurrentCounts[0] = left.response(); 01354 bestCurrentCounts[1] = right.response(); 01355 01356 min_index_ = part - begin; 01357 } 01358 01359 template<class DataSource_t, class Iter, class Array> 01360 double loss_of_region(DataSource_t const & labels, 01361 Iter & begin, 01362 Iter & end, 01363 Array const & region_response) const 01364 { 01365 typedef typename 01366 LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss; 01367 LineSearchLoss region_loss(labels, ext_param_); 01368 return 01369 region_loss.init(begin, end, region_response); 01370 } 01371 01372 }; 01373 01374 typedef ThresholdSplit<RandomSplitOfColumn> RandomSplit; 01375 } 01376 } 01377 01378 01379 } //namespace vigra 01380 #endif // VIGRA_RANDOM_FOREST_SPLIT_HXX
© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de) |
html generated using doxygen and Python
|