1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include <stdint.h>
11 
12 #include <random>
13 #include <string>
14 #include <unordered_map>
15 #include <vector>
16 
17 #include <faiss/impl/AdditiveQuantizer.h>
18 #include <faiss/utils/utils.h>
19 
20 namespace faiss {
21 
22 /** Implementation of LSQ/LSQ++ described in the following two papers:
23  *
24  * Revisiting additive quantization
25  * Julieta Martinez, et al. ECCV 2016
26  *
27  * LSQ++: Lower running time and higher recall in multi-codebook quantization
28  * Julieta Martinez, et al. ECCV 2018
29  *
30  * This implementation is mostly translated from the Julia implementations
31  * by Julieta Martinez:
32  * (https://github.com/una-dinosauria/local-search-quantization,
33  *  https://github.com/una-dinosauria/Rayuela.jl)
34  *
35  * The trained codes are stored in `codebooks` which is called
36  * `centroids` in PQ and RQ.
37  */
38 
39 struct LocalSearchQuantizer : AdditiveQuantizer {
40     size_t K; ///< number of codes per codebook
41 
42     size_t train_iters; ///< number of iterations in training
43 
44     size_t encode_ils_iters; ///< iterations of local search in encoding
45     size_t train_ils_iters;  ///< iterations of local search in training
46     size_t icm_iters;        ///< number of iterations in icm
47 
48     float p;     ///< temperature factor
49     float lambd; ///< regularization factor
50 
51     size_t chunk_size; ///< nb of vectors to encode at a time
52 
53     int random_seed; ///< seed for random generator
54     size_t nperts;   ///< number of perturbation in each code
55 
56     LocalSearchQuantizer(
57             size_t d,      /* dimensionality of the input vectors */
58             size_t M,      /* number of subquantizers */
59             size_t nbits); /* number of bit per subvector index */
60 
61     // Train the local search quantizer
62     void train(size_t n, const float* x) override;
63 
64     /** Encode a set of vectors
65      *
66      * @param x      vectors to encode, size n * d
67      * @param codes  output codes, size n * code_size
68      */
69     void compute_codes(const float* x, uint8_t* codes, size_t n) const override;
70 
71     /** Update codebooks given encodings
72      *
73      * @param x      training vectors, size n * d
74      * @param codes  encoded training vectors, size n * M
75      */
76     void update_codebooks(const float* x, const int32_t* codes, size_t n);
77 
78     /** Encode vectors given codebooks using iterative conditional mode (icm).
79      *
80      * @param x      vectors to encode, size n * d
81      * @param codes  output codes, size n * M
82      * @param ils_iters number of iterations of iterative local search
83      */
84     void icm_encode(
85             const float* x,
86             int32_t* codes,
87             size_t n,
88             size_t ils_iters,
89             std::mt19937& gen) const;
90 
91     void icm_encode_partial(
92             size_t index,
93             const float* x,
94             int32_t* codes,
95             size_t n,
96             const float* binaries,
97             size_t ils_iters,
98             std::mt19937& gen) const;
99 
100     void icm_encode_step(
101             const float* unaries,
102             const float* binaries,
103             int32_t* codes,
104             size_t n) const;
105 
106     /** Add some perturbation to codebooks
107      *
108      * @param T         temperature of simulated annealing
109      * @param stddev    standard derivations of each dimension in training data
110      */
111     void perturb_codebooks(
112             float T,
113             const std::vector<float>& stddev,
114             std::mt19937& gen);
115 
116     /** Add some perturbation to codes
117      *
118      * @param codes codes to be perturbed, size n * M
119      */
120     void perturb_codes(int32_t* codes, size_t n, std::mt19937& gen) const;
121 
122     /** Compute binary terms
123      *
124      * @param binaries binary terms, size M * M * K * K
125      */
126     void compute_binary_terms(float* binaries) const;
127 
128     /** Compute unary terms
129      *
130      * @param x       vectors to encode, size n * d
131      * @param unaries unary terms, size n * M * K
132      */
133     void compute_unary_terms(const float* x, float* unaries, size_t n) const;
134 
135     /** Helper function to compute reconstruction error
136      *
137      * @param x     vectors to encode, size n * d
138      * @param codes encoded codes, size n * M
139      * @param objs  if it is not null, store reconstruction
140                     error of each vector into it, size n
141      */
142     float evaluate(
143             const int32_t* codes,
144             const float* x,
145             size_t n,
146             float* objs = nullptr) const;
147 };
148 
149 /** A helper struct to count consuming time during training.
150  *  It is NOT thread-safe.
151  */
152 struct LSQTimer {
153     std::unordered_map<std::string, double> duration;
154     std::unordered_map<std::string, double> t0;
155     std::unordered_map<std::string, bool> started;
156 
LSQTimerLSQTimer157     LSQTimer() {
158         reset();
159     }
160 
161     double get(const std::string& name);
162 
163     void start(const std::string& name);
164 
165     void end(const std::string& name);
166 
167     void reset();
168 };
169 
170 FAISS_API extern LSQTimer lsq_timer; ///< timer to count consuming time
171 
172 } // namespace faiss
173