1 #pragma once
2 #include <cmath>
3 #include <cstdint>
4 #include <type_traits>
5 
6 #include "chainerx/array.h"
7 #include "chainerx/scalar.h"
8 
9 namespace chainerx {
10 
11 bool AllClose(const Array& a, const Array& b, double rtol = 1e-5, double atol = 1e-8, bool equal_nan = false);
12 
13 template <typename T>
IsNan(T)14 inline bool IsNan(T /*value*/) {
15     return false;
16 }
17 
IsNan(chainerx::Float16 value)18 inline bool IsNan(chainerx::Float16 value) { return value.IsNan(); }
IsNan(float value)19 inline bool IsNan(float value) { return std::isnan(value); }
IsNan(double value)20 inline bool IsNan(double value) { return std::isnan(value); }
21 
22 template <typename T>
IsInf(T)23 inline bool IsInf(T /*value*/) {
24     return false;
25 }
26 
IsInf(chainerx::Float16 value)27 inline bool IsInf(chainerx::Float16 value) { return value.IsInf(); }
IsInf(double value)28 inline bool IsInf(double value) { return std::isinf(value); }
IsInf(float value)29 inline bool IsInf(float value) { return std::isinf(value); }
30 
31 template <typename T>
Sign(T x)32 inline T Sign(T x) {
33     return IsNan(x) ? x : static_cast<T>(static_cast<int>(T{0} < x) - static_cast<int>(x < T{0}));
34 }
35 
36 template <>
37 inline chainerx::Float16 Sign<chainerx::Float16>(chainerx::Float16 x) {
38     return IsNan(x) ? x : Float16{static_cast<int>(Float16{0} < x) - static_cast<int>(x < Float16{0})};
39 }
40 
41 template <typename T>
Abs(T x)42 inline T Abs(T x) {
43     return std::abs(x);
44 }
45 
46 template <>
47 inline uint8_t Abs<uint8_t>(uint8_t x) {
48     return x;
49 }
50 template <>
51 inline chainerx::Float16 Abs<chainerx::Float16>(chainerx::Float16 x) {
52     return chainerx::Float16{std::abs(static_cast<float>(x))};
53 }
54 
55 #define CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(name, func)           \
56     template <typename T>                                                   \
57     inline T name(T x) {                                                    \
58         return func(x);                                                     \
59     }                                                                       \
60     template <>                                                             \
61     inline chainerx::Float16 name<chainerx::Float16>(chainerx::Float16 x) { \
62         return chainerx::Float16{func(static_cast<float>(x))};              \
63     }
64 
CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Ceil,std::ceil)65 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Ceil, std::ceil)
66 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Floor, std::floor)
67 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Sinh, std::sinh)
68 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Cosh, std::cosh)
69 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Tanh, std::tanh)
70 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Arcsinh, std::asinh)
71 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Arccosh, std::acosh)
72 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Sin, std::sin)
73 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Cos, std::cos)
74 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Tan, std::tan)
75 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Arcsin, std::asin)
76 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Arccos, std::acos)
77 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Arctan, std::atan)
78 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Erf, std::erf)
79 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Exp, std::exp)
80 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Expm1, std::expm1)
81 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Exp2, std::exp2)
82 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Log, std::log)
83 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Log10, std::log10)
84 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Log2, std::log2)
85 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Log1p, std::log1p)
86 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_UNARY(Sqrt, std::sqrt)
87 
88 namespace numeric_detail {
89 
90 template <typename T>
91 inline T NonNegativePower(T x1, T x2) {
92     static_assert(std::is_integral<T>::value, "NonNegativePower is only defined for non-negative integrals.");
93     T out{1};
94 
95     while (x2 > 0) {
96         if (x2 & 1) {
97             out *= x1;
98         }
99         x1 *= x1;
100         x2 >>= 1;
101     }
102 
103     return out;
104 }
105 
106 }  // namespace numeric_detail
107 
108 template <typename T>
109 inline auto Power(T x1, T x2) -> std::enable_if_t<std::is_integral<T>::value && std::is_signed<T>::value, T> {
110     if (x2 < 0) {
111         switch (x1) {
112             case -1:
113                 return x2 & 1 ? -1 : 1;
114             case 1:
115                 return 1;
116             default:
117                 return 0;
118         }
119     }
120     return numeric_detail::NonNegativePower(x1, x2);
121 }
122 
123 template <typename T>
124 inline auto Power(T x1, T x2) -> std::enable_if_t<std::is_integral<T>::value && std::is_unsigned<T>::value, T> {
125     return numeric_detail::NonNegativePower(x1, x2);
126 }
127 
128 template <typename T>
129 inline auto Power(T x1, T x2) -> std::enable_if_t<!std::is_integral<T>::value, T>;
130 template <>
Power(chainerx::Float16 x1,chainerx::Float16 x2)131 inline chainerx::Float16 Power(chainerx::Float16 x1, chainerx::Float16 x2) {
132     return chainerx::Float16{std::pow(static_cast<float>(x1), static_cast<float>(x2))};
133 }
134 template <>
Power(float x1,float x2)135 inline float Power(float x1, float x2) {
136     return std::pow(x1, x2);
137 }
138 template <>
Power(double x1,double x2)139 inline double Power(double x1, double x2) {
140     return std::pow(x1, x2);
141 }
142 
143 template <typename T>
Fmod(T x1,T x2)144 inline T Fmod(T x1, T x2) {
145     return x1 % x2;
146 }
147 
Fmod(chainerx::Float16 x1,chainerx::Float16 x2)148 inline chainerx::Float16 Fmod(chainerx::Float16 x1, chainerx::Float16 x2) {
149     return chainerx::Float16{std::fmod(static_cast<float>(x1), static_cast<float>(x2))};
150 }
Fmod(float x1,float x2)151 inline float Fmod(float x1, float x2) { return std::fmod(x1, x2); }
Fmod(double x1,double x2)152 inline double Fmod(double x1, double x2) { return std::fmod(x1, x2); }
153 
154 #define CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_BINARY(name, func)                                 \
155     template <typename T>                                                                          \
156     inline T name(T x1, T x2) {                                                                    \
157         return func(x1, x2);                                                                       \
158     }                                                                                              \
159     template <>                                                                                    \
160     inline chainerx::Float16 name<chainerx::Float16>(chainerx::Float16 x1, chainerx::Float16 x2) { \
161         return chainerx::Float16{func(static_cast<float>(x1), static_cast<float>(x2))};            \
162     }
163 
164 CHAINERX_DEFINE_NATIVE_FLOAT16_FALLBACK_BINARY(Arctan2, std::atan2)
165 
166 }  // namespace chainerx
167