1 /**
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 *
4 * This source code is licensed under the MIT license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
8 #include <faiss/gpu/GpuResources.h>
9 #include <faiss/gpu/utils/DeviceUtils.h>
10 #include <faiss/gpu/utils/StaticUtils.h>
11 #include <faiss/gpu/impl/DistanceUtils.cuh>
12 #include <faiss/gpu/impl/IVFFlatScan.cuh>
13 #include <faiss/gpu/impl/IVFUtils.cuh>
14 #include <faiss/gpu/utils/Comparators.cuh>
15 #include <faiss/gpu/utils/ConversionOperators.cuh>
16 #include <faiss/gpu/utils/DeviceDefs.cuh>
17 #include <faiss/gpu/utils/DeviceTensor.cuh>
18 #include <faiss/gpu/utils/Float16.cuh>
19 #include <faiss/gpu/utils/MathOperators.cuh>
20 #include <faiss/gpu/utils/PtxUtils.cuh>
21 #include <faiss/gpu/utils/Reductions.cuh>
22
23 #include <algorithm>
24
25 namespace faiss {
26 namespace gpu {
27
28 namespace {
29
30 /// Sort direction per each metric
metricToSortDirection(MetricType mt)31 inline bool metricToSortDirection(MetricType mt) {
32 switch (mt) {
33 case MetricType::METRIC_INNER_PRODUCT:
34 // highest
35 return true;
36 case MetricType::METRIC_L2:
37 // lowest
38 return false;
39 default:
40 // unhandled metric
41 FAISS_ASSERT(false);
42 return false;
43 }
44 }
45
46 } // namespace
47
48 // Number of warps we create per block of IVFFlatScan
49 constexpr int kIVFFlatScanWarps = 4;
50
51 // Works for any dimension size
52 template <typename Codec, typename Metric>
53 struct IVFFlatScan {
scanfaiss::gpu::IVFFlatScan54 static __device__ void scan(
55 float* query,
56 bool useResidual,
57 float* residualBaseSlice,
58 void* vecData,
59 const Codec& codec,
60 const Metric& metric,
61 int numVecs,
62 int dim,
63 float* distanceOut) {
64 // How many separate loading points are there for the decoder?
65 int limit = utils::divDown(dim, Codec::kDimPerIter);
66
67 // Each warp handles a separate chunk of vectors
68 int warpId = threadIdx.x / kWarpSize;
69 // FIXME: why does getLaneId() not work when we write out below!?!?!
70 int laneId = threadIdx.x % kWarpSize; // getLaneId();
71
72 // Divide the set of vectors among the warps
73 int vecsPerWarp = utils::divUp(numVecs, kIVFFlatScanWarps);
74
75 int vecStart = vecsPerWarp * warpId;
76 int vecEnd = min(vecsPerWarp * (warpId + 1), numVecs);
77
78 // Walk the list of vectors for this warp
79 for (int vec = vecStart; vec < vecEnd; ++vec) {
80 Metric dist = metric.zero();
81
82 // Scan the dimensions available that have whole units for the
83 // decoder, as the decoder may handle more than one dimension at
84 // once (leaving the remainder to be handled separately)
85 for (int d = laneId; d < limit; d += kWarpSize) {
86 int realDim = d * Codec::kDimPerIter;
87 float vecVal[Codec::kDimPerIter];
88
89 // Decode the kDimPerIter dimensions
90 codec.decode(vecData, vec, d, vecVal);
91
92 #pragma unroll
93 for (int j = 0; j < Codec::kDimPerIter; ++j) {
94 vecVal[j] +=
95 useResidual ? residualBaseSlice[realDim + j] : 0.0f;
96 }
97
98 #pragma unroll
99 for (int j = 0; j < Codec::kDimPerIter; ++j) {
100 dist.handle(query[realDim + j], vecVal[j]);
101 }
102 }
103
104 // Handle remainder by a single thread, if any
105 // Not needed if we decode 1 dim per time
106 if (Codec::kDimPerIter > 1) {
107 int realDim = limit * Codec::kDimPerIter;
108
109 // Was there any remainder?
110 if (realDim < dim) {
111 // Let the first threads in the block sequentially perform
112 // it
113 int remainderDim = realDim + laneId;
114
115 if (remainderDim < dim) {
116 float vecVal = codec.decodePartial(
117 vecData, vec, limit, laneId);
118 vecVal += useResidual ? residualBaseSlice[remainderDim]
119 : 0.0f;
120 dist.handle(query[remainderDim], vecVal);
121 }
122 }
123 }
124
125 // Reduce distance within warp
126 auto warpDist = warpReduceAllSum(dist.reduce());
127
128 if (laneId == 0) {
129 distanceOut[vec] = warpDist;
130 }
131 }
132 }
133 };
134
135 template <typename Codec, typename Metric>
ivfFlatScan(Tensor<float,2,true> queries,bool useResidual,Tensor<float,3,true> residualBase,Tensor<int,2,true> listIds,void ** allListData,int * listLengths,Codec codec,Metric metric,Tensor<int,2,true> prefixSumOffsets,Tensor<float,1,true> distance)136 __global__ void ivfFlatScan(
137 Tensor<float, 2, true> queries,
138 bool useResidual,
139 Tensor<float, 3, true> residualBase,
140 Tensor<int, 2, true> listIds,
141 void** allListData,
142 int* listLengths,
143 Codec codec,
144 Metric metric,
145 Tensor<int, 2, true> prefixSumOffsets,
146 Tensor<float, 1, true> distance) {
147 extern __shared__ float smem[];
148
149 auto queryId = blockIdx.y;
150 auto probeId = blockIdx.x;
151
152 // This is where we start writing out data
153 // We ensure that before the array (at offset -1), there is a 0 value
154 int outBase = *(prefixSumOffsets[queryId][probeId].data() - 1);
155
156 auto listId = listIds[queryId][probeId];
157 // Safety guard in case NaNs in input cause no list ID to be generated
158 if (listId == -1) {
159 return;
160 }
161
162 auto query = queries[queryId].data();
163 auto vecs = allListData[listId];
164 auto numVecs = listLengths[listId];
165 auto dim = queries.getSize(1);
166 auto distanceOut = distance[outBase].data();
167
168 auto residualBaseSlice = residualBase[queryId][probeId].data();
169
170 codec.initKernel(smem, dim);
171 __syncthreads();
172
173 IVFFlatScan<Codec, Metric>::scan(
174 query,
175 useResidual,
176 residualBaseSlice,
177 vecs,
178 codec,
179 metric,
180 numVecs,
181 dim,
182 distanceOut);
183 }
184
runIVFFlatScanTile(GpuResources * res,Tensor<float,2,true> & queries,Tensor<int,2,true> & listIds,thrust::device_vector<void * > & listData,thrust::device_vector<void * > & listIndices,IndicesOptions indicesOptions,thrust::device_vector<int> & listLengths,Tensor<char,1,true> & thrustMem,Tensor<int,2,true> & prefixSumOffsets,Tensor<float,1,true> & allDistances,Tensor<float,3,true> & heapDistances,Tensor<int,3,true> & heapIndices,int k,faiss::MetricType metricType,bool useResidual,Tensor<float,3,true> & residualBase,GpuScalarQuantizer * scalarQ,Tensor<float,2,true> & outDistances,Tensor<Index::idx_t,2,true> & outIndices,cudaStream_t stream)185 void runIVFFlatScanTile(
186 GpuResources* res,
187 Tensor<float, 2, true>& queries,
188 Tensor<int, 2, true>& listIds,
189 thrust::device_vector<void*>& listData,
190 thrust::device_vector<void*>& listIndices,
191 IndicesOptions indicesOptions,
192 thrust::device_vector<int>& listLengths,
193 Tensor<char, 1, true>& thrustMem,
194 Tensor<int, 2, true>& prefixSumOffsets,
195 Tensor<float, 1, true>& allDistances,
196 Tensor<float, 3, true>& heapDistances,
197 Tensor<int, 3, true>& heapIndices,
198 int k,
199 faiss::MetricType metricType,
200 bool useResidual,
201 Tensor<float, 3, true>& residualBase,
202 GpuScalarQuantizer* scalarQ,
203 Tensor<float, 2, true>& outDistances,
204 Tensor<Index::idx_t, 2, true>& outIndices,
205 cudaStream_t stream) {
206 int dim = queries.getSize(1);
207
208 // Check the amount of shared memory per block available based on our type
209 // is sufficient
210 if (scalarQ &&
211 (scalarQ->qtype == ScalarQuantizer::QuantizerType::QT_8bit ||
212 scalarQ->qtype == ScalarQuantizer::QuantizerType::QT_4bit)) {
213 int maxDim =
214 getMaxSharedMemPerBlockCurrentDevice() / (sizeof(float) * 2);
215
216 FAISS_THROW_IF_NOT_FMT(
217 dim < maxDim,
218 "Insufficient shared memory available on the GPU "
219 "for QT_8bit or QT_4bit with %d dimensions; "
220 "maximum dimensions possible is %d",
221 dim,
222 maxDim);
223 }
224
225 // Calculate offset lengths, so we know where to write out
226 // intermediate results
227 runCalcListOffsets(
228 res, listIds, listLengths, prefixSumOffsets, thrustMem, stream);
229
230 auto grid = dim3(listIds.getSize(1), listIds.getSize(0));
231 auto block = dim3(kWarpSize * kIVFFlatScanWarps);
232
233 #define RUN_IVF_FLAT \
234 do { \
235 ivfFlatScan<<<grid, block, codec.getSmemSize(dim), stream>>>( \
236 queries, \
237 useResidual, \
238 residualBase, \
239 listIds, \
240 listData.data().get(), \
241 listLengths.data().get(), \
242 codec, \
243 metric, \
244 prefixSumOffsets, \
245 allDistances); \
246 } while (0)
247
248 #define HANDLE_METRICS \
249 do { \
250 if (metricType == MetricType::METRIC_L2) { \
251 L2Distance metric; \
252 RUN_IVF_FLAT; \
253 } else { \
254 IPDistance metric; \
255 RUN_IVF_FLAT; \
256 } \
257 } while (0)
258
259 if (!scalarQ) {
260 CodecFloat codec(dim * sizeof(float));
261 HANDLE_METRICS;
262 } else {
263 switch (scalarQ->qtype) {
264 case ScalarQuantizer::QuantizerType::QT_8bit: {
265 Codec<ScalarQuantizer::QuantizerType::QT_8bit, 1> codec(
266 scalarQ->code_size,
267 scalarQ->gpuTrained.data(),
268 scalarQ->gpuTrained.data() + dim);
269 HANDLE_METRICS;
270 } break;
271 case ScalarQuantizer::QuantizerType::QT_8bit_uniform: {
272 Codec<ScalarQuantizer::QuantizerType::QT_8bit_uniform, 1> codec(
273 scalarQ->code_size,
274 scalarQ->trained[0],
275 scalarQ->trained[1]);
276 HANDLE_METRICS;
277 } break;
278 case ScalarQuantizer::QuantizerType::QT_fp16: {
279 Codec<ScalarQuantizer::QuantizerType::QT_fp16, 1> codec(
280 scalarQ->code_size);
281 HANDLE_METRICS;
282 } break;
283 case ScalarQuantizer::QuantizerType::QT_8bit_direct: {
284 Codec<ScalarQuantizer::QuantizerType::QT_8bit_direct, 1> codec(
285 scalarQ->code_size);
286 HANDLE_METRICS;
287 } break;
288 case ScalarQuantizer::QuantizerType::QT_4bit: {
289 Codec<ScalarQuantizer::QuantizerType::QT_4bit, 1> codec(
290 scalarQ->code_size,
291 scalarQ->gpuTrained.data(),
292 scalarQ->gpuTrained.data() + dim);
293 HANDLE_METRICS;
294 } break;
295 case ScalarQuantizer::QuantizerType::QT_4bit_uniform: {
296 Codec<ScalarQuantizer::QuantizerType::QT_4bit_uniform, 1> codec(
297 scalarQ->code_size,
298 scalarQ->trained[0],
299 scalarQ->trained[1]);
300 HANDLE_METRICS;
301 } break;
302 default:
303 // unimplemented, should be handled at a higher level
304 FAISS_ASSERT(false);
305 }
306 }
307
308 CUDA_TEST_ERROR();
309
310 #undef HANDLE_METRICS
311 #undef RUN_IVF_FLAT
312
313 // k-select the output in chunks, to increase parallelism
314 runPass1SelectLists(
315 prefixSumOffsets,
316 allDistances,
317 listIds.getSize(1),
318 k,
319 metricToSortDirection(metricType),
320 heapDistances,
321 heapIndices,
322 stream);
323
324 // k-select final output
325 auto flatHeapDistances = heapDistances.downcastInner<2>();
326 auto flatHeapIndices = heapIndices.downcastInner<2>();
327
328 runPass2SelectLists(
329 flatHeapDistances,
330 flatHeapIndices,
331 listIndices,
332 indicesOptions,
333 prefixSumOffsets,
334 listIds,
335 k,
336 metricToSortDirection(metricType),
337 outDistances,
338 outIndices,
339 stream);
340 }
341
runIVFFlatScan(Tensor<float,2,true> & queries,Tensor<int,2,true> & listIds,thrust::device_vector<void * > & listData,thrust::device_vector<void * > & listIndices,IndicesOptions indicesOptions,thrust::device_vector<int> & listLengths,int maxListLength,int k,faiss::MetricType metric,bool useResidual,Tensor<float,3,true> & residualBase,GpuScalarQuantizer * scalarQ,Tensor<float,2,true> & outDistances,Tensor<Index::idx_t,2,true> & outIndices,GpuResources * res)342 void runIVFFlatScan(
343 Tensor<float, 2, true>& queries,
344 Tensor<int, 2, true>& listIds,
345 thrust::device_vector<void*>& listData,
346 thrust::device_vector<void*>& listIndices,
347 IndicesOptions indicesOptions,
348 thrust::device_vector<int>& listLengths,
349 int maxListLength,
350 int k,
351 faiss::MetricType metric,
352 bool useResidual,
353 Tensor<float, 3, true>& residualBase,
354 GpuScalarQuantizer* scalarQ,
355 // output
356 Tensor<float, 2, true>& outDistances,
357 // output
358 Tensor<Index::idx_t, 2, true>& outIndices,
359 GpuResources* res) {
360 constexpr int kMinQueryTileSize = 8;
361 constexpr int kMaxQueryTileSize = 128;
362 constexpr int kThrustMemSize = 16384;
363
364 int nprobe = listIds.getSize(1);
365 auto stream = res->getDefaultStreamCurrentDevice();
366
367 // Make a reservation for Thrust to do its dirty work (global memory
368 // cross-block reduction space); hopefully this is large enough.
369 DeviceTensor<char, 1, true> thrustMem1(
370 res, makeTempAlloc(AllocType::Other, stream), {kThrustMemSize});
371 DeviceTensor<char, 1, true> thrustMem2(
372 res, makeTempAlloc(AllocType::Other, stream), {kThrustMemSize});
373 DeviceTensor<char, 1, true>* thrustMem[2] = {&thrustMem1, &thrustMem2};
374
375 // How much temporary storage is available?
376 // If possible, we'd like to fit within the space available.
377 size_t sizeAvailable = res->getTempMemoryAvailableCurrentDevice();
378
379 // We run two passes of heap selection
380 // This is the size of the first-level heap passes
381 constexpr int kNProbeSplit = 8;
382 int pass2Chunks = std::min(nprobe, kNProbeSplit);
383
384 size_t sizeForFirstSelectPass =
385 pass2Chunks * k * (sizeof(float) + sizeof(int));
386
387 // How much temporary storage we need per each query
388 size_t sizePerQuery = 2 * // # streams
389 ((nprobe * sizeof(int) + sizeof(int)) + // prefixSumOffsets
390 nprobe * maxListLength * sizeof(float) + // allDistances
391 sizeForFirstSelectPass);
392
393 int queryTileSize = (int)(sizeAvailable / sizePerQuery);
394
395 if (queryTileSize < kMinQueryTileSize) {
396 queryTileSize = kMinQueryTileSize;
397 } else if (queryTileSize > kMaxQueryTileSize) {
398 queryTileSize = kMaxQueryTileSize;
399 }
400
401 // FIXME: we should adjust queryTileSize to deal with this, since
402 // indexing is in int32
403 FAISS_ASSERT(
404 queryTileSize * nprobe * maxListLength <
405 std::numeric_limits<int>::max());
406
407 // Temporary memory buffers
408 // Make sure there is space prior to the start which will be 0, and
409 // will handle the boundary condition without branches
410 DeviceTensor<int, 1, true> prefixSumOffsetSpace1(
411 res,
412 makeTempAlloc(AllocType::Other, stream),
413 {queryTileSize * nprobe + 1});
414 DeviceTensor<int, 1, true> prefixSumOffsetSpace2(
415 res,
416 makeTempAlloc(AllocType::Other, stream),
417 {queryTileSize * nprobe + 1});
418
419 DeviceTensor<int, 2, true> prefixSumOffsets1(
420 prefixSumOffsetSpace1[1].data(), {queryTileSize, nprobe});
421 DeviceTensor<int, 2, true> prefixSumOffsets2(
422 prefixSumOffsetSpace2[1].data(), {queryTileSize, nprobe});
423 DeviceTensor<int, 2, true>* prefixSumOffsets[2] = {
424 &prefixSumOffsets1, &prefixSumOffsets2};
425
426 // Make sure the element before prefixSumOffsets is 0, since we
427 // depend upon simple, boundary-less indexing to get proper results
428 CUDA_VERIFY(cudaMemsetAsync(
429 prefixSumOffsetSpace1.data(), 0, sizeof(int), stream));
430 CUDA_VERIFY(cudaMemsetAsync(
431 prefixSumOffsetSpace2.data(), 0, sizeof(int), stream));
432
433 DeviceTensor<float, 1, true> allDistances1(
434 res,
435 makeTempAlloc(AllocType::Other, stream),
436 {queryTileSize * nprobe * maxListLength});
437 DeviceTensor<float, 1, true> allDistances2(
438 res,
439 makeTempAlloc(AllocType::Other, stream),
440 {queryTileSize * nprobe * maxListLength});
441 DeviceTensor<float, 1, true>* allDistances[2] = {
442 &allDistances1, &allDistances2};
443
444 DeviceTensor<float, 3, true> heapDistances1(
445 res,
446 makeTempAlloc(AllocType::Other, stream),
447 {queryTileSize, pass2Chunks, k});
448 DeviceTensor<float, 3, true> heapDistances2(
449 res,
450 makeTempAlloc(AllocType::Other, stream),
451 {queryTileSize, pass2Chunks, k});
452 DeviceTensor<float, 3, true>* heapDistances[2] = {
453 &heapDistances1, &heapDistances2};
454
455 DeviceTensor<int, 3, true> heapIndices1(
456 res,
457 makeTempAlloc(AllocType::Other, stream),
458 {queryTileSize, pass2Chunks, k});
459 DeviceTensor<int, 3, true> heapIndices2(
460 res,
461 makeTempAlloc(AllocType::Other, stream),
462 {queryTileSize, pass2Chunks, k});
463 DeviceTensor<int, 3, true>* heapIndices[2] = {&heapIndices1, &heapIndices2};
464
465 auto streams = res->getAlternateStreamsCurrentDevice();
466 streamWait(streams, {stream});
467
468 int curStream = 0;
469
470 for (int query = 0; query < queries.getSize(0); query += queryTileSize) {
471 int numQueriesInTile =
472 std::min(queryTileSize, queries.getSize(0) - query);
473
474 auto prefixSumOffsetsView =
475 prefixSumOffsets[curStream]->narrowOutermost(
476 0, numQueriesInTile);
477
478 auto listIdsView = listIds.narrowOutermost(query, numQueriesInTile);
479 auto queryView = queries.narrowOutermost(query, numQueriesInTile);
480 auto residualBaseView =
481 residualBase.narrowOutermost(query, numQueriesInTile);
482
483 auto heapDistancesView =
484 heapDistances[curStream]->narrowOutermost(0, numQueriesInTile);
485 auto heapIndicesView =
486 heapIndices[curStream]->narrowOutermost(0, numQueriesInTile);
487
488 auto outDistanceView =
489 outDistances.narrowOutermost(query, numQueriesInTile);
490 auto outIndicesView =
491 outIndices.narrowOutermost(query, numQueriesInTile);
492
493 runIVFFlatScanTile(
494 res,
495 queryView,
496 listIdsView,
497 listData,
498 listIndices,
499 indicesOptions,
500 listLengths,
501 *thrustMem[curStream],
502 prefixSumOffsetsView,
503 *allDistances[curStream],
504 heapDistancesView,
505 heapIndicesView,
506 k,
507 metric,
508 useResidual,
509 residualBaseView,
510 scalarQ,
511 outDistanceView,
512 outIndicesView,
513 streams[curStream]);
514
515 curStream = (curStream + 1) % 2;
516 }
517
518 streamWait({stream}, streams);
519 }
520
521 } // namespace gpu
522 } // namespace faiss
523