1 //===----------------------------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/MemRef/IR/MemRef.h"
10 #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h"
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/Dialect/StandardOps/Utils/Utils.h"
13 #include "mlir/Dialect/Tensor/IR/Tensor.h"
14 #include "mlir/Dialect/Utils/StaticValueUtils.h"
15 #include "mlir/IR/AffineMap.h"
16 #include "mlir/IR/Builders.h"
17 #include "mlir/IR/BuiltinTypes.h"
18 #include "mlir/IR/Matchers.h"
19 #include "mlir/IR/PatternMatch.h"
20 #include "mlir/IR/TypeUtilities.h"
21 #include "mlir/Interfaces/InferTypeOpInterface.h"
22 #include "mlir/Interfaces/ViewLikeInterface.h"
23 #include "llvm/ADT/STLExtras.h"
24
25 using namespace mlir;
26 using namespace mlir::memref;
27
28 /// Materialize a single constant operation from a given attribute value with
29 /// the desired resultant type.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)30 Operation *MemRefDialect::materializeConstant(OpBuilder &builder,
31 Attribute value, Type type,
32 Location loc) {
33 return builder.create<mlir::ConstantOp>(loc, type, value);
34 }
35
36 //===----------------------------------------------------------------------===//
37 // Common canonicalization pattern support logic
38 //===----------------------------------------------------------------------===//
39
40 /// This is a common class used for patterns of the form
41 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast
42 /// into the root operation directly.
foldMemRefCast(Operation * op,Value inner=nullptr)43 static LogicalResult foldMemRefCast(Operation *op, Value inner = nullptr) {
44 bool folded = false;
45 for (OpOperand &operand : op->getOpOperands()) {
46 auto cast = operand.get().getDefiningOp<CastOp>();
47 if (cast && operand.get() != inner &&
48 !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
49 operand.set(cast.getOperand());
50 folded = true;
51 }
52 }
53 return success(folded);
54 }
55
56 //===----------------------------------------------------------------------===//
57 // Helpers for GlobalOp
58 //===----------------------------------------------------------------------===//
59
getTensorTypeFromMemRefType(Type type)60 static Type getTensorTypeFromMemRefType(Type type) {
61 if (auto memref = type.dyn_cast<MemRefType>())
62 return RankedTensorType::get(memref.getShape(), memref.getElementType());
63 if (auto memref = type.dyn_cast<UnrankedMemRefType>())
64 return UnrankedTensorType::get(memref.getElementType());
65 return NoneType::get(type.getContext());
66 }
67
68 //===----------------------------------------------------------------------===//
69 // AllocOp / AllocaOp
70 //===----------------------------------------------------------------------===//
71
72 template <typename AllocLikeOp>
verifyAllocLikeOp(AllocLikeOp op)73 static LogicalResult verifyAllocLikeOp(AllocLikeOp op) {
74 static_assert(llvm::is_one_of<AllocLikeOp, AllocOp, AllocaOp>::value,
75 "applies to only alloc or alloca");
76 auto memRefType = op.getResult().getType().template dyn_cast<MemRefType>();
77 if (!memRefType)
78 return op.emitOpError("result must be a memref");
79
80 if (static_cast<int64_t>(op.dynamicSizes().size()) !=
81 memRefType.getNumDynamicDims())
82 return op.emitOpError("dimension operand count does not equal memref "
83 "dynamic dimension count");
84
85 unsigned numSymbols = 0;
86 if (!memRefType.getAffineMaps().empty())
87 numSymbols = memRefType.getAffineMaps().front().getNumSymbols();
88 if (op.symbolOperands().size() != numSymbols)
89 return op.emitOpError("symbol operand count does not equal memref symbol "
90 "count: expected ")
91 << numSymbols << ", got " << op.symbolOperands().size();
92
93 return success();
94 }
95
verify(AllocOp op)96 static LogicalResult verify(AllocOp op) { return verifyAllocLikeOp(op); }
97
verify(AllocaOp op)98 static LogicalResult verify(AllocaOp op) {
99 // An alloca op needs to have an ancestor with an allocation scope trait.
100 if (!op->getParentWithTrait<OpTrait::AutomaticAllocationScope>())
101 return op.emitOpError(
102 "requires an ancestor op with AutomaticAllocationScope trait");
103
104 return verifyAllocLikeOp(op);
105 }
106
107 namespace {
108 /// Fold constant dimensions into an alloc like operation.
109 template <typename AllocLikeOp>
110 struct SimplifyAllocConst : public OpRewritePattern<AllocLikeOp> {
111 using OpRewritePattern<AllocLikeOp>::OpRewritePattern;
112
matchAndRewrite__anonbe946e4a0111::SimplifyAllocConst113 LogicalResult matchAndRewrite(AllocLikeOp alloc,
114 PatternRewriter &rewriter) const override {
115 // Check to see if any dimensions operands are constants. If so, we can
116 // substitute and drop them.
117 if (llvm::none_of(alloc.dynamicSizes(), [](Value operand) {
118 return matchPattern(operand, matchConstantIndex());
119 }))
120 return failure();
121
122 auto memrefType = alloc.getType();
123
124 // Ok, we have one or more constant operands. Collect the non-constant ones
125 // and keep track of the resultant memref type to build.
126 SmallVector<int64_t, 4> newShapeConstants;
127 newShapeConstants.reserve(memrefType.getRank());
128 SmallVector<Value, 4> dynamicSizes;
129
130 unsigned dynamicDimPos = 0;
131 for (unsigned dim = 0, e = memrefType.getRank(); dim < e; ++dim) {
132 int64_t dimSize = memrefType.getDimSize(dim);
133 // If this is already static dimension, keep it.
134 if (dimSize != -1) {
135 newShapeConstants.push_back(dimSize);
136 continue;
137 }
138 auto dynamicSize = alloc.dynamicSizes()[dynamicDimPos];
139 auto *defOp = dynamicSize.getDefiningOp();
140 if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
141 // Dynamic shape dimension will be folded.
142 newShapeConstants.push_back(constantIndexOp.getValue());
143 } else {
144 // Dynamic shape dimension not folded; copy dynamicSize from old memref.
145 newShapeConstants.push_back(-1);
146 dynamicSizes.push_back(dynamicSize);
147 }
148 dynamicDimPos++;
149 }
150
151 // Create new memref type (which will have fewer dynamic dimensions).
152 MemRefType newMemRefType =
153 MemRefType::Builder(memrefType).setShape(newShapeConstants);
154 assert(static_cast<int64_t>(dynamicSizes.size()) ==
155 newMemRefType.getNumDynamicDims());
156
157 // Create and insert the alloc op for the new memref.
158 auto newAlloc = rewriter.create<AllocLikeOp>(
159 alloc.getLoc(), newMemRefType, dynamicSizes, alloc.symbolOperands(),
160 alloc.alignmentAttr());
161 // Insert a cast so we have the same type as the old alloc.
162 auto resultCast =
163 rewriter.create<CastOp>(alloc.getLoc(), newAlloc, alloc.getType());
164
165 rewriter.replaceOp(alloc, {resultCast});
166 return success();
167 }
168 };
169
170 /// Fold alloc operations with no users or only store and dealloc uses.
171 template <typename T>
172 struct SimplifyDeadAlloc : public OpRewritePattern<T> {
173 using OpRewritePattern<T>::OpRewritePattern;
174
matchAndRewrite__anonbe946e4a0111::SimplifyDeadAlloc175 LogicalResult matchAndRewrite(T alloc,
176 PatternRewriter &rewriter) const override {
177 if (llvm::any_of(alloc->getUsers(), [&](Operation *op) {
178 if (auto storeOp = dyn_cast<StoreOp>(op))
179 return storeOp.value() == alloc;
180 return !isa<DeallocOp>(op);
181 }))
182 return failure();
183
184 for (Operation *user : llvm::make_early_inc_range(alloc->getUsers()))
185 rewriter.eraseOp(user);
186
187 rewriter.eraseOp(alloc);
188 return success();
189 }
190 };
191 } // end anonymous namespace.
192
buildDealloc(OpBuilder & builder,Value alloc)193 Optional<Operation *> AllocOp::buildDealloc(OpBuilder &builder, Value alloc) {
194 return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
195 .getOperation();
196 }
197
buildClone(OpBuilder & builder,Value alloc)198 Optional<Value> AllocOp::buildClone(OpBuilder &builder, Value alloc) {
199 return builder.create<memref::CloneOp>(alloc.getLoc(), alloc).getResult();
200 }
201
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)202 void AllocOp::getCanonicalizationPatterns(RewritePatternSet &results,
203 MLIRContext *context) {
204 results.add<SimplifyAllocConst<AllocOp>, SimplifyDeadAlloc<AllocOp>>(context);
205 }
206
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)207 void AllocaOp::getCanonicalizationPatterns(RewritePatternSet &results,
208 MLIRContext *context) {
209 results.add<SimplifyAllocConst<AllocaOp>, SimplifyDeadAlloc<AllocaOp>>(
210 context);
211 }
212
213 //===----------------------------------------------------------------------===//
214 // AllocaScopeOp
215 //===----------------------------------------------------------------------===//
216
print(OpAsmPrinter & p,AllocaScopeOp & op)217 static void print(OpAsmPrinter &p, AllocaScopeOp &op) {
218 bool printBlockTerminators = false;
219
220 p << " ";
221 if (!op.results().empty()) {
222 p << " -> (" << op.getResultTypes() << ")";
223 printBlockTerminators = true;
224 }
225 p.printRegion(op.bodyRegion(),
226 /*printEntryBlockArgs=*/false,
227 /*printBlockTerminators=*/printBlockTerminators);
228 p.printOptionalAttrDict(op->getAttrs());
229 }
230
parseAllocaScopeOp(OpAsmParser & parser,OperationState & result)231 static ParseResult parseAllocaScopeOp(OpAsmParser &parser,
232 OperationState &result) {
233 // Create a region for the body.
234 result.regions.reserve(1);
235 Region *bodyRegion = result.addRegion();
236
237 // Parse optional results type list.
238 if (parser.parseOptionalArrowTypeList(result.types))
239 return failure();
240
241 // Parse the body region.
242 if (parser.parseRegion(*bodyRegion, /*arguments=*/{}, /*argTypes=*/{}))
243 return failure();
244 AllocaScopeOp::ensureTerminator(*bodyRegion, parser.getBuilder(),
245 result.location);
246
247 // Parse the optional attribute list.
248 if (parser.parseOptionalAttrDict(result.attributes))
249 return failure();
250
251 return success();
252 }
253
verify(AllocaScopeOp op)254 static LogicalResult verify(AllocaScopeOp op) {
255 if (failed(RegionBranchOpInterface::verifyTypes(op)))
256 return failure();
257
258 return success();
259 }
260
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)261 void AllocaScopeOp::getSuccessorRegions(
262 Optional<unsigned> index, ArrayRef<Attribute> operands,
263 SmallVectorImpl<RegionSuccessor> ®ions) {
264 if (index.hasValue()) {
265 regions.push_back(RegionSuccessor(getResults()));
266 return;
267 }
268
269 regions.push_back(RegionSuccessor(&bodyRegion()));
270 }
271
272 //===----------------------------------------------------------------------===//
273 // AssumeAlignmentOp
274 //===----------------------------------------------------------------------===//
275
verify(AssumeAlignmentOp op)276 static LogicalResult verify(AssumeAlignmentOp op) {
277 unsigned alignment = op.alignment();
278 if (!llvm::isPowerOf2_32(alignment))
279 return op.emitOpError("alignment must be power of 2");
280 return success();
281 }
282
283 //===----------------------------------------------------------------------===//
284 // BufferCastOp
285 //===----------------------------------------------------------------------===//
286
fold(ArrayRef<Attribute>)287 OpFoldResult BufferCastOp::fold(ArrayRef<Attribute>) {
288 if (auto tensorLoad = tensor().getDefiningOp<TensorLoadOp>())
289 if (tensorLoad.memref().getType() == getType())
290 return tensorLoad.memref();
291 return {};
292 }
293
294 namespace {
295 /// Replace tensor_cast + buffer_cast by buffer_cast + memref_cast.
296 struct BufferCast : public OpRewritePattern<BufferCastOp> {
297 using OpRewritePattern<BufferCastOp>::OpRewritePattern;
298
matchAndRewrite__anonbe946e4a0411::BufferCast299 LogicalResult matchAndRewrite(BufferCastOp bufferCast,
300 PatternRewriter &rewriter) const final {
301 auto tensorCastOperand =
302 bufferCast.getOperand().getDefiningOp<tensor::CastOp>();
303 if (!tensorCastOperand)
304 return failure();
305 auto srcTensorType =
306 tensorCastOperand.getOperand().getType().dyn_cast<RankedTensorType>();
307 if (!srcTensorType)
308 return failure();
309 auto memrefType = MemRefType::get(srcTensorType.getShape(),
310 srcTensorType.getElementType());
311 Value memref = rewriter.create<BufferCastOp>(
312 bufferCast.getLoc(), memrefType, tensorCastOperand.getOperand());
313 rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(),
314 memref);
315 return success();
316 }
317 };
318
319 /// Canonicalize memref.tensor_load + memref.buffer_cast to memref.cast when
320 /// type mismatches prevent `BufferCastOp::fold` to kick in.
321 struct TensorLoadToMemRef : public OpRewritePattern<BufferCastOp> {
322 using OpRewritePattern<BufferCastOp>::OpRewritePattern;
323
matchAndRewrite__anonbe946e4a0411::TensorLoadToMemRef324 LogicalResult matchAndRewrite(BufferCastOp bufferCast,
325 PatternRewriter &rewriter) const final {
326 auto tensorLoad = bufferCast.tensor().getDefiningOp<TensorLoadOp>();
327 // Bail unless we have a tensor_load + memref.buffer_cast with different
328 // types. `BufferCastOp::fold` handles the same type case.
329 if (!tensorLoad || tensorLoad.memref().getType() == bufferCast.getType())
330 return failure();
331 // If types are definitely not cast-compatible, bail.
332 if (!CastOp::areCastCompatible(tensorLoad.memref().getType(),
333 bufferCast.getType()))
334 return failure();
335
336 // We already know that the types are potentially cast-compatible. However
337 // in case the affine maps are different, we may need to use a copy if we go
338 // from dynamic to static offset or stride (the canonicalization cannot know
339 // at this point that it is really cast compatible).
340 auto isGuaranteedCastCompatible = [](MemRefType source, MemRefType target) {
341 int64_t sourceOffset, targetOffset;
342 SmallVector<int64_t, 4> sourceStrides, targetStrides;
343 if (failed(getStridesAndOffset(source, sourceStrides, sourceOffset)) ||
344 failed(getStridesAndOffset(target, targetStrides, targetOffset)))
345 return false;
346 auto dynamicToStatic = [](int64_t a, int64_t b) {
347 return a == MemRefType::getDynamicStrideOrOffset() &&
348 b != MemRefType::getDynamicStrideOrOffset();
349 };
350 if (dynamicToStatic(sourceOffset, targetOffset))
351 return false;
352 for (auto it : zip(sourceStrides, targetStrides))
353 if (dynamicToStatic(std::get<0>(it), std::get<1>(it)))
354 return false;
355 return true;
356 };
357
358 auto tensorLoadType = tensorLoad.memref().getType().dyn_cast<MemRefType>();
359 auto bufferCastType = bufferCast.getType().dyn_cast<MemRefType>();
360 if (tensorLoadType && bufferCastType &&
361 !isGuaranteedCastCompatible(tensorLoadType, bufferCastType)) {
362 MemRefType resultType = bufferCastType;
363 auto loc = bufferCast.getLoc();
364 SmallVector<Value, 4> dynamicOperands;
365 for (int i = 0; i < resultType.getRank(); ++i) {
366 if (resultType.getShape()[i] != ShapedType::kDynamicSize)
367 continue;
368 auto index = rewriter.createOrFold<ConstantIndexOp>(loc, i);
369 Value size = rewriter.create<tensor::DimOp>(loc, tensorLoad, index);
370 dynamicOperands.push_back(size);
371 }
372 auto copy =
373 rewriter.create<memref::AllocOp>(loc, resultType, dynamicOperands);
374 rewriter.create<CopyOp>(loc, tensorLoad.memref(), copy);
375 rewriter.replaceOp(bufferCast, {copy});
376 } else
377 rewriter.replaceOpWithNewOp<CastOp>(bufferCast, bufferCast.getType(),
378 tensorLoad.memref());
379 return success();
380 }
381 };
382
383 } // namespace
384
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)385 void BufferCastOp::getCanonicalizationPatterns(RewritePatternSet &results,
386 MLIRContext *context) {
387 results.add<BufferCast, TensorLoadToMemRef>(context);
388 }
389
390 //===----------------------------------------------------------------------===//
391 // CastOp
392 //===----------------------------------------------------------------------===//
393
394 /// Determines whether MemRef_CastOp casts to a more dynamic version of the
395 /// source memref. This is useful to to fold a memref.cast into a consuming op
396 /// and implement canonicalization patterns for ops in different dialects that
397 /// may consume the results of memref.cast operations. Such foldable memref.cast
398 /// operations are typically inserted as `view` and `subview` ops are
399 /// canonicalized, to preserve the type compatibility of their uses.
400 ///
401 /// Returns true when all conditions are met:
402 /// 1. source and result are ranked memrefs with strided semantics and same
403 /// element type and rank.
404 /// 2. each of the source's size, offset or stride has more static information
405 /// than the corresponding result's size, offset or stride.
406 ///
407 /// Example 1:
408 /// ```mlir
409 /// %1 = memref.cast %0 : memref<8x16xf32> to memref<?x?xf32>
410 /// %2 = consumer %1 ... : memref<?x?xf32> ...
411 /// ```
412 ///
413 /// may fold into:
414 ///
415 /// ```mlir
416 /// %2 = consumer %0 ... : memref<8x16xf32> ...
417 /// ```
418 ///
419 /// Example 2:
420 /// ```
421 /// %1 = memref.cast %0 : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
422 /// to memref<?x?xf32>
423 /// consumer %1 : memref<?x?xf32> ...
424 /// ```
425 ///
426 /// may fold into:
427 ///
428 /// ```
429 /// consumer %0 ... : memref<?x16xf32, affine_map<(i, j)->(16 * i + j)>>
430 /// ```
canFoldIntoConsumerOp(CastOp castOp)431 bool CastOp::canFoldIntoConsumerOp(CastOp castOp) {
432 MemRefType sourceType = castOp.source().getType().dyn_cast<MemRefType>();
433 MemRefType resultType = castOp.getType().dyn_cast<MemRefType>();
434
435 // Requires ranked MemRefType.
436 if (!sourceType || !resultType)
437 return false;
438
439 // Requires same elemental type.
440 if (sourceType.getElementType() != resultType.getElementType())
441 return false;
442
443 // Requires same rank.
444 if (sourceType.getRank() != resultType.getRank())
445 return false;
446
447 // Only fold casts between strided memref forms.
448 int64_t sourceOffset, resultOffset;
449 SmallVector<int64_t, 4> sourceStrides, resultStrides;
450 if (failed(getStridesAndOffset(sourceType, sourceStrides, sourceOffset)) ||
451 failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
452 return false;
453
454 // If cast is towards more static sizes along any dimension, don't fold.
455 for (auto it : llvm::zip(sourceType.getShape(), resultType.getShape())) {
456 auto ss = std::get<0>(it), st = std::get<1>(it);
457 if (ss != st)
458 if (MemRefType::isDynamic(ss) && !MemRefType::isDynamic(st))
459 return false;
460 }
461
462 // If cast is towards more static offset along any dimension, don't fold.
463 if (sourceOffset != resultOffset)
464 if (MemRefType::isDynamicStrideOrOffset(sourceOffset) &&
465 !MemRefType::isDynamicStrideOrOffset(resultOffset))
466 return false;
467
468 // If cast is towards more static strides along any dimension, don't fold.
469 for (auto it : llvm::zip(sourceStrides, resultStrides)) {
470 auto ss = std::get<0>(it), st = std::get<1>(it);
471 if (ss != st)
472 if (MemRefType::isDynamicStrideOrOffset(ss) &&
473 !MemRefType::isDynamicStrideOrOffset(st))
474 return false;
475 }
476
477 return true;
478 }
479
areCastCompatible(TypeRange inputs,TypeRange outputs)480 bool CastOp::areCastCompatible(TypeRange inputs, TypeRange outputs) {
481 if (inputs.size() != 1 || outputs.size() != 1)
482 return false;
483 Type a = inputs.front(), b = outputs.front();
484 auto aT = a.dyn_cast<MemRefType>();
485 auto bT = b.dyn_cast<MemRefType>();
486
487 auto uaT = a.dyn_cast<UnrankedMemRefType>();
488 auto ubT = b.dyn_cast<UnrankedMemRefType>();
489
490 if (aT && bT) {
491 if (aT.getElementType() != bT.getElementType())
492 return false;
493 if (aT.getAffineMaps() != bT.getAffineMaps()) {
494 int64_t aOffset, bOffset;
495 SmallVector<int64_t, 4> aStrides, bStrides;
496 if (failed(getStridesAndOffset(aT, aStrides, aOffset)) ||
497 failed(getStridesAndOffset(bT, bStrides, bOffset)) ||
498 aStrides.size() != bStrides.size())
499 return false;
500
501 // Strides along a dimension/offset are compatible if the value in the
502 // source memref is static and the value in the target memref is the
503 // same. They are also compatible if either one is dynamic (see
504 // description of MemRefCastOp for details).
505 auto checkCompatible = [](int64_t a, int64_t b) {
506 return (a == MemRefType::getDynamicStrideOrOffset() ||
507 b == MemRefType::getDynamicStrideOrOffset() || a == b);
508 };
509 if (!checkCompatible(aOffset, bOffset))
510 return false;
511 for (auto aStride : enumerate(aStrides))
512 if (!checkCompatible(aStride.value(), bStrides[aStride.index()]))
513 return false;
514 }
515 if (aT.getMemorySpace() != bT.getMemorySpace())
516 return false;
517
518 // They must have the same rank, and any specified dimensions must match.
519 if (aT.getRank() != bT.getRank())
520 return false;
521
522 for (unsigned i = 0, e = aT.getRank(); i != e; ++i) {
523 int64_t aDim = aT.getDimSize(i), bDim = bT.getDimSize(i);
524 if (aDim != -1 && bDim != -1 && aDim != bDim)
525 return false;
526 }
527 return true;
528 } else {
529 if (!aT && !uaT)
530 return false;
531 if (!bT && !ubT)
532 return false;
533 // Unranked to unranked casting is unsupported
534 if (uaT && ubT)
535 return false;
536
537 auto aEltType = (aT) ? aT.getElementType() : uaT.getElementType();
538 auto bEltType = (bT) ? bT.getElementType() : ubT.getElementType();
539 if (aEltType != bEltType)
540 return false;
541
542 auto aMemSpace = (aT) ? aT.getMemorySpace() : uaT.getMemorySpace();
543 auto bMemSpace = (bT) ? bT.getMemorySpace() : ubT.getMemorySpace();
544 if (aMemSpace != bMemSpace)
545 return false;
546
547 return true;
548 }
549
550 return false;
551 }
552
fold(ArrayRef<Attribute> operands)553 OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
554 return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
555 }
556
557 //===----------------------------------------------------------------------===//
558 // CloneOp
559 //===----------------------------------------------------------------------===//
560
getEffects(SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>> & effects)561 void CloneOp::getEffects(
562 SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
563 &effects) {
564 effects.emplace_back(MemoryEffects::Read::get(), input(),
565 SideEffects::DefaultResource::get());
566 effects.emplace_back(MemoryEffects::Write::get(), output(),
567 SideEffects::DefaultResource::get());
568 effects.emplace_back(MemoryEffects::Allocate::get(), output(),
569 SideEffects::DefaultResource::get());
570 }
571
572 namespace {
573 /// Merge the clone and its source (by converting the clone to a cast) when
574 /// possible.
575 struct SimplifyClones : public OpRewritePattern<CloneOp> {
576 using OpRewritePattern<CloneOp>::OpRewritePattern;
577
matchAndRewrite__anonbe946e4a0811::SimplifyClones578 LogicalResult matchAndRewrite(CloneOp cloneOp,
579 PatternRewriter &rewriter) const override {
580 if (cloneOp.use_empty()) {
581 rewriter.eraseOp(cloneOp);
582 return success();
583 }
584
585 Value source = cloneOp.input();
586
587 // This only finds dealloc operations for the immediate value. It should
588 // also consider aliases. That would also make the safety check below
589 // redundant.
590 llvm::Optional<Operation *> maybeCloneDeallocOp =
591 findDealloc(cloneOp.output());
592 // Skip if either of them has > 1 deallocate operations.
593 if (!maybeCloneDeallocOp.hasValue())
594 return failure();
595 llvm::Optional<Operation *> maybeSourceDeallocOp = findDealloc(source);
596 if (!maybeSourceDeallocOp.hasValue())
597 return failure();
598 Operation *cloneDeallocOp = *maybeCloneDeallocOp;
599 Operation *sourceDeallocOp = *maybeSourceDeallocOp;
600
601 // If both are deallocated in the same block, their in-block lifetimes
602 // might not fully overlap, so we cannot decide which one to drop.
603 if (cloneDeallocOp && sourceDeallocOp &&
604 cloneDeallocOp->getBlock() == sourceDeallocOp->getBlock())
605 return failure();
606
607 Block *currentBlock = cloneOp->getBlock();
608 Operation *redundantDealloc = nullptr;
609 if (cloneDeallocOp && cloneDeallocOp->getBlock() == currentBlock) {
610 redundantDealloc = cloneDeallocOp;
611 } else if (sourceDeallocOp && sourceDeallocOp->getBlock() == currentBlock) {
612 redundantDealloc = sourceDeallocOp;
613 }
614
615 if (!redundantDealloc)
616 return failure();
617
618 // Safety check that there are no other deallocations inbetween
619 // cloneOp and redundantDealloc, as otherwise we might deallocate an alias
620 // of source before the uses of the clone. With alias information, we could
621 // restrict this to only fail of the dealloc's operand is an alias
622 // of the source.
623 for (Operation *pos = cloneOp->getNextNode(); pos != redundantDealloc;
624 pos = pos->getNextNode()) {
625 auto effectInterface = dyn_cast<MemoryEffectOpInterface>(pos);
626 if (!effectInterface)
627 continue;
628 if (effectInterface.hasEffect<MemoryEffects::Free>())
629 return failure();
630 }
631
632 rewriter.replaceOpWithNewOp<memref::CastOp>(cloneOp, cloneOp.getType(),
633 source);
634 rewriter.eraseOp(redundantDealloc);
635 return success();
636 }
637 };
638
639 } // end anonymous namespace.
640
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)641 void CloneOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
642 MLIRContext *context) {
643 results.insert<SimplifyClones>(context);
644 }
645
fold(ArrayRef<Attribute> operands)646 OpFoldResult CloneOp::fold(ArrayRef<Attribute> operands) {
647 return succeeded(foldMemRefCast(*this)) ? getResult() : Value();
648 }
649
buildDealloc(OpBuilder & builder,Value alloc)650 Optional<Operation *> CloneOp::buildDealloc(OpBuilder &builder, Value alloc) {
651 return builder.create<memref::DeallocOp>(alloc.getLoc(), alloc)
652 .getOperation();
653 }
654
buildClone(OpBuilder & builder,Value alloc)655 Optional<Value> CloneOp::buildClone(OpBuilder &builder, Value alloc) {
656 return builder.create<memref::CloneOp>(alloc.getLoc(), alloc).getResult();
657 }
658
659 //===----------------------------------------------------------------------===//
660 // DeallocOp
661 //===----------------------------------------------------------------------===//
662
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)663 LogicalResult DeallocOp::fold(ArrayRef<Attribute> cstOperands,
664 SmallVectorImpl<OpFoldResult> &results) {
665 /// dealloc(memrefcast) -> dealloc
666 return foldMemRefCast(*this);
667 }
668
669 //===----------------------------------------------------------------------===//
670 // DimOp
671 //===----------------------------------------------------------------------===//
672
build(OpBuilder & builder,OperationState & result,Value source,int64_t index)673 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
674 int64_t index) {
675 auto loc = result.location;
676 Value indexValue = builder.create<ConstantIndexOp>(loc, index);
677 build(builder, result, source, indexValue);
678 }
679
build(OpBuilder & builder,OperationState & result,Value source,Value index)680 void DimOp::build(OpBuilder &builder, OperationState &result, Value source,
681 Value index) {
682 auto indexTy = builder.getIndexType();
683 build(builder, result, indexTy, source, index);
684 }
685
getConstantIndex()686 Optional<int64_t> DimOp::getConstantIndex() {
687 if (auto constantOp = index().getDefiningOp<ConstantOp>())
688 return constantOp.getValue().cast<IntegerAttr>().getInt();
689 return {};
690 }
691
verify(DimOp op)692 static LogicalResult verify(DimOp op) {
693 // Assume unknown index to be in range.
694 Optional<int64_t> index = op.getConstantIndex();
695 if (!index.hasValue())
696 return success();
697
698 // Check that constant index is not knowingly out of range.
699 auto type = op.source().getType();
700 if (auto memrefType = type.dyn_cast<MemRefType>()) {
701 if (index.getValue() >= memrefType.getRank())
702 return op.emitOpError("index is out of range");
703 } else if (type.isa<UnrankedMemRefType>()) {
704 // Assume index to be in range.
705 } else {
706 llvm_unreachable("expected operand with memref type");
707 }
708 return success();
709 }
710
711 /// Return a map with key being elements in `vals` and data being number of
712 /// occurences of it. Use std::map, since the `vals` here are strides and the
713 /// dynamic stride value is the same as the tombstone value for
714 /// `DenseMap<int64_t>`.
getNumOccurences(ArrayRef<int64_t> vals)715 static std::map<int64_t, unsigned> getNumOccurences(ArrayRef<int64_t> vals) {
716 std::map<int64_t, unsigned> numOccurences;
717 for (auto val : vals)
718 numOccurences[val]++;
719 return numOccurences;
720 }
721
722 /// Given the type of the un-rank reduced subview result type and the
723 /// rank-reduced result type, computes the dropped dimensions. This accounts for
724 /// cases where there are multiple unit-dims, but only a subset of those are
725 /// dropped. For MemRefTypes these can be disambiguated using the strides. If a
726 /// dimension is dropped the stride must be dropped too.
727 static llvm::Optional<llvm::SmallDenseSet<unsigned>>
computeMemRefRankReductionMask(MemRefType originalType,MemRefType reducedType,ArrayAttr staticSizes)728 computeMemRefRankReductionMask(MemRefType originalType, MemRefType reducedType,
729 ArrayAttr staticSizes) {
730 llvm::SmallDenseSet<unsigned> unusedDims;
731 if (originalType.getRank() == reducedType.getRank())
732 return unusedDims;
733
734 for (auto dim : llvm::enumerate(staticSizes))
735 if (dim.value().cast<IntegerAttr>().getInt() == 1)
736 unusedDims.insert(dim.index());
737 SmallVector<int64_t> originalStrides, candidateStrides;
738 int64_t originalOffset, candidateOffset;
739 if (failed(
740 getStridesAndOffset(originalType, originalStrides, originalOffset)) ||
741 failed(
742 getStridesAndOffset(reducedType, candidateStrides, candidateOffset)))
743 return llvm::None;
744
745 // For memrefs, a dimension is truly dropped if its corresponding stride is
746 // also dropped. This is particularly important when more than one of the dims
747 // is 1. Track the number of occurences of the strides in the original type
748 // and the candidate type. For each unused dim that stride should not be
749 // present in the candidate type. Note that there could be multiple dimensions
750 // that have the same size. We dont need to exactly figure out which dim
751 // corresponds to which stride, we just need to verify that the number of
752 // reptitions of a stride in the original + number of unused dims with that
753 // stride == number of repititions of a stride in the candidate.
754 std::map<int64_t, unsigned> currUnaccountedStrides =
755 getNumOccurences(originalStrides);
756 std::map<int64_t, unsigned> candidateStridesNumOccurences =
757 getNumOccurences(candidateStrides);
758 llvm::SmallDenseSet<unsigned> prunedUnusedDims;
759 for (unsigned dim : unusedDims) {
760 int64_t originalStride = originalStrides[dim];
761 if (currUnaccountedStrides[originalStride] >
762 candidateStridesNumOccurences[originalStride]) {
763 // This dim can be treated as dropped.
764 currUnaccountedStrides[originalStride]--;
765 continue;
766 }
767 if (currUnaccountedStrides[originalStride] ==
768 candidateStridesNumOccurences[originalStride]) {
769 // The stride for this is not dropped. Keep as is.
770 prunedUnusedDims.insert(dim);
771 continue;
772 }
773 if (currUnaccountedStrides[originalStride] <
774 candidateStridesNumOccurences[originalStride]) {
775 // This should never happen. Cant have a stride in the reduced rank type
776 // that wasnt in the original one.
777 return llvm::None;
778 }
779 }
780
781 for (auto prunedDim : prunedUnusedDims)
782 unusedDims.erase(prunedDim);
783 if (unusedDims.size() + reducedType.getRank() != originalType.getRank())
784 return llvm::None;
785 return unusedDims;
786 }
787
getDroppedDims()788 llvm::SmallDenseSet<unsigned> SubViewOp::getDroppedDims() {
789 MemRefType sourceType = getSourceType();
790 MemRefType resultType = getType();
791 llvm::Optional<llvm::SmallDenseSet<unsigned>> unusedDims =
792 computeMemRefRankReductionMask(sourceType, resultType, static_sizes());
793 assert(unusedDims && "unable to find unused dims of subview");
794 return *unusedDims;
795 }
796
fold(ArrayRef<Attribute> operands)797 OpFoldResult DimOp::fold(ArrayRef<Attribute> operands) {
798 // All forms of folding require a known index.
799 auto index = operands[1].dyn_cast_or_null<IntegerAttr>();
800 if (!index)
801 return {};
802
803 // Folding for unranked types (UnrankedMemRefType) is not supported.
804 auto memrefType = source().getType().dyn_cast<MemRefType>();
805 if (!memrefType)
806 return {};
807
808 // Fold if the shape extent along the given index is known.
809 if (!memrefType.isDynamicDim(index.getInt())) {
810 Builder builder(getContext());
811 return builder.getIndexAttr(memrefType.getShape()[index.getInt()]);
812 }
813
814 // The size at the given index is now known to be a dynamic size.
815 unsigned unsignedIndex = index.getValue().getZExtValue();
816
817 // Fold dim to the size argument for an `AllocOp`, `ViewOp`, or `SubViewOp`.
818 Operation *definingOp = source().getDefiningOp();
819
820 if (auto alloc = dyn_cast_or_null<AllocOp>(definingOp))
821 return *(alloc.getDynamicSizes().begin() +
822 memrefType.getDynamicDimIndex(unsignedIndex));
823
824 if (auto alloca = dyn_cast_or_null<AllocaOp>(definingOp))
825 return *(alloca.getDynamicSizes().begin() +
826 memrefType.getDynamicDimIndex(unsignedIndex));
827
828 if (auto view = dyn_cast_or_null<ViewOp>(definingOp))
829 return *(view.getDynamicSizes().begin() +
830 memrefType.getDynamicDimIndex(unsignedIndex));
831
832 if (auto subview = dyn_cast_or_null<SubViewOp>(definingOp)) {
833 llvm::SmallDenseSet<unsigned> unusedDims = subview.getDroppedDims();
834 unsigned resultIndex = 0;
835 unsigned sourceRank = subview.getSourceType().getRank();
836 unsigned sourceIndex = 0;
837 for (auto i : llvm::seq<unsigned>(0, sourceRank)) {
838 if (unusedDims.count(i))
839 continue;
840 if (resultIndex == unsignedIndex) {
841 sourceIndex = i;
842 break;
843 }
844 resultIndex++;
845 }
846 assert(subview.isDynamicSize(sourceIndex) &&
847 "expected dynamic subview size");
848 return subview.getDynamicSize(sourceIndex);
849 }
850
851 if (auto sizeInterface =
852 dyn_cast_or_null<OffsetSizeAndStrideOpInterface>(definingOp)) {
853 assert(sizeInterface.isDynamicSize(unsignedIndex) &&
854 "Expected dynamic subview size");
855 return sizeInterface.getDynamicSize(unsignedIndex);
856 }
857
858 // dim(memrefcast) -> dim
859 if (succeeded(foldMemRefCast(*this)))
860 return getResult();
861
862 return {};
863 }
864
865 namespace {
866 /// Fold dim of a memref reshape operation to a load into the reshape's shape
867 /// operand.
868 struct DimOfMemRefReshape : public OpRewritePattern<DimOp> {
869 using OpRewritePattern<DimOp>::OpRewritePattern;
870
matchAndRewrite__anonbe946e4a0911::DimOfMemRefReshape871 LogicalResult matchAndRewrite(DimOp dim,
872 PatternRewriter &rewriter) const override {
873 auto reshape = dim.source().getDefiningOp<ReshapeOp>();
874
875 if (!reshape)
876 return failure();
877
878 // Place the load directly after the reshape to ensure that the shape memref
879 // was not mutated.
880 rewriter.setInsertionPointAfter(reshape);
881 Location loc = dim.getLoc();
882 Value load = rewriter.create<LoadOp>(loc, reshape.shape(), dim.index());
883 if (load.getType() != dim.getType())
884 load = rewriter.create<IndexCastOp>(loc, dim.getType(), load);
885 rewriter.replaceOp(dim, load);
886 return success();
887 }
888 };
889
890 /// Fold dim of a cast into the dim of the source of the memref cast.
891 struct DimOfCastOp : public OpRewritePattern<DimOp> {
892 using OpRewritePattern<DimOp>::OpRewritePattern;
893
matchAndRewrite__anonbe946e4a0911::DimOfCastOp894 LogicalResult matchAndRewrite(DimOp dimOp,
895 PatternRewriter &rewriter) const override {
896 auto castOp = dimOp.source().getDefiningOp<BufferCastOp>();
897 if (!castOp)
898 return failure();
899 Value newSource = castOp.getOperand();
900 rewriter.replaceOpWithNewOp<tensor::DimOp>(dimOp, newSource, dimOp.index());
901 return success();
902 }
903 };
904 } // end anonymous namespace.
905
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)906 void DimOp::getCanonicalizationPatterns(RewritePatternSet &results,
907 MLIRContext *context) {
908 results.add<DimOfMemRefReshape, DimOfCastOp>(context);
909 }
910
911 // ---------------------------------------------------------------------------
912 // DmaStartOp
913 // ---------------------------------------------------------------------------
914
build(OpBuilder & builder,OperationState & result,Value srcMemRef,ValueRange srcIndices,Value destMemRef,ValueRange destIndices,Value numElements,Value tagMemRef,ValueRange tagIndices,Value stride,Value elementsPerStride)915 void DmaStartOp::build(OpBuilder &builder, OperationState &result,
916 Value srcMemRef, ValueRange srcIndices, Value destMemRef,
917 ValueRange destIndices, Value numElements,
918 Value tagMemRef, ValueRange tagIndices, Value stride,
919 Value elementsPerStride) {
920 result.addOperands(srcMemRef);
921 result.addOperands(srcIndices);
922 result.addOperands(destMemRef);
923 result.addOperands(destIndices);
924 result.addOperands({numElements, tagMemRef});
925 result.addOperands(tagIndices);
926 if (stride)
927 result.addOperands({stride, elementsPerStride});
928 }
929
print(OpAsmPrinter & p,DmaStartOp op)930 static void print(OpAsmPrinter &p, DmaStartOp op) {
931 p << " " << op.getSrcMemRef() << '[' << op.getSrcIndices() << "], "
932 << op.getDstMemRef() << '[' << op.getDstIndices() << "], "
933 << op.getNumElements() << ", " << op.getTagMemRef() << '['
934 << op.getTagIndices() << ']';
935 if (op.isStrided())
936 p << ", " << op.getStride() << ", " << op.getNumElementsPerStride();
937
938 p.printOptionalAttrDict(op->getAttrs());
939 p << " : " << op.getSrcMemRef().getType() << ", "
940 << op.getDstMemRef().getType() << ", " << op.getTagMemRef().getType();
941 }
942
943 // Parse DmaStartOp.
944 // Ex:
945 // %dma_id = dma_start %src[%i, %j], %dst[%k, %l], %size,
946 // %tag[%index], %stride, %num_elt_per_stride :
947 // : memref<3076 x f32, 0>,
948 // memref<1024 x f32, 2>,
949 // memref<1 x i32>
950 //
parseDmaStartOp(OpAsmParser & parser,OperationState & result)951 static ParseResult parseDmaStartOp(OpAsmParser &parser,
952 OperationState &result) {
953 OpAsmParser::OperandType srcMemRefInfo;
954 SmallVector<OpAsmParser::OperandType, 4> srcIndexInfos;
955 OpAsmParser::OperandType dstMemRefInfo;
956 SmallVector<OpAsmParser::OperandType, 4> dstIndexInfos;
957 OpAsmParser::OperandType numElementsInfo;
958 OpAsmParser::OperandType tagMemrefInfo;
959 SmallVector<OpAsmParser::OperandType, 4> tagIndexInfos;
960 SmallVector<OpAsmParser::OperandType, 2> strideInfo;
961
962 SmallVector<Type, 3> types;
963 auto indexType = parser.getBuilder().getIndexType();
964
965 // Parse and resolve the following list of operands:
966 // *) source memref followed by its indices (in square brackets).
967 // *) destination memref followed by its indices (in square brackets).
968 // *) dma size in KiB.
969 if (parser.parseOperand(srcMemRefInfo) ||
970 parser.parseOperandList(srcIndexInfos, OpAsmParser::Delimiter::Square) ||
971 parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
972 parser.parseOperandList(dstIndexInfos, OpAsmParser::Delimiter::Square) ||
973 parser.parseComma() || parser.parseOperand(numElementsInfo) ||
974 parser.parseComma() || parser.parseOperand(tagMemrefInfo) ||
975 parser.parseOperandList(tagIndexInfos, OpAsmParser::Delimiter::Square))
976 return failure();
977
978 // Parse optional stride and elements per stride.
979 if (parser.parseTrailingOperandList(strideInfo))
980 return failure();
981
982 bool isStrided = strideInfo.size() == 2;
983 if (!strideInfo.empty() && !isStrided) {
984 return parser.emitError(parser.getNameLoc(),
985 "expected two stride related operands");
986 }
987
988 if (parser.parseColonTypeList(types))
989 return failure();
990 if (types.size() != 3)
991 return parser.emitError(parser.getNameLoc(), "fewer/more types expected");
992
993 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
994 parser.resolveOperands(srcIndexInfos, indexType, result.operands) ||
995 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
996 parser.resolveOperands(dstIndexInfos, indexType, result.operands) ||
997 // size should be an index.
998 parser.resolveOperand(numElementsInfo, indexType, result.operands) ||
999 parser.resolveOperand(tagMemrefInfo, types[2], result.operands) ||
1000 // tag indices should be index.
1001 parser.resolveOperands(tagIndexInfos, indexType, result.operands))
1002 return failure();
1003
1004 if (isStrided) {
1005 if (parser.resolveOperands(strideInfo, indexType, result.operands))
1006 return failure();
1007 }
1008
1009 return success();
1010 }
1011
verify(DmaStartOp op)1012 static LogicalResult verify(DmaStartOp op) {
1013 unsigned numOperands = op.getNumOperands();
1014
1015 // Mandatory non-variadic operands are: src memref, dst memref, tag memref and
1016 // the number of elements.
1017 if (numOperands < 4)
1018 return op.emitOpError("expected at least 4 operands");
1019
1020 // Check types of operands. The order of these calls is important: the later
1021 // calls rely on some type properties to compute the operand position.
1022 // 1. Source memref.
1023 if (!op.getSrcMemRef().getType().isa<MemRefType>())
1024 return op.emitOpError("expected source to be of memref type");
1025 if (numOperands < op.getSrcMemRefRank() + 4)
1026 return op.emitOpError()
1027 << "expected at least " << op.getSrcMemRefRank() + 4 << " operands";
1028 if (!op.getSrcIndices().empty() &&
1029 !llvm::all_of(op.getSrcIndices().getTypes(),
1030 [](Type t) { return t.isIndex(); }))
1031 return op.emitOpError("expected source indices to be of index type");
1032
1033 // 2. Destination memref.
1034 if (!op.getDstMemRef().getType().isa<MemRefType>())
1035 return op.emitOpError("expected destination to be of memref type");
1036 unsigned numExpectedOperands =
1037 op.getSrcMemRefRank() + op.getDstMemRefRank() + 4;
1038 if (numOperands < numExpectedOperands)
1039 return op.emitOpError()
1040 << "expected at least " << numExpectedOperands << " operands";
1041 if (!op.getDstIndices().empty() &&
1042 !llvm::all_of(op.getDstIndices().getTypes(),
1043 [](Type t) { return t.isIndex(); }))
1044 return op.emitOpError("expected destination indices to be of index type");
1045
1046 // 3. Number of elements.
1047 if (!op.getNumElements().getType().isIndex())
1048 return op.emitOpError("expected num elements to be of index type");
1049
1050 // 4. Tag memref.
1051 if (!op.getTagMemRef().getType().isa<MemRefType>())
1052 return op.emitOpError("expected tag to be of memref type");
1053 numExpectedOperands += op.getTagMemRefRank();
1054 if (numOperands < numExpectedOperands)
1055 return op.emitOpError()
1056 << "expected at least " << numExpectedOperands << " operands";
1057 if (!op.getTagIndices().empty() &&
1058 !llvm::all_of(op.getTagIndices().getTypes(),
1059 [](Type t) { return t.isIndex(); }))
1060 return op.emitOpError("expected tag indices to be of index type");
1061
1062 // Optional stride-related operands must be either both present or both
1063 // absent.
1064 if (numOperands != numExpectedOperands &&
1065 numOperands != numExpectedOperands + 2)
1066 return op.emitOpError("incorrect number of operands");
1067
1068 // 5. Strides.
1069 if (op.isStrided()) {
1070 if (!op.getStride().getType().isIndex() ||
1071 !op.getNumElementsPerStride().getType().isIndex())
1072 return op.emitOpError(
1073 "expected stride and num elements per stride to be of type index");
1074 }
1075
1076 return success();
1077 }
1078
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1079 LogicalResult DmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1080 SmallVectorImpl<OpFoldResult> &results) {
1081 /// dma_start(memrefcast) -> dma_start
1082 return foldMemRefCast(*this);
1083 }
1084
1085 // ---------------------------------------------------------------------------
1086 // DmaWaitOp
1087 // ---------------------------------------------------------------------------
1088
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1089 LogicalResult DmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1090 SmallVectorImpl<OpFoldResult> &results) {
1091 /// dma_wait(memrefcast) -> dma_wait
1092 return foldMemRefCast(*this);
1093 }
1094
verify(DmaWaitOp op)1095 static LogicalResult verify(DmaWaitOp op) {
1096 // Check that the number of tag indices matches the tagMemRef rank.
1097 unsigned numTagIndices = op.tagIndices().size();
1098 unsigned tagMemRefRank = op.getTagMemRefRank();
1099 if (numTagIndices != tagMemRefRank)
1100 return op.emitOpError() << "expected tagIndices to have the same number of "
1101 "elements as the tagMemRef rank, expected "
1102 << tagMemRefRank << ", but got " << numTagIndices;
1103 return success();
1104 }
1105
1106 //===----------------------------------------------------------------------===//
1107 // GlobalOp
1108 //===----------------------------------------------------------------------===//
1109
printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter & p,GlobalOp op,TypeAttr type,Attribute initialValue)1110 static void printGlobalMemrefOpTypeAndInitialValue(OpAsmPrinter &p, GlobalOp op,
1111 TypeAttr type,
1112 Attribute initialValue) {
1113 p << type;
1114 if (!op.isExternal()) {
1115 p << " = ";
1116 if (op.isUninitialized())
1117 p << "uninitialized";
1118 else
1119 p.printAttributeWithoutType(initialValue);
1120 }
1121 }
1122
1123 static ParseResult
parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser & parser,TypeAttr & typeAttr,Attribute & initialValue)1124 parseGlobalMemrefOpTypeAndInitialValue(OpAsmParser &parser, TypeAttr &typeAttr,
1125 Attribute &initialValue) {
1126 Type type;
1127 if (parser.parseType(type))
1128 return failure();
1129
1130 auto memrefType = type.dyn_cast<MemRefType>();
1131 if (!memrefType || !memrefType.hasStaticShape())
1132 return parser.emitError(parser.getNameLoc())
1133 << "type should be static shaped memref, but got " << type;
1134 typeAttr = TypeAttr::get(type);
1135
1136 if (parser.parseOptionalEqual())
1137 return success();
1138
1139 if (succeeded(parser.parseOptionalKeyword("uninitialized"))) {
1140 initialValue = UnitAttr::get(parser.getContext());
1141 return success();
1142 }
1143
1144 Type tensorType = getTensorTypeFromMemRefType(memrefType);
1145 if (parser.parseAttribute(initialValue, tensorType))
1146 return failure();
1147 if (!initialValue.isa<ElementsAttr>())
1148 return parser.emitError(parser.getNameLoc())
1149 << "initial value should be a unit or elements attribute";
1150 return success();
1151 }
1152
verify(GlobalOp op)1153 static LogicalResult verify(GlobalOp op) {
1154 auto memrefType = op.type().dyn_cast<MemRefType>();
1155 if (!memrefType || !memrefType.hasStaticShape())
1156 return op.emitOpError("type should be static shaped memref, but got ")
1157 << op.type();
1158
1159 // Verify that the initial value, if present, is either a unit attribute or
1160 // an elements attribute.
1161 if (op.initial_value().hasValue()) {
1162 Attribute initValue = op.initial_value().getValue();
1163 if (!initValue.isa<UnitAttr>() && !initValue.isa<ElementsAttr>())
1164 return op.emitOpError("initial value should be a unit or elements "
1165 "attribute, but got ")
1166 << initValue;
1167
1168 // Check that the type of the initial value is compatible with the type of
1169 // the global variable.
1170 if (initValue.isa<ElementsAttr>()) {
1171 Type initType = initValue.getType();
1172 Type tensorType = getTensorTypeFromMemRefType(memrefType);
1173 if (initType != tensorType)
1174 return op.emitOpError("initial value expected to be of type ")
1175 << tensorType << ", but was of type " << initType;
1176 }
1177 }
1178
1179 if (Optional<uint64_t> alignAttr = op.alignment()) {
1180 uint64_t alignment = alignAttr.getValue();
1181
1182 if (!llvm::isPowerOf2_64(alignment))
1183 return op->emitError() << "alignment attribute value " << alignment
1184 << " is not a power of 2";
1185 }
1186
1187 // TODO: verify visibility for declarations.
1188 return success();
1189 }
1190
1191 //===----------------------------------------------------------------------===//
1192 // GetGlobalOp
1193 //===----------------------------------------------------------------------===//
1194
1195 LogicalResult
verifySymbolUses(SymbolTableCollection & symbolTable)1196 GetGlobalOp::verifySymbolUses(SymbolTableCollection &symbolTable) {
1197 // Verify that the result type is same as the type of the referenced
1198 // memref.global op.
1199 auto global =
1200 symbolTable.lookupNearestSymbolFrom<GlobalOp>(*this, nameAttr());
1201 if (!global)
1202 return emitOpError("'")
1203 << name() << "' does not reference a valid global memref";
1204
1205 Type resultType = result().getType();
1206 if (global.type() != resultType)
1207 return emitOpError("result type ")
1208 << resultType << " does not match type " << global.type()
1209 << " of the global memref @" << name();
1210 return success();
1211 }
1212
1213 //===----------------------------------------------------------------------===//
1214 // LoadOp
1215 //===----------------------------------------------------------------------===//
1216
verify(LoadOp op)1217 static LogicalResult verify(LoadOp op) {
1218 if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
1219 return op.emitOpError("incorrect number of indices for load");
1220 return success();
1221 }
1222
fold(ArrayRef<Attribute> cstOperands)1223 OpFoldResult LoadOp::fold(ArrayRef<Attribute> cstOperands) {
1224 /// load(memrefcast) -> load
1225 if (succeeded(foldMemRefCast(*this)))
1226 return getResult();
1227 return OpFoldResult();
1228 }
1229
1230 namespace {
1231 /// Fold a load on a buffer_cast operation into an tensor.extract on the
1232 /// corresponding tensor.
1233 struct LoadOfBufferCast : public OpRewritePattern<LoadOp> {
1234 using OpRewritePattern<LoadOp>::OpRewritePattern;
1235
matchAndRewrite__anonbe946e4a0d11::LoadOfBufferCast1236 LogicalResult matchAndRewrite(LoadOp load,
1237 PatternRewriter &rewriter) const override {
1238 auto buffercast = load.memref().getDefiningOp<BufferCastOp>();
1239 if (!buffercast)
1240 return failure();
1241
1242 rewriter.replaceOpWithNewOp<tensor::ExtractOp>(load, buffercast.tensor(),
1243 load.indices());
1244 return success();
1245 }
1246 };
1247 } // end anonymous namespace.
1248
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1249 void LoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
1250 MLIRContext *context) {
1251 results.add<LoadOfBufferCast>(context);
1252 }
1253
1254 //===----------------------------------------------------------------------===//
1255 // PrefetchOp
1256 //===----------------------------------------------------------------------===//
1257
print(OpAsmPrinter & p,PrefetchOp op)1258 static void print(OpAsmPrinter &p, PrefetchOp op) {
1259 p << " " << op.memref() << '[';
1260 p.printOperands(op.indices());
1261 p << ']' << ", " << (op.isWrite() ? "write" : "read");
1262 p << ", locality<" << op.localityHint();
1263 p << ">, " << (op.isDataCache() ? "data" : "instr");
1264 p.printOptionalAttrDict(
1265 op->getAttrs(),
1266 /*elidedAttrs=*/{"localityHint", "isWrite", "isDataCache"});
1267 p << " : " << op.getMemRefType();
1268 }
1269
parsePrefetchOp(OpAsmParser & parser,OperationState & result)1270 static ParseResult parsePrefetchOp(OpAsmParser &parser,
1271 OperationState &result) {
1272 OpAsmParser::OperandType memrefInfo;
1273 SmallVector<OpAsmParser::OperandType, 4> indexInfo;
1274 IntegerAttr localityHint;
1275 MemRefType type;
1276 StringRef readOrWrite, cacheType;
1277
1278 auto indexTy = parser.getBuilder().getIndexType();
1279 auto i32Type = parser.getBuilder().getIntegerType(32);
1280 if (parser.parseOperand(memrefInfo) ||
1281 parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) ||
1282 parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
1283 parser.parseComma() || parser.parseKeyword("locality") ||
1284 parser.parseLess() ||
1285 parser.parseAttribute(localityHint, i32Type, "localityHint",
1286 result.attributes) ||
1287 parser.parseGreater() || parser.parseComma() ||
1288 parser.parseKeyword(&cacheType) || parser.parseColonType(type) ||
1289 parser.resolveOperand(memrefInfo, type, result.operands) ||
1290 parser.resolveOperands(indexInfo, indexTy, result.operands))
1291 return failure();
1292
1293 if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
1294 return parser.emitError(parser.getNameLoc(),
1295 "rw specifier has to be 'read' or 'write'");
1296 result.addAttribute(
1297 PrefetchOp::getIsWriteAttrName(),
1298 parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
1299
1300 if (!cacheType.equals("data") && !cacheType.equals("instr"))
1301 return parser.emitError(parser.getNameLoc(),
1302 "cache type has to be 'data' or 'instr'");
1303
1304 result.addAttribute(
1305 PrefetchOp::getIsDataCacheAttrName(),
1306 parser.getBuilder().getBoolAttr(cacheType.equals("data")));
1307
1308 return success();
1309 }
1310
verify(PrefetchOp op)1311 static LogicalResult verify(PrefetchOp op) {
1312 if (op.getNumOperands() != 1 + op.getMemRefType().getRank())
1313 return op.emitOpError("too few indices");
1314
1315 return success();
1316 }
1317
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1318 LogicalResult PrefetchOp::fold(ArrayRef<Attribute> cstOperands,
1319 SmallVectorImpl<OpFoldResult> &results) {
1320 // prefetch(memrefcast) -> prefetch
1321 return foldMemRefCast(*this);
1322 }
1323
1324 //===----------------------------------------------------------------------===//
1325 // ReinterpretCastOp
1326 //===----------------------------------------------------------------------===//
1327
1328 /// Build a ReinterpretCastOp with all dynamic entries: `staticOffsets`,
1329 /// `staticSizes` and `staticStrides` are automatically filled with
1330 /// source-memref-rank sentinel values that encode dynamic entries.
build(OpBuilder & b,OperationState & result,MemRefType resultType,Value source,OpFoldResult offset,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides,ArrayRef<NamedAttribute> attrs)1331 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1332 MemRefType resultType, Value source,
1333 OpFoldResult offset, ArrayRef<OpFoldResult> sizes,
1334 ArrayRef<OpFoldResult> strides,
1335 ArrayRef<NamedAttribute> attrs) {
1336 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1337 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1338 dispatchIndexOpFoldResults(offset, dynamicOffsets, staticOffsets,
1339 ShapedType::kDynamicStrideOrOffset);
1340 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1341 ShapedType::kDynamicSize);
1342 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1343 ShapedType::kDynamicStrideOrOffset);
1344 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1345 dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1346 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1347 result.addAttributes(attrs);
1348 }
1349
build(OpBuilder & b,OperationState & result,MemRefType resultType,Value source,int64_t offset,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides,ArrayRef<NamedAttribute> attrs)1350 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1351 MemRefType resultType, Value source,
1352 int64_t offset, ArrayRef<int64_t> sizes,
1353 ArrayRef<int64_t> strides,
1354 ArrayRef<NamedAttribute> attrs) {
1355 SmallVector<OpFoldResult> sizeValues =
1356 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1357 return b.getI64IntegerAttr(v);
1358 }));
1359 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1360 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1361 return b.getI64IntegerAttr(v);
1362 }));
1363 build(b, result, resultType, source, b.getI64IntegerAttr(offset), sizeValues,
1364 strideValues, attrs);
1365 }
1366
build(OpBuilder & b,OperationState & result,MemRefType resultType,Value source,Value offset,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)1367 void ReinterpretCastOp::build(OpBuilder &b, OperationState &result,
1368 MemRefType resultType, Value source, Value offset,
1369 ValueRange sizes, ValueRange strides,
1370 ArrayRef<NamedAttribute> attrs) {
1371 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1372 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1373 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1374 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1375 build(b, result, resultType, source, offset, sizeValues, strideValues, attrs);
1376 }
1377
1378 // TODO: ponder whether we want to allow missing trailing sizes/strides that are
1379 // completed automatically, like we have for subview and extract_slice.
verify(ReinterpretCastOp op)1380 static LogicalResult verify(ReinterpretCastOp op) {
1381 // The source and result memrefs should be in the same memory space.
1382 auto srcType = op.source().getType().cast<BaseMemRefType>();
1383 auto resultType = op.getType().cast<MemRefType>();
1384 if (srcType.getMemorySpace() != resultType.getMemorySpace())
1385 return op.emitError("different memory spaces specified for source type ")
1386 << srcType << " and result memref type " << resultType;
1387 if (srcType.getElementType() != resultType.getElementType())
1388 return op.emitError("different element types specified for source type ")
1389 << srcType << " and result memref type " << resultType;
1390
1391 // Match sizes in result memref type and in static_sizes attribute.
1392 for (auto &en :
1393 llvm::enumerate(llvm::zip(resultType.getShape(),
1394 extractFromI64ArrayAttr(op.static_sizes())))) {
1395 int64_t resultSize = std::get<0>(en.value());
1396 int64_t expectedSize = std::get<1>(en.value());
1397 if (resultSize != expectedSize)
1398 return op.emitError("expected result type with size = ")
1399 << expectedSize << " instead of " << resultSize
1400 << " in dim = " << en.index();
1401 }
1402
1403 // Match offset and strides in static_offset and static_strides attributes if
1404 // result memref type has an affine map specified.
1405 if (!resultType.getAffineMaps().empty()) {
1406 int64_t resultOffset;
1407 SmallVector<int64_t, 4> resultStrides;
1408 if (failed(getStridesAndOffset(resultType, resultStrides, resultOffset)))
1409 return failure();
1410
1411 // Match offset in result memref type and in static_offsets attribute.
1412 int64_t expectedOffset =
1413 extractFromI64ArrayAttr(op.static_offsets()).front();
1414 if (resultOffset != expectedOffset)
1415 return op.emitError("expected result type with offset = ")
1416 << resultOffset << " instead of " << expectedOffset;
1417
1418 // Match strides in result memref type and in static_strides attribute.
1419 for (auto &en : llvm::enumerate(llvm::zip(
1420 resultStrides, extractFromI64ArrayAttr(op.static_strides())))) {
1421 int64_t resultStride = std::get<0>(en.value());
1422 int64_t expectedStride = std::get<1>(en.value());
1423 if (resultStride != expectedStride)
1424 return op.emitError("expected result type with stride = ")
1425 << expectedStride << " instead of " << resultStride
1426 << " in dim = " << en.index();
1427 }
1428 }
1429 return success();
1430 }
1431
1432 //===----------------------------------------------------------------------===//
1433 // Reassociative reshape ops
1434 //===----------------------------------------------------------------------===//
1435
getReassociationMaps()1436 SmallVector<AffineMap, 4> CollapseShapeOp::getReassociationMaps() {
1437 return getSymbolLessAffineMaps(getReassociationExprs());
1438 }
getReassociationExprs()1439 SmallVector<ReassociationExprs, 4> CollapseShapeOp::getReassociationExprs() {
1440 return convertReassociationIndicesToExprs(getContext(),
1441 getReassociationIndices());
1442 }
1443
getReassociationMaps()1444 SmallVector<AffineMap, 4> ExpandShapeOp::getReassociationMaps() {
1445 return getSymbolLessAffineMaps(getReassociationExprs());
1446 }
getReassociationExprs()1447 SmallVector<ReassociationExprs, 4> ExpandShapeOp::getReassociationExprs() {
1448 return convertReassociationIndicesToExprs(getContext(),
1449 getReassociationIndices());
1450 }
1451
print(OpAsmPrinter & p,ExpandShapeOp op)1452 static void print(OpAsmPrinter &p, ExpandShapeOp op) {
1453 ::mlir::printReshapeOp<ExpandShapeOp>(p, op);
1454 }
1455
print(OpAsmPrinter & p,CollapseShapeOp op)1456 static void print(OpAsmPrinter &p, CollapseShapeOp op) {
1457 ::mlir::printReshapeOp<CollapseShapeOp>(p, op);
1458 }
1459
1460 /// Detect whether memref dims [dim, dim + extent) can be reshaped without
1461 /// copies.
isReshapableDimBand(unsigned dim,unsigned extent,ArrayRef<int64_t> sizes,ArrayRef<AffineExpr> strides)1462 static bool isReshapableDimBand(unsigned dim, unsigned extent,
1463 ArrayRef<int64_t> sizes,
1464 ArrayRef<AffineExpr> strides) {
1465 assert(sizes.size() == strides.size() && "mismatched ranks");
1466 // off by 1 indexing to avoid out of bounds
1467 // V
1468 for (auto idx = dim, e = dim + extent; idx + 1 < e; ++idx) {
1469 // Only bands of static shapes are reshapable. This is due to the fact that
1470 // there is no relation between dynamic sizes and dynamic strides: we do not
1471 // have enough information to know whether a "-1" size corresponds to the
1472 // proper symbol in the AffineExpr of a stride.
1473 if (ShapedType::isDynamic(sizes[dim + 1]))
1474 return false;
1475 // TODO: Refine this by passing the proper nDims and nSymbols so we can
1476 // simplify on the fly and catch more reshapable cases.
1477 if (strides[idx] != strides[idx + 1] * sizes[idx + 1])
1478 return false;
1479 }
1480 return true;
1481 }
1482
1483 /// Compute the MemRefType obtained by applying the `reassociation` (which is
1484 /// expected to be valid) to `type`.
1485 /// If `type` is Contiguous MemRefType, this always produce a contiguous
1486 /// MemRefType.
1487 static MemRefType
computeReshapeCollapsedType(MemRefType type,ArrayRef<AffineMap> reassociation)1488 computeReshapeCollapsedType(MemRefType type,
1489 ArrayRef<AffineMap> reassociation) {
1490 auto sizes = type.getShape();
1491 AffineExpr offset;
1492 SmallVector<AffineExpr, 4> strides;
1493 auto status = getStridesAndOffset(type, strides, offset);
1494 (void)status;
1495 assert(succeeded(status) && "expected strided memref");
1496
1497 SmallVector<int64_t, 4> newSizes;
1498 newSizes.reserve(reassociation.size());
1499 SmallVector<AffineExpr, 4> newStrides;
1500 newStrides.reserve(reassociation.size());
1501
1502 // Use the fact that reassociation is valid to simplify the logic: only use
1503 // each map's rank.
1504 assert(isReassociationValid(reassociation) && "invalid reassociation");
1505 unsigned currentDim = 0;
1506 for (AffineMap m : reassociation) {
1507 unsigned dim = m.getNumResults();
1508 int64_t size = 1;
1509 AffineExpr stride = strides[currentDim + dim - 1];
1510 if (!isReshapableDimBand(currentDim, dim, sizes, strides)) {
1511 size = ShapedType::kDynamicSize;
1512 stride = AffineExpr();
1513 } else {
1514 for (unsigned d = 0; d < dim; ++d)
1515 size *= sizes[currentDim + d];
1516 }
1517 newSizes.push_back(size);
1518 newStrides.push_back(stride);
1519 currentDim += dim;
1520 }
1521
1522 // Early-exit: if `type` is contiguous, the result must be contiguous.
1523 if (canonicalizeStridedLayout(type).getAffineMaps().empty())
1524 return MemRefType::Builder(type).setShape(newSizes).setAffineMaps({});
1525
1526 // Convert back to int64_t because we don't have enough information to create
1527 // new strided layouts from AffineExpr only. This corresponds to a case where
1528 // copies may be necessary.
1529 int64_t intOffset = ShapedType::kDynamicStrideOrOffset;
1530 if (auto o = offset.dyn_cast<AffineConstantExpr>())
1531 intOffset = o.getValue();
1532 SmallVector<int64_t, 4> intStrides;
1533 intStrides.reserve(strides.size());
1534 for (auto stride : newStrides) {
1535 if (auto cst = stride.dyn_cast_or_null<AffineConstantExpr>())
1536 intStrides.push_back(cst.getValue());
1537 else
1538 intStrides.push_back(ShapedType::kDynamicStrideOrOffset);
1539 }
1540 auto layout =
1541 makeStridedLinearLayoutMap(intStrides, intOffset, type.getContext());
1542 return canonicalizeStridedLayout(
1543 MemRefType::Builder(type).setShape(newSizes).setAffineMaps({layout}));
1544 }
1545
build(OpBuilder & b,OperationState & result,Value src,ArrayRef<ReassociationIndices> reassociation,ArrayRef<NamedAttribute> attrs)1546 void ExpandShapeOp::build(OpBuilder &b, OperationState &result, Value src,
1547 ArrayRef<ReassociationIndices> reassociation,
1548 ArrayRef<NamedAttribute> attrs) {
1549 auto memRefType = src.getType().cast<MemRefType>();
1550 auto resultType = computeReshapeCollapsedType(
1551 memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
1552 b.getContext(), reassociation)));
1553 build(b, result, resultType, src, attrs);
1554 result.addAttribute(getReassociationAttrName(),
1555 getReassociationIndicesAttribute(b, reassociation));
1556 }
1557
build(OpBuilder & b,OperationState & result,Value src,ArrayRef<ReassociationIndices> reassociation,ArrayRef<NamedAttribute> attrs)1558 void CollapseShapeOp::build(OpBuilder &b, OperationState &result, Value src,
1559 ArrayRef<ReassociationIndices> reassociation,
1560 ArrayRef<NamedAttribute> attrs) {
1561 auto memRefType = src.getType().cast<MemRefType>();
1562 auto resultType = computeReshapeCollapsedType(
1563 memRefType, getSymbolLessAffineMaps(convertReassociationIndicesToExprs(
1564 b.getContext(), reassociation)));
1565 build(b, result, resultType, src, attrs);
1566 result.addAttribute(getReassociationAttrName(),
1567 getReassociationIndicesAttribute(b, reassociation));
1568 }
1569
1570 template <typename ReshapeOp,
1571 bool isExpansion = std::is_same<ReshapeOp, ExpandShapeOp>::value>
verifyReshapeOp(ReshapeOp op,MemRefType expandedType,MemRefType collapsedType)1572 static LogicalResult verifyReshapeOp(ReshapeOp op, MemRefType expandedType,
1573 MemRefType collapsedType) {
1574 if (failed(
1575 verifyReshapeLikeTypes(op, expandedType, collapsedType, isExpansion)))
1576 return failure();
1577 auto maps = op.getReassociationMaps();
1578 MemRefType expectedType = computeReshapeCollapsedType(expandedType, maps);
1579 if (collapsedType != expectedType)
1580 return op.emitOpError("expected collapsed type to be ")
1581 << expectedType << ", but got " << collapsedType;
1582 return success();
1583 }
1584
verify(ExpandShapeOp op)1585 static LogicalResult verify(ExpandShapeOp op) {
1586 return verifyReshapeOp(op, op.getResultType(), op.getSrcType());
1587 }
1588
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1589 void ExpandShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1590 MLIRContext *context) {
1591 results.add<CollapseReshapeOps<ExpandShapeOp>,
1592 CollapseMixedReshapeOps<ExpandShapeOp, CollapseShapeOp>>(context);
1593 }
1594
verify(CollapseShapeOp op)1595 static LogicalResult verify(CollapseShapeOp op) {
1596 return verifyReshapeOp(op, op.getSrcType(), op.getResultType());
1597 }
1598
1599 struct CollapseShapeOpMemRefCastFolder
1600 : public OpRewritePattern<CollapseShapeOp> {
1601 public:
1602 using OpRewritePattern<CollapseShapeOp>::OpRewritePattern;
1603
matchAndRewriteCollapseShapeOpMemRefCastFolder1604 LogicalResult matchAndRewrite(CollapseShapeOp op,
1605 PatternRewriter &rewriter) const override {
1606 auto cast = op.getOperand().getDefiningOp<CastOp>();
1607 if (!cast)
1608 return failure();
1609
1610 if (!CastOp::canFoldIntoConsumerOp(cast))
1611 return failure();
1612
1613 Type newResultType = computeReshapeCollapsedType(
1614 cast.getOperand().getType().cast<MemRefType>(),
1615 op.getReassociationMaps());
1616
1617 if (newResultType == op.getResultType()) {
1618 rewriter.updateRootInPlace(
1619 op, [&]() { op.srcMutable().assign(cast.source()); });
1620 } else {
1621 Value newOp = rewriter.create<CollapseShapeOp>(
1622 op->getLoc(), cast.source(), op.getReassociationIndices());
1623 rewriter.replaceOpWithNewOp<CastOp>(op, op.getType(), newOp);
1624 }
1625 return success();
1626 }
1627 };
1628
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1629 void CollapseShapeOp::getCanonicalizationPatterns(RewritePatternSet &results,
1630 MLIRContext *context) {
1631 results.add<CollapseReshapeOps<CollapseShapeOp>,
1632 CollapseMixedReshapeOps<CollapseShapeOp, ExpandShapeOp>,
1633 CollapseShapeOpMemRefCastFolder>(context);
1634 }
fold(ArrayRef<Attribute> operands)1635 OpFoldResult ExpandShapeOp::fold(ArrayRef<Attribute> operands) {
1636 if (succeeded(foldMemRefCast(*this)))
1637 return getResult();
1638 return foldReshapeOp<ExpandShapeOp, CollapseShapeOp>(*this, operands);
1639 }
fold(ArrayRef<Attribute> operands)1640 OpFoldResult CollapseShapeOp::fold(ArrayRef<Attribute> operands) {
1641 return foldReshapeOp<CollapseShapeOp, ExpandShapeOp>(*this, operands);
1642 }
1643
1644 //===----------------------------------------------------------------------===//
1645 // ReshapeOp
1646 //===----------------------------------------------------------------------===//
1647
verify(ReshapeOp op)1648 static LogicalResult verify(ReshapeOp op) {
1649 Type operandType = op.source().getType();
1650 Type resultType = op.result().getType();
1651
1652 Type operandElementType = operandType.cast<ShapedType>().getElementType();
1653 Type resultElementType = resultType.cast<ShapedType>().getElementType();
1654 if (operandElementType != resultElementType)
1655 return op.emitOpError("element types of source and destination memref "
1656 "types should be the same");
1657
1658 if (auto operandMemRefType = operandType.dyn_cast<MemRefType>())
1659 if (!operandMemRefType.getAffineMaps().empty())
1660 return op.emitOpError(
1661 "source memref type should have identity affine map");
1662
1663 int64_t shapeSize = op.shape().getType().cast<MemRefType>().getDimSize(0);
1664 auto resultMemRefType = resultType.dyn_cast<MemRefType>();
1665 if (resultMemRefType) {
1666 if (!resultMemRefType.getAffineMaps().empty())
1667 return op.emitOpError(
1668 "result memref type should have identity affine map");
1669 if (shapeSize == ShapedType::kDynamicSize)
1670 return op.emitOpError("cannot use shape operand with dynamic length to "
1671 "reshape to statically-ranked memref type");
1672 if (shapeSize != resultMemRefType.getRank())
1673 return op.emitOpError(
1674 "length of shape operand differs from the result's memref rank");
1675 }
1676 return success();
1677 }
1678
1679 //===----------------------------------------------------------------------===//
1680 // StoreOp
1681 //===----------------------------------------------------------------------===//
1682
verify(StoreOp op)1683 static LogicalResult verify(StoreOp op) {
1684 if (op.getNumOperands() != 2 + op.getMemRefType().getRank())
1685 return op.emitOpError("store index operand count not equal to memref rank");
1686
1687 return success();
1688 }
1689
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1690 LogicalResult StoreOp::fold(ArrayRef<Attribute> cstOperands,
1691 SmallVectorImpl<OpFoldResult> &results) {
1692 /// store(memrefcast) -> store
1693 return foldMemRefCast(*this, getValueToStore());
1694 }
1695
1696 //===----------------------------------------------------------------------===//
1697 // SubViewOp
1698 //===----------------------------------------------------------------------===//
1699
1700 namespace {
1701 /// Helpers to write more idiomatic operations.
1702 namespace saturated_arith {
1703 struct Wrapper {
Wrapper__anonbe946e4a1311::saturated_arith::Wrapper1704 explicit Wrapper(int64_t v) : v(v) {}
operator int64_t__anonbe946e4a1311::saturated_arith::Wrapper1705 operator int64_t() { return v; }
1706 int64_t v;
1707 };
operator +(Wrapper a,int64_t b)1708 Wrapper operator+(Wrapper a, int64_t b) {
1709 if (ShapedType::isDynamicStrideOrOffset(a) ||
1710 ShapedType::isDynamicStrideOrOffset(b))
1711 return Wrapper(ShapedType::kDynamicStrideOrOffset);
1712 return Wrapper(a.v + b);
1713 }
operator *(Wrapper a,int64_t b)1714 Wrapper operator*(Wrapper a, int64_t b) {
1715 if (ShapedType::isDynamicStrideOrOffset(a) ||
1716 ShapedType::isDynamicStrideOrOffset(b))
1717 return Wrapper(ShapedType::kDynamicStrideOrOffset);
1718 return Wrapper(a.v * b);
1719 }
1720 } // end namespace saturated_arith
1721 } // end namespace
1722
1723 /// A subview result type can be fully inferred from the source type and the
1724 /// static representation of offsets, sizes and strides. Special sentinels
1725 /// encode the dynamic case.
inferResultType(MemRefType sourceMemRefType,ArrayRef<int64_t> leadingStaticOffsets,ArrayRef<int64_t> leadingStaticSizes,ArrayRef<int64_t> leadingStaticStrides)1726 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
1727 ArrayRef<int64_t> leadingStaticOffsets,
1728 ArrayRef<int64_t> leadingStaticSizes,
1729 ArrayRef<int64_t> leadingStaticStrides) {
1730 // A subview may specify only a leading subset of offset/sizes/strides in
1731 // which case we complete with offset=0, sizes from memref type and strides=1.
1732 unsigned rank = sourceMemRefType.getRank();
1733 assert(leadingStaticOffsets.size() <= rank &&
1734 "unexpected leadingStaticOffsets overflow");
1735 assert(leadingStaticSizes.size() <= rank &&
1736 "unexpected leadingStaticSizes overflow");
1737 assert(leadingStaticStrides.size() <= rank &&
1738 "unexpected leadingStaticStrides overflow");
1739 auto staticOffsets = llvm::to_vector<4>(leadingStaticOffsets);
1740 auto staticSizes = llvm::to_vector<4>(leadingStaticSizes);
1741 auto staticStrides = llvm::to_vector<4>(leadingStaticStrides);
1742 unsigned numTrailingOffsets = rank - staticOffsets.size();
1743 unsigned numTrailingSizes = rank - staticSizes.size();
1744 unsigned numTrailingStrides = rank - staticStrides.size();
1745 staticOffsets.append(numTrailingOffsets, 0);
1746 llvm::append_range(staticSizes,
1747 sourceMemRefType.getShape().take_back(numTrailingSizes));
1748 staticStrides.append(numTrailingStrides, 1);
1749
1750 // Extract source offset and strides.
1751 int64_t sourceOffset;
1752 SmallVector<int64_t, 4> sourceStrides;
1753 auto res = getStridesAndOffset(sourceMemRefType, sourceStrides, sourceOffset);
1754 assert(succeeded(res) && "SubViewOp expected strided memref type");
1755 (void)res;
1756
1757 // Compute target offset whose value is:
1758 // `sourceOffset + sum_i(staticOffset_i * sourceStrides_i)`.
1759 int64_t targetOffset = sourceOffset;
1760 for (auto it : llvm::zip(staticOffsets, sourceStrides)) {
1761 auto staticOffset = std::get<0>(it), targetStride = std::get<1>(it);
1762 using namespace saturated_arith;
1763 targetOffset = Wrapper(targetOffset) + Wrapper(staticOffset) * targetStride;
1764 }
1765
1766 // Compute target stride whose value is:
1767 // `sourceStrides_i * staticStrides_i`.
1768 SmallVector<int64_t, 4> targetStrides;
1769 targetStrides.reserve(staticOffsets.size());
1770 for (auto it : llvm::zip(sourceStrides, staticStrides)) {
1771 auto sourceStride = std::get<0>(it), staticStride = std::get<1>(it);
1772 using namespace saturated_arith;
1773 targetStrides.push_back(Wrapper(sourceStride) * staticStride);
1774 }
1775
1776 // The type is now known.
1777 return MemRefType::get(
1778 staticSizes, sourceMemRefType.getElementType(),
1779 makeStridedLinearLayoutMap(targetStrides, targetOffset,
1780 sourceMemRefType.getContext()),
1781 sourceMemRefType.getMemorySpace());
1782 }
1783
inferResultType(MemRefType sourceMemRefType,ArrayRef<OpFoldResult> leadingStaticOffsets,ArrayRef<OpFoldResult> leadingStaticSizes,ArrayRef<OpFoldResult> leadingStaticStrides)1784 Type SubViewOp::inferResultType(MemRefType sourceMemRefType,
1785 ArrayRef<OpFoldResult> leadingStaticOffsets,
1786 ArrayRef<OpFoldResult> leadingStaticSizes,
1787 ArrayRef<OpFoldResult> leadingStaticStrides) {
1788 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1789 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1790 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
1791 staticOffsets, ShapedType::kDynamicStrideOrOffset);
1792 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
1793 ShapedType::kDynamicSize);
1794 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
1795 staticStrides, ShapedType::kDynamicStrideOrOffset);
1796 return SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
1797 staticSizes, staticStrides)
1798 .cast<MemRefType>();
1799 }
1800
inferRankReducedResultType(unsigned resultRank,MemRefType sourceRankedTensorType,ArrayRef<int64_t> leadingStaticOffsets,ArrayRef<int64_t> leadingStaticSizes,ArrayRef<int64_t> leadingStaticStrides)1801 Type SubViewOp::inferRankReducedResultType(
1802 unsigned resultRank, MemRefType sourceRankedTensorType,
1803 ArrayRef<int64_t> leadingStaticOffsets,
1804 ArrayRef<int64_t> leadingStaticSizes,
1805 ArrayRef<int64_t> leadingStaticStrides) {
1806 auto inferredType =
1807 inferResultType(sourceRankedTensorType, leadingStaticOffsets,
1808 leadingStaticSizes, leadingStaticStrides)
1809 .cast<MemRefType>();
1810 assert(inferredType.getRank() >= resultRank && "expected ");
1811 int rankDiff = inferredType.getRank() - resultRank;
1812 if (rankDiff > 0) {
1813 auto shape = inferredType.getShape();
1814 llvm::SmallDenseSet<unsigned> dimsToProject;
1815 mlir::getPositionsOfShapeOne(rankDiff, shape, dimsToProject);
1816 SmallVector<int64_t> projectedShape;
1817 for (unsigned pos = 0, e = shape.size(); pos < e; ++pos)
1818 if (!dimsToProject.contains(pos))
1819 projectedShape.push_back(shape[pos]);
1820
1821 AffineMap map;
1822 auto maps = inferredType.getAffineMaps();
1823 if (!maps.empty() && maps.front())
1824 map = getProjectedMap(maps.front(), dimsToProject);
1825 inferredType =
1826 MemRefType::get(projectedShape, inferredType.getElementType(), map,
1827 inferredType.getMemorySpace());
1828 }
1829 return inferredType;
1830 }
1831
inferRankReducedResultType(unsigned resultRank,MemRefType sourceRankedTensorType,ArrayRef<OpFoldResult> leadingStaticOffsets,ArrayRef<OpFoldResult> leadingStaticSizes,ArrayRef<OpFoldResult> leadingStaticStrides)1832 Type SubViewOp::inferRankReducedResultType(
1833 unsigned resultRank, MemRefType sourceRankedTensorType,
1834 ArrayRef<OpFoldResult> leadingStaticOffsets,
1835 ArrayRef<OpFoldResult> leadingStaticSizes,
1836 ArrayRef<OpFoldResult> leadingStaticStrides) {
1837 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1838 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1839 dispatchIndexOpFoldResults(leadingStaticOffsets, dynamicOffsets,
1840 staticOffsets, ShapedType::kDynamicStrideOrOffset);
1841 dispatchIndexOpFoldResults(leadingStaticSizes, dynamicSizes, staticSizes,
1842 ShapedType::kDynamicSize);
1843 dispatchIndexOpFoldResults(leadingStaticStrides, dynamicStrides,
1844 staticStrides, ShapedType::kDynamicStrideOrOffset);
1845 return SubViewOp::inferRankReducedResultType(
1846 resultRank, sourceRankedTensorType, staticOffsets, staticSizes,
1847 staticStrides);
1848 }
1849 // Build a SubViewOp with mixed static and dynamic entries and custom result
1850 // type. If the type passed is nullptr, it is inferred.
build(OpBuilder & b,OperationState & result,MemRefType resultType,Value source,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides,ArrayRef<NamedAttribute> attrs)1851 void SubViewOp::build(OpBuilder &b, OperationState &result,
1852 MemRefType resultType, Value source,
1853 ArrayRef<OpFoldResult> offsets,
1854 ArrayRef<OpFoldResult> sizes,
1855 ArrayRef<OpFoldResult> strides,
1856 ArrayRef<NamedAttribute> attrs) {
1857 SmallVector<int64_t> staticOffsets, staticSizes, staticStrides;
1858 SmallVector<Value> dynamicOffsets, dynamicSizes, dynamicStrides;
1859 dispatchIndexOpFoldResults(offsets, dynamicOffsets, staticOffsets,
1860 ShapedType::kDynamicStrideOrOffset);
1861 dispatchIndexOpFoldResults(sizes, dynamicSizes, staticSizes,
1862 ShapedType::kDynamicSize);
1863 dispatchIndexOpFoldResults(strides, dynamicStrides, staticStrides,
1864 ShapedType::kDynamicStrideOrOffset);
1865 auto sourceMemRefType = source.getType().cast<MemRefType>();
1866 // Structuring implementation this way avoids duplication between builders.
1867 if (!resultType) {
1868 resultType = SubViewOp::inferResultType(sourceMemRefType, staticOffsets,
1869 staticSizes, staticStrides)
1870 .cast<MemRefType>();
1871 }
1872 build(b, result, resultType, source, dynamicOffsets, dynamicSizes,
1873 dynamicStrides, b.getI64ArrayAttr(staticOffsets),
1874 b.getI64ArrayAttr(staticSizes), b.getI64ArrayAttr(staticStrides));
1875 result.addAttributes(attrs);
1876 }
1877
1878 // Build a SubViewOp with mixed static and dynamic entries and inferred result
1879 // type.
build(OpBuilder & b,OperationState & result,Value source,ArrayRef<OpFoldResult> offsets,ArrayRef<OpFoldResult> sizes,ArrayRef<OpFoldResult> strides,ArrayRef<NamedAttribute> attrs)1880 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1881 ArrayRef<OpFoldResult> offsets,
1882 ArrayRef<OpFoldResult> sizes,
1883 ArrayRef<OpFoldResult> strides,
1884 ArrayRef<NamedAttribute> attrs) {
1885 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
1886 }
1887
1888 // Build a SubViewOp with static entries and inferred result type.
build(OpBuilder & b,OperationState & result,Value source,ArrayRef<int64_t> offsets,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides,ArrayRef<NamedAttribute> attrs)1889 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1890 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1891 ArrayRef<int64_t> strides,
1892 ArrayRef<NamedAttribute> attrs) {
1893 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1894 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
1895 return b.getI64IntegerAttr(v);
1896 }));
1897 SmallVector<OpFoldResult> sizeValues =
1898 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1899 return b.getI64IntegerAttr(v);
1900 }));
1901 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1902 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1903 return b.getI64IntegerAttr(v);
1904 }));
1905 build(b, result, source, offsetValues, sizeValues, strideValues, attrs);
1906 }
1907
1908 // Build a SubViewOp with dynamic entries and custom result type. If the
1909 // type passed is nullptr, it is inferred.
build(OpBuilder & b,OperationState & result,MemRefType resultType,Value source,ArrayRef<int64_t> offsets,ArrayRef<int64_t> sizes,ArrayRef<int64_t> strides,ArrayRef<NamedAttribute> attrs)1910 void SubViewOp::build(OpBuilder &b, OperationState &result,
1911 MemRefType resultType, Value source,
1912 ArrayRef<int64_t> offsets, ArrayRef<int64_t> sizes,
1913 ArrayRef<int64_t> strides,
1914 ArrayRef<NamedAttribute> attrs) {
1915 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1916 llvm::map_range(offsets, [&](int64_t v) -> OpFoldResult {
1917 return b.getI64IntegerAttr(v);
1918 }));
1919 SmallVector<OpFoldResult> sizeValues =
1920 llvm::to_vector<4>(llvm::map_range(sizes, [&](int64_t v) -> OpFoldResult {
1921 return b.getI64IntegerAttr(v);
1922 }));
1923 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1924 llvm::map_range(strides, [&](int64_t v) -> OpFoldResult {
1925 return b.getI64IntegerAttr(v);
1926 }));
1927 build(b, result, resultType, source, offsetValues, sizeValues, strideValues,
1928 attrs);
1929 }
1930
1931 // Build a SubViewOp with dynamic entries and custom result type. If the type
1932 // passed is nullptr, it is inferred.
build(OpBuilder & b,OperationState & result,MemRefType resultType,Value source,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)1933 void SubViewOp::build(OpBuilder &b, OperationState &result,
1934 MemRefType resultType, Value source, ValueRange offsets,
1935 ValueRange sizes, ValueRange strides,
1936 ArrayRef<NamedAttribute> attrs) {
1937 SmallVector<OpFoldResult> offsetValues = llvm::to_vector<4>(
1938 llvm::map_range(offsets, [](Value v) -> OpFoldResult { return v; }));
1939 SmallVector<OpFoldResult> sizeValues = llvm::to_vector<4>(
1940 llvm::map_range(sizes, [](Value v) -> OpFoldResult { return v; }));
1941 SmallVector<OpFoldResult> strideValues = llvm::to_vector<4>(
1942 llvm::map_range(strides, [](Value v) -> OpFoldResult { return v; }));
1943 build(b, result, resultType, source, offsetValues, sizeValues, strideValues);
1944 }
1945
1946 // Build a SubViewOp with dynamic entries and inferred result type.
build(OpBuilder & b,OperationState & result,Value source,ValueRange offsets,ValueRange sizes,ValueRange strides,ArrayRef<NamedAttribute> attrs)1947 void SubViewOp::build(OpBuilder &b, OperationState &result, Value source,
1948 ValueRange offsets, ValueRange sizes, ValueRange strides,
1949 ArrayRef<NamedAttribute> attrs) {
1950 build(b, result, MemRefType(), source, offsets, sizes, strides, attrs);
1951 }
1952
1953 /// For ViewLikeOpInterface.
getViewSource()1954 Value SubViewOp::getViewSource() { return source(); }
1955
1956 enum SubViewVerificationResult {
1957 Success,
1958 RankTooLarge,
1959 SizeMismatch,
1960 ElemTypeMismatch,
1961 MemSpaceMismatch,
1962 AffineMapMismatch
1963 };
1964
1965 /// Checks if `original` Type type can be rank reduced to `reduced` type.
1966 /// This function is slight variant of `is subsequence` algorithm where
1967 /// not matching dimension must be 1.
1968 static SubViewVerificationResult
isRankReducedType(Type originalType,Type candidateReducedType,ArrayAttr staticSizes,std::string * errMsg=nullptr)1969 isRankReducedType(Type originalType, Type candidateReducedType,
1970 ArrayAttr staticSizes, std::string *errMsg = nullptr) {
1971 if (originalType == candidateReducedType)
1972 return SubViewVerificationResult::Success;
1973 if (!originalType.isa<MemRefType>())
1974 return SubViewVerificationResult::Success;
1975 if (originalType.isa<MemRefType>() && !candidateReducedType.isa<MemRefType>())
1976 return SubViewVerificationResult::Success;
1977
1978 ShapedType originalShapedType = originalType.cast<ShapedType>();
1979 ShapedType candidateReducedShapedType =
1980 candidateReducedType.cast<ShapedType>();
1981
1982 // Rank and size logic is valid for all ShapedTypes.
1983 ArrayRef<int64_t> originalShape = originalShapedType.getShape();
1984 ArrayRef<int64_t> candidateReducedShape =
1985 candidateReducedShapedType.getShape();
1986 unsigned originalRank = originalShape.size(),
1987 candidateReducedRank = candidateReducedShape.size();
1988 if (candidateReducedRank > originalRank)
1989 return SubViewVerificationResult::RankTooLarge;
1990
1991 MemRefType original = originalType.cast<MemRefType>();
1992 MemRefType candidateReduced = candidateReducedType.cast<MemRefType>();
1993
1994 auto optionalUnusedDimsMask =
1995 computeMemRefRankReductionMask(original, candidateReduced, staticSizes);
1996
1997 // Sizes cannot be matched in case empty vector is returned.
1998 if (!optionalUnusedDimsMask.hasValue())
1999 return SubViewVerificationResult::SizeMismatch;
2000
2001 if (originalShapedType.getElementType() !=
2002 candidateReducedShapedType.getElementType())
2003 return SubViewVerificationResult::ElemTypeMismatch;
2004
2005 // Strided layout logic is relevant for MemRefType only.
2006 if (original.getMemorySpace() != candidateReduced.getMemorySpace())
2007 return SubViewVerificationResult::MemSpaceMismatch;
2008 return SubViewVerificationResult::Success;
2009 }
2010
2011 template <typename OpTy>
produceSubViewErrorMsg(SubViewVerificationResult result,OpTy op,Type expectedType,StringRef errMsg="")2012 static LogicalResult produceSubViewErrorMsg(SubViewVerificationResult result,
2013 OpTy op, Type expectedType,
2014 StringRef errMsg = "") {
2015 auto memrefType = expectedType.cast<ShapedType>();
2016 switch (result) {
2017 case SubViewVerificationResult::Success:
2018 return success();
2019 case SubViewVerificationResult::RankTooLarge:
2020 return op.emitError("expected result rank to be smaller or equal to ")
2021 << "the source rank. " << errMsg;
2022 case SubViewVerificationResult::SizeMismatch:
2023 return op.emitError("expected result type to be ")
2024 << expectedType
2025 << " or a rank-reduced version. (mismatch of result sizes) "
2026 << errMsg;
2027 case SubViewVerificationResult::ElemTypeMismatch:
2028 return op.emitError("expected result element type to be ")
2029 << memrefType.getElementType() << errMsg;
2030 case SubViewVerificationResult::MemSpaceMismatch:
2031 return op.emitError("expected result and source memory spaces to match.")
2032 << errMsg;
2033 case SubViewVerificationResult::AffineMapMismatch:
2034 return op.emitError("expected result type to be ")
2035 << expectedType
2036 << " or a rank-reduced version. (mismatch of result affine map) "
2037 << errMsg;
2038 }
2039 llvm_unreachable("unexpected subview verification result");
2040 }
2041
2042 /// Verifier for SubViewOp.
verify(SubViewOp op)2043 static LogicalResult verify(SubViewOp op) {
2044 MemRefType baseType = op.getSourceType();
2045 MemRefType subViewType = op.getType();
2046
2047 // The base memref and the view memref should be in the same memory space.
2048 if (baseType.getMemorySpace() != subViewType.getMemorySpace())
2049 return op.emitError("different memory spaces specified for base memref "
2050 "type ")
2051 << baseType << " and subview memref type " << subViewType;
2052
2053 // Verify that the base memref type has a strided layout map.
2054 if (!isStrided(baseType))
2055 return op.emitError("base type ") << baseType << " is not strided";
2056
2057 // Verify result type against inferred type.
2058 auto expectedType = SubViewOp::inferResultType(
2059 baseType, extractFromI64ArrayAttr(op.static_offsets()),
2060 extractFromI64ArrayAttr(op.static_sizes()),
2061 extractFromI64ArrayAttr(op.static_strides()));
2062
2063 std::string errMsg;
2064 auto result =
2065 isRankReducedType(expectedType, subViewType, op.static_sizes(), &errMsg);
2066 return produceSubViewErrorMsg(result, op, expectedType, errMsg);
2067 }
2068
operator <<(raw_ostream & os,const Range & range)2069 raw_ostream &mlir::operator<<(raw_ostream &os, const Range &range) {
2070 return os << "range " << range.offset << ":" << range.size << ":"
2071 << range.stride;
2072 }
2073
2074 /// Return the list of Range (i.e. offset, size, stride). Each Range
2075 /// entry contains either the dynamic value or a ConstantIndexOp constructed
2076 /// with `b` at location `loc`.
getOrCreateRanges(OffsetSizeAndStrideOpInterface op,OpBuilder & b,Location loc)2077 SmallVector<Range, 8> mlir::getOrCreateRanges(OffsetSizeAndStrideOpInterface op,
2078 OpBuilder &b, Location loc) {
2079 std::array<unsigned, 3> ranks = op.getArrayAttrMaxRanks();
2080 assert(ranks[0] == ranks[1] && "expected offset and sizes of equal ranks");
2081 assert(ranks[1] == ranks[2] && "expected sizes and strides of equal ranks");
2082 SmallVector<Range, 8> res;
2083 unsigned rank = ranks[0];
2084 res.reserve(rank);
2085 for (unsigned idx = 0; idx < rank; ++idx) {
2086 Value offset =
2087 op.isDynamicOffset(idx)
2088 ? op.getDynamicOffset(idx)
2089 : b.create<ConstantIndexOp>(loc, op.getStaticOffset(idx));
2090 Value size = op.isDynamicSize(idx)
2091 ? op.getDynamicSize(idx)
2092 : b.create<ConstantIndexOp>(loc, op.getStaticSize(idx));
2093 Value stride =
2094 op.isDynamicStride(idx)
2095 ? op.getDynamicStride(idx)
2096 : b.create<ConstantIndexOp>(loc, op.getStaticStride(idx));
2097 res.emplace_back(Range{offset, size, stride});
2098 }
2099 return res;
2100 }
2101
2102 /// Infer the canonical type of the result of a subview operation. Returns a
2103 /// type with rank `resultRank` that is either the rank of the rank-reduced
2104 /// type, or the non-rank-reduced type.
2105 static MemRefType
getCanonicalSubViewResultType(unsigned resultRank,MemRefType sourceType,ArrayRef<OpFoldResult> mixedOffsets,ArrayRef<OpFoldResult> mixedSizes,ArrayRef<OpFoldResult> mixedStrides)2106 getCanonicalSubViewResultType(unsigned resultRank, MemRefType sourceType,
2107 ArrayRef<OpFoldResult> mixedOffsets,
2108 ArrayRef<OpFoldResult> mixedSizes,
2109 ArrayRef<OpFoldResult> mixedStrides) {
2110 auto resultType =
2111 SubViewOp::inferRankReducedResultType(
2112 resultRank, sourceType, mixedOffsets, mixedSizes, mixedStrides)
2113 .cast<MemRefType>();
2114 if (resultType.getRank() != resultRank) {
2115 resultType = SubViewOp::inferResultType(sourceType, mixedOffsets,
2116 mixedSizes, mixedStrides)
2117 .cast<MemRefType>();
2118 }
2119 return resultType;
2120 }
2121
2122 namespace {
2123 /// Pattern to rewrite a subview op with MemRefCast arguments.
2124 /// This essentially pushes memref.cast past its consuming subview when
2125 /// `canFoldIntoConsumerOp` is true.
2126 ///
2127 /// Example:
2128 /// ```
2129 /// %0 = memref.cast %V : memref<16x16xf32> to memref<?x?xf32>
2130 /// %1 = memref.subview %0[0, 0][3, 4][1, 1] :
2131 /// memref<?x?xf32> to memref<3x4xf32, offset:?, strides:[?, 1]>
2132 /// ```
2133 /// is rewritten into:
2134 /// ```
2135 /// %0 = memref.subview %V: memref<16x16xf32> to memref<3x4xf32, #[[map0]]>
2136 /// %1 = memref.cast %0: memref<3x4xf32, offset:0, strides:[16, 1]> to
2137 /// memref<3x4xf32, offset:?, strides:[?, 1]>
2138 /// ```
2139 class SubViewOpMemRefCastFolder final : public OpRewritePattern<SubViewOp> {
2140 public:
2141 using OpRewritePattern<SubViewOp>::OpRewritePattern;
2142
matchAndRewrite(SubViewOp subViewOp,PatternRewriter & rewriter) const2143 LogicalResult matchAndRewrite(SubViewOp subViewOp,
2144 PatternRewriter &rewriter) const override {
2145 // Any constant operand, just return to let SubViewOpConstantFolder kick in.
2146 if (llvm::any_of(subViewOp.getOperands(), [](Value operand) {
2147 return matchPattern(operand, matchConstantIndex());
2148 }))
2149 return failure();
2150
2151 auto castOp = subViewOp.source().getDefiningOp<CastOp>();
2152 if (!castOp)
2153 return failure();
2154
2155 if (!CastOp::canFoldIntoConsumerOp(castOp))
2156 return failure();
2157
2158 /// Deduce the resultType of the SubViewOp using `inferSubViewResultType` on
2159 /// the cast source operand type and the SubViewOp static information. This
2160 /// is the resulting type if the MemRefCastOp were folded.
2161 auto resultType = getCanonicalSubViewResultType(
2162 subViewOp.getType().getRank(),
2163 castOp.source().getType().cast<MemRefType>(),
2164 subViewOp.getMixedOffsets(), subViewOp.getMixedSizes(),
2165 subViewOp.getMixedStrides());
2166 Value newSubView = rewriter.create<SubViewOp>(
2167 subViewOp.getLoc(), resultType, castOp.source(), subViewOp.offsets(),
2168 subViewOp.sizes(), subViewOp.strides(), subViewOp.static_offsets(),
2169 subViewOp.static_sizes(), subViewOp.static_strides());
2170 rewriter.replaceOpWithNewOp<CastOp>(subViewOp, subViewOp.getType(),
2171 newSubView);
2172 return success();
2173 }
2174 };
2175 } // namespace
2176
2177 /// Return the canonical type of the result of a subview.
2178 struct SubViewReturnTypeCanonicalizer {
operator ()SubViewReturnTypeCanonicalizer2179 MemRefType operator()(SubViewOp op, ArrayRef<OpFoldResult> mixedOffsets,
2180 ArrayRef<OpFoldResult> mixedSizes,
2181 ArrayRef<OpFoldResult> mixedStrides) {
2182 return getCanonicalSubViewResultType(op.getType().getRank(),
2183 op.getSourceType(), mixedOffsets,
2184 mixedSizes, mixedStrides);
2185 }
2186 };
2187
2188 /// A canonicalizer wrapper to replace SubViewOps.
2189 struct SubViewCanonicalizer {
operator ()SubViewCanonicalizer2190 void operator()(PatternRewriter &rewriter, SubViewOp op, SubViewOp newOp) {
2191 rewriter.replaceOpWithNewOp<CastOp>(op, newOp, op.getType());
2192 }
2193 };
2194
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2195 void SubViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
2196 MLIRContext *context) {
2197 results
2198 .add<OpWithOffsetSizesAndStridesConstantArgumentFolder<
2199 SubViewOp, SubViewReturnTypeCanonicalizer, SubViewCanonicalizer>,
2200 SubViewOpMemRefCastFolder>(context);
2201 }
2202
fold(ArrayRef<Attribute> operands)2203 OpFoldResult SubViewOp::fold(ArrayRef<Attribute> operands) {
2204 auto resultShapedType = getResult().getType().cast<ShapedType>();
2205 auto sourceShapedType = source().getType().cast<ShapedType>();
2206
2207 if (resultShapedType.hasStaticShape() &&
2208 resultShapedType == sourceShapedType) {
2209 return getViewSource();
2210 }
2211
2212 return {};
2213 }
2214
2215 //===----------------------------------------------------------------------===//
2216 // TensorLoadOp
2217 //===----------------------------------------------------------------------===//
2218
fold(ArrayRef<Attribute>)2219 OpFoldResult TensorLoadOp::fold(ArrayRef<Attribute>) {
2220 if (auto bufferCast = memref().getDefiningOp<BufferCastOp>())
2221 // Approximate alias analysis by conservatively folding only when no there
2222 // is no interleaved operation.
2223 if (bufferCast->getBlock() == this->getOperation()->getBlock() &&
2224 bufferCast->getNextNode() == this->getOperation())
2225 return bufferCast.tensor();
2226 return {};
2227 }
2228
2229 namespace {
2230 struct DimOfTensorLoadFolder : public OpRewritePattern<tensor::DimOp> {
2231 using OpRewritePattern<tensor::DimOp>::OpRewritePattern;
2232
matchAndRewrite__anonbe946e4a1f11::DimOfTensorLoadFolder2233 LogicalResult matchAndRewrite(tensor::DimOp dimOp,
2234 PatternRewriter &rewriter) const override {
2235 auto tensorLoadOp = dimOp.source().getDefiningOp<TensorLoadOp>();
2236 if (!tensorLoadOp)
2237 return failure();
2238
2239 rewriter.replaceOpWithNewOp<DimOp>(dimOp, tensorLoadOp.memref(),
2240 dimOp.index());
2241 return success();
2242 }
2243 };
2244 } // namespace
2245
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2246 void TensorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
2247 MLIRContext *context) {
2248 results.add<DimOfTensorLoadFolder>(context);
2249 }
2250
2251 //===----------------------------------------------------------------------===//
2252 // TransposeOp
2253 //===----------------------------------------------------------------------===//
2254
2255 /// Build a strided memref type by applying `permutationMap` tp `memRefType`.
inferTransposeResultType(MemRefType memRefType,AffineMap permutationMap)2256 static MemRefType inferTransposeResultType(MemRefType memRefType,
2257 AffineMap permutationMap) {
2258 auto rank = memRefType.getRank();
2259 auto originalSizes = memRefType.getShape();
2260 // Compute permuted sizes.
2261 SmallVector<int64_t, 4> sizes(rank, 0);
2262 for (auto en : llvm::enumerate(permutationMap.getResults()))
2263 sizes[en.index()] =
2264 originalSizes[en.value().cast<AffineDimExpr>().getPosition()];
2265
2266 // Compute permuted strides.
2267 int64_t offset;
2268 SmallVector<int64_t, 4> strides;
2269 auto res = getStridesAndOffset(memRefType, strides, offset);
2270 assert(succeeded(res) && strides.size() == static_cast<unsigned>(rank));
2271 (void)res;
2272 auto map =
2273 makeStridedLinearLayoutMap(strides, offset, memRefType.getContext());
2274 map = permutationMap ? map.compose(permutationMap) : map;
2275 return MemRefType::Builder(memRefType).setShape(sizes).setAffineMaps(map);
2276 }
2277
build(OpBuilder & b,OperationState & result,Value in,AffineMapAttr permutation,ArrayRef<NamedAttribute> attrs)2278 void TransposeOp::build(OpBuilder &b, OperationState &result, Value in,
2279 AffineMapAttr permutation,
2280 ArrayRef<NamedAttribute> attrs) {
2281 auto permutationMap = permutation.getValue();
2282 assert(permutationMap);
2283
2284 auto memRefType = in.getType().cast<MemRefType>();
2285 // Compute result type.
2286 MemRefType resultType = inferTransposeResultType(memRefType, permutationMap);
2287
2288 build(b, result, resultType, in, attrs);
2289 result.addAttribute(TransposeOp::getPermutationAttrName(), permutation);
2290 }
2291
2292 // transpose $in $permutation attr-dict : type($in) `to` type(results)
print(OpAsmPrinter & p,TransposeOp op)2293 static void print(OpAsmPrinter &p, TransposeOp op) {
2294 p << " " << op.in() << " " << op.permutation();
2295 p.printOptionalAttrDict(op->getAttrs(),
2296 {TransposeOp::getPermutationAttrName()});
2297 p << " : " << op.in().getType() << " to " << op.getType();
2298 }
2299
parseTransposeOp(OpAsmParser & parser,OperationState & result)2300 static ParseResult parseTransposeOp(OpAsmParser &parser,
2301 OperationState &result) {
2302 OpAsmParser::OperandType in;
2303 AffineMap permutation;
2304 MemRefType srcType, dstType;
2305 if (parser.parseOperand(in) || parser.parseAffineMap(permutation) ||
2306 parser.parseOptionalAttrDict(result.attributes) ||
2307 parser.parseColonType(srcType) ||
2308 parser.resolveOperand(in, srcType, result.operands) ||
2309 parser.parseKeywordType("to", dstType) ||
2310 parser.addTypeToList(dstType, result.types))
2311 return failure();
2312
2313 result.addAttribute(TransposeOp::getPermutationAttrName(),
2314 AffineMapAttr::get(permutation));
2315 return success();
2316 }
2317
verify(TransposeOp op)2318 static LogicalResult verify(TransposeOp op) {
2319 if (!op.permutation().isPermutation())
2320 return op.emitOpError("expected a permutation map");
2321 if (op.permutation().getNumDims() != op.getShapedType().getRank())
2322 return op.emitOpError(
2323 "expected a permutation map of same rank as the input");
2324
2325 auto srcType = op.in().getType().cast<MemRefType>();
2326 auto dstType = op.getType().cast<MemRefType>();
2327 auto transposedType = inferTransposeResultType(srcType, op.permutation());
2328 if (dstType != transposedType)
2329 return op.emitOpError("output type ")
2330 << dstType << " does not match transposed input type " << srcType
2331 << ", " << transposedType;
2332 return success();
2333 }
2334
fold(ArrayRef<Attribute>)2335 OpFoldResult TransposeOp::fold(ArrayRef<Attribute>) {
2336 if (succeeded(foldMemRefCast(*this)))
2337 return getResult();
2338 return {};
2339 }
2340
2341 //===----------------------------------------------------------------------===//
2342 // ViewOp
2343 //===----------------------------------------------------------------------===//
2344
parseViewOp(OpAsmParser & parser,OperationState & result)2345 static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) {
2346 OpAsmParser::OperandType srcInfo;
2347 SmallVector<OpAsmParser::OperandType, 1> offsetInfo;
2348 SmallVector<OpAsmParser::OperandType, 4> sizesInfo;
2349 auto indexType = parser.getBuilder().getIndexType();
2350 Type srcType, dstType;
2351 llvm::SMLoc offsetLoc;
2352 if (parser.parseOperand(srcInfo) || parser.getCurrentLocation(&offsetLoc) ||
2353 parser.parseOperandList(offsetInfo, OpAsmParser::Delimiter::Square))
2354 return failure();
2355
2356 if (offsetInfo.size() != 1)
2357 return parser.emitError(offsetLoc) << "expects 1 offset operand";
2358
2359 return failure(
2360 parser.parseOperandList(sizesInfo, OpAsmParser::Delimiter::Square) ||
2361 parser.parseOptionalAttrDict(result.attributes) ||
2362 parser.parseColonType(srcType) ||
2363 parser.resolveOperand(srcInfo, srcType, result.operands) ||
2364 parser.resolveOperands(offsetInfo, indexType, result.operands) ||
2365 parser.resolveOperands(sizesInfo, indexType, result.operands) ||
2366 parser.parseKeywordType("to", dstType) ||
2367 parser.addTypeToList(dstType, result.types));
2368 }
2369
print(OpAsmPrinter & p,ViewOp op)2370 static void print(OpAsmPrinter &p, ViewOp op) {
2371 p << ' ' << op.getOperand(0) << '[';
2372 p.printOperand(op.byte_shift());
2373 p << "][" << op.sizes() << ']';
2374 p.printOptionalAttrDict(op->getAttrs());
2375 p << " : " << op.getOperand(0).getType() << " to " << op.getType();
2376 }
2377
verify(ViewOp op)2378 static LogicalResult verify(ViewOp op) {
2379 auto baseType = op.getOperand(0).getType().cast<MemRefType>();
2380 auto viewType = op.getType();
2381
2382 // The base memref should have identity layout map (or none).
2383 if (baseType.getAffineMaps().size() > 1 ||
2384 (baseType.getAffineMaps().size() == 1 &&
2385 !baseType.getAffineMaps()[0].isIdentity()))
2386 return op.emitError("unsupported map for base memref type ") << baseType;
2387
2388 // The result memref should have identity layout map (or none).
2389 if (viewType.getAffineMaps().size() > 1 ||
2390 (viewType.getAffineMaps().size() == 1 &&
2391 !viewType.getAffineMaps()[0].isIdentity()))
2392 return op.emitError("unsupported map for result memref type ") << viewType;
2393
2394 // The base memref and the view memref should be in the same memory space.
2395 if (baseType.getMemorySpace() != viewType.getMemorySpace())
2396 return op.emitError("different memory spaces specified for base memref "
2397 "type ")
2398 << baseType << " and view memref type " << viewType;
2399
2400 // Verify that we have the correct number of sizes for the result type.
2401 unsigned numDynamicDims = viewType.getNumDynamicDims();
2402 if (op.sizes().size() != numDynamicDims)
2403 return op.emitError("incorrect number of size operands for type ")
2404 << viewType;
2405
2406 return success();
2407 }
2408
getViewSource()2409 Value ViewOp::getViewSource() { return source(); }
2410
2411 namespace {
2412
2413 struct ViewOpShapeFolder : public OpRewritePattern<ViewOp> {
2414 using OpRewritePattern<ViewOp>::OpRewritePattern;
2415
matchAndRewrite__anonbe946e4a2011::ViewOpShapeFolder2416 LogicalResult matchAndRewrite(ViewOp viewOp,
2417 PatternRewriter &rewriter) const override {
2418 // Return if none of the operands are constants.
2419 if (llvm::none_of(viewOp.getOperands(), [](Value operand) {
2420 return matchPattern(operand, matchConstantIndex());
2421 }))
2422 return failure();
2423
2424 // Get result memref type.
2425 auto memrefType = viewOp.getType();
2426
2427 // Get offset from old memref view type 'memRefType'.
2428 int64_t oldOffset;
2429 SmallVector<int64_t, 4> oldStrides;
2430 if (failed(getStridesAndOffset(memrefType, oldStrides, oldOffset)))
2431 return failure();
2432 assert(oldOffset == 0 && "Expected 0 offset");
2433
2434 SmallVector<Value, 4> newOperands;
2435
2436 // Offset cannot be folded into result type.
2437
2438 // Fold any dynamic dim operands which are produced by a constant.
2439 SmallVector<int64_t, 4> newShapeConstants;
2440 newShapeConstants.reserve(memrefType.getRank());
2441
2442 unsigned dynamicDimPos = 0;
2443 unsigned rank = memrefType.getRank();
2444 for (unsigned dim = 0, e = rank; dim < e; ++dim) {
2445 int64_t dimSize = memrefType.getDimSize(dim);
2446 // If this is already static dimension, keep it.
2447 if (!ShapedType::isDynamic(dimSize)) {
2448 newShapeConstants.push_back(dimSize);
2449 continue;
2450 }
2451 auto *defOp = viewOp.sizes()[dynamicDimPos].getDefiningOp();
2452 if (auto constantIndexOp = dyn_cast_or_null<ConstantIndexOp>(defOp)) {
2453 // Dynamic shape dimension will be folded.
2454 newShapeConstants.push_back(constantIndexOp.getValue());
2455 } else {
2456 // Dynamic shape dimension not folded; copy operand from old memref.
2457 newShapeConstants.push_back(dimSize);
2458 newOperands.push_back(viewOp.sizes()[dynamicDimPos]);
2459 }
2460 dynamicDimPos++;
2461 }
2462
2463 // Create new memref type with constant folded dims.
2464 MemRefType newMemRefType =
2465 MemRefType::Builder(memrefType).setShape(newShapeConstants);
2466 // Nothing new, don't fold.
2467 if (newMemRefType == memrefType)
2468 return failure();
2469
2470 // Create new ViewOp.
2471 auto newViewOp = rewriter.create<ViewOp>(viewOp.getLoc(), newMemRefType,
2472 viewOp.getOperand(0),
2473 viewOp.byte_shift(), newOperands);
2474 // Insert a cast so we have the same type as the old memref type.
2475 rewriter.replaceOpWithNewOp<CastOp>(viewOp, newViewOp, viewOp.getType());
2476 return success();
2477 }
2478 };
2479
2480 struct ViewOpMemrefCastFolder : public OpRewritePattern<ViewOp> {
2481 using OpRewritePattern<ViewOp>::OpRewritePattern;
2482
matchAndRewrite__anonbe946e4a2011::ViewOpMemrefCastFolder2483 LogicalResult matchAndRewrite(ViewOp viewOp,
2484 PatternRewriter &rewriter) const override {
2485 Value memrefOperand = viewOp.getOperand(0);
2486 CastOp memrefCastOp = memrefOperand.getDefiningOp<CastOp>();
2487 if (!memrefCastOp)
2488 return failure();
2489 Value allocOperand = memrefCastOp.getOperand();
2490 AllocOp allocOp = allocOperand.getDefiningOp<AllocOp>();
2491 if (!allocOp)
2492 return failure();
2493 rewriter.replaceOpWithNewOp<ViewOp>(viewOp, viewOp.getType(), allocOperand,
2494 viewOp.byte_shift(), viewOp.sizes());
2495 return success();
2496 }
2497 };
2498
2499 } // end anonymous namespace
2500
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2501 void ViewOp::getCanonicalizationPatterns(RewritePatternSet &results,
2502 MLIRContext *context) {
2503 results.add<ViewOpShapeFolder, ViewOpMemrefCastFolder>(context);
2504 }
2505
2506 //===----------------------------------------------------------------------===//
2507 // TableGen'd op method definitions
2508 //===----------------------------------------------------------------------===//
2509
2510 #define GET_OP_CLASSES
2511 #include "mlir/Dialect/MemRef/IR/MemRefOps.cpp.inc"
2512