Tapkee
external/barnes_hut_sne/vptree.hpp
Go to the documentation of this file.
1 
32 #include <stdlib.h>
33 #include <algorithm>
34 #include <vector>
35 #include <stdio.h>
36 #include <queue>
37 #include <limits>
38 
39 #ifndef VPTREE_H
40 #define VPTREE_H
41 
42 namespace tsne
43 {
44 
45 class DataPoint
46 {
47  int _ind;
48  std::vector<ScalarType> _x;
49 
50 public:
51  DataPoint() : _ind(-1), _x(0) { }
52  DataPoint(int Dv, int indv, ScalarType* xv) : _ind(indv), _x(Dv)
53  {
54  for(int d = 0; d < dimensionality(); d++) _x[d] = xv[d];
55  }
56  DataPoint(const DataPoint& other) : _ind(0), _x(0) // this makes a deep copy -- should not free anything
57  {
58  if(this != &other) {
59  _ind = other.index();
60  _x.resize(other.dimensionality());
61  for(int d = 0; d < dimensionality(); d++) _x[d] = other.x(d);
62  }
63  }
64  ~DataPoint() { }
65  DataPoint& operator= (const DataPoint& other) { // asignment should free old object
66  if(this != &other) {
67  _ind = other.index();
68  _x.resize(other.dimensionality());
69  for(int d = 0; d < dimensionality(); d++) _x[d] = other.x(d);
70  }
71  return *this;
72  }
73  int index() const { return _ind; }
74  int dimensionality() const { return _x.size(); }
75  ScalarType x(int d) const { return _x[d]; }
76 };
77 
78 
80  ScalarType dd = .0;
81  for(int d = 0; d < t1.dimensionality(); d++) dd += (t1.x(d) - t2.x(d)) * (t1.x(d) - t2.x(d));
82  return dd;
83 }
84 
85 
86 template<typename T, ScalarType (*distance)( const T&, const T& )>
87 class VpTree
88 {
89 public:
90 
91  // Default constructor
92  VpTree() : _items(), _tau(0.0), _root(0) {}
93 
94  // Destructor
95  ~VpTree() {
96  delete _root;
97  }
98 
99  // Function to create a new VpTree from data
100  void create(const std::vector<T>& items) {
101  delete _root;
102  _items = items;
103  _root = buildFromPoints(0, items.size());
104  }
105 
106  // Function that uses the tree to find the k nearest neighbors of target
107  void search(const T& target, int k, std::vector<T>* results, std::vector<ScalarType>* distances)
108  {
109 
110  // Use a priority queue to store intermediate results on
111  std::priority_queue<HeapItem> heap;
112 
113  // Variable that tracks the distance to the farthest point in our results
114  _tau = DBL_MAX;
115 
116  // Perform the searcg
117  search(_root, target, k, heap);
118 
119  // Gather final results
120  results->clear(); distances->clear();
121  while(!heap.empty()) {
122  results->push_back(_items[heap.top().index]);
123  distances->push_back(heap.top().dist);
124  heap.pop();
125  }
126 
127  // Results are in reverse order
128  std::reverse(results->begin(), results->end());
129  std::reverse(distances->begin(), distances->end());
130  }
131 
132 private:
133 
134  VpTree(const VpTree&);
135  VpTree& operator=(const VpTree&);
136 
137  std::vector<T> _items;
139 
140  // Single node of a VP tree (has a point and radius; left children are closer to point than the radius)
141  struct Node
142  {
143  int index; // index of point in node
144  ScalarType threshold; // radius(?)
145  Node* left; // points closer by than threshold
146  Node* right; // points farther away than threshold
147 
148  Node() : index(0), threshold(0.), left(0), right(0) {}
149 
150  ~Node()
151  {
152  delete left;
153  delete right;
154  }
155 
156  Node(const Node&);
157  Node& operator=(const Node&);
158 
159  }* _root;
160 
161 
162  // An item on the intermediate result queue
163  struct HeapItem {
164  HeapItem(int indexv, ScalarType distv) :
165  index(indexv), dist(distv) {}
166  int index;
168  bool operator<(const HeapItem& o) const {
169  return dist < o.dist;
170  }
171  };
172 
173  // Distance comparator for use in std::nth_element
175  {
176  const T& item;
177  DistanceComparator(const T& itemv) : item(itemv) {}
178  bool operator()(const T& a, const T& b) {
179  return distance(item, a) < distance(item, b);
180  }
181  };
182 
183  // Function that (recursively) fills the tree
184  Node* buildFromPoints( int lower, int upper )
185  {
186  if (upper == lower) { // indicates that we're done here!
187  return NULL;
188  }
189 
190  // Lower index is center of current node
191  Node* node = new Node();
192  node->index = lower;
193 
194  if (upper - lower > 1) { // if we did not arrive at leaf yet
195 
196  // Choose an arbitrary point and move it to the start
197  int i = (int) (tapkee::uniform_random() * (upper - lower - 1)) + lower;
198  std::swap(_items[lower], _items[i]);
199 
200  // Partition around the median distance
201  int median = (upper + lower) / 2;
202  std::nth_element(_items.begin() + lower + 1,
203  _items.begin() + median,
204  _items.begin() + upper,
205  DistanceComparator(_items[lower]));
206 
207  // Threshold of the new node will be the distance to the median
208  node->threshold = distance(_items[lower], _items[median]);
209 
210  // Recursively build tree
211  node->index = lower;
212  node->left = buildFromPoints(lower + 1, median);
213  node->right = buildFromPoints(median, upper);
214  }
215 
216  // Return result
217  return node;
218  }
219 
220  // Helper function that searches the tree
221  void search(Node* node, const T& target, int k, std::priority_queue<HeapItem>& heap)
222  {
223  if(node == NULL) return; // indicates that we're done here
224 
225  // Compute distance between target and current node
226  ScalarType dist = distance(_items[node->index], target);
227 
228  // If current node within radius tau
229  if(dist < _tau) {
230  if(heap.size() == static_cast<size_t>(k)) heap.pop(); // remove furthest node from result list (if we already have k results)
231  heap.push(HeapItem(node->index, dist)); // add current node to result list
232  if(heap.size() == static_cast<size_t>(k)) _tau = heap.top().dist; // update value of tau (farthest point in result list)
233  }
234 
235  // Return if we arrived at a leaf
236  if(node->left == NULL && node->right == NULL) {
237  return;
238  }
239 
240  // If the target lies within the radius of ball
241  if(dist < node->threshold) {
242  search(node->left, target, k, heap);
243 
244  if(dist + _tau >= node->threshold) { // if there can still be neighbors outside the ball, recursively search right child
245  search(node->right, target, k, heap);
246  }
247 
248  // If the target lies outsize the radius of the ball
249  } else {
250  search(node->right, target, k, heap);
251 
252  if (dist - _tau <= node->threshold) { // if there can still be neighbors inside the ball, recursively search left child
253  search(node->left, target, k, heap);
254  }
255  }
256  }
257 };
258 
259 }
260 
261 #endif
ScalarType distance(Callback &cb, const CoverTreePoint< RandomAccessIterator > &l, const CoverTreePoint< RandomAccessIterator > &r, ScalarType upper_bound)
Namespace containing implementation of t-SNE algorithm.
Definition: quadtree.hpp:41
ScalarType x(int d) const
DataPoint(const DataPoint &other)
HeapItem(int indexv, ScalarType distv)
void search(const T &target, int k, std::vector< T > *results, std::vector< ScalarType > *distances)
ScalarType uniform_random()
Definition: random.hpp:30
double ScalarType
default scalar value (can be overrided with TAPKEE_CUSTOM_INTERNAL_NUMTYPE define) ...
Definition: types.hpp:15
std::vector< ScalarType > _x
void search(Node *node, const T &target, int k, std::priority_queue< HeapItem > &heap)
ScalarType euclidean_distance(const DataPoint &t1, const DataPoint &t2)
void create(const std::vector< T > &items)
bool operator<(const HeapItem &o) const
DataPoint(int Dv, int indv, ScalarType *xv)
DataPoint & operator=(const DataPoint &other)
Node * buildFromPoints(int lower, int upper)
static const NeighborsMethod VpTree("Vantage point tree")
Vantage point tree -based method.