1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 /* 20 * \file data_type.h 21 * \brief Primitive runtime data type. 22 */ 23 // Acknowledgement: This file originates from incubator-tvm 24 // Acknowledgement: MXNetDataType structure design originates from Halide. 25 #ifndef MXNET_RUNTIME_DATA_TYPE_H_ 26 #define MXNET_RUNTIME_DATA_TYPE_H_ 27 28 #include <mxnet/runtime/c_runtime_api.h> 29 #include <dmlc/logging.h> 30 #include <type_traits> 31 32 33 namespace mxnet { 34 namespace runtime { 35 /*! 36 * \brief Runtime primitive data type. 37 * 38 * This class is a thin wrapper of DLDataType. 39 * We also make use of MXNetDataType in compiler to store quick hint 40 */ 41 class MXNetDataType { 42 public: 43 /*! \brief Type code for the MXNetDataType. */ 44 enum TypeCode { 45 kInt = kDLInt, 46 kUInt = kDLUInt, 47 kFloat = kDLFloat, 48 kHandle = MXNetTypeCode::kHandle, 49 }; 50 /*! \brief default constructor */ MXNetDataType()51 MXNetDataType() {} 52 /*! 53 * \brief Constructor 54 * \param dtype The DLDataType 55 */ MXNetDataType(DLDataType dtype)56 explicit MXNetDataType(DLDataType dtype) 57 : data_(dtype) {} 58 /*! 59 * \brief Constructor 60 * \param code The type code. 61 * \param bits The number of bits in the type. 62 * \param lanes The number of lanes. 63 */ MXNetDataType(int code,int bits,int lanes)64 MXNetDataType(int code, int bits, int lanes) { 65 data_.code = static_cast<uint8_t>(code); 66 data_.bits = static_cast<uint8_t>(bits); 67 data_.lanes = static_cast<uint16_t>(lanes); 68 } 69 /*! \return The type code. */ code()70 int code() const { 71 return static_cast<int>(data_.code); 72 } 73 /*! \return number of bits in the data. */ bits()74 int bits() const { 75 return static_cast<int>(data_.bits); 76 } 77 /*! \return number of bytes to store each scalar. */ bytes()78 int bytes() const { 79 return (bits() + 7) / 8; 80 } 81 /*! \return number of lanes in the data. */ lanes()82 int lanes() const { 83 return static_cast<int>(data_.lanes); 84 } 85 /*! \return whether type is a scalar type. */ is_scalar()86 bool is_scalar() const { 87 return lanes() == 1; 88 } 89 /*! \return whether type is a scalar type. */ is_bool()90 bool is_bool() const { 91 return code() == MXNetDataType::kUInt && bits() == 1; 92 } 93 /*! \return whether type is a float type. */ is_float()94 bool is_float() const { 95 return code() == MXNetDataType::kFloat; 96 } 97 /*! \return whether type is an int type. */ is_int()98 bool is_int() const { 99 return code() == MXNetDataType::kInt; 100 } 101 /*! \return whether type is an uint type. */ is_uint()102 bool is_uint() const { 103 return code() == MXNetDataType::kUInt; 104 } 105 /*! \return whether type is a handle type. */ is_handle()106 bool is_handle() const { 107 return code() == MXNetDataType::kHandle; 108 } 109 /*! \return whether type is a vector type. */ is_vector()110 bool is_vector() const { 111 return lanes() > 1; 112 } 113 /*! 114 * \brief Create a new data type by change lanes to a specified value. 115 * \param lanes The target number of lanes. 116 * \return the result type. 117 */ with_lanes(int lanes)118 MXNetDataType with_lanes(int lanes) const { 119 return MXNetDataType(data_.code, data_.bits, lanes); 120 } 121 /*! 122 * \brief Create a new data type by change bits to a specified value. 123 * \param bits The target number of bits. 124 * \return the result type. 125 */ with_bits(int bits)126 MXNetDataType with_bits(int bits) const { 127 return MXNetDataType(data_.code, bits, data_.lanes); 128 } 129 /*! 130 * \brief Get the scalar version of the type. 131 * \return the result type. 132 */ element_of()133 MXNetDataType element_of() const { 134 return with_lanes(1); 135 } 136 /*! 137 * \brief Equal comparator. 138 * \param other The data type to compre against. 139 * \return The comparison resilt. 140 */ 141 bool operator==(const MXNetDataType& other) const { 142 return 143 data_.code == other.data_.code && 144 data_.bits == other.data_.bits && 145 data_.lanes == other.data_.lanes; 146 } 147 /*! 148 * \brief NotEqual comparator. 149 * \param other The data type to compre against. 150 * \return The comparison resilt. 151 */ 152 bool operator!=(const MXNetDataType& other) const { 153 return !operator==(other); 154 } 155 /*! 156 * \brief Converter to DLDataType 157 * \return the result. 158 */ DLDataType()159 operator DLDataType () const { 160 return data_; 161 } 162 163 /*! 164 * \brief Construct an int type. 165 * \param bits The number of bits in the type. 166 * \param lanes The number of lanes. 167 * \return The constructed data type. 168 */ 169 static MXNetDataType Int(int bits, int lanes = 1) { 170 return MXNetDataType(kDLInt, bits, lanes); 171 } 172 /*! 173 * \brief Construct an uint type. 174 * \param bits The number of bits in the type. 175 * \param lanes The number of lanes 176 * \return The constructed data type. 177 */ 178 static MXNetDataType UInt(int bits, int lanes = 1) { 179 return MXNetDataType(kDLUInt, bits, lanes); 180 } 181 /*! 182 * \brief Construct an uint type. 183 * \param bits The number of bits in the type. 184 * \param lanes The number of lanes 185 * \return The constructed data type. 186 */ 187 static MXNetDataType Float(int bits, int lanes = 1) { 188 return MXNetDataType(kDLFloat, bits, lanes); 189 } 190 /*! 191 * \brief Construct a bool type. 192 * \param lanes The number of lanes 193 * \return The constructed data type. 194 */ 195 static MXNetDataType Bool(int lanes = 1) { 196 return MXNetDataType::UInt(1, lanes); 197 } 198 /*! 199 * \brief Construct a handle type. 200 * \param bits The number of bits in the type. 201 * \param lanes The number of lanes 202 * \return The constructed data type. 203 */ 204 static MXNetDataType Handle(int bits = 64, int lanes = 1) { 205 return MXNetDataType(kHandle, bits, lanes); 206 } 207 208 private: 209 DLDataType data_; 210 }; 211 212 } // namespace runtime 213 214 using MXNetDataType = runtime::MXNetDataType; 215 216 } // namespace mxnet 217 #endif // MXNET_RUNTIME_DATA_TYPE_H_ 218