1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 // -*- c++ -*-
9 
10 /*
11  * Implementation of Hamming related functions (distances, smallest distance
12  * selection with regular heap|radix and probabilistic heap|radix.
13  *
14  * IMPLEMENTATION NOTES
15  * Bitvectors are generally assumed to be multiples of 64 bits.
16  *
17  * hamdis_t is used for distances because at this time
18  * it is not clear how we will need to balance
19  * - flexibility in vector size (unclear more than 2^16 or even 2^8 bitvectors)
20  * - memory usage
21  * - cache-misses when dealing with large volumes of data (lower bits is better)
22  *
23  * The hamdis_t should optimally be compatibe with one of the Torch Storage
24  * (Byte,Short,Long) and therefore should be signed for 2-bytes and 4-bytes
25  */
26 
27 #include <faiss/utils/hamming.h>
28 
29 #include <math.h>
30 #include <stdio.h>
31 #include <algorithm>
32 #include <memory>
33 #include <vector>
34 
35 #include <faiss/impl/AuxIndexStructures.h>
36 #include <faiss/impl/FaissAssert.h>
37 #include <faiss/utils/Heap.h>
38 #include <faiss/utils/utils.h>
39 
40 static const size_t BLOCKSIZE_QUERY = 8192;
41 
42 namespace faiss {
43 
44 size_t hamming_batch_size = 65536;
45 
46 const uint8_t hamdis_tab_ham_bytes[256] = {
47         0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4, 1, 2, 2, 3, 2, 3, 3, 4,
48         2, 3, 3, 4, 3, 4, 4, 5, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
49         2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 1, 2, 2, 3, 2, 3, 3, 4,
50         2, 3, 3, 4, 3, 4, 4, 5, 2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
51         2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6,
52         4, 5, 5, 6, 5, 6, 6, 7, 1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
53         2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 2, 3, 3, 4, 3, 4, 4, 5,
54         3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
55         2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6, 3, 4, 4, 5, 4, 5, 5, 6,
56         4, 5, 5, 6, 5, 6, 6, 7, 3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
57         4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8};
58 
59 /* Elementary Hamming distance computation: unoptimized  */
60 template <size_t nbits, typename T>
hamming(const uint8_t * bs1,const uint8_t * bs2)61 T hamming(const uint8_t* bs1, const uint8_t* bs2) {
62     const size_t nbytes = nbits / 8;
63     size_t i;
64     T h = 0;
65     for (i = 0; i < nbytes; i++)
66         h += (T)hamdis_tab_ham_bytes[bs1[i] ^ bs2[i]];
67     return h;
68 }
69 
70 /* Hamming distances for multiples of 64 bits */
71 template <size_t nbits>
hamming(const uint64_t * bs1,const uint64_t * bs2)72 hamdis_t hamming(const uint64_t* bs1, const uint64_t* bs2) {
73     const size_t nwords = nbits / 64;
74     size_t i;
75     hamdis_t h = 0;
76     for (i = 0; i < nwords; i++)
77         h += popcount64(bs1[i] ^ bs2[i]);
78     return h;
79 }
80 
81 /* specialized (optimized) functions */
82 template <>
hamming(const uint64_t * pa,const uint64_t * pb)83 hamdis_t hamming<64>(const uint64_t* pa, const uint64_t* pb) {
84     return popcount64(pa[0] ^ pb[0]);
85 }
86 
87 template <>
hamming(const uint64_t * pa,const uint64_t * pb)88 hamdis_t hamming<128>(const uint64_t* pa, const uint64_t* pb) {
89     return popcount64(pa[0] ^ pb[0]) + popcount64(pa[1] ^ pb[1]);
90 }
91 
92 template <>
hamming(const uint64_t * pa,const uint64_t * pb)93 hamdis_t hamming<256>(const uint64_t* pa, const uint64_t* pb) {
94     return popcount64(pa[0] ^ pb[0]) + popcount64(pa[1] ^ pb[1]) +
95             popcount64(pa[2] ^ pb[2]) + popcount64(pa[3] ^ pb[3]);
96 }
97 
98 /* Hamming distances for multiple of 64 bits */
hamming(const uint64_t * bs1,const uint64_t * bs2,size_t nwords)99 hamdis_t hamming(const uint64_t* bs1, const uint64_t* bs2, size_t nwords) {
100     size_t i;
101     hamdis_t h = 0;
102     for (i = 0; i < nwords; i++)
103         h += popcount64(bs1[i] ^ bs2[i]);
104     return h;
105 }
106 
107 template <size_t nbits>
hammings(const uint64_t * bs1,const uint64_t * bs2,size_t n1,size_t n2,hamdis_t * dis)108 void hammings(
109         const uint64_t* bs1,
110         const uint64_t* bs2,
111         size_t n1,
112         size_t n2,
113         hamdis_t* dis)
114 
115 {
116     size_t i, j;
117     const size_t nwords = nbits / 64;
118     for (i = 0; i < n1; i++) {
119         const uint64_t* __restrict bs1_ = bs1 + i * nwords;
120         hamdis_t* __restrict dis_ = dis + i * n2;
121         for (j = 0; j < n2; j++)
122             dis_[j] = hamming<nbits>(bs1_, bs2 + j * nwords);
123     }
124 }
125 
hammings(const uint64_t * bs1,const uint64_t * bs2,size_t n1,size_t n2,size_t nwords,hamdis_t * __restrict dis)126 void hammings(
127         const uint64_t* bs1,
128         const uint64_t* bs2,
129         size_t n1,
130         size_t n2,
131         size_t nwords,
132         hamdis_t* __restrict dis) {
133     size_t i, j;
134     n1 *= nwords;
135     n2 *= nwords;
136     for (i = 0; i < n1; i += nwords) {
137         const uint64_t* bs1_ = bs1 + i;
138         for (j = 0; j < n2; j += nwords)
139             dis[j] = hamming(bs1_, bs2 + j, nwords);
140     }
141 }
142 
143 /* Count number of matches given a max threshold */
144 template <size_t nbits>
hamming_count_thres(const uint64_t * bs1,const uint64_t * bs2,size_t n1,size_t n2,hamdis_t ht,size_t * nptr)145 void hamming_count_thres(
146         const uint64_t* bs1,
147         const uint64_t* bs2,
148         size_t n1,
149         size_t n2,
150         hamdis_t ht,
151         size_t* nptr) {
152     const size_t nwords = nbits / 64;
153     size_t i, j, posm = 0;
154     const uint64_t* bs2_ = bs2;
155 
156     for (i = 0; i < n1; i++) {
157         bs2 = bs2_;
158         for (j = 0; j < n2; j++) {
159             /* collect the match only if this satisfies the threshold */
160             if (hamming<nbits>(bs1, bs2) <= ht)
161                 posm++;
162             bs2 += nwords;
163         }
164         bs1 += nwords; /* next signature */
165     }
166     *nptr = posm;
167 }
168 
169 template <size_t nbits>
crosshamming_count_thres(const uint64_t * dbs,size_t n,int ht,size_t * nptr)170 void crosshamming_count_thres(
171         const uint64_t* dbs,
172         size_t n,
173         int ht,
174         size_t* nptr) {
175     const size_t nwords = nbits / 64;
176     size_t i, j, posm = 0;
177     const uint64_t* bs1 = dbs;
178     for (i = 0; i < n; i++) {
179         const uint64_t* bs2 = bs1 + 2;
180         for (j = i + 1; j < n; j++) {
181             /* collect the match only if this satisfies the threshold */
182             if (hamming<nbits>(bs1, bs2) <= ht)
183                 posm++;
184             bs2 += nwords;
185         }
186         bs1 += nwords;
187     }
188     *nptr = posm;
189 }
190 
191 template <size_t nbits>
match_hamming_thres(const uint64_t * bs1,const uint64_t * bs2,size_t n1,size_t n2,int ht,int64_t * idx,hamdis_t * hams)192 size_t match_hamming_thres(
193         const uint64_t* bs1,
194         const uint64_t* bs2,
195         size_t n1,
196         size_t n2,
197         int ht,
198         int64_t* idx,
199         hamdis_t* hams) {
200     const size_t nwords = nbits / 64;
201     size_t i, j, posm = 0;
202     hamdis_t h;
203     const uint64_t* bs2_ = bs2;
204     for (i = 0; i < n1; i++) {
205         bs2 = bs2_;
206         for (j = 0; j < n2; j++) {
207             /* Here perform the real work of computing the distance */
208             h = hamming<nbits>(bs1, bs2);
209 
210             /* collect the match only if this satisfies the threshold */
211             if (h <= ht) {
212                 /* Enough space to store another match ? */
213                 *idx = i;
214                 idx++;
215                 *idx = j;
216                 idx++;
217                 *hams = h;
218                 hams++;
219                 posm++;
220             }
221             bs2 += nwords; /* next signature */
222         }
223         bs1 += nwords;
224     }
225     return posm;
226 }
227 
228 /* Return closest neighbors w.r.t Hamming distance, using a heap. */
229 template <class HammingComputer>
hammings_knn_hc(int bytes_per_code,int_maxheap_array_t * ha,const uint8_t * bs1,const uint8_t * bs2,size_t n2,bool order=true,bool init_heap=true)230 static void hammings_knn_hc(
231         int bytes_per_code,
232         int_maxheap_array_t* ha,
233         const uint8_t* bs1,
234         const uint8_t* bs2,
235         size_t n2,
236         bool order = true,
237         bool init_heap = true) {
238     size_t k = ha->k;
239     if (init_heap)
240         ha->heapify();
241 
242     const size_t block_size = hamming_batch_size;
243     for (size_t j0 = 0; j0 < n2; j0 += block_size) {
244         const size_t j1 = std::min(j0 + block_size, n2);
245 #pragma omp parallel for
246         for (int64_t i = 0; i < ha->nh; i++) {
247             HammingComputer hc(bs1 + i * bytes_per_code, bytes_per_code);
248 
249             const uint8_t* bs2_ = bs2 + j0 * bytes_per_code;
250             hamdis_t dis;
251             hamdis_t* __restrict bh_val_ = ha->val + i * k;
252             int64_t* __restrict bh_ids_ = ha->ids + i * k;
253             size_t j;
254             for (j = j0; j < j1; j++, bs2_ += bytes_per_code) {
255                 dis = hc.hamming(bs2_);
256                 if (dis < bh_val_[0]) {
257                     faiss::maxheap_replace_top<hamdis_t>(
258                             k, bh_val_, bh_ids_, dis, j);
259                 }
260             }
261         }
262     }
263     if (order)
264         ha->reorder();
265 }
266 
267 /* Return closest neighbors w.r.t Hamming distance, using max count. */
268 template <class HammingComputer>
hammings_knn_mc(int bytes_per_code,const uint8_t * a,const uint8_t * b,size_t na,size_t nb,size_t k,int32_t * distances,int64_t * labels)269 static void hammings_knn_mc(
270         int bytes_per_code,
271         const uint8_t* a,
272         const uint8_t* b,
273         size_t na,
274         size_t nb,
275         size_t k,
276         int32_t* distances,
277         int64_t* labels) {
278     const int nBuckets = bytes_per_code * 8 + 1;
279     std::vector<int> all_counters(na * nBuckets, 0);
280     std::unique_ptr<int64_t[]> all_ids_per_dis(new int64_t[na * nBuckets * k]);
281 
282     std::vector<HCounterState<HammingComputer>> cs;
283     for (size_t i = 0; i < na; ++i) {
284         cs.push_back(HCounterState<HammingComputer>(
285                 all_counters.data() + i * nBuckets,
286                 all_ids_per_dis.get() + i * nBuckets * k,
287                 a + i * bytes_per_code,
288                 8 * bytes_per_code,
289                 k));
290     }
291 
292     const size_t block_size = hamming_batch_size;
293     for (size_t j0 = 0; j0 < nb; j0 += block_size) {
294         const size_t j1 = std::min(j0 + block_size, nb);
295 #pragma omp parallel for
296         for (int64_t i = 0; i < na; ++i) {
297             for (size_t j = j0; j < j1; ++j) {
298                 cs[i].update_counter(b + j * bytes_per_code, j);
299             }
300         }
301     }
302 
303     for (size_t i = 0; i < na; ++i) {
304         HCounterState<HammingComputer>& csi = cs[i];
305 
306         int nres = 0;
307         for (int b = 0; b < nBuckets && nres < k; b++) {
308             for (int l = 0; l < csi.counters[b] && nres < k; l++) {
309                 labels[i * k + nres] = csi.ids_per_dis[b * k + l];
310                 distances[i * k + nres] = b;
311                 nres++;
312             }
313         }
314         while (nres < k) {
315             labels[i * k + nres] = -1;
316             distances[i * k + nres] = std::numeric_limits<int32_t>::max();
317             ++nres;
318         }
319     }
320 }
321 
322 // works faster than the template version
hammings_knn_hc_1(int_maxheap_array_t * ha,const uint64_t * bs1,const uint64_t * bs2,size_t n2,bool order=true,bool init_heap=true)323 static void hammings_knn_hc_1(
324         int_maxheap_array_t* ha,
325         const uint64_t* bs1,
326         const uint64_t* bs2,
327         size_t n2,
328         bool order = true,
329         bool init_heap = true) {
330     const size_t nwords = 1;
331     size_t k = ha->k;
332 
333     if (init_heap) {
334         ha->heapify();
335     }
336 
337 #pragma omp parallel for
338     for (int64_t i = 0; i < ha->nh; i++) {
339         const uint64_t bs1_ = bs1[i];
340         const uint64_t* bs2_ = bs2;
341         hamdis_t dis;
342         hamdis_t* bh_val_ = ha->val + i * k;
343         hamdis_t bh_val_0 = bh_val_[0];
344         int64_t* bh_ids_ = ha->ids + i * k;
345         size_t j;
346         for (j = 0; j < n2; j++, bs2_ += nwords) {
347             dis = popcount64(bs1_ ^ *bs2_);
348             if (dis < bh_val_0) {
349                 faiss::maxheap_replace_top<hamdis_t>(
350                         k, bh_val_, bh_ids_, dis, j);
351                 bh_val_0 = bh_val_[0];
352             }
353         }
354     }
355     if (order) {
356         ha->reorder();
357     }
358 }
359 
360 /* Functions to maps vectors to bits. Assume proper allocation done beforehand,
361    meaning that b should be be able to receive as many bits as x may produce. */
362 
363 /*
364  * dimension 0 corresponds to the least significant bit of b[0], or
365  * equivalently to the lsb of the first byte that is stored.
366  */
fvec2bitvec(const float * x,uint8_t * b,size_t d)367 void fvec2bitvec(const float* x, uint8_t* b, size_t d) {
368     for (int i = 0; i < d; i += 8) {
369         uint8_t w = 0;
370         uint8_t mask = 1;
371         int nj = i + 8 <= d ? 8 : d - i;
372         for (int j = 0; j < nj; j++) {
373             if (x[i + j] >= 0)
374                 w |= mask;
375             mask <<= 1;
376         }
377         *b = w;
378         b++;
379     }
380 }
381 
382 /* Same but for n vectors.
383    Ensure that the ouptut b is byte-aligned (pad with 0s). */
fvecs2bitvecs(const float * x,uint8_t * b,size_t d,size_t n)384 void fvecs2bitvecs(const float* x, uint8_t* b, size_t d, size_t n) {
385     const int64_t ncodes = ((d + 7) / 8);
386 #pragma omp parallel for if (n > 100000)
387     for (int64_t i = 0; i < n; i++)
388         fvec2bitvec(x + i * d, b + i * ncodes, d);
389 }
390 
bitvecs2fvecs(const uint8_t * b,float * x,size_t d,size_t n)391 void bitvecs2fvecs(const uint8_t* b, float* x, size_t d, size_t n) {
392     const int64_t ncodes = ((d + 7) / 8);
393 #pragma omp parallel for if (n > 100000)
394     for (int64_t i = 0; i < n; i++) {
395         binary_to_real(d, b + i * ncodes, x + i * d);
396     }
397 }
398 
399 /* Reverse bit (NOT a optimized function, only used for print purpose) */
uint64_reverse_bits(uint64_t b)400 static uint64_t uint64_reverse_bits(uint64_t b) {
401     int i;
402     uint64_t revb = 0;
403     for (i = 0; i < 64; i++) {
404         revb <<= 1;
405         revb |= b & 1;
406         b >>= 1;
407     }
408     return revb;
409 }
410 
411 /* print the bit vector */
bitvec_print(const uint8_t * b,size_t d)412 void bitvec_print(const uint8_t* b, size_t d) {
413     size_t i, j;
414     for (i = 0; i < d;) {
415         uint64_t brev = uint64_reverse_bits(*(uint64_t*)b);
416         for (j = 0; j < 64 && i < d; j++, i++) {
417             printf("%d", (int)(brev & 1));
418             brev >>= 1;
419         }
420         b += 8;
421         printf(" ");
422     }
423 }
424 
bitvec_shuffle(size_t n,size_t da,size_t db,const int * order,const uint8_t * a,uint8_t * b)425 void bitvec_shuffle(
426         size_t n,
427         size_t da,
428         size_t db,
429         const int* order,
430         const uint8_t* a,
431         uint8_t* b) {
432     for (size_t i = 0; i < db; i++) {
433         FAISS_THROW_IF_NOT(order[i] >= 0 && order[i] < da);
434     }
435     size_t lda = (da + 7) / 8;
436     size_t ldb = (db + 7) / 8;
437 
438 #pragma omp parallel for if (n > 10000)
439     for (int64_t i = 0; i < n; i++) {
440         const uint8_t* ai = a + i * lda;
441         uint8_t* bi = b + i * ldb;
442         memset(bi, 0, ldb);
443         for (size_t j = 0; j < db; j++) {
444             int o = order[j];
445             uint8_t the_bit = (ai[o >> 3] >> (o & 7)) & 1;
446             bi[j >> 3] |= the_bit << (j & 7);
447         }
448     }
449 }
450 
451 /*----------------------------------------*/
452 /* Hamming distance computation and k-nn  */
453 
454 #define C64(x) ((uint64_t*)x)
455 
456 /* Compute a set of Hamming distances */
hammings(const uint8_t * a,const uint8_t * b,size_t na,size_t nb,size_t ncodes,hamdis_t * __restrict dis)457 void hammings(
458         const uint8_t* a,
459         const uint8_t* b,
460         size_t na,
461         size_t nb,
462         size_t ncodes,
463         hamdis_t* __restrict dis) {
464     FAISS_THROW_IF_NOT(ncodes % 8 == 0);
465     switch (ncodes) {
466         case 8:
467             faiss::hammings<64>(C64(a), C64(b), na, nb, dis);
468             return;
469         case 16:
470             faiss::hammings<128>(C64(a), C64(b), na, nb, dis);
471             return;
472         case 32:
473             faiss::hammings<256>(C64(a), C64(b), na, nb, dis);
474             return;
475         case 64:
476             faiss::hammings<512>(C64(a), C64(b), na, nb, dis);
477             return;
478         default:
479             faiss::hammings(C64(a), C64(b), na, nb, ncodes * 8, dis);
480             return;
481     }
482 }
483 
hammings_knn(int_maxheap_array_t * ha,const uint8_t * a,const uint8_t * b,size_t nb,size_t ncodes,int order)484 void hammings_knn(
485         int_maxheap_array_t* ha,
486         const uint8_t* a,
487         const uint8_t* b,
488         size_t nb,
489         size_t ncodes,
490         int order) {
491     hammings_knn_hc(ha, a, b, nb, ncodes, order);
492 }
493 
hammings_knn_hc(int_maxheap_array_t * ha,const uint8_t * a,const uint8_t * b,size_t nb,size_t ncodes,int order)494 void hammings_knn_hc(
495         int_maxheap_array_t* ha,
496         const uint8_t* a,
497         const uint8_t* b,
498         size_t nb,
499         size_t ncodes,
500         int order) {
501     switch (ncodes) {
502         case 4:
503             hammings_knn_hc<faiss::HammingComputer4>(
504                     4, ha, a, b, nb, order, true);
505             break;
506         case 8:
507             hammings_knn_hc_1(ha, C64(a), C64(b), nb, order, true);
508             // hammings_knn_hc<faiss::HammingComputer8>
509             //      (8, ha, a, b, nb, order, true);
510             break;
511         case 16:
512             hammings_knn_hc<faiss::HammingComputer16>(
513                     16, ha, a, b, nb, order, true);
514             break;
515         case 32:
516             hammings_knn_hc<faiss::HammingComputer32>(
517                     32, ha, a, b, nb, order, true);
518             break;
519         default:
520             hammings_knn_hc<faiss::HammingComputerDefault>(
521                     ncodes, ha, a, b, nb, order, true);
522             break;
523     }
524 }
525 
hammings_knn_mc(const uint8_t * a,const uint8_t * b,size_t na,size_t nb,size_t k,size_t ncodes,int32_t * distances,int64_t * labels)526 void hammings_knn_mc(
527         const uint8_t* a,
528         const uint8_t* b,
529         size_t na,
530         size_t nb,
531         size_t k,
532         size_t ncodes,
533         int32_t* distances,
534         int64_t* labels) {
535     switch (ncodes) {
536         case 4:
537             hammings_knn_mc<faiss::HammingComputer4>(
538                     4, a, b, na, nb, k, distances, labels);
539             break;
540         case 8:
541             // TODO(hoss): Write analog to hammings_knn_hc_1
542             // hammings_knn_hc_1 (ha, C64(a), C64(b), nb, order, true);
543             hammings_knn_mc<faiss::HammingComputer8>(
544                     8, a, b, na, nb, k, distances, labels);
545             break;
546         case 16:
547             hammings_knn_mc<faiss::HammingComputer16>(
548                     16, a, b, na, nb, k, distances, labels);
549             break;
550         case 32:
551             hammings_knn_mc<faiss::HammingComputer32>(
552                     32, a, b, na, nb, k, distances, labels);
553             break;
554         default:
555             hammings_knn_mc<faiss::HammingComputerDefault>(
556                     ncodes, a, b, na, nb, k, distances, labels);
557             break;
558     }
559 }
560 template <class HammingComputer>
hamming_range_search_template(const uint8_t * a,const uint8_t * b,size_t na,size_t nb,int radius,size_t code_size,RangeSearchResult * res)561 static void hamming_range_search_template(
562         const uint8_t* a,
563         const uint8_t* b,
564         size_t na,
565         size_t nb,
566         int radius,
567         size_t code_size,
568         RangeSearchResult* res) {
569 #pragma omp parallel
570     {
571         RangeSearchPartialResult pres(res);
572 
573 #pragma omp for
574         for (int64_t i = 0; i < na; i++) {
575             HammingComputer hc(a + i * code_size, code_size);
576             const uint8_t* yi = b;
577             RangeQueryResult& qres = pres.new_result(i);
578 
579             for (size_t j = 0; j < nb; j++) {
580                 int dis = hc.hamming(yi);
581                 if (dis < radius) {
582                     qres.add(dis, j);
583                 }
584                 yi += code_size;
585             }
586         }
587         pres.finalize();
588     }
589 }
590 
hamming_range_search(const uint8_t * a,const uint8_t * b,size_t na,size_t nb,int radius,size_t code_size,RangeSearchResult * result)591 void hamming_range_search(
592         const uint8_t* a,
593         const uint8_t* b,
594         size_t na,
595         size_t nb,
596         int radius,
597         size_t code_size,
598         RangeSearchResult* result) {
599 #define HC(name) \
600     hamming_range_search_template<name>(a, b, na, nb, radius, code_size, result)
601 
602     switch (code_size) {
603         case 4:
604             HC(HammingComputer4);
605             break;
606         case 8:
607             HC(HammingComputer8);
608             break;
609         case 16:
610             HC(HammingComputer16);
611             break;
612         case 32:
613             HC(HammingComputer32);
614             break;
615         default:
616             HC(HammingComputerDefault);
617             break;
618     }
619 #undef HC
620 }
621 
622 /* Count number of matches given a max threshold            */
hamming_count_thres(const uint8_t * bs1,const uint8_t * bs2,size_t n1,size_t n2,hamdis_t ht,size_t ncodes,size_t * nptr)623 void hamming_count_thres(
624         const uint8_t* bs1,
625         const uint8_t* bs2,
626         size_t n1,
627         size_t n2,
628         hamdis_t ht,
629         size_t ncodes,
630         size_t* nptr) {
631     switch (ncodes) {
632         case 8:
633             faiss::hamming_count_thres<64>(
634                     C64(bs1), C64(bs2), n1, n2, ht, nptr);
635             return;
636         case 16:
637             faiss::hamming_count_thres<128>(
638                     C64(bs1), C64(bs2), n1, n2, ht, nptr);
639             return;
640         case 32:
641             faiss::hamming_count_thres<256>(
642                     C64(bs1), C64(bs2), n1, n2, ht, nptr);
643             return;
644         case 64:
645             faiss::hamming_count_thres<512>(
646                     C64(bs1), C64(bs2), n1, n2, ht, nptr);
647             return;
648         default:
649             FAISS_THROW_FMT("not implemented for %zu bits", ncodes);
650     }
651 }
652 
653 /* Count number of cross-matches given a threshold */
crosshamming_count_thres(const uint8_t * dbs,size_t n,hamdis_t ht,size_t ncodes,size_t * nptr)654 void crosshamming_count_thres(
655         const uint8_t* dbs,
656         size_t n,
657         hamdis_t ht,
658         size_t ncodes,
659         size_t* nptr) {
660     switch (ncodes) {
661         case 8:
662             faiss::crosshamming_count_thres<64>(C64(dbs), n, ht, nptr);
663             return;
664         case 16:
665             faiss::crosshamming_count_thres<128>(C64(dbs), n, ht, nptr);
666             return;
667         case 32:
668             faiss::crosshamming_count_thres<256>(C64(dbs), n, ht, nptr);
669             return;
670         case 64:
671             faiss::crosshamming_count_thres<512>(C64(dbs), n, ht, nptr);
672             return;
673         default:
674             FAISS_THROW_FMT("not implemented for %zu bits", ncodes);
675     }
676 }
677 
678 /* Returns all matches given a threshold */
match_hamming_thres(const uint8_t * bs1,const uint8_t * bs2,size_t n1,size_t n2,hamdis_t ht,size_t ncodes,int64_t * idx,hamdis_t * dis)679 size_t match_hamming_thres(
680         const uint8_t* bs1,
681         const uint8_t* bs2,
682         size_t n1,
683         size_t n2,
684         hamdis_t ht,
685         size_t ncodes,
686         int64_t* idx,
687         hamdis_t* dis) {
688     switch (ncodes) {
689         case 8:
690             return faiss::match_hamming_thres<64>(
691                     C64(bs1), C64(bs2), n1, n2, ht, idx, dis);
692         case 16:
693             return faiss::match_hamming_thres<128>(
694                     C64(bs1), C64(bs2), n1, n2, ht, idx, dis);
695         case 32:
696             return faiss::match_hamming_thres<256>(
697                     C64(bs1), C64(bs2), n1, n2, ht, idx, dis);
698         case 64:
699             return faiss::match_hamming_thres<512>(
700                     C64(bs1), C64(bs2), n1, n2, ht, idx, dis);
701         default:
702             FAISS_THROW_FMT("not implemented for %zu bits", ncodes);
703             return 0;
704     }
705 }
706 
707 #undef C64
708 
709 /*************************************
710  * generalized Hamming distances
711  ************************************/
712 
713 template <class HammingComputer>
hamming_dis_inner_loop(const uint8_t * ca,const uint8_t * cb,size_t nb,size_t code_size,int k,hamdis_t * bh_val_,int64_t * bh_ids_)714 static void hamming_dis_inner_loop(
715         const uint8_t* ca,
716         const uint8_t* cb,
717         size_t nb,
718         size_t code_size,
719         int k,
720         hamdis_t* bh_val_,
721         int64_t* bh_ids_) {
722     HammingComputer hc(ca, code_size);
723 
724     for (size_t j = 0; j < nb; j++) {
725         int ndiff = hc.hamming(cb);
726         cb += code_size;
727         if (ndiff < bh_val_[0]) {
728             maxheap_replace_top<hamdis_t>(k, bh_val_, bh_ids_, ndiff, j);
729         }
730     }
731 }
732 
generalized_hammings_knn_hc(int_maxheap_array_t * ha,const uint8_t * a,const uint8_t * b,size_t nb,size_t code_size,int ordered)733 void generalized_hammings_knn_hc(
734         int_maxheap_array_t* ha,
735         const uint8_t* a,
736         const uint8_t* b,
737         size_t nb,
738         size_t code_size,
739         int ordered) {
740     int na = ha->nh;
741     int k = ha->k;
742 
743     if (ordered)
744         ha->heapify();
745 
746 #pragma omp parallel for
747     for (int i = 0; i < na; i++) {
748         const uint8_t* ca = a + i * code_size;
749         const uint8_t* cb = b;
750 
751         hamdis_t* bh_val_ = ha->val + i * k;
752         int64_t* bh_ids_ = ha->ids + i * k;
753 
754         switch (code_size) {
755             case 8:
756                 hamming_dis_inner_loop<GenHammingComputer8>(
757                         ca, cb, nb, 8, k, bh_val_, bh_ids_);
758                 break;
759             case 16:
760                 hamming_dis_inner_loop<GenHammingComputer16>(
761                         ca, cb, nb, 16, k, bh_val_, bh_ids_);
762                 break;
763             case 32:
764                 hamming_dis_inner_loop<GenHammingComputer32>(
765                         ca, cb, nb, 32, k, bh_val_, bh_ids_);
766                 break;
767             default:
768                 hamming_dis_inner_loop<GenHammingComputerM8>(
769                         ca, cb, nb, code_size, k, bh_val_, bh_ids_);
770                 break;
771         }
772     }
773 
774     if (ordered)
775         ha->reorder();
776 }
777 
778 } // namespace faiss
779