1 /**
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 *
4 * This source code is licensed under the MIT license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
8 #include <faiss/utils/partitioning.h>
9
10 #include <cassert>
11 #include <cmath>
12
13 #include <faiss/impl/FaissAssert.h>
14 #include <faiss/utils/AlignedTable.h>
15 #include <faiss/utils/ordered_key_value.h>
16 #include <faiss/utils/simdlib.h>
17
18 #include <faiss/impl/platform_macros.h>
19
20 namespace faiss {
21
22 /******************************************************************
23 * Internal routines
24 ******************************************************************/
25
26 namespace partitioning {
27
28 template <typename T>
median3(T a,T b,T c)29 T median3(T a, T b, T c) {
30 if (a > b) {
31 std::swap(a, b);
32 }
33 if (c > b) {
34 return b;
35 }
36 if (c > a) {
37 return c;
38 }
39 return a;
40 }
41
42 template <class C>
sample_threshold_median3(const typename C::T * vals,int n,typename C::T thresh_inf,typename C::T thresh_sup)43 typename C::T sample_threshold_median3(
44 const typename C::T* vals,
45 int n,
46 typename C::T thresh_inf,
47 typename C::T thresh_sup) {
48 using T = typename C::T;
49 size_t big_prime = 6700417;
50 T val3[3];
51 int vi = 0;
52
53 for (size_t i = 0; i < n; i++) {
54 T v = vals[(i * big_prime) % n];
55 // thresh_inf < v < thresh_sup (for CMax)
56 if (C::cmp(v, thresh_inf) && C::cmp(thresh_sup, v)) {
57 val3[vi++] = v;
58 if (vi == 3) {
59 break;
60 }
61 }
62 }
63
64 if (vi == 3) {
65 return median3(val3[0], val3[1], val3[2]);
66 } else if (vi != 0) {
67 return val3[0];
68 } else {
69 return thresh_inf;
70 // FAISS_THROW_MSG("too few values to compute a median");
71 }
72 }
73
74 template <class C>
count_lt_and_eq(const typename C::T * vals,size_t n,typename C::T thresh,size_t & n_lt,size_t & n_eq)75 void count_lt_and_eq(
76 const typename C::T* vals,
77 size_t n,
78 typename C::T thresh,
79 size_t& n_lt,
80 size_t& n_eq) {
81 n_lt = n_eq = 0;
82
83 for (size_t i = 0; i < n; i++) {
84 typename C::T v = *vals++;
85 if (C::cmp(thresh, v)) {
86 n_lt++;
87 } else if (v == thresh) {
88 n_eq++;
89 }
90 }
91 }
92
93 template <class C>
compress_array(typename C::T * vals,typename C::TI * ids,size_t n,typename C::T thresh,size_t n_eq)94 size_t compress_array(
95 typename C::T* vals,
96 typename C::TI* ids,
97 size_t n,
98 typename C::T thresh,
99 size_t n_eq) {
100 size_t wp = 0;
101 for (size_t i = 0; i < n; i++) {
102 if (C::cmp(thresh, vals[i])) {
103 vals[wp] = vals[i];
104 ids[wp] = ids[i];
105 wp++;
106 } else if (n_eq > 0 && vals[i] == thresh) {
107 vals[wp] = vals[i];
108 ids[wp] = ids[i];
109 wp++;
110 n_eq--;
111 }
112 }
113 assert(n_eq == 0);
114 return wp;
115 }
116
117 #define IFV if (false)
118
119 template <class C>
partition_fuzzy_median3(typename C::T * vals,typename C::TI * ids,size_t n,size_t q_min,size_t q_max,size_t * q_out)120 typename C::T partition_fuzzy_median3(
121 typename C::T* vals,
122 typename C::TI* ids,
123 size_t n,
124 size_t q_min,
125 size_t q_max,
126 size_t* q_out) {
127 if (q_min == 0) {
128 if (q_out) {
129 *q_out = C::Crev::neutral();
130 }
131 return 0;
132 }
133 if (q_max >= n) {
134 if (q_out) {
135 *q_out = q_max;
136 }
137 return C::neutral();
138 }
139
140 using T = typename C::T;
141
142 // here we use bissection with a median of 3 to find the threshold and
143 // compress the arrays afterwards. So it's a n*log(n) algoirithm rather than
144 // qselect's O(n) but it avoids shuffling around the array.
145
146 FAISS_THROW_IF_NOT(n >= 3);
147
148 T thresh_inf = C::Crev::neutral();
149 T thresh_sup = C::neutral();
150 T thresh = median3(vals[0], vals[n / 2], vals[n - 1]);
151
152 size_t n_eq = 0, n_lt = 0;
153 size_t q = 0;
154
155 for (int it = 0; it < 200; it++) {
156 count_lt_and_eq<C>(vals, n, thresh, n_lt, n_eq);
157
158 IFV printf(
159 " thresh=%g [%g %g] n_lt=%ld n_eq=%ld, q=%ld:%ld/%ld\n",
160 float(thresh),
161 float(thresh_inf),
162 float(thresh_sup),
163 long(n_lt),
164 long(n_eq),
165 long(q_min),
166 long(q_max),
167 long(n));
168
169 if (n_lt <= q_min) {
170 if (n_lt + n_eq >= q_min) {
171 q = q_min;
172 break;
173 } else {
174 thresh_inf = thresh;
175 }
176 } else if (n_lt <= q_max) {
177 q = n_lt;
178 break;
179 } else {
180 thresh_sup = thresh;
181 }
182
183 // FIXME avoid a second pass over the array to sample the threshold
184 IFV printf(
185 " sample thresh in [%g %g]\n",
186 float(thresh_inf),
187 float(thresh_sup));
188 T new_thresh =
189 sample_threshold_median3<C>(vals, n, thresh_inf, thresh_sup);
190 if (new_thresh == thresh_inf) {
191 // then there is nothing between thresh_inf and thresh_sup
192 break;
193 }
194 thresh = new_thresh;
195 }
196
197 int64_t n_eq_1 = q - n_lt;
198
199 IFV printf("shrink: thresh=%g n_eq_1=%ld\n", float(thresh), long(n_eq_1));
200
201 if (n_eq_1 < 0) { // happens when > q elements are at lower bound
202 q = q_min;
203 thresh = C::Crev::nextafter(thresh);
204 n_eq_1 = q;
205 } else {
206 assert(n_eq_1 <= n_eq);
207 }
208
209 int wp = compress_array<C>(vals, ids, n, thresh, n_eq_1);
210
211 assert(wp == q);
212 if (q_out) {
213 *q_out = q;
214 }
215
216 return thresh;
217 }
218
219 } // namespace partitioning
220
221 /******************************************************************
222 * SIMD routines when vals is an aligned array of uint16_t
223 ******************************************************************/
224
225 namespace simd_partitioning {
226
find_minimax(const uint16_t * vals,size_t n,uint16_t & smin,uint16_t & smax)227 void find_minimax(
228 const uint16_t* vals,
229 size_t n,
230 uint16_t& smin,
231 uint16_t& smax) {
232 simd16uint16 vmin(0xffff), vmax(0);
233 for (size_t i = 0; i + 15 < n; i += 16) {
234 simd16uint16 v(vals + i);
235 vmin.accu_min(v);
236 vmax.accu_max(v);
237 }
238
239 ALIGNED(32) uint16_t tab32[32];
240 vmin.store(tab32);
241 vmax.store(tab32 + 16);
242
243 smin = tab32[0], smax = tab32[16];
244
245 for (int i = 1; i < 16; i++) {
246 smin = std::min(smin, tab32[i]);
247 smax = std::max(smax, tab32[i + 16]);
248 }
249
250 // missing values
251 for (size_t i = (n & ~15); i < n; i++) {
252 smin = std::min(smin, vals[i]);
253 smax = std::max(smax, vals[i]);
254 }
255 }
256
257 // max func differentiates between CMin and CMax (keep lowest or largest)
258 template <class C>
max_func(simd16uint16 v,simd16uint16 thr16)259 simd16uint16 max_func(simd16uint16 v, simd16uint16 thr16) {
260 constexpr bool is_max = C::is_max;
261 if (is_max) {
262 return max(v, thr16);
263 } else {
264 return min(v, thr16);
265 }
266 }
267
268 template <class C>
count_lt_and_eq(const uint16_t * vals,int n,uint16_t thresh,size_t & n_lt,size_t & n_eq)269 void count_lt_and_eq(
270 const uint16_t* vals,
271 int n,
272 uint16_t thresh,
273 size_t& n_lt,
274 size_t& n_eq) {
275 n_lt = n_eq = 0;
276 simd16uint16 thr16(thresh);
277
278 size_t n1 = n / 16;
279
280 for (size_t i = 0; i < n1; i++) {
281 simd16uint16 v(vals);
282 vals += 16;
283 simd16uint16 eqmask = (v == thr16);
284 simd16uint16 max2 = max_func<C>(v, thr16);
285 simd16uint16 gemask = (v == max2);
286 uint32_t bits = get_MSBs(uint16_to_uint8_saturate(eqmask, gemask));
287 int i_eq = __builtin_popcount(bits & 0x00ff00ff);
288 int i_ge = __builtin_popcount(bits) - i_eq;
289 n_eq += i_eq;
290 n_lt += 16 - i_ge;
291 }
292
293 for (size_t i = n1 * 16; i < n; i++) {
294 uint16_t v = *vals++;
295 if (C::cmp(thresh, v)) {
296 n_lt++;
297 } else if (v == thresh) {
298 n_eq++;
299 }
300 }
301 }
302
303 /* compress separated values and ids table, keeping all values < thresh and at
304 * most n_eq equal values */
305 template <class C>
simd_compress_array(uint16_t * vals,typename C::TI * ids,size_t n,uint16_t thresh,int n_eq)306 int simd_compress_array(
307 uint16_t* vals,
308 typename C::TI* ids,
309 size_t n,
310 uint16_t thresh,
311 int n_eq) {
312 simd16uint16 thr16(thresh);
313 simd16uint16 mixmask(0xff00);
314
315 int wp = 0;
316 size_t i0;
317
318 // loop while there are eqs to collect
319 for (i0 = 0; i0 + 15 < n && n_eq > 0; i0 += 16) {
320 simd16uint16 v(vals + i0);
321 simd16uint16 max2 = max_func<C>(v, thr16);
322 simd16uint16 gemask = (v == max2);
323 simd16uint16 eqmask = (v == thr16);
324 uint32_t bits = get_MSBs(
325 blendv(simd32uint8(eqmask),
326 simd32uint8(gemask),
327 simd32uint8(mixmask)));
328 bits ^= 0xAAAAAAAA;
329 // bit 2*i : eq
330 // bit 2*i + 1 : lt
331
332 while (bits) {
333 int j = __builtin_ctz(bits) & (~1);
334 bool is_eq = (bits >> j) & 1;
335 bool is_lt = (bits >> j) & 2;
336 bits &= ~(3 << j);
337 j >>= 1;
338
339 if (is_lt) {
340 vals[wp] = vals[i0 + j];
341 ids[wp] = ids[i0 + j];
342 wp++;
343 } else if (is_eq && n_eq > 0) {
344 vals[wp] = vals[i0 + j];
345 ids[wp] = ids[i0 + j];
346 wp++;
347 n_eq--;
348 }
349 }
350 }
351
352 // handle remaining, only striclty lt ones.
353 for (; i0 + 15 < n; i0 += 16) {
354 simd16uint16 v(vals + i0);
355 simd16uint16 max2 = max_func<C>(v, thr16);
356 simd16uint16 gemask = (v == max2);
357 uint32_t bits = ~get_MSBs(simd32uint8(gemask));
358
359 while (bits) {
360 int j = __builtin_ctz(bits);
361 bits &= ~(3 << j);
362 j >>= 1;
363
364 vals[wp] = vals[i0 + j];
365 ids[wp] = ids[i0 + j];
366 wp++;
367 }
368 }
369
370 // end with scalar
371 for (int i = (n & ~15); i < n; i++) {
372 if (C::cmp(thresh, vals[i])) {
373 vals[wp] = vals[i];
374 ids[wp] = ids[i];
375 wp++;
376 } else if (vals[i] == thresh && n_eq > 0) {
377 vals[wp] = vals[i];
378 ids[wp] = ids[i];
379 wp++;
380 n_eq--;
381 }
382 }
383 assert(n_eq == 0);
384 return wp;
385 }
386
387 // #define MICRO_BENCHMARK
388
get_cy()389 static uint64_t get_cy() {
390 #ifdef MICRO_BENCHMARK
391 uint32_t high, low;
392 asm volatile("rdtsc \n\t" : "=a"(low), "=d"(high));
393 return ((uint64_t)high << 32) | (low);
394 #else
395 return 0;
396 #endif
397 }
398
399 #define IFV if (false)
400
401 template <class C>
simd_partition_fuzzy_with_bounds(uint16_t * vals,typename C::TI * ids,size_t n,size_t q_min,size_t q_max,size_t * q_out,uint16_t s0i,uint16_t s1i)402 uint16_t simd_partition_fuzzy_with_bounds(
403 uint16_t* vals,
404 typename C::TI* ids,
405 size_t n,
406 size_t q_min,
407 size_t q_max,
408 size_t* q_out,
409 uint16_t s0i,
410 uint16_t s1i) {
411 if (q_min == 0) {
412 if (q_out) {
413 *q_out = 0;
414 }
415 return 0;
416 }
417 if (q_max >= n) {
418 if (q_out) {
419 *q_out = q_max;
420 }
421 return 0xffff;
422 }
423 if (s0i == s1i) {
424 if (q_out) {
425 *q_out = q_min;
426 }
427 return s0i;
428 }
429 uint64_t t0 = get_cy();
430
431 // lower bound inclusive, upper exclusive
432 size_t s0 = s0i, s1 = s1i + 1;
433
434 IFV printf("bounds: %ld %ld\n", s0, s1 - 1);
435
436 int thresh;
437 size_t n_eq = 0, n_lt = 0;
438 size_t q = 0;
439
440 for (int it = 0; it < 200; it++) {
441 // while(s0 + 1 < s1) {
442 thresh = (s0 + s1) / 2;
443 count_lt_and_eq<C>(vals, n, thresh, n_lt, n_eq);
444
445 IFV printf(
446 " [%ld %ld] thresh=%d n_lt=%ld n_eq=%ld, q=%ld:%ld/%ld\n",
447 s0,
448 s1,
449 thresh,
450 n_lt,
451 n_eq,
452 q_min,
453 q_max,
454 n);
455 if (n_lt <= q_min) {
456 if (n_lt + n_eq >= q_min) {
457 q = q_min;
458 break;
459 } else {
460 if (C::is_max) {
461 s0 = thresh;
462 } else {
463 s1 = thresh;
464 }
465 }
466 } else if (n_lt <= q_max) {
467 q = n_lt;
468 break;
469 } else {
470 if (C::is_max) {
471 s1 = thresh;
472 } else {
473 s0 = thresh;
474 }
475 }
476 }
477
478 uint64_t t1 = get_cy();
479
480 // number of equal values to keep
481 int64_t n_eq_1 = q - n_lt;
482
483 IFV printf("shrink: thresh=%d q=%ld n_eq_1=%ld\n", thresh, q, n_eq_1);
484 if (n_eq_1 < 0) { // happens when > q elements are at lower bound
485 assert(s0 + 1 == s1);
486 q = q_min;
487 if (C::is_max) {
488 thresh--;
489 } else {
490 thresh++;
491 }
492 n_eq_1 = q;
493 IFV printf(" override: thresh=%d n_eq_1=%ld\n", thresh, n_eq_1);
494 } else {
495 assert(n_eq_1 <= n_eq);
496 }
497
498 size_t wp = simd_compress_array<C>(vals, ids, n, thresh, n_eq_1);
499
500 IFV printf("wp=%ld\n", wp);
501 assert(wp == q);
502 if (q_out) {
503 *q_out = q;
504 }
505
506 uint64_t t2 = get_cy();
507
508 partition_stats.bissect_cycles += t1 - t0;
509 partition_stats.compress_cycles += t2 - t1;
510
511 return thresh;
512 }
513
514 template <class C>
simd_partition_fuzzy_with_bounds_histogram(uint16_t * vals,typename C::TI * ids,size_t n,size_t q_min,size_t q_max,size_t * q_out,uint16_t s0i,uint16_t s1i)515 uint16_t simd_partition_fuzzy_with_bounds_histogram(
516 uint16_t* vals,
517 typename C::TI* ids,
518 size_t n,
519 size_t q_min,
520 size_t q_max,
521 size_t* q_out,
522 uint16_t s0i,
523 uint16_t s1i) {
524 if (q_min == 0) {
525 if (q_out) {
526 *q_out = 0;
527 }
528 return 0;
529 }
530 if (q_max >= n) {
531 if (q_out) {
532 *q_out = q_max;
533 }
534 return 0xffff;
535 }
536 if (s0i == s1i) {
537 if (q_out) {
538 *q_out = q_min;
539 }
540 return s0i;
541 }
542
543 IFV printf(
544 "partition fuzzy, q=%ld:%ld / %ld, bounds=%d %d\n",
545 q_min,
546 q_max,
547 n,
548 s0i,
549 s1i);
550
551 if (!C::is_max) {
552 IFV printf(
553 "revert due to CMin, q_min:q_max -> %ld:%ld\n", q_min, q_max);
554 q_min = n - q_min;
555 q_max = n - q_max;
556 }
557
558 // lower and upper bound of range, inclusive
559 int s0 = s0i, s1 = s1i;
560 // number of values < s0 and > s1
561 size_t n_lt = 0, n_gt = 0;
562
563 // output of loop:
564 int thresh; // final threshold
565 uint64_t tot_eq = 0; // total nb of equal values
566 uint64_t n_eq = 0; // nb of equal values to keep
567 size_t q; // final quantile
568
569 // buffer for the histograms
570 int hist[16];
571
572 for (int it = 0; it < 20; it++) {
573 // otherwise we would be done already
574
575 int shift = 0;
576
577 IFV printf(
578 " it %d bounds: %d %d n_lt=%ld n_gt=%ld\n",
579 it,
580 s0,
581 s1,
582 n_lt,
583 n_gt);
584
585 int maxval = s1 - s0;
586
587 while (maxval > 15) {
588 shift++;
589 maxval >>= 1;
590 }
591
592 IFV printf(
593 " histogram shift %d maxval %d ?= %d\n",
594 shift,
595 maxval,
596 int((s1 - s0) >> shift));
597
598 if (maxval > 7) {
599 simd_histogram_16(vals, n, s0, shift, hist);
600 } else {
601 simd_histogram_8(vals, n, s0, shift, hist);
602 }
603 IFV {
604 int sum = n_lt + n_gt;
605 printf(" n_lt=%ld hist=[", n_lt);
606 for (int i = 0; i <= maxval; i++) {
607 printf("%d ", hist[i]);
608 sum += hist[i];
609 }
610 printf("] n_gt=%ld sum=%d\n", n_gt, sum);
611 assert(sum == n);
612 }
613
614 size_t sum_below = n_lt;
615 int i;
616 for (i = 0; i <= maxval; i++) {
617 sum_below += hist[i];
618 if (sum_below >= q_min) {
619 break;
620 }
621 }
622 IFV printf(" i=%d sum_below=%ld\n", i, sum_below);
623 if (i <= maxval) {
624 s0 = s0 + (i << shift);
625 s1 = s0 + (1 << shift) - 1;
626 n_lt = sum_below - hist[i];
627 n_gt = n - sum_below;
628 } else {
629 assert(!"not implemented");
630 }
631
632 IFV printf(
633 " new bin: s0=%d s1=%d n_lt=%ld n_gt=%ld\n",
634 s0,
635 s1,
636 n_lt,
637 n_gt);
638
639 if (s1 > s0) {
640 if (n_lt >= q_min && q_max >= n_lt) {
641 IFV printf(" FOUND1\n");
642 thresh = s0;
643 q = n_lt;
644 break;
645 }
646
647 size_t n_lt_2 = n - n_gt;
648 if (n_lt_2 >= q_min && q_max >= n_lt_2) {
649 thresh = s1 + 1;
650 q = n_lt_2;
651 IFV printf(" FOUND2\n");
652 break;
653 }
654 } else {
655 thresh = s0;
656 q = q_min;
657 tot_eq = n - n_gt - n_lt;
658 n_eq = q_min - n_lt;
659 IFV printf(" FOUND3\n");
660 break;
661 }
662 }
663
664 IFV printf("end bissection: thresh=%d q=%ld n_eq=%ld\n", thresh, q, n_eq);
665
666 if (!C::is_max) {
667 if (n_eq == 0) {
668 thresh--;
669 } else {
670 // thresh unchanged
671 n_eq = tot_eq - n_eq;
672 }
673 q = n - q;
674 IFV printf("revert due to CMin, q->%ld n_eq->%ld\n", q, n_eq);
675 }
676
677 size_t wp = simd_compress_array<C>(vals, ids, n, thresh, n_eq);
678 IFV printf("wp=%ld ?= %ld\n", wp, q);
679 assert(wp == q);
680 if (q_out) {
681 *q_out = wp;
682 }
683
684 return thresh;
685 }
686
687 template <class C>
simd_partition_fuzzy(uint16_t * vals,typename C::TI * ids,size_t n,size_t q_min,size_t q_max,size_t * q_out)688 uint16_t simd_partition_fuzzy(
689 uint16_t* vals,
690 typename C::TI* ids,
691 size_t n,
692 size_t q_min,
693 size_t q_max,
694 size_t* q_out) {
695 assert(is_aligned_pointer(vals));
696
697 uint16_t s0i, s1i;
698 find_minimax(vals, n, s0i, s1i);
699 // QSelect_stats.t0 += get_cy() - t0;
700
701 return simd_partition_fuzzy_with_bounds<C>(
702 vals, ids, n, q_min, q_max, q_out, s0i, s1i);
703 }
704
705 template <class C>
simd_partition(uint16_t * vals,typename C::TI * ids,size_t n,size_t q)706 uint16_t simd_partition(
707 uint16_t* vals,
708 typename C::TI* ids,
709 size_t n,
710 size_t q) {
711 assert(is_aligned_pointer(vals));
712
713 if (q == 0) {
714 return 0;
715 }
716 if (q >= n) {
717 return 0xffff;
718 }
719
720 uint16_t s0i, s1i;
721 find_minimax(vals, n, s0i, s1i);
722
723 return simd_partition_fuzzy_with_bounds<C>(
724 vals, ids, n, q, q, nullptr, s0i, s1i);
725 }
726
727 template <class C>
simd_partition_with_bounds(uint16_t * vals,typename C::TI * ids,size_t n,size_t q,uint16_t s0i,uint16_t s1i)728 uint16_t simd_partition_with_bounds(
729 uint16_t* vals,
730 typename C::TI* ids,
731 size_t n,
732 size_t q,
733 uint16_t s0i,
734 uint16_t s1i) {
735 return simd_partition_fuzzy_with_bounds<C>(
736 vals, ids, n, q, q, nullptr, s0i, s1i);
737 }
738
739 } // namespace simd_partitioning
740
741 /******************************************************************
742 * Driver routine
743 ******************************************************************/
744
745 template <class C>
partition_fuzzy(typename C::T * vals,typename C::TI * ids,size_t n,size_t q_min,size_t q_max,size_t * q_out)746 typename C::T partition_fuzzy(
747 typename C::T* vals,
748 typename C::TI* ids,
749 size_t n,
750 size_t q_min,
751 size_t q_max,
752 size_t* q_out) {
753 // the code below compiles and runs without AVX2 but it's slower than
754 // the scalar implementation
755 #ifdef __AVX2__
756 constexpr bool is_uint16 = std::is_same<typename C::T, uint16_t>::value;
757 if (is_uint16 && is_aligned_pointer(vals)) {
758 return simd_partitioning::simd_partition_fuzzy<C>(
759 (uint16_t*)vals, ids, n, q_min, q_max, q_out);
760 }
761 #endif
762 return partitioning::partition_fuzzy_median3<C>(
763 vals, ids, n, q_min, q_max, q_out);
764 }
765
766 // explicit template instanciations
767
768 template float partition_fuzzy<CMin<float, int64_t>>(
769 float* vals,
770 int64_t* ids,
771 size_t n,
772 size_t q_min,
773 size_t q_max,
774 size_t* q_out);
775
776 template float partition_fuzzy<CMax<float, int64_t>>(
777 float* vals,
778 int64_t* ids,
779 size_t n,
780 size_t q_min,
781 size_t q_max,
782 size_t* q_out);
783
784 template uint16_t partition_fuzzy<CMin<uint16_t, int64_t>>(
785 uint16_t* vals,
786 int64_t* ids,
787 size_t n,
788 size_t q_min,
789 size_t q_max,
790 size_t* q_out);
791
792 template uint16_t partition_fuzzy<CMax<uint16_t, int64_t>>(
793 uint16_t* vals,
794 int64_t* ids,
795 size_t n,
796 size_t q_min,
797 size_t q_max,
798 size_t* q_out);
799
800 template uint16_t partition_fuzzy<CMin<uint16_t, int>>(
801 uint16_t* vals,
802 int* ids,
803 size_t n,
804 size_t q_min,
805 size_t q_max,
806 size_t* q_out);
807
808 template uint16_t partition_fuzzy<CMax<uint16_t, int>>(
809 uint16_t* vals,
810 int* ids,
811 size_t n,
812 size_t q_min,
813 size_t q_max,
814 size_t* q_out);
815
816 /******************************************************************
817 * Histogram subroutines
818 ******************************************************************/
819
820 #ifdef __AVX2__
821 /// FIXME when MSB of uint16 is set
822 // this code does not compile properly with GCC 7.4.0
823
824 namespace {
825
826 /************************************************************
827 * 8 bins
828 ************************************************************/
829
accu4to8(simd16uint16 a4)830 simd32uint8 accu4to8(simd16uint16 a4) {
831 simd16uint16 mask4(0x0f0f);
832
833 simd16uint16 a8_0 = a4 & mask4;
834 simd16uint16 a8_1 = (a4 >> 4) & mask4;
835
836 return simd32uint8(_mm256_hadd_epi16(a8_0.i, a8_1.i));
837 }
838
accu8to16(simd32uint8 a8)839 simd16uint16 accu8to16(simd32uint8 a8) {
840 simd16uint16 mask8(0x00ff);
841
842 simd16uint16 a8_0 = simd16uint16(a8) & mask8;
843 simd16uint16 a8_1 = (simd16uint16(a8) >> 8) & mask8;
844
845 return simd16uint16(_mm256_hadd_epi16(a8_0.i, a8_1.i));
846 }
847
848 static const simd32uint8 shifts(_mm256_setr_epi8(
849 1,
850 16,
851 0,
852 0,
853 4,
854 64,
855 0,
856 0,
857 0,
858 0,
859 1,
860 16,
861 0,
862 0,
863 4,
864 64,
865 1,
866 16,
867 0,
868 0,
869 4,
870 64,
871 0,
872 0,
873 0,
874 0,
875 1,
876 16,
877 0,
878 0,
879 4,
880 64));
881
882 // 2-bit accumulator: we can add only up to 3 elements
883 // on output we return 2*4-bit results
884 // preproc returns either an index in 0..7 or 0xffff
885 // that yeilds a 0 when used in the table look-up
886 template <int N, class Preproc>
compute_accu2(const uint16_t * & data,Preproc & pp,simd16uint16 & a4lo,simd16uint16 & a4hi)887 void compute_accu2(
888 const uint16_t*& data,
889 Preproc& pp,
890 simd16uint16& a4lo,
891 simd16uint16& a4hi) {
892 simd16uint16 mask2(0x3333);
893 simd16uint16 a2((uint16_t)0); // 2-bit accu
894 for (int j = 0; j < N; j++) {
895 simd16uint16 v(data);
896 data += 16;
897 v = pp(v);
898 // 0x800 -> force second half of table
899 simd16uint16 idx = v | (v << 8) | simd16uint16(0x800);
900 a2 += simd16uint16(shifts.lookup_2_lanes(simd32uint8(idx)));
901 }
902 a4lo += a2 & mask2;
903 a4hi += (a2 >> 2) & mask2;
904 }
905
906 template <class Preproc>
histogram_8(const uint16_t * data,Preproc pp,size_t n_in)907 simd16uint16 histogram_8(const uint16_t* data, Preproc pp, size_t n_in) {
908 assert(n_in % 16 == 0);
909 int n = n_in / 16;
910
911 simd32uint8 a8lo(0);
912 simd32uint8 a8hi(0);
913
914 for (int i0 = 0; i0 < n; i0 += 15) {
915 simd16uint16 a4lo(0); // 4-bit accus
916 simd16uint16 a4hi(0);
917
918 int i1 = std::min(i0 + 15, n);
919 int i;
920 for (i = i0; i + 2 < i1; i += 3) {
921 compute_accu2<3>(data, pp, a4lo, a4hi); // adds 3 max
922 }
923 switch (i1 - i) {
924 case 2:
925 compute_accu2<2>(data, pp, a4lo, a4hi);
926 break;
927 case 1:
928 compute_accu2<1>(data, pp, a4lo, a4hi);
929 break;
930 }
931
932 a8lo += accu4to8(a4lo);
933 a8hi += accu4to8(a4hi);
934 }
935
936 // move to 16-bit accu
937 simd16uint16 a16lo = accu8to16(a8lo);
938 simd16uint16 a16hi = accu8to16(a8hi);
939
940 simd16uint16 a16 = simd16uint16(_mm256_hadd_epi16(a16lo.i, a16hi.i));
941
942 // the 2 lanes must still be combined
943 return a16;
944 }
945
946 /************************************************************
947 * 16 bins
948 ************************************************************/
949
950 static const simd32uint8 shifts2(_mm256_setr_epi8(
951 1,
952 2,
953 4,
954 8,
955 16,
956 32,
957 64,
958 (char)128,
959 1,
960 2,
961 4,
962 8,
963 16,
964 32,
965 64,
966 (char)128,
967 1,
968 2,
969 4,
970 8,
971 16,
972 32,
973 64,
974 (char)128,
975 1,
976 2,
977 4,
978 8,
979 16,
980 32,
981 64,
982 (char)128));
983
shiftr_16(simd32uint8 x,int n)984 simd32uint8 shiftr_16(simd32uint8 x, int n) {
985 return simd32uint8(simd16uint16(x) >> n);
986 }
987
combine_2x2(simd32uint8 a,simd32uint8 b)988 inline simd32uint8 combine_2x2(simd32uint8 a, simd32uint8 b) {
989 __m256i a1b0 = _mm256_permute2f128_si256(a.i, b.i, 0x21);
990 __m256i a0b1 = _mm256_blend_epi32(a.i, b.i, 0xF0);
991
992 return simd32uint8(a1b0) + simd32uint8(a0b1);
993 }
994
995 // 2-bit accumulator: we can add only up to 3 elements
996 // on output we return 2*4-bit results
997 template <int N, class Preproc>
compute_accu2_16(const uint16_t * & data,Preproc pp,simd32uint8 & a4_0,simd32uint8 & a4_1,simd32uint8 & a4_2,simd32uint8 & a4_3)998 void compute_accu2_16(
999 const uint16_t*& data,
1000 Preproc pp,
1001 simd32uint8& a4_0,
1002 simd32uint8& a4_1,
1003 simd32uint8& a4_2,
1004 simd32uint8& a4_3) {
1005 simd32uint8 mask1(0x55);
1006 simd32uint8 a2_0; // 2-bit accu
1007 simd32uint8 a2_1; // 2-bit accu
1008 a2_0.clear();
1009 a2_1.clear();
1010
1011 for (int j = 0; j < N; j++) {
1012 simd16uint16 v(data);
1013 data += 16;
1014 v = pp(v);
1015
1016 simd16uint16 idx = v | (v << 8);
1017 simd32uint8 a1 = shifts2.lookup_2_lanes(simd32uint8(idx));
1018 // contains 0s for out-of-bounds elements
1019
1020 simd16uint16 lt8 = (v >> 3) == simd16uint16(0);
1021 lt8.i = _mm256_xor_si256(lt8.i, _mm256_set1_epi16(0xff00));
1022
1023 a1 = a1 & lt8;
1024
1025 a2_0 += a1 & mask1;
1026 a2_1 += shiftr_16(a1, 1) & mask1;
1027 }
1028 simd32uint8 mask2(0x33);
1029
1030 a4_0 += a2_0 & mask2;
1031 a4_1 += a2_1 & mask2;
1032 a4_2 += shiftr_16(a2_0, 2) & mask2;
1033 a4_3 += shiftr_16(a2_1, 2) & mask2;
1034 }
1035
accu4to8_2(simd32uint8 a4_0,simd32uint8 a4_1)1036 simd32uint8 accu4to8_2(simd32uint8 a4_0, simd32uint8 a4_1) {
1037 simd32uint8 mask4(0x0f);
1038
1039 simd32uint8 a8_0 = combine_2x2(a4_0 & mask4, shiftr_16(a4_0, 4) & mask4);
1040
1041 simd32uint8 a8_1 = combine_2x2(a4_1 & mask4, shiftr_16(a4_1, 4) & mask4);
1042
1043 return simd32uint8(_mm256_hadd_epi16(a8_0.i, a8_1.i));
1044 }
1045
1046 template <class Preproc>
histogram_16(const uint16_t * data,Preproc pp,size_t n_in)1047 simd16uint16 histogram_16(const uint16_t* data, Preproc pp, size_t n_in) {
1048 assert(n_in % 16 == 0);
1049 int n = n_in / 16;
1050
1051 simd32uint8 a8lo((uint8_t)0);
1052 simd32uint8 a8hi((uint8_t)0);
1053
1054 for (int i0 = 0; i0 < n; i0 += 7) {
1055 simd32uint8 a4_0(0); // 0, 4, 8, 12
1056 simd32uint8 a4_1(0); // 1, 5, 9, 13
1057 simd32uint8 a4_2(0); // 2, 6, 10, 14
1058 simd32uint8 a4_3(0); // 3, 7, 11, 15
1059
1060 int i1 = std::min(i0 + 7, n);
1061 int i;
1062 for (i = i0; i + 2 < i1; i += 3) {
1063 compute_accu2_16<3>(data, pp, a4_0, a4_1, a4_2, a4_3);
1064 }
1065 switch (i1 - i) {
1066 case 2:
1067 compute_accu2_16<2>(data, pp, a4_0, a4_1, a4_2, a4_3);
1068 break;
1069 case 1:
1070 compute_accu2_16<1>(data, pp, a4_0, a4_1, a4_2, a4_3);
1071 break;
1072 }
1073
1074 a8lo += accu4to8_2(a4_0, a4_1);
1075 a8hi += accu4to8_2(a4_2, a4_3);
1076 }
1077
1078 // move to 16-bit accu
1079 simd16uint16 a16lo = accu8to16(a8lo);
1080 simd16uint16 a16hi = accu8to16(a8hi);
1081
1082 simd16uint16 a16 = simd16uint16(_mm256_hadd_epi16(a16lo.i, a16hi.i));
1083
1084 __m256i perm32 = _mm256_setr_epi32(0, 2, 4, 6, 1, 3, 5, 7);
1085 a16.i = _mm256_permutevar8x32_epi32(a16.i, perm32);
1086
1087 return a16;
1088 }
1089
1090 struct PreprocNOP {
operator ()faiss::__anon75a0fc1d0111::PreprocNOP1091 simd16uint16 operator()(simd16uint16 x) {
1092 return x;
1093 }
1094 };
1095
1096 template <int shift, int nbin>
1097 struct PreprocMinShift {
1098 simd16uint16 min16;
1099 simd16uint16 max16;
1100
PreprocMinShiftfaiss::__anon75a0fc1d0111::PreprocMinShift1101 explicit PreprocMinShift(uint16_t min) {
1102 min16.set1(min);
1103 int vmax0 = std::min((nbin << shift) + min, 65536);
1104 uint16_t vmax = uint16_t(vmax0 - 1 - min);
1105 max16.set1(vmax); // vmax inclusive
1106 }
1107
operator ()faiss::__anon75a0fc1d0111::PreprocMinShift1108 simd16uint16 operator()(simd16uint16 x) {
1109 x = x - min16;
1110 simd16uint16 mask = (x == max(x, max16)) - (x == max16);
1111 return (x >> shift) | mask;
1112 }
1113 };
1114
1115 /* unbounded versions of the functions */
1116
simd_histogram_8_unbounded(const uint16_t * data,int n,int * hist)1117 void simd_histogram_8_unbounded(const uint16_t* data, int n, int* hist) {
1118 PreprocNOP pp;
1119 simd16uint16 a16 = histogram_8(data, pp, (n & ~15));
1120
1121 ALIGNED(32) uint16_t a16_tab[16];
1122 a16.store(a16_tab);
1123
1124 for (int i = 0; i < 8; i++) {
1125 hist[i] = a16_tab[i] + a16_tab[i + 8];
1126 }
1127
1128 for (int i = (n & ~15); i < n; i++) {
1129 hist[data[i]]++;
1130 }
1131 }
1132
simd_histogram_16_unbounded(const uint16_t * data,int n,int * hist)1133 void simd_histogram_16_unbounded(const uint16_t* data, int n, int* hist) {
1134 simd16uint16 a16 = histogram_16(data, PreprocNOP(), (n & ~15));
1135
1136 ALIGNED(32) uint16_t a16_tab[16];
1137 a16.store(a16_tab);
1138
1139 for (int i = 0; i < 16; i++) {
1140 hist[i] = a16_tab[i];
1141 }
1142
1143 for (int i = (n & ~15); i < n; i++) {
1144 hist[data[i]]++;
1145 }
1146 }
1147
1148 } // anonymous namespace
1149
1150 /************************************************************
1151 * Driver routines
1152 ************************************************************/
1153
simd_histogram_8(const uint16_t * data,int n,uint16_t min,int shift,int * hist)1154 void simd_histogram_8(
1155 const uint16_t* data,
1156 int n,
1157 uint16_t min,
1158 int shift,
1159 int* hist) {
1160 if (shift < 0) {
1161 simd_histogram_8_unbounded(data, n, hist);
1162 return;
1163 }
1164
1165 simd16uint16 a16;
1166
1167 #define DISPATCH(s) \
1168 case s: \
1169 a16 = histogram_8(data, PreprocMinShift<s, 8>(min), (n & ~15)); \
1170 break
1171
1172 switch (shift) {
1173 DISPATCH(0);
1174 DISPATCH(1);
1175 DISPATCH(2);
1176 DISPATCH(3);
1177 DISPATCH(4);
1178 DISPATCH(5);
1179 DISPATCH(6);
1180 DISPATCH(7);
1181 DISPATCH(8);
1182 DISPATCH(9);
1183 DISPATCH(10);
1184 DISPATCH(11);
1185 DISPATCH(12);
1186 DISPATCH(13);
1187 default:
1188 FAISS_THROW_FMT("dispatch for shift=%d not instantiated", shift);
1189 }
1190 #undef DISPATCH
1191
1192 ALIGNED(32) uint16_t a16_tab[16];
1193 a16.store(a16_tab);
1194
1195 for (int i = 0; i < 8; i++) {
1196 hist[i] = a16_tab[i] + a16_tab[i + 8];
1197 }
1198
1199 // complete with remaining bins
1200 for (int i = (n & ~15); i < n; i++) {
1201 if (data[i] < min)
1202 continue;
1203 uint16_t v = data[i] - min;
1204 v >>= shift;
1205 if (v < 8)
1206 hist[v]++;
1207 }
1208 }
1209
simd_histogram_16(const uint16_t * data,int n,uint16_t min,int shift,int * hist)1210 void simd_histogram_16(
1211 const uint16_t* data,
1212 int n,
1213 uint16_t min,
1214 int shift,
1215 int* hist) {
1216 if (shift < 0) {
1217 simd_histogram_16_unbounded(data, n, hist);
1218 return;
1219 }
1220
1221 simd16uint16 a16;
1222
1223 #define DISPATCH(s) \
1224 case s: \
1225 a16 = histogram_16(data, PreprocMinShift<s, 16>(min), (n & ~15)); \
1226 break
1227
1228 switch (shift) {
1229 DISPATCH(0);
1230 DISPATCH(1);
1231 DISPATCH(2);
1232 DISPATCH(3);
1233 DISPATCH(4);
1234 DISPATCH(5);
1235 DISPATCH(6);
1236 DISPATCH(7);
1237 DISPATCH(8);
1238 DISPATCH(9);
1239 DISPATCH(10);
1240 DISPATCH(11);
1241 DISPATCH(12);
1242 default:
1243 FAISS_THROW_FMT("dispatch for shift=%d not instantiated", shift);
1244 }
1245 #undef DISPATCH
1246
1247 ALIGNED(32) uint16_t a16_tab[16];
1248 a16.store(a16_tab);
1249
1250 for (int i = 0; i < 16; i++) {
1251 hist[i] = a16_tab[i];
1252 }
1253
1254 for (int i = (n & ~15); i < n; i++) {
1255 if (data[i] < min)
1256 continue;
1257 uint16_t v = data[i] - min;
1258 v >>= shift;
1259 if (v < 16)
1260 hist[v]++;
1261 }
1262 }
1263
1264 // no AVX2
1265 #else
1266
simd_histogram_16(const uint16_t * data,int n,uint16_t min,int shift,int * hist)1267 void simd_histogram_16(
1268 const uint16_t* data,
1269 int n,
1270 uint16_t min,
1271 int shift,
1272 int* hist) {
1273 memset(hist, 0, sizeof(*hist) * 16);
1274 if (shift < 0) {
1275 for (size_t i = 0; i < n; i++) {
1276 hist[data[i]]++;
1277 }
1278 } else {
1279 int vmax0 = std::min((16 << shift) + min, 65536);
1280 uint16_t vmax = uint16_t(vmax0 - 1 - min);
1281
1282 for (size_t i = 0; i < n; i++) {
1283 uint16_t v = data[i];
1284 v -= min;
1285 if (!(v <= vmax))
1286 continue;
1287 v >>= shift;
1288 hist[v]++;
1289
1290 /*
1291 if (data[i] < min) continue;
1292 uint16_t v = data[i] - min;
1293 v >>= shift;
1294 if (v < 16) hist[v]++;
1295 */
1296 }
1297 }
1298 }
1299
simd_histogram_8(const uint16_t * data,int n,uint16_t min,int shift,int * hist)1300 void simd_histogram_8(
1301 const uint16_t* data,
1302 int n,
1303 uint16_t min,
1304 int shift,
1305 int* hist) {
1306 memset(hist, 0, sizeof(*hist) * 8);
1307 if (shift < 0) {
1308 for (size_t i = 0; i < n; i++) {
1309 hist[data[i]]++;
1310 }
1311 } else {
1312 for (size_t i = 0; i < n; i++) {
1313 if (data[i] < min)
1314 continue;
1315 uint16_t v = data[i] - min;
1316 v >>= shift;
1317 if (v < 8)
1318 hist[v]++;
1319 }
1320 }
1321 }
1322
1323 #endif
1324
reset()1325 void PartitionStats::reset() {
1326 memset(this, 0, sizeof(*this));
1327 }
1328
1329 PartitionStats partition_stats;
1330
1331 } // namespace faiss
1332