1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #include <faiss/gpu/utils/DeviceUtils.h>
9 #include <faiss/gpu/utils/DeviceDefs.cuh>
10 #include <faiss/gpu/utils/DeviceTensor.cuh>
11 #include <faiss/gpu/utils/Select.cuh>
12 
13 namespace faiss {
14 namespace gpu {
15 
16 // Number of warps that the kernel is instantiated with
17 constexpr int kWarps = 8;
18 constexpr int kLanes = kWarpSize;
19 
20 constexpr int kMaxDistance = std::numeric_limits<int>::max();
21 
22 // Performs a binary matrix multiplication, returning the lowest k results in
23 // `vecs` for each `query` in terms of Hamming distance (a fused kernel)
24 // Each warp calculates distance for a single query
25 template <int NumWarpQ, int NumThreadQ, typename BinaryType>
__launch_bounds__(kWarps * kLanes)26 __launch_bounds__(kWarps* kLanes) __global__ void binaryDistanceAnySize(
27         const Tensor<BinaryType, 2, true> vecs,
28         const Tensor<BinaryType, 2, true> query,
29         Tensor<int, 2, true> outK,
30         Tensor<int, 2, true> outV,
31         int k) {
32     // A matrix tile (query, k)
33     __shared__ BinaryType queryTile[kWarps][kLanes + 1]; // avoid bank conflict
34 
35     // B matrix tile (vec, k)
36     __shared__ BinaryType vecTile[kLanes][kLanes + 1]; // avoid bank conflict
37 
38     WarpSelect<
39             int,
40             int,
41             false,
42             Comparator<int>,
43             NumWarpQ,
44             NumThreadQ,
45             kWarps * kLanes>
46             heap(kMaxDistance, -1, k);
47 
48     int warpId = threadIdx.y;
49     int laneId = threadIdx.x;
50 
51     // Each warp handles a single query
52     int warpQuery = blockIdx.x * kWarps + warpId;
53     bool queryInBounds = warpQuery < query.getSize(0);
54 
55     // Each warp loops through the entire chunk of vectors
56     for (int blockVec = 0; blockVec < vecs.getSize(0); blockVec += kLanes) {
57         int threadDistance = 0;
58 
59         // Reduction dimension
60         for (int blockK = 0; blockK < vecs.getSize(1); blockK += kLanes) {
61             int laneK = blockK + laneId;
62             bool kInBounds = laneK < vecs.getSize(1);
63 
64             queryTile[warpId][laneId] =
65                     queryInBounds && kInBounds ? query[warpQuery][laneK] : 0;
66 
67             // kWarps warps are responsible for loading 32 vecs
68 #pragma unroll
69             for (int i = 0; i < kLanes / kWarps; ++i) {
70                 int warpVec = i * kWarps + warpId;
71                 int vec = blockVec + warpVec;
72                 bool vecInBounds = vec < vecs.getSize(0);
73 
74                 vecTile[warpVec][laneId] =
75                         vecInBounds && kInBounds ? vecs[vec][laneK] : 0;
76             }
77 
78             __syncthreads();
79 
80             // Compare distances
81 #pragma unroll
82             for (int i = 0; i < kLanes; ++i) {
83                 threadDistance +=
84                         __popc(queryTile[warpId][i] ^ vecTile[laneId][i]);
85             }
86 
87             __syncthreads();
88         }
89 
90         // Lanes within a warp are different vec results against the same query
91         // Only submit distances which represent real (query, vec) pairs
92         bool valInBounds =
93                 queryInBounds && (blockVec + laneId < vecs.getSize(0));
94         threadDistance = valInBounds ? threadDistance : kMaxDistance;
95         int id = valInBounds ? blockVec + laneId : -1;
96 
97         heap.add(threadDistance, id);
98     }
99 
100     heap.reduce();
101 
102     if (warpQuery < query.getSize(0)) {
103         heap.writeOut(outK[warpQuery].data(), outV[warpQuery].data(), k);
104     }
105 }
106 
107 // Version of the kernel that avoids a loop over the reduction dimension, and
108 // thus avoids reloading the query vectors
109 template <
110         int NumWarpQ,
111         int NumThreadQ,
112         typename BinaryType,
113         int ReductionLimit = kLanes>
__launch_bounds__(kWarps * kLanes)114 __global__ void __launch_bounds__(kWarps* kLanes) binaryDistanceLimitSize(
115         const Tensor<BinaryType, 2, true> vecs,
116         const Tensor<BinaryType, 2, true> query,
117         Tensor<int, 2, true> outK,
118         Tensor<int, 2, true> outV,
119         int k) {
120     // A matrix tile (query, k)
121     __shared__ BinaryType queryTile[kWarps][kLanes + 1]; // avoid bank conflict
122 
123     // B matrix tile (vec, k)
124     __shared__ BinaryType vecTile[kLanes][kLanes + 1]; // avoid bank conflict
125 
126     WarpSelect<
127             int,
128             int,
129             false,
130             Comparator<int>,
131             NumWarpQ,
132             NumThreadQ,
133             kWarps * kLanes>
134             heap(kMaxDistance, -1, k);
135 
136     int warpId = threadIdx.y;
137     int laneId = threadIdx.x;
138 
139     // Each warp handles a single query
140     int laneK = laneId;
141     int warpQuery = blockIdx.x * kWarps + warpId;
142     bool kInBounds = laneK < vecs.getSize(1);
143     bool queryInBounds = warpQuery < query.getSize(0);
144 
145     queryTile[warpId][laneId] =
146             queryInBounds && kInBounds ? query[warpQuery][laneK] : 0;
147 
148     // Each warp loops through the entire chunk of vectors
149     for (int blockVec = 0; blockVec < vecs.getSize(0); blockVec += kLanes) {
150         int threadDistance = 0;
151 
152         // kWarps warps are responsible for loading 32 vecs
153 #pragma unroll
154         for (int i = 0; i < kLanes / kWarps; ++i) {
155             int warpVec = i * kWarps + warpId;
156             int vec = blockVec + warpVec;
157             bool vecInBounds = vec < vecs.getSize(0);
158 
159             vecTile[warpVec][laneId] =
160                     vecInBounds && kInBounds ? vecs[vec][laneK] : 0;
161         }
162 
163         __syncthreads();
164 
165         // Compare distances
166 #pragma unroll
167         for (int i = 0; i < ReductionLimit; ++i) {
168             threadDistance += __popc(queryTile[warpId][i] ^ vecTile[laneId][i]);
169         }
170 
171         __syncthreads();
172 
173         // Lanes within a warp are different vec results against the same query
174         // Only submit distances which represent real (query, vec) pairs
175         bool valInBounds =
176                 queryInBounds && (blockVec + laneId < vecs.getSize(0));
177         threadDistance = valInBounds ? threadDistance : kMaxDistance;
178         int id = valInBounds ? blockVec + laneId : -1;
179 
180         heap.add(threadDistance, id);
181     }
182 
183     heap.reduce();
184 
185     if (warpQuery < query.getSize(0)) {
186         heap.writeOut(outK[warpQuery].data(), outV[warpQuery].data(), k);
187     }
188 }
189 
190 template <typename BinaryType>
runBinaryDistanceAnySize(Tensor<BinaryType,2,true> & vecs,Tensor<BinaryType,2,true> & query,Tensor<int,2,true> & outK,Tensor<int,2,true> & outV,int k,cudaStream_t stream)191 void runBinaryDistanceAnySize(
192         Tensor<BinaryType, 2, true>& vecs,
193         Tensor<BinaryType, 2, true>& query,
194         Tensor<int, 2, true>& outK,
195         Tensor<int, 2, true>& outV,
196         int k,
197         cudaStream_t stream) {
198     dim3 grid(utils::divUp(query.getSize(0), kWarps));
199     dim3 block(kLanes, kWarps);
200 
201     if (k == 1) {
202         binaryDistanceAnySize<1, 1, BinaryType>
203                 <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
204     } else if (k <= 32) {
205         binaryDistanceAnySize<32, 2, BinaryType>
206                 <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
207     } else if (k <= 64) {
208         binaryDistanceAnySize<64, 3, BinaryType>
209                 <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
210     } else if (k <= 128) {
211         binaryDistanceAnySize<128, 3, BinaryType>
212                 <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
213     } else if (k <= 256) {
214         binaryDistanceAnySize<256, 4, BinaryType>
215                 <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
216     } else if (k <= 512) {
217         binaryDistanceAnySize<512, 8, BinaryType>
218                 <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
219     } else if (k <= 1024) {
220         binaryDistanceAnySize<1024, 8, BinaryType>
221                 <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
222     }
223 #if GPU_MAX_SELECTION_K >= 2048
224     else if (k <= 2048) {
225         binaryDistanceAnySize<2048, 8, BinaryType>
226                 <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
227     }
228 #endif
229 }
230 
231 template <typename BinaryType, int ReductionLimit>
runBinaryDistanceLimitSize(Tensor<BinaryType,2,true> & vecs,Tensor<BinaryType,2,true> & query,Tensor<int,2,true> & outK,Tensor<int,2,true> & outV,int k,cudaStream_t stream)232 void runBinaryDistanceLimitSize(
233         Tensor<BinaryType, 2, true>& vecs,
234         Tensor<BinaryType, 2, true>& query,
235         Tensor<int, 2, true>& outK,
236         Tensor<int, 2, true>& outV,
237         int k,
238         cudaStream_t stream) {
239     dim3 grid(utils::divUp(query.getSize(0), kWarps));
240     dim3 block(kLanes, kWarps);
241 
242     if (k == 1) {
243         binaryDistanceLimitSize<1, 1, BinaryType, ReductionLimit>
244                 <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
245     } else if (k <= 32) {
246         binaryDistanceLimitSize<32, 2, BinaryType, ReductionLimit>
247                 <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
248     } else if (k <= 64) {
249         binaryDistanceLimitSize<64, 3, BinaryType, ReductionLimit>
250                 <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
251     } else if (k <= 128) {
252         binaryDistanceLimitSize<128, 3, BinaryType, ReductionLimit>
253                 <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
254     } else if (k <= 256) {
255         binaryDistanceLimitSize<256, 4, BinaryType, ReductionLimit>
256                 <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
257     } else if (k <= 512) {
258         binaryDistanceLimitSize<512, 8, BinaryType, ReductionLimit>
259                 <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
260     } else if (k <= 1024) {
261         binaryDistanceLimitSize<1024, 8, BinaryType, ReductionLimit>
262                 <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
263     }
264 #if GPU_MAX_SELECTION_K >= 2048
265     else if (k <= 2048) {
266         binaryDistanceLimitSize<2048, 8, BinaryType, ReductionLimit>
267                 <<<grid, block, 0, stream>>>(vecs, query, outK, outV, k);
268     }
269 #endif
270 }
271 
runBinaryDistance(Tensor<unsigned char,2,true> & vecs,Tensor<unsigned char,2,true> & query,Tensor<int,2,true> & outK,Tensor<int,2,true> & outV,int k,cudaStream_t stream)272 void runBinaryDistance(
273         Tensor<unsigned char, 2, true>& vecs,
274         Tensor<unsigned char, 2, true>& query,
275         Tensor<int, 2, true>& outK,
276         Tensor<int, 2, true>& outV,
277         int k,
278         cudaStream_t stream) {
279     FAISS_ASSERT(k <= GPU_MAX_SELECTION_K);
280     FAISS_ASSERT(vecs.getSize(1) == query.getSize(1));
281 
282     FAISS_ASSERT(outK.getSize(1) == k);
283     FAISS_ASSERT(outV.getSize(1) == k);
284 
285     // For the optimized uint32 kernel, we handle 32 * 8 = 256 max dims
286     constexpr int kReductionLimit32 = 8;
287 
288     // For the optimized uint8 kernel, we handle 8 * 16 = 128 max dims
289     constexpr int kReductionLimit8 = 16;
290 
291     // All other cases (large or small) go through the general kernel
292 
293     if (vecs.getSize(1) % sizeof(unsigned int) == 0 &&
294         (vecs.getSize(1) / sizeof(unsigned int)) <= kReductionLimit32) {
295         auto vecs32 = vecs.castResize<unsigned int>();
296         auto query32 = query.castResize<unsigned int>();
297 
298         // Optimize for vectors with dimensions a multiple of 32 that are less
299         // than 32 * kReductionLimit (256) dimensions in size
300         runBinaryDistanceLimitSize<unsigned int, kReductionLimit32>(
301                 vecs32, query32, outK, outV, k, stream);
302 
303     } else if (vecs.getSize(1) <= kReductionLimit8) {
304         // Optimize for vectors with dimensions a multiple of 32 that are less
305         // than 32 * kReductionLimit (256) dimensions in size
306         runBinaryDistanceLimitSize<unsigned char, kReductionLimit8>(
307                 vecs, query, outK, outV, k, stream);
308     } else {
309         // Arbitrary size kernel
310         runBinaryDistanceAnySize<unsigned char>(
311                 vecs, query, outK, outV, k, stream);
312     }
313 }
314 
315 } // namespace gpu
316 } // namespace faiss
317