1 /**
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * This source code is licensed under the MIT license found in the
5  * LICENSE file in the root directory of this source tree.
6  */
7 
8 #pragma once
9 
10 #include <cuda.h>
11 #include <faiss/gpu/utils/StaticUtils.h>
12 #include <faiss/gpu/utils/DeviceDefs.cuh>
13 #include <faiss/gpu/utils/PtxUtils.cuh>
14 #include <faiss/gpu/utils/ReductionOperators.cuh>
15 #include <faiss/gpu/utils/WarpShuffles.cuh>
16 
17 namespace faiss {
18 namespace gpu {
19 
20 template <typename T, typename Op, int ReduceWidth = kWarpSize>
warpReduceAll(T val,Op op)21 __device__ inline T warpReduceAll(T val, Op op) {
22 #pragma unroll
23     for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) {
24         val = op(val, shfl_xor(val, mask));
25     }
26 
27     return val;
28 }
29 
30 /// Sums a register value across all warp threads
31 template <typename T, int ReduceWidth = kWarpSize>
warpReduceAllSum(T val)32 __device__ inline T warpReduceAllSum(T val) {
33     return warpReduceAll<T, Sum<T>, ReduceWidth>(val, Sum<T>());
34 }
35 
36 /// Performs a block-wide reduction
37 template <typename T, typename Op, bool BroadcastAll, bool KillWARDependency>
blockReduceAll(T val,Op op,T * smem)38 __device__ inline T blockReduceAll(T val, Op op, T* smem) {
39     int laneId = getLaneId();
40     int warpId = threadIdx.x / kWarpSize;
41 
42     val = warpReduceAll<T, Op>(val, op);
43     if (laneId == 0) {
44         smem[warpId] = val;
45     }
46     __syncthreads();
47 
48     if (warpId == 0) {
49         val = laneId < utils::divUp(blockDim.x, kWarpSize) ? smem[laneId]
50                                                            : op.identity();
51         val = warpReduceAll<T, Op>(val, op);
52 
53         if (BroadcastAll) {
54             __threadfence_block();
55 
56             if (laneId == 0) {
57                 smem[0] = val;
58             }
59         }
60     }
61 
62     if (BroadcastAll) {
63         __syncthreads();
64         val = smem[0];
65     }
66 
67     if (KillWARDependency) {
68         __syncthreads();
69     }
70 
71     return val;
72 }
73 
74 /// Performs a block-wide reduction of multiple values simultaneously
75 template <
76         int Num,
77         typename T,
78         typename Op,
79         bool BroadcastAll,
80         bool KillWARDependency>
blockReduceAll(T val[Num],Op op,T * smem)81 __device__ inline void blockReduceAll(T val[Num], Op op, T* smem) {
82     int laneId = getLaneId();
83     int warpId = threadIdx.x / kWarpSize;
84 
85 #pragma unroll
86     for (int i = 0; i < Num; ++i) {
87         val[i] = warpReduceAll<T, Op>(val[i], op);
88     }
89 
90     if (laneId == 0) {
91 #pragma unroll
92         for (int i = 0; i < Num; ++i) {
93             smem[warpId * Num + i] = val[i];
94         }
95     }
96 
97     __syncthreads();
98 
99     if (warpId == 0) {
100 #pragma unroll
101         for (int i = 0; i < Num; ++i) {
102             val[i] = laneId < utils::divUp(blockDim.x, kWarpSize)
103                     ? smem[laneId * Num + i]
104                     : op.identity();
105             val[i] = warpReduceAll<T, Op>(val[i], op);
106         }
107 
108         if (BroadcastAll) {
109             __threadfence_block();
110 
111             if (laneId == 0) {
112 #pragma unroll
113                 for (int i = 0; i < Num; ++i) {
114                     smem[i] = val[i];
115                 }
116             }
117         }
118     }
119 
120     if (BroadcastAll) {
121         __syncthreads();
122 #pragma unroll
123         for (int i = 0; i < Num; ++i) {
124             val[i] = smem[i];
125         }
126     }
127 
128     if (KillWARDependency) {
129         __syncthreads();
130     }
131 }
132 
133 /// Sums a register value across the entire block
134 template <typename T, bool BroadcastAll, bool KillWARDependency>
blockReduceAllSum(T val,T * smem)135 __device__ inline T blockReduceAllSum(T val, T* smem) {
136     return blockReduceAll<T, Sum<T>, BroadcastAll, KillWARDependency>(
137             val, Sum<T>(), smem);
138 }
139 
140 template <int Num, typename T, bool BroadcastAll, bool KillWARDependency>
blockReduceAllSum(T vals[Num],T * smem)141 __device__ inline void blockReduceAllSum(T vals[Num], T* smem) {
142     return blockReduceAll<Num, T, Sum<T>, BroadcastAll, KillWARDependency>(
143             vals, Sum<T>(), smem);
144 }
145 
146 } // namespace gpu
147 } // namespace faiss
148