1 //===- StandardTypes.cpp - MLIR Standard Type Classes ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/IR/StandardTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
13 #include "mlir/IR/Diagnostics.h"
14 #include "mlir/IR/Dialect.h"
15 #include "llvm/ADT/APFloat.h"
16 #include "llvm/ADT/Twine.h"
17
18 using namespace mlir;
19 using namespace mlir::detail;
20
21 //===----------------------------------------------------------------------===//
22 // Type
23 //===----------------------------------------------------------------------===//
24
isBF16()25 bool Type::isBF16() { return isa<BFloat16Type>(); }
isF16()26 bool Type::isF16() { return isa<Float16Type>(); }
isF32()27 bool Type::isF32() { return isa<Float32Type>(); }
isF64()28 bool Type::isF64() { return isa<Float64Type>(); }
29
isIndex()30 bool Type::isIndex() { return isa<IndexType>(); }
31
32 /// Return true if this is an integer type with the specified width.
isInteger(unsigned width)33 bool Type::isInteger(unsigned width) {
34 if (auto intTy = dyn_cast<IntegerType>())
35 return intTy.getWidth() == width;
36 return false;
37 }
38
isSignlessInteger()39 bool Type::isSignlessInteger() {
40 if (auto intTy = dyn_cast<IntegerType>())
41 return intTy.isSignless();
42 return false;
43 }
44
isSignlessInteger(unsigned width)45 bool Type::isSignlessInteger(unsigned width) {
46 if (auto intTy = dyn_cast<IntegerType>())
47 return intTy.isSignless() && intTy.getWidth() == width;
48 return false;
49 }
50
isSignedInteger()51 bool Type::isSignedInteger() {
52 if (auto intTy = dyn_cast<IntegerType>())
53 return intTy.isSigned();
54 return false;
55 }
56
isSignedInteger(unsigned width)57 bool Type::isSignedInteger(unsigned width) {
58 if (auto intTy = dyn_cast<IntegerType>())
59 return intTy.isSigned() && intTy.getWidth() == width;
60 return false;
61 }
62
isUnsignedInteger()63 bool Type::isUnsignedInteger() {
64 if (auto intTy = dyn_cast<IntegerType>())
65 return intTy.isUnsigned();
66 return false;
67 }
68
isUnsignedInteger(unsigned width)69 bool Type::isUnsignedInteger(unsigned width) {
70 if (auto intTy = dyn_cast<IntegerType>())
71 return intTy.isUnsigned() && intTy.getWidth() == width;
72 return false;
73 }
74
isSignlessIntOrIndex()75 bool Type::isSignlessIntOrIndex() {
76 return isSignlessInteger() || isa<IndexType>();
77 }
78
isSignlessIntOrIndexOrFloat()79 bool Type::isSignlessIntOrIndexOrFloat() {
80 return isSignlessInteger() || isa<IndexType, FloatType>();
81 }
82
isSignlessIntOrFloat()83 bool Type::isSignlessIntOrFloat() {
84 return isSignlessInteger() || isa<FloatType>();
85 }
86
isIntOrIndex()87 bool Type::isIntOrIndex() { return isa<IntegerType>() || isIndex(); }
88
isIntOrFloat()89 bool Type::isIntOrFloat() { return isa<IntegerType, FloatType>(); }
90
isIntOrIndexOrFloat()91 bool Type::isIntOrIndexOrFloat() { return isIntOrFloat() || isIndex(); }
92
getIntOrFloatBitWidth()93 unsigned Type::getIntOrFloatBitWidth() {
94 assert(isIntOrFloat() && "only integers and floats have a bitwidth");
95 if (auto intType = dyn_cast<IntegerType>())
96 return intType.getWidth();
97 return cast<FloatType>().getWidth();
98 }
99
100 //===----------------------------------------------------------------------===//
101 /// ComplexType
102 //===----------------------------------------------------------------------===//
103
get(Type elementType)104 ComplexType ComplexType::get(Type elementType) {
105 return Base::get(elementType.getContext(), elementType);
106 }
107
getChecked(Type elementType,Location location)108 ComplexType ComplexType::getChecked(Type elementType, Location location) {
109 return Base::getChecked(location, elementType);
110 }
111
112 /// Verify the construction of an integer type.
verifyConstructionInvariants(Location loc,Type elementType)113 LogicalResult ComplexType::verifyConstructionInvariants(Location loc,
114 Type elementType) {
115 if (!elementType.isIntOrFloat())
116 return emitError(loc, "invalid element type for complex");
117 return success();
118 }
119
getElementType()120 Type ComplexType::getElementType() { return getImpl()->elementType; }
121
122 //===----------------------------------------------------------------------===//
123 // Integer Type
124 //===----------------------------------------------------------------------===//
125
126 // static constexpr must have a definition (until in C++17 and inline variable).
127 constexpr unsigned IntegerType::kMaxWidth;
128
129 /// Verify the construction of an integer type.
130 LogicalResult
verifyConstructionInvariants(Location loc,unsigned width,SignednessSemantics signedness)131 IntegerType::verifyConstructionInvariants(Location loc, unsigned width,
132 SignednessSemantics signedness) {
133 if (width > IntegerType::kMaxWidth) {
134 return emitError(loc) << "integer bitwidth is limited to "
135 << IntegerType::kMaxWidth << " bits";
136 }
137 return success();
138 }
139
getWidth() const140 unsigned IntegerType::getWidth() const { return getImpl()->width; }
141
getSignedness() const142 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
143 return getImpl()->signedness;
144 }
145
146 //===----------------------------------------------------------------------===//
147 // Float Type
148 //===----------------------------------------------------------------------===//
149
getWidth()150 unsigned FloatType::getWidth() {
151 if (isa<Float16Type, BFloat16Type>())
152 return 16;
153 if (isa<Float32Type>())
154 return 32;
155 if (isa<Float64Type>())
156 return 64;
157 llvm_unreachable("unexpected float type");
158 }
159
160 /// Returns the floating semantics for the given type.
getFloatSemantics()161 const llvm::fltSemantics &FloatType::getFloatSemantics() {
162 if (isa<BFloat16Type>())
163 return APFloat::BFloat();
164 if (isa<Float16Type>())
165 return APFloat::IEEEhalf();
166 if (isa<Float32Type>())
167 return APFloat::IEEEsingle();
168 if (isa<Float64Type>())
169 return APFloat::IEEEdouble();
170 llvm_unreachable("non-floating point type used");
171 }
172
173 //===----------------------------------------------------------------------===//
174 // ShapedType
175 //===----------------------------------------------------------------------===//
176 constexpr int64_t ShapedType::kDynamicSize;
177 constexpr int64_t ShapedType::kDynamicStrideOrOffset;
178
getElementType() const179 Type ShapedType::getElementType() const {
180 return static_cast<ImplType *>(impl)->elementType;
181 }
182
getElementTypeBitWidth() const183 unsigned ShapedType::getElementTypeBitWidth() const {
184 return getElementType().getIntOrFloatBitWidth();
185 }
186
getNumElements() const187 int64_t ShapedType::getNumElements() const {
188 assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
189 auto shape = getShape();
190 int64_t num = 1;
191 for (auto dim : shape)
192 num *= dim;
193 return num;
194 }
195
getRank() const196 int64_t ShapedType::getRank() const { return getShape().size(); }
197
hasRank() const198 bool ShapedType::hasRank() const {
199 return !isa<UnrankedMemRefType, UnrankedTensorType>();
200 }
201
getDimSize(unsigned idx) const202 int64_t ShapedType::getDimSize(unsigned idx) const {
203 assert(idx < getRank() && "invalid index for shaped type");
204 return getShape()[idx];
205 }
206
isDynamicDim(unsigned idx) const207 bool ShapedType::isDynamicDim(unsigned idx) const {
208 assert(idx < getRank() && "invalid index for shaped type");
209 return isDynamic(getShape()[idx]);
210 }
211
getDynamicDimIndex(unsigned index) const212 unsigned ShapedType::getDynamicDimIndex(unsigned index) const {
213 assert(index < getRank() && "invalid index");
214 assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index");
215 return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic);
216 }
217
218 /// Get the number of bits require to store a value of the given shaped type.
219 /// Compute the value recursively since tensors are allowed to have vectors as
220 /// elements.
getSizeInBits() const221 int64_t ShapedType::getSizeInBits() const {
222 assert(hasStaticShape() &&
223 "cannot get the bit size of an aggregate with a dynamic shape");
224
225 auto elementType = getElementType();
226 if (elementType.isIntOrFloat())
227 return elementType.getIntOrFloatBitWidth() * getNumElements();
228
229 if (auto complexType = elementType.dyn_cast<ComplexType>()) {
230 elementType = complexType.getElementType();
231 return elementType.getIntOrFloatBitWidth() * getNumElements() * 2;
232 }
233
234 // Tensors can have vectors and other tensors as elements, other shaped types
235 // cannot.
236 assert(isa<TensorType>() && "unsupported element type");
237 assert((elementType.isa<VectorType, TensorType>()) &&
238 "unsupported tensor element type");
239 return getNumElements() * elementType.cast<ShapedType>().getSizeInBits();
240 }
241
getShape() const242 ArrayRef<int64_t> ShapedType::getShape() const {
243 if (auto vectorType = dyn_cast<VectorType>())
244 return vectorType.getShape();
245 if (auto tensorType = dyn_cast<RankedTensorType>())
246 return tensorType.getShape();
247 return cast<MemRefType>().getShape();
248 }
249
getNumDynamicDims() const250 int64_t ShapedType::getNumDynamicDims() const {
251 return llvm::count_if(getShape(), isDynamic);
252 }
253
hasStaticShape() const254 bool ShapedType::hasStaticShape() const {
255 return hasRank() && llvm::none_of(getShape(), isDynamic);
256 }
257
hasStaticShape(ArrayRef<int64_t> shape) const258 bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const {
259 return hasStaticShape() && getShape() == shape;
260 }
261
262 //===----------------------------------------------------------------------===//
263 // VectorType
264 //===----------------------------------------------------------------------===//
265
get(ArrayRef<int64_t> shape,Type elementType)266 VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) {
267 return Base::get(elementType.getContext(), shape, elementType);
268 }
269
getChecked(ArrayRef<int64_t> shape,Type elementType,Location location)270 VectorType VectorType::getChecked(ArrayRef<int64_t> shape, Type elementType,
271 Location location) {
272 return Base::getChecked(location, shape, elementType);
273 }
274
verifyConstructionInvariants(Location loc,ArrayRef<int64_t> shape,Type elementType)275 LogicalResult VectorType::verifyConstructionInvariants(Location loc,
276 ArrayRef<int64_t> shape,
277 Type elementType) {
278 if (shape.empty())
279 return emitError(loc, "vector types must have at least one dimension");
280
281 if (!isValidElementType(elementType))
282 return emitError(loc, "vector elements must be int or float type");
283
284 if (any_of(shape, [](int64_t i) { return i <= 0; }))
285 return emitError(loc, "vector types must have positive constant sizes");
286
287 return success();
288 }
289
getShape() const290 ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); }
291
292 //===----------------------------------------------------------------------===//
293 // TensorType
294 //===----------------------------------------------------------------------===//
295
296 // Check if "elementType" can be an element type of a tensor. Emit errors if
297 // location is not nullptr. Returns failure if check failed.
checkTensorElementType(Location location,Type elementType)298 static LogicalResult checkTensorElementType(Location location,
299 Type elementType) {
300 if (!TensorType::isValidElementType(elementType))
301 return emitError(location, "invalid tensor element type: ") << elementType;
302 return success();
303 }
304
305 /// Return true if the specified element type is ok in a tensor.
isValidElementType(Type type)306 bool TensorType::isValidElementType(Type type) {
307 // Note: Non standard/builtin types are allowed to exist within tensor
308 // types. Dialects are expected to verify that tensor types have a valid
309 // element type within that dialect.
310 return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
311 IndexType>() ||
312 !type.getDialect().getNamespace().empty();
313 }
314
315 //===----------------------------------------------------------------------===//
316 // RankedTensorType
317 //===----------------------------------------------------------------------===//
318
get(ArrayRef<int64_t> shape,Type elementType)319 RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape,
320 Type elementType) {
321 return Base::get(elementType.getContext(), shape, elementType);
322 }
323
getChecked(ArrayRef<int64_t> shape,Type elementType,Location location)324 RankedTensorType RankedTensorType::getChecked(ArrayRef<int64_t> shape,
325 Type elementType,
326 Location location) {
327 return Base::getChecked(location, shape, elementType);
328 }
329
verifyConstructionInvariants(Location loc,ArrayRef<int64_t> shape,Type elementType)330 LogicalResult RankedTensorType::verifyConstructionInvariants(
331 Location loc, ArrayRef<int64_t> shape, Type elementType) {
332 for (int64_t s : shape) {
333 if (s < -1)
334 return emitError(loc, "invalid tensor dimension size");
335 }
336 return checkTensorElementType(loc, elementType);
337 }
338
getShape() const339 ArrayRef<int64_t> RankedTensorType::getShape() const {
340 return getImpl()->getShape();
341 }
342
343 //===----------------------------------------------------------------------===//
344 // UnrankedTensorType
345 //===----------------------------------------------------------------------===//
346
get(Type elementType)347 UnrankedTensorType UnrankedTensorType::get(Type elementType) {
348 return Base::get(elementType.getContext(), elementType);
349 }
350
getChecked(Type elementType,Location location)351 UnrankedTensorType UnrankedTensorType::getChecked(Type elementType,
352 Location location) {
353 return Base::getChecked(location, elementType);
354 }
355
356 LogicalResult
verifyConstructionInvariants(Location loc,Type elementType)357 UnrankedTensorType::verifyConstructionInvariants(Location loc,
358 Type elementType) {
359 return checkTensorElementType(loc, elementType);
360 }
361
362 //===----------------------------------------------------------------------===//
363 // BaseMemRefType
364 //===----------------------------------------------------------------------===//
365
getMemorySpace() const366 unsigned BaseMemRefType::getMemorySpace() const {
367 return static_cast<ImplType *>(impl)->memorySpace;
368 }
369
370 //===----------------------------------------------------------------------===//
371 // MemRefType
372 //===----------------------------------------------------------------------===//
373
374 /// Get or create a new MemRefType based on shape, element type, affine
375 /// map composition, and memory space. Assumes the arguments define a
376 /// well-formed MemRef type. Use getChecked to gracefully handle MemRefType
377 /// construction failures.
get(ArrayRef<int64_t> shape,Type elementType,ArrayRef<AffineMap> affineMapComposition,unsigned memorySpace)378 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
379 ArrayRef<AffineMap> affineMapComposition,
380 unsigned memorySpace) {
381 auto result = getImpl(shape, elementType, affineMapComposition, memorySpace,
382 /*location=*/llvm::None);
383 assert(result && "Failed to construct instance of MemRefType.");
384 return result;
385 }
386
387 /// Get or create a new MemRefType based on shape, element type, affine
388 /// map composition, and memory space declared at the given location.
389 /// If the location is unknown, the last argument should be an instance of
390 /// UnknownLoc. If the MemRefType defined by the arguments would be
391 /// ill-formed, emits errors (to the handler registered with the context or to
392 /// the error stream) and returns nullptr.
getChecked(ArrayRef<int64_t> shape,Type elementType,ArrayRef<AffineMap> affineMapComposition,unsigned memorySpace,Location location)393 MemRefType MemRefType::getChecked(ArrayRef<int64_t> shape, Type elementType,
394 ArrayRef<AffineMap> affineMapComposition,
395 unsigned memorySpace, Location location) {
396 return getImpl(shape, elementType, affineMapComposition, memorySpace,
397 location);
398 }
399
400 /// Get or create a new MemRefType defined by the arguments. If the resulting
401 /// type would be ill-formed, return nullptr. If the location is provided,
402 /// emit detailed error messages. To emit errors when the location is unknown,
403 /// pass in an instance of UnknownLoc.
getImpl(ArrayRef<int64_t> shape,Type elementType,ArrayRef<AffineMap> affineMapComposition,unsigned memorySpace,Optional<Location> location)404 MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
405 ArrayRef<AffineMap> affineMapComposition,
406 unsigned memorySpace,
407 Optional<Location> location) {
408 auto *context = elementType.getContext();
409
410 if (!BaseMemRefType::isValidElementType(elementType))
411 return emitOptionalError(location, "invalid memref element type"),
412 MemRefType();
413
414 for (int64_t s : shape) {
415 // Negative sizes are not allowed except for `-1` that means dynamic size.
416 if (s < -1)
417 return emitOptionalError(location, "invalid memref size"), MemRefType();
418 }
419
420 // Check that the structure of the composition is valid, i.e. that each
421 // subsequent affine map has as many inputs as the previous map has results.
422 // Take the dimensionality of the MemRef for the first map.
423 auto dim = shape.size();
424 unsigned i = 0;
425 for (const auto &affineMap : affineMapComposition) {
426 if (affineMap.getNumDims() != dim) {
427 if (location)
428 emitError(*location)
429 << "memref affine map dimension mismatch between "
430 << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i))
431 << " and affine map" << i + 1 << ": " << dim
432 << " != " << affineMap.getNumDims();
433 return nullptr;
434 }
435
436 dim = affineMap.getNumResults();
437 ++i;
438 }
439
440 // Drop identity maps from the composition.
441 // This may lead to the composition becoming empty, which is interpreted as an
442 // implicit identity.
443 SmallVector<AffineMap, 2> cleanedAffineMapComposition;
444 for (const auto &map : affineMapComposition) {
445 if (map.isIdentity())
446 continue;
447 cleanedAffineMapComposition.push_back(map);
448 }
449
450 return Base::get(context, shape, elementType, cleanedAffineMapComposition,
451 memorySpace);
452 }
453
getShape() const454 ArrayRef<int64_t> MemRefType::getShape() const { return getImpl()->getShape(); }
455
getAffineMaps() const456 ArrayRef<AffineMap> MemRefType::getAffineMaps() const {
457 return getImpl()->getAffineMaps();
458 }
459
460 //===----------------------------------------------------------------------===//
461 // UnrankedMemRefType
462 //===----------------------------------------------------------------------===//
463
get(Type elementType,unsigned memorySpace)464 UnrankedMemRefType UnrankedMemRefType::get(Type elementType,
465 unsigned memorySpace) {
466 return Base::get(elementType.getContext(), elementType, memorySpace);
467 }
468
getChecked(Type elementType,unsigned memorySpace,Location location)469 UnrankedMemRefType UnrankedMemRefType::getChecked(Type elementType,
470 unsigned memorySpace,
471 Location location) {
472 return Base::getChecked(location, elementType, memorySpace);
473 }
474
475 LogicalResult
verifyConstructionInvariants(Location loc,Type elementType,unsigned memorySpace)476 UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType,
477 unsigned memorySpace) {
478 if (!BaseMemRefType::isValidElementType(elementType))
479 return emitError(loc, "invalid memref element type");
480 return success();
481 }
482
483 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
484 // i.e. single term). Accumulate the AffineExpr into the existing one.
extractStridesFromTerm(AffineExpr e,AffineExpr multiplicativeFactor,MutableArrayRef<AffineExpr> strides,AffineExpr & offset)485 static void extractStridesFromTerm(AffineExpr e,
486 AffineExpr multiplicativeFactor,
487 MutableArrayRef<AffineExpr> strides,
488 AffineExpr &offset) {
489 if (auto dim = e.dyn_cast<AffineDimExpr>())
490 strides[dim.getPosition()] =
491 strides[dim.getPosition()] + multiplicativeFactor;
492 else
493 offset = offset + e * multiplicativeFactor;
494 }
495
496 /// Takes a single AffineExpr `e` and populates the `strides` array with the
497 /// strides expressions for each dim position.
498 /// The convention is that the strides for dimensions d0, .. dn appear in
499 /// order to make indexing intuitive into the result.
extractStrides(AffineExpr e,AffineExpr multiplicativeFactor,MutableArrayRef<AffineExpr> strides,AffineExpr & offset)500 static LogicalResult extractStrides(AffineExpr e,
501 AffineExpr multiplicativeFactor,
502 MutableArrayRef<AffineExpr> strides,
503 AffineExpr &offset) {
504 auto bin = e.dyn_cast<AffineBinaryOpExpr>();
505 if (!bin) {
506 extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
507 return success();
508 }
509
510 if (bin.getKind() == AffineExprKind::CeilDiv ||
511 bin.getKind() == AffineExprKind::FloorDiv ||
512 bin.getKind() == AffineExprKind::Mod)
513 return failure();
514
515 if (bin.getKind() == AffineExprKind::Mul) {
516 auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
517 if (dim) {
518 strides[dim.getPosition()] =
519 strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
520 return success();
521 }
522 // LHS and RHS may both contain complex expressions of dims. Try one path
523 // and if it fails try the other. This is guaranteed to succeed because
524 // only one path may have a `dim`, otherwise this is not an AffineExpr in
525 // the first place.
526 if (bin.getLHS().isSymbolicOrConstant())
527 return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
528 strides, offset);
529 return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
530 strides, offset);
531 }
532
533 if (bin.getKind() == AffineExprKind::Add) {
534 auto res1 =
535 extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
536 auto res2 =
537 extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
538 return success(succeeded(res1) && succeeded(res2));
539 }
540
541 llvm_unreachable("unexpected binary operation");
542 }
543
getStridesAndOffset(MemRefType t,SmallVectorImpl<AffineExpr> & strides,AffineExpr & offset)544 LogicalResult mlir::getStridesAndOffset(MemRefType t,
545 SmallVectorImpl<AffineExpr> &strides,
546 AffineExpr &offset) {
547 auto affineMaps = t.getAffineMaps();
548 // For now strides are only computed on a single affine map with a single
549 // result (i.e. the closed subset of linearization maps that are compatible
550 // with striding semantics).
551 // TODO: support more forms on a per-need basis.
552 if (affineMaps.size() > 1)
553 return failure();
554 if (affineMaps.size() == 1 && affineMaps[0].getNumResults() != 1)
555 return failure();
556
557 auto zero = getAffineConstantExpr(0, t.getContext());
558 auto one = getAffineConstantExpr(1, t.getContext());
559 offset = zero;
560 strides.assign(t.getRank(), zero);
561
562 AffineMap m;
563 if (!affineMaps.empty()) {
564 m = affineMaps.front();
565 assert(!m.isIdentity() && "unexpected identity map");
566 }
567
568 // Canonical case for empty map.
569 if (!m) {
570 // 0-D corner case, offset is already 0.
571 if (t.getRank() == 0)
572 return success();
573 auto stridedExpr =
574 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
575 if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
576 return success();
577 assert(false && "unexpected failure: extract strides in canonical layout");
578 }
579
580 // Non-canonical case requires more work.
581 auto stridedExpr =
582 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
583 if (failed(extractStrides(stridedExpr, one, strides, offset))) {
584 offset = AffineExpr();
585 strides.clear();
586 return failure();
587 }
588
589 // Simplify results to allow folding to constants and simple checks.
590 unsigned numDims = m.getNumDims();
591 unsigned numSymbols = m.getNumSymbols();
592 offset = simplifyAffineExpr(offset, numDims, numSymbols);
593 for (auto &stride : strides)
594 stride = simplifyAffineExpr(stride, numDims, numSymbols);
595
596 /// In practice, a strided memref must be internally non-aliasing. Test
597 /// against 0 as a proxy.
598 /// TODO: static cases can have more advanced checks.
599 /// TODO: dynamic cases would require a way to compare symbolic
600 /// expressions and would probably need an affine set context propagated
601 /// everywhere.
602 if (llvm::any_of(strides, [](AffineExpr e) {
603 return e == getAffineConstantExpr(0, e.getContext());
604 })) {
605 offset = AffineExpr();
606 strides.clear();
607 return failure();
608 }
609
610 return success();
611 }
612
getStridesAndOffset(MemRefType t,SmallVectorImpl<int64_t> & strides,int64_t & offset)613 LogicalResult mlir::getStridesAndOffset(MemRefType t,
614 SmallVectorImpl<int64_t> &strides,
615 int64_t &offset) {
616 AffineExpr offsetExpr;
617 SmallVector<AffineExpr, 4> strideExprs;
618 if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
619 return failure();
620 if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>())
621 offset = cst.getValue();
622 else
623 offset = ShapedType::kDynamicStrideOrOffset;
624 for (auto e : strideExprs) {
625 if (auto c = e.dyn_cast<AffineConstantExpr>())
626 strides.push_back(c.getValue());
627 else
628 strides.push_back(ShapedType::kDynamicStrideOrOffset);
629 }
630 return success();
631 }
632
633 //===----------------------------------------------------------------------===//
634 /// TupleType
635 //===----------------------------------------------------------------------===//
636
637 /// Get or create a new TupleType with the provided element types. Assumes the
638 /// arguments define a well-formed type.
get(TypeRange elementTypes,MLIRContext * context)639 TupleType TupleType::get(TypeRange elementTypes, MLIRContext *context) {
640 return Base::get(context, elementTypes);
641 }
642
643 /// Get or create an empty tuple type.
get(MLIRContext * context)644 TupleType TupleType::get(MLIRContext *context) { return get({}, context); }
645
646 /// Return the elements types for this tuple.
getTypes() const647 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
648
649 /// Accumulate the types contained in this tuple and tuples nested within it.
650 /// Note that this only flattens nested tuples, not any other container type,
651 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
652 /// (i32, tensor<i32>, f32, i64)
getFlattenedTypes(SmallVectorImpl<Type> & types)653 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
654 for (Type type : getTypes()) {
655 if (auto nestedTuple = type.dyn_cast<TupleType>())
656 nestedTuple.getFlattenedTypes(types);
657 else
658 types.push_back(type);
659 }
660 }
661
662 /// Return the number of element types.
size() const663 size_t TupleType::size() const { return getImpl()->size(); }
664
makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,int64_t offset,MLIRContext * context)665 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
666 int64_t offset,
667 MLIRContext *context) {
668 AffineExpr expr;
669 unsigned nSymbols = 0;
670
671 // AffineExpr for offset.
672 // Static case.
673 if (offset != MemRefType::getDynamicStrideOrOffset()) {
674 auto cst = getAffineConstantExpr(offset, context);
675 expr = cst;
676 } else {
677 // Dynamic case, new symbol for the offset.
678 auto sym = getAffineSymbolExpr(nSymbols++, context);
679 expr = sym;
680 }
681
682 // AffineExpr for strides.
683 for (auto en : llvm::enumerate(strides)) {
684 auto dim = en.index();
685 auto stride = en.value();
686 assert(stride != 0 && "Invalid stride specification");
687 auto d = getAffineDimExpr(dim, context);
688 AffineExpr mult;
689 // Static case.
690 if (stride != MemRefType::getDynamicStrideOrOffset())
691 mult = getAffineConstantExpr(stride, context);
692 else
693 // Dynamic case, new symbol for each new stride.
694 mult = getAffineSymbolExpr(nSymbols++, context);
695 expr = expr + d * mult;
696 }
697
698 return AffineMap::get(strides.size(), nSymbols, expr);
699 }
700
701 /// Return a version of `t` with identity layout if it can be determined
702 /// statically that the layout is the canonical contiguous strided layout.
703 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
704 /// `t` with simplified layout.
705 /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
canonicalizeStridedLayout(MemRefType t)706 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
707 auto affineMaps = t.getAffineMaps();
708 // Already in canonical form.
709 if (affineMaps.empty())
710 return t;
711
712 // Can't reduce to canonical identity form, return in canonical form.
713 if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1)
714 return t;
715
716 // If the canonical strided layout for the sizes of `t` is equal to the
717 // simplified layout of `t` we can just return an empty layout. Otherwise,
718 // just simplify the existing layout.
719 AffineExpr expr =
720 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
721 auto m = affineMaps[0];
722 auto simplifiedLayoutExpr =
723 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
724 if (expr != simplifiedLayoutExpr)
725 return MemRefType::Builder(t).setAffineMaps({AffineMap::get(
726 m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)});
727 return MemRefType::Builder(t).setAffineMaps({});
728 }
729
makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,ArrayRef<AffineExpr> exprs,MLIRContext * context)730 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
731 ArrayRef<AffineExpr> exprs,
732 MLIRContext *context) {
733 // Size 0 corner case is useful for canonicalizations.
734 if (llvm::is_contained(sizes, 0))
735 return getAffineConstantExpr(0, context);
736
737 auto maps = AffineMap::inferFromExprList(exprs);
738 assert(!maps.empty() && "Expected one non-empty map");
739 unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
740
741 AffineExpr expr;
742 bool dynamicPoisonBit = false;
743 int64_t runningSize = 1;
744 for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
745 int64_t size = std::get<1>(en);
746 // Degenerate case, no size =-> no stride
747 if (size == 0)
748 continue;
749 AffineExpr dimExpr = std::get<0>(en);
750 AffineExpr stride = dynamicPoisonBit
751 ? getAffineSymbolExpr(nSymbols++, context)
752 : getAffineConstantExpr(runningSize, context);
753 expr = expr ? expr + dimExpr * stride : dimExpr * stride;
754 if (size > 0)
755 runningSize *= size;
756 else
757 dynamicPoisonBit = true;
758 }
759 return simplifyAffineExpr(expr, numDims, nSymbols);
760 }
761
762 /// Return a version of `t` with a layout that has all dynamic offset and
763 /// strides. This is used to erase the static layout.
eraseStridedLayout(MemRefType t)764 MemRefType mlir::eraseStridedLayout(MemRefType t) {
765 auto val = ShapedType::kDynamicStrideOrOffset;
766 return MemRefType::Builder(t).setAffineMaps(makeStridedLinearLayoutMap(
767 SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext()));
768 }
769
makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,MLIRContext * context)770 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
771 MLIRContext *context) {
772 SmallVector<AffineExpr, 4> exprs;
773 exprs.reserve(sizes.size());
774 for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
775 exprs.push_back(getAffineDimExpr(dim, context));
776 return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
777 }
778
779 /// Return true if the layout for `t` is compatible with strided semantics.
isStrided(MemRefType t)780 bool mlir::isStrided(MemRefType t) {
781 int64_t offset;
782 SmallVector<int64_t, 4> stridesAndOffset;
783 auto res = getStridesAndOffset(t, stridesAndOffset, offset);
784 return succeeded(res);
785 }
786