1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include <cuda.h>
11 #include <faiss/gpu/utils/StaticUtils.h>
12 #include <faiss/impl/FaissAssert.h>
13 #include <faiss/gpu/utils/DeviceDefs.cuh>
14 #include <faiss/gpu/utils/MergeNetworkUtils.cuh>
15 #include <faiss/gpu/utils/PtxUtils.cuh>
16 #include <faiss/gpu/utils/WarpShuffles.cuh>
17 
18 namespace faiss {
19 namespace gpu {
20 
21 // Merge pairs of lists smaller than blockDim.x (NumThreads)
22 template <
23         int NumThreads,
24         typename K,
25         typename V,
26         int N,
27         int L,
28         bool AllThreads,
29         bool Dir,
30         typename Comp,
31         bool FullMerge>
blockMergeSmall(K * listK,V * listV)32 inline __device__ void blockMergeSmall(K* listK, V* listV) {
33     static_assert(utils::isPowerOf2(L), "L must be a power-of-2");
34     static_assert(
35             utils::isPowerOf2(NumThreads), "NumThreads must be a power-of-2");
36     static_assert(L <= NumThreads, "merge list size must be <= NumThreads");
37 
38     // Which pair of lists we are merging
39     int mergeId = threadIdx.x / L;
40 
41     // Which thread we are within the merge
42     int tid = threadIdx.x % L;
43 
44     // listK points to a region of size N * 2 * L
45     listK += 2 * L * mergeId;
46     listV += 2 * L * mergeId;
47 
48     // It's not a bitonic merge, both lists are in the same direction,
49     // so handle the first swap assuming the second list is reversed
50     int pos = L - 1 - tid;
51     int stride = 2 * tid + 1;
52 
53     if (AllThreads || (threadIdx.x < N * L)) {
54         K ka = listK[pos];
55         K kb = listK[pos + stride];
56 
57         bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
58         listK[pos] = swap ? kb : ka;
59         listK[pos + stride] = swap ? ka : kb;
60 
61         V va = listV[pos];
62         V vb = listV[pos + stride];
63         listV[pos] = swap ? vb : va;
64         listV[pos + stride] = swap ? va : vb;
65 
66         // FIXME: is this a CUDA 9 compiler bug?
67         // K& ka = listK[pos];
68         // K& kb = listK[pos + stride];
69 
70         // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
71         // swap(s, ka, kb);
72 
73         // V& va = listV[pos];
74         // V& vb = listV[pos + stride];
75         // swap(s, va, vb);
76     }
77 
78     __syncthreads();
79 
80 #pragma unroll
81     for (int stride = L / 2; stride > 0; stride /= 2) {
82         int pos = 2 * tid - (tid & (stride - 1));
83 
84         if (AllThreads || (threadIdx.x < N * L)) {
85             K ka = listK[pos];
86             K kb = listK[pos + stride];
87 
88             bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
89             listK[pos] = swap ? kb : ka;
90             listK[pos + stride] = swap ? ka : kb;
91 
92             V va = listV[pos];
93             V vb = listV[pos + stride];
94             listV[pos] = swap ? vb : va;
95             listV[pos + stride] = swap ? va : vb;
96 
97             // FIXME: is this a CUDA 9 compiler bug?
98             // K& ka = listK[pos];
99             // K& kb = listK[pos + stride];
100 
101             // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
102             // swap(s, ka, kb);
103 
104             // V& va = listV[pos];
105             // V& vb = listV[pos + stride];
106             // swap(s, va, vb);
107         }
108 
109         __syncthreads();
110     }
111 }
112 
113 // Merge pairs of sorted lists larger than blockDim.x (NumThreads)
114 template <
115         int NumThreads,
116         typename K,
117         typename V,
118         int L,
119         bool Dir,
120         typename Comp,
121         bool FullMerge>
blockMergeLarge(K * listK,V * listV)122 inline __device__ void blockMergeLarge(K* listK, V* listV) {
123     static_assert(utils::isPowerOf2(L), "L must be a power-of-2");
124     static_assert(L >= kWarpSize, "merge list size must be >= 32");
125     static_assert(
126             utils::isPowerOf2(NumThreads), "NumThreads must be a power-of-2");
127     static_assert(L >= NumThreads, "merge list size must be >= NumThreads");
128 
129     // For L > NumThreads, each thread has to perform more work
130     // per each stride.
131     constexpr int kLoopPerThread = L / NumThreads;
132 
133     // It's not a bitonic merge, both lists are in the same direction,
134     // so handle the first swap assuming the second list is reversed
135 #pragma unroll
136     for (int loop = 0; loop < kLoopPerThread; ++loop) {
137         int tid = loop * NumThreads + threadIdx.x;
138         int pos = L - 1 - tid;
139         int stride = 2 * tid + 1;
140 
141         K ka = listK[pos];
142         K kb = listK[pos + stride];
143 
144         bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
145         listK[pos] = swap ? kb : ka;
146         listK[pos + stride] = swap ? ka : kb;
147 
148         V va = listV[pos];
149         V vb = listV[pos + stride];
150         listV[pos] = swap ? vb : va;
151         listV[pos + stride] = swap ? va : vb;
152 
153         // FIXME: is this a CUDA 9 compiler bug?
154         // K& ka = listK[pos];
155         // K& kb = listK[pos + stride];
156 
157         // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
158         // swap(s, ka, kb);
159 
160         // V& va = listV[pos];
161         // V& vb = listV[pos + stride];
162         // swap(s, va, vb);
163     }
164 
165     __syncthreads();
166 
167     constexpr int kSecondLoopPerThread =
168             FullMerge ? kLoopPerThread : kLoopPerThread / 2;
169 
170 #pragma unroll
171     for (int stride = L / 2; stride > 0; stride /= 2) {
172 #pragma unroll
173         for (int loop = 0; loop < kSecondLoopPerThread; ++loop) {
174             int tid = loop * NumThreads + threadIdx.x;
175             int pos = 2 * tid - (tid & (stride - 1));
176 
177             K ka = listK[pos];
178             K kb = listK[pos + stride];
179 
180             bool swap = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
181             listK[pos] = swap ? kb : ka;
182             listK[pos + stride] = swap ? ka : kb;
183 
184             V va = listV[pos];
185             V vb = listV[pos + stride];
186             listV[pos] = swap ? vb : va;
187             listV[pos + stride] = swap ? va : vb;
188 
189             // FIXME: is this a CUDA 9 compiler bug?
190             // K& ka = listK[pos];
191             // K& kb = listK[pos + stride];
192 
193             // bool s = Dir ? Comp::gt(ka, kb) : Comp::lt(ka, kb);
194             // swap(s, ka, kb);
195 
196             // V& va = listV[pos];
197             // V& vb = listV[pos + stride];
198             // swap(s, va, vb);
199         }
200 
201         __syncthreads();
202     }
203 }
204 
205 /// Class template to prevent static_assert from firing for
206 /// mixing smaller/larger than block cases
207 template <
208         int NumThreads,
209         typename K,
210         typename V,
211         int N,
212         int L,
213         bool Dir,
214         typename Comp,
215         bool SmallerThanBlock,
216         bool FullMerge>
217 struct BlockMerge {};
218 
219 /// Merging lists smaller than a block
220 template <
221         int NumThreads,
222         typename K,
223         typename V,
224         int N,
225         int L,
226         bool Dir,
227         typename Comp,
228         bool FullMerge>
229 struct BlockMerge<NumThreads, K, V, N, L, Dir, Comp, true, FullMerge> {
mergefaiss::gpu::BlockMerge230     static inline __device__ void merge(K* listK, V* listV) {
231         constexpr int kNumParallelMerges = NumThreads / L;
232         constexpr int kNumIterations = N / kNumParallelMerges;
233 
234         static_assert(L <= NumThreads, "list must be <= NumThreads");
235         static_assert(
236                 (N < kNumParallelMerges) ||
237                         (kNumIterations * kNumParallelMerges == N),
238                 "improper selection of N and L");
239 
240         if (N < kNumParallelMerges) {
241             // We only need L threads per each list to perform the merge
242             blockMergeSmall<
243                     NumThreads,
244                     K,
245                     V,
246                     N,
247                     L,
248                     false,
249                     Dir,
250                     Comp,
251                     FullMerge>(listK, listV);
252         } else {
253             // All threads participate
254 #pragma unroll
255             for (int i = 0; i < kNumIterations; ++i) {
256                 int start = i * kNumParallelMerges * 2 * L;
257 
258                 blockMergeSmall<
259                         NumThreads,
260                         K,
261                         V,
262                         N,
263                         L,
264                         true,
265                         Dir,
266                         Comp,
267                         FullMerge>(listK + start, listV + start);
268             }
269         }
270     }
271 };
272 
273 /// Merging lists larger than a block
274 template <
275         int NumThreads,
276         typename K,
277         typename V,
278         int N,
279         int L,
280         bool Dir,
281         typename Comp,
282         bool FullMerge>
283 struct BlockMerge<NumThreads, K, V, N, L, Dir, Comp, false, FullMerge> {
mergefaiss::gpu::BlockMerge284     static inline __device__ void merge(K* listK, V* listV) {
285         // Each pair of lists is merged sequentially
286 #pragma unroll
287         for (int i = 0; i < N; ++i) {
288             int start = i * 2 * L;
289 
290             blockMergeLarge<NumThreads, K, V, L, Dir, Comp, FullMerge>(
291                     listK + start, listV + start);
292         }
293     }
294 };
295 
296 template <
297         int NumThreads,
298         typename K,
299         typename V,
300         int N,
301         int L,
302         bool Dir,
303         typename Comp,
304         bool FullMerge = true>
blockMerge(K * listK,V * listV)305 inline __device__ void blockMerge(K* listK, V* listV) {
306     constexpr bool kSmallerThanBlock = (L <= NumThreads);
307 
308     BlockMerge<
309             NumThreads,
310             K,
311             V,
312             N,
313             L,
314             Dir,
315             Comp,
316             kSmallerThanBlock,
317             FullMerge>::merge(listK, listV);
318 }
319 
320 } // namespace gpu
321 } // namespace faiss
322