1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #include <faiss/gpu/utils/DeviceUtils.h>
9 #include <faiss/gpu/utils/StaticUtils.h>
10 #include <faiss/gpu/impl/IVFUtils.cuh>
11 #include <faiss/gpu/utils/DeviceDefs.cuh>
12 #include <faiss/gpu/utils/Limits.cuh>
13 #include <faiss/gpu/utils/Select.cuh>
14 #include <faiss/gpu/utils/Tensor.cuh>
15 
16 //
17 // This kernel is split into a separate compilation unit to cut down
18 // on compile time
19 //
20 
21 namespace faiss {
22 namespace gpu {
23 
24 // This is warp divergence central, but this is really a final step
25 // and happening a small number of times
binarySearchForBucket(int * prefixSumOffsets,int size,int val)26 inline __device__ int binarySearchForBucket(
27         int* prefixSumOffsets,
28         int size,
29         int val) {
30     int start = 0;
31     int end = size;
32 
33     while (end - start > 0) {
34         int mid = start + (end - start) / 2;
35 
36         int midVal = prefixSumOffsets[mid];
37 
38         // Find the first bucket that we are <=
39         if (midVal <= val) {
40             start = mid + 1;
41         } else {
42             end = mid;
43         }
44     }
45 
46     // We must find the bucket that it is in
47     assert(start != size);
48 
49     return start;
50 }
51 
52 template <int ThreadsPerBlock, int NumWarpQ, int NumThreadQ, bool Dir>
pass2SelectLists(Tensor<float,2,true> heapDistances,Tensor<int,2,true> heapIndices,void ** listIndices,Tensor<int,2,true> prefixSumOffsets,Tensor<int,2,true> topQueryToCentroid,int k,IndicesOptions opt,Tensor<float,2,true> outDistances,Tensor<Index::idx_t,2,true> outIndices)53 __global__ void pass2SelectLists(
54         Tensor<float, 2, true> heapDistances,
55         Tensor<int, 2, true> heapIndices,
56         void** listIndices,
57         Tensor<int, 2, true> prefixSumOffsets,
58         Tensor<int, 2, true> topQueryToCentroid,
59         int k,
60         IndicesOptions opt,
61         Tensor<float, 2, true> outDistances,
62         Tensor<Index::idx_t, 2, true> outIndices) {
63     constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
64 
65     __shared__ float smemK[kNumWarps * NumWarpQ];
66     __shared__ int smemV[kNumWarps * NumWarpQ];
67 
68     constexpr auto kInit = Dir ? kFloatMin : kFloatMax;
69     BlockSelect<
70             float,
71             int,
72             Dir,
73             Comparator<float>,
74             NumWarpQ,
75             NumThreadQ,
76             ThreadsPerBlock>
77             heap(kInit, -1, smemK, smemV, k);
78 
79     auto queryId = blockIdx.x;
80     int num = heapDistances.getSize(1);
81     int limit = utils::roundDown(num, kWarpSize);
82 
83     int i = threadIdx.x;
84     auto heapDistanceStart = heapDistances[queryId];
85 
86     // BlockSelect add cannot be used in a warp divergent circumstance; we
87     // handle the remainder warp below
88     for (; i < limit; i += blockDim.x) {
89         heap.add(heapDistanceStart[i], i);
90     }
91 
92     // Handle warp divergence separately
93     if (i < num) {
94         heap.addThreadQ(heapDistanceStart[i], i);
95     }
96 
97     // Merge all final results
98     heap.reduce();
99 
100     for (int i = threadIdx.x; i < k; i += blockDim.x) {
101         outDistances[queryId][i] = smemK[i];
102 
103         // `v` is the index in `heapIndices`
104         // We need to translate this into an original user index. The
105         // reason why we don't maintain intermediate results in terms of
106         // user indices is to substantially reduce temporary memory
107         // requirements and global memory write traffic for the list
108         // scanning.
109         // This code is highly divergent, but it's probably ok, since this
110         // is the very last step and it is happening a small number of
111         // times (#queries x k).
112         int v = smemV[i];
113         Index::idx_t index = -1;
114 
115         if (v != -1) {
116             // `offset` is the offset of the intermediate result, as
117             // calculated by the original scan.
118             int offset = heapIndices[queryId][v];
119 
120             // In order to determine the actual user index, we need to first
121             // determine what list it was in.
122             // We do this by binary search in the prefix sum list.
123             int probe = binarySearchForBucket(
124                     prefixSumOffsets[queryId].data(),
125                     prefixSumOffsets.getSize(1),
126                     offset);
127 
128             // This is then the probe for the query; we can find the actual
129             // list ID from this
130             int listId = topQueryToCentroid[queryId][probe];
131 
132             // Now, we need to know the offset within the list
133             // We ensure that before the array (at offset -1), there is a 0
134             // value
135             int listStart = *(prefixSumOffsets[queryId][probe].data() - 1);
136             int listOffset = offset - listStart;
137 
138             // This gives us our final index
139             if (opt == INDICES_32_BIT) {
140                 index = (Index::idx_t)((int*)listIndices[listId])[listOffset];
141             } else if (opt == INDICES_64_BIT) {
142                 index = ((Index::idx_t*)listIndices[listId])[listOffset];
143             } else {
144                 index = ((Index::idx_t)listId << 32 | (Index::idx_t)listOffset);
145             }
146         }
147 
148         outIndices[queryId][i] = index;
149     }
150 }
151 
runPass2SelectLists(Tensor<float,2,true> & heapDistances,Tensor<int,2,true> & heapIndices,thrust::device_vector<void * > & listIndices,IndicesOptions indicesOptions,Tensor<int,2,true> & prefixSumOffsets,Tensor<int,2,true> & topQueryToCentroid,int k,bool chooseLargest,Tensor<float,2,true> & outDistances,Tensor<Index::idx_t,2,true> & outIndices,cudaStream_t stream)152 void runPass2SelectLists(
153         Tensor<float, 2, true>& heapDistances,
154         Tensor<int, 2, true>& heapIndices,
155         thrust::device_vector<void*>& listIndices,
156         IndicesOptions indicesOptions,
157         Tensor<int, 2, true>& prefixSumOffsets,
158         Tensor<int, 2, true>& topQueryToCentroid,
159         int k,
160         bool chooseLargest,
161         Tensor<float, 2, true>& outDistances,
162         Tensor<Index::idx_t, 2, true>& outIndices,
163         cudaStream_t stream) {
164     auto grid = dim3(topQueryToCentroid.getSize(0));
165 
166 #define RUN_PASS(BLOCK, NUM_WARP_Q, NUM_THREAD_Q, DIR)         \
167     do {                                                       \
168         pass2SelectLists<BLOCK, NUM_WARP_Q, NUM_THREAD_Q, DIR> \
169                 <<<grid, BLOCK, 0, stream>>>(                  \
170                         heapDistances,                         \
171                         heapIndices,                           \
172                         listIndices.data().get(),              \
173                         prefixSumOffsets,                      \
174                         topQueryToCentroid,                    \
175                         k,                                     \
176                         indicesOptions,                        \
177                         outDistances,                          \
178                         outIndices);                           \
179         CUDA_TEST_ERROR();                                     \
180         return; /* success */                                  \
181     } while (0)
182 
183 #if GPU_MAX_SELECTION_K >= 2048
184 
185     // block size 128 for k <= 1024, 64 for k = 2048
186 #define RUN_PASS_DIR(DIR)                \
187     do {                                 \
188         if (k == 1) {                    \
189             RUN_PASS(128, 1, 1, DIR);    \
190         } else if (k <= 32) {            \
191             RUN_PASS(128, 32, 2, DIR);   \
192         } else if (k <= 64) {            \
193             RUN_PASS(128, 64, 3, DIR);   \
194         } else if (k <= 128) {           \
195             RUN_PASS(128, 128, 3, DIR);  \
196         } else if (k <= 256) {           \
197             RUN_PASS(128, 256, 4, DIR);  \
198         } else if (k <= 512) {           \
199             RUN_PASS(128, 512, 8, DIR);  \
200         } else if (k <= 1024) {          \
201             RUN_PASS(128, 1024, 8, DIR); \
202         } else if (k <= 2048) {          \
203             RUN_PASS(64, 2048, 8, DIR);  \
204         }                                \
205     } while (0)
206 
207 #else
208 
209 #define RUN_PASS_DIR(DIR)                \
210     do {                                 \
211         if (k == 1) {                    \
212             RUN_PASS(128, 1, 1, DIR);    \
213         } else if (k <= 32) {            \
214             RUN_PASS(128, 32, 2, DIR);   \
215         } else if (k <= 64) {            \
216             RUN_PASS(128, 64, 3, DIR);   \
217         } else if (k <= 128) {           \
218             RUN_PASS(128, 128, 3, DIR);  \
219         } else if (k <= 256) {           \
220             RUN_PASS(128, 256, 4, DIR);  \
221         } else if (k <= 512) {           \
222             RUN_PASS(128, 512, 8, DIR);  \
223         } else if (k <= 1024) {          \
224             RUN_PASS(128, 1024, 8, DIR); \
225         }                                \
226     } while (0)
227 
228 #endif // GPU_MAX_SELECTION_K
229 
230     if (chooseLargest) {
231         RUN_PASS_DIR(true);
232     } else {
233         RUN_PASS_DIR(false);
234     }
235 
236     // unimplemented / too many resources
237     FAISS_ASSERT_FMT(false, "unimplemented k value (%d)", k);
238 
239 #undef RUN_PASS_DIR
240 #undef RUN_PASS
241 }
242 
243 } // namespace gpu
244 } // namespace faiss
245