Tapkee
neighbors.hpp
Go to the documentation of this file.
1 /* This software is distributed under BSD 3-clause license (see LICENSE file).
2  *
3  * Copyright (c) 2012-2013 Sergey Lisitsyn, Fernando J. Iglesias Garcia
4  */
5 
6 #ifndef TAPKEE_NEIGHBORS_H_
7 #define TAPKEE_NEIGHBORS_H_
8 
9 /* Tapkee includes */
10 #include <tapkee/defines.hpp>
11 #ifdef TAPKEE_USE_LGPL_COVERTREE
13 #endif
16 /* End of Tapkee includes */
17 
18 #include <vector>
19 #include <utility>
20 #include <algorithm>
21 
22 namespace tapkee
23 {
24 namespace tapkee_internal
25 {
26 
27 template <class DistanceRecord>
29 {
30  inline bool operator()(const DistanceRecord& l, const DistanceRecord& r) const
31  {
32  return (l.second < r.second);
33  }
34 };
35 
36 struct KernelType
37 {
38 };
39 
40 template <class RandomAccessIterator, class Callback>
42 {
43  KernelDistance(const Callback& cb) : callback(cb) { }
44  inline ScalarType operator()(const RandomAccessIterator& l, const RandomAccessIterator& r)
45  {
46  return callback.kernel(*l,*r);
47  }
48  inline ScalarType distance(const RandomAccessIterator& l, const RandomAccessIterator& r)
49  {
50  return sqrt(callback.kernel(*l,*l) - 2*callback.kernel(*l,*r) + callback.kernel(*r,*r));
51  }
52  typedef KernelType type;
53  Callback callback;
54 };
55 
57 {
58 };
59 
60 template <class RandomAccessIterator, class Callback>
62 {
63  PlainDistance(const Callback& cb) : callback(cb) { }
64  inline ScalarType operator()(const RandomAccessIterator& l, const RandomAccessIterator& r)
65  {
66  return callback.distance(*l,*r);
67  }
68  inline ScalarType distance(const RandomAccessIterator& l, const RandomAccessIterator& r)
69  {
70  return callback.distance(*l,*r);
71  }
72  typedef DistanceType type;
73  Callback callback;
74 };
75 
76 #ifdef TAPKEE_USE_LGPL_COVERTREE
77 template <class RandomAccessIterator, class Callback>
78 Neighbors find_neighbors_covertree_impl(RandomAccessIterator begin, RandomAccessIterator end,
79  Callback callback, IndexType k)
80 {
81  timed_context context("Covertree-based neighbors search");
82 
83  typedef CoverTreePoint<RandomAccessIterator> TreePoint;
84  v_array<TreePoint> points;
85  for (RandomAccessIterator iter=begin; iter!=end; ++iter)
86  push(points, TreePoint(iter, callback(iter,iter)));
87 
88  node<TreePoint> ct = batch_create(callback, points);
89 
91  ++k; // because one of the neighbors will be the actual query point
92  k_nearest_neighbor(callback,ct,ct,res,k);
93 
94  Neighbors neighbors;
95  neighbors.resize(end-begin);
96  assert(end-begin==res.index);
97  for (int i=0; i<res.index; ++i)
98  {
99  LocalNeighbors local_neighbors;
100  local_neighbors.reserve(k);
101 
102  for (IndexType j=1; j<=k; ++j) // j=0 is the query point
103  {
104  // The actual query point is found as a neighbor, just ignore it
105  if (res[i][j].iter_-begin==res[i][0].iter_-begin)
106  continue;
107  local_neighbors.push_back(res[i][j].iter_-begin);
108  }
109  neighbors[res[i][0].iter_-begin] = local_neighbors;
110  free(res[i].elements);
111  };
112  free(res.elements);
113  free_children(ct);
114  free(points.elements);
115  return neighbors;
116 }
117 #endif
118 
119 template <class RandomAccessIterator, class Callback>
120 Neighbors find_neighbors_bruteforce_impl(const RandomAccessIterator& begin, const RandomAccessIterator& end,
121  Callback callback, IndexType k)
122 {
123  timed_context context("Distance sorting based neighbors search");
124  typedef std::pair<RandomAccessIterator, ScalarType> DistanceRecord;
125  typedef std::vector<DistanceRecord> Distances;
126 
127  Neighbors neighbors;
128  neighbors.reserve(end-begin);
129  for (RandomAccessIterator iter=begin; iter!=end; ++iter)
130  {
131  Distances distances;
132  for (RandomAccessIterator around_iter=begin; around_iter!=end; ++around_iter)
133  distances.push_back(std::make_pair(around_iter, callback.distance(iter,around_iter)));
134 
135  std::nth_element(distances.begin(),distances.begin()+k+1,distances.end(),
137 
138  LocalNeighbors local_neighbors;
139  local_neighbors.reserve(k);
140  for (typename Distances::const_iterator neighbors_iter=distances.begin();
141  neighbors_iter!=distances.begin()+k+1; ++neighbors_iter)
142  {
143  if (neighbors_iter->first != iter)
144  local_neighbors.push_back(neighbors_iter->first - begin);
145  }
146  neighbors.push_back(local_neighbors);
147  }
148  return neighbors;
149 }
150 
151 template <class RandomAccessIterator, class Callback>
152 Neighbors find_neighbors_vptree_impl(const RandomAccessIterator& begin, const RandomAccessIterator& end,
153  Callback callback, IndexType k)
154 {
155  timed_context context("VP-Tree based neighbors search");
156 
157  Neighbors neighbors;
158  neighbors.reserve(end-begin);
159 
160  VantagePointTree<RandomAccessIterator,Callback> tree(begin,end,callback);
161 
162  for (RandomAccessIterator i=begin; i!=end; ++i)
163  {
164  LocalNeighbors local_neighbors = tree.search(i,k+1);
165  std::remove(local_neighbors.begin(),local_neighbors.end(),i-begin);
166  neighbors.push_back(local_neighbors);
167  }
168 
169  return neighbors;
170 }
171 
172 template <class RandomAccessIterator, class Callback>
173 Neighbors find_neighbors(NeighborsMethod method, const RandomAccessIterator& begin,
174  const RandomAccessIterator& end, const Callback& callback,
175  IndexType k, bool check_connectivity)
176 {
177  if (k > static_cast<IndexType>(end-begin-1))
178  {
179  LoggingSingleton::instance().message_warning("Number of neighbors is greater than number of objects to embed. "
180  "Using greatest possible number of neighbors.");
181  k = static_cast<IndexType>(end-begin-1);
182  }
183  LoggingSingleton::instance().message_info("Using the " + get_neighbors_method_name(method) + " neighbors computation method.");
184 
185  Neighbors neighbors;
186  if (method.is(Brute))
187  neighbors = find_neighbors_bruteforce_impl(begin,end,callback,k);
188  if (method.is(VpTree))
189  neighbors = find_neighbors_vptree_impl(begin,end,callback,k);
190 #ifdef TAPKEE_USE_LGPL_COVERTREE
191  if (method.is(CoverTree))
192  neighbors = find_neighbors_covertree_impl(begin,end,callback,k);
193 #endif
194 
195  if (check_connectivity)
196  {
197  if (!is_connected(begin,end,neighbors))
198  LoggingSingleton::instance().message_warning("The neighborhood graph is not connected.");
199  }
200  return neighbors;
201 }
202 
203 } // End of namespace tapkee
204 } // End of namespace tapkee_internal
205 
206 #endif
ScalarType distance(const RandomAccessIterator &l, const RandomAccessIterator &r)
Definition: neighbors.hpp:68
bool operator()(const DistanceRecord &l, const DistanceRecord &r) const
Definition: neighbors.hpp:30
Neighbors find_neighbors(NeighborsMethod method, const RandomAccessIterator &begin, const RandomAccessIterator &end, const Callback &callback, IndexType k, bool check_connectivity)
Definition: neighbors.hpp:173
void k_nearest_neighbor(DistanceCallback &dcb, const node< P > &top_node, const node< P > &query, v_array< v_array< P > > &results, int k)
Definition: covertree.hpp:828
ScalarType operator()(const RandomAccessIterator &l, const RandomAccessIterator &r)
Definition: neighbors.hpp:44
node< P > batch_create(DistanceCallback &dcb, v_array< P > points)
Definition: covertree.hpp:299
ScalarType distance(const RandomAccessIterator &l, const RandomAccessIterator &r)
Definition: neighbors.hpp:48
static const NeighborsMethod Brute("Brute-force")
Brute force method with not least than time complexity. Recommended to be used only in debug purpose...
Class v_array taken directly from JL&#39;s implementation.
Neighbors find_neighbors_covertree_impl(RandomAccessIterator begin, RandomAccessIterator end, Callback callback, IndexType k)
Definition: neighbors.hpp:78
double ScalarType
default scalar value (can be overrided with TAPKEE_CUSTOM_INTERNAL_NUMTYPE define) ...
Definition: types.hpp:15
void free_children(const node< P > &n)
Definition: covertree.hpp:69
TAPKEE_INTERNAL_VECTOR< tapkee::IndexType > LocalNeighbors
Definition: synonyms.hpp:39
int IndexType
indexing type (non-overridable) set to int for compatibility with OpenMP 2.0
Definition: types.hpp:19
ScalarType operator()(const RandomAccessIterator &l, const RandomAccessIterator &r)
Definition: neighbors.hpp:64
void push(v_array< T > &v, const T &new_ele)
void message_info(const std::string &msg)
Definition: logging.hpp:115
TAPKEE_INTERNAL_VECTOR< tapkee::tapkee_internal::LocalNeighbors > Neighbors
Definition: synonyms.hpp:40
Neighbors find_neighbors_vptree_impl(const RandomAccessIterator &begin, const RandomAccessIterator &end, Callback callback, IndexType k)
Definition: neighbors.hpp:152
bool is_connected(RandomAccessIterator begin, RandomAccessIterator end, const Neighbors &neighbors)
Definition: connected.hpp:18
void message_warning(const std::string &msg)
Definition: logging.hpp:116
static LoggingSingleton & instance()
Definition: logging.hpp:102
std::vector< IndexType > search(const RandomAccessIterator &target, int k)
Class Point to use with John Langford&#39;s CoverTree. This class must have some associated functions def...
std::string get_neighbors_method_name(const NeighborsMethod &m)
Definition: naming.hpp:42
static const NeighborsMethod VpTree("Vantage point tree")
Vantage point tree -based method.
bool is(const M &m) const
static const NeighborsMethod CoverTree("Cover tree")
Covertree-based method with approximate time complexity. Recommended to be used as a default method...
Neighbors find_neighbors_bruteforce_impl(const RandomAccessIterator &begin, const RandomAccessIterator &end, Callback callback, IndexType k)
Definition: neighbors.hpp:120