[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

vigra/random_forest/rf_split.hxx VIGRA

00001 /************************************************************************/
00002 /*                                                                      */
00003 /*        Copyright 2008-2009 by  Ullrich Koethe and Rahul Nair         */
00004 /*                                                                      */
00005 /*    This file is part of the VIGRA computer vision library.           */
00006 /*    The VIGRA Website is                                              */
00007 /*        http://hci.iwr.uni-heidelberg.de/vigra/                       */
00008 /*    Please direct questions, bug reports, and contributions to        */
00009 /*        ullrich.koethe@iwr.uni-heidelberg.de    or                    */
00010 /*        vigra@informatik.uni-hamburg.de                               */
00011 /*                                                                      */
00012 /*    Permission is hereby granted, free of charge, to any person       */
00013 /*    obtaining a copy of this software and associated documentation    */
00014 /*    files (the "Software"), to deal in the Software without           */
00015 /*    restriction, including without limitation the rights to use,      */
00016 /*    copy, modify, merge, publish, distribute, sublicense, and/or      */
00017 /*    sell copies of the Software, and to permit persons to whom the    */
00018 /*    Software is furnished to do so, subject to the following          */
00019 /*    conditions:                                                       */
00020 /*                                                                      */
00021 /*    The above copyright notice and this permission notice shall be    */
00022 /*    included in all copies or substantial portions of the             */
00023 /*    Software.                                                         */
00024 /*                                                                      */
00025 /*    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND    */
00026 /*    EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES   */
00027 /*    OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND          */
00028 /*    NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT       */
00029 /*    HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,      */
00030 /*    WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING      */
00031 /*    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR     */
00032 /*    OTHER DEALINGS IN THE SOFTWARE.                                   */
00033 /*                                                                      */
00034 /************************************************************************/
00035 #ifndef VIGRA_RANDOM_FOREST_SPLIT_HXX
00036 #define VIGRA_RANDOM_FOREST_SPLIT_HXX
00037 #include <algorithm>
00038 #include <map>
00039 #include <numeric>
00040 #include <math.h>
00041 #include "../mathutil.hxx"
00042 #include "../array_vector.hxx"
00043 #include "../sized_int.hxx"
00044 #include "../matrix.hxx"
00045 #include "../random.hxx"
00046 #include "../functorexpression.hxx"
00047 #include "rf_nodeproxy.hxx"
00048 //#include "rf_sampling.hxx"
00049 #include "rf_region.hxx"
00050 //#include "../hokashyap.hxx"
00051 //#include "vigra/rf_helpers.hxx"
00052 
00053 namespace vigra
00054 {
00055 
00056 // Incomplete Class to ensure that findBestSplit is always implemented in
00057 // the derived classes of SplitBase
00058 class CompileTimeError;
00059 
00060 
00061 namespace detail
00062 {
00063     template<class Tag>
00064     class Normalise
00065     {
00066     public:
00067         template<class Iter>
00068         static void exec(Iter begin, Iter  end)
00069         {}
00070     };
00071 
00072     template<>
00073     class Normalise<ClassificationTag>
00074     {
00075     public:
00076         template<class Iter>
00077         static void exec (Iter begin, Iter end)
00078         {
00079             double bla = std::accumulate(begin, end, 0.0);
00080             for(int ii = 0; ii < end - begin; ++ii)
00081                 begin[ii] = begin[ii]/bla ;
00082         }
00083     };
00084 }
00085 
00086 
00087 /** Base Class for all SplitFunctors used with the \ref RandomForest class
00088     defines the interface used while learning a tree.
00089 **/
00090 template<class Tag>
00091 class SplitBase
00092 {
00093   public:
00094 
00095     typedef Tag           RF_Tag;
00096     typedef DT_StackEntry<ArrayVectorView<Int32>::iterator>
00097                                         StackEntry_t;
00098 
00099     ProblemSpec<>           ext_param_;
00100 
00101     NodeBase::T_Container_type          t_data;
00102     NodeBase::P_Container_type          p_data;
00103 
00104     NodeBase                            node_;
00105 
00106     /** returns the DecisionTree Node created by
00107         \ref findBestSplit or \ref makeTerminalNode.
00108     **/
00109 
00110     template<class T>
00111     void set_external_parameters(ProblemSpec<T> const & in)
00112     {
00113         ext_param_ = in;
00114         t_data.push_back(in.column_count_);
00115         t_data.push_back(in.class_count_);
00116     }
00117 
00118     NodeBase & createNode() 
00119     {
00120         return node_;
00121     }
00122 
00123     int classCount() const
00124     {
00125         return int(t_data[1]);
00126     }
00127 
00128     int featureCount() const
00129     {
00130         return int(t_data[0]);
00131     }
00132 
00133     /** resets internal data. Should always be called before
00134         calling findBestSplit or makeTerminalNode
00135     **/
00136     void reset()
00137     {
00138         t_data.resize(2);
00139         p_data.resize(0);
00140     }
00141 
00142 
00143     /** findBestSplit has to be implemented in derived split functor.
00144         these functions only insures That a CompileTime error is issued
00145         if no such method was defined.
00146     **/
00147 
00148     template<class T, class C, class T2, class C2, class Region, class Random>
00149     int findBestSplit(MultiArrayView<2, T, C> features,
00150                       MultiArrayView<2, T2, C2> labels,
00151                       Region region,
00152                       ArrayVector<Region> childs,
00153                       Random randint)
00154     {
00155         CompileTimeError SplitFunctor__findBestSplit_member_was_not_defined;
00156         return 0;
00157     }
00158 
00159     /** default action for creating a terminal Node.
00160         sets the Class probability of the remaining region according to
00161         the class histogram
00162     **/
00163     template<class T, class C, class T2,class C2, class Region, class Random>
00164     int makeTerminalNode(MultiArrayView<2, T, C> features,
00165                       MultiArrayView<2, T2, C2>  labels,
00166                       Region &                   region,
00167                       Random                     randint)
00168     {
00169         Node<e_ConstProbNode> ret(t_data, p_data);
00170         node_ = ret;
00171         if(ext_param_.class_weights_.size() != region.classCounts().size())
00172         {
00173         std::copy(          region.classCounts().begin(),
00174                             region.classCounts().end(),
00175                             ret.prob_begin());
00176         }
00177         else
00178         {
00179         std::transform(     region.classCounts().begin(),
00180                             region.classCounts().end(),
00181                             ext_param_.class_weights_.begin(),
00182                             ret.prob_begin(), std::multiplies<double>());
00183         }
00184         detail::Normalise<RF_Tag>::exec(ret.prob_begin(), ret.prob_end());
00185         ret.weights() = region.size();  
00186         return e_ConstProbNode;
00187     }
00188 
00189 
00190 };
00191 
00192 /** Functor to sort the indices of a feature Matrix by a certain dimension
00193 **/
00194 template<class DataMatrix>
00195 class SortSamplesByDimensions
00196 {
00197     DataMatrix const & data_;
00198     MultiArrayIndex sortColumn_;
00199     double thresVal_;
00200   public:
00201 
00202     SortSamplesByDimensions(DataMatrix const & data, 
00203                             MultiArrayIndex sortColumn,
00204                             double thresVal = 0.0)
00205     : data_(data),
00206       sortColumn_(sortColumn),
00207       thresVal_(thresVal)
00208     {}
00209 
00210     void setColumn(MultiArrayIndex sortColumn)
00211     {
00212         sortColumn_ = sortColumn;
00213     }
00214     void setThreshold(double value)
00215     {
00216         thresVal_ = value; 
00217     }
00218 
00219     bool operator()(MultiArrayIndex l, MultiArrayIndex r) const
00220     {
00221         return data_(l, sortColumn_) < data_(r, sortColumn_);
00222     }
00223     bool operator()(MultiArrayIndex l) const
00224     {
00225         return data_(l, sortColumn_) < thresVal_;
00226     }
00227 };
00228 
00229 template<class DataMatrix>
00230 class DimensionNotEqual
00231 {
00232     DataMatrix const & data_;
00233     MultiArrayIndex sortColumn_;
00234 
00235   public:
00236 
00237     DimensionNotEqual(DataMatrix const & data, 
00238                             MultiArrayIndex sortColumn)
00239     : data_(data),
00240       sortColumn_(sortColumn)
00241     {}
00242 
00243     void setColumn(MultiArrayIndex sortColumn)
00244     {
00245         sortColumn_ = sortColumn;
00246     }
00247 
00248     bool operator()(MultiArrayIndex l, MultiArrayIndex r) const
00249     {
00250         return data_(l, sortColumn_) != data_(r, sortColumn_);
00251     }
00252 };
00253 
00254 template<class DataMatrix>
00255 class SortSamplesByHyperplane
00256 {
00257     DataMatrix const & data_;
00258     Node<i_HyperplaneNode> const & node_;
00259 
00260   public:
00261 
00262     SortSamplesByHyperplane(DataMatrix              const & data, 
00263                             Node<i_HyperplaneNode>  const & node)
00264     :       
00265             data_(data), 
00266             node_()
00267     {}
00268 
00269     /** calculate the distance of a sample point to a hyperplane
00270      */
00271     double operator[](MultiArrayIndex l) const
00272     {
00273         double result_l = -1 * node_.intercept();
00274         for(int ii = 0; ii < node_.columns_size(); ++ii)
00275         {
00276             result_l +=     rowVector(data_, l)[node_.columns_begin()[ii]] 
00277                         *   node_.weights()[ii];
00278         }
00279         return result_l;
00280     }
00281 
00282     bool operator()(MultiArrayIndex l, MultiArrayIndex r) const
00283     {
00284         return (*this)[l]  < (*this)[r];
00285     }
00286 
00287 };
00288 
00289 /** makes a Class Histogram given indices in a labels_ array
00290  *  usage: 
00291  *      MultiArrayView<2, T2, C2> labels = makeSomeLabels()
00292  *      ArrayVector<int> hist(numberOfLabels(labels), 0);
00293  *      RandomForestClassCounter<T2, C2, ArrayVector> counter(labels, hist);
00294  *
00295  *      Container<int> indices = getSomeIndices()
00296  *      std::for_each(indices, counter);
00297  */
00298 template <class DataSource, class CountArray>
00299 class RandomForestClassCounter
00300 {
00301     DataSource  const &     labels_;
00302     CountArray        &     counts_;
00303 
00304   public:
00305 
00306     RandomForestClassCounter(DataSource  const & labels, 
00307                              CountArray & counts)
00308     : labels_(labels),
00309       counts_(counts)
00310     {
00311         reset();
00312     }
00313 
00314     void reset()
00315     {
00316         counts_.init(0);
00317     }
00318 
00319     void operator()(MultiArrayIndex l) const
00320     {
00321         counts_[labels_[l]] +=1;
00322     }
00323 };
00324 
00325 
00326 /** Functor To Calculate the Best possible Split Based on the Gini Index
00327     given Labels and Features along a given Axis
00328 */
00329 
00330 namespace detail
00331 {
00332     template<int N>
00333     class ConstArr
00334     {
00335     public:
00336         double operator[](size_t) const
00337         {
00338             return (double)N;
00339         }
00340     };
00341 
00342 
00343 }
00344 
00345 
00346 
00347 
00348 /** Functor to calculate the entropy based impurity
00349  */
00350 class EntropyCriterion
00351 {
00352 public:
00353     /**caculate the weighted gini impurity based on class histogram
00354      * and class weights
00355      */
00356     template<class Array, class Array2>
00357     double operator()        (Array     const & hist, 
00358                               Array2    const & weights, 
00359                               double            total = 1.0) const
00360     {
00361         return impurity(hist, weights, total);
00362     }
00363     
00364     /** calculate the gini based impurity based on class histogram
00365      */
00366     template<class Array>
00367     double operator()(Array const & hist, double total = 1.0) const
00368     {
00369         return impurity(hist, total);
00370     }
00371     
00372     /** static version of operator(hist total)
00373      */
00374     template<class Array>
00375     static double impurity(Array const & hist, double total)
00376     {
00377         return impurity(hist, detail::ConstArr<1>(), total);
00378     }
00379 
00380     /** static version of operator(hist, weights, total)
00381      */
00382     template<class Array, class Array2>
00383     static double impurity   (Array     const & hist, 
00384                               Array2    const & weights, 
00385                               double            total)
00386     {
00387 
00388         int     class_count     = hist.size();
00389         double  entropy            = 0.0;
00390         if(class_count == 2)
00391         {
00392             double p0           = (hist[0]/total);
00393             double p1           = (hist[1]/total);
00394             entropy             = 0 - weights[0]*p0*std::log(p0) - weights[1]*p1*std::log(p1);
00395         }
00396         else
00397         {
00398             for(int ii = 0; ii < class_count; ++ii)
00399             {
00400                 double w        = weights[ii];
00401                 double pii      = hist[ii]/total;
00402                 entropy         -= w*( pii*std::log(pii));
00403             }
00404         }
00405         entropy             = total * entropy;
00406         return entropy; 
00407     }
00408 };
00409 
00410 /** Functor to calculate the gini impurity
00411  */
00412 class GiniCriterion
00413 {
00414 public:
00415     /**caculate the weighted gini impurity based on class histogram
00416      * and class weights
00417      */
00418     template<class Array, class Array2>
00419     double operator()        (Array     const & hist, 
00420                               Array2    const & weights, 
00421                               double            total = 1.0) const
00422     {
00423         return impurity(hist, weights, total);
00424     }
00425     
00426     /** calculate the gini based impurity based on class histogram
00427      */
00428     template<class Array>
00429     double operator()(Array const & hist, double total = 1.0) const
00430     {
00431         return impurity(hist, total);
00432     }
00433     
00434     /** static version of operator(hist total)
00435      */
00436     template<class Array>
00437     static double impurity(Array const & hist, double total)
00438     {
00439         return impurity(hist, detail::ConstArr<1>(), total);
00440     }
00441 
00442     /** static version of operator(hist, weights, total)
00443      */
00444     template<class Array, class Array2>
00445     static double impurity   (Array     const & hist, 
00446                               Array2    const & weights, 
00447                               double            total)
00448     {
00449 
00450         int     class_count     = hist.size();
00451         double  gini            = 0.0;
00452         if(class_count == 2)
00453         {
00454             double w            = weights[0] * weights[1];
00455             gini                = w * (hist[0] * hist[1] / total);
00456         }
00457         else
00458         {
00459             for(int ii = 0; ii < class_count; ++ii)
00460             {
00461                 double w        = weights[ii];
00462                 gini           += w*( hist[ii]*( 1.0 - w * hist[ii]/total ) );
00463             }
00464         }
00465         return gini; 
00466     }
00467 };
00468 
00469 
00470 template <class DataSource, class Impurity= GiniCriterion>
00471 class ImpurityLoss
00472 {
00473 
00474     DataSource  const &         labels_;
00475     ArrayVector<double>        counts_;
00476     ArrayVector<double> const  class_weights_;
00477     double                      total_counts_;
00478     Impurity                    impurity_;
00479 
00480   public:
00481 
00482     template<class T>
00483     ImpurityLoss(DataSource  const & labels, 
00484                                 ProblemSpec<T> const & ext_)
00485     : labels_(labels),
00486       counts_(ext_.class_count_, 0.0),
00487       class_weights_(ext_.class_weights_),
00488       total_counts_(0.0)
00489     {}
00490 
00491     void reset()
00492     {
00493         counts_.init(0);
00494         total_counts_ = 0.0;
00495     }
00496 
00497     template<class Counts>
00498     double increment_histogram(Counts const & counts)
00499     {
00500         std::transform(counts.begin(), counts.end(),
00501                        counts_.begin(), counts_.begin(),
00502                        std::plus<double>());
00503         total_counts_ = std::accumulate( counts_.begin(), 
00504                                          counts_.end(),
00505                                          0.0);
00506         return impurity_(counts_, class_weights_, total_counts_);
00507     }
00508 
00509     template<class Counts>
00510     double decrement_histogram(Counts const & counts)
00511     {
00512         std::transform(counts.begin(), counts.end(),
00513                        counts_.begin(), counts_.begin(),
00514                        std::minus<double>());
00515         total_counts_ = std::accumulate( counts_.begin(), 
00516                                          counts_.end(),
00517                                          0.0);
00518         return impurity_(counts_, class_weights_, total_counts_);
00519     }
00520 
00521     template<class Iter>
00522     double increment(Iter begin, Iter end)
00523     {
00524         for(Iter iter = begin; iter != end; ++iter)
00525         {
00526             counts_[labels_(*iter, 0)] +=1.0;
00527             total_counts_ +=1.0;
00528         }
00529         return impurity_(counts_, class_weights_, total_counts_);
00530     }
00531 
00532     template<class Iter>
00533     double decrement(Iter const &  begin, Iter const & end)
00534     {
00535         for(Iter iter = begin; iter != end; ++iter)
00536         {
00537             counts_[labels_(*iter,0)] -=1.0;
00538             total_counts_ -=1.0;
00539         }
00540         return impurity_(counts_, class_weights_, total_counts_);
00541     }
00542 
00543     template<class Iter, class Resp_t>
00544     double init (Iter begin, Iter end, Resp_t resp)
00545     {
00546         reset();
00547         std::copy(resp.begin(), resp.end(), counts_.begin());
00548         total_counts_ = std::accumulate(counts_.begin(), counts_.end(), 0.0); 
00549         return impurity_(counts_,class_weights_, total_counts_);
00550     }
00551     
00552     ArrayVector<double> const & response()
00553     {
00554         return counts_;
00555     }
00556 };
00557 
00558 template <class DataSource>
00559 class RegressionForestCounter
00560 {
00561     typedef MultiArrayShape<2>::type Shp;
00562     DataSource const &      labels_;
00563     ArrayVector <double>    mean_;
00564     ArrayVector <double>    variance_;
00565     ArrayVector <double>    tmp_;
00566     size_t                  count_;
00567 
00568     template<class T>
00569     RegressionForestCounter(DataSource const & labels, 
00570                             ProblemSpec<T> const & ext_)
00571     :
00572         labels_(labels),
00573         mean_(ext_.response_size, 0.0),
00574         variance_(ext_.response_size, 0.0),
00575         tmp_(ext_.response_size),
00576         count_(0)
00577     {}
00578     
00579     //  west's alorithm for incremental variance
00580     // calculation
00581     template<class Iter>
00582     double increment (Iter begin, Iter end)
00583     {
00584         for(Iter iter = begin; iter != end; ++iter)
00585         {
00586             ++count_;
00587             for(int ii = 0; ii < mean_.size(); ++ii)
00588                 tmp_[ii] = labels_(*iter, ii) - mean_[ii]; 
00589             double f  = 1.0 / count_,
00590                    f1 = 1.0 - f;
00591             for(int ii = 0; ii < mean_.size(); ++ii)
00592                 mean_[ii] += f*tmp_[ii]; 
00593             for(int ii = 0; ii < mean_.size(); ++ii)
00594                 variance_[ii] += f1*sq(tmp_[ii]);
00595         }
00596         return std::accumulate(variance_.begin(), 
00597                                variance_.end(),
00598                                0.0,
00599                                std::plus<double>())
00600                 /(count_ -1);
00601     }
00602 
00603     template<class Iter>
00604     double decrement (Iter begin, Iter end)
00605     {
00606         for(Iter iter = begin; iter != end; ++iter)
00607         {
00608             --count_;
00609             for(int ii = 0; ii < mean_.size(); ++ii)
00610                 tmp_[ii] = labels_(*iter, ii) - mean_[ii]; 
00611             double f  = 1.0 / count_,
00612                    f1 = 1.0 + f;
00613             for(int ii = 0; ii < mean_.size(); ++ii)
00614                 mean_[ii] -= f*tmp_[ii]; 
00615             for(int ii = 0; ii < mean_.size(); ++ii)
00616                 variance_[ii] -= f1*sq(tmp_[ii]);
00617         }
00618         return std::accumulate(variance_.begin(), 
00619                                variance_.end(),
00620                                0.0,
00621                                std::plus<double>())
00622                 /(count_ -1);
00623     }
00624 
00625     template<class Iter, class Resp_t>
00626     double init (Iter begin, Iter end, Resp_t resp)
00627     {
00628         reset();
00629         return increment(begin, end);
00630     }
00631     
00632 
00633     ArrayVector<double> const & response()
00634     {
00635         return mean_;
00636     }
00637 
00638     void reset()
00639     {
00640         mean_.init(0.0);
00641         variance_.init(0.0);
00642         count_ = 0; 
00643     }
00644 };
00645 
00646 template<class Tag, class Datatyp>
00647 struct LossTraits;
00648 
00649 struct LSQLoss
00650 {};
00651 
00652 template<class Datatype>
00653 struct LossTraits<GiniCriterion, Datatype>
00654 {
00655     typedef ImpurityLoss<Datatype, GiniCriterion> type;
00656 };
00657 
00658 template<class Datatype>
00659 struct LossTraits<EntropyCriterion, Datatype>
00660 {
00661     typedef ImpurityLoss<Datatype, EntropyCriterion> type;
00662 };
00663 
00664 template<class Datatype>
00665 struct LossTraits<LSQLoss, Datatype>
00666 {
00667     typedef RegressionForestCounter<Datatype> type;
00668 };
00669 
00670 /** Given a column, choose a split that minimizes some loss
00671  */
00672 template<class LineSearchLossTag>
00673 class BestGiniOfColumn
00674 {
00675 public:
00676     ArrayVector<double>     class_weights_;
00677     ArrayVector<double>     bestCurrentCounts[2];
00678     double                  min_gini_;
00679     ptrdiff_t               min_index_;
00680     double                  min_threshold_;
00681     ProblemSpec<>           ext_param_;
00682 
00683     BestGiniOfColumn()
00684     {}
00685 
00686     template<class T> 
00687     BestGiniOfColumn(ProblemSpec<T> const & ext)
00688     :
00689         class_weights_(ext.class_weights_),
00690         ext_param_(ext)
00691     {
00692         bestCurrentCounts[0].resize(ext.class_count_);
00693         bestCurrentCounts[1].resize(ext.class_count_);
00694     }
00695     template<class T> 
00696     void set_external_parameters(ProblemSpec<T> const & ext)
00697     {
00698         class_weights_ = ext.class_weights_; 
00699         ext_param_ = ext;
00700         bestCurrentCounts[0].resize(ext.class_count_);
00701         bestCurrentCounts[1].resize(ext.class_count_);
00702     }
00703     /** calculate the best gini split along a Feature Column
00704      * \param column, the feature vector - has to support the [] operator
00705      * \param labels, the label vector 
00706      * \param begin 
00707      * \param end     (in and out)
00708      *                begin and end iterators to the indices of the
00709      *                samples in the current region. 
00710      *                the range begin - end is sorted by the column supplied
00711      *                during function execution.
00712      * \param class_counts
00713      *                class histogram of the range. 
00714      *
00715      *  precondition: begin, end valid range, 
00716      *                class_counts positive integer valued array with the 
00717      *                class counts in the current range.
00718      *                labels.size() >= max(begin, end); 
00719      *  postcondition:
00720      *                begin, end sorted by column given. 
00721      *                min_gini_ contains the minimum gini found or 
00722      *                NumericTraits<double>::max if no split was found.
00723      *                min_index_ countains the splitting index in the range
00724      *                or invalid data if no split was found.
00725      *                BestCirremtcounts[0] and [1] contain the 
00726      *                class histogram of the left and right region of 
00727      *                the left and right regions. 
00728      */
00729     template<   class DataSourceF_t,
00730                 class DataSource_t, 
00731                 class I_Iter, 
00732                 class Array>
00733     void operator()(DataSourceF_t   const & column,
00734                     int                     g,
00735                     DataSource_t    const & labels,
00736                     I_Iter                & begin, 
00737                     I_Iter                & end,
00738                     Array           const & region_response)
00739     {
00740         std::sort(begin, end, 
00741                   SortSamplesByDimensions<DataSourceF_t>(column, g));
00742         typedef typename 
00743             LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
00744         LineSearchLoss left(labels, ext_param_);
00745         LineSearchLoss right(labels, ext_param_);
00746 
00747         
00748 
00749         min_gini_ = right.init(begin, end, region_response);
00750         min_threshold_ = *begin;
00751         min_index_     = 0;
00752         DimensionNotEqual<DataSourceF_t> comp(column, g); 
00753         
00754         I_Iter iter = begin;
00755         I_Iter next = std::adjacent_find(iter, end, comp);
00756         while( next  != end)
00757         {
00758 
00759             double loss = right.decrement(iter, next + 1) 
00760                      +     left.increment(iter , next + 1);
00761 #ifdef CLASSIFIER_TEST
00762             if(loss < min_gini_ && !closeAtTolerance(loss, min_gini_))
00763 #else
00764             if(loss < min_gini_ )
00765 #endif 
00766             {
00767                 bestCurrentCounts[0] = left.response();
00768                 bestCurrentCounts[1] = right.response();
00769 #ifdef CLASSIFIER_TEST
00770                 min_gini_       = loss < min_gini_? loss : min_gini_;
00771 #else
00772                 min_gini_       = loss; 
00773 #endif
00774                 min_index_      = next - begin +1 ;
00775                 min_threshold_  = (double(column(*next,g)) + double(column(*(next +1), g)))/2.0;
00776             }
00777             iter = next +1 ;
00778             next = std::adjacent_find(iter, end, comp);
00779         }
00780     }
00781 
00782     template<class DataSource_t, class Iter, class Array>
00783     double loss_of_region(DataSource_t const & labels,
00784                           Iter & begin, 
00785                           Iter & end, 
00786                           Array const & region_response) const
00787     {
00788         typedef typename 
00789             LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
00790         LineSearchLoss region_loss(labels, ext_param_);
00791         return 
00792             region_loss.init(begin, end, region_response);
00793     }
00794 
00795 };
00796 
00797 
00798 /** Chooses mtry columns ad applys ColumnDecisionFunctor to each of the
00799  * columns. Then Chooses the column that is best
00800  */
00801 template<class ColumnDecisionFunctor, class Tag = ClassificationTag>
00802 class ThresholdSplit: public SplitBase<Tag>
00803 {
00804   public:
00805 
00806 
00807     typedef SplitBase<Tag> SB;
00808     
00809     ArrayVector<Int32>          splitColumns;
00810     ColumnDecisionFunctor       bgfunc;
00811 
00812     double                      region_gini_;
00813     ArrayVector<double>         min_gini_;
00814     ArrayVector<ptrdiff_t>      min_indices_;
00815     ArrayVector<double>         min_thresholds_;
00816 
00817     int                         bestSplitIndex;
00818 
00819     double minGini() const
00820     {
00821         return min_gini_[bestSplitIndex];
00822     }
00823     int bestSplitColumn() const
00824     {
00825         return splitColumns[bestSplitIndex];
00826     }
00827     double bestSplitThreshold() const
00828     {
00829         return min_thresholds_[bestSplitIndex];
00830     }
00831 
00832     template<class T>
00833     void set_external_parameters(ProblemSpec<T> const & in)
00834     {
00835         SB::set_external_parameters(in);        
00836         bgfunc.set_external_parameters( SB::ext_param_);
00837         int featureCount_ = SB::ext_param_.column_count_;
00838         splitColumns.resize(featureCount_);
00839         for(int k=0; k<featureCount_; ++k)
00840             splitColumns[k] = k;
00841         min_gini_.resize(featureCount_);
00842         min_indices_.resize(featureCount_);
00843         min_thresholds_.resize(featureCount_);
00844     }
00845 
00846 
00847     template<class T, class C, class T2, class C2, class Region, class Random>
00848     int findBestSplit(MultiArrayView<2, T, C> features,
00849                       MultiArrayView<2, T2, C2>  labels,
00850                       Region & region,
00851                       ArrayVector<Region>& childRegions,
00852                       Random & randint)
00853     {
00854 
00855         typedef typename Region::IndexIterator IndexIterator;
00856         if(region.size() == 0)
00857         {
00858            std::cerr << "SplitFunctor::findBestSplit(): stackentry with 0 examples encountered\n"
00859                         "continuing learning process...."; 
00860         }
00861         // calculate things that haven't been calculated yet. 
00862         
00863         if(std::accumulate(region.classCounts().begin(),
00864                            region.classCounts().end(), 0) != region.size())
00865         {
00866             RandomForestClassCounter<   MultiArrayView<2,T2, C2>, 
00867                                         ArrayVector<double> >
00868                 counter(labels, region.classCounts());
00869             std::for_each(  region.begin(), region.end(), counter);
00870             region.classCountsIsValid = true;
00871         }
00872 
00873         // Is the region pure already?
00874         region_gini_ = bgfunc.loss_of_region(labels,
00875                                              region.begin(), 
00876                                              region.end(),
00877                                              region.classCounts());
00878         if(region_gini_ <= SB::ext_param_.precision_)
00879             return  makeTerminalNode(features, labels, region, randint);
00880 
00881         // select columns  to be tried.
00882         for(int ii = 0; ii < SB::ext_param_.actual_mtry_; ++ii)
00883             std::swap(splitColumns[ii], 
00884                       splitColumns[ii+ randint(features.shape(1) - ii)]);
00885 
00886         // find the best gini index
00887         bestSplitIndex              = 0;
00888         double  current_min_gini    = region_gini_;
00889         int     num2try             = features.shape(1);
00890         for(int k=0; k<num2try; ++k)
00891         {
00892             //this functor does all the work
00893             bgfunc(features,
00894                    splitColumns[k],
00895                    labels, 
00896                    region.begin(), region.end(), 
00897                    region.classCounts());
00898             min_gini_[k]            = bgfunc.min_gini_; 
00899             min_indices_[k]         = bgfunc.min_index_;
00900             min_thresholds_[k]      = bgfunc.min_threshold_;
00901 #ifdef CLASSIFIER_TEST
00902             if(     bgfunc.min_gini_ < current_min_gini
00903                &&  !closeAtTolerance(bgfunc.min_gini_, current_min_gini))
00904 #else
00905             if(bgfunc.min_gini_ < current_min_gini)
00906 #endif
00907             {
00908                 current_min_gini = bgfunc.min_gini_;
00909                 childRegions[0].classCounts() = bgfunc.bestCurrentCounts[0];
00910                 childRegions[1].classCounts() = bgfunc.bestCurrentCounts[1];
00911                 childRegions[0].classCountsIsValid = true;
00912                 childRegions[1].classCountsIsValid = true;
00913 
00914                 bestSplitIndex   = k;
00915                 num2try = SB::ext_param_.actual_mtry_;
00916             }
00917         }
00918 
00919         // did not find any suitable split
00920         if(closeAtTolerance(current_min_gini, region_gini_))
00921             return  makeTerminalNode(features, labels, region, randint);
00922         
00923         //create a Node for output
00924         Node<i_ThresholdNode>   node(SB::t_data, SB::p_data);
00925         SB::node_ = node;
00926         node.threshold()    = min_thresholds_[bestSplitIndex];
00927         node.column()       = splitColumns[bestSplitIndex];
00928         
00929         // partition the range according to the best dimension 
00930         SortSamplesByDimensions<MultiArrayView<2, T, C> > 
00931             sorter(features, node.column(), node.threshold());
00932         IndexIterator bestSplit =
00933             std::partition(region.begin(), region.end(), sorter);
00934         // Save the ranges of the child stack entries.
00935         childRegions[0].setRange(   region.begin()  , bestSplit       );
00936         childRegions[0].rule = region.rule;
00937         childRegions[0].rule.push_back(std::make_pair(1, 1.0));
00938         childRegions[1].setRange(   bestSplit       , region.end()    );
00939         childRegions[1].rule = region.rule;
00940         childRegions[1].rule.push_back(std::make_pair(1, 1.0));
00941 
00942         return i_ThresholdNode;
00943     }
00944 };
00945 
00946 typedef  ThresholdSplit<BestGiniOfColumn<GiniCriterion> >                    GiniSplit;
00947 typedef  ThresholdSplit<BestGiniOfColumn<EntropyCriterion> >                 EntropySplit;
00948 typedef  ThresholdSplit<BestGiniOfColumn<LSQLoss>, RegressionTag>            RegressionSplit;
00949 
00950 namespace rf
00951 {
00952 
00953 /** This namespace contains additional Splitfunctors.
00954  *
00955  * The Split functor classes are designed in a modular fashion because new split functors may 
00956  * share a lot of code with existing ones. 
00957  * 
00958  * ThresholdSplit implements the functionality needed for any split functor, that makes its 
00959  * decision via one dimensional axis-parallel cuts. The Template parameter defines how the split
00960  * along one dimension is chosen. 
00961  *
00962  * The BestGiniOfColumn class chooses a split that minimizes one of the Loss functions supplied
00963  * (GiniCriterion for classification and LSQLoss for regression). Median chooses the Split in a 
00964  * kD tree fashion. 
00965  *
00966  *
00967  * Currently defined typedefs: 
00968  * \code
00969  * typedef  ThresholdSplit<BestGiniOfColumn<GiniCriterion> >                 GiniSplit;
00970  * typedef  ThresholdSplit<BestGiniOfColumn<LSQLoss>, RegressionTag>         RegressionSplit;
00971  * typedef  ThresholdSplit<Median> MedianSplit;
00972  * \endcode
00973  */
00974 namespace split
00975 {
00976 
00977 /** This Functor chooses the median value of a column
00978  */
00979 class Median
00980 {
00981 public:
00982 
00983     typedef GiniCriterion   LineSearchLossTag;
00984     ArrayVector<double>     class_weights_;
00985     ArrayVector<double>     bestCurrentCounts[2];
00986     double                  min_gini_;
00987     ptrdiff_t               min_index_;
00988     double                  min_threshold_;
00989     ProblemSpec<>           ext_param_;
00990 
00991     Median()
00992     {}
00993 
00994     template<class T> 
00995     Median(ProblemSpec<T> const & ext)
00996     :
00997         class_weights_(ext.class_weights_),
00998         ext_param_(ext)
00999     {
01000         bestCurrentCounts[0].resize(ext.class_count_);
01001         bestCurrentCounts[1].resize(ext.class_count_);
01002     }
01003   
01004     template<class T> 
01005     void set_external_parameters(ProblemSpec<T> const & ext)
01006     {
01007         class_weights_ = ext.class_weights_; 
01008         ext_param_ = ext;
01009         bestCurrentCounts[0].resize(ext.class_count_);
01010         bestCurrentCounts[1].resize(ext.class_count_);
01011     }
01012      
01013     template<   class DataSourceF_t,
01014                 class DataSource_t, 
01015                 class I_Iter, 
01016                 class Array>
01017     void operator()(DataSourceF_t   const & column,
01018                     DataSource_t    const & labels,
01019                     I_Iter                & begin, 
01020                     I_Iter                & end,
01021                     Array           const & region_response)
01022     {
01023         std::sort(begin, end, 
01024                   SortSamplesByDimensions<DataSourceF_t>(column, 0));
01025         typedef typename 
01026             LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
01027         LineSearchLoss left(labels, ext_param_);
01028         LineSearchLoss right(labels, ext_param_);
01029         right.init(begin, end, region_response);
01030 
01031         min_gini_ = NumericTraits<double>::max();
01032         min_index_ = floor(double(end - begin)/2.0); 
01033         min_threshold_ =  column[*(begin + min_index_)];
01034         SortSamplesByDimensions<DataSourceF_t> 
01035             sorter(column, 0, min_threshold_);
01036         I_Iter part = std::partition(begin, end, sorter);
01037         DimensionNotEqual<DataSourceF_t> comp(column, 0); 
01038         if(part == begin)
01039         {
01040             part= std::adjacent_find(part, end, comp)+1;
01041             
01042         }
01043         if(part >= end)
01044         {
01045             return; 
01046         }
01047         else
01048         {
01049             min_threshold_ = column[*part];
01050         }
01051         min_gini_ = right.decrement(begin, part) 
01052               +     left.increment(begin , part);
01053 
01054         bestCurrentCounts[0] = left.response();
01055         bestCurrentCounts[1] = right.response();
01056         
01057         min_index_      = part - begin;
01058     }
01059 
01060     template<class DataSource_t, class Iter, class Array>
01061     double loss_of_region(DataSource_t const & labels,
01062                           Iter & begin, 
01063                           Iter & end, 
01064                           Array const & region_response) const
01065     {
01066         typedef typename 
01067             LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
01068         LineSearchLoss region_loss(labels, ext_param_);
01069         return 
01070             region_loss.init(begin, end, region_response);
01071     }
01072 
01073 };
01074 
01075 typedef  ThresholdSplit<Median> MedianSplit;
01076 
01077 
01078 /** This Functor chooses a random value of a column
01079  */
01080 class RandomSplitOfColumn
01081 {
01082 public:
01083 
01084     typedef GiniCriterion   LineSearchLossTag;
01085     ArrayVector<double>     class_weights_;
01086     ArrayVector<double>     bestCurrentCounts[2];
01087     double                  min_gini_;
01088     ptrdiff_t               min_index_;
01089     double                  min_threshold_;
01090     ProblemSpec<>           ext_param_;
01091     typedef RandomMT19937   Random_t;
01092     Random_t                random;
01093 
01094     RandomSplitOfColumn()
01095     {}
01096 
01097     template<class T> 
01098     RandomSplitOfColumn(ProblemSpec<T> const & ext)
01099     :
01100         class_weights_(ext.class_weights_),
01101         ext_param_(ext),
01102         random(RandomSeed)
01103     {
01104         bestCurrentCounts[0].resize(ext.class_count_);
01105         bestCurrentCounts[1].resize(ext.class_count_);
01106     }
01107     
01108     template<class T> 
01109     RandomSplitOfColumn(ProblemSpec<T> const & ext, Random_t & random_)
01110     :
01111         class_weights_(ext.class_weights_),
01112         ext_param_(ext),
01113         random(random_)
01114     {
01115         bestCurrentCounts[0].resize(ext.class_count_);
01116         bestCurrentCounts[1].resize(ext.class_count_);
01117     }
01118   
01119     template<class T> 
01120     void set_external_parameters(ProblemSpec<T> const & ext)
01121     {
01122         class_weights_ = ext.class_weights_; 
01123         ext_param_ = ext;
01124         bestCurrentCounts[0].resize(ext.class_count_);
01125         bestCurrentCounts[1].resize(ext.class_count_);
01126     }
01127      
01128     template<   class DataSourceF_t,
01129                 class DataSource_t, 
01130                 class I_Iter, 
01131                 class Array>
01132     void operator()(DataSourceF_t   const & column,
01133                     DataSource_t    const & labels,
01134                     I_Iter                & begin, 
01135                     I_Iter                & end,
01136                     Array           const & region_response)
01137     {
01138         std::sort(begin, end, 
01139                   SortSamplesByDimensions<DataSourceF_t>(column, 0));
01140         typedef typename 
01141             LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
01142         LineSearchLoss left(labels, ext_param_);
01143         LineSearchLoss right(labels, ext_param_);
01144         right.init(begin, end, region_response);
01145 
01146         
01147         min_gini_ = NumericTraits<double>::max();
01148         
01149         min_index_ = begin + random.uniformInt(end -begin);
01150         min_threshold_ =  column[*(begin + min_index_)];
01151         SortSamplesByDimensions<DataSourceF_t> 
01152             sorter(column, 0, min_threshold_);
01153         I_Iter part = std::partition(begin, end, sorter);
01154         DimensionNotEqual<DataSourceF_t> comp(column, 0); 
01155         if(part == begin)
01156         {
01157             part= std::adjacent_find(part, end, comp)+1;
01158             
01159         }
01160         if(part >= end)
01161         {
01162             return; 
01163         }
01164         else
01165         {
01166             min_threshold_ = column[*part];
01167         }
01168         min_gini_ = right.decrement(begin, part) 
01169               +     left.increment(begin , part);
01170 
01171         bestCurrentCounts[0] = left.response();
01172         bestCurrentCounts[1] = right.response();
01173         
01174         min_index_      = part - begin;
01175     }
01176 
01177     template<class DataSource_t, class Iter, class Array>
01178     double loss_of_region(DataSource_t const & labels,
01179                           Iter & begin, 
01180                           Iter & end, 
01181                           Array const & region_response) const
01182     {
01183         typedef typename 
01184             LossTraits<LineSearchLossTag, DataSource_t>::type LineSearchLoss;
01185         LineSearchLoss region_loss(labels, ext_param_);
01186         return 
01187             region_loss.init(begin, end, region_response);
01188     }
01189 
01190 };
01191 
01192 typedef  ThresholdSplit<RandomSplitOfColumn> RandomSplit;
01193 }
01194 }
01195 
01196 
01197 } //namespace vigra
01198 #endif // VIGRA_RANDOM_FOREST_SPLIT_HXX

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.7.1 (3 Dec 2010)