Tapkee
neighbors.hpp
Go to the documentation of this file.
00001 /* This software is distributed under BSD 3-clause license (see LICENSE file).
00002  *
00003  * Copyright (c) 2012-2013 Sergey Lisitsyn, Fernando J. Iglesias Garcia
00004  */
00005 
00006 #ifndef TAPKEE_NEIGHBORS_H_
00007 #define TAPKEE_NEIGHBORS_H_
00008 
00009 /* Tapkee includes */
00010 #include <tapkee/defines.hpp>
00011 #ifdef TAPKEE_USE_LGPL_COVERTREE
00012     #include <tapkee/neighbors/covertree.hpp>
00013 #endif
00014 #include <tapkee/neighbors/connected.hpp>
00015 #include <tapkee/neighbors/vptree.hpp>
00016 /* End of Tapkee includes */
00017 
00018 #include <vector>
00019 #include <utility>
00020 #include <algorithm>
00021 
00022 namespace tapkee
00023 {
00024 namespace tapkee_internal
00025 {
00026 
00027 template <class DistanceRecord>
00028 struct distances_comparator
00029 {
00030     inline bool operator()(const DistanceRecord& l, const DistanceRecord& r) const
00031     {
00032         return (l.second < r.second);
00033     }
00034 };
00035 
00036 struct KernelType 
00037 {
00038 };
00039 
00040 template <class RandomAccessIterator, class Callback>
00041 struct KernelDistance
00042 {
00043     KernelDistance(const Callback& cb) : callback(cb) {  } 
00044     inline ScalarType operator()(const RandomAccessIterator& l, const RandomAccessIterator& r)
00045     {
00046         return callback.kernel(*l,*r);
00047     }
00048     inline ScalarType distance(const RandomAccessIterator& l, const RandomAccessIterator& r)
00049     {
00050         return sqrt(callback.kernel(*l,*l) - 2*callback.kernel(*l,*r) + callback.kernel(*r,*r));
00051     }
00052     typedef KernelType type;
00053     Callback callback;
00054 };
00055 
00056 struct DistanceType
00057 {
00058 };
00059 
00060 template <class RandomAccessIterator, class Callback>
00061 struct PlainDistance
00062 {
00063     PlainDistance(const Callback& cb) : callback(cb) {  }
00064     inline ScalarType operator()(const RandomAccessIterator& l, const RandomAccessIterator& r)
00065     {
00066         return callback.distance(*l,*r);
00067     }
00068     inline ScalarType distance(const RandomAccessIterator& l, const RandomAccessIterator& r)
00069     {
00070         return callback.distance(*l,*r);
00071     }
00072     typedef DistanceType type;
00073     Callback callback;
00074 };
00075 
00076 #ifdef TAPKEE_USE_LGPL_COVERTREE
00077 template <class RandomAccessIterator, class Callback>
00078 Neighbors find_neighbors_covertree_impl(RandomAccessIterator begin, RandomAccessIterator end, 
00079                          Callback callback, IndexType k)
00080 {
00081     timed_context context("Covertree-based neighbors search");
00082 
00083     typedef CoverTreePoint<RandomAccessIterator> TreePoint;
00084     v_array<TreePoint> points;
00085     for (RandomAccessIterator iter=begin; iter!=end; ++iter)
00086         push(points, TreePoint(iter, callback(iter,iter)));
00087 
00088     node<TreePoint> ct = batch_create(callback, points);
00089 
00090     v_array< v_array<TreePoint> > res;
00091     ++k; // because one of the neighbors will be the actual query point
00092     k_nearest_neighbor(callback,ct,ct,res,k);
00093 
00094     Neighbors neighbors;
00095     neighbors.resize(end-begin);
00096     assert(end-begin==res.index);
00097     for (int i=0; i<res.index; ++i)
00098     {
00099         LocalNeighbors local_neighbors;
00100         local_neighbors.reserve(k);
00101         
00102         for (IndexType j=1; j<=k; ++j) // j=0 is the query point
00103         {
00104             // The actual query point is found as a neighbor, just ignore it
00105             if (res[i][j].iter_-begin==res[i][0].iter_-begin)
00106                 continue;
00107             local_neighbors.push_back(res[i][j].iter_-begin);
00108         }
00109         neighbors[res[i][0].iter_-begin] = local_neighbors;
00110         free(res[i].elements);
00111     };
00112     free(res.elements);
00113     free_children(ct);
00114     free(points.elements);
00115     return neighbors;
00116 }
00117 #endif
00118 
00119 template <class RandomAccessIterator, class Callback>
00120 Neighbors find_neighbors_bruteforce_impl(const RandomAccessIterator& begin, const RandomAccessIterator& end, 
00121                                          Callback callback, IndexType k)
00122 {
00123     timed_context context("Distance sorting based neighbors search");
00124     typedef std::pair<RandomAccessIterator, ScalarType> DistanceRecord;
00125     typedef std::vector<DistanceRecord> Distances;
00126 
00127     Neighbors neighbors;
00128     neighbors.reserve(end-begin);
00129     for (RandomAccessIterator iter=begin; iter!=end; ++iter)
00130     {
00131         Distances distances;
00132         for (RandomAccessIterator around_iter=begin; around_iter!=end; ++around_iter)
00133             distances.push_back(std::make_pair(around_iter, callback.distance(iter,around_iter)));
00134 
00135         std::nth_element(distances.begin(),distances.begin()+k+1,distances.end(),
00136                          distances_comparator<DistanceRecord>());
00137 
00138         LocalNeighbors local_neighbors;
00139         local_neighbors.reserve(k);
00140         for (typename Distances::const_iterator neighbors_iter=distances.begin(); 
00141                 neighbors_iter!=distances.begin()+k+1; ++neighbors_iter)
00142         {
00143             if (neighbors_iter->first != iter) 
00144                 local_neighbors.push_back(neighbors_iter->first - begin);
00145         }
00146         neighbors.push_back(local_neighbors);
00147     }
00148     return neighbors;
00149 }
00150 
00151 template <class RandomAccessIterator, class Callback>
00152 Neighbors find_neighbors_vptree_impl(const RandomAccessIterator& begin, const RandomAccessIterator& end, 
00153                                      Callback callback, IndexType k)
00154 {
00155     timed_context context("VP-Tree based neighbors search");
00156 
00157     Neighbors neighbors;
00158     neighbors.reserve(end-begin);
00159 
00160     VantagePointTree<RandomAccessIterator,Callback> tree(begin,end,callback);
00161 
00162     for (RandomAccessIterator i=begin; i!=end; ++i)
00163     {
00164         LocalNeighbors local_neighbors = tree.search(i,k+1);
00165         std::remove(local_neighbors.begin(),local_neighbors.end(),i-begin);
00166         neighbors.push_back(local_neighbors);
00167     }
00168 
00169     return neighbors;
00170 }
00171 
00172 template <class RandomAccessIterator, class Callback>
00173 Neighbors find_neighbors(NeighborsMethod method, const RandomAccessIterator& begin, 
00174                          const RandomAccessIterator& end, const Callback& callback, 
00175                          IndexType k, bool check_connectivity)
00176 {
00177     if (k > static_cast<IndexType>(end-begin-1))
00178     {
00179         LoggingSingleton::instance().message_warning("Number of neighbors is greater than number of objects to embed. "
00180                                                      "Using greatest possible number of neighbors.");
00181         k = static_cast<IndexType>(end-begin-1);
00182     }
00183     LoggingSingleton::instance().message_info("Using the " + get_neighbors_method_name(method) + " neighbors computation method.");
00184     Neighbors neighbors;
00185     switch (method)
00186     {
00187         case Brute: neighbors = find_neighbors_bruteforce_impl(begin,end,callback,k); break;
00188         case VpTree: neighbors = find_neighbors_vptree_impl(begin,end,callback,k); break;
00189 #ifdef TAPKEE_USE_LGPL_COVERTREE
00190         case CoverTree: neighbors = find_neighbors_covertree_impl(begin,end,callback,k); break;
00191 #endif
00192         default: break;
00193     }
00194 
00195     if (check_connectivity)
00196     {
00197         if (!is_connected(begin,end,neighbors))
00198             LoggingSingleton::instance().message_warning("The neighborhood graph is not connected.");
00199     }
00200     return neighbors;
00201 }
00202 
00203 } // End of namespace tapkee
00204 } // End of namespace tapkee_internal
00205 
00206 #endif
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines