[ VIGRA Homepage | Function Index | Class Index | Namespaces | File List | Main Page ]
#include <vigra/random_forest.hxx>
Public Member Functions | |
int | class_count () const |
return number of classes used while training. | |
int | column_count () const |
return number of features used while training. | |
int | feature_count () const |
return number of features used while training. | |
template<class U , class C1 , class U2 , class C2 , class Split_t , class Stop_t , class Visitor_t , class Random_t > | |
void | reLearnTree (MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &response, int treeId, Visitor_t visitor_, Split_t split_, Stop_t stop_, Random_t &random) |
int | tree_count () const |
return number of trees | |
Data Access | |
ProblemSpec_t const & | ext_param () const |
return external parameters for viewing | |
Options_t const & | options () const |
access const random forest options | |
void | set_ext_param (ProblemSpec_t const &in) |
set external parameters | |
Options_t & | set_options () |
access random forest options | |
DecisionTree_t & | tree (int index) |
access trees | |
DecisionTree_t const & | tree (int index) const |
access const trees | |
Learning | |
template<class U , class C1 , class U2 , class C2 > | |
void | learn (MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &labels) |
learn on data with default configuration | |
template<class U , class C1 , class U2 , class C2 , class Visitor_t , class Split_t > | |
void | learn (MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &labels, Visitor_t visitor, Split_t split) |
template<class U , class C1 , class U2 , class C2 , class Visitor_t > | |
void | learn (MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &labels, Visitor_t visitor) |
template<class U , class C1 , class U2 , class C2 , class Split_t , class Stop_t , class Visitor_t > | |
void | learn (MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &response, Visitor_t visitor, Split_t split, Stop_t stop) |
template<class U , class C1 , class U2 , class C2 , class Split_t , class Stop_t , class Visitor_t , class Random_t > | |
void | learn (MultiArrayView< 2, U, C1 > const &features, MultiArrayView< 2, U2, C2 > const &response, Visitor_t visitor, Split_t split, Stop_t stop, Random_t const &random) |
learn on data with custom config and random number generator | |
prediction | |
template<class U , class C > | |
LabelType | predictLabel (MultiArrayView< 2, U, C > const &features, ArrayVectorView< double > prior) const |
predict a label with features and class priors | |
template<class U , class C > | |
LabelType | predictLabel (MultiArrayView< 2, U, C >const &features) |
template<class U , class C , class Stop > | |
LabelType | predictLabel (MultiArrayView< 2, U, C >const &features, Stop &stop) const |
predict a label given a feature. | |
template<class U , class C1 , class T , class C2 , class Stop > | |
void | predictLabels (MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &labels, Stop &stop) const |
template<class U , class C1 , class T , class C2 > | |
void | predictLabels (MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &labels) const |
predict multiple labels with given features | |
template<class U , class C1 , class T , class C2 > | |
void | predictProbabilities (MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &prob) const |
predict the class probabilities for multiple labels | |
template<class T1 , class T2 , class C > | |
void | predictProbabilities (OnlinePredictionSet< T1 > &predictionSet, MultiArrayView< 2, T2, C > &prob) |
template<class U , class C1 , class T , class C2 , class Stop > | |
void | predictProbabilities (MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &prob, Stop &stop) const |
predict the class probabilities for multiple labels | |
template<class U , class C1 , class T , class C2 > | |
void | predictRaw (MultiArrayView< 2, U, C1 >const &features, MultiArrayView< 2, T, C2 > &prob) const |
Constructors | |
template<class TopologyIterator , class ParameterIterator > | |
RandomForest (int treeCount, TopologyIterator topology_begin, ParameterIterator parameter_begin, ProblemSpec_t const &problem_spec, Options_t const &options=Options_t()) | |
Create RF from external source. | |
RandomForest (Options_t const &options=Options_t(), ProblemSpec_t const &ext_param=ProblemSpec_t()) | |
default constructor | |
Protected Attributes | |
MultiArray< 2, double > | garbage_prediction_ |
Random Forest class
<PrprocessorTag | = ClassificationTag> Class used to preprocess the input while learning and predicting. Currently Available: ClassificationTag and RegressionTag. It is recommended to use Splitfunctor::Preprocessor_t while using custom splitfunctors as they may need the data to be in a different format. |
simple usage for classification (regression is not yet supported): look at RandomForest::learn() as well as RandomForestOptions() for additional options.
using namespace vigra; using namespace rf; typedef xxx feature_t; \\ replace xxx with whichever type typedef yyy label_t; \\ likewise // allocate the training data MultiArrayView<2, feature_t> f = get_training_features(); MultiArrayView<2, label_t> l = get_training_labels(); RandomForest<> rf; // construct visitor to calculate out-of-bag error visitors::OOB_Error oob_v; // perform training rf.learn(f, l, visitors::create_visitor(oob_v)); std::cout << "the out-of-bag error is: " << oob_v.oob_breiman << "\n"; // get features for new data to be used for prediction MultiArrayView<2, feature_t> pf = get_features(); // allocate space for the response (pf.shape(0) is the number of samples) MultiArrayView<2, label_t> prediction(pf.shape(0), 1); MultiArrayView<2, double> prob(pf.shape(0), rf.class_count()); // perform prediction on new data rf.predict_labels(pf, prediction); rf.predict_probabilities(pf, prob);
Additional information such as Variable Importance measures are accessed via Visitors defined in rf::visitors. Have a look at rf::split for other splitting methods.
RandomForest | ( | Options_t const & | options = Options_t() , |
|
ProblemSpec_t const & | ext_param = ProblemSpec_t() | |||
) |
default constructor
options | general options to the Random Forest. Must be of Type Options_t | |
ext_param | problem specific values that can be supplied additionally. (class weights , labels etc) |
RandomForest | ( | int | treeCount, | |
TopologyIterator | topology_begin, | |||
ParameterIterator | parameter_begin, | |||
ProblemSpec_t const & | problem_spec, | |||
Options_t const & | options = Options_t() | |||
) |
Create RF from external source.
treeCount | Number of trees to add. | |
topology_begin | Iterator to a Container where the topology_ data of the trees are stored. Iterator should support at least treeCount forward iterations. (i.e. topology_end - topology_begin >= treeCount | |
parameter_begin | iterator to a Container where the parameters_ data of the trees are stored. Iterator should support at least treeCount forward iterations. | |
problem_spec | Extrinsic parameters that specify the problem e.g. ClassCount, featureCount etc. | |
options | (optional) specify options used to train the original Random forest. This parameter is not used anywhere during prediction and thus is optional. |
ProblemSpec_t const& ext_param | ( | ) | const |
return external parameters for viewing
void set_ext_param | ( | ProblemSpec_t const & | in | ) |
set external parameters
in | external parameters to be set |
set external parameters explicitly. If Random Forest has not been trained the preprocessor will either ignore filling values set this way or will throw an exception if values specified manually do not match the value calculated & during the preparation step.
Options_t& set_options | ( | ) |
access random forest options
Options_t const& options | ( | ) | const |
access const random forest options
int column_count | ( | ) | const |
return number of features used while training.
deprecated. Use feature_count() instead.
void learn | ( | MultiArrayView< 2, U, C1 > const & | features, | |
MultiArrayView< 2, U2, C2 > const & | labels | |||
) |
learn on data with default configuration
features | a N x M matrix containing N samples with M features | |
labels | a N x D matrix containing the corresponding N labels. Current split functors assume D to be 1 and ignore any additional columns. this is not enforced to allow future support for uncertain labels. |
learning is done with:
MultiArray<2, double> garbage_prediction_ [mutable, protected] |
optimisation for predictLabels
© Ullrich Köthe (ullrich.koethe@iwr.uni-heidelberg.de) |
html generated using doxygen and Python
|