1 /*********************************************************************** 2 * Software License Agreement (BSD License) 3 * 4 * Copyright 2008-2009 Marius Muja (mariusm@cs.ubc.ca). All rights reserved. 5 * Copyright 2008-2009 David G. Lowe (lowe@cs.ubc.ca). All rights reserved. 6 * 7 * THE BSD LICENSE 8 * 9 * Redistribution and use in source and binary forms, with or without 10 * modification, are permitted provided that the following conditions 11 * are met: 12 * 13 * 1. Redistributions of source code must retain the above copyright 14 * notice, this list of conditions and the following disclaimer. 15 * 2. Redistributions in binary form must reproduce the above copyright 16 * notice, this list of conditions and the following disclaimer in the 17 * documentation and/or other materials provided with the distribution. 18 * 19 * THIS SOFTWARE IS PROVIDED BY THE AUTHOR ``AS IS'' AND ANY EXPRESS OR 20 * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES 21 * OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. 22 * IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY DIRECT, INDIRECT, 23 * INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT 24 * NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, 25 * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY 26 * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 27 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF 28 * THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 29 *************************************************************************/ 30 31 #ifndef FLANN_KMEANS_INDEX_H_ 32 #define FLANN_KMEANS_INDEX_H_ 33 34 #include <algorithm> 35 #include <string> 36 #include <map> 37 #include <cassert> 38 #include <limits> 39 #include <cmath> 40 41 #include "flann/general.h" 42 #include "flann/algorithms/nn_index.h" 43 #include "flann/algorithms/dist.h" 44 #include "flann/algorithms/center_chooser.h" 45 #include "flann/util/matrix.h" 46 #include "flann/util/result_set.h" 47 #include "flann/util/heap.h" 48 #include "flann/util/allocator.h" 49 #include "flann/util/random.h" 50 #include "flann/util/saving.h" 51 #include "flann/util/logger.h" 52 53 54 55 namespace flann 56 { 57 58 struct KMeansIndexParams : public IndexParams 59 { 60 KMeansIndexParams(int branching = 32, int iterations = 11, 61 flann_centers_init_t centers_init = FLANN_CENTERS_RANDOM, float cb_index = 0.2 ) 62 { 63 (*this)["algorithm"] = FLANN_INDEX_KMEANS; 64 // branching factor 65 (*this)["branching"] = branching; 66 // max iterations to perform in one kmeans clustering (kmeans tree) 67 (*this)["iterations"] = iterations; 68 // algorithm used for picking the initial cluster centers for kmeans tree 69 (*this)["centers_init"] = centers_init; 70 // cluster boundary index. Used when searching the kmeans tree 71 (*this)["cb_index"] = cb_index; 72 } 73 }; 74 75 76 /** 77 * Hierarchical kmeans index 78 * 79 * Contains a tree constructed through a hierarchical kmeans clustering 80 * and other information for indexing a set of points for nearest-neighbour matching. 81 */ 82 template <typename Distance> 83 class KMeansIndex : public NNIndex<Distance> 84 { 85 public: 86 typedef typename Distance::ElementType ElementType; 87 typedef typename Distance::ResultType DistanceType; 88 89 typedef NNIndex<Distance> BaseClass; 90 91 typedef bool needs_vector_space_distance; 92 93 94 getType()95 flann_algorithm_t getType() const 96 { 97 return FLANN_INDEX_KMEANS; 98 } 99 100 /** 101 * Index constructor 102 * 103 * Params: 104 * inputData = dataset with the input features 105 * params = parameters passed to the hierarchical k-means algorithm 106 */ 107 KMeansIndex(const Matrix<ElementType>& inputData, const IndexParams& params = KMeansIndexParams(), 108 Distance d = Distance()) BaseClass(params,d)109 : BaseClass(params,d), root_(NULL), memoryCounter_(0) 110 { 111 branching_ = get_param(params,"branching",32); 112 iterations_ = get_param(params,"iterations",11); 113 if (iterations_<0) { 114 iterations_ = (std::numeric_limits<int>::max)(); 115 } 116 centers_init_ = get_param(params,"centers_init",FLANN_CENTERS_RANDOM); 117 cb_index_ = get_param(params,"cb_index",0.4f); 118 119 initCenterChooser(); 120 chooseCenters_->setDataset(inputData); 121 122 setDataset(inputData); 123 } 124 125 126 /** 127 * Index constructor 128 * 129 * Params: 130 * inputData = dataset with the input features 131 * params = parameters passed to the hierarchical k-means algorithm 132 */ 133 KMeansIndex(const IndexParams& params = KMeansIndexParams(), Distance d = Distance()) BaseClass(params,d)134 : BaseClass(params, d), root_(NULL), memoryCounter_(0) 135 { 136 branching_ = get_param(params,"branching",32); 137 iterations_ = get_param(params,"iterations",11); 138 if (iterations_<0) { 139 iterations_ = (std::numeric_limits<int>::max)(); 140 } 141 centers_init_ = get_param(params,"centers_init",FLANN_CENTERS_RANDOM); 142 cb_index_ = get_param(params,"cb_index",0.4f); 143 144 initCenterChooser(); 145 } 146 147 KMeansIndex(const KMeansIndex & other)148 KMeansIndex(const KMeansIndex& other) : BaseClass(other), 149 branching_(other.branching_), 150 iterations_(other.iterations_), 151 centers_init_(other.centers_init_), 152 cb_index_(other.cb_index_), 153 memoryCounter_(other.memoryCounter_) 154 { 155 initCenterChooser(); 156 157 copyTree(root_, other.root_); 158 } 159 160 KMeansIndex& operator=(KMeansIndex other) 161 { 162 this->swap(other); 163 return *this; 164 } 165 166 initCenterChooser()167 void initCenterChooser() 168 { 169 switch(centers_init_) { 170 case FLANN_CENTERS_RANDOM: 171 chooseCenters_ = new RandomCenterChooser<Distance>(distance_); 172 break; 173 case FLANN_CENTERS_GONZALES: 174 chooseCenters_ = new GonzalesCenterChooser<Distance>(distance_); 175 break; 176 case FLANN_CENTERS_KMEANSPP: 177 chooseCenters_ = new KMeansppCenterChooser<Distance>(distance_); 178 break; 179 default: 180 throw FLANNException("Unknown algorithm for choosing initial centers."); 181 } 182 } 183 184 /** 185 * Index destructor. 186 * 187 * Release the memory used by the index. 188 */ ~KMeansIndex()189 virtual ~KMeansIndex() 190 { 191 delete chooseCenters_; 192 freeIndex(); 193 } 194 clone()195 BaseClass* clone() const 196 { 197 return new KMeansIndex(*this); 198 } 199 200 set_cb_index(float index)201 void set_cb_index( float index) 202 { 203 cb_index_ = index; 204 } 205 206 /** 207 * Computes the inde memory usage 208 * Returns: memory used by the index 209 */ usedMemory()210 int usedMemory() const 211 { 212 return pool_.usedMemory+pool_.wastedMemory+memoryCounter_; 213 } 214 215 using BaseClass::buildIndex; 216 217 void addPoints(const Matrix<ElementType>& points, float rebuild_threshold = 2) 218 { 219 assert(points.cols==veclen_); 220 size_t old_size = size_; 221 222 extendDataset(points); 223 224 if (rebuild_threshold>1 && size_at_build_*rebuild_threshold<size_) { 225 buildIndex(); 226 } 227 else { 228 for (size_t i=0;i<points.rows;++i) { 229 DistanceType dist = distance_(root_->pivot, points[i], veclen_); 230 addPointToTree(root_, old_size + i, dist); 231 } 232 } 233 } 234 235 template<typename Archive> serialize(Archive & ar)236 void serialize(Archive& ar) 237 { 238 ar.setObject(this); 239 240 ar & *static_cast<NNIndex<Distance>*>(this); 241 242 ar & branching_; 243 ar & iterations_; 244 ar & memoryCounter_; 245 ar & cb_index_; 246 ar & centers_init_; 247 248 if (Archive::is_loading::value) { 249 root_ = new(pool_) Node(); 250 } 251 ar & *root_; 252 253 if (Archive::is_loading::value) { 254 index_params_["algorithm"] = getType(); 255 index_params_["branching"] = branching_; 256 index_params_["iterations"] = iterations_; 257 index_params_["centers_init"] = centers_init_; 258 index_params_["cb_index"] = cb_index_; 259 } 260 } 261 saveIndex(FILE * stream)262 void saveIndex(FILE* stream) 263 { 264 serialization::SaveArchive sa(stream); 265 sa & *this; 266 } 267 loadIndex(FILE * stream)268 void loadIndex(FILE* stream) 269 { 270 freeIndex(); 271 serialization::LoadArchive la(stream); 272 la & *this; 273 } 274 275 /** 276 * Find set of nearest neighbors to vec. Their indices are stored inside 277 * the result object. 278 * 279 * Params: 280 * result = the result object in which the indices of the nearest-neighbors are stored 281 * vec = the vector for which to search the nearest neighbors 282 * searchParams = parameters that influence the search algorithm (checks, cb_index) 283 */ 284 findNeighbors(ResultSet<DistanceType> & result,const ElementType * vec,const SearchParams & searchParams)285 void findNeighbors(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams) const 286 { 287 if (removed_) { 288 findNeighborsWithRemoved<true>(result, vec, searchParams); 289 } 290 else { 291 findNeighborsWithRemoved<false>(result, vec, searchParams); 292 } 293 294 } 295 296 /** 297 * Clustering function that takes a cut in the hierarchical k-means 298 * tree and return the clusters centers of that clustering. 299 * Params: 300 * numClusters = number of clusters to have in the clustering computed 301 * Returns: number of cluster centers 302 */ getClusterCenters(Matrix<DistanceType> & centers)303 int getClusterCenters(Matrix<DistanceType>& centers) 304 { 305 int numClusters = centers.rows; 306 if (numClusters<1) { 307 throw FLANNException("Number of clusters must be at least 1"); 308 } 309 310 DistanceType variance; 311 std::vector<NodePtr> clusters(numClusters); 312 313 int clusterCount = getMinVarianceClusters(root_, clusters, numClusters, variance); 314 315 Logger::info("Clusters requested: %d, returning %d\n",numClusters, clusterCount); 316 317 for (int i=0; i<clusterCount; ++i) { 318 DistanceType* center = clusters[i]->pivot; 319 for (size_t j=0; j<veclen_; ++j) { 320 centers[i][j] = center[j]; 321 } 322 } 323 324 return clusterCount; 325 } 326 327 protected: 328 /** 329 * Builds the index 330 */ buildIndexImpl()331 void buildIndexImpl() 332 { 333 if (branching_<2) { 334 throw FLANNException("Branching factor must be at least 2"); 335 } 336 337 std::vector<int> indices(size_); 338 for (size_t i=0; i<size_; ++i) { 339 indices[i] = int(i); 340 } 341 342 root_ = new(pool_) Node(); 343 computeNodeStatistics(root_, indices); 344 computeClustering(root_, &indices[0], (int)size_, branching_); 345 } 346 347 private: 348 349 struct PointInfo 350 { 351 size_t index; 352 ElementType* point; 353 private: 354 template<typename Archive> serializePointInfo355 void serialize(Archive& ar) 356 { 357 typedef KMeansIndex<Distance> Index; 358 Index* obj = static_cast<Index*>(ar.getObject()); 359 360 ar & index; 361 // ar & point; 362 363 if (Archive::is_loading::value) point = obj->points_[index]; 364 } 365 friend struct serialization::access; 366 }; 367 368 /** 369 * Struture representing a node in the hierarchical k-means tree. 370 */ 371 struct Node 372 { 373 /** 374 * The cluster center. 375 */ 376 DistanceType* pivot; 377 /** 378 * The cluster radius. 379 */ 380 DistanceType radius; 381 /** 382 * The cluster variance. 383 */ 384 DistanceType variance; 385 /** 386 * The cluster size (number of points in the cluster) 387 */ 388 int size; 389 /** 390 * Child nodes (only for non-terminal nodes) 391 */ 392 std::vector<Node*> childs; 393 /** 394 * Node points (only for terminal nodes) 395 */ 396 std::vector<PointInfo> points; 397 /** 398 * Level 399 */ 400 // int level; 401 ~NodeNode402 ~Node() 403 { 404 delete[] pivot; 405 if (!childs.empty()) { 406 for (size_t i=0; i<childs.size(); ++i) { 407 childs[i]->~Node(); 408 } 409 } 410 } 411 412 template<typename Archive> serializeNode413 void serialize(Archive& ar) 414 { 415 typedef KMeansIndex<Distance> Index; 416 Index* obj = static_cast<Index*>(ar.getObject()); 417 418 if (Archive::is_loading::value) { 419 pivot = new DistanceType[obj->veclen_]; 420 } 421 ar & serialization::make_binary_object(pivot, obj->veclen_*sizeof(DistanceType)); 422 ar & radius; 423 ar & variance; 424 ar & size; 425 426 size_t childs_size; 427 if (Archive::is_saving::value) { 428 childs_size = childs.size(); 429 } 430 ar & childs_size; 431 432 if (childs_size==0) { 433 ar & points; 434 } 435 else { 436 if (Archive::is_loading::value) { 437 childs.resize(childs_size); 438 } 439 for (size_t i=0;i<childs_size;++i) { 440 if (Archive::is_loading::value) { 441 childs[i] = new(obj->pool_) Node(); 442 } 443 ar & *childs[i]; 444 } 445 } 446 } 447 friend struct serialization::access; 448 }; 449 typedef Node* NodePtr; 450 451 /** 452 * Alias definition for a nicer syntax. 453 */ 454 typedef BranchStruct<NodePtr, DistanceType> BranchSt; 455 456 457 /** 458 * Helper function 459 */ freeIndex()460 void freeIndex() 461 { 462 if (root_) root_->~Node(); 463 root_ = NULL; 464 pool_.free(); 465 } 466 copyTree(NodePtr & dst,const NodePtr & src)467 void copyTree(NodePtr& dst, const NodePtr& src) 468 { 469 dst = new(pool_) Node(); 470 dst->pivot = new DistanceType[veclen_]; 471 std::copy(src->pivot, src->pivot+veclen_, dst->pivot); 472 dst->radius = src->radius; 473 dst->variance = src->variance; 474 dst->size = src->size; 475 476 if (src->childs.size()==0) { 477 dst->points = src->points; 478 } 479 else { 480 dst->childs.resize(src->childs.size()); 481 for (size_t i=0;i<src->childs.size();++i) { 482 copyTree(dst->childs[i], src->childs[i]); 483 } 484 } 485 } 486 487 488 /** 489 * Computes the statistics of a node (mean, radius, variance). 490 * 491 * Params: 492 * node = the node to use 493 * indices = the indices of the points belonging to the node 494 */ computeNodeStatistics(NodePtr node,const std::vector<int> & indices)495 void computeNodeStatistics(NodePtr node, const std::vector<int>& indices) 496 { 497 size_t size = indices.size(); 498 499 DistanceType* mean = new DistanceType[veclen_]; 500 memoryCounter_ += int(veclen_*sizeof(DistanceType)); 501 memset(mean,0,veclen_*sizeof(DistanceType)); 502 503 for (size_t i=0; i<size; ++i) { 504 ElementType* vec = points_[indices[i]]; 505 for (size_t j=0; j<veclen_; ++j) { 506 mean[j] += vec[j]; 507 } 508 } 509 DistanceType div_factor = DistanceType(1)/size; 510 for (size_t j=0; j<veclen_; ++j) { 511 mean[j] *= div_factor; 512 } 513 514 DistanceType radius = 0; 515 DistanceType variance = 0; 516 for (size_t i=0; i<size; ++i) { 517 DistanceType dist = distance_(mean, points_[indices[i]], veclen_); 518 if (dist>radius) { 519 radius = dist; 520 } 521 variance += dist; 522 } 523 variance /= size; 524 525 node->variance = variance; 526 node->radius = radius; 527 node->pivot = mean; 528 } 529 530 531 /** 532 * The method responsible with actually doing the recursive hierarchical 533 * clustering 534 * 535 * Params: 536 * node = the node to cluster 537 * indices = indices of the points belonging to the current node 538 * branching = the branching factor to use in the clustering 539 * 540 * TODO: for 1-sized clusters don't store a cluster center (it's the same as the single cluster point) 541 */ computeClustering(NodePtr node,int * indices,int indices_length,int branching)542 void computeClustering(NodePtr node, int* indices, int indices_length, int branching) 543 { 544 node->size = indices_length; 545 546 if (indices_length < branching) { 547 node->points.resize(indices_length); 548 for (int i=0;i<indices_length;++i) { 549 node->points[i].index = indices[i]; 550 node->points[i].point = points_[indices[i]]; 551 } 552 node->childs.clear(); 553 return; 554 } 555 556 std::vector<int> centers_idx(branching); 557 int centers_length; 558 (*chooseCenters_)(branching, indices, indices_length, ¢ers_idx[0], centers_length); 559 560 if (centers_length<branching) { 561 node->points.resize(indices_length); 562 for (int i=0;i<indices_length;++i) { 563 node->points[i].index = indices[i]; 564 node->points[i].point = points_[indices[i]]; 565 } 566 node->childs.clear(); 567 return; 568 } 569 570 571 Matrix<double> dcenters(new double[branching*veclen_],branching,veclen_); 572 for (int i=0; i<centers_length; ++i) { 573 ElementType* vec = points_[centers_idx[i]]; 574 for (size_t k=0; k<veclen_; ++k) { 575 dcenters[i][k] = double(vec[k]); 576 } 577 } 578 579 std::vector<DistanceType> radiuses(branching,0); 580 std::vector<int> count(branching,0); 581 582 // assign points to clusters 583 std::vector<int> belongs_to(indices_length); 584 for (int i=0; i<indices_length; ++i) { 585 586 DistanceType sq_dist = distance_(points_[indices[i]], dcenters[0], veclen_); 587 belongs_to[i] = 0; 588 for (int j=1; j<branching; ++j) { 589 DistanceType new_sq_dist = distance_(points_[indices[i]], dcenters[j], veclen_); 590 if (sq_dist>new_sq_dist) { 591 belongs_to[i] = j; 592 sq_dist = new_sq_dist; 593 } 594 } 595 if (sq_dist>radiuses[belongs_to[i]]) { 596 radiuses[belongs_to[i]] = sq_dist; 597 } 598 count[belongs_to[i]]++; 599 } 600 601 bool converged = false; 602 int iteration = 0; 603 while (!converged && iteration<iterations_) { 604 converged = true; 605 iteration++; 606 607 // compute the new cluster centers 608 for (int i=0; i<branching; ++i) { 609 memset(dcenters[i],0,sizeof(double)*veclen_); 610 radiuses[i] = 0; 611 } 612 for (int i=0; i<indices_length; ++i) { 613 ElementType* vec = points_[indices[i]]; 614 double* center = dcenters[belongs_to[i]]; 615 for (size_t k=0; k<veclen_; ++k) { 616 center[k] += vec[k]; 617 } 618 } 619 for (int i=0; i<branching; ++i) { 620 int cnt = count[i]; 621 double div_factor = 1.0/cnt; 622 for (size_t k=0; k<veclen_; ++k) { 623 dcenters[i][k] *= div_factor; 624 } 625 } 626 627 // reassign points to clusters 628 for (int i=0; i<indices_length; ++i) { 629 DistanceType sq_dist = distance_(points_[indices[i]], dcenters[0], veclen_); 630 int new_centroid = 0; 631 for (int j=1; j<branching; ++j) { 632 DistanceType new_sq_dist = distance_(points_[indices[i]], dcenters[j], veclen_); 633 if (sq_dist>new_sq_dist) { 634 new_centroid = j; 635 sq_dist = new_sq_dist; 636 } 637 } 638 if (sq_dist>radiuses[new_centroid]) { 639 radiuses[new_centroid] = sq_dist; 640 } 641 if (new_centroid != belongs_to[i]) { 642 count[belongs_to[i]]--; 643 count[new_centroid]++; 644 belongs_to[i] = new_centroid; 645 646 converged = false; 647 } 648 } 649 650 for (int i=0; i<branching; ++i) { 651 // if one cluster converges to an empty cluster, 652 // move an element into that cluster 653 if (count[i]==0) { 654 int j = (i+1)%branching; 655 while (count[j]<=1) { 656 j = (j+1)%branching; 657 } 658 659 for (int k=0; k<indices_length; ++k) { 660 if (belongs_to[k]==j) { 661 belongs_to[k] = i; 662 count[j]--; 663 count[i]++; 664 break; 665 } 666 } 667 converged = false; 668 } 669 } 670 671 } 672 673 std::vector<DistanceType*> centers(branching); 674 675 for (int i=0; i<branching; ++i) { 676 centers[i] = new DistanceType[veclen_]; 677 memoryCounter_ += veclen_*sizeof(DistanceType); 678 for (size_t k=0; k<veclen_; ++k) { 679 centers[i][k] = (DistanceType)dcenters[i][k]; 680 } 681 } 682 683 684 // compute kmeans clustering for each of the resulting clusters 685 node->childs.resize(branching); 686 int start = 0; 687 int end = start; 688 for (int c=0; c<branching; ++c) { 689 int s = count[c]; 690 691 DistanceType variance = 0; 692 for (int i=0; i<indices_length; ++i) { 693 if (belongs_to[i]==c) { 694 variance += distance_(centers[c], points_[indices[i]], veclen_); 695 std::swap(indices[i],indices[end]); 696 std::swap(belongs_to[i],belongs_to[end]); 697 end++; 698 } 699 } 700 variance /= s; 701 702 node->childs[c] = new(pool_) Node(); 703 node->childs[c]->radius = radiuses[c]; 704 node->childs[c]->pivot = centers[c]; 705 node->childs[c]->variance = variance; 706 computeClustering(node->childs[c],indices+start, end-start, branching); 707 start=end; 708 } 709 710 delete[] dcenters.ptr(); 711 } 712 713 714 template<bool with_removed> findNeighborsWithRemoved(ResultSet<DistanceType> & result,const ElementType * vec,const SearchParams & searchParams)715 void findNeighborsWithRemoved(ResultSet<DistanceType>& result, const ElementType* vec, const SearchParams& searchParams) const 716 { 717 718 int maxChecks = searchParams.checks; 719 720 if (maxChecks==FLANN_CHECKS_UNLIMITED) { 721 findExactNN<with_removed>(root_, result, vec); 722 } 723 else { 724 // Priority queue storing intermediate branches in the best-bin-first search 725 Heap<BranchSt>* heap = new Heap<BranchSt>((int)size_); 726 727 int checks = 0; 728 findNN<with_removed>(root_, result, vec, checks, maxChecks, heap); 729 730 BranchSt branch; 731 while (heap->popMin(branch) && (checks<maxChecks || !result.full())) { 732 NodePtr node = branch.node; 733 findNN<with_removed>(node, result, vec, checks, maxChecks, heap); 734 } 735 736 delete heap; 737 } 738 739 } 740 741 742 /** 743 * Performs one descent in the hierarchical k-means tree. The branches not 744 * visited are stored in a priority queue. 745 * 746 * Params: 747 * node = node to explore 748 * result = container for the k-nearest neighbors found 749 * vec = query points 750 * checks = how many points in the dataset have been checked so far 751 * maxChecks = maximum dataset points to checks 752 */ 753 754 template<bool with_removed> findNN(NodePtr node,ResultSet<DistanceType> & result,const ElementType * vec,int & checks,int maxChecks,Heap<BranchSt> * heap)755 void findNN(NodePtr node, ResultSet<DistanceType>& result, const ElementType* vec, int& checks, int maxChecks, 756 Heap<BranchSt>* heap) const 757 { 758 // Ignore those clusters that are too far away 759 { 760 DistanceType bsq = distance_(vec, node->pivot, veclen_); 761 DistanceType rsq = node->radius; 762 DistanceType wsq = result.worstDist(); 763 764 DistanceType val = bsq-rsq-wsq; 765 DistanceType val2 = val*val-4*rsq*wsq; 766 767 //if (val>0) { 768 if ((val>0)&&(val2>0)) { 769 return; 770 } 771 } 772 773 if (node->childs.empty()) { 774 if (checks>=maxChecks) { 775 if (result.full()) return; 776 } 777 for (int i=0; i<node->size; ++i) { 778 PointInfo& point_info = node->points[i]; 779 int index = point_info.index; 780 if (with_removed) { 781 if (removed_points_.test(index)) continue; 782 } 783 DistanceType dist = distance_(point_info.point, vec, veclen_); 784 result.addPoint(dist, index); 785 ++checks; 786 } 787 } 788 else { 789 int closest_center = exploreNodeBranches(node, vec, heap); 790 findNN<with_removed>(node->childs[closest_center],result,vec, checks, maxChecks, heap); 791 } 792 } 793 794 /** 795 * Helper function that computes the nearest childs of a node to a given query point. 796 * Params: 797 * node = the node 798 * q = the query point 799 * distances = array with the distances to each child node. 800 * Returns: 801 */ exploreNodeBranches(NodePtr node,const ElementType * q,Heap<BranchSt> * heap)802 int exploreNodeBranches(NodePtr node, const ElementType* q, Heap<BranchSt>* heap) const 803 { 804 std::vector<DistanceType> domain_distances(branching_); 805 int best_index = 0; 806 domain_distances[best_index] = distance_(q, node->childs[best_index]->pivot, veclen_); 807 for (int i=1; i<branching_; ++i) { 808 domain_distances[i] = distance_(q, node->childs[i]->pivot, veclen_); 809 if (domain_distances[i]<domain_distances[best_index]) { 810 best_index = i; 811 } 812 } 813 814 // float* best_center = node->childs[best_index]->pivot; 815 for (int i=0; i<branching_; ++i) { 816 if (i != best_index) { 817 domain_distances[i] -= cb_index_*node->childs[i]->variance; 818 819 // float dist_to_border = getDistanceToBorder(node.childs[i].pivot,best_center,q); 820 // if (domain_distances[i]<dist_to_border) { 821 // domain_distances[i] = dist_to_border; 822 // } 823 heap->insert(BranchSt(node->childs[i],domain_distances[i])); 824 } 825 } 826 827 return best_index; 828 } 829 830 831 /** 832 * Function the performs exact nearest neighbor search by traversing the entire tree. 833 */ 834 template<bool with_removed> findExactNN(NodePtr node,ResultSet<DistanceType> & result,const ElementType * vec)835 void findExactNN(NodePtr node, ResultSet<DistanceType>& result, const ElementType* vec) const 836 { 837 // Ignore those clusters that are too far away 838 { 839 DistanceType bsq = distance_(vec, node->pivot, veclen_); 840 DistanceType rsq = node->radius; 841 DistanceType wsq = result.worstDist(); 842 843 DistanceType val = bsq-rsq-wsq; 844 DistanceType val2 = val*val-4*rsq*wsq; 845 846 // if (val>0) { 847 if ((val>0)&&(val2>0)) { 848 return; 849 } 850 } 851 852 if (node->childs.empty()) { 853 for (int i=0; i<node->size; ++i) { 854 PointInfo& point_info = node->points[i]; 855 int index = point_info.index; 856 if (with_removed) { 857 if (removed_points_.test(index)) continue; 858 } 859 DistanceType dist = distance_(point_info.point, vec, veclen_); 860 result.addPoint(dist, index); 861 } 862 } 863 else { 864 std::vector<int> sort_indices(branching_); 865 getCenterOrdering(node, vec, sort_indices); 866 867 for (int i=0; i<branching_; ++i) { 868 findExactNN<with_removed>(node->childs[sort_indices[i]],result,vec); 869 } 870 871 } 872 } 873 874 875 /** 876 * Helper function. 877 * 878 * I computes the order in which to traverse the child nodes of a particular node. 879 */ getCenterOrdering(NodePtr node,const ElementType * q,std::vector<int> & sort_indices)880 void getCenterOrdering(NodePtr node, const ElementType* q, std::vector<int>& sort_indices) const 881 { 882 std::vector<DistanceType> domain_distances(branching_); 883 for (int i=0; i<branching_; ++i) { 884 DistanceType dist = distance_(q, node->childs[i]->pivot, veclen_); 885 886 int j=0; 887 while (domain_distances[j]<dist && j<i) j++; 888 for (int k=i; k>j; --k) { 889 domain_distances[k] = domain_distances[k-1]; 890 sort_indices[k] = sort_indices[k-1]; 891 } 892 domain_distances[j] = dist; 893 sort_indices[j] = i; 894 } 895 } 896 897 /** 898 * Method that computes the squared distance from the query point q 899 * from inside region with center c to the border between this 900 * region and the region with center p 901 */ getDistanceToBorder(DistanceType * p,DistanceType * c,DistanceType * q)902 DistanceType getDistanceToBorder(DistanceType* p, DistanceType* c, DistanceType* q) const 903 { 904 DistanceType sum = 0; 905 DistanceType sum2 = 0; 906 907 for (int i=0; i<veclen_; ++i) { 908 DistanceType t = c[i]-p[i]; 909 sum += t*(q[i]-(c[i]+p[i])/2); 910 sum2 += t*t; 911 } 912 913 return sum*sum/sum2; 914 } 915 916 917 /** 918 * Helper function the descends in the hierarchical k-means tree by spliting those clusters that minimize 919 * the overall variance of the clustering. 920 * Params: 921 * root = root node 922 * clusters = array with clusters centers (return value) 923 * varianceValue = variance of the clustering (return value) 924 * Returns: 925 */ getMinVarianceClusters(NodePtr root,std::vector<NodePtr> & clusters,int clusters_length,DistanceType & varianceValue)926 int getMinVarianceClusters(NodePtr root, std::vector<NodePtr>& clusters, int clusters_length, DistanceType& varianceValue) const 927 { 928 int clusterCount = 1; 929 clusters[0] = root; 930 931 DistanceType meanVariance = root->variance*root->size; 932 933 while (clusterCount<clusters_length) { 934 DistanceType minVariance = (std::numeric_limits<DistanceType>::max)(); 935 int splitIndex = -1; 936 937 for (int i=0; i<clusterCount; ++i) { 938 if (!clusters[i]->childs.empty()) { 939 940 DistanceType variance = meanVariance - clusters[i]->variance*clusters[i]->size; 941 942 for (int j=0; j<branching_; ++j) { 943 variance += clusters[i]->childs[j]->variance*clusters[i]->childs[j]->size; 944 } 945 if (variance<minVariance) { 946 minVariance = variance; 947 splitIndex = i; 948 } 949 } 950 } 951 952 if (splitIndex==-1) break; 953 if ( (branching_+clusterCount-1) > clusters_length) break; 954 955 meanVariance = minVariance; 956 957 // split node 958 NodePtr toSplit = clusters[splitIndex]; 959 clusters[splitIndex] = toSplit->childs[0]; 960 for (int i=1; i<branching_; ++i) { 961 clusters[clusterCount++] = toSplit->childs[i]; 962 } 963 } 964 965 varianceValue = meanVariance/root->size; 966 return clusterCount; 967 } 968 addPointToTree(NodePtr node,size_t index,DistanceType dist_to_pivot)969 void addPointToTree(NodePtr node, size_t index, DistanceType dist_to_pivot) 970 { 971 ElementType* point = points_[index]; 972 if (dist_to_pivot>node->radius) { 973 node->radius = dist_to_pivot; 974 } 975 // if radius changed above, the variance will be an approximation 976 node->variance = (node->size*node->variance+dist_to_pivot)/(node->size+1); 977 node->size++; 978 979 if (node->childs.empty()) { // leaf node 980 PointInfo point_info; 981 point_info.index = index; 982 point_info.point = point; 983 node->points.push_back(point_info); 984 985 std::vector<int> indices(node->points.size()); 986 for (size_t i=0;i<node->points.size();++i) { 987 indices[i] = node->points[i].index; 988 } 989 computeNodeStatistics(node, indices); 990 if (indices.size()>=size_t(branching_)) { 991 computeClustering(node, &indices[0], indices.size(), branching_); 992 } 993 } 994 else { 995 // find the closest child 996 int closest = 0; 997 DistanceType dist = distance_(node->childs[closest]->pivot, point, veclen_); 998 for (size_t i=1;i<size_t(branching_);++i) { 999 DistanceType crt_dist = distance_(node->childs[i]->pivot, point, veclen_); 1000 if (crt_dist<dist) { 1001 dist = crt_dist; 1002 closest = i; 1003 } 1004 } 1005 addPointToTree(node->childs[closest], index, dist); 1006 } 1007 } 1008 1009 swap(KMeansIndex & other)1010 void swap(KMeansIndex& other) 1011 { 1012 std::swap(branching_, other.branching_); 1013 std::swap(iterations_, other.iterations_); 1014 std::swap(centers_init_, other.centers_init_); 1015 std::swap(cb_index_, other.cb_index_); 1016 std::swap(root_, other.root_); 1017 std::swap(pool_, other.pool_); 1018 std::swap(memoryCounter_, other.memoryCounter_); 1019 std::swap(chooseCenters_, other.chooseCenters_); 1020 } 1021 1022 1023 private: 1024 /** The branching factor used in the hierarchical k-means clustering */ 1025 int branching_; 1026 1027 /** Maximum number of iterations to use when performing k-means clustering */ 1028 int iterations_; 1029 1030 /** Algorithm for choosing the cluster centers */ 1031 flann_centers_init_t centers_init_; 1032 1033 /** 1034 * Cluster border index. This is used in the tree search phase when determining 1035 * the closest cluster to explore next. A zero value takes into account only 1036 * the cluster centres, a value greater then zero also take into account the size 1037 * of the cluster. 1038 */ 1039 float cb_index_; 1040 1041 /** 1042 * The root node in the tree. 1043 */ 1044 NodePtr root_; 1045 1046 /** 1047 * Pooled memory allocator. 1048 */ 1049 PooledAllocator pool_; 1050 1051 /** 1052 * Memory occupied by the index. 1053 */ 1054 int memoryCounter_; 1055 1056 /** 1057 * Algorithm used to choose initial centers 1058 */ 1059 CenterChooser<Distance>* chooseCenters_; 1060 1061 USING_BASECLASS_SYMBOLS 1062 }; 1063 1064 } 1065 1066 #endif //FLANN_KMEANS_INDEX_H_ 1067