1 //===- StandardTypes.cpp - MLIR Standard Type Classes ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/IR/StandardTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
13 #include "mlir/IR/Diagnostics.h"
14 #include "llvm/ADT/APFloat.h"
15 #include "llvm/ADT/Twine.h"
16
17 using namespace mlir;
18 using namespace mlir::detail;
19
20 //===----------------------------------------------------------------------===//
21 // Type
22 //===----------------------------------------------------------------------===//
23
isBF16()24 bool Type::isBF16() { return getKind() == StandardTypes::BF16; }
isF16()25 bool Type::isF16() { return getKind() == StandardTypes::F16; }
isF32()26 bool Type::isF32() { return getKind() == StandardTypes::F32; }
isF64()27 bool Type::isF64() { return getKind() == StandardTypes::F64; }
28
isIndex()29 bool Type::isIndex() { return isa<IndexType>(); }
30
31 /// Return true if this is an integer type with the specified width.
isInteger(unsigned width)32 bool Type::isInteger(unsigned width) {
33 if (auto intTy = dyn_cast<IntegerType>())
34 return intTy.getWidth() == width;
35 return false;
36 }
37
isSignlessInteger()38 bool Type::isSignlessInteger() {
39 if (auto intTy = dyn_cast<IntegerType>())
40 return intTy.isSignless();
41 return false;
42 }
43
isSignlessInteger(unsigned width)44 bool Type::isSignlessInteger(unsigned width) {
45 if (auto intTy = dyn_cast<IntegerType>())
46 return intTy.isSignless() && intTy.getWidth() == width;
47 return false;
48 }
49
isSignedInteger()50 bool Type::isSignedInteger() {
51 if (auto intTy = dyn_cast<IntegerType>())
52 return intTy.isSigned();
53 return false;
54 }
55
isSignedInteger(unsigned width)56 bool Type::isSignedInteger(unsigned width) {
57 if (auto intTy = dyn_cast<IntegerType>())
58 return intTy.isSigned() && intTy.getWidth() == width;
59 return false;
60 }
61
isUnsignedInteger()62 bool Type::isUnsignedInteger() {
63 if (auto intTy = dyn_cast<IntegerType>())
64 return intTy.isUnsigned();
65 return false;
66 }
67
isUnsignedInteger(unsigned width)68 bool Type::isUnsignedInteger(unsigned width) {
69 if (auto intTy = dyn_cast<IntegerType>())
70 return intTy.isUnsigned() && intTy.getWidth() == width;
71 return false;
72 }
73
isSignlessIntOrIndex()74 bool Type::isSignlessIntOrIndex() {
75 return isSignlessInteger() || isa<IndexType>();
76 }
77
isSignlessIntOrIndexOrFloat()78 bool Type::isSignlessIntOrIndexOrFloat() {
79 return isSignlessInteger() || isa<IndexType, FloatType>();
80 }
81
isSignlessIntOrFloat()82 bool Type::isSignlessIntOrFloat() {
83 return isSignlessInteger() || isa<FloatType>();
84 }
85
isIntOrIndex()86 bool Type::isIntOrIndex() { return isa<IntegerType>() || isIndex(); }
87
isIntOrFloat()88 bool Type::isIntOrFloat() { return isa<IntegerType, FloatType>(); }
89
isIntOrIndexOrFloat()90 bool Type::isIntOrIndexOrFloat() { return isIntOrFloat() || isIndex(); }
91
92 //===----------------------------------------------------------------------===//
93 /// ComplexType
94 //===----------------------------------------------------------------------===//
95
get(Type elementType)96 ComplexType ComplexType::get(Type elementType) {
97 return Base::get(elementType.getContext(), StandardTypes::Complex,
98 elementType);
99 }
100
getChecked(Type elementType,Location location)101 ComplexType ComplexType::getChecked(Type elementType, Location location) {
102 return Base::getChecked(location, StandardTypes::Complex, elementType);
103 }
104
105 /// Verify the construction of an integer type.
verifyConstructionInvariants(Location loc,Type elementType)106 LogicalResult ComplexType::verifyConstructionInvariants(Location loc,
107 Type elementType) {
108 if (!elementType.isIntOrFloat())
109 return emitError(loc, "invalid element type for complex");
110 return success();
111 }
112
getElementType()113 Type ComplexType::getElementType() { return getImpl()->elementType; }
114
115 //===----------------------------------------------------------------------===//
116 // Integer Type
117 //===----------------------------------------------------------------------===//
118
119 // static constexpr must have a definition (until in C++17 and inline variable).
120 constexpr unsigned IntegerType::kMaxWidth;
121
122 /// Verify the construction of an integer type.
123 LogicalResult
verifyConstructionInvariants(Location loc,unsigned width,SignednessSemantics signedness)124 IntegerType::verifyConstructionInvariants(Location loc, unsigned width,
125 SignednessSemantics signedness) {
126 if (width > IntegerType::kMaxWidth) {
127 return emitError(loc) << "integer bitwidth is limited to "
128 << IntegerType::kMaxWidth << " bits";
129 }
130 return success();
131 }
132
getWidth() const133 unsigned IntegerType::getWidth() const { return getImpl()->getWidth(); }
134
getSignedness() const135 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
136 return getImpl()->getSignedness();
137 }
138
139 //===----------------------------------------------------------------------===//
140 // Float Type
141 //===----------------------------------------------------------------------===//
142
getWidth()143 unsigned FloatType::getWidth() {
144 switch (getKind()) {
145 case StandardTypes::BF16:
146 case StandardTypes::F16:
147 return 16;
148 case StandardTypes::F32:
149 return 32;
150 case StandardTypes::F64:
151 return 64;
152 default:
153 llvm_unreachable("unexpected type");
154 }
155 }
156
157 /// Returns the floating semantics for the given type.
getFloatSemantics()158 const llvm::fltSemantics &FloatType::getFloatSemantics() {
159 if (isBF16())
160 return APFloat::BFloat();
161 if (isF16())
162 return APFloat::IEEEhalf();
163 if (isF32())
164 return APFloat::IEEEsingle();
165 if (isF64())
166 return APFloat::IEEEdouble();
167 llvm_unreachable("non-floating point type used");
168 }
169
getIntOrFloatBitWidth()170 unsigned Type::getIntOrFloatBitWidth() {
171 assert(isIntOrFloat() && "only integers and floats have a bitwidth");
172 if (auto intType = dyn_cast<IntegerType>())
173 return intType.getWidth();
174 return cast<FloatType>().getWidth();
175 }
176
177 //===----------------------------------------------------------------------===//
178 // ShapedType
179 //===----------------------------------------------------------------------===//
180 constexpr int64_t ShapedType::kDynamicSize;
181 constexpr int64_t ShapedType::kDynamicStrideOrOffset;
182
getElementType() const183 Type ShapedType::getElementType() const {
184 return static_cast<ImplType *>(impl)->elementType;
185 }
186
getElementTypeBitWidth() const187 unsigned ShapedType::getElementTypeBitWidth() const {
188 return getElementType().getIntOrFloatBitWidth();
189 }
190
getNumElements() const191 int64_t ShapedType::getNumElements() const {
192 assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
193 auto shape = getShape();
194 int64_t num = 1;
195 for (auto dim : shape)
196 num *= dim;
197 return num;
198 }
199
getRank() const200 int64_t ShapedType::getRank() const { return getShape().size(); }
201
hasRank() const202 bool ShapedType::hasRank() const {
203 return !isa<UnrankedMemRefType, UnrankedTensorType>();
204 }
205
getDimSize(unsigned idx) const206 int64_t ShapedType::getDimSize(unsigned idx) const {
207 assert(idx < getRank() && "invalid index for shaped type");
208 return getShape()[idx];
209 }
210
isDynamicDim(unsigned idx) const211 bool ShapedType::isDynamicDim(unsigned idx) const {
212 assert(idx < getRank() && "invalid index for shaped type");
213 return isDynamic(getShape()[idx]);
214 }
215
getDynamicDimIndex(unsigned index) const216 unsigned ShapedType::getDynamicDimIndex(unsigned index) const {
217 assert(index < getRank() && "invalid index");
218 assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index");
219 return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic);
220 }
221
222 /// Get the number of bits require to store a value of the given shaped type.
223 /// Compute the value recursively since tensors are allowed to have vectors as
224 /// elements.
getSizeInBits() const225 int64_t ShapedType::getSizeInBits() const {
226 assert(hasStaticShape() &&
227 "cannot get the bit size of an aggregate with a dynamic shape");
228
229 auto elementType = getElementType();
230 if (elementType.isIntOrFloat())
231 return elementType.getIntOrFloatBitWidth() * getNumElements();
232
233 // Tensors can have vectors and other tensors as elements, other shaped types
234 // cannot.
235 assert(isa<TensorType>() && "unsupported element type");
236 assert((elementType.isa<VectorType, TensorType>()) &&
237 "unsupported tensor element type");
238 return getNumElements() * elementType.cast<ShapedType>().getSizeInBits();
239 }
240
getShape() const241 ArrayRef<int64_t> ShapedType::getShape() const {
242 switch (getKind()) {
243 case StandardTypes::Vector:
244 return cast<VectorType>().getShape();
245 case StandardTypes::RankedTensor:
246 return cast<RankedTensorType>().getShape();
247 case StandardTypes::MemRef:
248 return cast<MemRefType>().getShape();
249 default:
250 llvm_unreachable("not a ShapedType or not ranked");
251 }
252 }
253
getNumDynamicDims() const254 int64_t ShapedType::getNumDynamicDims() const {
255 return llvm::count_if(getShape(), isDynamic);
256 }
257
hasStaticShape() const258 bool ShapedType::hasStaticShape() const {
259 return hasRank() && llvm::none_of(getShape(), isDynamic);
260 }
261
hasStaticShape(ArrayRef<int64_t> shape) const262 bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const {
263 return hasStaticShape() && getShape() == shape;
264 }
265
266 //===----------------------------------------------------------------------===//
267 // VectorType
268 //===----------------------------------------------------------------------===//
269
get(ArrayRef<int64_t> shape,Type elementType)270 VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) {
271 return Base::get(elementType.getContext(), StandardTypes::Vector, shape,
272 elementType);
273 }
274
getChecked(ArrayRef<int64_t> shape,Type elementType,Location location)275 VectorType VectorType::getChecked(ArrayRef<int64_t> shape, Type elementType,
276 Location location) {
277 return Base::getChecked(location, StandardTypes::Vector, shape, elementType);
278 }
279
verifyConstructionInvariants(Location loc,ArrayRef<int64_t> shape,Type elementType)280 LogicalResult VectorType::verifyConstructionInvariants(Location loc,
281 ArrayRef<int64_t> shape,
282 Type elementType) {
283 if (shape.empty())
284 return emitError(loc, "vector types must have at least one dimension");
285
286 if (!isValidElementType(elementType))
287 return emitError(loc, "vector elements must be int or float type");
288
289 if (any_of(shape, [](int64_t i) { return i <= 0; }))
290 return emitError(loc, "vector types must have positive constant sizes");
291
292 return success();
293 }
294
getShape() const295 ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); }
296
297 //===----------------------------------------------------------------------===//
298 // TensorType
299 //===----------------------------------------------------------------------===//
300
301 // Check if "elementType" can be an element type of a tensor. Emit errors if
302 // location is not nullptr. Returns failure if check failed.
checkTensorElementType(Location location,Type elementType)303 static inline LogicalResult checkTensorElementType(Location location,
304 Type elementType) {
305 if (!TensorType::isValidElementType(elementType))
306 return emitError(location, "invalid tensor element type");
307 return success();
308 }
309
310 //===----------------------------------------------------------------------===//
311 // RankedTensorType
312 //===----------------------------------------------------------------------===//
313
get(ArrayRef<int64_t> shape,Type elementType)314 RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape,
315 Type elementType) {
316 return Base::get(elementType.getContext(), StandardTypes::RankedTensor, shape,
317 elementType);
318 }
319
getChecked(ArrayRef<int64_t> shape,Type elementType,Location location)320 RankedTensorType RankedTensorType::getChecked(ArrayRef<int64_t> shape,
321 Type elementType,
322 Location location) {
323 return Base::getChecked(location, StandardTypes::RankedTensor, shape,
324 elementType);
325 }
326
verifyConstructionInvariants(Location loc,ArrayRef<int64_t> shape,Type elementType)327 LogicalResult RankedTensorType::verifyConstructionInvariants(
328 Location loc, ArrayRef<int64_t> shape, Type elementType) {
329 for (int64_t s : shape) {
330 if (s < -1)
331 return emitError(loc, "invalid tensor dimension size");
332 }
333 return checkTensorElementType(loc, elementType);
334 }
335
getShape() const336 ArrayRef<int64_t> RankedTensorType::getShape() const {
337 return getImpl()->getShape();
338 }
339
340 //===----------------------------------------------------------------------===//
341 // UnrankedTensorType
342 //===----------------------------------------------------------------------===//
343
get(Type elementType)344 UnrankedTensorType UnrankedTensorType::get(Type elementType) {
345 return Base::get(elementType.getContext(), StandardTypes::UnrankedTensor,
346 elementType);
347 }
348
getChecked(Type elementType,Location location)349 UnrankedTensorType UnrankedTensorType::getChecked(Type elementType,
350 Location location) {
351 return Base::getChecked(location, StandardTypes::UnrankedTensor, elementType);
352 }
353
354 LogicalResult
verifyConstructionInvariants(Location loc,Type elementType)355 UnrankedTensorType::verifyConstructionInvariants(Location loc,
356 Type elementType) {
357 return checkTensorElementType(loc, elementType);
358 }
359
360 //===----------------------------------------------------------------------===//
361 // MemRefType
362 //===----------------------------------------------------------------------===//
363
364 /// Get or create a new MemRefType based on shape, element type, affine
365 /// map composition, and memory space. Assumes the arguments define a
366 /// well-formed MemRef type. Use getChecked to gracefully handle MemRefType
367 /// construction failures.
get(ArrayRef<int64_t> shape,Type elementType,ArrayRef<AffineMap> affineMapComposition,unsigned memorySpace)368 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
369 ArrayRef<AffineMap> affineMapComposition,
370 unsigned memorySpace) {
371 auto result = getImpl(shape, elementType, affineMapComposition, memorySpace,
372 /*location=*/llvm::None);
373 assert(result && "Failed to construct instance of MemRefType.");
374 return result;
375 }
376
377 /// Get or create a new MemRefType based on shape, element type, affine
378 /// map composition, and memory space declared at the given location.
379 /// If the location is unknown, the last argument should be an instance of
380 /// UnknownLoc. If the MemRefType defined by the arguments would be
381 /// ill-formed, emits errors (to the handler registered with the context or to
382 /// the error stream) and returns nullptr.
getChecked(ArrayRef<int64_t> shape,Type elementType,ArrayRef<AffineMap> affineMapComposition,unsigned memorySpace,Location location)383 MemRefType MemRefType::getChecked(ArrayRef<int64_t> shape, Type elementType,
384 ArrayRef<AffineMap> affineMapComposition,
385 unsigned memorySpace, Location location) {
386 return getImpl(shape, elementType, affineMapComposition, memorySpace,
387 location);
388 }
389
390 /// Get or create a new MemRefType defined by the arguments. If the resulting
391 /// type would be ill-formed, return nullptr. If the location is provided,
392 /// emit detailed error messages. To emit errors when the location is unknown,
393 /// pass in an instance of UnknownLoc.
getImpl(ArrayRef<int64_t> shape,Type elementType,ArrayRef<AffineMap> affineMapComposition,unsigned memorySpace,Optional<Location> location)394 MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
395 ArrayRef<AffineMap> affineMapComposition,
396 unsigned memorySpace,
397 Optional<Location> location) {
398 auto *context = elementType.getContext();
399
400 // Check that memref is formed from allowed types.
401 if (!elementType.isIntOrFloat() &&
402 !elementType.isa<VectorType, ComplexType>())
403 return emitOptionalError(location, "invalid memref element type"),
404 MemRefType();
405
406 for (int64_t s : shape) {
407 // Negative sizes are not allowed except for `-1` that means dynamic size.
408 if (s < -1)
409 return emitOptionalError(location, "invalid memref size"), MemRefType();
410 }
411
412 // Check that the structure of the composition is valid, i.e. that each
413 // subsequent affine map has as many inputs as the previous map has results.
414 // Take the dimensionality of the MemRef for the first map.
415 auto dim = shape.size();
416 unsigned i = 0;
417 for (const auto &affineMap : affineMapComposition) {
418 if (affineMap.getNumDims() != dim) {
419 if (location)
420 emitError(*location)
421 << "memref affine map dimension mismatch between "
422 << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i))
423 << " and affine map" << i + 1 << ": " << dim
424 << " != " << affineMap.getNumDims();
425 return nullptr;
426 }
427
428 dim = affineMap.getNumResults();
429 ++i;
430 }
431
432 // Drop identity maps from the composition.
433 // This may lead to the composition becoming empty, which is interpreted as an
434 // implicit identity.
435 SmallVector<AffineMap, 2> cleanedAffineMapComposition;
436 for (const auto &map : affineMapComposition) {
437 if (map.isIdentity())
438 continue;
439 cleanedAffineMapComposition.push_back(map);
440 }
441
442 return Base::get(context, StandardTypes::MemRef, shape, elementType,
443 cleanedAffineMapComposition, memorySpace);
444 }
445
getShape() const446 ArrayRef<int64_t> MemRefType::getShape() const { return getImpl()->getShape(); }
447
getAffineMaps() const448 ArrayRef<AffineMap> MemRefType::getAffineMaps() const {
449 return getImpl()->getAffineMaps();
450 }
451
getMemorySpace() const452 unsigned MemRefType::getMemorySpace() const { return getImpl()->memorySpace; }
453
454 //===----------------------------------------------------------------------===//
455 // UnrankedMemRefType
456 //===----------------------------------------------------------------------===//
457
get(Type elementType,unsigned memorySpace)458 UnrankedMemRefType UnrankedMemRefType::get(Type elementType,
459 unsigned memorySpace) {
460 return Base::get(elementType.getContext(), StandardTypes::UnrankedMemRef,
461 elementType, memorySpace);
462 }
463
getChecked(Type elementType,unsigned memorySpace,Location location)464 UnrankedMemRefType UnrankedMemRefType::getChecked(Type elementType,
465 unsigned memorySpace,
466 Location location) {
467 return Base::getChecked(location, StandardTypes::UnrankedMemRef, elementType,
468 memorySpace);
469 }
470
getMemorySpace() const471 unsigned UnrankedMemRefType::getMemorySpace() const {
472 return getImpl()->memorySpace;
473 }
474
475 LogicalResult
verifyConstructionInvariants(Location loc,Type elementType,unsigned memorySpace)476 UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType,
477 unsigned memorySpace) {
478 // Check that memref is formed from allowed types.
479 if (!elementType.isIntOrFloat() &&
480 !elementType.isa<VectorType, ComplexType>())
481 return emitError(loc, "invalid memref element type");
482 return success();
483 }
484
485 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
486 // i.e. single term). Accumulate the AffineExpr into the existing one.
extractStridesFromTerm(AffineExpr e,AffineExpr multiplicativeFactor,MutableArrayRef<AffineExpr> strides,AffineExpr & offset)487 static void extractStridesFromTerm(AffineExpr e,
488 AffineExpr multiplicativeFactor,
489 MutableArrayRef<AffineExpr> strides,
490 AffineExpr &offset) {
491 if (auto dim = e.dyn_cast<AffineDimExpr>())
492 strides[dim.getPosition()] =
493 strides[dim.getPosition()] + multiplicativeFactor;
494 else
495 offset = offset + e * multiplicativeFactor;
496 }
497
498 /// Takes a single AffineExpr `e` and populates the `strides` array with the
499 /// strides expressions for each dim position.
500 /// The convention is that the strides for dimensions d0, .. dn appear in
501 /// order to make indexing intuitive into the result.
extractStrides(AffineExpr e,AffineExpr multiplicativeFactor,MutableArrayRef<AffineExpr> strides,AffineExpr & offset)502 static LogicalResult extractStrides(AffineExpr e,
503 AffineExpr multiplicativeFactor,
504 MutableArrayRef<AffineExpr> strides,
505 AffineExpr &offset) {
506 auto bin = e.dyn_cast<AffineBinaryOpExpr>();
507 if (!bin) {
508 extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
509 return success();
510 }
511
512 if (bin.getKind() == AffineExprKind::CeilDiv ||
513 bin.getKind() == AffineExprKind::FloorDiv ||
514 bin.getKind() == AffineExprKind::Mod)
515 return failure();
516
517 if (bin.getKind() == AffineExprKind::Mul) {
518 auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
519 if (dim) {
520 strides[dim.getPosition()] =
521 strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
522 return success();
523 }
524 // LHS and RHS may both contain complex expressions of dims. Try one path
525 // and if it fails try the other. This is guaranteed to succeed because
526 // only one path may have a `dim`, otherwise this is not an AffineExpr in
527 // the first place.
528 if (bin.getLHS().isSymbolicOrConstant())
529 return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
530 strides, offset);
531 return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
532 strides, offset);
533 }
534
535 if (bin.getKind() == AffineExprKind::Add) {
536 auto res1 =
537 extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
538 auto res2 =
539 extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
540 return success(succeeded(res1) && succeeded(res2));
541 }
542
543 llvm_unreachable("unexpected binary operation");
544 }
545
getStridesAndOffset(MemRefType t,SmallVectorImpl<AffineExpr> & strides,AffineExpr & offset)546 LogicalResult mlir::getStridesAndOffset(MemRefType t,
547 SmallVectorImpl<AffineExpr> &strides,
548 AffineExpr &offset) {
549 auto affineMaps = t.getAffineMaps();
550 // For now strides are only computed on a single affine map with a single
551 // result (i.e. the closed subset of linearization maps that are compatible
552 // with striding semantics).
553 // TODO: support more forms on a per-need basis.
554 if (affineMaps.size() > 1)
555 return failure();
556 if (affineMaps.size() == 1 && affineMaps[0].getNumResults() != 1)
557 return failure();
558
559 auto zero = getAffineConstantExpr(0, t.getContext());
560 auto one = getAffineConstantExpr(1, t.getContext());
561 offset = zero;
562 strides.assign(t.getRank(), zero);
563
564 AffineMap m;
565 if (!affineMaps.empty()) {
566 m = affineMaps.front();
567 assert(!m.isIdentity() && "unexpected identity map");
568 }
569
570 // Canonical case for empty map.
571 if (!m) {
572 // 0-D corner case, offset is already 0.
573 if (t.getRank() == 0)
574 return success();
575 auto stridedExpr =
576 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
577 if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
578 return success();
579 assert(false && "unexpected failure: extract strides in canonical layout");
580 }
581
582 // Non-canonical case requires more work.
583 auto stridedExpr =
584 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
585 if (failed(extractStrides(stridedExpr, one, strides, offset))) {
586 offset = AffineExpr();
587 strides.clear();
588 return failure();
589 }
590
591 // Simplify results to allow folding to constants and simple checks.
592 unsigned numDims = m.getNumDims();
593 unsigned numSymbols = m.getNumSymbols();
594 offset = simplifyAffineExpr(offset, numDims, numSymbols);
595 for (auto &stride : strides)
596 stride = simplifyAffineExpr(stride, numDims, numSymbols);
597
598 /// In practice, a strided memref must be internally non-aliasing. Test
599 /// against 0 as a proxy.
600 /// TODO: static cases can have more advanced checks.
601 /// TODO: dynamic cases would require a way to compare symbolic
602 /// expressions and would probably need an affine set context propagated
603 /// everywhere.
604 if (llvm::any_of(strides, [](AffineExpr e) {
605 return e == getAffineConstantExpr(0, e.getContext());
606 })) {
607 offset = AffineExpr();
608 strides.clear();
609 return failure();
610 }
611
612 return success();
613 }
614
getStridesAndOffset(MemRefType t,SmallVectorImpl<int64_t> & strides,int64_t & offset)615 LogicalResult mlir::getStridesAndOffset(MemRefType t,
616 SmallVectorImpl<int64_t> &strides,
617 int64_t &offset) {
618 AffineExpr offsetExpr;
619 SmallVector<AffineExpr, 4> strideExprs;
620 if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
621 return failure();
622 if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>())
623 offset = cst.getValue();
624 else
625 offset = ShapedType::kDynamicStrideOrOffset;
626 for (auto e : strideExprs) {
627 if (auto c = e.dyn_cast<AffineConstantExpr>())
628 strides.push_back(c.getValue());
629 else
630 strides.push_back(ShapedType::kDynamicStrideOrOffset);
631 }
632 return success();
633 }
634
635 //===----------------------------------------------------------------------===//
636 /// TupleType
637 //===----------------------------------------------------------------------===//
638
639 /// Get or create a new TupleType with the provided element types. Assumes the
640 /// arguments define a well-formed type.
get(ArrayRef<Type> elementTypes,MLIRContext * context)641 TupleType TupleType::get(ArrayRef<Type> elementTypes, MLIRContext *context) {
642 return Base::get(context, StandardTypes::Tuple, elementTypes);
643 }
644
645 /// Return the elements types for this tuple.
getTypes() const646 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
647
648 /// Accumulate the types contained in this tuple and tuples nested within it.
649 /// Note that this only flattens nested tuples, not any other container type,
650 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
651 /// (i32, tensor<i32>, f32, i64)
getFlattenedTypes(SmallVectorImpl<Type> & types)652 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
653 for (Type type : getTypes()) {
654 if (auto nestedTuple = type.dyn_cast<TupleType>())
655 nestedTuple.getFlattenedTypes(types);
656 else
657 types.push_back(type);
658 }
659 }
660
661 /// Return the number of element types.
size() const662 size_t TupleType::size() const { return getImpl()->size(); }
663
makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,int64_t offset,MLIRContext * context)664 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
665 int64_t offset,
666 MLIRContext *context) {
667 AffineExpr expr;
668 unsigned nSymbols = 0;
669
670 // AffineExpr for offset.
671 // Static case.
672 if (offset != MemRefType::getDynamicStrideOrOffset()) {
673 auto cst = getAffineConstantExpr(offset, context);
674 expr = cst;
675 } else {
676 // Dynamic case, new symbol for the offset.
677 auto sym = getAffineSymbolExpr(nSymbols++, context);
678 expr = sym;
679 }
680
681 // AffineExpr for strides.
682 for (auto en : llvm::enumerate(strides)) {
683 auto dim = en.index();
684 auto stride = en.value();
685 assert(stride != 0 && "Invalid stride specification");
686 auto d = getAffineDimExpr(dim, context);
687 AffineExpr mult;
688 // Static case.
689 if (stride != MemRefType::getDynamicStrideOrOffset())
690 mult = getAffineConstantExpr(stride, context);
691 else
692 // Dynamic case, new symbol for each new stride.
693 mult = getAffineSymbolExpr(nSymbols++, context);
694 expr = expr + d * mult;
695 }
696
697 return AffineMap::get(strides.size(), nSymbols, expr);
698 }
699
700 /// Return a version of `t` with identity layout if it can be determined
701 /// statically that the layout is the canonical contiguous strided layout.
702 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
703 /// `t` with simplified layout.
704 /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
canonicalizeStridedLayout(MemRefType t)705 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
706 auto affineMaps = t.getAffineMaps();
707 // Already in canonical form.
708 if (affineMaps.empty())
709 return t;
710
711 // Can't reduce to canonical identity form, return in canonical form.
712 if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1)
713 return t;
714
715 // If the canonical strided layout for the sizes of `t` is equal to the
716 // simplified layout of `t` we can just return an empty layout. Otherwise,
717 // just simplify the existing layout.
718 AffineExpr expr =
719 makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
720 auto m = affineMaps[0];
721 auto simplifiedLayoutExpr =
722 simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
723 if (expr != simplifiedLayoutExpr)
724 return MemRefType::Builder(t).setAffineMaps({AffineMap::get(
725 m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)});
726 return MemRefType::Builder(t).setAffineMaps({});
727 }
728
makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,ArrayRef<AffineExpr> exprs,MLIRContext * context)729 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
730 ArrayRef<AffineExpr> exprs,
731 MLIRContext *context) {
732 AffineExpr expr;
733 bool dynamicPoisonBit = false;
734 unsigned numDims = 0;
735 unsigned nSymbols = 0;
736 // Compute the number of symbols and dimensions of the passed exprs.
737 for (AffineExpr expr : exprs) {
738 expr.walk([&numDims, &nSymbols](AffineExpr d) {
739 if (AffineDimExpr dim = d.dyn_cast<AffineDimExpr>())
740 numDims = std::max(numDims, dim.getPosition() + 1);
741 else if (AffineSymbolExpr symbol = d.dyn_cast<AffineSymbolExpr>())
742 nSymbols = std::max(nSymbols, symbol.getPosition() + 1);
743 });
744 }
745 int64_t runningSize = 1;
746 for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
747 int64_t size = std::get<1>(en);
748 // Degenerate case, no size =-> no stride
749 if (size == 0)
750 continue;
751 AffineExpr dimExpr = std::get<0>(en);
752 AffineExpr stride = dynamicPoisonBit
753 ? getAffineSymbolExpr(nSymbols++, context)
754 : getAffineConstantExpr(runningSize, context);
755 expr = expr ? expr + dimExpr * stride : dimExpr * stride;
756 if (size > 0)
757 runningSize *= size;
758 else
759 dynamicPoisonBit = true;
760 }
761 return simplifyAffineExpr(expr, numDims, nSymbols);
762 }
763
764 /// Return a version of `t` with a layout that has all dynamic offset and
765 /// strides. This is used to erase the static layout.
eraseStridedLayout(MemRefType t)766 MemRefType mlir::eraseStridedLayout(MemRefType t) {
767 auto val = ShapedType::kDynamicStrideOrOffset;
768 return MemRefType::Builder(t).setAffineMaps(makeStridedLinearLayoutMap(
769 SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext()));
770 }
771
makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,MLIRContext * context)772 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
773 MLIRContext *context) {
774 SmallVector<AffineExpr, 4> exprs;
775 exprs.reserve(sizes.size());
776 for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
777 exprs.push_back(getAffineDimExpr(dim, context));
778 return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
779 }
780
781 /// Return true if the layout for `t` is compatible with strided semantics.
isStrided(MemRefType t)782 bool mlir::isStrided(MemRefType t) {
783 int64_t offset;
784 SmallVector<int64_t, 4> stridesAndOffset;
785 auto res = getStridesAndOffset(t, stridesAndOffset, offset);
786 return succeeded(res);
787 }
788