1 /*
2  * -----------------------------------------------------------------
3  * Programmer(s): Cody J. Balos @ LLNL
4  * -----------------------------------------------------------------
5  * SUNDIALS Copyright Start
6  * Copyright (c) 2002-2021, Lawrence Livermore National Security
7  * and Southern Methodist University.
8  * All rights reserved.
9  *
10  * See the top-level LICENSE and NOTICE files for details.
11  *
12  * SPDX-License-Identifier: BSD-3-Clause
13  * SUNDIALS Copyright End
14  * -----------------------------------------------------------------
15  */
16 
17 #ifndef _SUNDIALS_HIP_KERNELS_HIP_HPP
18 #define _SUNDIALS_HIP_KERNELS_HIP_HPP
19 
20 #include "sundials_hip.h"
21 
22 #define GRID_STRIDE_XLOOP(type, iter, max)  \
23   for (type iter = blockDim.x * blockIdx.x + threadIdx.x; \
24        iter < max; \
25        iter += blockDim.x * gridDim.x)
26 
27 /* Uses correct __shfl_down depending on architecture being used */
28 #if defined(__CUDA_ARCH__)
29   #define _SHFL_DOWN(val,offset) (__shfl_down_sync(0xFFFFFFFF, val, offset))
30 #else
31   #define _SHFL_DOWN(val,offset) (__shfl_down(val, offset))
32 #endif
33 
34 namespace sundials
35 {
36 namespace hip
37 {
38 
39 /* The atomic functions below are implemented using the atomic compare and swap
40    function atomicCAS which performs an atomic version of
41    (*address == assumed) ? (assumed + val) : *address. Since *address could change
42    between when the value is loaded and the atomicCAS call the operation is repeated
43    until *address does not change between the read and the compare and swap operation. */
44 
45 typedef enum { RSUM, RMAX, RMIN } BinaryReductionOp;
46 
47 #if defined(__CUDA_ARCH__) and __CUDA_ARCH__ < 600
48 __forceinline__ __device__
atomicAdd(double * address,double val)49 double atomicAdd(double* address, double val)
50 {
51   unsigned long long int* address_as_ull = (unsigned long long int*)address;
52   unsigned long long int old = *address_as_ull, assumed;
53 
54   do {
55       assumed = old;
56       old = atomicCAS(address_as_ull, assumed,
57                       __double_as_longlong(val +
58                               __longlong_as_double(assumed)));
59   // Note: uses integer comparison to avoid hang in case of NaN (since NaN != NaN)
60   } while (assumed != old);
61 
62   return __longlong_as_double(old);
63 }
64 #endif
65 
66 /*
67  * Compute the maximum of 2 double-precision floating point values using an atomic operation
68  * "address" is the address of the reference value which might get updated with the maximum
69  * "value" is the value that is compared to the reference in order to determine the maximum
70  */
71 __forceinline__ __device__
AtomicMax(double * const address,const double value)72 void AtomicMax(double* const address, const double value)
73 {
74   if (*address >= value)
75   {
76     return;
77   }
78 
79   unsigned long long * const address_as_i = (unsigned long long *)address;
80   unsigned long long old = * address_as_i, assumed;
81 
82   do
83   {
84     assumed = old;
85     if (__longlong_as_double(assumed) >= value)
86     {
87       break;
88     }
89     old = atomicCAS(address_as_i, assumed, __double_as_longlong(value));
90   } while (assumed != old);
91 }
92 
93 /*
94  * Compute the maximum of 2 single-precision floating point values using an atomic operation
95  * "address" is the address of the reference value which might get updated with the maximum
96  * "value" is the value that is compared to the reference in order to determine the maximum
97  */
98  __forceinline__ __device__
AtomicMax(float * const address,const float value)99 void AtomicMax(float* const address, const float value)
100 {
101   if (*address >= value)
102   {
103     return;
104   }
105 
106   unsigned int* const address_as_i = (unsigned int *)address;
107   unsigned int old = *address_as_i, assumed;
108 
109   do
110   {
111     assumed = old;
112     if (__int_as_float(assumed) >= value)
113     {
114       break;
115     }
116     old = atomicCAS(address_as_i, assumed, __float_as_int(value));
117   } while (assumed != old);
118 }
119 
120 /*
121  * Compute the minimum of 2 double-precision floating point values using an atomic operation
122  * "address" is the address of the reference value which might get updated with the minimum
123  * "value" is the value that is compared to the reference in order to determine the minimum
124  */
125 __forceinline__ __device__
AtomicMin(double * const address,const double value)126 void AtomicMin(double* const address, const double value)
127 {
128   if (*address <= value)
129   {
130     return;
131   }
132 
133   unsigned long long* const address_as_i = (unsigned long long *)address;
134   unsigned long long old = *address_as_i, assumed;
135 
136   do
137   {
138     assumed = old;
139     if (__longlong_as_double(assumed) <= value)
140     {
141       break;
142     }
143     old = atomicCAS(address_as_i, assumed, __double_as_longlong(value));
144   } while (assumed != old);
145 }
146 
147 /*
148  * Compute the minimum of 2 single-precision floating point values using an atomic operation
149  * "address" is the address of the reference value which might get updated with the minimum
150  * "value" is the value that is compared to the reference in order to determine the minimum
151  */
152 __forceinline__ __device__
AtomicMin(float * const address,const float value)153 void AtomicMin(float* const address, const float value)
154 {
155   if (*address <= value)
156   {
157     return;
158   }
159 
160   unsigned int* const address_as_i = (unsigned int *)address;
161   unsigned int old = *address_as_i, assumed;
162 
163   do
164   {
165     assumed = old;
166     if (__int_as_float(assumed) <= value)
167     {
168       break;
169     }
170     old = atomicCAS(address_as_i, assumed, __float_as_int(value));
171   } while (assumed != old);
172 }
173 
174 /*
175  * Perform a reduce on the warp to get the sum.
176  */
177 template <typename T>
178 __inline__ __device__
warpReduceSum(T val)179 T warpReduceSum(T val)
180 {
181   for (int offset = warpSize/2; offset > 0; offset /= 2)
182     val += _SHFL_DOWN(val, offset);
183   return val;
184 }
185 
186 /*
187  * Perform a reduce on the warp to get the maximum value.
188  */
189 template<typename T>
190 __inline__ __device__
warpReduceMax(T val)191 T warpReduceMax(T val)
192 {
193   for (int offset = warpSize/2; offset > 0; offset /= 2)
194     val = max(_SHFL_DOWN(val, offset), val);
195   return val;
196 }
197 
198 /*
199  * Perform a reduce on the warp to get the minimum value.
200  */
201 template<typename T>
202 __inline__ __device__
warpReduceMin(T val)203 T warpReduceMin(T val)
204 {
205   for (int offset = warpSize/2; offset > 0; offset /= 2)
206     val = min(_SHFL_DOWN(val, offset), val);
207   return val;
208 }
209 
210 /*
211  * Reduce value across the thread block.
212  */
213 template <typename T, BinaryReductionOp op>
214 __inline__ __device__
blockReduce(T val,T default_value)215 T blockReduce(T val, T default_value)
216 {
217   // Shared memory for the partial sums
218   static __shared__ T shared[warpSize];
219 
220   int lane = threadIdx.x % warpSize; // thread lane within warp
221   int wid = threadIdx.x / warpSize;  // warp ID
222 
223   // Each warp performs partial reduction
224   switch(op)
225   {
226     case RSUM:
227       val = warpReduceSum<T>(val);
228       break;
229     case RMAX:
230       val = warpReduceMax<T>(val);
231       break;
232     case RMIN:
233       val = warpReduceMin<T>(val);
234       break;
235     default:
236       asm("trap;"); // illegal instruction
237       break;
238   }
239 
240   // Write reduced value from each warp to shared memory
241   if (lane == 0) shared[wid] = val;
242 
243   // Wait for all partial reductions to complete
244   __syncthreads();
245 
246   // Read from shared memory only if that warp existed
247   val = (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : default_value;
248 
249   // Final reduce within first warp
250   if (wid == 0)
251   {
252     switch(op)
253     {
254       case RSUM:
255         val = warpReduceSum<T>(val);
256         break;
257       case RMAX:
258         val = warpReduceMax<T>(val);
259         break;
260       case RMIN:
261         val = warpReduceMin<T>(val);
262         break;
263       default:
264         asm("trap;"); // illegal instruction
265         break;
266     }
267   }
268 
269   return val;
270 }
271 
272 } // namespace hip
273 } // namespace sundials
274 
275 #endif // _SUNDIALS_HIP_KERNELS_HIP_HPP
276