1 // Copyright (c) 2018, ETH Zurich and UNC Chapel Hill.
2 // All rights reserved.
3 //
4 // Redistribution and use in source and binary forms, with or without
5 // modification, are permitted provided that the following conditions are met:
6 //
7 //     * Redistributions of source code must retain the above copyright
8 //       notice, this list of conditions and the following disclaimer.
9 //
10 //     * Redistributions in binary form must reproduce the above copyright
11 //       notice, this list of conditions and the following disclaimer in the
12 //       documentation and/or other materials provided with the distribution.
13 //
14 //     * Neither the name of ETH Zurich and UNC Chapel Hill nor the names of
15 //       its contributors may be used to endorse or promote products derived
16 //       from this software without specific prior written permission.
17 //
18 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
19 // AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
20 // IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
21 // ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDERS OR CONTRIBUTORS BE
22 // LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
23 // CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
24 // SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
25 // INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
26 // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
27 // ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
28 // POSSIBILITY OF SUCH DAMAGE.
29 //
30 // Author: Johannes L. Schoenberger (jsch-at-demuc-dot-de)
31 
32 #ifndef COLMAP_SRC_RETRIEVAL_VISUAL_INDEX_H_
33 #define COLMAP_SRC_RETRIEVAL_VISUAL_INDEX_H_
34 
35 #include <boost/heap/fibonacci_heap.hpp>
36 #include <Eigen/Core>
37 
38 #include "FLANN/flann.hpp"
39 #include "feature/types.h"
40 #include "retrieval/inverted_file.h"
41 #include "retrieval/inverted_index.h"
42 #include "retrieval/vote_and_verify.h"
43 #include "util/alignment.h"
44 #include "util/endian.h"
45 #include "util/logging.h"
46 #include "util/math.h"
47 
48 namespace colmap {
49 namespace retrieval {
50 
51 // Visual index for image retrieval using a vocabulary tree with Hamming
52 // embedding, based on the papers:
53 //
54 //    Schönberger, Price, Sattler, Pollefeys, Frahm. "A Vote-and-Verify Strategy
55 //    for Fast Spatial Verification in Image Retrieval". ACCV 2016.
56 //
57 //    Arandjelovic, Zisserman: Scalable descriptor
58 //    distinctiveness for location recognition. ACCV 2014.
59 template <typename kDescType = uint8_t, int kDescDim = 128,
60           int kEmbeddingDim = 64>
61 class VisualIndex {
62  public:
63   static const int kMaxNumThreads = -1;
64   typedef InvertedIndex<kDescType, kDescDim, kEmbeddingDim> InvertedIndexType;
65   typedef FeatureKeypoints GeomType;
66   typedef typename InvertedIndexType::DescType DescType;
67   typedef typename InvertedIndexType::EntryType EntryType;
68 
69   struct IndexOptions {
70     // The number of nearest neighbor visual words that each feature descriptor
71     // is assigned to.
72     int num_neighbors = 1;
73 
74     // The number of checks in the nearest neighbor search.
75     int num_checks = 256;
76 
77     // The number of threads used in the index.
78     int num_threads = kMaxNumThreads;
79   };
80 
81   struct QueryOptions {
82     // The maximum number of most similar images to retrieve.
83     int max_num_images = -1;
84 
85     // The number of nearest neighbor visual words that each feature descriptor
86     // is assigned to.
87     int num_neighbors = 5;
88 
89     // The number of checks in the nearest neighbor search.
90     int num_checks = 256;
91 
92     // Whether to perform spatial verification after image retrieval.
93     int num_images_after_verification = 0;
94 
95     // The number of threads used in the index.
96     int num_threads = kMaxNumThreads;
97   };
98 
99   struct BuildOptions {
100     // The desired number of visual words, i.e. the number of leaf node
101     // clusters. Note that the actual number of visual words might be less.
102     int num_visual_words = 256 * 256;
103 
104     // The branching factor of the hierarchical k-means tree.
105     int branching = 256;
106 
107     // The number of iterations for the clustering.
108     int num_iterations = 11;
109 
110     // The target precision of the visual word search index.
111     double target_precision = 0.95;
112 
113     // The number of checks in the nearest neighbor search.
114     int num_checks = 256;
115 
116     // The number of threads used in the index.
117     int num_threads = kMaxNumThreads;
118   };
119 
120   VisualIndex();
121   ~VisualIndex();
122 
123   size_t NumVisualWords() const;
124 
125   // Add image to the visual index.
126   void Add(const IndexOptions& options, const int image_id,
127            const GeomType& geometries, const DescType& descriptors);
128 
129   // Check if an image has been indexed.
130   bool ImageIndexed(const int image_id) const;
131 
132   // Query for most similar images in the visual index.
133   void Query(const QueryOptions& options, const DescType& descriptors,
134              std::vector<ImageScore>* image_scores) const;
135 
136   // Query for most similar images in the visual index.
137   void Query(const QueryOptions& options, const GeomType& geometries,
138              const DescType& descriptors,
139              std::vector<ImageScore>* image_scores) const;
140 
141   // Prepare the index after adding images and before querying.
142   void Prepare();
143 
144   // Build a visual index from a set of training descriptors by quantizing the
145   // descriptor space into visual words and compute their Hamming embedding.
146   void Build(const BuildOptions& options, const DescType& descriptors);
147 
148   // Read and write the visual index. This can be done for an index with and
149   // without indexed images.
150   void Read(const std::string& path);
151   void Write(const std::string& path);
152 
153  private:
154   // Quantize the descriptor space into visual words.
155   void Quantize(const BuildOptions& options, const DescType& descriptors);
156 
157   // Query for nearest neighbor images and return nearest neighbor visual word
158   // identifiers for each descriptor.
159   void QueryAndFindWordIds(const QueryOptions& options,
160                            const DescType& descriptors,
161                            std::vector<ImageScore>* image_scores,
162                            Eigen::MatrixXi* word_ids) const;
163 
164   // Find the nearest neighbor visual words for the given descriptors.
165   Eigen::MatrixXi FindWordIds(const DescType& descriptors,
166                               const int num_neighbors, const int num_checks,
167                               const int num_threads) const;
168 
169   // The search structure on the quantized descriptor space.
170   flann::AutotunedIndex<flann::L2<kDescType>> visual_word_index_;
171 
172   // The centroids of the visual words.
173   flann::Matrix<kDescType> visual_words_;
174 
175   // The inverted index of the database.
176   InvertedIndexType inverted_index_;
177 
178   // Identifiers of all indexed images.
179   std::unordered_set<int> image_ids_;
180 
181   // Whether the index is prepared.
182   bool prepared_;
183 };
184 
185 ////////////////////////////////////////////////////////////////////////////////
186 // Implementation
187 ////////////////////////////////////////////////////////////////////////////////
188 
189 template <typename kDescType, int kDescDim, int kEmbeddingDim>
VisualIndex()190 VisualIndex<kDescType, kDescDim, kEmbeddingDim>::VisualIndex()
191     : prepared_(false) {}
192 
193 template <typename kDescType, int kDescDim, int kEmbeddingDim>
~VisualIndex()194 VisualIndex<kDescType, kDescDim, kEmbeddingDim>::~VisualIndex() {
195   if (visual_words_.ptr() != nullptr) {
196     delete[] visual_words_.ptr();
197   }
198 }
199 
200 template <typename kDescType, int kDescDim, int kEmbeddingDim>
NumVisualWords()201 size_t VisualIndex<kDescType, kDescDim, kEmbeddingDim>::NumVisualWords() const {
202   return visual_words_.rows;
203 }
204 
205 template <typename kDescType, int kDescDim, int kEmbeddingDim>
Add(const IndexOptions & options,const int image_id,const GeomType & geometries,const DescType & descriptors)206 void VisualIndex<kDescType, kDescDim, kEmbeddingDim>::Add(
207     const IndexOptions& options, const int image_id, const GeomType& geometries,
208     const DescType& descriptors) {
209   CHECK_EQ(geometries.size(), descriptors.rows());
210 
211   // If the image is already indexed, do nothing.
212   if (ImageIndexed(image_id)) {
213     return;
214   }
215 
216   image_ids_.insert(image_id);
217 
218   prepared_ = false;
219 
220   if (descriptors.rows() == 0) {
221     return;
222   }
223 
224   const Eigen::MatrixXi word_ids =
225       FindWordIds(descriptors, options.num_neighbors, options.num_checks,
226                   options.num_threads);
227 
228   for (typename DescType::Index i = 0; i < descriptors.rows(); ++i) {
229     const auto& descriptor = descriptors.row(i);
230 
231     typename InvertedIndexType::GeomType geometry;
232     geometry.x = geometries[i].x;
233     geometry.y = geometries[i].y;
234     geometry.scale = geometries[i].ComputeScale();
235     geometry.orientation = geometries[i].ComputeOrientation();
236 
237     for (int n = 0; n < options.num_neighbors; ++n) {
238       const int word_id = word_ids(i, n);
239       if (word_id != InvertedIndexType::kInvalidWordId) {
240         inverted_index_.AddEntry(image_id, word_id, i, descriptor, geometry);
241       }
242     }
243   }
244 }
245 
246 template <typename kDescType, int kDescDim, int kEmbeddingDim>
ImageIndexed(const int image_id)247 bool VisualIndex<kDescType, kDescDim, kEmbeddingDim>::ImageIndexed(
248     const int image_id) const {
249   return image_ids_.count(image_id) != 0;
250 }
251 
252 template <typename kDescType, int kDescDim, int kEmbeddingDim>
Query(const QueryOptions & options,const DescType & descriptors,std::vector<ImageScore> * image_scores)253 void VisualIndex<kDescType, kDescDim, kEmbeddingDim>::Query(
254     const QueryOptions& options,
255     const DescType& descriptors, std::vector<ImageScore>* image_scores) const {
256   const GeomType geometries;
257   Query(options, geometries, descriptors, image_scores);
258 }
259 
260 template <typename kDescType, int kDescDim, int kEmbeddingDim>
Query(const QueryOptions & options,const GeomType & geometries,const DescType & descriptors,std::vector<ImageScore> * image_scores)261 void VisualIndex<kDescType, kDescDim, kEmbeddingDim>::Query(
262     const QueryOptions& options, const GeomType& geometries,
263     const DescType& descriptors, std::vector<ImageScore>* image_scores) const {
264   Eigen::MatrixXi word_ids;
265   QueryAndFindWordIds(options, descriptors, image_scores, &word_ids);
266 
267   if (options.num_images_after_verification <= 0) {
268     return;
269   }
270 
271   CHECK_EQ(descriptors.rows(), geometries.size());
272 
273   // Extract top-ranked images to verify.
274   std::unordered_set<int> image_ids;
275   for (const auto& image_score : *image_scores) {
276     image_ids.insert(image_score.image_id);
277   }
278 
279   // Find matches for top-ranked images
280   typedef std::vector<
281       std::pair<float, std::pair<const EntryType*, const EntryType*>>>
282       OrderedMatchListType;
283 
284   // Reference our matches (with their lowest distance) for both
285   // {query feature => db feature} and vice versa.
286   std::unordered_map<int, std::unordered_map<int, OrderedMatchListType>>
287       query_to_db_matches;
288   std::unordered_map<int, std::unordered_map<int, OrderedMatchListType>>
289       db_to_query_matches;
290 
291   std::vector<const EntryType*> word_matches;
292 
293   std::vector<EntryType> query_entries;  // Convert query features, too.
294   query_entries.reserve(descriptors.rows());
295 
296   // NOTE: Currently, we are redundantly computing the feature weighting.
297   const HammingDistWeightFunctor<kEmbeddingDim> hamming_dist_weight_functor;
298 
299   for (typename DescType::Index i = 0; i < descriptors.rows(); ++i) {
300     const auto& descriptor = descriptors.row(i);
301 
302     EntryType query_entry;
303     query_entry.feature_idx = i;
304     query_entry.geometry.x = geometries[i].x;
305     query_entry.geometry.y = geometries[i].y;
306     query_entry.geometry.scale = geometries[i].ComputeScale();
307     query_entry.geometry.orientation = geometries[i].ComputeOrientation();
308     query_entries.push_back(query_entry);
309 
310     // For each db feature, keep track of the lowest distance (if db features
311     // are mapped to more than one visual word).
312     std::unordered_map<
313         int, std::unordered_map<int, std::pair<float, const EntryType*>>>
314         image_matches;
315 
316     for (int j = 0; j < word_ids.cols(); ++j) {
317       const int word_id = word_ids(i, j);
318 
319       if (word_id != InvertedIndexType::kInvalidWordId) {
320         inverted_index_.ConvertToBinaryDescriptor(word_id, descriptor,
321                                                   &query_entries[i].descriptor);
322 
323         const auto idf_weight = inverted_index_.GetIDFWeight(word_id);
324         const auto squared_idf_weight = idf_weight * idf_weight;
325 
326         inverted_index_.FindMatches(word_id, image_ids, &word_matches);
327 
328         for (const auto& match : word_matches) {
329           const size_t hamming_dist =
330               (query_entries[i].descriptor ^ match->descriptor).count();
331 
332           if (hamming_dist <= hamming_dist_weight_functor.kMaxHammingDistance) {
333             const float dist =
334                 hamming_dist_weight_functor(hamming_dist) * squared_idf_weight;
335 
336             auto& feature_matches = image_matches[match->image_id];
337             const auto feature_match = feature_matches.find(match->feature_idx);
338 
339             if (feature_match == feature_matches.end() ||
340                 feature_match->first < dist) {
341               feature_matches[match->feature_idx] = std::make_pair(dist, match);
342             }
343           }
344         }
345       }
346     }
347 
348     // Finally, cross-reference the query and db feature matches.
349     for (const auto& feature_matches : image_matches) {
350       const auto image_id = feature_matches.first;
351 
352       for (const auto& feature_match : feature_matches.second) {
353         const auto feature_idx = feature_match.first;
354         const auto dist = feature_match.second.first;
355         const auto db_match = feature_match.second.second;
356 
357         const auto entry_pair = std::make_pair(&query_entries[i], db_match);
358 
359         query_to_db_matches[image_id][i].emplace_back(dist, entry_pair);
360         db_to_query_matches[image_id][feature_idx].emplace_back(dist,
361                                                                 entry_pair);
362       }
363     }
364   }
365 
366   // Verify top-ranked images using the found matches.
367   for (auto& image_score : *image_scores) {
368     auto& query_matches = query_to_db_matches[image_score.image_id];
369     auto& db_matches = db_to_query_matches[image_score.image_id];
370 
371     // No matches found.
372     if (query_matches.empty()) {
373       continue;
374     }
375 
376     // Enforce 1-to-1 matching: Build Fibonacci heaps for the query and database
377     // features, ordered by the minimum number of matches per feature. We'll
378     // select these matches one at a time. For convenience, we'll also pre-sort
379     // the matched feature lists by matching score.
380 
381     typedef boost::heap::fibonacci_heap<std::pair<int, int>> FibonacciHeapType;
382     FibonacciHeapType query_heap;
383     FibonacciHeapType db_heap;
384     std::unordered_map<int, typename FibonacciHeapType::handle_type>
385         query_heap_handles;
386     std::unordered_map<int, typename FibonacciHeapType::handle_type>
387         db_heap_handles;
388 
389     for (auto& match_data : query_matches) {
390       std::sort(match_data.second.begin(), match_data.second.end(),
391                 std::greater<std::pair<
392                     float, std::pair<const EntryType*, const EntryType*>>>());
393 
394       query_heap_handles[match_data.first] = query_heap.push(std::make_pair(
395           -static_cast<int>(match_data.second.size()), match_data.first));
396     }
397 
398     for (auto& match_data : db_matches) {
399       std::sort(match_data.second.begin(), match_data.second.end(),
400                 std::greater<std::pair<
401                     float, std::pair<const EntryType*, const EntryType*>>>());
402 
403       db_heap_handles[match_data.first] = db_heap.push(std::make_pair(
404           -static_cast<int>(match_data.second.size()), match_data.first));
405     }
406 
407     // Keep tabs on what features have been already matched.
408     std::vector<FeatureGeometryMatch> matches;
409 
410     auto db_top = db_heap.top();  // (-num_available_matches, feature_idx)
411     auto query_top = query_heap.top();
412 
413     while (!db_heap.empty() && !query_heap.empty()) {
414       // Take the query or database feature with the smallest number of
415       // available matches.
416       const bool use_query =
417           (query_top.first >= db_top.first) && !query_heap.empty();
418 
419       // Find the best matching feature that hasn't already been matched.
420       auto& heap1 = (use_query) ? query_heap : db_heap;
421       auto& heap2 = (use_query) ? db_heap : query_heap;
422       auto& handles1 = (use_query) ? query_heap_handles : db_heap_handles;
423       auto& handles2 = (use_query) ? db_heap_handles : query_heap_handles;
424       auto& matches1 = (use_query) ? query_matches : db_matches;
425       auto& matches2 = (use_query) ? db_matches : query_matches;
426 
427       const auto idx1 = heap1.top().second;
428       heap1.pop();
429 
430       // Entries that have been matched (or processed and subsequently ignored)
431       // get their handles removed.
432       if (handles1.count(idx1) > 0) {
433         handles1.erase(idx1);
434 
435         bool match_found = false;
436 
437         // The matches have been ordered by Hamming distance, already --
438         // select the lowest available match.
439         for (auto& entry2 : matches1[idx1]) {
440           const auto idx2 = (use_query) ? entry2.second.second->feature_idx
441                                         : entry2.second.first->feature_idx;
442 
443           if (handles2.count(idx2) > 0) {
444             if (!match_found) {
445               match_found = true;
446               FeatureGeometryMatch match;
447               match.geometry1 = entry2.second.first->geometry;
448               match.geometries2.push_back(entry2.second.second->geometry);
449               matches.push_back(match);
450 
451               handles2.erase(idx2);
452 
453               // Remove this feature from consideration for all other features
454               // that matched to it.
455               for (auto& entry1 : matches2[idx2]) {
456                 const auto other_idx1 = (use_query)
457                                             ? entry1.second.first->feature_idx
458                                             : entry1.second.second->feature_idx;
459                 if (handles1.count(other_idx1) > 0) {
460                   (*handles1[other_idx1]).first += 1;
461                   heap1.increase(handles1[other_idx1]);
462                 }
463               }
464             } else {
465               (*handles2[idx2]).first += 1;
466               heap2.increase(handles2[idx2]);
467             }
468           }
469         }
470       }
471 
472       if (!query_heap.empty()) {
473         query_top = query_heap.top();
474       }
475 
476       if (!db_heap.empty()) {
477         db_top = db_heap.top();
478       }
479     }
480 
481     // Finally, run verification for the current image.
482     VoteAndVerifyOptions vote_and_verify_options;
483     image_score.score += VoteAndVerify(vote_and_verify_options, matches);
484   }
485 
486   // Re-rank the images using the spatial verification scores.
487 
488   const size_t num_images = std::min<size_t>(
489       image_scores->size(), options.num_images_after_verification);
490 
491   auto SortFunc = [](const ImageScore& score1, const ImageScore& score2) {
492     return score1.score > score2.score;
493   };
494 
495   if (num_images == image_scores->size()) {
496     std::sort(image_scores->begin(), image_scores->end(), SortFunc);
497   } else {
498     std::partial_sort(image_scores->begin(), image_scores->begin() + num_images,
499                       image_scores->end(), SortFunc);
500     image_scores->resize(num_images);
501   }
502 }
503 
504 template <typename kDescType, int kDescDim, int kEmbeddingDim>
Prepare()505 void VisualIndex<kDescType, kDescDim, kEmbeddingDim>::Prepare() {
506   inverted_index_.Finalize();
507   prepared_ = true;
508 }
509 
510 template <typename kDescType, int kDescDim, int kEmbeddingDim>
Build(const BuildOptions & options,const DescType & descriptors)511 void VisualIndex<kDescType, kDescDim, kEmbeddingDim>::Build(
512     const BuildOptions& options, const DescType& descriptors) {
513   // Quantize the descriptor space into visual words.
514   Quantize(options, descriptors);
515 
516   // Build the search index on the visual words.
517   flann::AutotunedIndexParams index_params;
518   index_params["target_precision"] =
519       static_cast<float>(options.target_precision);
520   visual_word_index_ =
521       flann::AutotunedIndex<flann::L2<kDescType>>(index_params);
522   visual_word_index_.buildIndex(visual_words_);
523 
524   // Initialize a new inverted index.
525   inverted_index_ = InvertedIndexType();
526   inverted_index_.Initialize(NumVisualWords());
527 
528   // Generate descriptor projection matrix.
529   inverted_index_.GenerateHammingEmbeddingProjection();
530 
531   // Learn the Hamming embedding.
532   const int kNumNeighbors = 1;
533   const Eigen::MatrixXi word_ids = FindWordIds(
534       descriptors, kNumNeighbors, options.num_checks, options.num_threads);
535   inverted_index_.ComputeHammingEmbedding(descriptors, word_ids);
536 }
537 
538 template <typename kDescType, int kDescDim, int kEmbeddingDim>
Read(const std::string & path)539 void VisualIndex<kDescType, kDescDim, kEmbeddingDim>::Read(
540     const std::string& path) {
541   long int file_offset = 0;
542 
543   // Read the visual words.
544 
545   {
546     if (visual_words_.ptr() != nullptr) {
547       delete[] visual_words_.ptr();
548     }
549 
550     std::ifstream file(path, std::ios::binary);
551     CHECK(file.is_open()) << path;
552     const uint64_t rows = ReadBinaryLittleEndian<uint64_t>(&file);
553     const uint64_t cols = ReadBinaryLittleEndian<uint64_t>(&file);
554     kDescType* visual_words_data = new kDescType[rows * cols];
555     for (size_t i = 0; i < rows * cols; ++i) {
556       visual_words_data[i] = ReadBinaryLittleEndian<kDescType>(&file);
557     }
558     visual_words_ = flann::Matrix<kDescType>(visual_words_data, rows, cols);
559     file_offset = file.tellg();
560   }
561 
562   // Read the visual words search index.
563 
564   visual_word_index_ =
565       flann::AutotunedIndex<flann::L2<kDescType>>(visual_words_);
566 
567   {
568     FILE* fin = fopen(path.c_str(), "rb");
569     CHECK_NOTNULL(fin);
570     fseek(fin, file_offset, SEEK_SET);
571     visual_word_index_.loadIndex(fin);
572     file_offset = ftell(fin);
573     fclose(fin);
574   }
575 
576   // Read the inverted index.
577 
578   {
579     std::ifstream file(path, std::ios::binary);
580     CHECK(file.is_open()) << path;
581     file.seekg(file_offset, std::ios::beg);
582     inverted_index_.Read(&file);
583   }
584 
585   image_ids_.clear();
586   inverted_index_.GetImageIds(&image_ids_);
587 }
588 
589 template <typename kDescType, int kDescDim, int kEmbeddingDim>
Write(const std::string & path)590 void VisualIndex<kDescType, kDescDim, kEmbeddingDim>::Write(
591     const std::string& path) {
592   // Write the visual words.
593 
594   {
595     CHECK_NOTNULL(visual_words_.ptr());
596     std::ofstream file(path, std::ios::binary);
597     CHECK(file.is_open()) << path;
598     WriteBinaryLittleEndian<uint64_t>(&file, visual_words_.rows);
599     WriteBinaryLittleEndian<uint64_t>(&file, visual_words_.cols);
600     for (size_t i = 0; i < visual_words_.rows * visual_words_.cols; ++i) {
601       WriteBinaryLittleEndian<kDescType>(&file, visual_words_.ptr()[i]);
602     }
603   }
604 
605   // Write the visual words search index.
606 
607   {
608     FILE* fout = fopen(path.c_str(), "ab");
609     CHECK_NOTNULL(fout);
610     visual_word_index_.saveIndex(fout);
611     fclose(fout);
612   }
613 
614   // Write the inverted index.
615 
616   {
617     std::ofstream file(path, std::ios::binary | std::ios::app);
618     CHECK(file.is_open()) << path;
619     inverted_index_.Write(&file);
620   }
621 }
622 
623 template <typename kDescType, int kDescDim, int kEmbeddingDim>
Quantize(const BuildOptions & options,const DescType & descriptors)624 void VisualIndex<kDescType, kDescDim, kEmbeddingDim>::Quantize(
625     const BuildOptions& options, const DescType& descriptors) {
626   static_assert(DescType::IsRowMajor, "Descriptors must be row-major.");
627 
628   CHECK_GE(options.num_visual_words, options.branching);
629   CHECK_GE(descriptors.rows(), options.num_visual_words);
630 
631   const flann::Matrix<kDescType> descriptor_matrix(
632       const_cast<kDescType*>(descriptors.data()), descriptors.rows(),
633       descriptors.cols());
634 
635   std::vector<typename flann::L2<kDescType>::ResultType> centers_data(
636       options.num_visual_words * descriptors.cols());
637   flann::Matrix<typename flann::L2<kDescType>::ResultType> centers(
638       centers_data.data(), options.num_visual_words, descriptors.cols());
639 
640   flann::KMeansIndexParams index_params;
641   index_params["branching"] = options.branching;
642   index_params["iterations"] = options.num_iterations;
643   index_params["centers_init"] = flann::FLANN_CENTERS_KMEANSPP;
644   const int num_centers = flann::hierarchicalClustering<flann::L2<kDescType>>(
645       descriptor_matrix, centers, index_params);
646 
647   CHECK_LE(num_centers, options.num_visual_words);
648 
649   const size_t visual_word_data_size = num_centers * descriptors.cols();
650   kDescType* visual_words_data = new kDescType[visual_word_data_size];
651   for (size_t i = 0; i < visual_word_data_size; ++i) {
652     if (std::is_integral<kDescType>::value) {
653       visual_words_data[i] = std::round(centers_data[i]);
654     } else {
655       visual_words_data[i] = centers_data[i];
656     }
657   }
658 
659   if (visual_words_.ptr() != nullptr) {
660     delete[] visual_words_.ptr();
661   }
662 
663   visual_words_ = flann::Matrix<kDescType>(visual_words_data, num_centers,
664                                            descriptors.cols());
665 }
666 
667 template <typename kDescType, int kDescDim, int kEmbeddingDim>
QueryAndFindWordIds(const QueryOptions & options,const DescType & descriptors,std::vector<ImageScore> * image_scores,Eigen::MatrixXi * word_ids)668 void VisualIndex<kDescType, kDescDim, kEmbeddingDim>::QueryAndFindWordIds(
669     const QueryOptions& options, const DescType& descriptors,
670     std::vector<ImageScore>* image_scores, Eigen::MatrixXi* word_ids) const {
671   CHECK(prepared_);
672 
673   if (descriptors.rows() == 0) {
674     image_scores->clear();
675     return;
676   }
677 
678   *word_ids = FindWordIds(descriptors, options.num_neighbors,
679                           options.num_checks, options.num_threads);
680   inverted_index_.Query(descriptors, *word_ids, image_scores);
681 
682   auto SortFunc = [](const ImageScore& score1, const ImageScore& score2) {
683     return score1.score > score2.score;
684   };
685 
686   size_t num_images = image_scores->size();
687   if (options.max_num_images >= 0) {
688     num_images = std::min<size_t>(image_scores->size(), options.max_num_images);
689   }
690 
691   if (num_images == image_scores->size()) {
692     std::sort(image_scores->begin(), image_scores->end(), SortFunc);
693   } else {
694     std::partial_sort(image_scores->begin(), image_scores->begin() + num_images,
695                       image_scores->end(), SortFunc);
696     image_scores->resize(num_images);
697   }
698 }
699 
700 template <typename kDescType, int kDescDim, int kEmbeddingDim>
FindWordIds(const DescType & descriptors,const int num_neighbors,const int num_checks,const int num_threads)701 Eigen::MatrixXi VisualIndex<kDescType, kDescDim, kEmbeddingDim>::FindWordIds(
702     const DescType& descriptors, const int num_neighbors, const int num_checks,
703     const int num_threads) const {
704   static_assert(DescType::IsRowMajor, "Descriptors must be row-major");
705 
706   CHECK_GT(descriptors.rows(), 0);
707   CHECK_GT(num_neighbors, 0);
708 
709   Eigen::Matrix<size_t, Eigen::Dynamic, Eigen::Dynamic, Eigen::RowMajor>
710       word_ids(descriptors.rows(), num_neighbors);
711   word_ids.setConstant(InvertedIndexType::kInvalidWordId);
712   flann::Matrix<size_t> indices(word_ids.data(), descriptors.rows(),
713                                 num_neighbors);
714 
715   Eigen::Matrix<typename flann::L2<kDescType>::ResultType, Eigen::Dynamic,
716                 Eigen::Dynamic, Eigen::RowMajor>
717       distance_matrix(descriptors.rows(), num_neighbors);
718   flann::Matrix<typename flann::L2<kDescType>::ResultType> distances(
719       distance_matrix.data(), descriptors.rows(), num_neighbors);
720 
721   const flann::Matrix<kDescType> query(
722       const_cast<kDescType*>(descriptors.data()), descriptors.rows(),
723       descriptors.cols());
724 
725   flann::SearchParams search_params(num_checks);
726   if (num_threads < 0) {
727     search_params.cores = std::thread::hardware_concurrency();
728   } else {
729     search_params.cores = num_threads;
730   }
731   if (search_params.cores <= 0) {
732     search_params.cores = 1;
733   }
734 
735   visual_word_index_.knnSearch(query, indices, distances, num_neighbors,
736                                search_params);
737 
738   return word_ids.cast<int>();
739 }
740 
741 }  // namespace retrieval
742 }  // namespace colmap
743 
744 #endif  // COLMAP_SRC_RETRIEVAL_VISUAL_INDEX_H_
745