Tapkee
covertree.hpp
Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Copyright (c) 2009-2013 John Langford, Dinoj Surendran, Fernando José Iglesias García
00008  */
00009 
00010 #ifndef COVERTREE_H_
00011 #define COVERTREE_H_
00012 
00013 /* Tapkee includes */
00014 #include <tapkee/neighbors/covertree_point.hpp>
00015 /* End of Tapkee includes */
00016 
00017 #include <cmath>
00018 #include <limits>
00019 #include <stdio.h>
00020 #include <assert.h>
00021 
00022 /* First written by John Langford jl@hunch.net
00023    Templatization by Dinoj Surendran dinojs@gmail.com
00024    Adaptation to Shogun by Fernando José Iglesias García
00025  */
00026 namespace tapkee
00027 {
00028 namespace tapkee_internal
00029 {
00030 
00034 template<class P>
00035 struct node 
00036 {
00037     node() : p(), max_dist(0.0), parent_dist(0.0), 
00038         children(NULL), num_children(0), scale(0) 
00039     {
00040     }
00041 
00042     node(P _p, ScalarType _max_dist, ScalarType _parent_dist, node<P>* _children,
00043          unsigned short int _num_children, short int _scale) : p(_p), 
00044         max_dist(_max_dist), parent_dist(_parent_dist), children(_children),
00045         num_children(_num_children), scale(_scale) 
00046     {
00047     }
00048 
00050     P p;
00051 
00053     ScalarType max_dist;
00054 
00056     ScalarType parent_dist;
00057 
00059     node<P>* children;
00060 
00062     unsigned short int num_children;
00063 
00065     short int scale;
00066 };
00067 
00068 template<class P>
00069 void free_children(const node<P>& n)
00070 {
00071     for (int i=0; i<n.num_children; i++)
00072     {
00073         free_children<P>(n.children[i]);
00074         n.children[i].~node<P>();
00075     }
00076     free(n.children);
00077 }
00078 
00079 
00083 template<class P>
00084 struct ds_node {
00085 
00086     ds_node() : dist(), p() {}
00087 
00089     v_array<ScalarType> dist;
00090 
00092     P p;
00093 };
00094 
00095 static ScalarType base = COVERTREE_BASE;
00096 static ScalarType il2 = 1. / log(base);
00097 
00098 inline ScalarType dist_of_scale (int s)
00099 {
00100     return pow(base, s);
00101 }
00102 
00103 inline int get_scale(ScalarType d)
00104 {
00105     return (int)ceil(il2 * log(d));
00106 }
00107 
00108     template<class P>
00109 node<P> new_node(const P &p)
00110 {
00111     node<P> new_node;
00112     new_node.p = p;
00113     return new_node;
00114 }
00115 
00116     template<class P>
00117 node<P> new_leaf(const P &p)
00118 {
00119     node<P> new_leaf(p,0.,0.,NULL,0,100);
00120     return new_leaf;
00121 }
00122 
00123     template<class P>
00124 ScalarType max_set(v_array<ds_node<P> > &v)
00125 {
00126     ScalarType max = 0.;
00127     for (int i = 0; i < v.index; i++)
00128         if ( max < v[i].dist.last()) 
00129             max = v[i].dist.last();
00130     return max;
00131 }
00132 
00133 void print_space(int s)
00134 {
00135     for (int i = 0; i < s; i++)
00136         printf(" ");
00137 }
00138 
00139 template<class P>
00140 void print(int depth, node<P> &top_node)
00141 {
00142     print_space(depth);
00143     print(top_node.p);
00144     if ( top_node.num_children > 0 ) 
00145     {
00146         print_space(depth); 
00147         printf("scale = %i\n",top_node.scale);
00148         print_space(depth); 
00149         printf("max_dist = %f\n",top_node.max_dist);
00150         print_space(depth); 
00151         printf("num children = %i\n",top_node.num_children);
00152         for (int i = 0; i < top_node.num_children;i++)
00153             print(depth+1, top_node.children[i]);
00154     }
00155 }
00156 
00157 template<class P>
00158 void split(v_array<ds_node<P> >& point_set, v_array<ds_node<P> >& far_set, int max_scale)
00159 {
00160     IndexType new_index = 0;
00161     ScalarType fmax = dist_of_scale(max_scale);
00162     for (int i = 0; i < point_set.index; i++)
00163     {
00164         if (point_set[i].dist.last() <= fmax) 
00165         {
00166             point_set[new_index++] = point_set[i];
00167         }
00168         else
00169             push(far_set,point_set[i]);
00170     }
00171     point_set.index=new_index;  
00172 }
00173 
00174 template<class P, class DistanceCallback>
00175 void dist_split(DistanceCallback& dcb, v_array<ds_node<P> >& point_set,
00176         v_array<ds_node<P> >& new_point_set, 
00177         P new_point, 
00178         int max_scale)
00179 {
00180     IndexType new_index = 0;
00181     ScalarType fmax = dist_of_scale(max_scale);
00182     for(int i = 0; i < point_set.index; i++) 
00183     {
00184         ScalarType new_d;
00185         new_d = distance(dcb, new_point, point_set[i].p, fmax);
00186         if (new_d <= fmax ) 
00187         {
00188             push(point_set[i].dist, new_d);
00189             push(new_point_set,point_set[i]);
00190         }
00191         else
00192             point_set[new_index++] = point_set[i];
00193     }
00194     point_set.index = new_index;
00195 }
00196 
00197 /*
00198    max_scale is the maximum scale of the node we might create here.
00199    point_set contains points which are 2*max_scale or less away.
00200    */
00201 template <class P, class DistanceCallback>
00202 node<P> batch_insert(DistanceCallback& dcb, const P& p,
00203         int max_scale, 
00204         int top_scale,
00205         v_array<ds_node<P> >& point_set, 
00206         v_array<ds_node<P> >& consumed_set,
00207         v_array<v_array<ds_node<P> > >& stack)
00208 {
00209     if (point_set.index == 0) 
00210         return new_leaf(p);
00211     else {
00212         ScalarType max_dist = max_set(point_set); //O(|point_set|)
00213         int next_scale = std::min(max_scale - 1, get_scale(max_dist));
00214         if (next_scale == -2147483647-1) // We have points with distance 0.
00215         {
00216             v_array<node<P> > children;
00217             push(children,new_leaf(p));
00218             while (point_set.index > 0)
00219             {
00220                 push(children,new_leaf(point_set.last().p));
00221                 push(consumed_set,point_set.last());
00222                 point_set.decr();
00223             }
00224             node<P> n = new_node(p);
00225             n.scale = 100; // A magic number meant to be larger than all scales.  
00226             n.max_dist = 0;
00227             alloc(children,children.index);
00228             n.num_children = children.index;
00229             n.children = children.elements;
00230             return n;
00231         }
00232         else
00233         {
00234             v_array<ds_node<P> > far = pop(stack);
00235             split(point_set,far,max_scale); //O(|point_set|)
00236 
00237             node<P> child = batch_insert(dcb, p, next_scale, top_scale, point_set, consumed_set, stack);
00238 
00239             if (point_set.index == 0)
00240             {
00241                 push(stack,point_set);
00242                 point_set=far;
00243                 return child;
00244             }
00245             else {
00246                 node<P> n = new_node(p);
00247                 v_array<node<P> > children;
00248                 push(children, child);
00249                 v_array<ds_node<P> > new_point_set = pop(stack);
00250                 v_array<ds_node<P> > new_consumed_set = pop(stack);
00251                 while (point_set.index != 0) { //O(|point_set| * num_children)
00252                     P new_point = point_set.last().p;
00253                     ScalarType new_dist = point_set.last().dist.last();
00254                     push(consumed_set, point_set.last());
00255                     point_set.decr();
00256 
00257                     dist_split(dcb,point_set,new_point_set,new_point,max_scale); //O(|point_saet|)
00258                     dist_split(dcb,far,new_point_set,new_point,max_scale); //O(|far|)
00259 
00260                     node<P> new_child = 
00261                         batch_insert(dcb, new_point, next_scale, top_scale, new_point_set, new_consumed_set, stack);
00262                     new_child.parent_dist = new_dist;
00263 
00264                     push(children, new_child);
00265 
00266                     ScalarType fmax = dist_of_scale(max_scale);
00267                     for(int i = 0; i< new_point_set.index; i++) //O(|new_point_set|)
00268                     {
00269                         new_point_set[i].dist.decr();
00270                         if (new_point_set[i].dist.last() <= fmax)
00271                             push(point_set, new_point_set[i]);
00272                         else
00273                             push(far, new_point_set[i]);
00274                     }
00275                     for(int i = 0; i< new_consumed_set.index; i++) //O(|new_point_set|)
00276                     {
00277                         new_consumed_set[i].dist.decr();
00278                         push(consumed_set, new_consumed_set[i]);
00279                     }
00280                     new_point_set.index = 0;
00281                     new_consumed_set.index = 0;
00282                 }
00283                 push(stack,new_point_set);
00284                 push(stack,new_consumed_set);
00285                 push(stack,point_set);
00286                 point_set=far;
00287                 n.scale = top_scale - max_scale;
00288                 n.max_dist = max_set(consumed_set);
00289                 alloc(children,children.index);
00290                 n.num_children = children.index;
00291                 n.children = children.elements;
00292                 return n;
00293             }
00294         }
00295     }
00296 }
00297 
00298 template<class P, class DistanceCallback>
00299 node<P> batch_create(DistanceCallback& dcb, v_array<P> points)
00300 {
00301     assert(points.index > 0);
00302     v_array<ds_node<P> > point_set;
00303     v_array<v_array<ds_node<P> > > stack;
00304 
00305     for (int i = 1; i < points.index; i++) {
00306         ds_node<P> temp;
00307         push(temp.dist, distance(dcb, points[0], points[i], std::numeric_limits<ScalarType>::max()));
00308         temp.p = points[i];
00309         push(point_set,temp);
00310     }
00311 
00312     v_array<ds_node<P> > consumed_set;
00313 
00314     ScalarType max_dist = max_set(point_set);
00315 
00316     node<P> top = batch_insert (dcb, points[0],
00317             get_scale(max_dist),
00318             get_scale(max_dist),
00319             point_set, 
00320             consumed_set,
00321             stack);
00322     for (int i = 0; i<consumed_set.index;i++)
00323         free(consumed_set[i].dist.elements);
00324     free(consumed_set.elements);
00325     for (int i = 0; i<stack.index;i++)
00326         free(stack[i].elements);
00327     free(stack.elements);
00328     free(point_set.elements);
00329     return top;
00330 }
00331 
00332 void add_height(int d, v_array<int> &heights)
00333 {
00334     if (heights.index <= d)
00335         for(;heights.index <= d;)
00336             push(heights,0);
00337     heights[d] = heights[d] + 1;
00338 }
00339 
00340 template <class P>
00341 int height_dist(const node<P> top_node,v_array<int> &heights)
00342 {
00343     if (top_node.num_children == 0)
00344     {
00345         add_height(0,heights);
00346         return 0;
00347     }
00348     else
00349     {
00350         int max_v=0;
00351         for (int i = 0; i<top_node.num_children ;i++)
00352         {
00353             int d = height_dist(top_node.children[i], heights);
00354             if (d > max_v)
00355                 max_v = d;
00356         }
00357         add_height(1 + max_v, heights);
00358         return (1 + max_v);
00359     }
00360 }
00361 
00362 template <class P>
00363 void depth_dist(int top_scale, const node<P> top_node,v_array<int> &depths)
00364 {
00365     if (top_node.num_children > 0)
00366         for (int i = 0; i<top_node.num_children ;i++)
00367         {
00368             add_height(top_node.scale, depths);
00369             depth_dist(top_scale, top_node.children[i], depths);
00370         }
00371 }
00372 
00373 template <class P>
00374 void breadth_dist(const node<P> top_node,v_array<int> &breadths)
00375 {
00376     if (top_node.num_children == 0)
00377         add_height(0,breadths);
00378     else
00379     {
00380         for (int i = 0; i<top_node.num_children ;i++)
00381             breadth_dist(top_node.children[i], breadths);
00382         add_height(top_node.num_children, breadths);
00383     }
00384 }
00385 
00389 template <class P>
00390 struct d_node 
00391 {
00393     ScalarType dist;
00394 
00396     const node<P> *n;
00397 };
00398 
00399 template <class P>
00400 inline ScalarType compare(const d_node<P> *p1, const d_node<P>* p2)
00401 {
00402     return p1 -> dist - p2 -> dist;
00403 }
00404 
00405 template <class P>
00406 void halfsort (v_array<d_node<P> > cover_set)
00407 {
00408     if (cover_set.index <= 1)
00409         return;
00410     register d_node<P> *base_ptr =  cover_set.elements;
00411 
00412     d_node<P> *hi = &base_ptr[cover_set.index - 1];
00413     d_node<P> *right_ptr = hi;
00414     d_node<P> *left_ptr;
00415 
00416     while (right_ptr > base_ptr)
00417     {
00418         d_node<P> *mid = base_ptr + ((hi - base_ptr) >> 1);
00419 
00420         if (compare ( mid,  base_ptr) < 0.)
00421             std::swap(*mid, *base_ptr);
00422         if (compare ( hi,  mid) < 0.)
00423             std::swap(*mid, *hi);
00424         else
00425             goto jump_over;
00426         if (compare ( mid,  base_ptr) < 0.)
00427             std::swap(*mid, *base_ptr);
00428 jump_over:;
00429 
00430         left_ptr  = base_ptr + 1;
00431         right_ptr = hi - 1;
00432 
00433         do
00434         {
00435             while (compare (left_ptr, mid) < 0.)
00436                 left_ptr++;
00437 
00438             while (compare (mid, right_ptr) < 0.)
00439                 right_ptr--;
00440 
00441             if (left_ptr < right_ptr)
00442             {
00443                 std::swap(*left_ptr, *right_ptr);
00444                 if (mid == left_ptr)
00445                     mid = right_ptr;
00446                 else if (mid == right_ptr)
00447                     mid = left_ptr;
00448                 left_ptr++;
00449                 right_ptr--;
00450             }
00451             else if (left_ptr == right_ptr)
00452             {
00453                 left_ptr ++;
00454                 right_ptr --;
00455                 break;
00456             }
00457         }
00458         while (left_ptr <= right_ptr);
00459         hi = right_ptr;
00460     }
00461 }
00462 
00463 template <class P>
00464 v_array<v_array<d_node<P> > > get_cover_sets(v_array<v_array<v_array<d_node<P> > > > &spare_cover_sets)
00465 {
00466     v_array<v_array<d_node<P> > > ret = pop(spare_cover_sets);
00467     while (ret.index < 101)
00468     {
00469         v_array<d_node<P> > temp;
00470         push(ret, temp);
00471     }
00472     return ret;
00473 }
00474 
00475 inline bool shell(ScalarType parent_query_dist, ScalarType child_parent_dist, ScalarType upper_bound)
00476 {
00477     return parent_query_dist - child_parent_dist <= upper_bound;
00478     //    && child_parent_dist - parent_query_dist <= upper_bound;
00479 }
00480 
00481 int internal_k =1;
00482 void update_k(ScalarType *k_upper_bound, ScalarType upper_bound)
00483 {
00484     ScalarType *end = k_upper_bound + internal_k-1;
00485     ScalarType *begin = k_upper_bound;
00486     for (;end != begin; begin++)
00487     {
00488         if (upper_bound < *(begin+1))
00489             *begin = *(begin+1);
00490         else {
00491             *begin = upper_bound;
00492             break;
00493         }
00494     }
00495     if (end == begin)
00496         *begin = upper_bound;
00497 }
00498 ScalarType *alloc_k()
00499 {
00500     return (ScalarType*)malloc(sizeof(ScalarType) * internal_k);
00501 }
00502 void set_k(ScalarType* begin, ScalarType max)
00503 {
00504     for(ScalarType *end = begin+internal_k;end != begin; begin++)
00505         *begin = max;
00506 }
00507 
00508 ScalarType internal_epsilon =0.;
00509 //void update_epsilon(ScalarType *upper_bound, ScalarType new_dist) {}
00510 ScalarType *alloc_epsilon()
00511 {
00512     return (ScalarType *)malloc(sizeof(ScalarType));
00513 }
00514 void set_epsilon(ScalarType* begin)
00515 {
00516     *begin = internal_epsilon;
00517 }
00518 
00519 void update_unequal(ScalarType *upper_bound, ScalarType new_dist) 
00520 {
00521     if (new_dist != 0.)
00522         *upper_bound = new_dist;
00523 }
00524 ScalarType* (*alloc_unequal)() = alloc_epsilon;
00525 void set_unequal(ScalarType* begin, ScalarType max)
00526 {
00527     *begin = max;
00528 }
00529 
00530 void (*update)(ScalarType *foo, ScalarType bar) = update_k;
00531 void (*setter)(ScalarType *foo, ScalarType bar) = set_k;
00532 ScalarType* (*alloc_upper)() = alloc_k;
00533 
00534 template <class P, class DistanceCallback>
00535 inline void copy_zero_set(DistanceCallback& dcb, node<P>* query_chi,
00536         ScalarType* new_upper_bound, v_array<d_node<P> > &zero_set,
00537         v_array<d_node<P> > &new_zero_set)
00538 {
00539     new_zero_set.index = 0;
00540     d_node<P> *end = zero_set.elements + zero_set.index;
00541     for (d_node<P> *ele = zero_set.elements; ele != end ; ele++)
00542     {
00543         ScalarType upper_dist = *new_upper_bound + query_chi->max_dist;
00544         if (shell(ele->dist, query_chi->parent_dist, upper_dist))
00545         {
00546             ScalarType d = distance(dcb, query_chi->p, ele->n->p, upper_dist);
00547 
00548             if (d <= upper_dist)
00549             {
00550                 if (d < *new_upper_bound) 
00551                     update(new_upper_bound, d);
00552                 d_node<P> temp = {d, ele->n};
00553                 push(new_zero_set,temp);
00554             }
00555         }
00556     }
00557 }
00558 
00559 template <class P, class DistanceCallback>
00560 inline void copy_cover_sets(DistanceCallback& dcb, node<P>* query_chi,
00561         ScalarType* new_upper_bound,
00562         v_array<v_array<d_node<P> > > &cover_sets,
00563         v_array<v_array<d_node<P> > > &new_cover_sets,
00564         int current_scale, int max_scale)
00565 {
00566     for (; current_scale <= max_scale; current_scale++)
00567     {
00568         d_node<P>* ele = cover_sets[current_scale].elements;
00569         d_node<P>* end = cover_sets[current_scale].elements + cover_sets[current_scale].index;
00570         for (; ele != end; ele++)
00571         { 
00572             ScalarType upper_dist = *new_upper_bound + query_chi->max_dist + ele->n->max_dist;
00573             if (shell(ele->dist, query_chi->parent_dist, upper_dist))
00574             {
00575                 ScalarType d = distance(dcb, query_chi->p, ele->n->p, upper_dist);
00576 
00577                 if (d <= upper_dist)
00578                 {
00579                     if (d < *new_upper_bound)
00580                         update(new_upper_bound,d);
00581                     d_node<P> temp = {d, ele->n};
00582                     push(new_cover_sets[current_scale],temp);
00583                 }
00584             }
00585         }
00586     }
00587 }
00588 
00589 template <class P>
00590 void print_query(const node<P> *top_node)
00591 {
00592     printf("query = \n");
00593     print(top_node->p);
00594     if ( top_node->num_children > 0 ) {
00595         printf("scale = %i\n",top_node->scale);
00596         printf("max_dist = %f\n",top_node->max_dist);
00597         printf("num children = %i\n",top_node->num_children);
00598     }
00599 }
00600 
00601 template <class P>
00602 void print_cover_sets(v_array<v_array<d_node<P> > > &cover_sets,
00603         v_array<d_node<P> > &zero_set,
00604         int current_scale, int max_scale)
00605 {
00606     printf("cover set = \n");
00607     for (; current_scale <= max_scale; current_scale++)
00608     {
00609         d_node<P> *ele = cover_sets[current_scale].elements;
00610         d_node<P> *end = cover_sets[current_scale].elements + cover_sets[current_scale].index;
00611         printf("%i\n", current_scale);
00612         for (; ele != end; ele++)
00613         {
00614             node<P> *n = (node<P> *)ele->n;
00615             print(n->p);
00616         }
00617     }
00618     d_node<P> *end = zero_set.elements + zero_set.index;
00619     printf("infinity\n");
00620     for (d_node<P> *ele = zero_set.elements; ele != end ; ele++)
00621     {
00622         node<P> *n = (node<P> *)ele->n;
00623         print(n->p);
00624     }
00625 }
00626 
00627 /*
00628    An optimization to consider:
00629    Make all distance evaluations occur in descend.
00630 
00631    Instead of passing a cover_set, pass a stack of cover sets.  The
00632    last element holds d_nodes with your distance.  The next lower
00633    element holds a d_node with the distance to your query parent,
00634    next = query grand parent, etc..
00635 
00636    Compute distances in the presence of the tighter upper bound.
00637    */
00638 template <class P, class DistanceCallback>
00639 inline 
00640 void descend(DistanceCallback& dcb, const node<P>* query, ScalarType* upper_bound,
00641         int current_scale,int &max_scale, v_array<v_array<d_node<P> > > &cover_sets,
00642         v_array<d_node<P> > &zero_set)
00643 {
00644     d_node<P> *end = cover_sets[current_scale].elements + cover_sets[current_scale].index;
00645     for (d_node<P> *parent = cover_sets[current_scale].elements; parent != end; parent++)
00646     {
00647         const node<P> *par = parent->n;
00648         ScalarType upper_dist = *upper_bound + query->max_dist + query->max_dist;
00649         if (parent->dist <= upper_dist + par->max_dist)
00650         {
00651             node<P> *chi = par->children;
00652             if (parent->dist <= upper_dist + chi->max_dist)
00653             {
00654                 if (chi->num_children > 0)
00655                 {
00656                     if (max_scale < chi->scale)
00657                         max_scale = chi->scale;
00658                     d_node<P> temp = {parent->dist, chi};
00659                     push(cover_sets[chi->scale], temp);
00660                 }
00661                 else if (parent->dist <= upper_dist)
00662                 {
00663                     d_node<P> temp = {parent->dist, chi};
00664                     push(zero_set, temp);
00665                 }
00666             }
00667             node<P> *child_end = par->children + par->num_children;
00668             for (chi++; chi != child_end; chi++)
00669             {
00670                 ScalarType upper_chi = *upper_bound + chi->max_dist + query->max_dist + query->max_dist;
00671                 if (shell(parent->dist, chi->parent_dist, upper_chi))
00672                 {
00673                     ScalarType d = distance(dcb, query->p, chi->p, upper_chi);
00674                     if (d <= upper_chi) 
00675                     {
00676                         if (d < *upper_bound)
00677                             update(upper_bound, d);
00678                         if (chi->num_children > 0)
00679                         {
00680                             if (max_scale < chi->scale)
00681                                 max_scale = chi->scale;
00682                             d_node<P> temp = {d, chi};
00683                             push(cover_sets[chi->scale],temp);
00684                         }
00685                         else 
00686                             if (d <= upper_chi - chi->max_dist)
00687                             {
00688                                 d_node<P> temp = {d, chi};
00689                                 push(zero_set, temp);
00690                             }
00691                     }
00692                 }
00693             }
00694         }
00695     }
00696 }
00697 
00698 template <class P, class DistanceCallback>
00699 void brute_nearest(DistanceCallback& dcb, const node<P>* query,
00700         v_array<d_node<P> > zero_set, ScalarType* upper_bound,
00701         v_array<v_array<P> > &results,
00702         v_array<v_array<d_node<P> > > &spare_zero_sets)
00703 {
00704     if (query->num_children > 0)
00705     {
00706         v_array<d_node<P> > new_zero_set = pop(spare_zero_sets);
00707         node<P> * query_chi = query->children; 
00708         brute_nearest(dcb, query_chi, zero_set, upper_bound, results, spare_zero_sets);
00709         ScalarType* new_upper_bound = alloc_upper();
00710 
00711         node<P> *child_end = query->children + query->num_children;
00712         for (query_chi++;query_chi != child_end; query_chi++)
00713         {
00714             setter(new_upper_bound,*upper_bound + query_chi->parent_dist);
00715             copy_zero_set(dcb, query_chi, new_upper_bound, zero_set, new_zero_set);
00716             brute_nearest(dcb, query_chi, new_zero_set, new_upper_bound, results, spare_zero_sets);
00717         }
00718         free (new_upper_bound);
00719         new_zero_set.index = 0;
00720         push(spare_zero_sets, new_zero_set);
00721     }
00722     else 
00723     {
00724         v_array<P> temp;
00725         push(temp, query->p);
00726         d_node<P> *end = zero_set.elements + zero_set.index;
00727         for (d_node<P> *ele = zero_set.elements; ele != end ; ele++)
00728             if (ele->dist <= *upper_bound) 
00729                 push(temp, ele->n->p);
00730         push(results,temp);
00731     }
00732 }
00733 
00734 template <class P, class DistanceCallback>
00735 void internal_batch_nearest_neighbor(DistanceCallback& dcb, const node<P> *query,
00736         v_array<v_array<d_node<P> > > &cover_sets,
00737         v_array<d_node<P> > &zero_set,
00738         int current_scale,
00739         int max_scale,
00740         ScalarType* upper_bound,
00741         v_array<v_array<P> > &results,
00742         v_array<v_array<v_array<d_node<P> > > > &spare_cover_sets,
00743         v_array<v_array<d_node<P> > > &spare_zero_sets)
00744 {
00745     if (current_scale > max_scale) // All remaining points are in the zero set. 
00746         brute_nearest(dcb, query, zero_set, upper_bound, results, spare_zero_sets);
00747     else
00748         if (query->scale <= current_scale && query->scale != 100) 
00749             // Our query has too much scale.  Reduce.
00750         { 
00751             node<P> *query_chi = query->children;
00752             v_array<d_node<P> > new_zero_set = pop(spare_zero_sets);
00753             v_array<v_array<d_node<P> > > new_cover_sets = get_cover_sets(spare_cover_sets);
00754             ScalarType* new_upper_bound = alloc_upper();
00755 
00756             node<P> *child_end = query->children + query->num_children;
00757             for (query_chi++; query_chi != child_end; query_chi++)
00758             {
00759                 setter(new_upper_bound,*upper_bound + query_chi->parent_dist);
00760                 copy_zero_set(dcb, query_chi, new_upper_bound, zero_set, new_zero_set);
00761                 copy_cover_sets(dcb, query_chi, new_upper_bound, cover_sets, new_cover_sets,
00762                         current_scale, max_scale);
00763                 internal_batch_nearest_neighbor(dcb, query_chi, new_cover_sets, new_zero_set,
00764                         current_scale, max_scale, new_upper_bound, 
00765                         results, spare_cover_sets, spare_zero_sets);
00766             }
00767             free (new_upper_bound);
00768             new_zero_set.index = 0;
00769             push(spare_zero_sets, new_zero_set);
00770             push(spare_cover_sets, new_cover_sets);
00771             internal_batch_nearest_neighbor(dcb, query->children, cover_sets, zero_set,
00772                     current_scale, max_scale, upper_bound, results, 
00773                     spare_cover_sets, spare_zero_sets);
00774         }
00775         else // reduce cover set scale
00776         {
00777             halfsort(cover_sets[current_scale]);
00778             descend(dcb, query, upper_bound, current_scale, max_scale,cover_sets, zero_set);
00779             cover_sets[current_scale++].index = 0;
00780             internal_batch_nearest_neighbor(dcb, query, cover_sets, zero_set,
00781                     current_scale, max_scale, upper_bound, results, 
00782                     spare_cover_sets, spare_zero_sets);
00783         }
00784 }
00785 
00786 template <class P, class DistanceCallback>
00787 void batch_nearest_neighbor(DistanceCallback &dcb, const node<P> &top_node,
00788         const node<P> &query, v_array<v_array<P> > &results)
00789 {
00790     v_array<v_array<v_array<d_node<P> > > > spare_cover_sets;
00791     v_array<v_array<d_node<P> > > spare_zero_sets;
00792 
00793     v_array<v_array<d_node<P> > > cover_sets = get_cover_sets(spare_cover_sets);
00794     v_array<d_node<P> > zero_set = pop(spare_zero_sets);
00795 
00796     ScalarType* upper_bound = alloc_upper();
00797     setter(upper_bound, std::numeric_limits<ScalarType>::max());
00798 
00799     ScalarType top_dist = distance(dcb, query.p, top_node.p, std::numeric_limits<ScalarType>::max());
00800     update(upper_bound, top_dist);
00801 
00802     d_node<P> temp = {top_dist, &top_node};
00803     push(cover_sets[0], temp);
00804 
00805     internal_batch_nearest_neighbor(dcb, &query,cover_sets,zero_set,0,0,upper_bound,results,
00806             spare_cover_sets,spare_zero_sets);
00807 
00808     free(upper_bound);
00809     push(spare_cover_sets, cover_sets);
00810 
00811     for (int i = 0; i < spare_cover_sets.index; i++)
00812     {
00813         v_array<v_array<d_node<P> > > cover_sets2 = spare_cover_sets[i];
00814         for (int j = 0; j < cover_sets2.index; j++)
00815             free (cover_sets2[j].elements);
00816         free(cover_sets2.elements);
00817     }
00818     free(spare_cover_sets.elements);
00819 
00820     push(spare_zero_sets, zero_set);
00821 
00822     for (int i = 0; i < spare_zero_sets.index; i++)
00823         free(spare_zero_sets[i].elements);
00824     free(spare_zero_sets.elements);
00825 }
00826 
00827 template <class P, class DistanceCallback>
00828 void k_nearest_neighbor(DistanceCallback &dcb, const node<P> &top_node,
00829         const node<P> &query, v_array<v_array<P> > &results, int k)
00830 {
00831     internal_k = k;
00832     update = update_k;
00833     setter = set_k;
00834     alloc_upper = alloc_k;
00835 
00836     batch_nearest_neighbor(dcb, top_node, query, results);
00837 }
00838 /*
00839 template <class P, class DistanceCallback>
00840 void epsilon_nearest_neighbor(DistanceCallback &dcb, const node<P> &top_node,
00841         const node<P> &query, v_array<v_array<P> > &results,
00842         ScalarType epsilon)
00843 {
00844     internal_epsilon = epsilon;
00845     //  update = update_epsilon;
00846     setter = set_epsilon;
00847     alloc_upper = alloc_epsilon;
00848 
00849     batch_nearest_neighbor(dcb, top_node, query, results);
00850 }
00851 
00852 template <class P, class DistanceCallback>
00853 void unequal_nearest_neighbor(DistanceCallback &dcb, const node<P> &top_node,
00854         const node<P> &query, v_array<v_array<P> > &results)
00855 {
00856     update = update_unequal;
00857     setter = set_unequal;
00858     alloc_upper = alloc_unequal;
00859 
00860     batch_nearest_neighbor(dcb, top_node, query, results);
00861 }
00862 */
00863 
00864 }
00865 }
00866 #endif
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Defines