1 #include "chainerx/float16.h" 2 3 #include <cstdint> 4 5 namespace chainerx { 6 namespace { 7 8 union UnionFloatUint { 9 public: UnionFloatUint(float v)10 explicit UnionFloatUint(float v) : f{v} {} UnionFloatUint(uint32_t v)11 explicit UnionFloatUint(uint32_t v) : i{v} {} 12 float f; 13 uint32_t i; 14 }; 15 16 union UnionDoubleUint { 17 public: UnionDoubleUint(double v)18 explicit UnionDoubleUint(double v) : f{v} {} UnionDoubleUint(uint64_t v)19 explicit UnionDoubleUint(uint64_t v) : i{v} {} 20 double f; 21 uint64_t i; 22 }; 23 24 // Borrowed from npy_floatbits_to_halfbits 25 // 26 // See LICENSE.txt of ChainerX. FloatbitsToHalfbits(uint32_t f)27uint16_t FloatbitsToHalfbits(uint32_t f) { 28 uint16_t h_sgn = static_cast<uint16_t>((f & 0x80000000U) >> 16); 29 uint32_t f_exp = (f & 0x7f800000U); 30 31 // Exponent overflow/NaN converts to signed inf/NaN 32 if (f_exp >= 0x47800000U) { 33 if (f_exp != 0x7f800000U) { 34 // Overflow to signed inf 35 return h_sgn + 0x7c00U; 36 } 37 38 uint32_t f_sig = (f & 0x007fffffU); 39 if (f_sig == 0) { 40 // Signed inf 41 return h_sgn + 0x7c00U; 42 } 43 44 // NaN - propagate the flag in the significand... 45 uint16_t ret = static_cast<uint16_t>(0x7c00U + (f_sig >> 13)); 46 47 // ...but make sure it stays a NaN 48 if (ret == 0x7c00U) { 49 ++ret; 50 } 51 return h_sgn + ret; 52 } 53 54 // Exponent underflow converts to a subnormal half or signed zero 55 if (f_exp <= 0x38000000U) { 56 if (f_exp < 0x33000000U) { 57 // Signed zero 58 return h_sgn; 59 } 60 61 // Make the subnormal significand 62 f_exp >>= 23; 63 uint32_t f_sig = (0x00800000U + (f & 0x007fffffU)) >> (113 - f_exp); 64 65 // Handle rounding by adding 1 to the bit beyond half precision 66 if (((f_sig & 0x00003fffU) != 0x00001000U) || ((f & 0x000007ffU) > 0)) { 67 f_sig += 0x00001000U; 68 } 69 uint16_t h_sig = static_cast<uint16_t>(f_sig >> 13); 70 71 // If the rounding causes a bit to spill into h_exp, it will increment h_exp from zero to one and h_sig will be zero. This is the 72 // correct result. 73 return h_sgn + h_sig; 74 } 75 76 // Regular case with no overflow or underflow 77 uint16_t h_exp = static_cast<uint16_t>((f_exp - 0x38000000U) >> 13); 78 79 // Handle rounding by adding 1 to the bit beyond half precision 80 uint32_t f_sig = f & 0x007fffffU; 81 if ((f_sig & 0x00003fffU) != 0x00001000U) { 82 f_sig += 0x00001000U; 83 } 84 uint16_t h_sig = static_cast<uint16_t>(f_sig >> 13); 85 86 // If the rounding causes a bit to spill into h_exp, it will increment h_exp by one and h_sig will be zero. This is the correct result. 87 // h_exp may increment to 15, at greatest, in which case the result overflows to a signed inf. 88 return h_sgn + h_exp + h_sig; 89 } 90 91 // Borrowed from npy_doublebits_to_halfbits 92 // 93 // See LICENSE.txt of ChainerX. DoublebitsToHalfbits(uint64_t d)94uint16_t DoublebitsToHalfbits(uint64_t d) { 95 uint16_t h_sgn = (d & 0x8000000000000000ULL) >> 48; 96 uint64_t d_exp = (d & 0x7ff0000000000000ULL); 97 98 // Exponent overflow/NaN converts to signed inf/NaN 99 if (d_exp >= 0x40f0000000000000ULL) { 100 if (d_exp != 0x7ff0000000000000ULL) { 101 // Overflow to signed inf 102 return h_sgn + 0x7c00U; 103 } 104 105 uint64_t d_sig = (d & 0x000fffffffffffffULL); 106 if (d_sig == 0) { 107 // Signed inf 108 return h_sgn + 0x7c00U; 109 } 110 111 // NaN - propagate the flag in the significand... 112 uint16_t ret = static_cast<uint16_t>(0x7c00U + (d_sig >> 42)); 113 114 // ...but make sure it stays a NaN 115 if (ret == 0x7c00U) { 116 ++ret; 117 } 118 return h_sgn + ret; 119 } 120 121 // Exponent underflow converts to subnormal half or signed zero 122 if (d_exp <= 0x3f00000000000000ULL) { 123 if (d_exp < 0x3e60000000000000ULL) { 124 // Signed zero 125 return h_sgn; 126 } 127 128 // Make the subnormal significand 129 d_exp >>= 52; 130 uint64_t d_sig = (0x0010000000000000ULL + (d & 0x000fffffffffffffULL)); 131 d_sig <<= (d_exp - 998); 132 // Handle rounding by adding 1 to the bit beyond half precision 133 if ((d_sig & 0x003fffffffffffffULL) != 0x0010000000000000ULL) { 134 d_sig += 0x0010000000000000ULL; 135 } 136 uint16_t h_sig = static_cast<uint16_t>(d_sig >> 53); 137 138 // If the rounding causes a bit to spill into h_exp, it will increment h_exp from zero to one and h_sig will be zero. This is the 139 // correct result. 140 return h_sgn + h_sig; 141 } 142 143 // Regular case with no overflow or underflow 144 uint16_t h_exp = static_cast<uint16_t>((d_exp - 0x3f00000000000000ULL) >> 42); 145 146 // Handle rounding by adding 1 to the bit beyond half precision 147 uint64_t d_sig = d & 0x000fffffffffffffULL; 148 if ((d_sig & 0x000007ffffffffffULL) != 0x0000020000000000ULL) { 149 d_sig += 0x0000020000000000ULL; 150 } 151 uint16_t h_sig = static_cast<uint16_t>(d_sig >> 42); 152 153 // If the rounding causes a bit to spill into h_exp, it will increment h_exp by one and h_sig will be zero. This is the correct result. 154 // h_exp may increment to 15, at greatest, in which case the result overflows to a signed inf. 155 return h_sgn + h_exp + h_sig; 156 } 157 158 // Borrowed from npy_halfbits_to_floatbits 159 // 160 // See LICENSE.txt of ChainerX. HalfbitsToFloatbits(uint16_t h)161uint32_t HalfbitsToFloatbits(uint16_t h) { 162 uint16_t h_exp = (h & 0x7c00U); 163 uint32_t f_sgn = (static_cast<uint32_t>(h) & 0x8000U) << 16; 164 switch (h_exp) { 165 case 0x0000U: { // 0 or subnormal 166 uint16_t h_sig = (h & 0x03ffU); 167 168 // Signed zero 169 if (h_sig == 0) { 170 return f_sgn; 171 } 172 173 // Subnormal 174 h_sig <<= 1; 175 while ((h_sig & 0x0400U) == 0) { 176 h_sig <<= 1; 177 ++h_exp; 178 } 179 180 uint32_t f_exp = (static_cast<uint32_t>(127 - 15 - h_exp)) << 23; 181 uint32_t f_sig = (static_cast<uint32_t>(h_sig & 0x03ffU)) << 13; 182 return f_sgn + f_exp + f_sig; 183 } 184 case 0x7c00U: { // inf or NaN 185 // All-ones exponent and a copy of the significand 186 return f_sgn + 0x7f800000U + ((static_cast<uint32_t>(h & 0x03ffU)) << 13); 187 } 188 default: { // normalized 189 // Just need to adjust the exponent and shift 190 return f_sgn + ((static_cast<uint32_t>(h & 0x7fffU) + 0x1c000U) << 13); 191 } 192 } 193 } 194 195 // Borrowed from npy_halfbits_to_doublebits 196 // 197 // See LICENSE.txt of ChainerX. HalfbitsToDoublebits(uint16_t h)198uint64_t HalfbitsToDoublebits(uint16_t h) { 199 uint16_t h_exp = (h & 0x7c00U); 200 uint64_t d_sgn = (static_cast<uint64_t>(h) & 0x8000U) << 48; 201 switch (h_exp) { 202 case 0x0000U: { // 0 or subnormal 203 uint16_t h_sig = (h & 0x03ffU); 204 205 // Signed zero 206 if (h_sig == 0) { 207 return d_sgn; 208 } 209 210 // Subnormal 211 h_sig <<= 1; 212 while ((h_sig & 0x0400U) == 0) { 213 h_sig <<= 1; 214 ++h_exp; 215 } 216 217 uint64_t d_exp = (static_cast<uint64_t>(1023 - 15 - h_exp)) << 52; 218 uint64_t d_sig = (static_cast<uint64_t>(h_sig & 0x03ffU)) << 42; 219 return d_sgn + d_exp + d_sig; 220 } 221 case 0x7c00U: { // inf or NaN 222 // All-ones exponent and a copy of the significand 223 return d_sgn + 0x7ff0000000000000ULL + ((static_cast<uint64_t>(h & 0x03ffU)) << 42); 224 } 225 default: { // normalized 226 // Just need to adjust the exponent and shift 227 return d_sgn + ((static_cast<uint64_t>(h & 0x7fffU) + 0xfc000U) << 42); 228 } 229 } 230 } 231 FloatToHalf(float v)232uint16_t FloatToHalf(float v) { 233 return FloatbitsToHalfbits(UnionFloatUint(v).i); // NOLINT(cppcoreguidelines-pro-type-union-access) 234 } DoubleToHalf(double v)235uint16_t DoubleToHalf(double v) { 236 return DoublebitsToHalfbits(UnionDoubleUint(v).i); // NOLINT(cppcoreguidelines-pro-type-union-access) 237 } HalfToFloat(uint16_t v)238float HalfToFloat(uint16_t v) { 239 return UnionFloatUint(HalfbitsToFloatbits(v)).f; // NOLINT(cppcoreguidelines-pro-type-union-access) 240 } HalfToDouble(uint16_t v)241double HalfToDouble(uint16_t v) { 242 return UnionDoubleUint(HalfbitsToDoublebits(v)).f; // NOLINT(cppcoreguidelines-pro-type-union-access) 243 } 244 245 } // namespace 246 Float16(float v)247Float16::Float16(float v) : data_{FloatToHalf(v)} {} Float16(double v)248Float16::Float16(double v) : data_{DoubleToHalf(v)} {} 249 operator float() const250Float16::operator float() const { return HalfToFloat(data_); } operator double() const251Float16::operator double() const { return HalfToDouble(data_); } 252 253 } // namespace chainerx 254