[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

vigra/random_forest/rf_split.hxx VIGRA

00001 /************************************************************************/
00002 /*                                                                      */
00003 /*        Copyright 2008-2009 by  Ullrich Koethe and Rahul Nair         */
00004 /*                                                                      */
00005 /*    This file is part of the VIGRA computer vision library.           */
00006 /*    The VIGRA Website is                                              */
00007 /*        http://hci.iwr.uni-heidelberg.de/vigra/                       */
00008 /*    Please direct questions, bug reports, and contributions to        */
00009 /*        ullrich.koethe@iwr.uni-heidelberg.de    or                    */
00010 /*        vigra@informatik.uni-hamburg.de                               */
00011 /*                                                                      */
00012 /*    Permission is hereby granted, free of charge, to any person       */
00013 /*    obtaining a copy of this software and associated documentation    */
00014 /*    files (the "Software"), to deal in the Software without           */
00015 /*    restriction, including without limitation the rights to use,      */
00016 /*    copy, modify, merge, publish, distribute, sublicense, and/or      */
00017 /*    sell copies of the Software, and to permit persons to whom the    */
00018 /*    Software is furnished to do so, subject to the following          */
00019 /*    conditions:                                                       */
00020 /*                                                                      */
00021 /*    The above copyright notice and this permission notice shall be    */
00022 /*    included in all copies or substantial portions of the             */
00023 /*    Software.                                                         */
00024 /*                                                                      */
00025 /*    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND    */
00026 /*    EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES   */
00027 /*    OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND          */
00028 /*    NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT       */
00029 /*    HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,      */
00030 /*    WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING      */
00031 /*    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR     */
00032 /*    OTHER DEALINGS IN THE SOFTWARE.                                   */
00033 /*                                                                      */
00034 /************************************************************************/
00035 #ifndef VIGRA_RANDOM_FOREST_SPLIT_HXX
00036 #define VIGRA_RANDOM_FOREST_SPLIT_HXX
00037 #include <algorithm>
00038 #include <cstddef>
00039 #include <map>
00040 #include <numeric>
00041 #include <math.h>
00042 #include "../mathutil.hxx"
00043 #include "../array_vector.hxx"
00044 #include "../sized_int.hxx"
00045 #include "../matrix.hxx"
00046 #include "../random.hxx"
00047 #include "../functorexpression.hxx"
00048 #include "rf_nodeproxy.hxx"
00049 //#include "rf_sampling.hxx"
00050 #include "rf_region.hxx"
00051 //#include "../hokashyap.hxx"
00052 //#include "vigra/rf_helpers.hxx"
00053 
00054 namespace vigra
00055 {
00056 
00057 // Incomplete Class to ensure that findBestSplit is always implemented in
00058 // the derived classes of SplitBase
00059 class CompileTimeError;
00060 
00061 
00062 namespace detail
00063 {
00064     template<class Tag>
00065     class Normalise
00066     {
00067       public:
00068         template<class Iter>
00069         static void exec(Iter begin, Iter  end)
00070         {}
00071     };
00072 
00073     template<>
00074     class Normalise<ClassificationTag>
00075     {
00076       public:
00077         template<class Iter>
00078         static void exec (Iter begin, Iter end)
00079         {
00080             double bla = std::accumulate(begin, end, 0.0);
00081             for(int ii = 0; ii < end - begin; ++ii)
00082                 begin[ii] = begin[ii]/bla ;
00083         }
00084     };
00085 }
00086 
00087 
00088 /** Base Class for all SplitFunctors used with the \ref RandomForest class
00089     defines the interface used while learning a tree.
00090 **/
00091 template<class Tag>
00092 class SplitBase
00093 {
00094   public:
00095 
00096     typedef Tag           RF_Tag;
00097     typedef DT_StackEntry<ArrayVectorView<Int32>::iterator>
00098                                         StackEntry_t;
00099 
00100     ProblemSpec<>           ext_param_;
00101 
00102     NodeBase::T_Container_type          t_data;
00103     NodeBase::P_Container_type          p_data;
00104 
00105     NodeBase                            node_;
00106 
00107     /** returns the DecisionTree Node created by
00108         \ref SplitBase::findBestSplit() or \ref SplitBase::makeTerminalNode().
00109     **/
00110 
00111     template<class T>
00112     void set_external_parameters(ProblemSpec<T> const & in)
00113     {
00114         ext_param_ = in;
00115         t_data.push_back(in.column_count_);
00116         t_data.push_back(in.class_count_);
00117     }
00118 
00119     NodeBase & createNode() 
00120     {
00121         return node_;
00122     }
00123 
00124     int classCount() const
00125     {
00126         return int(t_data[1]);
00127     }
00128 
00129     int featureCount() const
00130     {
00131         return int(t_data[0]);
00132     }
00133 
00134     /** resets internal data. Should always be called before
00135         calling findBestSplit or makeTerminalNode
00136     **/
00137     void reset()
00138     {
00139         t_data.resize(2);
00140         p_data.resize(0);
00141     }
00142 
00143 
00144     /** findBestSplit has to be re-implemented in derived split functor.
00145         The defaut implementation only insures that a CompileTime error is issued
00146         if no such method was defined.
00147     **/
00148 
00149     template<class T, class C, class T2, class C2, class Region, class Random>
00150     int findBestSplit(MultiArrayView<2, T, C> features,
00151                       MultiArrayView<2, T2, C2> labels,
00152                       Region region,
00153                       ArrayVector<Region> childs,
00154                       Random randint)
00155     {
00156 #ifndef __clang__   
00157         // FIXME: This compile-time checking trick does not work for clang.
00158         CompileTimeError SplitFunctor__findBestSplit_member_was_not_defined;
00159 #endif
00160         return 0;
00161     }
00162 
00163     /** Default action for creating a terminal Node.
00164         sets the Class probability of the remaining region according to
00165         the class histogram
00166     **/
00167     template<class T, class C, class T2,class C2, class Region, class Random>
00168     int makeTerminalNode(MultiArrayView<2, T, C> features,
00169                       MultiArrayView<2, T2, C2>  labels,
00170                       Region &                   region,
00171                       Random                     randint)
00172     {
00173         Node<e_ConstProbNode> ret(t_data, p_data);
00174         node_ = ret;
00175         if(ext_param_.class_weights_.size() != region.classCounts().size())
00176         {
00177             std::copy(region.classCounts().begin(),
00178                       region.classCounts().end(),
00179                       ret.prob_begin());
00180         }
00181         else
00182         {
00183             std::transform(region.classCounts().begin(),
00184                            region.classCounts().end(),
00185                            ext_param_.class_weights_.begin(),
00186                            ret.prob_begin(), std::multiplies<double>());
00187         }
00188         detail::Normalise<RF_Tag>::exec(ret.prob_begin(), ret.prob_end());
00189 //        std::copy(ret.prob_begin(), ret.prob_end(), std::ostream_iterator<double>(std::cerr, ", " ));
00190 //        std::cerr << std::endl;
00191         ret.weights() = region.size();  
00192         return e_ConstProbNode;
00193     }
00194 
00195 
00196 };
00197 
00198 /** Functor to sort the indices of a feature Matrix by a certain dimension
00199 **/
00200 template<class DataMatrix>
00201 class SortSamplesByDimensions
00202 {
00203     DataMatrix const & data_;
00204     MultiArrayIndex sortColumn_;
00205     double thresVal_;
00206   public:
00207 
00208     SortSamplesByDimensions(DataMatrix const & data, 
00209                             MultiArrayIndex sortColumn,
00210                             double thresVal = 0.0)
00211     : data_(data),
00212       sortColumn_(sortColumn),
00213       thresVal_(thresVal)
00214     {}
00215 
00216     void setColumn(MultiArrayIndex sortColumn)
00217     {
00218         sortColumn_ = sortColumn;
00219     }
00220     void setThreshold(double value)
00221     {
00222         thresVal_ = value; 
00223     }
00224 
00225     bool operator()(MultiArrayIndex l, MultiArrayIndex r) const
00226     {
00227         return data_(l, sortColumn_) < data_(r, sortColumn_);
00228     }
00229     bool operator()(MultiArrayIndex l) const
00230     {
00231         return data_(l, sortColumn_) < thresVal_;
00232     }
00233 };
00234 
00235 template<class DataMatrix>
00236 class DimensionNotEqual
00237 {
00238     DataMatrix const & data_;
00239     MultiArrayIndex sortColumn_;
00240 
00241   public:
00242 
00243     DimensionNotEqual(DataMatrix const & data, 
00244                             MultiArrayIndex sortColumn)
00245     : data_(data),
00246       sortColumn_(sortColumn)
00247     {}
00248 
00249     void setColumn(MultiArrayIndex sortColumn)
00250     {
00251         sortColumn_ = sortColumn;
00252     }
00253 
00254     bool operator()(MultiArrayIndex l, MultiArrayIndex r) const
00255     {
00256         return data_(l, sortColumn_) != data_(r, sortColumn_);
00257     }
00258 };
00259 
00260 template<class DataMatrix>
00261 class SortSamplesByHyperplane
00262 {
00263     DataMatrix const & data_;
00264     Node<i_HyperplaneNode> const & node_;
00265 
00266   public:
00267 
00268     SortSamplesByHyperplane(DataMatrix              const & data, 
00269                             Node<i_HyperplaneNode>  const & node)
00270     :       
00271             data_(data), 
00272             node_(node)
00273     {}
00274 
00275     /** calculate the distance of a sample point to a hyperplane
00276      */
00277     double operator[](MultiArrayIndex l) const
00278     {
00279         double result_l = -1 * node_.intercept();
00280         for(int ii = 0; ii < node_.columns_size(); ++ii)
00281         {
00282             result_l +=     rowVector(data_, l)[node_.columns_begin()[ii]] 
00283                         *   node_.weights()[ii];
00284         }
00285         return result_l;
00286     }
00287 
00288     bool operator()(MultiArrayIndex l, MultiArrayIndex r) const
00289     {
00290         return (*this)[l]  < (*this)[r];
00291     }
00292 
00293 };
00294 
00295 /** makes a Class Histogram given indices in a labels_ array
00296  *  usage: 
00297  *      MultiArrayView<2, T2, C2> labels = makeSomeLabels()
00298  *      ArrayVector<int> hist(numberOfLabels(labels), 0);
00299  *      RandomForestClassCounter<T2, C2, ArrayVector> counter(labels, hist);
00300  *
00301  *      Container<int> indices = getSomeIndices()
00302  *      std::for_each(indices, counter);
00303  */
00304 template <class DataSource, class CountArray>
00305 class RandomForestClassCounter
00306 {
00307     DataSource  const &     labels_;
00308     CountArray        &     counts_;
00309 
00310   public:
00311 
00312     RandomForestClassCounter(DataSource  const & labels, 
00313                              CountArray & counts)
00314     : labels_(labels),
00315       counts_(counts)
00316     {
00317         reset();
00318     }
00319 
00320     void reset()
00321     {
00322         counts_.init(0);
00323     }
00324 
00325     void operator()(MultiArrayIndex l) const
00326     {
00327         counts_[labels_[l]] +=1;
00328     }
00329 };
00330 
00331 
00332 /** Functor To Calculate the Best possible Split Based on the Gini Index
00333     given Labels and Features along a given Axis
00334 */
00335 
00336 namespace detail
00337 {
00338     template<int N>
00339     class ConstArr
00340     {
00341     public:
00342         double operator[](size_t) const
00343         {
00344             return (double)N;
00345         }
00346     };
00347 
00348 
00349 }
00350 
00351 
00352 
00353 
00354 /** Functor to calculate the entropy based impurity
00355  */
00356 class EntropyCriterion
00357 {
00358 public:
00359     /**calculate the weighted gini impurity based on class histogram
00360      * and class weights
00361      */
00362     template<class Array, class Array2>
00363     double operator()        (Array     const & hist, 
00364                               Array2    const & weights, 
00365                               double            total = 1.0) const
00366     {
00367         return impurity(hist, weights, total);
00368     }
00369     
00370     /** calculate the gini based impurity based on class histogram
00371      */
00372     template<class Array>
00373     double operator()(Array const & hist, double total = 1.0) const
00374     {
00375         return impurity(hist, total);
00376     }
00377     
00378     /** static version of operator(hist total)
00379      */
00380     template<class Array>
00381     static double impurity(Array const & hist, double total)
00382     {
00383         return impurity(hist, detail::ConstArr<1>(), total);
00384     }
00385 
00386     /** static version of operator(hist, weights, total)
00387      */
00388     template<class Array, class Array2>
00389     static double impurity   (Array     const & hist, 
00390                               Array2    const & weights, 
00391                               double            total)
00392     {
00393 
00394         int     class_count     = hist.size();
00395         double  entropy            = 0.0;
00396         if(class_count == 2)
00397         {
00398             double p0           = (hist[0]/total);
00399             double p1           = (hist[1]/total);
00400             entropy             = 0 - weights[0]*p0*std::log(p0) - weights[1]*p1*std::log(p1);
00401         }
00402         else
00403         {
00404             for(int ii = 0; ii < class_count; ++ii)
00405             {
00406                 double w        = weights[ii];
00407                 double pii      = hist[ii]/total;
00408                 entropy         -= w*( pii*std::log(pii));
00409             }
00410         }
00411         entropy             = total * entropy;
00412         return entropy; 
00413     }
00414 };
00415 
00416 /** Functor to calculate the gini impurity
00417  */
00418 class GiniCriterion
00419 {
00420 public:
00421     /**calculate the weighted gini impurity based on class histogram
00422      * and class weights
00423      */
00424     template<class Array, class Array2>
00425     double operator()        (Array     const & hist, 
00426                               Array2    const & weights, 
00427                               double            total = 1.0) const
00428     {
00429         return impurity(hist, weights, total);
00430     }
00431     
00432     /** calculate the gini based impurity based on class histogram
00433      */
00434     template<class Array>
00435     double operator()(Array const & hist, double total = 1.0) const
00436     {
00437         return impurity(hist, total);
00438     }
00439     
00440     /** static version of operator(hist total)
00441      */
00442     template<class Array>
00443     static double impurity(Array const & hist, double total)
00444     {
00445         return impurity(hist, detail::ConstArr<1>(), total);
00446     }
00447 
00448     /** static version of operator(hist, weights, total)
00449      */
00450     template<class Array, class Array2>
00451     static double impurity   (Array     const & hist, 
00452                               Array2    const & weights, 
00453                               double            total)
00454     {
00455 
00456         int     class_count     = hist.size();
00457         double  gini            = 0.0;
00458         if(class_count == 2)
00459         {
00460             double w            = weights[0] * weights[1];
00461             gini                = w * (hist[0] * hist[1] / total);
00462         }
00463         else
00464         {
00465             for(int ii = 0; ii < class_count; ++ii)
00466             {
00467                 double w        = weights[ii];
00468                 gini           += w*( hist[ii]*( 1.0 - w * hist[ii]/total ) );
00469             }
00470         }
00471         return gini; 
00472     }
00473 };
00474 
00475 
00476 template <class DataSource, class Impurity= GiniCriterion>
00477 class ImpurityLoss
00478 {
00479 
00480     DataSource  const &         labels_;
00481     ArrayVector<double>        counts_;
00482     ArrayVector<double> const  class_weights_;
00483     double                      total_counts_;
00484     Impurity                    impurity_;
00485 
00486   public:
00487 
00488     template<class T>
00489     ImpurityLoss(DataSource  const & labels, 
00490                                 ProblemSpec<T> const & ext_)
00491     : labels_(labels),
00492       counts_(ext_.class_count_, 0.0),
00493       class_weights_(ext_.class_weights_),
00494       total_counts_(0.0)
00495     {}
00496 
00497     void reset()
00498     {
00499         counts_.init(0);
00500         total_counts_ = 0.0;
00501     }
00502 
00503     template<class Counts>
00504     double increment_histogram(Counts const & counts)
00505     {
00506         std::transform(counts.begin(), counts.end(),
00507                        counts_.begin(), counts_.begin(),
00508                        std::plus<double>());
00509         total_counts_ = std::accumulate( counts_.begin(), 
00510                                          counts_.end(),
00511                                          0.0);
00512         return impurity_(counts_, class_weights_, total_counts_);
00513     }
00514 
00515     template<class Counts>
00516     double decrement_histogram(Counts const & counts)
00517     {
00518         std::transform(counts.begin(), counts.end(),
00519                        counts_.begin(), counts_.begin(),
00520                        std::minus<double>());
00521         total_counts_ = std::accumulate( counts_.begin(), 
00522                                          counts_.end(),
00523                                          0.0);
00524         return impurity_(counts_, class_weights_, total_counts_);
00525     }
00526 
00527     template<class Iter>
00528     double increment(Iter begin, Iter end)
00529     {
00530         for(Iter iter = begin; iter != end; ++iter)
00531         {
00532             counts_[labels_(*iter, 0)] +=1.0;
00533             total_counts_ +=1.0;
00534         }
00535         return impurity_(counts_, class_weights_, total_counts_);
00536     }
00537 
00538     template<class Iter>
00539     double decrement(Iter const &  begin, Iter const & end)
00540     {
00541         for(Iter iter = begin; iter != end; ++iter)
00542         {
00543             counts_[labels_(*iter,0)] -=1.0;
00544             total_counts_ -=1.0;
00545         }
00546         return impurity_(counts_, class_weights_, total_counts_);
00547     }
00548 
00549     template<class Iter, class Resp_t>
00550     double init (Iter begin, Iter end, Resp_t resp)
00551     {
00552         reset();
00553         std::copy(resp.begin(), resp.end(), counts_.begin());
00554         total_counts_ = std::accumulate(counts_.begin(), counts_.end(), 0.0); 
00555         return impurity_(counts_,class_weights_, total_counts_);
00556     }
00557     
00558     ArrayVector<double> const & response()
00559     {
00560         return counts_;
00561     }
00562 };
00563     
00564     
00565     
00566     template <class DataSource>
00567     class RegressionForestCounter
00568     {
00569     public:
00570         typedef MultiArrayShape<2>::type Shp;
00571         DataSource const &      labels_;
00572         ArrayVector <double>    mean_;
00573         ArrayVector <double>    variance_;
00574         ArrayVector <double>    tmp_;
00575         size_t                  count_;
00576         int*                    end_;
00577         
00578         template<class T>
00579         RegressionForestCounter(DataSource const & labels, 
00580                                 ProblemSpec<T> const & ext_)
00581         :
00582         labels_(labels),
00583         mean_(ext_.response_size_, 0.0),
00584         variance_(ext_.response_size_, 0.0),
00585         tmp_(ext_.response_size_),
00586         count_(0)
00587         {}
00588         
00589         template<class Iter>
00590         double increment (Iter begin, Iter end)
00591         {
00592             for(Iter iter = begin; iter != end; ++iter)
00593             {
00594                 ++count_;
00595                 for(unsigned int ii = 0; ii < mean_.size(); ++ii)
00596                     tmp_[ii] = labels_(*iter, ii) - mean_[ii]; 
00597                 double f  = 1.0 / count_,
00598                 f1 = 1.0 - f;
00599                 for(unsigned int ii = 0; ii < mean_.size(); ++ii)
00600                     mean_[ii] += f*tmp_[ii]; 
00601                 for(unsigned int ii = 0; ii < mean_.size(); ++ii)
00602                     variance_[ii] += f1*sq(tmp_[ii]);
00603             }
00604             double res = std::accumulate(variance_.begin(), 
00605                                          variance_.end(),
00606                                          0.0,
00607                                          std::plus<double>());
00608             //std::cerr << res << "  ) = ";
00609             return res;
00610         }
00611         
00612         template<class Iter>      //This is BROKEN
00613         double decrement (Iter begin, Iter end)
00614         {
00615             for(Iter iter = begin; iter != end; ++iter)
00616             {
00617                 --count_;
00618             }
00619 
00620             begin = end;
00621             end   = end + count_;
00622             
00623 
00624             for(unsigned int ii = 0; ii < mean_.size(); ++ii)
00625             {
00626                 mean_[ii] = 0;        
00627                 for(Iter iter = begin; iter != end; ++iter)
00628                 {
00629                     mean_[ii] += labels_(*iter, ii);
00630                 }
00631                 mean_[ii] /= count_;
00632                 variance_[ii] = 0;
00633                 for(Iter iter = begin; iter != end; ++iter)
00634                 {
00635                     variance_[ii] += (labels_(*iter, ii) - mean_[ii])*(labels_(*iter, ii) - mean_[ii]);
00636                 }
00637             }
00638             double res = std::accumulate(variance_.begin(), 
00639                                          variance_.end(),
00640                                          0.0,
00641                                          std::plus<double>());
00642             //std::cerr << res << "  ) = ";
00643             return res;
00644         }
00645 
00646         
00647         template<class Iter, class Resp_t>
00648         double init (Iter begin, Iter end, Resp_t resp)
00649         {
00650             reset();
00651             return this->increment(begin, end);
00652             
00653         }
00654         
00655         
00656         ArrayVector<double> const & response()
00657         {
00658             return mean_;
00659         }
00660         
00661         void reset()
00662         {
00663             mean_.init(0.0);
00664             variance_.init(0.0);
00665             count_ = 0; 
00666         }
00667     };
00668     
00669     
00670 template <class DataSource>
00671 class RegressionForestCounter2
00672 {
00673 public:
00674     typedef MultiArrayShape<2>::type Shp;
00675     DataSource const &      labels_;
00676     ArrayVector <double>    mean_;
00677     ArrayVector <double>    variance_;
00678     ArrayVector <double>    tmp_;
00679     size_t                  count_;
00680 
00681     template<class T>
00682     RegressionForestCounter2(DataSource const & labels, 
00683                             ProblemSpec<T> const & ext_)
00684     :
00685         labels_(labels),
00686         mean_(ext_.response_size_, 0.0),
00687         variance_(ext_.response_size_, 0.0),
00688         tmp_(ext_.response_size_),
00689         count_(0)
00690     {}
00691     
00692     template<class Iter>
00693     double increment (Iter begin, Iter end)
00694     {
00695         for(Iter iter = begin; iter != end; ++iter)
00696         {
00697             ++count_;
00698             for(int ii = 0; ii < mean_.size(); ++ii)
00699                 tmp_[ii] = labels_(*iter, ii) - mean_[ii]; 
00700             double f  = 1.0 / count_,
00701                    f1 = 1.0 - f;
00702             for(int ii = 0; ii < mean_.size(); ++ii)
00703                 mean_[ii] += f*tmp_[ii]; 
00704             for(int ii = 0; ii < mean_.size(); ++ii)
00705                 variance_[ii] += f1*sq(tmp_[ii]);
00706         }
00707         double res = std::accumulate(variance_.begin(), 
00708                                variance_.end(),
00709                                0.0,
00710                                std::plus<double>())
00711                 /((count_ == 1)? 1:(count_ -1));
00712         //std::cerr << res << "  ) = ";
00713         return res;
00714     }
00715 
00716     template<class Iter>      //This is BROKEN
00717     double decrement (Iter begin, Iter end)
00718     {
00719         for(Iter iter = begin; iter != end; ++iter)
00720         {
00721             double f  = 1.0 / count_,
00722                    f1 = 1.0 - f;
00723             for(int ii = 0; ii < mean_.size(); ++ii)
00724                 mean_[ii] = (mean_[ii] - f*labels_(*iter,ii))/(1-f); 
00725             for(int ii = 0; ii < mean_.size(); ++ii)
00726                 variance_[ii] -= f1*sq(labels_(*iter,ii) - mean_[ii]);
00727             --count_;
00728         }
00729         double res =  std::accumulate(variance_.begin(), 
00730                                variance_.end(),
00731                                0.0,
00732                                std::plus<double>())
00733                 /((count_ == 1)? 1:(count_ -1));
00734         //std::cerr << "( " << res << " + ";
00735         return res;
00736     }
00737     /* west's algorithm for incremental variance
00738     // calculation
00739     template<class Iter>
00740     double increment (Iter begin, Iter end)
00741     {
00742         for(Iter iter = begin; iter != end; ++iter)
00743         {
00744             ++count_;
00745             for(int ii = 0; ii < mean_.size(); ++ii)
00746                 tmp_[ii] = labels_(*iter, ii) - mean_[ii]; 
00747             double f  = 1.0 / count_,
00748                    f1 = 1.0 - f;
00749             for(int ii = 0; ii < mean_.size(); ++ii)
00750                 mean_[ii] += f*tmp_[ii]; 
00751             for(int ii = 0; ii < mean_.size(); ++ii)
00752                 variance_[ii] += f1*sq(tmp_[ii]);
00753         }
00754         return std::accumulate(variance_.begin(), 
00755                                variance_.end(),
00756                                0.0,
00757                                std::plus<double>())
00758                 /(count_ -1);
00759     }
00760 
00761     template<class Iter>
00762     double decrement (Iter begin, Iter end)
00763     {
00764         for(Iter iter = begin; iter != end; ++iter)
00765         {
00766             --count_;
00767             for(int ii = 0; ii < mean_.size(); ++ii)
00768                 tmp_[ii] = labels_(*iter, ii) - mean_[ii]; 
00769             double f  = 1.0 / count_,
00770                    f1 = 1.0 + f;
00771             for(int ii = 0; ii < mean_.size(); ++ii)
00772                 mean_[ii] -= f*tmp_[ii]; 
00773             for(int ii = 0; ii < mean_.size(); ++ii)
00774                 variance_[ii] -= f1*sq(tmp_[ii]);
00775         }
00776         return std::accumulate(variance_.begin(), 
00777                                variance_.end(),
00778                                0.0,
00779                                std::plus<double>())
00780                 /(count_ -1);
00781     }*/
00782 
00783     template<class Iter, class Resp_t>
00784     double init (Iter begin, Iter end, Resp_t resp)
00785     {
00786         reset();
00787         return this->increment(begin, end, resp);
00788     }
00789     
00790 
00791     ArrayVector<double> const & response()
00792     {
00793         return mean_;
00794     }
00795 
00796     void reset()
00797     {
00798         mean_.init(0.0);
00799         variance_.init(0.0);
00800         count_ = 0; 
00801     }
00802 };
00803 
00804 template<class Tag, class Datatyp>
00805 struct LossTraits;
00806 
00807 struct LSQLoss
00808 {};
00809 
00810 template<class Datatype>
00811 struct LossTraits<GiniCriterion, Datatype>
00812 {
00813     typedef ImpurityLoss<Datatype, GiniCriterion> type;
00814 };
00815 
00816 template<class Datatype>
00817 struct LossTraits<EntropyCriterion, Datatype>
00818 {
00819     typedef ImpurityLoss<Datatype, EntropyCriterion> type;
00820 };
00821 
00822 template<class Datatype>
00823 struct LossTraits<LSQLoss, Datatype>
00824 {
00825     typedef RegressionForestCounter<Datatype> type;
00826 };
00827 
00828 /** Given a column, choose a split that minimizes some loss
00829  */
00830 template<class LineSearchLossTag>
00831 class BestGiniOfColumn
00832 {
00833 public:
00834     ArrayVector<double>     class_weights_;
00835     ArrayVector<double>     bestCurrentCounts[2];
00836     double                  min_gini_;
00837     std::ptrdiff_t               min_index_;
00838     double                  min_threshold_;
00839     ProblemSpec<>           ext_param_;
00840 
00841     BestGiniOfColumn()
00842     {}
00843 
00844     template<class T> 
00845     BestGiniOfColumn(ProblemSpec<T> const & ext)
00846     :
00847         class_weights_(ext.class_weights_),
00848         ext_param_(ext)
00849     {
00850         bestCurrentCounts[0].resize(ext.class_count_);
00851         bestCurrentCounts[1].resize(ext.class_count_);
00852     }
00853     template<class T> 
00854     void set_external_parameters(ProblemSpec<T> const & ext)
00855     {
00856         class_weights_ = ext.class_weights_; 
00857         ext_param_ = ext;
00858         bestCurrentCounts[0].resize(ext.class_count_);
00859         bestCurrentCounts[1].resize(ext.class_count_);
00860     }
00861     /** calculate the best gini split along a Feature Column
00862      * \param column  the feature vector - has to support the [] operator
00863      * \param labels  the label vector 
00864      * \param begin 
00865      * \param end     (in and out)
00866      *                begin and end iterators to the indices of the
00867      *                samples in the current region. 
00868      *                the range begin - end is sorted by the column supplied
00869      *                during function execution.
00870      * \param region_response
00871      *                ???
00872      *                class histogram of the range. 
00873      *
00874      *  precondition: begin, end valid range, 
00875      *                class_counts positive integer valued array with the 
00876      *                class counts in the current range.
00877      *                labels.size() >= max(begin, end); 
00878      *  postcondition:
00879      *                begin, end sorted by column given. 
00880      *                min_gini_ contains the minimum gini found or 
00881      *                NumericTraits<double>::max if no split was found.
00882      *                min_index_ contains the splitting index in the range
00883      *                or invalid data if no split was found.
00884      *                BestCirremtcounts[0] and [1] contain the 
00885      *                class histogram of the left and right region of 
00886      *                the left and right regions. 
00887      */
00888     template<   class DataSourceF_t,
00889                 class DataSource_t, 
00890                 class I_Iter, 
00891                 class Array>
00892     void operator()(DataSourceF_t   const & column,
00893                     DataSource_t    const & labels,
00894                     I_Iter                & begin, 
00895                     I_Iter                & end,
00896                     Array           const & region_response)
00897     {
00898         std::sort(begin, end, 
00899                   SortSamplesByDimensions<DataSourceF_t>(column, 0));
00900         typedef typename 
00901             LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
00902         LineSearchLoss left(labels, ext_param_); //initialize left and right region
00903         LineSearchLoss right(labels, ext_param_);
00904 
00905         
00906 
00907         min_gini_ = right.init(begin, end, region_response);  
00908         min_threshold_ = *begin;
00909         min_index_     = 0;  //the starting point where to split 
00910         DimensionNotEqual<DataSourceF_t> comp(column, 0); 
00911         
00912         I_Iter iter = begin;
00913         I_Iter next = std::adjacent_find(iter, end, comp);
00914         //std::cerr << std::distance(begin, end) << std::endl;
00915         while( next  != end)
00916         {
00917             double lr  =  right.decrement(iter, next + 1);
00918             double ll  =  left.increment(iter , next + 1);
00919             double loss = lr +ll;
00920             //std::cerr <<lr << " + "<< ll << " " << loss << " ";
00921 #ifdef CLASSIFIER_TEST
00922             if(loss < min_gini_ && !closeAtTolerance(loss, min_gini_))
00923 #else
00924             if(loss < min_gini_ )
00925 #endif 
00926             {
00927                 bestCurrentCounts[0] = left.response();
00928                 bestCurrentCounts[1] = right.response();
00929 #ifdef CLASSIFIER_TEST
00930                 min_gini_       = loss < min_gini_? loss : min_gini_;
00931 #else
00932                 min_gini_       = loss; 
00933 #endif
00934                 min_index_      = next - begin +1 ;
00935                 min_threshold_  = (double(column(*next,0)) + double(column(*(next +1), 0)))/2.0;
00936             }
00937             iter = next +1 ;
00938             next = std::adjacent_find(iter, end, comp);
00939         }
00940         //std::cerr << std::endl << " 000 " << std::endl;
00941         //int in;
00942         //std::cin >> in;
00943     }
00944 
00945     template<class DataSource_t, class Iter, class Array>
00946     double loss_of_region(DataSource_t const & labels,
00947                           Iter & begin, 
00948                           Iter & end, 
00949                           Array const & region_response) const
00950     {
00951         typedef typename 
00952             LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
00953         LineSearchLoss region_loss(labels, ext_param_);
00954         return 
00955             region_loss.init(begin, end, region_response);
00956     }
00957 
00958 };
00959 
00960 namespace detail
00961 {
00962     template<class T>
00963     struct Correction
00964     {
00965         template<class Region, class LabelT>
00966         static void exec(Region & in, LabelT & labels)
00967         {}
00968     };
00969     
00970     template<>
00971     struct Correction<ClassificationTag>
00972     {
00973         template<class Region, class LabelT>
00974         static void exec(Region & region, LabelT & labels) 
00975         {
00976             if(std::accumulate(region.classCounts().begin(),
00977                                region.classCounts().end(), 0.0) != region.size())
00978             {
00979                 RandomForestClassCounter<   LabelT, 
00980                                             ArrayVector<double> >
00981                     counter(labels, region.classCounts());
00982                 std::for_each(  region.begin(), region.end(), counter);
00983                 region.classCountsIsValid = true;
00984             }
00985         }
00986     };
00987 }
00988 
00989 /** Chooses mtry columns and applies ColumnDecisionFunctor to each of the
00990  * columns. Then Chooses the column that is best
00991  */
00992 template<class ColumnDecisionFunctor, class Tag = ClassificationTag>
00993 class ThresholdSplit: public SplitBase<Tag>
00994 {
00995   public:
00996 
00997 
00998     typedef SplitBase<Tag> SB;
00999     
01000     ArrayVector<Int32>          splitColumns;
01001     ColumnDecisionFunctor       bgfunc;
01002 
01003     double                      region_gini_;
01004     ArrayVector<double>         min_gini_;
01005     ArrayVector<std::ptrdiff_t>      min_indices_;
01006     ArrayVector<double>         min_thresholds_;
01007 
01008     int                         bestSplitIndex;
01009 
01010     double minGini() const
01011     {
01012         return min_gini_[bestSplitIndex];
01013     }
01014     int bestSplitColumn() const
01015     {
01016         return splitColumns[bestSplitIndex];
01017     }
01018     double bestSplitThreshold() const
01019     {
01020         return min_thresholds_[bestSplitIndex];
01021     }
01022 
01023     template<class T>
01024     void set_external_parameters(ProblemSpec<T> const & in)
01025     {
01026         SB::set_external_parameters(in);        
01027         bgfunc.set_external_parameters( SB::ext_param_);
01028         int featureCount_ = SB::ext_param_.column_count_;
01029         splitColumns.resize(featureCount_);
01030         for(int k=0; k<featureCount_; ++k)
01031             splitColumns[k] = k;
01032         min_gini_.resize(featureCount_);
01033         min_indices_.resize(featureCount_);
01034         min_thresholds_.resize(featureCount_);
01035     }
01036 
01037 
01038     template<class T, class C, class T2, class C2, class Region, class Random>
01039     int findBestSplit(MultiArrayView<2, T, C> features,
01040                       MultiArrayView<2, T2, C2>  labels,
01041                       Region & region,
01042                       ArrayVector<Region>& childRegions,
01043                       Random & randint)
01044     {
01045 
01046         typedef typename Region::IndexIterator IndexIterator;
01047         if(region.size() == 0)
01048         {
01049            std::cerr << "SplitFunctor::findBestSplit(): stackentry with 0 examples encountered\n"
01050                         "continuing learning process...."; 
01051         }
01052         // calculate things that haven't been calculated yet. 
01053         detail::Correction<Tag>::exec(region, labels);
01054 
01055 
01056         // Is the region pure already?
01057         region_gini_ = bgfunc.loss_of_region(labels,
01058                                              region.begin(), 
01059                                              region.end(),
01060                                              region.classCounts());
01061         if(region_gini_ <= SB::ext_param_.precision_)
01062             return  this->makeTerminalNode(features, labels, region, randint);
01063 
01064         // select columns  to be tried.
01065         for(int ii = 0; ii < SB::ext_param_.actual_mtry_; ++ii)
01066             std::swap(splitColumns[ii], 
01067                       splitColumns[ii+ randint(features.shape(1) - ii)]);
01068 
01069         // find the best gini index
01070         bestSplitIndex              = 0;
01071         double  current_min_gini    = region_gini_;
01072         int     num2try             = features.shape(1);
01073         for(int k=0; k<num2try; ++k)
01074         {
01075             //this functor does all the work
01076             bgfunc(columnVector(features, splitColumns[k]),
01077                    labels, 
01078                    region.begin(), region.end(), 
01079                    region.classCounts());
01080             min_gini_[k]            = bgfunc.min_gini_; 
01081             min_indices_[k]         = bgfunc.min_index_;
01082             min_thresholds_[k]      = bgfunc.min_threshold_;
01083 #ifdef CLASSIFIER_TEST
01084             if(     bgfunc.min_gini_ < current_min_gini
01085                &&  !closeAtTolerance(bgfunc.min_gini_, current_min_gini))
01086 #else
01087             if(bgfunc.min_gini_ < current_min_gini)
01088 #endif
01089             {
01090                 current_min_gini = bgfunc.min_gini_;
01091                 childRegions[0].classCounts() = bgfunc.bestCurrentCounts[0];
01092                 childRegions[1].classCounts() = bgfunc.bestCurrentCounts[1];
01093                 childRegions[0].classCountsIsValid = true;
01094                 childRegions[1].classCountsIsValid = true;
01095 
01096                 bestSplitIndex   = k;
01097                 num2try = SB::ext_param_.actual_mtry_;
01098             }
01099         }
01100         //std::cerr << current_min_gini << "curr " << region_gini_ << std::endl;
01101         // did not find any suitable split
01102         if(closeAtTolerance(current_min_gini, region_gini_))
01103             return  this->makeTerminalNode(features, labels, region, randint);
01104         
01105         //create a Node for output
01106         Node<i_ThresholdNode>   node(SB::t_data, SB::p_data);
01107         SB::node_ = node;
01108         node.threshold()    = min_thresholds_[bestSplitIndex];
01109         node.column()       = splitColumns[bestSplitIndex];
01110         
01111         // partition the range according to the best dimension 
01112         SortSamplesByDimensions<MultiArrayView<2, T, C> > 
01113             sorter(features, node.column(), node.threshold());
01114         IndexIterator bestSplit =
01115             std::partition(region.begin(), region.end(), sorter);
01116         // Save the ranges of the child stack entries.
01117         childRegions[0].setRange(   region.begin()  , bestSplit       );
01118         childRegions[0].rule = region.rule;
01119         childRegions[0].rule.push_back(std::make_pair(1, 1.0));
01120         childRegions[1].setRange(   bestSplit       , region.end()    );
01121         childRegions[1].rule = region.rule;
01122         childRegions[1].rule.push_back(std::make_pair(1, 1.0));
01123 
01124         return i_ThresholdNode;
01125     }
01126 };
01127 
01128 typedef  ThresholdSplit<BestGiniOfColumn<GiniCriterion> >                      GiniSplit;
01129 typedef  ThresholdSplit<BestGiniOfColumn<EntropyCriterion> >                 EntropySplit;
01130 typedef  ThresholdSplit<BestGiniOfColumn<LSQLoss>, RegressionTag>              RegressionSplit;
01131 
01132 namespace rf
01133 {
01134 
01135 /** This namespace contains additional Splitfunctors.
01136  *
01137  * The Split functor classes are designed in a modular fashion because new split functors may 
01138  * share a lot of code with existing ones. 
01139  * 
01140  * ThresholdSplit implements the functionality needed for any split functor, that makes its 
01141  * decision via one dimensional axis-parallel cuts. The Template parameter defines how the split
01142  * along one dimension is chosen. 
01143  *
01144  * The BestGiniOfColumn class chooses a split that minimizes one of the Loss functions supplied
01145  * (GiniCriterion for classification and LSQLoss for regression). Median chooses the Split in a 
01146  * kD tree fashion. 
01147  *
01148  *
01149  * Currently defined typedefs: 
01150  * \code
01151  * typedef  ThresholdSplit<BestGiniOfColumn<GiniCriterion> >                 GiniSplit;
01152  * typedef  ThresholdSplit<BestGiniOfColumn<LSQLoss>, RegressionTag>         RegressionSplit;
01153  * typedef  ThresholdSplit<Median> MedianSplit;
01154  * \endcode
01155  */
01156 namespace split
01157 {
01158 
01159 /** This Functor chooses the median value of a column
01160  */
01161 class Median
01162 {
01163 public:
01164 
01165     typedef GiniCriterion   LineSearchLossTag;
01166     ArrayVector<double>     class_weights_;
01167     ArrayVector<double>     bestCurrentCounts[2];
01168     double                  min_gini_;
01169     std::ptrdiff_t          min_index_;
01170     double                  min_threshold_;
01171     ProblemSpec<>           ext_param_;
01172 
01173     Median()
01174     {}
01175 
01176     template<class T> 
01177     Median(ProblemSpec<T> const & ext)
01178     :
01179         class_weights_(ext.class_weights_),
01180         ext_param_(ext)
01181     {
01182         bestCurrentCounts[0].resize(ext.class_count_);
01183         bestCurrentCounts[1].resize(ext.class_count_);
01184     }
01185   
01186     template<class T> 
01187     void set_external_parameters(ProblemSpec<T> const & ext)
01188     {
01189         class_weights_ = ext.class_weights_; 
01190         ext_param_ = ext;
01191         bestCurrentCounts[0].resize(ext.class_count_);
01192         bestCurrentCounts[1].resize(ext.class_count_);
01193     }
01194      
01195     template<   class DataSourceF_t,
01196                 class DataSource_t, 
01197                 class I_Iter, 
01198                 class Array>
01199     void operator()(DataSourceF_t   const & column,
01200                     DataSource_t    const & labels,
01201                     I_Iter                & begin, 
01202                     I_Iter                & end,
01203                     Array           const & region_response)
01204     {
01205         std::sort(begin, end, 
01206                   SortSamplesByDimensions<DataSourceF_t>(column, 0));
01207         typedef typename 
01208             LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
01209         LineSearchLoss left(labels, ext_param_);
01210         LineSearchLoss right(labels, ext_param_);
01211         right.init(begin, end, region_response);
01212 
01213         min_gini_ = NumericTraits<double>::max();
01214         min_index_ = floor(double(end - begin)/2.0); 
01215         min_threshold_ =  column[*(begin + min_index_)];
01216         SortSamplesByDimensions<DataSourceF_t> 
01217             sorter(column, 0, min_threshold_);
01218         I_Iter part = std::partition(begin, end, sorter);
01219         DimensionNotEqual<DataSourceF_t> comp(column, 0); 
01220         if(part == begin)
01221         {
01222             part= std::adjacent_find(part, end, comp)+1;
01223             
01224         }
01225         if(part >= end)
01226         {
01227             return; 
01228         }
01229         else
01230         {
01231             min_threshold_ = column[*part];
01232         }
01233         min_gini_ = right.decrement(begin, part) 
01234               +     left.increment(begin , part);
01235 
01236         bestCurrentCounts[0] = left.response();
01237         bestCurrentCounts[1] = right.response();
01238         
01239         min_index_      = part - begin;
01240     }
01241 
01242     template<class DataSource_t, class Iter, class Array>
01243     double loss_of_region(DataSource_t const & labels,
01244                           Iter & begin, 
01245                           Iter & end, 
01246                           Array const & region_response) const
01247     {
01248         typedef typename 
01249             LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
01250         LineSearchLoss region_loss(labels, ext_param_);
01251         return 
01252             region_loss.init(begin, end, region_response);
01253     }
01254 
01255 };
01256 
01257 typedef  ThresholdSplit<Median> MedianSplit;
01258 
01259 
01260 /** This Functor chooses a random value of a column
01261  */
01262 class RandomSplitOfColumn
01263 {
01264 public:
01265 
01266     typedef GiniCriterion   LineSearchLossTag;
01267     ArrayVector<double>     class_weights_;
01268     ArrayVector<double>     bestCurrentCounts[2];
01269     double                  min_gini_;
01270     std::ptrdiff_t          min_index_;
01271     double                  min_threshold_;
01272     ProblemSpec<>           ext_param_;
01273     typedef RandomMT19937   Random_t;
01274     Random_t                random;
01275 
01276     RandomSplitOfColumn()
01277     {}
01278 
01279     template<class T> 
01280     RandomSplitOfColumn(ProblemSpec<T> const & ext)
01281     :
01282         class_weights_(ext.class_weights_),
01283         ext_param_(ext),
01284         random(RandomSeed)
01285     {
01286         bestCurrentCounts[0].resize(ext.class_count_);
01287         bestCurrentCounts[1].resize(ext.class_count_);
01288     }
01289     
01290     template<class T> 
01291     RandomSplitOfColumn(ProblemSpec<T> const & ext, Random_t & random_)
01292     :
01293         class_weights_(ext.class_weights_),
01294         ext_param_(ext),
01295         random(random_)
01296     {
01297         bestCurrentCounts[0].resize(ext.class_count_);
01298         bestCurrentCounts[1].resize(ext.class_count_);
01299     }
01300   
01301     template<class T> 
01302     void set_external_parameters(ProblemSpec<T> const & ext)
01303     {
01304         class_weights_ = ext.class_weights_; 
01305         ext_param_ = ext;
01306         bestCurrentCounts[0].resize(ext.class_count_);
01307         bestCurrentCounts[1].resize(ext.class_count_);
01308     }
01309      
01310     template<   class DataSourceF_t,
01311                 class DataSource_t, 
01312                 class I_Iter, 
01313                 class Array>
01314     void operator()(DataSourceF_t   const & column,
01315                     DataSource_t    const & labels,
01316                     I_Iter                & begin, 
01317                     I_Iter                & end,
01318                     Array           const & region_response)
01319     {
01320         std::sort(begin, end, 
01321                   SortSamplesByDimensions<DataSourceF_t>(column, 0));
01322         typedef typename 
01323             LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
01324         LineSearchLoss left(labels, ext_param_);
01325         LineSearchLoss right(labels, ext_param_);
01326         right.init(begin, end, region_response);
01327 
01328         
01329         min_gini_ = NumericTraits<double>::max();
01330         int tmp_pt = random.uniformInt(std::distance(begin, end));
01331         min_index_ = tmp_pt;
01332         min_threshold_ =  column[*(begin + min_index_)];
01333         SortSamplesByDimensions<DataSourceF_t> 
01334             sorter(column, 0, min_threshold_);
01335         I_Iter part = std::partition(begin, end, sorter);
01336         DimensionNotEqual<DataSourceF_t> comp(column, 0); 
01337         if(part == begin)
01338         {
01339             part= std::adjacent_find(part, end, comp)+1;
01340             
01341         }
01342         if(part >= end)
01343         {
01344             return; 
01345         }
01346         else
01347         {
01348             min_threshold_ = column[*part];
01349         }
01350         min_gini_ = right.decrement(begin, part) 
01351               +     left.increment(begin , part);
01352 
01353         bestCurrentCounts[0] = left.response();
01354         bestCurrentCounts[1] = right.response();
01355         
01356         min_index_      = part - begin;
01357     }
01358 
01359     template<class DataSource_t, class Iter, class Array>
01360     double loss_of_region(DataSource_t const & labels,
01361                           Iter & begin, 
01362                           Iter & end, 
01363                           Array const & region_response) const
01364     {
01365         typedef typename 
01366             LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
01367         LineSearchLoss region_loss(labels, ext_param_);
01368         return 
01369             region_loss.init(begin, end, region_response);
01370     }
01371 
01372 };
01373 
01374 typedef  ThresholdSplit<RandomSplitOfColumn> RandomSplit;
01375 }
01376 }
01377 
01378 
01379 } //namespace vigra
01380 #endif // VIGRA_RANDOM_FOREST_SPLIT_HXX

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.8.0 (20 Sep 2011)