1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #include <faiss/impl/pq4_fast_scan.h>
9 
10 #include <faiss/impl/FaissAssert.h>
11 #include <faiss/impl/simd_result_handlers.h>
12 
13 namespace faiss {
14 
15 using namespace simd_result_handlers;
16 
17 /***************************************************************
18  * accumulation functions
19  ***************************************************************/
20 
21 namespace {
22 
23 /*
24  * The computation kernel
25  * It accumulates results for NQ queries and BB * 32 database elements
26  * writes results in a ResultHandler
27  */
28 
29 template <int NQ, int BB, class ResultHandler>
kernel_accumulate_block(int nsq,const uint8_t * codes,const uint8_t * LUT,ResultHandler & res)30 void kernel_accumulate_block(
31         int nsq,
32         const uint8_t* codes,
33         const uint8_t* LUT,
34         ResultHandler& res) {
35     // distance accumulators
36     simd16uint16 accu[NQ][BB][4];
37 
38     for (int q = 0; q < NQ; q++) {
39         for (int b = 0; b < BB; b++) {
40             accu[q][b][0].clear();
41             accu[q][b][1].clear();
42             accu[q][b][2].clear();
43             accu[q][b][3].clear();
44         }
45     }
46 
47     for (int sq = 0; sq < nsq; sq += 2) {
48         simd32uint8 lut_cache[NQ];
49         for (int q = 0; q < NQ; q++) {
50             lut_cache[q] = simd32uint8(LUT);
51             LUT += 32;
52         }
53 
54         for (int b = 0; b < BB; b++) {
55             simd32uint8 c = simd32uint8(codes);
56             codes += 32;
57             simd32uint8 mask(15);
58             simd32uint8 chi = simd32uint8(simd16uint16(c) >> 4) & mask;
59             simd32uint8 clo = c & mask;
60 
61             for (int q = 0; q < NQ; q++) {
62                 simd32uint8 lut = lut_cache[q];
63                 simd32uint8 res0 = lut.lookup_2_lanes(clo);
64                 simd32uint8 res1 = lut.lookup_2_lanes(chi);
65 
66                 accu[q][b][0] += simd16uint16(res0);
67                 accu[q][b][1] += simd16uint16(res0) >> 8;
68 
69                 accu[q][b][2] += simd16uint16(res1);
70                 accu[q][b][3] += simd16uint16(res1) >> 8;
71             }
72         }
73     }
74 
75     for (int q = 0; q < NQ; q++) {
76         for (int b = 0; b < BB; b++) {
77             accu[q][b][0] -= accu[q][b][1] << 8;
78             simd16uint16 dis0 = combine2x2(accu[q][b][0], accu[q][b][1]);
79 
80             accu[q][b][2] -= accu[q][b][3] << 8;
81             simd16uint16 dis1 = combine2x2(accu[q][b][2], accu[q][b][3]);
82 
83             res.handle(q, b, dis0, dis1);
84         }
85     }
86 }
87 
88 template <int NQ, int BB, class ResultHandler>
accumulate_fixed_blocks(size_t nb,int nsq,const uint8_t * codes,const uint8_t * LUT,ResultHandler & res)89 void accumulate_fixed_blocks(
90         size_t nb,
91         int nsq,
92         const uint8_t* codes,
93         const uint8_t* LUT,
94         ResultHandler& res) {
95     constexpr int bbs = 32 * BB;
96     for (int64_t j0 = 0; j0 < nb; j0 += bbs) {
97         FixedStorageHandler<NQ, 2 * BB> res2;
98         kernel_accumulate_block<NQ, BB>(nsq, codes, LUT, res2);
99         res.set_block_origin(0, j0);
100         res2.to_other_handler(res);
101         codes += bbs * nsq / 2;
102     }
103 }
104 
105 } // anonymous namespace
106 
107 template <class ResultHandler>
pq4_accumulate_loop(int nq,size_t nb,int bbs,int nsq,const uint8_t * codes,const uint8_t * LUT,ResultHandler & res)108 void pq4_accumulate_loop(
109         int nq,
110         size_t nb,
111         int bbs,
112         int nsq,
113         const uint8_t* codes,
114         const uint8_t* LUT,
115         ResultHandler& res) {
116     FAISS_THROW_IF_NOT(is_aligned_pointer(codes));
117     FAISS_THROW_IF_NOT(is_aligned_pointer(LUT));
118     FAISS_THROW_IF_NOT(bbs % 32 == 0);
119     FAISS_THROW_IF_NOT(nb % bbs == 0);
120 
121 #define DISPATCH(NQ, BB)                                           \
122     case NQ * 1000 + BB:                                           \
123         accumulate_fixed_blocks<NQ, BB>(nb, nsq, codes, LUT, res); \
124         break
125 
126     switch (nq * 1000 + bbs / 32) {
127         DISPATCH(1, 1);
128         DISPATCH(1, 2);
129         DISPATCH(1, 3);
130         DISPATCH(1, 4);
131         DISPATCH(1, 5);
132         DISPATCH(2, 1);
133         DISPATCH(2, 2);
134         DISPATCH(3, 1);
135         DISPATCH(4, 1);
136         default:
137             FAISS_THROW_FMT("nq=%d bbs=%d not instantiated", nq, bbs);
138     }
139 #undef DISPATCH
140 }
141 
142 // explicit template instantiations
143 
144 #define INSTANTIATE_ACCUMULATE(TH, C, with_id_map)         \
145     template void pq4_accumulate_loop<TH<C, with_id_map>>( \
146             int,                                           \
147             size_t,                                        \
148             int,                                           \
149             int,                                           \
150             const uint8_t*,                                \
151             const uint8_t*,                                \
152             TH<C, with_id_map>&);
153 
154 #define INSTANTIATE_3(C, with_id_map)                           \
155     INSTANTIATE_ACCUMULATE(SingleResultHandler, C, with_id_map) \
156     INSTANTIATE_ACCUMULATE(HeapHandler, C, with_id_map)         \
157     INSTANTIATE_ACCUMULATE(ReservoirHandler, C, with_id_map)
158 
159 using Csi = CMax<uint16_t, int>;
160 INSTANTIATE_3(Csi, false);
161 using CsiMin = CMin<uint16_t, int>;
162 INSTANTIATE_3(CsiMin, false);
163 
164 using Csl = CMax<uint16_t, int64_t>;
165 INSTANTIATE_3(Csl, true);
166 using CslMin = CMin<uint16_t, int64_t>;
167 INSTANTIATE_3(CslMin, true);
168 
169 } // namespace faiss
170