1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 /*
20  * \file data_type.h
21  * \brief Primitive runtime data type.
22  */
23 // Acknowledgement: This file originates from incubator-tvm
24 // Acknowledgement: MXNetDataType structure design originates from Halide.
25 #ifndef MXNET_RUNTIME_DATA_TYPE_H_
26 #define MXNET_RUNTIME_DATA_TYPE_H_
27 
28 #include <mxnet/runtime/c_runtime_api.h>
29 #include <dmlc/logging.h>
30 #include <type_traits>
31 
32 
33 namespace mxnet {
34 namespace runtime {
35 /*!
36  * \brief Runtime primitive data type.
37  *
38  *  This class is a thin wrapper of DLDataType.
39  *  We also make use of MXNetDataType in compiler to store quick hint
40  */
41 class MXNetDataType {
42  public:
43   /*! \brief Type code for the MXNetDataType. */
44   enum TypeCode {
45     kInt = kDLInt,
46     kUInt = kDLUInt,
47     kFloat = kDLFloat,
48     kHandle = MXNetTypeCode::kHandle,
49   };
50   /*! \brief default constructor */
MXNetDataType()51   MXNetDataType() {}
52   /*!
53    * \brief Constructor
54    * \param dtype The DLDataType
55    */
MXNetDataType(DLDataType dtype)56   explicit MXNetDataType(DLDataType dtype)
57       : data_(dtype) {}
58   /*!
59    * \brief Constructor
60    * \param code The type code.
61    * \param bits The number of bits in the type.
62    * \param lanes The number of lanes.
63    */
MXNetDataType(int code,int bits,int lanes)64   MXNetDataType(int code, int bits, int lanes) {
65     data_.code = static_cast<uint8_t>(code);
66     data_.bits = static_cast<uint8_t>(bits);
67     data_.lanes = static_cast<uint16_t>(lanes);
68   }
69   /*! \return The type code. */
code()70   int code() const {
71     return static_cast<int>(data_.code);
72   }
73   /*! \return number of bits in the data. */
bits()74   int bits() const {
75     return static_cast<int>(data_.bits);
76   }
77   /*! \return number of bytes to store each scalar. */
bytes()78   int bytes() const {
79     return (bits() + 7) / 8;
80   }
81   /*! \return number of lanes in the data. */
lanes()82   int lanes() const {
83     return static_cast<int>(data_.lanes);
84   }
85   /*! \return whether type is a scalar type. */
is_scalar()86   bool is_scalar() const {
87     return lanes() == 1;
88   }
89   /*! \return whether type is a scalar type. */
is_bool()90   bool is_bool() const {
91     return code() == MXNetDataType::kUInt && bits() == 1;
92   }
93   /*! \return whether type is a float type. */
is_float()94   bool is_float() const {
95     return code() == MXNetDataType::kFloat;
96   }
97   /*! \return whether type is an int type. */
is_int()98   bool is_int() const {
99     return code() == MXNetDataType::kInt;
100   }
101   /*! \return whether type is an uint type. */
is_uint()102   bool is_uint() const {
103     return code() == MXNetDataType::kUInt;
104   }
105   /*! \return whether type is a handle type. */
is_handle()106   bool is_handle() const {
107     return code() == MXNetDataType::kHandle;
108   }
109   /*! \return whether type is a vector type. */
is_vector()110   bool is_vector() const {
111     return lanes() > 1;
112   }
113   /*!
114    * \brief Create a new data type by change lanes to a specified value.
115    * \param lanes The target number of lanes.
116    * \return the result type.
117    */
with_lanes(int lanes)118   MXNetDataType with_lanes(int lanes) const {
119     return MXNetDataType(data_.code, data_.bits, lanes);
120   }
121   /*!
122    * \brief Create a new data type by change bits to a specified value.
123    * \param bits The target number of bits.
124    * \return the result type.
125    */
with_bits(int bits)126   MXNetDataType with_bits(int bits) const {
127     return MXNetDataType(data_.code, bits, data_.lanes);
128   }
129   /*!
130    * \brief Get the scalar version of the type.
131    * \return the result type.
132    */
element_of()133   MXNetDataType element_of() const {
134     return with_lanes(1);
135   }
136   /*!
137    * \brief Equal comparator.
138    * \param other The data type to compre against.
139    * \return The comparison resilt.
140    */
141   bool operator==(const MXNetDataType& other) const {
142     return
143         data_.code == other.data_.code &&
144         data_.bits == other.data_.bits &&
145         data_.lanes == other.data_.lanes;
146   }
147   /*!
148    * \brief NotEqual comparator.
149    * \param other The data type to compre against.
150    * \return The comparison resilt.
151    */
152   bool operator!=(const MXNetDataType& other) const {
153     return !operator==(other);
154   }
155   /*!
156    * \brief Converter to DLDataType
157    * \return the result.
158    */
DLDataType()159   operator DLDataType () const {
160     return data_;
161   }
162 
163   /*!
164    * \brief Construct an int type.
165    * \param bits The number of bits in the type.
166    * \param lanes The number of lanes.
167    * \return The constructed data type.
168    */
169   static MXNetDataType Int(int bits, int lanes = 1) {
170     return MXNetDataType(kDLInt, bits, lanes);
171   }
172   /*!
173    * \brief Construct an uint type.
174    * \param bits The number of bits in the type.
175    * \param lanes The number of lanes
176    * \return The constructed data type.
177    */
178   static MXNetDataType UInt(int bits, int lanes = 1) {
179     return MXNetDataType(kDLUInt, bits, lanes);
180   }
181   /*!
182    * \brief Construct an uint type.
183    * \param bits The number of bits in the type.
184    * \param lanes The number of lanes
185    * \return The constructed data type.
186    */
187   static MXNetDataType Float(int bits, int lanes = 1) {
188     return MXNetDataType(kDLFloat, bits, lanes);
189   }
190   /*!
191    * \brief Construct a bool type.
192    * \param lanes The number of lanes
193    * \return The constructed data type.
194    */
195   static MXNetDataType Bool(int lanes = 1) {
196     return MXNetDataType::UInt(1, lanes);
197   }
198   /*!
199    * \brief Construct a handle type.
200    * \param bits The number of bits in the type.
201    * \param lanes The number of lanes
202    * \return The constructed data type.
203    */
204   static MXNetDataType Handle(int bits = 64, int lanes = 1) {
205     return MXNetDataType(kHandle, bits, lanes);
206   }
207 
208  private:
209   DLDataType data_;
210 };
211 
212 }  // namespace runtime
213 
214 using MXNetDataType = runtime::MXNetDataType;
215 
216 }  // namespace mxnet
217 #endif  //  MXNET_RUNTIME_DATA_TYPE_H_
218