MLPACK  1.0.10
cosine_tree.hpp
Go to the documentation of this file.
1 
23 #ifndef __MLPACK_CORE_TREE_COSINE_TREE_COSINE_TREE_HPP
24 #define __MLPACK_CORE_TREE_COSINE_TREE_COSINE_TREE_HPP
25 
26 #include <mlpack/core.hpp>
27 #include <boost/heap/priority_queue.hpp>
28 
29 namespace mlpack {
30 namespace tree {
31 
32 // Predeclare classes for CosineNodeQueue typedef.
33 class CompareCosineNode;
34 class CosineTree;
35 
36 // CosineNodeQueue typedef.
37 typedef boost::heap::priority_queue<CosineTree*,
38  boost::heap::compare<CompareCosineNode> > CosineNodeQueue;
39 
41 {
42  public:
43 
52  CosineTree(const arma::mat& dataset);
53 
63  CosineTree(CosineTree& parentNode, const std::vector<size_t>& subIndices);
64 
79  CosineTree(const arma::mat& dataset,
80  const double epsilon,
81  const double delta);
82 
92  void ModifiedGramSchmidt(CosineNodeQueue& treeQueue,
93  arma::vec& centroid,
94  arma::vec& newBasisVector,
95  arma::vec* addBasisVector = NULL);
96 
109  double MonteCarloError(CosineTree* node,
110  CosineNodeQueue& treeQueue,
111  arma::vec* addBasisVector1 = NULL,
112  arma::vec* addBasisVector2 = NULL);
113 
119  void ConstructBasis(CosineNodeQueue& treeQueue);
120 
126  void CosineNodeSplit();
127 
134  void ColumnSamplesLS(std::vector<size_t>& sampledIndices,
135  arma::vec& probabilities, size_t numSamples);
136 
143  size_t ColumnSampleLS();
144 
157  size_t BinarySearch(arma::vec& cDistribution, double value, size_t start,
158  size_t end);
159 
167  void CalculateCosines(arma::vec& cosines);
168 
173  void CalculateCentroid();
174 
176  void GetFinalBasis(arma::mat& finalBasis) { finalBasis = basis; }
177 
179  const arma::mat& GetDataset() const { return dataset; }
180 
182  std::vector<size_t>& VectorIndices() { return indices; }
183 
185  void L2Error(const double error) { this->l2Error = error; }
186 
188  double L2Error() const { return l2Error; }
189 
191  arma::vec& Centroid() { return centroid; }
192 
194  void BasisVector(arma::vec& bVector) { this->basisVector = bVector; }
195 
197  arma::vec& BasisVector() { return basisVector; }
198 
200  CosineTree* Left() { return left; }
201 
203  CosineTree* Right() { return right; }
204 
206  size_t NumColumns() const { return numColumns; }
207 
209  double FrobNormSquared() const { return frobNormSquared; }
210 
212  size_t SplitPointIndex() const { return indices[splitPointIndex]; }
213 
214  private:
216  const arma::mat& dataset;
218  double epsilon;
220  double delta;
222  arma::mat basis;
230  std::vector<size_t> indices;
232  arma::vec l2NormsSquared;
234  arma::vec centroid;
236  arma::vec basisVector;
240  size_t numColumns;
242  double l2Error;
245 };
246 
248 {
249  public:
250 
251  // Comparison function for construction of priority queue.
252  bool operator() (const CosineTree* a, const CosineTree* b) const
253  {
254  return a->L2Error() < b->L2Error();
255  }
256 };
257 
258 }; // namespace tree
259 }; // namespace mlpack
260 
261 #endif