1 /**
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 *
4 * This source code is licensed under the MIT license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
8 #include <faiss/impl/pq4_fast_scan.h>
9
10 #include <faiss/impl/FaissAssert.h>
11 #include <faiss/impl/simd_result_handlers.h>
12
13 namespace faiss {
14
15 using namespace simd_result_handlers;
16
17 /***************************************************************
18 * accumulation functions
19 ***************************************************************/
20
21 namespace {
22
23 /*
24 * The computation kernel
25 * It accumulates results for NQ queries and BB * 32 database elements
26 * writes results in a ResultHandler
27 */
28
29 template <int NQ, int BB, class ResultHandler>
kernel_accumulate_block(int nsq,const uint8_t * codes,const uint8_t * LUT,ResultHandler & res)30 void kernel_accumulate_block(
31 int nsq,
32 const uint8_t* codes,
33 const uint8_t* LUT,
34 ResultHandler& res) {
35 // distance accumulators
36 simd16uint16 accu[NQ][BB][4];
37
38 for (int q = 0; q < NQ; q++) {
39 for (int b = 0; b < BB; b++) {
40 accu[q][b][0].clear();
41 accu[q][b][1].clear();
42 accu[q][b][2].clear();
43 accu[q][b][3].clear();
44 }
45 }
46
47 for (int sq = 0; sq < nsq; sq += 2) {
48 simd32uint8 lut_cache[NQ];
49 for (int q = 0; q < NQ; q++) {
50 lut_cache[q] = simd32uint8(LUT);
51 LUT += 32;
52 }
53
54 for (int b = 0; b < BB; b++) {
55 simd32uint8 c = simd32uint8(codes);
56 codes += 32;
57 simd32uint8 mask(15);
58 simd32uint8 chi = simd32uint8(simd16uint16(c) >> 4) & mask;
59 simd32uint8 clo = c & mask;
60
61 for (int q = 0; q < NQ; q++) {
62 simd32uint8 lut = lut_cache[q];
63 simd32uint8 res0 = lut.lookup_2_lanes(clo);
64 simd32uint8 res1 = lut.lookup_2_lanes(chi);
65
66 accu[q][b][0] += simd16uint16(res0);
67 accu[q][b][1] += simd16uint16(res0) >> 8;
68
69 accu[q][b][2] += simd16uint16(res1);
70 accu[q][b][3] += simd16uint16(res1) >> 8;
71 }
72 }
73 }
74
75 for (int q = 0; q < NQ; q++) {
76 for (int b = 0; b < BB; b++) {
77 accu[q][b][0] -= accu[q][b][1] << 8;
78 simd16uint16 dis0 = combine2x2(accu[q][b][0], accu[q][b][1]);
79
80 accu[q][b][2] -= accu[q][b][3] << 8;
81 simd16uint16 dis1 = combine2x2(accu[q][b][2], accu[q][b][3]);
82
83 res.handle(q, b, dis0, dis1);
84 }
85 }
86 }
87
88 template <int NQ, int BB, class ResultHandler>
accumulate_fixed_blocks(size_t nb,int nsq,const uint8_t * codes,const uint8_t * LUT,ResultHandler & res)89 void accumulate_fixed_blocks(
90 size_t nb,
91 int nsq,
92 const uint8_t* codes,
93 const uint8_t* LUT,
94 ResultHandler& res) {
95 constexpr int bbs = 32 * BB;
96 for (int64_t j0 = 0; j0 < nb; j0 += bbs) {
97 FixedStorageHandler<NQ, 2 * BB> res2;
98 kernel_accumulate_block<NQ, BB>(nsq, codes, LUT, res2);
99 res.set_block_origin(0, j0);
100 res2.to_other_handler(res);
101 codes += bbs * nsq / 2;
102 }
103 }
104
105 } // anonymous namespace
106
107 template <class ResultHandler>
pq4_accumulate_loop(int nq,size_t nb,int bbs,int nsq,const uint8_t * codes,const uint8_t * LUT,ResultHandler & res)108 void pq4_accumulate_loop(
109 int nq,
110 size_t nb,
111 int bbs,
112 int nsq,
113 const uint8_t* codes,
114 const uint8_t* LUT,
115 ResultHandler& res) {
116 FAISS_THROW_IF_NOT(is_aligned_pointer(codes));
117 FAISS_THROW_IF_NOT(is_aligned_pointer(LUT));
118 FAISS_THROW_IF_NOT(bbs % 32 == 0);
119 FAISS_THROW_IF_NOT(nb % bbs == 0);
120
121 #define DISPATCH(NQ, BB) \
122 case NQ * 1000 + BB: \
123 accumulate_fixed_blocks<NQ, BB>(nb, nsq, codes, LUT, res); \
124 break
125
126 switch (nq * 1000 + bbs / 32) {
127 DISPATCH(1, 1);
128 DISPATCH(1, 2);
129 DISPATCH(1, 3);
130 DISPATCH(1, 4);
131 DISPATCH(1, 5);
132 DISPATCH(2, 1);
133 DISPATCH(2, 2);
134 DISPATCH(3, 1);
135 DISPATCH(4, 1);
136 default:
137 FAISS_THROW_FMT("nq=%d bbs=%d not instantiated", nq, bbs);
138 }
139 #undef DISPATCH
140 }
141
142 // explicit template instantiations
143
144 #define INSTANTIATE_ACCUMULATE(TH, C, with_id_map) \
145 template void pq4_accumulate_loop<TH<C, with_id_map>>( \
146 int, \
147 size_t, \
148 int, \
149 int, \
150 const uint8_t*, \
151 const uint8_t*, \
152 TH<C, with_id_map>&);
153
154 #define INSTANTIATE_3(C, with_id_map) \
155 INSTANTIATE_ACCUMULATE(SingleResultHandler, C, with_id_map) \
156 INSTANTIATE_ACCUMULATE(HeapHandler, C, with_id_map) \
157 INSTANTIATE_ACCUMULATE(ReservoirHandler, C, with_id_map)
158
159 using Csi = CMax<uint16_t, int>;
160 INSTANTIATE_3(Csi, false);
161 using CsiMin = CMin<uint16_t, int>;
162 INSTANTIATE_3(CsiMin, false);
163
164 using Csl = CMax<uint16_t, int64_t>;
165 INSTANTIATE_3(Csl, true);
166 using CslMin = CMin<uint16_t, int64_t>;
167 INSTANTIATE_3(CslMin, true);
168
169 } // namespace faiss
170