1 /**
2  * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #ifndef GLOW_BASE_TYPE_H
17 #define GLOW_BASE_TYPE_H
18 
19 #include "DimType.h"
20 
21 #include "glow/Support/BFloat16.h"
22 #include "glow/Support/Compiler.h"
23 #include "glow/Support/Float16.h"
24 #include "glow/Support/Memory.h"
25 
26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/StringRef.h"
28 
29 #include <glog/logging.h>
30 
31 #include <cstddef>
32 #include <cstdint>
33 #include <type_traits>
34 #include <utility>
35 
36 namespace llvm {
37 class raw_ostream;
38 }
39 
40 namespace glow {
41 
42 // UINT8_MIN is not defined in standard headers.
43 // Define it here for using these definitions consistently.
44 #define UINT8_MIN 0
45 
46 struct Type;
47 
48 using TypeRef = const Type *;
49 
50 constexpr unsigned max_tensor_dimensions = 6;
51 
52 /// This type is used to implement the Node and Instruction builder's
53 /// MemberType::Unsigned and MemberType::VectorUnsigned. Thus it should be used
54 /// when handling members of these classes, e.g. a convolution Node/Instr's
55 /// getGroup() (Unsigned), or getKernels() (UnsignedVector).
56 using unsigned_t = uint32_t;
57 
58 using float16_t = float16;
59 static_assert(sizeof(float16_t) == 2, "Half precision should be 16-bit");
60 
61 using bfloat16_t = bfloat16;
62 static_assert(sizeof(bfloat16_t) == 2, "bfloat16 should be 16-bit");
63 
64 using ShapeVector = llvm::SmallVector<dim_t, max_tensor_dimensions>;
65 
66 struct ShapeNHWC {
67 
68   enum {
69     DimN,
70     DimH,
71     DimW,
72     DimC,
73   };
74 
75   dim_t n; // Number of samples
76   dim_t h; // Height
77   dim_t w; // Width
78   dim_t c; // Number of Channels
79 
ShapeNHWCShapeNHWC80   template <typename T> explicit ShapeNHWC(llvm::ArrayRef<T> shape) {
81     assert(shape.size() == 4 && "Invalid shape");
82     n = shape[DimN];
83     h = shape[DimH];
84     w = shape[DimW];
85     c = shape[DimC];
86   }
87 
ShapeNHWCShapeNHWC88   ShapeNHWC(dim_t samples, dim_t height, dim_t width, dim_t channels)
89       : n(samples), h(height), w(width), c(channels) {}
90 
equalsShapeNHWC91   bool equals(const ShapeNHWC &other) const {
92     return n == other.n && h == other.h && w == other.w && c == other.c;
93   }
94 };
95 
96 struct ShapeNTHWC {
97   dim_t n; // Number of samples
98   dim_t t; // Temporal frames
99   dim_t h; // Height
100   dim_t w; // Width
101   dim_t c; // Number of Channels
102 
ShapeNTHWCShapeNTHWC103   template <typename T> explicit ShapeNTHWC(llvm::ArrayRef<T> shape) {
104     assert(shape.size() == 5 && "Invalid shape");
105     n = shape[0];
106     t = shape[1];
107     h = shape[2];
108     w = shape[3];
109     c = shape[4];
110   }
111 
ShapeNTHWCShapeNTHWC112   ShapeNTHWC(dim_t samples, dim_t temporal_frames, dim_t height, dim_t width,
113              dim_t channels)
114       : n(samples), t(temporal_frames), h(height), w(width), c(channels) {}
115 
equalsShapeNTHWC116   bool equals(const ShapeNTHWC &other) const {
117     return n == other.n && t == other.t && h == other.h && w == other.w &&
118            c == other.c;
119   }
120 };
121 
122 struct ShapeNHWTC {
123   dim_t n; // Number of samples
124   dim_t h; // Height
125   dim_t w; // Width
126   dim_t t; // Temporal_frames
127   dim_t c; // Number of Channels
128 
ShapeNHWTCShapeNHWTC129   template <typename T> explicit ShapeNHWTC(llvm::ArrayRef<T> shape) {
130     assert(shape.size() == 5 && "Invalid shape");
131     n = shape[0];
132     h = shape[1];
133     w = shape[2];
134     t = shape[3];
135     c = shape[4];
136   }
137 
ShapeNHWTCShapeNHWTC138   ShapeNHWTC(size_t samples, size_t height, size_t width,
139              size_t temporal_frames, size_t channels)
140       : n(samples), h(height), w(width), t(temporal_frames), c(channels) {}
141 
equalsShapeNHWTC142   bool equals(const ShapeNHWTC &other) const {
143     return n == other.n && h == other.h && w == other.w && t == other.t &&
144            c == other.c;
145   }
146 };
147 
148 struct ShapeNCHW {
149 
150   enum {
151     DimN,
152     DimC,
153     DimH,
154     DimW,
155   };
156 
157   dim_t n; // Number of samples
158   dim_t c; // Number of Channels
159   dim_t h; // Height
160   dim_t w; // Width
161 
ShapeNCHWShapeNCHW162   explicit ShapeNCHW(llvm::ArrayRef<dim_t> shape) {
163     assert(shape.size() == 4 && "Invalid shape");
164     n = shape[DimN];
165     c = shape[DimC];
166     h = shape[DimH];
167     w = shape[DimW];
168   }
169 
ShapeNCHWShapeNCHW170   ShapeNCHW(dim_t samples, dim_t channels, dim_t height, dim_t width)
171       : n(samples), c(channels), h(height), w(width) {}
172 
equalsShapeNCHW173   bool equals(const ShapeNCHW &other) const {
174     return n == other.n && h == other.h && w == other.w && c == other.c;
175   }
176 };
177 
178 struct ShapeNCTHW {
179   dim_t n; // Number of samples
180   dim_t c; // Number of Channels
181   dim_t t; // Temporal frames
182   dim_t h; // Height
183   dim_t w; // Width
184 
ShapeNCTHWShapeNCTHW185   explicit ShapeNCTHW(llvm::ArrayRef<dim_t> shape) {
186     assert(shape.size() == 5 && "Invalid shape");
187     n = shape[0];
188     c = shape[1];
189     t = shape[2];
190     h = shape[3];
191     w = shape[4];
192   }
193 
ShapeNCTHWShapeNCTHW194   ShapeNCTHW(dim_t samples, dim_t channels, dim_t temporal_frames, dim_t height,
195              dim_t width)
196       : n(samples), c(channels), t(temporal_frames), h(height), w(width) {}
197 
equalsShapeNCTHW198   bool equals(const ShapeNCTHW &other) const {
199     return n == other.n && t == other.t && h == other.h && w == other.w &&
200            c == other.c;
201   }
202 };
203 
204 struct PaddingTLBR {
205   dim_t top;
206   dim_t left;
207   dim_t bottom;
208   dim_t right;
209 
PaddingTLBRPaddingTLBR210   template <typename T> explicit PaddingTLBR(llvm::ArrayRef<T> pads) {
211     assert(pads.size() == 4 && "Invalid padding");
212     top = pads[0];
213     left = pads[1];
214     bottom = pads[2];
215     right = pads[3];
216   }
217 
equalPaddingPaddingTLBR218   bool equalPadding() const {
219     return top == left && top == bottom && top == right;
220   }
221 };
222 
223 struct PaddingTLNBRF {
224   dim_t top;
225   dim_t left;
226   dim_t near;
227   dim_t bottom;
228   dim_t right;
229   dim_t far;
230 
PaddingTLNBRFPaddingTLNBRF231   template <typename T> explicit PaddingTLNBRF(llvm::ArrayRef<T> pads) {
232     assert(pads.size() == 6 && "Invalid padding");
233     top = pads[0];
234     left = pads[1];
235     near = pads[2];
236     bottom = pads[3];
237     right = pads[4];
238     far = pads[5];
239   }
240 
equalPaddingPaddingTLNBRF241   bool equalPadding() const {
242     return top == left && top == bottom && top == right && top == near &&
243            top == far;
244   }
245 };
246 
247 struct PaddingNFTBLR {
248   dim_t near;
249   dim_t far;
250   dim_t top;
251   dim_t bottom;
252   dim_t left;
253   dim_t right;
254 
PaddingNFTBLRPaddingNFTBLR255   template <typename T> explicit PaddingNFTBLR(llvm::ArrayRef<T> pads) {
256     assert(pads.size() == 6 && "Invalid padding");
257     near = pads[0];
258     far = pads[1];
259     top = pads[2];
260     bottom = pads[3];
261     left = pads[4];
262     right = pads[5];
263   }
264 
equalPaddingPaddingNFTBLR265   bool equalPadding() const {
266     return top == left && top == bottom && top == right && top == near &&
267            top == far;
268   }
269 };
270 
271 struct ShapeHW {
272 
273   enum {
274     DimH,
275     DimW,
276   };
277 
278   dim_t height;
279   dim_t width;
280 
ShapeHWShapeHW281   template <typename T> explicit ShapeHW(llvm::ArrayRef<T> shape) {
282     assert(shape.size() == 2 && "Invalid shape");
283     height = shape[DimH];
284     width = shape[DimW];
285   }
286 
isSquareShapeHW287   bool isSquare() const { return height == width; }
288 };
289 
290 struct ShapeNHW {
291 
292   enum {
293     DimN,
294     DimH,
295     DimW,
296   };
297 
298   dim_t n; // Number of samples
299   dim_t h; // Height
300   dim_t w; // Width
301 
ShapeNHWShapeNHW302   template <typename T> explicit ShapeNHW(llvm::ArrayRef<T> shape) {
303     assert(shape.size() == 3 && "Invalid shape");
304     n = shape[DimN];
305     h = shape[DimH];
306     w = shape[DimW];
307   }
308 
isSquareShapeNHW309   bool isSquare() const { return h == w; }
310 };
311 
312 struct ShapeHWT {
313   dim_t height;
314   dim_t width;
315   dim_t temporal_frames;
316 
ShapeHWTShapeHWT317   template <typename T> explicit ShapeHWT(llvm::ArrayRef<T> shape) {
318     assert(shape.size() == 3 && "Invalid shape");
319     height = shape[0];
320     width = shape[1];
321     temporal_frames = shape[2];
322   }
323 
isCubeShapeHWT324   bool isCube() const { return height == width && height == temporal_frames; }
325 };
326 
327 struct ShapeTHW {
328   dim_t temporal_frames;
329   dim_t height;
330   dim_t width;
331 
ShapeTHWShapeTHW332   template <typename T> explicit ShapeTHW(llvm::ArrayRef<T> shape) {
333     assert(shape.size() == 3 && "Invalid shape");
334     temporal_frames = shape[0];
335     height = shape[1];
336     width = shape[2];
337   }
338 
isCubeShapeTHW339   bool isCube() const { return height == width && height == temporal_frames; }
340 };
341 
342 /// Collapse a tensor shape into two sizes: the first n dimensions and the size
343 /// of the rest of the dimensions. For example, ([7, 3, 4, 2], 1) -> [7, 24]
344 inline std::pair<dim_t, dim_t> flattenCdr(llvm::ArrayRef<dim_t> dims,
345                                           unsigned_t n = 1) {
346   assert(1 <= n && n <= dims.size());
347   size_t first = dims[0];
348   for (unsigned_t i = 1; i < n; i++) {
349     first *= dims[i];
350   }
351   size_t rest = 1;
352   for (unsigned_t i = n; i < dims.size(); i++) {
353     rest *= dims[i];
354   }
355 
356   return {first, rest};
357 }
358 
359 inline bool operator==(const ShapeNHWC &LHS, const ShapeNHWC &RHS) {
360   return LHS.equals(RHS);
361 }
362 
363 inline bool operator==(const ShapeNCHW &LHS, const ShapeNCHW &RHS) {
364   return LHS.equals(RHS);
365 }
366 
367 inline bool operator==(const ShapeNHWTC &LHS, const ShapeNHWTC &RHS) {
368   return LHS.equals(RHS);
369 }
370 
371 inline bool operator==(const ShapeNTHWC &LHS, const ShapeNTHWC &RHS) {
372   return LHS.equals(RHS);
373 }
374 
375 inline bool operator==(const ShapeNCTHW &LHS, const ShapeNCTHW &RHS) {
376   return LHS.equals(RHS);
377 }
378 
379 /// An enum representing the type used by the elements of a tensor. The types of
380 /// Handles for these tensors should match the element kind.
381 /// When adding new type, note that this enum definition must match with
382 /// ElemKind definition in Glow/lib/Backends/CPU/libjit/libjit.cpp
383 enum class ElemKind : unsigned char {
384   // 32-bit float type (float)
385   FloatTy,
386   // 16-bit float type (half, fp16)
387   Float16Ty,
388   // 16-bit float type (bfloat16)
389   BFloat16Ty,
390   // 8-bit quantized type (int8_t)
391   Int8QTy,
392   // unsigned 8-bit quantized type (uint8_t)
393   UInt8QTy,
394   // 16-bit quantized type (int16_t)
395   Int16QTy,
396   // 32-bit quantized type (int32_t)
397   Int32QTy,
398   // 32-bit index type (int32_t)
399   Int32ITy,
400   // 64-bit index type (int64_t)
401   Int64ITy,
402   // 8-bit quantized type with fused scale/offset (uint8_t)
403   UInt8FusedQTy,
404   // 8-bit quantized type with fused FP16 scale/offset (uint8_t)
405   UInt8FusedFP16QTy,
406   // 4-bit quantized type with fused FP16 scale/offset (uint8_t, each byte
407   // represents 2 4-bit quantized data)
408   UInt4FusedFP16QTy,
409   // Bool type (bool)
410   BoolTy,
411 };
412 
413 /// \returns whether \p e is a quantized ElemKind.
isQuantizedElemKind(ElemKind e)414 inline bool isQuantizedElemKind(ElemKind e) {
415   return e == ElemKind::Int8QTy || e == ElemKind::UInt8QTy ||
416          e == ElemKind::Int16QTy || e == ElemKind::Int32QTy ||
417          e == ElemKind::UInt8FusedQTy || e == ElemKind::UInt8FusedFP16QTy ||
418          e == ElemKind::UInt4FusedFP16QTy;
419 }
420 
421 /// \returns whether \p e is a float ElemKind.
isFloatElemKind(ElemKind e)422 inline bool isFloatElemKind(ElemKind e) {
423   return e == ElemKind::FloatTy || e == ElemKind::Float16Ty ||
424          e == ElemKind::BFloat16Ty;
425 }
426 
427 /// \returns whether \p e is a fused quantized ElemKind.
isFusedQuantizedElemKind(ElemKind e)428 inline bool isFusedQuantizedElemKind(ElemKind e) {
429   return e == ElemKind::UInt8FusedQTy || e == ElemKind::UInt8FusedFP16QTy ||
430          e == ElemKind::UInt4FusedFP16QTy;
431 }
432 
433 /// \returns the scale and offset ElemKind used by the fused ElemKind \p e.
getScaleOffsetElemKindFromFused(ElemKind e)434 inline ElemKind getScaleOffsetElemKindFromFused(ElemKind e) {
435   assert(isFusedQuantizedElemKind(e) && "Must pass Fused ElemKind.");
436   if (e == ElemKind::UInt8FusedQTy) {
437     return ElemKind::FloatTy;
438   }
439   return ElemKind::Float16Ty;
440 }
441 
442 /// A class that represents a type of a tensor.
443 struct Type final {
444   /// Contains the dimensions (sizes) of the tensor. Ex: [sx, sy, sz, ...].
445   dim_t sizes_[max_tensor_dimensions] = {
446       0,
447   };
448   /// Contains the strides for each dimension (in elements). The order should be
449   /// the same as in sizes_. In more details, suppose that the tensor is laid
450   /// out flat in memory, and some dimensions are aligned. strides_[i] is the
451   /// number of elements that needs to be skipped in order to reach the next
452   /// plane in the i-th dimension. For example, if the tensor has dimensions
453   /// [3, 5, 10] and alignments [3, 32, 1], the strides will be [162, 32, 1].
454   dim_t strides_[max_tensor_dimensions] = {
455       0,
456   };
457 
458   /// Contains the number of dimensions used by the tensor.
459   unsigned char numSizes_{0};
460 
461   /// On quantized tensors, this represents the scale of the values.
462   float scale_{0};
463   /// On quantized tensors, this represents the offset of the values.
464   int32_t offset_{0};
465 
466   /// Specifies the element type of the tensor.
467   ElemKind elementType_{ElemKind::Int64ITy};
468 
469   /// Initialize a new quantized type with \p scale and \p offset.
Typefinal470   Type(ElemKind elemTy, llvm::ArrayRef<dim_t> dims, float scale, int32_t offset)
471       : scale_(scale), offset_(offset), elementType_(elemTy) {
472     assert(isQuantizedType() && "Only quantized types have a scale and offset");
473     ShapeVector alignments(dims.size(), 1);
474     initDims(dims, llvm::makeArrayRef(alignments));
475   }
476 
477   /// Initialize a new non-quantized type.
Typefinal478   Type(ElemKind elemTy, llvm::ArrayRef<dim_t> dims) : elementType_(elemTy) {
479     assert(!isQuantizedType() &&
480            "Can't initialize quantized types without scale and offset");
481     ShapeVector alignments(dims.size(), 1);
482     initDims(dims, llvm::makeArrayRef(alignments));
483   }
484 
485   /// Initialize a new quantized type with \p scale and \p offset.
Typefinal486   Type(ElemKind elemTy, llvm::ArrayRef<dim_t> dims,
487        llvm::ArrayRef<dim_t> alignments, float scale, int32_t offset)
488       : scale_(scale), offset_(offset), elementType_(elemTy) {
489     assert(isQuantizedType() && "Only quantized types have a scale and offset");
490     initDims(dims, alignments);
491   }
492 
493   /// Initialize a new non-quantized type.
Typefinal494   Type(ElemKind elemTy, llvm::ArrayRef<dim_t> dims,
495        llvm::ArrayRef<dim_t> alignments)
496       : elementType_(elemTy) {
497     assert(!isQuantizedType() &&
498            "Can't initialize quantized types without scale and offset");
499     initDims(dims, alignments);
500   }
501 
502   /// Reshape existing type. This method takes care of quantized types.
newShapefinal503   static Type newShape(const Type &T, llvm::ArrayRef<dim_t> dims) {
504     if (T.isQuantizedType()) {
505       return Type(T.getElementType(), dims, T.getScale(), T.getOffset());
506     } else {
507       return Type(T.getElementType(), dims);
508     }
509   }
510 
511   /// Reshape existing type and change alignments.
newShapefinal512   static Type newShape(const Type &T, llvm::ArrayRef<dim_t> dims,
513                        llvm::ArrayRef<dim_t> alignments) {
514     if (T.isQuantizedType()) {
515       return Type(T.getElementType(), dims, alignments, T.getScale(),
516                   T.getOffset());
517     } else {
518       return Type(T.getElementType(), dims, alignments);
519     }
520   }
521 
522   /// Reshape existing type by taking shapes and strides of \p shapeType.
newShapefinal523   static Type newShape(const Type &T, TypeRef shapeType) {
524     Type ty;
525     if (T.isQuantizedType()) {
526       ty = Type(T.getElementType(), shapeType->dims(), T.getScale(),
527                 T.getOffset());
528     } else {
529       ty = Type(T.getElementType(), shapeType->dims());
530     }
531     // Copy the stride information.
532     std::copy(&shapeType->strides_[0], &shapeType->strides_[ty.numSizes_],
533               ty.strides_);
534     return ty;
535   }
536 
537   /// An empty type.
538   Type() = default;
539 
540   /// \returns true if \p other is the same type.
isEqualfinal541   bool isEqual(TypeRef other) const { return isEqual(*other); }
542 
543   /// \returns the scale of a quantized type.
getScalefinal544   float getScale() const {
545     assert(isQuantizedType() && "Can't get the scale of a non-quantized type");
546     return scale_;
547   }
548 
549   /// \returns the offset of a quantized type.
getOffsetfinal550   int32_t getOffset() const {
551     assert(isQuantizedType() && "Can't get the offset of a non-quantized type");
552     return offset_;
553   }
554 
555   /// \returns the floating point value range that covers a quantized type (min
556   /// first, max second).
getQuantizedValueRangefinal557   std::pair<float, float> getQuantizedValueRange() const {
558     assert(isQuantizedType() &&
559            "Can't get the quantized value range of a non-quantized type");
560 
561     int64_t low = 0, high = 0;
562     switch (elementType_) {
563     case ElemKind::Int32QTy: {
564       low = INT32_MIN;
565       high = INT32_MAX;
566       break;
567     }
568     case ElemKind::Int16QTy: {
569       low = INT16_MIN;
570       high = INT16_MAX;
571       break;
572     }
573     case ElemKind::Int8QTy: {
574       low = INT8_MIN;
575       high = INT8_MAX;
576       break;
577     }
578     case ElemKind::UInt8QTy: {
579       low = UINT8_MIN;
580       high = UINT8_MAX;
581       break;
582     }
583     default:;
584     }
585 
586     float lowFloat = (low - offset_) * scale_;
587     float highFloat = (high - offset_) * scale_;
588     return std::make_pair(lowFloat, highFloat);
589   }
590 
591   /// \returns true if \p other is the same type. If \p allowDifferentShape then
592   /// shapes will not be considered as part of the equal comparison.
593   bool isEqual(const Type &other, bool allowDifferentShape = false,
594                bool allowDifferentStrides = false) const {
595     // Element type must be the same.
596     if (elementType_ != other.elementType_) {
597       return false;
598     }
599     // Must have the same number of sizes.
600     if (numSizes_ != other.numSizes_) {
601       return false;
602     }
603     // Sizes must be the same.
604     if (!allowDifferentShape) {
605       for (size_t i = 0; i < numSizes_; i++) {
606         if (sizes_[i] != other.sizes_[i]) {
607           return false;
608         }
609       }
610       if (!allowDifferentStrides) {
611         // Strides must be the same.
612         for (size_t i = 0; i < numSizes_; i++) {
613           if (strides_[i] != other.strides_[i]) {
614             return false;
615           }
616         }
617       }
618     }
619 
620     // Compare the scale and offset of integers.
621     if (isQuantizedType()) {
622       if (scale_ != other.scale_ || offset_ != other.offset_) {
623         return false;
624       }
625     }
626 
627     return true;
628   }
629 
630   /// \returns a hash value for this Type. Hashes for Ty1 and Ty2 are equal if
631   /// Ty1.isEqual(Ty2).
equals_hashfinal632   llvm::hash_code equals_hash() const {
633     return llvm::hash_combine(
634         elementType_, dims(),
635         // hashing floats is tricky, fall back to std::hash
636         std::hash<float>{}(scale_), offset_);
637   }
638 
getElementTypefinal639   ElemKind getElementType() const { return elementType_; }
640 
641   /// \returns the shape of the tensor.
dimsfinal642   llvm::ArrayRef<dim_t> dims() const { return {sizes_, numSizes_}; }
643 
644   /// \returns the strides of the tensor.
stridesfinal645   llvm::ArrayRef<dim_t> strides() const { return {strides_, numSizes_}; }
646 
647   /// \returns the number of elements in the tensor.
sizefinal648   dim_t size() const {
649     dim_t s = 1;
650     for (unsigned char i = 0; i < numSizes_; i++) {
651       s *= dim_t(sizes_[i]);
652     }
653 
654     return s;
655   }
656 
657   /// \returns the number of elements in a slice in the tensor. Calculate the
658   /// size of the slice starting at \p startDim. For example, the tensor with
659   /// the shape [10, 10, 3] and startDim 1 would have the size 30, because this
660   /// is the size of the slice [10, 3] that starts at index 1.
getSliceSizefinal661   dim_t getSliceSize(unsigned char startDim) const {
662     assert(startDim <= numSizes_ && "Invalid start dim");
663     dim_t s = 1;
664     for (unsigned char i = startDim; i < numSizes_; i++) {
665       s *= dim_t(sizes_[i]);
666     }
667     return s;
668   }
669 
670   /// \returns true if the templated parameter \p ElemTy matches this type.
isTypefinal671   template <class ElemTy> bool isType() const {
672     return isType<ElemTy>(elementType_);
673   }
674 
675   /// \returns true if the templated parameter \p ElemTy matches the type that's
676   /// specified by the parameter \p Ty.
isTypefinal677   template <class ElemTy> static bool isType(ElemKind Ty) {
678     switch (Ty) {
679     case ElemKind::FloatTy:
680       return std::is_same<ElemTy, float>::value;
681     case ElemKind::Float16Ty:
682       return std::is_same<ElemTy, float16_t>::value;
683     case ElemKind::BFloat16Ty:
684       return std::is_same<ElemTy, bfloat16_t>::value;
685     case ElemKind::Int8QTy:
686       return std::is_same<ElemTy, int8_t>::value;
687     case ElemKind::UInt8QTy:
688       return std::is_same<ElemTy, uint8_t>::value;
689     case ElemKind::Int16QTy:
690       return std::is_same<ElemTy, int16_t>::value;
691     case ElemKind::Int32QTy:
692       return std::is_same<ElemTy, int32_t>::value;
693     case ElemKind::Int32ITy:
694       return std::is_same<ElemTy, int32_t>::value;
695     case ElemKind::Int64ITy:
696       return std::is_same<ElemTy, int64_t>::value;
697     case ElemKind::UInt8FusedQTy:
698       return std::is_same<ElemTy, uint8_t>::value;
699     case ElemKind::UInt8FusedFP16QTy:
700       return std::is_same<ElemTy, uint8_t>::value;
701     case ElemKind::UInt4FusedFP16QTy:
702       return std::is_same<ElemTy, uint8_t>::value;
703     case ElemKind::BoolTy:
704       return std::is_same<ElemTy, bool>::value;
705     }
706     LOG(FATAL) << "Invalid type: " << getElementName(Ty).str();
707   }
708 
709   /// \returns true if the type of this Tensor is one of the quantized types.
isQuantizedTypefinal710   bool isQuantizedType() const { return isQuantizedElemKind(elementType_); }
711 
712   /// \returns true if the type of this Tensor is one of the floating point
713   /// types.
isFPTypefinal714   bool isFPType() const { return isFloatElemKind(getElementType()); }
715 
716   /// \return the size of the type element.
getElementSizefinal717   unsigned getElementSize() const { return getElementSize(elementType_); }
718 
719   /// \returns the size in bytes for this Tensor.
getSizeInBytesfinal720   size_t getSizeInBytes() const {
721     size_t s = getElementSize();
722     for (unsigned char i = 0; i < numSizes_; i++) {
723       // If any dimensions are 0 then the entire size is 0, so early return.
724       if (sizes_[i] == 0) {
725         return 0;
726       }
727       s = std::max<dim_t>(s,
728                           size_t(sizes_[i]) * getElementSize() * strides_[i]);
729     }
730     return s;
731   }
732 
733   /// \returns the actual number of elements in the tensor taking striding into
734   /// account. Since size() does not take striding into account, size() is
735   /// always <= actualSize().
actualSizefinal736   size_t actualSize() const { return getSizeInBytes() / getElementSize(); }
737 
738   /// \return the size of the element \p Ty.
getElementSizefinal739   static unsigned getElementSize(ElemKind Ty) {
740     switch (Ty) {
741     case ElemKind::FloatTy:
742       return sizeof(float);
743     case ElemKind::Float16Ty:
744       return sizeof(float16_t);
745     case ElemKind::BFloat16Ty:
746       return sizeof(bfloat16_t);
747     case ElemKind::Int8QTy:
748       return sizeof(int8_t);
749     case ElemKind::UInt8QTy:
750       return sizeof(uint8_t);
751     case ElemKind::Int16QTy:
752       return sizeof(int16_t);
753     case ElemKind::Int32QTy:
754       return sizeof(int32_t);
755     case ElemKind::Int32ITy:
756       return sizeof(int32_t);
757     case ElemKind::Int64ITy:
758       return sizeof(int64_t);
759     case ElemKind::UInt8FusedQTy:
760       return sizeof(uint8_t);
761     case ElemKind::UInt8FusedFP16QTy:
762       return sizeof(uint8_t);
763     case ElemKind::UInt4FusedFP16QTy:
764       return sizeof(uint8_t);
765     case ElemKind::BoolTy:
766       return sizeof(bool);
767     }
768     LOG(FATAL) << "Invalid type: " << getElementName(Ty).str();
769   }
770 
771   /// \return the textual name of the element.
getElementNamefinal772   llvm::StringRef getElementName() const {
773     return getElementName(elementType_);
774   }
775 
776   /// \return the textual name of the element \p Ty.
getElementNamefinal777   static llvm::StringRef getElementName(ElemKind Ty) {
778     static const char *names[] = {
779         "float",        "float16",      "bfloat16", "i8",      "ui8",
780         "i16",          "i32",          "index32",  "index64", "ui8fused",
781         "ui8fusedfp16", "ui4fusedfp16", "bool",
782     };
783     return names[(int)Ty];
784   }
785 
786   /// Given a string \p str containing the name of an ElemKind from
787   /// Type::getElementName, returns the corresponding ElemKind or Error if a
788   /// mapping couldn't be found.
getElementKindFromNamefinal789   static ElemKind getElementKindFromName(llvm::StringRef str) {
790     if (str == Type::getElementName(ElemKind::FloatTy)) {
791       return ElemKind::FloatTy;
792     } else if (str == Type::getElementName(ElemKind::Float16Ty)) {
793       return ElemKind::Float16Ty;
794     } else if (str == Type::getElementName(ElemKind::BFloat16Ty)) {
795       return ElemKind::BFloat16Ty;
796     } else if (str == Type::getElementName(ElemKind::Int8QTy)) {
797       return ElemKind::Int8QTy;
798     } else if (str == Type::getElementName(ElemKind::UInt8QTy)) {
799       return ElemKind::UInt8QTy;
800     } else if (str == Type::getElementName(ElemKind::Int16QTy)) {
801       return ElemKind::Int16QTy;
802     } else if (str == Type::getElementName(ElemKind::Int32QTy)) {
803       return ElemKind::Int32QTy;
804     } else if (str == Type::getElementName(ElemKind::Int32ITy)) {
805       return ElemKind::Int32ITy;
806     } else if (str == Type::getElementName(ElemKind::Int64ITy)) {
807       return ElemKind::Int64ITy;
808     } else if (str == Type::getElementName(ElemKind::UInt8FusedQTy)) {
809       return ElemKind::UInt8FusedQTy;
810     } else if (str == Type::getElementName(ElemKind::UInt8FusedFP16QTy)) {
811       return ElemKind::UInt8FusedFP16QTy;
812     } else if (str == Type::getElementName(ElemKind::UInt4FusedFP16QTy)) {
813       return ElemKind::UInt4FusedFP16QTy;
814     } else if (str == Type::getElementName(ElemKind::BoolTy)) {
815       return ElemKind::BoolTy;
816     } else {
817       LOG(DFATAL) << "Invalid ElemKind string: " << str.str();
818       return ElemKind::FloatTy;
819     }
820   }
821 
822   /// Dump a textual representation of the Type into provided output stream.
823   void dump(llvm::raw_ostream &out) const;
824 
825   /// Dump a textual representation of the Type into default output stream.
826   void dump() const;
827 
828   /// Dump a textual representation of the Type to std::string.
829   std::string toString() const;
830 
831   /// Load a Type object from a textual representation \p str. This method is
832   /// paired and should be used together with \ref toString.
833   static Type fromString(llvm::StringRef str);
834 
835 private:
836   /// Setup the internals of type that store the dimensions. This method is
837   /// used by the constructor.
838   /// \param dims of the tensor (in elements).
839   /// \param alignments of the tensor (in bytes).
initDimsfinal840   void initDims(llvm::ArrayRef<dim_t> dims, llvm::ArrayRef<dim_t> alignments) {
841     assert(dims.size() <= max_tensor_dimensions && "Too many dimensions.");
842     assert(dims.size() == alignments.size() &&
843            "The number of dimensions and alignments should be the same");
844     // Update the tensor strides and sizes based on given dims and alignments.
845     // Sizes are simply assigned to dims. And strides are computed as partial
846     // product of dims, making sure that each dimension is aligned as required.
847     numSizes_ = dims.size();
848     if (numSizes_ > 0) {
849       // Stride of the last dimension is always 1.
850       assert(alignments[numSizes_ - 1] == 1 &&
851              "Last dimension must always be aligned.");
852       strides_[numSizes_ - 1] = 1;
853       sizes_[numSizes_ - 1] = dims[numSizes_ - 1];
854     }
855     for (int i = numSizes_ - 2; i >= 0; i--) {
856       dim_t alignment = alignments[i];
857       if (alignment != 1) {
858         assert(alignment % getElementSize() == 0 &&
859                "Alignment should be a multiple of element size");
860         alignment /= getElementSize();
861       }
862       // All the strides (except for last one) depend on the previous dimension.
863       strides_[i] = alignedSize(dims[i + 1] * strides_[i + 1], alignment);
864       sizes_[i] = dims[i];
865     }
866   }
867 
initDimsfinal868   void initDims(llvm::ArrayRef<dim_t> dims) {
869     assert(dims.size() <= max_tensor_dimensions && "Too many dimensions.");
870     // Update the tensor sizes.
871     for (size_t i = 0, e = dims.size(); i < e; i++) {
872       sizes_[i] = dims[i];
873     }
874     numSizes_ = dims.size();
875   }
876 };
877 
878 inline bool operator==(const Type &LHS, const Type &RHS) {
879   return LHS.isEqual(RHS);
880 }
881 
882 llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Type &type);
883 llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const TypeRef &type);
884 
885 } // namespace glow
886 
887 #endif // GLOW_BASE_TYPE_H
888