LCOV - code coverage report
Current view: top level - external/barnes_hut_sne - vptree.hpp (source / functions) Hit Total Coverage
Test: clean.info Lines: 94 94 100.0 %
Date: 2013-05-24 Functions: 21 21 100.0 %
Branches: 63 88 71.6 %

           Branch data     Line data    Source code
       1                 :            : /**
       2                 :            :  * Copyright (c) 2013, Laurens van der Maaten (Delft University of Technology)
       3                 :            :  * All rights reserved.
       4                 :            :  *
       5                 :            :  * Redistribution and use in source and binary forms, with or without
       6                 :            :  * modification, are permitted provided that the following conditions are met:
       7                 :            :  * 1. Redistributions of source code must retain the above copyright
       8                 :            :  *    notice, this list of conditions and the following disclaimer.
       9                 :            :  * 2. Redistributions in binary form must reproduce the above copyright
      10                 :            :  *    notice, this list of conditions and the following disclaimer in the
      11                 :            :  *    documentation and/or other materials provided with the distribution.
      12                 :            :  * 3. All advertising materials mentioning features or use of this software
      13                 :            :  *    must display the following acknowledgement:
      14                 :            :  *    This product includes software developed by the Delft University of Technology.
      15                 :            :  * 4. Neither the name of the Delft University of Technology nor the names of
      16                 :            :  *    its contributors may be used to endorse or promote products derived from
      17                 :            :  *    this software without specific prior written permission.
      18                 :            :  *
      19                 :            :  * THIS SOFTWARE IS PROVIDED BY LAURENS VAN DER MAATEN ''AS IS'' AND ANY EXPRESS
      20                 :            :  * OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
      21                 :            :  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO
      22                 :            :  * EVENT SHALL LAURENS VAN DER MAATEN BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
      23                 :            :  * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
      24                 :            :  * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
      25                 :            :  * BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
      26                 :            :  * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING
      27                 :            :  * IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY
      28                 :            :  * OF SUCH DAMAGE.
      29                 :            :  *
      30                 :            :  */
      31                 :            : 
      32                 :            : #include <stdlib.h>
      33                 :            : #include <algorithm>
      34                 :            : #include <vector>
      35                 :            : #include <stdio.h>
      36                 :            : #include <queue>
      37                 :            : #include <limits>
      38                 :            : 
      39                 :            : #ifndef VPTREE_H
      40                 :            : #define VPTREE_H
      41                 :            : 
      42                 :            : namespace tsne 
      43                 :            : {
      44                 :            : 
      45                 :            : class DataPoint
      46                 :            : {
      47                 :            :         int _D;
      48                 :            :         int _ind;
      49                 :            :         double* _x;
      50                 :            : 
      51                 :            : public:
      52                 :            :         DataPoint() : _D(1), _ind(-1), _x(NULL) { }
      53                 :         51 :         DataPoint(int Dv, int indv, double* xv) : _D(Dv), _ind(indv), _x(NULL) 
      54                 :            :         {
      55                 :         51 :                 _x = (double*) malloc(_D * sizeof(double));
      56         [ +  + ]:        204 :                 for(int d = 0; d < _D; d++) _x[d] = xv[d];
      57                 :         51 :         }
      58                 :       2575 :         DataPoint(const DataPoint& other) : _D(), _ind(0), _x(NULL) // this makes a deep copy -- should not free anything
      59                 :            :         {
      60         [ +  - ]:       2575 :                 if(this != &other) {
      61                 :       2575 :                         _D = other.dimensionality();
      62                 :       2575 :                         _ind = other.index();
      63                 :       2575 :                         _x = (double*) malloc(_D * sizeof(double));      
      64         [ +  + ]:      10300 :                         for(int d = 0; d < _D; d++) _x[d] = other.x(d);
      65                 :            :                 }
      66                 :       2575 :         }
      67         [ +  - ]:       2626 :         ~DataPoint() { if(_x != NULL) free(_x); }
      68                 :       1831 :         DataPoint& operator= (const DataPoint& other) {         // asignment should free old object
      69         [ +  + ]:       1831 :                 if(this != &other) {
      70         [ +  - ]:       1816 :                         if(_x != NULL) free(_x);
      71                 :       1816 :                         _D = other.dimensionality();
      72                 :       1816 :                         _ind = other.index();
      73                 :       1816 :                         _x = (double*) malloc(_D * sizeof(double));
      74         [ +  + ]:       7264 :                         for(int d = 0; d < _D; d++) _x[d] = other.x(d);
      75                 :            :                 }
      76                 :       1831 :                 return *this;
      77                 :            :         }
      78                 :       5891 :         int index() const { return _ind; }
      79                 :      17879 :         int dimensionality() const { return _D; }
      80                 :      53637 :         double x(int d) const { return _x[d]; }
      81                 :            : };
      82                 :            : 
      83                 :            : 
      84                 :       3372 : double euclidean_distance(const DataPoint &t1, const DataPoint &t2) {
      85                 :       3372 :         double dd = .0;
      86         [ +  + ]:      13488 :         for(int d = 0; d < t1.dimensionality(); d++) dd += (t1.x(d) - t2.x(d)) * (t1.x(d) - t2.x(d));
      87                 :       3372 :         return dd;
      88                 :            : }
      89                 :            : 
      90                 :            : 
      91                 :            : template<typename T, double (*distance)( const T&, const T& )>
      92                 :            : class VpTree
      93                 :            : {
      94                 :            : public:
      95                 :            : 
      96                 :            :         // Default constructor
      97                 :          1 :         VpTree() :  _items(), _tau(0.0), _root(0) {}
      98                 :            : 
      99                 :            :         // Destructor
     100                 :          1 :         ~VpTree() {
     101 [ +  - ][ +  - ]:          1 :                 delete _root;
     102                 :          1 :         }
     103                 :            : 
     104                 :            :         // Function to create a new VpTree from data
     105                 :          1 :         void create(const std::vector<T>& items) {
     106         [ -  + ]:          1 :                 delete _root;
     107                 :          1 :                 _items = items;
     108                 :          1 :                 _root = buildFromPoints(0, items.size());
     109                 :          1 :         }
     110                 :            : 
     111                 :            :         // Function that uses the tree to find the k nearest neighbors of target
     112                 :         50 :         void search(const T& target, int k, std::vector<T>* results, std::vector<double>* distances)
     113                 :            :         {
     114                 :            : 
     115                 :            :                 // Use a priority queue to store intermediate results on
     116 [ +  - ][ +  - ]:         50 :                 std::priority_queue<HeapItem> heap;
     117                 :            : 
     118                 :            :                 // Variable that tracks the distance to the farthest point in our results
     119                 :         50 :                 _tau = DBL_MAX;
     120                 :            : 
     121                 :            :                 // Perform the searcg
     122         [ +  - ]:         50 :                 search(_root, target, k, heap);
     123                 :            : 
     124                 :            :                 // Gather final results
     125 [ +  - ][ +  - ]:         50 :                 results->clear(); distances->clear();
     126 [ +  - ][ +  + ]:       1600 :                 while(!heap.empty()) {
     127 [ +  - ][ +  - ]:       1550 :                         results->push_back(_items[heap.top().index]);
     128 [ +  - ][ +  - ]:       1550 :                         distances->push_back(heap.top().dist);
     129         [ +  - ]:       1550 :                         heap.pop();
     130                 :            :                 }
     131                 :            : 
     132                 :            :                 // Results are in reverse order
     133 [ +  - ][ +  - ]:         50 :                 std::reverse(results->begin(), results->end());
                 [ +  - ]
     134 [ +  - ][ +  - ]:         50 :                 std::reverse(distances->begin(), distances->end());
                 [ +  - ]
     135                 :         50 :         }
     136                 :            : 
     137                 :            : private:
     138                 :            : 
     139                 :            :         VpTree(const VpTree&);
     140                 :            :         VpTree& operator=(const VpTree&);
     141                 :            : 
     142                 :            :         std::vector<T> _items;
     143                 :            :         double _tau;
     144                 :            : 
     145                 :            :         // Single node of a VP tree (has a point and radius; left children are closer to point than the radius)
     146                 :            :         struct Node
     147                 :            :         {
     148                 :            :                 int index;              // index of point in node
     149                 :            :                 double threshold;       // radius(?)
     150                 :            :                 Node* left;             // points closer by than threshold
     151                 :            :                 Node* right;            // points farther away than threshold
     152                 :            : 
     153                 :         50 :                 Node() : index(0), threshold(0.), left(0), right(0) {}
     154                 :            : 
     155                 :         50 :                 ~Node() 
     156                 :            :                 {
     157         [ +  + ]:         50 :                         delete left;
     158         [ +  + ]:         50 :                         delete right;
     159                 :         50 :                 }
     160                 :            : 
     161                 :            :                 Node(const Node&);
     162                 :            :                 Node& operator=(const Node&);
     163                 :            : 
     164                 :            :         }* _root;
     165                 :            : 
     166                 :            : 
     167                 :            :         // An item on the intermediate result queue
     168                 :       2178 :         struct HeapItem {
     169                 :       2178 :                 HeapItem(int indexv, double distv) :
     170                 :       2178 :                         index(indexv), dist(distv) {}
     171                 :            :                 int index;
     172                 :            :                 double dist;
     173                 :      13666 :                 bool operator<(const HeapItem& o) const {
     174                 :      13666 :                         return dist < o.dist;
     175                 :            :                 }
     176                 :            :         };
     177                 :            : 
     178                 :            :         // Distance comparator for use in std::nth_element
     179                 :            :         struct DistanceComparator
     180                 :            :         {
     181                 :            :                 const T& item;
     182                 :         33 :                 DistanceComparator(const T& itemv) : item(itemv) {}
     183                 :        536 :                 bool operator()(const T& a, const T& b) {
     184                 :        536 :                         return distance(item, a) < distance(item, b);
     185                 :            :                 }
     186                 :            :         };
     187                 :            : 
     188                 :            :         // Function that (recursively) fills the tree
     189                 :         67 :         Node* buildFromPoints( int lower, int upper )
     190                 :            :         {
     191         [ +  + ]:         67 :                 if (upper == lower) {     // indicates that we're done here!
     192                 :         17 :                         return NULL;
     193                 :            :                 }
     194                 :            : 
     195                 :            :                 // Lower index is center of current node
     196                 :         50 :                 Node* node = new Node();
     197                 :         50 :                 node->index = lower;
     198                 :            : 
     199         [ +  + ]:         50 :                 if (upper - lower > 1) {      // if we did not arrive at leaf yet
     200                 :            : 
     201                 :            :                         // Choose an arbitrary point and move it to the start
     202                 :         33 :                         int i = (int) (tapkee::uniform_random() * (upper - lower - 1)) + lower;
     203                 :         33 :                         std::swap(_items[lower], _items[i]);
     204                 :            : 
     205                 :            :                         // Partition around the median distance
     206                 :         33 :                         int median = (upper + lower) / 2;
     207                 :         33 :                         std::nth_element(_items.begin() + lower + 1,
     208                 :            :                                         _items.begin() + median,
     209                 :            :                                         _items.begin() + upper,
     210                 :            :                                         DistanceComparator(_items[lower]));
     211                 :            : 
     212                 :            :                         // Threshold of the new node will be the distance to the median
     213                 :         33 :                         node->threshold = distance(_items[lower], _items[median]);
     214                 :            : 
     215                 :            :                         // Recursively build tree
     216                 :         33 :                         node->index = lower;
     217                 :         33 :                         node->left = buildFromPoints(lower + 1, median);
     218                 :         33 :                         node->right = buildFromPoints(median, upper);
     219                 :            :                 }
     220                 :            : 
     221                 :            :                 // Return result
     222                 :         67 :                 return node;
     223                 :            :         }
     224                 :            : 
     225                 :            :         // Helper function that searches the tree    
     226                 :       2989 :         void search(Node* node, const T& target, int k, std::priority_queue<HeapItem>& heap)
     227                 :            :         {
     228         [ +  + ]:       2989 :                 if(node == NULL) return;     // indicates that we're done here
     229                 :            : 
     230                 :            :                 // Compute distance between target and current node
     231                 :       2267 :                 double dist = distance(_items[node->index], target);
     232                 :            : 
     233                 :            :                 // If current node within radius tau
     234         [ +  + ]:       2267 :                 if(dist < _tau) {
     235         [ +  + ]:       2178 :                         if(heap.size() == static_cast<size_t>(k)) heap.pop(); // remove furthest node from result list (if we already have k results)
     236                 :       2178 :                         heap.push(HeapItem(node->index, dist));           // add current node to result list
     237         [ +  + ]:       2178 :                         if(heap.size() == static_cast<size_t>(k)) _tau = heap.top().dist;     // update value of tau (farthest point in result list)
     238                 :            :                 }
     239                 :            : 
     240                 :            :                 // Return if we arrived at a leaf
     241 [ +  + ][ +  + ]:       2267 :                 if(node->left == NULL && node->right == NULL) {
     242                 :        745 :                         return;
     243                 :            :                 }
     244                 :            : 
     245                 :            :                 // If the target lies within the radius of ball
     246         [ +  + ]:       1522 :                 if(dist < node->threshold) {
     247         [ +  - ]:        336 :                         if(dist - _tau <= node->threshold) {         // if there can still be neighbors inside the ball, recursively search left child first
     248                 :        336 :                                 search(node->left, target, k, heap);
     249                 :            :                         }
     250                 :            : 
     251         [ +  + ]:        336 :                         if(dist + _tau >= node->threshold) {         // if there can still be neighbors outside the ball, recursively search right child
     252                 :        328 :                                 search(node->right, target, k, heap);
     253                 :            :                         }
     254                 :            : 
     255                 :            :                         // If the target lies outsize the radius of the ball
     256                 :            :                 } else {
     257         [ +  - ]:       1186 :                         if(dist + _tau >= node->threshold) {         // if there can still be neighbors outside the ball, recursively search right child first
     258                 :       1186 :                                 search(node->right, target, k, heap);
     259                 :            :                         }
     260                 :            : 
     261         [ +  + ]:       1186 :                         if (dist - _tau <= node->threshold) {         // if there can still be neighbors inside the ball, recursively search left child
     262                 :       2989 :                                 search(node->left, target, k, heap);
     263                 :            :                         }
     264                 :            :                 }
     265                 :            :         }
     266                 :            : };
     267                 :            : 
     268                 :            : }
     269                 :            : 
     270                 :            : #endif

Generated by: LCOV version 1.9