[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]
vigra/random_forest/rf_split.hxx | ![]() |
00001 /************************************************************************/ 00002 /* */ 00003 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */ 00004 /* */ 00005 /* This file is part of the VIGRA computer vision library. */ 00006 /* The VIGRA Website is */ 00007 /* http://hci.iwr.uni-heidelberg.de/vigra/ */ 00008 /* Please direct questions, bug reports, and contributions to */ 00009 /* ullrich.koethe@iwr.uni-heidelberg.de or */ 00010 /* vigra@informatik.uni-hamburg.de */ 00011 /* */ 00012 /* Permission is hereby granted, free of charge, to any person */ 00013 /* obtaining a copy of this software and associated documentation */ 00014 /* files (the "Software"), to deal in the Software without */ 00015 /* restriction, including without limitation the rights to use, */ 00016 /* copy, modify, merge, publish, distribute, sublicense, and/or */ 00017 /* sell copies of the Software, and to permit persons to whom the */ 00018 /* Software is furnished to do so, subject to the following */ 00019 /* conditions: */ 00020 /* */ 00021 /* The above copyright notice and this permission notice shall be */ 00022 /* included in all copies or substantial portions of the */ 00023 /* Software. */ 00024 /* */ 00025 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */ 00026 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */ 00027 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */ 00028 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */ 00029 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */ 00030 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */ 00031 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */ 00032 /* OTHER DEALINGS IN THE SOFTWARE. */ 00033 /* */ 00034 /************************************************************************/ 00035 #ifndef VIGRA_RANDOM_FOREST_SPLIT_HXX 00036 #define VIGRA_RANDOM_FOREST_SPLIT_HXX 00037 #include <algorithm> 00038 #include <map> 00039 #include <numeric> 00040 #include <math.h> 00041 #include "../mathutil.hxx" 00042 #include "../array_vector.hxx" 00043 #include "../sized_int.hxx" 00044 #include "../matrix.hxx" 00045 #include "../random.hxx" 00046 #include "../functorexpression.hxx" 00047 #include "rf_nodeproxy.hxx" 00048 //#include "rf_sampling.hxx" 00049 #include "rf_region.hxx" 00050 //#include "../hokashyap.hxx" 00051 //#include "vigra/rf_helpers.hxx" 00052 00053 namespace vigra 00054 { 00055 00056 // Incomplete Class to ensure that findBestSplit is always implemented in 00057 // the derived classes of SplitBase 00058 class CompileTimeError; 00059 00060 00061 namespace detail 00062 { 00063 template<class Tag> 00064 class Normalise 00065 { 00066 public: 00067 template<class Iter> 00068 static void exec(Iter begin, Iter end) 00069 {} 00070 }; 00071 00072 template<> 00073 class Normalise<ClassificationTag> 00074 { 00075 public: 00076 template<class Iter> 00077 static void exec (Iter begin, Iter end) 00078 { 00079 double bla = std::accumulate(begin, end, 0.0); 00080 for(int ii = 0; ii < end - begin; ++ii) 00081 begin[ii] = begin[ii]/bla ; 00082 } 00083 }; 00084 } 00085 00086 00087 /** Base Class for all SplitFunctors used with the \ref RandomForest class 00088 defines the interface used while learning a tree. 00089 **/ 00090 template<class Tag> 00091 class SplitBase 00092 { 00093 public: 00094 00095 typedef Tag RF_Tag; 00096 typedef DT_StackEntry<ArrayVectorView<Int32>::iterator> 00097 StackEntry_t; 00098 00099 ProblemSpec<> ext_param_; 00100 00101 NodeBase::T_Container_type t_data; 00102 NodeBase::P_Container_type p_data; 00103 00104 NodeBase node_; 00105 00106 /** returns the DecisionTree Node created by 00107 \ref findBestSplit or \ref makeTerminalNode. 00108 **/ 00109 00110 template<class T> 00111 void set_external_parameters(ProblemSpec<T> const & in) 00112 { 00113 ext_param_ = in; 00114 t_data.push_back(in.column_count_); 00115 t_data.push_back(in.class_count_); 00116 } 00117 00118 NodeBase & createNode() 00119 { 00120 return node_; 00121 } 00122 00123 int classCount() const 00124 { 00125 return int(t_data[1]); 00126 } 00127 00128 int featureCount() const 00129 { 00130 return int(t_data[0]); 00131 } 00132 00133 /** resets internal data. Should always be called before 00134 calling findBestSplit or makeTerminalNode 00135 **/ 00136 void reset() 00137 { 00138 t_data.resize(2); 00139 p_data.resize(0); 00140 } 00141 00142 00143 /** findBestSplit has to be implemented in derived split functor. 00144 these functions only insures That a CompileTime error is issued 00145 if no such method was defined. 00146 **/ 00147 00148 template<class T, class C, class T2, class C2, class Region, class Random> 00149 int findBestSplit(MultiArrayView<2, T, C> features, 00150 MultiArrayView<2, T2, C2> labels, 00151 Region region, 00152 ArrayVector<Region> childs, 00153 Random randint) 00154 { 00155 CompileTimeError SplitFunctor__findBestSplit_member_was_not_defined; 00156 return 0; 00157 } 00158 00159 /** default action for creating a terminal Node. 00160 sets the Class probability of the remaining region according to 00161 the class histogram 00162 **/ 00163 template<class T, class C, class T2,class C2, class Region, class Random> 00164 int makeTerminalNode(MultiArrayView<2, T, C> features, 00165 MultiArrayView<2, T2, C2> labels, 00166 Region & region, 00167 Random randint) 00168 { 00169 Node<e_ConstProbNode> ret(t_data, p_data); 00170 node_ = ret; 00171 if(ext_param_.class_weights_.size() != region.classCounts().size()) 00172 { 00173 std::copy( region.classCounts().begin(), 00174 region.classCounts().end(), 00175 ret.prob_begin()); 00176 } 00177 else 00178 { 00179 std::transform( region.classCounts().begin(), 00180 region.classCounts().end(), 00181 ext_param_.class_weights_.begin(), 00182 ret.prob_begin(), std::multiplies<double>()); 00183 } 00184 detail::Normalise<RF_Tag>::exec(ret.prob_begin(), ret.prob_end()); 00185 ret.weights() = region.size(); 00186 return e_ConstProbNode; 00187 } 00188 00189 00190 }; 00191 00192 /** Functor to sort the indices of a feature Matrix by a certain dimension 00193 **/ 00194 template<class DataMatrix> 00195 class SortSamplesByDimensions 00196 { 00197 DataMatrix const & data_; 00198 MultiArrayIndex sortColumn_; 00199 double thresVal_; 00200 public: 00201 00202 SortSamplesByDimensions(DataMatrix const & data, 00203 MultiArrayIndex sortColumn, 00204 double thresVal = 0.0) 00205 : data_(data), 00206 sortColumn_(sortColumn), 00207 thresVal_(thresVal) 00208 {} 00209 00210 void setColumn(MultiArrayIndex sortColumn) 00211 { 00212 sortColumn_ = sortColumn; 00213 } 00214 void setThreshold(double value) 00215 { 00216 thresVal_ = value; 00217 } 00218 00219 bool operator()(MultiArrayIndex l, MultiArrayIndex r) const 00220 { 00221 return data_(l, sortColumn_) < data_(r, sortColumn_); 00222 } 00223 bool operator()(MultiArrayIndex l) const 00224 { 00225 return data_(l, sortColumn_) < thresVal_; 00226 } 00227 }; 00228 00229 template<class DataMatrix> 00230 class DimensionNotEqual 00231 { 00232 DataMatrix const & data_; 00233 MultiArrayIndex sortColumn_; 00234 00235 public: 00236 00237 DimensionNotEqual(DataMatrix const & data, 00238 MultiArrayIndex sortColumn) 00239 : data_(data), 00240 sortColumn_(sortColumn) 00241 {} 00242 00243 void setColumn(MultiArrayIndex sortColumn) 00244 { 00245 sortColumn_ = sortColumn; 00246 } 00247 00248 bool operator()(MultiArrayIndex l, MultiArrayIndex r) const 00249 { 00250 return data_(l, sortColumn_) != data_(r, sortColumn_); 00251 } 00252 }; 00253 00254 template<class DataMatrix> 00255 class SortSamplesByHyperplane 00256 { 00257 DataMatrix const & data_; 00258 Node<i_HyperplaneNode> const & node_; 00259 00260 public: 00261 00262 SortSamplesByHyperplane(DataMatrix const & data, 00263 Node<i_HyperplaneNode> const & node) 00264 : 00265 data_(data), 00266 node_() 00267 {} 00268 00269 /** calculate the distance of a sample point to a hyperplane 00270 */ 00271 double operator[](MultiArrayIndex l) const 00272 { 00273 double result_l = -1 * node_.intercept(); 00274 for(int ii = 0; ii < node_.columns_size(); ++ii) 00275 { 00276 result_l += rowVector(data_, l)[node_.columns_begin()[ii]] 00277 * node_.weights()[ii]; 00278 } 00279 return result_l; 00280 } 00281 00282 bool operator()(MultiArrayIndex l, MultiArrayIndex r) const 00283 { 00284 return (*this)[l] < (*this)[r]; 00285 } 00286 00287 }; 00288 00289 /** makes a Class Histogram given indices in a labels_ array 00290 * usage: 00291 * MultiArrayView<2, T2, C2> labels = makeSomeLabels() 00292 * ArrayVector<int> hist(numberOfLabels(labels), 0); 00293 * RandomForestClassCounter<T2, C2, ArrayVector> counter(labels, hist); 00294 * 00295 * Container<int> indices = getSomeIndices() 00296 * std::for_each(indices, counter); 00297 */ 00298 template <class DataSource, class CountArray> 00299 class RandomForestClassCounter 00300 { 00301 DataSource const & labels_; 00302 CountArray & counts_; 00303 00304 public: 00305 00306 RandomForestClassCounter(DataSource const & labels, 00307 CountArray & counts) 00308 : labels_(labels), 00309 counts_(counts) 00310 { 00311 reset(); 00312 } 00313 00314 void reset() 00315 { 00316 counts_.init(0); 00317 } 00318 00319 void operator()(MultiArrayIndex l) const 00320 { 00321 counts_[labels_[l]] +=1; 00322 } 00323 }; 00324 00325 00326 /** Functor To Calculate the Best possible Split Based on the Gini Index 00327 given Labels and Features along a given Axis 00328 */ 00329 00330 namespace detail 00331 { 00332 template<int N> 00333 class ConstArr 00334 { 00335 public: 00336 double operator[](size_t) const 00337 { 00338 return (double)N; 00339 } 00340 }; 00341 00342 00343 } 00344 00345 00346 00347 00348 /** Functor to calculate the entropy based impurity 00349 */ 00350 class EntropyCriterion 00351 { 00352 public: 00353 /**caculate the weighted gini impurity based on class histogram 00354 * and class weights 00355 */ 00356 template<class Array, class Array2> 00357 double operator() (Array const & hist, 00358 Array2 const & weights, 00359 double total = 1.0) const 00360 { 00361 return impurity(hist, weights, total); 00362 } 00363 00364 /** calculate the gini based impurity based on class histogram 00365 */ 00366 template<class Array> 00367 double operator()(Array const & hist, double total = 1.0) const 00368 { 00369 return impurity(hist, total); 00370 } 00371 00372 /** static version of operator(hist total) 00373 */ 00374 template<class Array> 00375 static double impurity(Array const & hist, double total) 00376 { 00377 return impurity(hist, detail::ConstArr<1>(), total); 00378 } 00379 00380 /** static version of operator(hist, weights, total) 00381 */ 00382 template<class Array, class Array2> 00383 static double impurity (Array const & hist, 00384 Array2 const & weights, 00385 double total) 00386 { 00387 00388 int class_count = hist.size(); 00389 double entropy = 0.0; 00390 if(class_count == 2) 00391 { 00392 double p0 = (hist[0]/total); 00393 double p1 = (hist[1]/total); 00394 entropy = 0 - weights[0]*p0*std::log(p0) - weights[1]*p1*std::log(p1); 00395 } 00396 else 00397 { 00398 for(int ii = 0; ii < class_count; ++ii) 00399 { 00400 double w = weights[ii]; 00401 double pii = hist[ii]/total; 00402 entropy -= w*( pii*std::log(pii)); 00403 } 00404 } 00405 entropy = total * entropy; 00406 return entropy; 00407 } 00408 }; 00409 00410 /** Functor to calculate the gini impurity 00411 */ 00412 class GiniCriterion 00413 { 00414 public: 00415 /**caculate the weighted gini impurity based on class histogram 00416 * and class weights 00417 */ 00418 template<class Array, class Array2> 00419 double operator() (Array const & hist, 00420 Array2 const & weights, 00421 double total = 1.0) const 00422 { 00423 return impurity(hist, weights, total); 00424 } 00425 00426 /** calculate the gini based impurity based on class histogram 00427 */ 00428 template<class Array> 00429 double operator()(Array const & hist, double total = 1.0) const 00430 { 00431 return impurity(hist, total); 00432 } 00433 00434 /** static version of operator(hist total) 00435 */ 00436 template<class Array> 00437 static double impurity(Array const & hist, double total) 00438 { 00439 return impurity(hist, detail::ConstArr<1>(), total); 00440 } 00441 00442 /** static version of operator(hist, weights, total) 00443 */ 00444 template<class Array, class Array2> 00445 static double impurity (Array const & hist, 00446 Array2 const & weights, 00447 double total) 00448 { 00449 00450 int class_count = hist.size(); 00451 double gini = 0.0; 00452 if(class_count == 2) 00453 { 00454 double w = weights[0] * weights[1]; 00455 gini = w * (hist[0] * hist[1] / total); 00456 } 00457 else 00458 { 00459 for(int ii = 0; ii < class_count; ++ii) 00460 { 00461 double w = weights[ii]; 00462 gini += w*( hist[ii]*( 1.0 - w * hist[ii]/total ) ); 00463 } 00464 } 00465 return gini; 00466 } 00467 }; 00468 00469 00470 template <class DataSource, class Impurity= GiniCriterion> 00471 class ImpurityLoss 00472 { 00473 00474 DataSource const & labels_; 00475 ArrayVector<double> counts_; 00476 ArrayVector<double> const class_weights_; 00477 double total_counts_; 00478 Impurity impurity_; 00479 00480 public: 00481 00482 template<class T> 00483 ImpurityLoss(DataSource const & labels, 00484 ProblemSpec<T> const & ext_) 00485 : labels_(labels), 00486 counts_(ext_.class_count_, 0.0), 00487 class_weights_(ext_.class_weights_), 00488 total_counts_(0.0) 00489 {} 00490 00491 void reset() 00492 { 00493 counts_.init(0); 00494 total_counts_ = 0.0; 00495 } 00496 00497 template<class Counts> 00498 double increment_histogram(Counts const & counts) 00499 { 00500 std::transform(counts.begin(), counts.end(), 00501 counts_.begin(), counts_.begin(), 00502 std::plus<double>()); 00503 total_counts_ = std::accumulate( counts_.begin(), 00504 counts_.end(), 00505 0.0); 00506 return impurity_(counts_, class_weights_, total_counts_); 00507 } 00508 00509 template<class Counts> 00510 double decrement_histogram(Counts const & counts) 00511 { 00512 std::transform(counts.begin(), counts.end(), 00513 counts_.begin(), counts_.begin(), 00514 std::minus<double>()); 00515 total_counts_ = std::accumulate( counts_.begin(), 00516 counts_.end(), 00517 0.0); 00518 return impurity_(counts_, class_weights_, total_counts_); 00519 } 00520 00521 template<class Iter> 00522 double increment(Iter begin, Iter end) 00523 { 00524 for(Iter iter = begin; iter != end; ++iter) 00525 { 00526 counts_[labels_(*iter, 0)] +=1.0; 00527 total_counts_ +=1.0; 00528 } 00529 return impurity_(counts_, class_weights_, total_counts_); 00530 } 00531 00532 template<class Iter> 00533 double decrement(Iter const & begin, Iter const & end) 00534 { 00535 for(Iter iter = begin; iter != end; ++iter) 00536 { 00537 counts_[labels_(*iter,0)] -=1.0; 00538 total_counts_ -=1.0; 00539 } 00540 return impurity_(counts_, class_weights_, total_counts_); 00541 } 00542 00543 template<class Iter, class Resp_t> 00544 double init (Iter begin, Iter end, Resp_t resp) 00545 { 00546 reset(); 00547 std::copy(resp.begin(), resp.end(), counts_.begin()); 00548 total_counts_ = std::accumulate(counts_.begin(), counts_.end(), 0.0); 00549 return impurity_(counts_,class_weights_, total_counts_); 00550 } 00551 00552 ArrayVector<double> const & response() 00553 { 00554 return counts_; 00555 } 00556 }; 00557 00558 template <class DataSource> 00559 class RegressionForestCounter 00560 { 00561 typedef MultiArrayShape<2>::type Shp; 00562 DataSource const & labels_; 00563 ArrayVector <double> mean_; 00564 ArrayVector <double> variance_; 00565 ArrayVector <double> tmp_; 00566 size_t count_; 00567 00568 template<class T> 00569 RegressionForestCounter(DataSource const & labels, 00570 ProblemSpec<T> const & ext_) 00571 : 00572 labels_(labels), 00573 mean_(ext_.response_size, 0.0), 00574 variance_(ext_.response_size, 0.0), 00575 tmp_(ext_.response_size), 00576 count_(0) 00577 {} 00578 00579 // west's alorithm for incremental variance 00580 // calculation 00581 template<class Iter> 00582 double increment (Iter begin, Iter end) 00583 { 00584 for(Iter iter = begin; iter != end; ++iter) 00585 { 00586 ++count_; 00587 for(int ii = 0; ii < mean_.size(); ++ii) 00588 tmp_[ii] = labels_(*iter, ii) - mean_[ii]; 00589 double f = 1.0 / count_, 00590 f1 = 1.0 - f; 00591 for(int ii = 0; ii < mean_.size(); ++ii) 00592 mean_[ii] += f*tmp_[ii]; 00593 for(int ii = 0; ii < mean_.size(); ++ii) 00594 variance_[ii] += f1*sq(tmp_[ii]); 00595 } 00596 return std::accumulate(variance_.begin(), 00597 variance_.end(), 00598 0.0, 00599 std::plus<double>()) 00600 /(count_ -1); 00601 } 00602 00603 template<class Iter> 00604 double decrement (Iter begin, Iter end) 00605 { 00606 for(Iter iter = begin; iter != end; ++iter) 00607 { 00608 --count_; 00609 for(int ii = 0; ii < mean_.size(); ++ii) 00610 tmp_[ii] = labels_(*iter, ii) - mean_[ii]; 00611 double f = 1.0 / count_, 00612 f1 = 1.0 + f; 00613 for(int ii = 0; ii < mean_.size(); ++ii) 00614 mean_[ii] -= f*tmp_[ii]; 00615 for(int ii = 0; ii < mean_.size(); ++ii) 00616 variance_[ii] -= f1*sq(tmp_[ii]); 00617 } 00618 return std::accumulate(variance_.begin(), 00619 variance_.end(), 00620 0.0, 00621 std::plus<double>()) 00622 /(count_ -1); 00623 } 00624 00625 template<class Iter, class Resp_t> 00626 double init (Iter begin, Iter end, Resp_t resp) 00627 { 00628 reset(); 00629 return increment(begin, end); 00630 } 00631 00632 00633 ArrayVector<double> const & response() 00634 { 00635 return mean_; 00636 } 00637 00638 void reset() 00639 { 00640 mean_.init(0.0); 00641 variance_.init(0.0); 00642 count_ = 0; 00643 } 00644 }; 00645 00646 template<class Tag, class Datatyp> 00647 struct LossTraits; 00648 00649 struct LSQLoss 00650 {}; 00651 00652 template<class Datatype> 00653 struct LossTraits<GiniCriterion, Datatype> 00654 { 00655 typedef ImpurityLoss<Datatype, GiniCriterion> type; 00656 }; 00657 00658 template<class Datatype> 00659 struct LossTraits<EntropyCriterion, Datatype> 00660 { 00661 typedef ImpurityLoss<Datatype, EntropyCriterion> type; 00662 }; 00663 00664 template<class Datatype> 00665 struct LossTraits<LSQLoss, Datatype> 00666 { 00667 typedef RegressionForestCounter<Datatype> type; 00668 }; 00669 00670 /** Given a column, choose a split that minimizes some loss 00671 */ 00672 template<class LineSearchLossTag> 00673 class BestGiniOfColumn 00674 { 00675 public: 00676 ArrayVector<double> class_weights_; 00677 ArrayVector<double> bestCurrentCounts[2]; 00678 double min_gini_; 00679 ptrdiff_t min_index_; 00680 double min_threshold_; 00681 ProblemSpec<> ext_param_; 00682 00683 BestGiniOfColumn() 00684 {} 00685 00686 template<class T> 00687 BestGiniOfColumn(ProblemSpec<T> const & ext) 00688 : 00689 class_weights_(ext.class_weights_), 00690 ext_param_(ext) 00691 { 00692 bestCurrentCounts[0].resize(ext.class_count_); 00693 bestCurrentCounts[1].resize(ext.class_count_); 00694 } 00695 template<class T> 00696 void set_external_parameters(ProblemSpec<T> const & ext) 00697 { 00698 class_weights_ = ext.class_weights_; 00699 ext_param_ = ext; 00700 bestCurrentCounts[0].resize(ext.class_count_); 00701 bestCurrentCounts[1].resize(ext.class_count_); 00702 } 00703 /** calculate the best gini split along a Feature Column 00704 * \param column, the feature vector - has to support the [] operator 00705 * \param labels, the label vector 00706 * \param begin 00707 * \param end (in and out) 00708 * begin and end iterators to the indices of the 00709 * samples in the current region. 00710 * the range begin - end is sorted by the column supplied 00711 * during function execution. 00712 * \param class_counts 00713 * class histogram of the range. 00714 * 00715 * precondition: begin, end valid range, 00716 * class_counts positive integer valued array with the 00717 * class counts in the current range. 00718 * labels.size() >= max(begin, end); 00719 * postcondition: 00720 * begin, end sorted by column given. 00721 * min_gini_ contains the minimum gini found or 00722 * NumericTraits<double>::max if no split was found. 00723 * min_index_ countains the splitting index in the range 00724 * or invalid data if no split was found. 00725 * BestCirremtcounts[0] and [1] contain the 00726 * class histogram of the left and right region of 00727 * the left and right regions. 00728 */ 00729 template< class DataSourceF_t, 00730 class DataSource_t, 00731 class I_Iter, 00732 class Array> 00733 void operator()(DataSourceF_t const & column, 00734 int g, 00735 DataSource_t const & labels, 00736 I_Iter & begin, 00737 I_Iter & end, 00738 Array const & region_response) 00739 { 00740 std::sort(begin, end, 00741 SortSamplesByDimensions<DataSourceF_t>(column, g)); 00742 typedef typename 00743 LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss; 00744 LineSearchLoss left(labels, ext_param_); 00745 LineSearchLoss right(labels, ext_param_); 00746 00747 00748 00749 min_gini_ = right.init(begin, end, region_response); 00750 min_threshold_ = *begin; 00751 min_index_ = 0; 00752 DimensionNotEqual<DataSourceF_t> comp(column, g); 00753 00754 I_Iter iter = begin; 00755 I_Iter next = std::adjacent_find(iter, end, comp); 00756 while( next != end) 00757 { 00758 00759 double loss = right.decrement(iter, next + 1) 00760 + left.increment(iter , next + 1); 00761 #ifdef CLASSIFIER_TEST 00762 if(loss < min_gini_ && !closeAtTolerance(loss, min_gini_)) 00763 #else 00764 if(loss < min_gini_ ) 00765 #endif 00766 { 00767 bestCurrentCounts[0] = left.response(); 00768 bestCurrentCounts[1] = right.response(); 00769 #ifdef CLASSIFIER_TEST 00770 min_gini_ = loss < min_gini_? loss : min_gini_; 00771 #else 00772 min_gini_ = loss; 00773 #endif 00774 min_index_ = next - begin +1 ; 00775 min_threshold_ = (double(column(*next,g)) + double(column(*(next +1), g)))/2.0; 00776 } 00777 iter = next +1 ; 00778 next = std::adjacent_find(iter, end, comp); 00779 } 00780 } 00781 00782 template<class DataSource_t, class Iter, class Array> 00783 double loss_of_region(DataSource_t const & labels, 00784 Iter & begin, 00785 Iter & end, 00786 Array const & region_response) const 00787 { 00788 typedef typename 00789 LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss; 00790 LineSearchLoss region_loss(labels, ext_param_); 00791 return 00792 region_loss.init(begin, end, region_response); 00793 } 00794 00795 }; 00796 00797 00798 /** Chooses mtry columns ad applys ColumnDecisionFunctor to each of the 00799 * columns. Then Chooses the column that is best 00800 */ 00801 template<class ColumnDecisionFunctor, class Tag = ClassificationTag> 00802 class ThresholdSplit: public SplitBase<Tag> 00803 { 00804 public: 00805 00806 00807 typedef SplitBase<Tag> SB; 00808 00809 ArrayVector<Int32> splitColumns; 00810 ColumnDecisionFunctor bgfunc; 00811 00812 double region_gini_; 00813 ArrayVector<double> min_gini_; 00814 ArrayVector<ptrdiff_t> min_indices_; 00815 ArrayVector<double> min_thresholds_; 00816 00817 int bestSplitIndex; 00818 00819 double minGini() const 00820 { 00821 return min_gini_[bestSplitIndex]; 00822 } 00823 int bestSplitColumn() const 00824 { 00825 return splitColumns[bestSplitIndex]; 00826 } 00827 double bestSplitThreshold() const 00828 { 00829 return min_thresholds_[bestSplitIndex]; 00830 } 00831 00832 template<class T> 00833 void set_external_parameters(ProblemSpec<T> const & in) 00834 { 00835 SB::set_external_parameters(in); 00836 bgfunc.set_external_parameters( SB::ext_param_); 00837 int featureCount_ = SB::ext_param_.column_count_; 00838 splitColumns.resize(featureCount_); 00839 for(int k=0; k<featureCount_; ++k) 00840 splitColumns[k] = k; 00841 min_gini_.resize(featureCount_); 00842 min_indices_.resize(featureCount_); 00843 min_thresholds_.resize(featureCount_); 00844 } 00845 00846 00847 template<class T, class C, class T2, class C2, class Region, class Random> 00848 int findBestSplit(MultiArrayView<2, T, C> features, 00849 MultiArrayView<2, T2, C2> labels, 00850 Region & region, 00851 ArrayVector<Region>& childRegions, 00852 Random & randint) 00853 { 00854 00855 typedef typename Region::IndexIterator IndexIterator; 00856 if(region.size() == 0) 00857 { 00858 std::cerr << "SplitFunctor::findBestSplit(): stackentry with 0 examples encountered\n" 00859 "continuing learning process...."; 00860 } 00861 // calculate things that haven't been calculated yet. 00862 00863 if(std::accumulate(region.classCounts().begin(), 00864 region.classCounts().end(), 0) != region.size()) 00865 { 00866 RandomForestClassCounter< MultiArrayView<2,T2, C2>, 00867 ArrayVector<double> > 00868 counter(labels, region.classCounts()); 00869 std::for_each( region.begin(), region.end(), counter); 00870 region.classCountsIsValid = true; 00871 } 00872 00873 // Is the region pure already? 00874 region_gini_ = bgfunc.loss_of_region(labels, 00875 region.begin(), 00876 region.end(), 00877 region.classCounts()); 00878 if(region_gini_ <= SB::ext_param_.precision_) 00879 return makeTerminalNode(features, labels, region, randint); 00880 00881 // select columns to be tried. 00882 for(int ii = 0; ii < SB::ext_param_.actual_mtry_; ++ii) 00883 std::swap(splitColumns[ii], 00884 splitColumns[ii+ randint(features.shape(1) - ii)]); 00885 00886 // find the best gini index 00887 bestSplitIndex = 0; 00888 double current_min_gini = region_gini_; 00889 int num2try = features.shape(1); 00890 for(int k=0; k<num2try; ++k) 00891 { 00892 //this functor does all the work 00893 bgfunc(features, 00894 splitColumns[k], 00895 labels, 00896 region.begin(), region.end(), 00897 region.classCounts()); 00898 min_gini_[k] = bgfunc.min_gini_; 00899 min_indices_[k] = bgfunc.min_index_; 00900 min_thresholds_[k] = bgfunc.min_threshold_; 00901 #ifdef CLASSIFIER_TEST 00902 if( bgfunc.min_gini_ < current_min_gini 00903 && !closeAtTolerance(bgfunc.min_gini_, current_min_gini)) 00904 #else 00905 if(bgfunc.min_gini_ < current_min_gini) 00906 #endif 00907 { 00908 current_min_gini = bgfunc.min_gini_; 00909 childRegions[0].classCounts() = bgfunc.bestCurrentCounts[0]; 00910 childRegions[1].classCounts() = bgfunc.bestCurrentCounts[1]; 00911 childRegions[0].classCountsIsValid = true; 00912 childRegions[1].classCountsIsValid = true; 00913 00914 bestSplitIndex = k; 00915 num2try = SB::ext_param_.actual_mtry_; 00916 } 00917 } 00918 00919 // did not find any suitable split 00920 if(closeAtTolerance(current_min_gini, region_gini_)) 00921 return makeTerminalNode(features, labels, region, randint); 00922 00923 //create a Node for output 00924 Node<i_ThresholdNode> node(SB::t_data, SB::p_data); 00925 SB::node_ = node; 00926 node.threshold() = min_thresholds_[bestSplitIndex]; 00927 node.column() = splitColumns[bestSplitIndex]; 00928 00929 // partition the range according to the best dimension 00930 SortSamplesByDimensions<MultiArrayView<2, T, C> > 00931 sorter(features, node.column(), node.threshold()); 00932 IndexIterator bestSplit = 00933 std::partition(region.begin(), region.end(), sorter); 00934 // Save the ranges of the child stack entries. 00935 childRegions[0].setRange( region.begin() , bestSplit ); 00936 childRegions[0].rule = region.rule; 00937 childRegions[0].rule.push_back(std::make_pair(1, 1.0)); 00938 childRegions[1].setRange( bestSplit , region.end() ); 00939 childRegions[1].rule = region.rule; 00940 childRegions[1].rule.push_back(std::make_pair(1, 1.0)); 00941 00942 return i_ThresholdNode; 00943 } 00944 }; 00945 00946 typedef ThresholdSplit<BestGiniOfColumn<GiniCriterion> > GiniSplit; 00947 typedef ThresholdSplit<BestGiniOfColumn<EntropyCriterion> > EntropySplit; 00948 typedef ThresholdSplit<BestGiniOfColumn<LSQLoss>, RegressionTag> RegressionSplit; 00949 00950 namespace rf 00951 { 00952 00953 /** This namespace contains additional Splitfunctors. 00954 * 00955 * The Split functor classes are designed in a modular fashion because new split functors may 00956 * share a lot of code with existing ones. 00957 * 00958 * ThresholdSplit implements the functionality needed for any split functor, that makes its 00959 * decision via one dimensional axis-parallel cuts. The Template parameter defines how the split 00960 * along one dimension is chosen. 00961 * 00962 * The BestGiniOfColumn class chooses a split that minimizes one of the Loss functions supplied 00963 * (GiniCriterion for classification and LSQLoss for regression). Median chooses the Split in a 00964 * kD tree fashion. 00965 * 00966 * 00967 * Currently defined typedefs: 00968 * \code 00969 * typedef ThresholdSplit<BestGiniOfColumn<GiniCriterion> > GiniSplit; 00970 * typedef ThresholdSplit<BestGiniOfColumn<LSQLoss>, RegressionTag> RegressionSplit; 00971 * typedef ThresholdSplit<Median> MedianSplit; 00972 * \endcode 00973 */ 00974 namespace split 00975 { 00976 00977 /** This Functor chooses the median value of a column 00978 */ 00979 class Median 00980 { 00981 public: 00982 00983 typedef GiniCriterion LineSearchLossTag; 00984 ArrayVector<double> class_weights_; 00985 ArrayVector<double> bestCurrentCounts[2]; 00986 double min_gini_; 00987 ptrdiff_t min_index_; 00988 double min_threshold_; 00989 ProblemSpec<> ext_param_; 00990 00991 Median() 00992 {} 00993 00994 template<class T> 00995 Median(ProblemSpec<T> const & ext) 00996 : 00997 class_weights_(ext.class_weights_), 00998 ext_param_(ext) 00999 { 01000 bestCurrentCounts[0].resize(ext.class_count_); 01001 bestCurrentCounts[1].resize(ext.class_count_); 01002 } 01003 01004 template<class T> 01005 void set_external_parameters(ProblemSpec<T> const & ext) 01006 { 01007 class_weights_ = ext.class_weights_; 01008 ext_param_ = ext; 01009 bestCurrentCounts[0].resize(ext.class_count_); 01010 bestCurrentCounts[1].resize(ext.class_count_); 01011 } 01012 01013 template< class DataSourceF_t, 01014 class DataSource_t, 01015 class I_Iter, 01016 class Array> 01017 void operator()(DataSourceF_t const & column, 01018 DataSource_t const & labels, 01019 I_Iter & begin, 01020 I_Iter & end, 01021 Array const & region_response) 01022 { 01023 std::sort(begin, end, 01024 SortSamplesByDimensions<DataSourceF_t>(column, 0)); 01025 typedef typename 01026 LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss; 01027 LineSearchLoss left(labels, ext_param_); 01028 LineSearchLoss right(labels, ext_param_); 01029 right.init(begin, end, region_response); 01030 01031 min_gini_ = NumericTraits<double>::max(); 01032 min_index_ = floor(double(end - begin)/2.0); 01033 min_threshold_ = column[*(begin + min_index_)]; 01034 SortSamplesByDimensions<DataSourceF_t> 01035 sorter(column, 0, min_threshold_); 01036 I_Iter part = std::partition(begin, end, sorter); 01037 DimensionNotEqual<DataSourceF_t> comp(column, 0); 01038 if(part == begin) 01039 { 01040 part= std::adjacent_find(part, end, comp)+1; 01041 01042 } 01043 if(part >= end) 01044 { 01045 return; 01046 } 01047 else 01048 { 01049 min_threshold_ = column[*part]; 01050 } 01051 min_gini_ = right.decrement(begin, part) 01052 + left.increment(begin , part); 01053 01054 bestCurrentCounts[0] = left.response(); 01055 bestCurrentCounts[1] = right.response(); 01056 01057 min_index_ = part - begin; 01058 } 01059 01060 template<class DataSource_t, class Iter, class Array> 01061 double loss_of_region(DataSource_t const & labels, 01062 Iter & begin, 01063 Iter & end, 01064 Array const & region_response) const 01065 { 01066 typedef typename 01067 LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss; 01068 LineSearchLoss region_loss(labels, ext_param_); 01069 return 01070 region_loss.init(begin, end, region_response); 01071 } 01072 01073 }; 01074 01075 typedef ThresholdSplit<Median> MedianSplit; 01076 01077 01078 /** This Functor chooses a random value of a column 01079 */ 01080 class RandomSplitOfColumn 01081 { 01082 public: 01083 01084 typedef GiniCriterion LineSearchLossTag; 01085 ArrayVector<double> class_weights_; 01086 ArrayVector<double> bestCurrentCounts[2]; 01087 double min_gini_; 01088 ptrdiff_t min_index_; 01089 double min_threshold_; 01090 ProblemSpec<> ext_param_; 01091 typedef RandomMT19937 Random_t; 01092 Random_t random; 01093 01094 RandomSplitOfColumn() 01095 {} 01096 01097 template<class T> 01098 RandomSplitOfColumn(ProblemSpec<T> const & ext) 01099 : 01100 class_weights_(ext.class_weights_), 01101 ext_param_(ext), 01102 random(RandomSeed) 01103 { 01104 bestCurrentCounts[0].resize(ext.class_count_); 01105 bestCurrentCounts[1].resize(ext.class_count_); 01106 } 01107 01108 template<class T> 01109 RandomSplitOfColumn(ProblemSpec<T> const & ext, Random_t & random_) 01110 : 01111 class_weights_(ext.class_weights_), 01112 ext_param_(ext), 01113 random(random_) 01114 { 01115 bestCurrentCounts[0].resize(ext.class_count_); 01116 bestCurrentCounts[1].resize(ext.class_count_); 01117 } 01118 01119 template<class T> 01120 void set_external_parameters(ProblemSpec<T> const & ext) 01121 { 01122 class_weights_ = ext.class_weights_; 01123 ext_param_ = ext; 01124 bestCurrentCounts[0].resize(ext.class_count_); 01125 bestCurrentCounts[1].resize(ext.class_count_); 01126 } 01127 01128 template< class DataSourceF_t, 01129 class DataSource_t, 01130 class I_Iter, 01131 class Array> 01132 void operator()(DataSourceF_t const & column, 01133 DataSource_t const & labels, 01134 I_Iter & begin, 01135 I_Iter & end, 01136 Array const & region_response) 01137 { 01138 std::sort(begin, end, 01139 SortSamplesByDimensions<DataSourceF_t>(column, 0)); 01140 typedef typename 01141 LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss; 01142 LineSearchLoss left(labels, ext_param_); 01143 LineSearchLoss right(labels, ext_param_); 01144 right.init(begin, end, region_response); 01145 01146 01147 min_gini_ = NumericTraits<double>::max(); 01148 01149 min_index_ = begin + random.uniformInt(end -begin); 01150 min_threshold_ = column[*(begin + min_index_)]; 01151 SortSamplesByDimensions<DataSourceF_t> 01152 sorter(column, 0, min_threshold_); 01153 I_Iter part = std::partition(begin, end, sorter); 01154 DimensionNotEqual<DataSourceF_t> comp(column, 0); 01155 if(part == begin) 01156 { 01157 part= std::adjacent_find(part, end, comp)+1; 01158 01159 } 01160 if(part >= end) 01161 { 01162 return; 01163 } 01164 else 01165 { 01166 min_threshold_ = column[*part]; 01167 } 01168 min_gini_ = right.decrement(begin, part) 01169 + left.increment(begin , part); 01170 01171 bestCurrentCounts[0] = left.response(); 01172 bestCurrentCounts[1] = right.response(); 01173 01174 min_index_ = part - begin; 01175 } 01176 01177 template<class DataSource_t, class Iter, class Array> 01178 double loss_of_region(DataSource_t const & labels, 01179 Iter & begin, 01180 Iter & end, 01181 Array const & region_response) const 01182 { 01183 typedef typename 01184 LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss; 01185 LineSearchLoss region_loss(labels, ext_param_); 01186 return 01187 region_loss.init(begin, end, region_response); 01188 } 01189 01190 }; 01191 01192 typedef ThresholdSplit<RandomSplitOfColumn> RandomSplit; 01193 } 01194 } 01195 01196 01197 } //namespace vigra 01198 #endif // VIGRA_RANDOM_FOREST_SPLIT_HXX
© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de) |
html generated using doxygen and Python
|