[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

vigra/random_forest/rf_common.hxx VIGRA

00001 /************************************************************************/
00002 /*                                                                      */
00003 /*        Copyright 2008-2009 by  Ullrich Koethe and Rahul Nair         */
00004 /*                                                                      */
00005 /*    This file is part of the VIGRA computer vision library.           */
00006 /*    The VIGRA Website is                                              */
00007 /*        http://hci.iwr.uni-heidelberg.de/vigra/                       */
00008 /*    Please direct questions, bug reports, and contributions to        */
00009 /*        ullrich.koethe@iwr.uni-heidelberg.de    or                    */
00010 /*        vigra@informatik.uni-hamburg.de                               */
00011 /*                                                                      */
00012 /*    Permission is hereby granted, free of charge, to any person       */
00013 /*    obtaining a copy of this software and associated documentation    */
00014 /*    files (the "Software"), to deal in the Software without           */
00015 /*    restriction, including without limitation the rights to use,      */
00016 /*    copy, modify, merge, publish, distribute, sublicense, and/or      */
00017 /*    sell copies of the Software, and to permit persons to whom the    */
00018 /*    Software is furnished to do so, subject to the following          */
00019 /*    conditions:                                                       */
00020 /*                                                                      */
00021 /*    The above copyright notice and this permission notice shall be    */
00022 /*    included in all copies or substantial portions of the             */
00023 /*    Software.                                                         */
00024 /*                                                                      */
00025 /*    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND    */
00026 /*    EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES   */
00027 /*    OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND          */
00028 /*    NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT       */
00029 /*    HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,      */
00030 /*    WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING      */
00031 /*    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR     */
00032 /*    OTHER DEALINGS IN THE SOFTWARE.                                   */
00033 /*                                                                      */
00034 /************************************************************************/
00035 
00036 
00037 #ifndef VIGRA_RF_COMMON_HXX
00038 #define VIGRA_RF_COMMON_HXX
00039 
00040 namespace vigra
00041 {
00042 
00043 
00044 struct ClassificationTag
00045 {};
00046 
00047 struct RegressionTag
00048 {};
00049 
00050 namespace detail
00051 {
00052     class RF_DEFAULT;
00053 }
00054 inline detail::RF_DEFAULT& rf_default();
00055 namespace detail
00056 {
00057 
00058 /* \brief singleton default tag class -
00059  *
00060  *  use the rf_default() factory function to use the tag.
00061  *  \sa RandomForest<>::learn();
00062  */
00063 class RF_DEFAULT
00064 {
00065     private:
00066         RF_DEFAULT()
00067         {}
00068     public:
00069         friend RF_DEFAULT& ::vigra::rf_default();
00070 
00071         /** ok workaround for automatic choice of the decisiontree
00072          * stackentry.
00073          */
00074 };
00075 
00076 /* \brief chooses between default type and type supplied
00077  * 
00078  * This is an internal class and you shouldn't really care about it.
00079  * Just pass on used in RandomForest.learn()
00080  * Usage:
00081  *\code
00082  *      // example: use container type supplied by user or ArrayVector if 
00083  *      //          rf_default() was specified as argument;
00084  *      template<class Container_t>
00085  *      void do_some_foo(Container_t in)
00086  *      {
00087  *          typedef ArrayVector<int>    Default_Container_t;
00088  *          Default_Container_t         default_value;
00089  *          Value_Chooser<Container_t,  Default_Container_t> 
00090  *                      choose(in, default_value);
00091  *
00092  *          // if the user didn't care and the in was of type 
00093  *          // RF_DEFAULT then default_value is used.
00094  *          do_some_more_foo(choose.value());
00095  *      }
00096  *      Value_Chooser choose_val<Type, Default_Type>
00097  *\endcode
00098  */
00099 template<class T, class C>
00100 class Value_Chooser
00101 {
00102 public:
00103     typedef T type;
00104     static T & choose(T & t, C &)
00105     {
00106         return t; 
00107     }
00108 };
00109 
00110 template<class C>
00111 class Value_Chooser<detail::RF_DEFAULT, C>
00112 {
00113 public:
00114     typedef C type;
00115     
00116     static C & choose(detail::RF_DEFAULT &, C & c)
00117     {
00118         return c; 
00119     }
00120 };
00121 
00122 
00123 
00124 
00125 } //namespace detail
00126 
00127 
00128 /**\brief factory function to return a RF_DEFAULT tag
00129  * \sa RandomForest<>::learn()
00130  */
00131 detail::RF_DEFAULT& rf_default()
00132 {
00133     static detail::RF_DEFAULT result;
00134     return result;
00135 }
00136 
00137 /** tags used with the RandomForestOptions class
00138  * \sa RF_Traits::Option_t
00139  */
00140 enum RF_OptionTag   { RF_EQUAL,
00141                       RF_PROPORTIONAL,
00142                       RF_EXTERNAL,
00143                       RF_NONE,
00144                       RF_FUNCTION,
00145                       RF_LOG,
00146                       RF_SQRT,
00147                       RF_CONST,
00148                       RF_ALL};
00149 
00150 
00151 /** \addtogroup MachineLearning 
00152 **/
00153 //@{
00154 
00155 /**\brief Options object for the random forest
00156  *
00157  * usage:
00158  * RandomForestOptions a =  RandomForestOptions()
00159  *                              .param1(value1)
00160  *                              .param2(value2)
00161  *                              ...
00162  *
00163  * This class only contains options/parameters that are not problem
00164  * dependent. The ProblemSpec class contains methods to set class weights
00165  * if necessary.
00166  *
00167  * Note that the return value of all methods is *this which makes
00168  * concatenating of options as above possible.
00169  */
00170 class RandomForestOptions
00171 {
00172   public:
00173     /**\name sampling options*/
00174     /*\{*/
00175     // look at the member access functions for documentation
00176     double  training_set_proportion_;
00177     int     training_set_size_;
00178     int (*training_set_func_)(int);
00179     RF_OptionTag
00180         training_set_calc_switch_;
00181 
00182     bool    sample_with_replacement_;
00183     RF_OptionTag
00184             stratification_method_;
00185 
00186 
00187     /**\name general random forest options
00188      *
00189      * these usually will be used by most split functors and
00190      * stopping predicates
00191      */
00192     /*\{*/
00193     RF_OptionTag    mtry_switch_;
00194     int     mtry_;
00195     int (*mtry_func_)(int) ;
00196 
00197     bool predict_weighted_; 
00198     int tree_count_;
00199     int min_split_node_size_;
00200     bool prepare_online_learning_;
00201     /*\}*/
00202 
00203     typedef ArrayVector<double> double_array;
00204     typedef std::map<std::string, double_array> map_type;
00205 
00206     int serialized_size() const
00207     {
00208         return 12;
00209     }
00210     
00211 
00212     bool operator==(RandomForestOptions & rhs) const
00213     {
00214         bool result = true;
00215         #define COMPARE(field) result = result && (this->field == rhs.field); 
00216         COMPARE(training_set_proportion_);
00217         COMPARE(training_set_size_);
00218         COMPARE(training_set_calc_switch_);
00219         COMPARE(sample_with_replacement_);
00220         COMPARE(stratification_method_);
00221         COMPARE(mtry_switch_);
00222         COMPARE(mtry_);
00223         COMPARE(tree_count_);
00224         COMPARE(min_split_node_size_);
00225         COMPARE(predict_weighted_);
00226         #undef COMPARE
00227 
00228         return result;
00229     }
00230     bool operator!=(RandomForestOptions & rhs_) const
00231     {
00232         return !(*this == rhs_);
00233     }
00234     template<class Iter>
00235     void unserialize(Iter const & begin, Iter const & end)
00236     {
00237         Iter iter = begin;
00238         vigra_precondition(static_cast<int>(end - begin) == serialized_size(), 
00239                            "RandomForestOptions::unserialize():"
00240                            "wrong number of parameters");
00241         #define PULL(item_, type_) item_ = type_(*iter); ++iter;
00242         PULL(training_set_proportion_, double);
00243         PULL(training_set_size_, int);
00244         ++iter; //PULL(training_set_func_, double);
00245         PULL(training_set_calc_switch_, (RF_OptionTag)int);
00246         PULL(sample_with_replacement_, 0 != );
00247         PULL(stratification_method_, (RF_OptionTag)int);
00248         PULL(mtry_switch_, (RF_OptionTag)int);
00249         PULL(mtry_, int);
00250         ++iter; //PULL(mtry_func_, double);
00251         PULL(tree_count_, int);
00252         PULL(min_split_node_size_, int);
00253         PULL(predict_weighted_, 0 !=);
00254         #undef PULL
00255     }
00256     template<class Iter>
00257     void serialize(Iter const &  begin, Iter const & end) const
00258     {
00259         Iter iter = begin;
00260         vigra_precondition(static_cast<int>(end - begin) == serialized_size(), 
00261                            "RandomForestOptions::serialize():"
00262                            "wrong number of parameters");
00263         #define PUSH(item_) *iter = double(item_); ++iter;
00264         PUSH(training_set_proportion_);
00265         PUSH(training_set_size_);
00266         if(training_set_func_ != 0)
00267         {
00268             PUSH(1);
00269         }
00270         else
00271         {
00272             PUSH(0);
00273         }
00274         PUSH(training_set_calc_switch_);
00275         PUSH(sample_with_replacement_);
00276         PUSH(stratification_method_);
00277         PUSH(mtry_switch_);
00278         PUSH(mtry_);
00279         if(mtry_func_ != 0)
00280         {
00281             PUSH(1);
00282         }
00283         else
00284         {
00285             PUSH(0);
00286         }
00287         PUSH(tree_count_);
00288         PUSH(min_split_node_size_);
00289         PUSH(predict_weighted_);
00290         #undef PUSH
00291     }
00292     
00293     void make_from_map(map_type & in) // -> const: .operator[] -> .find
00294     {
00295         typedef MultiArrayShape<2>::type Shp; 
00296         #define PULL(item_, type_) item_ = type_(in[#item_][0]); 
00297         #define PULLBOOL(item_, type_) item_ = type_(in[#item_][0] > 0); 
00298         PULL(training_set_proportion_,double);
00299         PULL(training_set_size_, int);
00300         PULL(mtry_, int);
00301         PULL(tree_count_, int);
00302         PULL(min_split_node_size_, int);
00303         PULLBOOL(sample_with_replacement_, bool);
00304         PULLBOOL(prepare_online_learning_, bool);
00305         PULLBOOL(predict_weighted_, bool);
00306         
00307         PULL(training_set_calc_switch_, (RF_OptionTag)(int));
00308 
00309         PULL(stratification_method_, (RF_OptionTag)(int));
00310         PULL(mtry_switch_, (RF_OptionTag)(int));
00311         
00312         /*don't pull*/
00313         //PULL(mtry_func_!=0, int);
00314         //PULL(training_set_func,int);
00315         #undef PULL
00316         #undef PULLBOOL
00317     }
00318     void make_map(map_type & in) const
00319     {
00320         typedef MultiArrayShape<2>::type Shp; 
00321         #define PUSH(item_, type_) in[#item_] = double_array(1, double(item_));
00322         #define PUSHFUNC(item_, type_) in[#item_] = double_array(1, double(item_!=0));
00323         PUSH(training_set_proportion_,double);
00324         PUSH(training_set_size_, int);
00325         PUSH(mtry_, int);
00326         PUSH(tree_count_, int);
00327         PUSH(min_split_node_size_, int);
00328         PUSH(sample_with_replacement_, bool);
00329         PUSH(prepare_online_learning_, bool);
00330         PUSH(predict_weighted_, bool);
00331         
00332         PUSH(training_set_calc_switch_, RF_OptionTag);
00333         PUSH(stratification_method_, RF_OptionTag);
00334         PUSH(mtry_switch_, RF_OptionTag);
00335         
00336         PUSHFUNC(mtry_func_, int);
00337         PUSHFUNC(training_set_func_,int);
00338         #undef PUSH
00339         #undef PUSHFUNC
00340     }
00341 
00342 
00343     /**\brief create a RandomForestOptions object with default initialisation.
00344      *
00345      * look at the other member functions for more information on default
00346      * values
00347      */
00348     RandomForestOptions()
00349     :
00350         training_set_proportion_(1.0),
00351         training_set_size_(0),
00352         training_set_func_(0),
00353         training_set_calc_switch_(RF_PROPORTIONAL),
00354         sample_with_replacement_(true),
00355         stratification_method_(RF_NONE),
00356         mtry_switch_(RF_SQRT),
00357         mtry_(0),
00358         mtry_func_(0),
00359         predict_weighted_(false),
00360         tree_count_(256),
00361         min_split_node_size_(1),
00362         prepare_online_learning_(false)
00363     {}
00364 
00365     /**\brief specify stratification strategy
00366      *
00367      * default: RF_NONE
00368      * possible values: RF_EQUAL, RF_PROPORTIONAL,
00369      *                  RF_EXTERNAL, RF_NONE
00370      * RF_EQUAL:        get equal amount of samples per class.
00371      * RF_PROPORTIONAL: sample proportional to fraction of class samples
00372      *                  in population
00373      * RF_EXTERNAL:     strata_weights_ field of the ProblemSpec_t object
00374      *                  has been set externally. (defunct)
00375      */
00376     RandomForestOptions & use_stratification(RF_OptionTag in)
00377     {
00378         vigra_precondition(in == RF_EQUAL ||
00379                            in == RF_PROPORTIONAL ||
00380                            in == RF_EXTERNAL ||
00381                            in == RF_NONE,
00382                            "RandomForestOptions::use_stratification()"
00383                            "input must be RF_EQUAL, RF_PROPORTIONAL,"
00384                            "RF_EXTERNAL or RF_NONE");
00385         stratification_method_ = in;
00386         return *this;
00387     }
00388 
00389     RandomForestOptions & prepare_online_learning(bool in)
00390     {
00391         prepare_online_learning_=in;
00392         return *this;
00393     }
00394 
00395     /**\brief sample from training population with or without replacement?
00396      *
00397      * <br> Default: true
00398      */
00399     RandomForestOptions & sample_with_replacement(bool in)
00400     {
00401         sample_with_replacement_ = in;
00402         return *this;
00403     }
00404 
00405     /**\brief  specify the fraction of the total number of samples 
00406      * used per tree for learning. 
00407      *
00408      * This value should be in [0.0 1.0] if sampling without
00409      * replacement has been specified.
00410      *
00411      * <br> default : 1.0
00412      */
00413     RandomForestOptions & samples_per_tree(double in)
00414     {
00415         training_set_proportion_ = in;
00416         training_set_calc_switch_ = RF_PROPORTIONAL;
00417         return *this;
00418     }
00419 
00420     /**\brief directly specify the number of samples per tree
00421      */
00422     RandomForestOptions & samples_per_tree(int in)
00423     {
00424         training_set_size_ = in;
00425         training_set_calc_switch_ = RF_CONST;
00426         return *this;
00427     }
00428 
00429     /**\brief use external function to calculate the number of samples each
00430      *        tree should be learnt with.
00431      *
00432      * \param in function pointer that takes the number of rows in the
00433      *           learning data and outputs the number samples per tree.
00434      */
00435     RandomForestOptions & samples_per_tree(int (*in)(int))
00436     {
00437         training_set_func_ = in;
00438         training_set_calc_switch_ = RF_FUNCTION;
00439         return *this;
00440     }
00441     
00442     /**\brief weight each tree with number of samples in that node
00443      */
00444     RandomForestOptions & predict_weighted()
00445     {
00446         predict_weighted_ = true;
00447         return *this;
00448     }
00449 
00450     /**\brief use built in mapping to calculate mtry
00451      *
00452      * Use one of the built in mappings to calculate mtry from the number
00453      * of columns in the input feature data.
00454      * \param in possible values: RF_LOG, RF_SQRT or RF_ALL
00455      *           <br> default: RF_SQRT.
00456      */
00457     RandomForestOptions & features_per_node(RF_OptionTag in)
00458     {
00459         vigra_precondition(in == RF_LOG ||
00460                            in == RF_SQRT||
00461                            in == RF_ALL,
00462                            "RandomForestOptions()::features_per_node():"
00463                            "input must be of type RF_LOG or RF_SQRT");
00464         mtry_switch_ = in;
00465         return *this;
00466     }
00467 
00468     /**\brief Set mtry to a constant value
00469      *
00470      * mtry is the number of columns/variates/variables randomly chosen
00471      * to select the best split from.
00472      *
00473      */
00474     RandomForestOptions & features_per_node(int in)
00475     {
00476         mtry_ = in;
00477         mtry_switch_ = RF_CONST;
00478         return *this;
00479     }
00480 
00481     /**\brief use a external function to calculate mtry
00482      *
00483      * \param in function pointer that takes int (number of columns
00484      *           of the and outputs int (mtry)
00485      */
00486     RandomForestOptions & features_per_node(int(*in)(int))
00487     {
00488         mtry_func_ = in;
00489         mtry_switch_ = RF_FUNCTION;
00490         return *this;
00491     }
00492 
00493     /** How many trees to create?
00494      *
00495      * <br> Default: 255.
00496      */
00497     RandomForestOptions & tree_count(int in)
00498     {
00499         tree_count_ = in;
00500         return *this;
00501     }
00502 
00503     /**\brief Number of examples required for a node to be split.
00504      *
00505      *  When the number of examples in a node is below this number,
00506      *  the node is not split even if class separation is not yet perfect.
00507      *  Instead, the node returns the proportion of each class
00508      *  (among the remaining examples) during the prediction phase.
00509      *  <br> Default: 1 (complete growing)
00510      */
00511     RandomForestOptions & min_split_node_size(int in)
00512     {
00513         min_split_node_size_ = in;
00514         return *this;
00515     }
00516 };
00517 
00518 
00519 /** \brief problem types 
00520  */
00521 enum Problem_t{REGRESSION, CLASSIFICATION, CHECKLATER};
00522 
00523 
00524 /** \brief problem specification class for the random forest.
00525  *
00526  * This class contains all the problem specific parameters the random
00527  * forest needs for learning. Specification of an instance of this class
00528  * is optional as all necessary fields will be computed prior to learning
00529  * if not specified.
00530  *
00531  * if needed usage is similar to that of RandomForestOptions
00532  */
00533 
00534 template<class LabelType = double>
00535 class ProblemSpec
00536 {
00537 
00538 
00539 public:
00540 
00541     /** \brief  problem class
00542      */
00543 
00544     typedef LabelType       Label_t;
00545     ArrayVector<Label_t>    classes;
00546     typedef ArrayVector<double>                 double_array;
00547     typedef std::map<std::string, double_array> map_type;
00548 
00549     int                     column_count_;    // number of features
00550     int                     class_count_;     // number of classes
00551     int                     row_count_;       // number of samples
00552 
00553     int                     actual_mtry_;     // mtry used in training
00554     int                     actual_msample_;  // number if in-bag samples per tree
00555 
00556     Problem_t               problem_type_;    // classification or regression
00557     
00558     int used_;                                // this ProblemSpec is valid
00559     ArrayVector<double>     class_weights_;   // if classes have different importance
00560     int                     is_weighted_;     // class_weights_ are used
00561     double                  precision_;       // termination criterion for regression loss
00562     int                     response_size_; 
00563         
00564     template<class T> 
00565     void to_classlabel(int index, T & out) const
00566     {
00567         out = T(classes[index]);
00568     }
00569     template<class T> 
00570     int to_classIndex(T index) const
00571     {
00572         return std::find(classes.begin(), classes.end(), index) - classes.begin();
00573     }
00574 
00575     #define EQUALS(field) field(rhs.field)
00576     ProblemSpec(ProblemSpec const & rhs)
00577     : 
00578         EQUALS(column_count_),
00579         EQUALS(class_count_),
00580         EQUALS(row_count_),
00581         EQUALS(actual_mtry_),
00582         EQUALS(actual_msample_),
00583         EQUALS(problem_type_),
00584         EQUALS(used_),
00585         EQUALS(class_weights_),
00586         EQUALS(is_weighted_),
00587         EQUALS(precision_),
00588         EQUALS(response_size_)
00589     {
00590         std::back_insert_iterator<ArrayVector<Label_t> >
00591                         iter(classes);
00592         std::copy(rhs.classes.begin(), rhs.classes.end(), iter); 
00593     }
00594     #undef EQUALS
00595     #define EQUALS(field) field(rhs.field)
00596     template<class T>
00597     ProblemSpec(ProblemSpec<T> const & rhs)
00598     : 
00599         EQUALS(column_count_),
00600         EQUALS(class_count_),
00601         EQUALS(row_count_),
00602         EQUALS(actual_mtry_),
00603         EQUALS(actual_msample_),
00604         EQUALS(problem_type_),
00605         EQUALS(used_),
00606         EQUALS(class_weights_),
00607         EQUALS(is_weighted_),
00608         EQUALS(precision_),
00609         EQUALS(response_size_)
00610     {
00611         std::back_insert_iterator<ArrayVector<Label_t> >
00612                         iter(classes);
00613         std::copy(rhs.classes.begin(), rhs.classes.end(), iter); 
00614     }
00615     #undef EQUALS
00616 
00617     #define EQUALS(field) (this->field = rhs.field);
00618     ProblemSpec & operator=(ProblemSpec const & rhs)
00619     {
00620         EQUALS(column_count_);
00621         EQUALS(class_count_);
00622         EQUALS(row_count_);
00623         EQUALS(actual_mtry_);
00624         EQUALS(actual_msample_);
00625         EQUALS(problem_type_);
00626         EQUALS(used_);
00627         EQUALS(is_weighted_);
00628         EQUALS(precision_);
00629         EQUALS(response_size_)
00630         class_weights_.clear();
00631         std::back_insert_iterator<ArrayVector<double> >
00632                         iter2(class_weights_);
00633         std::copy(rhs.class_weights_.begin(), rhs.class_weights_.end(), iter2); 
00634         classes.clear();
00635         std::back_insert_iterator<ArrayVector<Label_t> >
00636                         iter(classes);
00637         std::copy(rhs.classes.begin(), rhs.classes.end(), iter); 
00638         return *this;
00639     }
00640 
00641     template<class T>
00642     ProblemSpec<Label_t> & operator=(ProblemSpec<T> const & rhs)
00643     {
00644         EQUALS(column_count_);
00645         EQUALS(class_count_);
00646         EQUALS(row_count_);
00647         EQUALS(actual_mtry_);
00648         EQUALS(actual_msample_);
00649         EQUALS(problem_type_);
00650         EQUALS(used_);
00651         EQUALS(is_weighted_);
00652         EQUALS(precision_);
00653         EQUALS(response_size_)
00654         class_weights_.clear();
00655         std::back_insert_iterator<ArrayVector<double> >
00656                         iter2(class_weights_);
00657         std::copy(rhs.class_weights_.begin(), rhs.class_weights_.end(), iter2); 
00658         classes.clear();
00659         std::back_insert_iterator<ArrayVector<Label_t> >
00660                         iter(classes);
00661         std::copy(rhs.classes.begin(), rhs.classes.end(), iter); 
00662         return *this;
00663     }
00664     #undef EQUALS
00665 
00666     template<class T>
00667     bool operator==(ProblemSpec<T> const & rhs)
00668     {
00669         bool result = true;
00670         #define COMPARE(field) result = result && (this->field == rhs.field);
00671         COMPARE(column_count_);
00672         COMPARE(class_count_);
00673         COMPARE(row_count_);
00674         COMPARE(actual_mtry_);
00675         COMPARE(actual_msample_);
00676         COMPARE(problem_type_);
00677         COMPARE(is_weighted_);
00678         COMPARE(precision_);
00679         COMPARE(used_);
00680         COMPARE(class_weights_);
00681         COMPARE(classes);
00682         COMPARE(response_size_)
00683         #undef COMPARE
00684         return result;
00685     }
00686 
00687     bool operator!=(ProblemSpec & rhs)
00688     {
00689         return !(*this == rhs);
00690     }
00691 
00692 
00693     size_t serialized_size() const
00694     {
00695         return 9 + class_count_ *int(is_weighted_+1);
00696     }
00697 
00698 
00699     template<class Iter>
00700     void unserialize(Iter const & begin, Iter const & end)
00701     {
00702         Iter iter = begin;
00703         vigra_precondition(end - begin >= 9, 
00704                            "ProblemSpec::unserialize():"
00705                            "wrong number of parameters");
00706         #define PULL(item_, type_) item_ = type_(*iter); ++iter;
00707         PULL(column_count_,int);
00708         PULL(class_count_, int);
00709 
00710         vigra_precondition(end - begin >= 9 + class_count_, 
00711                            "ProblemSpec::unserialize(): 1");
00712         PULL(row_count_, int);
00713         PULL(actual_mtry_,int);
00714         PULL(actual_msample_, int);
00715         PULL(problem_type_, Problem_t);
00716         PULL(is_weighted_, int);
00717         PULL(used_, int);
00718         PULL(precision_, double);
00719         PULL(response_size_, int);
00720         if(is_weighted_)
00721         {
00722             vigra_precondition(end - begin == 9 + 2*class_count_, 
00723                                "ProblemSpec::unserialize(): 2");
00724             class_weights_.insert(class_weights_.end(),
00725                                   iter, 
00726                                   iter + class_count_);
00727             iter += class_count_; 
00728         }
00729         classes.insert(classes.end(), iter, end);
00730         #undef PULL
00731     }
00732 
00733 
00734     template<class Iter>
00735     void serialize(Iter const & begin, Iter const & end) const
00736     {
00737         Iter iter = begin;
00738         vigra_precondition(end - begin == serialized_size(), 
00739                            "RandomForestOptions::serialize():"
00740                            "wrong number of parameters");
00741         #define PUSH(item_) *iter = double(item_); ++iter;
00742         PUSH(column_count_);
00743         PUSH(class_count_)
00744         PUSH(row_count_);
00745         PUSH(actual_mtry_);
00746         PUSH(actual_msample_);
00747         PUSH(problem_type_);
00748         PUSH(is_weighted_);
00749         PUSH(used_);
00750         PUSH(precision_);
00751         PUSH(response_size_);
00752         if(is_weighted_)
00753         {
00754             std::copy(class_weights_.begin(),
00755                       class_weights_.end(),
00756                       iter);
00757             iter += class_count_; 
00758         }
00759         std::copy(classes.begin(),
00760                   classes.end(),
00761                   iter);
00762         #undef PUSH
00763     }
00764 
00765     void make_from_map(map_type & in) // -> const: .operator[] -> .find
00766     {
00767         typedef MultiArrayShape<2>::type Shp; 
00768         #define PULL(item_, type_) item_ = type_(in[#item_][0]); 
00769         PULL(column_count_,int);
00770         PULL(class_count_, int);
00771         PULL(row_count_, int);
00772         PULL(actual_mtry_,int);
00773         PULL(actual_msample_, int);
00774         PULL(problem_type_, (Problem_t)int);
00775         PULL(is_weighted_, int);
00776         PULL(used_, int);
00777         PULL(precision_, double);
00778         PULL(response_size_, int);
00779         class_weights_ = in["class_weights_"];
00780         #undef PUSH
00781     }
00782     void make_map(map_type & in) const
00783     {
00784         typedef MultiArrayShape<2>::type Shp; 
00785         #define PUSH(item_) in[#item_] = double_array(1, double(item_));
00786         PUSH(column_count_);
00787         PUSH(class_count_)
00788         PUSH(row_count_);
00789         PUSH(actual_mtry_);
00790         PUSH(actual_msample_);
00791         PUSH(problem_type_);
00792         PUSH(is_weighted_);
00793         PUSH(used_);
00794         PUSH(precision_);
00795         PUSH(response_size_);
00796         in["class_weights_"] = class_weights_;
00797         #undef PUSH
00798     }
00799     
00800     /**\brief set default values (-> values not set)
00801      */
00802     ProblemSpec()
00803     :   column_count_(0),
00804         class_count_(0),
00805         row_count_(0),
00806         actual_mtry_(0),
00807         actual_msample_(0),
00808         problem_type_(CHECKLATER),
00809         used_(false),
00810         is_weighted_(false),
00811         precision_(0.0),
00812         response_size_(1)
00813     {}
00814 
00815 
00816     ProblemSpec & column_count(int in)
00817     {
00818         column_count_ = in;
00819         return *this;
00820     }
00821 
00822     /**\brief supply with class labels -
00823      * 
00824      * the preprocessor will not calculate the labels needed in this case.
00825      */
00826     template<class C_Iter>
00827     ProblemSpec & classes_(C_Iter begin, C_Iter end)
00828     {
00829         int size = end-begin;
00830         for(int k=0; k<size; ++k, ++begin)
00831             classes.push_back(detail::RequiresExplicitCast<LabelType>::cast(*begin));
00832         class_count_ = size;
00833         return *this;
00834     }
00835 
00836     /** \brief supply with class weights  -
00837      *
00838      * this is the only case where you would really have to 
00839      * create a ProblemSpec object.
00840      */
00841     template<class W_Iter>
00842     ProblemSpec & class_weights(W_Iter begin, W_Iter end)
00843     {
00844         class_weights_.insert(class_weights_.end(), begin, end);
00845         is_weighted_ = true;
00846         return *this;
00847     }
00848 
00849 
00850 
00851     void clear()
00852     {
00853         used_ = false; 
00854         classes.clear();
00855         class_weights_.clear();
00856         column_count_ = 0 ;
00857         class_count_ = 0;
00858         actual_mtry_ = 0;
00859         actual_msample_ = 0;
00860         problem_type_ = CHECKLATER;
00861         is_weighted_ = false;
00862         precision_   = 0.0;
00863         response_size_ = 0;
00864 
00865     }
00866 
00867     bool used() const
00868     {
00869         return used_ != 0;
00870     }
00871 };
00872 
00873 
00874 //@}
00875 
00876 
00877 
00878 /**\brief Standard early stopping criterion
00879  *
00880  * Stop if region.size() < min_split_node_size_;
00881  */
00882 class EarlyStoppStd
00883 {
00884     public:
00885     int min_split_node_size_;
00886 
00887     template<class Opt>
00888     EarlyStoppStd(Opt opt)
00889     :   min_split_node_size_(opt.min_split_node_size_)
00890     {}
00891 
00892     template<class T>
00893     void set_external_parameters(ProblemSpec<T>const  &, int /* tree_count */ = 0, bool /* is_weighted_ */ = false)
00894     {}
00895 
00896     template<class Region>
00897     bool operator()(Region& region)
00898     {
00899         return region.size() < min_split_node_size_;
00900     }
00901 
00902     template<class WeightIter, class T, class C>
00903     bool after_prediction(WeightIter,  int /* k */, MultiArrayView<2, T, C> /* prob */, double /* totalCt */)
00904     {
00905         return false; 
00906     }
00907 };
00908 
00909 
00910 } // namespace vigra
00911 
00912 #endif //VIGRA_RF_COMMON_HXX

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.8.0 (20 Sep 2011)