1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 /*! 21 * \file half.h 22 * \brief definition of half (float16) type. 23 * 24 * \author Junyuan Xie 25 */ 26 #ifndef MSHADOW_HALF_H_ 27 #define MSHADOW_HALF_H_ 28 #include "./base.h" 29 30 #if MSHADOW_USE_F16C 31 #include <x86intrin.h> 32 #endif // MSHADOW_USE_F16C 33 34 // This flag dictates rounding for the float2half() routine only (used generally on Windows), 35 // not the f16c lib or cuda v7.5 (or later) behavior which is fixed at round-to-nearest-even. 36 #ifndef MSHADOW_HALF_ROUND_TO_NEAREST 37 #define MSHADOW_HALF_ROUND_TO_NEAREST 1 38 #endif 39 40 #if (MSHADOW_USE_CUDA && CUDA_VERSION >= 7050) 41 #define MSHADOW_CUDA_HALF 1 42 #include <cuda_fp16.h> 43 #if defined(__CUDA_ARCH__) 44 /*! \brief __half2float_warp */ __half2float_warp(const volatile __half & h)45 __host__ __device__ float __half2float_warp(const volatile __half& h) { /* NOLINT(*) */ 46 __half val; 47 #if CUDA_VERSION >= 9000 48 val = const_cast<__half&>(h); 49 #else 50 val.x = h.x; 51 #endif 52 return __half2float(val); 53 } 54 #endif 55 #else 56 #define MSHADOW_CUDA_HALF 0 57 #endif 58 59 /*! \brief namespace for mshadow */ 60 namespace mshadow { 61 /* \brief name space for host/device portable half-precision floats */ 62 namespace half { 63 #define MSHADOW_HALF_OPERATOR(RTYPE, OP) \ 64 MSHADOW_XINLINE RTYPE operator OP (half_t a, half_t b) { \ 65 return RTYPE(float(a) OP float(b)); /* NOLINT(*) */ \ 66 } \ 67 template<typename T> \ 68 MSHADOW_XINLINE RTYPE operator OP (half_t a, T b) { \ 69 return RTYPE(float(a) OP float(b)); /* NOLINT(*) */ \ 70 } \ 71 template<typename T> \ 72 MSHADOW_XINLINE RTYPE operator OP (T a, half_t b) { \ 73 return RTYPE(float(a) OP float(b)); /* NOLINT(*) */ \ 74 } 75 76 #define MSHADOW_HALF_ASSIGNOP(AOP, OP) \ 77 template<typename T> \ 78 MSHADOW_XINLINE half_t operator AOP (const T& a) { \ 79 return *this = half_t(float(*this) OP float(a)); /* NOLINT(*)*/ \ 80 } \ 81 template<typename T> \ 82 MSHADOW_XINLINE half_t operator AOP (const volatile T& a) volatile { \ 83 return *this = half_t(float(*this) OP float(a)); /* NOLINT(*)*/ \ 84 } 85 86 #if (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__)) 87 #define MSHADOW_HALF_CONVERSIONOP(T) \ 88 MSHADOW_XINLINE operator T() const { \ 89 return T(__half2float(cuhalf_)); /* NOLINT(*)*/ \ 90 } \ 91 MSHADOW_XINLINE operator T() const volatile { \ 92 return T(__half2float_warp(cuhalf_)); /* NOLINT(*)*/ \ 93 } 94 #elif(MSHADOW_USE_F16C) 95 #define MSHADOW_HALF_CONVERSIONOP(T) \ 96 MSHADOW_XINLINE operator T() const { \ 97 return T(_cvtsh_ss(half_)); /* NOLINT(*)*/ \ 98 } \ 99 MSHADOW_XINLINE operator T() const volatile { \ 100 return T(_cvtsh_ss(half_)); /* NOLINT(*)*/ \ 101 } 102 #else 103 #define MSHADOW_HALF_CONVERSIONOP(T) \ 104 MSHADOW_XINLINE operator T() const { \ 105 return T(half2float(half_)); /* NOLINT(*)*/ \ 106 } \ 107 MSHADOW_XINLINE operator T() const volatile { \ 108 return T(half2float(half_)); /* NOLINT(*)*/ \ 109 } 110 #endif // (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__)) 111 112 class MSHADOW_ALIGNED(2) half_t { 113 public: 114 union { 115 uint16_t half_; 116 #if MSHADOW_CUDA_HALF 117 __half cuhalf_; 118 #endif // MSHADOW_CUDA_HALF 119 }; 120 Binary(uint16_t value)121 static MSHADOW_XINLINE half_t Binary(uint16_t value) { 122 half_t res; 123 res.half_ = value; 124 return res; 125 } 126 half_t()127 MSHADOW_XINLINE half_t() {} 128 129 #if MSHADOW_CUDA_HALF half_t(const __half & value)130 MSHADOW_XINLINE explicit half_t(const __half& value) { 131 cuhalf_ = value; 132 } 133 #endif // MSHADOW_CUDA_HALF 134 half_t(const float & value)135 MSHADOW_XINLINE half_t(const float& value) { constructor(value); } half_t(const double & value)136 MSHADOW_XINLINE explicit half_t(const double& value) { constructor(value); } half_t(const int8_t & value)137 MSHADOW_XINLINE explicit half_t(const int8_t& value) { constructor(value); } half_t(const uint8_t & value)138 MSHADOW_XINLINE explicit half_t(const uint8_t& value) { constructor(value); } half_t(const int32_t & value)139 MSHADOW_XINLINE explicit half_t(const int32_t& value) { constructor(value); } half_t(const uint32_t & value)140 MSHADOW_XINLINE explicit half_t(const uint32_t& value) { constructor(value); } half_t(const int64_t & value)141 MSHADOW_XINLINE explicit half_t(const int64_t& value) { constructor(value); } half_t(const uint64_t & value)142 MSHADOW_XINLINE explicit half_t(const uint64_t& value) { constructor(value); } 143 144 MSHADOW_HALF_CONVERSIONOP(float) 145 146 MSHADOW_HALF_ASSIGNOP(+=, +) 147 MSHADOW_HALF_ASSIGNOP(-=, -) 148 MSHADOW_HALF_ASSIGNOP(*=, *) 149 MSHADOW_HALF_ASSIGNOP(/=, /) 150 151 MSHADOW_XINLINE half_t operator+() { 152 return *this; 153 } 154 155 MSHADOW_XINLINE half_t operator-() { 156 return half_t(-float(*this)); // NOLINT(*) 157 } 158 159 MSHADOW_XINLINE half_t operator=(const half_t& a) { 160 half_ = a.half_; 161 return a; 162 } 163 164 template<typename T> 165 MSHADOW_XINLINE half_t operator=(const T& a) { 166 return *this = half_t(a); /* NOLINT(*)*/ 167 } 168 169 MSHADOW_XINLINE half_t operator=(const half_t& a) volatile { 170 half_ = a.half_; 171 return a; 172 } 173 174 template<typename T> 175 MSHADOW_XINLINE half_t operator=(const T& a) volatile { 176 return *this = half_t(a); /* NOLINT(*)*/ 177 } 178 179 private: 180 union Bits { 181 float f; 182 int32_t si; 183 uint32_t ui; 184 }; 185 186 static int const fp16FractionBits = 10; 187 static int const fp32FractionBits = 23; 188 static int32_t const fp32FractionMask = ~(~0u << fp32FractionBits); // == 0x7fffff 189 static int32_t const fp32HiddenBit = 1 << fp32FractionBits; // == 0x800000 190 static int const shift = fp32FractionBits - fp16FractionBits; // == 13 191 static int const shiftSign = 16; 192 static int32_t const expAdjust = 127 - 15; // exp32-127 = exp16-15, so exp16 = exp32 - (127-15) 193 194 static int32_t const infN = 0x7F800000; // flt32 infinity 195 static int32_t const maxN = 0x477FFFFF; // max flt32 that's a flt16 normal after >> by shift 196 static int32_t const minN = 0x38800000; // min flt16 normal as a flt32 197 static int32_t const maxZ = 0x33000000; // max fp32 number that's still rounded to zero in fp16 198 static int32_t const signN = 0x80000000; // flt32 sign bit 199 200 static int32_t const infC = infN >> shift; 201 static int32_t const nanN = (infC + 1) << shift; // minimum flt16 nan as a flt32 202 static int32_t const maxC = maxN >> shift; 203 static int32_t const minC = minN >> shift; 204 static int32_t const signC = signN >> shiftSign; // flt16 sign bit 205 206 static int32_t const mulN = 0x52000000; // (1 << 23) / minN 207 static int32_t const mulC = 0x33800000; // minN / (1 << (23 - shift)) 208 209 static int32_t const subC = 0x003FF; // max flt32 subnormal down shifted 210 static int32_t const norC = 0x00400; // min flt32 normal down shifted 211 212 static int32_t const maxD = infC - maxC - 1; 213 static int32_t const minD = minC - subC - 1; 214 float2half(const float & value)215 MSHADOW_XINLINE uint16_t float2half(const float& value) const { 216 Bits v; 217 v.f = value; 218 uint32_t sign = v.si & signN; // grab sign bit 219 v.si ^= sign; // clear sign bit from v 220 sign >>= shiftSign; // logical shift sign to fp16 position 221 222 if (v.si <= maxZ) { 223 // Handle eventual zeros here to ensure vshift will not exceed 32 below. 224 v.ui = 0; 225 } else if (v.si < minN) { 226 // Handle denorms 227 uint32_t exp32 = v.ui >> fp32FractionBits; 228 int32_t exp16 = exp32 - expAdjust; 229 // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1. 230 // Smaller (so negative) exp16 values should result in greater right shifts. 231 uint32_t vshift = 1 - exp16; 232 uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask); 233 v.ui = significand >> vshift; 234 // The only time it's *not* OK to add 0x1000 (i.e. half the flt16 fraction lsb) is 235 // when the lsb of the flt16 fraction == 0 (so not rounding up to even) and the additional 236 // bits to the right of the lsb are 1000... (including flt32 significand bits 237 // that may be lost during the above vshift). The first term below will always 238 // be true for vshift >=12 (since even the 'hidden bit' has been shifted to the 239 // right of the '1' bit in 0x1000). And when vshift <= 11, both terms combine to make 240 // the proper test of the flt32 significand bits, including those lost during the vshift. 241 #if MSHADOW_HALF_ROUND_TO_NEAREST == 1 242 // Rounding may increase the exponent to 1, but that's OK. 243 v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0; 244 #endif 245 } else if (v.si <= maxN) { 246 // Handle norms 247 #if MSHADOW_HALF_ROUND_TO_NEAREST == 1 248 // Rounding may increase the exponent, possibly creating an inf, but that's OK. 249 v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0; 250 #endif 251 v.ui -= expAdjust << fp32FractionBits; 252 } else if (v.si <= infN) { 253 v.si = infN; 254 } else if (v.si < nanN) { 255 v.si = nanN; 256 } 257 258 v.ui >>= shift; 259 return sign | (v.ui & 0x7fff); 260 } 261 262 // Same as above routine, except for addition of volatile keyword float2half(const volatile float & value)263 MSHADOW_XINLINE uint16_t float2half(const volatile float& value) const volatile { // NOLINT (*) 264 Bits v; 265 v.f = value; 266 uint32_t sign = v.si & signN; // grab sign bit 267 v.si ^= sign; // clear sign bit from v 268 sign >>= shiftSign; // logical shift sign to fp16 position 269 270 if (v.si <= maxZ) { 271 // Handle eventual zeros here to ensure vshift will not exceed 32 below. 272 v.ui = 0; 273 } else if (v.si < minN) { 274 // Handle denorms 275 uint32_t exp32 = v.ui >> fp32FractionBits; 276 int32_t exp16 = exp32 - expAdjust; 277 // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1. 278 // Smaller (so negative) exp16 values should result in greater right shifts. 279 uint32_t vshift = 1 - exp16; 280 uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask); 281 v.ui = significand >> vshift; 282 #if MSHADOW_HALF_ROUND_TO_NEAREST == 1 283 // Rounding may increase the exponent to 1, but that's OK. 284 v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0; 285 #endif 286 } else if (v.si <= maxN) { 287 // Handle norms 288 #if MSHADOW_HALF_ROUND_TO_NEAREST == 1 289 // Rounding may increase the exponent, possibly creating an inf, but that's OK. 290 v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0; 291 #endif 292 v.ui -= expAdjust << fp32FractionBits; 293 } else if (v.si <= infN) { 294 v.si = infN; 295 } else if (v.si < nanN) { 296 v.si = nanN; 297 } 298 299 v.ui >>= shift; 300 return sign | (v.ui & 0x7fff); 301 } 302 half2float(const uint16_t & value)303 MSHADOW_XINLINE float half2float(const uint16_t& value) const { 304 Bits v; 305 v.ui = value; 306 int32_t sign = v.si & signC; 307 v.si ^= sign; 308 sign <<= shiftSign; 309 v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC); 310 v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC); 311 Bits s; 312 s.si = mulC; 313 s.f *= v.si; 314 int32_t mask = -(norC > v.si); 315 v.si <<= shift; 316 v.si ^= (s.si ^ v.si) & mask; 317 v.si |= sign; 318 return v.f; 319 } 320 half2float(const volatile uint16_t & value)321 MSHADOW_XINLINE float half2float(const volatile uint16_t& value) const volatile { // NOLINT(*) 322 Bits v; 323 v.ui = value; 324 int32_t sign = v.si & signC; 325 v.si ^= sign; 326 sign <<= shiftSign; 327 v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC); 328 v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC); 329 Bits s; 330 s.si = mulC; 331 s.f *= v.si; 332 int32_t mask = -(norC > v.si); 333 v.si <<= shift; 334 v.si ^= (s.si ^ v.si) & mask; 335 v.si |= sign; 336 return v.f; 337 } 338 339 template<typename T> constructor(const T & value)340 MSHADOW_XINLINE void constructor(const T& value) { 341 #if (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__)) 342 cuhalf_ = __float2half(float(value)); // NOLINT(*) 343 #elif(MSHADOW_USE_F16C) 344 half_ = _cvtss_sh(static_cast<float>(value), 0); 345 #else /* !MSHADOW_CUDA_HALF && !MSHADOW_USE_F16C */ 346 half_ = float2half(float(value)); // NOLINT(*) 347 #endif /* !MSHADOW_CUDA_HALF && !MSHADOW_USE_F16C */ 348 } 349 }; 350 351 /*! \brief overloaded + operator for half_t */ 352 MSHADOW_HALF_OPERATOR(half_t, +) 353 /*! \brief overloaded - operator for half_t */ 354 MSHADOW_HALF_OPERATOR(half_t, -) 355 /*! \brief overloaded * operator for half_t */ 356 MSHADOW_HALF_OPERATOR(half_t, *) 357 /*! \brief overloaded / operator for half_t */ 358 MSHADOW_HALF_OPERATOR(half_t, /) 359 /*! \brief overloaded > operator for half_t */ 360 MSHADOW_HALF_OPERATOR(bool, >) 361 /*! \brief overloaded < operator for half_t */ 362 MSHADOW_HALF_OPERATOR(bool, <) 363 /*! \brief overloaded >= operator for half_t */ 364 MSHADOW_HALF_OPERATOR(bool, >=) 365 /*! \brief overloaded <= operator for half_t */ 366 MSHADOW_HALF_OPERATOR(bool, <=) 367 368 #define MSHADOW_HALF_MIN mshadow::half::half_t::Binary(0xFBFF); 369 #define MSHADOW_HALF_MAX mshadow::half::half_t::Binary(0x7BFF); 370 #define MSHADOW_HALF_SIGN_BIT 0x8000 371 #define MSHADOW_HALF_EXPONENT_BITS 0x7c00 372 } // namespace half 373 } // namespace mshadow 374 #endif // MSHADOW_HALF_H_ 375