1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file half.h
22  * \brief definition of half (float16) type.
23  *
24  * \author Junyuan Xie
25  */
26 #ifndef MSHADOW_HALF_H_
27 #define MSHADOW_HALF_H_
28 #include "./base.h"
29 
30 #if MSHADOW_USE_F16C
31   #include <x86intrin.h>
32 #endif  // MSHADOW_USE_F16C
33 
34 // This flag dictates rounding for the float2half() routine only (used generally on Windows),
35 // not the f16c lib or cuda v7.5 (or later) behavior which is fixed at round-to-nearest-even.
36 #ifndef MSHADOW_HALF_ROUND_TO_NEAREST
37 #define MSHADOW_HALF_ROUND_TO_NEAREST 1
38 #endif
39 
40 #if (MSHADOW_USE_CUDA && CUDA_VERSION >= 7050)
41   #define MSHADOW_CUDA_HALF 1
42   #include <cuda_fp16.h>
43   #if defined(__CUDA_ARCH__)
44     /*! \brief __half2float_warp */
__half2float_warp(const volatile __half & h)45     __host__ __device__ float __half2float_warp(const volatile __half& h) { /* NOLINT(*) */
46       __half val;
47 #if CUDA_VERSION >= 9000
48       val = const_cast<__half&>(h);
49 #else
50       val.x = h.x;
51 #endif
52       return __half2float(val);
53     }
54   #endif
55 #else
56   #define MSHADOW_CUDA_HALF 0
57 #endif
58 
59 /*! \brief namespace for mshadow */
60 namespace mshadow {
61 /* \brief name space for host/device portable half-precision floats */
62 namespace half {
63 #define MSHADOW_HALF_OPERATOR(RTYPE, OP)                                  \
64   MSHADOW_XINLINE RTYPE operator OP (half_t a, half_t b) {                \
65     return RTYPE(float(a) OP float(b));  /* NOLINT(*) */                  \
66   }                                                                       \
67   template<typename T>                                                    \
68   MSHADOW_XINLINE RTYPE operator OP (half_t a, T b) {                     \
69     return RTYPE(float(a) OP float(b));  /* NOLINT(*) */                  \
70   }                                                                       \
71   template<typename T>                                                    \
72   MSHADOW_XINLINE RTYPE operator OP (T a, half_t b) {                     \
73     return RTYPE(float(a) OP float(b));  /* NOLINT(*) */                  \
74   }
75 
76 #define MSHADOW_HALF_ASSIGNOP(AOP, OP)                                    \
77   template<typename T>                                                    \
78   MSHADOW_XINLINE half_t operator AOP (const T& a) {                      \
79     return *this = half_t(float(*this) OP float(a));  /* NOLINT(*)*/      \
80   }                                                                       \
81   template<typename T>                                                    \
82   MSHADOW_XINLINE half_t operator AOP (const volatile T& a) volatile {    \
83     return *this = half_t(float(*this) OP float(a));  /* NOLINT(*)*/      \
84   }
85 
86 #if (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__))
87 #define MSHADOW_HALF_CONVERSIONOP(T)                                      \
88   MSHADOW_XINLINE operator T() const {                                    \
89     return T(__half2float(cuhalf_));  /* NOLINT(*)*/                      \
90   }                                                                       \
91   MSHADOW_XINLINE operator T() const volatile {                           \
92     return T(__half2float_warp(cuhalf_));  /* NOLINT(*)*/                 \
93   }
94 #elif(MSHADOW_USE_F16C)
95 #define MSHADOW_HALF_CONVERSIONOP(T)                                      \
96   MSHADOW_XINLINE operator T() const {                                    \
97     return T(_cvtsh_ss(half_));   /* NOLINT(*)*/                          \
98   }                                                                       \
99   MSHADOW_XINLINE operator T() const volatile {                           \
100     return T(_cvtsh_ss(half_));   /* NOLINT(*)*/                          \
101   }
102 #else
103 #define MSHADOW_HALF_CONVERSIONOP(T)                                      \
104   MSHADOW_XINLINE operator T() const {                                    \
105     return T(half2float(half_));  /* NOLINT(*)*/                          \
106   }                                                                       \
107   MSHADOW_XINLINE operator T() const volatile {                           \
108     return T(half2float(half_));  /* NOLINT(*)*/                          \
109   }
110 #endif  // (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__))
111 
112 class MSHADOW_ALIGNED(2) half_t {
113  public:
114   union {
115     uint16_t half_;
116 #if MSHADOW_CUDA_HALF
117     __half cuhalf_;
118 #endif  // MSHADOW_CUDA_HALF
119   };
120 
Binary(uint16_t value)121   static MSHADOW_XINLINE half_t Binary(uint16_t value) {
122     half_t res;
123     res.half_ = value;
124     return res;
125   }
126 
half_t()127   MSHADOW_XINLINE half_t() {}
128 
129 #if MSHADOW_CUDA_HALF
half_t(const __half & value)130   MSHADOW_XINLINE explicit half_t(const __half& value) {
131     cuhalf_ = value;
132   }
133 #endif  // MSHADOW_CUDA_HALF
134 
half_t(const float & value)135   MSHADOW_XINLINE half_t(const float& value) { constructor(value); }
half_t(const double & value)136   MSHADOW_XINLINE explicit half_t(const double& value) { constructor(value); }
half_t(const int8_t & value)137   MSHADOW_XINLINE explicit half_t(const int8_t& value) { constructor(value); }
half_t(const uint8_t & value)138   MSHADOW_XINLINE explicit half_t(const uint8_t& value) { constructor(value); }
half_t(const int32_t & value)139   MSHADOW_XINLINE explicit half_t(const int32_t& value) { constructor(value); }
half_t(const uint32_t & value)140   MSHADOW_XINLINE explicit half_t(const uint32_t& value) { constructor(value); }
half_t(const int64_t & value)141   MSHADOW_XINLINE explicit half_t(const int64_t& value) { constructor(value); }
half_t(const uint64_t & value)142   MSHADOW_XINLINE explicit half_t(const uint64_t& value) { constructor(value); }
143 
144   MSHADOW_HALF_CONVERSIONOP(float)
145 
146   MSHADOW_HALF_ASSIGNOP(+=, +)
147   MSHADOW_HALF_ASSIGNOP(-=, -)
148   MSHADOW_HALF_ASSIGNOP(*=, *)
149   MSHADOW_HALF_ASSIGNOP(/=, /)
150 
151   MSHADOW_XINLINE half_t operator+() {
152     return *this;
153   }
154 
155   MSHADOW_XINLINE half_t operator-() {
156     return half_t(-float(*this));  // NOLINT(*)
157   }
158 
159   MSHADOW_XINLINE half_t operator=(const half_t& a) {
160     half_ = a.half_;
161     return a;
162   }
163 
164   template<typename T>
165   MSHADOW_XINLINE half_t operator=(const T& a) {
166     return *this = half_t(a);  /* NOLINT(*)*/
167   }
168 
169   MSHADOW_XINLINE half_t operator=(const half_t& a) volatile {
170     half_ = a.half_;
171     return a;
172   }
173 
174   template<typename T>
175   MSHADOW_XINLINE half_t operator=(const T& a) volatile {
176     return *this = half_t(a);  /* NOLINT(*)*/
177   }
178 
179  private:
180   union Bits {
181     float f;
182     int32_t si;
183     uint32_t ui;
184   };
185 
186   static int const fp16FractionBits = 10;
187   static int const fp32FractionBits = 23;
188   static int32_t const fp32FractionMask = ~(~0u << fp32FractionBits);  // == 0x7fffff
189   static int32_t const fp32HiddenBit = 1 << fp32FractionBits;         // == 0x800000
190   static int const shift = fp32FractionBits - fp16FractionBits;       // == 13
191   static int const shiftSign = 16;
192   static int32_t const expAdjust = 127 - 15;    // exp32-127 = exp16-15, so exp16 = exp32 - (127-15)
193 
194   static int32_t const infN = 0x7F800000;  // flt32 infinity
195   static int32_t const maxN = 0x477FFFFF;  // max flt32 that's a flt16 normal after >> by shift
196   static int32_t const minN = 0x38800000;  // min flt16 normal as a flt32
197   static int32_t const maxZ = 0x33000000;  // max fp32 number that's still rounded to zero in fp16
198   static int32_t const signN = 0x80000000;  // flt32 sign bit
199 
200   static int32_t const infC = infN >> shift;
201   static int32_t const nanN = (infC + 1) << shift;  // minimum flt16 nan as a flt32
202   static int32_t const maxC = maxN >> shift;
203   static int32_t const minC = minN >> shift;
204   static int32_t const signC = signN >> shiftSign;  // flt16 sign bit
205 
206   static int32_t const mulN = 0x52000000;  // (1 << 23) / minN
207   static int32_t const mulC = 0x33800000;  // minN / (1 << (23 - shift))
208 
209   static int32_t const subC = 0x003FF;  // max flt32 subnormal down shifted
210   static int32_t const norC = 0x00400;  // min flt32 normal down shifted
211 
212   static int32_t const maxD = infC - maxC - 1;
213   static int32_t const minD = minC - subC - 1;
214 
float2half(const float & value)215   MSHADOW_XINLINE uint16_t float2half(const float& value) const {
216     Bits v;
217     v.f = value;
218     uint32_t sign = v.si & signN;    // grab sign bit
219     v.si ^= sign;                    // clear sign bit from v
220     sign >>= shiftSign;              // logical shift sign to fp16 position
221 
222     if (v.si <= maxZ) {
223       // Handle eventual zeros here to ensure vshift will not exceed 32 below.
224       v.ui = 0;
225     } else if (v.si < minN) {
226       // Handle denorms
227       uint32_t exp32 = v.ui >> fp32FractionBits;
228       int32_t exp16 = exp32 - expAdjust;
229       // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.
230       // Smaller (so negative) exp16 values should result in greater right shifts.
231       uint32_t vshift = 1 - exp16;
232       uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
233       v.ui = significand >> vshift;
234       // The only time it's *not* OK to add 0x1000 (i.e. half the flt16 fraction lsb) is
235       // when the lsb of the flt16 fraction == 0 (so not rounding up to even) and the additional
236       // bits to the right of the lsb are 1000... (including flt32 significand bits
237       // that may be lost during the above vshift).  The first term below will always
238       // be true for vshift >=12 (since even the 'hidden bit' has been shifted to the
239       // right of the '1' bit in 0x1000). And when vshift <= 11, both terms combine to make
240       // the proper test of the flt32 significand bits, including those lost during the vshift.
241 #if MSHADOW_HALF_ROUND_TO_NEAREST == 1
242       // Rounding may increase the exponent to 1, but that's OK.
243       v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;
244 #endif
245     } else if (v.si <= maxN) {
246       // Handle norms
247 #if MSHADOW_HALF_ROUND_TO_NEAREST == 1
248       // Rounding may increase the exponent, possibly creating an inf, but that's OK.
249       v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
250 #endif
251       v.ui -= expAdjust << fp32FractionBits;
252     } else if (v.si <= infN) {
253       v.si = infN;
254     } else if (v.si < nanN) {
255       v.si = nanN;
256     }
257 
258     v.ui >>= shift;
259     return sign | (v.ui & 0x7fff);
260   }
261 
262   // Same as above routine, except for addition of volatile keyword
float2half(const volatile float & value)263   MSHADOW_XINLINE uint16_t float2half(const volatile float& value) const volatile {  // NOLINT (*)
264     Bits v;
265     v.f = value;
266     uint32_t sign = v.si & signN;    // grab sign bit
267     v.si ^= sign;                    // clear sign bit from v
268     sign >>= shiftSign;              // logical shift sign to fp16 position
269 
270     if (v.si <= maxZ) {
271       // Handle eventual zeros here to ensure vshift will not exceed 32 below.
272       v.ui = 0;
273     } else if (v.si < minN) {
274       // Handle denorms
275       uint32_t exp32 = v.ui >> fp32FractionBits;
276       int32_t exp16 = exp32 - expAdjust;
277       // If exp16 == 0 (just into the denorm range), then significant should be shifted right 1.
278       // Smaller (so negative) exp16 values should result in greater right shifts.
279       uint32_t vshift = 1 - exp16;
280       uint32_t significand = fp32HiddenBit | (v.ui & fp32FractionMask);
281       v.ui = significand >> vshift;
282 #if MSHADOW_HALF_ROUND_TO_NEAREST == 1
283       // Rounding may increase the exponent to 1, but that's OK.
284       v.ui += (v.ui & 0x3fff) != 0x1000 || (significand & 0x7ff) ? 0x1000 : 0;
285 #endif
286     } else if (v.si <= maxN) {
287       // Handle norms
288 #if MSHADOW_HALF_ROUND_TO_NEAREST == 1
289       // Rounding may increase the exponent, possibly creating an inf, but that's OK.
290       v.ui += (v.ui & 0x3fff) != 0x1000 ? 0x1000 : 0;
291 #endif
292       v.ui -= expAdjust << fp32FractionBits;
293     } else if (v.si <= infN) {
294       v.si = infN;
295     } else if (v.si < nanN) {
296       v.si = nanN;
297     }
298 
299     v.ui >>= shift;
300     return sign | (v.ui & 0x7fff);
301   }
302 
half2float(const uint16_t & value)303   MSHADOW_XINLINE float half2float(const uint16_t& value) const {
304     Bits v;
305     v.ui = value;
306     int32_t sign = v.si & signC;
307     v.si ^= sign;
308     sign <<= shiftSign;
309     v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
310     v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
311     Bits s;
312     s.si = mulC;
313     s.f *= v.si;
314     int32_t mask = -(norC > v.si);
315     v.si <<= shift;
316     v.si ^= (s.si ^ v.si) & mask;
317     v.si |= sign;
318     return v.f;
319   }
320 
half2float(const volatile uint16_t & value)321   MSHADOW_XINLINE float half2float(const volatile uint16_t& value) const volatile {  // NOLINT(*)
322     Bits v;
323     v.ui = value;
324     int32_t sign = v.si & signC;
325     v.si ^= sign;
326     sign <<= shiftSign;
327     v.si ^= ((v.si + minD) ^ v.si) & -(v.si > subC);
328     v.si ^= ((v.si + maxD) ^ v.si) & -(v.si > maxC);
329     Bits s;
330     s.si = mulC;
331     s.f *= v.si;
332     int32_t mask = -(norC > v.si);
333     v.si <<= shift;
334     v.si ^= (s.si ^ v.si) & mask;
335     v.si |= sign;
336     return v.f;
337   }
338 
339   template<typename T>
constructor(const T & value)340   MSHADOW_XINLINE void constructor(const T& value) {
341 #if (MSHADOW_CUDA_HALF && defined(__CUDA_ARCH__))
342     cuhalf_ = __float2half(float(value));  // NOLINT(*)
343 #elif(MSHADOW_USE_F16C)
344     half_ = _cvtss_sh(static_cast<float>(value), 0);
345 #else /* !MSHADOW_CUDA_HALF && !MSHADOW_USE_F16C */
346     half_ = float2half(float(value));  // NOLINT(*)
347 #endif /* !MSHADOW_CUDA_HALF && !MSHADOW_USE_F16C */
348   }
349 };
350 
351 /*! \brief overloaded + operator for half_t */
352 MSHADOW_HALF_OPERATOR(half_t, +)
353 /*! \brief overloaded - operator for half_t */
354 MSHADOW_HALF_OPERATOR(half_t, -)
355 /*! \brief overloaded * operator for half_t */
356 MSHADOW_HALF_OPERATOR(half_t, *)
357 /*! \brief overloaded / operator for half_t */
358 MSHADOW_HALF_OPERATOR(half_t, /)
359 /*! \brief overloaded > operator for half_t */
360 MSHADOW_HALF_OPERATOR(bool, >)
361 /*! \brief overloaded < operator for half_t */
362 MSHADOW_HALF_OPERATOR(bool, <)
363 /*! \brief overloaded >= operator for half_t */
364 MSHADOW_HALF_OPERATOR(bool, >=)
365 /*! \brief overloaded <= operator for half_t */
366 MSHADOW_HALF_OPERATOR(bool, <=)
367 
368 #define MSHADOW_HALF_MIN mshadow::half::half_t::Binary(0xFBFF);
369 #define MSHADOW_HALF_MAX mshadow::half::half_t::Binary(0x7BFF);
370 #define MSHADOW_HALF_SIGN_BIT 0x8000
371 #define MSHADOW_HALF_EXPONENT_BITS 0x7c00
372 }  // namespace half
373 }  // namespace mshadow
374 #endif  // MSHADOW_HALF_H_
375