1 //===- StandardTypes.cpp - MLIR Standard Type Classes ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/StandardTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
13 #include "mlir/IR/Diagnostics.h"
14 #include "llvm/ADT/APFloat.h"
15 #include "llvm/ADT/Twine.h"
16 
17 using namespace mlir;
18 using namespace mlir::detail;
19 
20 //===----------------------------------------------------------------------===//
21 // Type
22 //===----------------------------------------------------------------------===//
23 
isBF16()24 bool Type::isBF16() { return getKind() == StandardTypes::BF16; }
isF16()25 bool Type::isF16() { return getKind() == StandardTypes::F16; }
isF32()26 bool Type::isF32() { return getKind() == StandardTypes::F32; }
isF64()27 bool Type::isF64() { return getKind() == StandardTypes::F64; }
28 
isIndex()29 bool Type::isIndex() { return isa<IndexType>(); }
30 
31 /// Return true if this is an integer type with the specified width.
isInteger(unsigned width)32 bool Type::isInteger(unsigned width) {
33   if (auto intTy = dyn_cast<IntegerType>())
34     return intTy.getWidth() == width;
35   return false;
36 }
37 
isSignlessInteger()38 bool Type::isSignlessInteger() {
39   if (auto intTy = dyn_cast<IntegerType>())
40     return intTy.isSignless();
41   return false;
42 }
43 
isSignlessInteger(unsigned width)44 bool Type::isSignlessInteger(unsigned width) {
45   if (auto intTy = dyn_cast<IntegerType>())
46     return intTy.isSignless() && intTy.getWidth() == width;
47   return false;
48 }
49 
isSignedInteger()50 bool Type::isSignedInteger() {
51   if (auto intTy = dyn_cast<IntegerType>())
52     return intTy.isSigned();
53   return false;
54 }
55 
isSignedInteger(unsigned width)56 bool Type::isSignedInteger(unsigned width) {
57   if (auto intTy = dyn_cast<IntegerType>())
58     return intTy.isSigned() && intTy.getWidth() == width;
59   return false;
60 }
61 
isUnsignedInteger()62 bool Type::isUnsignedInteger() {
63   if (auto intTy = dyn_cast<IntegerType>())
64     return intTy.isUnsigned();
65   return false;
66 }
67 
isUnsignedInteger(unsigned width)68 bool Type::isUnsignedInteger(unsigned width) {
69   if (auto intTy = dyn_cast<IntegerType>())
70     return intTy.isUnsigned() && intTy.getWidth() == width;
71   return false;
72 }
73 
isSignlessIntOrIndex()74 bool Type::isSignlessIntOrIndex() {
75   return isSignlessInteger() || isa<IndexType>();
76 }
77 
isSignlessIntOrIndexOrFloat()78 bool Type::isSignlessIntOrIndexOrFloat() {
79   return isSignlessInteger() || isa<IndexType, FloatType>();
80 }
81 
isSignlessIntOrFloat()82 bool Type::isSignlessIntOrFloat() {
83   return isSignlessInteger() || isa<FloatType>();
84 }
85 
isIntOrIndex()86 bool Type::isIntOrIndex() { return isa<IntegerType>() || isIndex(); }
87 
isIntOrFloat()88 bool Type::isIntOrFloat() { return isa<IntegerType, FloatType>(); }
89 
isIntOrIndexOrFloat()90 bool Type::isIntOrIndexOrFloat() { return isIntOrFloat() || isIndex(); }
91 
92 //===----------------------------------------------------------------------===//
93 /// ComplexType
94 //===----------------------------------------------------------------------===//
95 
get(Type elementType)96 ComplexType ComplexType::get(Type elementType) {
97   return Base::get(elementType.getContext(), StandardTypes::Complex,
98                    elementType);
99 }
100 
getChecked(Type elementType,Location location)101 ComplexType ComplexType::getChecked(Type elementType, Location location) {
102   return Base::getChecked(location, StandardTypes::Complex, elementType);
103 }
104 
105 /// Verify the construction of an integer type.
verifyConstructionInvariants(Location loc,Type elementType)106 LogicalResult ComplexType::verifyConstructionInvariants(Location loc,
107                                                         Type elementType) {
108   if (!elementType.isIntOrFloat())
109     return emitError(loc, "invalid element type for complex");
110   return success();
111 }
112 
getElementType()113 Type ComplexType::getElementType() { return getImpl()->elementType; }
114 
115 //===----------------------------------------------------------------------===//
116 // Integer Type
117 //===----------------------------------------------------------------------===//
118 
119 // static constexpr must have a definition (until in C++17 and inline variable).
120 constexpr unsigned IntegerType::kMaxWidth;
121 
122 /// Verify the construction of an integer type.
123 LogicalResult
verifyConstructionInvariants(Location loc,unsigned width,SignednessSemantics signedness)124 IntegerType::verifyConstructionInvariants(Location loc, unsigned width,
125                                           SignednessSemantics signedness) {
126   if (width > IntegerType::kMaxWidth) {
127     return emitError(loc) << "integer bitwidth is limited to "
128                           << IntegerType::kMaxWidth << " bits";
129   }
130   return success();
131 }
132 
getWidth() const133 unsigned IntegerType::getWidth() const { return getImpl()->getWidth(); }
134 
getSignedness() const135 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
136   return getImpl()->getSignedness();
137 }
138 
139 //===----------------------------------------------------------------------===//
140 // Float Type
141 //===----------------------------------------------------------------------===//
142 
getWidth()143 unsigned FloatType::getWidth() {
144   switch (getKind()) {
145   case StandardTypes::BF16:
146   case StandardTypes::F16:
147     return 16;
148   case StandardTypes::F32:
149     return 32;
150   case StandardTypes::F64:
151     return 64;
152   default:
153     llvm_unreachable("unexpected type");
154   }
155 }
156 
157 /// Returns the floating semantics for the given type.
getFloatSemantics()158 const llvm::fltSemantics &FloatType::getFloatSemantics() {
159   if (isBF16())
160     return APFloat::BFloat();
161   if (isF16())
162     return APFloat::IEEEhalf();
163   if (isF32())
164     return APFloat::IEEEsingle();
165   if (isF64())
166     return APFloat::IEEEdouble();
167   llvm_unreachable("non-floating point type used");
168 }
169 
getIntOrFloatBitWidth()170 unsigned Type::getIntOrFloatBitWidth() {
171   assert(isIntOrFloat() && "only integers and floats have a bitwidth");
172   if (auto intType = dyn_cast<IntegerType>())
173     return intType.getWidth();
174   return cast<FloatType>().getWidth();
175 }
176 
177 //===----------------------------------------------------------------------===//
178 // ShapedType
179 //===----------------------------------------------------------------------===//
180 constexpr int64_t ShapedType::kDynamicSize;
181 constexpr int64_t ShapedType::kDynamicStrideOrOffset;
182 
getElementType() const183 Type ShapedType::getElementType() const {
184   return static_cast<ImplType *>(impl)->elementType;
185 }
186 
getElementTypeBitWidth() const187 unsigned ShapedType::getElementTypeBitWidth() const {
188   return getElementType().getIntOrFloatBitWidth();
189 }
190 
getNumElements() const191 int64_t ShapedType::getNumElements() const {
192   assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
193   auto shape = getShape();
194   int64_t num = 1;
195   for (auto dim : shape)
196     num *= dim;
197   return num;
198 }
199 
getRank() const200 int64_t ShapedType::getRank() const { return getShape().size(); }
201 
hasRank() const202 bool ShapedType::hasRank() const {
203   return !isa<UnrankedMemRefType, UnrankedTensorType>();
204 }
205 
getDimSize(unsigned idx) const206 int64_t ShapedType::getDimSize(unsigned idx) const {
207   assert(idx < getRank() && "invalid index for shaped type");
208   return getShape()[idx];
209 }
210 
isDynamicDim(unsigned idx) const211 bool ShapedType::isDynamicDim(unsigned idx) const {
212   assert(idx < getRank() && "invalid index for shaped type");
213   return isDynamic(getShape()[idx]);
214 }
215 
getDynamicDimIndex(unsigned index) const216 unsigned ShapedType::getDynamicDimIndex(unsigned index) const {
217   assert(index < getRank() && "invalid index");
218   assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index");
219   return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic);
220 }
221 
222 /// Get the number of bits require to store a value of the given shaped type.
223 /// Compute the value recursively since tensors are allowed to have vectors as
224 /// elements.
getSizeInBits() const225 int64_t ShapedType::getSizeInBits() const {
226   assert(hasStaticShape() &&
227          "cannot get the bit size of an aggregate with a dynamic shape");
228 
229   auto elementType = getElementType();
230   if (elementType.isIntOrFloat())
231     return elementType.getIntOrFloatBitWidth() * getNumElements();
232 
233   // Tensors can have vectors and other tensors as elements, other shaped types
234   // cannot.
235   assert(isa<TensorType>() && "unsupported element type");
236   assert((elementType.isa<VectorType, TensorType>()) &&
237          "unsupported tensor element type");
238   return getNumElements() * elementType.cast<ShapedType>().getSizeInBits();
239 }
240 
getShape() const241 ArrayRef<int64_t> ShapedType::getShape() const {
242   switch (getKind()) {
243   case StandardTypes::Vector:
244     return cast<VectorType>().getShape();
245   case StandardTypes::RankedTensor:
246     return cast<RankedTensorType>().getShape();
247   case StandardTypes::MemRef:
248     return cast<MemRefType>().getShape();
249   default:
250     llvm_unreachable("not a ShapedType or not ranked");
251   }
252 }
253 
getNumDynamicDims() const254 int64_t ShapedType::getNumDynamicDims() const {
255   return llvm::count_if(getShape(), isDynamic);
256 }
257 
hasStaticShape() const258 bool ShapedType::hasStaticShape() const {
259   return hasRank() && llvm::none_of(getShape(), isDynamic);
260 }
261 
hasStaticShape(ArrayRef<int64_t> shape) const262 bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const {
263   return hasStaticShape() && getShape() == shape;
264 }
265 
266 //===----------------------------------------------------------------------===//
267 // VectorType
268 //===----------------------------------------------------------------------===//
269 
get(ArrayRef<int64_t> shape,Type elementType)270 VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) {
271   return Base::get(elementType.getContext(), StandardTypes::Vector, shape,
272                    elementType);
273 }
274 
getChecked(ArrayRef<int64_t> shape,Type elementType,Location location)275 VectorType VectorType::getChecked(ArrayRef<int64_t> shape, Type elementType,
276                                   Location location) {
277   return Base::getChecked(location, StandardTypes::Vector, shape, elementType);
278 }
279 
verifyConstructionInvariants(Location loc,ArrayRef<int64_t> shape,Type elementType)280 LogicalResult VectorType::verifyConstructionInvariants(Location loc,
281                                                        ArrayRef<int64_t> shape,
282                                                        Type elementType) {
283   if (shape.empty())
284     return emitError(loc, "vector types must have at least one dimension");
285 
286   if (!isValidElementType(elementType))
287     return emitError(loc, "vector elements must be int or float type");
288 
289   if (any_of(shape, [](int64_t i) { return i <= 0; }))
290     return emitError(loc, "vector types must have positive constant sizes");
291 
292   return success();
293 }
294 
getShape() const295 ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); }
296 
297 //===----------------------------------------------------------------------===//
298 // TensorType
299 //===----------------------------------------------------------------------===//
300 
301 // Check if "elementType" can be an element type of a tensor. Emit errors if
302 // location is not nullptr.  Returns failure if check failed.
checkTensorElementType(Location location,Type elementType)303 static inline LogicalResult checkTensorElementType(Location location,
304                                                    Type elementType) {
305   if (!TensorType::isValidElementType(elementType))
306     return emitError(location, "invalid tensor element type");
307   return success();
308 }
309 
310 //===----------------------------------------------------------------------===//
311 // RankedTensorType
312 //===----------------------------------------------------------------------===//
313 
get(ArrayRef<int64_t> shape,Type elementType)314 RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape,
315                                        Type elementType) {
316   return Base::get(elementType.getContext(), StandardTypes::RankedTensor, shape,
317                    elementType);
318 }
319 
getChecked(ArrayRef<int64_t> shape,Type elementType,Location location)320 RankedTensorType RankedTensorType::getChecked(ArrayRef<int64_t> shape,
321                                               Type elementType,
322                                               Location location) {
323   return Base::getChecked(location, StandardTypes::RankedTensor, shape,
324                           elementType);
325 }
326 
verifyConstructionInvariants(Location loc,ArrayRef<int64_t> shape,Type elementType)327 LogicalResult RankedTensorType::verifyConstructionInvariants(
328     Location loc, ArrayRef<int64_t> shape, Type elementType) {
329   for (int64_t s : shape) {
330     if (s < -1)
331       return emitError(loc, "invalid tensor dimension size");
332   }
333   return checkTensorElementType(loc, elementType);
334 }
335 
getShape() const336 ArrayRef<int64_t> RankedTensorType::getShape() const {
337   return getImpl()->getShape();
338 }
339 
340 //===----------------------------------------------------------------------===//
341 // UnrankedTensorType
342 //===----------------------------------------------------------------------===//
343 
get(Type elementType)344 UnrankedTensorType UnrankedTensorType::get(Type elementType) {
345   return Base::get(elementType.getContext(), StandardTypes::UnrankedTensor,
346                    elementType);
347 }
348 
getChecked(Type elementType,Location location)349 UnrankedTensorType UnrankedTensorType::getChecked(Type elementType,
350                                                   Location location) {
351   return Base::getChecked(location, StandardTypes::UnrankedTensor, elementType);
352 }
353 
354 LogicalResult
verifyConstructionInvariants(Location loc,Type elementType)355 UnrankedTensorType::verifyConstructionInvariants(Location loc,
356                                                  Type elementType) {
357   return checkTensorElementType(loc, elementType);
358 }
359 
360 //===----------------------------------------------------------------------===//
361 // MemRefType
362 //===----------------------------------------------------------------------===//
363 
364 /// Get or create a new MemRefType based on shape, element type, affine
365 /// map composition, and memory space.  Assumes the arguments define a
366 /// well-formed MemRef type.  Use getChecked to gracefully handle MemRefType
367 /// construction failures.
get(ArrayRef<int64_t> shape,Type elementType,ArrayRef<AffineMap> affineMapComposition,unsigned memorySpace)368 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
369                            ArrayRef<AffineMap> affineMapComposition,
370                            unsigned memorySpace) {
371   auto result = getImpl(shape, elementType, affineMapComposition, memorySpace,
372                         /*location=*/llvm::None);
373   assert(result && "Failed to construct instance of MemRefType.");
374   return result;
375 }
376 
377 /// Get or create a new MemRefType based on shape, element type, affine
378 /// map composition, and memory space declared at the given location.
379 /// If the location is unknown, the last argument should be an instance of
380 /// UnknownLoc.  If the MemRefType defined by the arguments would be
381 /// ill-formed, emits errors (to the handler registered with the context or to
382 /// the error stream) and returns nullptr.
getChecked(ArrayRef<int64_t> shape,Type elementType,ArrayRef<AffineMap> affineMapComposition,unsigned memorySpace,Location location)383 MemRefType MemRefType::getChecked(ArrayRef<int64_t> shape, Type elementType,
384                                   ArrayRef<AffineMap> affineMapComposition,
385                                   unsigned memorySpace, Location location) {
386   return getImpl(shape, elementType, affineMapComposition, memorySpace,
387                  location);
388 }
389 
390 /// Get or create a new MemRefType defined by the arguments.  If the resulting
391 /// type would be ill-formed, return nullptr.  If the location is provided,
392 /// emit detailed error messages.  To emit errors when the location is unknown,
393 /// pass in an instance of UnknownLoc.
getImpl(ArrayRef<int64_t> shape,Type elementType,ArrayRef<AffineMap> affineMapComposition,unsigned memorySpace,Optional<Location> location)394 MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
395                                ArrayRef<AffineMap> affineMapComposition,
396                                unsigned memorySpace,
397                                Optional<Location> location) {
398   auto *context = elementType.getContext();
399 
400   // Check that memref is formed from allowed types.
401   if (!elementType.isIntOrFloat() &&
402       !elementType.isa<VectorType, ComplexType>())
403     return emitOptionalError(location, "invalid memref element type"),
404            MemRefType();
405 
406   for (int64_t s : shape) {
407     // Negative sizes are not allowed except for `-1` that means dynamic size.
408     if (s < -1)
409       return emitOptionalError(location, "invalid memref size"), MemRefType();
410   }
411 
412   // Check that the structure of the composition is valid, i.e. that each
413   // subsequent affine map has as many inputs as the previous map has results.
414   // Take the dimensionality of the MemRef for the first map.
415   auto dim = shape.size();
416   unsigned i = 0;
417   for (const auto &affineMap : affineMapComposition) {
418     if (affineMap.getNumDims() != dim) {
419       if (location)
420         emitError(*location)
421             << "memref affine map dimension mismatch between "
422             << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i))
423             << " and affine map" << i + 1 << ": " << dim
424             << " != " << affineMap.getNumDims();
425       return nullptr;
426     }
427 
428     dim = affineMap.getNumResults();
429     ++i;
430   }
431 
432   // Drop identity maps from the composition.
433   // This may lead to the composition becoming empty, which is interpreted as an
434   // implicit identity.
435   SmallVector<AffineMap, 2> cleanedAffineMapComposition;
436   for (const auto &map : affineMapComposition) {
437     if (map.isIdentity())
438       continue;
439     cleanedAffineMapComposition.push_back(map);
440   }
441 
442   return Base::get(context, StandardTypes::MemRef, shape, elementType,
443                    cleanedAffineMapComposition, memorySpace);
444 }
445 
getShape() const446 ArrayRef<int64_t> MemRefType::getShape() const { return getImpl()->getShape(); }
447 
getAffineMaps() const448 ArrayRef<AffineMap> MemRefType::getAffineMaps() const {
449   return getImpl()->getAffineMaps();
450 }
451 
getMemorySpace() const452 unsigned MemRefType::getMemorySpace() const { return getImpl()->memorySpace; }
453 
454 //===----------------------------------------------------------------------===//
455 // UnrankedMemRefType
456 //===----------------------------------------------------------------------===//
457 
get(Type elementType,unsigned memorySpace)458 UnrankedMemRefType UnrankedMemRefType::get(Type elementType,
459                                            unsigned memorySpace) {
460   return Base::get(elementType.getContext(), StandardTypes::UnrankedMemRef,
461                    elementType, memorySpace);
462 }
463 
getChecked(Type elementType,unsigned memorySpace,Location location)464 UnrankedMemRefType UnrankedMemRefType::getChecked(Type elementType,
465                                                   unsigned memorySpace,
466                                                   Location location) {
467   return Base::getChecked(location, StandardTypes::UnrankedMemRef, elementType,
468                           memorySpace);
469 }
470 
getMemorySpace() const471 unsigned UnrankedMemRefType::getMemorySpace() const {
472   return getImpl()->memorySpace;
473 }
474 
475 LogicalResult
verifyConstructionInvariants(Location loc,Type elementType,unsigned memorySpace)476 UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType,
477                                                  unsigned memorySpace) {
478   // Check that memref is formed from allowed types.
479   if (!elementType.isIntOrFloat() &&
480       !elementType.isa<VectorType, ComplexType>())
481     return emitError(loc, "invalid memref element type");
482   return success();
483 }
484 
485 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
486 // i.e. single term). Accumulate the AffineExpr into the existing one.
extractStridesFromTerm(AffineExpr e,AffineExpr multiplicativeFactor,MutableArrayRef<AffineExpr> strides,AffineExpr & offset)487 static void extractStridesFromTerm(AffineExpr e,
488                                    AffineExpr multiplicativeFactor,
489                                    MutableArrayRef<AffineExpr> strides,
490                                    AffineExpr &offset) {
491   if (auto dim = e.dyn_cast<AffineDimExpr>())
492     strides[dim.getPosition()] =
493         strides[dim.getPosition()] + multiplicativeFactor;
494   else
495     offset = offset + e * multiplicativeFactor;
496 }
497 
498 /// Takes a single AffineExpr `e` and populates the `strides` array with the
499 /// strides expressions for each dim position.
500 /// The convention is that the strides for dimensions d0, .. dn appear in
501 /// order to make indexing intuitive into the result.
extractStrides(AffineExpr e,AffineExpr multiplicativeFactor,MutableArrayRef<AffineExpr> strides,AffineExpr & offset)502 static LogicalResult extractStrides(AffineExpr e,
503                                     AffineExpr multiplicativeFactor,
504                                     MutableArrayRef<AffineExpr> strides,
505                                     AffineExpr &offset) {
506   auto bin = e.dyn_cast<AffineBinaryOpExpr>();
507   if (!bin) {
508     extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
509     return success();
510   }
511 
512   if (bin.getKind() == AffineExprKind::CeilDiv ||
513       bin.getKind() == AffineExprKind::FloorDiv ||
514       bin.getKind() == AffineExprKind::Mod)
515     return failure();
516 
517   if (bin.getKind() == AffineExprKind::Mul) {
518     auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
519     if (dim) {
520       strides[dim.getPosition()] =
521           strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
522       return success();
523     }
524     // LHS and RHS may both contain complex expressions of dims. Try one path
525     // and if it fails try the other. This is guaranteed to succeed because
526     // only one path may have a `dim`, otherwise this is not an AffineExpr in
527     // the first place.
528     if (bin.getLHS().isSymbolicOrConstant())
529       return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
530                             strides, offset);
531     return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
532                           strides, offset);
533   }
534 
535   if (bin.getKind() == AffineExprKind::Add) {
536     auto res1 =
537         extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
538     auto res2 =
539         extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
540     return success(succeeded(res1) && succeeded(res2));
541   }
542 
543   llvm_unreachable("unexpected binary operation");
544 }
545 
getStridesAndOffset(MemRefType t,SmallVectorImpl<AffineExpr> & strides,AffineExpr & offset)546 LogicalResult mlir::getStridesAndOffset(MemRefType t,
547                                         SmallVectorImpl<AffineExpr> &strides,
548                                         AffineExpr &offset) {
549   auto affineMaps = t.getAffineMaps();
550   // For now strides are only computed on a single affine map with a single
551   // result (i.e. the closed subset of linearization maps that are compatible
552   // with striding semantics).
553   // TODO: support more forms on a per-need basis.
554   if (affineMaps.size() > 1)
555     return failure();
556   if (affineMaps.size() == 1 && affineMaps[0].getNumResults() != 1)
557     return failure();
558 
559   auto zero = getAffineConstantExpr(0, t.getContext());
560   auto one = getAffineConstantExpr(1, t.getContext());
561   offset = zero;
562   strides.assign(t.getRank(), zero);
563 
564   AffineMap m;
565   if (!affineMaps.empty()) {
566     m = affineMaps.front();
567     assert(!m.isIdentity() && "unexpected identity map");
568   }
569 
570   // Canonical case for empty map.
571   if (!m) {
572     // 0-D corner case, offset is already 0.
573     if (t.getRank() == 0)
574       return success();
575     auto stridedExpr =
576         makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
577     if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
578       return success();
579     assert(false && "unexpected failure: extract strides in canonical layout");
580   }
581 
582   // Non-canonical case requires more work.
583   auto stridedExpr =
584       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
585   if (failed(extractStrides(stridedExpr, one, strides, offset))) {
586     offset = AffineExpr();
587     strides.clear();
588     return failure();
589   }
590 
591   // Simplify results to allow folding to constants and simple checks.
592   unsigned numDims = m.getNumDims();
593   unsigned numSymbols = m.getNumSymbols();
594   offset = simplifyAffineExpr(offset, numDims, numSymbols);
595   for (auto &stride : strides)
596     stride = simplifyAffineExpr(stride, numDims, numSymbols);
597 
598   /// In practice, a strided memref must be internally non-aliasing. Test
599   /// against 0 as a proxy.
600   /// TODO: static cases can have more advanced checks.
601   /// TODO: dynamic cases would require a way to compare symbolic
602   /// expressions and would probably need an affine set context propagated
603   /// everywhere.
604   if (llvm::any_of(strides, [](AffineExpr e) {
605         return e == getAffineConstantExpr(0, e.getContext());
606       })) {
607     offset = AffineExpr();
608     strides.clear();
609     return failure();
610   }
611 
612   return success();
613 }
614 
getStridesAndOffset(MemRefType t,SmallVectorImpl<int64_t> & strides,int64_t & offset)615 LogicalResult mlir::getStridesAndOffset(MemRefType t,
616                                         SmallVectorImpl<int64_t> &strides,
617                                         int64_t &offset) {
618   AffineExpr offsetExpr;
619   SmallVector<AffineExpr, 4> strideExprs;
620   if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
621     return failure();
622   if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>())
623     offset = cst.getValue();
624   else
625     offset = ShapedType::kDynamicStrideOrOffset;
626   for (auto e : strideExprs) {
627     if (auto c = e.dyn_cast<AffineConstantExpr>())
628       strides.push_back(c.getValue());
629     else
630       strides.push_back(ShapedType::kDynamicStrideOrOffset);
631   }
632   return success();
633 }
634 
635 //===----------------------------------------------------------------------===//
636 /// TupleType
637 //===----------------------------------------------------------------------===//
638 
639 /// Get or create a new TupleType with the provided element types. Assumes the
640 /// arguments define a well-formed type.
get(ArrayRef<Type> elementTypes,MLIRContext * context)641 TupleType TupleType::get(ArrayRef<Type> elementTypes, MLIRContext *context) {
642   return Base::get(context, StandardTypes::Tuple, elementTypes);
643 }
644 
645 /// Return the elements types for this tuple.
getTypes() const646 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
647 
648 /// Accumulate the types contained in this tuple and tuples nested within it.
649 /// Note that this only flattens nested tuples, not any other container type,
650 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
651 /// (i32, tensor<i32>, f32, i64)
getFlattenedTypes(SmallVectorImpl<Type> & types)652 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
653   for (Type type : getTypes()) {
654     if (auto nestedTuple = type.dyn_cast<TupleType>())
655       nestedTuple.getFlattenedTypes(types);
656     else
657       types.push_back(type);
658   }
659 }
660 
661 /// Return the number of element types.
size() const662 size_t TupleType::size() const { return getImpl()->size(); }
663 
makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,int64_t offset,MLIRContext * context)664 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
665                                            int64_t offset,
666                                            MLIRContext *context) {
667   AffineExpr expr;
668   unsigned nSymbols = 0;
669 
670   // AffineExpr for offset.
671   // Static case.
672   if (offset != MemRefType::getDynamicStrideOrOffset()) {
673     auto cst = getAffineConstantExpr(offset, context);
674     expr = cst;
675   } else {
676     // Dynamic case, new symbol for the offset.
677     auto sym = getAffineSymbolExpr(nSymbols++, context);
678     expr = sym;
679   }
680 
681   // AffineExpr for strides.
682   for (auto en : llvm::enumerate(strides)) {
683     auto dim = en.index();
684     auto stride = en.value();
685     assert(stride != 0 && "Invalid stride specification");
686     auto d = getAffineDimExpr(dim, context);
687     AffineExpr mult;
688     // Static case.
689     if (stride != MemRefType::getDynamicStrideOrOffset())
690       mult = getAffineConstantExpr(stride, context);
691     else
692       // Dynamic case, new symbol for each new stride.
693       mult = getAffineSymbolExpr(nSymbols++, context);
694     expr = expr + d * mult;
695   }
696 
697   return AffineMap::get(strides.size(), nSymbols, expr);
698 }
699 
700 /// Return a version of `t` with identity layout if it can be determined
701 /// statically that the layout is the canonical contiguous strided layout.
702 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
703 /// `t` with simplified layout.
704 /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
canonicalizeStridedLayout(MemRefType t)705 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
706   auto affineMaps = t.getAffineMaps();
707   // Already in canonical form.
708   if (affineMaps.empty())
709     return t;
710 
711   // Can't reduce to canonical identity form, return in canonical form.
712   if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1)
713     return t;
714 
715   // If the canonical strided layout for the sizes of `t` is equal to the
716   // simplified layout of `t` we can just return an empty layout. Otherwise,
717   // just simplify the existing layout.
718   AffineExpr expr =
719       makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
720   auto m = affineMaps[0];
721   auto simplifiedLayoutExpr =
722       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
723   if (expr != simplifiedLayoutExpr)
724     return MemRefType::Builder(t).setAffineMaps({AffineMap::get(
725         m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)});
726   return MemRefType::Builder(t).setAffineMaps({});
727 }
728 
makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,ArrayRef<AffineExpr> exprs,MLIRContext * context)729 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
730                                                 ArrayRef<AffineExpr> exprs,
731                                                 MLIRContext *context) {
732   AffineExpr expr;
733   bool dynamicPoisonBit = false;
734   unsigned numDims = 0;
735   unsigned nSymbols = 0;
736   // Compute the number of symbols and dimensions of the passed exprs.
737   for (AffineExpr expr : exprs) {
738     expr.walk([&numDims, &nSymbols](AffineExpr d) {
739       if (AffineDimExpr dim = d.dyn_cast<AffineDimExpr>())
740         numDims = std::max(numDims, dim.getPosition() + 1);
741       else if (AffineSymbolExpr symbol = d.dyn_cast<AffineSymbolExpr>())
742         nSymbols = std::max(nSymbols, symbol.getPosition() + 1);
743     });
744   }
745   int64_t runningSize = 1;
746   for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
747     int64_t size = std::get<1>(en);
748     // Degenerate case, no size =-> no stride
749     if (size == 0)
750       continue;
751     AffineExpr dimExpr = std::get<0>(en);
752     AffineExpr stride = dynamicPoisonBit
753                             ? getAffineSymbolExpr(nSymbols++, context)
754                             : getAffineConstantExpr(runningSize, context);
755     expr = expr ? expr + dimExpr * stride : dimExpr * stride;
756     if (size > 0)
757       runningSize *= size;
758     else
759       dynamicPoisonBit = true;
760   }
761   return simplifyAffineExpr(expr, numDims, nSymbols);
762 }
763 
764 /// Return a version of `t` with a layout that has all dynamic offset and
765 /// strides. This is used to erase the static layout.
eraseStridedLayout(MemRefType t)766 MemRefType mlir::eraseStridedLayout(MemRefType t) {
767   auto val = ShapedType::kDynamicStrideOrOffset;
768   return MemRefType::Builder(t).setAffineMaps(makeStridedLinearLayoutMap(
769       SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext()));
770 }
771 
makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,MLIRContext * context)772 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
773                                                 MLIRContext *context) {
774   SmallVector<AffineExpr, 4> exprs;
775   exprs.reserve(sizes.size());
776   for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
777     exprs.push_back(getAffineDimExpr(dim, context));
778   return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
779 }
780 
781 /// Return true if the layout for `t` is compatible with strided semantics.
isStrided(MemRefType t)782 bool mlir::isStrided(MemRefType t) {
783   int64_t offset;
784   SmallVector<int64_t, 4> stridesAndOffset;
785   auto res = getStridesAndOffset(t, stridesAndOffset, offset);
786   return succeeded(res);
787 }
788