1 /**
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 *
4 * This source code is licensed under the MIT license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
8 #pragma once
9
10 #include <algorithm>
11 #include <type_traits>
12 #include <vector>
13
14 #include <faiss/utils/Heap.h>
15 #include <faiss/utils/simdlib.h>
16
17 #include <faiss/impl/platform_macros.h>
18 #include <faiss/utils/AlignedTable.h>
19 #include <faiss/utils/partitioning.h>
20
21 /** This file contains callbacks for kernels that compute distances.
22 *
23 * The SIMDResultHandler object is intended to be templated and inlined.
24 * Methods:
25 * - handle(): called when 32 distances are computed and provided in two
26 * simd16uint16. (q, b) indicate which entry it is in the block.
27 * - set_block_origin(): set the sub-matrix that is being computed
28 */
29
30 namespace faiss {
31
32 namespace simd_result_handlers {
33
34 /** Dummy structure that just computes a checksum on results
35 * (to avoid the computation to be optimized away) */
36 struct DummyResultHandler {
37 size_t cs = 0;
38
handleDummyResultHandler39 void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
40 cs += q * 123 + b * 789 + d0.get_scalar_0() + d1.get_scalar_0();
41 }
42
set_block_originDummyResultHandler43 void set_block_origin(size_t, size_t) {}
44 };
45
46 /** memorize results in a nq-by-nb matrix.
47 *
48 * j0 is the current upper-left block of the matrix
49 */
50 struct StoreResultHandler {
51 uint16_t* data;
52 size_t ld; // total number of columns
53 size_t i0 = 0;
54 size_t j0 = 0;
55
StoreResultHandlerStoreResultHandler56 StoreResultHandler(uint16_t* data, size_t ld) : data(data), ld(ld) {}
57
handleStoreResultHandler58 void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
59 size_t ofs = (q + i0) * ld + j0 + b * 32;
60 d0.store(data + ofs);
61 d1.store(data + ofs + 16);
62 }
63
set_block_originStoreResultHandler64 void set_block_origin(size_t i0, size_t j0) {
65 this->i0 = i0;
66 this->j0 = j0;
67 }
68 };
69
70 /** stores results in fixed-size matrix. */
71 template <int NQ, int BB>
72 struct FixedStorageHandler {
73 simd16uint16 dis[NQ][BB];
74 int i0 = 0;
75
handleFixedStorageHandler76 void handle(int q, int b, simd16uint16 d0, simd16uint16 d1) {
77 dis[q + i0][2 * b] = d0;
78 dis[q + i0][2 * b + 1] = d1;
79 }
80
set_block_originFixedStorageHandler81 void set_block_origin(size_t i0, size_t j0) {
82 this->i0 = i0;
83 assert(j0 == 0);
84 }
85
86 template <class OtherResultHandler>
to_other_handlerFixedStorageHandler87 void to_other_handler(OtherResultHandler& other) const {
88 for (int q = 0; q < NQ; q++) {
89 for (int b = 0; b < BB; b += 2) {
90 other.handle(q, b / 2, dis[q][b], dis[q][b + 1]);
91 }
92 }
93 }
94 };
95
96 /** Record origin of current block */
97 template <class C, bool with_id_map>
98 struct SIMDResultHandler {
99 using TI = typename C::TI;
100
101 bool disable = false;
102
103 int64_t i0 = 0; // query origin
104 int64_t j0 = 0; // db origin
105 size_t ntotal; // ignore excess elements after ntotal
106
107 /// these fields are used mainly for the IVF variants (with_id_map=true)
108 const TI* id_map; // map offset in invlist to vector id
109 const int* q_map; // map q to global query
110 const uint16_t* dbias; // table of biases to add to each query
111
SIMDResultHandlerSIMDResultHandler112 explicit SIMDResultHandler(size_t ntotal)
113 : ntotal(ntotal), id_map(nullptr), q_map(nullptr), dbias(nullptr) {}
114
set_block_originSIMDResultHandler115 void set_block_origin(size_t i0, size_t j0) {
116 this->i0 = i0;
117 this->j0 = j0;
118 }
119
120 // adjust handler data for IVF.
adjust_with_originSIMDResultHandler121 void adjust_with_origin(size_t& q, simd16uint16& d0, simd16uint16& d1) {
122 q += i0;
123
124 if (dbias) {
125 simd16uint16 dbias16(dbias[q]);
126 d0 += dbias16;
127 d1 += dbias16;
128 }
129
130 if (with_id_map) { // FIXME test on q_map instead
131 q = q_map[q];
132 }
133 }
134
135 // compute and adjust idx
adjust_idSIMDResultHandler136 int64_t adjust_id(size_t b, size_t j) {
137 int64_t idx = j0 + 32 * b + j;
138 if (with_id_map) {
139 idx = id_map[idx];
140 }
141 return idx;
142 }
143
144 /// return binary mask of elements below thr in (d0, d1)
145 /// inverse_test returns elements above
get_lt_maskSIMDResultHandler146 uint32_t get_lt_mask(
147 uint16_t thr,
148 size_t b,
149 simd16uint16 d0,
150 simd16uint16 d1) {
151 simd16uint16 thr16(thr);
152 uint32_t lt_mask;
153
154 constexpr bool keep_min = C::is_max;
155 if (keep_min) {
156 lt_mask = ~cmp_ge32(d0, d1, thr16);
157 } else {
158 lt_mask = ~cmp_le32(d0, d1, thr16);
159 }
160
161 if (lt_mask == 0) {
162 return 0;
163 }
164 uint64_t idx = j0 + b * 32;
165 if (idx + 32 > ntotal) {
166 if (idx >= ntotal) {
167 return 0;
168 }
169 int nbit = (ntotal - idx);
170 lt_mask &= (uint32_t(1) << nbit) - 1;
171 }
172 return lt_mask;
173 }
174
175 virtual void to_flat_arrays(
176 float* distances,
177 int64_t* labels,
178 const float* normalizers = nullptr) = 0;
179
~SIMDResultHandlerSIMDResultHandler180 virtual ~SIMDResultHandler() {}
181 };
182
183 /** Special version for k=1 */
184 template <class C, bool with_id_map = false>
185 struct SingleResultHandler : SIMDResultHandler<C, with_id_map> {
186 using T = typename C::T;
187 using TI = typename C::TI;
188
189 struct Result {
190 T val;
191 TI id;
192 };
193 std::vector<Result> results;
194
SingleResultHandlerSingleResultHandler195 SingleResultHandler(size_t nq, size_t ntotal)
196 : SIMDResultHandler<C, with_id_map>(ntotal), results(nq) {
197 for (int i = 0; i < nq; i++) {
198 Result res = {C::neutral(), -1};
199 results[i] = res;
200 }
201 }
202
handleSingleResultHandler203 void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
204 if (this->disable) {
205 return;
206 }
207
208 this->adjust_with_origin(q, d0, d1);
209
210 Result& res = results[q];
211 uint32_t lt_mask = this->get_lt_mask(res.val, b, d0, d1);
212 if (!lt_mask) {
213 return;
214 }
215
216 ALIGNED(32) uint16_t d32tab[32];
217 d0.store(d32tab);
218 d1.store(d32tab + 16);
219
220 while (lt_mask) {
221 // find first non-zero
222 int j = __builtin_ctz(lt_mask);
223 lt_mask -= 1 << j;
224 T dis = d32tab[j];
225 if (C::cmp(res.val, dis)) {
226 res.val = dis;
227 res.id = this->adjust_id(b, j);
228 }
229 }
230 }
231
232 void to_flat_arrays(
233 float* distances,
234 int64_t* labels,
235 const float* normalizers = nullptr) override {
236 for (int q = 0; q < results.size(); q++) {
237 if (!normalizers) {
238 distances[q] = results[q].val;
239 } else {
240 float one_a = 1 / normalizers[2 * q];
241 float b = normalizers[2 * q + 1];
242 distances[q] = b + results[q].val * one_a;
243 }
244 labels[q] = results[q].id;
245 }
246 }
247 };
248
249 /** Structure that collects results in a min- or max-heap */
250 template <class C, bool with_id_map = false>
251 struct HeapHandler : SIMDResultHandler<C, with_id_map> {
252 using T = typename C::T;
253 using TI = typename C::TI;
254
255 int nq;
256 T* heap_dis_tab;
257 TI* heap_ids_tab;
258
259 int64_t k; // number of results to keep
260
HeapHandlerHeapHandler261 HeapHandler(
262 int nq,
263 T* heap_dis_tab,
264 TI* heap_ids_tab,
265 size_t k,
266 size_t ntotal)
267 : SIMDResultHandler<C, with_id_map>(ntotal),
268 nq(nq),
269 heap_dis_tab(heap_dis_tab),
270 heap_ids_tab(heap_ids_tab),
271 k(k) {
272 for (int q = 0; q < nq; q++) {
273 T* heap_dis_in = heap_dis_tab + q * k;
274 TI* heap_ids_in = heap_ids_tab + q * k;
275 heap_heapify<C>(k, heap_dis_in, heap_ids_in);
276 }
277 }
278
handleHeapHandler279 void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
280 if (this->disable) {
281 return;
282 }
283
284 this->adjust_with_origin(q, d0, d1);
285
286 T* heap_dis = heap_dis_tab + q * k;
287 TI* heap_ids = heap_ids_tab + q * k;
288
289 uint16_t cur_thresh =
290 heap_dis[0] < 65536 ? (uint16_t)(heap_dis[0]) : 0xffff;
291
292 // here we handle the reverse comparison case as well
293 uint32_t lt_mask = this->get_lt_mask(cur_thresh, b, d0, d1);
294
295 if (!lt_mask) {
296 return;
297 }
298
299 ALIGNED(32) uint16_t d32tab[32];
300 d0.store(d32tab);
301 d1.store(d32tab + 16);
302
303 while (lt_mask) {
304 // find first non-zero
305 int j = __builtin_ctz(lt_mask);
306 lt_mask -= 1 << j;
307 T dis = d32tab[j];
308 if (C::cmp(heap_dis[0], dis)) {
309 int64_t idx = this->adjust_id(b, j);
310 heap_pop<C>(k, heap_dis, heap_ids);
311 heap_push<C>(k, heap_dis, heap_ids, dis, idx);
312 }
313 }
314 }
315
316 void to_flat_arrays(
317 float* distances,
318 int64_t* labels,
319 const float* normalizers = nullptr) override {
320 for (int q = 0; q < nq; q++) {
321 T* heap_dis_in = heap_dis_tab + q * k;
322 TI* heap_ids_in = heap_ids_tab + q * k;
323 heap_reorder<C>(k, heap_dis_in, heap_ids_in);
324 int64_t* heap_ids = labels + q * k;
325 float* heap_dis = distances + q * k;
326
327 float one_a = 1.0, b = 0.0;
328 if (normalizers) {
329 one_a = 1 / normalizers[2 * q];
330 b = normalizers[2 * q + 1];
331 }
332 for (int j = 0; j < k; j++) {
333 heap_ids[j] = heap_ids_in[j];
334 heap_dis[j] = heap_dis_in[j] * one_a + b;
335 }
336 }
337 }
338 };
339
340 /** Simple top-N implementation using a reservoir.
341 *
342 * Results are stored when they are below the threshold until the capacity is
343 * reached. Then a partition sort is used to update the threshold. */
344
345 namespace {
346
get_cy()347 uint64_t get_cy() {
348 #ifdef MICRO_BENCHMARK
349 uint32_t high, low;
350 asm volatile("rdtsc \n\t" : "=a"(low), "=d"(high));
351 return ((uint64_t)high << 32) | (low);
352 #else
353 return 0;
354 #endif
355 }
356
357 } // anonymous namespace
358
359 template <class C>
360 struct ReservoirTopN {
361 using T = typename C::T;
362 using TI = typename C::TI;
363
364 T* vals;
365 TI* ids;
366
367 size_t i; // number of stored elements
368 size_t n; // number of requested elements
369 size_t capacity; // size of storage
370 size_t cycles = 0;
371
372 T threshold; // current threshold
373
ReservoirTopNReservoirTopN374 ReservoirTopN(size_t n, size_t capacity, T* vals, TI* ids)
375 : vals(vals), ids(ids), i(0), n(n), capacity(capacity) {
376 assert(n < capacity);
377 threshold = C::neutral();
378 }
379
addReservoirTopN380 void add(T val, TI id) {
381 if (C::cmp(threshold, val)) {
382 if (i == capacity) {
383 shrink_fuzzy();
384 }
385 vals[i] = val;
386 ids[i] = id;
387 i++;
388 }
389 }
390
391 /// shrink number of stored elements to n
shrink_xxReservoirTopN392 void shrink_xx() {
393 uint64_t t0 = get_cy();
394 qselect(vals, ids, i, n);
395 i = n; // forget all elements above i = n
396 threshold = C::Crev::neutral();
397 for (size_t j = 0; j < n; j++) {
398 if (C::cmp(vals[j], threshold)) {
399 threshold = vals[j];
400 }
401 }
402 cycles += get_cy() - t0;
403 }
404
shrinkReservoirTopN405 void shrink() {
406 uint64_t t0 = get_cy();
407 threshold = partition<C>(vals, ids, i, n);
408 i = n;
409 cycles += get_cy() - t0;
410 }
411
shrink_fuzzyReservoirTopN412 void shrink_fuzzy() {
413 uint64_t t0 = get_cy();
414 assert(i == capacity);
415 threshold = partition_fuzzy<C>(
416 vals, ids, capacity, n, (capacity + n) / 2, &i);
417 cycles += get_cy() - t0;
418 }
419 };
420
421 /** Handler built from several ReservoirTopN (one per query) */
422 template <class C, bool with_id_map = false>
423 struct ReservoirHandler : SIMDResultHandler<C, with_id_map> {
424 using T = typename C::T;
425 using TI = typename C::TI;
426
427 size_t capacity; // rounded up to multiple of 16
428 std::vector<TI> all_ids;
429 AlignedTable<T> all_vals;
430
431 std::vector<ReservoirTopN<C>> reservoirs;
432
433 uint64_t times[4];
434
ReservoirHandlerReservoirHandler435 ReservoirHandler(size_t nq, size_t ntotal, size_t n, size_t capacity_in)
436 : SIMDResultHandler<C, with_id_map>(ntotal),
437 capacity((capacity_in + 15) & ~15),
438 all_ids(nq * capacity),
439 all_vals(nq * capacity) {
440 assert(capacity % 16 == 0);
441 for (size_t i = 0; i < nq; i++) {
442 reservoirs.emplace_back(
443 n,
444 capacity,
445 all_vals.get() + i * capacity,
446 all_ids.data() + i * capacity);
447 }
448 times[0] = times[1] = times[2] = times[3] = 0;
449 }
450
handleReservoirHandler451 void handle(size_t q, size_t b, simd16uint16 d0, simd16uint16 d1) {
452 uint64_t t0 = get_cy();
453 if (this->disable) {
454 return;
455 }
456 this->adjust_with_origin(q, d0, d1);
457
458 ReservoirTopN<C>& res = reservoirs[q];
459 uint32_t lt_mask = this->get_lt_mask(res.threshold, b, d0, d1);
460 uint64_t t1 = get_cy();
461 times[0] += t1 - t0;
462
463 if (!lt_mask) {
464 return;
465 }
466 ALIGNED(32) uint16_t d32tab[32];
467 d0.store(d32tab);
468 d1.store(d32tab + 16);
469
470 while (lt_mask) {
471 // find first non-zero
472 int j = __builtin_ctz(lt_mask);
473 lt_mask -= 1 << j;
474 T dis = d32tab[j];
475 res.add(dis, this->adjust_id(b, j));
476 }
477 times[1] += get_cy() - t1;
478 }
479
480 void to_flat_arrays(
481 float* distances,
482 int64_t* labels,
483 const float* normalizers = nullptr) override {
484 using Cf = typename std::conditional<
485 C::is_max,
486 CMax<float, int64_t>,
487 CMin<float, int64_t>>::type;
488
489 uint64_t t0 = get_cy();
490 uint64_t t3 = 0;
491 std::vector<int> perm(reservoirs[0].n);
492 for (int q = 0; q < reservoirs.size(); q++) {
493 ReservoirTopN<C>& res = reservoirs[q];
494 size_t n = res.n;
495
496 if (res.i > res.n) {
497 res.shrink();
498 }
499 int64_t* heap_ids = labels + q * n;
500 float* heap_dis = distances + q * n;
501
502 float one_a = 1.0, b = 0.0;
503 if (normalizers) {
504 one_a = 1 / normalizers[2 * q];
505 b = normalizers[2 * q + 1];
506 }
507 for (int i = 0; i < res.i; i++) {
508 perm[i] = i;
509 }
510 // indirect sort of result arrays
511 std::sort(perm.begin(), perm.begin() + res.i, [&res](int i, int j) {
512 return C::cmp(res.vals[j], res.vals[i]);
513 });
514 for (int i = 0; i < res.i; i++) {
515 heap_dis[i] = res.vals[perm[i]] * one_a + b;
516 heap_ids[i] = res.ids[perm[i]];
517 }
518
519 // possibly add empty results
520 heap_heapify<Cf>(n - res.i, heap_dis + res.i, heap_ids + res.i);
521
522 t3 += res.cycles;
523 }
524 times[2] += get_cy() - t0;
525 times[3] += t3;
526 }
527 };
528
529 } // namespace simd_result_handlers
530
531 } // namespace faiss
532