1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #include <faiss/gpu/impl/IVFInterleaved.cuh>
9 #include <faiss/gpu/impl/scan/IVFInterleavedImpl.cuh>
10 
11 namespace faiss {
12 namespace gpu {
13 
14 constexpr uint32_t kMaxUInt32 = std::numeric_limits<uint32_t>::max();
15 
16 // Second-pass kernel to further k-select the results from the first pass across
17 // IVF lists and produce the final results
18 template <int ThreadsPerBlock, int NumWarpQ, int NumThreadQ>
ivfInterleavedScan2(Tensor<float,3,true> distanceIn,Tensor<int,3,true> indicesIn,Tensor<int,2,true> listIds,int k,void ** listIndices,IndicesOptions opt,bool dir,Tensor<float,2,true> distanceOut,Tensor<Index::idx_t,2,true> indicesOut)19 __global__ void ivfInterleavedScan2(
20         Tensor<float, 3, true> distanceIn,
21         Tensor<int, 3, true> indicesIn,
22         Tensor<int, 2, true> listIds,
23         int k,
24         void** listIndices,
25         IndicesOptions opt,
26         bool dir,
27         Tensor<float, 2, true> distanceOut,
28         Tensor<Index::idx_t, 2, true> indicesOut) {
29     int queryId = blockIdx.x;
30 
31     constexpr int kNumWarps = ThreadsPerBlock / kWarpSize;
32 
33     __shared__ float smemK[kNumWarps * NumWarpQ];
34     __shared__ uint32_t smemV[kNumWarps * NumWarpQ];
35 
36     // To avoid creating excessive specializations, we combine direction
37     // kernels, selecting for the smallest element. If `dir` is true, we negate
38     // all values being selected (so that we are selecting the largest element).
39     BlockSelect<
40             float,
41             uint32_t,
42             false,
43             Comparator<float>,
44             NumWarpQ,
45             NumThreadQ,
46             ThreadsPerBlock>
47             heap(kFloatMax, kMaxUInt32, smemK, smemV, k);
48 
49     // nprobe x k
50     int num = distanceIn.getSize(1) * distanceIn.getSize(2);
51 
52     auto distanceBase = distanceIn[queryId].data();
53     int limit = utils::roundDown(num, kWarpSize);
54 
55     // This will keep our negation factor
56     float adj = dir ? -1 : 1;
57 
58     int i = threadIdx.x;
59     for (; i < limit; i += blockDim.x) {
60         // We represent the index as (probe id)(k)
61         // Right now, both are limited to a maximum of 2048, but we will
62         // dedicate each to the high and low words of a uint32_t
63         static_assert(GPU_MAX_SELECTION_K <= 65536, "");
64 
65         uint32_t curProbe = i / k;
66         uint32_t curK = i % k;
67         uint32_t index = (curProbe << 16) | (curK & (uint32_t)0xffff);
68 
69         int listId = listIds[queryId][curProbe];
70         if (listId != -1) {
71             // Adjust the value we are selecting based on the sorting order
72             heap.addThreadQ(distanceBase[i] * adj, index);
73         }
74 
75         heap.checkThreadQ();
76     }
77 
78     // Handle warp divergence separately
79     if (i < num) {
80         uint32_t curProbe = i / k;
81         uint32_t curK = i % k;
82         uint32_t index = (curProbe << 16) | (curK & (uint32_t)0xffff);
83 
84         int listId = listIds[queryId][curProbe];
85         if (listId != -1) {
86             heap.addThreadQ(distanceBase[i] * adj, index);
87         }
88     }
89 
90     // Merge all final results
91     heap.reduce();
92 
93     for (int i = threadIdx.x; i < k; i += blockDim.x) {
94         // Re-adjust the value we are selecting based on the sorting order
95         distanceOut[queryId][i] = smemK[i] * adj;
96         auto packedIndex = smemV[i];
97 
98         // We need to remap to the user-provided indices
99         Index::idx_t index = -1;
100 
101         // We may not have at least k values to return; in this function, max
102         // uint32 is our sentinel value
103         if (packedIndex != kMaxUInt32) {
104             uint32_t curProbe = packedIndex >> 16;
105             uint32_t curK = packedIndex & 0xffff;
106 
107             int listId = listIds[queryId][curProbe];
108             int listOffset = indicesIn[queryId][curProbe][curK];
109 
110             if (opt == INDICES_32_BIT) {
111                 index = (Index::idx_t)((int*)listIndices[listId])[listOffset];
112             } else if (opt == INDICES_64_BIT) {
113                 index = ((Index::idx_t*)listIndices[listId])[listOffset];
114             } else {
115                 index = ((Index::idx_t)listId << 32 | (Index::idx_t)listOffset);
116             }
117         }
118 
119         indicesOut[queryId][i] = index;
120     }
121 }
122 
runIVFInterleavedScan2(Tensor<float,3,true> & distanceIn,Tensor<int,3,true> & indicesIn,Tensor<int,2,true> & listIds,int k,thrust::device_vector<void * > & listIndices,IndicesOptions indicesOptions,bool dir,Tensor<float,2,true> & distanceOut,Tensor<Index::idx_t,2,true> & indicesOut,cudaStream_t stream)123 void runIVFInterleavedScan2(
124         Tensor<float, 3, true>& distanceIn,
125         Tensor<int, 3, true>& indicesIn,
126         Tensor<int, 2, true>& listIds,
127         int k,
128         thrust::device_vector<void*>& listIndices,
129         IndicesOptions indicesOptions,
130         bool dir,
131         Tensor<float, 2, true>& distanceOut,
132         Tensor<Index::idx_t, 2, true>& indicesOut,
133         cudaStream_t stream) {
134 #define IVF_SCAN_2(THREADS, NUM_WARP_Q, NUM_THREAD_Q)        \
135     ivfInterleavedScan2<THREADS, NUM_WARP_Q, NUM_THREAD_Q>   \
136             <<<distanceIn.getSize(0), THREADS, 0, stream>>>( \
137                     distanceIn,                              \
138                     indicesIn,                               \
139                     listIds,                                 \
140                     k,                                       \
141                     listIndices.data().get(),                \
142                     indicesOptions,                          \
143                     dir,                                     \
144                     distanceOut,                             \
145                     indicesOut)
146 
147     if (k == 1) {
148         IVF_SCAN_2(128, 1, 1);
149     } else if (k <= 32) {
150         IVF_SCAN_2(128, 32, 2);
151     } else if (k <= 64) {
152         IVF_SCAN_2(128, 64, 3);
153     } else if (k <= 128) {
154         IVF_SCAN_2(128, 128, 3);
155     } else if (k <= 256) {
156         IVF_SCAN_2(128, 256, 4);
157     } else if (k <= 512) {
158         IVF_SCAN_2(128, 512, 8);
159     } else if (k <= 1024) {
160         IVF_SCAN_2(128, 1024, 8);
161     }
162 #if GPU_MAX_SELECTION_K >= 2048
163     else if (k <= 2048) {
164         IVF_SCAN_2(64, 2048, 8);
165     }
166 #endif
167 }
168 
runIVFInterleavedScan(Tensor<float,2,true> & queries,Tensor<int,2,true> & listIds,thrust::device_vector<void * > & listData,thrust::device_vector<void * > & listIndices,IndicesOptions indicesOptions,thrust::device_vector<int> & listLengths,int k,faiss::MetricType metric,bool useResidual,Tensor<float,3,true> & residualBase,GpuScalarQuantizer * scalarQ,Tensor<float,2,true> & outDistances,Tensor<Index::idx_t,2,true> & outIndices,GpuResources * res)169 void runIVFInterleavedScan(
170         Tensor<float, 2, true>& queries,
171         Tensor<int, 2, true>& listIds,
172         thrust::device_vector<void*>& listData,
173         thrust::device_vector<void*>& listIndices,
174         IndicesOptions indicesOptions,
175         thrust::device_vector<int>& listLengths,
176         int k,
177         faiss::MetricType metric,
178         bool useResidual,
179         Tensor<float, 3, true>& residualBase,
180         GpuScalarQuantizer* scalarQ,
181         // output
182         Tensor<float, 2, true>& outDistances,
183         // output
184         Tensor<Index::idx_t, 2, true>& outIndices,
185         GpuResources* res) {
186     // caught for exceptions at a higher level
187     FAISS_ASSERT(k <= GPU_MAX_SELECTION_K);
188 
189     if (k == 1) {
190         IVF_INTERLEAVED_CALL(1);
191     } else if (k <= 32) {
192         IVF_INTERLEAVED_CALL(32);
193     } else if (k <= 64) {
194         IVF_INTERLEAVED_CALL(64);
195     } else if (k <= 128) {
196         IVF_INTERLEAVED_CALL(128);
197     } else if (k <= 256) {
198         IVF_INTERLEAVED_CALL(256);
199     } else if (k <= 512) {
200         IVF_INTERLEAVED_CALL(512);
201     } else if (k <= 1024) {
202         IVF_INTERLEAVED_CALL(1024);
203     }
204 #if GPU_MAX_SELECTION_K >= 2048
205     else if (k <= 2048) {
206         IVF_INTERLEAVED_CALL(2048);
207     }
208 #endif
209 }
210 
211 } // namespace gpu
212 } // namespace faiss
213