LCOV - code coverage report
Current view: top level - neighbors - vptree.hpp (source / functions) Hit Total Coverage
Test: clean.info Lines: 0 76 0.0 %
Date: 2013-05-24 Functions: 0 117 0.0 %
Branches: 0 463 0.0 %

           Branch data     Line data    Source code
       1                 :            : /* This software is distributed under BSD 3-clause license (see LICENSE file).
       2                 :            :  *
       3                 :            :  * Copyright (c) 2012-2013 Laurens van der Maaten, Sergey Lisitsyn
       4                 :            :  */
       5                 :            : 
       6                 :            : #ifndef TAPKEE_VPTREE_H_
       7                 :            : #define TAPKEE_VPTREE_H_
       8                 :            : 
       9                 :            : /* Tapkee includes */
      10                 :            : #include <tapkee/defines.hpp>
      11                 :            : /* End of Tapkee includes */
      12                 :            : 
      13                 :            : #include <vector>
      14                 :            : #include <queue>
      15                 :            : #include <algorithm>
      16                 :            : #include <limits>
      17                 :            : 
      18                 :            : namespace tapkee
      19                 :            : {
      20                 :            : namespace tapkee_internal
      21                 :            : {
      22                 :            : 
      23                 :            : template<class Type, class RandomAccessIterator, class DistanceCallback> 
      24                 :            : struct compare_impl;
      25                 :            : 
      26                 :            : template<class RandomAccessIterator, class DistanceCallback>
      27                 :            : struct DistanceComparator
      28                 :            : {
      29                 :            :         DistanceCallback callback;
      30                 :            :         const RandomAccessIterator item;
      31                 :          0 :         DistanceComparator(const DistanceCallback& c, const RandomAccessIterator& i) :
      32                 :          0 :                 callback(c), item(i) {}
      33                 :          0 :         inline bool operator()(const RandomAccessIterator& a, const RandomAccessIterator& b)
      34                 :            :         {
      35                 :            :                 return compare_impl<typename DistanceCallback::type,RandomAccessIterator,DistanceCallback>()
      36                 :          0 :                         (callback,item,a,b);
      37                 :            :         }
      38                 :            : };
      39                 :            : 
      40                 :            : struct KernelType;
      41                 :            : 
      42                 :            : template<class RandomAccessIterator, class DistanceCallback> 
      43                 :            : struct compare_impl<KernelType,RandomAccessIterator,DistanceCallback>
      44                 :            : {
      45                 :          0 :         inline bool operator()(DistanceCallback& callback, const RandomAccessIterator& item,
      46                 :            :                                const RandomAccessIterator& a, const RandomAccessIterator& b)
      47                 :            :         {
      48                 :          0 :                 return (-2*callback(item,a) + callback(a,a)) < (-2*callback(item,b) + callback(b,b));
      49                 :            :         }
      50                 :            : };
      51                 :            : 
      52                 :            : struct DistanceType;
      53                 :            : 
      54                 :            : template<class RandomAccessIterator, class DistanceCallback> 
      55                 :            : struct compare_impl<DistanceType,RandomAccessIterator,DistanceCallback>
      56                 :            : {
      57                 :          0 :         inline bool operator()(DistanceCallback& callback, const RandomAccessIterator& item,
      58                 :            :                                const RandomAccessIterator& a, const RandomAccessIterator& b)
      59                 :            :         {
      60                 :          0 :                 return callback(item,a) < callback(item,b);
      61                 :            :         }
      62                 :            : };
      63                 :            : 
      64                 :            : template<class RandomAccessIterator, class DistanceCallback>
      65                 :            : class VantagePointTree
      66                 :            : {
      67                 :            : public:
      68                 :            : 
      69                 :            :         // Default constructor
      70                 :          0 :         VantagePointTree(RandomAccessIterator b, RandomAccessIterator e, DistanceCallback c) :  
      71                 :          0 :                 begin(b), items(), callback(c), tau(0.0), root(0)
      72                 :            :         {
      73   [ #  #  #  #  :          0 :                 items.reserve(e-b);
          #  #  #  #  #  
          #  #  #  #  #  
                      # ]
           [ #  #  #  # ]
                 [ #  # ]
      74 [ #  # ][ #  # ]:          0 :                 for (RandomAccessIterator i=b; i!=e; ++i)
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
                 [ #  # ]
      75 [ #  # ][ #  # ]:          0 :                         items.push_back(i);
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
                 [ #  # ]
      76 [ #  # ][ #  # ]:          0 :                 root = buildFromPoints(0, items.size());
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
                 [ #  # ]
      77                 :          0 :         }
      78                 :            : 
      79                 :            :         // Destructor
      80                 :          0 :         ~VantagePointTree() 
      81                 :            :         {
      82 [ #  # ][ #  # ]:          0 :                 delete root;
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
      83                 :          0 :         }
      84                 :            : 
      85                 :            :         // Function that uses the tree to find the k nearest neighbors of target
      86                 :          0 :         std::vector<IndexType> search(const RandomAccessIterator& target, int k)
      87                 :            :         {
      88                 :          0 :                 std::vector<IndexType> results;
      89                 :            :                 // Use a priority queue to store intermediate results on
      90 [ #  # ][ #  # ]:          0 :                 std::priority_queue<HeapItem> heap;
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
                 [ #  # ]
      91                 :            : 
      92                 :            :                 // Variable that tracks the distance to the farthest point in our results
      93                 :          0 :                 tau = std::numeric_limits<double>::max();
      94                 :            : 
      95                 :            :                 // Perform the searcg
      96   [ #  #  #  #  :          0 :                 search(root, target, k, heap);
          #  #  #  #  #  
             #  #  #  #  
                      # ]
      97                 :            : 
      98                 :            :                 // Gather final results
      99 [ #  # ][ #  # ]:          0 :                 results.reserve(k);
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
                 [ #  # ]
     100 [ #  # ][ #  # ]:          0 :                 while(!heap.empty()) {
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
     101 [ #  # ][ #  # ]:          0 :                         results.push_back(items[heap.top().index]-begin);
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
               [ # ][ # ]
         [ #  # ][ #  # ]
                 [ #  # ]
     102 [ #  # ][ #  # ]:          0 :                         heap.pop();
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
                 [ #  # ]
     103                 :            :                 }
     104 [ #  # ][ #  # ]:          0 :                 return results;
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
                 [ #  # ]
     105                 :            :         }
     106                 :            : 
     107                 :            : private:
     108                 :            : 
     109                 :            :         VantagePointTree(const VantagePointTree&);
     110                 :            :         VantagePointTree& operator=(const VantagePointTree&);
     111                 :            : 
     112                 :            :         RandomAccessIterator begin;
     113                 :            :         std::vector<RandomAccessIterator> items;
     114                 :            :         DistanceCallback callback;
     115                 :            :         double tau;
     116                 :            : 
     117                 :            :         struct Node
     118                 :            :         {
     119                 :            :                 int index;
     120                 :            :                 double threshold;
     121                 :            :                 Node* left;
     122                 :            :                 Node* right;
     123                 :            : 
     124                 :          0 :                 Node() : 
     125                 :            :                         index(0), threshold(0.), 
     126                 :          0 :                         left(0), right(0) 
     127                 :            :                 {
     128                 :          0 :                 }
     129                 :            : 
     130                 :          0 :                 ~Node() 
     131                 :            :                 {
     132 [ #  # ][ #  # ]:          0 :                         delete left;
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
                 [ #  # ]
     133 [ #  # ][ #  # ]:          0 :                         delete right;
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
                 [ #  # ]
     134                 :          0 :                 }
     135                 :            : 
     136                 :            :                 Node(const Node&);
     137                 :            :                 Node& operator=(const Node&);
     138                 :            : 
     139                 :            :         }* root;
     140                 :            : 
     141                 :          0 :         struct HeapItem {
     142                 :          0 :                 HeapItem(int i, double d) :
     143                 :          0 :                         index(i), distance(d) {}
     144                 :            :                 int index;
     145                 :            :                 double distance;
     146                 :          0 :                 bool operator<(const HeapItem& o) const {
     147                 :          0 :                         return distance < o.distance;
     148                 :            :                 }
     149                 :            :         };
     150                 :            : 
     151                 :            : 
     152                 :          0 :         Node* buildFromPoints(int lower, int upper)
     153                 :            :         {
     154 [ #  # ][ #  # ]:          0 :                 if (upper == lower)
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
                 [ #  # ]
     155                 :            :                 {
     156                 :          0 :                         return NULL;
     157                 :            :                 }
     158                 :            : 
     159                 :          0 :                 Node* node = new Node();
     160                 :          0 :                 node->index = lower;
     161                 :            : 
     162   [ #  #  #  #  :          0 :                 if (upper - lower > 1)
          #  #  #  #  #  
             #  #  #  #  
                      # ]
     163                 :            :                 {
     164                 :          0 :                         int i = (int) (tapkee::uniform_random() * (upper - lower - 1)) + lower;
     165                 :          0 :                         std::swap(items[lower], items[i]);
     166                 :            : 
     167                 :          0 :                         int median = (upper + lower) / 2;
     168                 :          0 :                         std::nth_element(items.begin() + lower + 1, items.begin() + median, items.begin() + upper, 
     169                 :            :                                 DistanceComparator<RandomAccessIterator,DistanceCallback>(callback,items[lower]));
     170                 :            : 
     171                 :          0 :                         node->threshold = callback.distance(items[lower], items[median]);
     172                 :          0 :                         node->index = lower;
     173                 :          0 :                         node->left = buildFromPoints(lower + 1, median);
     174                 :          0 :                         node->right = buildFromPoints(median, upper);
     175                 :            :                 }
     176                 :            : 
     177                 :          0 :                 return node;
     178                 :            :         }
     179                 :            : 
     180                 :          0 :         void search(Node* node, const RandomAccessIterator& target, int k, std::priority_queue<HeapItem>& heap)
     181                 :            :         {
     182 [ #  # ][ #  # ]:          0 :                 if (node == NULL) 
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
                 [ #  # ]
     183                 :          0 :                         return;
     184                 :            : 
     185                 :          0 :                 double distance = callback.distance(items[node->index], target);
     186                 :            : 
     187   [ #  #  #  #  :          0 :                 if (distance < tau) 
          #  #  #  #  #  
             #  #  #  #  
                      # ]
     188                 :            :                 {
     189 [ #  # ][ #  # ]:          0 :                         if (heap.size() == static_cast<size_t>(k)) 
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
                 [ #  # ]
     190                 :          0 :                                 heap.pop();
     191                 :            : 
     192                 :          0 :                         heap.push(HeapItem(node->index, distance));
     193                 :            : 
     194   [ #  #  #  #  :          0 :                         if (heap.size() == static_cast<size_t>(k))
          #  #  #  #  #  
             #  #  #  #  
                      # ]
     195                 :          0 :                                 tau = heap.top().distance;
     196                 :            :                 }
     197                 :            : 
     198 [ #  # ][ #  # ]:          0 :                 if (node->left == NULL && node->right == NULL) 
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
     199                 :            :                 {
     200                 :          0 :                         return;
     201                 :            :                 }
     202                 :            : 
     203 [ #  # ][ #  # ]:          0 :                 if (distance < node->threshold)
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
                 [ #  # ]
     204                 :            :                 {
     205 [ #  # ][ #  # ]:          0 :                         if ((distance - tau) <= node->threshold) 
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
                 [ #  # ]
     206                 :          0 :                                 search(node->left, target, k, heap);
     207                 :            : 
     208 [ #  # ][ #  # ]:          0 :                         if ((distance + tau) >= node->threshold) 
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
                 [ #  # ]
     209                 :          0 :                                 search(node->right, target, k, heap);
     210                 :            :                 } 
     211                 :            :                 else
     212                 :            :                 {
     213 [ #  # ][ #  # ]:          0 :                         if ((distance + tau) >= node->threshold) 
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
                 [ #  # ]
     214                 :          0 :                                 search(node->right, target, k, heap);
     215                 :            : 
     216 [ #  # ][ #  # ]:          0 :                         if ((distance - tau) <= node->threshold) 
         [ #  # ][ #  # ]
         [ #  # ][ #  # ]
                 [ #  # ]
     217                 :          0 :                                 search(node->left, target, k, heap);
     218                 :            :                 }
     219                 :            :         }
     220                 :            : };
     221                 :            : 
     222                 :            : }
     223                 :            : }
     224                 :            : #endif

Generated by: LCOV version 1.9