Tapkee
neighbors/vptree.hpp
Go to the documentation of this file.
00001 /* This software is distributed under BSD 3-clause license (see LICENSE file).
00002  *
00003  * Copyright (c) 2012-2013 Laurens van der Maaten, Sergey Lisitsyn
00004  */
00005 
00006 #ifndef TAPKEE_VPTREE_H_
00007 #define TAPKEE_VPTREE_H_
00008 
00009 /* Tapkee includes */
00010 #include <tapkee/defines.hpp>
00011 /* End of Tapkee includes */
00012 
00013 #include <vector>
00014 #include <queue>
00015 #include <algorithm>
00016 #include <limits>
00017 
00018 namespace tapkee
00019 {
00020 namespace tapkee_internal
00021 {
00022 
00023 template<class Type, class RandomAccessIterator, class DistanceCallback> 
00024 struct compare_impl;
00025 
00026 template<class RandomAccessIterator, class DistanceCallback>
00027 struct DistanceComparator
00028 {
00029     DistanceCallback callback;
00030     const RandomAccessIterator item;
00031     DistanceComparator(const DistanceCallback& c, const RandomAccessIterator& i) :
00032         callback(c), item(i) {}
00033     inline bool operator()(const RandomAccessIterator& a, const RandomAccessIterator& b)
00034     {
00035         return compare_impl<typename DistanceCallback::type,RandomAccessIterator,DistanceCallback>()
00036             (callback,item,a,b);
00037     }
00038 };
00039 
00040 struct KernelType;
00041 
00042 template<class RandomAccessIterator, class DistanceCallback> 
00043 struct compare_impl<KernelType,RandomAccessIterator,DistanceCallback>
00044 {
00045     inline bool operator()(DistanceCallback& callback, const RandomAccessIterator& item,
00046                            const RandomAccessIterator& a, const RandomAccessIterator& b)
00047     {
00048         return (-2*callback(item,a) + callback(a,a)) < (-2*callback(item,b) + callback(b,b));
00049     }
00050 };
00051 
00052 struct DistanceType;
00053 
00054 template<class RandomAccessIterator, class DistanceCallback> 
00055 struct compare_impl<DistanceType,RandomAccessIterator,DistanceCallback>
00056 {
00057     inline bool operator()(DistanceCallback& callback, const RandomAccessIterator& item,
00058                            const RandomAccessIterator& a, const RandomAccessIterator& b)
00059     {
00060         return callback(item,a) < callback(item,b);
00061     }
00062 };
00063 
00064 template<class RandomAccessIterator, class DistanceCallback>
00065 class VantagePointTree
00066 {
00067 public:
00068 
00069     // Default constructor
00070     VantagePointTree(RandomAccessIterator b, RandomAccessIterator e, DistanceCallback c) :  
00071         begin(b), items(), callback(c), tau(0.0), root(0)
00072     {
00073         items.reserve(e-b);
00074         for (RandomAccessIterator i=b; i!=e; ++i)
00075             items.push_back(i);
00076         root = buildFromPoints(0, items.size());
00077     }
00078 
00079     // Destructor
00080     ~VantagePointTree() 
00081     {
00082         delete root;
00083     }
00084 
00085     // Function that uses the tree to find the k nearest neighbors of target
00086     std::vector<IndexType> search(const RandomAccessIterator& target, int k)
00087     {
00088         std::vector<IndexType> results;
00089         // Use a priority queue to store intermediate results on
00090         std::priority_queue<HeapItem> heap;
00091 
00092         // Variable that tracks the distance to the farthest point in our results
00093         tau = std::numeric_limits<double>::max();
00094 
00095         // Perform the searcg
00096         search(root, target, k, heap);
00097 
00098         // Gather final results
00099         results.reserve(k);
00100         while(!heap.empty()) {
00101             results.push_back(items[heap.top().index]-begin);
00102             heap.pop();
00103         }
00104         return results;
00105     }
00106 
00107 private:
00108 
00109     VantagePointTree(const VantagePointTree&);
00110     VantagePointTree& operator=(const VantagePointTree&);
00111 
00112     RandomAccessIterator begin;
00113     std::vector<RandomAccessIterator> items;
00114     DistanceCallback callback;
00115     double tau;
00116 
00117     struct Node
00118     {
00119         int index;
00120         double threshold;
00121         Node* left;
00122         Node* right;
00123 
00124         Node() : 
00125             index(0), threshold(0.), 
00126             left(0), right(0) 
00127         {
00128         }
00129 
00130         ~Node() 
00131         {
00132             delete left;
00133             delete right;
00134         }
00135 
00136         Node(const Node&);
00137         Node& operator=(const Node&);
00138 
00139     }* root;
00140 
00141     struct HeapItem {
00142         HeapItem(int i, double d) :
00143             index(i), distance(d) {}
00144         int index;
00145         double distance;
00146         bool operator<(const HeapItem& o) const {
00147             return distance < o.distance;
00148         }
00149     };
00150 
00151 
00152     Node* buildFromPoints(int lower, int upper)
00153     {
00154         if (upper == lower)
00155         {
00156             return NULL;
00157         }
00158 
00159         Node* node = new Node();
00160         node->index = lower;
00161 
00162         if (upper - lower > 1)
00163         {
00164             int i = (int) (tapkee::uniform_random() * (upper - lower - 1)) + lower;
00165             std::swap(items[lower], items[i]);
00166 
00167             int median = (upper + lower) / 2;
00168             std::nth_element(items.begin() + lower + 1, items.begin() + median, items.begin() + upper, 
00169                 DistanceComparator<RandomAccessIterator,DistanceCallback>(callback,items[lower]));
00170 
00171             node->threshold = callback.distance(items[lower], items[median]);
00172             node->index = lower;
00173             node->left = buildFromPoints(lower + 1, median);
00174             node->right = buildFromPoints(median, upper);
00175         }
00176 
00177         return node;
00178     }
00179 
00180     void search(Node* node, const RandomAccessIterator& target, int k, std::priority_queue<HeapItem>& heap)
00181     {
00182         if (node == NULL) 
00183             return;
00184 
00185         double distance = callback.distance(items[node->index], target);
00186 
00187         if (distance < tau) 
00188         {
00189             if (heap.size() == static_cast<size_t>(k)) 
00190                 heap.pop();
00191 
00192             heap.push(HeapItem(node->index, distance));
00193 
00194             if (heap.size() == static_cast<size_t>(k))
00195                 tau = heap.top().distance;
00196         }
00197 
00198         if (node->left == NULL && node->right == NULL) 
00199         {
00200             return;
00201         }
00202 
00203         if (distance < node->threshold)
00204         {
00205             if ((distance - tau) <= node->threshold) 
00206                 search(node->left, target, k, heap);
00207 
00208             if ((distance + tau) >= node->threshold) 
00209                 search(node->right, target, k, heap);
00210         } 
00211         else
00212         {
00213             if ((distance + tau) >= node->threshold) 
00214                 search(node->right, target, k, heap);
00215 
00216             if ((distance - tau) <= node->threshold) 
00217                 search(node->left, target, k, heap);
00218         }
00219     }
00220 };
00221 
00222 }
00223 }
00224 #endif
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines