Tapkee
external/barnes_hut_sne/vptree.hpp
Go to the documentation of this file.
00001 
00032 #include <stdlib.h>
00033 #include <algorithm>
00034 #include <vector>
00035 #include <stdio.h>
00036 #include <queue>
00037 #include <limits>
00038 
00039 #ifndef VPTREE_H
00040 #define VPTREE_H
00041 
00042 namespace tsne 
00043 {
00044 
00045 class DataPoint
00046 {
00047     int _D;
00048     int _ind;
00049     double* _x;
00050 
00051 public:
00052     DataPoint() : _D(1), _ind(-1), _x(NULL) { }
00053     DataPoint(int Dv, int indv, double* xv) : _D(Dv), _ind(indv), _x(NULL) 
00054     {
00055         _x = (double*) malloc(_D * sizeof(double));
00056         for(int d = 0; d < _D; d++) _x[d] = xv[d];
00057     }
00058     DataPoint(const DataPoint& other) : _D(), _ind(0), _x(NULL) // this makes a deep copy -- should not free anything
00059     {
00060         if(this != &other) {
00061             _D = other.dimensionality();
00062             _ind = other.index();
00063             _x = (double*) malloc(_D * sizeof(double));      
00064             for(int d = 0; d < _D; d++) _x[d] = other.x(d);
00065         }
00066     }
00067     ~DataPoint() { if(_x != NULL) free(_x); }
00068     DataPoint& operator= (const DataPoint& other) {         // asignment should free old object
00069         if(this != &other) {
00070             if(_x != NULL) free(_x);
00071             _D = other.dimensionality();
00072             _ind = other.index();
00073             _x = (double*) malloc(_D * sizeof(double));
00074             for(int d = 0; d < _D; d++) _x[d] = other.x(d);
00075         }
00076         return *this;
00077     }
00078     int index() const { return _ind; }
00079     int dimensionality() const { return _D; }
00080     double x(int d) const { return _x[d]; }
00081 };
00082 
00083 
00084 double euclidean_distance(const DataPoint &t1, const DataPoint &t2) {
00085     double dd = .0;
00086     for(int d = 0; d < t1.dimensionality(); d++) dd += (t1.x(d) - t2.x(d)) * (t1.x(d) - t2.x(d));
00087     return dd;
00088 }
00089 
00090 
00091 template<typename T, double (*distance)( const T&, const T& )>
00092 class VpTree
00093 {
00094 public:
00095 
00096     // Default constructor
00097     VpTree() :  _items(), _tau(0.0), _root(0) {}
00098 
00099     // Destructor
00100     ~VpTree() {
00101         delete _root;
00102     }
00103 
00104     // Function to create a new VpTree from data
00105     void create(const std::vector<T>& items) {
00106         delete _root;
00107         _items = items;
00108         _root = buildFromPoints(0, items.size());
00109     }
00110 
00111     // Function that uses the tree to find the k nearest neighbors of target
00112     void search(const T& target, int k, std::vector<T>* results, std::vector<double>* distances)
00113     {
00114 
00115         // Use a priority queue to store intermediate results on
00116         std::priority_queue<HeapItem> heap;
00117 
00118         // Variable that tracks the distance to the farthest point in our results
00119         _tau = DBL_MAX;
00120 
00121         // Perform the searcg
00122         search(_root, target, k, heap);
00123 
00124         // Gather final results
00125         results->clear(); distances->clear();
00126         while(!heap.empty()) {
00127             results->push_back(_items[heap.top().index]);
00128             distances->push_back(heap.top().dist);
00129             heap.pop();
00130         }
00131 
00132         // Results are in reverse order
00133         std::reverse(results->begin(), results->end());
00134         std::reverse(distances->begin(), distances->end());
00135     }
00136 
00137 private:
00138 
00139     VpTree(const VpTree&);
00140     VpTree& operator=(const VpTree&);
00141 
00142     std::vector<T> _items;
00143     double _tau;
00144 
00145     // Single node of a VP tree (has a point and radius; left children are closer to point than the radius)
00146     struct Node
00147     {
00148         int index;              // index of point in node
00149         double threshold;       // radius(?)
00150         Node* left;             // points closer by than threshold
00151         Node* right;            // points farther away than threshold
00152 
00153         Node() : index(0), threshold(0.), left(0), right(0) {}
00154 
00155         ~Node() 
00156         {
00157             delete left;
00158             delete right;
00159         }
00160 
00161         Node(const Node&);
00162         Node& operator=(const Node&);
00163 
00164     }* _root;
00165 
00166 
00167     // An item on the intermediate result queue
00168     struct HeapItem {
00169         HeapItem(int indexv, double distv) :
00170             index(indexv), dist(distv) {}
00171         int index;
00172         double dist;
00173         bool operator<(const HeapItem& o) const {
00174             return dist < o.dist;
00175         }
00176     };
00177 
00178     // Distance comparator for use in std::nth_element
00179     struct DistanceComparator
00180     {
00181         const T& item;
00182         DistanceComparator(const T& itemv) : item(itemv) {}
00183         bool operator()(const T& a, const T& b) {
00184             return distance(item, a) < distance(item, b);
00185         }
00186     };
00187 
00188     // Function that (recursively) fills the tree
00189     Node* buildFromPoints( int lower, int upper )
00190     {
00191         if (upper == lower) {     // indicates that we're done here!
00192             return NULL;
00193         }
00194 
00195         // Lower index is center of current node
00196         Node* node = new Node();
00197         node->index = lower;
00198 
00199         if (upper - lower > 1) {      // if we did not arrive at leaf yet
00200 
00201             // Choose an arbitrary point and move it to the start
00202             int i = (int) (tapkee::uniform_random() * (upper - lower - 1)) + lower;
00203             std::swap(_items[lower], _items[i]);
00204 
00205             // Partition around the median distance
00206             int median = (upper + lower) / 2;
00207             std::nth_element(_items.begin() + lower + 1,
00208                     _items.begin() + median,
00209                     _items.begin() + upper,
00210                     DistanceComparator(_items[lower]));
00211 
00212             // Threshold of the new node will be the distance to the median
00213             node->threshold = distance(_items[lower], _items[median]);
00214 
00215             // Recursively build tree
00216             node->index = lower;
00217             node->left = buildFromPoints(lower + 1, median);
00218             node->right = buildFromPoints(median, upper);
00219         }
00220 
00221         // Return result
00222         return node;
00223     }
00224 
00225     // Helper function that searches the tree    
00226     void search(Node* node, const T& target, int k, std::priority_queue<HeapItem>& heap)
00227     {
00228         if(node == NULL) return;     // indicates that we're done here
00229 
00230         // Compute distance between target and current node
00231         double dist = distance(_items[node->index], target);
00232 
00233         // If current node within radius tau
00234         if(dist < _tau) {
00235             if(heap.size() == static_cast<size_t>(k)) heap.pop(); // remove furthest node from result list (if we already have k results)
00236             heap.push(HeapItem(node->index, dist));           // add current node to result list
00237             if(heap.size() == static_cast<size_t>(k)) _tau = heap.top().dist;     // update value of tau (farthest point in result list)
00238         }
00239 
00240         // Return if we arrived at a leaf
00241         if(node->left == NULL && node->right == NULL) {
00242             return;
00243         }
00244 
00245         // If the target lies within the radius of ball
00246         if(dist < node->threshold) {
00247             if(dist - _tau <= node->threshold) {         // if there can still be neighbors inside the ball, recursively search left child first
00248                 search(node->left, target, k, heap);
00249             }
00250 
00251             if(dist + _tau >= node->threshold) {         // if there can still be neighbors outside the ball, recursively search right child
00252                 search(node->right, target, k, heap);
00253             }
00254 
00255             // If the target lies outsize the radius of the ball
00256         } else {
00257             if(dist + _tau >= node->threshold) {         // if there can still be neighbors outside the ball, recursively search right child first
00258                 search(node->right, target, k, heap);
00259             }
00260 
00261             if (dist - _tau <= node->threshold) {         // if there can still be neighbors inside the ball, recursively search left child
00262                 search(node->left, target, k, heap);
00263             }
00264         }
00265     }
00266 };
00267 
00268 }
00269 
00270 #endif
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines