1 //===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
11 #include "mlir/Dialect/MemRef/IR/MemRef.h"
12 #include "mlir/Dialect/StandardOps/IR/Ops.h"
13 #include "mlir/IR/BlockAndValueMapping.h"
14 #include "mlir/IR/BuiltinOps.h"
15 #include "mlir/IR/IntegerSet.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/OpImplementation.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/Transforms/InliningUtils.h"
20 #include "llvm/ADT/SetVector.h"
21 #include "llvm/ADT/SmallBitVector.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/Debug.h"
24
25 using namespace mlir;
26
27 #define DEBUG_TYPE "affine-analysis"
28
29 #include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
30
31 /// A utility function to check if a value is defined at the top level of
32 /// `region` or is an argument of `region`. A value of index type defined at the
33 /// top level of a `AffineScope` region is always a valid symbol for all
34 /// uses in that region.
isTopLevelValue(Value value,Region * region)35 static bool isTopLevelValue(Value value, Region *region) {
36 if (auto arg = value.dyn_cast<BlockArgument>())
37 return arg.getParentRegion() == region;
38 return value.getDefiningOp()->getParentRegion() == region;
39 }
40
41 /// Checks if `value` known to be a legal affine dimension or symbol in `src`
42 /// region remains legal if the operation that uses it is inlined into `dest`
43 /// with the given value mapping. `legalityCheck` is either `isValidDim` or
44 /// `isValidSymbol`, depending on the value being required to remain a valid
45 /// dimension or symbol.
46 static bool
remainsLegalAfterInline(Value value,Region * src,Region * dest,const BlockAndValueMapping & mapping,function_ref<bool (Value,Region *)> legalityCheck)47 remainsLegalAfterInline(Value value, Region *src, Region *dest,
48 const BlockAndValueMapping &mapping,
49 function_ref<bool(Value, Region *)> legalityCheck) {
50 // If the value is a valid dimension for any other reason than being
51 // a top-level value, it will remain valid: constants get inlined
52 // with the function, transitive affine applies also get inlined and
53 // will be checked themselves, etc.
54 if (!isTopLevelValue(value, src))
55 return true;
56
57 // If it's a top-level value because it's a block operand, i.e. a
58 // function argument, check whether the value replacing it after
59 // inlining is a valid dimension in the new region.
60 if (value.isa<BlockArgument>())
61 return legalityCheck(mapping.lookup(value), dest);
62
63 // If it's a top-level value because it's defined in the region,
64 // it can only be inlined if the defining op is a constant or a
65 // `dim`, which can appear anywhere and be valid, since the defining
66 // op won't be top-level anymore after inlining.
67 Attribute operandCst;
68 return matchPattern(value.getDefiningOp(), m_Constant(&operandCst)) ||
69 value.getDefiningOp<memref::DimOp>() ||
70 value.getDefiningOp<tensor::DimOp>();
71 }
72
73 /// Checks if all values known to be legal affine dimensions or symbols in `src`
74 /// remain so if their respective users are inlined into `dest`.
75 static bool
remainsLegalAfterInline(ValueRange values,Region * src,Region * dest,const BlockAndValueMapping & mapping,function_ref<bool (Value,Region *)> legalityCheck)76 remainsLegalAfterInline(ValueRange values, Region *src, Region *dest,
77 const BlockAndValueMapping &mapping,
78 function_ref<bool(Value, Region *)> legalityCheck) {
79 return llvm::all_of(values, [&](Value v) {
80 return remainsLegalAfterInline(v, src, dest, mapping, legalityCheck);
81 });
82 }
83
84 /// Checks if an affine read or write operation remains legal after inlining
85 /// from `src` to `dest`.
86 template <typename OpTy>
remainsLegalAfterInline(OpTy op,Region * src,Region * dest,const BlockAndValueMapping & mapping)87 static bool remainsLegalAfterInline(OpTy op, Region *src, Region *dest,
88 const BlockAndValueMapping &mapping) {
89 static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
90 AffineWriteOpInterface>::value,
91 "only ops with affine read/write interface are supported");
92
93 AffineMap map = op.getAffineMap();
94 ValueRange dimOperands = op.getMapOperands().take_front(map.getNumDims());
95 ValueRange symbolOperands =
96 op.getMapOperands().take_back(map.getNumSymbols());
97 if (!remainsLegalAfterInline(
98 dimOperands, src, dest, mapping,
99 static_cast<bool (*)(Value, Region *)>(isValidDim)))
100 return false;
101 if (!remainsLegalAfterInline(
102 symbolOperands, src, dest, mapping,
103 static_cast<bool (*)(Value, Region *)>(isValidSymbol)))
104 return false;
105 return true;
106 }
107
108 /// Checks if an affine apply operation remains legal after inlining from `src`
109 /// to `dest`.
110 template <>
remainsLegalAfterInline(AffineApplyOp op,Region * src,Region * dest,const BlockAndValueMapping & mapping)111 bool remainsLegalAfterInline(AffineApplyOp op, Region *src, Region *dest,
112 const BlockAndValueMapping &mapping) {
113 // If it's a valid dimension, we need to check that it remains so.
114 if (isValidDim(op.getResult(), src))
115 return remainsLegalAfterInline(
116 op.getMapOperands(), src, dest, mapping,
117 static_cast<bool (*)(Value, Region *)>(isValidDim));
118
119 // Otherwise it must be a valid symbol, check that it remains so.
120 return remainsLegalAfterInline(
121 op.getMapOperands(), src, dest, mapping,
122 static_cast<bool (*)(Value, Region *)>(isValidSymbol));
123 }
124
125 //===----------------------------------------------------------------------===//
126 // AffineDialect Interfaces
127 //===----------------------------------------------------------------------===//
128
129 namespace {
130 /// This class defines the interface for handling inlining with affine
131 /// operations.
132 struct AffineInlinerInterface : public DialectInlinerInterface {
133 using DialectInlinerInterface::DialectInlinerInterface;
134
135 //===--------------------------------------------------------------------===//
136 // Analysis Hooks
137 //===--------------------------------------------------------------------===//
138
139 /// Returns true if the given region 'src' can be inlined into the region
140 /// 'dest' that is attached to an operation registered to the current dialect.
141 /// 'wouldBeCloned' is set if the region is cloned into its new location
142 /// rather than moved, indicating there may be other users.
isLegalToInline__anon4d71c4590211::AffineInlinerInterface143 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
144 BlockAndValueMapping &valueMapping) const final {
145 // We can inline into affine loops and conditionals if this doesn't break
146 // affine value categorization rules.
147 Operation *destOp = dest->getParentOp();
148 if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
149 return false;
150
151 // Multi-block regions cannot be inlined into affine constructs, all of
152 // which require single-block regions.
153 if (!llvm::hasSingleElement(*src))
154 return false;
155
156 // Side-effecting operations that the affine dialect cannot understand
157 // should not be inlined.
158 Block &srcBlock = src->front();
159 for (Operation &op : srcBlock) {
160 // Ops with no side effects are fine,
161 if (auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
162 if (iface.hasNoEffect())
163 continue;
164 }
165
166 // Assuming the inlined region is valid, we only need to check if the
167 // inlining would change it.
168 bool remainsValid =
169 llvm::TypeSwitch<Operation *, bool>(&op)
170 .Case<AffineApplyOp, AffineReadOpInterface,
171 AffineWriteOpInterface>([&](auto op) {
172 return remainsLegalAfterInline(op, src, dest, valueMapping);
173 })
174 .Default([](Operation *) {
175 // Conservatively disallow inlining ops we cannot reason about.
176 return false;
177 });
178
179 if (!remainsValid)
180 return false;
181 }
182
183 return true;
184 }
185
186 /// Returns true if the given operation 'op', that is registered to this
187 /// dialect, can be inlined into the given region, false otherwise.
isLegalToInline__anon4d71c4590211::AffineInlinerInterface188 bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
189 BlockAndValueMapping &valueMapping) const final {
190 // Always allow inlining affine operations into a region that is marked as
191 // affine scope, or into affine loops and conditionals. There are some edge
192 // cases when inlining *into* affine structures, but that is handled in the
193 // other 'isLegalToInline' hook above.
194 Operation *parentOp = region->getParentOp();
195 return parentOp->hasTrait<OpTrait::AffineScope>() ||
196 isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
197 }
198
199 /// Affine regions should be analyzed recursively.
shouldAnalyzeRecursively__anon4d71c4590211::AffineInlinerInterface200 bool shouldAnalyzeRecursively(Operation *op) const final { return true; }
201 };
202 } // end anonymous namespace
203
204 //===----------------------------------------------------------------------===//
205 // AffineDialect
206 //===----------------------------------------------------------------------===//
207
initialize()208 void AffineDialect::initialize() {
209 addOperations<AffineDmaStartOp, AffineDmaWaitOp,
210 #define GET_OP_LIST
211 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
212 >();
213 addInterfaces<AffineInlinerInterface>();
214 }
215
216 /// Materialize a single constant operation from a given attribute value with
217 /// the desired resultant type.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)218 Operation *AffineDialect::materializeConstant(OpBuilder &builder,
219 Attribute value, Type type,
220 Location loc) {
221 return builder.create<ConstantOp>(loc, type, value);
222 }
223
224 /// A utility function to check if a value is defined at the top level of an
225 /// op with trait `AffineScope`. If the value is defined in an unlinked region,
226 /// conservatively assume it is not top-level. A value of index type defined at
227 /// the top level is always a valid symbol.
isTopLevelValue(Value value)228 bool mlir::isTopLevelValue(Value value) {
229 if (auto arg = value.dyn_cast<BlockArgument>()) {
230 // The block owning the argument may be unlinked, e.g. when the surrounding
231 // region has not yet been attached to an Op, at which point the parent Op
232 // is null.
233 Operation *parentOp = arg.getOwner()->getParentOp();
234 return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
235 }
236 // The defining Op may live in an unlinked block so its parent Op may be null.
237 Operation *parentOp = value.getDefiningOp()->getParentOp();
238 return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
239 }
240
241 /// Returns the closest region enclosing `op` that is held by an operation with
242 /// trait `AffineScope`; `nullptr` if there is no such region.
243 // TODO: getAffineScope should be publicly exposed for affine passes/utilities.
getAffineScope(Operation * op)244 static Region *getAffineScope(Operation *op) {
245 auto *curOp = op;
246 while (auto *parentOp = curOp->getParentOp()) {
247 if (parentOp->hasTrait<OpTrait::AffineScope>())
248 return curOp->getParentRegion();
249 curOp = parentOp;
250 }
251 return nullptr;
252 }
253
254 // A Value can be used as a dimension id iff it meets one of the following
255 // conditions:
256 // *) It is valid as a symbol.
257 // *) It is an induction variable.
258 // *) It is the result of affine apply operation with dimension id arguments.
isValidDim(Value value)259 bool mlir::isValidDim(Value value) {
260 // The value must be an index type.
261 if (!value.getType().isIndex())
262 return false;
263
264 if (auto *defOp = value.getDefiningOp())
265 return isValidDim(value, getAffineScope(defOp));
266
267 // This value has to be a block argument for an op that has the
268 // `AffineScope` trait or for an affine.for or affine.parallel.
269 auto *parentOp = value.cast<BlockArgument>().getOwner()->getParentOp();
270 return parentOp && (parentOp->hasTrait<OpTrait::AffineScope>() ||
271 isa<AffineForOp, AffineParallelOp>(parentOp));
272 }
273
274 // Value can be used as a dimension id iff it meets one of the following
275 // conditions:
276 // *) It is valid as a symbol.
277 // *) It is an induction variable.
278 // *) It is the result of an affine apply operation with dimension id operands.
isValidDim(Value value,Region * region)279 bool mlir::isValidDim(Value value, Region *region) {
280 // The value must be an index type.
281 if (!value.getType().isIndex())
282 return false;
283
284 // All valid symbols are okay.
285 if (isValidSymbol(value, region))
286 return true;
287
288 auto *op = value.getDefiningOp();
289 if (!op) {
290 // This value has to be a block argument for an affine.for or an
291 // affine.parallel.
292 auto *parentOp = value.cast<BlockArgument>().getOwner()->getParentOp();
293 return isa<AffineForOp, AffineParallelOp>(parentOp);
294 }
295
296 // Affine apply operation is ok if all of its operands are ok.
297 if (auto applyOp = dyn_cast<AffineApplyOp>(op))
298 return applyOp.isValidDim(region);
299 // The dim op is okay if its operand memref/tensor is defined at the top
300 // level.
301 if (auto dimOp = dyn_cast<memref::DimOp>(op))
302 return isTopLevelValue(dimOp.source());
303 if (auto dimOp = dyn_cast<tensor::DimOp>(op))
304 return isTopLevelValue(dimOp.source());
305 return false;
306 }
307
308 /// Returns true if the 'index' dimension of the `memref` defined by
309 /// `memrefDefOp` is a statically shaped one or defined using a valid symbol
310 /// for `region`.
311 template <typename AnyMemRefDefOp>
isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp,unsigned index,Region * region)312 static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index,
313 Region *region) {
314 auto memRefType = memrefDefOp.getType();
315 // Statically shaped.
316 if (!memRefType.isDynamicDim(index))
317 return true;
318 // Get the position of the dimension among dynamic dimensions;
319 unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
320 return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
321 region);
322 }
323
324 /// Returns true if the result of the dim op is a valid symbol for `region`.
325 template <typename OpTy>
isDimOpValidSymbol(OpTy dimOp,Region * region)326 static bool isDimOpValidSymbol(OpTy dimOp, Region *region) {
327 // The dim op is okay if its source is defined at the top level.
328 if (isTopLevelValue(dimOp.source()))
329 return true;
330
331 // Conservatively handle remaining BlockArguments as non-valid symbols.
332 // E.g. scf.for iterArgs.
333 if (dimOp.source().template isa<BlockArgument>())
334 return false;
335
336 // The dim op is also okay if its operand memref is a view/subview whose
337 // corresponding size is a valid symbol.
338 Optional<int64_t> index = dimOp.getConstantIndex();
339 assert(index.hasValue() &&
340 "expect only `dim` operations with a constant index");
341 int64_t i = index.getValue();
342 return TypeSwitch<Operation *, bool>(dimOp.source().getDefiningOp())
343 .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
344 [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); })
345 .Default([](Operation *) { return false; });
346 }
347
348 // A value can be used as a symbol (at all its use sites) iff it meets one of
349 // the following conditions:
350 // *) It is a constant.
351 // *) Its defining op or block arg appearance is immediately enclosed by an op
352 // with `AffineScope` trait.
353 // *) It is the result of an affine.apply operation with symbol operands.
354 // *) It is a result of the dim op on a memref whose corresponding size is a
355 // valid symbol.
isValidSymbol(Value value)356 bool mlir::isValidSymbol(Value value) {
357 // The value must be an index type.
358 if (!value.getType().isIndex())
359 return false;
360
361 // Check that the value is a top level value.
362 if (isTopLevelValue(value))
363 return true;
364
365 if (auto *defOp = value.getDefiningOp())
366 return isValidSymbol(value, getAffineScope(defOp));
367
368 return false;
369 }
370
371 /// A value can be used as a symbol for `region` iff it meets one of the
372 /// following conditions:
373 /// *) It is a constant.
374 /// *) It is the result of an affine apply operation with symbol arguments.
375 /// *) It is a result of the dim op on a memref whose corresponding size is
376 /// a valid symbol.
377 /// *) It is defined at the top level of 'region' or is its argument.
378 /// *) It dominates `region`'s parent op.
379 /// If `region` is null, conservatively assume the symbol definition scope does
380 /// not exist and only accept the values that would be symbols regardless of
381 /// the surrounding region structure, i.e. the first three cases above.
isValidSymbol(Value value,Region * region)382 bool mlir::isValidSymbol(Value value, Region *region) {
383 // The value must be an index type.
384 if (!value.getType().isIndex())
385 return false;
386
387 // A top-level value is a valid symbol.
388 if (region && ::isTopLevelValue(value, region))
389 return true;
390
391 auto *defOp = value.getDefiningOp();
392 if (!defOp) {
393 // A block argument that is not a top-level value is a valid symbol if it
394 // dominates region's parent op.
395 Operation *regionOp = region ? region->getParentOp() : nullptr;
396 if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
397 if (auto *parentOpRegion = region->getParentOp()->getParentRegion())
398 return isValidSymbol(value, parentOpRegion);
399 return false;
400 }
401
402 // Constant operation is ok.
403 Attribute operandCst;
404 if (matchPattern(defOp, m_Constant(&operandCst)))
405 return true;
406
407 // Affine apply operation is ok if all of its operands are ok.
408 if (auto applyOp = dyn_cast<AffineApplyOp>(defOp))
409 return applyOp.isValidSymbol(region);
410
411 // Dim op results could be valid symbols at any level.
412 if (auto dimOp = dyn_cast<memref::DimOp>(defOp))
413 return isDimOpValidSymbol(dimOp, region);
414 if (auto dimOp = dyn_cast<tensor::DimOp>(defOp))
415 return isDimOpValidSymbol(dimOp, region);
416
417 // Check for values dominating `region`'s parent op.
418 Operation *regionOp = region ? region->getParentOp() : nullptr;
419 if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
420 if (auto *parentRegion = region->getParentOp()->getParentRegion())
421 return isValidSymbol(value, parentRegion);
422
423 return false;
424 }
425
426 // Returns true if 'value' is a valid index to an affine operation (e.g.
427 // affine.load, affine.store, affine.dma_start, affine.dma_wait) where
428 // `region` provides the polyhedral symbol scope. Returns false otherwise.
isValidAffineIndexOperand(Value value,Region * region)429 static bool isValidAffineIndexOperand(Value value, Region *region) {
430 return isValidDim(value, region) || isValidSymbol(value, region);
431 }
432
433 /// Prints dimension and symbol list.
printDimAndSymbolList(Operation::operand_iterator begin,Operation::operand_iterator end,unsigned numDims,OpAsmPrinter & printer)434 static void printDimAndSymbolList(Operation::operand_iterator begin,
435 Operation::operand_iterator end,
436 unsigned numDims, OpAsmPrinter &printer) {
437 OperandRange operands(begin, end);
438 printer << '(' << operands.take_front(numDims) << ')';
439 if (operands.size() > numDims)
440 printer << '[' << operands.drop_front(numDims) << ']';
441 }
442
443 /// Parses dimension and symbol list and returns true if parsing failed.
parseDimAndSymbolList(OpAsmParser & parser,SmallVectorImpl<Value> & operands,unsigned & numDims)444 ParseResult mlir::parseDimAndSymbolList(OpAsmParser &parser,
445 SmallVectorImpl<Value> &operands,
446 unsigned &numDims) {
447 SmallVector<OpAsmParser::OperandType, 8> opInfos;
448 if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren))
449 return failure();
450 // Store number of dimensions for validation by caller.
451 numDims = opInfos.size();
452
453 // Parse the optional symbol operands.
454 auto indexTy = parser.getBuilder().getIndexType();
455 return failure(parser.parseOperandList(
456 opInfos, OpAsmParser::Delimiter::OptionalSquare) ||
457 parser.resolveOperands(opInfos, indexTy, operands));
458 }
459
460 /// Utility function to verify that a set of operands are valid dimension and
461 /// symbol identifiers. The operands should be laid out such that the dimension
462 /// operands are before the symbol operands. This function returns failure if
463 /// there was an invalid operand. An operation is provided to emit any necessary
464 /// errors.
465 template <typename OpTy>
466 static LogicalResult
verifyDimAndSymbolIdentifiers(OpTy & op,Operation::operand_range operands,unsigned numDims)467 verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands,
468 unsigned numDims) {
469 unsigned opIt = 0;
470 for (auto operand : operands) {
471 if (opIt++ < numDims) {
472 if (!isValidDim(operand, getAffineScope(op)))
473 return op.emitOpError("operand cannot be used as a dimension id");
474 } else if (!isValidSymbol(operand, getAffineScope(op))) {
475 return op.emitOpError("operand cannot be used as a symbol");
476 }
477 }
478 return success();
479 }
480
481 //===----------------------------------------------------------------------===//
482 // AffineApplyOp
483 //===----------------------------------------------------------------------===//
484
getAffineValueMap()485 AffineValueMap AffineApplyOp::getAffineValueMap() {
486 return AffineValueMap(getAffineMap(), getOperands(), getResult());
487 }
488
parseAffineApplyOp(OpAsmParser & parser,OperationState & result)489 static ParseResult parseAffineApplyOp(OpAsmParser &parser,
490 OperationState &result) {
491 auto &builder = parser.getBuilder();
492 auto indexTy = builder.getIndexType();
493
494 AffineMapAttr mapAttr;
495 unsigned numDims;
496 if (parser.parseAttribute(mapAttr, "map", result.attributes) ||
497 parseDimAndSymbolList(parser, result.operands, numDims) ||
498 parser.parseOptionalAttrDict(result.attributes))
499 return failure();
500 auto map = mapAttr.getValue();
501
502 if (map.getNumDims() != numDims ||
503 numDims + map.getNumSymbols() != result.operands.size()) {
504 return parser.emitError(parser.getNameLoc(),
505 "dimension or symbol index mismatch");
506 }
507
508 result.types.append(map.getNumResults(), indexTy);
509 return success();
510 }
511
print(OpAsmPrinter & p,AffineApplyOp op)512 static void print(OpAsmPrinter &p, AffineApplyOp op) {
513 p << AffineApplyOp::getOperationName() << " " << op.mapAttr();
514 printDimAndSymbolList(op.operand_begin(), op.operand_end(),
515 op.getAffineMap().getNumDims(), p);
516 p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"map"});
517 }
518
verify(AffineApplyOp op)519 static LogicalResult verify(AffineApplyOp op) {
520 // Check input and output dimensions match.
521 auto map = op.map();
522
523 // Verify that operand count matches affine map dimension and symbol count.
524 if (op.getNumOperands() != map.getNumDims() + map.getNumSymbols())
525 return op.emitOpError(
526 "operand count and affine map dimension and symbol count must match");
527
528 // Verify that the map only produces one result.
529 if (map.getNumResults() != 1)
530 return op.emitOpError("mapping must produce one value");
531
532 return success();
533 }
534
535 // The result of the affine apply operation can be used as a dimension id if all
536 // its operands are valid dimension ids.
isValidDim()537 bool AffineApplyOp::isValidDim() {
538 return llvm::all_of(getOperands(),
539 [](Value op) { return mlir::isValidDim(op); });
540 }
541
542 // The result of the affine apply operation can be used as a dimension id if all
543 // its operands are valid dimension ids with the parent operation of `region`
544 // defining the polyhedral scope for symbols.
isValidDim(Region * region)545 bool AffineApplyOp::isValidDim(Region *region) {
546 return llvm::all_of(getOperands(),
547 [&](Value op) { return ::isValidDim(op, region); });
548 }
549
550 // The result of the affine apply operation can be used as a symbol if all its
551 // operands are symbols.
isValidSymbol()552 bool AffineApplyOp::isValidSymbol() {
553 return llvm::all_of(getOperands(),
554 [](Value op) { return mlir::isValidSymbol(op); });
555 }
556
557 // The result of the affine apply operation can be used as a symbol in `region`
558 // if all its operands are symbols in `region`.
isValidSymbol(Region * region)559 bool AffineApplyOp::isValidSymbol(Region *region) {
560 return llvm::all_of(getOperands(), [&](Value operand) {
561 return mlir::isValidSymbol(operand, region);
562 });
563 }
564
fold(ArrayRef<Attribute> operands)565 OpFoldResult AffineApplyOp::fold(ArrayRef<Attribute> operands) {
566 auto map = getAffineMap();
567
568 // Fold dims and symbols to existing values.
569 auto expr = map.getResult(0);
570 if (auto dim = expr.dyn_cast<AffineDimExpr>())
571 return getOperand(dim.getPosition());
572 if (auto sym = expr.dyn_cast<AffineSymbolExpr>())
573 return getOperand(map.getNumDims() + sym.getPosition());
574
575 // Otherwise, default to folding the map.
576 SmallVector<Attribute, 1> result;
577 if (failed(map.constantFold(operands, result)))
578 return {};
579 return result[0];
580 }
581
582 /// Replace all occurrences of AffineExpr at position `pos` in `map` by the
583 /// defining AffineApplyOp expression and operands.
584 /// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
585 /// When `dimOrSymbolPosition >= dims.size()`,
586 /// AffineSymbolExpr@[pos - dims.size()] is replaced.
587 /// Mutate `map`,`dims` and `syms` in place as follows:
588 /// 1. `dims` and `syms` are only appended to.
589 /// 2. `map` dim and symbols are gradually shifted to higer positions.
590 /// 3. Old `dim` and `sym` entries are replaced by nullptr
591 /// This avoids the need for any bookkeeping.
replaceDimOrSym(AffineMap * map,unsigned dimOrSymbolPosition,SmallVectorImpl<Value> & dims,SmallVectorImpl<Value> & syms)592 static LogicalResult replaceDimOrSym(AffineMap *map,
593 unsigned dimOrSymbolPosition,
594 SmallVectorImpl<Value> &dims,
595 SmallVectorImpl<Value> &syms) {
596 bool isDimReplacement = (dimOrSymbolPosition < dims.size());
597 unsigned pos = isDimReplacement ? dimOrSymbolPosition
598 : dimOrSymbolPosition - dims.size();
599 Value &v = isDimReplacement ? dims[pos] : syms[pos];
600 if (!v)
601 return failure();
602
603 auto affineApply = v.getDefiningOp<AffineApplyOp>();
604 if (!affineApply)
605 return failure();
606
607 // At this point we will perform a replacement of `v`, set the entry in `dim`
608 // or `sym` to nullptr immediately.
609 v = nullptr;
610
611 // Compute the map, dims and symbols coming from the AffineApplyOp.
612 AffineMap composeMap = affineApply.getAffineMap();
613 assert(composeMap.getNumResults() == 1 && "affine.apply with >1 results");
614 AffineExpr composeExpr =
615 composeMap.shiftDims(dims.size()).shiftSymbols(syms.size()).getResult(0);
616 ValueRange composeDims =
617 affineApply.getMapOperands().take_front(composeMap.getNumDims());
618 ValueRange composeSyms =
619 affineApply.getMapOperands().take_back(composeMap.getNumSymbols());
620
621 // Perform the replacement and append the dims and symbols where relevant.
622 MLIRContext *ctx = map->getContext();
623 AffineExpr toReplace = isDimReplacement ? getAffineDimExpr(pos, ctx)
624 : getAffineSymbolExpr(pos, ctx);
625 *map = map->replace(toReplace, composeExpr, dims.size(), syms.size());
626 dims.append(composeDims.begin(), composeDims.end());
627 syms.append(composeSyms.begin(), composeSyms.end());
628
629 return success();
630 }
631
632 /// Iterate over `operands` and fold away all those produced by an AffineApplyOp
633 /// iteratively. Perform canonicalization of map and operands as well as
634 /// AffineMap simplification. `map` and `operands` are mutated in place.
composeAffineMapAndOperands(AffineMap * map,SmallVectorImpl<Value> * operands)635 static void composeAffineMapAndOperands(AffineMap *map,
636 SmallVectorImpl<Value> *operands) {
637 if (map->getNumResults() == 0) {
638 canonicalizeMapAndOperands(map, operands);
639 *map = simplifyAffineMap(*map);
640 return;
641 }
642
643 MLIRContext *ctx = map->getContext();
644 SmallVector<Value, 4> dims(operands->begin(),
645 operands->begin() + map->getNumDims());
646 SmallVector<Value, 4> syms(operands->begin() + map->getNumDims(),
647 operands->end());
648
649 // Iterate over dims and symbols coming from AffineApplyOp and replace until
650 // exhaustion. This iteratively mutates `map`, `dims` and `syms`. Both `dims`
651 // and `syms` can only increase by construction.
652 // The implementation uses a `while` loop to support the case of symbols
653 // that may be constructed from dims ;this may be overkill.
654 while (true) {
655 bool changed = false;
656 for (unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
657 if ((changed |= succeeded(replaceDimOrSym(map, pos, dims, syms))))
658 break;
659 if (!changed)
660 break;
661 }
662
663 // Clear operands so we can fill them anew.
664 operands->clear();
665
666 // At this point we may have introduced null operands, prune them out before
667 // canonicalizing map and operands.
668 unsigned nDims = 0, nSyms = 0;
669 SmallVector<AffineExpr, 4> dimReplacements, symReplacements;
670 dimReplacements.reserve(dims.size());
671 symReplacements.reserve(syms.size());
672 for (auto *container : {&dims, &syms}) {
673 bool isDim = (container == &dims);
674 auto &repls = isDim ? dimReplacements : symReplacements;
675 for (auto en : llvm::enumerate(*container)) {
676 Value v = en.value();
677 if (!v) {
678 assert(isDim ? !map->isFunctionOfDim(en.index())
679 : !map->isFunctionOfSymbol(en.index()) &&
680 "map is function of unexpected expr@pos");
681 repls.push_back(getAffineConstantExpr(0, ctx));
682 continue;
683 }
684 repls.push_back(isDim ? getAffineDimExpr(nDims++, ctx)
685 : getAffineSymbolExpr(nSyms++, ctx));
686 operands->push_back(v);
687 }
688 }
689 *map = map->replaceDimsAndSymbols(dimReplacements, symReplacements, nDims,
690 nSyms);
691
692 // Canonicalize and simplify before returning.
693 canonicalizeMapAndOperands(map, operands);
694 *map = simplifyAffineMap(*map);
695 }
696
fullyComposeAffineMapAndOperands(AffineMap * map,SmallVectorImpl<Value> * operands)697 void mlir::fullyComposeAffineMapAndOperands(AffineMap *map,
698 SmallVectorImpl<Value> *operands) {
699 while (llvm::any_of(*operands, [](Value v) {
700 return isa_and_nonnull<AffineApplyOp>(v.getDefiningOp());
701 })) {
702 composeAffineMapAndOperands(map, operands);
703 }
704 }
705
makeComposedAffineApply(OpBuilder & b,Location loc,AffineMap map,ValueRange operands)706 AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc,
707 AffineMap map,
708 ValueRange operands) {
709 AffineMap normalizedMap = map;
710 SmallVector<Value, 8> normalizedOperands(operands.begin(), operands.end());
711 composeAffineMapAndOperands(&normalizedMap, &normalizedOperands);
712 assert(normalizedMap);
713 return b.create<AffineApplyOp>(loc, normalizedMap, normalizedOperands);
714 }
715
makeComposedAffineApply(OpBuilder & b,Location loc,AffineExpr e,ValueRange values)716 AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc,
717 AffineExpr e, ValueRange values) {
718 return makeComposedAffineApply(
719 b, loc, AffineMap::inferFromExprList(ArrayRef<AffineExpr>{e}).front(),
720 values);
721 }
722
723 // A symbol may appear as a dim in affine.apply operations. This function
724 // canonicalizes dims that are valid symbols into actual symbols.
725 template <class MapOrSet>
canonicalizePromotedSymbols(MapOrSet * mapOrSet,SmallVectorImpl<Value> * operands)726 static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,
727 SmallVectorImpl<Value> *operands) {
728 if (!mapOrSet || operands->empty())
729 return;
730
731 assert(mapOrSet->getNumInputs() == operands->size() &&
732 "map/set inputs must match number of operands");
733
734 auto *context = mapOrSet->getContext();
735 SmallVector<Value, 8> resultOperands;
736 resultOperands.reserve(operands->size());
737 SmallVector<Value, 8> remappedSymbols;
738 remappedSymbols.reserve(operands->size());
739 unsigned nextDim = 0;
740 unsigned nextSym = 0;
741 unsigned oldNumSyms = mapOrSet->getNumSymbols();
742 SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
743 for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
744 if (i < mapOrSet->getNumDims()) {
745 if (isValidSymbol((*operands)[i])) {
746 // This is a valid symbol that appears as a dim, canonicalize it.
747 dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context);
748 remappedSymbols.push_back((*operands)[i]);
749 } else {
750 dimRemapping[i] = getAffineDimExpr(nextDim++, context);
751 resultOperands.push_back((*operands)[i]);
752 }
753 } else {
754 resultOperands.push_back((*operands)[i]);
755 }
756 }
757
758 resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
759 *operands = resultOperands;
760 *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
761 oldNumSyms + nextSym);
762
763 assert(mapOrSet->getNumInputs() == operands->size() &&
764 "map/set inputs must match number of operands");
765 }
766
767 // Works for either an affine map or an integer set.
768 template <class MapOrSet>
canonicalizeMapOrSetAndOperands(MapOrSet * mapOrSet,SmallVectorImpl<Value> * operands)769 static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
770 SmallVectorImpl<Value> *operands) {
771 static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
772 "Argument must be either of AffineMap or IntegerSet type");
773
774 if (!mapOrSet || operands->empty())
775 return;
776
777 assert(mapOrSet->getNumInputs() == operands->size() &&
778 "map/set inputs must match number of operands");
779
780 canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
781
782 // Check to see what dims are used.
783 llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
784 llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
785 mapOrSet->walkExprs([&](AffineExpr expr) {
786 if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
787 usedDims[dimExpr.getPosition()] = true;
788 else if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
789 usedSyms[symExpr.getPosition()] = true;
790 });
791
792 auto *context = mapOrSet->getContext();
793
794 SmallVector<Value, 8> resultOperands;
795 resultOperands.reserve(operands->size());
796
797 llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
798 SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
799 unsigned nextDim = 0;
800 for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
801 if (usedDims[i]) {
802 // Remap dim positions for duplicate operands.
803 auto it = seenDims.find((*operands)[i]);
804 if (it == seenDims.end()) {
805 dimRemapping[i] = getAffineDimExpr(nextDim++, context);
806 resultOperands.push_back((*operands)[i]);
807 seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
808 } else {
809 dimRemapping[i] = it->second;
810 }
811 }
812 }
813 llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
814 SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
815 unsigned nextSym = 0;
816 for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
817 if (!usedSyms[i])
818 continue;
819 // Handle constant operands (only needed for symbolic operands since
820 // constant operands in dimensional positions would have already been
821 // promoted to symbolic positions above).
822 IntegerAttr operandCst;
823 if (matchPattern((*operands)[i + mapOrSet->getNumDims()],
824 m_Constant(&operandCst))) {
825 symRemapping[i] =
826 getAffineConstantExpr(operandCst.getValue().getSExtValue(), context);
827 continue;
828 }
829 // Remap symbol positions for duplicate operands.
830 auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
831 if (it == seenSymbols.end()) {
832 symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
833 resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
834 seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
835 symRemapping[i]));
836 } else {
837 symRemapping[i] = it->second;
838 }
839 }
840 *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
841 nextDim, nextSym);
842 *operands = resultOperands;
843 }
844
canonicalizeMapAndOperands(AffineMap * map,SmallVectorImpl<Value> * operands)845 void mlir::canonicalizeMapAndOperands(AffineMap *map,
846 SmallVectorImpl<Value> *operands) {
847 canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
848 }
849
canonicalizeSetAndOperands(IntegerSet * set,SmallVectorImpl<Value> * operands)850 void mlir::canonicalizeSetAndOperands(IntegerSet *set,
851 SmallVectorImpl<Value> *operands) {
852 canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
853 }
854
855 namespace {
856 /// Simplify AffineApply, AffineLoad, and AffineStore operations by composing
857 /// maps that supply results into them.
858 ///
859 template <typename AffineOpTy>
860 struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
861 using OpRewritePattern<AffineOpTy>::OpRewritePattern;
862
863 /// Replace the affine op with another instance of it with the supplied
864 /// map and mapOperands.
865 void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
866 AffineMap map, ArrayRef<Value> mapOperands) const;
867
matchAndRewrite__anon4d71c4590d11::SimplifyAffineOp868 LogicalResult matchAndRewrite(AffineOpTy affineOp,
869 PatternRewriter &rewriter) const override {
870 static_assert(
871 llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
872 AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp,
873 AffineVectorStoreOp, AffineVectorLoadOp>::value,
874 "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op "
875 "expected");
876 auto map = affineOp.getAffineMap();
877 AffineMap oldMap = map;
878 auto oldOperands = affineOp.getMapOperands();
879 SmallVector<Value, 8> resultOperands(oldOperands);
880 composeAffineMapAndOperands(&map, &resultOperands);
881 canonicalizeMapAndOperands(&map, &resultOperands);
882 if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
883 resultOperands.begin()))
884 return failure();
885
886 replaceAffineOp(rewriter, affineOp, map, resultOperands);
887 return success();
888 }
889 };
890
891 // Specialize the template to account for the different build signatures for
892 // affine load, store, and apply ops.
893 template <>
replaceAffineOp(PatternRewriter & rewriter,AffineLoadOp load,AffineMap map,ArrayRef<Value> mapOperands) const894 void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
895 PatternRewriter &rewriter, AffineLoadOp load, AffineMap map,
896 ArrayRef<Value> mapOperands) const {
897 rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.getMemRef(), map,
898 mapOperands);
899 }
900 template <>
replaceAffineOp(PatternRewriter & rewriter,AffinePrefetchOp prefetch,AffineMap map,ArrayRef<Value> mapOperands) const901 void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
902 PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map,
903 ArrayRef<Value> mapOperands) const {
904 rewriter.replaceOpWithNewOp<AffinePrefetchOp>(
905 prefetch, prefetch.memref(), map, mapOperands, prefetch.localityHint(),
906 prefetch.isWrite(), prefetch.isDataCache());
907 }
908 template <>
replaceAffineOp(PatternRewriter & rewriter,AffineStoreOp store,AffineMap map,ArrayRef<Value> mapOperands) const909 void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
910 PatternRewriter &rewriter, AffineStoreOp store, AffineMap map,
911 ArrayRef<Value> mapOperands) const {
912 rewriter.replaceOpWithNewOp<AffineStoreOp>(
913 store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
914 }
915 template <>
replaceAffineOp(PatternRewriter & rewriter,AffineVectorLoadOp vectorload,AffineMap map,ArrayRef<Value> mapOperands) const916 void SimplifyAffineOp<AffineVectorLoadOp>::replaceAffineOp(
917 PatternRewriter &rewriter, AffineVectorLoadOp vectorload, AffineMap map,
918 ArrayRef<Value> mapOperands) const {
919 rewriter.replaceOpWithNewOp<AffineVectorLoadOp>(
920 vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map,
921 mapOperands);
922 }
923 template <>
replaceAffineOp(PatternRewriter & rewriter,AffineVectorStoreOp vectorstore,AffineMap map,ArrayRef<Value> mapOperands) const924 void SimplifyAffineOp<AffineVectorStoreOp>::replaceAffineOp(
925 PatternRewriter &rewriter, AffineVectorStoreOp vectorstore, AffineMap map,
926 ArrayRef<Value> mapOperands) const {
927 rewriter.replaceOpWithNewOp<AffineVectorStoreOp>(
928 vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map,
929 mapOperands);
930 }
931
932 // Generic version for ops that don't have extra operands.
933 template <typename AffineOpTy>
replaceAffineOp(PatternRewriter & rewriter,AffineOpTy op,AffineMap map,ArrayRef<Value> mapOperands) const934 void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
935 PatternRewriter &rewriter, AffineOpTy op, AffineMap map,
936 ArrayRef<Value> mapOperands) const {
937 rewriter.replaceOpWithNewOp<AffineOpTy>(op, map, mapOperands);
938 }
939 } // end anonymous namespace.
940
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)941 void AffineApplyOp::getCanonicalizationPatterns(RewritePatternSet &results,
942 MLIRContext *context) {
943 results.add<SimplifyAffineOp<AffineApplyOp>>(context);
944 }
945
946 //===----------------------------------------------------------------------===//
947 // Common canonicalization pattern support logic
948 //===----------------------------------------------------------------------===//
949
950 /// This is a common class used for patterns of the form
951 /// "someop(memrefcast) -> someop". It folds the source of any memref.cast
952 /// into the root operation directly.
foldMemRefCast(Operation * op,Value ignore=nullptr)953 static LogicalResult foldMemRefCast(Operation *op, Value ignore = nullptr) {
954 bool folded = false;
955 for (OpOperand &operand : op->getOpOperands()) {
956 auto cast = operand.get().getDefiningOp<memref::CastOp>();
957 if (cast && operand.get() != ignore &&
958 !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
959 operand.set(cast.getOperand());
960 folded = true;
961 }
962 }
963 return success(folded);
964 }
965
966 //===----------------------------------------------------------------------===//
967 // AffineDmaStartOp
968 //===----------------------------------------------------------------------===//
969
970 // TODO: Check that map operands are loop IVs or symbols.
build(OpBuilder & builder,OperationState & result,Value srcMemRef,AffineMap srcMap,ValueRange srcIndices,Value destMemRef,AffineMap dstMap,ValueRange destIndices,Value tagMemRef,AffineMap tagMap,ValueRange tagIndices,Value numElements,Value stride,Value elementsPerStride)971 void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result,
972 Value srcMemRef, AffineMap srcMap,
973 ValueRange srcIndices, Value destMemRef,
974 AffineMap dstMap, ValueRange destIndices,
975 Value tagMemRef, AffineMap tagMap,
976 ValueRange tagIndices, Value numElements,
977 Value stride, Value elementsPerStride) {
978 result.addOperands(srcMemRef);
979 result.addAttribute(getSrcMapAttrName(), AffineMapAttr::get(srcMap));
980 result.addOperands(srcIndices);
981 result.addOperands(destMemRef);
982 result.addAttribute(getDstMapAttrName(), AffineMapAttr::get(dstMap));
983 result.addOperands(destIndices);
984 result.addOperands(tagMemRef);
985 result.addAttribute(getTagMapAttrName(), AffineMapAttr::get(tagMap));
986 result.addOperands(tagIndices);
987 result.addOperands(numElements);
988 if (stride) {
989 result.addOperands({stride, elementsPerStride});
990 }
991 }
992
print(OpAsmPrinter & p)993 void AffineDmaStartOp::print(OpAsmPrinter &p) {
994 p << "affine.dma_start " << getSrcMemRef() << '[';
995 p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices());
996 p << "], " << getDstMemRef() << '[';
997 p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices());
998 p << "], " << getTagMemRef() << '[';
999 p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices());
1000 p << "], " << getNumElements();
1001 if (isStrided()) {
1002 p << ", " << getStride();
1003 p << ", " << getNumElementsPerStride();
1004 }
1005 p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
1006 << getTagMemRefType();
1007 }
1008
1009 // Parse AffineDmaStartOp.
1010 // Ex:
1011 // affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size,
1012 // %stride, %num_elt_per_stride
1013 // : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32>
1014 //
parse(OpAsmParser & parser,OperationState & result)1015 ParseResult AffineDmaStartOp::parse(OpAsmParser &parser,
1016 OperationState &result) {
1017 OpAsmParser::OperandType srcMemRefInfo;
1018 AffineMapAttr srcMapAttr;
1019 SmallVector<OpAsmParser::OperandType, 4> srcMapOperands;
1020 OpAsmParser::OperandType dstMemRefInfo;
1021 AffineMapAttr dstMapAttr;
1022 SmallVector<OpAsmParser::OperandType, 4> dstMapOperands;
1023 OpAsmParser::OperandType tagMemRefInfo;
1024 AffineMapAttr tagMapAttr;
1025 SmallVector<OpAsmParser::OperandType, 4> tagMapOperands;
1026 OpAsmParser::OperandType numElementsInfo;
1027 SmallVector<OpAsmParser::OperandType, 2> strideInfo;
1028
1029 SmallVector<Type, 3> types;
1030 auto indexType = parser.getBuilder().getIndexType();
1031
1032 // Parse and resolve the following list of operands:
1033 // *) dst memref followed by its affine maps operands (in square brackets).
1034 // *) src memref followed by its affine map operands (in square brackets).
1035 // *) tag memref followed by its affine map operands (in square brackets).
1036 // *) number of elements transferred by DMA operation.
1037 if (parser.parseOperand(srcMemRefInfo) ||
1038 parser.parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr,
1039 getSrcMapAttrName(), result.attributes) ||
1040 parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1041 parser.parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr,
1042 getDstMapAttrName(), result.attributes) ||
1043 parser.parseComma() || parser.parseOperand(tagMemRefInfo) ||
1044 parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1045 getTagMapAttrName(), result.attributes) ||
1046 parser.parseComma() || parser.parseOperand(numElementsInfo))
1047 return failure();
1048
1049 // Parse optional stride and elements per stride.
1050 if (parser.parseTrailingOperandList(strideInfo)) {
1051 return failure();
1052 }
1053 if (!strideInfo.empty() && strideInfo.size() != 2) {
1054 return parser.emitError(parser.getNameLoc(),
1055 "expected two stride related operands");
1056 }
1057 bool isStrided = strideInfo.size() == 2;
1058
1059 if (parser.parseColonTypeList(types))
1060 return failure();
1061
1062 if (types.size() != 3)
1063 return parser.emitError(parser.getNameLoc(), "expected three types");
1064
1065 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1066 parser.resolveOperands(srcMapOperands, indexType, result.operands) ||
1067 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1068 parser.resolveOperands(dstMapOperands, indexType, result.operands) ||
1069 parser.resolveOperand(tagMemRefInfo, types[2], result.operands) ||
1070 parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1071 parser.resolveOperand(numElementsInfo, indexType, result.operands))
1072 return failure();
1073
1074 if (isStrided) {
1075 if (parser.resolveOperands(strideInfo, indexType, result.operands))
1076 return failure();
1077 }
1078
1079 // Check that src/dst/tag operand counts match their map.numInputs.
1080 if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
1081 dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
1082 tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1083 return parser.emitError(parser.getNameLoc(),
1084 "memref operand count not equal to map.numInputs");
1085 return success();
1086 }
1087
verify()1088 LogicalResult AffineDmaStartOp::verify() {
1089 if (!getOperand(getSrcMemRefOperandIndex()).getType().isa<MemRefType>())
1090 return emitOpError("expected DMA source to be of memref type");
1091 if (!getOperand(getDstMemRefOperandIndex()).getType().isa<MemRefType>())
1092 return emitOpError("expected DMA destination to be of memref type");
1093 if (!getOperand(getTagMemRefOperandIndex()).getType().isa<MemRefType>())
1094 return emitOpError("expected DMA tag to be of memref type");
1095
1096 unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
1097 getDstMap().getNumInputs() +
1098 getTagMap().getNumInputs();
1099 if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
1100 getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
1101 return emitOpError("incorrect number of operands");
1102 }
1103
1104 Region *scope = getAffineScope(*this);
1105 for (auto idx : getSrcIndices()) {
1106 if (!idx.getType().isIndex())
1107 return emitOpError("src index to dma_start must have 'index' type");
1108 if (!isValidAffineIndexOperand(idx, scope))
1109 return emitOpError("src index must be a dimension or symbol identifier");
1110 }
1111 for (auto idx : getDstIndices()) {
1112 if (!idx.getType().isIndex())
1113 return emitOpError("dst index to dma_start must have 'index' type");
1114 if (!isValidAffineIndexOperand(idx, scope))
1115 return emitOpError("dst index must be a dimension or symbol identifier");
1116 }
1117 for (auto idx : getTagIndices()) {
1118 if (!idx.getType().isIndex())
1119 return emitOpError("tag index to dma_start must have 'index' type");
1120 if (!isValidAffineIndexOperand(idx, scope))
1121 return emitOpError("tag index must be a dimension or symbol identifier");
1122 }
1123 return success();
1124 }
1125
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1126 LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1127 SmallVectorImpl<OpFoldResult> &results) {
1128 /// dma_start(memrefcast) -> dma_start
1129 return foldMemRefCast(*this);
1130 }
1131
1132 //===----------------------------------------------------------------------===//
1133 // AffineDmaWaitOp
1134 //===----------------------------------------------------------------------===//
1135
1136 // TODO: Check that map operands are loop IVs or symbols.
build(OpBuilder & builder,OperationState & result,Value tagMemRef,AffineMap tagMap,ValueRange tagIndices,Value numElements)1137 void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result,
1138 Value tagMemRef, AffineMap tagMap,
1139 ValueRange tagIndices, Value numElements) {
1140 result.addOperands(tagMemRef);
1141 result.addAttribute(getTagMapAttrName(), AffineMapAttr::get(tagMap));
1142 result.addOperands(tagIndices);
1143 result.addOperands(numElements);
1144 }
1145
print(OpAsmPrinter & p)1146 void AffineDmaWaitOp::print(OpAsmPrinter &p) {
1147 p << "affine.dma_wait " << getTagMemRef() << '[';
1148 SmallVector<Value, 2> operands(getTagIndices());
1149 p.printAffineMapOfSSAIds(getTagMapAttr(), operands);
1150 p << "], ";
1151 p.printOperand(getNumElements());
1152 p << " : " << getTagMemRef().getType();
1153 }
1154
1155 // Parse AffineDmaWaitOp.
1156 // Eg:
1157 // affine.dma_wait %tag[%index], %num_elements
1158 // : memref<1 x i32, (d0) -> (d0), 4>
1159 //
parse(OpAsmParser & parser,OperationState & result)1160 ParseResult AffineDmaWaitOp::parse(OpAsmParser &parser,
1161 OperationState &result) {
1162 OpAsmParser::OperandType tagMemRefInfo;
1163 AffineMapAttr tagMapAttr;
1164 SmallVector<OpAsmParser::OperandType, 2> tagMapOperands;
1165 Type type;
1166 auto indexType = parser.getBuilder().getIndexType();
1167 OpAsmParser::OperandType numElementsInfo;
1168
1169 // Parse tag memref, its map operands, and dma size.
1170 if (parser.parseOperand(tagMemRefInfo) ||
1171 parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1172 getTagMapAttrName(), result.attributes) ||
1173 parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1174 parser.parseColonType(type) ||
1175 parser.resolveOperand(tagMemRefInfo, type, result.operands) ||
1176 parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1177 parser.resolveOperand(numElementsInfo, indexType, result.operands))
1178 return failure();
1179
1180 if (!type.isa<MemRefType>())
1181 return parser.emitError(parser.getNameLoc(),
1182 "expected tag to be of memref type");
1183
1184 if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1185 return parser.emitError(parser.getNameLoc(),
1186 "tag memref operand count != to map.numInputs");
1187 return success();
1188 }
1189
verify()1190 LogicalResult AffineDmaWaitOp::verify() {
1191 if (!getOperand(0).getType().isa<MemRefType>())
1192 return emitOpError("expected DMA tag to be of memref type");
1193 Region *scope = getAffineScope(*this);
1194 for (auto idx : getTagIndices()) {
1195 if (!idx.getType().isIndex())
1196 return emitOpError("index to dma_wait must have 'index' type");
1197 if (!isValidAffineIndexOperand(idx, scope))
1198 return emitOpError("index must be a dimension or symbol identifier");
1199 }
1200 return success();
1201 }
1202
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1203 LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1204 SmallVectorImpl<OpFoldResult> &results) {
1205 /// dma_wait(memrefcast) -> dma_wait
1206 return foldMemRefCast(*this);
1207 }
1208
1209 //===----------------------------------------------------------------------===//
1210 // AffineForOp
1211 //===----------------------------------------------------------------------===//
1212
1213 /// 'bodyBuilder' is used to build the body of affine.for. If iterArgs and
1214 /// bodyBuilder are empty/null, we include default terminator op.
build(OpBuilder & builder,OperationState & result,ValueRange lbOperands,AffineMap lbMap,ValueRange ubOperands,AffineMap ubMap,int64_t step,ValueRange iterArgs,BodyBuilderFn bodyBuilder)1215 void AffineForOp::build(OpBuilder &builder, OperationState &result,
1216 ValueRange lbOperands, AffineMap lbMap,
1217 ValueRange ubOperands, AffineMap ubMap, int64_t step,
1218 ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
1219 assert(((!lbMap && lbOperands.empty()) ||
1220 lbOperands.size() == lbMap.getNumInputs()) &&
1221 "lower bound operand count does not match the affine map");
1222 assert(((!ubMap && ubOperands.empty()) ||
1223 ubOperands.size() == ubMap.getNumInputs()) &&
1224 "upper bound operand count does not match the affine map");
1225 assert(step > 0 && "step has to be a positive integer constant");
1226
1227 for (Value val : iterArgs)
1228 result.addTypes(val.getType());
1229
1230 // Add an attribute for the step.
1231 result.addAttribute(getStepAttrName(),
1232 builder.getIntegerAttr(builder.getIndexType(), step));
1233
1234 // Add the lower bound.
1235 result.addAttribute(getLowerBoundAttrName(), AffineMapAttr::get(lbMap));
1236 result.addOperands(lbOperands);
1237
1238 // Add the upper bound.
1239 result.addAttribute(getUpperBoundAttrName(), AffineMapAttr::get(ubMap));
1240 result.addOperands(ubOperands);
1241
1242 result.addOperands(iterArgs);
1243 // Create a region and a block for the body. The argument of the region is
1244 // the loop induction variable.
1245 Region *bodyRegion = result.addRegion();
1246 bodyRegion->push_back(new Block);
1247 Block &bodyBlock = bodyRegion->front();
1248 Value inductionVar = bodyBlock.addArgument(builder.getIndexType());
1249 for (Value val : iterArgs)
1250 bodyBlock.addArgument(val.getType());
1251
1252 // Create the default terminator if the builder is not provided and if the
1253 // iteration arguments are not provided. Otherwise, leave this to the caller
1254 // because we don't know which values to return from the loop.
1255 if (iterArgs.empty() && !bodyBuilder) {
1256 ensureTerminator(*bodyRegion, builder, result.location);
1257 } else if (bodyBuilder) {
1258 OpBuilder::InsertionGuard guard(builder);
1259 builder.setInsertionPointToStart(&bodyBlock);
1260 bodyBuilder(builder, result.location, inductionVar,
1261 bodyBlock.getArguments().drop_front());
1262 }
1263 }
1264
build(OpBuilder & builder,OperationState & result,int64_t lb,int64_t ub,int64_t step,ValueRange iterArgs,BodyBuilderFn bodyBuilder)1265 void AffineForOp::build(OpBuilder &builder, OperationState &result, int64_t lb,
1266 int64_t ub, int64_t step, ValueRange iterArgs,
1267 BodyBuilderFn bodyBuilder) {
1268 auto lbMap = AffineMap::getConstantMap(lb, builder.getContext());
1269 auto ubMap = AffineMap::getConstantMap(ub, builder.getContext());
1270 return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,
1271 bodyBuilder);
1272 }
1273
verify(AffineForOp op)1274 static LogicalResult verify(AffineForOp op) {
1275 // Check that the body defines as single block argument for the induction
1276 // variable.
1277 auto *body = op.getBody();
1278 if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
1279 return op.emitOpError(
1280 "expected body to have a single index argument for the "
1281 "induction variable");
1282
1283 // Verify that the bound operands are valid dimension/symbols.
1284 /// Lower bound.
1285 if (op.getLowerBoundMap().getNumInputs() > 0)
1286 if (failed(
1287 verifyDimAndSymbolIdentifiers(op, op.getLowerBoundOperands(),
1288 op.getLowerBoundMap().getNumDims())))
1289 return failure();
1290 /// Upper bound.
1291 if (op.getUpperBoundMap().getNumInputs() > 0)
1292 if (failed(
1293 verifyDimAndSymbolIdentifiers(op, op.getUpperBoundOperands(),
1294 op.getUpperBoundMap().getNumDims())))
1295 return failure();
1296
1297 unsigned opNumResults = op.getNumResults();
1298 if (opNumResults == 0)
1299 return success();
1300
1301 // If ForOp defines values, check that the number and types of the defined
1302 // values match ForOp initial iter operands and backedge basic block
1303 // arguments.
1304 if (op.getNumIterOperands() != opNumResults)
1305 return op.emitOpError(
1306 "mismatch between the number of loop-carried values and results");
1307 if (op.getNumRegionIterArgs() != opNumResults)
1308 return op.emitOpError(
1309 "mismatch between the number of basic block args and results");
1310
1311 return success();
1312 }
1313
1314 /// Parse a for operation loop bounds.
parseBound(bool isLower,OperationState & result,OpAsmParser & p)1315 static ParseResult parseBound(bool isLower, OperationState &result,
1316 OpAsmParser &p) {
1317 // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
1318 // the map has multiple results.
1319 bool failedToParsedMinMax =
1320 failed(p.parseOptionalKeyword(isLower ? "max" : "min"));
1321
1322 auto &builder = p.getBuilder();
1323 auto boundAttrName = isLower ? AffineForOp::getLowerBoundAttrName()
1324 : AffineForOp::getUpperBoundAttrName();
1325
1326 // Parse ssa-id as identity map.
1327 SmallVector<OpAsmParser::OperandType, 1> boundOpInfos;
1328 if (p.parseOperandList(boundOpInfos))
1329 return failure();
1330
1331 if (!boundOpInfos.empty()) {
1332 // Check that only one operand was parsed.
1333 if (boundOpInfos.size() > 1)
1334 return p.emitError(p.getNameLoc(),
1335 "expected only one loop bound operand");
1336
1337 // TODO: improve error message when SSA value is not of index type.
1338 // Currently it is 'use of value ... expects different type than prior uses'
1339 if (p.resolveOperand(boundOpInfos.front(), builder.getIndexType(),
1340 result.operands))
1341 return failure();
1342
1343 // Create an identity map using symbol id. This representation is optimized
1344 // for storage. Analysis passes may expand it into a multi-dimensional map
1345 // if desired.
1346 AffineMap map = builder.getSymbolIdentityMap();
1347 result.addAttribute(boundAttrName, AffineMapAttr::get(map));
1348 return success();
1349 }
1350
1351 // Get the attribute location.
1352 llvm::SMLoc attrLoc = p.getCurrentLocation();
1353
1354 Attribute boundAttr;
1355 if (p.parseAttribute(boundAttr, builder.getIndexType(), boundAttrName,
1356 result.attributes))
1357 return failure();
1358
1359 // Parse full form - affine map followed by dim and symbol list.
1360 if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) {
1361 unsigned currentNumOperands = result.operands.size();
1362 unsigned numDims;
1363 if (parseDimAndSymbolList(p, result.operands, numDims))
1364 return failure();
1365
1366 auto map = affineMapAttr.getValue();
1367 if (map.getNumDims() != numDims)
1368 return p.emitError(
1369 p.getNameLoc(),
1370 "dim operand count and affine map dim count must match");
1371
1372 unsigned numDimAndSymbolOperands =
1373 result.operands.size() - currentNumOperands;
1374 if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
1375 return p.emitError(
1376 p.getNameLoc(),
1377 "symbol operand count and affine map symbol count must match");
1378
1379 // If the map has multiple results, make sure that we parsed the min/max
1380 // prefix.
1381 if (map.getNumResults() > 1 && failedToParsedMinMax) {
1382 if (isLower) {
1383 return p.emitError(attrLoc, "lower loop bound affine map with "
1384 "multiple results requires 'max' prefix");
1385 }
1386 return p.emitError(attrLoc, "upper loop bound affine map with multiple "
1387 "results requires 'min' prefix");
1388 }
1389 return success();
1390 }
1391
1392 // Parse custom assembly form.
1393 if (auto integerAttr = boundAttr.dyn_cast<IntegerAttr>()) {
1394 result.attributes.pop_back();
1395 result.addAttribute(
1396 boundAttrName,
1397 AffineMapAttr::get(builder.getConstantAffineMap(integerAttr.getInt())));
1398 return success();
1399 }
1400
1401 return p.emitError(
1402 p.getNameLoc(),
1403 "expected valid affine map representation for loop bounds");
1404 }
1405
parseAffineForOp(OpAsmParser & parser,OperationState & result)1406 static ParseResult parseAffineForOp(OpAsmParser &parser,
1407 OperationState &result) {
1408 auto &builder = parser.getBuilder();
1409 OpAsmParser::OperandType inductionVariable;
1410 // Parse the induction variable followed by '='.
1411 if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual())
1412 return failure();
1413
1414 // Parse loop bounds.
1415 if (parseBound(/*isLower=*/true, result, parser) ||
1416 parser.parseKeyword("to", " between bounds") ||
1417 parseBound(/*isLower=*/false, result, parser))
1418 return failure();
1419
1420 // Parse the optional loop step, we default to 1 if one is not present.
1421 if (parser.parseOptionalKeyword("step")) {
1422 result.addAttribute(
1423 AffineForOp::getStepAttrName(),
1424 builder.getIntegerAttr(builder.getIndexType(), /*value=*/1));
1425 } else {
1426 llvm::SMLoc stepLoc = parser.getCurrentLocation();
1427 IntegerAttr stepAttr;
1428 if (parser.parseAttribute(stepAttr, builder.getIndexType(),
1429 AffineForOp::getStepAttrName().data(),
1430 result.attributes))
1431 return failure();
1432
1433 if (stepAttr.getValue().getSExtValue() < 0)
1434 return parser.emitError(
1435 stepLoc,
1436 "expected step to be representable as a positive signed integer");
1437 }
1438
1439 // Parse the optional initial iteration arguments.
1440 SmallVector<OpAsmParser::OperandType, 4> regionArgs, operands;
1441 SmallVector<Type, 4> argTypes;
1442 regionArgs.push_back(inductionVariable);
1443
1444 if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
1445 // Parse assignment list and results type list.
1446 if (parser.parseAssignmentList(regionArgs, operands) ||
1447 parser.parseArrowTypeList(result.types))
1448 return failure();
1449 // Resolve input operands.
1450 for (auto operandType : llvm::zip(operands, result.types))
1451 if (parser.resolveOperand(std::get<0>(operandType),
1452 std::get<1>(operandType), result.operands))
1453 return failure();
1454 }
1455 // Induction variable.
1456 Type indexType = builder.getIndexType();
1457 argTypes.push_back(indexType);
1458 // Loop carried variables.
1459 argTypes.append(result.types.begin(), result.types.end());
1460 // Parse the body region.
1461 Region *body = result.addRegion();
1462 if (regionArgs.size() != argTypes.size())
1463 return parser.emitError(
1464 parser.getNameLoc(),
1465 "mismatch between the number of loop-carried values and results");
1466 if (parser.parseRegion(*body, regionArgs, argTypes))
1467 return failure();
1468
1469 AffineForOp::ensureTerminator(*body, builder, result.location);
1470
1471 // Parse the optional attribute list.
1472 return parser.parseOptionalAttrDict(result.attributes);
1473 }
1474
printBound(AffineMapAttr boundMap,Operation::operand_range boundOperands,const char * prefix,OpAsmPrinter & p)1475 static void printBound(AffineMapAttr boundMap,
1476 Operation::operand_range boundOperands,
1477 const char *prefix, OpAsmPrinter &p) {
1478 AffineMap map = boundMap.getValue();
1479
1480 // Check if this bound should be printed using custom assembly form.
1481 // The decision to restrict printing custom assembly form to trivial cases
1482 // comes from the will to roundtrip MLIR binary -> text -> binary in a
1483 // lossless way.
1484 // Therefore, custom assembly form parsing and printing is only supported for
1485 // zero-operand constant maps and single symbol operand identity maps.
1486 if (map.getNumResults() == 1) {
1487 AffineExpr expr = map.getResult(0);
1488
1489 // Print constant bound.
1490 if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
1491 if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
1492 p << constExpr.getValue();
1493 return;
1494 }
1495 }
1496
1497 // Print bound that consists of a single SSA symbol if the map is over a
1498 // single symbol.
1499 if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
1500 if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
1501 p.printOperand(*boundOperands.begin());
1502 return;
1503 }
1504 }
1505 } else {
1506 // Map has multiple results. Print 'min' or 'max' prefix.
1507 p << prefix << ' ';
1508 }
1509
1510 // Print the map and its operands.
1511 p << boundMap;
1512 printDimAndSymbolList(boundOperands.begin(), boundOperands.end(),
1513 map.getNumDims(), p);
1514 }
1515
getNumIterOperands()1516 unsigned AffineForOp::getNumIterOperands() {
1517 AffineMap lbMap = getLowerBoundMapAttr().getValue();
1518 AffineMap ubMap = getUpperBoundMapAttr().getValue();
1519
1520 return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs();
1521 }
1522
print(OpAsmPrinter & p,AffineForOp op)1523 static void print(OpAsmPrinter &p, AffineForOp op) {
1524 p << op.getOperationName() << ' ';
1525 p.printOperand(op.getBody()->getArgument(0));
1526 p << " = ";
1527 printBound(op.getLowerBoundMapAttr(), op.getLowerBoundOperands(), "max", p);
1528 p << " to ";
1529 printBound(op.getUpperBoundMapAttr(), op.getUpperBoundOperands(), "min", p);
1530
1531 if (op.getStep() != 1)
1532 p << " step " << op.getStep();
1533
1534 bool printBlockTerminators = false;
1535 if (op.getNumIterOperands() > 0) {
1536 p << " iter_args(";
1537 auto regionArgs = op.getRegionIterArgs();
1538 auto operands = op.getIterOperands();
1539
1540 llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) {
1541 p << std::get<0>(it) << " = " << std::get<1>(it);
1542 });
1543 p << ") -> (" << op.getResultTypes() << ")";
1544 printBlockTerminators = true;
1545 }
1546
1547 p.printRegion(op.region(),
1548 /*printEntryBlockArgs=*/false, printBlockTerminators);
1549 p.printOptionalAttrDict(op->getAttrs(),
1550 /*elidedAttrs=*/{op.getLowerBoundAttrName(),
1551 op.getUpperBoundAttrName(),
1552 op.getStepAttrName()});
1553 }
1554
1555 /// Fold the constant bounds of a loop.
foldLoopBounds(AffineForOp forOp)1556 static LogicalResult foldLoopBounds(AffineForOp forOp) {
1557 auto foldLowerOrUpperBound = [&forOp](bool lower) {
1558 // Check to see if each of the operands is the result of a constant. If
1559 // so, get the value. If not, ignore it.
1560 SmallVector<Attribute, 8> operandConstants;
1561 auto boundOperands =
1562 lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
1563 for (auto operand : boundOperands) {
1564 Attribute operandCst;
1565 matchPattern(operand, m_Constant(&operandCst));
1566 operandConstants.push_back(operandCst);
1567 }
1568
1569 AffineMap boundMap =
1570 lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
1571 assert(boundMap.getNumResults() >= 1 &&
1572 "bound maps should have at least one result");
1573 SmallVector<Attribute, 4> foldedResults;
1574 if (failed(boundMap.constantFold(operandConstants, foldedResults)))
1575 return failure();
1576
1577 // Compute the max or min as applicable over the results.
1578 assert(!foldedResults.empty() && "bounds should have at least one result");
1579 auto maxOrMin = foldedResults[0].cast<IntegerAttr>().getValue();
1580 for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
1581 auto foldedResult = foldedResults[i].cast<IntegerAttr>().getValue();
1582 maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
1583 : llvm::APIntOps::smin(maxOrMin, foldedResult);
1584 }
1585 lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
1586 : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
1587 return success();
1588 };
1589
1590 // Try to fold the lower bound.
1591 bool folded = false;
1592 if (!forOp.hasConstantLowerBound())
1593 folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true));
1594
1595 // Try to fold the upper bound.
1596 if (!forOp.hasConstantUpperBound())
1597 folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false));
1598 return success(folded);
1599 }
1600
1601 /// Canonicalize the bounds of the given loop.
canonicalizeLoopBounds(AffineForOp forOp)1602 static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
1603 SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
1604 SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
1605
1606 auto lbMap = forOp.getLowerBoundMap();
1607 auto ubMap = forOp.getUpperBoundMap();
1608 auto prevLbMap = lbMap;
1609 auto prevUbMap = ubMap;
1610
1611 composeAffineMapAndOperands(&lbMap, &lbOperands);
1612 canonicalizeMapAndOperands(&lbMap, &lbOperands);
1613 lbMap = removeDuplicateExprs(lbMap);
1614
1615 composeAffineMapAndOperands(&ubMap, &ubOperands);
1616 canonicalizeMapAndOperands(&ubMap, &ubOperands);
1617 ubMap = removeDuplicateExprs(ubMap);
1618
1619 // Any canonicalization change always leads to updated map(s).
1620 if (lbMap == prevLbMap && ubMap == prevUbMap)
1621 return failure();
1622
1623 if (lbMap != prevLbMap)
1624 forOp.setLowerBound(lbOperands, lbMap);
1625 if (ubMap != prevUbMap)
1626 forOp.setUpperBound(ubOperands, ubMap);
1627 return success();
1628 }
1629
1630 namespace {
1631 /// This is a pattern to fold trivially empty loop bodies.
1632 /// TODO: This should be moved into the folding hook.
1633 struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
1634 using OpRewritePattern<AffineForOp>::OpRewritePattern;
1635
matchAndRewrite__anon4d71c4591011::AffineForEmptyLoopFolder1636 LogicalResult matchAndRewrite(AffineForOp forOp,
1637 PatternRewriter &rewriter) const override {
1638 // Check that the body only contains a yield.
1639 if (!llvm::hasSingleElement(*forOp.getBody()))
1640 return failure();
1641 // The initial values of the iteration arguments would be the op's results.
1642 rewriter.replaceOp(forOp, forOp.getIterOperands());
1643 return success();
1644 }
1645 };
1646 } // end anonymous namespace
1647
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1648 void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1649 MLIRContext *context) {
1650 results.add<AffineForEmptyLoopFolder>(context);
1651 }
1652
1653 /// Returns true if the affine.for has zero iterations in trivial cases.
hasTrivialZeroTripCount(AffineForOp op)1654 static bool hasTrivialZeroTripCount(AffineForOp op) {
1655 if (!op.hasConstantBounds())
1656 return false;
1657 int64_t lb = op.getConstantLowerBound();
1658 int64_t ub = op.getConstantUpperBound();
1659 return ub - lb <= 0;
1660 }
1661
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1662 LogicalResult AffineForOp::fold(ArrayRef<Attribute> operands,
1663 SmallVectorImpl<OpFoldResult> &results) {
1664 bool folded = succeeded(foldLoopBounds(*this));
1665 folded |= succeeded(canonicalizeLoopBounds(*this));
1666 if (hasTrivialZeroTripCount(*this)) {
1667 // The initial values of the loop-carried variables (iter_args) are the
1668 // results of the op.
1669 results.assign(getIterOperands().begin(), getIterOperands().end());
1670 folded = true;
1671 }
1672 return success(folded);
1673 }
1674
getLowerBound()1675 AffineBound AffineForOp::getLowerBound() {
1676 auto lbMap = getLowerBoundMap();
1677 return AffineBound(AffineForOp(*this), 0, lbMap.getNumInputs(), lbMap);
1678 }
1679
getUpperBound()1680 AffineBound AffineForOp::getUpperBound() {
1681 auto lbMap = getLowerBoundMap();
1682 auto ubMap = getUpperBoundMap();
1683 return AffineBound(AffineForOp(*this), lbMap.getNumInputs(),
1684 lbMap.getNumInputs() + ubMap.getNumInputs(), ubMap);
1685 }
1686
setLowerBound(ValueRange lbOperands,AffineMap map)1687 void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) {
1688 assert(lbOperands.size() == map.getNumInputs());
1689 assert(map.getNumResults() >= 1 && "bound map has at least one result");
1690
1691 SmallVector<Value, 4> newOperands(lbOperands.begin(), lbOperands.end());
1692
1693 auto ubOperands = getUpperBoundOperands();
1694 newOperands.append(ubOperands.begin(), ubOperands.end());
1695 auto iterOperands = getIterOperands();
1696 newOperands.append(iterOperands.begin(), iterOperands.end());
1697 (*this)->setOperands(newOperands);
1698
1699 (*this)->setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map));
1700 }
1701
setUpperBound(ValueRange ubOperands,AffineMap map)1702 void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) {
1703 assert(ubOperands.size() == map.getNumInputs());
1704 assert(map.getNumResults() >= 1 && "bound map has at least one result");
1705
1706 SmallVector<Value, 4> newOperands(getLowerBoundOperands());
1707 newOperands.append(ubOperands.begin(), ubOperands.end());
1708 auto iterOperands = getIterOperands();
1709 newOperands.append(iterOperands.begin(), iterOperands.end());
1710 (*this)->setOperands(newOperands);
1711
1712 (*this)->setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map));
1713 }
1714
setLowerBoundMap(AffineMap map)1715 void AffineForOp::setLowerBoundMap(AffineMap map) {
1716 auto lbMap = getLowerBoundMap();
1717 assert(lbMap.getNumDims() == map.getNumDims() &&
1718 lbMap.getNumSymbols() == map.getNumSymbols());
1719 assert(map.getNumResults() >= 1 && "bound map has at least one result");
1720 (void)lbMap;
1721 (*this)->setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map));
1722 }
1723
setUpperBoundMap(AffineMap map)1724 void AffineForOp::setUpperBoundMap(AffineMap map) {
1725 auto ubMap = getUpperBoundMap();
1726 assert(ubMap.getNumDims() == map.getNumDims() &&
1727 ubMap.getNumSymbols() == map.getNumSymbols());
1728 assert(map.getNumResults() >= 1 && "bound map has at least one result");
1729 (void)ubMap;
1730 (*this)->setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map));
1731 }
1732
hasConstantLowerBound()1733 bool AffineForOp::hasConstantLowerBound() {
1734 return getLowerBoundMap().isSingleConstant();
1735 }
1736
hasConstantUpperBound()1737 bool AffineForOp::hasConstantUpperBound() {
1738 return getUpperBoundMap().isSingleConstant();
1739 }
1740
getConstantLowerBound()1741 int64_t AffineForOp::getConstantLowerBound() {
1742 return getLowerBoundMap().getSingleConstantResult();
1743 }
1744
getConstantUpperBound()1745 int64_t AffineForOp::getConstantUpperBound() {
1746 return getUpperBoundMap().getSingleConstantResult();
1747 }
1748
setConstantLowerBound(int64_t value)1749 void AffineForOp::setConstantLowerBound(int64_t value) {
1750 setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
1751 }
1752
setConstantUpperBound(int64_t value)1753 void AffineForOp::setConstantUpperBound(int64_t value) {
1754 setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
1755 }
1756
getLowerBoundOperands()1757 AffineForOp::operand_range AffineForOp::getLowerBoundOperands() {
1758 return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()};
1759 }
1760
getUpperBoundOperands()1761 AffineForOp::operand_range AffineForOp::getUpperBoundOperands() {
1762 return {operand_begin() + getLowerBoundMap().getNumInputs(),
1763 operand_begin() + getLowerBoundMap().getNumInputs() +
1764 getUpperBoundMap().getNumInputs()};
1765 }
1766
matchingBoundOperandList()1767 bool AffineForOp::matchingBoundOperandList() {
1768 auto lbMap = getLowerBoundMap();
1769 auto ubMap = getUpperBoundMap();
1770 if (lbMap.getNumDims() != ubMap.getNumDims() ||
1771 lbMap.getNumSymbols() != ubMap.getNumSymbols())
1772 return false;
1773
1774 unsigned numOperands = lbMap.getNumInputs();
1775 for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
1776 // Compare Value 's.
1777 if (getOperand(i) != getOperand(numOperands + i))
1778 return false;
1779 }
1780 return true;
1781 }
1782
getLoopBody()1783 Region &AffineForOp::getLoopBody() { return region(); }
1784
isDefinedOutsideOfLoop(Value value)1785 bool AffineForOp::isDefinedOutsideOfLoop(Value value) {
1786 return !region().isAncestor(value.getParentRegion());
1787 }
1788
moveOutOfLoop(ArrayRef<Operation * > ops)1789 LogicalResult AffineForOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
1790 for (auto *op : ops)
1791 op->moveBefore(*this);
1792 return success();
1793 }
1794
1795 /// Returns true if the provided value is the induction variable of a
1796 /// AffineForOp.
isForInductionVar(Value val)1797 bool mlir::isForInductionVar(Value val) {
1798 return getForInductionVarOwner(val) != AffineForOp();
1799 }
1800
1801 /// Returns the loop parent of an induction variable. If the provided value is
1802 /// not an induction variable, then return nullptr.
getForInductionVarOwner(Value val)1803 AffineForOp mlir::getForInductionVarOwner(Value val) {
1804 auto ivArg = val.dyn_cast<BlockArgument>();
1805 if (!ivArg || !ivArg.getOwner())
1806 return AffineForOp();
1807 auto *containingInst = ivArg.getOwner()->getParent()->getParentOp();
1808 return dyn_cast<AffineForOp>(containingInst);
1809 }
1810
1811 /// Extracts the induction variables from a list of AffineForOps and returns
1812 /// them.
extractForInductionVars(ArrayRef<AffineForOp> forInsts,SmallVectorImpl<Value> * ivs)1813 void mlir::extractForInductionVars(ArrayRef<AffineForOp> forInsts,
1814 SmallVectorImpl<Value> *ivs) {
1815 ivs->reserve(forInsts.size());
1816 for (auto forInst : forInsts)
1817 ivs->push_back(forInst.getInductionVar());
1818 }
1819
1820 /// Builds an affine loop nest, using "loopCreatorFn" to create individual loop
1821 /// operations.
1822 template <typename BoundListTy, typename LoopCreatorTy>
buildAffineLoopNestImpl(OpBuilder & builder,Location loc,BoundListTy lbs,BoundListTy ubs,ArrayRef<int64_t> steps,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilderFn,LoopCreatorTy && loopCreatorFn)1823 static void buildAffineLoopNestImpl(
1824 OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs,
1825 ArrayRef<int64_t> steps,
1826 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
1827 LoopCreatorTy &&loopCreatorFn) {
1828 assert(lbs.size() == ubs.size() && "Mismatch in number of arguments");
1829 assert(lbs.size() == steps.size() && "Mismatch in number of arguments");
1830
1831 // If there are no loops to be constructed, construct the body anyway.
1832 OpBuilder::InsertionGuard guard(builder);
1833 if (lbs.empty()) {
1834 if (bodyBuilderFn)
1835 bodyBuilderFn(builder, loc, ValueRange());
1836 return;
1837 }
1838
1839 // Create the loops iteratively and store the induction variables.
1840 SmallVector<Value, 4> ivs;
1841 ivs.reserve(lbs.size());
1842 for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
1843 // Callback for creating the loop body, always creates the terminator.
1844 auto loopBody = [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
1845 ValueRange iterArgs) {
1846 ivs.push_back(iv);
1847 // In the innermost loop, call the body builder.
1848 if (i == e - 1 && bodyBuilderFn) {
1849 OpBuilder::InsertionGuard nestedGuard(nestedBuilder);
1850 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
1851 }
1852 nestedBuilder.create<AffineYieldOp>(nestedLoc);
1853 };
1854
1855 // Delegate actual loop creation to the callback in order to dispatch
1856 // between constant- and variable-bound loops.
1857 auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
1858 builder.setInsertionPointToStart(loop.getBody());
1859 }
1860 }
1861
1862 /// Creates an affine loop from the bounds known to be constants.
1863 static AffineForOp
buildAffineLoopFromConstants(OpBuilder & builder,Location loc,int64_t lb,int64_t ub,int64_t step,AffineForOp::BodyBuilderFn bodyBuilderFn)1864 buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb,
1865 int64_t ub, int64_t step,
1866 AffineForOp::BodyBuilderFn bodyBuilderFn) {
1867 return builder.create<AffineForOp>(loc, lb, ub, step, /*iterArgs=*/llvm::None,
1868 bodyBuilderFn);
1869 }
1870
1871 /// Creates an affine loop from the bounds that may or may not be constants.
1872 static AffineForOp
buildAffineLoopFromValues(OpBuilder & builder,Location loc,Value lb,Value ub,int64_t step,AffineForOp::BodyBuilderFn bodyBuilderFn)1873 buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub,
1874 int64_t step,
1875 AffineForOp::BodyBuilderFn bodyBuilderFn) {
1876 auto lbConst = lb.getDefiningOp<ConstantIndexOp>();
1877 auto ubConst = ub.getDefiningOp<ConstantIndexOp>();
1878 if (lbConst && ubConst)
1879 return buildAffineLoopFromConstants(builder, loc, lbConst.getValue(),
1880 ubConst.getValue(), step,
1881 bodyBuilderFn);
1882 return builder.create<AffineForOp>(loc, lb, builder.getDimIdentityMap(), ub,
1883 builder.getDimIdentityMap(), step,
1884 /*iterArgs=*/llvm::None, bodyBuilderFn);
1885 }
1886
buildAffineLoopNest(OpBuilder & builder,Location loc,ArrayRef<int64_t> lbs,ArrayRef<int64_t> ubs,ArrayRef<int64_t> steps,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilderFn)1887 void mlir::buildAffineLoopNest(
1888 OpBuilder &builder, Location loc, ArrayRef<int64_t> lbs,
1889 ArrayRef<int64_t> ubs, ArrayRef<int64_t> steps,
1890 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1891 buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
1892 buildAffineLoopFromConstants);
1893 }
1894
buildAffineLoopNest(OpBuilder & builder,Location loc,ValueRange lbs,ValueRange ubs,ArrayRef<int64_t> steps,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilderFn)1895 void mlir::buildAffineLoopNest(
1896 OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
1897 ArrayRef<int64_t> steps,
1898 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1899 buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
1900 buildAffineLoopFromValues);
1901 }
1902
replaceForOpWithNewYields(OpBuilder & b,AffineForOp loop,ValueRange newIterOperands,ValueRange newYieldedValues,ValueRange newIterArgs,bool replaceLoopResults)1903 AffineForOp mlir::replaceForOpWithNewYields(OpBuilder &b, AffineForOp loop,
1904 ValueRange newIterOperands,
1905 ValueRange newYieldedValues,
1906 ValueRange newIterArgs,
1907 bool replaceLoopResults) {
1908 assert(newIterOperands.size() == newYieldedValues.size() &&
1909 "newIterOperands must be of the same size as newYieldedValues");
1910 // Create a new loop before the existing one, with the extra operands.
1911 OpBuilder::InsertionGuard g(b);
1912 b.setInsertionPoint(loop);
1913 auto operands = llvm::to_vector<4>(loop.getIterOperands());
1914 operands.append(newIterOperands.begin(), newIterOperands.end());
1915 SmallVector<Value, 4> lbOperands(loop.getLowerBoundOperands());
1916 SmallVector<Value, 4> ubOperands(loop.getUpperBoundOperands());
1917 SmallVector<Value, 4> steps(loop.getStep());
1918 auto lbMap = loop.getLowerBoundMap();
1919 auto ubMap = loop.getUpperBoundMap();
1920 AffineForOp newLoop =
1921 b.create<AffineForOp>(loop.getLoc(), lbOperands, lbMap, ubOperands, ubMap,
1922 loop.getStep(), operands);
1923 // Take the body of the original parent loop.
1924 newLoop.getLoopBody().takeBody(loop.getLoopBody());
1925 for (Value val : newIterArgs)
1926 newLoop.getLoopBody().addArgument(val.getType());
1927
1928 // Update yield operation with new values to be added.
1929 if (!newYieldedValues.empty()) {
1930 auto yield = cast<AffineYieldOp>(newLoop.getBody()->getTerminator());
1931 b.setInsertionPoint(yield);
1932 auto yieldOperands = llvm::to_vector<4>(yield.getOperands());
1933 yieldOperands.append(newYieldedValues.begin(), newYieldedValues.end());
1934 b.create<AffineYieldOp>(yield.getLoc(), yieldOperands);
1935 yield.erase();
1936 }
1937 if (replaceLoopResults) {
1938 for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
1939 loop.getNumResults()))) {
1940 std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
1941 }
1942 }
1943 return newLoop;
1944 }
1945
1946 //===----------------------------------------------------------------------===//
1947 // AffineIfOp
1948 //===----------------------------------------------------------------------===//
1949
1950 namespace {
1951 /// Remove else blocks that have nothing other than a zero value yield.
1952 struct SimplifyDeadElse : public OpRewritePattern<AffineIfOp> {
1953 using OpRewritePattern<AffineIfOp>::OpRewritePattern;
1954
matchAndRewrite__anon4d71c4591211::SimplifyDeadElse1955 LogicalResult matchAndRewrite(AffineIfOp ifOp,
1956 PatternRewriter &rewriter) const override {
1957 if (ifOp.elseRegion().empty() ||
1958 !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
1959 return failure();
1960
1961 rewriter.startRootUpdate(ifOp);
1962 rewriter.eraseBlock(ifOp.getElseBlock());
1963 rewriter.finalizeRootUpdate(ifOp);
1964 return success();
1965 }
1966 };
1967
1968 /// Removes affine.if cond if the condition is always true or false in certain
1969 /// trivial cases. Promotes the then/else block in the parent operation block.
1970 struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> {
1971 using OpRewritePattern<AffineIfOp>::OpRewritePattern;
1972
matchAndRewrite__anon4d71c4591211::AlwaysTrueOrFalseIf1973 LogicalResult matchAndRewrite(AffineIfOp op,
1974 PatternRewriter &rewriter) const override {
1975
1976 auto isTriviallyFalse = [](IntegerSet iSet) {
1977 return iSet.isEmptyIntegerSet();
1978 };
1979
1980 auto isTriviallyTrue = [](IntegerSet iSet) {
1981 return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 &&
1982 iSet.getConstraint(0) == 0);
1983 };
1984
1985 IntegerSet affineIfConditions = op.getIntegerSet();
1986 Block *blockToMove;
1987 if (isTriviallyFalse(affineIfConditions)) {
1988 // The absence, or equivalently, the emptiness of the else region need not
1989 // be checked when affine.if is returning results because if an affine.if
1990 // operation is returning results, it always has a non-empty else region.
1991 if (op.getNumResults() == 0 && !op.hasElse()) {
1992 // If the else region is absent, or equivalently, empty, remove the
1993 // affine.if operation (which is not returning any results).
1994 rewriter.eraseOp(op);
1995 return success();
1996 }
1997 blockToMove = op.getElseBlock();
1998 } else if (isTriviallyTrue(affineIfConditions)) {
1999 blockToMove = op.getThenBlock();
2000 } else {
2001 return failure();
2002 }
2003 Operation *blockToMoveTerminator = blockToMove->getTerminator();
2004 // Promote the "blockToMove" block to the parent operation block between the
2005 // prologue and epilogue of "op".
2006 rewriter.mergeBlockBefore(blockToMove, op);
2007 // Replace the "op" operation with the operands of the
2008 // "blockToMoveTerminator" operation. Note that "blockToMoveTerminator" is
2009 // the affine.yield operation present in the "blockToMove" block. It has no
2010 // operands when affine.if is not returning results and therefore, in that
2011 // case, replaceOp just erases "op". When affine.if is not returning
2012 // results, the affine.yield operation can be omitted. It gets inserted
2013 // implicitly.
2014 rewriter.replaceOp(op, blockToMoveTerminator->getOperands());
2015 // Erase the "blockToMoveTerminator" operation since it is now in the parent
2016 // operation block, which already has its own terminator.
2017 rewriter.eraseOp(blockToMoveTerminator);
2018 return success();
2019 }
2020 };
2021 } // end anonymous namespace.
2022
verify(AffineIfOp op)2023 static LogicalResult verify(AffineIfOp op) {
2024 // Verify that we have a condition attribute.
2025 auto conditionAttr =
2026 op->getAttrOfType<IntegerSetAttr>(op.getConditionAttrName());
2027 if (!conditionAttr)
2028 return op.emitOpError(
2029 "requires an integer set attribute named 'condition'");
2030
2031 // Verify that there are enough operands for the condition.
2032 IntegerSet condition = conditionAttr.getValue();
2033 if (op.getNumOperands() != condition.getNumInputs())
2034 return op.emitOpError(
2035 "operand count and condition integer set dimension and "
2036 "symbol count must match");
2037
2038 // Verify that the operands are valid dimension/symbols.
2039 if (failed(verifyDimAndSymbolIdentifiers(op, op.getOperands(),
2040 condition.getNumDims())))
2041 return failure();
2042
2043 return success();
2044 }
2045
parseAffineIfOp(OpAsmParser & parser,OperationState & result)2046 static ParseResult parseAffineIfOp(OpAsmParser &parser,
2047 OperationState &result) {
2048 // Parse the condition attribute set.
2049 IntegerSetAttr conditionAttr;
2050 unsigned numDims;
2051 if (parser.parseAttribute(conditionAttr, AffineIfOp::getConditionAttrName(),
2052 result.attributes) ||
2053 parseDimAndSymbolList(parser, result.operands, numDims))
2054 return failure();
2055
2056 // Verify the condition operands.
2057 auto set = conditionAttr.getValue();
2058 if (set.getNumDims() != numDims)
2059 return parser.emitError(
2060 parser.getNameLoc(),
2061 "dim operand count and integer set dim count must match");
2062 if (numDims + set.getNumSymbols() != result.operands.size())
2063 return parser.emitError(
2064 parser.getNameLoc(),
2065 "symbol operand count and integer set symbol count must match");
2066
2067 if (parser.parseOptionalArrowTypeList(result.types))
2068 return failure();
2069
2070 // Create the regions for 'then' and 'else'. The latter must be created even
2071 // if it remains empty for the validity of the operation.
2072 result.regions.reserve(2);
2073 Region *thenRegion = result.addRegion();
2074 Region *elseRegion = result.addRegion();
2075
2076 // Parse the 'then' region.
2077 if (parser.parseRegion(*thenRegion, {}, {}))
2078 return failure();
2079 AffineIfOp::ensureTerminator(*thenRegion, parser.getBuilder(),
2080 result.location);
2081
2082 // If we find an 'else' keyword then parse the 'else' region.
2083 if (!parser.parseOptionalKeyword("else")) {
2084 if (parser.parseRegion(*elseRegion, {}, {}))
2085 return failure();
2086 AffineIfOp::ensureTerminator(*elseRegion, parser.getBuilder(),
2087 result.location);
2088 }
2089
2090 // Parse the optional attribute list.
2091 if (parser.parseOptionalAttrDict(result.attributes))
2092 return failure();
2093
2094 return success();
2095 }
2096
print(OpAsmPrinter & p,AffineIfOp op)2097 static void print(OpAsmPrinter &p, AffineIfOp op) {
2098 auto conditionAttr =
2099 op->getAttrOfType<IntegerSetAttr>(op.getConditionAttrName());
2100 p << "affine.if " << conditionAttr;
2101 printDimAndSymbolList(op.operand_begin(), op.operand_end(),
2102 conditionAttr.getValue().getNumDims(), p);
2103 p.printOptionalArrowTypeList(op.getResultTypes());
2104 p.printRegion(op.thenRegion(),
2105 /*printEntryBlockArgs=*/false,
2106 /*printBlockTerminators=*/op.getNumResults());
2107
2108 // Print the 'else' regions if it has any blocks.
2109 auto &elseRegion = op.elseRegion();
2110 if (!elseRegion.empty()) {
2111 p << " else";
2112 p.printRegion(elseRegion,
2113 /*printEntryBlockArgs=*/false,
2114 /*printBlockTerminators=*/op.getNumResults());
2115 }
2116
2117 // Print the attribute list.
2118 p.printOptionalAttrDict(op->getAttrs(),
2119 /*elidedAttrs=*/op.getConditionAttrName());
2120 }
2121
getIntegerSet()2122 IntegerSet AffineIfOp::getIntegerSet() {
2123 return (*this)
2124 ->getAttrOfType<IntegerSetAttr>(getConditionAttrName())
2125 .getValue();
2126 }
2127
setIntegerSet(IntegerSet newSet)2128 void AffineIfOp::setIntegerSet(IntegerSet newSet) {
2129 (*this)->setAttr(getConditionAttrName(), IntegerSetAttr::get(newSet));
2130 }
2131
setConditional(IntegerSet set,ValueRange operands)2132 void AffineIfOp::setConditional(IntegerSet set, ValueRange operands) {
2133 setIntegerSet(set);
2134 (*this)->setOperands(operands);
2135 }
2136
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,IntegerSet set,ValueRange args,bool withElseRegion)2137 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
2138 TypeRange resultTypes, IntegerSet set, ValueRange args,
2139 bool withElseRegion) {
2140 assert(resultTypes.empty() || withElseRegion);
2141 result.addTypes(resultTypes);
2142 result.addOperands(args);
2143 result.addAttribute(getConditionAttrName(), IntegerSetAttr::get(set));
2144
2145 Region *thenRegion = result.addRegion();
2146 thenRegion->push_back(new Block());
2147 if (resultTypes.empty())
2148 AffineIfOp::ensureTerminator(*thenRegion, builder, result.location);
2149
2150 Region *elseRegion = result.addRegion();
2151 if (withElseRegion) {
2152 elseRegion->push_back(new Block());
2153 if (resultTypes.empty())
2154 AffineIfOp::ensureTerminator(*elseRegion, builder, result.location);
2155 }
2156 }
2157
build(OpBuilder & builder,OperationState & result,IntegerSet set,ValueRange args,bool withElseRegion)2158 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
2159 IntegerSet set, ValueRange args, bool withElseRegion) {
2160 AffineIfOp::build(builder, result, /*resultTypes=*/{}, set, args,
2161 withElseRegion);
2162 }
2163
2164 /// Canonicalize an affine if op's conditional (integer set + operands).
fold(ArrayRef<Attribute>,SmallVectorImpl<OpFoldResult> &)2165 LogicalResult AffineIfOp::fold(ArrayRef<Attribute>,
2166 SmallVectorImpl<OpFoldResult> &) {
2167 auto set = getIntegerSet();
2168 SmallVector<Value, 4> operands(getOperands());
2169 canonicalizeSetAndOperands(&set, &operands);
2170
2171 // Any canonicalization change always leads to either a reduction in the
2172 // number of operands or a change in the number of symbolic operands
2173 // (promotion of dims to symbols).
2174 if (operands.size() < getIntegerSet().getNumInputs() ||
2175 set.getNumSymbols() > getIntegerSet().getNumSymbols()) {
2176 setConditional(set, operands);
2177 return success();
2178 }
2179
2180 return failure();
2181 }
2182
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2183 void AffineIfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2184 MLIRContext *context) {
2185 results.add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context);
2186 }
2187
2188 //===----------------------------------------------------------------------===//
2189 // AffineLoadOp
2190 //===----------------------------------------------------------------------===//
2191
build(OpBuilder & builder,OperationState & result,AffineMap map,ValueRange operands)2192 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2193 AffineMap map, ValueRange operands) {
2194 assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
2195 result.addOperands(operands);
2196 if (map)
2197 result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
2198 auto memrefType = operands[0].getType().cast<MemRefType>();
2199 result.types.push_back(memrefType.getElementType());
2200 }
2201
build(OpBuilder & builder,OperationState & result,Value memref,AffineMap map,ValueRange mapOperands)2202 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2203 Value memref, AffineMap map, ValueRange mapOperands) {
2204 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
2205 result.addOperands(memref);
2206 result.addOperands(mapOperands);
2207 auto memrefType = memref.getType().cast<MemRefType>();
2208 result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
2209 result.types.push_back(memrefType.getElementType());
2210 }
2211
build(OpBuilder & builder,OperationState & result,Value memref,ValueRange indices)2212 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2213 Value memref, ValueRange indices) {
2214 auto memrefType = memref.getType().cast<MemRefType>();
2215 int64_t rank = memrefType.getRank();
2216 // Create identity map for memrefs with at least one dimension or () -> ()
2217 // for zero-dimensional memrefs.
2218 auto map =
2219 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
2220 build(builder, result, memref, map, indices);
2221 }
2222
parseAffineLoadOp(OpAsmParser & parser,OperationState & result)2223 static ParseResult parseAffineLoadOp(OpAsmParser &parser,
2224 OperationState &result) {
2225 auto &builder = parser.getBuilder();
2226 auto indexTy = builder.getIndexType();
2227
2228 MemRefType type;
2229 OpAsmParser::OperandType memrefInfo;
2230 AffineMapAttr mapAttr;
2231 SmallVector<OpAsmParser::OperandType, 1> mapOperands;
2232 return failure(
2233 parser.parseOperand(memrefInfo) ||
2234 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
2235 AffineLoadOp::getMapAttrName(),
2236 result.attributes) ||
2237 parser.parseOptionalAttrDict(result.attributes) ||
2238 parser.parseColonType(type) ||
2239 parser.resolveOperand(memrefInfo, type, result.operands) ||
2240 parser.resolveOperands(mapOperands, indexTy, result.operands) ||
2241 parser.addTypeToList(type.getElementType(), result.types));
2242 }
2243
print(OpAsmPrinter & p,AffineLoadOp op)2244 static void print(OpAsmPrinter &p, AffineLoadOp op) {
2245 p << "affine.load " << op.getMemRef() << '[';
2246 if (AffineMapAttr mapAttr =
2247 op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
2248 p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
2249 p << ']';
2250 p.printOptionalAttrDict(op->getAttrs(),
2251 /*elidedAttrs=*/{op.getMapAttrName()});
2252 p << " : " << op.getMemRefType();
2253 }
2254
2255 /// Verify common indexing invariants of affine.load, affine.store,
2256 /// affine.vector_load and affine.vector_store.
2257 static LogicalResult
verifyMemoryOpIndexing(Operation * op,AffineMapAttr mapAttr,Operation::operand_range mapOperands,MemRefType memrefType,unsigned numIndexOperands)2258 verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
2259 Operation::operand_range mapOperands,
2260 MemRefType memrefType, unsigned numIndexOperands) {
2261 if (mapAttr) {
2262 AffineMap map = mapAttr.getValue();
2263 if (map.getNumResults() != memrefType.getRank())
2264 return op->emitOpError("affine map num results must equal memref rank");
2265 if (map.getNumInputs() != numIndexOperands)
2266 return op->emitOpError("expects as many subscripts as affine map inputs");
2267 } else {
2268 if (memrefType.getRank() != numIndexOperands)
2269 return op->emitOpError(
2270 "expects the number of subscripts to be equal to memref rank");
2271 }
2272
2273 Region *scope = getAffineScope(op);
2274 for (auto idx : mapOperands) {
2275 if (!idx.getType().isIndex())
2276 return op->emitOpError("index to load must have 'index' type");
2277 if (!isValidAffineIndexOperand(idx, scope))
2278 return op->emitOpError("index must be a dimension or symbol identifier");
2279 }
2280
2281 return success();
2282 }
2283
verify(AffineLoadOp op)2284 LogicalResult verify(AffineLoadOp op) {
2285 auto memrefType = op.getMemRefType();
2286 if (op.getType() != memrefType.getElementType())
2287 return op.emitOpError("result type must match element type of memref");
2288
2289 if (failed(verifyMemoryOpIndexing(
2290 op.getOperation(),
2291 op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
2292 op.getMapOperands(), memrefType,
2293 /*numIndexOperands=*/op.getNumOperands() - 1)))
2294 return failure();
2295
2296 return success();
2297 }
2298
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2299 void AffineLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
2300 MLIRContext *context) {
2301 results.add<SimplifyAffineOp<AffineLoadOp>>(context);
2302 }
2303
fold(ArrayRef<Attribute> cstOperands)2304 OpFoldResult AffineLoadOp::fold(ArrayRef<Attribute> cstOperands) {
2305 /// load(memrefcast) -> load
2306 if (succeeded(foldMemRefCast(*this)))
2307 return getResult();
2308 return OpFoldResult();
2309 }
2310
2311 //===----------------------------------------------------------------------===//
2312 // AffineStoreOp
2313 //===----------------------------------------------------------------------===//
2314
build(OpBuilder & builder,OperationState & result,Value valueToStore,Value memref,AffineMap map,ValueRange mapOperands)2315 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
2316 Value valueToStore, Value memref, AffineMap map,
2317 ValueRange mapOperands) {
2318 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
2319 result.addOperands(valueToStore);
2320 result.addOperands(memref);
2321 result.addOperands(mapOperands);
2322 result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
2323 }
2324
2325 // Use identity map.
build(OpBuilder & builder,OperationState & result,Value valueToStore,Value memref,ValueRange indices)2326 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
2327 Value valueToStore, Value memref,
2328 ValueRange indices) {
2329 auto memrefType = memref.getType().cast<MemRefType>();
2330 int64_t rank = memrefType.getRank();
2331 // Create identity map for memrefs with at least one dimension or () -> ()
2332 // for zero-dimensional memrefs.
2333 auto map =
2334 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
2335 build(builder, result, valueToStore, memref, map, indices);
2336 }
2337
parseAffineStoreOp(OpAsmParser & parser,OperationState & result)2338 static ParseResult parseAffineStoreOp(OpAsmParser &parser,
2339 OperationState &result) {
2340 auto indexTy = parser.getBuilder().getIndexType();
2341
2342 MemRefType type;
2343 OpAsmParser::OperandType storeValueInfo;
2344 OpAsmParser::OperandType memrefInfo;
2345 AffineMapAttr mapAttr;
2346 SmallVector<OpAsmParser::OperandType, 1> mapOperands;
2347 return failure(parser.parseOperand(storeValueInfo) || parser.parseComma() ||
2348 parser.parseOperand(memrefInfo) ||
2349 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
2350 AffineStoreOp::getMapAttrName(),
2351 result.attributes) ||
2352 parser.parseOptionalAttrDict(result.attributes) ||
2353 parser.parseColonType(type) ||
2354 parser.resolveOperand(storeValueInfo, type.getElementType(),
2355 result.operands) ||
2356 parser.resolveOperand(memrefInfo, type, result.operands) ||
2357 parser.resolveOperands(mapOperands, indexTy, result.operands));
2358 }
2359
print(OpAsmPrinter & p,AffineStoreOp op)2360 static void print(OpAsmPrinter &p, AffineStoreOp op) {
2361 p << "affine.store " << op.getValueToStore();
2362 p << ", " << op.getMemRef() << '[';
2363 if (AffineMapAttr mapAttr =
2364 op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
2365 p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
2366 p << ']';
2367 p.printOptionalAttrDict(op->getAttrs(),
2368 /*elidedAttrs=*/{op.getMapAttrName()});
2369 p << " : " << op.getMemRefType();
2370 }
2371
verify(AffineStoreOp op)2372 LogicalResult verify(AffineStoreOp op) {
2373 // First operand must have same type as memref element type.
2374 auto memrefType = op.getMemRefType();
2375 if (op.getValueToStore().getType() != memrefType.getElementType())
2376 return op.emitOpError(
2377 "first operand must have same type memref element type");
2378
2379 if (failed(verifyMemoryOpIndexing(
2380 op.getOperation(),
2381 op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
2382 op.getMapOperands(), memrefType,
2383 /*numIndexOperands=*/op.getNumOperands() - 2)))
2384 return failure();
2385
2386 return success();
2387 }
2388
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2389 void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
2390 MLIRContext *context) {
2391 results.add<SimplifyAffineOp<AffineStoreOp>>(context);
2392 }
2393
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)2394 LogicalResult AffineStoreOp::fold(ArrayRef<Attribute> cstOperands,
2395 SmallVectorImpl<OpFoldResult> &results) {
2396 /// store(memrefcast) -> store
2397 return foldMemRefCast(*this, getValueToStore());
2398 }
2399
2400 //===----------------------------------------------------------------------===//
2401 // AffineMinMaxOpBase
2402 //===----------------------------------------------------------------------===//
2403
2404 template <typename T>
verifyAffineMinMaxOp(T op)2405 static LogicalResult verifyAffineMinMaxOp(T op) {
2406 // Verify that operand count matches affine map dimension and symbol count.
2407 if (op.getNumOperands() != op.map().getNumDims() + op.map().getNumSymbols())
2408 return op.emitOpError(
2409 "operand count and affine map dimension and symbol count must match");
2410 return success();
2411 }
2412
2413 template <typename T>
printAffineMinMaxOp(OpAsmPrinter & p,T op)2414 static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
2415 p << op.getOperationName() << ' ' << op->getAttr(T::getMapAttrName());
2416 auto operands = op.getOperands();
2417 unsigned numDims = op.map().getNumDims();
2418 p << '(' << operands.take_front(numDims) << ')';
2419
2420 if (operands.size() != numDims)
2421 p << '[' << operands.drop_front(numDims) << ']';
2422 p.printOptionalAttrDict(op->getAttrs(),
2423 /*elidedAttrs=*/{T::getMapAttrName()});
2424 }
2425
2426 template <typename T>
parseAffineMinMaxOp(OpAsmParser & parser,OperationState & result)2427 static ParseResult parseAffineMinMaxOp(OpAsmParser &parser,
2428 OperationState &result) {
2429 auto &builder = parser.getBuilder();
2430 auto indexType = builder.getIndexType();
2431 SmallVector<OpAsmParser::OperandType, 8> dimInfos;
2432 SmallVector<OpAsmParser::OperandType, 8> symInfos;
2433 AffineMapAttr mapAttr;
2434 return failure(
2435 parser.parseAttribute(mapAttr, T::getMapAttrName(), result.attributes) ||
2436 parser.parseOperandList(dimInfos, OpAsmParser::Delimiter::Paren) ||
2437 parser.parseOperandList(symInfos,
2438 OpAsmParser::Delimiter::OptionalSquare) ||
2439 parser.parseOptionalAttrDict(result.attributes) ||
2440 parser.resolveOperands(dimInfos, indexType, result.operands) ||
2441 parser.resolveOperands(symInfos, indexType, result.operands) ||
2442 parser.addTypeToList(indexType, result.types));
2443 }
2444
2445 /// Fold an affine min or max operation with the given operands. The operand
2446 /// list may contain nulls, which are interpreted as the operand not being a
2447 /// constant.
2448 template <typename T>
foldMinMaxOp(T op,ArrayRef<Attribute> operands)2449 static OpFoldResult foldMinMaxOp(T op, ArrayRef<Attribute> operands) {
2450 static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
2451 "expected affine min or max op");
2452
2453 // Fold the affine map.
2454 // TODO: Fold more cases:
2455 // min(some_affine, some_affine + constant, ...), etc.
2456 SmallVector<int64_t, 2> results;
2457 auto foldedMap = op.map().partialConstantFold(operands, &results);
2458
2459 // If some of the map results are not constant, try changing the map in-place.
2460 if (results.empty()) {
2461 // If the map is the same, report that folding did not happen.
2462 if (foldedMap == op.map())
2463 return {};
2464 op->setAttr("map", AffineMapAttr::get(foldedMap));
2465 return op.getResult();
2466 }
2467
2468 // Otherwise, completely fold the op into a constant.
2469 auto resultIt = std::is_same<T, AffineMinOp>::value
2470 ? std::min_element(results.begin(), results.end())
2471 : std::max_element(results.begin(), results.end());
2472 if (resultIt == results.end())
2473 return {};
2474 return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt);
2475 }
2476
2477 /// Remove duplicated expressions in affine min/max ops.
2478 template <typename T>
2479 struct DeduplicateAffineMinMaxExpressions : public OpRewritePattern<T> {
2480 using OpRewritePattern<T>::OpRewritePattern;
2481
matchAndRewriteDeduplicateAffineMinMaxExpressions2482 LogicalResult matchAndRewrite(T affineOp,
2483 PatternRewriter &rewriter) const override {
2484 AffineMap oldMap = affineOp.getAffineMap();
2485
2486 SmallVector<AffineExpr, 4> newExprs;
2487 for (AffineExpr expr : oldMap.getResults()) {
2488 // This is a linear scan over newExprs, but it should be fine given that
2489 // we typically just have a few expressions per op.
2490 if (!llvm::is_contained(newExprs, expr))
2491 newExprs.push_back(expr);
2492 }
2493
2494 if (newExprs.size() == oldMap.getNumResults())
2495 return failure();
2496
2497 auto newMap = AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(),
2498 newExprs, rewriter.getContext());
2499 rewriter.replaceOpWithNewOp<T>(affineOp, newMap, affineOp.getMapOperands());
2500
2501 return success();
2502 }
2503 };
2504
2505 /// Merge an affine min/max op to its consumers if its consumer is also an
2506 /// affine min/max op.
2507 ///
2508 /// This pattern requires the producer affine min/max op is bound to a
2509 /// dimension/symbol that is used as a standalone expression in the consumer
2510 /// affine op's map.
2511 ///
2512 /// For example, a pattern like the following:
2513 ///
2514 /// %0 = affine.min affine_map<()[s0] -> (s0 + 16, s0 * 8)> ()[%sym1]
2515 /// %1 = affine.min affine_map<(d0)[s0] -> (s0 + 4, d0)> (%0)[%sym2]
2516 ///
2517 /// Can be turned into:
2518 ///
2519 /// %1 = affine.min affine_map<
2520 /// ()[s0, s1] -> (s0 + 4, s1 + 16, s1 * 8)> ()[%sym2, %sym1]
2521 template <typename T>
2522 struct MergeAffineMinMaxOp : public OpRewritePattern<T> {
2523 using OpRewritePattern<T>::OpRewritePattern;
2524
matchAndRewriteMergeAffineMinMaxOp2525 LogicalResult matchAndRewrite(T affineOp,
2526 PatternRewriter &rewriter) const override {
2527 AffineMap oldMap = affineOp.getAffineMap();
2528 ValueRange dimOperands =
2529 affineOp.getMapOperands().take_front(oldMap.getNumDims());
2530 ValueRange symOperands =
2531 affineOp.getMapOperands().take_back(oldMap.getNumSymbols());
2532
2533 auto newDimOperands = llvm::to_vector<8>(dimOperands);
2534 auto newSymOperands = llvm::to_vector<8>(symOperands);
2535 SmallVector<AffineExpr, 4> newExprs;
2536 SmallVector<T, 4> producerOps;
2537
2538 // Go over each expression to see whether it's a single dimension/symbol
2539 // with the corresponding operand which is the result of another affine
2540 // min/max op. If So it can be merged into this affine op.
2541 for (AffineExpr expr : oldMap.getResults()) {
2542 if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
2543 Value symValue = symOperands[symExpr.getPosition()];
2544 if (auto producerOp = symValue.getDefiningOp<T>()) {
2545 producerOps.push_back(producerOp);
2546 continue;
2547 }
2548 } else if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
2549 Value dimValue = dimOperands[dimExpr.getPosition()];
2550 if (auto producerOp = dimValue.getDefiningOp<T>()) {
2551 producerOps.push_back(producerOp);
2552 continue;
2553 }
2554 }
2555 // For the above cases we will remove the expression by merging the
2556 // producer affine min/max's affine expressions. Otherwise we need to
2557 // keep the existing expression.
2558 newExprs.push_back(expr);
2559 }
2560
2561 if (producerOps.empty())
2562 return failure();
2563
2564 unsigned numUsedDims = oldMap.getNumDims();
2565 unsigned numUsedSyms = oldMap.getNumSymbols();
2566
2567 // Now go over all producer affine ops and merge their expressions.
2568 for (T producerOp : producerOps) {
2569 AffineMap producerMap = producerOp.getAffineMap();
2570 unsigned numProducerDims = producerMap.getNumDims();
2571 unsigned numProducerSyms = producerMap.getNumSymbols();
2572
2573 // Collect all dimension/symbol values.
2574 ValueRange dimValues =
2575 producerOp.getMapOperands().take_front(numProducerDims);
2576 ValueRange symValues =
2577 producerOp.getMapOperands().take_back(numProducerSyms);
2578 newDimOperands.append(dimValues.begin(), dimValues.end());
2579 newSymOperands.append(symValues.begin(), symValues.end());
2580
2581 // For expressions we need to shift to avoid overlap.
2582 for (AffineExpr expr : producerMap.getResults()) {
2583 newExprs.push_back(expr.shiftDims(numProducerDims, numUsedDims)
2584 .shiftSymbols(numProducerSyms, numUsedSyms));
2585 }
2586
2587 numUsedDims += numProducerDims;
2588 numUsedSyms += numProducerSyms;
2589 }
2590
2591 auto newMap = AffineMap::get(numUsedDims, numUsedSyms, newExprs,
2592 rewriter.getContext());
2593 auto newOperands =
2594 llvm::to_vector<8>(llvm::concat<Value>(newDimOperands, newSymOperands));
2595 rewriter.replaceOpWithNewOp<T>(affineOp, newMap, newOperands);
2596
2597 return success();
2598 }
2599 };
2600
2601 template <typename T>
2602 struct CanonicalizeSingleResultAffineMinMaxOp : public OpRewritePattern<T> {
2603 using OpRewritePattern<T>::OpRewritePattern;
2604
matchAndRewriteCanonicalizeSingleResultAffineMinMaxOp2605 LogicalResult matchAndRewrite(T affineOp,
2606 PatternRewriter &rewriter) const override {
2607 if (affineOp.map().getNumResults() != 1)
2608 return failure();
2609 rewriter.replaceOpWithNewOp<AffineApplyOp>(affineOp, affineOp.map(),
2610 affineOp.getOperands());
2611 return success();
2612 }
2613 };
2614
2615 //===----------------------------------------------------------------------===//
2616 // AffineMinOp
2617 //===----------------------------------------------------------------------===//
2618 //
2619 // %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
2620 //
2621
fold(ArrayRef<Attribute> operands)2622 OpFoldResult AffineMinOp::fold(ArrayRef<Attribute> operands) {
2623 return foldMinMaxOp(*this, operands);
2624 }
2625
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)2626 void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2627 MLIRContext *context) {
2628 patterns.add<CanonicalizeSingleResultAffineMinMaxOp<AffineMinOp>,
2629 DeduplicateAffineMinMaxExpressions<AffineMinOp>,
2630 MergeAffineMinMaxOp<AffineMinOp>, SimplifyAffineOp<AffineMinOp>>(
2631 context);
2632 }
2633
2634 //===----------------------------------------------------------------------===//
2635 // AffineMaxOp
2636 //===----------------------------------------------------------------------===//
2637 //
2638 // %0 = affine.max (d0) -> (1000, d0 + 512) (%i0)
2639 //
2640
fold(ArrayRef<Attribute> operands)2641 OpFoldResult AffineMaxOp::fold(ArrayRef<Attribute> operands) {
2642 return foldMinMaxOp(*this, operands);
2643 }
2644
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)2645 void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2646 MLIRContext *context) {
2647 patterns.add<CanonicalizeSingleResultAffineMinMaxOp<AffineMaxOp>,
2648 DeduplicateAffineMinMaxExpressions<AffineMaxOp>,
2649 MergeAffineMinMaxOp<AffineMaxOp>, SimplifyAffineOp<AffineMaxOp>>(
2650 context);
2651 }
2652
2653 //===----------------------------------------------------------------------===//
2654 // AffinePrefetchOp
2655 //===----------------------------------------------------------------------===//
2656
2657 //
2658 // affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32>
2659 //
parseAffinePrefetchOp(OpAsmParser & parser,OperationState & result)2660 static ParseResult parseAffinePrefetchOp(OpAsmParser &parser,
2661 OperationState &result) {
2662 auto &builder = parser.getBuilder();
2663 auto indexTy = builder.getIndexType();
2664
2665 MemRefType type;
2666 OpAsmParser::OperandType memrefInfo;
2667 IntegerAttr hintInfo;
2668 auto i32Type = parser.getBuilder().getIntegerType(32);
2669 StringRef readOrWrite, cacheType;
2670
2671 AffineMapAttr mapAttr;
2672 SmallVector<OpAsmParser::OperandType, 1> mapOperands;
2673 if (parser.parseOperand(memrefInfo) ||
2674 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
2675 AffinePrefetchOp::getMapAttrName(),
2676 result.attributes) ||
2677 parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
2678 parser.parseComma() || parser.parseKeyword("locality") ||
2679 parser.parseLess() ||
2680 parser.parseAttribute(hintInfo, i32Type,
2681 AffinePrefetchOp::getLocalityHintAttrName(),
2682 result.attributes) ||
2683 parser.parseGreater() || parser.parseComma() ||
2684 parser.parseKeyword(&cacheType) ||
2685 parser.parseOptionalAttrDict(result.attributes) ||
2686 parser.parseColonType(type) ||
2687 parser.resolveOperand(memrefInfo, type, result.operands) ||
2688 parser.resolveOperands(mapOperands, indexTy, result.operands))
2689 return failure();
2690
2691 if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
2692 return parser.emitError(parser.getNameLoc(),
2693 "rw specifier has to be 'read' or 'write'");
2694 result.addAttribute(
2695 AffinePrefetchOp::getIsWriteAttrName(),
2696 parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
2697
2698 if (!cacheType.equals("data") && !cacheType.equals("instr"))
2699 return parser.emitError(parser.getNameLoc(),
2700 "cache type has to be 'data' or 'instr'");
2701
2702 result.addAttribute(
2703 AffinePrefetchOp::getIsDataCacheAttrName(),
2704 parser.getBuilder().getBoolAttr(cacheType.equals("data")));
2705
2706 return success();
2707 }
2708
print(OpAsmPrinter & p,AffinePrefetchOp op)2709 static void print(OpAsmPrinter &p, AffinePrefetchOp op) {
2710 p << AffinePrefetchOp::getOperationName() << " " << op.memref() << '[';
2711 AffineMapAttr mapAttr = op->getAttrOfType<AffineMapAttr>(op.getMapAttrName());
2712 if (mapAttr) {
2713 SmallVector<Value, 2> operands(op.getMapOperands());
2714 p.printAffineMapOfSSAIds(mapAttr, operands);
2715 }
2716 p << ']' << ", " << (op.isWrite() ? "write" : "read") << ", "
2717 << "locality<" << op.localityHint() << ">, "
2718 << (op.isDataCache() ? "data" : "instr");
2719 p.printOptionalAttrDict(
2720 op->getAttrs(),
2721 /*elidedAttrs=*/{op.getMapAttrName(), op.getLocalityHintAttrName(),
2722 op.getIsDataCacheAttrName(), op.getIsWriteAttrName()});
2723 p << " : " << op.getMemRefType();
2724 }
2725
verify(AffinePrefetchOp op)2726 static LogicalResult verify(AffinePrefetchOp op) {
2727 auto mapAttr = op->getAttrOfType<AffineMapAttr>(op.getMapAttrName());
2728 if (mapAttr) {
2729 AffineMap map = mapAttr.getValue();
2730 if (map.getNumResults() != op.getMemRefType().getRank())
2731 return op.emitOpError("affine.prefetch affine map num results must equal"
2732 " memref rank");
2733 if (map.getNumInputs() + 1 != op.getNumOperands())
2734 return op.emitOpError("too few operands");
2735 } else {
2736 if (op.getNumOperands() != 1)
2737 return op.emitOpError("too few operands");
2738 }
2739
2740 Region *scope = getAffineScope(op);
2741 for (auto idx : op.getMapOperands()) {
2742 if (!isValidAffineIndexOperand(idx, scope))
2743 return op.emitOpError("index must be a dimension or symbol identifier");
2744 }
2745 return success();
2746 }
2747
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2748 void AffinePrefetchOp::getCanonicalizationPatterns(RewritePatternSet &results,
2749 MLIRContext *context) {
2750 // prefetch(memrefcast) -> prefetch
2751 results.add<SimplifyAffineOp<AffinePrefetchOp>>(context);
2752 }
2753
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)2754 LogicalResult AffinePrefetchOp::fold(ArrayRef<Attribute> cstOperands,
2755 SmallVectorImpl<OpFoldResult> &results) {
2756 /// prefetch(memrefcast) -> prefetch
2757 return foldMemRefCast(*this);
2758 }
2759
2760 //===----------------------------------------------------------------------===//
2761 // AffineParallelOp
2762 //===----------------------------------------------------------------------===//
2763
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,ArrayRef<AtomicRMWKind> reductions,ArrayRef<int64_t> ranges)2764 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
2765 TypeRange resultTypes,
2766 ArrayRef<AtomicRMWKind> reductions,
2767 ArrayRef<int64_t> ranges) {
2768 SmallVector<AffineMap> lbs(ranges.size(), builder.getConstantAffineMap(0));
2769 auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) {
2770 return builder.getConstantAffineMap(value);
2771 }));
2772 SmallVector<int64_t> steps(ranges.size(), 1);
2773 build(builder, result, resultTypes, reductions, lbs, /*lbArgs=*/{}, ubs,
2774 /*ubArgs=*/{}, steps);
2775 }
2776
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,ArrayRef<AtomicRMWKind> reductions,ArrayRef<AffineMap> lbMaps,ValueRange lbArgs,ArrayRef<AffineMap> ubMaps,ValueRange ubArgs,ArrayRef<int64_t> steps)2777 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
2778 TypeRange resultTypes,
2779 ArrayRef<AtomicRMWKind> reductions,
2780 ArrayRef<AffineMap> lbMaps, ValueRange lbArgs,
2781 ArrayRef<AffineMap> ubMaps, ValueRange ubArgs,
2782 ArrayRef<int64_t> steps) {
2783 assert(llvm::all_of(lbMaps,
2784 [lbMaps](AffineMap m) {
2785 return m.getNumDims() == lbMaps[0].getNumDims() &&
2786 m.getNumSymbols() == lbMaps[0].getNumSymbols();
2787 }) &&
2788 "expected all lower bounds maps to have the same number of dimensions "
2789 "and symbols");
2790 assert(llvm::all_of(ubMaps,
2791 [ubMaps](AffineMap m) {
2792 return m.getNumDims() == ubMaps[0].getNumDims() &&
2793 m.getNumSymbols() == ubMaps[0].getNumSymbols();
2794 }) &&
2795 "expected all upper bounds maps to have the same number of dimensions "
2796 "and symbols");
2797 assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) &&
2798 "expected lower bound maps to have as many inputs as lower bound "
2799 "operands");
2800 assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) &&
2801 "expected upper bound maps to have as many inputs as upper bound "
2802 "operands");
2803
2804 result.addTypes(resultTypes);
2805
2806 // Convert the reductions to integer attributes.
2807 SmallVector<Attribute, 4> reductionAttrs;
2808 for (AtomicRMWKind reduction : reductions)
2809 reductionAttrs.push_back(
2810 builder.getI64IntegerAttr(static_cast<int64_t>(reduction)));
2811 result.addAttribute(getReductionsAttrName(),
2812 builder.getArrayAttr(reductionAttrs));
2813
2814 // Concatenates maps defined in the same input space (same dimensions and
2815 // symbols), assumes there is at least one map.
2816 auto concatMapsSameInput = [&builder](ArrayRef<AffineMap> maps,
2817 SmallVectorImpl<int32_t> &groups) {
2818 if (maps.empty())
2819 return AffineMap::get(builder.getContext());
2820 SmallVector<AffineExpr> exprs;
2821 groups.reserve(groups.size() + maps.size());
2822 exprs.reserve(maps.size());
2823 for (AffineMap m : maps) {
2824 llvm::append_range(exprs, m.getResults());
2825 groups.push_back(m.getNumResults());
2826 }
2827 return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs,
2828 maps[0].getContext());
2829 };
2830
2831 // Set up the bounds.
2832 SmallVector<int32_t> lbGroups, ubGroups;
2833 AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups);
2834 AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups);
2835 result.addAttribute(getLowerBoundsMapAttrName(), AffineMapAttr::get(lbMap));
2836 result.addAttribute(getLowerBoundsGroupsAttrName(),
2837 builder.getI32TensorAttr(lbGroups));
2838 result.addAttribute(getUpperBoundsMapAttrName(), AffineMapAttr::get(ubMap));
2839 result.addAttribute(getUpperBoundsGroupsAttrName(),
2840 builder.getI32TensorAttr(ubGroups));
2841 result.addAttribute(getStepsAttrName(), builder.getI64ArrayAttr(steps));
2842 result.addOperands(lbArgs);
2843 result.addOperands(ubArgs);
2844
2845 // Create a region and a block for the body.
2846 auto *bodyRegion = result.addRegion();
2847 auto *body = new Block();
2848 // Add all the block arguments.
2849 for (unsigned i = 0, e = steps.size(); i < e; ++i)
2850 body->addArgument(IndexType::get(builder.getContext()));
2851 bodyRegion->push_back(body);
2852 if (resultTypes.empty())
2853 ensureTerminator(*bodyRegion, builder, result.location);
2854 }
2855
getLoopBody()2856 Region &AffineParallelOp::getLoopBody() { return region(); }
2857
isDefinedOutsideOfLoop(Value value)2858 bool AffineParallelOp::isDefinedOutsideOfLoop(Value value) {
2859 return !region().isAncestor(value.getParentRegion());
2860 }
2861
moveOutOfLoop(ArrayRef<Operation * > ops)2862 LogicalResult AffineParallelOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
2863 for (Operation *op : ops)
2864 op->moveBefore(*this);
2865 return success();
2866 }
2867
getNumDims()2868 unsigned AffineParallelOp::getNumDims() { return steps().size(); }
2869
getLowerBoundsOperands()2870 AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
2871 return getOperands().take_front(lowerBoundsMap().getNumInputs());
2872 }
2873
getUpperBoundsOperands()2874 AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
2875 return getOperands().drop_front(lowerBoundsMap().getNumInputs());
2876 }
2877
getLowerBoundMap(unsigned pos)2878 AffineMap AffineParallelOp::getLowerBoundMap(unsigned pos) {
2879 unsigned start = 0;
2880 for (unsigned i = 0; i < pos; ++i)
2881 start += lowerBoundsGroups().getValue<int32_t>(i);
2882 return lowerBoundsMap().getSliceMap(
2883 start, lowerBoundsGroups().getValue<int32_t>(pos));
2884 }
2885
getUpperBoundMap(unsigned pos)2886 AffineMap AffineParallelOp::getUpperBoundMap(unsigned pos) {
2887 unsigned start = 0;
2888 for (unsigned i = 0; i < pos; ++i)
2889 start += upperBoundsGroups().getValue<int32_t>(i);
2890 return upperBoundsMap().getSliceMap(
2891 start, upperBoundsGroups().getValue<int32_t>(pos));
2892 }
2893
getLowerBoundsValueMap()2894 AffineValueMap AffineParallelOp::getLowerBoundsValueMap() {
2895 return AffineValueMap(lowerBoundsMap(), getLowerBoundsOperands());
2896 }
2897
getUpperBoundsValueMap()2898 AffineValueMap AffineParallelOp::getUpperBoundsValueMap() {
2899 return AffineValueMap(upperBoundsMap(), getUpperBoundsOperands());
2900 }
2901
getConstantRanges()2902 Optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
2903 if (hasMinMaxBounds())
2904 return llvm::None;
2905
2906 // Try to convert all the ranges to constant expressions.
2907 SmallVector<int64_t, 8> out;
2908 AffineValueMap rangesValueMap;
2909 AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
2910 &rangesValueMap);
2911 out.reserve(rangesValueMap.getNumResults());
2912 for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) {
2913 auto expr = rangesValueMap.getResult(i);
2914 auto cst = expr.dyn_cast<AffineConstantExpr>();
2915 if (!cst)
2916 return llvm::None;
2917 out.push_back(cst.getValue());
2918 }
2919 return out;
2920 }
2921
getBody()2922 Block *AffineParallelOp::getBody() { return ®ion().front(); }
2923
getBodyBuilder()2924 OpBuilder AffineParallelOp::getBodyBuilder() {
2925 return OpBuilder(getBody(), std::prev(getBody()->end()));
2926 }
2927
setLowerBounds(ValueRange lbOperands,AffineMap map)2928 void AffineParallelOp::setLowerBounds(ValueRange lbOperands, AffineMap map) {
2929 assert(lbOperands.size() == map.getNumInputs() &&
2930 "operands to map must match number of inputs");
2931
2932 auto ubOperands = getUpperBoundsOperands();
2933
2934 SmallVector<Value, 4> newOperands(lbOperands);
2935 newOperands.append(ubOperands.begin(), ubOperands.end());
2936 (*this)->setOperands(newOperands);
2937
2938 lowerBoundsMapAttr(AffineMapAttr::get(map));
2939 }
2940
setUpperBounds(ValueRange ubOperands,AffineMap map)2941 void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) {
2942 assert(ubOperands.size() == map.getNumInputs() &&
2943 "operands to map must match number of inputs");
2944
2945 SmallVector<Value, 4> newOperands(getLowerBoundsOperands());
2946 newOperands.append(ubOperands.begin(), ubOperands.end());
2947 (*this)->setOperands(newOperands);
2948
2949 upperBoundsMapAttr(AffineMapAttr::get(map));
2950 }
2951
setLowerBoundsMap(AffineMap map)2952 void AffineParallelOp::setLowerBoundsMap(AffineMap map) {
2953 AffineMap lbMap = lowerBoundsMap();
2954 assert(lbMap.getNumDims() == map.getNumDims() &&
2955 lbMap.getNumSymbols() == map.getNumSymbols());
2956 (void)lbMap;
2957 lowerBoundsMapAttr(AffineMapAttr::get(map));
2958 }
2959
setUpperBoundsMap(AffineMap map)2960 void AffineParallelOp::setUpperBoundsMap(AffineMap map) {
2961 AffineMap ubMap = upperBoundsMap();
2962 assert(ubMap.getNumDims() == map.getNumDims() &&
2963 ubMap.getNumSymbols() == map.getNumSymbols());
2964 (void)ubMap;
2965 upperBoundsMapAttr(AffineMapAttr::get(map));
2966 }
2967
getSteps()2968 SmallVector<int64_t, 8> AffineParallelOp::getSteps() {
2969 SmallVector<int64_t, 8> result;
2970 for (Attribute attr : steps()) {
2971 result.push_back(attr.cast<IntegerAttr>().getInt());
2972 }
2973 return result;
2974 }
2975
setSteps(ArrayRef<int64_t> newSteps)2976 void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
2977 stepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
2978 }
2979
verify(AffineParallelOp op)2980 static LogicalResult verify(AffineParallelOp op) {
2981 auto numDims = op.getNumDims();
2982 if (op.lowerBoundsGroups().getNumElements() != numDims ||
2983 op.upperBoundsGroups().getNumElements() != numDims ||
2984 op.steps().size() != numDims ||
2985 op.getBody()->getNumArguments() != numDims) {
2986 return op.emitOpError()
2987 << "the number of region arguments ("
2988 << op.getBody()->getNumArguments()
2989 << ") and the number of map groups for lower ("
2990 << op.lowerBoundsGroups().getNumElements() << ") and upper bound ("
2991 << op.upperBoundsGroups().getNumElements()
2992 << "), and the number of steps (" << op.steps().size()
2993 << ") must all match";
2994 }
2995
2996 unsigned expectedNumLBResults = 0;
2997 for (APInt v : op.lowerBoundsGroups())
2998 expectedNumLBResults += v.getZExtValue();
2999 if (expectedNumLBResults != op.lowerBoundsMap().getNumResults())
3000 return op.emitOpError() << "expected lower bounds map to have "
3001 << expectedNumLBResults << " results";
3002 unsigned expectedNumUBResults = 0;
3003 for (APInt v : op.upperBoundsGroups())
3004 expectedNumUBResults += v.getZExtValue();
3005 if (expectedNumUBResults != op.upperBoundsMap().getNumResults())
3006 return op.emitOpError() << "expected upper bounds map to have "
3007 << expectedNumUBResults << " results";
3008
3009 if (op.reductions().size() != op.getNumResults())
3010 return op.emitOpError("a reduction must be specified for each output");
3011
3012 // Verify reduction ops are all valid
3013 for (Attribute attr : op.reductions()) {
3014 auto intAttr = attr.dyn_cast<IntegerAttr>();
3015 if (!intAttr || !symbolizeAtomicRMWKind(intAttr.getInt()))
3016 return op.emitOpError("invalid reduction attribute");
3017 }
3018
3019 // Verify that the bound operands are valid dimension/symbols.
3020 /// Lower bounds.
3021 if (failed(verifyDimAndSymbolIdentifiers(op, op.getLowerBoundsOperands(),
3022 op.lowerBoundsMap().getNumDims())))
3023 return failure();
3024 /// Upper bounds.
3025 if (failed(verifyDimAndSymbolIdentifiers(op, op.getUpperBoundsOperands(),
3026 op.upperBoundsMap().getNumDims())))
3027 return failure();
3028 return success();
3029 }
3030
canonicalize()3031 LogicalResult AffineValueMap::canonicalize() {
3032 SmallVector<Value, 4> newOperands{operands};
3033 auto newMap = getAffineMap();
3034 composeAffineMapAndOperands(&newMap, &newOperands);
3035 if (newMap == getAffineMap() && newOperands == operands)
3036 return failure();
3037 reset(newMap, newOperands);
3038 return success();
3039 }
3040
3041 /// Canonicalize the bounds of the given loop.
canonicalizeLoopBounds(AffineParallelOp op)3042 static LogicalResult canonicalizeLoopBounds(AffineParallelOp op) {
3043 AffineValueMap lb = op.getLowerBoundsValueMap();
3044 bool lbCanonicalized = succeeded(lb.canonicalize());
3045
3046 AffineValueMap ub = op.getUpperBoundsValueMap();
3047 bool ubCanonicalized = succeeded(ub.canonicalize());
3048
3049 // Any canonicalization change always leads to updated map(s).
3050 if (!lbCanonicalized && !ubCanonicalized)
3051 return failure();
3052
3053 if (lbCanonicalized)
3054 op.setLowerBounds(lb.getOperands(), lb.getAffineMap());
3055 if (ubCanonicalized)
3056 op.setUpperBounds(ub.getOperands(), ub.getAffineMap());
3057
3058 return success();
3059 }
3060
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)3061 LogicalResult AffineParallelOp::fold(ArrayRef<Attribute> operands,
3062 SmallVectorImpl<OpFoldResult> &results) {
3063 return canonicalizeLoopBounds(*this);
3064 }
3065
3066 /// Prints a lower(upper) bound of an affine parallel loop with max(min)
3067 /// conditions in it. `mapAttr` is a flat list of affine expressions and `group`
3068 /// identifies which of the those expressions form max/min groups. `operands`
3069 /// are the SSA values of dimensions and symbols and `keyword` is either "min"
3070 /// or "max".
printMinMaxBound(OpAsmPrinter & p,AffineMapAttr mapAttr,DenseIntElementsAttr group,ValueRange operands,StringRef keyword)3071 static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr,
3072 DenseIntElementsAttr group, ValueRange operands,
3073 StringRef keyword) {
3074 AffineMap map = mapAttr.getValue();
3075 unsigned numDims = map.getNumDims();
3076 ValueRange dimOperands = operands.take_front(numDims);
3077 ValueRange symOperands = operands.drop_front(numDims);
3078 unsigned start = 0;
3079 for (llvm::APInt groupSize : group) {
3080 if (start != 0)
3081 p << ", ";
3082
3083 unsigned size = groupSize.getZExtValue();
3084 if (size == 1) {
3085 p.printAffineExprOfSSAIds(map.getResult(start), dimOperands, symOperands);
3086 ++start;
3087 } else {
3088 p << keyword << '(';
3089 AffineMap submap = map.getSliceMap(start, size);
3090 p.printAffineMapOfSSAIds(AffineMapAttr::get(submap), operands);
3091 p << ')';
3092 start += size;
3093 }
3094 }
3095 }
3096
print(OpAsmPrinter & p,AffineParallelOp op)3097 static void print(OpAsmPrinter &p, AffineParallelOp op) {
3098 p << op.getOperationName() << " (" << op.getBody()->getArguments() << ") = (";
3099 printMinMaxBound(p, op.lowerBoundsMapAttr(), op.lowerBoundsGroupsAttr(),
3100 op.getLowerBoundsOperands(), "max");
3101 p << ") to (";
3102 printMinMaxBound(p, op.upperBoundsMapAttr(), op.upperBoundsGroupsAttr(),
3103 op.getUpperBoundsOperands(), "min");
3104 p << ')';
3105 SmallVector<int64_t, 8> steps = op.getSteps();
3106 bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; });
3107 if (!elideSteps) {
3108 p << " step (";
3109 llvm::interleaveComma(steps, p);
3110 p << ')';
3111 }
3112 if (op.getNumResults()) {
3113 p << " reduce (";
3114 llvm::interleaveComma(op.reductions(), p, [&](auto &attr) {
3115 AtomicRMWKind sym =
3116 *symbolizeAtomicRMWKind(attr.template cast<IntegerAttr>().getInt());
3117 p << "\"" << stringifyAtomicRMWKind(sym) << "\"";
3118 });
3119 p << ") -> (" << op.getResultTypes() << ")";
3120 }
3121
3122 p.printRegion(op.region(), /*printEntryBlockArgs=*/false,
3123 /*printBlockTerminators=*/op.getNumResults());
3124 p.printOptionalAttrDict(
3125 op->getAttrs(),
3126 /*elidedAttrs=*/{AffineParallelOp::getReductionsAttrName(),
3127 AffineParallelOp::getLowerBoundsMapAttrName(),
3128 AffineParallelOp::getLowerBoundsGroupsAttrName(),
3129 AffineParallelOp::getUpperBoundsMapAttrName(),
3130 AffineParallelOp::getUpperBoundsGroupsAttrName(),
3131 AffineParallelOp::getStepsAttrName()});
3132 }
3133
3134 /// Given a list of lists of parsed operands, populates `uniqueOperands` with
3135 /// unique operands. Also populates `replacements with affine expressions of
3136 /// `kind` that can be used to update affine maps previously accepting a
3137 /// `operands` to accept `uniqueOperands` instead.
deduplicateAndResolveOperands(OpAsmParser & parser,ArrayRef<SmallVector<OpAsmParser::OperandType>> operands,SmallVectorImpl<Value> & uniqueOperands,SmallVectorImpl<AffineExpr> & replacements,AffineExprKind kind)3138 static void deduplicateAndResolveOperands(
3139 OpAsmParser &parser,
3140 ArrayRef<SmallVector<OpAsmParser::OperandType>> operands,
3141 SmallVectorImpl<Value> &uniqueOperands,
3142 SmallVectorImpl<AffineExpr> &replacements, AffineExprKind kind) {
3143 assert((kind == AffineExprKind::DimId || kind == AffineExprKind::SymbolId) &&
3144 "expected operands to be dim or symbol expression");
3145
3146 Type indexType = parser.getBuilder().getIndexType();
3147 for (const auto &list : operands) {
3148 SmallVector<Value> valueOperands;
3149 parser.resolveOperands(list, indexType, valueOperands);
3150 for (Value operand : valueOperands) {
3151 unsigned pos = std::distance(uniqueOperands.begin(),
3152 llvm::find(uniqueOperands, operand));
3153 if (pos == uniqueOperands.size())
3154 uniqueOperands.push_back(operand);
3155 replacements.push_back(
3156 kind == AffineExprKind::DimId
3157 ? getAffineDimExpr(pos, parser.getBuilder().getContext())
3158 : getAffineSymbolExpr(pos, parser.getBuilder().getContext()));
3159 }
3160 }
3161 }
3162
3163 namespace {
3164 enum class MinMaxKind { Min, Max };
3165 } // namespace
3166
3167 /// Parses an affine map that can contain a min/max for groups of its results,
3168 /// e.g., max(expr-1, expr-2), expr-3, max(expr-4, expr-5, expr-6). Populates
3169 /// `result` attributes with the map (flat list of expressions) and the grouping
3170 /// (list of integers that specify how many expressions to put into each
3171 /// min/max) attributes. Deduplicates repeated operands.
3172 ///
3173 /// parallel-bound ::= `(` parallel-group-list `)`
3174 /// parallel-group-list ::= parallel-group (`,` parallel-group-list)?
3175 /// parallel-group ::= simple-group | min-max-group
3176 /// simple-group ::= expr-of-ssa-ids
3177 /// min-max-group ::= ( `min` | `max` ) `(` expr-of-ssa-ids-list `)`
3178 /// expr-of-ssa-ids-list ::= expr-of-ssa-ids (`,` expr-of-ssa-id-list)?
3179 ///
3180 /// Examples:
3181 /// (%0, min(%1 + %2, %3), %4, min(%5 floordiv 32, %6))
3182 /// (%0, max(%1 - 2 * %2))
parseAffineMapWithMinMax(OpAsmParser & parser,OperationState & result,MinMaxKind kind)3183 static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser,
3184 OperationState &result,
3185 MinMaxKind kind) {
3186 constexpr llvm::StringLiteral tmpAttrName = "__pseudo_bound_map";
3187
3188 StringRef mapName = kind == MinMaxKind::Min
3189 ? AffineParallelOp::getUpperBoundsMapAttrName()
3190 : AffineParallelOp::getLowerBoundsMapAttrName();
3191 StringRef groupsName = kind == MinMaxKind::Min
3192 ? AffineParallelOp::getUpperBoundsGroupsAttrName()
3193 : AffineParallelOp::getLowerBoundsGroupsAttrName();
3194
3195 if (failed(parser.parseLParen()))
3196 return failure();
3197
3198 if (succeeded(parser.parseOptionalRParen())) {
3199 result.addAttribute(
3200 mapName, AffineMapAttr::get(parser.getBuilder().getEmptyAffineMap()));
3201 result.addAttribute(groupsName, parser.getBuilder().getI32TensorAttr({}));
3202 return success();
3203 }
3204
3205 SmallVector<AffineExpr> flatExprs;
3206 SmallVector<SmallVector<OpAsmParser::OperandType>> flatDimOperands;
3207 SmallVector<SmallVector<OpAsmParser::OperandType>> flatSymOperands;
3208 SmallVector<int32_t> numMapsPerGroup;
3209 SmallVector<OpAsmParser::OperandType> mapOperands;
3210 do {
3211 if (succeeded(parser.parseOptionalKeyword(
3212 kind == MinMaxKind::Min ? "min" : "max"))) {
3213 mapOperands.clear();
3214 AffineMapAttr map;
3215 if (failed(parser.parseAffineMapOfSSAIds(mapOperands, map, tmpAttrName,
3216 result.attributes,
3217 OpAsmParser::Delimiter::Paren)))
3218 return failure();
3219 result.attributes.erase(tmpAttrName);
3220 llvm::append_range(flatExprs, map.getValue().getResults());
3221 auto operandsRef = llvm::makeArrayRef(mapOperands);
3222 auto dimsRef = operandsRef.take_front(map.getValue().getNumDims());
3223 SmallVector<OpAsmParser::OperandType> dims(dimsRef.begin(),
3224 dimsRef.end());
3225 auto symsRef = operandsRef.drop_front(map.getValue().getNumDims());
3226 SmallVector<OpAsmParser::OperandType> syms(symsRef.begin(),
3227 symsRef.end());
3228 flatDimOperands.append(map.getValue().getNumResults(), dims);
3229 flatSymOperands.append(map.getValue().getNumResults(), syms);
3230 numMapsPerGroup.push_back(map.getValue().getNumResults());
3231 } else {
3232 if (failed(parser.parseAffineExprOfSSAIds(flatDimOperands.emplace_back(),
3233 flatSymOperands.emplace_back(),
3234 flatExprs.emplace_back())))
3235 return failure();
3236 numMapsPerGroup.push_back(1);
3237 }
3238 } while (succeeded(parser.parseOptionalComma()));
3239
3240 if (failed(parser.parseRParen()))
3241 return failure();
3242
3243 unsigned totalNumDims = 0;
3244 unsigned totalNumSyms = 0;
3245 for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
3246 unsigned numDims = flatDimOperands[i].size();
3247 unsigned numSyms = flatSymOperands[i].size();
3248 flatExprs[i] = flatExprs[i]
3249 .shiftDims(numDims, totalNumDims)
3250 .shiftSymbols(numSyms, totalNumSyms);
3251 totalNumDims += numDims;
3252 totalNumSyms += numSyms;
3253 }
3254
3255 // Deduplicate map operands.
3256 SmallVector<Value> dimOperands, symOperands;
3257 SmallVector<AffineExpr> dimRplacements, symRepacements;
3258 deduplicateAndResolveOperands(parser, flatDimOperands, dimOperands,
3259 dimRplacements, AffineExprKind::DimId);
3260 deduplicateAndResolveOperands(parser, flatSymOperands, symOperands,
3261 symRepacements, AffineExprKind::SymbolId);
3262
3263 result.operands.append(dimOperands.begin(), dimOperands.end());
3264 result.operands.append(symOperands.begin(), symOperands.end());
3265
3266 Builder &builder = parser.getBuilder();
3267 auto flatMap = AffineMap::get(totalNumDims, totalNumSyms, flatExprs,
3268 parser.getBuilder().getContext());
3269 flatMap = flatMap.replaceDimsAndSymbols(
3270 dimRplacements, symRepacements, dimOperands.size(), symOperands.size());
3271
3272 result.addAttribute(mapName, AffineMapAttr::get(flatMap));
3273 result.addAttribute(groupsName, builder.getI32TensorAttr(numMapsPerGroup));
3274 return success();
3275 }
3276
3277 //
3278 // operation ::= `affine.parallel` `(` ssa-ids `)` `=` parallel-bound
3279 // `to` parallel-bound steps? region attr-dict?
3280 // steps ::= `steps` `(` integer-literals `)`
3281 //
parseAffineParallelOp(OpAsmParser & parser,OperationState & result)3282 static ParseResult parseAffineParallelOp(OpAsmParser &parser,
3283 OperationState &result) {
3284 auto &builder = parser.getBuilder();
3285 auto indexType = builder.getIndexType();
3286 SmallVector<OpAsmParser::OperandType, 4> ivs;
3287 if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
3288 OpAsmParser::Delimiter::Paren) ||
3289 parser.parseEqual() ||
3290 parseAffineMapWithMinMax(parser, result, MinMaxKind::Max) ||
3291 parser.parseKeyword("to") ||
3292 parseAffineMapWithMinMax(parser, result, MinMaxKind::Min))
3293 return failure();
3294
3295 AffineMapAttr stepsMapAttr;
3296 NamedAttrList stepsAttrs;
3297 SmallVector<OpAsmParser::OperandType, 4> stepsMapOperands;
3298 if (failed(parser.parseOptionalKeyword("step"))) {
3299 SmallVector<int64_t, 4> steps(ivs.size(), 1);
3300 result.addAttribute(AffineParallelOp::getStepsAttrName(),
3301 builder.getI64ArrayAttr(steps));
3302 } else {
3303 if (parser.parseAffineMapOfSSAIds(stepsMapOperands, stepsMapAttr,
3304 AffineParallelOp::getStepsAttrName(),
3305 stepsAttrs,
3306 OpAsmParser::Delimiter::Paren))
3307 return failure();
3308
3309 // Convert steps from an AffineMap into an I64ArrayAttr.
3310 SmallVector<int64_t, 4> steps;
3311 auto stepsMap = stepsMapAttr.getValue();
3312 for (const auto &result : stepsMap.getResults()) {
3313 auto constExpr = result.dyn_cast<AffineConstantExpr>();
3314 if (!constExpr)
3315 return parser.emitError(parser.getNameLoc(),
3316 "steps must be constant integers");
3317 steps.push_back(constExpr.getValue());
3318 }
3319 result.addAttribute(AffineParallelOp::getStepsAttrName(),
3320 builder.getI64ArrayAttr(steps));
3321 }
3322
3323 // Parse optional clause of the form: `reduce ("addf", "maxf")`, where the
3324 // quoted strings are a member of the enum AtomicRMWKind.
3325 SmallVector<Attribute, 4> reductions;
3326 if (succeeded(parser.parseOptionalKeyword("reduce"))) {
3327 if (parser.parseLParen())
3328 return failure();
3329 do {
3330 // Parse a single quoted string via the attribute parsing, and then
3331 // verify it is a member of the enum and convert to it's integer
3332 // representation.
3333 StringAttr attrVal;
3334 NamedAttrList attrStorage;
3335 auto loc = parser.getCurrentLocation();
3336 if (parser.parseAttribute(attrVal, builder.getNoneType(), "reduce",
3337 attrStorage))
3338 return failure();
3339 llvm::Optional<AtomicRMWKind> reduction =
3340 symbolizeAtomicRMWKind(attrVal.getValue());
3341 if (!reduction)
3342 return parser.emitError(loc, "invalid reduction value: ") << attrVal;
3343 reductions.push_back(builder.getI64IntegerAttr(
3344 static_cast<int64_t>(reduction.getValue())));
3345 // While we keep getting commas, keep parsing.
3346 } while (succeeded(parser.parseOptionalComma()));
3347 if (parser.parseRParen())
3348 return failure();
3349 }
3350 result.addAttribute(AffineParallelOp::getReductionsAttrName(),
3351 builder.getArrayAttr(reductions));
3352
3353 // Parse return types of reductions (if any)
3354 if (parser.parseOptionalArrowTypeList(result.types))
3355 return failure();
3356
3357 // Now parse the body.
3358 Region *body = result.addRegion();
3359 SmallVector<Type, 4> types(ivs.size(), indexType);
3360 if (parser.parseRegion(*body, ivs, types) ||
3361 parser.parseOptionalAttrDict(result.attributes))
3362 return failure();
3363
3364 // Add a terminator if none was parsed.
3365 AffineParallelOp::ensureTerminator(*body, builder, result.location);
3366 return success();
3367 }
3368
3369 //===----------------------------------------------------------------------===//
3370 // AffineYieldOp
3371 //===----------------------------------------------------------------------===//
3372
verify(AffineYieldOp op)3373 static LogicalResult verify(AffineYieldOp op) {
3374 auto *parentOp = op->getParentOp();
3375 auto results = parentOp->getResults();
3376 auto operands = op.getOperands();
3377
3378 if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
3379 return op.emitOpError() << "only terminates affine.if/for/parallel regions";
3380 if (parentOp->getNumResults() != op.getNumOperands())
3381 return op.emitOpError() << "parent of yield must have same number of "
3382 "results as the yield operands";
3383 for (auto it : llvm::zip(results, operands)) {
3384 if (std::get<0>(it).getType() != std::get<1>(it).getType())
3385 return op.emitOpError()
3386 << "types mismatch between yield op and its parent";
3387 }
3388
3389 return success();
3390 }
3391
3392 //===----------------------------------------------------------------------===//
3393 // AffineVectorLoadOp
3394 //===----------------------------------------------------------------------===//
3395
build(OpBuilder & builder,OperationState & result,VectorType resultType,AffineMap map,ValueRange operands)3396 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
3397 VectorType resultType, AffineMap map,
3398 ValueRange operands) {
3399 assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
3400 result.addOperands(operands);
3401 if (map)
3402 result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
3403 result.types.push_back(resultType);
3404 }
3405
build(OpBuilder & builder,OperationState & result,VectorType resultType,Value memref,AffineMap map,ValueRange mapOperands)3406 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
3407 VectorType resultType, Value memref,
3408 AffineMap map, ValueRange mapOperands) {
3409 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3410 result.addOperands(memref);
3411 result.addOperands(mapOperands);
3412 result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
3413 result.types.push_back(resultType);
3414 }
3415
build(OpBuilder & builder,OperationState & result,VectorType resultType,Value memref,ValueRange indices)3416 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
3417 VectorType resultType, Value memref,
3418 ValueRange indices) {
3419 auto memrefType = memref.getType().cast<MemRefType>();
3420 int64_t rank = memrefType.getRank();
3421 // Create identity map for memrefs with at least one dimension or () -> ()
3422 // for zero-dimensional memrefs.
3423 auto map =
3424 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3425 build(builder, result, resultType, memref, map, indices);
3426 }
3427
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3428 void AffineVectorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3429 MLIRContext *context) {
3430 results.add<SimplifyAffineOp<AffineVectorLoadOp>>(context);
3431 }
3432
parseAffineVectorLoadOp(OpAsmParser & parser,OperationState & result)3433 static ParseResult parseAffineVectorLoadOp(OpAsmParser &parser,
3434 OperationState &result) {
3435 auto &builder = parser.getBuilder();
3436 auto indexTy = builder.getIndexType();
3437
3438 MemRefType memrefType;
3439 VectorType resultType;
3440 OpAsmParser::OperandType memrefInfo;
3441 AffineMapAttr mapAttr;
3442 SmallVector<OpAsmParser::OperandType, 1> mapOperands;
3443 return failure(
3444 parser.parseOperand(memrefInfo) ||
3445 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3446 AffineVectorLoadOp::getMapAttrName(),
3447 result.attributes) ||
3448 parser.parseOptionalAttrDict(result.attributes) ||
3449 parser.parseColonType(memrefType) || parser.parseComma() ||
3450 parser.parseType(resultType) ||
3451 parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
3452 parser.resolveOperands(mapOperands, indexTy, result.operands) ||
3453 parser.addTypeToList(resultType, result.types));
3454 }
3455
print(OpAsmPrinter & p,AffineVectorLoadOp op)3456 static void print(OpAsmPrinter &p, AffineVectorLoadOp op) {
3457 p << "affine.vector_load " << op.getMemRef() << '[';
3458 if (AffineMapAttr mapAttr =
3459 op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
3460 p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
3461 p << ']';
3462 p.printOptionalAttrDict(op->getAttrs(),
3463 /*elidedAttrs=*/{op.getMapAttrName()});
3464 p << " : " << op.getMemRefType() << ", " << op.getType();
3465 }
3466
3467 /// Verify common invariants of affine.vector_load and affine.vector_store.
verifyVectorMemoryOp(Operation * op,MemRefType memrefType,VectorType vectorType)3468 static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType,
3469 VectorType vectorType) {
3470 // Check that memref and vector element types match.
3471 if (memrefType.getElementType() != vectorType.getElementType())
3472 return op->emitOpError(
3473 "requires memref and vector types of the same elemental type");
3474 return success();
3475 }
3476
verify(AffineVectorLoadOp op)3477 static LogicalResult verify(AffineVectorLoadOp op) {
3478 MemRefType memrefType = op.getMemRefType();
3479 if (failed(verifyMemoryOpIndexing(
3480 op.getOperation(),
3481 op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
3482 op.getMapOperands(), memrefType,
3483 /*numIndexOperands=*/op.getNumOperands() - 1)))
3484 return failure();
3485
3486 if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType,
3487 op.getVectorType())))
3488 return failure();
3489
3490 return success();
3491 }
3492
3493 //===----------------------------------------------------------------------===//
3494 // AffineVectorStoreOp
3495 //===----------------------------------------------------------------------===//
3496
build(OpBuilder & builder,OperationState & result,Value valueToStore,Value memref,AffineMap map,ValueRange mapOperands)3497 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
3498 Value valueToStore, Value memref, AffineMap map,
3499 ValueRange mapOperands) {
3500 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3501 result.addOperands(valueToStore);
3502 result.addOperands(memref);
3503 result.addOperands(mapOperands);
3504 result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
3505 }
3506
3507 // Use identity map.
build(OpBuilder & builder,OperationState & result,Value valueToStore,Value memref,ValueRange indices)3508 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
3509 Value valueToStore, Value memref,
3510 ValueRange indices) {
3511 auto memrefType = memref.getType().cast<MemRefType>();
3512 int64_t rank = memrefType.getRank();
3513 // Create identity map for memrefs with at least one dimension or () -> ()
3514 // for zero-dimensional memrefs.
3515 auto map =
3516 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3517 build(builder, result, valueToStore, memref, map, indices);
3518 }
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3519 void AffineVectorStoreOp::getCanonicalizationPatterns(
3520 RewritePatternSet &results, MLIRContext *context) {
3521 results.add<SimplifyAffineOp<AffineVectorStoreOp>>(context);
3522 }
3523
parseAffineVectorStoreOp(OpAsmParser & parser,OperationState & result)3524 static ParseResult parseAffineVectorStoreOp(OpAsmParser &parser,
3525 OperationState &result) {
3526 auto indexTy = parser.getBuilder().getIndexType();
3527
3528 MemRefType memrefType;
3529 VectorType resultType;
3530 OpAsmParser::OperandType storeValueInfo;
3531 OpAsmParser::OperandType memrefInfo;
3532 AffineMapAttr mapAttr;
3533 SmallVector<OpAsmParser::OperandType, 1> mapOperands;
3534 return failure(
3535 parser.parseOperand(storeValueInfo) || parser.parseComma() ||
3536 parser.parseOperand(memrefInfo) ||
3537 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3538 AffineVectorStoreOp::getMapAttrName(),
3539 result.attributes) ||
3540 parser.parseOptionalAttrDict(result.attributes) ||
3541 parser.parseColonType(memrefType) || parser.parseComma() ||
3542 parser.parseType(resultType) ||
3543 parser.resolveOperand(storeValueInfo, resultType, result.operands) ||
3544 parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
3545 parser.resolveOperands(mapOperands, indexTy, result.operands));
3546 }
3547
print(OpAsmPrinter & p,AffineVectorStoreOp op)3548 static void print(OpAsmPrinter &p, AffineVectorStoreOp op) {
3549 p << "affine.vector_store " << op.getValueToStore();
3550 p << ", " << op.getMemRef() << '[';
3551 if (AffineMapAttr mapAttr =
3552 op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
3553 p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
3554 p << ']';
3555 p.printOptionalAttrDict(op->getAttrs(),
3556 /*elidedAttrs=*/{op.getMapAttrName()});
3557 p << " : " << op.getMemRefType() << ", " << op.getValueToStore().getType();
3558 }
3559
verify(AffineVectorStoreOp op)3560 static LogicalResult verify(AffineVectorStoreOp op) {
3561 MemRefType memrefType = op.getMemRefType();
3562 if (failed(verifyMemoryOpIndexing(
3563 op.getOperation(),
3564 op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
3565 op.getMapOperands(), memrefType,
3566 /*numIndexOperands=*/op.getNumOperands() - 2)))
3567 return failure();
3568
3569 if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType,
3570 op.getVectorType())))
3571 return failure();
3572
3573 return success();
3574 }
3575
3576 //===----------------------------------------------------------------------===//
3577 // TableGen'd op method definitions
3578 //===----------------------------------------------------------------------===//
3579
3580 #define GET_OP_CLASSES
3581 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
3582