1 /**
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 *
4 * This source code is licensed under the MIT license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
8 #pragma once
9
10 #include <cuda.h>
11 #include <faiss/gpu/utils/StaticUtils.h>
12 #include <faiss/impl/FaissAssert.h>
13 #include <faiss/gpu/utils/DeviceDefs.cuh>
14 #include <faiss/gpu/utils/MergeNetworkUtils.cuh>
15 #include <faiss/gpu/utils/PtxUtils.cuh>
16 #include <faiss/gpu/utils/WarpShuffles.cuh>
17
18 namespace faiss {
19 namespace gpu {
20
21 // Merge pairs of lists smaller than blockDim.x (NumThreads)
22 template <
23 int NumThreads,
24 typename K,
25 typename V,
26 int N,
27 int L,
28 bool AllThreads,
29 bool Dir,
30 typename Comp,
31 bool FullMerge>
blockMergeSmall(K * listK,V * listV)32 inline __device__ void blockMergeSmall(K* listK, V* listV) {
33 static_assert(utils::isPowerOf2(L), "L must be a power-of-2");
34 static_assert(
35 utils::isPowerOf2(NumThreads), "NumThreads must be a power-of-2");
36 static_assert(L <= NumThreads, "merge list size must be <= NumThreads");
37
38 // Which pair of lists we are merging
39 int mergeId = threadIdx.x / L;
40
41 // Which thread we are within the merge
42 int tid = threadIdx.x % L;
43
44 // listK points to a region of size N * 2 * L
45 listK += 2 * L * mergeId;
46 listV += 2 * L * mergeId;
47
48 // It's not a bitonic merge, both lists are in the same direction,
49 // so handle the first swap assuming the second list is reversed
50 int pos = L - 1 - tid;
51 int stride = 2 * tid + 1;
52
53 if (AllThreads || (threadIdx.x < N * L)) {
54 K ka = listK[pos];
55 K kb = listK[pos + stride];
56
57 bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
58 listK[pos] = swap ? kb : ka;
59 listK[pos + stride] = swap ? ka : kb;
60
61 V va = listV[pos];
62 V vb = listV[pos + stride];
63 listV[pos] = swap ? vb : va;
64 listV[pos + stride] = swap ? va : vb;
65
66 // FIXME: is this a CUDA 9 compiler bug?
67 // K& ka = listK[pos];
68 // K& kb = listK[pos + stride];
69
70 // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
71 // swap(s, ka, kb);
72
73 // V& va = listV[pos];
74 // V& vb = listV[pos + stride];
75 // swap(s, va, vb);
76 }
77
78 __syncthreads();
79
80 #pragma unroll
81 for (int stride = L / 2; stride > 0; stride /= 2) {
82 int pos = 2 * tid - (tid & (stride - 1));
83
84 if (AllThreads || (threadIdx.x < N * L)) {
85 K ka = listK[pos];
86 K kb = listK[pos + stride];
87
88 bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
89 listK[pos] = swap ? kb : ka;
90 listK[pos + stride] = swap ? ka : kb;
91
92 V va = listV[pos];
93 V vb = listV[pos + stride];
94 listV[pos] = swap ? vb : va;
95 listV[pos + stride] = swap ? va : vb;
96
97 // FIXME: is this a CUDA 9 compiler bug?
98 // K& ka = listK[pos];
99 // K& kb = listK[pos + stride];
100
101 // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
102 // swap(s, ka, kb);
103
104 // V& va = listV[pos];
105 // V& vb = listV[pos + stride];
106 // swap(s, va, vb);
107 }
108
109 __syncthreads();
110 }
111 }
112
113 // Merge pairs of sorted lists larger than blockDim.x (NumThreads)
114 template <
115 int NumThreads,
116 typename K,
117 typename V,
118 int L,
119 bool Dir,
120 typename Comp,
121 bool FullMerge>
blockMergeLarge(K * listK,V * listV)122 inline __device__ void blockMergeLarge(K* listK, V* listV) {
123 static_assert(utils::isPowerOf2(L), "L must be a power-of-2");
124 static_assert(L >= kWarpSize, "merge list size must be >= 32");
125 static_assert(
126 utils::isPowerOf2(NumThreads), "NumThreads must be a power-of-2");
127 static_assert(L >= NumThreads, "merge list size must be >= NumThreads");
128
129 // For L > NumThreads, each thread has to perform more work
130 // per each stride.
131 constexpr int kLoopPerThread = L / NumThreads;
132
133 // It's not a bitonic merge, both lists are in the same direction,
134 // so handle the first swap assuming the second list is reversed
135 #pragma unroll
136 for (int loop = 0; loop < kLoopPerThread; ++loop) {
137 int tid = loop * NumThreads + threadIdx.x;
138 int pos = L - 1 - tid;
139 int stride = 2 * tid + 1;
140
141 K ka = listK[pos];
142 K kb = listK[pos + stride];
143
144 bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
145 listK[pos] = swap ? kb : ka;
146 listK[pos + stride] = swap ? ka : kb;
147
148 V va = listV[pos];
149 V vb = listV[pos + stride];
150 listV[pos] = swap ? vb : va;
151 listV[pos + stride] = swap ? va : vb;
152
153 // FIXME: is this a CUDA 9 compiler bug?
154 // K& ka = listK[pos];
155 // K& kb = listK[pos + stride];
156
157 // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
158 // swap(s, ka, kb);
159
160 // V& va = listV[pos];
161 // V& vb = listV[pos + stride];
162 // swap(s, va, vb);
163 }
164
165 __syncthreads();
166
167 constexpr int kSecondLoopPerThread =
168 FullMerge ? kLoopPerThread : kLoopPerThread / 2;
169
170 #pragma unroll
171 for (int stride = L / 2; stride > 0; stride /= 2) {
172 #pragma unroll
173 for (int loop = 0; loop < kSecondLoopPerThread; ++loop) {
174 int tid = loop * NumThreads + threadIdx.x;
175 int pos = 2 * tid - (tid & (stride - 1));
176
177 K ka = listK[pos];
178 K kb = listK[pos + stride];
179
180 bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
181 listK[pos] = swap ? kb : ka;
182 listK[pos + stride] = swap ? ka : kb;
183
184 V va = listV[pos];
185 V vb = listV[pos + stride];
186 listV[pos] = swap ? vb : va;
187 listV[pos + stride] = swap ? va : vb;
188
189 // FIXME: is this a CUDA 9 compiler bug?
190 // K& ka = listK[pos];
191 // K& kb = listK[pos + stride];
192
193 // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
194 // swap(s, ka, kb);
195
196 // V& va = listV[pos];
197 // V& vb = listV[pos + stride];
198 // swap(s, va, vb);
199 }
200
201 __syncthreads();
202 }
203 }
204
205 /// Class template to prevent static_assert from firing for
206 /// mixing smaller/larger than block cases
207 template <
208 int NumThreads,
209 typename K,
210 typename V,
211 int N,
212 int L,
213 bool Dir,
214 typename Comp,
215 bool SmallerThanBlock,
216 bool FullMerge>
217 struct BlockMerge {};
218
219 /// Merging lists smaller than a block
220 template <
221 int NumThreads,
222 typename K,
223 typename V,
224 int N,
225 int L,
226 bool Dir,
227 typename Comp,
228 bool FullMerge>
229 struct BlockMerge<NumThreads, K, V, N, L, Dir, Comp, true, FullMerge> {
mergefaiss::gpu::BlockMerge230 static inline __device__ void merge(K* listK, V* listV) {
231 constexpr int kNumParallelMerges = NumThreads / L;
232 constexpr int kNumIterations = N / kNumParallelMerges;
233
234 static_assert(L <= NumThreads, "list must be <= NumThreads");
235 static_assert(
236 (N < kNumParallelMerges) ||
237 (kNumIterations * kNumParallelMerges == N),
238 "improper selection of N and L");
239
240 if (N < kNumParallelMerges) {
241 // We only need L threads per each list to perform the merge
242 blockMergeSmall<
243 NumThreads,
244 K,
245 V,
246 N,
247 L,
248 false,
249 Dir,
250 Comp,
251 FullMerge>(listK, listV);
252 } else {
253 // All threads participate
254 #pragma unroll
255 for (int i = 0; i < kNumIterations; ++i) {
256 int start = i * kNumParallelMerges * 2 * L;
257
258 blockMergeSmall<
259 NumThreads,
260 K,
261 V,
262 N,
263 L,
264 true,
265 Dir,
266 Comp,
267 FullMerge>(listK + start, listV + start);
268 }
269 }
270 }
271 };
272
273 /// Merging lists larger than a block
274 template <
275 int NumThreads,
276 typename K,
277 typename V,
278 int N,
279 int L,
280 bool Dir,
281 typename Comp,
282 bool FullMerge>
283 struct BlockMerge<NumThreads, K, V, N, L, Dir, Comp, false, FullMerge> {
mergefaiss::gpu::BlockMerge284 static inline __device__ void merge(K* listK, V* listV) {
285 // Each pair of lists is merged sequentially
286 #pragma unroll
287 for (int i = 0; i < N; ++i) {
288 int start = i * 2 * L;
289
290 blockMergeLarge<NumThreads, K, V, L, Dir, Comp, FullMerge>(
291 listK + start, listV + start);
292 }
293 }
294 };
295
296 template <
297 int NumThreads,
298 typename K,
299 typename V,
300 int N,
301 int L,
302 bool Dir,
303 typename Comp,
304 bool FullMerge = true>
blockMerge(K * listK,V * listV)305 inline __device__ void blockMerge(K* listK, V* listV) {
306 constexpr bool kSmallerThanBlock = (L <= NumThreads);
307
308 BlockMerge<
309 NumThreads,
310 K,
311 V,
312 N,
313 L,
314 Dir,
315 Comp,
316 kSmallerThanBlock,
317 FullMerge>::merge(listK, listV);
318 }
319
320 } // namespace gpu
321 } // namespace faiss
322