1 /**
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 *
4 * This source code is licensed under the MIT license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
8 // -*- c++ -*-
9
10 #include <faiss/impl/lattice_Zn.h>
11
12 #include <cassert>
13 #include <cmath>
14 #include <cstdlib>
15 #include <cstring>
16
17 #include <algorithm>
18 #include <queue>
19 #include <unordered_map>
20 #include <unordered_set>
21
22 #include <faiss/impl/platform_macros.h>
23 #include <faiss/utils/distances.h>
24
25 namespace faiss {
26
27 /********************************************
28 * small utility functions
29 ********************************************/
30
31 namespace {
32
sqr(float x)33 inline float sqr(float x) {
34 return x * x;
35 }
36
37 typedef std::vector<float> point_list_t;
38
39 struct Comb {
40 std::vector<uint64_t> tab; // Pascal's triangle
41 int nmax;
42
Combfaiss::__anon3f10e8930111::Comb43 explicit Comb(int nmax) : nmax(nmax) {
44 tab.resize(nmax * nmax, 0);
45 tab[0] = 1;
46 for (int i = 1; i < nmax; i++) {
47 tab[i * nmax] = 1;
48 for (int j = 1; j <= i; j++) {
49 tab[i * nmax + j] =
50 tab[(i - 1) * nmax + j] + tab[(i - 1) * nmax + (j - 1)];
51 }
52 }
53 }
54
operator ()faiss::__anon3f10e8930111::Comb55 uint64_t operator()(int n, int p) const {
56 assert(n < nmax && p < nmax);
57 if (p > n)
58 return 0;
59 return tab[n * nmax + p];
60 }
61 };
62
63 Comb comb(100);
64
65 // compute combinations of n integer values <= v that sum up to total (squared)
sum_of_sq(float total,int v,int n,float add=0)66 point_list_t sum_of_sq(float total, int v, int n, float add = 0) {
67 if (total < 0) {
68 return point_list_t();
69 } else if (n == 1) {
70 while (sqr(v + add) > total)
71 v--;
72 if (sqr(v + add) == total) {
73 return point_list_t(1, v + add);
74 } else {
75 return point_list_t();
76 }
77 } else {
78 point_list_t res;
79 while (v >= 0) {
80 point_list_t sub_points =
81 sum_of_sq(total - sqr(v + add), v, n - 1, add);
82 for (size_t i = 0; i < sub_points.size(); i += n - 1) {
83 res.push_back(v + add);
84 for (int j = 0; j < n - 1; j++) {
85 res.push_back(sub_points[i + j]);
86 }
87 }
88 v--;
89 }
90 return res;
91 }
92 }
93
decode_comb_1(uint64_t * n,int k1,int r)94 int decode_comb_1(uint64_t* n, int k1, int r) {
95 while (comb(r, k1) > *n) {
96 r--;
97 }
98 *n -= comb(r, k1);
99 return r;
100 }
101
102 // optimized version for < 64 bits
repeats_encode_64(const std::vector<Repeat> & repeats,int dim,const float * c)103 uint64_t repeats_encode_64(
104 const std::vector<Repeat>& repeats,
105 int dim,
106 const float* c) {
107 uint64_t coded = 0;
108 int nfree = dim;
109 uint64_t code = 0, shift = 1;
110 for (auto r = repeats.begin(); r != repeats.end(); ++r) {
111 int rank = 0, occ = 0;
112 uint64_t code_comb = 0;
113 uint64_t tosee = ~coded;
114 for (;;) {
115 // directly jump to next available slot.
116 int i = __builtin_ctzll(tosee);
117 tosee &= ~(uint64_t{1} << i);
118 if (c[i] == r->val) {
119 code_comb += comb(rank, occ + 1);
120 occ++;
121 coded |= uint64_t{1} << i;
122 if (occ == r->n)
123 break;
124 }
125 rank++;
126 }
127 uint64_t max_comb = comb(nfree, r->n);
128 code += shift * code_comb;
129 shift *= max_comb;
130 nfree -= r->n;
131 }
132 return code;
133 }
134
repeats_decode_64(const std::vector<Repeat> & repeats,int dim,uint64_t code,float * c)135 void repeats_decode_64(
136 const std::vector<Repeat>& repeats,
137 int dim,
138 uint64_t code,
139 float* c) {
140 uint64_t decoded = 0;
141 int nfree = dim;
142 for (auto r = repeats.begin(); r != repeats.end(); ++r) {
143 uint64_t max_comb = comb(nfree, r->n);
144 uint64_t code_comb = code % max_comb;
145 code /= max_comb;
146
147 int occ = 0;
148 int rank = nfree;
149 int next_rank = decode_comb_1(&code_comb, r->n, rank);
150 uint64_t tosee = ((uint64_t{1} << dim) - 1) ^ decoded;
151 for (;;) {
152 int i = 63 - __builtin_clzll(tosee);
153 tosee &= ~(uint64_t{1} << i);
154 rank--;
155 if (rank == next_rank) {
156 decoded |= uint64_t{1} << i;
157 c[i] = r->val;
158 occ++;
159 if (occ == r->n)
160 break;
161 next_rank = decode_comb_1(&code_comb, r->n - occ, next_rank);
162 }
163 }
164 nfree -= r->n;
165 }
166 }
167
168 } // anonymous namespace
169
Repeats(int dim,const float * c)170 Repeats::Repeats(int dim, const float* c) : dim(dim) {
171 for (int i = 0; i < dim; i++) {
172 int j = 0;
173 for (;;) {
174 if (j == repeats.size()) {
175 repeats.push_back(Repeat{c[i], 1});
176 break;
177 }
178 if (repeats[j].val == c[i]) {
179 repeats[j].n++;
180 break;
181 }
182 j++;
183 }
184 }
185 }
186
count() const187 uint64_t Repeats::count() const {
188 uint64_t accu = 1;
189 int remain = dim;
190 for (int i = 0; i < repeats.size(); i++) {
191 accu *= comb(remain, repeats[i].n);
192 remain -= repeats[i].n;
193 }
194 return accu;
195 }
196
197 // version with a bool vector that works for > 64 dim
encode(const float * c) const198 uint64_t Repeats::encode(const float* c) const {
199 if (dim < 64) {
200 return repeats_encode_64(repeats, dim, c);
201 }
202 std::vector<bool> coded(dim, false);
203 int nfree = dim;
204 uint64_t code = 0, shift = 1;
205 for (auto r = repeats.begin(); r != repeats.end(); ++r) {
206 int rank = 0, occ = 0;
207 uint64_t code_comb = 0;
208 for (int i = 0; i < dim; i++) {
209 if (!coded[i]) {
210 if (c[i] == r->val) {
211 code_comb += comb(rank, occ + 1);
212 occ++;
213 coded[i] = true;
214 if (occ == r->n)
215 break;
216 }
217 rank++;
218 }
219 }
220 uint64_t max_comb = comb(nfree, r->n);
221 code += shift * code_comb;
222 shift *= max_comb;
223 nfree -= r->n;
224 }
225 return code;
226 }
227
decode(uint64_t code,float * c) const228 void Repeats::decode(uint64_t code, float* c) const {
229 if (dim < 64) {
230 repeats_decode_64(repeats, dim, code, c);
231 return;
232 }
233
234 std::vector<bool> decoded(dim, false);
235 int nfree = dim;
236 for (auto r = repeats.begin(); r != repeats.end(); ++r) {
237 uint64_t max_comb = comb(nfree, r->n);
238 uint64_t code_comb = code % max_comb;
239 code /= max_comb;
240
241 int occ = 0;
242 int rank = nfree;
243 int next_rank = decode_comb_1(&code_comb, r->n, rank);
244 for (int i = dim - 1; i >= 0; i--) {
245 if (!decoded[i]) {
246 rank--;
247 if (rank == next_rank) {
248 decoded[i] = true;
249 c[i] = r->val;
250 occ++;
251 if (occ == r->n)
252 break;
253 next_rank =
254 decode_comb_1(&code_comb, r->n - occ, next_rank);
255 }
256 }
257 }
258 nfree -= r->n;
259 }
260 }
261
262 /********************************************
263 * EnumeratedVectors functions
264 ********************************************/
265
encode_multi(size_t n,const float * c,uint64_t * codes) const266 void EnumeratedVectors::encode_multi(size_t n, const float* c, uint64_t* codes)
267 const {
268 #pragma omp parallel if (n > 1000)
269 {
270 #pragma omp for
271 for (int i = 0; i < n; i++) {
272 codes[i] = encode(c + i * dim);
273 }
274 }
275 }
276
decode_multi(size_t n,const uint64_t * codes,float * c) const277 void EnumeratedVectors::decode_multi(size_t n, const uint64_t* codes, float* c)
278 const {
279 #pragma omp parallel if (n > 1000)
280 {
281 #pragma omp for
282 for (int i = 0; i < n; i++) {
283 decode(codes[i], c + i * dim);
284 }
285 }
286 }
287
find_nn(size_t nc,const uint64_t * codes,size_t nq,const float * xq,int64_t * labels,float * distances)288 void EnumeratedVectors::find_nn(
289 size_t nc,
290 const uint64_t* codes,
291 size_t nq,
292 const float* xq,
293 int64_t* labels,
294 float* distances) {
295 for (size_t i = 0; i < nq; i++) {
296 distances[i] = -1e20;
297 labels[i] = -1;
298 }
299
300 std::vector<float> c(dim);
301 for (size_t i = 0; i < nc; i++) {
302 uint64_t code = codes[nc];
303 decode(code, c.data());
304 for (size_t j = 0; j < nq; j++) {
305 const float* x = xq + j * dim;
306 float dis = fvec_inner_product(x, c.data(), dim);
307 if (dis > distances[j]) {
308 distances[j] = dis;
309 labels[j] = i;
310 }
311 }
312 }
313 }
314
315 /**********************************************************
316 * ZnSphereSearch
317 **********************************************************/
318
ZnSphereSearch(int dim,int r2)319 ZnSphereSearch::ZnSphereSearch(int dim, int r2) : dimS(dim), r2(r2) {
320 voc = sum_of_sq(r2, int(ceil(sqrt(r2)) + 1), dim);
321 natom = voc.size() / dim;
322 }
323
search(const float * x,float * c) const324 float ZnSphereSearch::search(const float* x, float* c) const {
325 std::vector<float> tmp(dimS * 2);
326 std::vector<int> tmp_int(dimS);
327 return search(x, c, tmp.data(), tmp_int.data());
328 }
329
search(const float * x,float * c,float * tmp,int * tmp_int,int * ibest_out) const330 float ZnSphereSearch::search(
331 const float* x,
332 float* c,
333 float* tmp, // size 2 *dim
334 int* tmp_int, // size dim
335 int* ibest_out) const {
336 int dim = dimS;
337 assert(natom > 0);
338 int* o = tmp_int;
339 float* xabs = tmp;
340 float* xperm = tmp + dim;
341
342 // argsort
343 for (int i = 0; i < dim; i++) {
344 o[i] = i;
345 xabs[i] = fabsf(x[i]);
346 }
347 std::sort(o, o + dim, [xabs](int a, int b) { return xabs[a] > xabs[b]; });
348 for (int i = 0; i < dim; i++) {
349 xperm[i] = xabs[o[i]];
350 }
351 // find best
352 int ibest = -1;
353 float dpbest = -100;
354 for (int i = 0; i < natom; i++) {
355 float dp = fvec_inner_product(voc.data() + i * dim, xperm, dim);
356 if (dp > dpbest) {
357 dpbest = dp;
358 ibest = i;
359 }
360 }
361 // revert sort
362 const float* cin = voc.data() + ibest * dim;
363 for (int i = 0; i < dim; i++) {
364 c[o[i]] = copysignf(cin[i], x[o[i]]);
365 }
366 if (ibest_out) {
367 *ibest_out = ibest;
368 }
369 return dpbest;
370 }
371
search_multi(int n,const float * x,float * c_out,float * dp_out)372 void ZnSphereSearch::search_multi(
373 int n,
374 const float* x,
375 float* c_out,
376 float* dp_out) {
377 #pragma omp parallel if (n > 1000)
378 {
379 #pragma omp for
380 for (int i = 0; i < n; i++) {
381 dp_out[i] = search(x + i * dimS, c_out + i * dimS);
382 }
383 }
384 }
385
386 /**********************************************************
387 * ZnSphereCodec
388 **********************************************************/
389
ZnSphereCodec(int dim,int r2)390 ZnSphereCodec::ZnSphereCodec(int dim, int r2)
391 : ZnSphereSearch(dim, r2), EnumeratedVectors(dim) {
392 nv = 0;
393 for (int i = 0; i < natom; i++) {
394 Repeats repeats(dim, &voc[i * dim]);
395 CodeSegment cs(repeats);
396 cs.c0 = nv;
397 Repeat& br = repeats.repeats.back();
398 cs.signbits = br.val == 0 ? dim - br.n : dim;
399 code_segments.push_back(cs);
400 nv += repeats.count() << cs.signbits;
401 }
402
403 uint64_t nvx = nv;
404 code_size = 0;
405 while (nvx > 0) {
406 nvx >>= 8;
407 code_size++;
408 }
409 }
410
search_and_encode(const float * x) const411 uint64_t ZnSphereCodec::search_and_encode(const float* x) const {
412 std::vector<float> tmp(dim * 2);
413 std::vector<int> tmp_int(dim);
414 int ano; // atom number
415 std::vector<float> c(dim);
416 search(x, c.data(), tmp.data(), tmp_int.data(), &ano);
417 uint64_t signs = 0;
418 std::vector<float> cabs(dim);
419 int nnz = 0;
420 for (int i = 0; i < dim; i++) {
421 cabs[i] = fabs(c[i]);
422 if (c[i] != 0) {
423 if (c[i] < 0) {
424 signs |= uint64_t{1} << nnz;
425 }
426 nnz++;
427 }
428 }
429 const CodeSegment& cs = code_segments[ano];
430 assert(nnz == cs.signbits);
431 uint64_t code = cs.c0 + signs;
432 code += cs.encode(cabs.data()) << cs.signbits;
433 return code;
434 }
435
encode(const float * x) const436 uint64_t ZnSphereCodec::encode(const float* x) const {
437 return search_and_encode(x);
438 }
439
decode(uint64_t code,float * c) const440 void ZnSphereCodec::decode(uint64_t code, float* c) const {
441 int i0 = 0, i1 = natom;
442 while (i0 + 1 < i1) {
443 int imed = (i0 + i1) / 2;
444 if (code_segments[imed].c0 <= code)
445 i0 = imed;
446 else
447 i1 = imed;
448 }
449 const CodeSegment& cs = code_segments[i0];
450 code -= cs.c0;
451 uint64_t signs = code;
452 code >>= cs.signbits;
453 cs.decode(code, c);
454
455 int nnz = 0;
456 for (int i = 0; i < dim; i++) {
457 if (c[i] != 0) {
458 if (signs & (1UL << nnz)) {
459 c[i] = -c[i];
460 }
461 nnz++;
462 }
463 }
464 }
465
466 /**************************************************************
467 * ZnSphereCodecRec
468 **************************************************************/
469
get_nv(int ld,int r2a) const470 uint64_t ZnSphereCodecRec::get_nv(int ld, int r2a) const {
471 return all_nv[ld * (r2 + 1) + r2a];
472 }
473
get_nv_cum(int ld,int r2t,int r2a) const474 uint64_t ZnSphereCodecRec::get_nv_cum(int ld, int r2t, int r2a) const {
475 return all_nv_cum[(ld * (r2 + 1) + r2t) * (r2 + 1) + r2a];
476 }
477
set_nv_cum(int ld,int r2t,int r2a,uint64_t cum)478 void ZnSphereCodecRec::set_nv_cum(int ld, int r2t, int r2a, uint64_t cum) {
479 all_nv_cum[(ld * (r2 + 1) + r2t) * (r2 + 1) + r2a] = cum;
480 }
481
ZnSphereCodecRec(int dim,int r2)482 ZnSphereCodecRec::ZnSphereCodecRec(int dim, int r2)
483 : EnumeratedVectors(dim), r2(r2) {
484 log2_dim = 0;
485 while (dim > (1 << log2_dim)) {
486 log2_dim++;
487 }
488 assert(dim == (1 << log2_dim) || !"dimension must be a power of 2");
489
490 all_nv.resize((log2_dim + 1) * (r2 + 1));
491 all_nv_cum.resize((log2_dim + 1) * (r2 + 1) * (r2 + 1));
492
493 for (int r2a = 0; r2a <= r2; r2a++) {
494 int r = int(sqrt(r2a));
495 if (r * r == r2a) {
496 all_nv[r2a] = r == 0 ? 1 : 2;
497 } else {
498 all_nv[r2a] = 0;
499 }
500 }
501
502 for (int ld = 1; ld <= log2_dim; ld++) {
503 for (int r2sub = 0; r2sub <= r2; r2sub++) {
504 uint64_t nv = 0;
505 for (int r2a = 0; r2a <= r2sub; r2a++) {
506 int r2b = r2sub - r2a;
507 set_nv_cum(ld, r2sub, r2a, nv);
508 nv += get_nv(ld - 1, r2a) * get_nv(ld - 1, r2b);
509 }
510 all_nv[ld * (r2 + 1) + r2sub] = nv;
511 }
512 }
513 nv = get_nv(log2_dim, r2);
514
515 uint64_t nvx = nv;
516 code_size = 0;
517 while (nvx > 0) {
518 nvx >>= 8;
519 code_size++;
520 }
521
522 int cache_level = std::min(3, log2_dim - 1);
523 decode_cache_ld = 0;
524 assert(cache_level <= log2_dim);
525 decode_cache.resize((r2 + 1));
526
527 for (int r2sub = 0; r2sub <= r2; r2sub++) {
528 int ld = cache_level;
529 uint64_t nvi = get_nv(ld, r2sub);
530 std::vector<float>& cache = decode_cache[r2sub];
531 int dimsub = (1 << cache_level);
532 cache.resize(nvi * dimsub);
533 std::vector<float> c(dim);
534 uint64_t code0 = get_nv_cum(cache_level + 1, r2, r2 - r2sub);
535 for (int i = 0; i < nvi; i++) {
536 decode(i + code0, c.data());
537 memcpy(&cache[i * dimsub],
538 c.data() + dim - dimsub,
539 dimsub * sizeof(*c.data()));
540 }
541 }
542 decode_cache_ld = cache_level;
543 }
544
encode(const float * c) const545 uint64_t ZnSphereCodecRec::encode(const float* c) const {
546 return encode_centroid(c);
547 }
548
encode_centroid(const float * c) const549 uint64_t ZnSphereCodecRec::encode_centroid(const float* c) const {
550 std::vector<uint64_t> codes(dim);
551 std::vector<int> norm2s(dim);
552 for (int i = 0; i < dim; i++) {
553 if (c[i] == 0) {
554 codes[i] = 0;
555 norm2s[i] = 0;
556 } else {
557 int r2i = int(c[i] * c[i]);
558 norm2s[i] = r2i;
559 codes[i] = c[i] >= 0 ? 0 : 1;
560 }
561 }
562 int dim2 = dim / 2;
563 for (int ld = 1; ld <= log2_dim; ld++) {
564 for (int i = 0; i < dim2; i++) {
565 int r2a = norm2s[2 * i];
566 int r2b = norm2s[2 * i + 1];
567
568 uint64_t code_a = codes[2 * i];
569 uint64_t code_b = codes[2 * i + 1];
570
571 codes[i] = get_nv_cum(ld, r2a + r2b, r2a) +
572 code_a * get_nv(ld - 1, r2b) + code_b;
573 norm2s[i] = r2a + r2b;
574 }
575 dim2 /= 2;
576 }
577 return codes[0];
578 }
579
decode(uint64_t code,float * c) const580 void ZnSphereCodecRec::decode(uint64_t code, float* c) const {
581 std::vector<uint64_t> codes(dim);
582 std::vector<int> norm2s(dim);
583 codes[0] = code;
584 norm2s[0] = r2;
585
586 int dim2 = 1;
587 for (int ld = log2_dim; ld > decode_cache_ld; ld--) {
588 for (int i = dim2 - 1; i >= 0; i--) {
589 int r2sub = norm2s[i];
590 int i0 = 0, i1 = r2sub + 1;
591 uint64_t codei = codes[i];
592 const uint64_t* cum =
593 &all_nv_cum[(ld * (r2 + 1) + r2sub) * (r2 + 1)];
594 while (i1 > i0 + 1) {
595 int imed = (i0 + i1) / 2;
596 if (cum[imed] <= codei)
597 i0 = imed;
598 else
599 i1 = imed;
600 }
601 int r2a = i0, r2b = r2sub - i0;
602 codei -= cum[r2a];
603 norm2s[2 * i] = r2a;
604 norm2s[2 * i + 1] = r2b;
605
606 uint64_t code_a = codei / get_nv(ld - 1, r2b);
607 uint64_t code_b = codei % get_nv(ld - 1, r2b);
608
609 codes[2 * i] = code_a;
610 codes[2 * i + 1] = code_b;
611 }
612 dim2 *= 2;
613 }
614
615 if (decode_cache_ld == 0) {
616 for (int i = 0; i < dim; i++) {
617 if (norm2s[i] == 0) {
618 c[i] = 0;
619 } else {
620 float r = sqrt(norm2s[i]);
621 assert(r * r == norm2s[i]);
622 c[i] = codes[i] == 0 ? r : -r;
623 }
624 }
625 } else {
626 int subdim = 1 << decode_cache_ld;
627 assert((dim2 * subdim) == dim);
628
629 for (int i = 0; i < dim2; i++) {
630 const std::vector<float>& cache = decode_cache[norm2s[i]];
631 assert(codes[i] < cache.size());
632 memcpy(c + i * subdim,
633 &cache[codes[i] * subdim],
634 sizeof(*c) * subdim);
635 }
636 }
637 }
638
639 // if not use_rec, instanciate an arbitrary harmless znc_rec
ZnSphereCodecAlt(int dim,int r2)640 ZnSphereCodecAlt::ZnSphereCodecAlt(int dim, int r2)
641 : ZnSphereCodec(dim, r2),
642 use_rec((dim & (dim - 1)) == 0),
643 znc_rec(use_rec ? dim : 8, use_rec ? r2 : 14) {}
644
encode(const float * x) const645 uint64_t ZnSphereCodecAlt::encode(const float* x) const {
646 if (!use_rec) {
647 // it's ok if the vector is not normalized
648 return ZnSphereCodec::encode(x);
649 } else {
650 // find nearest centroid
651 std::vector<float> centroid(dim);
652 search(x, centroid.data());
653 return znc_rec.encode(centroid.data());
654 }
655 }
656
decode(uint64_t code,float * c) const657 void ZnSphereCodecAlt::decode(uint64_t code, float* c) const {
658 if (!use_rec) {
659 ZnSphereCodec::decode(code, c);
660 } else {
661 znc_rec.decode(code, c);
662 }
663 }
664
665 } // namespace faiss
666