1 #pragma once
2 #include <type_traits>
3
4 #include "chainerx/cuda/float16.cuh"
5 #include "chainerx/numeric.h"
6
7 namespace chainerx {
8 namespace cuda {
9
10 template <typename T>
IsNan(T)11 __device__ inline bool IsNan(T /*value*/) {
12 return false;
13 }
IsNan(cuda::Float16 value)14 __device__ inline bool IsNan(cuda::Float16 value) { return value.IsNan(); }
IsNan(double value)15 __device__ inline bool IsNan(double value) { return isnan(value); }
IsNan(float value)16 __device__ inline bool IsNan(float value) { return isnan(value); }
17
18 template <typename T>
IsInf(T)19 __device__ inline bool IsInf(T /*value*/) {
20 return false;
21 }
IsInf(cuda::Float16 value)22 __device__ inline bool IsInf(cuda::Float16 value) { return value.IsInf(); }
IsInf(double value)23 __device__ inline bool IsInf(double value) { return isinf(value); }
IsInf(float value)24 __device__ inline bool IsInf(float value) { return isinf(value); }
25
26 template <typename T>
Arctan2(T x1,T x2)27 __device__ inline T Arctan2(T x1, T x2) {
28 return std::atan2(x1, x2);
29 }
30 template <>
Arctan2(cuda::Float16 x1,cuda::Float16 x2)31 __device__ inline cuda::Float16 Arctan2<cuda::Float16>(cuda::Float16 x1, cuda::Float16 x2) {
32 return cuda::Float16{std::atan2(static_cast<float>(x1), static_cast<float>(x2))};
33 }
34
Arcsinh(double x)35 __device__ inline double Arcsinh(double x) { return std::asinh(x); }
36
Arcsinh(float x)37 __device__ inline float Arcsinh(float x) { return std::asinhf(x); }
38
Arcsinh(cuda::Float16 x)39 __device__ inline cuda::Float16 Arcsinh(cuda::Float16 x) { return cuda::Float16{std::asinhf(static_cast<float>(x))}; }
40
Arccosh(double x)41 __device__ inline double Arccosh(double x) { return std::acosh(x); }
42
Arccosh(float x)43 __device__ inline float Arccosh(float x) { return std::acoshf(x); }
44
Arccosh(cuda::Float16 x)45 __device__ inline cuda::Float16 Arccosh(cuda::Float16 x) { return cuda::Float16{std::acoshf(static_cast<float>(x))}; }
46
Log2(double x)47 __device__ inline double Log2(double x) { return std::log2(x); }
48
Log2(float x)49 __device__ inline float Log2(float x) { return std::log2f(x); }
50
Log2(cuda::Float16 x)51 __device__ inline cuda::Float16 Log2(cuda::Float16 x) { return cuda::Float16{std::log2f(static_cast<float>(x))}; }
52
Log1p(double x)53 __device__ inline double Log1p(double x) { return std::log1p(x); }
54
Log1p(float x)55 __device__ inline float Log1p(float x) { return std::log1pf(x); }
56
Log1p(cuda::Float16 x)57 __device__ inline cuda::Float16 Log1p(cuda::Float16 x) { return cuda::Float16{std::log1pf(static_cast<float>(x))}; }
58
59 template <typename T>
Sign(T x)60 __device__ inline T Sign(T x) {
61 return IsNan(x) ? x : static_cast<T>(static_cast<int>(T{0} < x) - static_cast<int>(x < T{0}));
62 }
63
64 template <>
Sign(uint8_t x)65 __device__ inline uint8_t Sign(uint8_t x) {
66 return static_cast<uint8_t>(x > 0);
67 }
68 template <>
Sign(cuda::Float16 x)69 __device__ inline cuda::Float16 Sign(cuda::Float16 x) {
70 return IsNan(x) ? x : cuda::Float16{static_cast<int>(cuda::Float16{0} < x) - static_cast<int>(x < cuda::Float16{0})};
71 }
72
Erf(double x)73 __device__ inline double Erf(double x) { return std::erf(x); }
74
Erf(float x)75 __device__ inline float Erf(float x) { return std::erff(x); }
76
Erf(cuda::Float16 x)77 __device__ inline cuda::Float16 Erf(cuda::Float16 x) { return cuda::Float16{std::erff(static_cast<float>(x))}; }
78
Expm1(double x)79 __device__ inline double Expm1(double x) { return std::expm1(x); }
80
Expm1(float x)81 __device__ inline float Expm1(float x) { return std::expm1f(x); }
82
Expm1(cuda::Float16 x)83 __device__ inline cuda::Float16 Expm1(cuda::Float16 x) { return cuda::Float16{std::expm1f(static_cast<float>(x))}; }
84
Exp2(double x)85 __device__ inline double Exp2(double x) { return std::exp2(x); }
86
Exp2(float x)87 __device__ inline float Exp2(float x) { return std::exp2f(x); }
88
Exp2(cuda::Float16 x)89 __device__ inline cuda::Float16 Exp2(cuda::Float16 x) { return cuda::Float16{std::exp2f(static_cast<float>(x))}; }
90
Abs(uint8_t x)91 __device__ inline uint8_t Abs(uint8_t x) { return x; }
Abs(int8_t x)92 __device__ inline int8_t Abs(int8_t x) { return std::labs(x); }
Abs(int16_t x)93 __device__ inline int16_t Abs(int16_t x) { return std::labs(x); }
Abs(int32_t x)94 __device__ inline int32_t Abs(int32_t x) { return std::labs(x); }
Abs(int64_t x)95 __device__ inline int64_t Abs(int64_t x) { return std::llabs(x); }
Abs(double x)96 __device__ inline double Abs(double x) { return std::fabs(x); }
Abs(float x)97 __device__ inline float Abs(float x) { return std::fabs(x); }
Abs(cuda::Float16 x)98 __device__ inline cuda::Float16 Abs(cuda::Float16 x) { return static_cast<cuda::Float16>(std::fabs(static_cast<float>(x))); }
99
100 template <typename T>
Fmod(T x1,T x2)101 __device__ inline T Fmod(T x1, T x2) {
102 return x1 % x2;
103 }
104
Fmod(cuda::Float16 x1,cuda::Float16 x2)105 __device__ inline cuda::Float16 Fmod(cuda::Float16 x1, cuda::Float16 x2) {
106 return cuda::Float16{std::fmod(static_cast<float>(x1), static_cast<float>(x2))};
107 }
Fmod(float x1,float x2)108 __device__ inline float Fmod(float x1, float x2) { return std::fmod(x1, x2); }
Fmod(double x1,double x2)109 __device__ inline double Fmod(double x1, double x2) { return std::fmod(x1, x2); }
110
111 #define CHAINERX_DEFINE_CUDA_FLOAT16_FALLBACK_UNARY(name, func) \
112 template <typename T> \
113 __device__ inline T name(T x) { \
114 return func(x); \
115 } \
116 __device__ inline cuda::Float16 name(cuda::Float16 x) { return cuda::Float16{func(static_cast<float>(x))}; }
117
118 CHAINERX_DEFINE_CUDA_FLOAT16_FALLBACK_UNARY(Ceil, std::ceil)
119 CHAINERX_DEFINE_CUDA_FLOAT16_FALLBACK_UNARY(Floor, std::floor)
120 CHAINERX_DEFINE_CUDA_FLOAT16_FALLBACK_UNARY(Sinh, std::sinh)
121 CHAINERX_DEFINE_CUDA_FLOAT16_FALLBACK_UNARY(Cosh, std::cosh)
122 CHAINERX_DEFINE_CUDA_FLOAT16_FALLBACK_UNARY(Tanh, std::tanh)
123 CHAINERX_DEFINE_CUDA_FLOAT16_FALLBACK_UNARY(Sin, std::sin)
124 CHAINERX_DEFINE_CUDA_FLOAT16_FALLBACK_UNARY(Cos, std::cos)
125 CHAINERX_DEFINE_CUDA_FLOAT16_FALLBACK_UNARY(Tan, std::tan)
126 CHAINERX_DEFINE_CUDA_FLOAT16_FALLBACK_UNARY(Arcsin, std::asin)
127 CHAINERX_DEFINE_CUDA_FLOAT16_FALLBACK_UNARY(Arccos, std::acos)
128 CHAINERX_DEFINE_CUDA_FLOAT16_FALLBACK_UNARY(Arctan, std::atan)
129 CHAINERX_DEFINE_CUDA_FLOAT16_FALLBACK_UNARY(Exp, std::exp)
130 CHAINERX_DEFINE_CUDA_FLOAT16_FALLBACK_UNARY(Log, std::log)
131 CHAINERX_DEFINE_CUDA_FLOAT16_FALLBACK_UNARY(Log10, std::log10)
132 CHAINERX_DEFINE_CUDA_FLOAT16_FALLBACK_UNARY(Sqrt, std::sqrt)
133
134 namespace numeric_detail {
135
136 template <typename T>
NonNegativePower(T x1,T x2)137 __device__ inline T NonNegativePower(T x1, T x2) {
138 static_assert(std::is_integral<T>::value, "NonNegativePower is only defined for non-negative integrals.");
139 T out{1};
140
141 while (x2 > 0) {
142 if (x2 & 1) {
143 out *= x1;
144 }
145 x1 *= x1;
146 x2 >>= 1;
147 }
148
149 return out;
150 }
151
152 } // namespace numeric_detail
153
154 template <typename T>
Power(T x1,T x2)155 __device__ inline auto Power(T x1, T x2) -> std::enable_if_t<std::is_integral<T>::value && std::is_signed<T>::value, T> {
156 if (x2 < 0) {
157 switch (x1) {
158 case -1:
159 return x2 & 1 ? -1 : 1;
160 case 1:
161 return 1;
162 default:
163 return 0;
164 }
165 }
166 return numeric_detail::NonNegativePower(x1, x2);
167 }
168
169 template <typename T>
Power(T x1,T x2)170 __device__ inline auto Power(T x1, T x2) -> std::enable_if_t<std::is_integral<T>::value && std::is_unsigned<T>::value, T> {
171 return numeric_detail::NonNegativePower(x1, x2);
172 }
173
174 template <typename T>
175 __device__ inline auto Power(T x1, T x2) -> std::enable_if_t<!std::is_integral<T>::value, T>;
176 template <>
Power(cuda::Float16 x1,cuda::Float16 x2)177 __device__ inline cuda::Float16 Power(cuda::Float16 x1, cuda::Float16 x2) {
178 return cuda::Float16{powf(static_cast<float>(x1), static_cast<float>(x2))};
179 }
180 template <>
Power(float x1,float x2)181 __device__ inline float Power(float x1, float x2) {
182 return powf(x1, x2);
183 }
184 template <>
Power(double x1,double x2)185 __device__ inline double Power(double x1, double x2) {
186 return pow(x1, x2);
187 }
188
189 } // namespace cuda
190 } // namespace chainerx
191