1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #include <faiss/gpu/GpuResources.h>
9 #include <faiss/gpu/utils/DeviceUtils.h>
10 #include <faiss/gpu/utils/StaticUtils.h>
11 #include <faiss/gpu/impl/DistanceUtils.cuh>
12 #include <faiss/gpu/impl/IVFFlatScan.cuh>
13 #include <faiss/gpu/impl/IVFUtils.cuh>
14 #include <faiss/gpu/utils/Comparators.cuh>
15 #include <faiss/gpu/utils/ConversionOperators.cuh>
16 #include <faiss/gpu/utils/DeviceDefs.cuh>
17 #include <faiss/gpu/utils/DeviceTensor.cuh>
18 #include <faiss/gpu/utils/Float16.cuh>
19 #include <faiss/gpu/utils/MathOperators.cuh>
20 #include <faiss/gpu/utils/PtxUtils.cuh>
21 #include <faiss/gpu/utils/Reductions.cuh>
22 
23 #include <algorithm>
24 
25 namespace faiss {
26 namespace gpu {
27 
28 namespace {
29 
30 /// Sort direction per each metric
metricToSortDirection(MetricType mt)31 inline bool metricToSortDirection(MetricType mt) {
32     switch (mt) {
33         case MetricType::METRIC_INNER_PRODUCT:
34             // highest
35             return true;
36         case MetricType::METRIC_L2:
37             // lowest
38             return false;
39         default:
40             // unhandled metric
41             FAISS_ASSERT(false);
42             return false;
43     }
44 }
45 
46 } // namespace
47 
48 // Number of warps we create per block of IVFFlatScan
49 constexpr int kIVFFlatScanWarps = 4;
50 
51 // Works for any dimension size
52 template <typename Codec, typename Metric>
53 struct IVFFlatScan {
scanfaiss::gpu::IVFFlatScan54     static __device__ void scan(
55             float* query,
56             bool useResidual,
57             float* residualBaseSlice,
58             void* vecData,
59             const Codec& codec,
60             const Metric& metric,
61             int numVecs,
62             int dim,
63             float* distanceOut) {
64         // How many separate loading points are there for the decoder?
65         int limit = utils::divDown(dim, Codec::kDimPerIter);
66 
67         // Each warp handles a separate chunk of vectors
68         int warpId = threadIdx.x / kWarpSize;
69         // FIXME: why does getLaneId() not work when we write out below!?!?!
70         int laneId = threadIdx.x % kWarpSize; // getLaneId();
71 
72         // Divide the set of vectors among the warps
73         int vecsPerWarp = utils::divUp(numVecs, kIVFFlatScanWarps);
74 
75         int vecStart = vecsPerWarp * warpId;
76         int vecEnd = min(vecsPerWarp * (warpId + 1), numVecs);
77 
78         // Walk the list of vectors for this warp
79         for (int vec = vecStart; vec < vecEnd; ++vec) {
80             Metric dist = metric.zero();
81 
82             // Scan the dimensions available that have whole units for the
83             // decoder, as the decoder may handle more than one dimension at
84             // once (leaving the remainder to be handled separately)
85             for (int d = laneId; d < limit; d += kWarpSize) {
86                 int realDim = d * Codec::kDimPerIter;
87                 float vecVal[Codec::kDimPerIter];
88 
89                 // Decode the kDimPerIter dimensions
90                 codec.decode(vecData, vec, d, vecVal);
91 
92 #pragma unroll
93                 for (int j = 0; j < Codec::kDimPerIter; ++j) {
94                     vecVal[j] +=
95                             useResidual ? residualBaseSlice[realDim + j] : 0.0f;
96                 }
97 
98 #pragma unroll
99                 for (int j = 0; j < Codec::kDimPerIter; ++j) {
100                     dist.handle(query[realDim + j], vecVal[j]);
101                 }
102             }
103 
104             // Handle remainder by a single thread, if any
105             // Not needed if we decode 1 dim per time
106             if (Codec::kDimPerIter > 1) {
107                 int realDim = limit * Codec::kDimPerIter;
108 
109                 // Was there any remainder?
110                 if (realDim < dim) {
111                     // Let the first threads in the block sequentially perform
112                     // it
113                     int remainderDim = realDim + laneId;
114 
115                     if (remainderDim < dim) {
116                         float vecVal = codec.decodePartial(
117                                 vecData, vec, limit, laneId);
118                         vecVal += useResidual ? residualBaseSlice[remainderDim]
119                                               : 0.0f;
120                         dist.handle(query[remainderDim], vecVal);
121                     }
122                 }
123             }
124 
125             // Reduce distance within warp
126             auto warpDist = warpReduceAllSum(dist.reduce());
127 
128             if (laneId == 0) {
129                 distanceOut[vec] = warpDist;
130             }
131         }
132     }
133 };
134 
135 template <typename Codec, typename Metric>
ivfFlatScan(Tensor<float,2,true> queries,bool useResidual,Tensor<float,3,true> residualBase,Tensor<int,2,true> listIds,void ** allListData,int * listLengths,Codec codec,Metric metric,Tensor<int,2,true> prefixSumOffsets,Tensor<float,1,true> distance)136 __global__ void ivfFlatScan(
137         Tensor<float, 2, true> queries,
138         bool useResidual,
139         Tensor<float, 3, true> residualBase,
140         Tensor<int, 2, true> listIds,
141         void** allListData,
142         int* listLengths,
143         Codec codec,
144         Metric metric,
145         Tensor<int, 2, true> prefixSumOffsets,
146         Tensor<float, 1, true> distance) {
147     extern __shared__ float smem[];
148 
149     auto queryId = blockIdx.y;
150     auto probeId = blockIdx.x;
151 
152     // This is where we start writing out data
153     // We ensure that before the array (at offset -1), there is a 0 value
154     int outBase = *(prefixSumOffsets[queryId][probeId].data() - 1);
155 
156     auto listId = listIds[queryId][probeId];
157     // Safety guard in case NaNs in input cause no list ID to be generated
158     if (listId == -1) {
159         return;
160     }
161 
162     auto query = queries[queryId].data();
163     auto vecs = allListData[listId];
164     auto numVecs = listLengths[listId];
165     auto dim = queries.getSize(1);
166     auto distanceOut = distance[outBase].data();
167 
168     auto residualBaseSlice = residualBase[queryId][probeId].data();
169 
170     codec.initKernel(smem, dim);
171     __syncthreads();
172 
173     IVFFlatScan<Codec, Metric>::scan(
174             query,
175             useResidual,
176             residualBaseSlice,
177             vecs,
178             codec,
179             metric,
180             numVecs,
181             dim,
182             distanceOut);
183 }
184 
runIVFFlatScanTile(GpuResources * res,Tensor<float,2,true> & queries,Tensor<int,2,true> & listIds,thrust::device_vector<void * > & listData,thrust::device_vector<void * > & listIndices,IndicesOptions indicesOptions,thrust::device_vector<int> & listLengths,Tensor<char,1,true> & thrustMem,Tensor<int,2,true> & prefixSumOffsets,Tensor<float,1,true> & allDistances,Tensor<float,3,true> & heapDistances,Tensor<int,3,true> & heapIndices,int k,faiss::MetricType metricType,bool useResidual,Tensor<float,3,true> & residualBase,GpuScalarQuantizer * scalarQ,Tensor<float,2,true> & outDistances,Tensor<Index::idx_t,2,true> & outIndices,cudaStream_t stream)185 void runIVFFlatScanTile(
186         GpuResources* res,
187         Tensor<float, 2, true>& queries,
188         Tensor<int, 2, true>& listIds,
189         thrust::device_vector<void*>& listData,
190         thrust::device_vector<void*>& listIndices,
191         IndicesOptions indicesOptions,
192         thrust::device_vector<int>& listLengths,
193         Tensor<char, 1, true>& thrustMem,
194         Tensor<int, 2, true>& prefixSumOffsets,
195         Tensor<float, 1, true>& allDistances,
196         Tensor<float, 3, true>& heapDistances,
197         Tensor<int, 3, true>& heapIndices,
198         int k,
199         faiss::MetricType metricType,
200         bool useResidual,
201         Tensor<float, 3, true>& residualBase,
202         GpuScalarQuantizer* scalarQ,
203         Tensor<float, 2, true>& outDistances,
204         Tensor<Index::idx_t, 2, true>& outIndices,
205         cudaStream_t stream) {
206     int dim = queries.getSize(1);
207 
208     // Check the amount of shared memory per block available based on our type
209     // is sufficient
210     if (scalarQ &&
211         (scalarQ->qtype == ScalarQuantizer::QuantizerType::QT_8bit ||
212          scalarQ->qtype == ScalarQuantizer::QuantizerType::QT_4bit)) {
213         int maxDim =
214                 getMaxSharedMemPerBlockCurrentDevice() / (sizeof(float) * 2);
215 
216         FAISS_THROW_IF_NOT_FMT(
217                 dim < maxDim,
218                 "Insufficient shared memory available on the GPU "
219                 "for QT_8bit or QT_4bit with %d dimensions; "
220                 "maximum dimensions possible is %d",
221                 dim,
222                 maxDim);
223     }
224 
225     // Calculate offset lengths, so we know where to write out
226     // intermediate results
227     runCalcListOffsets(
228             res, listIds, listLengths, prefixSumOffsets, thrustMem, stream);
229 
230     auto grid = dim3(listIds.getSize(1), listIds.getSize(0));
231     auto block = dim3(kWarpSize * kIVFFlatScanWarps);
232 
233 #define RUN_IVF_FLAT                                                  \
234     do {                                                              \
235         ivfFlatScan<<<grid, block, codec.getSmemSize(dim), stream>>>( \
236                 queries,                                              \
237                 useResidual,                                          \
238                 residualBase,                                         \
239                 listIds,                                              \
240                 listData.data().get(),                                \
241                 listLengths.data().get(),                             \
242                 codec,                                                \
243                 metric,                                               \
244                 prefixSumOffsets,                                     \
245                 allDistances);                                        \
246     } while (0)
247 
248 #define HANDLE_METRICS                             \
249     do {                                           \
250         if (metricType == MetricType::METRIC_L2) { \
251             L2Distance metric;                     \
252             RUN_IVF_FLAT;                          \
253         } else {                                   \
254             IPDistance metric;                     \
255             RUN_IVF_FLAT;                          \
256         }                                          \
257     } while (0)
258 
259     if (!scalarQ) {
260         CodecFloat codec(dim * sizeof(float));
261         HANDLE_METRICS;
262     } else {
263         switch (scalarQ->qtype) {
264             case ScalarQuantizer::QuantizerType::QT_8bit: {
265                 Codec<ScalarQuantizer::QuantizerType::QT_8bit, 1> codec(
266                         scalarQ->code_size,
267                         scalarQ->gpuTrained.data(),
268                         scalarQ->gpuTrained.data() + dim);
269                 HANDLE_METRICS;
270             } break;
271             case ScalarQuantizer::QuantizerType::QT_8bit_uniform: {
272                 Codec<ScalarQuantizer::QuantizerType::QT_8bit_uniform, 1> codec(
273                         scalarQ->code_size,
274                         scalarQ->trained[0],
275                         scalarQ->trained[1]);
276                 HANDLE_METRICS;
277             } break;
278             case ScalarQuantizer::QuantizerType::QT_fp16: {
279                 Codec<ScalarQuantizer::QuantizerType::QT_fp16, 1> codec(
280                         scalarQ->code_size);
281                 HANDLE_METRICS;
282             } break;
283             case ScalarQuantizer::QuantizerType::QT_8bit_direct: {
284                 Codec<ScalarQuantizer::QuantizerType::QT_8bit_direct, 1> codec(
285                         scalarQ->code_size);
286                 HANDLE_METRICS;
287             } break;
288             case ScalarQuantizer::QuantizerType::QT_4bit: {
289                 Codec<ScalarQuantizer::QuantizerType::QT_4bit, 1> codec(
290                         scalarQ->code_size,
291                         scalarQ->gpuTrained.data(),
292                         scalarQ->gpuTrained.data() + dim);
293                 HANDLE_METRICS;
294             } break;
295             case ScalarQuantizer::QuantizerType::QT_4bit_uniform: {
296                 Codec<ScalarQuantizer::QuantizerType::QT_4bit_uniform, 1> codec(
297                         scalarQ->code_size,
298                         scalarQ->trained[0],
299                         scalarQ->trained[1]);
300                 HANDLE_METRICS;
301             } break;
302             default:
303                 // unimplemented, should be handled at a higher level
304                 FAISS_ASSERT(false);
305         }
306     }
307 
308     CUDA_TEST_ERROR();
309 
310 #undef HANDLE_METRICS
311 #undef RUN_IVF_FLAT
312 
313     // k-select the output in chunks, to increase parallelism
314     runPass1SelectLists(
315             prefixSumOffsets,
316             allDistances,
317             listIds.getSize(1),
318             k,
319             metricToSortDirection(metricType),
320             heapDistances,
321             heapIndices,
322             stream);
323 
324     // k-select final output
325     auto flatHeapDistances = heapDistances.downcastInner<2>();
326     auto flatHeapIndices = heapIndices.downcastInner<2>();
327 
328     runPass2SelectLists(
329             flatHeapDistances,
330             flatHeapIndices,
331             listIndices,
332             indicesOptions,
333             prefixSumOffsets,
334             listIds,
335             k,
336             metricToSortDirection(metricType),
337             outDistances,
338             outIndices,
339             stream);
340 }
341 
runIVFFlatScan(Tensor<float,2,true> & queries,Tensor<int,2,true> & listIds,thrust::device_vector<void * > & listData,thrust::device_vector<void * > & listIndices,IndicesOptions indicesOptions,thrust::device_vector<int> & listLengths,int maxListLength,int k,faiss::MetricType metric,bool useResidual,Tensor<float,3,true> & residualBase,GpuScalarQuantizer * scalarQ,Tensor<float,2,true> & outDistances,Tensor<Index::idx_t,2,true> & outIndices,GpuResources * res)342 void runIVFFlatScan(
343         Tensor<float, 2, true>& queries,
344         Tensor<int, 2, true>& listIds,
345         thrust::device_vector<void*>& listData,
346         thrust::device_vector<void*>& listIndices,
347         IndicesOptions indicesOptions,
348         thrust::device_vector<int>& listLengths,
349         int maxListLength,
350         int k,
351         faiss::MetricType metric,
352         bool useResidual,
353         Tensor<float, 3, true>& residualBase,
354         GpuScalarQuantizer* scalarQ,
355         // output
356         Tensor<float, 2, true>& outDistances,
357         // output
358         Tensor<Index::idx_t, 2, true>& outIndices,
359         GpuResources* res) {
360     constexpr int kMinQueryTileSize = 8;
361     constexpr int kMaxQueryTileSize = 128;
362     constexpr int kThrustMemSize = 16384;
363 
364     int nprobe = listIds.getSize(1);
365     auto stream = res->getDefaultStreamCurrentDevice();
366 
367     // Make a reservation for Thrust to do its dirty work (global memory
368     // cross-block reduction space); hopefully this is large enough.
369     DeviceTensor<char, 1, true> thrustMem1(
370             res, makeTempAlloc(AllocType::Other, stream), {kThrustMemSize});
371     DeviceTensor<char, 1, true> thrustMem2(
372             res, makeTempAlloc(AllocType::Other, stream), {kThrustMemSize});
373     DeviceTensor<char, 1, true>* thrustMem[2] = {&thrustMem1, &thrustMem2};
374 
375     // How much temporary storage is available?
376     // If possible, we'd like to fit within the space available.
377     size_t sizeAvailable = res->getTempMemoryAvailableCurrentDevice();
378 
379     // We run two passes of heap selection
380     // This is the size of the first-level heap passes
381     constexpr int kNProbeSplit = 8;
382     int pass2Chunks = std::min(nprobe, kNProbeSplit);
383 
384     size_t sizeForFirstSelectPass =
385             pass2Chunks * k * (sizeof(float) + sizeof(int));
386 
387     // How much temporary storage we need per each query
388     size_t sizePerQuery = 2 *                         // # streams
389             ((nprobe * sizeof(int) + sizeof(int)) +   // prefixSumOffsets
390              nprobe * maxListLength * sizeof(float) + // allDistances
391              sizeForFirstSelectPass);
392 
393     int queryTileSize = (int)(sizeAvailable / sizePerQuery);
394 
395     if (queryTileSize < kMinQueryTileSize) {
396         queryTileSize = kMinQueryTileSize;
397     } else if (queryTileSize > kMaxQueryTileSize) {
398         queryTileSize = kMaxQueryTileSize;
399     }
400 
401     // FIXME: we should adjust queryTileSize to deal with this, since
402     // indexing is in int32
403     FAISS_ASSERT(
404             queryTileSize * nprobe * maxListLength <
405             std::numeric_limits<int>::max());
406 
407     // Temporary memory buffers
408     // Make sure there is space prior to the start which will be 0, and
409     // will handle the boundary condition without branches
410     DeviceTensor<int, 1, true> prefixSumOffsetSpace1(
411             res,
412             makeTempAlloc(AllocType::Other, stream),
413             {queryTileSize * nprobe + 1});
414     DeviceTensor<int, 1, true> prefixSumOffsetSpace2(
415             res,
416             makeTempAlloc(AllocType::Other, stream),
417             {queryTileSize * nprobe + 1});
418 
419     DeviceTensor<int, 2, true> prefixSumOffsets1(
420             prefixSumOffsetSpace1[1].data(), {queryTileSize, nprobe});
421     DeviceTensor<int, 2, true> prefixSumOffsets2(
422             prefixSumOffsetSpace2[1].data(), {queryTileSize, nprobe});
423     DeviceTensor<int, 2, true>* prefixSumOffsets[2] = {
424             &prefixSumOffsets1, &prefixSumOffsets2};
425 
426     // Make sure the element before prefixSumOffsets is 0, since we
427     // depend upon simple, boundary-less indexing to get proper results
428     CUDA_VERIFY(cudaMemsetAsync(
429             prefixSumOffsetSpace1.data(), 0, sizeof(int), stream));
430     CUDA_VERIFY(cudaMemsetAsync(
431             prefixSumOffsetSpace2.data(), 0, sizeof(int), stream));
432 
433     DeviceTensor<float, 1, true> allDistances1(
434             res,
435             makeTempAlloc(AllocType::Other, stream),
436             {queryTileSize * nprobe * maxListLength});
437     DeviceTensor<float, 1, true> allDistances2(
438             res,
439             makeTempAlloc(AllocType::Other, stream),
440             {queryTileSize * nprobe * maxListLength});
441     DeviceTensor<float, 1, true>* allDistances[2] = {
442             &allDistances1, &allDistances2};
443 
444     DeviceTensor<float, 3, true> heapDistances1(
445             res,
446             makeTempAlloc(AllocType::Other, stream),
447             {queryTileSize, pass2Chunks, k});
448     DeviceTensor<float, 3, true> heapDistances2(
449             res,
450             makeTempAlloc(AllocType::Other, stream),
451             {queryTileSize, pass2Chunks, k});
452     DeviceTensor<float, 3, true>* heapDistances[2] = {
453             &heapDistances1, &heapDistances2};
454 
455     DeviceTensor<int, 3, true> heapIndices1(
456             res,
457             makeTempAlloc(AllocType::Other, stream),
458             {queryTileSize, pass2Chunks, k});
459     DeviceTensor<int, 3, true> heapIndices2(
460             res,
461             makeTempAlloc(AllocType::Other, stream),
462             {queryTileSize, pass2Chunks, k});
463     DeviceTensor<int, 3, true>* heapIndices[2] = {&heapIndices1, &heapIndices2};
464 
465     auto streams = res->getAlternateStreamsCurrentDevice();
466     streamWait(streams, {stream});
467 
468     int curStream = 0;
469 
470     for (int query = 0; query < queries.getSize(0); query += queryTileSize) {
471         int numQueriesInTile =
472                 std::min(queryTileSize, queries.getSize(0) - query);
473 
474         auto prefixSumOffsetsView =
475                 prefixSumOffsets[curStream]->narrowOutermost(
476                         0, numQueriesInTile);
477 
478         auto listIdsView = listIds.narrowOutermost(query, numQueriesInTile);
479         auto queryView = queries.narrowOutermost(query, numQueriesInTile);
480         auto residualBaseView =
481                 residualBase.narrowOutermost(query, numQueriesInTile);
482 
483         auto heapDistancesView =
484                 heapDistances[curStream]->narrowOutermost(0, numQueriesInTile);
485         auto heapIndicesView =
486                 heapIndices[curStream]->narrowOutermost(0, numQueriesInTile);
487 
488         auto outDistanceView =
489                 outDistances.narrowOutermost(query, numQueriesInTile);
490         auto outIndicesView =
491                 outIndices.narrowOutermost(query, numQueriesInTile);
492 
493         runIVFFlatScanTile(
494                 res,
495                 queryView,
496                 listIdsView,
497                 listData,
498                 listIndices,
499                 indicesOptions,
500                 listLengths,
501                 *thrustMem[curStream],
502                 prefixSumOffsetsView,
503                 *allDistances[curStream],
504                 heapDistancesView,
505                 heapIndicesView,
506                 k,
507                 metric,
508                 useResidual,
509                 residualBaseView,
510                 scalarQ,
511                 outDistanceView,
512                 outIndicesView,
513                 streams[curStream]);
514 
515         curStream = (curStream + 1) % 2;
516     }
517 
518     streamWait({stream}, streams);
519 }
520 
521 } // namespace gpu
522 } // namespace faiss
523