1 /**
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 *
4 * This source code is licensed under the MIT license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
8 #include <faiss/gpu/utils/DeviceUtils.h>
9 #include <faiss/gpu/utils/StaticUtils.h>
10 #include <faiss/gpu/impl/IVFUtils.cuh>
11 #include <faiss/gpu/utils/DeviceDefs.cuh>
12 #include <faiss/gpu/utils/Limits.cuh>
13 #include <faiss/gpu/utils/Select.cuh>
14 #include <faiss/gpu/utils/Tensor.cuh>
15
16 //
17 // This kernel is split into a separate compilation unit to cut down
18 // on compile time
19 //
20
21 namespace faiss {
22 namespace gpu {
23
24 // This is warp divergence central, but this is really a final step
25 // and happening a small number of times
binarySearchForBucket(int * prefixSumOffsets,int size,int val)26 inline __device__ int binarySearchForBucket(
27 int* prefixSumOffsets,
28 int size,
29 int val) {
30 int start = 0;
31 int end = size;
32
33 while (end - start > 0) {
34 int mid = start + (end - start) / 2;
35
36 int midVal = prefixSumOffsets[mid];
37
38 // Find the first bucket that we are <=
39 if (midVal <= val) {
40 start = mid + 1;
41 } else {
42 end = mid;
43 }
44 }
45
46 // We must find the bucket that it is in
47 assert(start != size);
48
49 return start;
50 }
51
52 template <int ThreadsPerBlock, int NumWarpQ, int NumThreadQ, bool Dir>
pass2SelectLists(Tensor<float,2,true> heapDistances,Tensor<int,2,true> heapIndices,void ** listIndices,Tensor<int,2,true> prefixSumOffsets,Tensor<int,2,true> topQueryToCentroid,int k,IndicesOptions opt,Tensor<float,2,true> outDistances,Tensor<Index::idx_t,2,true> outIndices)53 __global__ void pass2SelectLists(
54 Tensor<float, 2, true> heapDistances,
55 Tensor<int, 2, true> heapIndices,
56 void** listIndices,
57 Tensor<int, 2, true> prefixSumOffsets,
58 Tensor<int, 2, true> topQueryToCentroid,
59 int k,
60 IndicesOptions opt,
61 Tensor<float, 2, true> outDistances,
62 Tensor<Index::idx_t, 2, true> outIndices) {
63 constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
64
65 __shared__ float smemK[kNumWarps * NumWarpQ];
66 __shared__ int smemV[kNumWarps * NumWarpQ];
67
68 constexpr auto kInit = Dir ? kFloatMin : kFloatMax;
69 BlockSelect<
70 float,
71 int,
72 Dir,
73 Comparator<float>,
74 NumWarpQ,
75 NumThreadQ,
76 ThreadsPerBlock>
77 heap(kInit, -1, smemK, smemV, k);
78
79 auto queryId = blockIdx.x;
80 int num = heapDistances.getSize(1);
81 int limit = utils::roundDown(num, kWarpSize);
82
83 int i = threadIdx.x;
84 auto heapDistanceStart = heapDistances[queryId];
85
86 // BlockSelect add cannot be used in a warp divergent circumstance; we
87 // handle the remainder warp below
88 for (; i < limit; i += blockDim.x) {
89 heap.add(heapDistanceStart[i], i);
90 }
91
92 // Handle warp divergence separately
93 if (i < num) {
94 heap.addThreadQ(heapDistanceStart[i], i);
95 }
96
97 // Merge all final results
98 heap.reduce();
99
100 for (int i = threadIdx.x; i < k; i += blockDim.x) {
101 outDistances[queryId][i] = smemK[i];
102
103 // `v` is the index in `heapIndices`
104 // We need to translate this into an original user index. The
105 // reason why we don't maintain intermediate results in terms of
106 // user indices is to substantially reduce temporary memory
107 // requirements and global memory write traffic for the list
108 // scanning.
109 // This code is highly divergent, but it's probably ok, since this
110 // is the very last step and it is happening a small number of
111 // times (#queries x k).
112 int v = smemV[i];
113 Index::idx_t index = -1;
114
115 if (v != -1) {
116 // `offset` is the offset of the intermediate result, as
117 // calculated by the original scan.
118 int offset = heapIndices[queryId][v];
119
120 // In order to determine the actual user index, we need to first
121 // determine what list it was in.
122 // We do this by binary search in the prefix sum list.
123 int probe = binarySearchForBucket(
124 prefixSumOffsets[queryId].data(),
125 prefixSumOffsets.getSize(1),
126 offset);
127
128 // This is then the probe for the query; we can find the actual
129 // list ID from this
130 int listId = topQueryToCentroid[queryId][probe];
131
132 // Now, we need to know the offset within the list
133 // We ensure that before the array (at offset -1), there is a 0
134 // value
135 int listStart = *(prefixSumOffsets[queryId][probe].data() - 1);
136 int listOffset = offset - listStart;
137
138 // This gives us our final index
139 if (opt == INDICES_32_BIT) {
140 index = (Index::idx_t)((int*)listIndices[listId])[listOffset];
141 } else if (opt == INDICES_64_BIT) {
142 index = ((Index::idx_t*)listIndices[listId])[listOffset];
143 } else {
144 index = ((Index::idx_t)listId << 32 | (Index::idx_t)listOffset);
145 }
146 }
147
148 outIndices[queryId][i] = index;
149 }
150 }
151
runPass2SelectLists(Tensor<float,2,true> & heapDistances,Tensor<int,2,true> & heapIndices,thrust::device_vector<void * > & listIndices,IndicesOptions indicesOptions,Tensor<int,2,true> & prefixSumOffsets,Tensor<int,2,true> & topQueryToCentroid,int k,bool chooseLargest,Tensor<float,2,true> & outDistances,Tensor<Index::idx_t,2,true> & outIndices,cudaStream_t stream)152 void runPass2SelectLists(
153 Tensor<float, 2, true>& heapDistances,
154 Tensor<int, 2, true>& heapIndices,
155 thrust::device_vector<void*>& listIndices,
156 IndicesOptions indicesOptions,
157 Tensor<int, 2, true>& prefixSumOffsets,
158 Tensor<int, 2, true>& topQueryToCentroid,
159 int k,
160 bool chooseLargest,
161 Tensor<float, 2, true>& outDistances,
162 Tensor<Index::idx_t, 2, true>& outIndices,
163 cudaStream_t stream) {
164 auto grid = dim3(topQueryToCentroid.getSize(0));
165
166 #define RUN_PASS(BLOCK, NUM_WARP_Q, NUM_THREAD_Q, DIR) \
167 do { \
168 pass2SelectLists<BLOCK, NUM_WARP_Q, NUM_THREAD_Q, DIR> \
169 <<<grid, BLOCK, 0, stream>>>( \
170 heapDistances, \
171 heapIndices, \
172 listIndices.data().get(), \
173 prefixSumOffsets, \
174 topQueryToCentroid, \
175 k, \
176 indicesOptions, \
177 outDistances, \
178 outIndices); \
179 CUDA_TEST_ERROR(); \
180 return; /* success */ \
181 } while (0)
182
183 #if GPU_MAX_SELECTION_K >= 2048
184
185 // block size 128 for k <= 1024, 64 for k = 2048
186 #define RUN_PASS_DIR(DIR) \
187 do { \
188 if (k == 1) { \
189 RUN_PASS(128, 1, 1, DIR); \
190 } else if (k <= 32) { \
191 RUN_PASS(128, 32, 2, DIR); \
192 } else if (k <= 64) { \
193 RUN_PASS(128, 64, 3, DIR); \
194 } else if (k <= 128) { \
195 RUN_PASS(128, 128, 3, DIR); \
196 } else if (k <= 256) { \
197 RUN_PASS(128, 256, 4, DIR); \
198 } else if (k <= 512) { \
199 RUN_PASS(128, 512, 8, DIR); \
200 } else if (k <= 1024) { \
201 RUN_PASS(128, 1024, 8, DIR); \
202 } else if (k <= 2048) { \
203 RUN_PASS(64, 2048, 8, DIR); \
204 } \
205 } while (0)
206
207 #else
208
209 #define RUN_PASS_DIR(DIR) \
210 do { \
211 if (k == 1) { \
212 RUN_PASS(128, 1, 1, DIR); \
213 } else if (k <= 32) { \
214 RUN_PASS(128, 32, 2, DIR); \
215 } else if (k <= 64) { \
216 RUN_PASS(128, 64, 3, DIR); \
217 } else if (k <= 128) { \
218 RUN_PASS(128, 128, 3, DIR); \
219 } else if (k <= 256) { \
220 RUN_PASS(128, 256, 4, DIR); \
221 } else if (k <= 512) { \
222 RUN_PASS(128, 512, 8, DIR); \
223 } else if (k <= 1024) { \
224 RUN_PASS(128, 1024, 8, DIR); \
225 } \
226 } while (0)
227
228 #endif // GPU_MAX_SELECTION_K
229
230 if (chooseLargest) {
231 RUN_PASS_DIR(true);
232 } else {
233 RUN_PASS_DIR(false);
234 }
235
236 // unimplemented / too many resources
237 FAISS_ASSERT_FMT(false, "unimplemented k value (%d)", k);
238
239 #undef RUN_PASS_DIR
240 #undef RUN_PASS
241 }
242
243 } // namespace gpu
244 } // namespace faiss
245