1 #include "chainerx/cuda/cuda_device.h"
2 
3 #include <cstdint>
4 
5 #include <cuda_runtime.h>
6 
7 #include "chainerx/arithmetic_ops.h"
8 #include "chainerx/array.h"
9 #include "chainerx/cuda/cuda_runtime.h"
10 #include "chainerx/cuda/cuda_set_device_scope.h"
11 #include "chainerx/cuda/elementwise.cuh"
12 #include "chainerx/cuda/float16.cuh"
13 #include "chainerx/cuda/kernel_regist.h"
14 #include "chainerx/cuda/numeric.cuh"
15 #include "chainerx/device.h"
16 #include "chainerx/dtype.h"
17 #include "chainerx/kernels/arithmetic.h"
18 #include "chainerx/routines/arithmetic.h"
19 #include "chainerx/scalar.h"
20 
21 namespace chainerx {
22 namespace cuda {
23 namespace {
24 
25 CHAINERX_CUDA_REGISTER_ELTWISE_BINARY_KERNEL(AddKernel, { out = ArithmeticOps<CudaType>::Add(x1, x2); });
26 
27 template <typename T>
28 struct AddASImpl {
29     using CudaType = cuda_internal::DataType<T>;
operator ()chainerx::cuda::__anon8d8019610111::AddASImpl30     __device__ void operator()(int64_t /*i*/, CudaType x1, CudaType& out) { out = ArithmeticOps<CudaType>::Add(x1, x2); }
31     CudaType x2;
32 };
33 
34 class CudaAddASKernel : public AddASKernel {
35 public:
Call(const Array & x1,Scalar x2,const Array & out)36     void Call(const Array& x1, Scalar x2, const Array& out) override {
37         Device& device = x1.device();
38         device.CheckDevicesCompatible(x1, out);
39         const Array& x1_cast = x1.dtype() == out.dtype() ? x1 : x1.AsType(out.dtype());
40         CudaSetDeviceScope scope{device.index()};
41         VisitDtype(out.dtype(), [&](auto pt) {
42             using T = typename decltype(pt)::type;
43             using CudaType = cuda_internal::DataType<T>;
44             Elementwise<const T, T>(AddASImpl<T>{static_cast<CudaType>(x2)}, x1_cast, out);
45         });
46     }
47 };
48 
49 CHAINERX_CUDA_REGISTER_KERNEL(AddASKernel, CudaAddASKernel);
50 
51 CHAINERX_CUDA_REGISTER_ELTWISE_DTYPE_BINARY_KERNEL(SubtractKernel, { out = ArithmeticOps<CudaType>::Subtract(x1, x2); }, VisitNumericDtype);
52 
53 template <typename T>
54 struct SubtractASImpl {
55     using CudaType = cuda_internal::DataType<T>;
operator ()chainerx::cuda::__anon8d8019610111::SubtractASImpl56     __device__ void operator()(int64_t /*i*/, CudaType x1, CudaType& out) { out = ArithmeticOps<CudaType>::Subtract(x1, x2); }
57     CudaType x2;
58 };
59 
60 class CudaSubtractASKernel : public SubtractASKernel {
61 public:
Call(const Array & x1,Scalar x2,const Array & out)62     void Call(const Array& x1, Scalar x2, const Array& out) override {
63         Device& device = x1.device();
64         device.CheckDevicesCompatible(x1, out);
65         const Array& x1_cast = x1.dtype() == out.dtype() ? x1 : x1.AsType(out.dtype());
66         CudaSetDeviceScope scope{device.index()};
67         VisitNumericDtype(out.dtype(), [&](auto pt) {
68             using T = typename decltype(pt)::type;
69             using CudaType = cuda_internal::DataType<T>;
70             Elementwise<const T, T>(SubtractASImpl<T>{static_cast<CudaType>(x2)}, x1_cast, out);
71         });
72     }
73 };
74 
75 CHAINERX_CUDA_REGISTER_KERNEL(SubtractASKernel, CudaSubtractASKernel);
76 
77 // TODO(sonots): support stream
78 CHAINERX_CUDA_REGISTER_ELTWISE_BINARY_KERNEL(MultiplyKernel, { out = ArithmeticOps<CudaType>::Multiply(x1, x2); });
79 
80 template <typename T>
81 struct MultiplyASImpl {
82     using CudaType = cuda_internal::DataType<T>;
operator ()chainerx::cuda::__anon8d8019610111::MultiplyASImpl83     __device__ void operator()(int64_t /*i*/, CudaType x1, CudaType& out) { out = ArithmeticOps<CudaType>::Multiply(x1, x2); }
84     CudaType x2;
85 };
86 
87 class CudaMultiplyASKernel : public MultiplyASKernel {
88 public:
Call(const Array & x1,Scalar x2,const Array & out)89     void Call(const Array& x1, Scalar x2, const Array& out) override {
90         Device& device = x1.device();
91         device.CheckDevicesCompatible(x1, out);
92         const Array& x1_cast = x1.dtype() == out.dtype() ? x1 : x1.AsType(out.dtype());
93         CudaSetDeviceScope scope{device.index()};
94         VisitDtype(out.dtype(), [&](auto pt) {
95             using T = typename decltype(pt)::type;
96             using CudaType = cuda_internal::DataType<T>;
97             Elementwise<const T, T>(MultiplyASImpl<T>{static_cast<CudaType>(x2)}, x1_cast, out);
98         });
99     }
100 };
101 
102 CHAINERX_CUDA_REGISTER_KERNEL(MultiplyASKernel, CudaMultiplyASKernel);
103 
104 // CUDA does not have std::div, which is used for the native backend.
105 template <typename T>
FloorDivideImpl(T x,T y)106 __device__ T FloorDivideImpl(T x, T y) {
107     if (y == 0) {
108         return 0;
109     }
110     return x / y - ((y >= 0 ? x % y : -(x % y)) < 0 ? 1 : 0);
111 }
FloorDivide(int8_t x,int8_t y)112 __device__ int8_t FloorDivide(int8_t x, int8_t y) { return FloorDivideImpl(x, y); }
FloorDivide(int16_t x,int16_t y)113 __device__ int16_t FloorDivide(int16_t x, int16_t y) { return FloorDivideImpl(x, y); }
FloorDivide(int32_t x,int32_t y)114 __device__ int32_t FloorDivide(int32_t x, int32_t y) { return FloorDivideImpl(x, y); }
FloorDivide(int64_t x,int64_t y)115 __device__ int64_t FloorDivide(int64_t x, int64_t y) { return FloorDivideImpl(x, y); }
FloorDivide(uint8_t x,uint8_t y)116 __device__ uint8_t FloorDivide(uint8_t x, uint8_t y) {
117     if (y == 0) {
118         return 0;
119     }
120     return x / y;
121 }
FloorDivide(float x,float y)122 __device__ float FloorDivide(float x, float y) {
123     float rem = std::fmod(x, y);
124     return (x - rem) / y - ((rem < 0 && y > 0) || (rem > 0 && y < 0) ? 1 : 0);
125 }
FloorDivide(double x,double y)126 __device__ double FloorDivide(double x, double y) {
127     double rem = std::fmod(x, y);
128     return (x - rem) / y - ((rem < 0 && y > 0) || (rem > 0 && y < 0) ? 1 : 0);
129 }
FloorDivide(cuda::Float16 x,cuda::Float16 y)130 __device__ cuda::Float16 FloorDivide(cuda::Float16 x, cuda::Float16 y) {
131     return cuda::Float16{FloorDivide(static_cast<float>(x), static_cast<float>(y))};
132 }
133 
134 CHAINERX_CUDA_REGISTER_ELTWISE_DTYPE_BINARY_KERNEL(FloorDivideKernel, { out = cuda::FloorDivide(x1, x2); }, VisitNumericDtype);
135 
136 template <typename T>
137 struct FloorDivideASImpl {
138     using CudaType = cuda_internal::DataType<T>;
operator ()chainerx::cuda::__anon8d8019610111::FloorDivideASImpl139     __device__ void operator()(int64_t /*i*/, CudaType x1, CudaType& out) { out = cuda::FloorDivide(x1, x2); }
140     CudaType x2;
141 };
142 
143 class CudaFloorDivideASKernel : public FloorDivideASKernel {
144 public:
Call(const Array & x1,Scalar x2,const Array & out)145     void Call(const Array& x1, Scalar x2, const Array& out) override {
146         Device& device = x1.device();
147         device.CheckDevicesCompatible(x1, out);
148         const Array& x1_cast = x1.dtype() == out.dtype() ? x1 : x1.AsType(out.dtype());
149         CudaSetDeviceScope scope{device.index()};
150         VisitNumericDtype(out.dtype(), [&](auto pt) {
151             using T = typename decltype(pt)::type;
152             using CudaType = cuda_internal::DataType<T>;
153             Elementwise<const T, T>(FloorDivideASImpl<T>{static_cast<CudaType>(x2)}, x1_cast, out);
154         });
155     }
156 };
157 
158 CHAINERX_CUDA_REGISTER_KERNEL(FloorDivideASKernel, CudaFloorDivideASKernel);
159 
160 template <typename T>
161 struct FloorDivideSAImpl {
162     using CudaType = cuda_internal::DataType<T>;
operator ()chainerx::cuda::__anon8d8019610111::FloorDivideSAImpl163     __device__ void operator()(int64_t /*i*/, CudaType x2, CudaType& out) { out = cuda::FloorDivide(x1, x2); }
164     CudaType x1;
165 };
166 
167 class CudaFloorDivideSAKernel : public FloorDivideSAKernel {
168 public:
Call(Scalar x1,const Array & x2,const Array & out)169     void Call(Scalar x1, const Array& x2, const Array& out) override {
170         Device& device = x2.device();
171         device.CheckDevicesCompatible(x2, out);
172         const Array& x2_cast = x2.dtype() == out.dtype() ? x2 : x2.AsType(out.dtype());
173         CudaSetDeviceScope scope{device.index()};
174         VisitNumericDtype(out.dtype(), [&](auto pt) {
175             using T = typename decltype(pt)::type;
176             using CudaType = cuda_internal::DataType<T>;
177             Elementwise<const T, T>(FloorDivideSAImpl<T>{static_cast<CudaType>(x1)}, x2_cast, out);
178         });
179     }
180 };
181 
182 CHAINERX_CUDA_REGISTER_KERNEL(FloorDivideSAKernel, CudaFloorDivideSAKernel);
183 
184 CHAINERX_CUDA_REGISTER_ELTWISE_BINARY_KERNEL(DivideKernel, { out = ArithmeticOps<CudaType>::Divide(x1, x2); });
185 
186 template <typename T>
187 struct DivideASImpl {
188     using CudaType = cuda_internal::DataType<T>;
operator ()chainerx::cuda::__anon8d8019610111::DivideASImpl189     __device__ void operator()(int64_t /*i*/, CudaType x1, CudaType& out) { out = ArithmeticOps<CudaType>::Divide(x1, x2); }
190     CudaType x2;
191 };
192 
193 class CudaDivideASKernel : public DivideASKernel {
194 public:
Call(const Array & x1,Scalar x2,const Array & out)195     void Call(const Array& x1, Scalar x2, const Array& out) override {
196         Device& device = x1.device();
197         device.CheckDevicesCompatible(x1, out);
198         const Array& x1_cast = x1.dtype() == out.dtype() ? x1 : x1.AsType(out.dtype());
199         CudaSetDeviceScope scope{device.index()};
200         VisitDtype(out.dtype(), [&](auto pt) {
201             using T = typename decltype(pt)::type;
202             using CudaType = cuda_internal::DataType<T>;
203             Elementwise<const T, T>(DivideASImpl<T>{static_cast<CudaType>(x2)}, x1_cast, out);
204         });
205     }
206 };
207 
208 CHAINERX_CUDA_REGISTER_KERNEL(DivideASKernel, CudaDivideASKernel);
209 
210 template <typename T>
211 struct DivideSAImpl {
212     using CudaType = cuda_internal::DataType<T>;
operator ()chainerx::cuda::__anon8d8019610111::DivideSAImpl213     __device__ void operator()(int64_t /*i*/, CudaType x2, CudaType& out) { out = ArithmeticOps<CudaType>::Divide(x1, x2); }
214     CudaType x1;
215 };
216 
217 class CudaDivideSAKernel : public DivideSAKernel {
218 public:
Call(Scalar x1,const Array & x2,const Array & out)219     void Call(Scalar x1, const Array& x2, const Array& out) override {
220         Device& device = x2.device();
221         device.CheckDevicesCompatible(x2, out);
222         const Array& x2_cast = x2.dtype() == out.dtype() ? x2 : x2.AsType(out.dtype());
223         CudaSetDeviceScope scope{device.index()};
224         VisitDtype(out.dtype(), [&](auto pt) {
225             using T = typename decltype(pt)::type;
226             using CudaType = cuda_internal::DataType<T>;
227             Elementwise<const T, T>(DivideSAImpl<T>{static_cast<CudaType>(x1)}, x2_cast, out);
228         });
229     }
230 };
231 
232 CHAINERX_CUDA_REGISTER_KERNEL(DivideSAKernel, CudaDivideSAKernel);
233 
234 CHAINERX_CUDA_REGISTER_ELTWISE_DTYPE_BINARY_KERNEL(PowerKernel, { out = cuda::Power(x1, x2); }, VisitNumericDtype);
235 
236 template <typename T>
237 struct PowerASImpl {
238     using CudaType = cuda_internal::DataType<T>;
operator ()chainerx::cuda::__anon8d8019610111::PowerASImpl239     __device__ void operator()(int64_t /*i*/, CudaType x1, CudaType& out) { out = cuda::Power(x1, x2); }
240     CudaType x2;
241 };
242 
243 class CudaPowerASKernel : public PowerASKernel {
244 public:
Call(const Array & x1,Scalar x2,const Array & out)245     void Call(const Array& x1, Scalar x2, const Array& out) {
246         Device& device = x1.device();
247         device.CheckDevicesCompatible(x1, out);
248         const Array& x1_cast = x1.dtype() == out.dtype() ? x1 : x1.AsType(out.dtype());
249         CudaSetDeviceScope scope{device.index()};
250         VisitNumericDtype(out.dtype(), [&](auto pt) {
251             using T = typename decltype(pt)::type;
252             using CudaType = cuda_internal::DataType<T>;
253             Elementwise<const T, T>(PowerASImpl<T>{static_cast<CudaType>(x2)}, x1_cast, out);
254         });
255     }
256 };
257 
258 CHAINERX_CUDA_REGISTER_KERNEL(PowerASKernel, CudaPowerASKernel);
259 
260 template <typename T>
261 struct PowerSAImpl {
262     using CudaType = cuda_internal::DataType<T>;
operator ()chainerx::cuda::__anon8d8019610111::PowerSAImpl263     __device__ void operator()(int64_t /*i*/, CudaType x2, CudaType& out) { out = cuda::Power(x1, x2); }
264     CudaType x1;
265 };
266 
267 class CudaPowerSAKernel : public PowerSAKernel {
268 public:
Call(Scalar x1,const Array & x2,const Array & out)269     void Call(Scalar x1, const Array& x2, const Array& out) {
270         Device& device = x2.device();
271         device.CheckDevicesCompatible(x2, out);
272         const Array& x2_cast = x2.dtype() == out.dtype() ? x2 : x2.AsType(out.dtype());
273         CudaSetDeviceScope scope{device.index()};
274         VisitNumericDtype(out.dtype(), [&](auto pt) {
275             using T = typename decltype(pt)::type;
276             using CudaType = cuda_internal::DataType<T>;
277             Elementwise<const T, T>(PowerSAImpl<T>{static_cast<CudaType>(x1)}, x2_cast, out);
278         });
279     }
280 };
281 
282 CHAINERX_CUDA_REGISTER_KERNEL(PowerSAKernel, CudaPowerSAKernel);
283 
284 // CUDA does not have std::mod, which is used for the native backend.
285 template <typename T>
ModSignedIntegerImpl(T x,T y)286 __device__ T ModSignedIntegerImpl(T x, T y) {
287     if (x == 0 || y == 0) {
288         return 0;
289     }
290     T ret = x % y;
291     if ((ret > 0 && y < 0) || (ret < 0 && y > 0)) {
292         return y + ret;
293     }
294     return ret;
295 }
Mod(int8_t x,int8_t y)296 __device__ int8_t Mod(int8_t x, int8_t y) { return ModSignedIntegerImpl(x, y); }
Mod(int16_t x,int16_t y)297 __device__ int16_t Mod(int16_t x, int16_t y) { return ModSignedIntegerImpl(x, y); }
Mod(int32_t x,int32_t y)298 __device__ int32_t Mod(int32_t x, int32_t y) { return ModSignedIntegerImpl(x, y); }
Mod(int64_t x,int64_t y)299 __device__ int64_t Mod(int64_t x, int64_t y) { return ModSignedIntegerImpl(x, y); }
Mod(uint8_t x,uint8_t y)300 __device__ uint8_t Mod(uint8_t x, uint8_t y) {
301     if (x == 0 || y == 0) {
302         return 0;
303     }
304     return x % y;
305 }
306 template <typename T>
ModFloatImpl(T x,T y)307 __device__ T ModFloatImpl(T x, T y) {
308     if (y == 0) {
309         return NAN;
310     }
311     T ret = std::fmod(x, y);
312     if ((ret > 0 && y < 0) || (ret < 0 && y > 0)) {
313         return y + ret;
314     }
315     return ret;
316 }
Mod(double x,double y)317 __device__ double Mod(double x, double y) { return ModFloatImpl(x, y); }
Mod(float x,float y)318 __device__ float Mod(float x, float y) { return ModFloatImpl(x, y); }
Mod(cuda::Float16 x,cuda::Float16 y)319 __device__ cuda::Float16 Mod(cuda::Float16 x, cuda::Float16 y) { return cuda::Float16{Mod(static_cast<float>(x), static_cast<float>(y))}; }
320 
321 CHAINERX_CUDA_REGISTER_ELTWISE_DTYPE_BINARY_KERNEL(ModAAKernel, { out = cuda::Mod(x1, x2); }, VisitNumericDtype);
322 
323 template <typename T>
324 struct ModASImpl {
325     using CudaType = cuda_internal::DataType<T>;
operator ()chainerx::cuda::__anon8d8019610111::ModASImpl326     __device__ void operator()(int64_t /*i*/, CudaType x1, CudaType& out) { out = cuda::Mod(x1, x2); }
327     CudaType x2;
328 };
329 
330 class CudaModASKernel : public ModASKernel {
331 public:
Call(const Array & x1,Scalar x2,const Array & out)332     void Call(const Array& x1, Scalar x2, const Array& out) override {
333         Device& device = x1.device();
334         device.CheckDevicesCompatible(x1, out);
335         const Array& x1_cast = x1.dtype() == out.dtype() ? x1 : x1.AsType(out.dtype());
336         CudaSetDeviceScope scope{device.index()};
337         VisitNumericDtype(out.dtype(), [&](auto pt) {
338             using T = typename decltype(pt)::type;
339             using CudaType = cuda_internal::DataType<T>;
340             Elementwise<const T, T>(ModASImpl<T>{static_cast<CudaType>(x2)}, x1_cast, out);
341         });
342     }
343 };
344 
345 CHAINERX_CUDA_REGISTER_KERNEL(ModASKernel, CudaModASKernel);
346 
347 template <typename T>
348 struct ModSAImpl {
349     using CudaType = cuda_internal::DataType<T>;
operator ()chainerx::cuda::__anon8d8019610111::ModSAImpl350     __device__ void operator()(int64_t /*i*/, CudaType x2, CudaType& out) { out = cuda::Mod(x1, x2); }
351     CudaType x1;
352 };
353 
354 class CudaModSAKernel : public ModSAKernel {
355 public:
Call(Scalar x1,const Array & x2,const Array & out)356     void Call(Scalar x1, const Array& x2, const Array& out) override {
357         Device& device = x2.device();
358         device.CheckDevicesCompatible(x2, out);
359         const Array& x2_cast = x2.dtype() == out.dtype() ? x2 : x2.AsType(out.dtype());
360         CudaSetDeviceScope scope{device.index()};
361         VisitNumericDtype(out.dtype(), [&](auto pt) {
362             using T = typename decltype(pt)::type;
363             using CudaType = cuda_internal::DataType<T>;
364             Elementwise<const T, T>(ModSAImpl<T>{static_cast<CudaType>(x1)}, x2_cast, out);
365         });
366     }
367 };
368 
369 CHAINERX_CUDA_REGISTER_KERNEL(ModSAKernel, CudaModSAKernel);
370 
371 CHAINERX_CUDA_REGISTER_ELTWISE_BINARY_KERNEL(FmodKernel, { out = cuda::Fmod(x1, x2); });
372 
373 }  // namespace
374 }  // namespace cuda
375 }  // namespace chainerx
376