1 /**
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 *
4 * This source code is licensed under the MIT license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
8 #include <faiss/gpu/utils/DeviceUtils.h>
9 #include <faiss/gpu/utils/DeviceDefs.cuh>
10 #include <faiss/gpu/utils/DeviceTensor.cuh>
11 #include <faiss/gpu/utils/Select.cuh>
12
13 namespace faiss {
14 namespace gpu {
15
16 // Number of warps that the kernel is instantiated with
17 constexpr int kWarps = 8;
18 constexpr int kLanes = kWarpSize;
19
20 constexpr int kMaxDistance = std::numeric_limits<int>::max();
21
22 // Performs a binary matrix multiplication, returning the lowest k results in
23 // `vecs` for each `query` in terms of Hamming distance (a fused kernel)
24 // Each warp calculates distance for a single query
25 template <int NumWarpQ, int NumThreadQ, typename BinaryType>
__launch_bounds__(kWarps * kLanes)26 __launch_bounds__(kWarps* kLanes) __global__ void binaryDistanceAnySize(
27 const Tensor<BinaryType, 2, true> vecs,
28 const Tensor<BinaryType, 2, true> query,
29 Tensor<int, 2, true> outK,
30 Tensor<int, 2, true> outV,
31 int k) {
32 // A matrix tile (query, k)
33 __shared__ BinaryType queryTile[kWarps][kLanes + 1]; // avoid bank conflict
34
35 // B matrix tile (vec, k)
36 __shared__ BinaryType vecTile[kLanes][kLanes + 1]; // avoid bank conflict
37
38 WarpSelect<
39 int,
40 int,
41 false,
42 Comparator<int>,
43 NumWarpQ,
44 NumThreadQ,
45 kWarps * kLanes>
46 heap(kMaxDistance, -1, k);
47
48 int warpId = threadIdx.y;
49 int laneId = threadIdx.x;
50
51 // Each warp handles a single query
52 int warpQuery = blockIdx.x * kWarps + warpId;
53 bool queryInBounds = warpQuery < query.getSize(0);
54
55 // Each warp loops through the entire chunk of vectors
56 for (int blockVec = 0; blockVec < vecs.getSize(0); blockVec += kLanes) {
57 int threadDistance = 0;
58
59 // Reduction dimension
60 for (int blockK = 0; blockK < vecs.getSize(1); blockK += kLanes) {
61 int laneK = blockK + laneId;
62 bool kInBounds = laneK < vecs.getSize(1);
63
64 queryTile[warpId][laneId] =
65 queryInBounds && kInBounds ? query[warpQuery][laneK] : 0;
66
67 // kWarps warps are responsible for loading 32 vecs
68 #pragma unroll
69 for (int i = 0; i < kLanes / kWarps; ++i) {
70 int warpVec = i * kWarps + warpId;
71 int vec = blockVec + warpVec;
72 bool vecInBounds = vec < vecs.getSize(0);
73
74 vecTile[warpVec][laneId] =
75 vecInBounds && kInBounds ? vecs[vec][laneK] : 0;
76 }
77
78 __syncthreads();
79
80 // Compare distances
81 #pragma unroll
82 for (int i = 0; i < kLanes; ++i) {
83 threadDistance +=
84 __popc(queryTile[warpId][i] ^ vecTile[laneId][i]);
85 }
86
87 __syncthreads();
88 }
89
90 // Lanes within a warp are different vec results against the same query
91 // Only submit distances which represent real (query, vec) pairs
92 bool valInBounds =
93 queryInBounds && (blockVec + laneId < vecs.getSize(0));
94 threadDistance = valInBounds ? threadDistance : kMaxDistance;
95 int id = valInBounds ? blockVec + laneId : -1;
96
97 heap.add(threadDistance, id);
98 }
99
100 heap.reduce();
101
102 if (warpQuery < query.getSize(0)) {
103 heap.writeOut(outK[warpQuery].data(), outV[warpQuery].data(), k);
104 }
105 }
106
107 // Version of the kernel that avoids a loop over the reduction dimension, and
108 // thus avoids reloading the query vectors
109 template <
110 int NumWarpQ,
111 int NumThreadQ,
112 typename BinaryType,
113 int ReductionLimit = kLanes>
__launch_bounds__(kWarps * kLanes)114 __global__ void __launch_bounds__(kWarps* kLanes) binaryDistanceLimitSize(
115 const Tensor<BinaryType, 2, true> vecs,
116 const Tensor<BinaryType, 2, true> query,
117 Tensor<int, 2, true> outK,
118 Tensor<int, 2, true> outV,
119 int k) {
120 // A matrix tile (query, k)
121 __shared__ BinaryType queryTile[kWarps][kLanes + 1]; // avoid bank conflict
122
123 // B matrix tile (vec, k)
124 __shared__ BinaryType vecTile[kLanes][kLanes + 1]; // avoid bank conflict
125
126 WarpSelect<
127 int,
128 int,
129 false,
130 Comparator<int>,
131 NumWarpQ,
132 NumThreadQ,
133 kWarps * kLanes>
134 heap(kMaxDistance, -1, k);
135
136 int warpId = threadIdx.y;
137 int laneId = threadIdx.x;
138
139 // Each warp handles a single query
140 int laneK = laneId;
141 int warpQuery = blockIdx.x * kWarps + warpId;
142 bool kInBounds = laneK < vecs.getSize(1);
143 bool queryInBounds = warpQuery < query.getSize(0);
144
145 queryTile[warpId][laneId] =
146 queryInBounds && kInBounds ? query[warpQuery][laneK] : 0;
147
148 // Each warp loops through the entire chunk of vectors
149 for (int blockVec = 0; blockVec < vecs.getSize(0); blockVec += kLanes) {
150 int threadDistance = 0;
151
152 // kWarps warps are responsible for loading 32 vecs
153 #pragma unroll
154 for (int i = 0; i < kLanes / kWarps; ++i) {
155 int warpVec = i * kWarps + warpId;
156 int vec = blockVec + warpVec;
157 bool vecInBounds = vec < vecs.getSize(0);
158
159 vecTile[warpVec][laneId] =
160 vecInBounds && kInBounds ? vecs[vec][laneK] : 0;
161 }
162
163 __syncthreads();
164
165 // Compare distances
166 #pragma unroll
167 for (int i = 0; i < ReductionLimit; ++i) {
168 threadDistance += __popc(queryTile[warpId][i] ^ vecTile[laneId][i]);
169 }
170
171 __syncthreads();
172
173 // Lanes within a warp are different vec results against the same query
174 // Only submit distances which represent real (query, vec) pairs
175 bool valInBounds =
176 queryInBounds && (blockVec + laneId < vecs.getSize(0));
177 threadDistance = valInBounds ? threadDistance : kMaxDistance;
178 int id = valInBounds ? blockVec + laneId : -1;
179
180 heap.add(threadDistance, id);
181 }
182
183 heap.reduce();
184
185 if (warpQuery < query.getSize(0)) {
186 heap.writeOut(outK[warpQuery].data(), outV[warpQuery].data(), k);
187 }
188 }
189
190 template <typename BinaryType>
runBinaryDistanceAnySize(Tensor<BinaryType,2,true> & vecs,Tensor<BinaryType,2,true> & query,Tensor<int,2,true> & outK,Tensor<int,2,true> & outV,int k,cudaStream_t stream)191 void runBinaryDistanceAnySize(
192 Tensor<BinaryType, 2, true>& vecs,
193 Tensor<BinaryType, 2, true>& query,
194 Tensor<int, 2, true>& outK,
195 Tensor<int, 2, true>& outV,
196 int k,
197 cudaStream_t stream) {
198 dim3 grid(utils::divUp(query.getSize(0), kWarps));
199 dim3 block(kLanes, kWarps);
200
201 if (k == 1) {
202 binaryDistanceAnySize<1, 1, BinaryType>
203 <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
204 } else if (k <= 32) {
205 binaryDistanceAnySize<32, 2, BinaryType>
206 <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
207 } else if (k <= 64) {
208 binaryDistanceAnySize<64, 3, BinaryType>
209 <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
210 } else if (k <= 128) {
211 binaryDistanceAnySize<128, 3, BinaryType>
212 <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
213 } else if (k <= 256) {
214 binaryDistanceAnySize<256, 4, BinaryType>
215 <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
216 } else if (k <= 512) {
217 binaryDistanceAnySize<512, 8, BinaryType>
218 <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
219 } else if (k <= 1024) {
220 binaryDistanceAnySize<1024, 8, BinaryType>
221 <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
222 }
223 #if GPU_MAX_SELECTION_K >= 2048
224 else if (k <= 2048) {
225 binaryDistanceAnySize<2048, 8, BinaryType>
226 <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
227 }
228 #endif
229 }
230
231 template <typename BinaryType, int ReductionLimit>
runBinaryDistanceLimitSize(Tensor<BinaryType,2,true> & vecs,Tensor<BinaryType,2,true> & query,Tensor<int,2,true> & outK,Tensor<int,2,true> & outV,int k,cudaStream_t stream)232 void runBinaryDistanceLimitSize(
233 Tensor<BinaryType, 2, true>& vecs,
234 Tensor<BinaryType, 2, true>& query,
235 Tensor<int, 2, true>& outK,
236 Tensor<int, 2, true>& outV,
237 int k,
238 cudaStream_t stream) {
239 dim3 grid(utils::divUp(query.getSize(0), kWarps));
240 dim3 block(kLanes, kWarps);
241
242 if (k == 1) {
243 binaryDistanceLimitSize<1, 1, BinaryType, ReductionLimit>
244 <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
245 } else if (k <= 32) {
246 binaryDistanceLimitSize<32, 2, BinaryType, ReductionLimit>
247 <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
248 } else if (k <= 64) {
249 binaryDistanceLimitSize<64, 3, BinaryType, ReductionLimit>
250 <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
251 } else if (k <= 128) {
252 binaryDistanceLimitSize<128, 3, BinaryType, ReductionLimit>
253 <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
254 } else if (k <= 256) {
255 binaryDistanceLimitSize<256, 4, BinaryType, ReductionLimit>
256 <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
257 } else if (k <= 512) {
258 binaryDistanceLimitSize<512, 8, BinaryType, ReductionLimit>
259 <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
260 } else if (k <= 1024) {
261 binaryDistanceLimitSize<1024, 8, BinaryType, ReductionLimit>
262 <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
263 }
264 #if GPU_MAX_SELECTION_K >= 2048
265 else if (k <= 2048) {
266 binaryDistanceLimitSize<2048, 8, BinaryType, ReductionLimit>
267 <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
268 }
269 #endif
270 }
271
runBinaryDistance(Tensor<unsigned char,2,true> & vecs,Tensor<unsigned char,2,true> & query,Tensor<int,2,true> & outK,Tensor<int,2,true> & outV,int k,cudaStream_t stream)272 void runBinaryDistance(
273 Tensor<unsigned char, 2, true>& vecs,
274 Tensor<unsigned char, 2, true>& query,
275 Tensor<int, 2, true>& outK,
276 Tensor<int, 2, true>& outV,
277 int k,
278 cudaStream_t stream) {
279 FAISS_ASSERT(k <= GPU_MAX_SELECTION_K);
280 FAISS_ASSERT(vecs.getSize(1) == query.getSize(1));
281
282 FAISS_ASSERT(outK.getSize(1) == k);
283 FAISS_ASSERT(outV.getSize(1) == k);
284
285 // For the optimized uint32 kernel, we handle 32 * 8 = 256 max dims
286 constexpr int kReductionLimit32 = 8;
287
288 // For the optimized uint8 kernel, we handle 8 * 16 = 128 max dims
289 constexpr int kReductionLimit8 = 16;
290
291 // All other cases (large or small) go through the general kernel
292
293 if (vecs.getSize(1) % sizeof(unsigned int) == 0 &&
294 (vecs.getSize(1) / sizeof(unsigned int)) <= kReductionLimit32) {
295 auto vecs32 = vecs.castResize<unsigned int>();
296 auto query32 = query.castResize<unsigned int>();
297
298 // Optimize for vectors with dimensions a multiple of 32 that are less
299 // than 32 * kReductionLimit (256) dimensions in size
300 runBinaryDistanceLimitSize<unsigned int, kReductionLimit32>(
301 vecs32, query32, outK, outV, k, stream);
302
303 } else if (vecs.getSize(1) <= kReductionLimit8) {
304 // Optimize for vectors with dimensions a multiple of 32 that are less
305 // than 32 * kReductionLimit (256) dimensions in size
306 runBinaryDistanceLimitSize<unsigned char, kReductionLimit8>(
307 vecs, query, outK, outV, k, stream);
308 } else {
309 // Arbitrary size kernel
310 runBinaryDistanceAnySize<unsigned char>(
311 vecs, query, outK, outV, k, stream);
312 }
313 }
314
315 } // namespace gpu
316 } // namespace faiss
317