1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 /*
20  * \file tvm/dtype.h
21  * \brief Data type used in IR.
22  */
23 // Acknowledgement: DataType structure design originates from Halide.
24 #ifndef TVM_DTYPE_H_
25 #define TVM_DTYPE_H_
26 
27 #include "runtime/packed_func.h"
28 
29 namespace tvm {
30 class Expr;
31 
32 /*!
33  * \brief Primitive data types in tvm.
34  */
35 class DataType {
36  public:
37   /*! \brief default constructor */
DataType()38   DataType() {}
39   /*!
40    * \brief Constructor
41    * \param dtype The DLDataType
42    */
DataType(DLDataType dtype)43   explicit DataType(DLDataType dtype)
44       : data_(dtype) {}
45   /*!
46    * \brief Constructor
47    * \param code The type code.
48    * \param bits The number of bits in the type.
49    * \param lanes The number of lanes.
50    */
DataType(int code,int bits,int lanes)51   DataType(int code, int bits, int lanes) {
52     data_.code = static_cast<uint8_t>(code);
53     data_.bits = static_cast<uint8_t>(bits);
54     data_.lanes = static_cast<uint16_t>(lanes);
55   }
56   /*! \return The type code. */
code()57   int code() const {
58     return static_cast<int>(data_.code);
59   }
60   /*! \return number of bits in the data. */
bits()61   int bits() const {
62     return static_cast<int>(data_.bits);
63   }
64   /*! \return number of bytes to store each scalar. */
bytes()65   int bytes() const {
66     return (bits() + 7) / 8;
67   }
68   /*! \return number of lanes in the data. */
lanes()69   int lanes() const {
70     return static_cast<int>(data_.lanes);
71   }
72   /*! \return whether type is a scalar type. */
is_scalar()73   bool is_scalar() const {
74     return lanes() == 1;
75   }
76   /*! \return whether type is a scalar type. */
is_bool()77   bool is_bool() const {
78     return code() == kDLUInt && bits() == 1;
79   }
80   /*! \return whether type is a float type. */
is_float()81   bool is_float() const {
82     return code() == kDLFloat;
83   }
84   /*! \return whether type is an int type. */
is_int()85   bool is_int() const {
86     return code() == kDLInt;
87   }
88   /*! \return whether type is an uint type. */
is_uint()89   bool is_uint() const {
90     return code() == kDLUInt;
91   }
92   /*! \return whether type is a handle type. */
is_handle()93   bool is_handle() const {
94     return code() == kHandle;
95   }
96   /*! \return whether type is a vector type. */
is_vector()97   bool is_vector() const {
98     return lanes() > 1;
99   }
100   /*!
101    * \brief Create a new data type by change lanes to a specified value.
102    * \param lanes The target number of lanes.
103    * \return the result type.
104    */
with_lanes(int lanes)105   DataType with_lanes(int lanes) const {
106     return DataType(data_.code, data_.bits, lanes);
107   }
108   /*!
109    * \brief Create a new data type by change bits to a specified value.
110    * \param bits The target number of bits.
111    * \return the result type.
112    */
with_bits(int bits)113   DataType with_bits(int bits) const {
114     return DataType(data_.code, bits, data_.lanes);
115   }
116   /*!
117    * \brief Get the scalar version of the type.
118    * \return the result type.
119    */
element_of()120   DataType element_of() const {
121     return with_lanes(1);
122   }
123   // operator overloadings
124   bool operator==(const DataType& other) const {
125     return
126         data_.code == other.data_.code &&
127         data_.bits == other.data_.bits &&
128         data_.lanes == other.data_.lanes;
129   }
130   bool operator!=(const DataType& other) const {
131     return !operator==(other);
132   }
DLDataType()133   operator DLDataType () const {
134     return data_;
135   }
136   /*! \return the maximum possible value in this format. */
137   TVM_DLL Expr max() const;
138   /*! \return the minimum possible value in this format. */
139   TVM_DLL Expr min() const;
140 
141  private:
142   DLDataType data_;
143 };
144 
145 /*!
146  * \brief Construct an int type.
147  * \param bits The number of bits in the type.
148  * \param lanes The number of lanes.
149  * \return The constructed data type.
150  */
151 inline DataType Int(int bits, int lanes = 1) {
152   return DataType(kDLInt, bits, lanes);
153 }
154 
155 /*!
156  * \brief Construct an uint type.
157  * \param bits The number of bits in the type.
158  * \param lanes The number of lanes
159  * \return The constructed data type.
160  */
161 inline DataType UInt(int bits, int lanes = 1) {
162   return DataType(kDLUInt, bits, lanes);
163 }
164 
165 /*!
166  * \brief Construct a bool type.
167  * \param lanes The number of lanes
168  * \return The constructed data type.
169  */
170 inline DataType Bool(int lanes = 1) {
171   return UInt(1, lanes);
172 }
173 
174 /*!
175  * \brief Construct an uint type.
176  * \param bits The number of bits in the type.
177  * \param lanes The number of lanes
178  * \return The constructed data type.
179  */
180 inline DataType Float(int bits, int lanes = 1) {
181   return DataType(kDLFloat, bits, lanes);
182 }
183 
184 /*!
185  * \brief Construct a handle type.
186  * \param bits The number of bits in the type.
187  * \param lanes The number of lanes
188  * \return The constructed data type.
189  */
190 inline DataType Handle(int bits = 64, int lanes = 1) {
191   return DataType(kHandle, bits, lanes);
192 }
193 
194 /*!
195  * \brief Get the corresponding type of TVMShapeIndex.
196  * \return The type of TVM shape index.
197  */
TVMShapeIndexType()198 inline DataType TVMShapeIndexType() {
199   if (std::is_signed<tvm_index_t>::value) {
200     return Int(sizeof(tvm_index_t) * 8);
201   } else {
202     return UInt(sizeof(tvm_index_t) * 8);
203   }
204 }
205 
206 /*!
207  * \brief Convert DLDataType to DataType.
208  * \param t The original type.
209  * \return The conversion result.
210  */
TVMType2Type(DLDataType t)211 inline DataType TVMType2Type(DLDataType t) {
212   return DataType(t.code, t.bits, t.lanes);
213 }
214 
215 /*!
216  * \brief Convert DataType to DataType.
217  * \param t The original type.
218  * \return The conversion result.
219  */
Type2TVMType(DataType t)220 inline DLDataType Type2TVMType(DataType t) {
221   return t.operator DLDataType();
222 }
223 
224 /*!
225  * \brief Get the number of bytes needed in a vector.
226  * \param dtype The data type.
227  * \return Number of bytes needed.
228  */
GetVectorBytes(DataType dtype)229 inline int GetVectorBytes(DataType dtype) {
230   int data_bits = dtype.bits() * dtype.lanes();
231   // allow bool to exist
232   if (dtype == Bool()) return 1;
233   CHECK_EQ(data_bits % 8, 0U)
234       << "Need to load/store by multiple of bytes";
235   return data_bits / 8;
236 }
237 
238 // Overload print function.
239 inline std::ostream& operator<<(std::ostream& os, DataType dtype) { // NOLINT(*)
240   using namespace tvm::runtime;
241   return os << dtype.operator DLDataType();
242 }
243 
244 // Backward compatibility
245 using Type = DataType;
246 }  // namespace tvm
247 #endif  //  TVM_DTYPE_H_
248