1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19 /*
20 * \file tvm/dtype.h
21 * \brief Data type used in IR.
22 */
23 // Acknowledgement: DataType structure design originates from Halide.
24 #ifndef TVM_DTYPE_H_
25 #define TVM_DTYPE_H_
26
27 #include "runtime/packed_func.h"
28
29 namespace tvm {
30 class Expr;
31
32 /*!
33 * \brief Primitive data types in tvm.
34 */
35 class DataType {
36 public:
37 /*! \brief default constructor */
DataType()38 DataType() {}
39 /*!
40 * \brief Constructor
41 * \param dtype The DLDataType
42 */
DataType(DLDataType dtype)43 explicit DataType(DLDataType dtype)
44 : data_(dtype) {}
45 /*!
46 * \brief Constructor
47 * \param code The type code.
48 * \param bits The number of bits in the type.
49 * \param lanes The number of lanes.
50 */
DataType(int code,int bits,int lanes)51 DataType(int code, int bits, int lanes) {
52 data_.code = static_cast<uint8_t>(code);
53 data_.bits = static_cast<uint8_t>(bits);
54 data_.lanes = static_cast<uint16_t>(lanes);
55 }
56 /*! \return The type code. */
code()57 int code() const {
58 return static_cast<int>(data_.code);
59 }
60 /*! \return number of bits in the data. */
bits()61 int bits() const {
62 return static_cast<int>(data_.bits);
63 }
64 /*! \return number of bytes to store each scalar. */
bytes()65 int bytes() const {
66 return (bits() + 7) / 8;
67 }
68 /*! \return number of lanes in the data. */
lanes()69 int lanes() const {
70 return static_cast<int>(data_.lanes);
71 }
72 /*! \return whether type is a scalar type. */
is_scalar()73 bool is_scalar() const {
74 return lanes() == 1;
75 }
76 /*! \return whether type is a scalar type. */
is_bool()77 bool is_bool() const {
78 return code() == kDLUInt && bits() == 1;
79 }
80 /*! \return whether type is a float type. */
is_float()81 bool is_float() const {
82 return code() == kDLFloat;
83 }
84 /*! \return whether type is an int type. */
is_int()85 bool is_int() const {
86 return code() == kDLInt;
87 }
88 /*! \return whether type is an uint type. */
is_uint()89 bool is_uint() const {
90 return code() == kDLUInt;
91 }
92 /*! \return whether type is a handle type. */
is_handle()93 bool is_handle() const {
94 return code() == kHandle;
95 }
96 /*! \return whether type is a vector type. */
is_vector()97 bool is_vector() const {
98 return lanes() > 1;
99 }
100 /*!
101 * \brief Create a new data type by change lanes to a specified value.
102 * \param lanes The target number of lanes.
103 * \return the result type.
104 */
with_lanes(int lanes)105 DataType with_lanes(int lanes) const {
106 return DataType(data_.code, data_.bits, lanes);
107 }
108 /*!
109 * \brief Create a new data type by change bits to a specified value.
110 * \param bits The target number of bits.
111 * \return the result type.
112 */
with_bits(int bits)113 DataType with_bits(int bits) const {
114 return DataType(data_.code, bits, data_.lanes);
115 }
116 /*!
117 * \brief Get the scalar version of the type.
118 * \return the result type.
119 */
element_of()120 DataType element_of() const {
121 return with_lanes(1);
122 }
123 // operator overloadings
124 bool operator==(const DataType& other) const {
125 return
126 data_.code == other.data_.code &&
127 data_.bits == other.data_.bits &&
128 data_.lanes == other.data_.lanes;
129 }
130 bool operator!=(const DataType& other) const {
131 return !operator==(other);
132 }
DLDataType()133 operator DLDataType () const {
134 return data_;
135 }
136 /*! \return the maximum possible value in this format. */
137 TVM_DLL Expr max() const;
138 /*! \return the minimum possible value in this format. */
139 TVM_DLL Expr min() const;
140
141 private:
142 DLDataType data_;
143 };
144
145 /*!
146 * \brief Construct an int type.
147 * \param bits The number of bits in the type.
148 * \param lanes The number of lanes.
149 * \return The constructed data type.
150 */
151 inline DataType Int(int bits, int lanes = 1) {
152 return DataType(kDLInt, bits, lanes);
153 }
154
155 /*!
156 * \brief Construct an uint type.
157 * \param bits The number of bits in the type.
158 * \param lanes The number of lanes
159 * \return The constructed data type.
160 */
161 inline DataType UInt(int bits, int lanes = 1) {
162 return DataType(kDLUInt, bits, lanes);
163 }
164
165 /*!
166 * \brief Construct a bool type.
167 * \param lanes The number of lanes
168 * \return The constructed data type.
169 */
170 inline DataType Bool(int lanes = 1) {
171 return UInt(1, lanes);
172 }
173
174 /*!
175 * \brief Construct an uint type.
176 * \param bits The number of bits in the type.
177 * \param lanes The number of lanes
178 * \return The constructed data type.
179 */
180 inline DataType Float(int bits, int lanes = 1) {
181 return DataType(kDLFloat, bits, lanes);
182 }
183
184 /*!
185 * \brief Construct a handle type.
186 * \param bits The number of bits in the type.
187 * \param lanes The number of lanes
188 * \return The constructed data type.
189 */
190 inline DataType Handle(int bits = 64, int lanes = 1) {
191 return DataType(kHandle, bits, lanes);
192 }
193
194 /*!
195 * \brief Get the corresponding type of TVMShapeIndex.
196 * \return The type of TVM shape index.
197 */
TVMShapeIndexType()198 inline DataType TVMShapeIndexType() {
199 if (std::is_signed<tvm_index_t>::value) {
200 return Int(sizeof(tvm_index_t) * 8);
201 } else {
202 return UInt(sizeof(tvm_index_t) * 8);
203 }
204 }
205
206 /*!
207 * \brief Convert DLDataType to DataType.
208 * \param t The original type.
209 * \return The conversion result.
210 */
TVMType2Type(DLDataType t)211 inline DataType TVMType2Type(DLDataType t) {
212 return DataType(t.code, t.bits, t.lanes);
213 }
214
215 /*!
216 * \brief Convert DataType to DataType.
217 * \param t The original type.
218 * \return The conversion result.
219 */
Type2TVMType(DataType t)220 inline DLDataType Type2TVMType(DataType t) {
221 return t.operator DLDataType();
222 }
223
224 /*!
225 * \brief Get the number of bytes needed in a vector.
226 * \param dtype The data type.
227 * \return Number of bytes needed.
228 */
GetVectorBytes(DataType dtype)229 inline int GetVectorBytes(DataType dtype) {
230 int data_bits = dtype.bits() * dtype.lanes();
231 // allow bool to exist
232 if (dtype == Bool()) return 1;
233 CHECK_EQ(data_bits % 8, 0U)
234 << "Need to load/store by multiple of bytes";
235 return data_bits / 8;
236 }
237
238 // Overload print function.
239 inline std::ostream& operator<<(std::ostream& os, DataType dtype) { // NOLINT(*)
240 using namespace tvm::runtime;
241 return os << dtype.operator DLDataType();
242 }
243
244 // Backward compatibility
245 using Type = DataType;
246 } // namespace tvm
247 #endif // TVM_DTYPE_H_
248