1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 // -*- c++ -*-
9 
10 #include <faiss/impl/lattice_Zn.h>
11 
12 #include <cassert>
13 #include <cmath>
14 #include <cstdlib>
15 #include <cstring>
16 
17 #include <algorithm>
18 #include <queue>
19 #include <unordered_map>
20 #include <unordered_set>
21 
22 #include <faiss/impl/platform_macros.h>
23 #include <faiss/utils/distances.h>
24 
25 namespace faiss {
26 
27 /********************************************
28  * small utility functions
29  ********************************************/
30 
31 namespace {
32 
sqr(float x)33 inline float sqr(float x) {
34     return x * x;
35 }
36 
37 typedef std::vector<float> point_list_t;
38 
39 struct Comb {
40     std::vector<uint64_t> tab; // Pascal's triangle
41     int nmax;
42 
Combfaiss::__anon3f10e8930111::Comb43     explicit Comb(int nmax) : nmax(nmax) {
44         tab.resize(nmax * nmax, 0);
45         tab[0] = 1;
46         for (int i = 1; i < nmax; i++) {
47             tab[i * nmax] = 1;
48             for (int j = 1; j <= i; j++) {
49                 tab[i * nmax + j] =
50                         tab[(i - 1) * nmax + j] + tab[(i - 1) * nmax + (j - 1)];
51             }
52         }
53     }
54 
operator ()faiss::__anon3f10e8930111::Comb55     uint64_t operator()(int n, int p) const {
56         assert(n < nmax && p < nmax);
57         if (p > n)
58             return 0;
59         return tab[n * nmax + p];
60     }
61 };
62 
63 Comb comb(100);
64 
65 // compute combinations of n integer values <= v that sum up to total (squared)
sum_of_sq(float total,int v,int n,float add=0)66 point_list_t sum_of_sq(float total, int v, int n, float add = 0) {
67     if (total < 0) {
68         return point_list_t();
69     } else if (n == 1) {
70         while (sqr(v + add) > total)
71             v--;
72         if (sqr(v + add) == total) {
73             return point_list_t(1, v + add);
74         } else {
75             return point_list_t();
76         }
77     } else {
78         point_list_t res;
79         while (v >= 0) {
80             point_list_t sub_points =
81                     sum_of_sq(total - sqr(v + add), v, n - 1, add);
82             for (size_t i = 0; i < sub_points.size(); i += n - 1) {
83                 res.push_back(v + add);
84                 for (int j = 0; j < n - 1; j++) {
85                     res.push_back(sub_points[i + j]);
86                 }
87             }
88             v--;
89         }
90         return res;
91     }
92 }
93 
decode_comb_1(uint64_t * n,int k1,int r)94 int decode_comb_1(uint64_t* n, int k1, int r) {
95     while (comb(r, k1) > *n) {
96         r--;
97     }
98     *n -= comb(r, k1);
99     return r;
100 }
101 
102 // optimized version for < 64 bits
repeats_encode_64(const std::vector<Repeat> & repeats,int dim,const float * c)103 uint64_t repeats_encode_64(
104         const std::vector<Repeat>& repeats,
105         int dim,
106         const float* c) {
107     uint64_t coded = 0;
108     int nfree = dim;
109     uint64_t code = 0, shift = 1;
110     for (auto r = repeats.begin(); r != repeats.end(); ++r) {
111         int rank = 0, occ = 0;
112         uint64_t code_comb = 0;
113         uint64_t tosee = ~coded;
114         for (;;) {
115             // directly jump to next available slot.
116             int i = __builtin_ctzll(tosee);
117             tosee &= ~(uint64_t{1} << i);
118             if (c[i] == r->val) {
119                 code_comb += comb(rank, occ + 1);
120                 occ++;
121                 coded |= uint64_t{1} << i;
122                 if (occ == r->n)
123                     break;
124             }
125             rank++;
126         }
127         uint64_t max_comb = comb(nfree, r->n);
128         code += shift * code_comb;
129         shift *= max_comb;
130         nfree -= r->n;
131     }
132     return code;
133 }
134 
repeats_decode_64(const std::vector<Repeat> & repeats,int dim,uint64_t code,float * c)135 void repeats_decode_64(
136         const std::vector<Repeat>& repeats,
137         int dim,
138         uint64_t code,
139         float* c) {
140     uint64_t decoded = 0;
141     int nfree = dim;
142     for (auto r = repeats.begin(); r != repeats.end(); ++r) {
143         uint64_t max_comb = comb(nfree, r->n);
144         uint64_t code_comb = code % max_comb;
145         code /= max_comb;
146 
147         int occ = 0;
148         int rank = nfree;
149         int next_rank = decode_comb_1(&code_comb, r->n, rank);
150         uint64_t tosee = ((uint64_t{1} << dim) - 1) ^ decoded;
151         for (;;) {
152             int i = 63 - __builtin_clzll(tosee);
153             tosee &= ~(uint64_t{1} << i);
154             rank--;
155             if (rank == next_rank) {
156                 decoded |= uint64_t{1} << i;
157                 c[i] = r->val;
158                 occ++;
159                 if (occ == r->n)
160                     break;
161                 next_rank = decode_comb_1(&code_comb, r->n - occ, next_rank);
162             }
163         }
164         nfree -= r->n;
165     }
166 }
167 
168 } // anonymous namespace
169 
Repeats(int dim,const float * c)170 Repeats::Repeats(int dim, const float* c) : dim(dim) {
171     for (int i = 0; i < dim; i++) {
172         int j = 0;
173         for (;;) {
174             if (j == repeats.size()) {
175                 repeats.push_back(Repeat{c[i], 1});
176                 break;
177             }
178             if (repeats[j].val == c[i]) {
179                 repeats[j].n++;
180                 break;
181             }
182             j++;
183         }
184     }
185 }
186 
count() const187 uint64_t Repeats::count() const {
188     uint64_t accu = 1;
189     int remain = dim;
190     for (int i = 0; i < repeats.size(); i++) {
191         accu *= comb(remain, repeats[i].n);
192         remain -= repeats[i].n;
193     }
194     return accu;
195 }
196 
197 // version with a bool vector that works for > 64 dim
encode(const float * c) const198 uint64_t Repeats::encode(const float* c) const {
199     if (dim < 64) {
200         return repeats_encode_64(repeats, dim, c);
201     }
202     std::vector<bool> coded(dim, false);
203     int nfree = dim;
204     uint64_t code = 0, shift = 1;
205     for (auto r = repeats.begin(); r != repeats.end(); ++r) {
206         int rank = 0, occ = 0;
207         uint64_t code_comb = 0;
208         for (int i = 0; i < dim; i++) {
209             if (!coded[i]) {
210                 if (c[i] == r->val) {
211                     code_comb += comb(rank, occ + 1);
212                     occ++;
213                     coded[i] = true;
214                     if (occ == r->n)
215                         break;
216                 }
217                 rank++;
218             }
219         }
220         uint64_t max_comb = comb(nfree, r->n);
221         code += shift * code_comb;
222         shift *= max_comb;
223         nfree -= r->n;
224     }
225     return code;
226 }
227 
decode(uint64_t code,float * c) const228 void Repeats::decode(uint64_t code, float* c) const {
229     if (dim < 64) {
230         repeats_decode_64(repeats, dim, code, c);
231         return;
232     }
233 
234     std::vector<bool> decoded(dim, false);
235     int nfree = dim;
236     for (auto r = repeats.begin(); r != repeats.end(); ++r) {
237         uint64_t max_comb = comb(nfree, r->n);
238         uint64_t code_comb = code % max_comb;
239         code /= max_comb;
240 
241         int occ = 0;
242         int rank = nfree;
243         int next_rank = decode_comb_1(&code_comb, r->n, rank);
244         for (int i = dim - 1; i >= 0; i--) {
245             if (!decoded[i]) {
246                 rank--;
247                 if (rank == next_rank) {
248                     decoded[i] = true;
249                     c[i] = r->val;
250                     occ++;
251                     if (occ == r->n)
252                         break;
253                     next_rank =
254                             decode_comb_1(&code_comb, r->n - occ, next_rank);
255                 }
256             }
257         }
258         nfree -= r->n;
259     }
260 }
261 
262 /********************************************
263  * EnumeratedVectors functions
264  ********************************************/
265 
encode_multi(size_t n,const float * c,uint64_t * codes) const266 void EnumeratedVectors::encode_multi(size_t n, const float* c, uint64_t* codes)
267         const {
268 #pragma omp parallel if (n > 1000)
269     {
270 #pragma omp for
271         for (int i = 0; i < n; i++) {
272             codes[i] = encode(c + i * dim);
273         }
274     }
275 }
276 
decode_multi(size_t n,const uint64_t * codes,float * c) const277 void EnumeratedVectors::decode_multi(size_t n, const uint64_t* codes, float* c)
278         const {
279 #pragma omp parallel if (n > 1000)
280     {
281 #pragma omp for
282         for (int i = 0; i < n; i++) {
283             decode(codes[i], c + i * dim);
284         }
285     }
286 }
287 
find_nn(size_t nc,const uint64_t * codes,size_t nq,const float * xq,int64_t * labels,float * distances)288 void EnumeratedVectors::find_nn(
289         size_t nc,
290         const uint64_t* codes,
291         size_t nq,
292         const float* xq,
293         int64_t* labels,
294         float* distances) {
295     for (size_t i = 0; i < nq; i++) {
296         distances[i] = -1e20;
297         labels[i] = -1;
298     }
299 
300     std::vector<float> c(dim);
301     for (size_t i = 0; i < nc; i++) {
302         uint64_t code = codes[nc];
303         decode(code, c.data());
304         for (size_t j = 0; j < nq; j++) {
305             const float* x = xq + j * dim;
306             float dis = fvec_inner_product(x, c.data(), dim);
307             if (dis > distances[j]) {
308                 distances[j] = dis;
309                 labels[j] = i;
310             }
311         }
312     }
313 }
314 
315 /**********************************************************
316  * ZnSphereSearch
317  **********************************************************/
318 
ZnSphereSearch(int dim,int r2)319 ZnSphereSearch::ZnSphereSearch(int dim, int r2) : dimS(dim), r2(r2) {
320     voc = sum_of_sq(r2, int(ceil(sqrt(r2)) + 1), dim);
321     natom = voc.size() / dim;
322 }
323 
search(const float * x,float * c) const324 float ZnSphereSearch::search(const float* x, float* c) const {
325     std::vector<float> tmp(dimS * 2);
326     std::vector<int> tmp_int(dimS);
327     return search(x, c, tmp.data(), tmp_int.data());
328 }
329 
search(const float * x,float * c,float * tmp,int * tmp_int,int * ibest_out) const330 float ZnSphereSearch::search(
331         const float* x,
332         float* c,
333         float* tmp,   // size 2 *dim
334         int* tmp_int, // size dim
335         int* ibest_out) const {
336     int dim = dimS;
337     assert(natom > 0);
338     int* o = tmp_int;
339     float* xabs = tmp;
340     float* xperm = tmp + dim;
341 
342     // argsort
343     for (int i = 0; i < dim; i++) {
344         o[i] = i;
345         xabs[i] = fabsf(x[i]);
346     }
347     std::sort(o, o + dim, [xabs](int a, int b) { return xabs[a] > xabs[b]; });
348     for (int i = 0; i < dim; i++) {
349         xperm[i] = xabs[o[i]];
350     }
351     // find best
352     int ibest = -1;
353     float dpbest = -100;
354     for (int i = 0; i < natom; i++) {
355         float dp = fvec_inner_product(voc.data() + i * dim, xperm, dim);
356         if (dp > dpbest) {
357             dpbest = dp;
358             ibest = i;
359         }
360     }
361     // revert sort
362     const float* cin = voc.data() + ibest * dim;
363     for (int i = 0; i < dim; i++) {
364         c[o[i]] = copysignf(cin[i], x[o[i]]);
365     }
366     if (ibest_out) {
367         *ibest_out = ibest;
368     }
369     return dpbest;
370 }
371 
search_multi(int n,const float * x,float * c_out,float * dp_out)372 void ZnSphereSearch::search_multi(
373         int n,
374         const float* x,
375         float* c_out,
376         float* dp_out) {
377 #pragma omp parallel if (n > 1000)
378     {
379 #pragma omp for
380         for (int i = 0; i < n; i++) {
381             dp_out[i] = search(x + i * dimS, c_out + i * dimS);
382         }
383     }
384 }
385 
386 /**********************************************************
387  * ZnSphereCodec
388  **********************************************************/
389 
ZnSphereCodec(int dim,int r2)390 ZnSphereCodec::ZnSphereCodec(int dim, int r2)
391         : ZnSphereSearch(dim, r2), EnumeratedVectors(dim) {
392     nv = 0;
393     for (int i = 0; i < natom; i++) {
394         Repeats repeats(dim, &voc[i * dim]);
395         CodeSegment cs(repeats);
396         cs.c0 = nv;
397         Repeat& br = repeats.repeats.back();
398         cs.signbits = br.val == 0 ? dim - br.n : dim;
399         code_segments.push_back(cs);
400         nv += repeats.count() << cs.signbits;
401     }
402 
403     uint64_t nvx = nv;
404     code_size = 0;
405     while (nvx > 0) {
406         nvx >>= 8;
407         code_size++;
408     }
409 }
410 
search_and_encode(const float * x) const411 uint64_t ZnSphereCodec::search_and_encode(const float* x) const {
412     std::vector<float> tmp(dim * 2);
413     std::vector<int> tmp_int(dim);
414     int ano; // atom number
415     std::vector<float> c(dim);
416     search(x, c.data(), tmp.data(), tmp_int.data(), &ano);
417     uint64_t signs = 0;
418     std::vector<float> cabs(dim);
419     int nnz = 0;
420     for (int i = 0; i < dim; i++) {
421         cabs[i] = fabs(c[i]);
422         if (c[i] != 0) {
423             if (c[i] < 0) {
424                 signs |= uint64_t{1} << nnz;
425             }
426             nnz++;
427         }
428     }
429     const CodeSegment& cs = code_segments[ano];
430     assert(nnz == cs.signbits);
431     uint64_t code = cs.c0 + signs;
432     code += cs.encode(cabs.data()) << cs.signbits;
433     return code;
434 }
435 
encode(const float * x) const436 uint64_t ZnSphereCodec::encode(const float* x) const {
437     return search_and_encode(x);
438 }
439 
decode(uint64_t code,float * c) const440 void ZnSphereCodec::decode(uint64_t code, float* c) const {
441     int i0 = 0, i1 = natom;
442     while (i0 + 1 < i1) {
443         int imed = (i0 + i1) / 2;
444         if (code_segments[imed].c0 <= code)
445             i0 = imed;
446         else
447             i1 = imed;
448     }
449     const CodeSegment& cs = code_segments[i0];
450     code -= cs.c0;
451     uint64_t signs = code;
452     code >>= cs.signbits;
453     cs.decode(code, c);
454 
455     int nnz = 0;
456     for (int i = 0; i < dim; i++) {
457         if (c[i] != 0) {
458             if (signs & (1UL << nnz)) {
459                 c[i] = -c[i];
460             }
461             nnz++;
462         }
463     }
464 }
465 
466 /**************************************************************
467  * ZnSphereCodecRec
468  **************************************************************/
469 
get_nv(int ld,int r2a) const470 uint64_t ZnSphereCodecRec::get_nv(int ld, int r2a) const {
471     return all_nv[ld * (r2 + 1) + r2a];
472 }
473 
get_nv_cum(int ld,int r2t,int r2a) const474 uint64_t ZnSphereCodecRec::get_nv_cum(int ld, int r2t, int r2a) const {
475     return all_nv_cum[(ld * (r2 + 1) + r2t) * (r2 + 1) + r2a];
476 }
477 
set_nv_cum(int ld,int r2t,int r2a,uint64_t cum)478 void ZnSphereCodecRec::set_nv_cum(int ld, int r2t, int r2a, uint64_t cum) {
479     all_nv_cum[(ld * (r2 + 1) + r2t) * (r2 + 1) + r2a] = cum;
480 }
481 
ZnSphereCodecRec(int dim,int r2)482 ZnSphereCodecRec::ZnSphereCodecRec(int dim, int r2)
483         : EnumeratedVectors(dim), r2(r2) {
484     log2_dim = 0;
485     while (dim > (1 << log2_dim)) {
486         log2_dim++;
487     }
488     assert(dim == (1 << log2_dim) || !"dimension must be a power of 2");
489 
490     all_nv.resize((log2_dim + 1) * (r2 + 1));
491     all_nv_cum.resize((log2_dim + 1) * (r2 + 1) * (r2 + 1));
492 
493     for (int r2a = 0; r2a <= r2; r2a++) {
494         int r = int(sqrt(r2a));
495         if (r * r == r2a) {
496             all_nv[r2a] = r == 0 ? 1 : 2;
497         } else {
498             all_nv[r2a] = 0;
499         }
500     }
501 
502     for (int ld = 1; ld <= log2_dim; ld++) {
503         for (int r2sub = 0; r2sub <= r2; r2sub++) {
504             uint64_t nv = 0;
505             for (int r2a = 0; r2a <= r2sub; r2a++) {
506                 int r2b = r2sub - r2a;
507                 set_nv_cum(ld, r2sub, r2a, nv);
508                 nv += get_nv(ld - 1, r2a) * get_nv(ld - 1, r2b);
509             }
510             all_nv[ld * (r2 + 1) + r2sub] = nv;
511         }
512     }
513     nv = get_nv(log2_dim, r2);
514 
515     uint64_t nvx = nv;
516     code_size = 0;
517     while (nvx > 0) {
518         nvx >>= 8;
519         code_size++;
520     }
521 
522     int cache_level = std::min(3, log2_dim - 1);
523     decode_cache_ld = 0;
524     assert(cache_level <= log2_dim);
525     decode_cache.resize((r2 + 1));
526 
527     for (int r2sub = 0; r2sub <= r2; r2sub++) {
528         int ld = cache_level;
529         uint64_t nvi = get_nv(ld, r2sub);
530         std::vector<float>& cache = decode_cache[r2sub];
531         int dimsub = (1 << cache_level);
532         cache.resize(nvi * dimsub);
533         std::vector<float> c(dim);
534         uint64_t code0 = get_nv_cum(cache_level + 1, r2, r2 - r2sub);
535         for (int i = 0; i < nvi; i++) {
536             decode(i + code0, c.data());
537             memcpy(&cache[i * dimsub],
538                    c.data() + dim - dimsub,
539                    dimsub * sizeof(*c.data()));
540         }
541     }
542     decode_cache_ld = cache_level;
543 }
544 
encode(const float * c) const545 uint64_t ZnSphereCodecRec::encode(const float* c) const {
546     return encode_centroid(c);
547 }
548 
encode_centroid(const float * c) const549 uint64_t ZnSphereCodecRec::encode_centroid(const float* c) const {
550     std::vector<uint64_t> codes(dim);
551     std::vector<int> norm2s(dim);
552     for (int i = 0; i < dim; i++) {
553         if (c[i] == 0) {
554             codes[i] = 0;
555             norm2s[i] = 0;
556         } else {
557             int r2i = int(c[i] * c[i]);
558             norm2s[i] = r2i;
559             codes[i] = c[i] >= 0 ? 0 : 1;
560         }
561     }
562     int dim2 = dim / 2;
563     for (int ld = 1; ld <= log2_dim; ld++) {
564         for (int i = 0; i < dim2; i++) {
565             int r2a = norm2s[2 * i];
566             int r2b = norm2s[2 * i + 1];
567 
568             uint64_t code_a = codes[2 * i];
569             uint64_t code_b = codes[2 * i + 1];
570 
571             codes[i] = get_nv_cum(ld, r2a + r2b, r2a) +
572                     code_a * get_nv(ld - 1, r2b) + code_b;
573             norm2s[i] = r2a + r2b;
574         }
575         dim2 /= 2;
576     }
577     return codes[0];
578 }
579 
decode(uint64_t code,float * c) const580 void ZnSphereCodecRec::decode(uint64_t code, float* c) const {
581     std::vector<uint64_t> codes(dim);
582     std::vector<int> norm2s(dim);
583     codes[0] = code;
584     norm2s[0] = r2;
585 
586     int dim2 = 1;
587     for (int ld = log2_dim; ld > decode_cache_ld; ld--) {
588         for (int i = dim2 - 1; i >= 0; i--) {
589             int r2sub = norm2s[i];
590             int i0 = 0, i1 = r2sub + 1;
591             uint64_t codei = codes[i];
592             const uint64_t* cum =
593                     &all_nv_cum[(ld * (r2 + 1) + r2sub) * (r2 + 1)];
594             while (i1 > i0 + 1) {
595                 int imed = (i0 + i1) / 2;
596                 if (cum[imed] <= codei)
597                     i0 = imed;
598                 else
599                     i1 = imed;
600             }
601             int r2a = i0, r2b = r2sub - i0;
602             codei -= cum[r2a];
603             norm2s[2 * i] = r2a;
604             norm2s[2 * i + 1] = r2b;
605 
606             uint64_t code_a = codei / get_nv(ld - 1, r2b);
607             uint64_t code_b = codei % get_nv(ld - 1, r2b);
608 
609             codes[2 * i] = code_a;
610             codes[2 * i + 1] = code_b;
611         }
612         dim2 *= 2;
613     }
614 
615     if (decode_cache_ld == 0) {
616         for (int i = 0; i < dim; i++) {
617             if (norm2s[i] == 0) {
618                 c[i] = 0;
619             } else {
620                 float r = sqrt(norm2s[i]);
621                 assert(r * r == norm2s[i]);
622                 c[i] = codes[i] == 0 ? r : -r;
623             }
624         }
625     } else {
626         int subdim = 1 << decode_cache_ld;
627         assert((dim2 * subdim) == dim);
628 
629         for (int i = 0; i < dim2; i++) {
630             const std::vector<float>& cache = decode_cache[norm2s[i]];
631             assert(codes[i] < cache.size());
632             memcpy(c + i * subdim,
633                    &cache[codes[i] * subdim],
634                    sizeof(*c) * subdim);
635         }
636     }
637 }
638 
639 // if not use_rec, instanciate an arbitrary harmless znc_rec
ZnSphereCodecAlt(int dim,int r2)640 ZnSphereCodecAlt::ZnSphereCodecAlt(int dim, int r2)
641         : ZnSphereCodec(dim, r2),
642           use_rec((dim & (dim - 1)) == 0),
643           znc_rec(use_rec ? dim : 8, use_rec ? r2 : 14) {}
644 
encode(const float * x) const645 uint64_t ZnSphereCodecAlt::encode(const float* x) const {
646     if (!use_rec) {
647         // it's ok if the vector is not normalized
648         return ZnSphereCodec::encode(x);
649     } else {
650         // find nearest centroid
651         std::vector<float> centroid(dim);
652         search(x, centroid.data());
653         return znc_rec.encode(centroid.data());
654     }
655 }
656 
decode(uint64_t code,float * c) const657 void ZnSphereCodecAlt::decode(uint64_t code, float* c) const {
658     if (!use_rec) {
659         ZnSphereCodec::decode(code, c);
660     } else {
661         znc_rec.decode(code, c);
662     }
663 }
664 
665 } // namespace faiss
666