[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

rf_common.hxx VIGRA

1 /************************************************************************/
2 /* */
3 /* Copyright 2008-2009 by Ullrich Koethe and Rahul Nair */
4 /* */
5 /* This file is part of the VIGRA computer vision library. */
6 /* The VIGRA Website is */
7 /* http://hci.iwr.uni-heidelberg.de/vigra/ */
8 /* Please direct questions, bug reports, and contributions to */
9 /* ullrich.koethe@iwr.uni-heidelberg.de or */
10 /* vigra@informatik.uni-hamburg.de */
11 /* */
12 /* Permission is hereby granted, free of charge, to any person */
13 /* obtaining a copy of this software and associated documentation */
14 /* files (the "Software"), to deal in the Software without */
15 /* restriction, including without limitation the rights to use, */
16 /* copy, modify, merge, publish, distribute, sublicense, and/or */
17 /* sell copies of the Software, and to permit persons to whom the */
18 /* Software is furnished to do so, subject to the following */
19 /* conditions: */
20 /* */
21 /* The above copyright notice and this permission notice shall be */
22 /* included in all copies or substantial portions of the */
23 /* Software. */
24 /* */
25 /* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND */
26 /* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES */
27 /* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND */
28 /* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT */
29 /* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, */
30 /* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING */
31 /* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR */
32 /* OTHER DEALINGS IN THE SOFTWARE. */
33 /* */
34 /************************************************************************/
35 
36 
37 #ifndef VIGRA_RF_COMMON_HXX
38 #define VIGRA_RF_COMMON_HXX
39 
40 namespace vigra
41 {
42 
43 
44 struct ClassificationTag
45 {};
46 
47 struct RegressionTag
48 {};
49 
50 namespace detail
51 {
52  class RF_DEFAULT;
53 }
54 inline detail::RF_DEFAULT& rf_default();
55 namespace detail
56 {
57 
58 /* \brief singleton default tag class -
59  *
60  * use the rf_default() factory function to use the tag.
61  * \sa RandomForest<>::learn();
62  */
63 class RF_DEFAULT
64 {
65  private:
66  RF_DEFAULT()
67  {}
68  public:
69  friend RF_DEFAULT& ::vigra::rf_default();
70 
71  /** ok workaround for automatic choice of the decisiontree
72  * stackentry.
73  */
74 };
75 
76 /* \brief chooses between default type and type supplied
77  *
78  * This is an internal class and you shouldn't really care about it.
79  * Just pass on used in RandomForest.learn()
80  * Usage:
81  *\code
82  * // example: use container type supplied by user or ArrayVector if
83  * // rf_default() was specified as argument;
84  * template<class Container_t>
85  * void do_some_foo(Container_t in)
86  * {
87  * typedef ArrayVector<int> Default_Container_t;
88  * Default_Container_t default_value;
89  * Value_Chooser<Container_t, Default_Container_t>
90  * choose(in, default_value);
91  *
92  * // if the user didn't care and the in was of type
93  * // RF_DEFAULT then default_value is used.
94  * do_some_more_foo(choose.value());
95  * }
96  * Value_Chooser choose_val<Type, Default_Type>
97  *\endcode
98  */
99 template<class T, class C>
100 class Value_Chooser
101 {
102 public:
103  typedef T type;
104  static T & choose(T & t, C &)
105  {
106  return t;
107  }
108 };
109 
110 template<class C>
111 class Value_Chooser<detail::RF_DEFAULT, C>
112 {
113 public:
114  typedef C type;
115 
116  static C & choose(detail::RF_DEFAULT &, C & c)
117  {
118  return c;
119  }
120 };
121 
122 
123 
124 
125 } //namespace detail
126 
127 
128 /**\brief factory function to return a RF_DEFAULT tag
129  * \sa RandomForest<>::learn()
130  */
131 detail::RF_DEFAULT& rf_default()
132 {
133  static detail::RF_DEFAULT result;
134  return result;
135 }
136 
137 /** tags used with the RandomForestOptions class
138  * \sa RF_Traits::Option_t
139  */
140 enum RF_OptionTag { RF_EQUAL,
141  RF_PROPORTIONAL,
142  RF_EXTERNAL,
143  RF_NONE,
144  RF_FUNCTION,
145  RF_LOG,
146  RF_SQRT,
147  RF_CONST,
148  RF_ALL};
149 
150 
151 /** \addtogroup MachineLearning
152 **/
153 //@{
154 
155 /**\brief Options object for the random forest
156  *
157  * usage:
158  * RandomForestOptions a = RandomForestOptions()
159  * .param1(value1)
160  * .param2(value2)
161  * ...
162  *
163  * This class only contains options/parameters that are not problem
164  * dependent. The ProblemSpec class contains methods to set class weights
165  * if necessary.
166  *
167  * Note that the return value of all methods is *this which makes
168  * concatenating of options as above possible.
169  */
171 {
172  public:
173  /**\name sampling options*/
174  /*\{*/
175  // look at the member access functions for documentation
176  double training_set_proportion_;
177  int training_set_size_;
178  int (*training_set_func_)(int);
180  training_set_calc_switch_;
181 
182  bool sample_with_replacement_;
184  stratification_method_;
185 
186 
187  /**\name general random forest options
188  *
189  * these usually will be used by most split functors and
190  * stopping predicates
191  */
192  /*\{*/
193  RF_OptionTag mtry_switch_;
194  int mtry_;
195  int (*mtry_func_)(int) ;
196 
197  bool predict_weighted_;
198  int tree_count_;
199  int min_split_node_size_;
200  bool prepare_online_learning_;
201  /*\}*/
202 
204  typedef std::map<std::string, double_array> map_type;
205 
206  int serialized_size() const
207  {
208  return 12;
209  }
210 
211 
212  bool operator==(RandomForestOptions & rhs) const
213  {
214  bool result = true;
215  #define COMPARE(field) result = result && (this->field == rhs.field);
216  COMPARE(training_set_proportion_);
217  COMPARE(training_set_size_);
218  COMPARE(training_set_calc_switch_);
219  COMPARE(sample_with_replacement_);
220  COMPARE(stratification_method_);
221  COMPARE(mtry_switch_);
222  COMPARE(mtry_);
223  COMPARE(tree_count_);
224  COMPARE(min_split_node_size_);
225  COMPARE(predict_weighted_);
226  #undef COMPARE
227 
228  return result;
229  }
230  bool operator!=(RandomForestOptions & rhs_) const
231  {
232  return !(*this == rhs_);
233  }
234  template<class Iter>
235  void unserialize(Iter const & begin, Iter const & end)
236  {
237  Iter iter = begin;
238  vigra_precondition(static_cast<int>(end - begin) == serialized_size(),
239  "RandomForestOptions::unserialize():"
240  "wrong number of parameters");
241  #define PULL(item_, type_) item_ = type_(*iter); ++iter;
242  PULL(training_set_proportion_, double);
243  PULL(training_set_size_, int);
244  ++iter; //PULL(training_set_func_, double);
245  PULL(training_set_calc_switch_, (RF_OptionTag)int);
246  PULL(sample_with_replacement_, 0 != );
247  PULL(stratification_method_, (RF_OptionTag)int);
248  PULL(mtry_switch_, (RF_OptionTag)int);
249  PULL(mtry_, int);
250  ++iter; //PULL(mtry_func_, double);
251  PULL(tree_count_, int);
252  PULL(min_split_node_size_, int);
253  PULL(predict_weighted_, 0 !=);
254  #undef PULL
255  }
256  template<class Iter>
257  void serialize(Iter const & begin, Iter const & end) const
258  {
259  Iter iter = begin;
260  vigra_precondition(static_cast<int>(end - begin) == serialized_size(),
261  "RandomForestOptions::serialize():"
262  "wrong number of parameters");
263  #define PUSH(item_) *iter = double(item_); ++iter;
264  PUSH(training_set_proportion_);
265  PUSH(training_set_size_);
266  if(training_set_func_ != 0)
267  {
268  PUSH(1);
269  }
270  else
271  {
272  PUSH(0);
273  }
274  PUSH(training_set_calc_switch_);
275  PUSH(sample_with_replacement_);
276  PUSH(stratification_method_);
277  PUSH(mtry_switch_);
278  PUSH(mtry_);
279  if(mtry_func_ != 0)
280  {
281  PUSH(1);
282  }
283  else
284  {
285  PUSH(0);
286  }
287  PUSH(tree_count_);
288  PUSH(min_split_node_size_);
289  PUSH(predict_weighted_);
290  #undef PUSH
291  }
292 
293  void make_from_map(map_type & in) // -> const: .operator[] -> .find
294  {
295  #define PULL(item_, type_) item_ = type_(in[#item_][0]);
296  #define PULLBOOL(item_, type_) item_ = type_(in[#item_][0] > 0);
297  PULL(training_set_proportion_,double);
298  PULL(training_set_size_, int);
299  PULL(mtry_, int);
300  PULL(tree_count_, int);
301  PULL(min_split_node_size_, int);
302  PULLBOOL(sample_with_replacement_, bool);
303  PULLBOOL(prepare_online_learning_, bool);
304  PULLBOOL(predict_weighted_, bool);
305 
306  PULL(training_set_calc_switch_, (RF_OptionTag)(int));
307 
308  PULL(stratification_method_, (RF_OptionTag)(int));
309  PULL(mtry_switch_, (RF_OptionTag)(int));
310 
311  /*don't pull*/
312  //PULL(mtry_func_!=0, int);
313  //PULL(training_set_func,int);
314  #undef PULL
315  #undef PULLBOOL
316  }
317  void make_map(map_type & in) const
318  {
319  #define PUSH(item_, type_) in[#item_] = double_array(1, double(item_));
320  #define PUSHFUNC(item_, type_) in[#item_] = double_array(1, double(item_!=0));
321  PUSH(training_set_proportion_,double);
322  PUSH(training_set_size_, int);
323  PUSH(mtry_, int);
324  PUSH(tree_count_, int);
325  PUSH(min_split_node_size_, int);
326  PUSH(sample_with_replacement_, bool);
327  PUSH(prepare_online_learning_, bool);
328  PUSH(predict_weighted_, bool);
329 
330  PUSH(training_set_calc_switch_, RF_OptionTag);
331  PUSH(stratification_method_, RF_OptionTag);
332  PUSH(mtry_switch_, RF_OptionTag);
333 
334  PUSHFUNC(mtry_func_, int);
335  PUSHFUNC(training_set_func_,int);
336  #undef PUSH
337  #undef PUSHFUNC
338  }
339 
340 
341  /**\brief create a RandomForestOptions object with default initialisation.
342  *
343  * look at the other member functions for more information on default
344  * values
345  */
347  :
348  training_set_proportion_(1.0),
349  training_set_size_(0),
350  training_set_func_(0),
351  training_set_calc_switch_(RF_PROPORTIONAL),
352  sample_with_replacement_(true),
353  stratification_method_(RF_NONE),
354  mtry_switch_(RF_SQRT),
355  mtry_(0),
356  mtry_func_(0),
357  predict_weighted_(false),
358  tree_count_(256),
359  min_split_node_size_(1),
360  prepare_online_learning_(false)
361  {}
362 
363  /**\brief specify stratification strategy
364  *
365  * default: RF_NONE
366  * possible values: RF_EQUAL, RF_PROPORTIONAL,
367  * RF_EXTERNAL, RF_NONE
368  * RF_EQUAL: get equal amount of samples per class.
369  * RF_PROPORTIONAL: sample proportional to fraction of class samples
370  * in population
371  * RF_EXTERNAL: strata_weights_ field of the ProblemSpec_t object
372  * has been set externally. (defunct)
373  */
375  {
376  vigra_precondition(in == RF_EQUAL ||
377  in == RF_PROPORTIONAL ||
378  in == RF_EXTERNAL ||
379  in == RF_NONE,
380  "RandomForestOptions::use_stratification()"
381  "input must be RF_EQUAL, RF_PROPORTIONAL,"
382  "RF_EXTERNAL or RF_NONE");
383  stratification_method_ = in;
384  return *this;
385  }
386 
387  RandomForestOptions & prepare_online_learning(bool in)
388  {
389  prepare_online_learning_=in;
390  return *this;
391  }
392 
393  /**\brief sample from training population with or without replacement?
394  *
395  * <br> Default: true
396  */
398  {
399  sample_with_replacement_ = in;
400  return *this;
401  }
402 
403  /**\brief specify the fraction of the total number of samples
404  * used per tree for learning.
405  *
406  * This value should be in [0.0 1.0] if sampling without
407  * replacement has been specified.
408  *
409  * <br> default : 1.0
410  */
412  {
413  training_set_proportion_ = in;
414  training_set_calc_switch_ = RF_PROPORTIONAL;
415  return *this;
416  }
417 
418  /**\brief directly specify the number of samples per tree
419  */
421  {
422  training_set_size_ = in;
423  training_set_calc_switch_ = RF_CONST;
424  return *this;
425  }
426 
427  /**\brief use external function to calculate the number of samples each
428  * tree should be learnt with.
429  *
430  * \param in function pointer that takes the number of rows in the
431  * learning data and outputs the number samples per tree.
432  */
434  {
435  training_set_func_ = in;
436  training_set_calc_switch_ = RF_FUNCTION;
437  return *this;
438  }
439 
440  /**\brief weight each tree with number of samples in that node
441  */
443  {
444  predict_weighted_ = true;
445  return *this;
446  }
447 
448  /**\brief use built in mapping to calculate mtry
449  *
450  * Use one of the built in mappings to calculate mtry from the number
451  * of columns in the input feature data.
452  * \param in possible values: RF_LOG, RF_SQRT or RF_ALL
453  * <br> default: RF_SQRT.
454  */
456  {
457  vigra_precondition(in == RF_LOG ||
458  in == RF_SQRT||
459  in == RF_ALL,
460  "RandomForestOptions()::features_per_node():"
461  "input must be of type RF_LOG or RF_SQRT");
462  mtry_switch_ = in;
463  return *this;
464  }
465 
466  /**\brief Set mtry to a constant value
467  *
468  * mtry is the number of columns/variates/variables randomly chosen
469  * to select the best split from.
470  *
471  */
473  {
474  mtry_ = in;
475  mtry_switch_ = RF_CONST;
476  return *this;
477  }
478 
479  /**\brief use a external function to calculate mtry
480  *
481  * \param in function pointer that takes int (number of columns
482  * of the and outputs int (mtry)
483  */
485  {
486  mtry_func_ = in;
487  mtry_switch_ = RF_FUNCTION;
488  return *this;
489  }
490 
491  /** How many trees to create?
492  *
493  * <br> Default: 255.
494  */
496  {
497  tree_count_ = in;
498  return *this;
499  }
500 
501  /**\brief Number of examples required for a node to be split.
502  *
503  * When the number of examples in a node is below this number,
504  * the node is not split even if class separation is not yet perfect.
505  * Instead, the node returns the proportion of each class
506  * (among the remaining examples) during the prediction phase.
507  * <br> Default: 1 (complete growing)
508  */
510  {
511  min_split_node_size_ = in;
512  return *this;
513  }
514 };
515 
516 
517 /** \brief problem types
518  */
519 enum Problem_t{REGRESSION, CLASSIFICATION, CHECKLATER};
520 
521 
522 /** \brief problem specification class for the random forest.
523  *
524  * This class contains all the problem specific parameters the random
525  * forest needs for learning. Specification of an instance of this class
526  * is optional as all necessary fields will be computed prior to learning
527  * if not specified.
528  *
529  * if needed usage is similar to that of RandomForestOptions
530  */
531 
532 template<class LabelType = double>
534 {
535 
536 
537 public:
538 
539  /** \brief problem class
540  */
541 
542  typedef LabelType Label_t;
543  ArrayVector<Label_t> classes;
545  typedef std::map<std::string, double_array> map_type;
546 
547  int column_count_; // number of features
548  int class_count_; // number of classes
549  int row_count_; // number of samples
550 
551  int actual_mtry_; // mtry used in training
552  int actual_msample_; // number if in-bag samples per tree
553 
554  Problem_t problem_type_; // classification or regression
555 
556  int used_; // this ProblemSpec is valid
557  ArrayVector<double> class_weights_; // if classes have different importance
558  int is_weighted_; // class_weights_ are used
559  double precision_; // termination criterion for regression loss
560  int response_size_;
561 
562  template<class T>
563  void to_classlabel(int index, T & out) const
564  {
565  out = T(classes[index]);
566  }
567  template<class T>
568  int to_classIndex(T index) const
569  {
570  return std::find(classes.begin(), classes.end(), index) - classes.begin();
571  }
572 
573  #define EQUALS(field) field(rhs.field)
574  ProblemSpec(ProblemSpec const & rhs)
575  :
576  EQUALS(column_count_),
577  EQUALS(class_count_),
578  EQUALS(row_count_),
579  EQUALS(actual_mtry_),
580  EQUALS(actual_msample_),
581  EQUALS(problem_type_),
582  EQUALS(used_),
583  EQUALS(class_weights_),
584  EQUALS(is_weighted_),
585  EQUALS(precision_),
586  EQUALS(response_size_)
587  {
588  std::back_insert_iterator<ArrayVector<Label_t> >
589  iter(classes);
590  std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
591  }
592  #undef EQUALS
593  #define EQUALS(field) field(rhs.field)
594  template<class T>
595  ProblemSpec(ProblemSpec<T> const & rhs)
596  :
597  EQUALS(column_count_),
598  EQUALS(class_count_),
599  EQUALS(row_count_),
600  EQUALS(actual_mtry_),
601  EQUALS(actual_msample_),
602  EQUALS(problem_type_),
603  EQUALS(used_),
604  EQUALS(class_weights_),
605  EQUALS(is_weighted_),
606  EQUALS(precision_),
607  EQUALS(response_size_)
608  {
609  std::back_insert_iterator<ArrayVector<Label_t> >
610  iter(classes);
611  std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
612  }
613  #undef EQUALS
614 
615  #define EQUALS(field) (this->field = rhs.field);
616  ProblemSpec & operator=(ProblemSpec const & rhs)
617  {
618  EQUALS(column_count_);
619  EQUALS(class_count_);
620  EQUALS(row_count_);
621  EQUALS(actual_mtry_);
622  EQUALS(actual_msample_);
623  EQUALS(problem_type_);
624  EQUALS(used_);
625  EQUALS(is_weighted_);
626  EQUALS(precision_);
627  EQUALS(response_size_)
628  class_weights_.clear();
629  std::back_insert_iterator<ArrayVector<double> >
630  iter2(class_weights_);
631  std::copy(rhs.class_weights_.begin(), rhs.class_weights_.end(), iter2);
632  classes.clear();
633  std::back_insert_iterator<ArrayVector<Label_t> >
634  iter(classes);
635  std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
636  return *this;
637  }
638 
639  template<class T>
640  ProblemSpec<Label_t> & operator=(ProblemSpec<T> const & rhs)
641  {
642  EQUALS(column_count_);
643  EQUALS(class_count_);
644  EQUALS(row_count_);
645  EQUALS(actual_mtry_);
646  EQUALS(actual_msample_);
647  EQUALS(problem_type_);
648  EQUALS(used_);
649  EQUALS(is_weighted_);
650  EQUALS(precision_);
651  EQUALS(response_size_)
652  class_weights_.clear();
653  std::back_insert_iterator<ArrayVector<double> >
654  iter2(class_weights_);
655  std::copy(rhs.class_weights_.begin(), rhs.class_weights_.end(), iter2);
656  classes.clear();
657  std::back_insert_iterator<ArrayVector<Label_t> >
658  iter(classes);
659  std::copy(rhs.classes.begin(), rhs.classes.end(), iter);
660  return *this;
661  }
662  #undef EQUALS
663 
664  template<class T>
665  bool operator==(ProblemSpec<T> const & rhs)
666  {
667  bool result = true;
668  #define COMPARE(field) result = result && (this->field == rhs.field);
669  COMPARE(column_count_);
670  COMPARE(class_count_);
671  COMPARE(row_count_);
672  COMPARE(actual_mtry_);
673  COMPARE(actual_msample_);
674  COMPARE(problem_type_);
675  COMPARE(is_weighted_);
676  COMPARE(precision_);
677  COMPARE(used_);
678  COMPARE(class_weights_);
679  COMPARE(classes);
680  COMPARE(response_size_)
681  #undef COMPARE
682  return result;
683  }
684 
685  bool operator!=(ProblemSpec & rhs)
686  {
687  return !(*this == rhs);
688  }
689 
690 
691  size_t serialized_size() const
692  {
693  return 10 + class_count_ *int(is_weighted_+1);
694  }
695 
696 
697  template<class Iter>
698  void unserialize(Iter const & begin, Iter const & end)
699  {
700  Iter iter = begin;
701  vigra_precondition(end - begin >= 10,
702  "ProblemSpec::unserialize():"
703  "wrong number of parameters");
704  #define PULL(item_, type_) item_ = type_(*iter); ++iter;
705  PULL(column_count_,int);
706  PULL(class_count_, int);
707 
708  vigra_precondition(end - begin >= 10 + class_count_,
709  "ProblemSpec::unserialize(): 1");
710  PULL(row_count_, int);
711  PULL(actual_mtry_,int);
712  PULL(actual_msample_, int);
713  PULL(problem_type_, Problem_t);
714  PULL(is_weighted_, int);
715  PULL(used_, int);
716  PULL(precision_, double);
717  PULL(response_size_, int);
718  if(is_weighted_)
719  {
720  vigra_precondition(end - begin == 10 + 2*class_count_,
721  "ProblemSpec::unserialize(): 2");
722  class_weights_.insert(class_weights_.end(),
723  iter,
724  iter + class_count_);
725  iter += class_count_;
726  }
727  classes.insert(classes.end(), iter, end);
728  #undef PULL
729  }
730 
731 
732  template<class Iter>
733  void serialize(Iter const & begin, Iter const & end) const
734  {
735  Iter iter = begin;
736  vigra_precondition(end - begin == serialized_size(),
737  "RandomForestOptions::serialize():"
738  "wrong number of parameters");
739  #define PUSH(item_) *iter = double(item_); ++iter;
740  PUSH(column_count_);
741  PUSH(class_count_)
742  PUSH(row_count_);
743  PUSH(actual_mtry_);
744  PUSH(actual_msample_);
745  PUSH(problem_type_);
746  PUSH(is_weighted_);
747  PUSH(used_);
748  PUSH(precision_);
749  PUSH(response_size_);
750  if(is_weighted_)
751  {
752  std::copy(class_weights_.begin(),
753  class_weights_.end(),
754  iter);
755  iter += class_count_;
756  }
757  std::copy(classes.begin(),
758  classes.end(),
759  iter);
760  #undef PUSH
761  }
762 
763  void make_from_map(map_type & in) // -> const: .operator[] -> .find
764  {
765  #define PULL(item_, type_) item_ = type_(in[#item_][0]);
766  PULL(column_count_,int);
767  PULL(class_count_, int);
768  PULL(row_count_, int);
769  PULL(actual_mtry_,int);
770  PULL(actual_msample_, int);
771  PULL(problem_type_, (Problem_t)int);
772  PULL(is_weighted_, int);
773  PULL(used_, int);
774  PULL(precision_, double);
775  PULL(response_size_, int);
776  class_weights_ = in["class_weights_"];
777  #undef PULL
778  }
779  void make_map(map_type & in) const
780  {
781  #define PUSH(item_) in[#item_] = double_array(1, double(item_));
782  PUSH(column_count_);
783  PUSH(class_count_)
784  PUSH(row_count_);
785  PUSH(actual_mtry_);
786  PUSH(actual_msample_);
787  PUSH(problem_type_);
788  PUSH(is_weighted_);
789  PUSH(used_);
790  PUSH(precision_);
791  PUSH(response_size_);
792  in["class_weights_"] = class_weights_;
793  #undef PUSH
794  }
795 
796  /**\brief set default values (-> values not set)
797  */
799  : column_count_(0),
800  class_count_(0),
801  row_count_(0),
802  actual_mtry_(0),
803  actual_msample_(0),
804  problem_type_(CHECKLATER),
805  used_(false),
806  is_weighted_(false),
807  precision_(0.0),
808  response_size_(1)
809  {}
810 
811 
812  ProblemSpec & column_count(int in)
813  {
814  column_count_ = in;
815  return *this;
816  }
817 
818  /**\brief supply with class labels -
819  *
820  * the preprocessor will not calculate the labels needed in this case.
821  */
822  template<class C_Iter>
823  ProblemSpec & classes_(C_Iter begin, C_Iter end)
824  {
825  classes.clear();
826  int size = end-begin;
827  for(int k=0; k<size; ++k, ++begin)
828  classes.push_back(detail::RequiresExplicitCast<LabelType>::cast(*begin));
829  class_count_ = size;
830  return *this;
831  }
832 
833  /** \brief supply with class weights -
834  *
835  * this is the only case where you would really have to
836  * create a ProblemSpec object.
837  */
838  template<class W_Iter>
839  ProblemSpec & class_weights(W_Iter begin, W_Iter end)
840  {
841  class_weights_.clear();
842  class_weights_.insert(class_weights_.end(), begin, end);
843  is_weighted_ = true;
844  return *this;
845  }
846 
847 
848 
849  void clear()
850  {
851  used_ = false;
852  classes.clear();
853  class_weights_.clear();
854  column_count_ = 0 ;
855  class_count_ = 0;
856  actual_mtry_ = 0;
857  actual_msample_ = 0;
858  problem_type_ = CHECKLATER;
859  is_weighted_ = false;
860  precision_ = 0.0;
861  response_size_ = 0;
862 
863  }
864 
865  bool used() const
866  {
867  return used_ != 0;
868  }
869 };
870 
871 
872 //@}
873 
874 
875 
876 /**\brief Standard early stopping criterion
877  *
878  * Stop if region.size() < min_split_node_size_;
879  */
881 {
882  public:
883  int min_split_node_size_;
884 
885  template<class Opt>
886  EarlyStoppStd(Opt opt)
887  : min_split_node_size_(opt.min_split_node_size_)
888  {}
889 
890  template<class T>
891  void set_external_parameters(ProblemSpec<T>const &, int /* tree_count */ = 0, bool /* is_weighted_ */ = false)
892  {}
893 
894  template<class Region>
895  bool operator()(Region& region)
896  {
897  return region.size() < min_split_node_size_;
898  }
899 
900  template<class WeightIter, class T, class C>
901  bool after_prediction(WeightIter, int /* k */, MultiArrayView<2, T, C> /* prob */, double /* totalCt */)
902  {
903  return false;
904  }
905 };
906 
907 
908 } // namespace vigra
909 
910 #endif //VIGRA_RF_COMMON_HXX

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.10.0 (Thu Jan 8 2015)