1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/MemRef/IR/MemRef.h"
10 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/Dialect/Utils/StaticValueUtils.h"
15 #include "mlir/IR/AffineMap.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/Interfaces/InferTypeOpInterface.h"
22 #include "mlir/Interfaces/ViewLikeInterface.h"
23 #include "llvm/ADT/STLExtras.h"
24 
25 using namespace mlir;
26 using namespace mlir::memref;
27 
28 /// Materialize a single constant operation from a given attribute value with
29 /// the desired resultant type.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)30 Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
31                                               Attribute value, Type type,
32                                               Location loc) {
33   return builder.create<mlir::ConstantOp>(loc, type, value);
34 }
35 
36 //===----------------------------------------------------------------------===//
37 // Common canonicalization pattern support logic
38 //===----------------------------------------------------------------------===//
39 
40 /// This is a common class used for patterns of the form
41 /// "someop(memrefcast) -> someop".  It folds the source of any memref.cast
42 /// into the root operation directly.
foldMemRefCast(Operation * op,Value inner=nullptr)43 static LogicalResult foldMemRefCast(Operation *op, Value inner = nullptr) {
44   bool folded = false;
45   for (OpOperand &operand : op->getOpOperands()) {
46     auto cast = operand.get().getDefiningOp<CastOp>();
47     if (cast && operand.get() != inner &&
48         !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
49       operand.set(cast.getOperand());
50       folded = true;
51     }
52   }
53   return success(folded);
54 }
55 
56 //===----------------------------------------------------------------------===//
57 // Helpers for GlobalOp
58 //===----------------------------------------------------------------------===//
59 
getTensorTypeFromMemRefType(Type type)60 static Type getTensorTypeFromMemRefType(Type type) {
61   if (auto memref = type.dyn_cast<MemRefType>())
62     return RankedTensorType::get(memref.getShape(), memref.getElementType());
63   if (auto memref = type.dyn_cast<UnrankedMemRefType>())
64     return UnrankedTensorType::get(memref.getElementType());
65   return NoneType::get(type.getContext());
66 }
67 
68 //===----------------------------------------------------------------------===//
69 // AllocOp / AllocaOp
70 //===----------------------------------------------------------------------===//
71 
72 template <typename AllocLikeOp>
verifyAllocLikeOp(AllocLikeOp op)73 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
74   static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
75                 "applies to only alloc or alloca");
76   auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>();
77   if (!memRefType)
78     return op.emitOpError("result must be a memref");
79 
80   if (static_cast<int64_t>(op.dynamicSizes().size()) !=
81       memRefType.getNumDynamicDims())
82     return op.emitOpError("dimension operand count does not equal memref "
83                           "dynamic dimension count");
84 
85   unsigned numSymbols = 0;
86   if (!memRefType.getAffineMaps().empty())
87     numSymbols = memRefType.getAffineMaps().front().getNumSymbols();
88   if (op.symbolOperands().size() != numSymbols)
89     return op.emitOpError("symbol operand count does not equal memref symbol "
90                           "count: expected ")
91            << numSymbols << ", got " << op.symbolOperands().size();
92 
93   return success();
94 }
95 
verify(AllocOp op)96 static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); }
97 
verify(AllocaOp op)98 static LogicalResult verify(AllocaOp op) {
99   // An alloca op needs to have an ancestor with an allocation scope trait.
100   if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
101     return op.emitOpError(
102         "requires an ancestor op with AutomaticAllocationScope trait");
103 
104   return verifyAllocLikeOp(op);
105 }
106 
107 namespace {
108 /// Fold constant dimensions into an alloc like operation.
109 template <typename AllocLikeOp>
110 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
111   using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
112 
matchAndRewrite__anonbe946e4a0111::SimplifyAllocConst113   LogicalResult matchAndRewrite(AllocLikeOp alloc,
114                                 PatternRewriter &rewriter) const override {
115     // Check to see if any dimensions operands are constants.  If so, we can
116     // substitute and drop them.
117     if (llvm::none_of(alloc.dynamicSizes(), [](Value operand) {
118           return matchPattern(operand, matchConstantIndex());
119         }))
120       return failure();
121 
122     auto memrefType = alloc.getType();
123 
124     // Ok, we have one or more constant operands.  Collect the non-constant ones
125     // and keep track of the resultant memref type to build.
126     SmallVector<int64_t, 4> newShapeConstants;
127     newShapeConstants.reserve(memrefType.getRank());
128     SmallVector<Value, 4> dynamicSizes;
129 
130     unsigned dynamicDimPos = 0;
131     for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
132       int64_t dimSize = memrefType.getDimSize(dim);
133       // If this is already static dimension, keep it.
134       if (dimSize != -1) {
135         newShapeConstants.push_back(dimSize);
136         continue;
137       }
138       auto dynamicSize = alloc.dynamicSizes()[dynamicDimPos];
139       auto *defOp = dynamicSize.getDefiningOp();
140       if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
141         // Dynamic shape dimension will be folded.
142         newShapeConstants.push_back(constantIndexOp.getValue());
143       } else {
144         // Dynamic shape dimension not folded; copy dynamicSize from old memref.
145         newShapeConstants.push_back(-1);
146         dynamicSizes.push_back(dynamicSize);
147       }
148       dynamicDimPos++;
149     }
150 
151     // Create new memref type (which will have fewer dynamic dimensions).
152     MemRefType newMemRefType =
153         MemRefType::Builder(memrefType).setShape(newShapeConstants);
154     assert(static_cast<int64_t>(dynamicSizes.size()) ==
155            newMemRefType.getNumDynamicDims());
156 
157     // Create and insert the alloc op for the new memref.
158     auto newAlloc = rewriter.create<AllocLikeOp>(
159         alloc.getLoc(), newMemRefType, dynamicSizes, alloc.symbolOperands(),
160         alloc.alignmentAttr());
161     // Insert a cast so we have the same type as the old alloc.
162     auto resultCast =
163         rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType());
164 
165     rewriter.replaceOp(alloc, {resultCast});
166     return success();
167   }
168 };
169 
170 /// Fold alloc operations with no users or only store and dealloc uses.
171 template <typename T>
172 struct SimplifyDeadAlloc : public OpRewritePattern<T> {
173   using OpRewritePattern<T>::OpRewritePattern;
174 
matchAndRewrite__anonbe946e4a0111::SimplifyDeadAlloc175   LogicalResult matchAndRewrite(T alloc,
176                                 PatternRewriter &rewriter) const override {
177     if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
178           if (auto storeOp = dyn_cast<StoreOp>(op))
179             return storeOp.value() == alloc;
180           return !isa<DeallocOp>(op);
181         }))
182       return failure();
183 
184     for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
185       rewriter.eraseOp(user);
186 
187     rewriter.eraseOp(alloc);
188     return success();
189   }
190 };
191 } // end anonymous namespace.
192 
buildDealloc(OpBuilder & builder,Value alloc)193 Optional<Operation *> AllocOp::buildDealloc(OpBuilder &builder, Value alloc) {
194   return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
195       .getOperation();
196 }
197 
buildClone(OpBuilder & builder,Value alloc)198 Optional<Value> AllocOp::buildClone(OpBuilder &builder, Value alloc) {
199   return builder.create<memref::CloneOp>(alloc.getLoc(), alloc).getResult();
200 }
201 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)202 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
203                                           MLIRContext *context) {
204   results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
205 }
206 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)207 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
208                                            MLIRContext *context) {
209   results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
210       context);
211 }
212 
213 //===----------------------------------------------------------------------===//
214 // AllocaScopeOp
215 //===----------------------------------------------------------------------===//
216 
print(OpAsmPrinter & p,AllocaScopeOp & op)217 static void print(OpAsmPrinter &p, AllocaScopeOp &op) {
218   bool printBlockTerminators = false;
219 
220   p << " ";
221   if (!op.results().empty()) {
222     p << " -> (" << op.getResultTypes() << ")";
223     printBlockTerminators = true;
224   }
225   p.printRegion(op.bodyRegion(),
226                 /*printEntryBlockArgs=*/false,
227                 /*printBlockTerminators=*/printBlockTerminators);
228   p.printOptionalAttrDict(op->getAttrs());
229 }
230 
parseAllocaScopeOp(OpAsmParser & parser,OperationState & result)231 static ParseResult parseAllocaScopeOp(OpAsmParser &parser,
232                                       OperationState &result) {
233   // Create a region for the body.
234   result.regions.reserve(1);
235   Region *bodyRegion = result.addRegion();
236 
237   // Parse optional results type list.
238   if (parser.parseOptionalArrowTypeList(result.types))
239     return failure();
240 
241   // Parse the body region.
242   if (parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{}))
243     return failure();
244   AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
245                                   result.location);
246 
247   // Parse the optional attribute list.
248   if (parser.parseOptionalAttrDict(result.attributes))
249     return failure();
250 
251   return success();
252 }
253 
verify(AllocaScopeOp op)254 static LogicalResult verify(AllocaScopeOp op) {
255   if (failed(RegionBranchOpInterface::verifyTypes(op)))
256     return failure();
257 
258   return success();
259 }
260 
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)261 void AllocaScopeOp::getSuccessorRegions(
262     Optional<unsigned> index, ArrayRef<Attribute> operands,
263     SmallVectorImpl<RegionSuccessor> &regions) {
264   if (index.hasValue()) {
265     regions.push_back(RegionSuccessor(getResults()));
266     return;
267   }
268 
269   regions.push_back(RegionSuccessor(&bodyRegion()));
270 }
271 
272 //===----------------------------------------------------------------------===//
273 // AssumeAlignmentOp
274 //===----------------------------------------------------------------------===//
275 
verify(AssumeAlignmentOp op)276 static LogicalResult verify(AssumeAlignmentOp op) {
277   unsigned alignment = op.alignment();
278   if (!llvm::isPowerOf2_32(alignment))
279     return op.emitOpError("alignment must be power of 2");
280   return success();
281 }
282 
283 //===----------------------------------------------------------------------===//
284 // BufferCastOp
285 //===----------------------------------------------------------------------===//
286 
fold(ArrayRef<Attribute>)287 OpFoldResult BufferCastOp::fold(ArrayRef<Attribute>) {
288   if (auto tensorLoad = tensor().getDefiningOp<TensorLoadOp>())
289     if (tensorLoad.memref().getType() == getType())
290       return tensorLoad.memref();
291   return {};
292 }
293 
294 namespace {
295 /// Replace tensor_cast + buffer_cast by buffer_cast + memref_cast.
296 struct BufferCast : public OpRewritePattern<BufferCastOp> {
297   using OpRewritePattern<BufferCastOp>::OpRewritePattern;
298 
matchAndRewrite__anonbe946e4a0411::BufferCast299   LogicalResult matchAndRewrite(BufferCastOp bufferCast,
300                                 PatternRewriter &rewriter) const final {
301     auto tensorCastOperand =
302         bufferCast.getOperand().getDefiningOp<tensor::CastOp>();
303     if (!tensorCastOperand)
304       return failure();
305     auto srcTensorType =
306         tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>();
307     if (!srcTensorType)
308       return failure();
309     auto memrefType = MemRefType::get(srcTensorType.getShape(),
310                                       srcTensorType.getElementType());
311     Value memref = rewriter.create<BufferCastOp>(
312         bufferCast.getLoc(), memrefType, tensorCastOperand.getOperand());
313     rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(),
314                                         memref);
315     return success();
316   }
317 };
318 
319 /// Canonicalize memref.tensor_load + memref.buffer_cast to memref.cast when
320 /// type mismatches prevent `BufferCastOp::fold` to kick in.
321 struct TensorLoadToMemRef : public OpRewritePattern<BufferCastOp> {
322   using OpRewritePattern<BufferCastOp>::OpRewritePattern;
323 
matchAndRewrite__anonbe946e4a0411::TensorLoadToMemRef324   LogicalResult matchAndRewrite(BufferCastOp bufferCast,
325                                 PatternRewriter &rewriter) const final {
326     auto tensorLoad = bufferCast.tensor().getDefiningOp<TensorLoadOp>();
327     // Bail unless we have a tensor_load + memref.buffer_cast with different
328     // types. `BufferCastOp::fold` handles the same type case.
329     if (!tensorLoad || tensorLoad.memref().getType() == bufferCast.getType())
330       return failure();
331     // If types are definitely not cast-compatible, bail.
332     if (!CastOp::areCastCompatible(tensorLoad.memref().getType(),
333                                    bufferCast.getType()))
334       return failure();
335 
336     // We already know that the types are potentially cast-compatible. However
337     // in case the affine maps are different, we may need to use a copy if we go
338     // from dynamic to static offset or stride (the canonicalization cannot know
339     // at this point that it is really cast compatible).
340     auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
341       int64_t sourceOffset, targetOffset;
342       SmallVector<int64_t, 4> sourceStrides, targetStrides;
343       if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) ||
344           failed(getStridesAndOffset(target, targetStrides, targetOffset)))
345         return false;
346       auto dynamicToStatic = [](int64_t a, int64_t b) {
347         return a == MemRefType::getDynamicStrideOrOffset() &&
348                b != MemRefType::getDynamicStrideOrOffset();
349       };
350       if (dynamicToStatic(sourceOffset, targetOffset))
351         return false;
352       for (auto it : zip(sourceStrides, targetStrides))
353         if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
354           return false;
355       return true;
356     };
357 
358     auto tensorLoadType = tensorLoad.memref().getType().dyn_cast<MemRefType>();
359     auto bufferCastType = bufferCast.getType().dyn_cast<MemRefType>();
360     if (tensorLoadType && bufferCastType &&
361         !isGuaranteedCastCompatible(tensorLoadType, bufferCastType)) {
362       MemRefType resultType = bufferCastType;
363       auto loc = bufferCast.getLoc();
364       SmallVector<Value, 4> dynamicOperands;
365       for (int i = 0; i < resultType.getRank(); ++i) {
366         if (resultType.getShape()[i] != ShapedType::kDynamicSize)
367           continue;
368         auto index = rewriter.createOrFold<ConstantIndexOp>(loc, i);
369         Value size = rewriter.create<tensor::DimOp>(loc, tensorLoad, index);
370         dynamicOperands.push_back(size);
371       }
372       auto copy =
373           rewriter.create<memref::AllocOp>(loc, resultType, dynamicOperands);
374       rewriter.create<CopyOp>(loc, tensorLoad.memref(), copy);
375       rewriter.replaceOp(bufferCast, {copy});
376     } else
377       rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(),
378                                           tensorLoad.memref());
379     return success();
380   }
381 };
382 
383 } // namespace
384 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)385 void BufferCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
386                                                MLIRContext *context) {
387   results.add<BufferCast, TensorLoadToMemRef>(context);
388 }
389 
390 //===----------------------------------------------------------------------===//
391 // CastOp
392 //===----------------------------------------------------------------------===//
393 
394 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
395 /// source memref. This is useful to to fold a memref.cast into a consuming op
396 /// and implement canonicalization patterns for ops in different dialects that
397 /// may consume the results of memref.cast operations. Such foldable memref.cast
398 /// operations are typically inserted as `view` and `subview` ops are
399 /// canonicalized, to preserve the type compatibility of their uses.
400 ///
401 /// Returns true when all conditions are met:
402 /// 1. source and result are ranked memrefs with strided semantics and same
403 /// element type and rank.
404 /// 2. each of the source's size, offset or stride has more static information
405 /// than the corresponding result's size, offset or stride.
406 ///
407 /// Example 1:
408 /// ```mlir
409 ///   %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
410 ///   %2 = consumer %1 ... : memref<?x?xf32> ...
411 /// ```
412 ///
413 /// may fold into:
414 ///
415 /// ```mlir
416 ///   %2 = consumer %0 ... : memref<8x16xf32> ...
417 /// ```
418 ///
419 /// Example 2:
420 /// ```
421 ///   %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
422 ///          to memref<?x?xf32>
423 ///   consumer %1 : memref<?x?xf32> ...
424 /// ```
425 ///
426 /// may fold into:
427 ///
428 /// ```
429 ///   consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
430 /// ```
canFoldIntoConsumerOp(CastOp castOp)431 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
432   MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
433   MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
434 
435   // Requires ranked MemRefType.
436   if (!sourceType || !resultType)
437     return false;
438 
439   // Requires same elemental type.
440   if (sourceType.getElementType() != resultType.getElementType())
441     return false;
442 
443   // Requires same rank.
444   if (sourceType.getRank() != resultType.getRank())
445     return false;
446 
447   // Only fold casts between strided memref forms.
448   int64_t sourceOffset, resultOffset;
449   SmallVector<int64_t, 4> sourceStrides, resultStrides;
450   if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
451       failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
452     return false;
453 
454   // If cast is towards more static sizes along any dimension, don't fold.
455   for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
456     auto ss = std::get<0>(it), st = std::get<1>(it);
457     if (ss != st)
458       if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st))
459         return false;
460   }
461 
462   // If cast is towards more static offset along any dimension, don't fold.
463   if (sourceOffset != resultOffset)
464     if (MemRefType::isDynamicStrideOrOffset(sourceOffset) &&
465         !MemRefType::isDynamicStrideOrOffset(resultOffset))
466       return false;
467 
468   // If cast is towards more static strides along any dimension, don't fold.
469   for (auto it : llvm::zip(sourceStrides, resultStrides)) {
470     auto ss = std::get<0>(it), st = std::get<1>(it);
471     if (ss != st)
472       if (MemRefType::isDynamicStrideOrOffset(ss) &&
473           !MemRefType::isDynamicStrideOrOffset(st))
474         return false;
475   }
476 
477   return true;
478 }
479 
areCastCompatible(TypeRange inputs,TypeRange outputs)480 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
481   if (inputs.size() != 1 || outputs.size() != 1)
482     return false;
483   Type a = inputs.front(), b = outputs.front();
484   auto aT = a.dyn_cast<MemRefType>();
485   auto bT = b.dyn_cast<MemRefType>();
486 
487   auto uaT = a.dyn_cast<UnrankedMemRefType>();
488   auto ubT = b.dyn_cast<UnrankedMemRefType>();
489 
490   if (aT && bT) {
491     if (aT.getElementType() != bT.getElementType())
492       return false;
493     if (aT.getAffineMaps() != bT.getAffineMaps()) {
494       int64_t aOffset, bOffset;
495       SmallVector<int64_t, 4> aStrides, bStrides;
496       if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
497           failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
498           aStrides.size() != bStrides.size())
499         return false;
500 
501       // Strides along a dimension/offset are compatible if the value in the
502       // source memref is static and the value in the target memref is the
503       // same. They are also compatible if either one is dynamic (see
504       // description of MemRefCastOp for details).
505       auto checkCompatible = [](int64_t a, int64_t b) {
506         return (a == MemRefType::getDynamicStrideOrOffset() ||
507                 b == MemRefType::getDynamicStrideOrOffset() || a == b);
508       };
509       if (!checkCompatible(aOffset, bOffset))
510         return false;
511       for (auto aStride : enumerate(aStrides))
512         if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
513           return false;
514     }
515     if (aT.getMemorySpace() != bT.getMemorySpace())
516       return false;
517 
518     // They must have the same rank, and any specified dimensions must match.
519     if (aT.getRank() != bT.getRank())
520       return false;
521 
522     for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
523       int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
524       if (aDim != -1 && bDim != -1 && aDim != bDim)
525         return false;
526     }
527     return true;
528   } else {
529     if (!aT && !uaT)
530       return false;
531     if (!bT && !ubT)
532       return false;
533     // Unranked to unranked casting is unsupported
534     if (uaT && ubT)
535       return false;
536 
537     auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
538     auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
539     if (aEltType != bEltType)
540       return false;
541 
542     auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
543     auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
544     if (aMemSpace != bMemSpace)
545       return false;
546 
547     return true;
548   }
549 
550   return false;
551 }
552 
fold(ArrayRef<Attribute> operands)553 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
554   return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
555 }
556 
557 //===----------------------------------------------------------------------===//
558 // CloneOp
559 //===----------------------------------------------------------------------===//
560 
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)561 void CloneOp::getEffects(
562     SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
563         &effects) {
564   effects.emplace_back(MemoryEffects::Read::get(), input(),
565                        SideEffects::DefaultResource::get());
566   effects.emplace_back(MemoryEffects::Write::get(), output(),
567                        SideEffects::DefaultResource::get());
568   effects.emplace_back(MemoryEffects::Allocate::get(), output(),
569                        SideEffects::DefaultResource::get());
570 }
571 
572 namespace {
573 /// Merge the clone and its source (by converting the clone to a cast) when
574 /// possible.
575 struct SimplifyClones : public OpRewritePattern<CloneOp> {
576   using OpRewritePattern<CloneOp>::OpRewritePattern;
577 
matchAndRewrite__anonbe946e4a0811::SimplifyClones578   LogicalResult matchAndRewrite(CloneOp cloneOp,
579                                 PatternRewriter &rewriter) const override {
580     if (cloneOp.use_empty()) {
581       rewriter.eraseOp(cloneOp);
582       return success();
583     }
584 
585     Value source = cloneOp.input();
586 
587     // This only finds dealloc operations for the immediate value. It should
588     // also consider aliases. That would also make the safety check below
589     // redundant.
590     llvm::Optional<Operation *> maybeCloneDeallocOp =
591         findDealloc(cloneOp.output());
592     // Skip if either of them has > 1 deallocate operations.
593     if (!maybeCloneDeallocOp.hasValue())
594       return failure();
595     llvm::Optional<Operation *> maybeSourceDeallocOp = findDealloc(source);
596     if (!maybeSourceDeallocOp.hasValue())
597       return failure();
598     Operation *cloneDeallocOp = *maybeCloneDeallocOp;
599     Operation *sourceDeallocOp = *maybeSourceDeallocOp;
600 
601     // If both are deallocated in the same block, their in-block lifetimes
602     // might not fully overlap, so we cannot decide which one to drop.
603     if (cloneDeallocOp && sourceDeallocOp &&
604         cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock())
605       return failure();
606 
607     Block *currentBlock = cloneOp->getBlock();
608     Operation *redundantDealloc = nullptr;
609     if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) {
610       redundantDealloc = cloneDeallocOp;
611     } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) {
612       redundantDealloc = sourceDeallocOp;
613     }
614 
615     if (!redundantDealloc)
616       return failure();
617 
618     // Safety check that there are no other deallocations inbetween
619     // cloneOp and redundantDealloc, as otherwise we might deallocate an alias
620     // of source before the uses of the clone. With alias information, we could
621     // restrict this to only fail of the dealloc's operand is an alias
622     // of the source.
623     for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
624          pos = pos->getNextNode()) {
625       auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
626       if (!effectInterface)
627         continue;
628       if (effectInterface.hasEffect<MemoryEffects::Free>())
629         return failure();
630     }
631 
632     rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(),
633                                                 source);
634     rewriter.eraseOp(redundantDealloc);
635     return success();
636   }
637 };
638 
639 } // end anonymous namespace.
640 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)641 void CloneOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
642                                           MLIRContext *context) {
643   results.insert<SimplifyClones>(context);
644 }
645 
fold(ArrayRef<Attribute> operands)646 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) {
647   return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
648 }
649 
buildDealloc(OpBuilder & builder,Value alloc)650 Optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder, Value alloc) {
651   return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
652       .getOperation();
653 }
654 
buildClone(OpBuilder & builder,Value alloc)655 Optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) {
656   return builder.create<memref::CloneOp>(alloc.getLoc(), alloc).getResult();
657 }
658 
659 //===----------------------------------------------------------------------===//
660 // DeallocOp
661 //===----------------------------------------------------------------------===//
662 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)663 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
664                               SmallVectorImpl<OpFoldResult> &results) {
665   /// dealloc(memrefcast) -> dealloc
666   return foldMemRefCast(*this);
667 }
668 
669 //===----------------------------------------------------------------------===//
670 // DimOp
671 //===----------------------------------------------------------------------===//
672 
build(OpBuilder & builder,OperationState & result,Value source,int64_t index)673 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
674                   int64_t index) {
675   auto loc = result.location;
676   Value indexValue = builder.create<ConstantIndexOp>(loc, index);
677   build(builder, result, source, indexValue);
678 }
679 
build(OpBuilder & builder,OperationState & result,Value source,Value index)680 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
681                   Value index) {
682   auto indexTy = builder.getIndexType();
683   build(builder, result, indexTy, source, index);
684 }
685 
getConstantIndex()686 Optional<int64_t> DimOp::getConstantIndex() {
687   if (auto constantOp = index().getDefiningOp<ConstantOp>())
688     return constantOp.getValue().cast<IntegerAttr>().getInt();
689   return {};
690 }
691 
verify(DimOp op)692 static LogicalResult verify(DimOp op) {
693   // Assume unknown index to be in range.
694   Optional<int64_t> index = op.getConstantIndex();
695   if (!index.hasValue())
696     return success();
697 
698   // Check that constant index is not knowingly out of range.
699   auto type = op.source().getType();
700   if (auto memrefType = type.dyn_cast<MemRefType>()) {
701     if (index.getValue() >= memrefType.getRank())
702       return op.emitOpError("index is out of range");
703   } else if (type.isa<UnrankedMemRefType>()) {
704     // Assume index to be in range.
705   } else {
706     llvm_unreachable("expected operand with memref type");
707   }
708   return success();
709 }
710 
711 /// Return a map with key being elements in `vals` and data being number of
712 /// occurences of it. Use std::map, since the `vals` here are strides and the
713 /// dynamic stride value is the same as the tombstone value for
714 /// `DenseMap<int64_t>`.
getNumOccurences(ArrayRef<int64_t> vals)715 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
716   std::map<int64_t, unsigned> numOccurences;
717   for (auto val : vals)
718     numOccurences[val]++;
719   return numOccurences;
720 }
721 
722 /// Given the type of the un-rank reduced subview result type and the
723 /// rank-reduced result type, computes the dropped dimensions. This accounts for
724 /// cases where there are multiple unit-dims, but only a subset of those are
725 /// dropped. For MemRefTypes these can be disambiguated using the strides. If a
726 /// dimension is dropped the stride must be dropped too.
727 static llvm::Optional<llvm::SmallDenseSet<unsigned>>
computeMemRefRankReductionMask(MemRefType originalType,MemRefType reducedType,ArrayAttr staticSizes)728 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
729                                ArrayAttr staticSizes) {
730   llvm::SmallDenseSet<unsigned> unusedDims;
731   if (originalType.getRank() == reducedType.getRank())
732     return unusedDims;
733 
734   for (auto dim : llvm::enumerate(staticSizes))
735     if (dim.value().cast<IntegerAttr>().getInt() == 1)
736       unusedDims.insert(dim.index());
737   SmallVector<int64_t> originalStrides, candidateStrides;
738   int64_t originalOffset, candidateOffset;
739   if (failed(
740           getStridesAndOffset(originalType, originalStrides, originalOffset)) ||
741       failed(
742           getStridesAndOffset(reducedType, candidateStrides, candidateOffset)))
743     return llvm::None;
744 
745   // For memrefs, a dimension is truly dropped if its corresponding stride is
746   // also dropped. This is particularly important when more than one of the dims
747   // is 1. Track the number of occurences of the strides in the original type
748   // and the candidate type. For each unused dim that stride should not be
749   // present in the candidate type. Note that there could be multiple dimensions
750   // that have the same size. We dont need to exactly figure out which dim
751   // corresponds to which stride, we just need to verify that the number of
752   // reptitions of a stride in the original + number of unused dims with that
753   // stride == number of repititions of a stride in the candidate.
754   std::map<int64_t, unsigned> currUnaccountedStrides =
755       getNumOccurences(originalStrides);
756   std::map<int64_t, unsigned> candidateStridesNumOccurences =
757       getNumOccurences(candidateStrides);
758   llvm::SmallDenseSet<unsigned> prunedUnusedDims;
759   for (unsigned dim : unusedDims) {
760     int64_t originalStride = originalStrides[dim];
761     if (currUnaccountedStrides[originalStride] >
762         candidateStridesNumOccurences[originalStride]) {
763       // This dim can be treated as dropped.
764       currUnaccountedStrides[originalStride]--;
765       continue;
766     }
767     if (currUnaccountedStrides[originalStride] ==
768         candidateStridesNumOccurences[originalStride]) {
769       // The stride for this is not dropped. Keep as is.
770       prunedUnusedDims.insert(dim);
771       continue;
772     }
773     if (currUnaccountedStrides[originalStride] <
774         candidateStridesNumOccurences[originalStride]) {
775       // This should never happen. Cant have a stride in the reduced rank type
776       // that wasnt in the original one.
777       return llvm::None;
778     }
779   }
780 
781   for (auto prunedDim : prunedUnusedDims)
782     unusedDims.erase(prunedDim);
783   if (unusedDims.size() + reducedType.getRank() != originalType.getRank())
784     return llvm::None;
785   return unusedDims;
786 }
787 
getDroppedDims()788 llvm::SmallDenseSet<unsigned> SubViewOp::getDroppedDims() {
789   MemRefType sourceType = getSourceType();
790   MemRefType resultType = getType();
791   llvm::Optional<llvm::SmallDenseSet<unsigned>> unusedDims =
792       computeMemRefRankReductionMask(sourceType, resultType, static_sizes());
793   assert(unusedDims && "unable to find unused dims of subview");
794   return *unusedDims;
795 }
796 
fold(ArrayRef<Attribute> operands)797 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
798   // All forms of folding require a known index.
799   auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
800   if (!index)
801     return {};
802 
803   // Folding for unranked types (UnrankedMemRefType) is not supported.
804   auto memrefType = source().getType().dyn_cast<MemRefType>();
805   if (!memrefType)
806     return {};
807 
808   // Fold if the shape extent along the given index is known.
809   if (!memrefType.isDynamicDim(index.getInt())) {
810     Builder builder(getContext());
811     return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
812   }
813 
814   // The size at the given index is now known to be a dynamic size.
815   unsigned unsignedIndex = index.getValue().getZExtValue();
816 
817   // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
818   Operation *definingOp = source().getDefiningOp();
819 
820   if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
821     return *(alloc.getDynamicSizes().begin() +
822              memrefType.getDynamicDimIndex(unsignedIndex));
823 
824   if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
825     return *(alloca.getDynamicSizes().begin() +
826              memrefType.getDynamicDimIndex(unsignedIndex));
827 
828   if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
829     return *(view.getDynamicSizes().begin() +
830              memrefType.getDynamicDimIndex(unsignedIndex));
831 
832   if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
833     llvm::SmallDenseSet<unsigned> unusedDims = subview.getDroppedDims();
834     unsigned resultIndex = 0;
835     unsigned sourceRank = subview.getSourceType().getRank();
836     unsigned sourceIndex = 0;
837     for (auto i : llvm::seq<unsigned>(0, sourceRank)) {
838       if (unusedDims.count(i))
839         continue;
840       if (resultIndex == unsignedIndex) {
841         sourceIndex = i;
842         break;
843       }
844       resultIndex++;
845     }
846     assert(subview.isDynamicSize(sourceIndex) &&
847            "expected dynamic subview size");
848     return subview.getDynamicSize(sourceIndex);
849   }
850 
851   if (auto sizeInterface =
852           dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
853     assert(sizeInterface.isDynamicSize(unsignedIndex) &&
854            "Expected dynamic subview size");
855     return sizeInterface.getDynamicSize(unsignedIndex);
856   }
857 
858   // dim(memrefcast) -> dim
859   if (succeeded(foldMemRefCast(*this)))
860     return getResult();
861 
862   return {};
863 }
864 
865 namespace {
866 /// Fold dim of a memref reshape operation to a load into the reshape's shape
867 /// operand.
868 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
869   using OpRewritePattern<DimOp>::OpRewritePattern;
870 
matchAndRewrite__anonbe946e4a0911::DimOfMemRefReshape871   LogicalResult matchAndRewrite(DimOp dim,
872                                 PatternRewriter &rewriter) const override {
873     auto reshape = dim.source().getDefiningOp<ReshapeOp>();
874 
875     if (!reshape)
876       return failure();
877 
878     // Place the load directly after the reshape to ensure that the shape memref
879     // was not mutated.
880     rewriter.setInsertionPointAfter(reshape);
881     Location loc = dim.getLoc();
882     Value load = rewriter.create<LoadOp>(loc, reshape.shape(), dim.index());
883     if (load.getType() != dim.getType())
884       load = rewriter.create<IndexCastOp>(loc, dim.getType(), load);
885     rewriter.replaceOp(dim, load);
886     return success();
887   }
888 };
889 
890 /// Fold dim of a cast into the dim of the source of the memref cast.
891 struct DimOfCastOp : public OpRewritePattern<DimOp> {
892   using OpRewritePattern<DimOp>::OpRewritePattern;
893 
matchAndRewrite__anonbe946e4a0911::DimOfCastOp894   LogicalResult matchAndRewrite(DimOp dimOp,
895                                 PatternRewriter &rewriter) const override {
896     auto castOp = dimOp.source().getDefiningOp<BufferCastOp>();
897     if (!castOp)
898       return failure();
899     Value newSource = castOp.getOperand();
900     rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource, dimOp.index());
901     return success();
902   }
903 };
904 } // end anonymous namespace.
905 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)906 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
907                                         MLIRContext *context) {
908   results.add<DimOfMemRefReshape, DimOfCastOp>(context);
909 }
910 
911 // ---------------------------------------------------------------------------
912 // DmaStartOp
913 // ---------------------------------------------------------------------------
914 
build(OpBuilder & builder,OperationState & result,Value srcMemRef,ValueRange srcIndices,Value destMemRef,ValueRange destIndices,Value numElements,Value tagMemRef,ValueRange tagIndices,Value stride,Value elementsPerStride)915 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
916                        Value srcMemRef, ValueRange srcIndices, Value destMemRef,
917                        ValueRange destIndices, Value numElements,
918                        Value tagMemRef, ValueRange tagIndices, Value stride,
919                        Value elementsPerStride) {
920   result.addOperands(srcMemRef);
921   result.addOperands(srcIndices);
922   result.addOperands(destMemRef);
923   result.addOperands(destIndices);
924   result.addOperands({numElements, tagMemRef});
925   result.addOperands(tagIndices);
926   if (stride)
927     result.addOperands({stride, elementsPerStride});
928 }
929 
print(OpAsmPrinter & p,DmaStartOp op)930 static void print(OpAsmPrinter &p, DmaStartOp op) {
931   p << " " << op.getSrcMemRef() << '[' << op.getSrcIndices() << "], "
932     << op.getDstMemRef() << '[' << op.getDstIndices() << "], "
933     << op.getNumElements() << ", " << op.getTagMemRef() << '['
934     << op.getTagIndices() << ']';
935   if (op.isStrided())
936     p << ", " << op.getStride() << ", " << op.getNumElementsPerStride();
937 
938   p.printOptionalAttrDict(op->getAttrs());
939   p << " : " << op.getSrcMemRef().getType() << ", "
940     << op.getDstMemRef().getType() << ", " << op.getTagMemRef().getType();
941 }
942 
943 // Parse DmaStartOp.
944 // Ex:
945 //   %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
946 //                       %tag[%index], %stride, %num_elt_per_stride :
947 //                     : memref<3076 x f32, 0>,
948 //                       memref<1024 x f32, 2>,
949 //                       memref<1 x i32>
950 //
parseDmaStartOp(OpAsmParser & parser,OperationState & result)951 static ParseResult parseDmaStartOp(OpAsmParser &parser,
952                                    OperationState &result) {
953   OpAsmParser::OperandType srcMemRefInfo;
954   SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos;
955   OpAsmParser::OperandType dstMemRefInfo;
956   SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos;
957   OpAsmParser::OperandType numElementsInfo;
958   OpAsmParser::OperandType tagMemrefInfo;
959   SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
960   SmallVector<OpAsmParser::OperandType, 2> strideInfo;
961 
962   SmallVector<Type, 3> types;
963   auto indexType = parser.getBuilder().getIndexType();
964 
965   // Parse and resolve the following list of operands:
966   // *) source memref followed by its indices (in square brackets).
967   // *) destination memref followed by its indices (in square brackets).
968   // *) dma size in KiB.
969   if (parser.parseOperand(srcMemRefInfo) ||
970       parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
971       parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
972       parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
973       parser.parseComma() || parser.parseOperand(numElementsInfo) ||
974       parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
975       parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
976     return failure();
977 
978   // Parse optional stride and elements per stride.
979   if (parser.parseTrailingOperandList(strideInfo))
980     return failure();
981 
982   bool isStrided = strideInfo.size() == 2;
983   if (!strideInfo.empty() && !isStrided) {
984     return parser.emitError(parser.getNameLoc(),
985                             "expected two stride related operands");
986   }
987 
988   if (parser.parseColonTypeList(types))
989     return failure();
990   if (types.size() != 3)
991     return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
992 
993   if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
994       parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
995       parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
996       parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
997       // size should be an index.
998       parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
999       parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1000       // tag indices should be index.
1001       parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1002     return failure();
1003 
1004   if (isStrided) {
1005     if (parser.resolveOperands(strideInfo, indexType, result.operands))
1006       return failure();
1007   }
1008 
1009   return success();
1010 }
1011 
verify(DmaStartOp op)1012 static LogicalResult verify(DmaStartOp op) {
1013   unsigned numOperands = op.getNumOperands();
1014 
1015   // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1016   // the number of elements.
1017   if (numOperands < 4)
1018     return op.emitOpError("expected at least 4 operands");
1019 
1020   // Check types of operands. The order of these calls is important: the later
1021   // calls rely on some type properties to compute the operand position.
1022   // 1. Source memref.
1023   if (!op.getSrcMemRef().getType().isa<MemRefType>())
1024     return op.emitOpError("expected source to be of memref type");
1025   if (numOperands < op.getSrcMemRefRank() + 4)
1026     return op.emitOpError()
1027            << "expected at least " << op.getSrcMemRefRank() + 4 << " operands";
1028   if (!op.getSrcIndices().empty() &&
1029       !llvm::all_of(op.getSrcIndices().getTypes(),
1030                     [](Type t) { return t.isIndex(); }))
1031     return op.emitOpError("expected source indices to be of index type");
1032 
1033   // 2. Destination memref.
1034   if (!op.getDstMemRef().getType().isa<MemRefType>())
1035     return op.emitOpError("expected destination to be of memref type");
1036   unsigned numExpectedOperands =
1037       op.getSrcMemRefRank() + op.getDstMemRefRank() + 4;
1038   if (numOperands < numExpectedOperands)
1039     return op.emitOpError()
1040            << "expected at least " << numExpectedOperands << " operands";
1041   if (!op.getDstIndices().empty() &&
1042       !llvm::all_of(op.getDstIndices().getTypes(),
1043                     [](Type t) { return t.isIndex(); }))
1044     return op.emitOpError("expected destination indices to be of index type");
1045 
1046   // 3. Number of elements.
1047   if (!op.getNumElements().getType().isIndex())
1048     return op.emitOpError("expected num elements to be of index type");
1049 
1050   // 4. Tag memref.
1051   if (!op.getTagMemRef().getType().isa<MemRefType>())
1052     return op.emitOpError("expected tag to be of memref type");
1053   numExpectedOperands += op.getTagMemRefRank();
1054   if (numOperands < numExpectedOperands)
1055     return op.emitOpError()
1056            << "expected at least " << numExpectedOperands << " operands";
1057   if (!op.getTagIndices().empty() &&
1058       !llvm::all_of(op.getTagIndices().getTypes(),
1059                     [](Type t) { return t.isIndex(); }))
1060     return op.emitOpError("expected tag indices to be of index type");
1061 
1062   // Optional stride-related operands must be either both present or both
1063   // absent.
1064   if (numOperands != numExpectedOperands &&
1065       numOperands != numExpectedOperands + 2)
1066     return op.emitOpError("incorrect number of operands");
1067 
1068   // 5. Strides.
1069   if (op.isStrided()) {
1070     if (!op.getStride().getType().isIndex() ||
1071         !op.getNumElementsPerStride().getType().isIndex())
1072       return op.emitOpError(
1073           "expected stride and num elements per stride to be of type index");
1074   }
1075 
1076   return success();
1077 }
1078 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1079 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1080                                SmallVectorImpl<OpFoldResult> &results) {
1081   /// dma_start(memrefcast) -> dma_start
1082   return foldMemRefCast(*this);
1083 }
1084 
1085 // ---------------------------------------------------------------------------
1086 // DmaWaitOp
1087 // ---------------------------------------------------------------------------
1088 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1089 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1090                               SmallVectorImpl<OpFoldResult> &results) {
1091   /// dma_wait(memrefcast) -> dma_wait
1092   return foldMemRefCast(*this);
1093 }
1094 
verify(DmaWaitOp op)1095 static LogicalResult verify(DmaWaitOp op) {
1096   // Check that the number of tag indices matches the tagMemRef rank.
1097   unsigned numTagIndices = op.tagIndices().size();
1098   unsigned tagMemRefRank = op.getTagMemRefRank();
1099   if (numTagIndices != tagMemRefRank)
1100     return op.emitOpError() << "expected tagIndices to have the same number of "
1101                                "elements as the tagMemRef rank, expected "
1102                             << tagMemRefRank << ", but got " << numTagIndices;
1103   return success();
1104 }
1105 
1106 //===----------------------------------------------------------------------===//
1107 // GlobalOp
1108 //===----------------------------------------------------------------------===//
1109 
printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter & p,GlobalOp op,TypeAttr type,Attribute initialValue)1110 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op,
1111                                                    TypeAttr type,
1112                                                    Attribute initialValue) {
1113   p << type;
1114   if (!op.isExternal()) {
1115     p << " = ";
1116     if (op.isUninitialized())
1117       p << "uninitialized";
1118     else
1119       p.printAttributeWithoutType(initialValue);
1120   }
1121 }
1122 
1123 static ParseResult
parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser & parser,TypeAttr & typeAttr,Attribute & initialValue)1124 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
1125                                        Attribute &initialValue) {
1126   Type type;
1127   if (parser.parseType(type))
1128     return failure();
1129 
1130   auto memrefType = type.dyn_cast<MemRefType>();
1131   if (!memrefType || !memrefType.hasStaticShape())
1132     return parser.emitError(parser.getNameLoc())
1133            << "type should be static shaped memref, but got " << type;
1134   typeAttr = TypeAttr::get(type);
1135 
1136   if (parser.parseOptionalEqual())
1137     return success();
1138 
1139   if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1140     initialValue = UnitAttr::get(parser.getContext());
1141     return success();
1142   }
1143 
1144   Type tensorType = getTensorTypeFromMemRefType(memrefType);
1145   if (parser.parseAttribute(initialValue, tensorType))
1146     return failure();
1147   if (!initialValue.isa<ElementsAttr>())
1148     return parser.emitError(parser.getNameLoc())
1149            << "initial value should be a unit or elements attribute";
1150   return success();
1151 }
1152 
verify(GlobalOp op)1153 static LogicalResult verify(GlobalOp op) {
1154   auto memrefType = op.type().dyn_cast<MemRefType>();
1155   if (!memrefType || !memrefType.hasStaticShape())
1156     return op.emitOpError("type should be static shaped memref, but got ")
1157            << op.type();
1158 
1159   // Verify that the initial value, if present, is either a unit attribute or
1160   // an elements attribute.
1161   if (op.initial_value().hasValue()) {
1162     Attribute initValue = op.initial_value().getValue();
1163     if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>())
1164       return op.emitOpError("initial value should be a unit or elements "
1165                             "attribute, but got ")
1166              << initValue;
1167 
1168     // Check that the type of the initial value is compatible with the type of
1169     // the global variable.
1170     if (initValue.isa<ElementsAttr>()) {
1171       Type initType = initValue.getType();
1172       Type tensorType = getTensorTypeFromMemRefType(memrefType);
1173       if (initType != tensorType)
1174         return op.emitOpError("initial value expected to be of type ")
1175                << tensorType << ", but was of type " << initType;
1176     }
1177   }
1178 
1179   if (Optional<uint64_t> alignAttr = op.alignment()) {
1180     uint64_t alignment = alignAttr.getValue();
1181 
1182     if (!llvm::isPowerOf2_64(alignment))
1183       return op->emitError() << "alignment attribute value " << alignment
1184                              << " is not a power of 2";
1185   }
1186 
1187   // TODO: verify visibility for declarations.
1188   return success();
1189 }
1190 
1191 //===----------------------------------------------------------------------===//
1192 // GetGlobalOp
1193 //===----------------------------------------------------------------------===//
1194 
1195 LogicalResult
verifySymbolUses(SymbolTableCollection & symbolTable)1196 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1197   // Verify that the result type is same as the type of the referenced
1198   // memref.global op.
1199   auto global =
1200       symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr());
1201   if (!global)
1202     return emitOpError("'")
1203            << name() << "' does not reference a valid global memref";
1204 
1205   Type resultType = result().getType();
1206   if (global.type() != resultType)
1207     return emitOpError("result type ")
1208            << resultType << " does not match type " << global.type()
1209            << " of the global memref @" << name();
1210   return success();
1211 }
1212 
1213 //===----------------------------------------------------------------------===//
1214 // LoadOp
1215 //===----------------------------------------------------------------------===//
1216 
verify(LoadOp op)1217 static LogicalResult verify(LoadOp op) {
1218   if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
1219     return op.emitOpError("incorrect number of indices for load");
1220   return success();
1221 }
1222 
fold(ArrayRef<Attribute> cstOperands)1223 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
1224   /// load(memrefcast) -> load
1225   if (succeeded(foldMemRefCast(*this)))
1226     return getResult();
1227   return OpFoldResult();
1228 }
1229 
1230 namespace {
1231 /// Fold a load on a buffer_cast operation into an tensor.extract on the
1232 /// corresponding tensor.
1233 struct LoadOfBufferCast : public OpRewritePattern<LoadOp> {
1234   using OpRewritePattern<LoadOp>::OpRewritePattern;
1235 
matchAndRewrite__anonbe946e4a0d11::LoadOfBufferCast1236   LogicalResult matchAndRewrite(LoadOp load,
1237                                 PatternRewriter &rewriter) const override {
1238     auto buffercast = load.memref().getDefiningOp<BufferCastOp>();
1239     if (!buffercast)
1240       return failure();
1241 
1242     rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, buffercast.tensor(),
1243                                                    load.indices());
1244     return success();
1245   }
1246 };
1247 } // end anonymous namespace.
1248 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1249 void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
1250                                          MLIRContext *context) {
1251   results.add<LoadOfBufferCast>(context);
1252 }
1253 
1254 //===----------------------------------------------------------------------===//
1255 // PrefetchOp
1256 //===----------------------------------------------------------------------===//
1257 
print(OpAsmPrinter & p,PrefetchOp op)1258 static void print(OpAsmPrinter &p, PrefetchOp op) {
1259   p << " " << op.memref() << '[';
1260   p.printOperands(op.indices());
1261   p << ']' << ", " << (op.isWrite() ? "write" : "read");
1262   p << ", locality<" << op.localityHint();
1263   p << ">, " << (op.isDataCache() ? "data" : "instr");
1264   p.printOptionalAttrDict(
1265       op->getAttrs(),
1266       /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1267   p << " : " << op.getMemRefType();
1268 }
1269 
parsePrefetchOp(OpAsmParser & parser,OperationState & result)1270 static ParseResult parsePrefetchOp(OpAsmParser &parser,
1271                                    OperationState &result) {
1272   OpAsmParser::OperandType memrefInfo;
1273   SmallVector<OpAsmParser::OperandType, 4> indexInfo;
1274   IntegerAttr localityHint;
1275   MemRefType type;
1276   StringRef readOrWrite, cacheType;
1277 
1278   auto indexTy = parser.getBuilder().getIndexType();
1279   auto i32Type = parser.getBuilder().getIntegerType(32);
1280   if (parser.parseOperand(memrefInfo) ||
1281       parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1282       parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1283       parser.parseComma() || parser.parseKeyword("locality") ||
1284       parser.parseLess() ||
1285       parser.parseAttribute(localityHint, i32Type, "localityHint",
1286                             result.attributes) ||
1287       parser.parseGreater() || parser.parseComma() ||
1288       parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1289       parser.resolveOperand(memrefInfo, type, result.operands) ||
1290       parser.resolveOperands(indexInfo, indexTy, result.operands))
1291     return failure();
1292 
1293   if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
1294     return parser.emitError(parser.getNameLoc(),
1295                             "rw specifier has to be 'read' or 'write'");
1296   result.addAttribute(
1297       PrefetchOp::getIsWriteAttrName(),
1298       parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
1299 
1300   if (!cacheType.equals("data") && !cacheType.equals("instr"))
1301     return parser.emitError(parser.getNameLoc(),
1302                             "cache type has to be 'data' or 'instr'");
1303 
1304   result.addAttribute(
1305       PrefetchOp::getIsDataCacheAttrName(),
1306       parser.getBuilder().getBoolAttr(cacheType.equals("data")));
1307 
1308   return success();
1309 }
1310 
verify(PrefetchOp op)1311 static LogicalResult verify(PrefetchOp op) {
1312   if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
1313     return op.emitOpError("too few indices");
1314 
1315   return success();
1316 }
1317 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1318 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
1319                                SmallVectorImpl<OpFoldResult> &results) {
1320   // prefetch(memrefcast) -> prefetch
1321   return foldMemRefCast(*this);
1322 }
1323 
1324 //===----------------------------------------------------------------------===//
1325 // ReinterpretCastOp
1326 //===----------------------------------------------------------------------===//
1327 
1328 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1329 /// `staticSizes` and `staticStrides` are automatically filled with
1330 /// source-memref-rank sentinel values that encode dynamic entries.
build(OpBuilder & b,OperationState & result,MemRefType resultType,Value source,OpFoldResult offset,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides,ArrayRef<NamedAttribute> attrs)1331 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1332                               MemRefType resultType, Value source,
1333                               OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1334                               ArrayRef<OpFoldResult> strides,
1335                               ArrayRef<NamedAttribute> attrs) {
1336   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1337   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1338   dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets,
1339                              ShapedType::kDynamicStrideOrOffset);
1340   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1341                              ShapedType::kDynamicSize);
1342   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1343                              ShapedType::kDynamicStrideOrOffset);
1344   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1345         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1346         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1347   result.addAttributes(attrs);
1348 }
1349 
build(OpBuilder & b,OperationState & result,MemRefType resultType,Value source,int64_t offset,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides,ArrayRef<NamedAttribute> attrs)1350 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1351                               MemRefType resultType, Value source,
1352                               int64_t offset, ArrayRef<int64_t> sizes,
1353                               ArrayRef<int64_t> strides,
1354                               ArrayRef<NamedAttribute> attrs) {
1355   SmallVector<OpFoldResult> sizeValues =
1356       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1357         return b.getI64IntegerAttr(v);
1358       }));
1359   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1360       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1361         return b.getI64IntegerAttr(v);
1362       }));
1363   build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1364         strideValues, attrs);
1365 }
1366 
build(OpBuilder & b,OperationState & result,MemRefType resultType,Value source,Value offset,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)1367 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1368                               MemRefType resultType, Value source, Value offset,
1369                               ValueRange sizes, ValueRange strides,
1370                               ArrayRef<NamedAttribute> attrs) {
1371   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1372       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1373   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1374       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1375   build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1376 }
1377 
1378 // TODO: ponder whether we want to allow missing trailing sizes/strides that are
1379 // completed automatically, like we have for subview and extract_slice.
verify(ReinterpretCastOp op)1380 static LogicalResult verify(ReinterpretCastOp op) {
1381   // The source and result memrefs should be in the same memory space.
1382   auto srcType = op.source().getType().cast<BaseMemRefType>();
1383   auto resultType = op.getType().cast<MemRefType>();
1384   if (srcType.getMemorySpace() != resultType.getMemorySpace())
1385     return op.emitError("different memory spaces specified for source type ")
1386            << srcType << " and result memref type " << resultType;
1387   if (srcType.getElementType() != resultType.getElementType())
1388     return op.emitError("different element types specified for source type ")
1389            << srcType << " and result memref type " << resultType;
1390 
1391   // Match sizes in result memref type and in static_sizes attribute.
1392   for (auto &en :
1393        llvm::enumerate(llvm::zip(resultType.getShape(),
1394                                  extractFromI64ArrayAttr(op.static_sizes())))) {
1395     int64_t resultSize = std::get<0>(en.value());
1396     int64_t expectedSize = std::get<1>(en.value());
1397     if (resultSize != expectedSize)
1398       return op.emitError("expected result type with size = ")
1399              << expectedSize << " instead of " << resultSize
1400              << " in dim = " << en.index();
1401   }
1402 
1403   // Match offset and strides in static_offset and static_strides attributes if
1404   // result memref type has an affine map specified.
1405   if (!resultType.getAffineMaps().empty()) {
1406     int64_t resultOffset;
1407     SmallVector<int64_t, 4> resultStrides;
1408     if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
1409       return failure();
1410 
1411     // Match offset in result memref type and in static_offsets attribute.
1412     int64_t expectedOffset =
1413         extractFromI64ArrayAttr(op.static_offsets()).front();
1414     if (resultOffset != expectedOffset)
1415       return op.emitError("expected result type with offset = ")
1416              << resultOffset << " instead of " << expectedOffset;
1417 
1418     // Match strides in result memref type and in static_strides attribute.
1419     for (auto &en : llvm::enumerate(llvm::zip(
1420              resultStrides, extractFromI64ArrayAttr(op.static_strides())))) {
1421       int64_t resultStride = std::get<0>(en.value());
1422       int64_t expectedStride = std::get<1>(en.value());
1423       if (resultStride != expectedStride)
1424         return op.emitError("expected result type with stride = ")
1425                << expectedStride << " instead of " << resultStride
1426                << " in dim = " << en.index();
1427     }
1428   }
1429   return success();
1430 }
1431 
1432 //===----------------------------------------------------------------------===//
1433 // Reassociative reshape ops
1434 //===----------------------------------------------------------------------===//
1435 
getReassociationMaps()1436 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1437   return getSymbolLessAffineMaps(getReassociationExprs());
1438 }
getReassociationExprs()1439 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1440   return convertReassociationIndicesToExprs(getContext(),
1441                                             getReassociationIndices());
1442 }
1443 
getReassociationMaps()1444 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1445   return getSymbolLessAffineMaps(getReassociationExprs());
1446 }
getReassociationExprs()1447 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1448   return convertReassociationIndicesToExprs(getContext(),
1449                                             getReassociationIndices());
1450 }
1451 
print(OpAsmPrinter & p,ExpandShapeOp op)1452 static void print(OpAsmPrinter &p, ExpandShapeOp op) {
1453   ::mlir::printReshapeOp<ExpandShapeOp>(p, op);
1454 }
1455 
print(OpAsmPrinter & p,CollapseShapeOp op)1456 static void print(OpAsmPrinter &p, CollapseShapeOp op) {
1457   ::mlir::printReshapeOp<CollapseShapeOp>(p, op);
1458 }
1459 
1460 /// Detect whether memref dims [dim, dim + extent) can be reshaped without
1461 /// copies.
isReshapableDimBand(unsigned dim,unsigned extent,ArrayRef<int64_t> sizes,ArrayRef<AffineExpr> strides)1462 static bool isReshapableDimBand(unsigned dim, unsigned extent,
1463                                 ArrayRef<int64_t> sizes,
1464                                 ArrayRef<AffineExpr> strides) {
1465   assert(sizes.size() == strides.size() && "mismatched ranks");
1466   // off by 1 indexing to avoid out of bounds
1467   //                       V
1468   for (auto idx = dim, e = dim + extent; idx + 1 < e; ++idx) {
1469     // Only bands of static shapes are reshapable. This is due to the fact that
1470     // there is no relation between dynamic sizes and dynamic strides: we do not
1471     // have enough information to know whether a "-1" size corresponds to the
1472     // proper symbol in the AffineExpr of a stride.
1473     if (ShapedType::isDynamic(sizes[dim + 1]))
1474       return false;
1475     // TODO: Refine this by passing the proper nDims and nSymbols so we can
1476     // simplify on the fly and catch more reshapable cases.
1477     if (strides[idx] != strides[idx + 1] * sizes[idx + 1])
1478       return false;
1479   }
1480   return true;
1481 }
1482 
1483 /// Compute the MemRefType obtained by applying the `reassociation` (which is
1484 /// expected to be valid) to `type`.
1485 /// If `type` is Contiguous MemRefType, this always produce a contiguous
1486 /// MemRefType.
1487 static MemRefType
computeReshapeCollapsedType(MemRefType type,ArrayRef<AffineMap> reassociation)1488 computeReshapeCollapsedType(MemRefType type,
1489                             ArrayRef<AffineMap> reassociation) {
1490   auto sizes = type.getShape();
1491   AffineExpr offset;
1492   SmallVector<AffineExpr, 4> strides;
1493   auto status = getStridesAndOffset(type, strides, offset);
1494   (void)status;
1495   assert(succeeded(status) && "expected strided memref");
1496 
1497   SmallVector<int64_t, 4> newSizes;
1498   newSizes.reserve(reassociation.size());
1499   SmallVector<AffineExpr, 4> newStrides;
1500   newStrides.reserve(reassociation.size());
1501 
1502   // Use the fact that reassociation is valid to simplify the logic: only use
1503   // each map's rank.
1504   assert(isReassociationValid(reassociation) && "invalid reassociation");
1505   unsigned currentDim = 0;
1506   for (AffineMap m : reassociation) {
1507     unsigned dim = m.getNumResults();
1508     int64_t size = 1;
1509     AffineExpr stride = strides[currentDim + dim - 1];
1510     if (!isReshapableDimBand(currentDim, dim, sizes, strides)) {
1511       size = ShapedType::kDynamicSize;
1512       stride = AffineExpr();
1513     } else {
1514       for (unsigned d = 0; d < dim; ++d)
1515         size *= sizes[currentDim + d];
1516     }
1517     newSizes.push_back(size);
1518     newStrides.push_back(stride);
1519     currentDim += dim;
1520   }
1521 
1522   // Early-exit: if `type` is contiguous, the result must be contiguous.
1523   if (canonicalizeStridedLayout(type).getAffineMaps().empty())
1524     return MemRefType::Builder(type).setShape(newSizes).setAffineMaps({});
1525 
1526   // Convert back to int64_t because we don't have enough information to create
1527   // new strided layouts from AffineExpr only. This corresponds to a case where
1528   // copies may be necessary.
1529   int64_t intOffset = ShapedType::kDynamicStrideOrOffset;
1530   if (auto o = offset.dyn_cast<AffineConstantExpr>())
1531     intOffset = o.getValue();
1532   SmallVector<int64_t, 4> intStrides;
1533   intStrides.reserve(strides.size());
1534   for (auto stride : newStrides) {
1535     if (auto cst = stride.dyn_cast_or_null<AffineConstantExpr>())
1536       intStrides.push_back(cst.getValue());
1537     else
1538       intStrides.push_back(ShapedType::kDynamicStrideOrOffset);
1539   }
1540   auto layout =
1541       makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext());
1542   return canonicalizeStridedLayout(
1543       MemRefType::Builder(type).setShape(newSizes).setAffineMaps({layout}));
1544 }
1545 
build(OpBuilder & b,OperationState & result,Value src,ArrayRef<ReassociationIndices> reassociation,ArrayRef<NamedAttribute> attrs)1546 void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src,
1547                           ArrayRef<ReassociationIndices> reassociation,
1548                           ArrayRef<NamedAttribute> attrs) {
1549   auto memRefType = src.getType().cast<MemRefType>();
1550   auto resultType = computeReshapeCollapsedType(
1551       memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
1552                       b.getContext(), reassociation)));
1553   build(b, result, resultType, src, attrs);
1554   result.addAttribute(getReassociationAttrName(),
1555                       getReassociationIndicesAttribute(b, reassociation));
1556 }
1557 
build(OpBuilder & b,OperationState & result,Value src,ArrayRef<ReassociationIndices> reassociation,ArrayRef<NamedAttribute> attrs)1558 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
1559                             ArrayRef<ReassociationIndices> reassociation,
1560                             ArrayRef<NamedAttribute> attrs) {
1561   auto memRefType = src.getType().cast<MemRefType>();
1562   auto resultType = computeReshapeCollapsedType(
1563       memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
1564                       b.getContext(), reassociation)));
1565   build(b, result, resultType, src, attrs);
1566   result.addAttribute(getReassociationAttrName(),
1567                       getReassociationIndicesAttribute(b, reassociation));
1568 }
1569 
1570 template <typename ReshapeOp,
1571           bool isExpansion = std::is_same<ReshapeOp, ExpandShapeOp>::value>
verifyReshapeOp(ReshapeOp op,MemRefType expandedType,MemRefType collapsedType)1572 static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType,
1573                                      MemRefType collapsedType) {
1574   if (failed(
1575           verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
1576     return failure();
1577   auto maps = op.getReassociationMaps();
1578   MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps);
1579   if (collapsedType != expectedType)
1580     return op.emitOpError("expected collapsed type to be ")
1581            << expectedType << ", but got " << collapsedType;
1582   return success();
1583 }
1584 
verify(ExpandShapeOp op)1585 static LogicalResult verify(ExpandShapeOp op) {
1586   return verifyReshapeOp(op, op.getResultType(), op.getSrcType());
1587 }
1588 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1589 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1590                                                 MLIRContext *context) {
1591   results.add<CollapseReshapeOps<ExpandShapeOp>,
1592               CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>>(context);
1593 }
1594 
verify(CollapseShapeOp op)1595 static LogicalResult verify(CollapseShapeOp op) {
1596   return verifyReshapeOp(op, op.getSrcType(), op.getResultType());
1597 }
1598 
1599 struct CollapseShapeOpMemRefCastFolder
1600     : public OpRewritePattern<CollapseShapeOp> {
1601 public:
1602   using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
1603 
matchAndRewriteCollapseShapeOpMemRefCastFolder1604   LogicalResult matchAndRewrite(CollapseShapeOp op,
1605                                 PatternRewriter &rewriter) const override {
1606     auto cast = op.getOperand().getDefiningOp<CastOp>();
1607     if (!cast)
1608       return failure();
1609 
1610     if (!CastOp::canFoldIntoConsumerOp(cast))
1611       return failure();
1612 
1613     Type newResultType = computeReshapeCollapsedType(
1614         cast.getOperand().getType().cast<MemRefType>(),
1615         op.getReassociationMaps());
1616 
1617     if (newResultType == op.getResultType()) {
1618       rewriter.updateRootInPlace(
1619           op, [&]() { op.srcMutable().assign(cast.source()); });
1620     } else {
1621       Value newOp = rewriter.create<CollapseShapeOp>(
1622           op->getLoc(), cast.source(), op.getReassociationIndices());
1623       rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
1624     }
1625     return success();
1626   }
1627 };
1628 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1629 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1630                                                   MLIRContext *context) {
1631   results.add<CollapseReshapeOps<CollapseShapeOp>,
1632               CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>,
1633               CollapseShapeOpMemRefCastFolder>(context);
1634 }
fold(ArrayRef<Attribute> operands)1635 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
1636   if (succeeded(foldMemRefCast(*this)))
1637     return getResult();
1638   return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
1639 }
fold(ArrayRef<Attribute> operands)1640 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
1641   return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands);
1642 }
1643 
1644 //===----------------------------------------------------------------------===//
1645 // ReshapeOp
1646 //===----------------------------------------------------------------------===//
1647 
verify(ReshapeOp op)1648 static LogicalResult verify(ReshapeOp op) {
1649   Type operandType = op.source().getType();
1650   Type resultType = op.result().getType();
1651 
1652   Type operandElementType = operandType.cast<ShapedType>().getElementType();
1653   Type resultElementType = resultType.cast<ShapedType>().getElementType();
1654   if (operandElementType != resultElementType)
1655     return op.emitOpError("element types of source and destination memref "
1656                           "types should be the same");
1657 
1658   if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
1659     if (!operandMemRefType.getAffineMaps().empty())
1660       return op.emitOpError(
1661           "source memref type should have identity affine map");
1662 
1663   int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0);
1664   auto resultMemRefType = resultType.dyn_cast<MemRefType>();
1665   if (resultMemRefType) {
1666     if (!resultMemRefType.getAffineMaps().empty())
1667       return op.emitOpError(
1668           "result memref type should have identity affine map");
1669     if (shapeSize == ShapedType::kDynamicSize)
1670       return op.emitOpError("cannot use shape operand with dynamic length to "
1671                             "reshape to statically-ranked memref type");
1672     if (shapeSize != resultMemRefType.getRank())
1673       return op.emitOpError(
1674           "length of shape operand differs from the result's memref rank");
1675   }
1676   return success();
1677 }
1678 
1679 //===----------------------------------------------------------------------===//
1680 // StoreOp
1681 //===----------------------------------------------------------------------===//
1682 
verify(StoreOp op)1683 static LogicalResult verify(StoreOp op) {
1684   if (op.getNumOperands() != 2 + op.getMemRefType().getRank())
1685     return op.emitOpError("store index operand count not equal to memref rank");
1686 
1687   return success();
1688 }
1689 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1690 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
1691                             SmallVectorImpl<OpFoldResult> &results) {
1692   /// store(memrefcast) -> store
1693   return foldMemRefCast(*this, getValueToStore());
1694 }
1695 
1696 //===----------------------------------------------------------------------===//
1697 // SubViewOp
1698 //===----------------------------------------------------------------------===//
1699 
1700 namespace {
1701 /// Helpers to write more idiomatic operations.
1702 namespace saturated_arith {
1703 struct Wrapper {
Wrapper__anonbe946e4a1311::saturated_arith::Wrapper1704   explicit Wrapper(int64_t v) : v(v) {}
operator int64_t__anonbe946e4a1311::saturated_arith::Wrapper1705   operator int64_t() { return v; }
1706   int64_t v;
1707 };
operator +(Wrapper a,int64_t b)1708 Wrapper operator+(Wrapper a, int64_t b) {
1709   if (ShapedType::isDynamicStrideOrOffset(a) ||
1710       ShapedType::isDynamicStrideOrOffset(b))
1711     return Wrapper(ShapedType::kDynamicStrideOrOffset);
1712   return Wrapper(a.v + b);
1713 }
operator *(Wrapper a,int64_t b)1714 Wrapper operator*(Wrapper a, int64_t b) {
1715   if (ShapedType::isDynamicStrideOrOffset(a) ||
1716       ShapedType::isDynamicStrideOrOffset(b))
1717     return Wrapper(ShapedType::kDynamicStrideOrOffset);
1718   return Wrapper(a.v * b);
1719 }
1720 } // end namespace saturated_arith
1721 } // end namespace
1722 
1723 /// A subview result type can be fully inferred from the source type and the
1724 /// static representation of offsets, sizes and strides. Special sentinels
1725 /// encode the dynamic case.
inferResultType(MemRefType sourceMemRefType,ArrayRef<int64_t> leadingStaticOffsets,ArrayRef<int64_t> leadingStaticSizes,ArrayRef<int64_t> leadingStaticStrides)1726 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
1727                                 ArrayRef<int64_t> leadingStaticOffsets,
1728                                 ArrayRef<int64_t> leadingStaticSizes,
1729                                 ArrayRef<int64_t> leadingStaticStrides) {
1730   // A subview may specify only a leading subset of offset/sizes/strides in
1731   // which case we complete with offset=0, sizes from memref type and strides=1.
1732   unsigned rank = sourceMemRefType.getRank();
1733   assert(leadingStaticOffsets.size() <= rank &&
1734          "unexpected leadingStaticOffsets overflow");
1735   assert(leadingStaticSizes.size() <= rank &&
1736          "unexpected leadingStaticSizes overflow");
1737   assert(leadingStaticStrides.size() <= rank &&
1738          "unexpected leadingStaticStrides overflow");
1739   auto staticOffsets = llvm::to_vector<4>(leadingStaticOffsets);
1740   auto staticSizes = llvm::to_vector<4>(leadingStaticSizes);
1741   auto staticStrides = llvm::to_vector<4>(leadingStaticStrides);
1742   unsigned numTrailingOffsets = rank - staticOffsets.size();
1743   unsigned numTrailingSizes = rank - staticSizes.size();
1744   unsigned numTrailingStrides = rank - staticStrides.size();
1745   staticOffsets.append(numTrailingOffsets, 0);
1746   llvm::append_range(staticSizes,
1747                      sourceMemRefType.getShape().take_back(numTrailingSizes));
1748   staticStrides.append(numTrailingStrides, 1);
1749 
1750   // Extract source offset and strides.
1751   int64_t sourceOffset;
1752   SmallVector<int64_t, 4> sourceStrides;
1753   auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset);
1754   assert(succeeded(res) && "SubViewOp expected strided memref type");
1755   (void)res;
1756 
1757   // Compute target offset whose value is:
1758   //   `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
1759   int64_t targetOffset = sourceOffset;
1760   for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
1761     auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
1762     using namespace saturated_arith;
1763     targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride;
1764   }
1765 
1766   // Compute target stride whose value is:
1767   //   `sourceStrides_i * staticStrides_i`.
1768   SmallVector<int64_t, 4> targetStrides;
1769   targetStrides.reserve(staticOffsets.size());
1770   for (auto it : llvm::zip(sourceStrides, staticStrides)) {
1771     auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
1772     using namespace saturated_arith;
1773     targetStrides.push_back(Wrapper(sourceStride) * staticStride);
1774   }
1775 
1776   // The type is now known.
1777   return MemRefType::get(
1778       staticSizes, sourceMemRefType.getElementType(),
1779       makeStridedLinearLayoutMap(targetStrides, targetOffset,
1780                                  sourceMemRefType.getContext()),
1781       sourceMemRefType.getMemorySpace());
1782 }
1783 
inferResultType(MemRefType sourceMemRefType,ArrayRef<OpFoldResult> leadingStaticOffsets,ArrayRef<OpFoldResult> leadingStaticSizes,ArrayRef<OpFoldResult> leadingStaticStrides)1784 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
1785                                 ArrayRef<OpFoldResult> leadingStaticOffsets,
1786                                 ArrayRef<OpFoldResult> leadingStaticSizes,
1787                                 ArrayRef<OpFoldResult> leadingStaticStrides) {
1788   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1789   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1790   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
1791                              staticOffsets, ShapedType::kDynamicStrideOrOffset);
1792   dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
1793                              ShapedType::kDynamicSize);
1794   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
1795                              staticStrides, ShapedType::kDynamicStrideOrOffset);
1796   return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
1797                                     staticSizes, staticStrides)
1798       .cast<MemRefType>();
1799 }
1800 
inferRankReducedResultType(unsigned resultRank,MemRefType sourceRankedTensorType,ArrayRef<int64_t> leadingStaticOffsets,ArrayRef<int64_t> leadingStaticSizes,ArrayRef<int64_t> leadingStaticStrides)1801 Type SubViewOp::inferRankReducedResultType(
1802     unsigned resultRank, MemRefType sourceRankedTensorType,
1803     ArrayRef<int64_t> leadingStaticOffsets,
1804     ArrayRef<int64_t> leadingStaticSizes,
1805     ArrayRef<int64_t> leadingStaticStrides) {
1806   auto inferredType =
1807       inferResultType(sourceRankedTensorType, leadingStaticOffsets,
1808                       leadingStaticSizes, leadingStaticStrides)
1809           .cast<MemRefType>();
1810   assert(inferredType.getRank() >= resultRank && "expected ");
1811   int rankDiff = inferredType.getRank() - resultRank;
1812   if (rankDiff > 0) {
1813     auto shape = inferredType.getShape();
1814     llvm::SmallDenseSet<unsigned> dimsToProject;
1815     mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject);
1816     SmallVector<int64_t> projectedShape;
1817     for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
1818       if (!dimsToProject.contains(pos))
1819         projectedShape.push_back(shape[pos]);
1820 
1821     AffineMap map;
1822     auto maps = inferredType.getAffineMaps();
1823     if (!maps.empty() && maps.front())
1824       map = getProjectedMap(maps.front(), dimsToProject);
1825     inferredType =
1826         MemRefType::get(projectedShape, inferredType.getElementType(), map,
1827                         inferredType.getMemorySpace());
1828   }
1829   return inferredType;
1830 }
1831 
inferRankReducedResultType(unsigned resultRank,MemRefType sourceRankedTensorType,ArrayRef<OpFoldResult> leadingStaticOffsets,ArrayRef<OpFoldResult> leadingStaticSizes,ArrayRef<OpFoldResult> leadingStaticStrides)1832 Type SubViewOp::inferRankReducedResultType(
1833     unsigned resultRank, MemRefType sourceRankedTensorType,
1834     ArrayRef<OpFoldResult> leadingStaticOffsets,
1835     ArrayRef<OpFoldResult> leadingStaticSizes,
1836     ArrayRef<OpFoldResult> leadingStaticStrides) {
1837   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1838   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1839   dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
1840                              staticOffsets, ShapedType::kDynamicStrideOrOffset);
1841   dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
1842                              ShapedType::kDynamicSize);
1843   dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
1844                              staticStrides, ShapedType::kDynamicStrideOrOffset);
1845   return SubViewOp::inferRankReducedResultType(
1846       resultRank, sourceRankedTensorType, staticOffsets, staticSizes,
1847       staticStrides);
1848 }
1849 // Build a SubViewOp with mixed static and dynamic entries and custom result
1850 // type. If the type passed is nullptr, it is inferred.
build(OpBuilder & b,OperationState & result,MemRefType resultType,Value source,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides,ArrayRef<NamedAttribute> attrs)1851 void SubViewOp::build(OpBuilder &b, OperationState &result,
1852                       MemRefType resultType, Value source,
1853                       ArrayRef<OpFoldResult> offsets,
1854                       ArrayRef<OpFoldResult> sizes,
1855                       ArrayRef<OpFoldResult> strides,
1856                       ArrayRef<NamedAttribute> attrs) {
1857   SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1858   SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1859   dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1860                              ShapedType::kDynamicStrideOrOffset);
1861   dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1862                              ShapedType::kDynamicSize);
1863   dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1864                              ShapedType::kDynamicStrideOrOffset);
1865   auto sourceMemRefType = source.getType().cast<MemRefType>();
1866   // Structuring implementation this way avoids duplication between builders.
1867   if (!resultType) {
1868     resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
1869                                             staticSizes, staticStrides)
1870                      .cast<MemRefType>();
1871   }
1872   build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1873         dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1874         b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1875   result.addAttributes(attrs);
1876 }
1877 
1878 // Build a SubViewOp with mixed static and dynamic entries and inferred result
1879 // type.
build(OpBuilder & b,OperationState & result,Value source,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides,ArrayRef<NamedAttribute> attrs)1880 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1881                       ArrayRef<OpFoldResult> offsets,
1882                       ArrayRef<OpFoldResult> sizes,
1883                       ArrayRef<OpFoldResult> strides,
1884                       ArrayRef<NamedAttribute> attrs) {
1885   build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
1886 }
1887 
1888 // Build a SubViewOp with static entries and inferred result type.
build(OpBuilder & b,OperationState & result,Value source,ArrayRef<int64_t> offsets,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides,ArrayRef<NamedAttribute> attrs)1889 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1890                       ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1891                       ArrayRef<int64_t> strides,
1892                       ArrayRef<NamedAttribute> attrs) {
1893   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1894       llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
1895         return b.getI64IntegerAttr(v);
1896       }));
1897   SmallVector<OpFoldResult> sizeValues =
1898       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1899         return b.getI64IntegerAttr(v);
1900       }));
1901   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1902       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1903         return b.getI64IntegerAttr(v);
1904       }));
1905   build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
1906 }
1907 
1908 // Build a SubViewOp with dynamic entries and custom result type. If the
1909 // type passed is nullptr, it is inferred.
build(OpBuilder & b,OperationState & result,MemRefType resultType,Value source,ArrayRef<int64_t> offsets,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides,ArrayRef<NamedAttribute> attrs)1910 void SubViewOp::build(OpBuilder &b, OperationState &result,
1911                       MemRefType resultType, Value source,
1912                       ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1913                       ArrayRef<int64_t> strides,
1914                       ArrayRef<NamedAttribute> attrs) {
1915   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1916       llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
1917         return b.getI64IntegerAttr(v);
1918       }));
1919   SmallVector<OpFoldResult> sizeValues =
1920       llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1921         return b.getI64IntegerAttr(v);
1922       }));
1923   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1924       llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1925         return b.getI64IntegerAttr(v);
1926       }));
1927   build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
1928         attrs);
1929 }
1930 
1931 // Build a SubViewOp with dynamic entries and custom result type. If the type
1932 // passed is nullptr, it is inferred.
build(OpBuilder & b,OperationState & result,MemRefType resultType,Value source,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)1933 void SubViewOp::build(OpBuilder &b, OperationState &result,
1934                       MemRefType resultType, Value source, ValueRange offsets,
1935                       ValueRange sizes, ValueRange strides,
1936                       ArrayRef<NamedAttribute> attrs) {
1937   SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1938       llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
1939   SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1940       llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1941   SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1942       llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1943   build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
1944 }
1945 
1946 // Build a SubViewOp with dynamic entries and inferred result type.
build(OpBuilder & b,OperationState & result,Value source,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)1947 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1948                       ValueRange offsets, ValueRange sizes, ValueRange strides,
1949                       ArrayRef<NamedAttribute> attrs) {
1950   build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
1951 }
1952 
1953 /// For ViewLikeOpInterface.
getViewSource()1954 Value SubViewOp::getViewSource() { return source(); }
1955 
1956 enum SubViewVerificationResult {
1957   Success,
1958   RankTooLarge,
1959   SizeMismatch,
1960   ElemTypeMismatch,
1961   MemSpaceMismatch,
1962   AffineMapMismatch
1963 };
1964 
1965 /// Checks if `original` Type type can be rank reduced to `reduced` type.
1966 /// This function is slight variant of `is subsequence` algorithm where
1967 /// not matching dimension must be 1.
1968 static SubViewVerificationResult
isRankReducedType(Type originalType,Type candidateReducedType,ArrayAttr staticSizes,std::string * errMsg=nullptr)1969 isRankReducedType(Type originalType, Type candidateReducedType,
1970                   ArrayAttr staticSizes, std::string *errMsg = nullptr) {
1971   if (originalType == candidateReducedType)
1972     return SubViewVerificationResult::Success;
1973   if (!originalType.isa<MemRefType>())
1974     return SubViewVerificationResult::Success;
1975   if (originalType.isa<MemRefType>() && !candidateReducedType.isa<MemRefType>())
1976     return SubViewVerificationResult::Success;
1977 
1978   ShapedType originalShapedType = originalType.cast<ShapedType>();
1979   ShapedType candidateReducedShapedType =
1980       candidateReducedType.cast<ShapedType>();
1981 
1982   // Rank and size logic is valid for all ShapedTypes.
1983   ArrayRef<int64_t> originalShape = originalShapedType.getShape();
1984   ArrayRef<int64_t> candidateReducedShape =
1985       candidateReducedShapedType.getShape();
1986   unsigned originalRank = originalShape.size(),
1987            candidateReducedRank = candidateReducedShape.size();
1988   if (candidateReducedRank > originalRank)
1989     return SubViewVerificationResult::RankTooLarge;
1990 
1991   MemRefType original = originalType.cast<MemRefType>();
1992   MemRefType candidateReduced = candidateReducedType.cast<MemRefType>();
1993 
1994   auto optionalUnusedDimsMask =
1995       computeMemRefRankReductionMask(original, candidateReduced, staticSizes);
1996 
1997   // Sizes cannot be matched in case empty vector is returned.
1998   if (!optionalUnusedDimsMask.hasValue())
1999     return SubViewVerificationResult::SizeMismatch;
2000 
2001   if (originalShapedType.getElementType() !=
2002       candidateReducedShapedType.getElementType())
2003     return SubViewVerificationResult::ElemTypeMismatch;
2004 
2005   // Strided layout logic is relevant for MemRefType only.
2006   if (original.getMemorySpace() != candidateReduced.getMemorySpace())
2007     return SubViewVerificationResult::MemSpaceMismatch;
2008   return SubViewVerificationResult::Success;
2009 }
2010 
2011 template <typename OpTy>
produceSubViewErrorMsg(SubViewVerificationResult result,OpTy op,Type expectedType,StringRef errMsg="")2012 static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result,
2013                                             OpTy op, Type expectedType,
2014                                             StringRef errMsg = "") {
2015   auto memrefType = expectedType.cast<ShapedType>();
2016   switch (result) {
2017   case SubViewVerificationResult::Success:
2018     return success();
2019   case SubViewVerificationResult::RankTooLarge:
2020     return op.emitError("expected result rank to be smaller or equal to ")
2021            << "the source rank. " << errMsg;
2022   case SubViewVerificationResult::SizeMismatch:
2023     return op.emitError("expected result type to be ")
2024            << expectedType
2025            << " or a rank-reduced version. (mismatch of result sizes) "
2026            << errMsg;
2027   case SubViewVerificationResult::ElemTypeMismatch:
2028     return op.emitError("expected result element type to be ")
2029            << memrefType.getElementType() << errMsg;
2030   case SubViewVerificationResult::MemSpaceMismatch:
2031     return op.emitError("expected result and source memory spaces to match.")
2032            << errMsg;
2033   case SubViewVerificationResult::AffineMapMismatch:
2034     return op.emitError("expected result type to be ")
2035            << expectedType
2036            << " or a rank-reduced version. (mismatch of result affine map) "
2037            << errMsg;
2038   }
2039   llvm_unreachable("unexpected subview verification result");
2040 }
2041 
2042 /// Verifier for SubViewOp.
verify(SubViewOp op)2043 static LogicalResult verify(SubViewOp op) {
2044   MemRefType baseType = op.getSourceType();
2045   MemRefType subViewType = op.getType();
2046 
2047   // The base memref and the view memref should be in the same memory space.
2048   if (baseType.getMemorySpace() != subViewType.getMemorySpace())
2049     return op.emitError("different memory spaces specified for base memref "
2050                         "type ")
2051            << baseType << " and subview memref type " << subViewType;
2052 
2053   // Verify that the base memref type has a strided layout map.
2054   if (!isStrided(baseType))
2055     return op.emitError("base type ") << baseType << " is not strided";
2056 
2057   // Verify result type against inferred type.
2058   auto expectedType = SubViewOp::inferResultType(
2059       baseType, extractFromI64ArrayAttr(op.static_offsets()),
2060       extractFromI64ArrayAttr(op.static_sizes()),
2061       extractFromI64ArrayAttr(op.static_strides()));
2062 
2063   std::string errMsg;
2064   auto result =
2065       isRankReducedType(expectedType, subViewType, op.static_sizes(), &errMsg);
2066   return produceSubViewErrorMsg(result, op, expectedType, errMsg);
2067 }
2068 
operator <<(raw_ostream & os,const Range & range)2069 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {
2070   return os << "range " << range.offset << ":" << range.size << ":"
2071             << range.stride;
2072 }
2073 
2074 /// Return the list of Range (i.e. offset, size, stride). Each Range
2075 /// entry contains either the dynamic value or a ConstantIndexOp constructed
2076 /// with `b` at location `loc`.
getOrCreateRanges(OffsetSizeAndStrideOpInterface op,OpBuilder & b,Location loc)2077 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
2078                                               OpBuilder &b, Location loc) {
2079   std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
2080   assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
2081   assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
2082   SmallVector<Range, 8> res;
2083   unsigned rank = ranks[0];
2084   res.reserve(rank);
2085   for (unsigned idx = 0; idx < rank; ++idx) {
2086     Value offset =
2087         op.isDynamicOffset(idx)
2088             ? op.getDynamicOffset(idx)
2089             : b.create<ConstantIndexOp>(loc, op.getStaticOffset(idx));
2090     Value size = op.isDynamicSize(idx)
2091                      ? op.getDynamicSize(idx)
2092                      : b.create<ConstantIndexOp>(loc, op.getStaticSize(idx));
2093     Value stride =
2094         op.isDynamicStride(idx)
2095             ? op.getDynamicStride(idx)
2096             : b.create<ConstantIndexOp>(loc, op.getStaticStride(idx));
2097     res.emplace_back(Range{offset, size, stride});
2098   }
2099   return res;
2100 }
2101 
2102 /// Infer the canonical type of the result of a subview operation. Returns a
2103 /// type with rank `resultRank` that is either the rank of the rank-reduced
2104 /// type, or the non-rank-reduced type.
2105 static MemRefType
getCanonicalSubViewResultType(unsigned resultRank,MemRefType sourceType,ArrayRef<OpFoldResult> mixedOffsets,ArrayRef<OpFoldResult> mixedSizes,ArrayRef<OpFoldResult> mixedStrides)2106 getCanonicalSubViewResultType(unsigned resultRank, MemRefType sourceType,
2107                               ArrayRef<OpFoldResult> mixedOffsets,
2108                               ArrayRef<OpFoldResult> mixedSizes,
2109                               ArrayRef<OpFoldResult> mixedStrides) {
2110   auto resultType =
2111       SubViewOp::inferRankReducedResultType(
2112           resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides)
2113           .cast<MemRefType>();
2114   if (resultType.getRank() != resultRank) {
2115     resultType = SubViewOp::inferResultType(sourceType, mixedOffsets,
2116                                             mixedSizes, mixedStrides)
2117                      .cast<MemRefType>();
2118   }
2119   return resultType;
2120 }
2121 
2122 namespace {
2123 /// Pattern to rewrite a subview op with MemRefCast arguments.
2124 /// This essentially pushes memref.cast past its consuming subview when
2125 /// `canFoldIntoConsumerOp` is true.
2126 ///
2127 /// Example:
2128 /// ```
2129 ///   %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
2130 ///   %1 = memref.subview %0[0, 0][3, 4][1, 1] :
2131 ///     memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]>
2132 /// ```
2133 /// is rewritten into:
2134 /// ```
2135 ///   %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
2136 ///   %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to
2137 ///     memref<3x4xf32, offset:?, strides:[?, 1]>
2138 /// ```
2139 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
2140 public:
2141   using OpRewritePattern<SubViewOp>::OpRewritePattern;
2142 
matchAndRewrite(SubViewOp subViewOp,PatternRewriter & rewriter) const2143   LogicalResult matchAndRewrite(SubViewOp subViewOp,
2144                                 PatternRewriter &rewriter) const override {
2145     // Any constant operand, just return to let SubViewOpConstantFolder kick in.
2146     if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
2147           return matchPattern(operand, matchConstantIndex());
2148         }))
2149       return failure();
2150 
2151     auto castOp = subViewOp.source().getDefiningOp<CastOp>();
2152     if (!castOp)
2153       return failure();
2154 
2155     if (!CastOp::canFoldIntoConsumerOp(castOp))
2156       return failure();
2157 
2158     /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on
2159     /// the cast source operand type and the SubViewOp static information. This
2160     /// is the resulting type if the MemRefCastOp were folded.
2161     auto resultType = getCanonicalSubViewResultType(
2162         subViewOp.getType().getRank(),
2163         castOp.source().getType().cast<MemRefType>(),
2164         subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
2165         subViewOp.getMixedStrides());
2166     Value newSubView = rewriter.create<SubViewOp>(
2167         subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(),
2168         subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(),
2169         subViewOp.static_sizes(), subViewOp.static_strides());
2170     rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
2171                                         newSubView);
2172     return success();
2173   }
2174 };
2175 } // namespace
2176 
2177 /// Return the canonical type of the result of a subview.
2178 struct SubViewReturnTypeCanonicalizer {
operator ()SubViewReturnTypeCanonicalizer2179   MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
2180                         ArrayRef<OpFoldResult> mixedSizes,
2181                         ArrayRef<OpFoldResult> mixedStrides) {
2182     return getCanonicalSubViewResultType(op.getType().getRank(),
2183                                          op.getSourceType(), mixedOffsets,
2184                                          mixedSizes, mixedStrides);
2185   }
2186 };
2187 
2188 /// A canonicalizer wrapper to replace SubViewOps.
2189 struct SubViewCanonicalizer {
operator ()SubViewCanonicalizer2190   void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
2191     rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType());
2192   }
2193 };
2194 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2195 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
2196                                             MLIRContext *context) {
2197   results
2198       .add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
2199                SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>,
2200            SubViewOpMemRefCastFolder>(context);
2201 }
2202 
fold(ArrayRef<Attribute> operands)2203 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
2204   auto resultShapedType = getResult().getType().cast<ShapedType>();
2205   auto sourceShapedType = source().getType().cast<ShapedType>();
2206 
2207   if (resultShapedType.hasStaticShape() &&
2208       resultShapedType == sourceShapedType) {
2209     return getViewSource();
2210   }
2211 
2212   return {};
2213 }
2214 
2215 //===----------------------------------------------------------------------===//
2216 // TensorLoadOp
2217 //===----------------------------------------------------------------------===//
2218 
fold(ArrayRef<Attribute>)2219 OpFoldResult TensorLoadOp::fold(ArrayRef<Attribute>) {
2220   if (auto bufferCast = memref().getDefiningOp<BufferCastOp>())
2221     // Approximate alias analysis by conservatively folding only when no there
2222     // is no interleaved operation.
2223     if (bufferCast->getBlock() == this->getOperation()->getBlock() &&
2224         bufferCast->getNextNode() == this->getOperation())
2225       return bufferCast.tensor();
2226   return {};
2227 }
2228 
2229 namespace {
2230 struct DimOfTensorLoadFolder : public OpRewritePattern<tensor::DimOp> {
2231   using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
2232 
matchAndRewrite__anonbe946e4a1f11::DimOfTensorLoadFolder2233   LogicalResult matchAndRewrite(tensor::DimOp dimOp,
2234                                 PatternRewriter &rewriter) const override {
2235     auto tensorLoadOp = dimOp.source().getDefiningOp<TensorLoadOp>();
2236     if (!tensorLoadOp)
2237       return failure();
2238 
2239     rewriter.replaceOpWithNewOp<DimOp>(dimOp, tensorLoadOp.memref(),
2240                                        dimOp.index());
2241     return success();
2242   }
2243 };
2244 } // namespace
2245 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2246 void TensorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
2247                                                MLIRContext *context) {
2248   results.add<DimOfTensorLoadFolder>(context);
2249 }
2250 
2251 //===----------------------------------------------------------------------===//
2252 // TransposeOp
2253 //===----------------------------------------------------------------------===//
2254 
2255 /// Build a strided memref type by applying `permutationMap` tp `memRefType`.
inferTransposeResultType(MemRefType memRefType,AffineMap permutationMap)2256 static MemRefType inferTransposeResultType(MemRefType memRefType,
2257                                            AffineMap permutationMap) {
2258   auto rank = memRefType.getRank();
2259   auto originalSizes = memRefType.getShape();
2260   // Compute permuted sizes.
2261   SmallVector<int64_t, 4> sizes(rank, 0);
2262   for (auto en : llvm::enumerate(permutationMap.getResults()))
2263     sizes[en.index()] =
2264         originalSizes[en.value().cast<AffineDimExpr>().getPosition()];
2265 
2266   // Compute permuted strides.
2267   int64_t offset;
2268   SmallVector<int64_t, 4> strides;
2269   auto res = getStridesAndOffset(memRefType, strides, offset);
2270   assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank));
2271   (void)res;
2272   auto map =
2273       makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
2274   map = permutationMap ? map.compose(permutationMap) : map;
2275   return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map);
2276 }
2277 
build(OpBuilder & b,OperationState & result,Value in,AffineMapAttr permutation,ArrayRef<NamedAttribute> attrs)2278 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
2279                         AffineMapAttr permutation,
2280                         ArrayRef<NamedAttribute> attrs) {
2281   auto permutationMap = permutation.getValue();
2282   assert(permutationMap);
2283 
2284   auto memRefType = in.getType().cast<MemRefType>();
2285   // Compute result type.
2286   MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
2287 
2288   build(b, result, resultType, in, attrs);
2289   result.addAttribute(TransposeOp::getPermutationAttrName(), permutation);
2290 }
2291 
2292 // transpose $in $permutation attr-dict : type($in) `to` type(results)
print(OpAsmPrinter & p,TransposeOp op)2293 static void print(OpAsmPrinter &p, TransposeOp op) {
2294   p << " " << op.in() << " " << op.permutation();
2295   p.printOptionalAttrDict(op->getAttrs(),
2296                           {TransposeOp::getPermutationAttrName()});
2297   p << " : " << op.in().getType() << " to " << op.getType();
2298 }
2299 
parseTransposeOp(OpAsmParser & parser,OperationState & result)2300 static ParseResult parseTransposeOp(OpAsmParser &parser,
2301                                     OperationState &result) {
2302   OpAsmParser::OperandType in;
2303   AffineMap permutation;
2304   MemRefType srcType, dstType;
2305   if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
2306       parser.parseOptionalAttrDict(result.attributes) ||
2307       parser.parseColonType(srcType) ||
2308       parser.resolveOperand(in, srcType, result.operands) ||
2309       parser.parseKeywordType("to", dstType) ||
2310       parser.addTypeToList(dstType, result.types))
2311     return failure();
2312 
2313   result.addAttribute(TransposeOp::getPermutationAttrName(),
2314                       AffineMapAttr::get(permutation));
2315   return success();
2316 }
2317 
verify(TransposeOp op)2318 static LogicalResult verify(TransposeOp op) {
2319   if (!op.permutation().isPermutation())
2320     return op.emitOpError("expected a permutation map");
2321   if (op.permutation().getNumDims() != op.getShapedType().getRank())
2322     return op.emitOpError(
2323         "expected a permutation map of same rank as the input");
2324 
2325   auto srcType = op.in().getType().cast<MemRefType>();
2326   auto dstType = op.getType().cast<MemRefType>();
2327   auto transposedType = inferTransposeResultType(srcType, op.permutation());
2328   if (dstType != transposedType)
2329     return op.emitOpError("output type ")
2330            << dstType << " does not match transposed input type " << srcType
2331            << ", " << transposedType;
2332   return success();
2333 }
2334 
fold(ArrayRef<Attribute>)2335 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
2336   if (succeeded(foldMemRefCast(*this)))
2337     return getResult();
2338   return {};
2339 }
2340 
2341 //===----------------------------------------------------------------------===//
2342 // ViewOp
2343 //===----------------------------------------------------------------------===//
2344 
parseViewOp(OpAsmParser & parser,OperationState & result)2345 static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) {
2346   OpAsmParser::OperandType srcInfo;
2347   SmallVector<OpAsmParser::OperandType, 1> offsetInfo;
2348   SmallVector<OpAsmParser::OperandType, 4> sizesInfo;
2349   auto indexType = parser.getBuilder().getIndexType();
2350   Type srcType, dstType;
2351   llvm::SMLoc offsetLoc;
2352   if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) ||
2353       parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square))
2354     return failure();
2355 
2356   if (offsetInfo.size() != 1)
2357     return parser.emitError(offsetLoc) << "expects 1 offset operand";
2358 
2359   return failure(
2360       parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) ||
2361       parser.parseOptionalAttrDict(result.attributes) ||
2362       parser.parseColonType(srcType) ||
2363       parser.resolveOperand(srcInfo, srcType, result.operands) ||
2364       parser.resolveOperands(offsetInfo, indexType, result.operands) ||
2365       parser.resolveOperands(sizesInfo, indexType, result.operands) ||
2366       parser.parseKeywordType("to", dstType) ||
2367       parser.addTypeToList(dstType, result.types));
2368 }
2369 
print(OpAsmPrinter & p,ViewOp op)2370 static void print(OpAsmPrinter &p, ViewOp op) {
2371   p << ' ' << op.getOperand(0) << '[';
2372   p.printOperand(op.byte_shift());
2373   p << "][" << op.sizes() << ']';
2374   p.printOptionalAttrDict(op->getAttrs());
2375   p << " : " << op.getOperand(0).getType() << " to " << op.getType();
2376 }
2377 
verify(ViewOp op)2378 static LogicalResult verify(ViewOp op) {
2379   auto baseType = op.getOperand(0).getType().cast<MemRefType>();
2380   auto viewType = op.getType();
2381 
2382   // The base memref should have identity layout map (or none).
2383   if (baseType.getAffineMaps().size() > 1 ||
2384       (baseType.getAffineMaps().size() == 1 &&
2385        !baseType.getAffineMaps()[0].isIdentity()))
2386     return op.emitError("unsupported map for base memref type ") << baseType;
2387 
2388   // The result memref should have identity layout map (or none).
2389   if (viewType.getAffineMaps().size() > 1 ||
2390       (viewType.getAffineMaps().size() == 1 &&
2391        !viewType.getAffineMaps()[0].isIdentity()))
2392     return op.emitError("unsupported map for result memref type ") << viewType;
2393 
2394   // The base memref and the view memref should be in the same memory space.
2395   if (baseType.getMemorySpace() != viewType.getMemorySpace())
2396     return op.emitError("different memory spaces specified for base memref "
2397                         "type ")
2398            << baseType << " and view memref type " << viewType;
2399 
2400   // Verify that we have the correct number of sizes for the result type.
2401   unsigned numDynamicDims = viewType.getNumDynamicDims();
2402   if (op.sizes().size() != numDynamicDims)
2403     return op.emitError("incorrect number of size operands for type ")
2404            << viewType;
2405 
2406   return success();
2407 }
2408 
getViewSource()2409 Value ViewOp::getViewSource() { return source(); }
2410 
2411 namespace {
2412 
2413 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
2414   using OpRewritePattern<ViewOp>::OpRewritePattern;
2415 
matchAndRewrite__anonbe946e4a2011::ViewOpShapeFolder2416   LogicalResult matchAndRewrite(ViewOp viewOp,
2417                                 PatternRewriter &rewriter) const override {
2418     // Return if none of the operands are constants.
2419     if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
2420           return matchPattern(operand, matchConstantIndex());
2421         }))
2422       return failure();
2423 
2424     // Get result memref type.
2425     auto memrefType = viewOp.getType();
2426 
2427     // Get offset from old memref view type 'memRefType'.
2428     int64_t oldOffset;
2429     SmallVector<int64_t, 4> oldStrides;
2430     if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
2431       return failure();
2432     assert(oldOffset == 0 && "Expected 0 offset");
2433 
2434     SmallVector<Value, 4> newOperands;
2435 
2436     // Offset cannot be folded into result type.
2437 
2438     // Fold any dynamic dim operands which are produced by a constant.
2439     SmallVector<int64_t, 4> newShapeConstants;
2440     newShapeConstants.reserve(memrefType.getRank());
2441 
2442     unsigned dynamicDimPos = 0;
2443     unsigned rank = memrefType.getRank();
2444     for (unsigned dim = 0, e = rank; dim < e; ++dim) {
2445       int64_t dimSize = memrefType.getDimSize(dim);
2446       // If this is already static dimension, keep it.
2447       if (!ShapedType::isDynamic(dimSize)) {
2448         newShapeConstants.push_back(dimSize);
2449         continue;
2450       }
2451       auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp();
2452       if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
2453         // Dynamic shape dimension will be folded.
2454         newShapeConstants.push_back(constantIndexOp.getValue());
2455       } else {
2456         // Dynamic shape dimension not folded; copy operand from old memref.
2457         newShapeConstants.push_back(dimSize);
2458         newOperands.push_back(viewOp.sizes()[dynamicDimPos]);
2459       }
2460       dynamicDimPos++;
2461     }
2462 
2463     // Create new memref type with constant folded dims.
2464     MemRefType newMemRefType =
2465         MemRefType::Builder(memrefType).setShape(newShapeConstants);
2466     // Nothing new, don't fold.
2467     if (newMemRefType == memrefType)
2468       return failure();
2469 
2470     // Create new ViewOp.
2471     auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType,
2472                                              viewOp.getOperand(0),
2473                                              viewOp.byte_shift(), newOperands);
2474     // Insert a cast so we have the same type as the old memref type.
2475     rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType());
2476     return success();
2477   }
2478 };
2479 
2480 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
2481   using OpRewritePattern<ViewOp>::OpRewritePattern;
2482 
matchAndRewrite__anonbe946e4a2011::ViewOpMemrefCastFolder2483   LogicalResult matchAndRewrite(ViewOp viewOp,
2484                                 PatternRewriter &rewriter) const override {
2485     Value memrefOperand = viewOp.getOperand(0);
2486     CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
2487     if (!memrefCastOp)
2488       return failure();
2489     Value allocOperand = memrefCastOp.getOperand();
2490     AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
2491     if (!allocOp)
2492       return failure();
2493     rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
2494                                         viewOp.byte_shift(), viewOp.sizes());
2495     return success();
2496   }
2497 };
2498 
2499 } // end anonymous namespace
2500 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2501 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
2502                                          MLIRContext *context) {
2503   results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
2504 }
2505 
2506 //===----------------------------------------------------------------------===//
2507 // TableGen'd op method definitions
2508 //===----------------------------------------------------------------------===//
2509 
2510 #define GET_OP_CLASSES
2511 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
2512