00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037 #ifndef OMPL_DATASTRUCTURES_NEAREST_NEIGHBORS_GNAT_
00038 #define OMPL_DATASTRUCTURES_NEAREST_NEIGHBORS_GNAT_
00039
00040 #include "ompl/datastructures/NearestNeighbors.h"
00041 #include "ompl/datastructures/GreedyKCenters.h"
00042 #include "ompl/util/Exception.h"
00043 #include <boost/unordered_set.hpp>
00044 #include <queue>
00045 #include <algorithm>
00046
00047 namespace ompl
00048 {
00049
00058 template<typename _T>
00059 class NearestNeighborsGNAT : public NearestNeighbors<_T>
00060 {
00061 protected:
00063
00064
00065 typedef std::pair<const _T*,double> DataDist;
00066 struct DataDistCompare
00067 {
00068 bool operator()(const DataDist& d0, const DataDist& d1)
00069 {
00070 return d0.second < d1.second;
00071 }
00072 };
00073 typedef std::priority_queue<DataDist, std::vector<DataDist>, DataDistCompare> NearQueue;
00074
00075
00076
00077 class Node;
00078 typedef std::pair<Node*,double> NodeDist;
00079 struct NodeDistCompare
00080 {
00081 bool operator()(const NodeDist& n0, const NodeDist& n1) const
00082 {
00083 return (n0.second - n0.first->maxRadius_) > (n1.second - n1.first->maxRadius_);
00084 }
00085 };
00086 typedef std::priority_queue<NodeDist, std::vector<NodeDist>, NodeDistCompare> NodeQueue;
00088
00089 public:
00090 NearestNeighborsGNAT(unsigned int degree = 4, unsigned int minDegree = 2,
00091 unsigned int maxDegree = 6, unsigned int maxNumPtsPerLeaf = 50,
00092 unsigned int removedCacheSize = 50, bool rebalancing = false)
00093 : NearestNeighbors<_T>(), tree_(NULL), degree_(degree),
00094 minDegree_(std::min(degree,minDegree)), maxDegree_(std::max(maxDegree,degree)),
00095 maxNumPtsPerLeaf_(maxNumPtsPerLeaf), size_(0),
00096 rebuildSize_(rebalancing ? maxNumPtsPerLeaf*degree : std::numeric_limits<std::size_t>::max()),
00097 removedCacheSize_(removedCacheSize)
00098 {
00099 }
00100
00101 virtual ~NearestNeighborsGNAT(void)
00102 {
00103 if (tree_)
00104 delete tree_;
00105 }
00107 virtual void setDistanceFunction(const typename NearestNeighbors<_T>::DistanceFunction &distFun)
00108 {
00109 NearestNeighbors<_T>::setDistanceFunction(distFun);
00110 pivotSelector_.setDistanceFunction(distFun);
00111 }
00112 virtual void clear(void)
00113 {
00114 if (tree_)
00115 {
00116 delete tree_;
00117 tree_ = NULL;
00118 }
00119 size_ = 0;
00120 removed_.clear();
00121 }
00122
00123 virtual void add(const _T &data)
00124 {
00125 if (tree_)
00126 tree_->add(*this, data);
00127 else
00128 {
00129 tree_ = new Node(degree_, maxNumPtsPerLeaf_, data);
00130 size_ = 1;
00131 }
00132 }
00133 virtual void add(const std::vector<_T> &data)
00134 {
00135 if (tree_)
00136 NearestNeighbors<_T>::add(data);
00137 else if (data.size()>0)
00138 {
00139 tree_ = new Node(degree_, maxNumPtsPerLeaf_, data[0]);
00140 for (unsigned int i=1; i<data.size(); ++i)
00141 tree_->data_.push_back(data[i]);
00142 if (tree_->needToSplit(*this))
00143 tree_->split(*this);
00144 }
00145 size_ += data.size();
00146 }
00148 void rebuildDataStructure()
00149 {
00150 std::vector<_T> lst;
00151 list(lst);
00152 clear();
00153 add(lst);
00154 }
00160 virtual bool remove(const _T &data)
00161 {
00162 if (!tree_) return false;
00163 NearQueue nbhQueue;
00164
00165 bool isPivot = nearestKInternal(data, 1, nbhQueue);
00166 if (*nbhQueue.top().first != data)
00167 return false;
00168 removed_.insert(nbhQueue.top().first);
00169 size_--;
00170
00171
00172 if (isPivot || removed_.size()>=removedCacheSize_)
00173 rebuildDataStructure();
00174 return true;
00175 }
00176
00177 virtual _T nearest(const _T &data) const
00178 {
00179 if (tree_)
00180 {
00181 std::vector<_T> nbh;
00182 nearestK(data, 1, nbh);
00183 if (!nbh.empty()) return nbh[0];
00184 }
00185 throw Exception("No elements found");
00186 }
00187
00188 virtual void nearestK(const _T &data, std::size_t k, std::vector<_T> &nbh) const
00189 {
00190 nbh.clear();
00191 if (k == 0) return;
00192 if (tree_)
00193 {
00194 NearQueue nbhQueue;
00195 nearestKInternal(data, k, nbhQueue);
00196 postprocessNearest(nbhQueue, nbh);
00197 }
00198 }
00199
00200 virtual void nearestR(const _T &data, double radius, std::vector<_T> &nbh) const
00201 {
00202 nbh.clear();
00203 if (tree_)
00204 {
00205 NearQueue nbhQueue;
00206 nearestRInternal(data, radius, nbhQueue);
00207 postprocessNearest(nbhQueue, nbh);
00208 }
00209 }
00210
00211 virtual std::size_t size(void) const
00212 {
00213 return size_;
00214 }
00215
00216 virtual void list(std::vector<_T> &data) const
00217 {
00218 data.clear();
00219 data.reserve(size());
00220 if (tree_)
00221 tree_->list(*this, data);
00222 }
00223
00225 friend std::ostream& operator<<(std::ostream& out, const NearestNeighborsGNAT<_T>& gnat)
00226 {
00227 if (gnat.tree_)
00228 {
00229 out << *gnat.tree_;
00230 if (!gnat.removed_.empty())
00231 {
00232 out << "Elements marked for removal:\n";
00233 for (typename boost::unordered_set<const _T*>::const_iterator it = gnat.removed_.begin();
00234 it != gnat.removed_.end(); it++)
00235 out << **it << '\t';
00236 out << std::endl;
00237 }
00238 }
00239 return out;
00240 }
00241
00242
00243 void integrityCheck()
00244 {
00245 std::vector<_T> lst;
00246 boost::unordered_set<const _T*> tmp;
00247
00248 removed_.swap(tmp);
00249 list(lst);
00250
00251 for (typename boost::unordered_set<const _T*>::iterator it=tmp.begin(); it!=tmp.end(); it++)
00252 {
00253 unsigned int i;
00254 for (i=0; i<lst.size(); ++i)
00255 if (lst[i]==**it)
00256 break;
00257 if (i == lst.size())
00258 {
00259
00260 std::cout << "***** FAIL!! ******\n" << *this << '\n';
00261 for (unsigned int j=0; j<lst.size(); ++j) std::cout<<lst[j]<<'\t';
00262 std::cout<<std::endl;
00263 }
00264 assert(i != lst.size());
00265 }
00266
00267 removed_.swap(tmp);
00268
00269 list(lst);
00270 if (lst.size() != size_)
00271 std::cout << "#########################################\n" << *this << std::endl;
00272 assert(lst.size() == size_);
00273 }
00274 protected:
00275 typedef NearestNeighborsGNAT<_T> GNAT;
00276
00278 bool isRemoved(const _T& data) const
00279 {
00280 return !removed_.empty() && removed_.find(&data) != removed_.end();
00281 }
00282
00287 bool nearestKInternal(const _T &data, std::size_t k, NearQueue& nbhQueue) const
00288 {
00289 bool isPivot;
00290 double dist;
00291 NodeDist nodeDist;
00292 NodeQueue nodeQueue;
00293
00294 isPivot = tree_->insertNeighborK(nbhQueue, k, tree_->pivot_, data,
00295 NearestNeighbors<_T>::distFun_(data, tree_->pivot_));
00296 tree_->nearestK(*this, data, k, nbhQueue, nodeQueue, isPivot);
00297 while (nodeQueue.size() > 0)
00298 {
00299 dist = nbhQueue.top().second;
00300 nodeDist = nodeQueue.top();
00301 nodeQueue.pop();
00302 if (nbhQueue.size() == k &&
00303 (nodeDist.second > nodeDist.first->maxRadius_ + dist ||
00304 nodeDist.second < nodeDist.first->minRadius_ - dist))
00305 break;
00306 nodeDist.first->nearestK(*this, data, k, nbhQueue, nodeQueue, isPivot);
00307 }
00308 return isPivot;
00309 }
00311 void nearestRInternal(const _T &data, double radius, NearQueue& nbhQueue) const
00312 {
00313 double dist = radius;
00314 NodeQueue nodeQueue;
00315 NodeDist nodeDist;
00316
00317 tree_->insertNeighborR(nbhQueue, radius, tree_->pivot_,
00318 NearestNeighbors<_T>::distFun_(data, tree_->pivot_));
00319 tree_->nearestR(*this, data, radius, nbhQueue, nodeQueue);
00320 while (nodeQueue.size() > 0)
00321 {
00322 nodeDist = nodeQueue.top();
00323 nodeQueue.pop();
00324 if (nodeDist.second > nodeDist.first->maxRadius_ + dist ||
00325 nodeDist.second < nodeDist.first->minRadius_ - dist)
00326 break;
00327 nodeDist.first->nearestR(*this, data, radius, nbhQueue, nodeQueue);
00328 }
00329 }
00332 void postprocessNearest(NearQueue& nbhQueue, std::vector<_T> &nbh) const
00333 {
00334 typename std::vector<_T>::reverse_iterator it;
00335 nbh.resize(nbhQueue.size());
00336 for (it=nbh.rbegin(); it!=nbh.rend(); it++, nbhQueue.pop())
00337 *it = *nbhQueue.top().first;
00338 }
00339
00341 class Node
00342 {
00343 public:
00346 Node(int degree, int capacity, const _T& pivot)
00347 : degree_(degree), pivot_(pivot),
00348 minRadius_(std::numeric_limits<double>::infinity()),
00349 maxRadius_(-minRadius_), minRange_(degree, minRadius_),
00350 maxRange_(degree, maxRadius_)
00351 {
00352
00353 data_.reserve(capacity+1);
00354 }
00355
00356 ~Node()
00357 {
00358 for (unsigned int i=0; i<children_.size(); ++i)
00359 delete children_[i];
00360 }
00361
00364 void updateRadius(double dist)
00365 {
00366 if (minRadius_ > dist)
00367 minRadius_ = dist;
00368 if (maxRadius_ < dist)
00369 maxRadius_ = dist;
00370 }
00374 void updateRange(unsigned int i, double dist)
00375 {
00376 if (minRange_[i] > dist)
00377 minRange_[i] = dist;
00378 if (maxRange_[i] < dist)
00379 maxRange_[i] = dist;
00380 }
00382 void add(GNAT& gnat, const _T& data)
00383 {
00384 if (children_.size()==0)
00385 {
00386 data_.push_back(data);
00387 gnat.size_++;
00388 if (needToSplit(gnat))
00389 {
00390 if (gnat.removed_.size() > 0)
00391 gnat.rebuildDataStructure();
00392 else if (gnat.size_ >= gnat.rebuildSize_)
00393 {
00394 gnat.rebuildSize_ <<= 1;
00395 gnat.rebuildDataStructure();
00396 }
00397 else
00398 split(gnat);
00399 }
00400 }
00401 else
00402 {
00403 std::vector<double> dist(children_.size());
00404 double minDist = dist[0] = gnat.distFun_(data, children_[0]->pivot_);
00405 int minInd = 0;
00406
00407 for (unsigned int i=1; i<children_.size(); ++i)
00408 if ((dist[i] = gnat.distFun_(data, children_[i]->pivot_)) < minDist)
00409 {
00410 minDist = dist[i];
00411 minInd = i;
00412 }
00413 for (unsigned int i=0; i<children_.size(); ++i)
00414 children_[i]->updateRange(minInd, dist[i]);
00415 children_[minInd]->updateRadius(minDist);
00416 children_[minInd]->add(gnat, data);
00417 }
00418 }
00420 bool needToSplit(const GNAT& gnat) const
00421 {
00422 unsigned int sz = data_.size();
00423 return sz > gnat.maxNumPtsPerLeaf_ && sz > degree_;
00424 }
00428 void split(GNAT& gnat)
00429 {
00430 std::vector<std::vector<double> > dists;
00431 std::vector<unsigned int> pivots;
00432
00433 children_.reserve(degree_);
00434 gnat.pivotSelector_.kcenters(data_, degree_, pivots, dists);
00435 for(unsigned int i=0; i<pivots.size(); i++)
00436 children_.push_back(new Node(degree_, gnat.maxNumPtsPerLeaf_, data_[pivots[i]]));
00437 degree_ = pivots.size();
00438 for (unsigned int j=0; j<data_.size(); ++j)
00439 {
00440 unsigned int k = 0;
00441 for (unsigned int i=1; i<degree_; ++i)
00442 if (dists[j][i] < dists[j][k])
00443 k = i;
00444 Node* child = children_[k];
00445 if (j != pivots[k])
00446 {
00447 child->data_.push_back(data_[j]);
00448 child->updateRadius(dists[j][k]);
00449 }
00450 for (unsigned int i=0; i<degree_; ++i)
00451 children_[i]->updateRange(k, dists[j][i]);
00452 }
00453
00454 for (unsigned int i=0; i<degree_; ++i)
00455 {
00456
00457 children_[i]->degree_ = std::min(std::max(
00458 degree_ * (unsigned int)(children_[i]->data_.size() / data_.size()),
00459 gnat.minDegree_), gnat.maxDegree_);
00460
00461 if (children_[i]->minRadius_ == std::numeric_limits<double>::infinity())
00462 children_[i]->minRadius_ = children_[i]->maxRadius_ = 0.;
00463 }
00464
00465 std::vector<_T> tmp;
00466 data_.swap(tmp);
00467
00468 for (unsigned int i=0; i<degree_; ++i)
00469 if (children_[i]->needToSplit(gnat))
00470 children_[i]->split(gnat);
00471 }
00472
00474 bool insertNeighborK(NearQueue& nbh, std::size_t k, const _T& data, const _T& key, double dist) const
00475 {
00476 if (nbh.size() < k)
00477 {
00478 nbh.push(std::make_pair(&data, dist));
00479 return true;
00480 }
00481 else if (dist < nbh.top().second ||
00482 (dist < std::numeric_limits<double>::epsilon() && data==key))
00483 {
00484 nbh.pop();
00485 nbh.push(std::make_pair(&data, dist));
00486 return true;
00487 }
00488 return false;
00489 }
00490
00496 void nearestK(const GNAT& gnat, const _T &data, std::size_t k,
00497 NearQueue& nbh, NodeQueue& nodeQueue, bool& isPivot) const
00498 {
00499 for (unsigned int i=0; i<data_.size(); ++i)
00500 if (!gnat.isRemoved(data_[i]))
00501 {
00502 if (insertNeighborK(nbh, k, data_[i], data, gnat.distFun_(data, data_[i])))
00503 isPivot = false;
00504 }
00505 if (children_.size() > 0)
00506 {
00507 double dist;
00508 Node* child;
00509 std::vector<double> distToPivot(children_.size());
00510 std::vector<int> permutation(children_.size());
00511
00512 for (unsigned int i=0; i<permutation.size(); ++i)
00513 permutation[i] = i;
00514 std::random_shuffle(permutation.begin(), permutation.end());
00515
00516 for (unsigned int i=0; i<children_.size(); ++i)
00517 if (permutation[i] >= 0)
00518 {
00519 child = children_[permutation[i]];
00520 distToPivot[permutation[i]] = gnat.distFun_(data, child->pivot_);
00521 if (insertNeighborK(nbh, k, child->pivot_, data, distToPivot[permutation[i]]))
00522 isPivot = true;
00523 if (nbh.size()==k)
00524 {
00525 dist = nbh.top().second;
00526 for (unsigned int j=0; j<children_.size(); ++j)
00527 if (permutation[j] >=0 && i != j &&
00528 (distToPivot[permutation[i]] - dist > child->maxRange_[permutation[j]] ||
00529 distToPivot[permutation[i]] + dist < child->minRange_[permutation[j]]))
00530 permutation[j] = -1;
00531 }
00532 }
00533
00534 dist = nbh.top().second;
00535 for (unsigned int i=0; i<children_.size(); ++i)
00536 if (permutation[i] >= 0)
00537 {
00538 child = children_[permutation[i]];
00539 if (nbh.size()<k ||
00540 (distToPivot[permutation[i]] - dist <= child->maxRadius_ &&
00541 distToPivot[permutation[i]] + dist >= child->minRadius_))
00542 nodeQueue.push(std::make_pair(child, distToPivot[permutation[i]]));
00543 }
00544 }
00545 }
00547 void insertNeighborR(NearQueue& nbh, double r, const _T& data, double dist) const
00548 {
00549 if (dist <= r)
00550 nbh.push(std::make_pair(&data, dist));
00551 }
00555 void nearestR(const GNAT& gnat, const _T &data, double r, NearQueue& nbh, NodeQueue& nodeQueue) const
00556 {
00557 double dist = r;
00558
00559 for (unsigned int i=0; i<data_.size(); ++i)
00560 if (!gnat.isRemoved(data_[i]))
00561 insertNeighborR(nbh, r, data_[i], gnat.distFun_(data, data_[i]));
00562 if (children_.size() > 0)
00563 {
00564 Node* child;
00565 std::vector<double> distToPivot(children_.size());
00566 std::vector<int> permutation(children_.size());
00567
00568 for (unsigned int i=0; i<permutation.size(); ++i)
00569 permutation[i] = i;
00570 std::random_shuffle(permutation.begin(), permutation.end());
00571
00572 for (unsigned int i=0; i<children_.size(); ++i)
00573 if (permutation[i] >= 0)
00574 {
00575 child = children_[permutation[i]];
00576 distToPivot[i] = gnat.distFun_(data, child->pivot_);
00577 insertNeighborR(nbh, r, child->pivot_, distToPivot[i]);
00578 for (unsigned int j=0; j<children_.size(); ++j)
00579 if (permutation[j] >=0 && i != j &&
00580 (distToPivot[i] - dist > child->maxRange_[permutation[j]] ||
00581 distToPivot[i] + dist < child->minRange_[permutation[j]]))
00582 permutation[j] = -1;
00583 }
00584
00585 for (unsigned int i=0; i<children_.size(); ++i)
00586 if (permutation[i] >= 0)
00587 {
00588 child = children_[permutation[i]];
00589 if (distToPivot[i] - dist <= child->maxRadius_ &&
00590 distToPivot[i] + dist >= child->minRadius_)
00591 nodeQueue.push(std::make_pair(child, distToPivot[i]));
00592 }
00593 }
00594 }
00595
00596 void list(const GNAT& gnat, std::vector<_T> &data) const
00597 {
00598 if (!gnat.isRemoved(pivot_))
00599 data.push_back(pivot_);
00600 for (unsigned int i=0; i<data_.size(); ++i)
00601 if(!gnat.isRemoved(data_[i]))
00602 data.push_back(data_[i]);
00603 for (unsigned int i=0; i<children_.size(); ++i)
00604 children_[i]->list(gnat, data);
00605 }
00606
00607 friend std::ostream& operator<<(std::ostream& out, const Node& node)
00608 {
00609 out << "\ndegree:\t" << node.degree_;
00610 out << "\nminRadius:\t" << node.minRadius_;
00611 out << "\nmaxRadius:\t" << node.maxRadius_;
00612 out << "\nminRange:\t";
00613 for (unsigned int i=0; i<node.minRange_.size(); ++i)
00614 out << node.minRange_[i] << '\t';
00615 out << "\nmaxRange: ";
00616 for (unsigned int i=0; i<node.maxRange_.size(); ++i)
00617 out << node.maxRange_[i] << '\t';
00618 out << "\npivot:\t" << node.pivot_;
00619 out << "\ndata: ";
00620 for (unsigned int i=0; i<node.data_.size(); ++i)
00621 out << node.data_[i] << '\t';
00622 out << "\nthis:\t" << &node;
00623 out << "\nchildren:\n";
00624 for (unsigned int i=0; i<node.children_.size(); ++i)
00625 out << node.children_[i] << '\t';
00626 out << '\n';
00627 for (unsigned int i=0; i<node.children_.size(); ++i)
00628 out << *node.children_[i] << '\n';
00629 return out;
00630 }
00631
00633 unsigned int degree_;
00635 const _T pivot_;
00637 double minRadius_;
00639 double maxRadius_;
00642 std::vector<double> minRange_;
00645 std::vector<double> maxRange_;
00648 std::vector<_T> data_;
00651 std::vector<Node*> children_;
00652 };
00653
00655 Node* tree_;
00657 unsigned int degree_;
00662 unsigned int minDegree_;
00667 unsigned int maxDegree_;
00670 unsigned int maxNumPtsPerLeaf_;
00672 std::size_t size_;
00675 std::size_t rebuildSize_;
00679 std::size_t removedCacheSize_;
00681 GreedyKCenters<_T> pivotSelector_;
00683 boost::unordered_set<const _T*> removed_;
00684 };
00685
00686 }
00687
00688 #endif