1 //===- StandardTypes.cpp - MLIR Standard Type Classes ---------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/IR/StandardTypes.h"
10 #include "TypeDetail.h"
11 #include "mlir/IR/AffineExpr.h"
12 #include "mlir/IR/AffineMap.h"
13 #include "mlir/IR/Diagnostics.h"
14 #include "mlir/IR/Dialect.h"
15 #include "llvm/ADT/APFloat.h"
16 #include "llvm/ADT/Twine.h"
17 
18 using namespace mlir;
19 using namespace mlir::detail;
20 
21 //===----------------------------------------------------------------------===//
22 // Type
23 //===----------------------------------------------------------------------===//
24 
isBF16()25 bool Type::isBF16() { return isa<BFloat16Type>(); }
isF16()26 bool Type::isF16() { return isa<Float16Type>(); }
isF32()27 bool Type::isF32() { return isa<Float32Type>(); }
isF64()28 bool Type::isF64() { return isa<Float64Type>(); }
29 
isIndex()30 bool Type::isIndex() { return isa<IndexType>(); }
31 
32 /// Return true if this is an integer type with the specified width.
isInteger(unsigned width)33 bool Type::isInteger(unsigned width) {
34   if (auto intTy = dyn_cast<IntegerType>())
35     return intTy.getWidth() == width;
36   return false;
37 }
38 
isSignlessInteger()39 bool Type::isSignlessInteger() {
40   if (auto intTy = dyn_cast<IntegerType>())
41     return intTy.isSignless();
42   return false;
43 }
44 
isSignlessInteger(unsigned width)45 bool Type::isSignlessInteger(unsigned width) {
46   if (auto intTy = dyn_cast<IntegerType>())
47     return intTy.isSignless() && intTy.getWidth() == width;
48   return false;
49 }
50 
isSignedInteger()51 bool Type::isSignedInteger() {
52   if (auto intTy = dyn_cast<IntegerType>())
53     return intTy.isSigned();
54   return false;
55 }
56 
isSignedInteger(unsigned width)57 bool Type::isSignedInteger(unsigned width) {
58   if (auto intTy = dyn_cast<IntegerType>())
59     return intTy.isSigned() && intTy.getWidth() == width;
60   return false;
61 }
62 
isUnsignedInteger()63 bool Type::isUnsignedInteger() {
64   if (auto intTy = dyn_cast<IntegerType>())
65     return intTy.isUnsigned();
66   return false;
67 }
68 
isUnsignedInteger(unsigned width)69 bool Type::isUnsignedInteger(unsigned width) {
70   if (auto intTy = dyn_cast<IntegerType>())
71     return intTy.isUnsigned() && intTy.getWidth() == width;
72   return false;
73 }
74 
isSignlessIntOrIndex()75 bool Type::isSignlessIntOrIndex() {
76   return isSignlessInteger() || isa<IndexType>();
77 }
78 
isSignlessIntOrIndexOrFloat()79 bool Type::isSignlessIntOrIndexOrFloat() {
80   return isSignlessInteger() || isa<IndexType, FloatType>();
81 }
82 
isSignlessIntOrFloat()83 bool Type::isSignlessIntOrFloat() {
84   return isSignlessInteger() || isa<FloatType>();
85 }
86 
isIntOrIndex()87 bool Type::isIntOrIndex() { return isa<IntegerType>() || isIndex(); }
88 
isIntOrFloat()89 bool Type::isIntOrFloat() { return isa<IntegerType, FloatType>(); }
90 
isIntOrIndexOrFloat()91 bool Type::isIntOrIndexOrFloat() { return isIntOrFloat() || isIndex(); }
92 
getIntOrFloatBitWidth()93 unsigned Type::getIntOrFloatBitWidth() {
94   assert(isIntOrFloat() && "only integers and floats have a bitwidth");
95   if (auto intType = dyn_cast<IntegerType>())
96     return intType.getWidth();
97   return cast<FloatType>().getWidth();
98 }
99 
100 //===----------------------------------------------------------------------===//
101 /// ComplexType
102 //===----------------------------------------------------------------------===//
103 
get(Type elementType)104 ComplexType ComplexType::get(Type elementType) {
105   return Base::get(elementType.getContext(), elementType);
106 }
107 
getChecked(Type elementType,Location location)108 ComplexType ComplexType::getChecked(Type elementType, Location location) {
109   return Base::getChecked(location, elementType);
110 }
111 
112 /// Verify the construction of an integer type.
verifyConstructionInvariants(Location loc,Type elementType)113 LogicalResult ComplexType::verifyConstructionInvariants(Location loc,
114                                                         Type elementType) {
115   if (!elementType.isIntOrFloat())
116     return emitError(loc, "invalid element type for complex");
117   return success();
118 }
119 
getElementType()120 Type ComplexType::getElementType() { return getImpl()->elementType; }
121 
122 //===----------------------------------------------------------------------===//
123 // Integer Type
124 //===----------------------------------------------------------------------===//
125 
126 // static constexpr must have a definition (until in C++17 and inline variable).
127 constexpr unsigned IntegerType::kMaxWidth;
128 
129 /// Verify the construction of an integer type.
130 LogicalResult
verifyConstructionInvariants(Location loc,unsigned width,SignednessSemantics signedness)131 IntegerType::verifyConstructionInvariants(Location loc, unsigned width,
132                                           SignednessSemantics signedness) {
133   if (width > IntegerType::kMaxWidth) {
134     return emitError(loc) << "integer bitwidth is limited to "
135                           << IntegerType::kMaxWidth << " bits";
136   }
137   return success();
138 }
139 
getWidth() const140 unsigned IntegerType::getWidth() const { return getImpl()->width; }
141 
getSignedness() const142 IntegerType::SignednessSemantics IntegerType::getSignedness() const {
143   return getImpl()->signedness;
144 }
145 
146 //===----------------------------------------------------------------------===//
147 // Float Type
148 //===----------------------------------------------------------------------===//
149 
getWidth()150 unsigned FloatType::getWidth() {
151   if (isa<Float16Type, BFloat16Type>())
152     return 16;
153   if (isa<Float32Type>())
154     return 32;
155   if (isa<Float64Type>())
156     return 64;
157   llvm_unreachable("unexpected float type");
158 }
159 
160 /// Returns the floating semantics for the given type.
getFloatSemantics()161 const llvm::fltSemantics &FloatType::getFloatSemantics() {
162   if (isa<BFloat16Type>())
163     return APFloat::BFloat();
164   if (isa<Float16Type>())
165     return APFloat::IEEEhalf();
166   if (isa<Float32Type>())
167     return APFloat::IEEEsingle();
168   if (isa<Float64Type>())
169     return APFloat::IEEEdouble();
170   llvm_unreachable("non-floating point type used");
171 }
172 
173 //===----------------------------------------------------------------------===//
174 // ShapedType
175 //===----------------------------------------------------------------------===//
176 constexpr int64_t ShapedType::kDynamicSize;
177 constexpr int64_t ShapedType::kDynamicStrideOrOffset;
178 
getElementType() const179 Type ShapedType::getElementType() const {
180   return static_cast<ImplType *>(impl)->elementType;
181 }
182 
getElementTypeBitWidth() const183 unsigned ShapedType::getElementTypeBitWidth() const {
184   return getElementType().getIntOrFloatBitWidth();
185 }
186 
getNumElements() const187 int64_t ShapedType::getNumElements() const {
188   assert(hasStaticShape() && "cannot get element count of dynamic shaped type");
189   auto shape = getShape();
190   int64_t num = 1;
191   for (auto dim : shape)
192     num *= dim;
193   return num;
194 }
195 
getRank() const196 int64_t ShapedType::getRank() const { return getShape().size(); }
197 
hasRank() const198 bool ShapedType::hasRank() const {
199   return !isa<UnrankedMemRefType, UnrankedTensorType>();
200 }
201 
getDimSize(unsigned idx) const202 int64_t ShapedType::getDimSize(unsigned idx) const {
203   assert(idx < getRank() && "invalid index for shaped type");
204   return getShape()[idx];
205 }
206 
isDynamicDim(unsigned idx) const207 bool ShapedType::isDynamicDim(unsigned idx) const {
208   assert(idx < getRank() && "invalid index for shaped type");
209   return isDynamic(getShape()[idx]);
210 }
211 
getDynamicDimIndex(unsigned index) const212 unsigned ShapedType::getDynamicDimIndex(unsigned index) const {
213   assert(index < getRank() && "invalid index");
214   assert(ShapedType::isDynamic(getDimSize(index)) && "invalid index");
215   return llvm::count_if(getShape().take_front(index), ShapedType::isDynamic);
216 }
217 
218 /// Get the number of bits require to store a value of the given shaped type.
219 /// Compute the value recursively since tensors are allowed to have vectors as
220 /// elements.
getSizeInBits() const221 int64_t ShapedType::getSizeInBits() const {
222   assert(hasStaticShape() &&
223          "cannot get the bit size of an aggregate with a dynamic shape");
224 
225   auto elementType = getElementType();
226   if (elementType.isIntOrFloat())
227     return elementType.getIntOrFloatBitWidth() * getNumElements();
228 
229   if (auto complexType = elementType.dyn_cast<ComplexType>()) {
230     elementType = complexType.getElementType();
231     return elementType.getIntOrFloatBitWidth() * getNumElements() * 2;
232   }
233 
234   // Tensors can have vectors and other tensors as elements, other shaped types
235   // cannot.
236   assert(isa<TensorType>() && "unsupported element type");
237   assert((elementType.isa<VectorType, TensorType>()) &&
238          "unsupported tensor element type");
239   return getNumElements() * elementType.cast<ShapedType>().getSizeInBits();
240 }
241 
getShape() const242 ArrayRef<int64_t> ShapedType::getShape() const {
243   if (auto vectorType = dyn_cast<VectorType>())
244     return vectorType.getShape();
245   if (auto tensorType = dyn_cast<RankedTensorType>())
246     return tensorType.getShape();
247   return cast<MemRefType>().getShape();
248 }
249 
getNumDynamicDims() const250 int64_t ShapedType::getNumDynamicDims() const {
251   return llvm::count_if(getShape(), isDynamic);
252 }
253 
hasStaticShape() const254 bool ShapedType::hasStaticShape() const {
255   return hasRank() && llvm::none_of(getShape(), isDynamic);
256 }
257 
hasStaticShape(ArrayRef<int64_t> shape) const258 bool ShapedType::hasStaticShape(ArrayRef<int64_t> shape) const {
259   return hasStaticShape() && getShape() == shape;
260 }
261 
262 //===----------------------------------------------------------------------===//
263 // VectorType
264 //===----------------------------------------------------------------------===//
265 
get(ArrayRef<int64_t> shape,Type elementType)266 VectorType VectorType::get(ArrayRef<int64_t> shape, Type elementType) {
267   return Base::get(elementType.getContext(), shape, elementType);
268 }
269 
getChecked(ArrayRef<int64_t> shape,Type elementType,Location location)270 VectorType VectorType::getChecked(ArrayRef<int64_t> shape, Type elementType,
271                                   Location location) {
272   return Base::getChecked(location, shape, elementType);
273 }
274 
verifyConstructionInvariants(Location loc,ArrayRef<int64_t> shape,Type elementType)275 LogicalResult VectorType::verifyConstructionInvariants(Location loc,
276                                                        ArrayRef<int64_t> shape,
277                                                        Type elementType) {
278   if (shape.empty())
279     return emitError(loc, "vector types must have at least one dimension");
280 
281   if (!isValidElementType(elementType))
282     return emitError(loc, "vector elements must be int or float type");
283 
284   if (any_of(shape, [](int64_t i) { return i <= 0; }))
285     return emitError(loc, "vector types must have positive constant sizes");
286 
287   return success();
288 }
289 
getShape() const290 ArrayRef<int64_t> VectorType::getShape() const { return getImpl()->getShape(); }
291 
292 //===----------------------------------------------------------------------===//
293 // TensorType
294 //===----------------------------------------------------------------------===//
295 
296 // Check if "elementType" can be an element type of a tensor. Emit errors if
297 // location is not nullptr.  Returns failure if check failed.
checkTensorElementType(Location location,Type elementType)298 static LogicalResult checkTensorElementType(Location location,
299                                             Type elementType) {
300   if (!TensorType::isValidElementType(elementType))
301     return emitError(location, "invalid tensor element type: ") << elementType;
302   return success();
303 }
304 
305 /// Return true if the specified element type is ok in a tensor.
isValidElementType(Type type)306 bool TensorType::isValidElementType(Type type) {
307   // Note: Non standard/builtin types are allowed to exist within tensor
308   // types. Dialects are expected to verify that tensor types have a valid
309   // element type within that dialect.
310   return type.isa<ComplexType, FloatType, IntegerType, OpaqueType, VectorType,
311                   IndexType>() ||
312          !type.getDialect().getNamespace().empty();
313 }
314 
315 //===----------------------------------------------------------------------===//
316 // RankedTensorType
317 //===----------------------------------------------------------------------===//
318 
get(ArrayRef<int64_t> shape,Type elementType)319 RankedTensorType RankedTensorType::get(ArrayRef<int64_t> shape,
320                                        Type elementType) {
321   return Base::get(elementType.getContext(), shape, elementType);
322 }
323 
getChecked(ArrayRef<int64_t> shape,Type elementType,Location location)324 RankedTensorType RankedTensorType::getChecked(ArrayRef<int64_t> shape,
325                                               Type elementType,
326                                               Location location) {
327   return Base::getChecked(location, shape, elementType);
328 }
329 
verifyConstructionInvariants(Location loc,ArrayRef<int64_t> shape,Type elementType)330 LogicalResult RankedTensorType::verifyConstructionInvariants(
331     Location loc, ArrayRef<int64_t> shape, Type elementType) {
332   for (int64_t s : shape) {
333     if (s < -1)
334       return emitError(loc, "invalid tensor dimension size");
335   }
336   return checkTensorElementType(loc, elementType);
337 }
338 
getShape() const339 ArrayRef<int64_t> RankedTensorType::getShape() const {
340   return getImpl()->getShape();
341 }
342 
343 //===----------------------------------------------------------------------===//
344 // UnrankedTensorType
345 //===----------------------------------------------------------------------===//
346 
get(Type elementType)347 UnrankedTensorType UnrankedTensorType::get(Type elementType) {
348   return Base::get(elementType.getContext(), elementType);
349 }
350 
getChecked(Type elementType,Location location)351 UnrankedTensorType UnrankedTensorType::getChecked(Type elementType,
352                                                   Location location) {
353   return Base::getChecked(location, elementType);
354 }
355 
356 LogicalResult
verifyConstructionInvariants(Location loc,Type elementType)357 UnrankedTensorType::verifyConstructionInvariants(Location loc,
358                                                  Type elementType) {
359   return checkTensorElementType(loc, elementType);
360 }
361 
362 //===----------------------------------------------------------------------===//
363 // BaseMemRefType
364 //===----------------------------------------------------------------------===//
365 
getMemorySpace() const366 unsigned BaseMemRefType::getMemorySpace() const {
367   return static_cast<ImplType *>(impl)->memorySpace;
368 }
369 
370 //===----------------------------------------------------------------------===//
371 // MemRefType
372 //===----------------------------------------------------------------------===//
373 
374 /// Get or create a new MemRefType based on shape, element type, affine
375 /// map composition, and memory space.  Assumes the arguments define a
376 /// well-formed MemRef type.  Use getChecked to gracefully handle MemRefType
377 /// construction failures.
get(ArrayRef<int64_t> shape,Type elementType,ArrayRef<AffineMap> affineMapComposition,unsigned memorySpace)378 MemRefType MemRefType::get(ArrayRef<int64_t> shape, Type elementType,
379                            ArrayRef<AffineMap> affineMapComposition,
380                            unsigned memorySpace) {
381   auto result = getImpl(shape, elementType, affineMapComposition, memorySpace,
382                         /*location=*/llvm::None);
383   assert(result && "Failed to construct instance of MemRefType.");
384   return result;
385 }
386 
387 /// Get or create a new MemRefType based on shape, element type, affine
388 /// map composition, and memory space declared at the given location.
389 /// If the location is unknown, the last argument should be an instance of
390 /// UnknownLoc.  If the MemRefType defined by the arguments would be
391 /// ill-formed, emits errors (to the handler registered with the context or to
392 /// the error stream) and returns nullptr.
getChecked(ArrayRef<int64_t> shape,Type elementType,ArrayRef<AffineMap> affineMapComposition,unsigned memorySpace,Location location)393 MemRefType MemRefType::getChecked(ArrayRef<int64_t> shape, Type elementType,
394                                   ArrayRef<AffineMap> affineMapComposition,
395                                   unsigned memorySpace, Location location) {
396   return getImpl(shape, elementType, affineMapComposition, memorySpace,
397                  location);
398 }
399 
400 /// Get or create a new MemRefType defined by the arguments.  If the resulting
401 /// type would be ill-formed, return nullptr.  If the location is provided,
402 /// emit detailed error messages.  To emit errors when the location is unknown,
403 /// pass in an instance of UnknownLoc.
getImpl(ArrayRef<int64_t> shape,Type elementType,ArrayRef<AffineMap> affineMapComposition,unsigned memorySpace,Optional<Location> location)404 MemRefType MemRefType::getImpl(ArrayRef<int64_t> shape, Type elementType,
405                                ArrayRef<AffineMap> affineMapComposition,
406                                unsigned memorySpace,
407                                Optional<Location> location) {
408   auto *context = elementType.getContext();
409 
410   if (!BaseMemRefType::isValidElementType(elementType))
411     return emitOptionalError(location, "invalid memref element type"),
412            MemRefType();
413 
414   for (int64_t s : shape) {
415     // Negative sizes are not allowed except for `-1` that means dynamic size.
416     if (s < -1)
417       return emitOptionalError(location, "invalid memref size"), MemRefType();
418   }
419 
420   // Check that the structure of the composition is valid, i.e. that each
421   // subsequent affine map has as many inputs as the previous map has results.
422   // Take the dimensionality of the MemRef for the first map.
423   auto dim = shape.size();
424   unsigned i = 0;
425   for (const auto &affineMap : affineMapComposition) {
426     if (affineMap.getNumDims() != dim) {
427       if (location)
428         emitError(*location)
429             << "memref affine map dimension mismatch between "
430             << (i == 0 ? Twine("memref rank") : "affine map " + Twine(i))
431             << " and affine map" << i + 1 << ": " << dim
432             << " != " << affineMap.getNumDims();
433       return nullptr;
434     }
435 
436     dim = affineMap.getNumResults();
437     ++i;
438   }
439 
440   // Drop identity maps from the composition.
441   // This may lead to the composition becoming empty, which is interpreted as an
442   // implicit identity.
443   SmallVector<AffineMap, 2> cleanedAffineMapComposition;
444   for (const auto &map : affineMapComposition) {
445     if (map.isIdentity())
446       continue;
447     cleanedAffineMapComposition.push_back(map);
448   }
449 
450   return Base::get(context, shape, elementType, cleanedAffineMapComposition,
451                    memorySpace);
452 }
453 
getShape() const454 ArrayRef<int64_t> MemRefType::getShape() const { return getImpl()->getShape(); }
455 
getAffineMaps() const456 ArrayRef<AffineMap> MemRefType::getAffineMaps() const {
457   return getImpl()->getAffineMaps();
458 }
459 
460 //===----------------------------------------------------------------------===//
461 // UnrankedMemRefType
462 //===----------------------------------------------------------------------===//
463 
get(Type elementType,unsigned memorySpace)464 UnrankedMemRefType UnrankedMemRefType::get(Type elementType,
465                                            unsigned memorySpace) {
466   return Base::get(elementType.getContext(), elementType, memorySpace);
467 }
468 
getChecked(Type elementType,unsigned memorySpace,Location location)469 UnrankedMemRefType UnrankedMemRefType::getChecked(Type elementType,
470                                                   unsigned memorySpace,
471                                                   Location location) {
472   return Base::getChecked(location, elementType, memorySpace);
473 }
474 
475 LogicalResult
verifyConstructionInvariants(Location loc,Type elementType,unsigned memorySpace)476 UnrankedMemRefType::verifyConstructionInvariants(Location loc, Type elementType,
477                                                  unsigned memorySpace) {
478   if (!BaseMemRefType::isValidElementType(elementType))
479     return emitError(loc, "invalid memref element type");
480   return success();
481 }
482 
483 // Fallback cases for terminal dim/sym/cst that are not part of a binary op (
484 // i.e. single term). Accumulate the AffineExpr into the existing one.
extractStridesFromTerm(AffineExpr e,AffineExpr multiplicativeFactor,MutableArrayRef<AffineExpr> strides,AffineExpr & offset)485 static void extractStridesFromTerm(AffineExpr e,
486                                    AffineExpr multiplicativeFactor,
487                                    MutableArrayRef<AffineExpr> strides,
488                                    AffineExpr &offset) {
489   if (auto dim = e.dyn_cast<AffineDimExpr>())
490     strides[dim.getPosition()] =
491         strides[dim.getPosition()] + multiplicativeFactor;
492   else
493     offset = offset + e * multiplicativeFactor;
494 }
495 
496 /// Takes a single AffineExpr `e` and populates the `strides` array with the
497 /// strides expressions for each dim position.
498 /// The convention is that the strides for dimensions d0, .. dn appear in
499 /// order to make indexing intuitive into the result.
extractStrides(AffineExpr e,AffineExpr multiplicativeFactor,MutableArrayRef<AffineExpr> strides,AffineExpr & offset)500 static LogicalResult extractStrides(AffineExpr e,
501                                     AffineExpr multiplicativeFactor,
502                                     MutableArrayRef<AffineExpr> strides,
503                                     AffineExpr &offset) {
504   auto bin = e.dyn_cast<AffineBinaryOpExpr>();
505   if (!bin) {
506     extractStridesFromTerm(e, multiplicativeFactor, strides, offset);
507     return success();
508   }
509 
510   if (bin.getKind() == AffineExprKind::CeilDiv ||
511       bin.getKind() == AffineExprKind::FloorDiv ||
512       bin.getKind() == AffineExprKind::Mod)
513     return failure();
514 
515   if (bin.getKind() == AffineExprKind::Mul) {
516     auto dim = bin.getLHS().dyn_cast<AffineDimExpr>();
517     if (dim) {
518       strides[dim.getPosition()] =
519           strides[dim.getPosition()] + bin.getRHS() * multiplicativeFactor;
520       return success();
521     }
522     // LHS and RHS may both contain complex expressions of dims. Try one path
523     // and if it fails try the other. This is guaranteed to succeed because
524     // only one path may have a `dim`, otherwise this is not an AffineExpr in
525     // the first place.
526     if (bin.getLHS().isSymbolicOrConstant())
527       return extractStrides(bin.getRHS(), multiplicativeFactor * bin.getLHS(),
528                             strides, offset);
529     return extractStrides(bin.getLHS(), multiplicativeFactor * bin.getRHS(),
530                           strides, offset);
531   }
532 
533   if (bin.getKind() == AffineExprKind::Add) {
534     auto res1 =
535         extractStrides(bin.getLHS(), multiplicativeFactor, strides, offset);
536     auto res2 =
537         extractStrides(bin.getRHS(), multiplicativeFactor, strides, offset);
538     return success(succeeded(res1) && succeeded(res2));
539   }
540 
541   llvm_unreachable("unexpected binary operation");
542 }
543 
getStridesAndOffset(MemRefType t,SmallVectorImpl<AffineExpr> & strides,AffineExpr & offset)544 LogicalResult mlir::getStridesAndOffset(MemRefType t,
545                                         SmallVectorImpl<AffineExpr> &strides,
546                                         AffineExpr &offset) {
547   auto affineMaps = t.getAffineMaps();
548   // For now strides are only computed on a single affine map with a single
549   // result (i.e. the closed subset of linearization maps that are compatible
550   // with striding semantics).
551   // TODO: support more forms on a per-need basis.
552   if (affineMaps.size() > 1)
553     return failure();
554   if (affineMaps.size() == 1 && affineMaps[0].getNumResults() != 1)
555     return failure();
556 
557   auto zero = getAffineConstantExpr(0, t.getContext());
558   auto one = getAffineConstantExpr(1, t.getContext());
559   offset = zero;
560   strides.assign(t.getRank(), zero);
561 
562   AffineMap m;
563   if (!affineMaps.empty()) {
564     m = affineMaps.front();
565     assert(!m.isIdentity() && "unexpected identity map");
566   }
567 
568   // Canonical case for empty map.
569   if (!m) {
570     // 0-D corner case, offset is already 0.
571     if (t.getRank() == 0)
572       return success();
573     auto stridedExpr =
574         makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
575     if (succeeded(extractStrides(stridedExpr, one, strides, offset)))
576       return success();
577     assert(false && "unexpected failure: extract strides in canonical layout");
578   }
579 
580   // Non-canonical case requires more work.
581   auto stridedExpr =
582       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
583   if (failed(extractStrides(stridedExpr, one, strides, offset))) {
584     offset = AffineExpr();
585     strides.clear();
586     return failure();
587   }
588 
589   // Simplify results to allow folding to constants and simple checks.
590   unsigned numDims = m.getNumDims();
591   unsigned numSymbols = m.getNumSymbols();
592   offset = simplifyAffineExpr(offset, numDims, numSymbols);
593   for (auto &stride : strides)
594     stride = simplifyAffineExpr(stride, numDims, numSymbols);
595 
596   /// In practice, a strided memref must be internally non-aliasing. Test
597   /// against 0 as a proxy.
598   /// TODO: static cases can have more advanced checks.
599   /// TODO: dynamic cases would require a way to compare symbolic
600   /// expressions and would probably need an affine set context propagated
601   /// everywhere.
602   if (llvm::any_of(strides, [](AffineExpr e) {
603         return e == getAffineConstantExpr(0, e.getContext());
604       })) {
605     offset = AffineExpr();
606     strides.clear();
607     return failure();
608   }
609 
610   return success();
611 }
612 
getStridesAndOffset(MemRefType t,SmallVectorImpl<int64_t> & strides,int64_t & offset)613 LogicalResult mlir::getStridesAndOffset(MemRefType t,
614                                         SmallVectorImpl<int64_t> &strides,
615                                         int64_t &offset) {
616   AffineExpr offsetExpr;
617   SmallVector<AffineExpr, 4> strideExprs;
618   if (failed(::getStridesAndOffset(t, strideExprs, offsetExpr)))
619     return failure();
620   if (auto cst = offsetExpr.dyn_cast<AffineConstantExpr>())
621     offset = cst.getValue();
622   else
623     offset = ShapedType::kDynamicStrideOrOffset;
624   for (auto e : strideExprs) {
625     if (auto c = e.dyn_cast<AffineConstantExpr>())
626       strides.push_back(c.getValue());
627     else
628       strides.push_back(ShapedType::kDynamicStrideOrOffset);
629   }
630   return success();
631 }
632 
633 //===----------------------------------------------------------------------===//
634 /// TupleType
635 //===----------------------------------------------------------------------===//
636 
637 /// Get or create a new TupleType with the provided element types. Assumes the
638 /// arguments define a well-formed type.
get(TypeRange elementTypes,MLIRContext * context)639 TupleType TupleType::get(TypeRange elementTypes, MLIRContext *context) {
640   return Base::get(context, elementTypes);
641 }
642 
643 /// Get or create an empty tuple type.
get(MLIRContext * context)644 TupleType TupleType::get(MLIRContext *context) { return get({}, context); }
645 
646 /// Return the elements types for this tuple.
getTypes() const647 ArrayRef<Type> TupleType::getTypes() const { return getImpl()->getTypes(); }
648 
649 /// Accumulate the types contained in this tuple and tuples nested within it.
650 /// Note that this only flattens nested tuples, not any other container type,
651 /// e.g. a tuple<i32, tensor<i32>, tuple<f32, tuple<i64>>> is flattened to
652 /// (i32, tensor<i32>, f32, i64)
getFlattenedTypes(SmallVectorImpl<Type> & types)653 void TupleType::getFlattenedTypes(SmallVectorImpl<Type> &types) {
654   for (Type type : getTypes()) {
655     if (auto nestedTuple = type.dyn_cast<TupleType>())
656       nestedTuple.getFlattenedTypes(types);
657     else
658       types.push_back(type);
659   }
660 }
661 
662 /// Return the number of element types.
size() const663 size_t TupleType::size() const { return getImpl()->size(); }
664 
makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,int64_t offset,MLIRContext * context)665 AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef<int64_t> strides,
666                                            int64_t offset,
667                                            MLIRContext *context) {
668   AffineExpr expr;
669   unsigned nSymbols = 0;
670 
671   // AffineExpr for offset.
672   // Static case.
673   if (offset != MemRefType::getDynamicStrideOrOffset()) {
674     auto cst = getAffineConstantExpr(offset, context);
675     expr = cst;
676   } else {
677     // Dynamic case, new symbol for the offset.
678     auto sym = getAffineSymbolExpr(nSymbols++, context);
679     expr = sym;
680   }
681 
682   // AffineExpr for strides.
683   for (auto en : llvm::enumerate(strides)) {
684     auto dim = en.index();
685     auto stride = en.value();
686     assert(stride != 0 && "Invalid stride specification");
687     auto d = getAffineDimExpr(dim, context);
688     AffineExpr mult;
689     // Static case.
690     if (stride != MemRefType::getDynamicStrideOrOffset())
691       mult = getAffineConstantExpr(stride, context);
692     else
693       // Dynamic case, new symbol for each new stride.
694       mult = getAffineSymbolExpr(nSymbols++, context);
695     expr = expr + d * mult;
696   }
697 
698   return AffineMap::get(strides.size(), nSymbols, expr);
699 }
700 
701 /// Return a version of `t` with identity layout if it can be determined
702 /// statically that the layout is the canonical contiguous strided layout.
703 /// Otherwise pass `t`'s layout into `simplifyAffineMap` and return a copy of
704 /// `t` with simplified layout.
705 /// If `t` has multiple layout maps or a multi-result layout, just return `t`.
canonicalizeStridedLayout(MemRefType t)706 MemRefType mlir::canonicalizeStridedLayout(MemRefType t) {
707   auto affineMaps = t.getAffineMaps();
708   // Already in canonical form.
709   if (affineMaps.empty())
710     return t;
711 
712   // Can't reduce to canonical identity form, return in canonical form.
713   if (affineMaps.size() > 1 || affineMaps[0].getNumResults() > 1)
714     return t;
715 
716   // If the canonical strided layout for the sizes of `t` is equal to the
717   // simplified layout of `t` we can just return an empty layout. Otherwise,
718   // just simplify the existing layout.
719   AffineExpr expr =
720       makeCanonicalStridedLayoutExpr(t.getShape(), t.getContext());
721   auto m = affineMaps[0];
722   auto simplifiedLayoutExpr =
723       simplifyAffineExpr(m.getResult(0), m.getNumDims(), m.getNumSymbols());
724   if (expr != simplifiedLayoutExpr)
725     return MemRefType::Builder(t).setAffineMaps({AffineMap::get(
726         m.getNumDims(), m.getNumSymbols(), simplifiedLayoutExpr)});
727   return MemRefType::Builder(t).setAffineMaps({});
728 }
729 
makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,ArrayRef<AffineExpr> exprs,MLIRContext * context)730 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
731                                                 ArrayRef<AffineExpr> exprs,
732                                                 MLIRContext *context) {
733   // Size 0 corner case is useful for canonicalizations.
734   if (llvm::is_contained(sizes, 0))
735     return getAffineConstantExpr(0, context);
736 
737   auto maps = AffineMap::inferFromExprList(exprs);
738   assert(!maps.empty() && "Expected one non-empty map");
739   unsigned numDims = maps[0].getNumDims(), nSymbols = maps[0].getNumSymbols();
740 
741   AffineExpr expr;
742   bool dynamicPoisonBit = false;
743   int64_t runningSize = 1;
744   for (auto en : llvm::zip(llvm::reverse(exprs), llvm::reverse(sizes))) {
745     int64_t size = std::get<1>(en);
746     // Degenerate case, no size =-> no stride
747     if (size == 0)
748       continue;
749     AffineExpr dimExpr = std::get<0>(en);
750     AffineExpr stride = dynamicPoisonBit
751                             ? getAffineSymbolExpr(nSymbols++, context)
752                             : getAffineConstantExpr(runningSize, context);
753     expr = expr ? expr + dimExpr * stride : dimExpr * stride;
754     if (size > 0)
755       runningSize *= size;
756     else
757       dynamicPoisonBit = true;
758   }
759   return simplifyAffineExpr(expr, numDims, nSymbols);
760 }
761 
762 /// Return a version of `t` with a layout that has all dynamic offset and
763 /// strides. This is used to erase the static layout.
eraseStridedLayout(MemRefType t)764 MemRefType mlir::eraseStridedLayout(MemRefType t) {
765   auto val = ShapedType::kDynamicStrideOrOffset;
766   return MemRefType::Builder(t).setAffineMaps(makeStridedLinearLayoutMap(
767       SmallVector<int64_t, 4>(t.getRank(), val), val, t.getContext()));
768 }
769 
makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,MLIRContext * context)770 AffineExpr mlir::makeCanonicalStridedLayoutExpr(ArrayRef<int64_t> sizes,
771                                                 MLIRContext *context) {
772   SmallVector<AffineExpr, 4> exprs;
773   exprs.reserve(sizes.size());
774   for (auto dim : llvm::seq<unsigned>(0, sizes.size()))
775     exprs.push_back(getAffineDimExpr(dim, context));
776   return makeCanonicalStridedLayoutExpr(sizes, exprs, context);
777 }
778 
779 /// Return true if the layout for `t` is compatible with strided semantics.
isStrided(MemRefType t)780 bool mlir::isStrided(MemRefType t) {
781   int64_t offset;
782   SmallVector<int64_t, 4> stridesAndOffset;
783   auto res = getStridesAndOffset(t, stridesAndOffset, offset);
784   return succeeded(res);
785 }
786