38 #ifndef PCL_ML_DT_DECISION_TREE_TRAINER_HPP_ 39 #define PCL_ML_DT_DECISION_TREE_TRAINER_HPP_ 42 template <
class FeatureType,
class DataSet,
class LabelType,
class ExampleIndex,
class NodeType>
44 : max_tree_depth_ (15)
45 , num_of_features_ (1000)
46 , num_of_thresholds_ (10)
47 , feature_handler_ (NULL)
48 , stats_estimator_ (NULL)
52 , decision_tree_trainer_data_provider_ ()
53 , random_features_at_split_node_(false)
59 template <
class FeatureType,
class DataSet,
class LabelType,
class ExampleIndex,
class NodeType>
66 template <
class FeatureType,
class DataSet,
class LabelType,
class ExampleIndex,
class NodeType>
72 std::vector<FeatureType> features;
74 if (!random_features_at_split_node_)
75 feature_handler_->createRandomFeatures (num_of_features_, features);
81 if (decision_tree_trainer_data_provider_)
83 std::cerr <<
"use decision_tree_trainer_data_provider_" << std::endl;
85 decision_tree_trainer_data_provider_->getDatasetAndLabels (data_set_, label_data_, examples_);
86 trainDecisionTreeNode (features, examples_, label_data_, max_tree_depth_, tree.
getRoot ());
93 trainDecisionTreeNode (features, examples_, label_data_, max_tree_depth_, tree.
getRoot ());
99 template <
class FeatureType,
class DataSet,
class LabelType,
class ExampleIndex,
class NodeType>
102 std::vector<FeatureType> & features,
103 std::vector<ExampleIndex> & examples,
104 std::vector<LabelType> & label_data,
105 const size_t max_depth,
108 const size_t num_of_examples = examples.size ();
109 if (num_of_examples == 0)
111 PCL_ERROR (
"Reached invalid point in decision tree training: Number of examples is 0!");
117 stats_estimator_->computeAndSetNodeStats(data_set_, examples, label_data, node);
121 if(examples.size () < min_examples_for_split_) {
122 stats_estimator_->computeAndSetNodeStats (data_set_, examples, label_data, node);
126 if(random_features_at_split_node_) {
128 feature_handler_->createRandomFeatures (num_of_features_, features);
131 std::vector<float> feature_results;
132 std::vector<unsigned char> flags;
134 feature_results.reserve (num_of_examples);
135 flags.reserve (num_of_examples);
138 int best_feature_index = -1;
139 float best_feature_threshold = 0.0f;
140 float best_feature_information_gain = 0.0f;
142 const size_t num_of_features = features.size ();
143 for (
size_t feature_index = 0; feature_index < num_of_features; ++feature_index)
146 feature_handler_->evaluateFeature (features[feature_index],
153 if (thresholds_.size () > 0)
156 for (
size_t threshold_index = 0; threshold_index < thresholds_.size (); ++threshold_index)
159 const float information_gain = stats_estimator_->computeInformationGain (data_set_,
164 thresholds_[threshold_index]);
166 if (information_gain > best_feature_information_gain)
168 best_feature_information_gain = information_gain;
169 best_feature_index =
static_cast<int> (feature_index);
170 best_feature_threshold = thresholds_[threshold_index];
176 std::vector<float> thresholds;
177 thresholds.reserve (num_of_thresholds_);
178 createThresholdsUniform (num_of_thresholds_, feature_results, thresholds);
181 for (
size_t threshold_index = 0; threshold_index < num_of_thresholds_; ++threshold_index)
183 const float threshold = thresholds[threshold_index];
186 const float information_gain = stats_estimator_->computeInformationGain (data_set_,
193 if (information_gain > best_feature_information_gain)
195 best_feature_information_gain = information_gain;
196 best_feature_index =
static_cast<int> (feature_index);
197 best_feature_threshold = threshold;
203 if (best_feature_index == -1)
205 stats_estimator_->computeAndSetNodeStats (data_set_, examples, label_data, node);
210 std::vector<unsigned char> branch_indices;
211 branch_indices.reserve (num_of_examples);
213 feature_handler_->evaluateFeature (features[best_feature_index],
219 stats_estimator_->computeBranchIndices (feature_results,
221 best_feature_threshold,
225 stats_estimator_->computeAndSetNodeStats (data_set_, examples, label_data, node);
229 const size_t num_of_branches = stats_estimator_->getNumOfBranches ();
231 std::vector<size_t> branch_counts (num_of_branches, 0);
232 for (
size_t example_index = 0; example_index < num_of_examples; ++example_index)
234 ++branch_counts[branch_indices[example_index]];
237 node.feature = features[best_feature_index];
238 node.threshold = best_feature_threshold;
239 node.sub_nodes.resize (num_of_branches);
241 for (
size_t branch_index = 0; branch_index < num_of_branches; ++branch_index)
243 if (branch_counts[branch_index] == 0)
245 NodeType branch_node;
246 stats_estimator_->computeAndSetNodeStats (data_set_, examples, label_data, branch_node);
249 node.sub_nodes[branch_index] = branch_node;
254 std::vector<LabelType> branch_labels;
255 std::vector<ExampleIndex> branch_examples;
256 branch_labels.reserve (branch_counts[branch_index]);
257 branch_examples.reserve (branch_counts[branch_index]);
259 for (
size_t example_index = 0; example_index < num_of_examples; ++example_index)
261 if (branch_indices[example_index] == branch_index)
263 branch_examples.push_back (examples[example_index]);
264 branch_labels.push_back (label_data[example_index]);
268 trainDecisionTreeNode (features, branch_examples, branch_labels, max_depth-1, node.sub_nodes[branch_index]);
275 template <
class FeatureType,
class DataSet,
class LabelType,
class ExampleIndex,
class NodeType>
278 const size_t num_of_thresholds,
279 std::vector<float> & values,
280 std::vector<float> & thresholds)
283 float min_value = ::std::numeric_limits<float>::max();
284 float max_value = -::std::numeric_limits<float>::max();
286 const size_t num_of_values = values.size ();
287 for (
size_t value_index = 0; value_index < num_of_values; ++value_index)
289 const float value = values[value_index];
291 if (value < min_value) min_value = value;
292 if (value > max_value) max_value = value;
295 const float range = max_value - min_value;
296 const float step = range /
static_cast<float>(num_of_thresholds+2);
299 thresholds.resize (num_of_thresholds);
301 for (
size_t threshold_index = 0; threshold_index < num_of_thresholds; ++threshold_index)
303 thresholds[threshold_index] = min_value + step*(
static_cast<float>(threshold_index+1));
void train(DecisionTree< NodeType > &tree)
Trains a decision tree using the set training data and settings.
Class representing a decision tree.
DecisionTreeTrainer()
Constructor.
void setRoot(const NodeType &root)
Sets the root node of the tree.
NodeType & getRoot()
Returns the root node of the tree.
void trainDecisionTreeNode(std::vector< FeatureType > &features, std::vector< ExampleIndex > &examples, std::vector< LabelType > &label_data, size_t max_depth, NodeType &node)
Trains a decision tree node from the specified features, label data, and examples.
static void createThresholdsUniform(const size_t num_of_thresholds, std::vector< float > &values, std::vector< float > &thresholds)
Creates uniformely distrebuted thresholds over the range of the supplied values.
virtual ~DecisionTreeTrainer()
Destructor.