1 #pragma once
2 #include <cmath>
3 #include <cstdint>
4 #include <type_traits>
5
6 #include "chainerx/array.h"
7 #include "chainerx/scalar.h"
8
9 namespace chainerx {
10
11 bool AllClose(const Array& a, const Array& b, double rtol = 1e-5, double atol = 1e-8, bool equal_nan = false);
12
13 template <typename T>
IsNan(T)14 inline bool IsNan(T /*value*/) {
15 return false;
16 }
17
IsNan(chainerx::Float16 value)18 inline bool IsNan(chainerx::Float16 value) { return value.IsNan(); }
IsNan(float value)19 inline bool IsNan(float value) { return std::isnan(value); }
IsNan(double value)20 inline bool IsNan(double value) { return std::isnan(value); }
21
22 template <typename T>
IsInf(T)23 inline bool IsInf(T /*value*/) {
24 return false;
25 }
26
IsInf(chainerx::Float16 value)27 inline bool IsInf(chainerx::Float16 value) { return value.IsInf(); }
IsInf(double value)28 inline bool IsInf(double value) { return std::isinf(value); }
IsInf(float value)29 inline bool IsInf(float value) { return std::isinf(value); }
30
31 template <typename T>
Sign(T x)32 inline T Sign(T x) {
33 return IsNan(x) ? x : static_cast<T>(static_cast<int>(T{0} < x) - static_cast<int>(x < T{0}));
34 }
35
36 template <>
37 inline chainerx::Float16 Sign<chainerx::Float16>(chainerx::Float16 x) {
38 return IsNan(x) ? x : Float16{static_cast<int>(Float16{0} < x) - static_cast<int>(x < Float16{0})};
39 }
40
41 template <typename T>
Abs(T x)42 inline T Abs(T x) {
43 return std::abs(x);
44 }
45
46 template <>
47 inline uint8_t Abs<uint8_t>(uint8_t x) {
48 return x;
49 }
50 template <>
51 inline chainerx::Float16 Abs<chainerx::Float16>(chainerx::Float16 x) {
52 return chainerx::Float16{std::abs(static_cast<float>(x))};
53 }
54
55 #define CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(name, func) \
56 template <typename T> \
57 inline T name(T x) { \
58 return func(x); \
59 } \
60 template <> \
61 inline chainerx::Float16 name<chainerx::Float16>(chainerx::Float16 x) { \
62 return chainerx::Float16{func(static_cast<float>(x))}; \
63 }
64
CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Ceil,std::ceil)65 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Ceil, std::ceil)
66 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Floor, std::floor)
67 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Sinh, std::sinh)
68 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Cosh, std::cosh)
69 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Tanh, std::tanh)
70 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Arcsinh, std::asinh)
71 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Arccosh, std::acosh)
72 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Sin, std::sin)
73 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Cos, std::cos)
74 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Tan, std::tan)
75 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Arcsin, std::asin)
76 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Arccos, std::acos)
77 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Arctan, std::atan)
78 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Erf, std::erf)
79 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Exp, std::exp)
80 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Expm1, std::expm1)
81 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Exp2, std::exp2)
82 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Log, std::log)
83 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Log10, std::log10)
84 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Log2, std::log2)
85 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Log1p, std::log1p)
86 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Sqrt, std::sqrt)
87
88 namespace numeric_detail {
89
90 template <typename T>
91 inline T NonNegativePower(T x1, T x2) {
92 static_assert(std::is_integral<T>::value, "NonNegativePower is only defined for non-negative integrals.");
93 T out{1};
94
95 while (x2 > 0) {
96 if (x2 & 1) {
97 out *= x1;
98 }
99 x1 *= x1;
100 x2 >>= 1;
101 }
102
103 return out;
104 }
105
106 } // namespace numeric_detail
107
108 template <typename T>
109 inline auto Power(T x1, T x2) -> std::enable_if_t<std::is_integral<T>::value && std::is_signed<T>::value, T> {
110 if (x2 < 0) {
111 switch (x1) {
112 case -1:
113 return x2 & 1 ? -1 : 1;
114 case 1:
115 return 1;
116 default:
117 return 0;
118 }
119 }
120 return numeric_detail::NonNegativePower(x1, x2);
121 }
122
123 template <typename T>
124 inline auto Power(T x1, T x2) -> std::enable_if_t<std::is_integral<T>::value && std::is_unsigned<T>::value, T> {
125 return numeric_detail::NonNegativePower(x1, x2);
126 }
127
128 template <typename T>
129 inline auto Power(T x1, T x2) -> std::enable_if_t<!std::is_integral<T>::value, T>;
130 template <>
Power(chainerx::Float16 x1,chainerx::Float16 x2)131 inline chainerx::Float16 Power(chainerx::Float16 x1, chainerx::Float16 x2) {
132 return chainerx::Float16{std::pow(static_cast<float>(x1), static_cast<float>(x2))};
133 }
134 template <>
Power(float x1,float x2)135 inline float Power(float x1, float x2) {
136 return std::pow(x1, x2);
137 }
138 template <>
Power(double x1,double x2)139 inline double Power(double x1, double x2) {
140 return std::pow(x1, x2);
141 }
142
143 template <typename T>
Fmod(T x1,T x2)144 inline T Fmod(T x1, T x2) {
145 return x1 % x2;
146 }
147
Fmod(chainerx::Float16 x1,chainerx::Float16 x2)148 inline chainerx::Float16 Fmod(chainerx::Float16 x1, chainerx::Float16 x2) {
149 return chainerx::Float16{std::fmod(static_cast<float>(x1), static_cast<float>(x2))};
150 }
Fmod(float x1,float x2)151 inline float Fmod(float x1, float x2) { return std::fmod(x1, x2); }
Fmod(double x1,double x2)152 inline double Fmod(double x1, double x2) { return std::fmod(x1, x2); }
153
154 #define CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_BINARY(name, func) \
155 template <typename T> \
156 inline T name(T x1, T x2) { \
157 return func(x1, x2); \
158 } \
159 template <> \
160 inline chainerx::Float16 name<chainerx::Float16>(chainerx::Float16 x1, chainerx::Float16 x2) { \
161 return chainerx::Float16{func(static_cast<float>(x1), static_cast<float>(x2))}; \
162 }
163
164 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_BINARY(Arctan2, std::atan2)
165
166 } // namespace chainerx
167