1 /** 2 * Copyright (c) Facebook, Inc. and its affiliates. 3 * 4 * This source code is licensed under the MIT license found in the 5 * LICENSE file in the root directory of this source tree. 6 */ 7 8 #pragma once 9 10 #include <stdint.h> 11 12 #include <random> 13 #include <string> 14 #include <unordered_map> 15 #include <vector> 16 17 #include <faiss/impl/AdditiveQuantizer.h> 18 #include <faiss/utils/utils.h> 19 20 namespace faiss { 21 22 /** Implementation of LSQ/LSQ++ described in the following two papers: 23 * 24 * Revisiting additive quantization 25 * Julieta Martinez, et al. ECCV 2016 26 * 27 * LSQ++: Lower running time and higher recall in multi-codebook quantization 28 * Julieta Martinez, et al. ECCV 2018 29 * 30 * This implementation is mostly translated from the Julia implementations 31 * by Julieta Martinez: 32 * (https://github.com/una-dinosauria/local-search-quantization, 33 * https://github.com/una-dinosauria/Rayuela.jl) 34 * 35 * The trained codes are stored in `codebooks` which is called 36 * `centroids` in PQ and RQ. 37 */ 38 39 struct LocalSearchQuantizer : AdditiveQuantizer { 40 size_t K; ///< number of codes per codebook 41 42 size_t train_iters; ///< number of iterations in training 43 44 size_t encode_ils_iters; ///< iterations of local search in encoding 45 size_t train_ils_iters; ///< iterations of local search in training 46 size_t icm_iters; ///< number of iterations in icm 47 48 float p; ///< temperature factor 49 float lambd; ///< regularization factor 50 51 size_t chunk_size; ///< nb of vectors to encode at a time 52 53 int random_seed; ///< seed for random generator 54 size_t nperts; ///< number of perturbation in each code 55 56 LocalSearchQuantizer( 57 size_t d, /* dimensionality of the input vectors */ 58 size_t M, /* number of subquantizers */ 59 size_t nbits); /* number of bit per subvector index */ 60 61 // Train the local search quantizer 62 void train(size_t n, const float* x) override; 63 64 /** Encode a set of vectors 65 * 66 * @param x vectors to encode, size n * d 67 * @param codes output codes, size n * code_size 68 */ 69 void compute_codes(const float* x, uint8_t* codes, size_t n) const override; 70 71 /** Update codebooks given encodings 72 * 73 * @param x training vectors, size n * d 74 * @param codes encoded training vectors, size n * M 75 */ 76 void update_codebooks(const float* x, const int32_t* codes, size_t n); 77 78 /** Encode vectors given codebooks using iterative conditional mode (icm). 79 * 80 * @param x vectors to encode, size n * d 81 * @param codes output codes, size n * M 82 * @param ils_iters number of iterations of iterative local search 83 */ 84 void icm_encode( 85 const float* x, 86 int32_t* codes, 87 size_t n, 88 size_t ils_iters, 89 std::mt19937& gen) const; 90 91 void icm_encode_partial( 92 size_t index, 93 const float* x, 94 int32_t* codes, 95 size_t n, 96 const float* binaries, 97 size_t ils_iters, 98 std::mt19937& gen) const; 99 100 void icm_encode_step( 101 const float* unaries, 102 const float* binaries, 103 int32_t* codes, 104 size_t n) const; 105 106 /** Add some perturbation to codebooks 107 * 108 * @param T temperature of simulated annealing 109 * @param stddev standard derivations of each dimension in training data 110 */ 111 void perturb_codebooks( 112 float T, 113 const std::vector<float>& stddev, 114 std::mt19937& gen); 115 116 /** Add some perturbation to codes 117 * 118 * @param codes codes to be perturbed, size n * M 119 */ 120 void perturb_codes(int32_t* codes, size_t n, std::mt19937& gen) const; 121 122 /** Compute binary terms 123 * 124 * @param binaries binary terms, size M * M * K * K 125 */ 126 void compute_binary_terms(float* binaries) const; 127 128 /** Compute unary terms 129 * 130 * @param x vectors to encode, size n * d 131 * @param unaries unary terms, size n * M * K 132 */ 133 void compute_unary_terms(const float* x, float* unaries, size_t n) const; 134 135 /** Helper function to compute reconstruction error 136 * 137 * @param x vectors to encode, size n * d 138 * @param codes encoded codes, size n * M 139 * @param objs if it is not null, store reconstruction 140 error of each vector into it, size n 141 */ 142 float evaluate( 143 const int32_t* codes, 144 const float* x, 145 size_t n, 146 float* objs = nullptr) const; 147 }; 148 149 /** A helper struct to count consuming time during training. 150 * It is NOT thread-safe. 151 */ 152 struct LSQTimer { 153 std::unordered_map<std::string, double> duration; 154 std::unordered_map<std::string, double> t0; 155 std::unordered_map<std::string, bool> started; 156 LSQTimerLSQTimer157 LSQTimer() { 158 reset(); 159 } 160 161 double get(const std::string& name); 162 163 void start(const std::string& name); 164 165 void end(const std::string& name); 166 167 void reset(); 168 }; 169 170 FAISS_API extern LSQTimer lsq_timer; ///< timer to count consuming time 171 172 } // namespace faiss 173