MLPACK  1.0.7
dtree.hpp
Go to the documentation of this file.
1 
23 #ifndef __MLPACK_METHODS_DET_DTREE_HPP
24 #define __MLPACK_METHODS_DET_DTREE_HPP
25 
26 #include <mlpack/core.hpp>
27 
28 namespace mlpack {
29 namespace det {
30 
54 class DTree
55 {
56  public:
60  DTree();
61 
70  DTree(const arma::vec& maxVals,
71  const arma::vec& minVals,
72  const size_t totalPoints);
73 
82  DTree(arma::mat& data);
83 
96  DTree(const arma::vec& maxVals,
97  const arma::vec& minVals,
98  const size_t start,
99  const size_t end,
100  const double logNegError);
101 
113  DTree(const arma::vec& maxVals,
114  const arma::vec& minVals,
115  const size_t totalPoints,
116  const size_t start,
117  const size_t end);
118 
120  ~DTree();
121 
132  double Grow(arma::mat& data,
133  arma::Col<size_t>& oldFromNew,
134  const bool useVolReg = false,
135  const size_t maxLeafSize = 10,
136  const size_t minLeafSize = 5);
137 
146  double PruneAndUpdate(const double oldAlpha,
147  const size_t points,
148  const bool useVolReg = false);
149 
155  double ComputeValue(const arma::vec& query) const;
156 
164  void WriteTree(FILE *fp, const size_t level = 0) const;
165 
173  int TagTree(const int tag = 0);
174 
181  int FindBucket(const arma::vec& query) const;
182 
188  void ComputeVariableImportance(arma::vec& importances) const;
189 
196  double LogNegativeError(const size_t totalPoints) const;
197 
201  bool WithinRange(const arma::vec& query) const;
202 
203  private:
204  // The indices in the complete set of points
205  // (after all forms of swapping in the original data
206  // matrix to align all the points in a node
207  // consecutively in the matrix. The 'old_from_new' array
208  // maps the points back to their original indices.
209 
212  size_t start;
215  size_t end;
216 
218  arma::vec maxVals;
220  arma::vec minVals;
221 
223  size_t splitDim;
224 
226  double splitValue;
227 
229  double logNegError;
230 
233 
236 
238  bool root;
239 
241  double ratio;
242 
244  double logVolume;
245 
248 
250  double alphaUpper;
251 
256 
257  public:
259  size_t Start() const { return start; }
261  size_t End() const { return end; }
263  size_t SplitDim() const { return splitDim; }
265  double SplitValue() const { return splitValue; }
267  double LogNegError() const { return logNegError; }
271  size_t SubtreeLeaves() const { return subtreeLeaves; }
274  double Ratio() const { return ratio; }
276  double LogVolume() const { return logVolume; }
278  DTree* Left() const { return left; }
280  DTree* Right() const { return right; }
282  bool Root() const { return root; }
284  double AlphaUpper() const { return alphaUpper; }
285 
287  const arma::vec& MaxVals() const { return maxVals; }
289  arma::vec& MaxVals() { return maxVals; }
290 
292  const arma::vec& MinVals() const { return minVals; }
294  arma::vec& MinVals() { return minVals; }
295 
296  private:
297 
298  // Utility methods.
299 
303  bool FindSplit(const arma::mat& data,
304  size_t& splitDim,
305  double& splitValue,
306  double& leftError,
307  double& rightError,
308  const size_t maxLeafSize = 10,
309  const size_t minLeafSize = 5) const;
310 
314  size_t SplitData(arma::mat& data,
315  const size_t splitDim,
316  const double splitValue,
317  arma::Col<size_t>& oldFromNew) const;
318 
319 };
320 
321 }; // namespace det
322 }; // namespace mlpack
323 
324 #endif // __MLPACK_METHODS_DET_DTREE_HPP