1 #pragma once
2 
3 #include <cuda_fp16.h>
4 
5 #include "chainerx/float16.h"
6 #include "chainerx/scalar.h"
7 
8 namespace chainerx {
9 namespace cuda {
10 
11 // Float16 for CUDA devices.
12 // Used from the device, it supports full arithmetics just like other C++ numerical types.
13 // Used from the host, it's only a 16-bit data storage.
14 class Float16 {
15 private:
16     struct FromDataTag {};
17 
18 public:
Float16()19     __device__ Float16() : data_{0} {}
Float16(bool v)20     explicit __device__ Float16(bool v) : Float16{static_cast<int>(v)} {}
Float16(int8_t v)21     explicit __device__ Float16(int8_t v) : Float16{static_cast<int16_t>(v)} {}
Float16(uint8_t v)22     explicit __device__ Float16(uint8_t v) : Float16{static_cast<uint16_t>(v)} {}
Float16(int64_t v)23     explicit __device__ Float16(int64_t v) : Float16{static_cast<int16_t>(v)} {}
Float16(uint64_t v)24     explicit __device__ Float16(uint64_t v) : Float16{static_cast<uint16_t>(v)} {}
25     template <typename T>
Float16(T v)26     explicit __device__ Float16(T v) : Float16{::__half{v}} {}
27     // It is assumed that chainerx::Float16 and cuda::Float16 have commmon representaiton.
Float16(Scalar v)28     explicit __host__ Float16(Scalar v) : Float16{static_cast<chainerx::Float16>(v)} {}
Float16(chainerx::Float16 v)29     explicit __host__ Float16(chainerx::Float16 v) : Float16{v.data(), FromDataTag{}} {}
30 
operator bool() const31     explicit __device__ operator bool() const { return *this != Float16{0}; }
32     // int8 conversion is not implemented in cuda_fp16
operator int8_t() const33     explicit __device__ operator int8_t() const { return static_cast<int8_t>(static_cast<int16_t>(*this)); }
operator uint8_t() const34     explicit __device__ operator uint8_t() const { return static_cast<uint8_t>(static_cast<uint16_t>(*this)); }
operator int16_t() const35     explicit __device__ operator int16_t() const { return static_cast<int16_t>(cuda_half()); }
operator uint16_t() const36     explicit __device__ operator uint16_t() const { return static_cast<uint16_t>(cuda_half()); }
operator int32_t() const37     explicit __device__ operator int32_t() const { return static_cast<int32_t>(cuda_half()); }
operator uint32_t() const38     explicit __device__ operator uint32_t() const { return static_cast<uint32_t>(cuda_half()); }
39     // int64 conversion is not implemented in cuda_fp16
operator int64_t() const40     explicit __device__ operator int64_t() const { return static_cast<int32_t>(cuda_half()); }
operator uint64_t() const41     explicit __device__ operator uint64_t() const { return static_cast<uint32_t>(cuda_half()); }
operator float() const42     explicit __device__ operator float() const { return static_cast<float>(cuda_half()); }
43     // double conversion is not implemented in cuda_fp16
operator double() const44     explicit __device__ operator double() const { return float{*this}; }
45 
46     // TODO(imanishi): Use cuda_half()
operator -() const47     __device__ Float16 operator-() const { return Float16{-static_cast<float>(*this)}; }
operator !() const48     __device__ bool operator!() const { return !static_cast<float>(*this); }
operator +(Float16 r) const49     __device__ Float16 operator+(Float16 r) const { return Float16{static_cast<float>(*this) + static_cast<float>(r)}; }
operator -(Float16 r) const50     __device__ Float16 operator-(Float16 r) const { return Float16{static_cast<float>(*this) - static_cast<float>(r)}; }
operator *(Float16 r) const51     __device__ Float16 operator*(Float16 r) const { return Float16{static_cast<float>(*this) * static_cast<float>(r)}; }
operator /(Float16 r) const52     __device__ Float16 operator/(Float16 r) const { return Float16{static_cast<float>(*this) / static_cast<float>(r)}; }
operator +=(Float16 r)53     __device__ Float16 operator+=(Float16 r) { return *this = Float16{*this + r}; }
operator -=(Float16 r)54     __device__ Float16 operator-=(Float16 r) { return *this = Float16{*this - r}; }
operator *=(Float16 r)55     __device__ Float16 operator*=(Float16 r) { return *this = Float16{*this * r}; }
operator /=(Float16 r)56     __device__ Float16 operator/=(Float16 r) { return *this = Float16{*this / r}; }
operator ==(Float16 r) const57     __device__ bool operator==(Float16 r) const { return static_cast<float>(*this) == static_cast<float>(r); }
operator !=(Float16 r) const58     __device__ bool operator!=(Float16 r) const { return !(*this == r); }
operator <(Float16 r) const59     __device__ bool operator<(Float16 r) const { return static_cast<float>(*this) < static_cast<float>(r); }
operator >(Float16 r) const60     __device__ bool operator>(Float16 r) const { return static_cast<float>(*this) > static_cast<float>(r); }
operator <=(Float16 r) const61     __device__ bool operator<=(Float16 r) const { return static_cast<float>(*this) <= static_cast<float>(r); }
operator >=(Float16 r) const62     __device__ bool operator>=(Float16 r) const { return static_cast<float>(*this) >= static_cast<float>(r); }
63 
FromData(uint16_t data)64     __host__ __device__ static constexpr Float16 FromData(uint16_t data) { return cuda::Float16{data, FromDataTag{}}; }
65 
Inf()66     __host__ __device__ static constexpr Float16 Inf() { return FromData(0x7c00U); }
NegInf()67     __host__ __device__ static constexpr Float16 NegInf() { return FromData(0xfc00U); }
68 
IsNan() const69     __device__ bool IsNan() const { return (data_ & 0x7c00U) == 0x7c00U && (data_ & 0x03ffU) != 0; }
IsInf() const70     __device__ bool IsInf() const { return (data_ & 0x7c00U) == 0x7c00U && (data_ & 0x03ffU) == 0; }
Exp() const71     __device__ Float16 Exp() const { return Float16{std::exp(static_cast<float>(*this))}; }
Log() const72     __device__ Float16 Log() const { return Float16{std::log(static_cast<float>(*this))}; }
Log10() const73     __device__ Float16 Log10() const { return Float16{std::log10(static_cast<float>(*this))}; }
Log2() const74     __device__ Float16 Log2() const { return Float16{std::log2f(static_cast<float>(*this))}; }
Log1p() const75     __device__ Float16 Log1p() const { return Float16{std::log1pf(static_cast<float>(*this))}; }
Sqrt() const76     __device__ Float16 Sqrt() const { return Float16{std::sqrt(static_cast<float>(*this))}; }
Floor() const77     __device__ Float16 Floor() const { return Float16{std::floor(static_cast<float>(*this))}; }
78 
79 private:
Float16(::__half x)80     explicit __device__ Float16(::__half x) : data_{__half_as_ushort(x)} {}
Float16(uint16_t data,FromDataTag)81     explicit __host__ __device__ constexpr Float16(uint16_t data, FromDataTag) : data_{data} {}
82 
cuda_half() const83     __device__ ::__half cuda_half() const { return ::__half{__ushort_as_half(data_)}; }
84 
85     uint16_t data_;
86 };
87 
88 template <typename T>
operator ==(const T & l,Float16 r)89 __device__ inline bool operator==(const T& l, Float16 r) {
90     return l == static_cast<float>(r);
91 }
92 
93 template <typename T>
operator ==(Float16 l,const T & r)94 __device__ inline bool operator==(Float16 l, const T& r) {
95     return static_cast<float>(l) == r;
96 }
97 
98 template <typename T>
operator !=(const T & l,Float16 r)99 __device__ inline bool operator!=(const T& l, Float16 r) {
100     return !(l == r);
101 }
102 
103 template <typename T>
operator !=(Float16 l,const T & r)104 __device__ inline bool operator!=(Float16 l, const T& r) {
105     return !(l == r);
106 }
107 
108 }  // namespace cuda
109 }  // namespace chainerx
110