1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include <algorithm>
11 #include <type_traits>
12 #include <vector>
13 
14 #include <faiss/utils/Heap.h>
15 #include <faiss/utils/simdlib.h>
16 
17 #include <faiss/impl/platform_macros.h>
18 #include <faiss/utils/AlignedTable.h>
19 #include <faiss/utils/partitioning.h>
20 
21 /** This file contains callbacks for kernels that compute distances.
22  *
23  * The SIMDResultHandler object is intended to be templated and inlined.
24  * Methods:
25  * - handle(): called when 32 distances are computed and provided in two
26  *   simd16uint16. (q, b) indicate which entry it is in the block.
27  * - set_block_origin(): set the sub-matrix that is being computed
28  */
29 
30 namespace faiss {
31 
32 namespace simd_result_handlers {
33 
34 /** Dummy structure that just computes a checksum on results
35  * (to avoid the computation to be optimized away) */
36 struct DummyResultHandler {
37     size_t cs = 0;
38 
handleDummyResultHandler39     void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
40         cs += q * 123 + b * 789 + d0.get_scalar_0() + d1.get_scalar_0();
41     }
42 
set_block_originDummyResultHandler43     void set_block_origin(size_t, size_t) {}
44 };
45 
46 /** memorize results in a nq-by-nb matrix.
47  *
48  * j0 is the current upper-left block of the matrix
49  */
50 struct StoreResultHandler {
51     uint16_t* data;
52     size_t ld; // total number of columns
53     size_t i0 = 0;
54     size_t j0 = 0;
55 
StoreResultHandlerStoreResultHandler56     StoreResultHandler(uint16_t* data, size_t ld) : data(data), ld(ld) {}
57 
handleStoreResultHandler58     void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
59         size_t ofs = (q + i0) * ld + j0 + b * 32;
60         d0.store(data + ofs);
61         d1.store(data + ofs + 16);
62     }
63 
set_block_originStoreResultHandler64     void set_block_origin(size_t i0, size_t j0) {
65         this->i0 = i0;
66         this->j0 = j0;
67     }
68 };
69 
70 /** stores results in fixed-size matrix. */
71 template <int NQ, int BB>
72 struct FixedStorageHandler {
73     simd16uint16 dis[NQ][BB];
74     int i0 = 0;
75 
handleFixedStorageHandler76     void handle(int q, int b, simd16uint16 d0, simd16uint16 d1) {
77         dis[q + i0][2 * b] = d0;
78         dis[q + i0][2 * b + 1] = d1;
79     }
80 
set_block_originFixedStorageHandler81     void set_block_origin(size_t i0, size_t j0) {
82         this->i0 = i0;
83         assert(j0 == 0);
84     }
85 
86     template <class OtherResultHandler>
to_other_handlerFixedStorageHandler87     void to_other_handler(OtherResultHandler& other) const {
88         for (int q = 0; q < NQ; q++) {
89             for (int b = 0; b < BB; b += 2) {
90                 other.handle(q, b / 2, dis[q][b], dis[q][b + 1]);
91             }
92         }
93     }
94 };
95 
96 /** Record origin of current block  */
97 template <class C, bool with_id_map>
98 struct SIMDResultHandler {
99     using TI = typename C::TI;
100 
101     bool disable = false;
102 
103     int64_t i0 = 0; // query origin
104     int64_t j0 = 0; // db origin
105     size_t ntotal;  // ignore excess elements after ntotal
106 
107     /// these fields are used mainly for the IVF variants (with_id_map=true)
108     const TI* id_map;      // map offset in invlist to vector id
109     const int* q_map;      // map q to global query
110     const uint16_t* dbias; // table of biases to add to each query
111 
SIMDResultHandlerSIMDResultHandler112     explicit SIMDResultHandler(size_t ntotal)
113             : ntotal(ntotal), id_map(nullptr), q_map(nullptr), dbias(nullptr) {}
114 
set_block_originSIMDResultHandler115     void set_block_origin(size_t i0, size_t j0) {
116         this->i0 = i0;
117         this->j0 = j0;
118     }
119 
120     // adjust handler data for IVF.
adjust_with_originSIMDResultHandler121     void adjust_with_origin(size_t& q, simd16uint16& d0, simd16uint16& d1) {
122         q += i0;
123 
124         if (dbias) {
125             simd16uint16 dbias16(dbias[q]);
126             d0 += dbias16;
127             d1 += dbias16;
128         }
129 
130         if (with_id_map) { // FIXME test on q_map instead
131             q = q_map[q];
132         }
133     }
134 
135     // compute and adjust idx
adjust_idSIMDResultHandler136     int64_t adjust_id(size_t b, size_t j) {
137         int64_t idx = j0 + 32 * b + j;
138         if (with_id_map) {
139             idx = id_map[idx];
140         }
141         return idx;
142     }
143 
144     /// return binary mask of elements below thr in (d0, d1)
145     /// inverse_test returns elements above
get_lt_maskSIMDResultHandler146     uint32_t get_lt_mask(
147             uint16_t thr,
148             size_t b,
149             simd16uint16 d0,
150             simd16uint16 d1) {
151         simd16uint16 thr16(thr);
152         uint32_t lt_mask;
153 
154         constexpr bool keep_min = C::is_max;
155         if (keep_min) {
156             lt_mask = ~cmp_ge32(d0, d1, thr16);
157         } else {
158             lt_mask = ~cmp_le32(d0, d1, thr16);
159         }
160 
161         if (lt_mask == 0) {
162             return 0;
163         }
164         uint64_t idx = j0 + b * 32;
165         if (idx + 32 > ntotal) {
166             if (idx >= ntotal) {
167                 return 0;
168             }
169             int nbit = (ntotal - idx);
170             lt_mask &= (uint32_t(1) << nbit) - 1;
171         }
172         return lt_mask;
173     }
174 
175     virtual void to_flat_arrays(
176             float* distances,
177             int64_t* labels,
178             const float* normalizers = nullptr) = 0;
179 
~SIMDResultHandlerSIMDResultHandler180     virtual ~SIMDResultHandler() {}
181 };
182 
183 /** Special version for k=1 */
184 template <class C, bool with_id_map = false>
185 struct SingleResultHandler : SIMDResultHandler<C, with_id_map> {
186     using T = typename C::T;
187     using TI = typename C::TI;
188 
189     struct Result {
190         T val;
191         TI id;
192     };
193     std::vector<Result> results;
194 
SingleResultHandlerSingleResultHandler195     SingleResultHandler(size_t nq, size_t ntotal)
196             : SIMDResultHandler<C, with_id_map>(ntotal), results(nq) {
197         for (int i = 0; i < nq; i++) {
198             Result res = {C::neutral(), -1};
199             results[i] = res;
200         }
201     }
202 
handleSingleResultHandler203     void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
204         if (this->disable) {
205             return;
206         }
207 
208         this->adjust_with_origin(q, d0, d1);
209 
210         Result& res = results[q];
211         uint32_t lt_mask = this->get_lt_mask(res.val, b, d0, d1);
212         if (!lt_mask) {
213             return;
214         }
215 
216         ALIGNED(32) uint16_t d32tab[32];
217         d0.store(d32tab);
218         d1.store(d32tab + 16);
219 
220         while (lt_mask) {
221             // find first non-zero
222             int j = __builtin_ctz(lt_mask);
223             lt_mask -= 1 << j;
224             T dis = d32tab[j];
225             if (C::cmp(res.val, dis)) {
226                 res.val = dis;
227                 res.id = this->adjust_id(b, j);
228             }
229         }
230     }
231 
232     void to_flat_arrays(
233             float* distances,
234             int64_t* labels,
235             const float* normalizers = nullptr) override {
236         for (int q = 0; q < results.size(); q++) {
237             if (!normalizers) {
238                 distances[q] = results[q].val;
239             } else {
240                 float one_a = 1 / normalizers[2 * q];
241                 float b = normalizers[2 * q + 1];
242                 distances[q] = b + results[q].val * one_a;
243             }
244             labels[q] = results[q].id;
245         }
246     }
247 };
248 
249 /** Structure that collects results in a min- or max-heap */
250 template <class C, bool with_id_map = false>
251 struct HeapHandler : SIMDResultHandler<C, with_id_map> {
252     using T = typename C::T;
253     using TI = typename C::TI;
254 
255     int nq;
256     T* heap_dis_tab;
257     TI* heap_ids_tab;
258 
259     int64_t k; // number of results to keep
260 
HeapHandlerHeapHandler261     HeapHandler(
262             int nq,
263             T* heap_dis_tab,
264             TI* heap_ids_tab,
265             size_t k,
266             size_t ntotal)
267             : SIMDResultHandler<C, with_id_map>(ntotal),
268               nq(nq),
269               heap_dis_tab(heap_dis_tab),
270               heap_ids_tab(heap_ids_tab),
271               k(k) {
272         for (int q = 0; q < nq; q++) {
273             T* heap_dis_in = heap_dis_tab + q * k;
274             TI* heap_ids_in = heap_ids_tab + q * k;
275             heap_heapify<C>(k, heap_dis_in, heap_ids_in);
276         }
277     }
278 
handleHeapHandler279     void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
280         if (this->disable) {
281             return;
282         }
283 
284         this->adjust_with_origin(q, d0, d1);
285 
286         T* heap_dis = heap_dis_tab + q * k;
287         TI* heap_ids = heap_ids_tab + q * k;
288 
289         uint16_t cur_thresh =
290                 heap_dis[0] < 65536 ? (uint16_t)(heap_dis[0]) : 0xffff;
291 
292         // here we handle the reverse comparison case as well
293         uint32_t lt_mask = this->get_lt_mask(cur_thresh, b, d0, d1);
294 
295         if (!lt_mask) {
296             return;
297         }
298 
299         ALIGNED(32) uint16_t d32tab[32];
300         d0.store(d32tab);
301         d1.store(d32tab + 16);
302 
303         while (lt_mask) {
304             // find first non-zero
305             int j = __builtin_ctz(lt_mask);
306             lt_mask -= 1 << j;
307             T dis = d32tab[j];
308             if (C::cmp(heap_dis[0], dis)) {
309                 int64_t idx = this->adjust_id(b, j);
310                 heap_pop<C>(k, heap_dis, heap_ids);
311                 heap_push<C>(k, heap_dis, heap_ids, dis, idx);
312             }
313         }
314     }
315 
316     void to_flat_arrays(
317             float* distances,
318             int64_t* labels,
319             const float* normalizers = nullptr) override {
320         for (int q = 0; q < nq; q++) {
321             T* heap_dis_in = heap_dis_tab + q * k;
322             TI* heap_ids_in = heap_ids_tab + q * k;
323             heap_reorder<C>(k, heap_dis_in, heap_ids_in);
324             int64_t* heap_ids = labels + q * k;
325             float* heap_dis = distances + q * k;
326 
327             float one_a = 1.0, b = 0.0;
328             if (normalizers) {
329                 one_a = 1 / normalizers[2 * q];
330                 b = normalizers[2 * q + 1];
331             }
332             for (int j = 0; j < k; j++) {
333                 heap_ids[j] = heap_ids_in[j];
334                 heap_dis[j] = heap_dis_in[j] * one_a + b;
335             }
336         }
337     }
338 };
339 
340 /** Simple top-N implementation using a reservoir.
341  *
342  * Results are stored when they are below the threshold until the capacity is
343  * reached. Then a partition sort is used to update the threshold. */
344 
345 namespace {
346 
get_cy()347 uint64_t get_cy() {
348 #ifdef MICRO_BENCHMARK
349     uint32_t high, low;
350     asm volatile("rdtsc \n\t" : "=a"(low), "=d"(high));
351     return ((uint64_t)high << 32) | (low);
352 #else
353     return 0;
354 #endif
355 }
356 
357 } // anonymous namespace
358 
359 template <class C>
360 struct ReservoirTopN {
361     using T = typename C::T;
362     using TI = typename C::TI;
363 
364     T* vals;
365     TI* ids;
366 
367     size_t i;        // number of stored elements
368     size_t n;        // number of requested elements
369     size_t capacity; // size of storage
370     size_t cycles = 0;
371 
372     T threshold; // current threshold
373 
ReservoirTopNReservoirTopN374     ReservoirTopN(size_t n, size_t capacity, T* vals, TI* ids)
375             : vals(vals), ids(ids), i(0), n(n), capacity(capacity) {
376         assert(n < capacity);
377         threshold = C::neutral();
378     }
379 
addReservoirTopN380     void add(T val, TI id) {
381         if (C::cmp(threshold, val)) {
382             if (i == capacity) {
383                 shrink_fuzzy();
384             }
385             vals[i] = val;
386             ids[i] = id;
387             i++;
388         }
389     }
390 
391     /// shrink number of stored elements to n
shrink_xxReservoirTopN392     void shrink_xx() {
393         uint64_t t0 = get_cy();
394         qselect(vals, ids, i, n);
395         i = n; // forget all elements above i = n
396         threshold = C::Crev::neutral();
397         for (size_t j = 0; j < n; j++) {
398             if (C::cmp(vals[j], threshold)) {
399                 threshold = vals[j];
400             }
401         }
402         cycles += get_cy() - t0;
403     }
404 
shrinkReservoirTopN405     void shrink() {
406         uint64_t t0 = get_cy();
407         threshold = partition<C>(vals, ids, i, n);
408         i = n;
409         cycles += get_cy() - t0;
410     }
411 
shrink_fuzzyReservoirTopN412     void shrink_fuzzy() {
413         uint64_t t0 = get_cy();
414         assert(i == capacity);
415         threshold = partition_fuzzy<C>(
416                 vals, ids, capacity, n, (capacity + n) / 2, &i);
417         cycles += get_cy() - t0;
418     }
419 };
420 
421 /** Handler built from several ReservoirTopN (one per query) */
422 template <class C, bool with_id_map = false>
423 struct ReservoirHandler : SIMDResultHandler<C, with_id_map> {
424     using T = typename C::T;
425     using TI = typename C::TI;
426 
427     size_t capacity; // rounded up to multiple of 16
428     std::vector<TI> all_ids;
429     AlignedTable<T> all_vals;
430 
431     std::vector<ReservoirTopN<C>> reservoirs;
432 
433     uint64_t times[4];
434 
ReservoirHandlerReservoirHandler435     ReservoirHandler(size_t nq, size_t ntotal, size_t n, size_t capacity_in)
436             : SIMDResultHandler<C, with_id_map>(ntotal),
437               capacity((capacity_in + 15) & ~15),
438               all_ids(nq * capacity),
439               all_vals(nq * capacity) {
440         assert(capacity % 16 == 0);
441         for (size_t i = 0; i < nq; i++) {
442             reservoirs.emplace_back(
443                     n,
444                     capacity,
445                     all_vals.get() + i * capacity,
446                     all_ids.data() + i * capacity);
447         }
448         times[0] = times[1] = times[2] = times[3] = 0;
449     }
450 
handleReservoirHandler451     void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
452         uint64_t t0 = get_cy();
453         if (this->disable) {
454             return;
455         }
456         this->adjust_with_origin(q, d0, d1);
457 
458         ReservoirTopN<C>& res = reservoirs[q];
459         uint32_t lt_mask = this->get_lt_mask(res.threshold, b, d0, d1);
460         uint64_t t1 = get_cy();
461         times[0] += t1 - t0;
462 
463         if (!lt_mask) {
464             return;
465         }
466         ALIGNED(32) uint16_t d32tab[32];
467         d0.store(d32tab);
468         d1.store(d32tab + 16);
469 
470         while (lt_mask) {
471             // find first non-zero
472             int j = __builtin_ctz(lt_mask);
473             lt_mask -= 1 << j;
474             T dis = d32tab[j];
475             res.add(dis, this->adjust_id(b, j));
476         }
477         times[1] += get_cy() - t1;
478     }
479 
480     void to_flat_arrays(
481             float* distances,
482             int64_t* labels,
483             const float* normalizers = nullptr) override {
484         using Cf = typename std::conditional<
485                 C::is_max,
486                 CMax<float, int64_t>,
487                 CMin<float, int64_t>>::type;
488 
489         uint64_t t0 = get_cy();
490         uint64_t t3 = 0;
491         std::vector<int> perm(reservoirs[0].n);
492         for (int q = 0; q < reservoirs.size(); q++) {
493             ReservoirTopN<C>& res = reservoirs[q];
494             size_t n = res.n;
495 
496             if (res.i > res.n) {
497                 res.shrink();
498             }
499             int64_t* heap_ids = labels + q * n;
500             float* heap_dis = distances + q * n;
501 
502             float one_a = 1.0, b = 0.0;
503             if (normalizers) {
504                 one_a = 1 / normalizers[2 * q];
505                 b = normalizers[2 * q + 1];
506             }
507             for (int i = 0; i < res.i; i++) {
508                 perm[i] = i;
509             }
510             // indirect sort of result arrays
511             std::sort(perm.begin(), perm.begin() + res.i, [&res](int i, int j) {
512                 return C::cmp(res.vals[j], res.vals[i]);
513             });
514             for (int i = 0; i < res.i; i++) {
515                 heap_dis[i] = res.vals[perm[i]] * one_a + b;
516                 heap_ids[i] = res.ids[perm[i]];
517             }
518 
519             // possibly add empty results
520             heap_heapify<Cf>(n - res.i, heap_dis + res.i, heap_ids + res.i);
521 
522             t3 += res.cycles;
523         }
524         times[2] += get_cy() - t0;
525         times[3] += t3;
526     }
527 };
528 
529 } // namespace simd_result_handlers
530 
531 } // namespace faiss
532