1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #include <faiss/utils/partitioning.h>
9 
10 #include <cassert>
11 #include <cmath>
12 
13 #include <faiss/impl/FaissAssert.h>
14 #include <faiss/utils/AlignedTable.h>
15 #include <faiss/utils/ordered_key_value.h>
16 #include <faiss/utils/simdlib.h>
17 
18 #include <faiss/impl/platform_macros.h>
19 
20 namespace faiss {
21 
22 /******************************************************************
23  * Internal routines
24  ******************************************************************/
25 
26 namespace partitioning {
27 
28 template <typename T>
median3(T a,T b,T c)29 T median3(T a, T b, T c) {
30     if (a > b) {
31         std::swap(a, b);
32     }
33     if (c > b) {
34         return b;
35     }
36     if (c > a) {
37         return c;
38     }
39     return a;
40 }
41 
42 template <class C>
sample_threshold_median3(const typename C::T * vals,int n,typename C::T thresh_inf,typename C::T thresh_sup)43 typename C::T sample_threshold_median3(
44         const typename C::T* vals,
45         int n,
46         typename C::T thresh_inf,
47         typename C::T thresh_sup) {
48     using T = typename C::T;
49     size_t big_prime = 6700417;
50     T val3[3];
51     int vi = 0;
52 
53     for (size_t i = 0; i < n; i++) {
54         T v = vals[(i * big_prime) % n];
55         // thresh_inf < v < thresh_sup (for CMax)
56         if (C::cmp(v, thresh_inf) && C::cmp(thresh_sup, v)) {
57             val3[vi++] = v;
58             if (vi == 3) {
59                 break;
60             }
61         }
62     }
63 
64     if (vi == 3) {
65         return median3(val3[0], val3[1], val3[2]);
66     } else if (vi != 0) {
67         return val3[0];
68     } else {
69         return thresh_inf;
70         //   FAISS_THROW_MSG("too few values to compute a median");
71     }
72 }
73 
74 template <class C>
count_lt_and_eq(const typename C::T * vals,size_t n,typename C::T thresh,size_t & n_lt,size_t & n_eq)75 void count_lt_and_eq(
76         const typename C::T* vals,
77         size_t n,
78         typename C::T thresh,
79         size_t& n_lt,
80         size_t& n_eq) {
81     n_lt = n_eq = 0;
82 
83     for (size_t i = 0; i < n; i++) {
84         typename C::T v = *vals++;
85         if (C::cmp(thresh, v)) {
86             n_lt++;
87         } else if (v == thresh) {
88             n_eq++;
89         }
90     }
91 }
92 
93 template <class C>
compress_array(typename C::T * vals,typename C::TI * ids,size_t n,typename C::T thresh,size_t n_eq)94 size_t compress_array(
95         typename C::T* vals,
96         typename C::TI* ids,
97         size_t n,
98         typename C::T thresh,
99         size_t n_eq) {
100     size_t wp = 0;
101     for (size_t i = 0; i < n; i++) {
102         if (C::cmp(thresh, vals[i])) {
103             vals[wp] = vals[i];
104             ids[wp] = ids[i];
105             wp++;
106         } else if (n_eq > 0 && vals[i] == thresh) {
107             vals[wp] = vals[i];
108             ids[wp] = ids[i];
109             wp++;
110             n_eq--;
111         }
112     }
113     assert(n_eq == 0);
114     return wp;
115 }
116 
117 #define IFV if (false)
118 
119 template <class C>
partition_fuzzy_median3(typename C::T * vals,typename C::TI * ids,size_t n,size_t q_min,size_t q_max,size_t * q_out)120 typename C::T partition_fuzzy_median3(
121         typename C::T* vals,
122         typename C::TI* ids,
123         size_t n,
124         size_t q_min,
125         size_t q_max,
126         size_t* q_out) {
127     if (q_min == 0) {
128         if (q_out) {
129             *q_out = C::Crev::neutral();
130         }
131         return 0;
132     }
133     if (q_max >= n) {
134         if (q_out) {
135             *q_out = q_max;
136         }
137         return C::neutral();
138     }
139 
140     using T = typename C::T;
141 
142     // here we use bissection with a median of 3 to find the threshold and
143     // compress the arrays afterwards. So it's a n*log(n) algoirithm rather than
144     // qselect's O(n) but it avoids shuffling around the array.
145 
146     FAISS_THROW_IF_NOT(n >= 3);
147 
148     T thresh_inf = C::Crev::neutral();
149     T thresh_sup = C::neutral();
150     T thresh = median3(vals[0], vals[n / 2], vals[n - 1]);
151 
152     size_t n_eq = 0, n_lt = 0;
153     size_t q = 0;
154 
155     for (int it = 0; it < 200; it++) {
156         count_lt_and_eq<C>(vals, n, thresh, n_lt, n_eq);
157 
158         IFV printf(
159                 "   thresh=%g [%g %g] n_lt=%ld n_eq=%ld, q=%ld:%ld/%ld\n",
160                 float(thresh),
161                 float(thresh_inf),
162                 float(thresh_sup),
163                 long(n_lt),
164                 long(n_eq),
165                 long(q_min),
166                 long(q_max),
167                 long(n));
168 
169         if (n_lt <= q_min) {
170             if (n_lt + n_eq >= q_min) {
171                 q = q_min;
172                 break;
173             } else {
174                 thresh_inf = thresh;
175             }
176         } else if (n_lt <= q_max) {
177             q = n_lt;
178             break;
179         } else {
180             thresh_sup = thresh;
181         }
182 
183         // FIXME avoid a second pass over the array to sample the threshold
184         IFV printf(
185                 "     sample thresh in [%g %g]\n",
186                 float(thresh_inf),
187                 float(thresh_sup));
188         T new_thresh =
189                 sample_threshold_median3<C>(vals, n, thresh_inf, thresh_sup);
190         if (new_thresh == thresh_inf) {
191             // then there is nothing between thresh_inf and thresh_sup
192             break;
193         }
194         thresh = new_thresh;
195     }
196 
197     int64_t n_eq_1 = q - n_lt;
198 
199     IFV printf("shrink: thresh=%g n_eq_1=%ld\n", float(thresh), long(n_eq_1));
200 
201     if (n_eq_1 < 0) { // happens when > q elements are at lower bound
202         q = q_min;
203         thresh = C::Crev::nextafter(thresh);
204         n_eq_1 = q;
205     } else {
206         assert(n_eq_1 <= n_eq);
207     }
208 
209     int wp = compress_array<C>(vals, ids, n, thresh, n_eq_1);
210 
211     assert(wp == q);
212     if (q_out) {
213         *q_out = q;
214     }
215 
216     return thresh;
217 }
218 
219 } // namespace partitioning
220 
221 /******************************************************************
222  * SIMD routines when vals is an aligned array of uint16_t
223  ******************************************************************/
224 
225 namespace simd_partitioning {
226 
find_minimax(const uint16_t * vals,size_t n,uint16_t & smin,uint16_t & smax)227 void find_minimax(
228         const uint16_t* vals,
229         size_t n,
230         uint16_t& smin,
231         uint16_t& smax) {
232     simd16uint16 vmin(0xffff), vmax(0);
233     for (size_t i = 0; i + 15 < n; i += 16) {
234         simd16uint16 v(vals + i);
235         vmin.accu_min(v);
236         vmax.accu_max(v);
237     }
238 
239     ALIGNED(32) uint16_t tab32[32];
240     vmin.store(tab32);
241     vmax.store(tab32 + 16);
242 
243     smin = tab32[0], smax = tab32[16];
244 
245     for (int i = 1; i < 16; i++) {
246         smin = std::min(smin, tab32[i]);
247         smax = std::max(smax, tab32[i + 16]);
248     }
249 
250     // missing values
251     for (size_t i = (n & ~15); i < n; i++) {
252         smin = std::min(smin, vals[i]);
253         smax = std::max(smax, vals[i]);
254     }
255 }
256 
257 // max func differentiates between CMin and CMax (keep lowest or largest)
258 template <class C>
max_func(simd16uint16 v,simd16uint16 thr16)259 simd16uint16 max_func(simd16uint16 v, simd16uint16 thr16) {
260     constexpr bool is_max = C::is_max;
261     if (is_max) {
262         return max(v, thr16);
263     } else {
264         return min(v, thr16);
265     }
266 }
267 
268 template <class C>
count_lt_and_eq(const uint16_t * vals,int n,uint16_t thresh,size_t & n_lt,size_t & n_eq)269 void count_lt_and_eq(
270         const uint16_t* vals,
271         int n,
272         uint16_t thresh,
273         size_t& n_lt,
274         size_t& n_eq) {
275     n_lt = n_eq = 0;
276     simd16uint16 thr16(thresh);
277 
278     size_t n1 = n / 16;
279 
280     for (size_t i = 0; i < n1; i++) {
281         simd16uint16 v(vals);
282         vals += 16;
283         simd16uint16 eqmask = (v == thr16);
284         simd16uint16 max2 = max_func<C>(v, thr16);
285         simd16uint16 gemask = (v == max2);
286         uint32_t bits = get_MSBs(uint16_to_uint8_saturate(eqmask, gemask));
287         int i_eq = __builtin_popcount(bits & 0x00ff00ff);
288         int i_ge = __builtin_popcount(bits) - i_eq;
289         n_eq += i_eq;
290         n_lt += 16 - i_ge;
291     }
292 
293     for (size_t i = n1 * 16; i < n; i++) {
294         uint16_t v = *vals++;
295         if (C::cmp(thresh, v)) {
296             n_lt++;
297         } else if (v == thresh) {
298             n_eq++;
299         }
300     }
301 }
302 
303 /* compress separated values and ids table, keeping all values < thresh and at
304  * most n_eq equal values */
305 template <class C>
simd_compress_array(uint16_t * vals,typename C::TI * ids,size_t n,uint16_t thresh,int n_eq)306 int simd_compress_array(
307         uint16_t* vals,
308         typename C::TI* ids,
309         size_t n,
310         uint16_t thresh,
311         int n_eq) {
312     simd16uint16 thr16(thresh);
313     simd16uint16 mixmask(0xff00);
314 
315     int wp = 0;
316     size_t i0;
317 
318     // loop while there are eqs to collect
319     for (i0 = 0; i0 + 15 < n && n_eq > 0; i0 += 16) {
320         simd16uint16 v(vals + i0);
321         simd16uint16 max2 = max_func<C>(v, thr16);
322         simd16uint16 gemask = (v == max2);
323         simd16uint16 eqmask = (v == thr16);
324         uint32_t bits = get_MSBs(
325                 blendv(simd32uint8(eqmask),
326                        simd32uint8(gemask),
327                        simd32uint8(mixmask)));
328         bits ^= 0xAAAAAAAA;
329         // bit 2*i     : eq
330         // bit 2*i + 1 : lt
331 
332         while (bits) {
333             int j = __builtin_ctz(bits) & (~1);
334             bool is_eq = (bits >> j) & 1;
335             bool is_lt = (bits >> j) & 2;
336             bits &= ~(3 << j);
337             j >>= 1;
338 
339             if (is_lt) {
340                 vals[wp] = vals[i0 + j];
341                 ids[wp] = ids[i0 + j];
342                 wp++;
343             } else if (is_eq && n_eq > 0) {
344                 vals[wp] = vals[i0 + j];
345                 ids[wp] = ids[i0 + j];
346                 wp++;
347                 n_eq--;
348             }
349         }
350     }
351 
352     // handle remaining, only striclty lt ones.
353     for (; i0 + 15 < n; i0 += 16) {
354         simd16uint16 v(vals + i0);
355         simd16uint16 max2 = max_func<C>(v, thr16);
356         simd16uint16 gemask = (v == max2);
357         uint32_t bits = ~get_MSBs(simd32uint8(gemask));
358 
359         while (bits) {
360             int j = __builtin_ctz(bits);
361             bits &= ~(3 << j);
362             j >>= 1;
363 
364             vals[wp] = vals[i0 + j];
365             ids[wp] = ids[i0 + j];
366             wp++;
367         }
368     }
369 
370     // end with scalar
371     for (int i = (n & ~15); i < n; i++) {
372         if (C::cmp(thresh, vals[i])) {
373             vals[wp] = vals[i];
374             ids[wp] = ids[i];
375             wp++;
376         } else if (vals[i] == thresh && n_eq > 0) {
377             vals[wp] = vals[i];
378             ids[wp] = ids[i];
379             wp++;
380             n_eq--;
381         }
382     }
383     assert(n_eq == 0);
384     return wp;
385 }
386 
387 // #define MICRO_BENCHMARK
388 
get_cy()389 static uint64_t get_cy() {
390 #ifdef MICRO_BENCHMARK
391     uint32_t high, low;
392     asm volatile("rdtsc \n\t" : "=a"(low), "=d"(high));
393     return ((uint64_t)high << 32) | (low);
394 #else
395     return 0;
396 #endif
397 }
398 
399 #define IFV if (false)
400 
401 template <class C>
simd_partition_fuzzy_with_bounds(uint16_t * vals,typename C::TI * ids,size_t n,size_t q_min,size_t q_max,size_t * q_out,uint16_t s0i,uint16_t s1i)402 uint16_t simd_partition_fuzzy_with_bounds(
403         uint16_t* vals,
404         typename C::TI* ids,
405         size_t n,
406         size_t q_min,
407         size_t q_max,
408         size_t* q_out,
409         uint16_t s0i,
410         uint16_t s1i) {
411     if (q_min == 0) {
412         if (q_out) {
413             *q_out = 0;
414         }
415         return 0;
416     }
417     if (q_max >= n) {
418         if (q_out) {
419             *q_out = q_max;
420         }
421         return 0xffff;
422     }
423     if (s0i == s1i) {
424         if (q_out) {
425             *q_out = q_min;
426         }
427         return s0i;
428     }
429     uint64_t t0 = get_cy();
430 
431     // lower bound inclusive, upper exclusive
432     size_t s0 = s0i, s1 = s1i + 1;
433 
434     IFV printf("bounds: %ld %ld\n", s0, s1 - 1);
435 
436     int thresh;
437     size_t n_eq = 0, n_lt = 0;
438     size_t q = 0;
439 
440     for (int it = 0; it < 200; it++) {
441         // while(s0 + 1 < s1) {
442         thresh = (s0 + s1) / 2;
443         count_lt_and_eq<C>(vals, n, thresh, n_lt, n_eq);
444 
445         IFV printf(
446                 "   [%ld %ld] thresh=%d n_lt=%ld n_eq=%ld, q=%ld:%ld/%ld\n",
447                 s0,
448                 s1,
449                 thresh,
450                 n_lt,
451                 n_eq,
452                 q_min,
453                 q_max,
454                 n);
455         if (n_lt <= q_min) {
456             if (n_lt + n_eq >= q_min) {
457                 q = q_min;
458                 break;
459             } else {
460                 if (C::is_max) {
461                     s0 = thresh;
462                 } else {
463                     s1 = thresh;
464                 }
465             }
466         } else if (n_lt <= q_max) {
467             q = n_lt;
468             break;
469         } else {
470             if (C::is_max) {
471                 s1 = thresh;
472             } else {
473                 s0 = thresh;
474             }
475         }
476     }
477 
478     uint64_t t1 = get_cy();
479 
480     // number of equal values to keep
481     int64_t n_eq_1 = q - n_lt;
482 
483     IFV printf("shrink: thresh=%d q=%ld n_eq_1=%ld\n", thresh, q, n_eq_1);
484     if (n_eq_1 < 0) { // happens when > q elements are at lower bound
485         assert(s0 + 1 == s1);
486         q = q_min;
487         if (C::is_max) {
488             thresh--;
489         } else {
490             thresh++;
491         }
492         n_eq_1 = q;
493         IFV printf("  override: thresh=%d n_eq_1=%ld\n", thresh, n_eq_1);
494     } else {
495         assert(n_eq_1 <= n_eq);
496     }
497 
498     size_t wp = simd_compress_array<C>(vals, ids, n, thresh, n_eq_1);
499 
500     IFV printf("wp=%ld\n", wp);
501     assert(wp == q);
502     if (q_out) {
503         *q_out = q;
504     }
505 
506     uint64_t t2 = get_cy();
507 
508     partition_stats.bissect_cycles += t1 - t0;
509     partition_stats.compress_cycles += t2 - t1;
510 
511     return thresh;
512 }
513 
514 template <class C>
simd_partition_fuzzy_with_bounds_histogram(uint16_t * vals,typename C::TI * ids,size_t n,size_t q_min,size_t q_max,size_t * q_out,uint16_t s0i,uint16_t s1i)515 uint16_t simd_partition_fuzzy_with_bounds_histogram(
516         uint16_t* vals,
517         typename C::TI* ids,
518         size_t n,
519         size_t q_min,
520         size_t q_max,
521         size_t* q_out,
522         uint16_t s0i,
523         uint16_t s1i) {
524     if (q_min == 0) {
525         if (q_out) {
526             *q_out = 0;
527         }
528         return 0;
529     }
530     if (q_max >= n) {
531         if (q_out) {
532             *q_out = q_max;
533         }
534         return 0xffff;
535     }
536     if (s0i == s1i) {
537         if (q_out) {
538             *q_out = q_min;
539         }
540         return s0i;
541     }
542 
543     IFV printf(
544             "partition fuzzy, q=%ld:%ld / %ld, bounds=%d %d\n",
545             q_min,
546             q_max,
547             n,
548             s0i,
549             s1i);
550 
551     if (!C::is_max) {
552         IFV printf(
553                 "revert due to CMin, q_min:q_max -> %ld:%ld\n", q_min, q_max);
554         q_min = n - q_min;
555         q_max = n - q_max;
556     }
557 
558     // lower and upper bound of range, inclusive
559     int s0 = s0i, s1 = s1i;
560     // number of values < s0 and > s1
561     size_t n_lt = 0, n_gt = 0;
562 
563     // output of loop:
564     int thresh;          // final threshold
565     uint64_t tot_eq = 0; // total nb of equal values
566     uint64_t n_eq = 0;   // nb of equal values to keep
567     size_t q;            // final quantile
568 
569     // buffer for the histograms
570     int hist[16];
571 
572     for (int it = 0; it < 20; it++) {
573         // otherwise we would be done already
574 
575         int shift = 0;
576 
577         IFV printf(
578                 "  it %d bounds: %d %d n_lt=%ld n_gt=%ld\n",
579                 it,
580                 s0,
581                 s1,
582                 n_lt,
583                 n_gt);
584 
585         int maxval = s1 - s0;
586 
587         while (maxval > 15) {
588             shift++;
589             maxval >>= 1;
590         }
591 
592         IFV printf(
593                 "    histogram shift %d maxval %d ?= %d\n",
594                 shift,
595                 maxval,
596                 int((s1 - s0) >> shift));
597 
598         if (maxval > 7) {
599             simd_histogram_16(vals, n, s0, shift, hist);
600         } else {
601             simd_histogram_8(vals, n, s0, shift, hist);
602         }
603         IFV {
604             int sum = n_lt + n_gt;
605             printf("    n_lt=%ld hist=[", n_lt);
606             for (int i = 0; i <= maxval; i++) {
607                 printf("%d ", hist[i]);
608                 sum += hist[i];
609             }
610             printf("] n_gt=%ld sum=%d\n", n_gt, sum);
611             assert(sum == n);
612         }
613 
614         size_t sum_below = n_lt;
615         int i;
616         for (i = 0; i <= maxval; i++) {
617             sum_below += hist[i];
618             if (sum_below >= q_min) {
619                 break;
620             }
621         }
622         IFV printf("    i=%d sum_below=%ld\n", i, sum_below);
623         if (i <= maxval) {
624             s0 = s0 + (i << shift);
625             s1 = s0 + (1 << shift) - 1;
626             n_lt = sum_below - hist[i];
627             n_gt = n - sum_below;
628         } else {
629             assert(!"not implemented");
630         }
631 
632         IFV printf(
633                 "    new bin: s0=%d s1=%d n_lt=%ld n_gt=%ld\n",
634                 s0,
635                 s1,
636                 n_lt,
637                 n_gt);
638 
639         if (s1 > s0) {
640             if (n_lt >= q_min && q_max >= n_lt) {
641                 IFV printf("    FOUND1\n");
642                 thresh = s0;
643                 q = n_lt;
644                 break;
645             }
646 
647             size_t n_lt_2 = n - n_gt;
648             if (n_lt_2 >= q_min && q_max >= n_lt_2) {
649                 thresh = s1 + 1;
650                 q = n_lt_2;
651                 IFV printf("    FOUND2\n");
652                 break;
653             }
654         } else {
655             thresh = s0;
656             q = q_min;
657             tot_eq = n - n_gt - n_lt;
658             n_eq = q_min - n_lt;
659             IFV printf("    FOUND3\n");
660             break;
661         }
662     }
663 
664     IFV printf("end bissection: thresh=%d q=%ld n_eq=%ld\n", thresh, q, n_eq);
665 
666     if (!C::is_max) {
667         if (n_eq == 0) {
668             thresh--;
669         } else {
670             // thresh unchanged
671             n_eq = tot_eq - n_eq;
672         }
673         q = n - q;
674         IFV printf("revert due to CMin, q->%ld n_eq->%ld\n", q, n_eq);
675     }
676 
677     size_t wp = simd_compress_array<C>(vals, ids, n, thresh, n_eq);
678     IFV printf("wp=%ld ?= %ld\n", wp, q);
679     assert(wp == q);
680     if (q_out) {
681         *q_out = wp;
682     }
683 
684     return thresh;
685 }
686 
687 template <class C>
simd_partition_fuzzy(uint16_t * vals,typename C::TI * ids,size_t n,size_t q_min,size_t q_max,size_t * q_out)688 uint16_t simd_partition_fuzzy(
689         uint16_t* vals,
690         typename C::TI* ids,
691         size_t n,
692         size_t q_min,
693         size_t q_max,
694         size_t* q_out) {
695     assert(is_aligned_pointer(vals));
696 
697     uint16_t s0i, s1i;
698     find_minimax(vals, n, s0i, s1i);
699     // QSelect_stats.t0 += get_cy() - t0;
700 
701     return simd_partition_fuzzy_with_bounds<C>(
702             vals, ids, n, q_min, q_max, q_out, s0i, s1i);
703 }
704 
705 template <class C>
simd_partition(uint16_t * vals,typename C::TI * ids,size_t n,size_t q)706 uint16_t simd_partition(
707         uint16_t* vals,
708         typename C::TI* ids,
709         size_t n,
710         size_t q) {
711     assert(is_aligned_pointer(vals));
712 
713     if (q == 0) {
714         return 0;
715     }
716     if (q >= n) {
717         return 0xffff;
718     }
719 
720     uint16_t s0i, s1i;
721     find_minimax(vals, n, s0i, s1i);
722 
723     return simd_partition_fuzzy_with_bounds<C>(
724             vals, ids, n, q, q, nullptr, s0i, s1i);
725 }
726 
727 template <class C>
simd_partition_with_bounds(uint16_t * vals,typename C::TI * ids,size_t n,size_t q,uint16_t s0i,uint16_t s1i)728 uint16_t simd_partition_with_bounds(
729         uint16_t* vals,
730         typename C::TI* ids,
731         size_t n,
732         size_t q,
733         uint16_t s0i,
734         uint16_t s1i) {
735     return simd_partition_fuzzy_with_bounds<C>(
736             vals, ids, n, q, q, nullptr, s0i, s1i);
737 }
738 
739 } // namespace simd_partitioning
740 
741 /******************************************************************
742  * Driver routine
743  ******************************************************************/
744 
745 template <class C>
partition_fuzzy(typename C::T * vals,typename C::TI * ids,size_t n,size_t q_min,size_t q_max,size_t * q_out)746 typename C::T partition_fuzzy(
747         typename C::T* vals,
748         typename C::TI* ids,
749         size_t n,
750         size_t q_min,
751         size_t q_max,
752         size_t* q_out) {
753     // the code below compiles and runs without AVX2 but it's slower than
754     // the scalar implementation
755 #ifdef __AVX2__
756     constexpr bool is_uint16 = std::is_same<typename C::T, uint16_t>::value;
757     if (is_uint16 && is_aligned_pointer(vals)) {
758         return simd_partitioning::simd_partition_fuzzy<C>(
759                 (uint16_t*)vals, ids, n, q_min, q_max, q_out);
760     }
761 #endif
762     return partitioning::partition_fuzzy_median3<C>(
763             vals, ids, n, q_min, q_max, q_out);
764 }
765 
766 // explicit template instanciations
767 
768 template float partition_fuzzy<CMin<float, int64_t>>(
769         float* vals,
770         int64_t* ids,
771         size_t n,
772         size_t q_min,
773         size_t q_max,
774         size_t* q_out);
775 
776 template float partition_fuzzy<CMax<float, int64_t>>(
777         float* vals,
778         int64_t* ids,
779         size_t n,
780         size_t q_min,
781         size_t q_max,
782         size_t* q_out);
783 
784 template uint16_t partition_fuzzy<CMin<uint16_t, int64_t>>(
785         uint16_t* vals,
786         int64_t* ids,
787         size_t n,
788         size_t q_min,
789         size_t q_max,
790         size_t* q_out);
791 
792 template uint16_t partition_fuzzy<CMax<uint16_t, int64_t>>(
793         uint16_t* vals,
794         int64_t* ids,
795         size_t n,
796         size_t q_min,
797         size_t q_max,
798         size_t* q_out);
799 
800 template uint16_t partition_fuzzy<CMin<uint16_t, int>>(
801         uint16_t* vals,
802         int* ids,
803         size_t n,
804         size_t q_min,
805         size_t q_max,
806         size_t* q_out);
807 
808 template uint16_t partition_fuzzy<CMax<uint16_t, int>>(
809         uint16_t* vals,
810         int* ids,
811         size_t n,
812         size_t q_min,
813         size_t q_max,
814         size_t* q_out);
815 
816 /******************************************************************
817  * Histogram subroutines
818  ******************************************************************/
819 
820 #ifdef __AVX2__
821 /// FIXME when MSB of uint16 is set
822 // this code does not compile properly with GCC 7.4.0
823 
824 namespace {
825 
826 /************************************************************
827  * 8 bins
828  ************************************************************/
829 
accu4to8(simd16uint16 a4)830 simd32uint8 accu4to8(simd16uint16 a4) {
831     simd16uint16 mask4(0x0f0f);
832 
833     simd16uint16 a8_0 = a4 & mask4;
834     simd16uint16 a8_1 = (a4 >> 4) & mask4;
835 
836     return simd32uint8(_mm256_hadd_epi16(a8_0.i, a8_1.i));
837 }
838 
accu8to16(simd32uint8 a8)839 simd16uint16 accu8to16(simd32uint8 a8) {
840     simd16uint16 mask8(0x00ff);
841 
842     simd16uint16 a8_0 = simd16uint16(a8) & mask8;
843     simd16uint16 a8_1 = (simd16uint16(a8) >> 8) & mask8;
844 
845     return simd16uint16(_mm256_hadd_epi16(a8_0.i, a8_1.i));
846 }
847 
848 static const simd32uint8 shifts(_mm256_setr_epi8(
849         1,
850         16,
851         0,
852         0,
853         4,
854         64,
855         0,
856         0,
857         0,
858         0,
859         1,
860         16,
861         0,
862         0,
863         4,
864         64,
865         1,
866         16,
867         0,
868         0,
869         4,
870         64,
871         0,
872         0,
873         0,
874         0,
875         1,
876         16,
877         0,
878         0,
879         4,
880         64));
881 
882 // 2-bit accumulator: we can add only up to 3 elements
883 // on output we return 2*4-bit results
884 // preproc returns either an index in 0..7 or 0xffff
885 // that yeilds a 0 when used in the table look-up
886 template <int N, class Preproc>
compute_accu2(const uint16_t * & data,Preproc & pp,simd16uint16 & a4lo,simd16uint16 & a4hi)887 void compute_accu2(
888         const uint16_t*& data,
889         Preproc& pp,
890         simd16uint16& a4lo,
891         simd16uint16& a4hi) {
892     simd16uint16 mask2(0x3333);
893     simd16uint16 a2((uint16_t)0); // 2-bit accu
894     for (int j = 0; j < N; j++) {
895         simd16uint16 v(data);
896         data += 16;
897         v = pp(v);
898         // 0x800 -> force second half of table
899         simd16uint16 idx = v | (v << 8) | simd16uint16(0x800);
900         a2 += simd16uint16(shifts.lookup_2_lanes(simd32uint8(idx)));
901     }
902     a4lo += a2 & mask2;
903     a4hi += (a2 >> 2) & mask2;
904 }
905 
906 template <class Preproc>
histogram_8(const uint16_t * data,Preproc pp,size_t n_in)907 simd16uint16 histogram_8(const uint16_t* data, Preproc pp, size_t n_in) {
908     assert(n_in % 16 == 0);
909     int n = n_in / 16;
910 
911     simd32uint8 a8lo(0);
912     simd32uint8 a8hi(0);
913 
914     for (int i0 = 0; i0 < n; i0 += 15) {
915         simd16uint16 a4lo(0); // 4-bit accus
916         simd16uint16 a4hi(0);
917 
918         int i1 = std::min(i0 + 15, n);
919         int i;
920         for (i = i0; i + 2 < i1; i += 3) {
921             compute_accu2<3>(data, pp, a4lo, a4hi); // adds 3 max
922         }
923         switch (i1 - i) {
924             case 2:
925                 compute_accu2<2>(data, pp, a4lo, a4hi);
926                 break;
927             case 1:
928                 compute_accu2<1>(data, pp, a4lo, a4hi);
929                 break;
930         }
931 
932         a8lo += accu4to8(a4lo);
933         a8hi += accu4to8(a4hi);
934     }
935 
936     // move to 16-bit accu
937     simd16uint16 a16lo = accu8to16(a8lo);
938     simd16uint16 a16hi = accu8to16(a8hi);
939 
940     simd16uint16 a16 = simd16uint16(_mm256_hadd_epi16(a16lo.i, a16hi.i));
941 
942     // the 2 lanes must still be combined
943     return a16;
944 }
945 
946 /************************************************************
947  * 16 bins
948  ************************************************************/
949 
950 static const simd32uint8 shifts2(_mm256_setr_epi8(
951         1,
952         2,
953         4,
954         8,
955         16,
956         32,
957         64,
958         (char)128,
959         1,
960         2,
961         4,
962         8,
963         16,
964         32,
965         64,
966         (char)128,
967         1,
968         2,
969         4,
970         8,
971         16,
972         32,
973         64,
974         (char)128,
975         1,
976         2,
977         4,
978         8,
979         16,
980         32,
981         64,
982         (char)128));
983 
shiftr_16(simd32uint8 x,int n)984 simd32uint8 shiftr_16(simd32uint8 x, int n) {
985     return simd32uint8(simd16uint16(x) >> n);
986 }
987 
combine_2x2(simd32uint8 a,simd32uint8 b)988 inline simd32uint8 combine_2x2(simd32uint8 a, simd32uint8 b) {
989     __m256i a1b0 = _mm256_permute2f128_si256(a.i, b.i, 0x21);
990     __m256i a0b1 = _mm256_blend_epi32(a.i, b.i, 0xF0);
991 
992     return simd32uint8(a1b0) + simd32uint8(a0b1);
993 }
994 
995 // 2-bit accumulator: we can add only up to 3 elements
996 // on output we return 2*4-bit results
997 template <int N, class Preproc>
compute_accu2_16(const uint16_t * & data,Preproc pp,simd32uint8 & a4_0,simd32uint8 & a4_1,simd32uint8 & a4_2,simd32uint8 & a4_3)998 void compute_accu2_16(
999         const uint16_t*& data,
1000         Preproc pp,
1001         simd32uint8& a4_0,
1002         simd32uint8& a4_1,
1003         simd32uint8& a4_2,
1004         simd32uint8& a4_3) {
1005     simd32uint8 mask1(0x55);
1006     simd32uint8 a2_0; // 2-bit accu
1007     simd32uint8 a2_1; // 2-bit accu
1008     a2_0.clear();
1009     a2_1.clear();
1010 
1011     for (int j = 0; j < N; j++) {
1012         simd16uint16 v(data);
1013         data += 16;
1014         v = pp(v);
1015 
1016         simd16uint16 idx = v | (v << 8);
1017         simd32uint8 a1 = shifts2.lookup_2_lanes(simd32uint8(idx));
1018         // contains 0s for out-of-bounds elements
1019 
1020         simd16uint16 lt8 = (v >> 3) == simd16uint16(0);
1021         lt8.i = _mm256_xor_si256(lt8.i, _mm256_set1_epi16(0xff00));
1022 
1023         a1 = a1 & lt8;
1024 
1025         a2_0 += a1 & mask1;
1026         a2_1 += shiftr_16(a1, 1) & mask1;
1027     }
1028     simd32uint8 mask2(0x33);
1029 
1030     a4_0 += a2_0 & mask2;
1031     a4_1 += a2_1 & mask2;
1032     a4_2 += shiftr_16(a2_0, 2) & mask2;
1033     a4_3 += shiftr_16(a2_1, 2) & mask2;
1034 }
1035 
accu4to8_2(simd32uint8 a4_0,simd32uint8 a4_1)1036 simd32uint8 accu4to8_2(simd32uint8 a4_0, simd32uint8 a4_1) {
1037     simd32uint8 mask4(0x0f);
1038 
1039     simd32uint8 a8_0 = combine_2x2(a4_0 & mask4, shiftr_16(a4_0, 4) & mask4);
1040 
1041     simd32uint8 a8_1 = combine_2x2(a4_1 & mask4, shiftr_16(a4_1, 4) & mask4);
1042 
1043     return simd32uint8(_mm256_hadd_epi16(a8_0.i, a8_1.i));
1044 }
1045 
1046 template <class Preproc>
histogram_16(const uint16_t * data,Preproc pp,size_t n_in)1047 simd16uint16 histogram_16(const uint16_t* data, Preproc pp, size_t n_in) {
1048     assert(n_in % 16 == 0);
1049     int n = n_in / 16;
1050 
1051     simd32uint8 a8lo((uint8_t)0);
1052     simd32uint8 a8hi((uint8_t)0);
1053 
1054     for (int i0 = 0; i0 < n; i0 += 7) {
1055         simd32uint8 a4_0(0); // 0, 4, 8, 12
1056         simd32uint8 a4_1(0); // 1, 5, 9, 13
1057         simd32uint8 a4_2(0); // 2, 6, 10, 14
1058         simd32uint8 a4_3(0); // 3, 7, 11, 15
1059 
1060         int i1 = std::min(i0 + 7, n);
1061         int i;
1062         for (i = i0; i + 2 < i1; i += 3) {
1063             compute_accu2_16<3>(data, pp, a4_0, a4_1, a4_2, a4_3);
1064         }
1065         switch (i1 - i) {
1066             case 2:
1067                 compute_accu2_16<2>(data, pp, a4_0, a4_1, a4_2, a4_3);
1068                 break;
1069             case 1:
1070                 compute_accu2_16<1>(data, pp, a4_0, a4_1, a4_2, a4_3);
1071                 break;
1072         }
1073 
1074         a8lo += accu4to8_2(a4_0, a4_1);
1075         a8hi += accu4to8_2(a4_2, a4_3);
1076     }
1077 
1078     // move to 16-bit accu
1079     simd16uint16 a16lo = accu8to16(a8lo);
1080     simd16uint16 a16hi = accu8to16(a8hi);
1081 
1082     simd16uint16 a16 = simd16uint16(_mm256_hadd_epi16(a16lo.i, a16hi.i));
1083 
1084     __m256i perm32 = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
1085     a16.i = _mm256_permutevar8x32_epi32(a16.i, perm32);
1086 
1087     return a16;
1088 }
1089 
1090 struct PreprocNOP {
operator ()faiss::__anon75a0fc1d0111::PreprocNOP1091     simd16uint16 operator()(simd16uint16 x) {
1092         return x;
1093     }
1094 };
1095 
1096 template <int shift, int nbin>
1097 struct PreprocMinShift {
1098     simd16uint16 min16;
1099     simd16uint16 max16;
1100 
PreprocMinShiftfaiss::__anon75a0fc1d0111::PreprocMinShift1101     explicit PreprocMinShift(uint16_t min) {
1102         min16.set1(min);
1103         int vmax0 = std::min((nbin << shift) + min, 65536);
1104         uint16_t vmax = uint16_t(vmax0 - 1 - min);
1105         max16.set1(vmax); // vmax inclusive
1106     }
1107 
operator ()faiss::__anon75a0fc1d0111::PreprocMinShift1108     simd16uint16 operator()(simd16uint16 x) {
1109         x = x - min16;
1110         simd16uint16 mask = (x == max(x, max16)) - (x == max16);
1111         return (x >> shift) | mask;
1112     }
1113 };
1114 
1115 /* unbounded versions of the functions */
1116 
simd_histogram_8_unbounded(const uint16_t * data,int n,int * hist)1117 void simd_histogram_8_unbounded(const uint16_t* data, int n, int* hist) {
1118     PreprocNOP pp;
1119     simd16uint16 a16 = histogram_8(data, pp, (n & ~15));
1120 
1121     ALIGNED(32) uint16_t a16_tab[16];
1122     a16.store(a16_tab);
1123 
1124     for (int i = 0; i < 8; i++) {
1125         hist[i] = a16_tab[i] + a16_tab[i + 8];
1126     }
1127 
1128     for (int i = (n & ~15); i < n; i++) {
1129         hist[data[i]]++;
1130     }
1131 }
1132 
simd_histogram_16_unbounded(const uint16_t * data,int n,int * hist)1133 void simd_histogram_16_unbounded(const uint16_t* data, int n, int* hist) {
1134     simd16uint16 a16 = histogram_16(data, PreprocNOP(), (n & ~15));
1135 
1136     ALIGNED(32) uint16_t a16_tab[16];
1137     a16.store(a16_tab);
1138 
1139     for (int i = 0; i < 16; i++) {
1140         hist[i] = a16_tab[i];
1141     }
1142 
1143     for (int i = (n & ~15); i < n; i++) {
1144         hist[data[i]]++;
1145     }
1146 }
1147 
1148 } // anonymous namespace
1149 
1150 /************************************************************
1151  * Driver routines
1152  ************************************************************/
1153 
simd_histogram_8(const uint16_t * data,int n,uint16_t min,int shift,int * hist)1154 void simd_histogram_8(
1155         const uint16_t* data,
1156         int n,
1157         uint16_t min,
1158         int shift,
1159         int* hist) {
1160     if (shift < 0) {
1161         simd_histogram_8_unbounded(data, n, hist);
1162         return;
1163     }
1164 
1165     simd16uint16 a16;
1166 
1167 #define DISPATCH(s)                                                     \
1168     case s:                                                             \
1169         a16 = histogram_8(data, PreprocMinShift<s, 8>(min), (n & ~15)); \
1170         break
1171 
1172     switch (shift) {
1173         DISPATCH(0);
1174         DISPATCH(1);
1175         DISPATCH(2);
1176         DISPATCH(3);
1177         DISPATCH(4);
1178         DISPATCH(5);
1179         DISPATCH(6);
1180         DISPATCH(7);
1181         DISPATCH(8);
1182         DISPATCH(9);
1183         DISPATCH(10);
1184         DISPATCH(11);
1185         DISPATCH(12);
1186         DISPATCH(13);
1187         default:
1188             FAISS_THROW_FMT("dispatch for shift=%d not instantiated", shift);
1189     }
1190 #undef DISPATCH
1191 
1192     ALIGNED(32) uint16_t a16_tab[16];
1193     a16.store(a16_tab);
1194 
1195     for (int i = 0; i < 8; i++) {
1196         hist[i] = a16_tab[i] + a16_tab[i + 8];
1197     }
1198 
1199     // complete with remaining bins
1200     for (int i = (n & ~15); i < n; i++) {
1201         if (data[i] < min)
1202             continue;
1203         uint16_t v = data[i] - min;
1204         v >>= shift;
1205         if (v < 8)
1206             hist[v]++;
1207     }
1208 }
1209 
simd_histogram_16(const uint16_t * data,int n,uint16_t min,int shift,int * hist)1210 void simd_histogram_16(
1211         const uint16_t* data,
1212         int n,
1213         uint16_t min,
1214         int shift,
1215         int* hist) {
1216     if (shift < 0) {
1217         simd_histogram_16_unbounded(data, n, hist);
1218         return;
1219     }
1220 
1221     simd16uint16 a16;
1222 
1223 #define DISPATCH(s)                                                       \
1224     case s:                                                               \
1225         a16 = histogram_16(data, PreprocMinShift<s, 16>(min), (n & ~15)); \
1226         break
1227 
1228     switch (shift) {
1229         DISPATCH(0);
1230         DISPATCH(1);
1231         DISPATCH(2);
1232         DISPATCH(3);
1233         DISPATCH(4);
1234         DISPATCH(5);
1235         DISPATCH(6);
1236         DISPATCH(7);
1237         DISPATCH(8);
1238         DISPATCH(9);
1239         DISPATCH(10);
1240         DISPATCH(11);
1241         DISPATCH(12);
1242         default:
1243             FAISS_THROW_FMT("dispatch for shift=%d not instantiated", shift);
1244     }
1245 #undef DISPATCH
1246 
1247     ALIGNED(32) uint16_t a16_tab[16];
1248     a16.store(a16_tab);
1249 
1250     for (int i = 0; i < 16; i++) {
1251         hist[i] = a16_tab[i];
1252     }
1253 
1254     for (int i = (n & ~15); i < n; i++) {
1255         if (data[i] < min)
1256             continue;
1257         uint16_t v = data[i] - min;
1258         v >>= shift;
1259         if (v < 16)
1260             hist[v]++;
1261     }
1262 }
1263 
1264 // no AVX2
1265 #else
1266 
simd_histogram_16(const uint16_t * data,int n,uint16_t min,int shift,int * hist)1267 void simd_histogram_16(
1268         const uint16_t* data,
1269         int n,
1270         uint16_t min,
1271         int shift,
1272         int* hist) {
1273     memset(hist, 0, sizeof(*hist) * 16);
1274     if (shift < 0) {
1275         for (size_t i = 0; i < n; i++) {
1276             hist[data[i]]++;
1277         }
1278     } else {
1279         int vmax0 = std::min((16 << shift) + min, 65536);
1280         uint16_t vmax = uint16_t(vmax0 - 1 - min);
1281 
1282         for (size_t i = 0; i < n; i++) {
1283             uint16_t v = data[i];
1284             v -= min;
1285             if (!(v <= vmax))
1286                 continue;
1287             v >>= shift;
1288             hist[v]++;
1289 
1290             /*
1291             if (data[i] < min) continue;
1292             uint16_t v = data[i] - min;
1293             v >>= shift;
1294             if (v < 16) hist[v]++;
1295             */
1296         }
1297     }
1298 }
1299 
simd_histogram_8(const uint16_t * data,int n,uint16_t min,int shift,int * hist)1300 void simd_histogram_8(
1301         const uint16_t* data,
1302         int n,
1303         uint16_t min,
1304         int shift,
1305         int* hist) {
1306     memset(hist, 0, sizeof(*hist) * 8);
1307     if (shift < 0) {
1308         for (size_t i = 0; i < n; i++) {
1309             hist[data[i]]++;
1310         }
1311     } else {
1312         for (size_t i = 0; i < n; i++) {
1313             if (data[i] < min)
1314                 continue;
1315             uint16_t v = data[i] - min;
1316             v >>= shift;
1317             if (v < 8)
1318                 hist[v]++;
1319         }
1320     }
1321 }
1322 
1323 #endif
1324 
reset()1325 void PartitionStats::reset() {
1326     memset(this, 0, sizeof(*this));
1327 }
1328 
1329 PartitionStats partition_stats;
1330 
1331 } // namespace faiss
1332