1 /***********************************************************************
2  * Software License Agreement (BSD License)
3  *
4  * Copyright 2008-2009  Marius Muja (mariusm@cs.ubc.ca). All rights reserved.
5  * Copyright 2008-2009  David G. Lowe (lowe@cs.ubc.ca). All rights reserved.
6  *
7  * THE BSD LICENSE
8  *
9  * Redistribution and use in source and binary forms, with or without
10  * modification, are permitted provided that the following conditions
11  * are met:
12  *
13  * 1. Redistributions of source code must retain the above copyright
14  *    notice, this list of conditions and the following disclaimer.
15  * 2. Redistributions in binary form must reproduce the above copyright
16  *    notice, this list of conditions and the following disclaimer in the
17  *    documentation and/or other materials provided with the distribution.
18  *
19  * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR
20  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
21  * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
22  * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT,
23  * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT
24  * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
25  * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
26  * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF
28  * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29  *************************************************************************/
30 
31 #ifndef FLANN_KMEANS_INDEX_H_
32 #define FLANN_KMEANS_INDEX_H_
33 
34 #include <algorithm>
35 #include <string>
36 #include <map>
37 #include <cassert>
38 #include <limits>
39 #include <cmath>
40 
41 #include "flann/general.h"
42 #include "flann/algorithms/nn_index.h"
43 #include "flann/algorithms/dist.h"
44 #include "flann/algorithms/center_chooser.h"
45 #include "flann/util/matrix.h"
46 #include "flann/util/result_set.h"
47 #include "flann/util/heap.h"
48 #include "flann/util/allocator.h"
49 #include "flann/util/random.h"
50 #include "flann/util/saving.h"
51 #include "flann/util/logger.h"
52 
53 
54 
55 namespace flann
56 {
57 
58 struct KMeansIndexParams : public IndexParams
59 {
60     KMeansIndexParams(int branching = 32, int iterations = 11,
61                       flann_centers_init_t centers_init = FLANN_CENTERS_RANDOM, float cb_index = 0.2 )
62     {
63         (*this)["algorithm"] = FLANN_INDEX_KMEANS;
64         // branching factor
65         (*this)["branching"] = branching;
66         // max iterations to perform in one kmeans clustering (kmeans tree)
67         (*this)["iterations"] = iterations;
68         // algorithm used for picking the initial cluster centers for kmeans tree
69         (*this)["centers_init"] = centers_init;
70         // cluster boundary index. Used when searching the kmeans tree
71         (*this)["cb_index"] = cb_index;
72     }
73 };
74 
75 
76 /**
77  * Hierarchical kmeans index
78  *
79  * Contains a tree constructed through a hierarchical kmeans clustering
80  * and other information for indexing a set of points for nearest-neighbour matching.
81  */
82 template <typename Distance>
83 class KMeansIndex : public NNIndex<Distance>
84 {
85 public:
86     typedef typename Distance::ElementType ElementType;
87     typedef typename Distance::ResultType DistanceType;
88 
89     typedef NNIndex<Distance> BaseClass;
90 
91     typedef bool needs_vector_space_distance;
92 
93 
94 
getType()95     flann_algorithm_t getType() const
96     {
97         return FLANN_INDEX_KMEANS;
98     }
99 
100     /**
101      * Index constructor
102      *
103      * Params:
104      *          inputData = dataset with the input features
105      *          params = parameters passed to the hierarchical k-means algorithm
106      */
107     KMeansIndex(const Matrix<ElementType>& inputData, const IndexParams& params = KMeansIndexParams(),
108                 Distance d = Distance())
BaseClass(params,d)109         : BaseClass(params,d), root_(NULL), memoryCounter_(0)
110     {
111         branching_ = get_param(params,"branching",32);
112         iterations_ = get_param(params,"iterations",11);
113         if (iterations_<0) {
114             iterations_ = (std::numeric_limits<int>::max)();
115         }
116         centers_init_  = get_param(params,"centers_init",FLANN_CENTERS_RANDOM);
117         cb_index_  = get_param(params,"cb_index",0.4f);
118 
119         initCenterChooser();
120         chooseCenters_->setDataset(inputData);
121 
122         setDataset(inputData);
123     }
124 
125 
126     /**
127      * Index constructor
128      *
129      * Params:
130      *          inputData = dataset with the input features
131      *          params = parameters passed to the hierarchical k-means algorithm
132      */
133     KMeansIndex(const IndexParams& params = KMeansIndexParams(), Distance d = Distance())
BaseClass(params,d)134         : BaseClass(params, d), root_(NULL), memoryCounter_(0)
135     {
136         branching_ = get_param(params,"branching",32);
137         iterations_ = get_param(params,"iterations",11);
138         if (iterations_<0) {
139             iterations_ = (std::numeric_limits<int>::max)();
140         }
141         centers_init_  = get_param(params,"centers_init",FLANN_CENTERS_RANDOM);
142         cb_index_  = get_param(params,"cb_index",0.4f);
143 
144         initCenterChooser();
145     }
146 
147 
KMeansIndex(const KMeansIndex & other)148     KMeansIndex(const KMeansIndex& other) : BaseClass(other),
149     		branching_(other.branching_),
150     		iterations_(other.iterations_),
151     		centers_init_(other.centers_init_),
152     		cb_index_(other.cb_index_),
153     		memoryCounter_(other.memoryCounter_)
154     {
155     	initCenterChooser();
156 
157     	copyTree(root_, other.root_);
158     }
159 
160     KMeansIndex& operator=(KMeansIndex other)
161     {
162     	this->swap(other);
163     	return *this;
164     }
165 
166 
initCenterChooser()167     void initCenterChooser()
168     {
169         switch(centers_init_) {
170         case FLANN_CENTERS_RANDOM:
171         	chooseCenters_ = new RandomCenterChooser<Distance>(distance_);
172         	break;
173         case FLANN_CENTERS_GONZALES:
174         	chooseCenters_ = new GonzalesCenterChooser<Distance>(distance_);
175         	break;
176         case FLANN_CENTERS_KMEANSPP:
177             chooseCenters_ = new KMeansppCenterChooser<Distance>(distance_);
178         	break;
179         default:
180             throw FLANNException("Unknown algorithm for choosing initial centers.");
181         }
182     }
183 
184     /**
185      * Index destructor.
186      *
187      * Release the memory used by the index.
188      */
~KMeansIndex()189     virtual ~KMeansIndex()
190     {
191     	delete chooseCenters_;
192     	freeIndex();
193     }
194 
clone()195     BaseClass* clone() const
196     {
197     	return new KMeansIndex(*this);
198     }
199 
200 
set_cb_index(float index)201     void set_cb_index( float index)
202     {
203         cb_index_ = index;
204     }
205 
206     /**
207      * Computes the inde memory usage
208      * Returns: memory used by the index
209      */
usedMemory()210     int usedMemory() const
211     {
212         return pool_.usedMemory+pool_.wastedMemory+memoryCounter_;
213     }
214 
215     using BaseClass::buildIndex;
216 
217     void addPoints(const Matrix<ElementType>& points, float rebuild_threshold = 2)
218     {
219         assert(points.cols==veclen_);
220         size_t old_size = size_;
221 
222         extendDataset(points);
223 
224         if (rebuild_threshold>1 && size_at_build_*rebuild_threshold<size_) {
225             buildIndex();
226         }
227         else {
228             for (size_t i=0;i<points.rows;++i) {
229                 DistanceType dist = distance_(root_->pivot, points[i], veclen_);
230                 addPointToTree(root_, old_size + i, dist);
231             }
232         }
233     }
234 
235     template<typename Archive>
serialize(Archive & ar)236     void serialize(Archive& ar)
237     {
238     	ar.setObject(this);
239 
240     	ar & *static_cast<NNIndex<Distance>*>(this);
241 
242     	ar & branching_;
243     	ar & iterations_;
244     	ar & memoryCounter_;
245     	ar & cb_index_;
246     	ar & centers_init_;
247 
248     	if (Archive::is_loading::value) {
249     		root_ = new(pool_) Node();
250     	}
251     	ar & *root_;
252 
253     	if (Archive::is_loading::value) {
254             index_params_["algorithm"] = getType();
255             index_params_["branching"] = branching_;
256             index_params_["iterations"] = iterations_;
257             index_params_["centers_init"] = centers_init_;
258             index_params_["cb_index"] = cb_index_;
259     	}
260     }
261 
saveIndex(FILE * stream)262     void saveIndex(FILE* stream)
263     {
264     	serialization::SaveArchive sa(stream);
265     	sa & *this;
266     }
267 
loadIndex(FILE * stream)268     void loadIndex(FILE* stream)
269     {
270     	freeIndex();
271     	serialization::LoadArchive la(stream);
272     	la & *this;
273     }
274 
275     /**
276      * Find set of nearest neighbors to vec. Their indices are stored inside
277      * the result object.
278      *
279      * Params:
280      *     result = the result object in which the indices of the nearest-neighbors are stored
281      *     vec = the vector for which to search the nearest neighbors
282      *     searchParams = parameters that influence the search algorithm (checks, cb_index)
283      */
284 
findNeighbors(ResultSet<DistanceType> & result,const ElementType * vec,const SearchParams & searchParams)285     void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams) const
286     {
287     	if (removed_) {
288     		findNeighborsWithRemoved<true>(result, vec, searchParams);
289     	}
290     	else {
291     		findNeighborsWithRemoved<false>(result, vec, searchParams);
292     	}
293 
294     }
295 
296     /**
297      * Clustering function that takes a cut in the hierarchical k-means
298      * tree and return the clusters centers of that clustering.
299      * Params:
300      *     numClusters = number of clusters to have in the clustering computed
301      * Returns: number of cluster centers
302      */
getClusterCenters(Matrix<DistanceType> & centers)303     int getClusterCenters(Matrix<DistanceType>& centers)
304     {
305         int numClusters = centers.rows;
306         if (numClusters<1) {
307             throw FLANNException("Number of clusters must be at least 1");
308         }
309 
310         DistanceType variance;
311         std::vector<NodePtr> clusters(numClusters);
312 
313         int clusterCount = getMinVarianceClusters(root_, clusters, numClusters, variance);
314 
315         Logger::info("Clusters requested: %d, returning %d\n",numClusters, clusterCount);
316 
317         for (int i=0; i<clusterCount; ++i) {
318             DistanceType* center = clusters[i]->pivot;
319             for (size_t j=0; j<veclen_; ++j) {
320                 centers[i][j] = center[j];
321             }
322         }
323 
324         return clusterCount;
325     }
326 
327 protected:
328     /**
329      * Builds the index
330      */
buildIndexImpl()331     void buildIndexImpl()
332     {
333         if (branching_<2) {
334             throw FLANNException("Branching factor must be at least 2");
335         }
336 
337         std::vector<int> indices(size_);
338         for (size_t i=0; i<size_; ++i) {
339         	indices[i] = int(i);
340         }
341 
342         root_ = new(pool_) Node();
343         computeNodeStatistics(root_, indices);
344         computeClustering(root_, &indices[0], (int)size_, branching_);
345     }
346 
347 private:
348 
349     struct PointInfo
350     {
351     	size_t index;
352     	ElementType* point;
353     private:
354     	template<typename Archive>
serializePointInfo355     	void serialize(Archive& ar)
356     	{
357     		typedef KMeansIndex<Distance> Index;
358     		Index* obj = static_cast<Index*>(ar.getObject());
359 
360     		ar & index;
361 //    		ar & point;
362 
363 			if (Archive::is_loading::value) point = obj->points_[index];
364     	}
365     	friend struct serialization::access;
366     };
367 
368     /**
369      * Struture representing a node in the hierarchical k-means tree.
370      */
371     struct Node
372     {
373         /**
374          * The cluster center.
375          */
376         DistanceType* pivot;
377         /**
378          * The cluster radius.
379          */
380         DistanceType radius;
381         /**
382          * The cluster variance.
383          */
384         DistanceType variance;
385         /**
386          * The cluster size (number of points in the cluster)
387          */
388         int size;
389         /**
390          * Child nodes (only for non-terminal nodes)
391          */
392         std::vector<Node*> childs;
393         /**
394          * Node points (only for terminal nodes)
395          */
396         std::vector<PointInfo> points;
397         /**
398          * Level
399          */
400 //        int level;
401 
~NodeNode402         ~Node()
403         {
404             delete[] pivot;
405             if (!childs.empty()) {
406                 for (size_t i=0; i<childs.size(); ++i) {
407                     childs[i]->~Node();
408                 }
409             }
410         }
411 
412     	template<typename Archive>
serializeNode413     	void serialize(Archive& ar)
414     	{
415     		typedef KMeansIndex<Distance> Index;
416     		Index* obj = static_cast<Index*>(ar.getObject());
417 
418     		if (Archive::is_loading::value) {
419     			pivot = new DistanceType[obj->veclen_];
420     		}
421     		ar & serialization::make_binary_object(pivot, obj->veclen_*sizeof(DistanceType));
422     		ar & radius;
423     		ar & variance;
424     		ar & size;
425 
426     		size_t childs_size;
427     		if (Archive::is_saving::value) {
428     			childs_size = childs.size();
429     		}
430     		ar & childs_size;
431 
432     		if (childs_size==0) {
433     			ar & points;
434     		}
435     		else {
436     			if (Archive::is_loading::value) {
437     				childs.resize(childs_size);
438     			}
439     			for (size_t i=0;i<childs_size;++i) {
440     				if (Archive::is_loading::value) {
441     					childs[i] = new(obj->pool_) Node();
442     				}
443     				ar & *childs[i];
444     			}
445     		}
446     	}
447     	friend struct serialization::access;
448     };
449     typedef Node* NodePtr;
450 
451     /**
452      * Alias definition for a nicer syntax.
453      */
454     typedef BranchStruct<NodePtr, DistanceType> BranchSt;
455 
456 
457     /**
458      * Helper function
459      */
freeIndex()460     void freeIndex()
461     {
462     	if (root_) root_->~Node();
463     	root_ = NULL;
464     	pool_.free();
465     }
466 
copyTree(NodePtr & dst,const NodePtr & src)467     void copyTree(NodePtr& dst, const NodePtr& src)
468     {
469     	dst = new(pool_) Node();
470     	dst->pivot = new DistanceType[veclen_];
471     	std::copy(src->pivot, src->pivot+veclen_, dst->pivot);
472     	dst->radius = src->radius;
473     	dst->variance = src->variance;
474     	dst->size = src->size;
475 
476     	if (src->childs.size()==0) {
477     		dst->points = src->points;
478     	}
479     	else {
480     		dst->childs.resize(src->childs.size());
481     		for (size_t i=0;i<src->childs.size();++i) {
482     			copyTree(dst->childs[i], src->childs[i]);
483     		}
484     	}
485     }
486 
487 
488     /**
489      * Computes the statistics of a node (mean, radius, variance).
490      *
491      * Params:
492      *     node = the node to use
493      *     indices = the indices of the points belonging to the node
494      */
computeNodeStatistics(NodePtr node,const std::vector<int> & indices)495     void computeNodeStatistics(NodePtr node, const std::vector<int>& indices)
496     {
497         size_t size = indices.size();
498 
499         DistanceType* mean = new DistanceType[veclen_];
500         memoryCounter_ += int(veclen_*sizeof(DistanceType));
501         memset(mean,0,veclen_*sizeof(DistanceType));
502 
503         for (size_t i=0; i<size; ++i) {
504             ElementType* vec = points_[indices[i]];
505             for (size_t j=0; j<veclen_; ++j) {
506                 mean[j] += vec[j];
507             }
508         }
509         DistanceType div_factor = DistanceType(1)/size;
510         for (size_t j=0; j<veclen_; ++j) {
511             mean[j] *= div_factor;
512         }
513 
514         DistanceType radius = 0;
515         DistanceType variance = 0;
516         for (size_t i=0; i<size; ++i) {
517             DistanceType dist = distance_(mean, points_[indices[i]], veclen_);
518             if (dist>radius) {
519                 radius = dist;
520             }
521             variance += dist;
522         }
523         variance /= size;
524 
525         node->variance = variance;
526         node->radius = radius;
527         node->pivot = mean;
528     }
529 
530 
531     /**
532      * The method responsible with actually doing the recursive hierarchical
533      * clustering
534      *
535      * Params:
536      *     node = the node to cluster
537      *     indices = indices of the points belonging to the current node
538      *     branching = the branching factor to use in the clustering
539      *
540      * TODO: for 1-sized clusters don't store a cluster center (it's the same as the single cluster point)
541      */
computeClustering(NodePtr node,int * indices,int indices_length,int branching)542     void computeClustering(NodePtr node, int* indices, int indices_length, int branching)
543     {
544         node->size = indices_length;
545 
546         if (indices_length < branching) {
547             node->points.resize(indices_length);
548             for (int i=0;i<indices_length;++i) {
549             	node->points[i].index = indices[i];
550             	node->points[i].point = points_[indices[i]];
551             }
552             node->childs.clear();
553             return;
554         }
555 
556         std::vector<int> centers_idx(branching);
557         int centers_length;
558         (*chooseCenters_)(branching, indices, indices_length, &centers_idx[0], centers_length);
559 
560         if (centers_length<branching) {
561             node->points.resize(indices_length);
562             for (int i=0;i<indices_length;++i) {
563             	node->points[i].index = indices[i];
564             	node->points[i].point = points_[indices[i]];
565             }
566             node->childs.clear();
567             return;
568         }
569 
570 
571         Matrix<double> dcenters(new double[branching*veclen_],branching,veclen_);
572         for (int i=0; i<centers_length; ++i) {
573             ElementType* vec = points_[centers_idx[i]];
574             for (size_t k=0; k<veclen_; ++k) {
575                 dcenters[i][k] = double(vec[k]);
576             }
577         }
578 
579         std::vector<DistanceType> radiuses(branching,0);
580         std::vector<int> count(branching,0);
581 
582         //	assign points to clusters
583         std::vector<int> belongs_to(indices_length);
584         for (int i=0; i<indices_length; ++i) {
585 
586             DistanceType sq_dist = distance_(points_[indices[i]], dcenters[0], veclen_);
587             belongs_to[i] = 0;
588             for (int j=1; j<branching; ++j) {
589                 DistanceType new_sq_dist = distance_(points_[indices[i]], dcenters[j], veclen_);
590                 if (sq_dist>new_sq_dist) {
591                     belongs_to[i] = j;
592                     sq_dist = new_sq_dist;
593                 }
594             }
595             if (sq_dist>radiuses[belongs_to[i]]) {
596                 radiuses[belongs_to[i]] = sq_dist;
597             }
598             count[belongs_to[i]]++;
599         }
600 
601         bool converged = false;
602         int iteration = 0;
603         while (!converged && iteration<iterations_) {
604             converged = true;
605             iteration++;
606 
607             // compute the new cluster centers
608             for (int i=0; i<branching; ++i) {
609                 memset(dcenters[i],0,sizeof(double)*veclen_);
610                 radiuses[i] = 0;
611             }
612             for (int i=0; i<indices_length; ++i) {
613                 ElementType* vec = points_[indices[i]];
614                 double* center = dcenters[belongs_to[i]];
615                 for (size_t k=0; k<veclen_; ++k) {
616                     center[k] += vec[k];
617                 }
618             }
619             for (int i=0; i<branching; ++i) {
620                 int cnt = count[i];
621                 double div_factor = 1.0/cnt;
622                 for (size_t k=0; k<veclen_; ++k) {
623                     dcenters[i][k] *= div_factor;
624                 }
625             }
626 
627             // reassign points to clusters
628             for (int i=0; i<indices_length; ++i) {
629                 DistanceType sq_dist = distance_(points_[indices[i]], dcenters[0], veclen_);
630                 int new_centroid = 0;
631                 for (int j=1; j<branching; ++j) {
632                     DistanceType new_sq_dist = distance_(points_[indices[i]], dcenters[j], veclen_);
633                     if (sq_dist>new_sq_dist) {
634                         new_centroid = j;
635                         sq_dist = new_sq_dist;
636                     }
637                 }
638                 if (sq_dist>radiuses[new_centroid]) {
639                     radiuses[new_centroid] = sq_dist;
640                 }
641                 if (new_centroid != belongs_to[i]) {
642                     count[belongs_to[i]]--;
643                     count[new_centroid]++;
644                     belongs_to[i] = new_centroid;
645 
646                     converged = false;
647                 }
648             }
649 
650             for (int i=0; i<branching; ++i) {
651                 // if one cluster converges to an empty cluster,
652                 // move an element into that cluster
653                 if (count[i]==0) {
654                     int j = (i+1)%branching;
655                     while (count[j]<=1) {
656                         j = (j+1)%branching;
657                     }
658 
659                     for (int k=0; k<indices_length; ++k) {
660                         if (belongs_to[k]==j) {
661                             belongs_to[k] = i;
662                             count[j]--;
663                             count[i]++;
664                             break;
665                         }
666                     }
667                     converged = false;
668                 }
669             }
670 
671         }
672 
673         std::vector<DistanceType*> centers(branching);
674 
675         for (int i=0; i<branching; ++i) {
676             centers[i] = new DistanceType[veclen_];
677             memoryCounter_ += veclen_*sizeof(DistanceType);
678             for (size_t k=0; k<veclen_; ++k) {
679                 centers[i][k] = (DistanceType)dcenters[i][k];
680             }
681         }
682 
683 
684         // compute kmeans clustering for each of the resulting clusters
685         node->childs.resize(branching);
686         int start = 0;
687         int end = start;
688         for (int c=0; c<branching; ++c) {
689             int s = count[c];
690 
691             DistanceType variance = 0;
692             for (int i=0; i<indices_length; ++i) {
693                 if (belongs_to[i]==c) {
694                     variance += distance_(centers[c], points_[indices[i]], veclen_);
695                     std::swap(indices[i],indices[end]);
696                     std::swap(belongs_to[i],belongs_to[end]);
697                     end++;
698                 }
699             }
700             variance /= s;
701 
702             node->childs[c] = new(pool_) Node();
703             node->childs[c]->radius = radiuses[c];
704             node->childs[c]->pivot = centers[c];
705             node->childs[c]->variance = variance;
706             computeClustering(node->childs[c],indices+start, end-start, branching);
707             start=end;
708         }
709 
710         delete[] dcenters.ptr();
711     }
712 
713 
714     template<bool with_removed>
findNeighborsWithRemoved(ResultSet<DistanceType> & result,const ElementType * vec,const SearchParams & searchParams)715     void findNeighborsWithRemoved(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams) const
716     {
717 
718         int maxChecks = searchParams.checks;
719 
720         if (maxChecks==FLANN_CHECKS_UNLIMITED) {
721             findExactNN<with_removed>(root_, result, vec);
722         }
723         else {
724             // Priority queue storing intermediate branches in the best-bin-first search
725             Heap<BranchSt>* heap = new Heap<BranchSt>((int)size_);
726 
727             int checks = 0;
728             findNN<with_removed>(root_, result, vec, checks, maxChecks, heap);
729 
730             BranchSt branch;
731             while (heap->popMin(branch) && (checks<maxChecks || !result.full())) {
732                 NodePtr node = branch.node;
733                 findNN<with_removed>(node, result, vec, checks, maxChecks, heap);
734             }
735 
736             delete heap;
737         }
738 
739     }
740 
741 
742     /**
743      * Performs one descent in the hierarchical k-means tree. The branches not
744      * visited are stored in a priority queue.
745      *
746      * Params:
747      *      node = node to explore
748      *      result = container for the k-nearest neighbors found
749      *      vec = query points
750      *      checks = how many points in the dataset have been checked so far
751      *      maxChecks = maximum dataset points to checks
752      */
753 
754     template<bool with_removed>
findNN(NodePtr node,ResultSet<DistanceType> & result,const ElementType * vec,int & checks,int maxChecks,Heap<BranchSt> * heap)755     void findNN(NodePtr node, ResultSet<DistanceType>& result, const ElementType* vec, int& checks, int maxChecks,
756                 Heap<BranchSt>* heap) const
757     {
758         // Ignore those clusters that are too far away
759         {
760             DistanceType bsq = distance_(vec, node->pivot, veclen_);
761             DistanceType rsq = node->radius;
762             DistanceType wsq = result.worstDist();
763 
764             DistanceType val = bsq-rsq-wsq;
765             DistanceType val2 = val*val-4*rsq*wsq;
766 
767             //if (val>0) {
768             if ((val>0)&&(val2>0)) {
769                 return;
770             }
771         }
772 
773         if (node->childs.empty()) {
774             if (checks>=maxChecks) {
775                 if (result.full()) return;
776             }
777             for (int i=0; i<node->size; ++i) {
778             	PointInfo& point_info = node->points[i];
779                 int index = point_info.index;
780                 if (with_removed) {
781                 	if (removed_points_.test(index)) continue;
782                 }
783                 DistanceType dist = distance_(point_info.point, vec, veclen_);
784                 result.addPoint(dist, index);
785                 ++checks;
786             }
787         }
788         else {
789             int closest_center = exploreNodeBranches(node, vec, heap);
790             findNN<with_removed>(node->childs[closest_center],result,vec, checks, maxChecks, heap);
791         }
792     }
793 
794     /**
795      * Helper function that computes the nearest childs of a node to a given query point.
796      * Params:
797      *     node = the node
798      *     q = the query point
799      *     distances = array with the distances to each child node.
800      * Returns:
801      */
exploreNodeBranches(NodePtr node,const ElementType * q,Heap<BranchSt> * heap)802     int exploreNodeBranches(NodePtr node, const ElementType* q, Heap<BranchSt>* heap) const
803     {
804         std::vector<DistanceType> domain_distances(branching_);
805         int best_index = 0;
806         domain_distances[best_index] = distance_(q, node->childs[best_index]->pivot, veclen_);
807         for (int i=1; i<branching_; ++i) {
808             domain_distances[i] = distance_(q, node->childs[i]->pivot, veclen_);
809             if (domain_distances[i]<domain_distances[best_index]) {
810                 best_index = i;
811             }
812         }
813 
814         //		float* best_center = node->childs[best_index]->pivot;
815         for (int i=0; i<branching_; ++i) {
816             if (i != best_index) {
817                 domain_distances[i] -= cb_index_*node->childs[i]->variance;
818 
819                 //				float dist_to_border = getDistanceToBorder(node.childs[i].pivot,best_center,q);
820                 //				if (domain_distances[i]<dist_to_border) {
821                 //					domain_distances[i] = dist_to_border;
822                 //				}
823                 heap->insert(BranchSt(node->childs[i],domain_distances[i]));
824             }
825         }
826 
827         return best_index;
828     }
829 
830 
831     /**
832      * Function the performs exact nearest neighbor search by traversing the entire tree.
833      */
834     template<bool with_removed>
findExactNN(NodePtr node,ResultSet<DistanceType> & result,const ElementType * vec)835     void findExactNN(NodePtr node, ResultSet<DistanceType>& result, const ElementType* vec) const
836     {
837         // Ignore those clusters that are too far away
838         {
839             DistanceType bsq = distance_(vec, node->pivot, veclen_);
840             DistanceType rsq = node->radius;
841             DistanceType wsq = result.worstDist();
842 
843             DistanceType val = bsq-rsq-wsq;
844             DistanceType val2 = val*val-4*rsq*wsq;
845 
846             //                  if (val>0) {
847             if ((val>0)&&(val2>0)) {
848                 return;
849             }
850         }
851 
852         if (node->childs.empty()) {
853             for (int i=0; i<node->size; ++i) {
854             	PointInfo& point_info = node->points[i];
855                 int index = point_info.index;
856                 if (with_removed) {
857                 	if (removed_points_.test(index)) continue;
858                 }
859                 DistanceType dist = distance_(point_info.point, vec, veclen_);
860                 result.addPoint(dist, index);
861             }
862         }
863         else {
864             std::vector<int> sort_indices(branching_);
865             getCenterOrdering(node, vec, sort_indices);
866 
867             for (int i=0; i<branching_; ++i) {
868                 findExactNN<with_removed>(node->childs[sort_indices[i]],result,vec);
869             }
870 
871         }
872     }
873 
874 
875     /**
876      * Helper function.
877      *
878      * I computes the order in which to traverse the child nodes of a particular node.
879      */
getCenterOrdering(NodePtr node,const ElementType * q,std::vector<int> & sort_indices)880     void getCenterOrdering(NodePtr node, const ElementType* q, std::vector<int>& sort_indices) const
881     {
882         std::vector<DistanceType> domain_distances(branching_);
883         for (int i=0; i<branching_; ++i) {
884             DistanceType dist = distance_(q, node->childs[i]->pivot, veclen_);
885 
886             int j=0;
887             while (domain_distances[j]<dist && j<i) j++;
888             for (int k=i; k>j; --k) {
889                 domain_distances[k] = domain_distances[k-1];
890                 sort_indices[k] = sort_indices[k-1];
891             }
892             domain_distances[j] = dist;
893             sort_indices[j] = i;
894         }
895     }
896 
897     /**
898      * Method that computes the squared distance from the query point q
899      * from inside region with center c to the border between this
900      * region and the region with center p
901      */
getDistanceToBorder(DistanceType * p,DistanceType * c,DistanceType * q)902     DistanceType getDistanceToBorder(DistanceType* p, DistanceType* c, DistanceType* q) const
903     {
904         DistanceType sum = 0;
905         DistanceType sum2 = 0;
906 
907         for (int i=0; i<veclen_; ++i) {
908             DistanceType t = c[i]-p[i];
909             sum += t*(q[i]-(c[i]+p[i])/2);
910             sum2 += t*t;
911         }
912 
913         return sum*sum/sum2;
914     }
915 
916 
917     /**
918      * Helper function the descends in the hierarchical k-means tree by spliting those clusters that minimize
919      * the overall variance of the clustering.
920      * Params:
921      *     root = root node
922      *     clusters = array with clusters centers (return value)
923      *     varianceValue = variance of the clustering (return value)
924      * Returns:
925      */
getMinVarianceClusters(NodePtr root,std::vector<NodePtr> & clusters,int clusters_length,DistanceType & varianceValue)926     int getMinVarianceClusters(NodePtr root, std::vector<NodePtr>& clusters, int clusters_length, DistanceType& varianceValue) const
927     {
928         int clusterCount = 1;
929         clusters[0] = root;
930 
931         DistanceType meanVariance = root->variance*root->size;
932 
933         while (clusterCount<clusters_length) {
934             DistanceType minVariance = (std::numeric_limits<DistanceType>::max)();
935             int splitIndex = -1;
936 
937             for (int i=0; i<clusterCount; ++i) {
938                 if (!clusters[i]->childs.empty()) {
939 
940                     DistanceType variance = meanVariance - clusters[i]->variance*clusters[i]->size;
941 
942                     for (int j=0; j<branching_; ++j) {
943                         variance += clusters[i]->childs[j]->variance*clusters[i]->childs[j]->size;
944                     }
945                     if (variance<minVariance) {
946                         minVariance = variance;
947                         splitIndex = i;
948                     }
949                 }
950             }
951 
952             if (splitIndex==-1) break;
953             if ( (branching_+clusterCount-1) > clusters_length) break;
954 
955             meanVariance = minVariance;
956 
957             // split node
958             NodePtr toSplit = clusters[splitIndex];
959             clusters[splitIndex] = toSplit->childs[0];
960             for (int i=1; i<branching_; ++i) {
961                 clusters[clusterCount++] = toSplit->childs[i];
962             }
963         }
964 
965         varianceValue = meanVariance/root->size;
966         return clusterCount;
967     }
968 
addPointToTree(NodePtr node,size_t index,DistanceType dist_to_pivot)969     void addPointToTree(NodePtr node, size_t index, DistanceType dist_to_pivot)
970     {
971         ElementType* point = points_[index];
972         if (dist_to_pivot>node->radius) {
973             node->radius = dist_to_pivot;
974         }
975         // if radius changed above, the variance will be an approximation
976         node->variance = (node->size*node->variance+dist_to_pivot)/(node->size+1);
977         node->size++;
978 
979         if (node->childs.empty()) { // leaf node
980         	PointInfo point_info;
981         	point_info.index = index;
982         	point_info.point = point;
983         	node->points.push_back(point_info);
984 
985             std::vector<int> indices(node->points.size());
986             for (size_t i=0;i<node->points.size();++i) {
987             	indices[i] = node->points[i].index;
988             }
989             computeNodeStatistics(node, indices);
990             if (indices.size()>=size_t(branching_)) {
991                 computeClustering(node, &indices[0], indices.size(), branching_);
992             }
993         }
994         else {
995             // find the closest child
996             int closest = 0;
997             DistanceType dist = distance_(node->childs[closest]->pivot, point, veclen_);
998             for (size_t i=1;i<size_t(branching_);++i) {
999                 DistanceType crt_dist = distance_(node->childs[i]->pivot, point, veclen_);
1000                 if (crt_dist<dist) {
1001                     dist = crt_dist;
1002                     closest = i;
1003                 }
1004             }
1005             addPointToTree(node->childs[closest], index, dist);
1006         }
1007     }
1008 
1009 
swap(KMeansIndex & other)1010     void swap(KMeansIndex& other)
1011     {
1012     	std::swap(branching_, other.branching_);
1013     	std::swap(iterations_, other.iterations_);
1014     	std::swap(centers_init_, other.centers_init_);
1015     	std::swap(cb_index_, other.cb_index_);
1016     	std::swap(root_, other.root_);
1017     	std::swap(pool_, other.pool_);
1018     	std::swap(memoryCounter_, other.memoryCounter_);
1019     	std::swap(chooseCenters_, other.chooseCenters_);
1020     }
1021 
1022 
1023 private:
1024     /** The branching factor used in the hierarchical k-means clustering */
1025     int branching_;
1026 
1027     /** Maximum number of iterations to use when performing k-means clustering */
1028     int iterations_;
1029 
1030     /** Algorithm for choosing the cluster centers */
1031     flann_centers_init_t centers_init_;
1032 
1033     /**
1034      * Cluster border index. This is used in the tree search phase when determining
1035      * the closest cluster to explore next. A zero value takes into account only
1036      * the cluster centres, a value greater then zero also take into account the size
1037      * of the cluster.
1038      */
1039     float cb_index_;
1040 
1041     /**
1042      * The root node in the tree.
1043      */
1044     NodePtr root_;
1045 
1046     /**
1047      * Pooled memory allocator.
1048      */
1049     PooledAllocator pool_;
1050 
1051     /**
1052      * Memory occupied by the index.
1053      */
1054     int memoryCounter_;
1055 
1056     /**
1057      * Algorithm used to choose initial centers
1058      */
1059     CenterChooser<Distance>* chooseCenters_;
1060 
1061     USING_BASECLASS_SYMBOLS
1062 };
1063 
1064 }
1065 
1066 #endif //FLANN_KMEANS_INDEX_H_
1067