1 /**
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 *
4 * This source code is licensed under the MIT license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
8 // -*- c++ -*-
9
10 /*
11 * Implementation of Hamming related functions (distances, smallest distance
12 * selection with regular heap|radix and probabilistic heap|radix.
13 *
14 * IMPLEMENTATION NOTES
15 * Bitvectors are generally assumed to be multiples of 64 bits.
16 *
17 * hamdis_t is used for distances because at this time
18 * it is not clear how we will need to balance
19 * - flexibility in vector size (unclear more than 2^16 or even 2^8 bitvectors)
20 * - memory usage
21 * - cache-misses when dealing with large volumes of data (lower bits is better)
22 *
23 * The hamdis_t should optimally be compatibe with one of the Torch Storage
24 * (Byte,Short,Long) and therefore should be signed for 2-bytes and 4-bytes
25 */
26
27 #include <faiss/utils/hamming.h>
28
29 #include <math.h>
30 #include <stdio.h>
31 #include <algorithm>
32 #include <memory>
33 #include <vector>
34
35 #include <faiss/impl/AuxIndexStructures.h>
36 #include <faiss/impl/FaissAssert.h>
37 #include <faiss/utils/Heap.h>
38 #include <faiss/utils/utils.h>
39
40 static const size_t BLOCKSIZE_QUERY = 8192;
41
42 namespace faiss {
43
44 size_t hamming_batch_size = 65536;
45
46 const uint8_t hamdis_tab_ham_bytes[256] = {
47 0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 1, 2, 2, 3, 2, 3, 3, 4,
48 2, 3, 3, 4, 3, 4, 4, 5, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
49 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 1, 2, 2, 3, 2, 3, 3, 4,
50 2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
51 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6,
52 4, 5, 5, 6, 5, 6, 6, 7, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
53 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 2, 3, 3, 4, 3, 4, 4, 5,
54 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
55 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6,
56 4, 5, 5, 6, 5, 6, 6, 7, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
57 4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8};
58
59 /* Elementary Hamming distance computation: unoptimized */
60 template <size_t nbits, typename T>
hamming(const uint8_t * bs1,const uint8_t * bs2)61 T hamming(const uint8_t* bs1, const uint8_t* bs2) {
62 const size_t nbytes = nbits / 8;
63 size_t i;
64 T h = 0;
65 for (i = 0; i < nbytes; i++)
66 h += (T)hamdis_tab_ham_bytes[bs1[i] ^ bs2[i]];
67 return h;
68 }
69
70 /* Hamming distances for multiples of 64 bits */
71 template <size_t nbits>
hamming(const uint64_t * bs1,const uint64_t * bs2)72 hamdis_t hamming(const uint64_t* bs1, const uint64_t* bs2) {
73 const size_t nwords = nbits / 64;
74 size_t i;
75 hamdis_t h = 0;
76 for (i = 0; i < nwords; i++)
77 h += popcount64(bs1[i] ^ bs2[i]);
78 return h;
79 }
80
81 /* specialized (optimized) functions */
82 template <>
hamming(const uint64_t * pa,const uint64_t * pb)83 hamdis_t hamming<64>(const uint64_t* pa, const uint64_t* pb) {
84 return popcount64(pa[0] ^ pb[0]);
85 }
86
87 template <>
hamming(const uint64_t * pa,const uint64_t * pb)88 hamdis_t hamming<128>(const uint64_t* pa, const uint64_t* pb) {
89 return popcount64(pa[0] ^ pb[0]) + popcount64(pa[1] ^ pb[1]);
90 }
91
92 template <>
hamming(const uint64_t * pa,const uint64_t * pb)93 hamdis_t hamming<256>(const uint64_t* pa, const uint64_t* pb) {
94 return popcount64(pa[0] ^ pb[0]) + popcount64(pa[1] ^ pb[1]) +
95 popcount64(pa[2] ^ pb[2]) + popcount64(pa[3] ^ pb[3]);
96 }
97
98 /* Hamming distances for multiple of 64 bits */
hamming(const uint64_t * bs1,const uint64_t * bs2,size_t nwords)99 hamdis_t hamming(const uint64_t* bs1, const uint64_t* bs2, size_t nwords) {
100 size_t i;
101 hamdis_t h = 0;
102 for (i = 0; i < nwords; i++)
103 h += popcount64(bs1[i] ^ bs2[i]);
104 return h;
105 }
106
107 template <size_t nbits>
hammings(const uint64_t * bs1,const uint64_t * bs2,size_t n1,size_t n2,hamdis_t * dis)108 void hammings(
109 const uint64_t* bs1,
110 const uint64_t* bs2,
111 size_t n1,
112 size_t n2,
113 hamdis_t* dis)
114
115 {
116 size_t i, j;
117 const size_t nwords = nbits / 64;
118 for (i = 0; i < n1; i++) {
119 const uint64_t* __restrict bs1_ = bs1 + i * nwords;
120 hamdis_t* __restrict dis_ = dis + i * n2;
121 for (j = 0; j < n2; j++)
122 dis_[j] = hamming<nbits>(bs1_, bs2 + j * nwords);
123 }
124 }
125
hammings(const uint64_t * bs1,const uint64_t * bs2,size_t n1,size_t n2,size_t nwords,hamdis_t * __restrict dis)126 void hammings(
127 const uint64_t* bs1,
128 const uint64_t* bs2,
129 size_t n1,
130 size_t n2,
131 size_t nwords,
132 hamdis_t* __restrict dis) {
133 size_t i, j;
134 n1 *= nwords;
135 n2 *= nwords;
136 for (i = 0; i < n1; i += nwords) {
137 const uint64_t* bs1_ = bs1 + i;
138 for (j = 0; j < n2; j += nwords)
139 dis[j] = hamming(bs1_, bs2 + j, nwords);
140 }
141 }
142
143 /* Count number of matches given a max threshold */
144 template <size_t nbits>
hamming_count_thres(const uint64_t * bs1,const uint64_t * bs2,size_t n1,size_t n2,hamdis_t ht,size_t * nptr)145 void hamming_count_thres(
146 const uint64_t* bs1,
147 const uint64_t* bs2,
148 size_t n1,
149 size_t n2,
150 hamdis_t ht,
151 size_t* nptr) {
152 const size_t nwords = nbits / 64;
153 size_t i, j, posm = 0;
154 const uint64_t* bs2_ = bs2;
155
156 for (i = 0; i < n1; i++) {
157 bs2 = bs2_;
158 for (j = 0; j < n2; j++) {
159 /* collect the match only if this satisfies the threshold */
160 if (hamming<nbits>(bs1, bs2) <= ht)
161 posm++;
162 bs2 += nwords;
163 }
164 bs1 += nwords; /* next signature */
165 }
166 *nptr = posm;
167 }
168
169 template <size_t nbits>
crosshamming_count_thres(const uint64_t * dbs,size_t n,int ht,size_t * nptr)170 void crosshamming_count_thres(
171 const uint64_t* dbs,
172 size_t n,
173 int ht,
174 size_t* nptr) {
175 const size_t nwords = nbits / 64;
176 size_t i, j, posm = 0;
177 const uint64_t* bs1 = dbs;
178 for (i = 0; i < n; i++) {
179 const uint64_t* bs2 = bs1 + 2;
180 for (j = i + 1; j < n; j++) {
181 /* collect the match only if this satisfies the threshold */
182 if (hamming<nbits>(bs1, bs2) <= ht)
183 posm++;
184 bs2 += nwords;
185 }
186 bs1 += nwords;
187 }
188 *nptr = posm;
189 }
190
191 template <size_t nbits>
match_hamming_thres(const uint64_t * bs1,const uint64_t * bs2,size_t n1,size_t n2,int ht,int64_t * idx,hamdis_t * hams)192 size_t match_hamming_thres(
193 const uint64_t* bs1,
194 const uint64_t* bs2,
195 size_t n1,
196 size_t n2,
197 int ht,
198 int64_t* idx,
199 hamdis_t* hams) {
200 const size_t nwords = nbits / 64;
201 size_t i, j, posm = 0;
202 hamdis_t h;
203 const uint64_t* bs2_ = bs2;
204 for (i = 0; i < n1; i++) {
205 bs2 = bs2_;
206 for (j = 0; j < n2; j++) {
207 /* Here perform the real work of computing the distance */
208 h = hamming<nbits>(bs1, bs2);
209
210 /* collect the match only if this satisfies the threshold */
211 if (h <= ht) {
212 /* Enough space to store another match ? */
213 *idx = i;
214 idx++;
215 *idx = j;
216 idx++;
217 *hams = h;
218 hams++;
219 posm++;
220 }
221 bs2 += nwords; /* next signature */
222 }
223 bs1 += nwords;
224 }
225 return posm;
226 }
227
228 /* Return closest neighbors w.r.t Hamming distance, using a heap. */
229 template <class HammingComputer>
hammings_knn_hc(int bytes_per_code,int_maxheap_array_t * ha,const uint8_t * bs1,const uint8_t * bs2,size_t n2,bool order=true,bool init_heap=true)230 static void hammings_knn_hc(
231 int bytes_per_code,
232 int_maxheap_array_t* ha,
233 const uint8_t* bs1,
234 const uint8_t* bs2,
235 size_t n2,
236 bool order = true,
237 bool init_heap = true) {
238 size_t k = ha->k;
239 if (init_heap)
240 ha->heapify();
241
242 const size_t block_size = hamming_batch_size;
243 for (size_t j0 = 0; j0 < n2; j0 += block_size) {
244 const size_t j1 = std::min(j0 + block_size, n2);
245 #pragma omp parallel for
246 for (int64_t i = 0; i < ha->nh; i++) {
247 HammingComputer hc(bs1 + i * bytes_per_code, bytes_per_code);
248
249 const uint8_t* bs2_ = bs2 + j0 * bytes_per_code;
250 hamdis_t dis;
251 hamdis_t* __restrict bh_val_ = ha->val + i * k;
252 int64_t* __restrict bh_ids_ = ha->ids + i * k;
253 size_t j;
254 for (j = j0; j < j1; j++, bs2_ += bytes_per_code) {
255 dis = hc.hamming(bs2_);
256 if (dis < bh_val_[0]) {
257 faiss::maxheap_replace_top<hamdis_t>(
258 k, bh_val_, bh_ids_, dis, j);
259 }
260 }
261 }
262 }
263 if (order)
264 ha->reorder();
265 }
266
267 /* Return closest neighbors w.r.t Hamming distance, using max count. */
268 template <class HammingComputer>
hammings_knn_mc(int bytes_per_code,const uint8_t * a,const uint8_t * b,size_t na,size_t nb,size_t k,int32_t * distances,int64_t * labels)269 static void hammings_knn_mc(
270 int bytes_per_code,
271 const uint8_t* a,
272 const uint8_t* b,
273 size_t na,
274 size_t nb,
275 size_t k,
276 int32_t* distances,
277 int64_t* labels) {
278 const int nBuckets = bytes_per_code * 8 + 1;
279 std::vector<int> all_counters(na * nBuckets, 0);
280 std::unique_ptr<int64_t[]> all_ids_per_dis(new int64_t[na * nBuckets * k]);
281
282 std::vector<HCounterState<HammingComputer>> cs;
283 for (size_t i = 0; i < na; ++i) {
284 cs.push_back(HCounterState<HammingComputer>(
285 all_counters.data() + i * nBuckets,
286 all_ids_per_dis.get() + i * nBuckets * k,
287 a + i * bytes_per_code,
288 8 * bytes_per_code,
289 k));
290 }
291
292 const size_t block_size = hamming_batch_size;
293 for (size_t j0 = 0; j0 < nb; j0 += block_size) {
294 const size_t j1 = std::min(j0 + block_size, nb);
295 #pragma omp parallel for
296 for (int64_t i = 0; i < na; ++i) {
297 for (size_t j = j0; j < j1; ++j) {
298 cs[i].update_counter(b + j * bytes_per_code, j);
299 }
300 }
301 }
302
303 for (size_t i = 0; i < na; ++i) {
304 HCounterState<HammingComputer>& csi = cs[i];
305
306 int nres = 0;
307 for (int b = 0; b < nBuckets && nres < k; b++) {
308 for (int l = 0; l < csi.counters[b] && nres < k; l++) {
309 labels[i * k + nres] = csi.ids_per_dis[b * k + l];
310 distances[i * k + nres] = b;
311 nres++;
312 }
313 }
314 while (nres < k) {
315 labels[i * k + nres] = -1;
316 distances[i * k + nres] = std::numeric_limits<int32_t>::max();
317 ++nres;
318 }
319 }
320 }
321
322 // works faster than the template version
hammings_knn_hc_1(int_maxheap_array_t * ha,const uint64_t * bs1,const uint64_t * bs2,size_t n2,bool order=true,bool init_heap=true)323 static void hammings_knn_hc_1(
324 int_maxheap_array_t* ha,
325 const uint64_t* bs1,
326 const uint64_t* bs2,
327 size_t n2,
328 bool order = true,
329 bool init_heap = true) {
330 const size_t nwords = 1;
331 size_t k = ha->k;
332
333 if (init_heap) {
334 ha->heapify();
335 }
336
337 #pragma omp parallel for
338 for (int64_t i = 0; i < ha->nh; i++) {
339 const uint64_t bs1_ = bs1[i];
340 const uint64_t* bs2_ = bs2;
341 hamdis_t dis;
342 hamdis_t* bh_val_ = ha->val + i * k;
343 hamdis_t bh_val_0 = bh_val_[0];
344 int64_t* bh_ids_ = ha->ids + i * k;
345 size_t j;
346 for (j = 0; j < n2; j++, bs2_ += nwords) {
347 dis = popcount64(bs1_ ^ *bs2_);
348 if (dis < bh_val_0) {
349 faiss::maxheap_replace_top<hamdis_t>(
350 k, bh_val_, bh_ids_, dis, j);
351 bh_val_0 = bh_val_[0];
352 }
353 }
354 }
355 if (order) {
356 ha->reorder();
357 }
358 }
359
360 /* Functions to maps vectors to bits. Assume proper allocation done beforehand,
361 meaning that b should be be able to receive as many bits as x may produce. */
362
363 /*
364 * dimension 0 corresponds to the least significant bit of b[0], or
365 * equivalently to the lsb of the first byte that is stored.
366 */
fvec2bitvec(const float * x,uint8_t * b,size_t d)367 void fvec2bitvec(const float* x, uint8_t* b, size_t d) {
368 for (int i = 0; i < d; i += 8) {
369 uint8_t w = 0;
370 uint8_t mask = 1;
371 int nj = i + 8 <= d ? 8 : d - i;
372 for (int j = 0; j < nj; j++) {
373 if (x[i + j] >= 0)
374 w |= mask;
375 mask <<= 1;
376 }
377 *b = w;
378 b++;
379 }
380 }
381
382 /* Same but for n vectors.
383 Ensure that the ouptut b is byte-aligned (pad with 0s). */
fvecs2bitvecs(const float * x,uint8_t * b,size_t d,size_t n)384 void fvecs2bitvecs(const float* x, uint8_t* b, size_t d, size_t n) {
385 const int64_t ncodes = ((d + 7) / 8);
386 #pragma omp parallel for if (n > 100000)
387 for (int64_t i = 0; i < n; i++)
388 fvec2bitvec(x + i * d, b + i * ncodes, d);
389 }
390
bitvecs2fvecs(const uint8_t * b,float * x,size_t d,size_t n)391 void bitvecs2fvecs(const uint8_t* b, float* x, size_t d, size_t n) {
392 const int64_t ncodes = ((d + 7) / 8);
393 #pragma omp parallel for if (n > 100000)
394 for (int64_t i = 0; i < n; i++) {
395 binary_to_real(d, b + i * ncodes, x + i * d);
396 }
397 }
398
399 /* Reverse bit (NOT a optimized function, only used for print purpose) */
uint64_reverse_bits(uint64_t b)400 static uint64_t uint64_reverse_bits(uint64_t b) {
401 int i;
402 uint64_t revb = 0;
403 for (i = 0; i < 64; i++) {
404 revb <<= 1;
405 revb |= b & 1;
406 b >>= 1;
407 }
408 return revb;
409 }
410
411 /* print the bit vector */
bitvec_print(const uint8_t * b,size_t d)412 void bitvec_print(const uint8_t* b, size_t d) {
413 size_t i, j;
414 for (i = 0; i < d;) {
415 uint64_t brev = uint64_reverse_bits(*(uint64_t*)b);
416 for (j = 0; j < 64 && i < d; j++, i++) {
417 printf("%d", (int)(brev & 1));
418 brev >>= 1;
419 }
420 b += 8;
421 printf(" ");
422 }
423 }
424
bitvec_shuffle(size_t n,size_t da,size_t db,const int * order,const uint8_t * a,uint8_t * b)425 void bitvec_shuffle(
426 size_t n,
427 size_t da,
428 size_t db,
429 const int* order,
430 const uint8_t* a,
431 uint8_t* b) {
432 for (size_t i = 0; i < db; i++) {
433 FAISS_THROW_IF_NOT(order[i] >= 0 && order[i] < da);
434 }
435 size_t lda = (da + 7) / 8;
436 size_t ldb = (db + 7) / 8;
437
438 #pragma omp parallel for if (n > 10000)
439 for (int64_t i = 0; i < n; i++) {
440 const uint8_t* ai = a + i * lda;
441 uint8_t* bi = b + i * ldb;
442 memset(bi, 0, ldb);
443 for (size_t j = 0; j < db; j++) {
444 int o = order[j];
445 uint8_t the_bit = (ai[o >> 3] >> (o & 7)) & 1;
446 bi[j >> 3] |= the_bit << (j & 7);
447 }
448 }
449 }
450
451 /*----------------------------------------*/
452 /* Hamming distance computation and k-nn */
453
454 #define C64(x) ((uint64_t*)x)
455
456 /* Compute a set of Hamming distances */
hammings(const uint8_t * a,const uint8_t * b,size_t na,size_t nb,size_t ncodes,hamdis_t * __restrict dis)457 void hammings(
458 const uint8_t* a,
459 const uint8_t* b,
460 size_t na,
461 size_t nb,
462 size_t ncodes,
463 hamdis_t* __restrict dis) {
464 FAISS_THROW_IF_NOT(ncodes % 8 == 0);
465 switch (ncodes) {
466 case 8:
467 faiss::hammings<64>(C64(a), C64(b), na, nb, dis);
468 return;
469 case 16:
470 faiss::hammings<128>(C64(a), C64(b), na, nb, dis);
471 return;
472 case 32:
473 faiss::hammings<256>(C64(a), C64(b), na, nb, dis);
474 return;
475 case 64:
476 faiss::hammings<512>(C64(a), C64(b), na, nb, dis);
477 return;
478 default:
479 faiss::hammings(C64(a), C64(b), na, nb, ncodes * 8, dis);
480 return;
481 }
482 }
483
hammings_knn(int_maxheap_array_t * ha,const uint8_t * a,const uint8_t * b,size_t nb,size_t ncodes,int order)484 void hammings_knn(
485 int_maxheap_array_t* ha,
486 const uint8_t* a,
487 const uint8_t* b,
488 size_t nb,
489 size_t ncodes,
490 int order) {
491 hammings_knn_hc(ha, a, b, nb, ncodes, order);
492 }
493
hammings_knn_hc(int_maxheap_array_t * ha,const uint8_t * a,const uint8_t * b,size_t nb,size_t ncodes,int order)494 void hammings_knn_hc(
495 int_maxheap_array_t* ha,
496 const uint8_t* a,
497 const uint8_t* b,
498 size_t nb,
499 size_t ncodes,
500 int order) {
501 switch (ncodes) {
502 case 4:
503 hammings_knn_hc<faiss::HammingComputer4>(
504 4, ha, a, b, nb, order, true);
505 break;
506 case 8:
507 hammings_knn_hc_1(ha, C64(a), C64(b), nb, order, true);
508 // hammings_knn_hc<faiss::HammingComputer8>
509 // (8, ha, a, b, nb, order, true);
510 break;
511 case 16:
512 hammings_knn_hc<faiss::HammingComputer16>(
513 16, ha, a, b, nb, order, true);
514 break;
515 case 32:
516 hammings_knn_hc<faiss::HammingComputer32>(
517 32, ha, a, b, nb, order, true);
518 break;
519 default:
520 hammings_knn_hc<faiss::HammingComputerDefault>(
521 ncodes, ha, a, b, nb, order, true);
522 break;
523 }
524 }
525
hammings_knn_mc(const uint8_t * a,const uint8_t * b,size_t na,size_t nb,size_t k,size_t ncodes,int32_t * distances,int64_t * labels)526 void hammings_knn_mc(
527 const uint8_t* a,
528 const uint8_t* b,
529 size_t na,
530 size_t nb,
531 size_t k,
532 size_t ncodes,
533 int32_t* distances,
534 int64_t* labels) {
535 switch (ncodes) {
536 case 4:
537 hammings_knn_mc<faiss::HammingComputer4>(
538 4, a, b, na, nb, k, distances, labels);
539 break;
540 case 8:
541 // TODO(hoss): Write analog to hammings_knn_hc_1
542 // hammings_knn_hc_1 (ha, C64(a), C64(b), nb, order, true);
543 hammings_knn_mc<faiss::HammingComputer8>(
544 8, a, b, na, nb, k, distances, labels);
545 break;
546 case 16:
547 hammings_knn_mc<faiss::HammingComputer16>(
548 16, a, b, na, nb, k, distances, labels);
549 break;
550 case 32:
551 hammings_knn_mc<faiss::HammingComputer32>(
552 32, a, b, na, nb, k, distances, labels);
553 break;
554 default:
555 hammings_knn_mc<faiss::HammingComputerDefault>(
556 ncodes, a, b, na, nb, k, distances, labels);
557 break;
558 }
559 }
560 template <class HammingComputer>
hamming_range_search_template(const uint8_t * a,const uint8_t * b,size_t na,size_t nb,int radius,size_t code_size,RangeSearchResult * res)561 static void hamming_range_search_template(
562 const uint8_t* a,
563 const uint8_t* b,
564 size_t na,
565 size_t nb,
566 int radius,
567 size_t code_size,
568 RangeSearchResult* res) {
569 #pragma omp parallel
570 {
571 RangeSearchPartialResult pres(res);
572
573 #pragma omp for
574 for (int64_t i = 0; i < na; i++) {
575 HammingComputer hc(a + i * code_size, code_size);
576 const uint8_t* yi = b;
577 RangeQueryResult& qres = pres.new_result(i);
578
579 for (size_t j = 0; j < nb; j++) {
580 int dis = hc.hamming(yi);
581 if (dis < radius) {
582 qres.add(dis, j);
583 }
584 yi += code_size;
585 }
586 }
587 pres.finalize();
588 }
589 }
590
hamming_range_search(const uint8_t * a,const uint8_t * b,size_t na,size_t nb,int radius,size_t code_size,RangeSearchResult * result)591 void hamming_range_search(
592 const uint8_t* a,
593 const uint8_t* b,
594 size_t na,
595 size_t nb,
596 int radius,
597 size_t code_size,
598 RangeSearchResult* result) {
599 #define HC(name) \
600 hamming_range_search_template<name>(a, b, na, nb, radius, code_size, result)
601
602 switch (code_size) {
603 case 4:
604 HC(HammingComputer4);
605 break;
606 case 8:
607 HC(HammingComputer8);
608 break;
609 case 16:
610 HC(HammingComputer16);
611 break;
612 case 32:
613 HC(HammingComputer32);
614 break;
615 default:
616 HC(HammingComputerDefault);
617 break;
618 }
619 #undef HC
620 }
621
622 /* Count number of matches given a max threshold */
hamming_count_thres(const uint8_t * bs1,const uint8_t * bs2,size_t n1,size_t n2,hamdis_t ht,size_t ncodes,size_t * nptr)623 void hamming_count_thres(
624 const uint8_t* bs1,
625 const uint8_t* bs2,
626 size_t n1,
627 size_t n2,
628 hamdis_t ht,
629 size_t ncodes,
630 size_t* nptr) {
631 switch (ncodes) {
632 case 8:
633 faiss::hamming_count_thres<64>(
634 C64(bs1), C64(bs2), n1, n2, ht, nptr);
635 return;
636 case 16:
637 faiss::hamming_count_thres<128>(
638 C64(bs1), C64(bs2), n1, n2, ht, nptr);
639 return;
640 case 32:
641 faiss::hamming_count_thres<256>(
642 C64(bs1), C64(bs2), n1, n2, ht, nptr);
643 return;
644 case 64:
645 faiss::hamming_count_thres<512>(
646 C64(bs1), C64(bs2), n1, n2, ht, nptr);
647 return;
648 default:
649 FAISS_THROW_FMT("not implemented for %zu bits", ncodes);
650 }
651 }
652
653 /* Count number of cross-matches given a threshold */
crosshamming_count_thres(const uint8_t * dbs,size_t n,hamdis_t ht,size_t ncodes,size_t * nptr)654 void crosshamming_count_thres(
655 const uint8_t* dbs,
656 size_t n,
657 hamdis_t ht,
658 size_t ncodes,
659 size_t* nptr) {
660 switch (ncodes) {
661 case 8:
662 faiss::crosshamming_count_thres<64>(C64(dbs), n, ht, nptr);
663 return;
664 case 16:
665 faiss::crosshamming_count_thres<128>(C64(dbs), n, ht, nptr);
666 return;
667 case 32:
668 faiss::crosshamming_count_thres<256>(C64(dbs), n, ht, nptr);
669 return;
670 case 64:
671 faiss::crosshamming_count_thres<512>(C64(dbs), n, ht, nptr);
672 return;
673 default:
674 FAISS_THROW_FMT("not implemented for %zu bits", ncodes);
675 }
676 }
677
678 /* Returns all matches given a threshold */
match_hamming_thres(const uint8_t * bs1,const uint8_t * bs2,size_t n1,size_t n2,hamdis_t ht,size_t ncodes,int64_t * idx,hamdis_t * dis)679 size_t match_hamming_thres(
680 const uint8_t* bs1,
681 const uint8_t* bs2,
682 size_t n1,
683 size_t n2,
684 hamdis_t ht,
685 size_t ncodes,
686 int64_t* idx,
687 hamdis_t* dis) {
688 switch (ncodes) {
689 case 8:
690 return faiss::match_hamming_thres<64>(
691 C64(bs1), C64(bs2), n1, n2, ht, idx, dis);
692 case 16:
693 return faiss::match_hamming_thres<128>(
694 C64(bs1), C64(bs2), n1, n2, ht, idx, dis);
695 case 32:
696 return faiss::match_hamming_thres<256>(
697 C64(bs1), C64(bs2), n1, n2, ht, idx, dis);
698 case 64:
699 return faiss::match_hamming_thres<512>(
700 C64(bs1), C64(bs2), n1, n2, ht, idx, dis);
701 default:
702 FAISS_THROW_FMT("not implemented for %zu bits", ncodes);
703 return 0;
704 }
705 }
706
707 #undef C64
708
709 /*************************************
710 * generalized Hamming distances
711 ************************************/
712
713 template <class HammingComputer>
hamming_dis_inner_loop(const uint8_t * ca,const uint8_t * cb,size_t nb,size_t code_size,int k,hamdis_t * bh_val_,int64_t * bh_ids_)714 static void hamming_dis_inner_loop(
715 const uint8_t* ca,
716 const uint8_t* cb,
717 size_t nb,
718 size_t code_size,
719 int k,
720 hamdis_t* bh_val_,
721 int64_t* bh_ids_) {
722 HammingComputer hc(ca, code_size);
723
724 for (size_t j = 0; j < nb; j++) {
725 int ndiff = hc.hamming(cb);
726 cb += code_size;
727 if (ndiff < bh_val_[0]) {
728 maxheap_replace_top<hamdis_t>(k, bh_val_, bh_ids_, ndiff, j);
729 }
730 }
731 }
732
generalized_hammings_knn_hc(int_maxheap_array_t * ha,const uint8_t * a,const uint8_t * b,size_t nb,size_t code_size,int ordered)733 void generalized_hammings_knn_hc(
734 int_maxheap_array_t* ha,
735 const uint8_t* a,
736 const uint8_t* b,
737 size_t nb,
738 size_t code_size,
739 int ordered) {
740 int na = ha->nh;
741 int k = ha->k;
742
743 if (ordered)
744 ha->heapify();
745
746 #pragma omp parallel for
747 for (int i = 0; i < na; i++) {
748 const uint8_t* ca = a + i * code_size;
749 const uint8_t* cb = b;
750
751 hamdis_t* bh_val_ = ha->val + i * k;
752 int64_t* bh_ids_ = ha->ids + i * k;
753
754 switch (code_size) {
755 case 8:
756 hamming_dis_inner_loop<GenHammingComputer8>(
757 ca, cb, nb, 8, k, bh_val_, bh_ids_);
758 break;
759 case 16:
760 hamming_dis_inner_loop<GenHammingComputer16>(
761 ca, cb, nb, 16, k, bh_val_, bh_ids_);
762 break;
763 case 32:
764 hamming_dis_inner_loop<GenHammingComputer32>(
765 ca, cb, nb, 32, k, bh_val_, bh_ids_);
766 break;
767 default:
768 hamming_dis_inner_loop<GenHammingComputerM8>(
769 ca, cb, nb, code_size, k, bh_val_, bh_ids_);
770 break;
771 }
772 }
773
774 if (ordered)
775 ha->reorder();
776 }
777
778 } // namespace faiss
779