1 /**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16 #ifndef GLOW_BASE_TYPE_H
17 #define GLOW_BASE_TYPE_H
18
19 #include "DimType.h"
20
21 #include "glow/Support/BFloat16.h"
22 #include "glow/Support/Compiler.h"
23 #include "glow/Support/Float16.h"
24 #include "glow/Support/Memory.h"
25
26 #include "llvm/ADT/ArrayRef.h"
27 #include "llvm/ADT/StringRef.h"
28
29 #include <glog/logging.h>
30
31 #include <cstddef>
32 #include <cstdint>
33 #include <type_traits>
34 #include <utility>
35
36 namespace llvm {
37 class raw_ostream;
38 }
39
40 namespace glow {
41
42 // UINT8_MIN is not defined in standard headers.
43 // Define it here for using these definitions consistently.
44 #define UINT8_MIN 0
45
46 struct Type;
47
48 using TypeRef = const Type *;
49
50 constexpr unsigned max_tensor_dimensions = 6;
51
52 /// This type is used to implement the Node and Instruction builder's
53 /// MemberType::Unsigned and MemberType::VectorUnsigned. Thus it should be used
54 /// when handling members of these classes, e.g. a convolution Node/Instr's
55 /// getGroup() (Unsigned), or getKernels() (UnsignedVector).
56 using unsigned_t = uint32_t;
57
58 using float16_t = float16;
59 static_assert(sizeof(float16_t) == 2, "Half precision should be 16-bit");
60
61 using bfloat16_t = bfloat16;
62 static_assert(sizeof(bfloat16_t) == 2, "bfloat16 should be 16-bit");
63
64 using ShapeVector = llvm::SmallVector<dim_t, max_tensor_dimensions>;
65
66 struct ShapeNHWC {
67
68 enum {
69 DimN,
70 DimH,
71 DimW,
72 DimC,
73 };
74
75 dim_t n; // Number of samples
76 dim_t h; // Height
77 dim_t w; // Width
78 dim_t c; // Number of Channels
79
ShapeNHWCShapeNHWC80 template <typename T> explicit ShapeNHWC(llvm::ArrayRef<T> shape) {
81 assert(shape.size() == 4 && "Invalid shape");
82 n = shape[DimN];
83 h = shape[DimH];
84 w = shape[DimW];
85 c = shape[DimC];
86 }
87
ShapeNHWCShapeNHWC88 ShapeNHWC(dim_t samples, dim_t height, dim_t width, dim_t channels)
89 : n(samples), h(height), w(width), c(channels) {}
90
equalsShapeNHWC91 bool equals(const ShapeNHWC &other) const {
92 return n == other.n && h == other.h && w == other.w && c == other.c;
93 }
94 };
95
96 struct ShapeNTHWC {
97 dim_t n; // Number of samples
98 dim_t t; // Temporal frames
99 dim_t h; // Height
100 dim_t w; // Width
101 dim_t c; // Number of Channels
102
ShapeNTHWCShapeNTHWC103 template <typename T> explicit ShapeNTHWC(llvm::ArrayRef<T> shape) {
104 assert(shape.size() == 5 && "Invalid shape");
105 n = shape[0];
106 t = shape[1];
107 h = shape[2];
108 w = shape[3];
109 c = shape[4];
110 }
111
ShapeNTHWCShapeNTHWC112 ShapeNTHWC(dim_t samples, dim_t temporal_frames, dim_t height, dim_t width,
113 dim_t channels)
114 : n(samples), t(temporal_frames), h(height), w(width), c(channels) {}
115
equalsShapeNTHWC116 bool equals(const ShapeNTHWC &other) const {
117 return n == other.n && t == other.t && h == other.h && w == other.w &&
118 c == other.c;
119 }
120 };
121
122 struct ShapeNHWTC {
123 dim_t n; // Number of samples
124 dim_t h; // Height
125 dim_t w; // Width
126 dim_t t; // Temporal_frames
127 dim_t c; // Number of Channels
128
ShapeNHWTCShapeNHWTC129 template <typename T> explicit ShapeNHWTC(llvm::ArrayRef<T> shape) {
130 assert(shape.size() == 5 && "Invalid shape");
131 n = shape[0];
132 h = shape[1];
133 w = shape[2];
134 t = shape[3];
135 c = shape[4];
136 }
137
ShapeNHWTCShapeNHWTC138 ShapeNHWTC(size_t samples, size_t height, size_t width,
139 size_t temporal_frames, size_t channels)
140 : n(samples), h(height), w(width), t(temporal_frames), c(channels) {}
141
equalsShapeNHWTC142 bool equals(const ShapeNHWTC &other) const {
143 return n == other.n && h == other.h && w == other.w && t == other.t &&
144 c == other.c;
145 }
146 };
147
148 struct ShapeNCHW {
149
150 enum {
151 DimN,
152 DimC,
153 DimH,
154 DimW,
155 };
156
157 dim_t n; // Number of samples
158 dim_t c; // Number of Channels
159 dim_t h; // Height
160 dim_t w; // Width
161
ShapeNCHWShapeNCHW162 explicit ShapeNCHW(llvm::ArrayRef<dim_t> shape) {
163 assert(shape.size() == 4 && "Invalid shape");
164 n = shape[DimN];
165 c = shape[DimC];
166 h = shape[DimH];
167 w = shape[DimW];
168 }
169
ShapeNCHWShapeNCHW170 ShapeNCHW(dim_t samples, dim_t channels, dim_t height, dim_t width)
171 : n(samples), c(channels), h(height), w(width) {}
172
equalsShapeNCHW173 bool equals(const ShapeNCHW &other) const {
174 return n == other.n && h == other.h && w == other.w && c == other.c;
175 }
176 };
177
178 struct ShapeNCTHW {
179 dim_t n; // Number of samples
180 dim_t c; // Number of Channels
181 dim_t t; // Temporal frames
182 dim_t h; // Height
183 dim_t w; // Width
184
ShapeNCTHWShapeNCTHW185 explicit ShapeNCTHW(llvm::ArrayRef<dim_t> shape) {
186 assert(shape.size() == 5 && "Invalid shape");
187 n = shape[0];
188 c = shape[1];
189 t = shape[2];
190 h = shape[3];
191 w = shape[4];
192 }
193
ShapeNCTHWShapeNCTHW194 ShapeNCTHW(dim_t samples, dim_t channels, dim_t temporal_frames, dim_t height,
195 dim_t width)
196 : n(samples), c(channels), t(temporal_frames), h(height), w(width) {}
197
equalsShapeNCTHW198 bool equals(const ShapeNCTHW &other) const {
199 return n == other.n && t == other.t && h == other.h && w == other.w &&
200 c == other.c;
201 }
202 };
203
204 struct PaddingTLBR {
205 dim_t top;
206 dim_t left;
207 dim_t bottom;
208 dim_t right;
209
PaddingTLBRPaddingTLBR210 template <typename T> explicit PaddingTLBR(llvm::ArrayRef<T> pads) {
211 assert(pads.size() == 4 && "Invalid padding");
212 top = pads[0];
213 left = pads[1];
214 bottom = pads[2];
215 right = pads[3];
216 }
217
equalPaddingPaddingTLBR218 bool equalPadding() const {
219 return top == left && top == bottom && top == right;
220 }
221 };
222
223 struct PaddingTLNBRF {
224 dim_t top;
225 dim_t left;
226 dim_t near;
227 dim_t bottom;
228 dim_t right;
229 dim_t far;
230
PaddingTLNBRFPaddingTLNBRF231 template <typename T> explicit PaddingTLNBRF(llvm::ArrayRef<T> pads) {
232 assert(pads.size() == 6 && "Invalid padding");
233 top = pads[0];
234 left = pads[1];
235 near = pads[2];
236 bottom = pads[3];
237 right = pads[4];
238 far = pads[5];
239 }
240
equalPaddingPaddingTLNBRF241 bool equalPadding() const {
242 return top == left && top == bottom && top == right && top == near &&
243 top == far;
244 }
245 };
246
247 struct PaddingNFTBLR {
248 dim_t near;
249 dim_t far;
250 dim_t top;
251 dim_t bottom;
252 dim_t left;
253 dim_t right;
254
PaddingNFTBLRPaddingNFTBLR255 template <typename T> explicit PaddingNFTBLR(llvm::ArrayRef<T> pads) {
256 assert(pads.size() == 6 && "Invalid padding");
257 near = pads[0];
258 far = pads[1];
259 top = pads[2];
260 bottom = pads[3];
261 left = pads[4];
262 right = pads[5];
263 }
264
equalPaddingPaddingNFTBLR265 bool equalPadding() const {
266 return top == left && top == bottom && top == right && top == near &&
267 top == far;
268 }
269 };
270
271 struct ShapeHW {
272
273 enum {
274 DimH,
275 DimW,
276 };
277
278 dim_t height;
279 dim_t width;
280
ShapeHWShapeHW281 template <typename T> explicit ShapeHW(llvm::ArrayRef<T> shape) {
282 assert(shape.size() == 2 && "Invalid shape");
283 height = shape[DimH];
284 width = shape[DimW];
285 }
286
isSquareShapeHW287 bool isSquare() const { return height == width; }
288 };
289
290 struct ShapeNHW {
291
292 enum {
293 DimN,
294 DimH,
295 DimW,
296 };
297
298 dim_t n; // Number of samples
299 dim_t h; // Height
300 dim_t w; // Width
301
ShapeNHWShapeNHW302 template <typename T> explicit ShapeNHW(llvm::ArrayRef<T> shape) {
303 assert(shape.size() == 3 && "Invalid shape");
304 n = shape[DimN];
305 h = shape[DimH];
306 w = shape[DimW];
307 }
308
isSquareShapeNHW309 bool isSquare() const { return h == w; }
310 };
311
312 struct ShapeHWT {
313 dim_t height;
314 dim_t width;
315 dim_t temporal_frames;
316
ShapeHWTShapeHWT317 template <typename T> explicit ShapeHWT(llvm::ArrayRef<T> shape) {
318 assert(shape.size() == 3 && "Invalid shape");
319 height = shape[0];
320 width = shape[1];
321 temporal_frames = shape[2];
322 }
323
isCubeShapeHWT324 bool isCube() const { return height == width && height == temporal_frames; }
325 };
326
327 struct ShapeTHW {
328 dim_t temporal_frames;
329 dim_t height;
330 dim_t width;
331
ShapeTHWShapeTHW332 template <typename T> explicit ShapeTHW(llvm::ArrayRef<T> shape) {
333 assert(shape.size() == 3 && "Invalid shape");
334 temporal_frames = shape[0];
335 height = shape[1];
336 width = shape[2];
337 }
338
isCubeShapeTHW339 bool isCube() const { return height == width && height == temporal_frames; }
340 };
341
342 /// Collapse a tensor shape into two sizes: the first n dimensions and the size
343 /// of the rest of the dimensions. For example, ([7, 3, 4, 2], 1) -> [7, 24]
344 inline std::pair<dim_t, dim_t> flattenCdr(llvm::ArrayRef<dim_t> dims,
345 unsigned_t n = 1) {
346 assert(1 <= n && n <= dims.size());
347 size_t first = dims[0];
348 for (unsigned_t i = 1; i < n; i++) {
349 first *= dims[i];
350 }
351 size_t rest = 1;
352 for (unsigned_t i = n; i < dims.size(); i++) {
353 rest *= dims[i];
354 }
355
356 return {first, rest};
357 }
358
359 inline bool operator==(const ShapeNHWC &LHS, const ShapeNHWC &RHS) {
360 return LHS.equals(RHS);
361 }
362
363 inline bool operator==(const ShapeNCHW &LHS, const ShapeNCHW &RHS) {
364 return LHS.equals(RHS);
365 }
366
367 inline bool operator==(const ShapeNHWTC &LHS, const ShapeNHWTC &RHS) {
368 return LHS.equals(RHS);
369 }
370
371 inline bool operator==(const ShapeNTHWC &LHS, const ShapeNTHWC &RHS) {
372 return LHS.equals(RHS);
373 }
374
375 inline bool operator==(const ShapeNCTHW &LHS, const ShapeNCTHW &RHS) {
376 return LHS.equals(RHS);
377 }
378
379 /// An enum representing the type used by the elements of a tensor. The types of
380 /// Handles for these tensors should match the element kind.
381 /// When adding new type, note that this enum definition must match with
382 /// ElemKind definition in Glow/lib/Backends/CPU/libjit/libjit.cpp
383 enum class ElemKind : unsigned char {
384 // 32-bit float type (float)
385 FloatTy,
386 // 16-bit float type (half, fp16)
387 Float16Ty,
388 // 16-bit float type (bfloat16)
389 BFloat16Ty,
390 // 8-bit quantized type (int8_t)
391 Int8QTy,
392 // unsigned 8-bit quantized type (uint8_t)
393 UInt8QTy,
394 // 16-bit quantized type (int16_t)
395 Int16QTy,
396 // 32-bit quantized type (int32_t)
397 Int32QTy,
398 // 32-bit index type (int32_t)
399 Int32ITy,
400 // 64-bit index type (int64_t)
401 Int64ITy,
402 // 8-bit quantized type with fused scale/offset (uint8_t)
403 UInt8FusedQTy,
404 // 8-bit quantized type with fused FP16 scale/offset (uint8_t)
405 UInt8FusedFP16QTy,
406 // 4-bit quantized type with fused FP16 scale/offset (uint8_t, each byte
407 // represents 2 4-bit quantized data)
408 UInt4FusedFP16QTy,
409 // Bool type (bool)
410 BoolTy,
411 };
412
413 /// \returns whether \p e is a quantized ElemKind.
isQuantizedElemKind(ElemKind e)414 inline bool isQuantizedElemKind(ElemKind e) {
415 return e == ElemKind::Int8QTy || e == ElemKind::UInt8QTy ||
416 e == ElemKind::Int16QTy || e == ElemKind::Int32QTy ||
417 e == ElemKind::UInt8FusedQTy || e == ElemKind::UInt8FusedFP16QTy ||
418 e == ElemKind::UInt4FusedFP16QTy;
419 }
420
421 /// \returns whether \p e is a float ElemKind.
isFloatElemKind(ElemKind e)422 inline bool isFloatElemKind(ElemKind e) {
423 return e == ElemKind::FloatTy || e == ElemKind::Float16Ty ||
424 e == ElemKind::BFloat16Ty;
425 }
426
427 /// \returns whether \p e is a fused quantized ElemKind.
isFusedQuantizedElemKind(ElemKind e)428 inline bool isFusedQuantizedElemKind(ElemKind e) {
429 return e == ElemKind::UInt8FusedQTy || e == ElemKind::UInt8FusedFP16QTy ||
430 e == ElemKind::UInt4FusedFP16QTy;
431 }
432
433 /// \returns the scale and offset ElemKind used by the fused ElemKind \p e.
getScaleOffsetElemKindFromFused(ElemKind e)434 inline ElemKind getScaleOffsetElemKindFromFused(ElemKind e) {
435 assert(isFusedQuantizedElemKind(e) && "Must pass Fused ElemKind.");
436 if (e == ElemKind::UInt8FusedQTy) {
437 return ElemKind::FloatTy;
438 }
439 return ElemKind::Float16Ty;
440 }
441
442 /// A class that represents a type of a tensor.
443 struct Type final {
444 /// Contains the dimensions (sizes) of the tensor. Ex: [sx, sy, sz, ...].
445 dim_t sizes_[max_tensor_dimensions] = {
446 0,
447 };
448 /// Contains the strides for each dimension (in elements). The order should be
449 /// the same as in sizes_. In more details, suppose that the tensor is laid
450 /// out flat in memory, and some dimensions are aligned. strides_[i] is the
451 /// number of elements that needs to be skipped in order to reach the next
452 /// plane in the i-th dimension. For example, if the tensor has dimensions
453 /// [3, 5, 10] and alignments [3, 32, 1], the strides will be [162, 32, 1].
454 dim_t strides_[max_tensor_dimensions] = {
455 0,
456 };
457
458 /// Contains the number of dimensions used by the tensor.
459 unsigned char numSizes_{0};
460
461 /// On quantized tensors, this represents the scale of the values.
462 float scale_{0};
463 /// On quantized tensors, this represents the offset of the values.
464 int32_t offset_{0};
465
466 /// Specifies the element type of the tensor.
467 ElemKind elementType_{ElemKind::Int64ITy};
468
469 /// Initialize a new quantized type with \p scale and \p offset.
Typefinal470 Type(ElemKind elemTy, llvm::ArrayRef<dim_t> dims, float scale, int32_t offset)
471 : scale_(scale), offset_(offset), elementType_(elemTy) {
472 assert(isQuantizedType() && "Only quantized types have a scale and offset");
473 ShapeVector alignments(dims.size(), 1);
474 initDims(dims, llvm::makeArrayRef(alignments));
475 }
476
477 /// Initialize a new non-quantized type.
Typefinal478 Type(ElemKind elemTy, llvm::ArrayRef<dim_t> dims) : elementType_(elemTy) {
479 assert(!isQuantizedType() &&
480 "Can't initialize quantized types without scale and offset");
481 ShapeVector alignments(dims.size(), 1);
482 initDims(dims, llvm::makeArrayRef(alignments));
483 }
484
485 /// Initialize a new quantized type with \p scale and \p offset.
Typefinal486 Type(ElemKind elemTy, llvm::ArrayRef<dim_t> dims,
487 llvm::ArrayRef<dim_t> alignments, float scale, int32_t offset)
488 : scale_(scale), offset_(offset), elementType_(elemTy) {
489 assert(isQuantizedType() && "Only quantized types have a scale and offset");
490 initDims(dims, alignments);
491 }
492
493 /// Initialize a new non-quantized type.
Typefinal494 Type(ElemKind elemTy, llvm::ArrayRef<dim_t> dims,
495 llvm::ArrayRef<dim_t> alignments)
496 : elementType_(elemTy) {
497 assert(!isQuantizedType() &&
498 "Can't initialize quantized types without scale and offset");
499 initDims(dims, alignments);
500 }
501
502 /// Reshape existing type. This method takes care of quantized types.
newShapefinal503 static Type newShape(const Type &T, llvm::ArrayRef<dim_t> dims) {
504 if (T.isQuantizedType()) {
505 return Type(T.getElementType(), dims, T.getScale(), T.getOffset());
506 } else {
507 return Type(T.getElementType(), dims);
508 }
509 }
510
511 /// Reshape existing type and change alignments.
newShapefinal512 static Type newShape(const Type &T, llvm::ArrayRef<dim_t> dims,
513 llvm::ArrayRef<dim_t> alignments) {
514 if (T.isQuantizedType()) {
515 return Type(T.getElementType(), dims, alignments, T.getScale(),
516 T.getOffset());
517 } else {
518 return Type(T.getElementType(), dims, alignments);
519 }
520 }
521
522 /// Reshape existing type by taking shapes and strides of \p shapeType.
newShapefinal523 static Type newShape(const Type &T, TypeRef shapeType) {
524 Type ty;
525 if (T.isQuantizedType()) {
526 ty = Type(T.getElementType(), shapeType->dims(), T.getScale(),
527 T.getOffset());
528 } else {
529 ty = Type(T.getElementType(), shapeType->dims());
530 }
531 // Copy the stride information.
532 std::copy(&shapeType->strides_[0], &shapeType->strides_[ty.numSizes_],
533 ty.strides_);
534 return ty;
535 }
536
537 /// An empty type.
538 Type() = default;
539
540 /// \returns true if \p other is the same type.
isEqualfinal541 bool isEqual(TypeRef other) const { return isEqual(*other); }
542
543 /// \returns the scale of a quantized type.
getScalefinal544 float getScale() const {
545 assert(isQuantizedType() && "Can't get the scale of a non-quantized type");
546 return scale_;
547 }
548
549 /// \returns the offset of a quantized type.
getOffsetfinal550 int32_t getOffset() const {
551 assert(isQuantizedType() && "Can't get the offset of a non-quantized type");
552 return offset_;
553 }
554
555 /// \returns the floating point value range that covers a quantized type (min
556 /// first, max second).
getQuantizedValueRangefinal557 std::pair<float, float> getQuantizedValueRange() const {
558 assert(isQuantizedType() &&
559 "Can't get the quantized value range of a non-quantized type");
560
561 int64_t low = 0, high = 0;
562 switch (elementType_) {
563 case ElemKind::Int32QTy: {
564 low = INT32_MIN;
565 high = INT32_MAX;
566 break;
567 }
568 case ElemKind::Int16QTy: {
569 low = INT16_MIN;
570 high = INT16_MAX;
571 break;
572 }
573 case ElemKind::Int8QTy: {
574 low = INT8_MIN;
575 high = INT8_MAX;
576 break;
577 }
578 case ElemKind::UInt8QTy: {
579 low = UINT8_MIN;
580 high = UINT8_MAX;
581 break;
582 }
583 default:;
584 }
585
586 float lowFloat = (low - offset_) * scale_;
587 float highFloat = (high - offset_) * scale_;
588 return std::make_pair(lowFloat, highFloat);
589 }
590
591 /// \returns true if \p other is the same type. If \p allowDifferentShape then
592 /// shapes will not be considered as part of the equal comparison.
593 bool isEqual(const Type &other, bool allowDifferentShape = false,
594 bool allowDifferentStrides = false) const {
595 // Element type must be the same.
596 if (elementType_ != other.elementType_) {
597 return false;
598 }
599 // Must have the same number of sizes.
600 if (numSizes_ != other.numSizes_) {
601 return false;
602 }
603 // Sizes must be the same.
604 if (!allowDifferentShape) {
605 for (size_t i = 0; i < numSizes_; i++) {
606 if (sizes_[i] != other.sizes_[i]) {
607 return false;
608 }
609 }
610 if (!allowDifferentStrides) {
611 // Strides must be the same.
612 for (size_t i = 0; i < numSizes_; i++) {
613 if (strides_[i] != other.strides_[i]) {
614 return false;
615 }
616 }
617 }
618 }
619
620 // Compare the scale and offset of integers.
621 if (isQuantizedType()) {
622 if (scale_ != other.scale_ || offset_ != other.offset_) {
623 return false;
624 }
625 }
626
627 return true;
628 }
629
630 /// \returns a hash value for this Type. Hashes for Ty1 and Ty2 are equal if
631 /// Ty1.isEqual(Ty2).
equals_hashfinal632 llvm::hash_code equals_hash() const {
633 return llvm::hash_combine(
634 elementType_, dims(),
635 // hashing floats is tricky, fall back to std::hash
636 std::hash<float>{}(scale_), offset_);
637 }
638
getElementTypefinal639 ElemKind getElementType() const { return elementType_; }
640
641 /// \returns the shape of the tensor.
dimsfinal642 llvm::ArrayRef<dim_t> dims() const { return {sizes_, numSizes_}; }
643
644 /// \returns the strides of the tensor.
stridesfinal645 llvm::ArrayRef<dim_t> strides() const { return {strides_, numSizes_}; }
646
647 /// \returns the number of elements in the tensor.
sizefinal648 dim_t size() const {
649 dim_t s = 1;
650 for (unsigned char i = 0; i < numSizes_; i++) {
651 s *= dim_t(sizes_[i]);
652 }
653
654 return s;
655 }
656
657 /// \returns the number of elements in a slice in the tensor. Calculate the
658 /// size of the slice starting at \p startDim. For example, the tensor with
659 /// the shape [10, 10, 3] and startDim 1 would have the size 30, because this
660 /// is the size of the slice [10, 3] that starts at index 1.
getSliceSizefinal661 dim_t getSliceSize(unsigned char startDim) const {
662 assert(startDim <= numSizes_ && "Invalid start dim");
663 dim_t s = 1;
664 for (unsigned char i = startDim; i < numSizes_; i++) {
665 s *= dim_t(sizes_[i]);
666 }
667 return s;
668 }
669
670 /// \returns true if the templated parameter \p ElemTy matches this type.
isTypefinal671 template <class ElemTy> bool isType() const {
672 return isType<ElemTy>(elementType_);
673 }
674
675 /// \returns true if the templated parameter \p ElemTy matches the type that's
676 /// specified by the parameter \p Ty.
isTypefinal677 template <class ElemTy> static bool isType(ElemKind Ty) {
678 switch (Ty) {
679 case ElemKind::FloatTy:
680 return std::is_same<ElemTy, float>::value;
681 case ElemKind::Float16Ty:
682 return std::is_same<ElemTy, float16_t>::value;
683 case ElemKind::BFloat16Ty:
684 return std::is_same<ElemTy, bfloat16_t>::value;
685 case ElemKind::Int8QTy:
686 return std::is_same<ElemTy, int8_t>::value;
687 case ElemKind::UInt8QTy:
688 return std::is_same<ElemTy, uint8_t>::value;
689 case ElemKind::Int16QTy:
690 return std::is_same<ElemTy, int16_t>::value;
691 case ElemKind::Int32QTy:
692 return std::is_same<ElemTy, int32_t>::value;
693 case ElemKind::Int32ITy:
694 return std::is_same<ElemTy, int32_t>::value;
695 case ElemKind::Int64ITy:
696 return std::is_same<ElemTy, int64_t>::value;
697 case ElemKind::UInt8FusedQTy:
698 return std::is_same<ElemTy, uint8_t>::value;
699 case ElemKind::UInt8FusedFP16QTy:
700 return std::is_same<ElemTy, uint8_t>::value;
701 case ElemKind::UInt4FusedFP16QTy:
702 return std::is_same<ElemTy, uint8_t>::value;
703 case ElemKind::BoolTy:
704 return std::is_same<ElemTy, bool>::value;
705 }
706 LOG(FATAL) << "Invalid type: " << getElementName(Ty).str();
707 }
708
709 /// \returns true if the type of this Tensor is one of the quantized types.
isQuantizedTypefinal710 bool isQuantizedType() const { return isQuantizedElemKind(elementType_); }
711
712 /// \returns true if the type of this Tensor is one of the floating point
713 /// types.
isFPTypefinal714 bool isFPType() const { return isFloatElemKind(getElementType()); }
715
716 /// \return the size of the type element.
getElementSizefinal717 unsigned getElementSize() const { return getElementSize(elementType_); }
718
719 /// \returns the size in bytes for this Tensor.
getSizeInBytesfinal720 size_t getSizeInBytes() const {
721 size_t s = getElementSize();
722 for (unsigned char i = 0; i < numSizes_; i++) {
723 // If any dimensions are 0 then the entire size is 0, so early return.
724 if (sizes_[i] == 0) {
725 return 0;
726 }
727 s = std::max<dim_t>(s,
728 size_t(sizes_[i]) * getElementSize() * strides_[i]);
729 }
730 return s;
731 }
732
733 /// \returns the actual number of elements in the tensor taking striding into
734 /// account. Since size() does not take striding into account, size() is
735 /// always <= actualSize().
actualSizefinal736 size_t actualSize() const { return getSizeInBytes() / getElementSize(); }
737
738 /// \return the size of the element \p Ty.
getElementSizefinal739 static unsigned getElementSize(ElemKind Ty) {
740 switch (Ty) {
741 case ElemKind::FloatTy:
742 return sizeof(float);
743 case ElemKind::Float16Ty:
744 return sizeof(float16_t);
745 case ElemKind::BFloat16Ty:
746 return sizeof(bfloat16_t);
747 case ElemKind::Int8QTy:
748 return sizeof(int8_t);
749 case ElemKind::UInt8QTy:
750 return sizeof(uint8_t);
751 case ElemKind::Int16QTy:
752 return sizeof(int16_t);
753 case ElemKind::Int32QTy:
754 return sizeof(int32_t);
755 case ElemKind::Int32ITy:
756 return sizeof(int32_t);
757 case ElemKind::Int64ITy:
758 return sizeof(int64_t);
759 case ElemKind::UInt8FusedQTy:
760 return sizeof(uint8_t);
761 case ElemKind::UInt8FusedFP16QTy:
762 return sizeof(uint8_t);
763 case ElemKind::UInt4FusedFP16QTy:
764 return sizeof(uint8_t);
765 case ElemKind::BoolTy:
766 return sizeof(bool);
767 }
768 LOG(FATAL) << "Invalid type: " << getElementName(Ty).str();
769 }
770
771 /// \return the textual name of the element.
getElementNamefinal772 llvm::StringRef getElementName() const {
773 return getElementName(elementType_);
774 }
775
776 /// \return the textual name of the element \p Ty.
getElementNamefinal777 static llvm::StringRef getElementName(ElemKind Ty) {
778 static const char *names[] = {
779 "float", "float16", "bfloat16", "i8", "ui8",
780 "i16", "i32", "index32", "index64", "ui8fused",
781 "ui8fusedfp16", "ui4fusedfp16", "bool",
782 };
783 return names[(int)Ty];
784 }
785
786 /// Given a string \p str containing the name of an ElemKind from
787 /// Type::getElementName, returns the corresponding ElemKind or Error if a
788 /// mapping couldn't be found.
getElementKindFromNamefinal789 static ElemKind getElementKindFromName(llvm::StringRef str) {
790 if (str == Type::getElementName(ElemKind::FloatTy)) {
791 return ElemKind::FloatTy;
792 } else if (str == Type::getElementName(ElemKind::Float16Ty)) {
793 return ElemKind::Float16Ty;
794 } else if (str == Type::getElementName(ElemKind::BFloat16Ty)) {
795 return ElemKind::BFloat16Ty;
796 } else if (str == Type::getElementName(ElemKind::Int8QTy)) {
797 return ElemKind::Int8QTy;
798 } else if (str == Type::getElementName(ElemKind::UInt8QTy)) {
799 return ElemKind::UInt8QTy;
800 } else if (str == Type::getElementName(ElemKind::Int16QTy)) {
801 return ElemKind::Int16QTy;
802 } else if (str == Type::getElementName(ElemKind::Int32QTy)) {
803 return ElemKind::Int32QTy;
804 } else if (str == Type::getElementName(ElemKind::Int32ITy)) {
805 return ElemKind::Int32ITy;
806 } else if (str == Type::getElementName(ElemKind::Int64ITy)) {
807 return ElemKind::Int64ITy;
808 } else if (str == Type::getElementName(ElemKind::UInt8FusedQTy)) {
809 return ElemKind::UInt8FusedQTy;
810 } else if (str == Type::getElementName(ElemKind::UInt8FusedFP16QTy)) {
811 return ElemKind::UInt8FusedFP16QTy;
812 } else if (str == Type::getElementName(ElemKind::UInt4FusedFP16QTy)) {
813 return ElemKind::UInt4FusedFP16QTy;
814 } else if (str == Type::getElementName(ElemKind::BoolTy)) {
815 return ElemKind::BoolTy;
816 } else {
817 LOG(DFATAL) << "Invalid ElemKind string: " << str.str();
818 return ElemKind::FloatTy;
819 }
820 }
821
822 /// Dump a textual representation of the Type into provided output stream.
823 void dump(llvm::raw_ostream &out) const;
824
825 /// Dump a textual representation of the Type into default output stream.
826 void dump() const;
827
828 /// Dump a textual representation of the Type to std::string.
829 std::string toString() const;
830
831 /// Load a Type object from a textual representation \p str. This method is
832 /// paired and should be used together with \ref toString.
833 static Type fromString(llvm::StringRef str);
834
835 private:
836 /// Setup the internals of type that store the dimensions. This method is
837 /// used by the constructor.
838 /// \param dims of the tensor (in elements).
839 /// \param alignments of the tensor (in bytes).
initDimsfinal840 void initDims(llvm::ArrayRef<dim_t> dims, llvm::ArrayRef<dim_t> alignments) {
841 assert(dims.size() <= max_tensor_dimensions && "Too many dimensions.");
842 assert(dims.size() == alignments.size() &&
843 "The number of dimensions and alignments should be the same");
844 // Update the tensor strides and sizes based on given dims and alignments.
845 // Sizes are simply assigned to dims. And strides are computed as partial
846 // product of dims, making sure that each dimension is aligned as required.
847 numSizes_ = dims.size();
848 if (numSizes_ > 0) {
849 // Stride of the last dimension is always 1.
850 assert(alignments[numSizes_ - 1] == 1 &&
851 "Last dimension must always be aligned.");
852 strides_[numSizes_ - 1] = 1;
853 sizes_[numSizes_ - 1] = dims[numSizes_ - 1];
854 }
855 for (int i = numSizes_ - 2; i >= 0; i--) {
856 dim_t alignment = alignments[i];
857 if (alignment != 1) {
858 assert(alignment % getElementSize() == 0 &&
859 "Alignment should be a multiple of element size");
860 alignment /= getElementSize();
861 }
862 // All the strides (except for last one) depend on the previous dimension.
863 strides_[i] = alignedSize(dims[i + 1] * strides_[i + 1], alignment);
864 sizes_[i] = dims[i];
865 }
866 }
867
initDimsfinal868 void initDims(llvm::ArrayRef<dim_t> dims) {
869 assert(dims.size() <= max_tensor_dimensions && "Too many dimensions.");
870 // Update the tensor sizes.
871 for (size_t i = 0, e = dims.size(); i < e; i++) {
872 sizes_[i] = dims[i];
873 }
874 numSizes_ = dims.size();
875 }
876 };
877
878 inline bool operator==(const Type &LHS, const Type &RHS) {
879 return LHS.isEqual(RHS);
880 }
881
882 llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const Type &type);
883 llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const TypeRef &type);
884
885 } // namespace glow
886
887 #endif // GLOW_BASE_TYPE_H
888