[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

vigra/random_forest/rf_earlystopping.hxx VIGRA

00001 #ifndef RF_EARLY_STOPPING_P_HXX
00002 #define RF_EARLY_STOPPING_P_HXX
00003 #include <cmath>
00004 #include "rf_common.hxx"
00005 
00006 namespace vigra
00007 {
00008 
00009 #if 0    
00010 namespace es_detail
00011 {
00012     template<class T>
00013     T power(T const & in, int n)
00014     {
00015         T result = NumericTraits<T>::one();
00016         for(int ii = 0; ii < n ;++ii)
00017             result *= in;
00018         return result;
00019     }
00020 }
00021 #endif
00022 
00023 /**Base class from which all EarlyStopping Functors derive.
00024  */
00025 class StopBase
00026 {
00027 protected:
00028     ProblemSpec<> ext_param_;
00029     int tree_count_ ;
00030     bool is_weighted_;
00031 
00032 public:
00033     template<class T>
00034     void set_external_parameters(ProblemSpec<T> const  &prob, int tree_count = 0, bool is_weighted = false)
00035     {
00036         ext_param_ = prob; 
00037         is_weighted_ = is_weighted;
00038         tree_count_ = tree_count;
00039     }
00040 
00041 #ifdef DOXYGEN
00042         /** called after the prediction of a tree was added to the total prediction
00043          * \param weightIter Iterator to the weights delivered by current tree.
00044          * \param k          after kth tree
00045          * \param prob       Total probability array
00046          * \param totalCt    sum of probability array. 
00047          */
00048     template<class WeightIter, class T, class C>
00049     bool after_prediction(WeightIter weightIter, int k, MultiArrayView<2, T, C> const &  prob , double totalCt)
00050 #else
00051     template<class WeightIter, class T, class C>
00052     bool after_prediction(WeightIter,  int /* k */, MultiArrayView<2, T, C> const & /* prob */, double /* totalCt */)
00053     {return false;}
00054 #endif //DOXYGEN
00055 };
00056 
00057 
00058 /**Stop predicting after a set number of trees
00059  */
00060 class StopAfterTree : public StopBase
00061 {
00062 public:
00063     double max_tree_p;
00064     int max_tree_;
00065     typedef StopBase SB;
00066     
00067     ArrayVector<double> depths;
00068     
00069     /** Constructor
00070      * \param max_tree number of trees to be used for prediction
00071      */
00072     StopAfterTree(double max_tree)
00073     :
00074         max_tree_p(max_tree)
00075     {}
00076 
00077     template<class T>
00078     void set_external_parameters(ProblemSpec<T> const  &prob, int tree_count = 0, bool is_weighted = false)
00079     {
00080         max_tree_ = ceil(max_tree_p * tree_count);
00081         SB::set_external_parameters(prob, tree_count, is_weighted);
00082     }
00083 
00084     template<class WeightIter, class T, class C>
00085     bool after_prediction(WeightIter,  int k, MultiArrayView<2, T, C> const & /* prob */, double /* totalCt */)
00086     {
00087         if(k == SB::tree_count_ -1)
00088         {
00089                 depths.push_back(double(k+1)/double(SB::tree_count_));
00090                 return false;
00091         }
00092         if(k < max_tree_)
00093            return false;
00094         depths.push_back(double(k+1)/double(SB::tree_count_));
00095         return true;  
00096     }
00097 };
00098 
00099 /** Stop predicting after a certain amount of votes exceed certain proportion.
00100  *  case unweighted voting: stop if the leading class exceeds proportion * SB::tree_count_ 
00101  *  case weighted voting: stop if the leading class exceeds proportion * msample_ * SB::tree_count_ ;
00102  *                          (maximal number of votes possible in both cases)
00103  */
00104 class StopAfterVoteCount : public StopBase
00105 {
00106 public:
00107     double proportion_;
00108     typedef StopBase SB;
00109     ArrayVector<double> depths;
00110 
00111     /** Constructor
00112      * \param proportion specify proportion to be used.
00113      */
00114     StopAfterVoteCount(double proportion)
00115     :
00116         proportion_(proportion)
00117     {}
00118 
00119     template<class WeightIter, class T, class C>
00120     bool after_prediction(WeightIter,  int k, MultiArrayView<2, T, C> const & prob, double /* totalCt */)
00121     {
00122         if(k == SB::tree_count_ -1)
00123         {
00124                 depths.push_back(double(k+1)/double(SB::tree_count_));
00125                 return false;
00126         }
00127 
00128 
00129         if(SB::is_weighted_)
00130         {
00131             if(prob[argMax(prob)] > proportion_ *SB::ext_param_.actual_msample_* SB::tree_count_)
00132             {
00133                 depths.push_back(double(k+1)/double(SB::tree_count_));
00134                 return true;
00135             }
00136         }
00137         else
00138         {
00139             if(prob[argMax(prob)] > proportion_ * SB::tree_count_)
00140             {
00141                 depths.push_back(double(k+1)/double(SB::tree_count_));
00142                 return true;
00143             }
00144         }
00145         return false;
00146     }
00147 
00148 };
00149 
00150 
00151 /** Stop predicting if the 2norm of the probabilities does not change*/
00152 class StopIfConverging : public StopBase
00153 
00154 {
00155 public:
00156     double thresh_;
00157     int num_;
00158     MultiArray<2, double> last_;
00159     MultiArray<2, double> cur_;
00160     ArrayVector<double> depths;
00161     typedef StopBase SB;
00162 
00163     /** Constructor
00164      * \param thresh: If the two norm of the probabilities changes less then thresh then stop
00165      * \param num   : look at atleast num trees before stopping
00166      */
00167     StopIfConverging(double thresh, int num = 10)
00168     :
00169         thresh_(thresh), 
00170         num_(num)
00171     {}
00172 
00173     template<class T>
00174     void set_external_parameters(ProblemSpec<T> const  &prob, int tree_count = 0, bool is_weighted = false)
00175     {
00176         last_.reshape(MultiArrayShape<2>::type(1, prob.class_count_), 0);
00177         cur_.reshape(MultiArrayShape<2>::type(1, prob.class_count_), 0);
00178         SB::set_external_parameters(prob, tree_count, is_weighted);
00179     }
00180     template<class WeightIter, class T, class C>
00181     bool after_prediction(WeightIter iter,  int k, MultiArrayView<2, T, C> const & prob, double totalCt)
00182     {
00183         if(k == SB::tree_count_ -1)
00184         {
00185                 depths.push_back(double(k+1)/double(SB::tree_count_));
00186                 return false;
00187         }
00188         if(k <= num_)
00189         {
00190             last_ = prob;
00191             last_/= last_.norm(1);
00192             return false;
00193         }
00194         else 
00195         {
00196             cur_ = prob;
00197             cur_ /= cur_.norm(1);
00198             last_ -= cur_;
00199             double nrm = last_.norm(); 
00200             if(nrm < thresh_)
00201             {
00202                 depths.push_back(double(k+1)/double(SB::tree_count_));
00203                 return true;
00204             }
00205             else
00206             {
00207                 last_ = cur_;
00208             }
00209         }
00210         return false;
00211     }
00212 };
00213 
00214 
00215 /** Stop predicting if the margin prob(leading class) - prob(second class) exceeds a proportion
00216  *  case unweighted voting: stop if margin exceeds proportion * SB::tree_count_ 
00217  *  case weighted voting: stop if margin exceeds proportion * msample_ * SB::tree_count_ ;
00218  *                          (maximal number of votes possible in both cases)
00219  */
00220 class StopIfMargin : public StopBase  
00221 {
00222 public:
00223     double proportion_;
00224     typedef StopBase SB;
00225     ArrayVector<double> depths;
00226 
00227     /** Constructor
00228      * \param proportion specify proportion to be used.
00229      */
00230     StopIfMargin(double proportion)
00231     :
00232         proportion_(proportion)
00233     {}
00234 
00235     template<class WeightIter, class T, class C>
00236     bool after_prediction(WeightIter,  int k, MultiArrayView<2, T, C> prob, double /* totalCt */)
00237     {
00238         if(k == SB::tree_count_ -1)
00239         {
00240                 depths.push_back(double(k+1)/double(SB::tree_count_));
00241                 return false;
00242         }
00243         int index = argMax(prob);
00244         double a = prob[argMax(prob)];
00245         prob[argMax(prob)] = 0;
00246         double b = prob[argMax(prob)];
00247         prob[index] = a; 
00248         double margin = a - b;
00249         if(SB::is_weighted_)
00250         {
00251             if(margin > proportion_ *SB::ext_param_.actual_msample_ * SB::tree_count_)
00252             {
00253                 depths.push_back(double(k+1)/double(SB::tree_count_));
00254                 return true;
00255             }
00256         }
00257         else
00258         {
00259             if(prob[argMax(prob)] > proportion_ * SB::tree_count_)
00260             {
00261                 depths.push_back(double(k+1)/double(SB::tree_count_));
00262                 return true;
00263             }
00264         }
00265         return false;
00266     }
00267 };
00268 
00269 
00270 /**Probabilistic Stopping criterion (binomial test)
00271  *
00272  * Can only be used in a two class setting
00273  *
00274  * Stop if the Parameters estimated for the underlying binomial distribution
00275  * can be estimated with certainty over 1-alpha.
00276  * (Thesis, Rahul Nair Page 80 onwards: called the "binomial" criterion
00277  */
00278 class StopIfBinTest : public StopBase  
00279 {
00280 public:
00281     double alpha_;  
00282     MultiArrayView<2, double> n_choose_k;
00283     /** Constructor
00284      * \param alpha specify alpha (=proportion) value for binomial test.
00285      * \param nck_ Matrix with precomputed values for n choose k
00286      * nck_(n, k) is n choose k. 
00287      */
00288     StopIfBinTest(double alpha, MultiArrayView<2, double> nck_)
00289     :
00290         alpha_(alpha),
00291         n_choose_k(nck_)
00292     {}
00293     typedef StopBase SB;
00294     
00295     /**ArrayVector that will contain the fraction of trees that was visited before terminating
00296      */
00297     ArrayVector<double> depths;
00298 
00299     double binomial(int N, int k, double p)
00300     {
00301 //        return n_choose_k(N, k) * es_detail::power(p, k) *es_detail::power(1 - p, N-k);
00302         return n_choose_k(N, k) * std::pow(p, k) * std::pow(1 - p, N-k);
00303     }
00304 
00305     template<class WeightIter, class T, class C>
00306     bool after_prediction(WeightIter iter,  int k, MultiArrayView<2, T, C> prob, double totalCt)
00307     {
00308         if(k == SB::tree_count_ -1)
00309         {
00310                 depths.push_back(double(k+1)/double(SB::tree_count_));
00311                 return false;
00312         }
00313         if(k < 10)
00314         {
00315             return false;
00316         }
00317         int index = argMax(prob);
00318         int n_a  = prob[index];
00319         int n_b  = prob[(index+1)%2];
00320         int n_tilde = (SB::tree_count_ - n_a + n_b);
00321         double p_a = double(n_b - n_a + n_tilde)/double(2* n_tilde);
00322         vigra_precondition(p_a <= 1, "probability should be smaller than 1");
00323         double cum_val = 0;
00324         int c = 0; 
00325   //      std::cerr << "prob: " << p_a << std::endl;
00326         if(n_a <= 0)n_a = 0;
00327         if(n_b <= 0)n_b = 0;
00328         for(int ii = 0; ii <= n_b + n_a;++ii)
00329         {
00330 //            std::cerr << "nb +ba " << n_b + n_a << " " << ii <<std::endl;
00331             cum_val += binomial(n_b + n_a, ii, p_a); 
00332             if(cum_val >= 1 -alpha_)
00333             {
00334                 c = ii;
00335                 break;
00336             }
00337         }
00338 //        std::cerr << c << " " << n_a << " " << n_b << " " << p_a <<   alpha_ << std::endl;
00339         if(c < n_a)
00340         {
00341             depths.push_back(double(k+1)/double(SB::tree_count_));
00342             return true;
00343         }
00344 
00345         return false;
00346     }
00347 };
00348 
00349 /**Probabilistic Stopping criteria. (toChange)
00350  *
00351  * Can only be used in a two class setting
00352  *
00353  * Stop if the probability that the decision will change after seeing all trees falls under
00354  * a specified value alpha.
00355  * (Thesis, Rahul Nair Page 80 onwards: called the "toChange" criterion
00356  */
00357 class StopIfProb : public StopBase  
00358 {
00359 public:
00360     double alpha_;  
00361     MultiArrayView<2, double> n_choose_k;
00362     
00363     
00364     /** Constructor
00365      * \param alpha specify alpha (=proportion) value
00366      * \param nck_ Matrix with precomputed values for n choose k
00367      * nck_(n, k) is n choose k. 
00368      */
00369     StopIfProb(double alpha, MultiArrayView<2, double> nck_)
00370     :
00371         alpha_(alpha),
00372         n_choose_k(nck_)
00373     {}
00374     typedef StopBase SB;
00375     /**ArrayVector that will contain the fraction of trees that was visited before terminating
00376      */
00377     ArrayVector<double> depths;
00378 
00379     double binomial(int N, int k, double p)
00380     {
00381 //        return n_choose_k(N, k) * es_detail::power(p, k) *es_detail::power(1 - p, N-k);
00382         return n_choose_k(N, k) * std::pow(p, k) * std::pow(1 - p, N-k);
00383     }
00384 
00385     template<class WeightIter, class T, class C>
00386     bool after_prediction(WeightIter iter,  int k, MultiArrayView<2, T, C> prob, double totalCt)
00387     {
00388         if(k == SB::tree_count_ -1)
00389         {
00390                 depths.push_back(double(k+1)/double(SB::tree_count_));
00391                 return false;
00392         }
00393         if(k <= 10)
00394         {
00395             return false;
00396         }
00397         int index = argMax(prob);
00398         int n_a  = prob[index];
00399         int n_b  = prob[(index+1)%2];
00400         int n_needed = ceil(double(SB::tree_count_)/2.0)-n_a;
00401         int n_tilde = SB::tree_count_ - (n_a +n_b);
00402         if(n_tilde <= 0) n_tilde = 0;
00403         if(n_needed <= 0) n_needed = 0;
00404         double p = 0;
00405         for(int ii = n_needed; ii < n_tilde; ++ii)
00406             p += binomial(n_tilde, ii, 0.5);
00407         
00408         if(p >= 1-alpha_)
00409         {
00410             depths.push_back(double(k+1)/double(SB::tree_count_));
00411             return true;
00412         }
00413 
00414         return false;
00415     }
00416 };
00417 } //namespace vigra;
00418 #endif //RF_EARLY_STOPPING_P_HXX

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.8.0 (20 Sep 2011)