[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]

vigra/random_forest/rf_preprocessing.hxx VIGRA

00001 /************************************************************************/
00002 /*                                                                      */
00003 /*        Copyright 2008-2009 by  Ullrich Koethe and Rahul Nair         */
00004 /*                                                                      */
00005 /*    This file is part of the VIGRA computer vision library.           */
00006 /*    The VIGRA Website is                                              */
00007 /*        http://hci.iwr.uni-heidelberg.de/vigra/                       */
00008 /*    Please direct questions, bug reports, and contributions to        */
00009 /*        ullrich.koethe@iwr.uni-heidelberg.de    or                    */
00010 /*        vigra@informatik.uni-hamburg.de                               */
00011 /*                                                                      */
00012 /*    Permission is hereby granted, free of charge, to any person       */
00013 /*    obtaining a copy of this software and associated documentation    */
00014 /*    files (the "Software"), to deal in the Software without           */
00015 /*    restriction, including without limitation the rights to use,      */
00016 /*    copy, modify, merge, publish, distribute, sublicense, and/or      */
00017 /*    sell copies of the Software, and to permit persons to whom the    */
00018 /*    Software is furnished to do so, subject to the following          */
00019 /*    conditions:                                                       */
00020 /*                                                                      */
00021 /*    The above copyright notice and this permission notice shall be    */
00022 /*    included in all copies or substantial portions of the             */
00023 /*    Software.                                                         */
00024 /*                                                                      */
00025 /*    THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND    */
00026 /*    EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES   */
00027 /*    OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND          */
00028 /*    NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT       */
00029 /*    HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,      */
00030 /*    WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING      */
00031 /*    FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR     */
00032 /*    OTHER DEALINGS IN THE SOFTWARE.                                   */
00033 /*                                                                      */
00034 /************************************************************************/
00035 
00036 #ifndef VIGRA_RF_PREPROCESSING_HXX
00037 #define VIGRA_RF_PREPROCESSING_HXX
00038 
00039 #include <limits>
00040 #include "rf_common.hxx"
00041 
00042 namespace vigra
00043 {
00044 
00045 /** Class used while preprocessing  (currently used only during learn)
00046  *
00047  * This class is internally used by the Random Forest learn function. 
00048  * Different split functors may need to process the data in different manners
00049  * (i.e., regression labels that should not be touched and classification 
00050  * labels that must be converted into a integral format)
00051  *
00052  * This Class only exists in specialized versions, where the Tag class is 
00053  * fixed. 
00054  *
00055  * The Tag class is determined by Splitfunctor::Preprocessor_t . Currently
00056  * it can either be ClassificationTag or RegressionTag.  look At the 
00057  * RegressionTag specialisation for the basic interface if you ever happen
00058  * to care.... - or need some sort of vague new preprocessor.  
00059  * new preprocessor ( Soft labels or whatever)
00060  */
00061 template<class Tag, class LabelType, class T1, class C1, class T2, class C2>
00062 class Processor;
00063 
00064 namespace detail
00065 {
00066 
00067     /* Common helper function used in all Processors. 
00068      * This function analyses the options struct and calculates the real 
00069      * values needed for the current problem (data)
00070      */
00071     template<class T>
00072     void fill_external_parameters(RandomForestOptions const  & options,
00073                                   ProblemSpec<T> & ext_param)
00074     {
00075         // set correct value for mtry
00076         switch(options.mtry_switch_)
00077         {
00078             case RF_SQRT:
00079                 ext_param.actual_mtry_ =
00080                     int(std::floor(
00081                             std::sqrt(double(ext_param.column_count_))
00082                             + 0.5));
00083                 break;
00084             case RF_LOG:
00085                 // this is in Breimans original paper
00086                 ext_param.actual_mtry_ =
00087                     int(1+(std::log(double(ext_param.column_count_))
00088                            /std::log(2.0)));
00089                 break;
00090             case RF_FUNCTION:
00091                 ext_param.actual_mtry_ =
00092                     options.mtry_func_(ext_param.column_count_);
00093                 break;
00094             case RF_ALL:
00095                 ext_param.actual_mtry_ = ext_param.column_count_;
00096                 break;
00097             default:
00098                 ext_param.actual_mtry_ =
00099                     options.mtry_;
00100         }
00101         // set correct value for msample
00102         switch(options.training_set_calc_switch_)
00103         {
00104             case RF_CONST:
00105                 ext_param.actual_msample_ =
00106                     options.training_set_size_;
00107                 break;
00108             case RF_PROPORTIONAL:
00109                 ext_param.actual_msample_ =
00110                     (int)std::ceil(  options.training_set_proportion_ *
00111                                      ext_param.row_count_);
00112                     break;
00113             case RF_FUNCTION:
00114                 ext_param.actual_msample_ =
00115                     options.training_set_func_(ext_param.row_count_);
00116                 break;
00117             default:
00118                 vigra_precondition(1!= 1, "unexpected error");
00119 
00120         }
00121 
00122     }
00123     
00124     /* Returns true if MultiArray contains NaNs
00125      */
00126     template<unsigned int N, class T, class C>
00127     bool contains_nan(MultiArrayView<N, T, C> const & in)
00128     {
00129         for(int ii = 0; ii < in.size(); ++ii)
00130             if(in[ii] != in[ii])
00131                 return true;
00132         return false; 
00133     }
00134     
00135     /* Returns true if MultiArray contains Infs
00136      */
00137     template<unsigned int N, class T, class C>
00138     bool contains_inf(MultiArrayView<N, T, C> const & in)
00139     {
00140          if(!std::numeric_limits<T>::has_infinity)
00141              return false;
00142          for(int ii = 0; ii < in.size(); ++ii)
00143             if(in[ii] == std::numeric_limits<T>::infinity())
00144                 return true;
00145          return false;
00146     }
00147 } // namespace detail
00148 
00149 
00150 
00151 /** Preprocessor used during Classification
00152  *
00153  * This class converts the labels int Integral labels which are used by the 
00154  * standard split functor to address memory in the node objects.
00155  */
00156 template<class LabelType, class T1, class C1, class T2, class C2>
00157 class Processor<ClassificationTag, LabelType, T1, C1, T2, C2>
00158 {
00159     public:
00160     typedef Int32 LabelInt;
00161     typedef MultiArrayView<2, T1, C1> Feature_t;
00162     typedef MultiArray<2, T1> FeatureWithMemory_t;
00163     typedef MultiArrayView<2,LabelInt> Label_t;
00164     MultiArrayView<2, T1, C1>const &    features_;
00165     MultiArray<2, LabelInt>             intLabels_;
00166     MultiArrayView<2, LabelInt>         strata_;
00167 
00168     template<class T>
00169     Processor(MultiArrayView<2, T1, C1>const & features,   
00170               MultiArrayView<2, T2, C2>const & response,
00171               RandomForestOptions &options,         
00172               ProblemSpec<T> &ext_param)
00173     :
00174         features_( features) // do not touch the features. 
00175     {
00176         vigra_precondition(!detail::contains_nan(features), "Processor(): Feature Matrix "
00177                                                            "Contains NaNs");
00178         vigra_precondition(!detail::contains_nan(response), "Processor(): Response "
00179                                                            "Contains NaNs");
00180         vigra_precondition(!detail::contains_inf(features), "Processor(): Feature Matrix "
00181                                                            "Contains inf");
00182         vigra_precondition(!detail::contains_inf(response), "Processor(): Response "
00183                                                            "Contains inf");
00184         // set some of the problem specific parameters 
00185         ext_param.column_count_  = features.shape(1);
00186         ext_param.row_count_     = features.shape(0);
00187         ext_param.problem_type_  = CLASSIFICATION;
00188         ext_param.used_          = true;
00189         intLabels_.reshape(response.shape());
00190 
00191         //get the class labels
00192         if(ext_param.class_count_ == 0)
00193         {
00194             // fill up a map with the current labels and then create the 
00195             // integral labels.
00196             std::set<T2>                    labelToInt;
00197             for(MultiArrayIndex k = 0; k < features.shape(0); ++k)
00198                 labelToInt.insert(response(k,0));
00199             std::vector<T2> tmp_(labelToInt.begin(), labelToInt.end());
00200             ext_param.classes_(tmp_.begin(), tmp_.end());
00201         }
00202         for(MultiArrayIndex k = 0; k < features.shape(0); ++k)
00203         {
00204             if(std::find(ext_param.classes.begin(), ext_param.classes.end(), response(k,0)) == ext_param.classes.end())
00205             {
00206                 throw std::runtime_error("unknown label type");
00207             }
00208             else
00209                 intLabels_(k, 0) = std::find(ext_param.classes.begin(), ext_param.classes.end(), response(k,0))
00210                                     - ext_param.classes.begin();
00211         }
00212         // set class weights
00213         if(ext_param.class_weights_.size() == 0)
00214         {
00215             ArrayVector<T2> 
00216                 tmp((std::size_t)ext_param.class_count_, 
00217                     NumericTraits<T2>::one());
00218             ext_param.class_weights(tmp.begin(), tmp.end());
00219         }
00220 
00221         // set mtry and msample
00222         detail::fill_external_parameters(options, ext_param);
00223 
00224         // set strata
00225         strata_ = intLabels_;
00226 
00227     }
00228 
00229     /** Access the processed features
00230      */
00231     MultiArrayView<2, T1, C1>const & features()
00232     {
00233         return features_;
00234     }
00235 
00236     /** Access processed labels
00237      */
00238     MultiArrayView<2, LabelInt>& response()
00239     {
00240         return intLabels_;
00241     }
00242 
00243     /** Access processed strata
00244      */
00245     ArrayVectorView < LabelInt>  strata()
00246     {
00247         return ArrayVectorView<LabelInt>(intLabels_.size(), intLabels_.data());
00248     }
00249 
00250     /** Access strata fraction sized - not used currently
00251      */
00252     ArrayVectorView< double> strata_prob()
00253     {
00254         return ArrayVectorView< double>();
00255     }
00256 };
00257 
00258 
00259 
00260 /** Regression Preprocessor - This basically does not do anything with the
00261  * data.
00262  */
00263 template<class LabelType, class T1, class C1, class T2, class C2>
00264 class Processor<RegressionTag,LabelType, T1, C1, T2, C2>
00265 {
00266 public:
00267     // only views are created - no data copied.
00268     MultiArrayView<2, T1, C1>   features_;
00269     MultiArrayView<2, T2, C2>   response_;
00270     RandomForestOptions const & options_;
00271     ProblemSpec<LabelType> const &
00272                                 ext_param_;
00273     // will only be filled if needed
00274     MultiArray<2, int>      strata_;
00275     bool strata_filled;
00276 
00277     // copy the views.
00278     template<class T>
00279     Processor(  MultiArrayView<2, T1, C1>   features,
00280                 MultiArrayView<2, T2, C2>   response,
00281                 RandomForestOptions const & options,
00282                 ProblemSpec<T>& ext_param)
00283     :
00284         features_(features),
00285         response_(response),
00286         options_(options),
00287         ext_param_(ext_param)
00288     {
00289         // set some of the problem specific parameters 
00290         ext_param.column_count_  = features.shape(1);
00291         ext_param.row_count_     = features.shape(0);
00292         ext_param.problem_type_  = REGRESSION;
00293         ext_param.used_          = true;
00294         detail::fill_external_parameters(options, ext_param);
00295         vigra_precondition(!detail::contains_nan(features), "Processor(): Feature Matrix "
00296                                                            "Contains NaNs");
00297         vigra_precondition(!detail::contains_nan(response), "Processor(): Response "
00298                                                            "Contains NaNs");
00299         vigra_precondition(!detail::contains_inf(features), "Processor(): Feature Matrix "
00300                                                            "Contains inf");
00301         vigra_precondition(!detail::contains_inf(response), "Processor(): Response "
00302                                                            "Contains inf");
00303         strata_ = MultiArray<2, int> (MultiArrayShape<2>::type(response_.shape(0), 1));
00304         ext_param.response_size_ = response.shape(1);
00305         ext_param.class_count_ = response_.shape(1);
00306         std::vector<T2> tmp_(ext_param.class_count_, 0);
00307             ext_param.classes_(tmp_.begin(), tmp_.end());
00308     }
00309 
00310     /** access preprocessed features
00311      */
00312     MultiArrayView<2, T1, C1> & features()
00313     {
00314         return features_;
00315     }
00316 
00317     /** access preprocessed response
00318      */
00319     MultiArrayView<2, T2, C2> & response()
00320     {
00321         return response_;
00322     }
00323 
00324     /** access strata - this is not used currently
00325      */
00326     MultiArray<2, int> & strata()
00327     {
00328         return strata_;
00329     }
00330 };
00331 }
00332 #endif //VIGRA_RF_PREPROCESSING_HXX
00333 
00334 
00335 

© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de)
Heidelberg Collaboratory for Image Processing, University of Heidelberg, Germany

html generated using doxygen and Python
vigra 1.8.0 (20 Sep 2011)