[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

vigra/random_forest.hxx VIGRA

00001 /************************************************************************/
00002 /*                                                                      */
00003 /*        Copyright 2008-2009 by  Ullrich Koethe and Rahul Nair         */
00004 /*                                                                      */
00005 /*    This file is part of the VIGRA computer vision library.           */
00006 /*    The VIGRA Website is                                              */
00007 /*        http://hci.iwr.uni-heidelberg.de/vigra/                       */
00008 /*    Please direct questions, bug reports, and contributions to        */
00009 /*        ullrich.koethe@iwr.uni-heidelberg.de    or                    */
00010 /*        vigra@informatik.uni-hamburg.de                               */
00011 /*                                                                      */
00012 /*    Permission is hereby granted, free of charge, to any person       */
00013 /*    obtaining a copy of this software and associated documentation    */
00014 /*    files (the "Software"), to deal in the Software without           */
00015 /*    restriction, including without limitation the rights to use,      */
00016 /*    copy, modify, merge, publish, distribute, sublicense, and/or      */
00017 /*    sell copies of the Software, and to permit persons to whom the    */
00018 /*    Software is furnished to do so, subject to the following          */
00019 /*    conditions:                                                       */
00020 /*                                                                      */
00021 /*    The above copyright notice and this permission notice shall be    */
00022 /*    included in all copies or substantial portions of the             */
00023 /*    Software.                                                         */
00024 /*                                                                      */
00025 /*    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND    */
00026 /*    EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES   */
00027 /*    OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND          */
00028 /*    NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT       */
00029 /*    HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,      */
00030 /*    WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING      */
00031 /*    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR     */
00032 /*    OTHER DEALINGS IN THE SOFTWARE.                                   */
00033 /*                                                                      */
00034 /************************************************************************/
00035 
00036 
00037 #ifndef VIGRA_RANDOM_FOREST_HXX
00038 #define VIGRA_RANDOM_FOREST_HXX
00039 
00040 #include <iostream>
00041 #include <algorithm>
00042 #include <map>
00043 #include <set>
00044 #include <list>
00045 #include <numeric>
00046 #include "mathutil.hxx"
00047 #include "array_vector.hxx"
00048 #include "sized_int.hxx"
00049 #include "matrix.hxx"
00050 #include "random.hxx"
00051 #include "functorexpression.hxx"
00052 #include "random_forest/rf_common.hxx"
00053 #include "random_forest/rf_nodeproxy.hxx"
00054 #include "random_forest/rf_split.hxx"
00055 #include "random_forest/rf_decisionTree.hxx"
00056 #include "random_forest/rf_visitors.hxx"
00057 #include "random_forest/rf_region.hxx"
00058 #include "sampling.hxx"
00059 #include "random_forest/rf_preprocessing.hxx"
00060 #include "random_forest/rf_online_prediction_set.hxx"
00061 #include "random_forest/rf_earlystopping.hxx"
00062 #include "random_forest/rf_ridge_split.hxx"
00063 namespace vigra
00064 {
00065 
00066 /** \addtogroup MachineLearning Machine Learning
00067 
00068     This module provides classification algorithms that map 
00069     features to labels or label probabilities.
00070     Look at the RandomForest class first for a overview of most of the 
00071     functionality provided as well as use cases. 
00072 **/
00073 //@{
00074 
00075 namespace detail
00076 {
00077 
00078 
00079 
00080 /* \brief sampling option factory function
00081  */
00082 inline SamplerOptions make_sampler_opt ( RandomForestOptions     & RF_opt)
00083 {
00084     SamplerOptions return_opt;
00085     return_opt.withReplacement(RF_opt.sample_with_replacement_);
00086     return_opt.stratified(RF_opt.stratification_method_ == RF_EQUAL);
00087     return return_opt;
00088 }
00089 }//namespace detail
00090 
00091 /** Random Forest class
00092  *
00093  * \tparam <PrprocessorTag = ClassificationTag> Class used to preprocess
00094  *          the input while learning and predicting. Currently Available:
00095  *          ClassificationTag and RegressionTag. It is recommended to use
00096  *          Splitfunctor::Preprocessor_t while using custom splitfunctors
00097  *          as they may need the data to be in a different format. 
00098  *          \sa Preprocessor
00099  *  
00100  *  simple usage for classification (regression is not yet supported):
00101  *  look at RandomForest::learn() as well as RandomForestOptions() for additional
00102  *  options. 
00103  *
00104  *  \code
00105  *  using namespace vigra;
00106  *  using namespace rf;
00107  *  typedef xxx feature_t; \\ replace xxx with whichever type
00108  *  typedef yyy label_t;   \\ likewise 
00109  *  
00110  *  // allocate the training data
00111  *  MultiArrayView<2, feature_t> f = get_training_features();
00112  *  MultiArrayView<2, label_t>   l = get_training_labels();
00113  *  
00114  *  RandomForest<> rf;
00115  *
00116  *  // construct visitor to calculate out-of-bag error
00117  *  visitors::OOB_Error oob_v;
00118  *
00119  *  // perform training
00120  *  rf.learn(f, l, visitors::create_visitor(oob_v));
00121  *
00122  *  std::cout << "the out-of-bag error is: " << oob_v.oob_breiman << "\n";
00123  *      
00124  *  // get features for new data to be used for prediction
00125  *  MultiArrayView<2, feature_t> pf = get_features();
00126  *
00127  *  // allocate space for the response (pf.shape(0) is the number of samples)
00128  *  MultiArrayView<2, label_t> prediction(pf.shape(0), 1);
00129  *  MultiArrayView<2, double> prob(pf.shape(0), rf.class_count());
00130  *      
00131  *  // perform prediction on new data
00132  *  rf.predict_labels(pf, prediction);
00133  *  rf.predict_probabilities(pf, prob);
00134  *
00135  *  \endcode
00136  *
00137  *  Additional information such as Variable Importance measures are accessed
00138  *  via Visitors defined in rf::visitors. 
00139  *  Have a look at rf::split for other splitting methods.
00140  *
00141 */
00142 template <class LabelType = double , class PreprocessorTag = ClassificationTag >
00143 class RandomForest
00144 {
00145 
00146   public:
00147     //public typedefs
00148     typedef RandomForestOptions             Options_t;
00149     typedef detail::DecisionTree            DecisionTree_t;
00150     typedef ProblemSpec<LabelType>          ProblemSpec_t;
00151     typedef GiniSplit                       Default_Split_t;
00152     typedef EarlyStoppStd                   Default_Stop_t;
00153     typedef rf::visitors::StopVisiting      Default_Visitor_t;
00154     typedef  DT_StackEntry<ArrayVectorView<Int32>::iterator>
00155                     StackEntry_t;
00156     typedef LabelType                       LabelT; 
00157   protected:
00158 
00159     /** optimisation for predictLabels
00160      * */
00161     mutable MultiArray<2, double> garbage_prediction_;
00162 
00163   public:
00164 
00165     //problem independent data.
00166     Options_t                                   options_;
00167     //problem dependent data members - is only set if
00168     //a copy constructor, some sort of import
00169     //function or the learn function is called
00170     ArrayVector<DecisionTree_t>                 trees_;
00171     ProblemSpec_t                               ext_param_;
00172     /*mutable ArrayVector<int>                    tree_indices_;*/
00173     rf::visitors::OnlineLearnVisitor            online_visitor_;
00174 
00175 
00176     void reset()
00177     {
00178         ext_param_.clear();
00179         trees_.clear();
00180     }
00181 
00182   public:
00183 
00184     /** \name Constructors
00185      * Note: No copy Constructor specified as no pointers are manipulated
00186      * in this class
00187      */
00188     /*\{*/
00189     /**\brief default constructor
00190      *
00191      * \param options   general options to the Random Forest. Must be of Type
00192      *                  Options_t
00193      * \param ext_param problem specific values that can be supplied 
00194      *                  additionally. (class weights , labels etc)
00195      * \sa  RandomForestOptions, ProblemSpec
00196      *
00197      */
00198     RandomForest(Options_t const & options = Options_t(), 
00199                  ProblemSpec_t const & ext_param = ProblemSpec_t())
00200     :
00201         options_(options),
00202         ext_param_(ext_param)/*,
00203         tree_indices_(options.tree_count_,0)*/
00204     {
00205         /*for(int ii = 0 ; ii < int(tree_indices_.size()); ++ii)
00206             tree_indices_[ii] = ii;*/
00207     }
00208 
00209     /**\brief Create RF from external source
00210      * \param treeCount Number of trees to add.
00211      * \param topology_begin     
00212      *                  Iterator to a Container where the topology_ data
00213      *                  of the trees are stored.
00214      *                  Iterator should support at least treeCount forward 
00215      *                  iterations. (i.e. topology_end - topology_begin >= treeCount
00216      * \param parameter_begin  
00217      *                  iterator to a Container where the parameters_ data
00218      *                  of the trees are stored. Iterator should support at 
00219      *                  least treeCount forward iterations.
00220      * \param problem_spec 
00221      *                  Extrinsic parameters that specify the problem e.g.
00222      *                  ClassCount, featureCount etc.
00223      * \param options   (optional) specify options used to train the original
00224      *                  Random forest. This parameter is not used anywhere
00225      *                  during prediction and thus is optional.
00226      *
00227      */
00228      /* TODO: This constructor may be replaced by a Constructor using
00229      * NodeProxy iterators to encapsulate the underlying data type.
00230      */
00231     template<class TopologyIterator, class ParameterIterator>
00232     RandomForest(int                       treeCount,
00233                   TopologyIterator         topology_begin,
00234                   ParameterIterator        parameter_begin,
00235                   ProblemSpec_t const & problem_spec,
00236                   Options_t const &     options = Options_t())
00237     :
00238         trees_(treeCount, DecisionTree_t(problem_spec)),
00239         ext_param_(problem_spec),
00240         options_(options)
00241     {
00242         for(unsigned int k=0; k<treeCount; ++k, ++topology_begin, ++parameter_begin)
00243         {
00244             trees_[k].topology_ = *topology_begin;
00245             trees_[k].parameters_ = *parameter_begin;
00246         }
00247     }
00248 
00249     /*\}*/
00250 
00251 
00252     /** \name Data Access
00253      * data access interface - usage of member variables is deprecated
00254      */
00255 
00256     /*\{*/
00257 
00258 
00259     /**\brief return external parameters for viewing
00260      * \return ProblemSpec_t
00261      */
00262     ProblemSpec_t const & ext_param() const
00263     {
00264         vigra_precondition(ext_param_.used() == true,
00265            "RandomForest::ext_param(): "
00266            "Random forest has not been trained yet.");
00267         return ext_param_;
00268     }
00269 
00270     /**\brief set external parameters
00271      *
00272      *  \param in external parameters to be set
00273      *
00274      * set external parameters explicitly. 
00275      * If Random Forest has not been trained the preprocessor will 
00276      * either ignore filling values set this way or will throw an exception 
00277      * if values specified manually do not match the value calculated 
00278      & during the preparation step.
00279      */
00280     void set_ext_param(ProblemSpec_t const & in)
00281     {
00282         vigra_precondition(ext_param_.used() == false,
00283             "RandomForest::set_ext_param():"
00284             "Random forest has been trained! Call reset()"
00285             "before specifying new extrinsic parameters.");
00286     }
00287 
00288     /**\brief access random forest options
00289      *
00290      * \return random forest options
00291      */
00292     Options_t & set_options()
00293     {
00294         return options;
00295     }
00296 
00297 
00298     /**\brief access const random forest options
00299      *
00300      * \return const Option_t
00301      */
00302     Options_t const & options() const
00303     {
00304         return options_;
00305     }
00306 
00307     /**\brief access const trees
00308      */
00309     DecisionTree_t const & tree(int index) const
00310     {
00311         return trees_[index];
00312     }
00313 
00314     /**\brief access trees
00315      */
00316     DecisionTree_t & tree(int index)
00317     {
00318         return trees_[index];
00319     }
00320 
00321     /*\}*/
00322 
00323     /**\brief return number of features used while 
00324      * training.
00325      */
00326     int feature_count() const
00327     {
00328       return ext_param_.column_count_;
00329     }
00330     
00331     
00332     /**\brief return number of features used while 
00333      * training.
00334      *
00335      * deprecated. Use feature_count() instead.
00336      */
00337     int column_count() const
00338     {
00339       return ext_param_.column_count_;
00340     }
00341 
00342     /**\brief return number of classes used while 
00343      * training.
00344      */
00345     int class_count() const
00346     {
00347       return ext_param_.class_count_;
00348     }
00349 
00350     /**\brief return number of trees
00351      */
00352     int tree_count() const
00353     {
00354       return options_.tree_count_;
00355     }
00356 
00357 
00358     
00359     template<class U,class C1,
00360         class U2, class C2,
00361         class Split_t,
00362         class Stop_t,
00363         class Visitor_t,
00364         class Random_t>
00365     void onlineLearn(   MultiArrayView<2,U,C1> const & features,
00366                         MultiArrayView<2,U2,C2> const & response,
00367                         int new_start_index,
00368                         Visitor_t visitor_,
00369                         Split_t split_,
00370                         Stop_t stop_,
00371                         Random_t & random,
00372                         bool adjust_thresholds=false);
00373 
00374     template <class U, class C1, class U2,class C2>
00375     void onlineLearn(   MultiArrayView<2, U, C1> const  & features,
00376                         MultiArrayView<2, U2,C2> const  & labels,int new_start_index,bool adjust_thresholds=false)
00377     {
00378         RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed);
00379         onlineLearn(features, 
00380                     labels, 
00381                     new_start_index,
00382                     rf_default(), 
00383                     rf_default(), 
00384                     rf_default(),
00385                     rnd,
00386                     adjust_thresholds);
00387     }
00388 
00389     template<class U,class C1,
00390         class U2, class C2,
00391         class Split_t,
00392         class Stop_t,
00393         class Visitor_t,
00394         class Random_t>
00395     void reLearnTree(MultiArrayView<2,U,C1> const & features,
00396                      MultiArrayView<2,U2,C2> const & response,
00397                      int treeId,
00398                      Visitor_t visitor_,
00399                      Split_t split_,
00400                      Stop_t stop_,
00401                      Random_t & random);
00402 
00403     template<class U, class C1, class U2, class C2>
00404     void reLearnTree(MultiArrayView<2, U, C1> const & features,
00405                      MultiArrayView<2, U2, C2> const & labels,
00406                      int treeId)
00407     {
00408         RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed);
00409         reLearnTree(features,
00410                     labels,
00411                     treeId,
00412                     rf_default(),
00413                     rf_default(),
00414                     rf_default(),
00415                     rnd);
00416     }
00417 
00418 
00419     /**\name Learning
00420      * Following functions differ in the degree of customization
00421      * allowed
00422      */
00423     /*\{*/
00424     /**\brief learn on data with custom config and random number generator
00425      *
00426      * \param features  a N x M matrix containing N samples with M
00427      *                  features
00428      * \param response  a N x D matrix containing the corresponding
00429      *                  response. Current split functors assume D to
00430      *                  be 1 and ignore any additional columns.
00431      *                  This is not enforced to allow future support
00432      *                  for uncertain labels, label independent strata etc.
00433      *                  The Preprocessor specified during construction
00434      *                  should be able to handle features and labels
00435      *                  features and the labels.
00436      *                  see also: SplitFunctor, Preprocessing
00437      *
00438      * \param visitor   visitor which is to be applied after each split,
00439      *                  tree and at the end. Use rf_default for using
00440      *                  default value. (No Visitors)
00441      *                  see also: rf::visitors
00442      * \param split     split functor to be used to calculate each split
00443      *                  use rf_default() for using default value. (GiniSplit)
00444      *                  see also:  rf::split 
00445      * \param stop
00446      *                  predicate to be used to calculate each split
00447      *                  use rf_default() for using default value. (EarlyStoppStd)
00448      * \param random    RandomNumberGenerator to be used. Use
00449      *                  rf_default() to use default value.(RandomMT19337)
00450      *
00451      *
00452      */
00453     template <class U, class C1,
00454              class U2,class C2,
00455              class Split_t,
00456              class Stop_t,
00457              class Visitor_t,
00458              class Random_t>
00459     void learn( MultiArrayView<2, U, C1> const  &   features,
00460                 MultiArrayView<2, U2,C2> const  &   response,
00461                 Visitor_t                           visitor,
00462                 Split_t                             split,
00463                 Stop_t                              stop,
00464                 Random_t                 const  &   random);
00465 
00466     template <class U, class C1,
00467              class U2,class C2,
00468              class Split_t,
00469              class Stop_t,
00470              class Visitor_t>
00471     void learn( MultiArrayView<2, U, C1> const  &   features,
00472                 MultiArrayView<2, U2,C2> const  &   response,
00473                 Visitor_t                           visitor,
00474                 Split_t                             split,
00475                 Stop_t                              stop)
00476 
00477     {
00478         RandomNumberGenerator<> rnd = RandomNumberGenerator<>(RandomSeed);
00479         learn(  features, 
00480                 response,
00481                 visitor, 
00482                 split, 
00483                 stop,
00484                 rnd);
00485     }
00486 
00487     template <class U, class C1, class U2,class C2, class Visitor_t>
00488     void learn( MultiArrayView<2, U, C1> const  & features,
00489                 MultiArrayView<2, U2,C2> const  & labels,
00490                 Visitor_t                         visitor)
00491     {
00492         learn(  features, 
00493                 labels, 
00494                 visitor, 
00495                 rf_default(), 
00496                 rf_default());
00497     }
00498 
00499     template <class U, class C1, class U2,class C2, 
00500               class Visitor_t, class Split_t>
00501     void learn(   MultiArrayView<2, U, C1> const  & features,
00502                   MultiArrayView<2, U2,C2> const  & labels,
00503                   Visitor_t                         visitor,
00504                   Split_t                           split)
00505     {
00506         learn(  features, 
00507                 labels, 
00508                 visitor, 
00509                 split, 
00510                 rf_default());
00511     }
00512 
00513     /**\brief learn on data with default configuration
00514      *
00515      * \param features  a N x M matrix containing N samples with M
00516      *                  features
00517      * \param labels    a N x D matrix containing the corresponding
00518      *                  N labels. Current split functors assume D to
00519      *                  be 1 and ignore any additional columns.
00520      *                  this is not enforced to allow future support
00521      *                  for uncertain labels.
00522      *
00523      * learning is done with:
00524      *
00525      * \sa rf::split, EarlyStoppStd
00526      *
00527      * - Randomly seeded random number generator
00528      * - default gini split functor as described by Breiman
00529      * - default The standard early stopping criterion
00530      */
00531     template <class U, class C1, class U2,class C2>
00532     void learn(   MultiArrayView<2, U, C1> const  & features,
00533                     MultiArrayView<2, U2,C2> const  & labels)
00534     {
00535         learn(  features, 
00536                 labels, 
00537                 rf_default(), 
00538                 rf_default(), 
00539                 rf_default());
00540     }
00541     /*\}*/
00542 
00543 
00544 
00545     /**\name prediction
00546      */
00547     /*\{*/
00548     /** \brief predict a label given a feature.
00549      *
00550      * \param features: a 1 by featureCount matrix containing
00551      *        data point to be predicted (this only works in
00552      *        classification setting)
00553      * \param stop: early stopping criterion
00554      * \return double value representing class. You can use the
00555      *         predictLabels() function together with the
00556      *         rf.external_parameter().class_type_ attribute
00557      *         to get back the same type used during learning. 
00558      */
00559     template <class U, class C, class Stop>
00560     LabelType predictLabel(MultiArrayView<2, U, C>const & features, Stop & stop) const;
00561 
00562     template <class U, class C>
00563     LabelType predictLabel(MultiArrayView<2, U, C>const & features)
00564     {
00565         return predictLabel(features, rf_default()); 
00566     } 
00567     /** \brief predict a label with features and class priors
00568      *
00569      * \param features: same as above.
00570      * \param prior:   iterator to prior weighting of classes
00571      * \return sam as above.
00572      */
00573     template <class U, class C>
00574     LabelType predictLabel(MultiArrayView<2, U, C> const & features,
00575                                 ArrayVectorView<double> prior) const;
00576 
00577     /** \brief predict multiple labels with given features
00578      *
00579      * \param features: a n by featureCount matrix containing
00580      *        data point to be predicted (this only works in
00581      *        classification setting)
00582      * \param labels: a n by 1 matrix passed by reference to store
00583      *        output.
00584      */
00585     template <class U, class C1, class T, class C2>
00586     void predictLabels(MultiArrayView<2, U, C1>const & features,
00587                        MultiArrayView<2, T, C2> & labels) const
00588     {
00589         vigra_precondition(features.shape(0) == labels.shape(0),
00590             "RandomForest::predictLabels(): Label array has wrong size.");
00591         for(int k=0; k<features.shape(0); ++k)
00592             labels(k,0) = detail::RequiresExplicitCast<T>::cast(predictLabel(rowVector(features, k), rf_default()));
00593     }
00594 
00595     template <class U, class C1, class T, class C2, class Stop>
00596     void predictLabels(MultiArrayView<2, U, C1>const & features,
00597                        MultiArrayView<2, T, C2> & labels,
00598                        Stop                     & stop) const
00599     {
00600         vigra_precondition(features.shape(0) == labels.shape(0),
00601             "RandomForest::predictLabels(): Label array has wrong size.");
00602         for(int k=0; k<features.shape(0); ++k)
00603             labels(k,0) = detail::RequiresExplicitCast<T>::cast(predictLabel(rowVector(features, k), stop));
00604     }
00605     /** \brief predict the class probabilities for multiple labels
00606      *
00607      *  \param features same as above
00608      *  \param prob a n x class_count_ matrix. passed by reference to
00609      *  save class probabilities
00610      *  \param stop earlystopping criterion
00611      *  \sa EarlyStopping
00612      */
00613     template <class U, class C1, class T, class C2, class Stop>
00614     void predictProbabilities(MultiArrayView<2, U, C1>const &   features,
00615                               MultiArrayView<2, T, C2> &        prob,
00616                               Stop                     &        stop) const;
00617     template <class T1,class T2, class C>
00618     void predictProbabilities(OnlinePredictionSet<T1> &  predictionSet,
00619                                MultiArrayView<2, T2, C> &       prob);
00620 
00621     /** \brief predict the class probabilities for multiple labels
00622      *
00623      *  \param features same as above
00624      *  \param prob a n x class_count_ matrix. passed by reference to
00625      *  save class probabilities
00626      */
00627     template <class U, class C1, class T, class C2>
00628     void predictProbabilities(MultiArrayView<2, U, C1>const &   features,
00629                               MultiArrayView<2, T, C2> &        prob)  const
00630     {
00631         predictProbabilities(features, prob, rf_default()); 
00632     }   
00633 
00634     template <class U, class C1, class T, class C2>
00635     void predictRaw(MultiArrayView<2, U, C1>const &   features,
00636                     MultiArrayView<2, T, C2> &        prob)  const;
00637 
00638 
00639     /*\}*/
00640 
00641 };
00642 
00643 
00644 template <class LabelType, class PreprocessorTag>
00645 template<class U,class C1,
00646     class U2, class C2,
00647     class Split_t,
00648     class Stop_t,
00649     class Visitor_t,
00650     class Random_t>
00651 void RandomForest<LabelType, PreprocessorTag>::onlineLearn(MultiArrayView<2,U,C1> const & features,
00652                                                              MultiArrayView<2,U2,C2> const & response,
00653                                                              int new_start_index,
00654                                                              Visitor_t visitor_,
00655                                                              Split_t split_,
00656                                                              Stop_t stop_,
00657                                                              Random_t & random,
00658                                                              bool adjust_thresholds)
00659 {
00660     online_visitor_.activate();
00661     online_visitor_.adjust_thresholds=adjust_thresholds;
00662 
00663     using namespace rf;
00664     //typedefs
00665     typedef Processor<PreprocessorTag,LabelType,U,C1,U2,C2> Preprocessor_t;
00666     typedef          UniformIntRandomFunctor<Random_t>
00667                                                     RandFunctor_t;
00668     // default values and initialization
00669     // Value Chooser chooses second argument as value if first argument
00670     // is of type RF_DEFAULT. (thanks to template magic - don't care about
00671     // it - just smile and wave.
00672     
00673     #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_> 
00674     Default_Stop_t default_stop(options_);
00675     typename RF_CHOOSER(Stop_t)::type stop
00676             = RF_CHOOSER(Stop_t)::choose(stop_, default_stop); 
00677     Default_Split_t default_split;
00678     typename RF_CHOOSER(Split_t)::type split 
00679             = RF_CHOOSER(Split_t)::choose(split_, default_split); 
00680     rf::visitors::StopVisiting stopvisiting;
00681     typedef  rf::visitors::detail::VisitorNode
00682                 <rf::visitors::OnlineLearnVisitor, 
00683                  typename RF_CHOOSER(Visitor_t)::type> 
00684                                                         IntermedVis; 
00685     IntermedVis
00686         visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
00687     #undef RF_CHOOSER
00688 
00689     // Preprocess the data to get something the split functor can work
00690     // with. Also fill the ext_param structure by preprocessing
00691     // option parameters that could only be completely evaluated
00692     // when the training data is known.
00693     ext_param_.class_count_=0;
00694     Preprocessor_t preprocessor(    features, response,
00695                                     options_, ext_param_);
00696 
00697     // Make stl compatible random functor.
00698     RandFunctor_t           randint     ( random);
00699 
00700     // Give the Split functor information about the data.
00701     split.set_external_parameters(ext_param_);
00702     stop.set_external_parameters(ext_param_);
00703 
00704 
00705     //Create poisson samples
00706     PoissonSampler<RandomTT800> poisson_sampler(1.0,vigra::Int32(new_start_index),vigra::Int32(ext_param().row_count_));
00707 
00708     //TODO: visitors for online learning
00709     //visitor.visit_at_beginning(*this, preprocessor);
00710 
00711     // THE MAIN EFFING RF LOOP - YEAY DUDE!
00712     for(int ii = 0; ii < (int)trees_.size(); ++ii)
00713     {
00714         online_visitor_.tree_id=ii;
00715         poisson_sampler.sample();
00716         std::map<int,int> leaf_parents;
00717         leaf_parents.clear();
00718         //Get all the leaf nodes for that sample
00719         for(int s=0;s<poisson_sampler.numOfSamples();++s)
00720         {
00721             int sample=poisson_sampler[s];
00722             online_visitor_.current_label=preprocessor.response()(sample,0);
00723             online_visitor_.last_node_id=StackEntry_t::DecisionTreeNoParent;
00724             int leaf=trees_[ii].getToLeaf(rowVector(features,sample),online_visitor_);
00725 
00726 
00727             //Add to the list for that leaf
00728             online_visitor_.add_to_index_list(ii,leaf,sample);
00729             //TODO: Class count?
00730             //Store parent
00731             if(Node<e_ConstProbNode>(trees_[ii].topology_,trees_[ii].parameters_,leaf).prob_begin()[preprocessor.response()(sample,0)]!=1.0)
00732             {
00733                 leaf_parents[leaf]=online_visitor_.last_node_id;
00734             }
00735         }
00736 
00737 
00738         std::map<int,int>::iterator leaf_iterator;
00739         for(leaf_iterator=leaf_parents.begin();leaf_iterator!=leaf_parents.end();++leaf_iterator)
00740         {
00741             int leaf=leaf_iterator->first;
00742             int parent=leaf_iterator->second;
00743             int lin_index=online_visitor_.trees_online_information[ii].exterior_to_index[leaf];
00744             ArrayVector<Int32> indeces;
00745             indeces.clear();
00746             indeces.swap(online_visitor_.trees_online_information[ii].index_lists[lin_index]);
00747             StackEntry_t stack_entry(indeces.begin(),
00748                                      indeces.end(),
00749                                      ext_param_.class_count_);
00750 
00751 
00752             if(parent!=-1)
00753             {
00754                 if(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(0)==leaf)
00755                 {
00756                     stack_entry.leftParent=parent;
00757                 }
00758                 else
00759                 {
00760                     vigra_assert(NodeBase(trees_[ii].topology_,trees_[ii].parameters_,parent).child(1)==leaf,"last_node_id seems to be wrong");
00761                     stack_entry.rightParent=parent;
00762                 }
00763             }
00764             //trees_[ii].continueLearn(preprocessor.features(),preprocessor.response(),stack_entry,split,stop,visitor,randint,leaf);
00765             trees_[ii].continueLearn(preprocessor.features(),preprocessor.response(),stack_entry,split,stop,visitor,randint,-1);
00766             //Now, the last one moved onto leaf
00767             online_visitor_.move_exterior_node(ii,trees_[ii].topology_.size(),ii,leaf);
00768             //Now it should be classified correctly!
00769         }
00770 
00771         /*visitor
00772             .visit_after_tree(  *this,
00773                                 preprocessor,
00774                                 poisson_sampler,
00775                                 stack_entry,
00776                                 ii);*/
00777     }
00778 
00779     //visitor.visit_at_end(*this, preprocessor);
00780     online_visitor_.deactivate();
00781 }
00782 
00783 template<class LabelType, class PreprocessorTag>
00784 template<class U,class C1,
00785     class U2, class C2,
00786     class Split_t,
00787     class Stop_t,
00788     class Visitor_t,
00789     class Random_t>
00790 void RandomForest<LabelType, PreprocessorTag>::reLearnTree(MultiArrayView<2,U,C1> const & features,
00791                  MultiArrayView<2,U2,C2> const & response,
00792                  int treeId,
00793                  Visitor_t visitor_,
00794                  Split_t split_,
00795                  Stop_t stop_,
00796                  Random_t & random)
00797 {
00798     using namespace rf;
00799     
00800     
00801     typedef          UniformIntRandomFunctor<Random_t>
00802                                                     RandFunctor_t;
00803 
00804     // See rf_preprocessing.hxx for more info on this
00805     ext_param_.class_count_=0;
00806     typedef Processor<PreprocessorTag,LabelType, U, C1, U2, C2> Preprocessor_t;
00807     
00808     // default values and initialization
00809     // Value Chooser chooses second argument as value if first argument
00810     // is of type RF_DEFAULT. (thanks to template magic - don't care about
00811     // it - just smile and wave.
00812     
00813     #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_> 
00814     Default_Stop_t default_stop(options_);
00815     typename RF_CHOOSER(Stop_t)::type stop
00816             = RF_CHOOSER(Stop_t)::choose(stop_, default_stop); 
00817     Default_Split_t default_split;
00818     typename RF_CHOOSER(Split_t)::type split 
00819             = RF_CHOOSER(Split_t)::choose(split_, default_split); 
00820     rf::visitors::StopVisiting stopvisiting;
00821     typedef  rf::visitors::detail::VisitorNode
00822                 <rf::visitors::OnlineLearnVisitor, 
00823                 typename RF_CHOOSER(Visitor_t)::type> IntermedVis; 
00824     IntermedVis
00825         visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
00826     #undef RF_CHOOSER
00827     vigra_precondition(options_.prepare_online_learning_,"reLearnTree: Re learning trees only makes sense, if online learning is enabled");
00828     online_visitor_.activate();
00829 
00830     // Make stl compatible random functor.
00831     RandFunctor_t           randint     ( random);
00832 
00833     // Preprocess the data to get something the split functor can work
00834     // with. Also fill the ext_param structure by preprocessing
00835     // option parameters that could only be completely evaluated
00836     // when the training data is known.
00837     Preprocessor_t preprocessor(    features, response,
00838                                     options_, ext_param_);
00839 
00840     // Give the Split functor information about the data.
00841     split.set_external_parameters(ext_param_);
00842     stop.set_external_parameters(ext_param_);
00843 
00844     /**\todo    replace this crappy class out. It uses function pointers.
00845      *          and is making code slower according to me.
00846      *          Comment from Nathan: This is copied from Rahul, so me=Rahul
00847      */
00848     Sampler<Random_t > sampler(preprocessor.strata().begin(),
00849                                preprocessor.strata().end(),
00850                                detail::make_sampler_opt(options_)
00851                                         .sampleSize(ext_param().actual_msample_),
00852                                     random);
00853     //initialize First region/node/stack entry
00854     sampler
00855         .sample();
00856 
00857     StackEntry_t
00858         first_stack_entry(  sampler.sampledIndices().begin(),
00859                             sampler.sampledIndices().end(),
00860                             ext_param_.class_count_);
00861     first_stack_entry
00862         .set_oob_range(     sampler.oobIndices().begin(),
00863                             sampler.oobIndices().end());
00864     online_visitor_.reset_tree(treeId);
00865     online_visitor_.tree_id=treeId;
00866     trees_[treeId].reset();
00867     trees_[treeId]
00868         .learn( preprocessor.features(),
00869                 preprocessor.response(),
00870                 first_stack_entry,
00871                 split,
00872                 stop,
00873                 visitor,
00874                 randint);
00875     visitor
00876         .visit_after_tree(  *this,
00877                             preprocessor,
00878                             sampler,
00879                             first_stack_entry,
00880                             treeId);
00881 
00882     online_visitor_.deactivate();
00883 }
00884 
00885 template <class LabelType, class PreprocessorTag>
00886 template <class U, class C1,
00887          class U2,class C2,
00888          class Split_t,
00889          class Stop_t,
00890          class Visitor_t,
00891          class Random_t>
00892 void RandomForest<LabelType, PreprocessorTag>::
00893                      learn( MultiArrayView<2, U, C1> const  &   features,
00894                             MultiArrayView<2, U2,C2> const  &   response,
00895                             Visitor_t                           visitor_,
00896                             Split_t                             split_,
00897                             Stop_t                              stop_,
00898                             Random_t                 const  &   random)
00899 {
00900     using namespace rf;
00901     //this->reset();
00902     //typedefs
00903     typedef          UniformIntRandomFunctor<Random_t>
00904                                                     RandFunctor_t;
00905 
00906     // See rf_preprocessing.hxx for more info on this
00907     typedef Processor<PreprocessorTag,LabelType, U, C1, U2, C2> Preprocessor_t;
00908 
00909     vigra_precondition(features.shape(0) == response.shape(0),
00910         "RandomForest::learn(): shape mismatch between features and response.");
00911     
00912     // default values and initialization
00913     // Value Chooser chooses second argument as value if first argument
00914     // is of type RF_DEFAULT. (thanks to template magic - don't care about
00915     // it - just smile and wave.
00916     
00917     #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_> 
00918     Default_Stop_t default_stop(options_);
00919     typename RF_CHOOSER(Stop_t)::type stop
00920             = RF_CHOOSER(Stop_t)::choose(stop_, default_stop); 
00921     Default_Split_t default_split;
00922     typename RF_CHOOSER(Split_t)::type split 
00923             = RF_CHOOSER(Split_t)::choose(split_, default_split); 
00924     rf::visitors::StopVisiting stopvisiting;
00925     typedef  rf::visitors::detail::VisitorNode<
00926                 rf::visitors::OnlineLearnVisitor, 
00927                 typename RF_CHOOSER(Visitor_t)::type> IntermedVis; 
00928     IntermedVis
00929         visitor(online_visitor_, RF_CHOOSER(Visitor_t)::choose(visitor_, stopvisiting));
00930     #undef RF_CHOOSER
00931     if(options_.prepare_online_learning_)
00932         online_visitor_.activate();
00933     else
00934         online_visitor_.deactivate();
00935 
00936 
00937     // Make stl compatible random functor.
00938     RandFunctor_t           randint     ( random);
00939 
00940 
00941     // Preprocess the data to get something the split functor can work
00942     // with. Also fill the ext_param structure by preprocessing
00943     // option parameters that could only be completely evaluated
00944     // when the training data is known.
00945     Preprocessor_t preprocessor(    features, response,
00946                                     options_, ext_param_);
00947 
00948     // Give the Split functor information about the data.
00949     split.set_external_parameters(ext_param_);
00950     stop.set_external_parameters(ext_param_);
00951 
00952 
00953     //initialize trees.
00954     trees_.resize(options_.tree_count_  , DecisionTree_t(ext_param_));
00955 
00956     Sampler<Random_t > sampler(preprocessor.strata().begin(),
00957                                preprocessor.strata().end(),
00958                                detail::make_sampler_opt(options_)
00959                                         .sampleSize(ext_param().actual_msample_),
00960                                     random);
00961 
00962     visitor.visit_at_beginning(*this, preprocessor);
00963     // THE MAIN EFFING RF LOOP - YEAY DUDE!
00964     
00965     for(int ii = 0; ii < (int)trees_.size(); ++ii)
00966     {
00967         //initialize First region/node/stack entry
00968         sampler
00969             .sample();  
00970         StackEntry_t
00971             first_stack_entry(  sampler.sampledIndices().begin(),
00972                                 sampler.sampledIndices().end(),
00973                                 ext_param_.class_count_);
00974         first_stack_entry
00975             .set_oob_range(     sampler.oobIndices().begin(),
00976                                 sampler.oobIndices().end());
00977         trees_[ii]
00978             .learn(             preprocessor.features(),
00979                                 preprocessor.response(),
00980                                 first_stack_entry,
00981                                 split,
00982                                 stop,
00983                                 visitor,
00984                                 randint);
00985         visitor
00986             .visit_after_tree(  *this,
00987                                 preprocessor,
00988                                 sampler,
00989                                 first_stack_entry,
00990                                 ii);
00991     }
00992 
00993     visitor.visit_at_end(*this, preprocessor);
00994     // Only for online learning?
00995     online_visitor_.deactivate();
00996 }
00997 
00998 
00999 
01000 
01001 template <class LabelType, class Tag>
01002 template <class U, class C, class Stop>
01003 LabelType RandomForest<LabelType, Tag>
01004     ::predictLabel(MultiArrayView<2, U, C> const & features, Stop & stop) const
01005 {
01006     vigra_precondition(columnCount(features) >= ext_param_.column_count_,
01007         "RandomForestn::predictLabel():"
01008             " Too few columns in feature matrix.");
01009     vigra_precondition(rowCount(features) == 1,
01010         "RandomForestn::predictLabel():"
01011             " Feature matrix must have a singlerow.");
01012     typedef MultiArrayShape<2>::type Shp;
01013     garbage_prediction_.reshape(Shp(1, ext_param_.class_count_), 0.0);
01014     LabelType          d;
01015     predictProbabilities(features, garbage_prediction_, stop);
01016     ext_param_.to_classlabel(argMax(garbage_prediction_), d);
01017     return d;
01018 }
01019 
01020 
01021 //Same thing as above with priors for each label !!!
01022 template <class LabelType, class PreprocessorTag>
01023 template <class U, class C>
01024 LabelType RandomForest<LabelType, PreprocessorTag>
01025     ::predictLabel( MultiArrayView<2, U, C> const & features,
01026                     ArrayVectorView<double> priors) const
01027 {
01028     using namespace functor;
01029     vigra_precondition(columnCount(features) >= ext_param_.column_count_,
01030         "RandomForestn::predictLabel(): Too few columns in feature matrix.");
01031     vigra_precondition(rowCount(features) == 1,
01032         "RandomForestn::predictLabel():"
01033         " Feature matrix must have a single row.");
01034     Matrix<double>  prob(1,ext_param_.class_count_);
01035     predictProbabilities(features, prob);
01036     std::transform( prob.begin(), prob.end(),
01037                     priors.begin(), prob.begin(),
01038                     Arg1()*Arg2());
01039     LabelType          d;
01040     ext_param_.to_classlabel(argMax(prob), d);
01041     return d;
01042 }
01043 
01044 template<class LabelType,class PreprocessorTag>
01045 template <class T1,class T2, class C>
01046 void RandomForest<LabelType,PreprocessorTag>
01047     ::predictProbabilities(OnlinePredictionSet<T1> &  predictionSet,
01048                           MultiArrayView<2, T2, C> &       prob)
01049 {
01050     //Features are n xp
01051     //prob is n x NumOfLabel probability for each feature in each class
01052     
01053     vigra_precondition(rowCount(predictionSet.features) == rowCount(prob),
01054                        "RandomFroest::predictProbabilities():"
01055                        " Feature matrix and probability matrix size mismatch.");
01056     // num of features must be bigger than num of features in Random forest training
01057     // but why bigger?
01058     vigra_precondition( columnCount(predictionSet.features) >= ext_param_.column_count_,
01059       "RandomForestn::predictProbabilities():"
01060         " Too few columns in feature matrix.");
01061     vigra_precondition( columnCount(prob)
01062                         == (MultiArrayIndex)ext_param_.class_count_,
01063       "RandomForestn::predictProbabilities():"
01064       " Probability matrix must have as many columns as there are classes.");
01065     prob.init(0.0);
01066     //store total weights
01067     std::vector<T1> totalWeights(predictionSet.indices[0].size(),0.0);
01068     //Go through all trees
01069     int set_id=-1;
01070     for(int k=0; k<options_.tree_count_; ++k)
01071     {
01072         set_id=(set_id+1) % predictionSet.indices[0].size();
01073         typedef std::set<SampleRange<T1> > my_set;
01074         typedef typename my_set::iterator set_it;
01075         //typedef std::set<std::pair<int,SampleRange<T1> > >::iterator set_it;
01076         //Build a stack with all the ranges we have
01077         std::vector<std::pair<int,set_it> > stack;
01078         stack.clear();
01079         for(set_it i=predictionSet.ranges[set_id].begin();
01080              i!=predictionSet.ranges[set_id].end();++i)
01081             stack.push_back(std::pair<int,set_it>(2,i));
01082         //get weights predicted by single tree
01083         int num_decisions=0;
01084         while(!stack.empty())
01085         {
01086             set_it range=stack.back().second;
01087             int index=stack.back().first;
01088             stack.pop_back();
01089             ++num_decisions;
01090 
01091             if(trees_[k].isLeafNode(trees_[k].topology_[index]))
01092             {
01093                 ArrayVector<double>::iterator weights=Node<e_ConstProbNode>(trees_[k].topology_,
01094                                                                             trees_[k].parameters_,
01095                                                                             index).prob_begin();
01096                 for(int i=range->start;i!=range->end;++i)
01097                 {
01098                     //update votecount.
01099                     for(int l=0; l<ext_param_.class_count_; ++l)
01100                     {
01101                         prob(predictionSet.indices[set_id][i], l) += (T2)weights[l];
01102                         //every weight in totalWeight.
01103                         totalWeights[predictionSet.indices[set_id][i]] += (T1)weights[l];
01104                     }
01105                 }
01106             }
01107 
01108             else
01109             {
01110                 if(trees_[k].topology_[index]!=i_ThresholdNode)
01111                 {
01112                     throw std::runtime_error("predicting with online prediction sets is only supported for RFs with threshold nodes");
01113                 }
01114                 Node<i_ThresholdNode> node(trees_[k].topology_,trees_[k].parameters_,index);
01115                 if(range->min_boundaries[node.column()]>=node.threshold())
01116                 {
01117                     //Everything goes to right child
01118                     stack.push_back(std::pair<int,set_it>(node.child(1),range));
01119                     continue;
01120                 }
01121                 if(range->max_boundaries[node.column()]<node.threshold())
01122                 {
01123                     //Everything goes to the left child
01124                     stack.push_back(std::pair<int,set_it>(node.child(0),range));
01125                     continue;
01126                 }
01127                 //We have to split at this node
01128                 SampleRange<T1> new_range=*range;
01129                 new_range.min_boundaries[node.column()]=FLT_MAX;
01130                 range->max_boundaries[node.column()]=-FLT_MAX;
01131                 new_range.start=new_range.end=range->end;
01132                 int i=range->start;
01133                 while(i!=range->end)
01134                 {
01135                     //Decide for range->indices[i]
01136                     if(predictionSet.features(predictionSet.indices[set_id][i],node.column())>=node.threshold())
01137                     {
01138                         new_range.min_boundaries[node.column()]=std::min(new_range.min_boundaries[node.column()],
01139                                                                     predictionSet.features(predictionSet.indices[set_id][i],node.column()));
01140                         --range->end;
01141                         --new_range.start;
01142                         std::swap(predictionSet.indices[set_id][i],predictionSet.indices[set_id][range->end]);
01143 
01144                     }
01145                     else
01146                     {
01147                         range->max_boundaries[node.column()]=std::max(range->max_boundaries[node.column()],
01148                                                                  predictionSet.features(predictionSet.indices[set_id][i],node.column()));
01149                         ++i;
01150                     }
01151                 }
01152                 //The old one ...
01153                 if(range->start==range->end)
01154                 {
01155                     predictionSet.ranges[set_id].erase(range);
01156                 }
01157                 else
01158                 {
01159                     stack.push_back(std::pair<int,set_it>(node.child(0),range));
01160                 }
01161                 //And the new one ...
01162                 if(new_range.start!=new_range.end)
01163                 {
01164                     std::pair<set_it,bool> new_it=predictionSet.ranges[set_id].insert(new_range);
01165                     stack.push_back(std::pair<int,set_it>(node.child(1),new_it.first));
01166                 }
01167             }
01168         }
01169         predictionSet.cumulativePredTime[k]=num_decisions;
01170     }
01171     for(unsigned int i=0;i<totalWeights.size();++i)
01172     {
01173         double test=0.0;
01174         //Normalise votes in each row by total VoteCount (totalWeight
01175         for(int l=0; l<ext_param_.class_count_; ++l)
01176         {
01177             test+=prob(i,l);
01178             prob(i, l) /= totalWeights[i];
01179         }
01180         assert(test==totalWeights[i]);
01181         assert(totalWeights[i]>0.0);
01182     }
01183 }
01184 
01185 template <class LabelType, class PreprocessorTag>
01186 template <class U, class C1, class T, class C2, class Stop_t>
01187 void RandomForest<LabelType, PreprocessorTag>
01188     ::predictProbabilities(MultiArrayView<2, U, C1>const &  features,
01189                            MultiArrayView<2, T, C2> &       prob,
01190                            Stop_t                   &       stop_) const
01191 {
01192     //Features are n xp
01193     //prob is n x NumOfLabel probability for each feature in each class
01194 
01195     vigra_precondition(rowCount(features) == rowCount(prob),
01196       "RandomForestn::predictProbabilities():"
01197         " Feature matrix and probability matrix size mismatch.");
01198 
01199     // num of features must be bigger than num of features in Random forest training
01200     // but why bigger?
01201     vigra_precondition( columnCount(features) >= ext_param_.column_count_,
01202       "RandomForestn::predictProbabilities():"
01203         " Too few columns in feature matrix.");
01204     vigra_precondition( columnCount(prob)
01205                         == (MultiArrayIndex)ext_param_.class_count_,
01206       "RandomForestn::predictProbabilities():"
01207       " Probability matrix must have as many columns as there are classes.");
01208 
01209     #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_> 
01210     Default_Stop_t default_stop(options_);
01211     typename RF_CHOOSER(Stop_t)::type & stop
01212             = RF_CHOOSER(Stop_t)::choose(stop_, default_stop); 
01213     #undef RF_CHOOSER 
01214     stop.set_external_parameters(ext_param_, tree_count());
01215     prob.init(NumericTraits<T>::zero());
01216     /* This code was originally there for testing early stopping
01217      * - we wanted the order of the trees to be randomized
01218     if(tree_indices_.size() != 0)
01219     {
01220        std::random_shuffle(tree_indices_.begin(),
01221                            tree_indices_.end()); 
01222     }
01223     */
01224     //Classify for each row.
01225     for(int row=0; row < rowCount(features); ++row)
01226     {
01227         ArrayVector<double>::const_iterator weights;
01228 
01229         //totalWeight == totalVoteCount!
01230         double totalWeight = 0.0;
01231 
01232         //Let each tree classify...
01233         for(int k=0; k<options_.tree_count_; ++k)
01234         {
01235             //get weights predicted by single tree
01236             weights = trees_[k /*tree_indices_[k]*/].predict(rowVector(features, row));
01237 
01238             //update votecount.
01239             int weighted = options_.predict_weighted_;
01240             for(int l=0; l<ext_param_.class_count_; ++l)
01241             {
01242                 double cur_w = weights[l] * (weighted * (*(weights-1))
01243                                            + (1-weighted));
01244                 prob(row, l) += (T)cur_w;
01245                 //every weight in totalWeight.
01246                 totalWeight += cur_w;
01247             }
01248             if(stop.after_prediction(weights, 
01249                                      k,
01250                                      rowVector(prob, row),
01251                                      totalWeight))
01252             {
01253                 break;
01254             }
01255         }
01256 
01257         //Normalise votes in each row by total VoteCount (totalWeight
01258         for(int l=0; l< ext_param_.class_count_; ++l)
01259         {
01260             prob(row, l) /= detail::RequiresExplicitCast<T>::cast(totalWeight);
01261         }
01262     }
01263 
01264 }
01265 
01266 template <class LabelType, class PreprocessorTag>
01267 template <class U, class C1, class T, class C2>
01268 void RandomForest<LabelType, PreprocessorTag>
01269     ::predictRaw(MultiArrayView<2, U, C1>const &  features,
01270                            MultiArrayView<2, T, C2> &       prob) const
01271 {
01272     //Features are n xp
01273     //prob is n x NumOfLabel probability for each feature in each class
01274 
01275     vigra_precondition(rowCount(features) == rowCount(prob),
01276       "RandomForestn::predictProbabilities():"
01277         " Feature matrix and probability matrix size mismatch.");
01278 
01279     // num of features must be bigger than num of features in Random forest training
01280     // but why bigger?
01281     vigra_precondition( columnCount(features) >= ext_param_.column_count_,
01282       "RandomForestn::predictProbabilities():"
01283         " Too few columns in feature matrix.");
01284     vigra_precondition( columnCount(prob)
01285                         == (MultiArrayIndex)ext_param_.class_count_,
01286       "RandomForestn::predictProbabilities():"
01287       " Probability matrix must have as many columns as there are classes.");
01288 
01289     #define RF_CHOOSER(type_) detail::Value_Chooser<type_, Default_##type_> 
01290     prob.init(NumericTraits<T>::zero());
01291     /* This code was originally there for testing early stopping
01292      * - we wanted the order of the trees to be randomized
01293     if(tree_indices_.size() != 0)
01294     {
01295        std::random_shuffle(tree_indices_.begin(),
01296                            tree_indices_.end()); 
01297     }
01298     */
01299     //Classify for each row.
01300     for(int row=0; row < rowCount(features); ++row)
01301     {
01302         ArrayVector<double>::const_iterator weights;
01303 
01304         //totalWeight == totalVoteCount!
01305         double totalWeight = 0.0;
01306 
01307         //Let each tree classify...
01308         for(int k=0; k<options_.tree_count_; ++k)
01309         {
01310             //get weights predicted by single tree
01311             weights = trees_[k /*tree_indices_[k]*/].predict(rowVector(features, row));
01312 
01313             //update votecount.
01314             int weighted = options_.predict_weighted_;
01315             for(int l=0; l<ext_param_.class_count_; ++l)
01316             {
01317                 double cur_w = weights[l] * (weighted * (*(weights-1))
01318                                            + (1-weighted));
01319                 prob(row, l) += (T)cur_w;
01320                 //every weight in totalWeight.
01321                 totalWeight += cur_w;
01322             }
01323         }
01324     }
01325     prob/= options_.tree_count_;
01326 
01327 }
01328 
01329 //@}
01330 
01331 } // namespace vigra
01332 
01333 #include "random_forest/rf_algorithm.hxx"
01334 #endif // VIGRA_RANDOM_FOREST_HXX

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.8.0 (20 Sep 2011)