1 /**
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 *
4 * This source code is licensed under the MIT license found in the
5 * LICENSE file in the root directory of this source tree.
6 */
7
8 #pragma once
9
10 #include <cuda.h>
11 #include <faiss/gpu/utils/StaticUtils.h>
12 #include <faiss/gpu/utils/DeviceDefs.cuh>
13 #include <faiss/gpu/utils/PtxUtils.cuh>
14 #include <faiss/gpu/utils/ReductionOperators.cuh>
15 #include <faiss/gpu/utils/WarpShuffles.cuh>
16
17 namespace faiss {
18 namespace gpu {
19
20 template <typename T, typename Op, int ReduceWidth = kWarpSize>
warpReduceAll(T val,Op op)21 __device__ inline T warpReduceAll(T val, Op op) {
22 #pragma unroll
23 for (int mask = ReduceWidth / 2; mask > 0; mask >>= 1) {
24 val = op(val, shfl_xor(val, mask));
25 }
26
27 return val;
28 }
29
30 /// Sums a register value across all warp threads
31 template <typename T, int ReduceWidth = kWarpSize>
warpReduceAllSum(T val)32 __device__ inline T warpReduceAllSum(T val) {
33 return warpReduceAll<T, Sum<T>, ReduceWidth>(val, Sum<T>());
34 }
35
36 /// Performs a block-wide reduction
37 template <typename T, typename Op, bool BroadcastAll, bool KillWARDependency>
blockReduceAll(T val,Op op,T * smem)38 __device__ inline T blockReduceAll(T val, Op op, T* smem) {
39 int laneId = getLaneId();
40 int warpId = threadIdx.x / kWarpSize;
41
42 val = warpReduceAll<T, Op>(val, op);
43 if (laneId == 0) {
44 smem[warpId] = val;
45 }
46 __syncthreads();
47
48 if (warpId == 0) {
49 val = laneId < utils::divUp(blockDim.x, kWarpSize) ? smem[laneId]
50 : op.identity();
51 val = warpReduceAll<T, Op>(val, op);
52
53 if (BroadcastAll) {
54 __threadfence_block();
55
56 if (laneId == 0) {
57 smem[0] = val;
58 }
59 }
60 }
61
62 if (BroadcastAll) {
63 __syncthreads();
64 val = smem[0];
65 }
66
67 if (KillWARDependency) {
68 __syncthreads();
69 }
70
71 return val;
72 }
73
74 /// Performs a block-wide reduction of multiple values simultaneously
75 template <
76 int Num,
77 typename T,
78 typename Op,
79 bool BroadcastAll,
80 bool KillWARDependency>
blockReduceAll(T val[Num],Op op,T * smem)81 __device__ inline void blockReduceAll(T val[Num], Op op, T* smem) {
82 int laneId = getLaneId();
83 int warpId = threadIdx.x / kWarpSize;
84
85 #pragma unroll
86 for (int i = 0; i < Num; ++i) {
87 val[i] = warpReduceAll<T, Op>(val[i], op);
88 }
89
90 if (laneId == 0) {
91 #pragma unroll
92 for (int i = 0; i < Num; ++i) {
93 smem[warpId * Num + i] = val[i];
94 }
95 }
96
97 __syncthreads();
98
99 if (warpId == 0) {
100 #pragma unroll
101 for (int i = 0; i < Num; ++i) {
102 val[i] = laneId < utils::divUp(blockDim.x, kWarpSize)
103 ? smem[laneId * Num + i]
104 : op.identity();
105 val[i] = warpReduceAll<T, Op>(val[i], op);
106 }
107
108 if (BroadcastAll) {
109 __threadfence_block();
110
111 if (laneId == 0) {
112 #pragma unroll
113 for (int i = 0; i < Num; ++i) {
114 smem[i] = val[i];
115 }
116 }
117 }
118 }
119
120 if (BroadcastAll) {
121 __syncthreads();
122 #pragma unroll
123 for (int i = 0; i < Num; ++i) {
124 val[i] = smem[i];
125 }
126 }
127
128 if (KillWARDependency) {
129 __syncthreads();
130 }
131 }
132
133 /// Sums a register value across the entire block
134 template <typename T, bool BroadcastAll, bool KillWARDependency>
blockReduceAllSum(T val,T * smem)135 __device__ inline T blockReduceAllSum(T val, T* smem) {
136 return blockReduceAll<T, Sum<T>, BroadcastAll, KillWARDependency>(
137 val, Sum<T>(), smem);
138 }
139
140 template <int Num, typename T, bool BroadcastAll, bool KillWARDependency>
blockReduceAllSum(T vals[Num],T * smem)141 __device__ inline void blockReduceAllSum(T vals[Num], T* smem) {
142 return blockReduceAll<Num, T, Sum<T>, BroadcastAll, KillWARDependency>(
143 vals, Sum<T>(), smem);
144 }
145
146 } // namespace gpu
147 } // namespace faiss
148