1 //===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/IR/BlockAndValueMapping.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/IR/IntegerSet.h"
15 #include "mlir/IR/Matchers.h"
16 #include "mlir/IR/OpImplementation.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/Transforms/InliningUtils.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallBitVector.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23
24 using namespace mlir;
25 using llvm::dbgs;
26
27 #define DEBUG_TYPE "affine-analysis"
28
29 /// A utility function to check if a value is defined at the top level of
30 /// `region` or is an argument of `region`. A value of index type defined at the
31 /// top level of a `AffineScope` region is always a valid symbol for all
32 /// uses in that region.
isTopLevelValue(Value value,Region * region)33 static bool isTopLevelValue(Value value, Region *region) {
34 if (auto arg = value.dyn_cast<BlockArgument>())
35 return arg.getParentRegion() == region;
36 return value.getDefiningOp()->getParentRegion() == region;
37 }
38
39 /// Checks if `value` known to be a legal affine dimension or symbol in `src`
40 /// region remains legal if the operation that uses it is inlined into `dest`
41 /// with the given value mapping. `legalityCheck` is either `isValidDim` or
42 /// `isValidSymbol`, depending on the value being required to remain a valid
43 /// dimension or symbol.
44 static bool
remainsLegalAfterInline(Value value,Region * src,Region * dest,const BlockAndValueMapping & mapping,function_ref<bool (Value,Region *)> legalityCheck)45 remainsLegalAfterInline(Value value, Region *src, Region *dest,
46 const BlockAndValueMapping &mapping,
47 function_ref<bool(Value, Region *)> legalityCheck) {
48 // If the value is a valid dimension for any other reason than being
49 // a top-level value, it will remain valid: constants get inlined
50 // with the function, transitive affine applies also get inlined and
51 // will be checked themselves, etc.
52 if (!isTopLevelValue(value, src))
53 return true;
54
55 // If it's a top-level value because it's a block operand, i.e. a
56 // function argument, check whether the value replacing it after
57 // inlining is a valid dimension in the new region.
58 if (value.isa<BlockArgument>())
59 return legalityCheck(mapping.lookup(value), dest);
60
61 // If it's a top-level value beacuse it's defined in the region,
62 // it can only be inlined if the defining op is a constant or a
63 // `dim`, which can appear anywhere and be valid, since the defining
64 // op won't be top-level anymore after inlining.
65 Attribute operandCst;
66 return matchPattern(value.getDefiningOp(), m_Constant(&operandCst)) ||
67 value.getDefiningOp<DimOp>();
68 }
69
70 /// Checks if all values known to be legal affine dimensions or symbols in `src`
71 /// remain so if their respective users are inlined into `dest`.
72 static bool
remainsLegalAfterInline(ValueRange values,Region * src,Region * dest,const BlockAndValueMapping & mapping,function_ref<bool (Value,Region *)> legalityCheck)73 remainsLegalAfterInline(ValueRange values, Region *src, Region *dest,
74 const BlockAndValueMapping &mapping,
75 function_ref<bool(Value, Region *)> legalityCheck) {
76 return llvm::all_of(values, [&](Value v) {
77 return remainsLegalAfterInline(v, src, dest, mapping, legalityCheck);
78 });
79 }
80
81 /// Checks if an affine read or write operation remains legal after inlining
82 /// from `src` to `dest`.
83 template <typename OpTy>
remainsLegalAfterInline(OpTy op,Region * src,Region * dest,const BlockAndValueMapping & mapping)84 static bool remainsLegalAfterInline(OpTy op, Region *src, Region *dest,
85 const BlockAndValueMapping &mapping) {
86 static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
87 AffineWriteOpInterface>::value,
88 "only ops with affine read/write interface are supported");
89
90 AffineMap map = op.getAffineMap();
91 ValueRange dimOperands = op.getMapOperands().take_front(map.getNumDims());
92 ValueRange symbolOperands =
93 op.getMapOperands().take_back(map.getNumSymbols());
94 if (!remainsLegalAfterInline(
95 dimOperands, src, dest, mapping,
96 static_cast<bool (*)(Value, Region *)>(isValidDim)))
97 return false;
98 if (!remainsLegalAfterInline(
99 symbolOperands, src, dest, mapping,
100 static_cast<bool (*)(Value, Region *)>(isValidSymbol)))
101 return false;
102 return true;
103 }
104
105 /// Checks if an affine apply operation remains legal after inlining from `src`
106 /// to `dest`.
107 template <>
remainsLegalAfterInline(AffineApplyOp op,Region * src,Region * dest,const BlockAndValueMapping & mapping)108 bool remainsLegalAfterInline(AffineApplyOp op, Region *src, Region *dest,
109 const BlockAndValueMapping &mapping) {
110 // If it's a valid dimension, we need to check that it remains so.
111 if (isValidDim(op.getResult(), src))
112 return remainsLegalAfterInline(
113 op.getMapOperands(), src, dest, mapping,
114 static_cast<bool (*)(Value, Region *)>(isValidDim));
115
116 // Otherwise it must be a valid symbol, check that it remains so.
117 return remainsLegalAfterInline(
118 op.getMapOperands(), src, dest, mapping,
119 static_cast<bool (*)(Value, Region *)>(isValidSymbol));
120 }
121
122 //===----------------------------------------------------------------------===//
123 // AffineDialect Interfaces
124 //===----------------------------------------------------------------------===//
125
126 namespace {
127 /// This class defines the interface for handling inlining with affine
128 /// operations.
129 struct AffineInlinerInterface : public DialectInlinerInterface {
130 using DialectInlinerInterface::DialectInlinerInterface;
131
132 //===--------------------------------------------------------------------===//
133 // Analysis Hooks
134 //===--------------------------------------------------------------------===//
135
136 /// Returns true if the given region 'src' can be inlined into the region
137 /// 'dest' that is attached to an operation registered to the current dialect.
138 /// 'wouldBeCloned' is set if the region is cloned into its new location
139 /// rather than moved, indicating there may be other users.
isLegalToInline__anone6b2d6170211::AffineInlinerInterface140 bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
141 BlockAndValueMapping &valueMapping) const final {
142 // We can inline into affine loops and conditionals if this doesn't break
143 // affine value categorization rules.
144 Operation *destOp = dest->getParentOp();
145 if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
146 return false;
147
148 // Multi-block regions cannot be inlined into affine constructs, all of
149 // which require single-block regions.
150 if (!llvm::hasSingleElement(*src))
151 return false;
152
153 // Side-effecting operations that the affine dialect cannot understand
154 // should not be inlined.
155 Block &srcBlock = src->front();
156 for (Operation &op : srcBlock) {
157 // Ops with no side effects are fine,
158 if (auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
159 if (iface.hasNoEffect())
160 continue;
161 }
162
163 // Assuming the inlined region is valid, we only need to check if the
164 // inlining would change it.
165 bool remainsValid =
166 llvm::TypeSwitch<Operation *, bool>(&op)
167 .Case<AffineApplyOp, AffineReadOpInterface,
168 AffineWriteOpInterface>([&](auto op) {
169 return remainsLegalAfterInline(op, src, dest, valueMapping);
170 })
171 .Default([](Operation *) {
172 // Conservatively disallow inlining ops we cannot reason about.
173 return false;
174 });
175
176 if (!remainsValid)
177 return false;
178 }
179
180 return true;
181 }
182
183 /// Returns true if the given operation 'op', that is registered to this
184 /// dialect, can be inlined into the given region, false otherwise.
isLegalToInline__anone6b2d6170211::AffineInlinerInterface185 bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
186 BlockAndValueMapping &valueMapping) const final {
187 // Always allow inlining affine operations into a region that is marked as
188 // affine scope, or into affine loops and conditionals. There are some edge
189 // cases when inlining *into* affine structures, but that is handled in the
190 // other 'isLegalToInline' hook above.
191 Operation *parentOp = region->getParentOp();
192 return parentOp->hasTrait<OpTrait::AffineScope>() ||
193 isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
194 }
195
196 /// Affine regions should be analyzed recursively.
shouldAnalyzeRecursively__anone6b2d6170211::AffineInlinerInterface197 bool shouldAnalyzeRecursively(Operation *op) const final { return true; }
198 };
199 } // end anonymous namespace
200
201 //===----------------------------------------------------------------------===//
202 // AffineDialect
203 //===----------------------------------------------------------------------===//
204
initialize()205 void AffineDialect::initialize() {
206 addOperations<AffineDmaStartOp, AffineDmaWaitOp,
207 #define GET_OP_LIST
208 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
209 >();
210 addInterfaces<AffineInlinerInterface>();
211 }
212
213 /// Materialize a single constant operation from a given attribute value with
214 /// the desired resultant type.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)215 Operation *AffineDialect::materializeConstant(OpBuilder &builder,
216 Attribute value, Type type,
217 Location loc) {
218 return builder.create<ConstantOp>(loc, type, value);
219 }
220
221 /// A utility function to check if a value is defined at the top level of an
222 /// op with trait `AffineScope`. If the value is defined in an unlinked region,
223 /// conservatively assume it is not top-level. A value of index type defined at
224 /// the top level is always a valid symbol.
isTopLevelValue(Value value)225 bool mlir::isTopLevelValue(Value value) {
226 if (auto arg = value.dyn_cast<BlockArgument>()) {
227 // The block owning the argument may be unlinked, e.g. when the surrounding
228 // region has not yet been attached to an Op, at which point the parent Op
229 // is null.
230 Operation *parentOp = arg.getOwner()->getParentOp();
231 return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
232 }
233 // The defining Op may live in an unlinked block so its parent Op may be null.
234 Operation *parentOp = value.getDefiningOp()->getParentOp();
235 return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
236 }
237
238 /// Returns the closest region enclosing `op` that is held by an operation with
239 /// trait `AffineScope`; `nullptr` if there is no such region.
240 // TODO: getAffineScope should be publicly exposed for affine passes/utilities.
getAffineScope(Operation * op)241 static Region *getAffineScope(Operation *op) {
242 auto *curOp = op;
243 while (auto *parentOp = curOp->getParentOp()) {
244 if (parentOp->hasTrait<OpTrait::AffineScope>())
245 return curOp->getParentRegion();
246 curOp = parentOp;
247 }
248 return nullptr;
249 }
250
251 // A Value can be used as a dimension id iff it meets one of the following
252 // conditions:
253 // *) It is valid as a symbol.
254 // *) It is an induction variable.
255 // *) It is the result of affine apply operation with dimension id arguments.
isValidDim(Value value)256 bool mlir::isValidDim(Value value) {
257 // The value must be an index type.
258 if (!value.getType().isIndex())
259 return false;
260
261 if (auto *defOp = value.getDefiningOp())
262 return isValidDim(value, getAffineScope(defOp));
263
264 // This value has to be a block argument for an op that has the
265 // `AffineScope` trait or for an affine.for or affine.parallel.
266 auto *parentOp = value.cast<BlockArgument>().getOwner()->getParentOp();
267 return parentOp && (parentOp->hasTrait<OpTrait::AffineScope>() ||
268 isa<AffineForOp, AffineParallelOp>(parentOp));
269 }
270
271 // Value can be used as a dimension id iff it meets one of the following
272 // conditions:
273 // *) It is valid as a symbol.
274 // *) It is an induction variable.
275 // *) It is the result of an affine apply operation with dimension id operands.
isValidDim(Value value,Region * region)276 bool mlir::isValidDim(Value value, Region *region) {
277 // The value must be an index type.
278 if (!value.getType().isIndex())
279 return false;
280
281 // All valid symbols are okay.
282 if (isValidSymbol(value, region))
283 return true;
284
285 auto *op = value.getDefiningOp();
286 if (!op) {
287 // This value has to be a block argument for an affine.for or an
288 // affine.parallel.
289 auto *parentOp = value.cast<BlockArgument>().getOwner()->getParentOp();
290 return isa<AffineForOp, AffineParallelOp>(parentOp);
291 }
292
293 // Affine apply operation is ok if all of its operands are ok.
294 if (auto applyOp = dyn_cast<AffineApplyOp>(op))
295 return applyOp.isValidDim(region);
296 // The dim op is okay if its operand memref/tensor is defined at the top
297 // level.
298 if (auto dimOp = dyn_cast<DimOp>(op))
299 return isTopLevelValue(dimOp.memrefOrTensor());
300 return false;
301 }
302
303 /// Returns true if the 'index' dimension of the `memref` defined by
304 /// `memrefDefOp` is a statically shaped one or defined using a valid symbol
305 /// for `region`.
306 template <typename AnyMemRefDefOp>
isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp,unsigned index,Region * region)307 static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index,
308 Region *region) {
309 auto memRefType = memrefDefOp.getType();
310 // Statically shaped.
311 if (!memRefType.isDynamicDim(index))
312 return true;
313 // Get the position of the dimension among dynamic dimensions;
314 unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
315 return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
316 region);
317 }
318
319 /// Returns true if the result of the dim op is a valid symbol for `region`.
isDimOpValidSymbol(DimOp dimOp,Region * region)320 static bool isDimOpValidSymbol(DimOp dimOp, Region *region) {
321 // The dim op is okay if its operand memref/tensor is defined at the top
322 // level.
323 if (isTopLevelValue(dimOp.memrefOrTensor()))
324 return true;
325
326 // Conservatively handle remaining BlockArguments as non-valid symbols.
327 // E.g. scf.for iterArgs.
328 if (dimOp.memrefOrTensor().isa<BlockArgument>())
329 return false;
330
331 // The dim op is also okay if its operand memref/tensor is a view/subview
332 // whose corresponding size is a valid symbol.
333 Optional<int64_t> index = dimOp.getConstantIndex();
334 assert(index.hasValue() &&
335 "expect only `dim` operations with a constant index");
336 int64_t i = index.getValue();
337 return TypeSwitch<Operation *, bool>(dimOp.memrefOrTensor().getDefiningOp())
338 .Case<ViewOp, SubViewOp, AllocOp>(
339 [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); })
340 .Default([](Operation *) { return false; });
341 }
342
343 // A value can be used as a symbol (at all its use sites) iff it meets one of
344 // the following conditions:
345 // *) It is a constant.
346 // *) Its defining op or block arg appearance is immediately enclosed by an op
347 // with `AffineScope` trait.
348 // *) It is the result of an affine.apply operation with symbol operands.
349 // *) It is a result of the dim op on a memref whose corresponding size is a
350 // valid symbol.
isValidSymbol(Value value)351 bool mlir::isValidSymbol(Value value) {
352 // The value must be an index type.
353 if (!value.getType().isIndex())
354 return false;
355
356 // Check that the value is a top level value.
357 if (isTopLevelValue(value))
358 return true;
359
360 if (auto *defOp = value.getDefiningOp())
361 return isValidSymbol(value, getAffineScope(defOp));
362
363 return false;
364 }
365
366 /// A value can be used as a symbol for `region` iff it meets onf of the the
367 /// following conditions:
368 /// *) It is a constant.
369 /// *) It is the result of an affine apply operation with symbol arguments.
370 /// *) It is a result of the dim op on a memref whose corresponding size is
371 /// a valid symbol.
372 /// *) It is defined at the top level of 'region' or is its argument.
373 /// *) It dominates `region`'s parent op.
374 /// If `region` is null, conservatively assume the symbol definition scope does
375 /// not exist and only accept the values that would be symbols regardless of
376 /// the surrounding region structure, i.e. the first three cases above.
isValidSymbol(Value value,Region * region)377 bool mlir::isValidSymbol(Value value, Region *region) {
378 // The value must be an index type.
379 if (!value.getType().isIndex())
380 return false;
381
382 // A top-level value is a valid symbol.
383 if (region && ::isTopLevelValue(value, region))
384 return true;
385
386 auto *defOp = value.getDefiningOp();
387 if (!defOp) {
388 // A block argument that is not a top-level value is a valid symbol if it
389 // dominates region's parent op.
390 if (region && !region->getParentOp()->isKnownIsolatedFromAbove())
391 if (auto *parentOpRegion = region->getParentOp()->getParentRegion())
392 return isValidSymbol(value, parentOpRegion);
393 return false;
394 }
395
396 // Constant operation is ok.
397 Attribute operandCst;
398 if (matchPattern(defOp, m_Constant(&operandCst)))
399 return true;
400
401 // Affine apply operation is ok if all of its operands are ok.
402 if (auto applyOp = dyn_cast<AffineApplyOp>(defOp))
403 return applyOp.isValidSymbol(region);
404
405 // Dim op results could be valid symbols at any level.
406 if (auto dimOp = dyn_cast<DimOp>(defOp))
407 return isDimOpValidSymbol(dimOp, region);
408
409 // Check for values dominating `region`'s parent op.
410 if (region && !region->getParentOp()->isKnownIsolatedFromAbove())
411 if (auto *parentRegion = region->getParentOp()->getParentRegion())
412 return isValidSymbol(value, parentRegion);
413
414 return false;
415 }
416
417 // Returns true if 'value' is a valid index to an affine operation (e.g.
418 // affine.load, affine.store, affine.dma_start, affine.dma_wait) where
419 // `region` provides the polyhedral symbol scope. Returns false otherwise.
isValidAffineIndexOperand(Value value,Region * region)420 static bool isValidAffineIndexOperand(Value value, Region *region) {
421 return isValidDim(value, region) || isValidSymbol(value, region);
422 }
423
424 /// Prints dimension and symbol list.
printDimAndSymbolList(Operation::operand_iterator begin,Operation::operand_iterator end,unsigned numDims,OpAsmPrinter & printer)425 static void printDimAndSymbolList(Operation::operand_iterator begin,
426 Operation::operand_iterator end,
427 unsigned numDims, OpAsmPrinter &printer) {
428 OperandRange operands(begin, end);
429 printer << '(' << operands.take_front(numDims) << ')';
430 if (operands.size() > numDims)
431 printer << '[' << operands.drop_front(numDims) << ']';
432 }
433
434 /// Parses dimension and symbol list and returns true if parsing failed.
parseDimAndSymbolList(OpAsmParser & parser,SmallVectorImpl<Value> & operands,unsigned & numDims)435 ParseResult mlir::parseDimAndSymbolList(OpAsmParser &parser,
436 SmallVectorImpl<Value> &operands,
437 unsigned &numDims) {
438 SmallVector<OpAsmParser::OperandType, 8> opInfos;
439 if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren))
440 return failure();
441 // Store number of dimensions for validation by caller.
442 numDims = opInfos.size();
443
444 // Parse the optional symbol operands.
445 auto indexTy = parser.getBuilder().getIndexType();
446 return failure(parser.parseOperandList(
447 opInfos, OpAsmParser::Delimiter::OptionalSquare) ||
448 parser.resolveOperands(opInfos, indexTy, operands));
449 }
450
451 /// Utility function to verify that a set of operands are valid dimension and
452 /// symbol identifiers. The operands should be laid out such that the dimension
453 /// operands are before the symbol operands. This function returns failure if
454 /// there was an invalid operand. An operation is provided to emit any necessary
455 /// errors.
456 template <typename OpTy>
457 static LogicalResult
verifyDimAndSymbolIdentifiers(OpTy & op,Operation::operand_range operands,unsigned numDims)458 verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands,
459 unsigned numDims) {
460 unsigned opIt = 0;
461 for (auto operand : operands) {
462 if (opIt++ < numDims) {
463 if (!isValidDim(operand, getAffineScope(op)))
464 return op.emitOpError("operand cannot be used as a dimension id");
465 } else if (!isValidSymbol(operand, getAffineScope(op))) {
466 return op.emitOpError("operand cannot be used as a symbol");
467 }
468 }
469 return success();
470 }
471
472 //===----------------------------------------------------------------------===//
473 // AffineApplyOp
474 //===----------------------------------------------------------------------===//
475
getAffineValueMap()476 AffineValueMap AffineApplyOp::getAffineValueMap() {
477 return AffineValueMap(getAffineMap(), getOperands(), getResult());
478 }
479
parseAffineApplyOp(OpAsmParser & parser,OperationState & result)480 static ParseResult parseAffineApplyOp(OpAsmParser &parser,
481 OperationState &result) {
482 auto &builder = parser.getBuilder();
483 auto indexTy = builder.getIndexType();
484
485 AffineMapAttr mapAttr;
486 unsigned numDims;
487 if (parser.parseAttribute(mapAttr, "map", result.attributes) ||
488 parseDimAndSymbolList(parser, result.operands, numDims) ||
489 parser.parseOptionalAttrDict(result.attributes))
490 return failure();
491 auto map = mapAttr.getValue();
492
493 if (map.getNumDims() != numDims ||
494 numDims + map.getNumSymbols() != result.operands.size()) {
495 return parser.emitError(parser.getNameLoc(),
496 "dimension or symbol index mismatch");
497 }
498
499 result.types.append(map.getNumResults(), indexTy);
500 return success();
501 }
502
print(OpAsmPrinter & p,AffineApplyOp op)503 static void print(OpAsmPrinter &p, AffineApplyOp op) {
504 p << AffineApplyOp::getOperationName() << " " << op.mapAttr();
505 printDimAndSymbolList(op.operand_begin(), op.operand_end(),
506 op.getAffineMap().getNumDims(), p);
507 p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"});
508 }
509
verify(AffineApplyOp op)510 static LogicalResult verify(AffineApplyOp op) {
511 // Check input and output dimensions match.
512 auto map = op.map();
513
514 // Verify that operand count matches affine map dimension and symbol count.
515 if (op.getNumOperands() != map.getNumDims() + map.getNumSymbols())
516 return op.emitOpError(
517 "operand count and affine map dimension and symbol count must match");
518
519 // Verify that the map only produces one result.
520 if (map.getNumResults() != 1)
521 return op.emitOpError("mapping must produce one value");
522
523 return success();
524 }
525
526 // The result of the affine apply operation can be used as a dimension id if all
527 // its operands are valid dimension ids.
isValidDim()528 bool AffineApplyOp::isValidDim() {
529 return llvm::all_of(getOperands(),
530 [](Value op) { return mlir::isValidDim(op); });
531 }
532
533 // The result of the affine apply operation can be used as a dimension id if all
534 // its operands are valid dimension ids with the parent operation of `region`
535 // defining the polyhedral scope for symbols.
isValidDim(Region * region)536 bool AffineApplyOp::isValidDim(Region *region) {
537 return llvm::all_of(getOperands(),
538 [&](Value op) { return ::isValidDim(op, region); });
539 }
540
541 // The result of the affine apply operation can be used as a symbol if all its
542 // operands are symbols.
isValidSymbol()543 bool AffineApplyOp::isValidSymbol() {
544 return llvm::all_of(getOperands(),
545 [](Value op) { return mlir::isValidSymbol(op); });
546 }
547
548 // The result of the affine apply operation can be used as a symbol in `region`
549 // if all its operands are symbols in `region`.
isValidSymbol(Region * region)550 bool AffineApplyOp::isValidSymbol(Region *region) {
551 return llvm::all_of(getOperands(), [&](Value operand) {
552 return mlir::isValidSymbol(operand, region);
553 });
554 }
555
fold(ArrayRef<Attribute> operands)556 OpFoldResult AffineApplyOp::fold(ArrayRef<Attribute> operands) {
557 auto map = getAffineMap();
558
559 // Fold dims and symbols to existing values.
560 auto expr = map.getResult(0);
561 if (auto dim = expr.dyn_cast<AffineDimExpr>())
562 return getOperand(dim.getPosition());
563 if (auto sym = expr.dyn_cast<AffineSymbolExpr>())
564 return getOperand(map.getNumDims() + sym.getPosition());
565
566 // Otherwise, default to folding the map.
567 SmallVector<Attribute, 1> result;
568 if (failed(map.constantFold(operands, result)))
569 return {};
570 return result[0];
571 }
572
573 /// Replace all occurrences of AffineExpr at position `pos` in `map` by the
574 /// defining AffineApplyOp expression and operands.
575 /// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
576 /// When `dimOrSymbolPosition >= dims.size()`,
577 /// AffineSymbolExpr@[pos - dims.size()] is replaced.
578 /// Mutate `map`,`dims` and `syms` in place as follows:
579 /// 1. `dims` and `syms` are only appended to.
580 /// 2. `map` dim and symbols are gradually shifted to higer positions.
581 /// 3. Old `dim` and `sym` entries are replaced by nullptr
582 /// This avoids the need for any bookkeeping.
replaceDimOrSym(AffineMap * map,unsigned dimOrSymbolPosition,SmallVectorImpl<Value> & dims,SmallVectorImpl<Value> & syms)583 static LogicalResult replaceDimOrSym(AffineMap *map,
584 unsigned dimOrSymbolPosition,
585 SmallVectorImpl<Value> &dims,
586 SmallVectorImpl<Value> &syms) {
587 bool isDimReplacement = (dimOrSymbolPosition < dims.size());
588 unsigned pos = isDimReplacement ? dimOrSymbolPosition
589 : dimOrSymbolPosition - dims.size();
590 Value &v = isDimReplacement ? dims[pos] : syms[pos];
591 if (!v)
592 return failure();
593
594 auto affineApply = v.getDefiningOp<AffineApplyOp>();
595 if (!affineApply)
596 return failure();
597
598 // At this point we will perform a replacement of `v`, set the entry in `dim`
599 // or `sym` to nullptr immediately.
600 v = nullptr;
601
602 // Compute the map, dims and symbols coming from the AffineApplyOp.
603 AffineMap composeMap = affineApply.getAffineMap();
604 assert(composeMap.getNumResults() == 1 && "affine.apply with >1 results");
605 AffineExpr composeExpr =
606 composeMap.shiftDims(dims.size()).shiftSymbols(syms.size()).getResult(0);
607 ValueRange composeDims =
608 affineApply.getMapOperands().take_front(composeMap.getNumDims());
609 ValueRange composeSyms =
610 affineApply.getMapOperands().take_back(composeMap.getNumSymbols());
611
612 // Perform the replacement and append the dims and symbols where relevant.
613 MLIRContext *ctx = map->getContext();
614 AffineExpr toReplace = isDimReplacement ? getAffineDimExpr(pos, ctx)
615 : getAffineSymbolExpr(pos, ctx);
616 *map = map->replace(toReplace, composeExpr, dims.size(), syms.size());
617 dims.append(composeDims.begin(), composeDims.end());
618 syms.append(composeSyms.begin(), composeSyms.end());
619
620 return success();
621 }
622
623 /// Iterate over `operands` and fold away all those produced by an AffineApplyOp
624 /// iteratively. Perform canonicalization of map and operands as well as
625 /// AffineMap simplification. `map` and `operands` are mutated in place.
composeAffineMapAndOperands(AffineMap * map,SmallVectorImpl<Value> * operands)626 static void composeAffineMapAndOperands(AffineMap *map,
627 SmallVectorImpl<Value> *operands) {
628 if (map->getNumResults() == 0) {
629 canonicalizeMapAndOperands(map, operands);
630 *map = simplifyAffineMap(*map);
631 return;
632 }
633
634 MLIRContext *ctx = map->getContext();
635 SmallVector<Value, 4> dims(operands->begin(),
636 operands->begin() + map->getNumDims());
637 SmallVector<Value, 4> syms(operands->begin() + map->getNumDims(),
638 operands->end());
639
640 // Iterate over dims and symbols coming from AffineApplyOp and replace until
641 // exhaustion. This iteratively mutates `map`, `dims` and `syms`. Both `dims`
642 // and `syms` can only increase by construction.
643 // The implementation uses a `while` loop to support the case of symbols
644 // that may be constructed from dims ;this may be overkill.
645 while (true) {
646 bool changed = false;
647 for (unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
648 if ((changed |= succeeded(replaceDimOrSym(map, pos, dims, syms))))
649 break;
650 if (!changed)
651 break;
652 }
653
654 // Clear operands so we can fill them anew.
655 operands->clear();
656
657 // At this point we may have introduced null operands, prune them out before
658 // canonicalizing map and operands.
659 unsigned nDims = 0, nSyms = 0;
660 SmallVector<AffineExpr, 4> dimReplacements, symReplacements;
661 dimReplacements.reserve(dims.size());
662 symReplacements.reserve(syms.size());
663 for (auto *container : {&dims, &syms}) {
664 bool isDim = (container == &dims);
665 auto &repls = isDim ? dimReplacements : symReplacements;
666 for (auto en : llvm::enumerate(*container)) {
667 Value v = en.value();
668 if (!v) {
669 assert(isDim ? !map->isFunctionOfDim(en.index())
670 : !map->isFunctionOfSymbol(en.index()) &&
671 "map is function of unexpected expr@pos");
672 repls.push_back(getAffineConstantExpr(0, ctx));
673 continue;
674 }
675 repls.push_back(isDim ? getAffineDimExpr(nDims++, ctx)
676 : getAffineSymbolExpr(nSyms++, ctx));
677 operands->push_back(v);
678 }
679 }
680 *map = map->replaceDimsAndSymbols(dimReplacements, symReplacements, nDims,
681 nSyms);
682
683 // Canonicalize and simplify before returning.
684 canonicalizeMapAndOperands(map, operands);
685 *map = simplifyAffineMap(*map);
686 }
687
fullyComposeAffineMapAndOperands(AffineMap * map,SmallVectorImpl<Value> * operands)688 void mlir::fullyComposeAffineMapAndOperands(AffineMap *map,
689 SmallVectorImpl<Value> *operands) {
690 while (llvm::any_of(*operands, [](Value v) {
691 return isa_and_nonnull<AffineApplyOp>(v.getDefiningOp());
692 })) {
693 composeAffineMapAndOperands(map, operands);
694 }
695 }
696
makeComposedAffineApply(OpBuilder & b,Location loc,AffineMap map,ArrayRef<Value> operands)697 AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc,
698 AffineMap map,
699 ArrayRef<Value> operands) {
700 AffineMap normalizedMap = map;
701 SmallVector<Value, 8> normalizedOperands(operands.begin(), operands.end());
702 composeAffineMapAndOperands(&normalizedMap, &normalizedOperands);
703 assert(normalizedMap);
704 return b.create<AffineApplyOp>(loc, normalizedMap, normalizedOperands);
705 }
706
707 // A symbol may appear as a dim in affine.apply operations. This function
708 // canonicalizes dims that are valid symbols into actual symbols.
709 template <class MapOrSet>
canonicalizePromotedSymbols(MapOrSet * mapOrSet,SmallVectorImpl<Value> * operands)710 static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,
711 SmallVectorImpl<Value> *operands) {
712 if (!mapOrSet || operands->empty())
713 return;
714
715 assert(mapOrSet->getNumInputs() == operands->size() &&
716 "map/set inputs must match number of operands");
717
718 auto *context = mapOrSet->getContext();
719 SmallVector<Value, 8> resultOperands;
720 resultOperands.reserve(operands->size());
721 SmallVector<Value, 8> remappedSymbols;
722 remappedSymbols.reserve(operands->size());
723 unsigned nextDim = 0;
724 unsigned nextSym = 0;
725 unsigned oldNumSyms = mapOrSet->getNumSymbols();
726 SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
727 for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
728 if (i < mapOrSet->getNumDims()) {
729 if (isValidSymbol((*operands)[i])) {
730 // This is a valid symbol that appears as a dim, canonicalize it.
731 dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context);
732 remappedSymbols.push_back((*operands)[i]);
733 } else {
734 dimRemapping[i] = getAffineDimExpr(nextDim++, context);
735 resultOperands.push_back((*operands)[i]);
736 }
737 } else {
738 resultOperands.push_back((*operands)[i]);
739 }
740 }
741
742 resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
743 *operands = resultOperands;
744 *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
745 oldNumSyms + nextSym);
746
747 assert(mapOrSet->getNumInputs() == operands->size() &&
748 "map/set inputs must match number of operands");
749 }
750
751 // Works for either an affine map or an integer set.
752 template <class MapOrSet>
canonicalizeMapOrSetAndOperands(MapOrSet * mapOrSet,SmallVectorImpl<Value> * operands)753 static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
754 SmallVectorImpl<Value> *operands) {
755 static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
756 "Argument must be either of AffineMap or IntegerSet type");
757
758 if (!mapOrSet || operands->empty())
759 return;
760
761 assert(mapOrSet->getNumInputs() == operands->size() &&
762 "map/set inputs must match number of operands");
763
764 canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
765
766 // Check to see what dims are used.
767 llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
768 llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
769 mapOrSet->walkExprs([&](AffineExpr expr) {
770 if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
771 usedDims[dimExpr.getPosition()] = true;
772 else if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
773 usedSyms[symExpr.getPosition()] = true;
774 });
775
776 auto *context = mapOrSet->getContext();
777
778 SmallVector<Value, 8> resultOperands;
779 resultOperands.reserve(operands->size());
780
781 llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
782 SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
783 unsigned nextDim = 0;
784 for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
785 if (usedDims[i]) {
786 // Remap dim positions for duplicate operands.
787 auto it = seenDims.find((*operands)[i]);
788 if (it == seenDims.end()) {
789 dimRemapping[i] = getAffineDimExpr(nextDim++, context);
790 resultOperands.push_back((*operands)[i]);
791 seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
792 } else {
793 dimRemapping[i] = it->second;
794 }
795 }
796 }
797 llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
798 SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
799 unsigned nextSym = 0;
800 for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
801 if (!usedSyms[i])
802 continue;
803 // Handle constant operands (only needed for symbolic operands since
804 // constant operands in dimensional positions would have already been
805 // promoted to symbolic positions above).
806 IntegerAttr operandCst;
807 if (matchPattern((*operands)[i + mapOrSet->getNumDims()],
808 m_Constant(&operandCst))) {
809 symRemapping[i] =
810 getAffineConstantExpr(operandCst.getValue().getSExtValue(), context);
811 continue;
812 }
813 // Remap symbol positions for duplicate operands.
814 auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
815 if (it == seenSymbols.end()) {
816 symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
817 resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
818 seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
819 symRemapping[i]));
820 } else {
821 symRemapping[i] = it->second;
822 }
823 }
824 *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
825 nextDim, nextSym);
826 *operands = resultOperands;
827 }
828
canonicalizeMapAndOperands(AffineMap * map,SmallVectorImpl<Value> * operands)829 void mlir::canonicalizeMapAndOperands(AffineMap *map,
830 SmallVectorImpl<Value> *operands) {
831 canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
832 }
833
canonicalizeSetAndOperands(IntegerSet * set,SmallVectorImpl<Value> * operands)834 void mlir::canonicalizeSetAndOperands(IntegerSet *set,
835 SmallVectorImpl<Value> *operands) {
836 canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
837 }
838
839 namespace {
840 /// Simplify AffineApply, AffineLoad, and AffineStore operations by composing
841 /// maps that supply results into them.
842 ///
843 template <typename AffineOpTy>
844 struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
845 using OpRewritePattern<AffineOpTy>::OpRewritePattern;
846
847 /// Replace the affine op with another instance of it with the supplied
848 /// map and mapOperands.
849 void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
850 AffineMap map, ArrayRef<Value> mapOperands) const;
851
matchAndRewrite__anone6b2d6170d11::SimplifyAffineOp852 LogicalResult matchAndRewrite(AffineOpTy affineOp,
853 PatternRewriter &rewriter) const override {
854 static_assert(llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
855 AffineStoreOp, AffineApplyOp, AffineMinOp,
856 AffineMaxOp>::value,
857 "affine load/store/apply/prefetch/min/max op expected");
858 auto map = affineOp.getAffineMap();
859 AffineMap oldMap = map;
860 auto oldOperands = affineOp.getMapOperands();
861 SmallVector<Value, 8> resultOperands(oldOperands);
862 composeAffineMapAndOperands(&map, &resultOperands);
863 if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
864 resultOperands.begin()))
865 return failure();
866
867 replaceAffineOp(rewriter, affineOp, map, resultOperands);
868 return success();
869 }
870 };
871
872 // Specialize the template to account for the different build signatures for
873 // affine load, store, and apply ops.
874 template <>
replaceAffineOp(PatternRewriter & rewriter,AffineLoadOp load,AffineMap map,ArrayRef<Value> mapOperands) const875 void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
876 PatternRewriter &rewriter, AffineLoadOp load, AffineMap map,
877 ArrayRef<Value> mapOperands) const {
878 rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.getMemRef(), map,
879 mapOperands);
880 }
881 template <>
replaceAffineOp(PatternRewriter & rewriter,AffinePrefetchOp prefetch,AffineMap map,ArrayRef<Value> mapOperands) const882 void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
883 PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map,
884 ArrayRef<Value> mapOperands) const {
885 rewriter.replaceOpWithNewOp<AffinePrefetchOp>(
886 prefetch, prefetch.memref(), map, mapOperands, prefetch.localityHint(),
887 prefetch.isWrite(), prefetch.isDataCache());
888 }
889 template <>
replaceAffineOp(PatternRewriter & rewriter,AffineStoreOp store,AffineMap map,ArrayRef<Value> mapOperands) const890 void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
891 PatternRewriter &rewriter, AffineStoreOp store, AffineMap map,
892 ArrayRef<Value> mapOperands) const {
893 rewriter.replaceOpWithNewOp<AffineStoreOp>(
894 store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
895 }
896
897 // Generic version for ops that don't have extra operands.
898 template <typename AffineOpTy>
replaceAffineOp(PatternRewriter & rewriter,AffineOpTy op,AffineMap map,ArrayRef<Value> mapOperands) const899 void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
900 PatternRewriter &rewriter, AffineOpTy op, AffineMap map,
901 ArrayRef<Value> mapOperands) const {
902 rewriter.replaceOpWithNewOp<AffineOpTy>(op, map, mapOperands);
903 }
904 } // end anonymous namespace.
905
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)906 void AffineApplyOp::getCanonicalizationPatterns(
907 OwningRewritePatternList &results, MLIRContext *context) {
908 results.insert<SimplifyAffineOp<AffineApplyOp>>(context);
909 }
910
911 //===----------------------------------------------------------------------===//
912 // Common canonicalization pattern support logic
913 //===----------------------------------------------------------------------===//
914
915 /// This is a common class used for patterns of the form
916 /// "someop(memrefcast) -> someop". It folds the source of any memref_cast
917 /// into the root operation directly.
foldMemRefCast(Operation * op)918 static LogicalResult foldMemRefCast(Operation *op) {
919 bool folded = false;
920 for (OpOperand &operand : op->getOpOperands()) {
921 auto cast = operand.get().getDefiningOp<MemRefCastOp>();
922 if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
923 operand.set(cast.getOperand());
924 folded = true;
925 }
926 }
927 return success(folded);
928 }
929
930 //===----------------------------------------------------------------------===//
931 // AffineDmaStartOp
932 //===----------------------------------------------------------------------===//
933
934 // TODO: Check that map operands are loop IVs or symbols.
build(OpBuilder & builder,OperationState & result,Value srcMemRef,AffineMap srcMap,ValueRange srcIndices,Value destMemRef,AffineMap dstMap,ValueRange destIndices,Value tagMemRef,AffineMap tagMap,ValueRange tagIndices,Value numElements,Value stride,Value elementsPerStride)935 void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result,
936 Value srcMemRef, AffineMap srcMap,
937 ValueRange srcIndices, Value destMemRef,
938 AffineMap dstMap, ValueRange destIndices,
939 Value tagMemRef, AffineMap tagMap,
940 ValueRange tagIndices, Value numElements,
941 Value stride, Value elementsPerStride) {
942 result.addOperands(srcMemRef);
943 result.addAttribute(getSrcMapAttrName(), AffineMapAttr::get(srcMap));
944 result.addOperands(srcIndices);
945 result.addOperands(destMemRef);
946 result.addAttribute(getDstMapAttrName(), AffineMapAttr::get(dstMap));
947 result.addOperands(destIndices);
948 result.addOperands(tagMemRef);
949 result.addAttribute(getTagMapAttrName(), AffineMapAttr::get(tagMap));
950 result.addOperands(tagIndices);
951 result.addOperands(numElements);
952 if (stride) {
953 result.addOperands({stride, elementsPerStride});
954 }
955 }
956
print(OpAsmPrinter & p)957 void AffineDmaStartOp::print(OpAsmPrinter &p) {
958 p << "affine.dma_start " << getSrcMemRef() << '[';
959 p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices());
960 p << "], " << getDstMemRef() << '[';
961 p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices());
962 p << "], " << getTagMemRef() << '[';
963 p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices());
964 p << "], " << getNumElements();
965 if (isStrided()) {
966 p << ", " << getStride();
967 p << ", " << getNumElementsPerStride();
968 }
969 p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
970 << getTagMemRefType();
971 }
972
973 // Parse AffineDmaStartOp.
974 // Ex:
975 // affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size,
976 // %stride, %num_elt_per_stride
977 // : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32>
978 //
parse(OpAsmParser & parser,OperationState & result)979 ParseResult AffineDmaStartOp::parse(OpAsmParser &parser,
980 OperationState &result) {
981 OpAsmParser::OperandType srcMemRefInfo;
982 AffineMapAttr srcMapAttr;
983 SmallVector<OpAsmParser::OperandType, 4> srcMapOperands;
984 OpAsmParser::OperandType dstMemRefInfo;
985 AffineMapAttr dstMapAttr;
986 SmallVector<OpAsmParser::OperandType, 4> dstMapOperands;
987 OpAsmParser::OperandType tagMemRefInfo;
988 AffineMapAttr tagMapAttr;
989 SmallVector<OpAsmParser::OperandType, 4> tagMapOperands;
990 OpAsmParser::OperandType numElementsInfo;
991 SmallVector<OpAsmParser::OperandType, 2> strideInfo;
992
993 SmallVector<Type, 3> types;
994 auto indexType = parser.getBuilder().getIndexType();
995
996 // Parse and resolve the following list of operands:
997 // *) dst memref followed by its affine maps operands (in square brackets).
998 // *) src memref followed by its affine map operands (in square brackets).
999 // *) tag memref followed by its affine map operands (in square brackets).
1000 // *) number of elements transferred by DMA operation.
1001 if (parser.parseOperand(srcMemRefInfo) ||
1002 parser.parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr,
1003 getSrcMapAttrName(), result.attributes) ||
1004 parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1005 parser.parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr,
1006 getDstMapAttrName(), result.attributes) ||
1007 parser.parseComma() || parser.parseOperand(tagMemRefInfo) ||
1008 parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1009 getTagMapAttrName(), result.attributes) ||
1010 parser.parseComma() || parser.parseOperand(numElementsInfo))
1011 return failure();
1012
1013 // Parse optional stride and elements per stride.
1014 if (parser.parseTrailingOperandList(strideInfo)) {
1015 return failure();
1016 }
1017 if (!strideInfo.empty() && strideInfo.size() != 2) {
1018 return parser.emitError(parser.getNameLoc(),
1019 "expected two stride related operands");
1020 }
1021 bool isStrided = strideInfo.size() == 2;
1022
1023 if (parser.parseColonTypeList(types))
1024 return failure();
1025
1026 if (types.size() != 3)
1027 return parser.emitError(parser.getNameLoc(), "expected three types");
1028
1029 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1030 parser.resolveOperands(srcMapOperands, indexType, result.operands) ||
1031 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1032 parser.resolveOperands(dstMapOperands, indexType, result.operands) ||
1033 parser.resolveOperand(tagMemRefInfo, types[2], result.operands) ||
1034 parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1035 parser.resolveOperand(numElementsInfo, indexType, result.operands))
1036 return failure();
1037
1038 if (isStrided) {
1039 if (parser.resolveOperands(strideInfo, indexType, result.operands))
1040 return failure();
1041 }
1042
1043 // Check that src/dst/tag operand counts match their map.numInputs.
1044 if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
1045 dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
1046 tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1047 return parser.emitError(parser.getNameLoc(),
1048 "memref operand count not equal to map.numInputs");
1049 return success();
1050 }
1051
verify()1052 LogicalResult AffineDmaStartOp::verify() {
1053 if (!getOperand(getSrcMemRefOperandIndex()).getType().isa<MemRefType>())
1054 return emitOpError("expected DMA source to be of memref type");
1055 if (!getOperand(getDstMemRefOperandIndex()).getType().isa<MemRefType>())
1056 return emitOpError("expected DMA destination to be of memref type");
1057 if (!getOperand(getTagMemRefOperandIndex()).getType().isa<MemRefType>())
1058 return emitOpError("expected DMA tag to be of memref type");
1059
1060 // DMAs from different memory spaces supported.
1061 if (getSrcMemorySpace() == getDstMemorySpace()) {
1062 return emitOpError("DMA should be between different memory spaces");
1063 }
1064 unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
1065 getDstMap().getNumInputs() +
1066 getTagMap().getNumInputs();
1067 if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
1068 getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
1069 return emitOpError("incorrect number of operands");
1070 }
1071
1072 Region *scope = getAffineScope(*this);
1073 for (auto idx : getSrcIndices()) {
1074 if (!idx.getType().isIndex())
1075 return emitOpError("src index to dma_start must have 'index' type");
1076 if (!isValidAffineIndexOperand(idx, scope))
1077 return emitOpError("src index must be a dimension or symbol identifier");
1078 }
1079 for (auto idx : getDstIndices()) {
1080 if (!idx.getType().isIndex())
1081 return emitOpError("dst index to dma_start must have 'index' type");
1082 if (!isValidAffineIndexOperand(idx, scope))
1083 return emitOpError("dst index must be a dimension or symbol identifier");
1084 }
1085 for (auto idx : getTagIndices()) {
1086 if (!idx.getType().isIndex())
1087 return emitOpError("tag index to dma_start must have 'index' type");
1088 if (!isValidAffineIndexOperand(idx, scope))
1089 return emitOpError("tag index must be a dimension or symbol identifier");
1090 }
1091 return success();
1092 }
1093
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1094 LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1095 SmallVectorImpl<OpFoldResult> &results) {
1096 /// dma_start(memrefcast) -> dma_start
1097 return foldMemRefCast(*this);
1098 }
1099
1100 //===----------------------------------------------------------------------===//
1101 // AffineDmaWaitOp
1102 //===----------------------------------------------------------------------===//
1103
1104 // TODO: Check that map operands are loop IVs or symbols.
build(OpBuilder & builder,OperationState & result,Value tagMemRef,AffineMap tagMap,ValueRange tagIndices,Value numElements)1105 void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result,
1106 Value tagMemRef, AffineMap tagMap,
1107 ValueRange tagIndices, Value numElements) {
1108 result.addOperands(tagMemRef);
1109 result.addAttribute(getTagMapAttrName(), AffineMapAttr::get(tagMap));
1110 result.addOperands(tagIndices);
1111 result.addOperands(numElements);
1112 }
1113
print(OpAsmPrinter & p)1114 void AffineDmaWaitOp::print(OpAsmPrinter &p) {
1115 p << "affine.dma_wait " << getTagMemRef() << '[';
1116 SmallVector<Value, 2> operands(getTagIndices());
1117 p.printAffineMapOfSSAIds(getTagMapAttr(), operands);
1118 p << "], ";
1119 p.printOperand(getNumElements());
1120 p << " : " << getTagMemRef().getType();
1121 }
1122
1123 // Parse AffineDmaWaitOp.
1124 // Eg:
1125 // affine.dma_wait %tag[%index], %num_elements
1126 // : memref<1 x i32, (d0) -> (d0), 4>
1127 //
parse(OpAsmParser & parser,OperationState & result)1128 ParseResult AffineDmaWaitOp::parse(OpAsmParser &parser,
1129 OperationState &result) {
1130 OpAsmParser::OperandType tagMemRefInfo;
1131 AffineMapAttr tagMapAttr;
1132 SmallVector<OpAsmParser::OperandType, 2> tagMapOperands;
1133 Type type;
1134 auto indexType = parser.getBuilder().getIndexType();
1135 OpAsmParser::OperandType numElementsInfo;
1136
1137 // Parse tag memref, its map operands, and dma size.
1138 if (parser.parseOperand(tagMemRefInfo) ||
1139 parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1140 getTagMapAttrName(), result.attributes) ||
1141 parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1142 parser.parseColonType(type) ||
1143 parser.resolveOperand(tagMemRefInfo, type, result.operands) ||
1144 parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1145 parser.resolveOperand(numElementsInfo, indexType, result.operands))
1146 return failure();
1147
1148 if (!type.isa<MemRefType>())
1149 return parser.emitError(parser.getNameLoc(),
1150 "expected tag to be of memref type");
1151
1152 if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1153 return parser.emitError(parser.getNameLoc(),
1154 "tag memref operand count != to map.numInputs");
1155 return success();
1156 }
1157
verify()1158 LogicalResult AffineDmaWaitOp::verify() {
1159 if (!getOperand(0).getType().isa<MemRefType>())
1160 return emitOpError("expected DMA tag to be of memref type");
1161 Region *scope = getAffineScope(*this);
1162 for (auto idx : getTagIndices()) {
1163 if (!idx.getType().isIndex())
1164 return emitOpError("index to dma_wait must have 'index' type");
1165 if (!isValidAffineIndexOperand(idx, scope))
1166 return emitOpError("index must be a dimension or symbol identifier");
1167 }
1168 return success();
1169 }
1170
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1171 LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1172 SmallVectorImpl<OpFoldResult> &results) {
1173 /// dma_wait(memrefcast) -> dma_wait
1174 return foldMemRefCast(*this);
1175 }
1176
1177 //===----------------------------------------------------------------------===//
1178 // AffineForOp
1179 //===----------------------------------------------------------------------===//
1180
1181 /// 'bodyBuilder' is used to build the body of affine.for. If iterArgs and
1182 /// bodyBuilder are empty/null, we include default terminator op.
build(OpBuilder & builder,OperationState & result,ValueRange lbOperands,AffineMap lbMap,ValueRange ubOperands,AffineMap ubMap,int64_t step,ValueRange iterArgs,BodyBuilderFn bodyBuilder)1183 void AffineForOp::build(OpBuilder &builder, OperationState &result,
1184 ValueRange lbOperands, AffineMap lbMap,
1185 ValueRange ubOperands, AffineMap ubMap, int64_t step,
1186 ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
1187 assert(((!lbMap && lbOperands.empty()) ||
1188 lbOperands.size() == lbMap.getNumInputs()) &&
1189 "lower bound operand count does not match the affine map");
1190 assert(((!ubMap && ubOperands.empty()) ||
1191 ubOperands.size() == ubMap.getNumInputs()) &&
1192 "upper bound operand count does not match the affine map");
1193 assert(step > 0 && "step has to be a positive integer constant");
1194
1195 for (Value val : iterArgs)
1196 result.addTypes(val.getType());
1197
1198 // Add an attribute for the step.
1199 result.addAttribute(getStepAttrName(),
1200 builder.getIntegerAttr(builder.getIndexType(), step));
1201
1202 // Add the lower bound.
1203 result.addAttribute(getLowerBoundAttrName(), AffineMapAttr::get(lbMap));
1204 result.addOperands(lbOperands);
1205
1206 // Add the upper bound.
1207 result.addAttribute(getUpperBoundAttrName(), AffineMapAttr::get(ubMap));
1208 result.addOperands(ubOperands);
1209
1210 result.addOperands(iterArgs);
1211 // Create a region and a block for the body. The argument of the region is
1212 // the loop induction variable.
1213 Region *bodyRegion = result.addRegion();
1214 bodyRegion->push_back(new Block);
1215 Block &bodyBlock = bodyRegion->front();
1216 Value inductionVar = bodyBlock.addArgument(builder.getIndexType());
1217 for (Value val : iterArgs)
1218 bodyBlock.addArgument(val.getType());
1219
1220 // Create the default terminator if the builder is not provided and if the
1221 // iteration arguments are not provided. Otherwise, leave this to the caller
1222 // because we don't know which values to return from the loop.
1223 if (iterArgs.empty() && !bodyBuilder) {
1224 ensureTerminator(*bodyRegion, builder, result.location);
1225 } else if (bodyBuilder) {
1226 OpBuilder::InsertionGuard guard(builder);
1227 builder.setInsertionPointToStart(&bodyBlock);
1228 bodyBuilder(builder, result.location, inductionVar,
1229 bodyBlock.getArguments().drop_front());
1230 }
1231 }
1232
build(OpBuilder & builder,OperationState & result,int64_t lb,int64_t ub,int64_t step,ValueRange iterArgs,BodyBuilderFn bodyBuilder)1233 void AffineForOp::build(OpBuilder &builder, OperationState &result, int64_t lb,
1234 int64_t ub, int64_t step, ValueRange iterArgs,
1235 BodyBuilderFn bodyBuilder) {
1236 auto lbMap = AffineMap::getConstantMap(lb, builder.getContext());
1237 auto ubMap = AffineMap::getConstantMap(ub, builder.getContext());
1238 return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,
1239 bodyBuilder);
1240 }
1241
verify(AffineForOp op)1242 static LogicalResult verify(AffineForOp op) {
1243 // Check that the body defines as single block argument for the induction
1244 // variable.
1245 auto *body = op.getBody();
1246 if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
1247 return op.emitOpError(
1248 "expected body to have a single index argument for the "
1249 "induction variable");
1250
1251 // Verify that the bound operands are valid dimension/symbols.
1252 /// Lower bound.
1253 if (op.getLowerBoundMap().getNumInputs() > 0)
1254 if (failed(
1255 verifyDimAndSymbolIdentifiers(op, op.getLowerBoundOperands(),
1256 op.getLowerBoundMap().getNumDims())))
1257 return failure();
1258 /// Upper bound.
1259 if (op.getUpperBoundMap().getNumInputs() > 0)
1260 if (failed(
1261 verifyDimAndSymbolIdentifiers(op, op.getUpperBoundOperands(),
1262 op.getUpperBoundMap().getNumDims())))
1263 return failure();
1264
1265 unsigned opNumResults = op.getNumResults();
1266 if (opNumResults == 0)
1267 return success();
1268
1269 // If ForOp defines values, check that the number and types of the defined
1270 // values match ForOp initial iter operands and backedge basic block
1271 // arguments.
1272 if (op.getNumIterOperands() != opNumResults)
1273 return op.emitOpError(
1274 "mismatch between the number of loop-carried values and results");
1275 if (op.getNumRegionIterArgs() != opNumResults)
1276 return op.emitOpError(
1277 "mismatch between the number of basic block args and results");
1278
1279 return success();
1280 }
1281
1282 /// Parse a for operation loop bounds.
parseBound(bool isLower,OperationState & result,OpAsmParser & p)1283 static ParseResult parseBound(bool isLower, OperationState &result,
1284 OpAsmParser &p) {
1285 // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
1286 // the map has multiple results.
1287 bool failedToParsedMinMax =
1288 failed(p.parseOptionalKeyword(isLower ? "max" : "min"));
1289
1290 auto &builder = p.getBuilder();
1291 auto boundAttrName = isLower ? AffineForOp::getLowerBoundAttrName()
1292 : AffineForOp::getUpperBoundAttrName();
1293
1294 // Parse ssa-id as identity map.
1295 SmallVector<OpAsmParser::OperandType, 1> boundOpInfos;
1296 if (p.parseOperandList(boundOpInfos))
1297 return failure();
1298
1299 if (!boundOpInfos.empty()) {
1300 // Check that only one operand was parsed.
1301 if (boundOpInfos.size() > 1)
1302 return p.emitError(p.getNameLoc(),
1303 "expected only one loop bound operand");
1304
1305 // TODO: improve error message when SSA value is not of index type.
1306 // Currently it is 'use of value ... expects different type than prior uses'
1307 if (p.resolveOperand(boundOpInfos.front(), builder.getIndexType(),
1308 result.operands))
1309 return failure();
1310
1311 // Create an identity map using symbol id. This representation is optimized
1312 // for storage. Analysis passes may expand it into a multi-dimensional map
1313 // if desired.
1314 AffineMap map = builder.getSymbolIdentityMap();
1315 result.addAttribute(boundAttrName, AffineMapAttr::get(map));
1316 return success();
1317 }
1318
1319 // Get the attribute location.
1320 llvm::SMLoc attrLoc = p.getCurrentLocation();
1321
1322 Attribute boundAttr;
1323 if (p.parseAttribute(boundAttr, builder.getIndexType(), boundAttrName,
1324 result.attributes))
1325 return failure();
1326
1327 // Parse full form - affine map followed by dim and symbol list.
1328 if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) {
1329 unsigned currentNumOperands = result.operands.size();
1330 unsigned numDims;
1331 if (parseDimAndSymbolList(p, result.operands, numDims))
1332 return failure();
1333
1334 auto map = affineMapAttr.getValue();
1335 if (map.getNumDims() != numDims)
1336 return p.emitError(
1337 p.getNameLoc(),
1338 "dim operand count and affine map dim count must match");
1339
1340 unsigned numDimAndSymbolOperands =
1341 result.operands.size() - currentNumOperands;
1342 if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
1343 return p.emitError(
1344 p.getNameLoc(),
1345 "symbol operand count and affine map symbol count must match");
1346
1347 // If the map has multiple results, make sure that we parsed the min/max
1348 // prefix.
1349 if (map.getNumResults() > 1 && failedToParsedMinMax) {
1350 if (isLower) {
1351 return p.emitError(attrLoc, "lower loop bound affine map with "
1352 "multiple results requires 'max' prefix");
1353 }
1354 return p.emitError(attrLoc, "upper loop bound affine map with multiple "
1355 "results requires 'min' prefix");
1356 }
1357 return success();
1358 }
1359
1360 // Parse custom assembly form.
1361 if (auto integerAttr = boundAttr.dyn_cast<IntegerAttr>()) {
1362 result.attributes.pop_back();
1363 result.addAttribute(
1364 boundAttrName,
1365 AffineMapAttr::get(builder.getConstantAffineMap(integerAttr.getInt())));
1366 return success();
1367 }
1368
1369 return p.emitError(
1370 p.getNameLoc(),
1371 "expected valid affine map representation for loop bounds");
1372 }
1373
parseAffineForOp(OpAsmParser & parser,OperationState & result)1374 static ParseResult parseAffineForOp(OpAsmParser &parser,
1375 OperationState &result) {
1376 auto &builder = parser.getBuilder();
1377 OpAsmParser::OperandType inductionVariable;
1378 // Parse the induction variable followed by '='.
1379 if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual())
1380 return failure();
1381
1382 // Parse loop bounds.
1383 if (parseBound(/*isLower=*/true, result, parser) ||
1384 parser.parseKeyword("to", " between bounds") ||
1385 parseBound(/*isLower=*/false, result, parser))
1386 return failure();
1387
1388 // Parse the optional loop step, we default to 1 if one is not present.
1389 if (parser.parseOptionalKeyword("step")) {
1390 result.addAttribute(
1391 AffineForOp::getStepAttrName(),
1392 builder.getIntegerAttr(builder.getIndexType(), /*value=*/1));
1393 } else {
1394 llvm::SMLoc stepLoc = parser.getCurrentLocation();
1395 IntegerAttr stepAttr;
1396 if (parser.parseAttribute(stepAttr, builder.getIndexType(),
1397 AffineForOp::getStepAttrName().data(),
1398 result.attributes))
1399 return failure();
1400
1401 if (stepAttr.getValue().getSExtValue() < 0)
1402 return parser.emitError(
1403 stepLoc,
1404 "expected step to be representable as a positive signed integer");
1405 }
1406
1407 // Parse the optional initial iteration arguments.
1408 SmallVector<OpAsmParser::OperandType, 4> regionArgs, operands;
1409 SmallVector<Type, 4> argTypes;
1410 regionArgs.push_back(inductionVariable);
1411
1412 if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
1413 // Parse assignment list and results type list.
1414 if (parser.parseAssignmentList(regionArgs, operands) ||
1415 parser.parseArrowTypeList(result.types))
1416 return failure();
1417 // Resolve input operands.
1418 for (auto operandType : llvm::zip(operands, result.types))
1419 if (parser.resolveOperand(std::get<0>(operandType),
1420 std::get<1>(operandType), result.operands))
1421 return failure();
1422 }
1423 // Induction variable.
1424 Type indexType = builder.getIndexType();
1425 argTypes.push_back(indexType);
1426 // Loop carried variables.
1427 argTypes.append(result.types.begin(), result.types.end());
1428 // Parse the body region.
1429 Region *body = result.addRegion();
1430 if (regionArgs.size() != argTypes.size())
1431 return parser.emitError(
1432 parser.getNameLoc(),
1433 "mismatch between the number of loop-carried values and results");
1434 if (parser.parseRegion(*body, regionArgs, argTypes))
1435 return failure();
1436
1437 AffineForOp::ensureTerminator(*body, builder, result.location);
1438
1439 // Parse the optional attribute list.
1440 return parser.parseOptionalAttrDict(result.attributes);
1441 }
1442
printBound(AffineMapAttr boundMap,Operation::operand_range boundOperands,const char * prefix,OpAsmPrinter & p)1443 static void printBound(AffineMapAttr boundMap,
1444 Operation::operand_range boundOperands,
1445 const char *prefix, OpAsmPrinter &p) {
1446 AffineMap map = boundMap.getValue();
1447
1448 // Check if this bound should be printed using custom assembly form.
1449 // The decision to restrict printing custom assembly form to trivial cases
1450 // comes from the will to roundtrip MLIR binary -> text -> binary in a
1451 // lossless way.
1452 // Therefore, custom assembly form parsing and printing is only supported for
1453 // zero-operand constant maps and single symbol operand identity maps.
1454 if (map.getNumResults() == 1) {
1455 AffineExpr expr = map.getResult(0);
1456
1457 // Print constant bound.
1458 if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
1459 if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
1460 p << constExpr.getValue();
1461 return;
1462 }
1463 }
1464
1465 // Print bound that consists of a single SSA symbol if the map is over a
1466 // single symbol.
1467 if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
1468 if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
1469 p.printOperand(*boundOperands.begin());
1470 return;
1471 }
1472 }
1473 } else {
1474 // Map has multiple results. Print 'min' or 'max' prefix.
1475 p << prefix << ' ';
1476 }
1477
1478 // Print the map and its operands.
1479 p << boundMap;
1480 printDimAndSymbolList(boundOperands.begin(), boundOperands.end(),
1481 map.getNumDims(), p);
1482 }
1483
getNumIterOperands()1484 unsigned AffineForOp::getNumIterOperands() {
1485 AffineMap lbMap = getLowerBoundMapAttr().getValue();
1486 AffineMap ubMap = getUpperBoundMapAttr().getValue();
1487
1488 return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs();
1489 }
1490
print(OpAsmPrinter & p,AffineForOp op)1491 static void print(OpAsmPrinter &p, AffineForOp op) {
1492 p << op.getOperationName() << ' ';
1493 p.printOperand(op.getBody()->getArgument(0));
1494 p << " = ";
1495 printBound(op.getLowerBoundMapAttr(), op.getLowerBoundOperands(), "max", p);
1496 p << " to ";
1497 printBound(op.getUpperBoundMapAttr(), op.getUpperBoundOperands(), "min", p);
1498
1499 if (op.getStep() != 1)
1500 p << " step " << op.getStep();
1501
1502 bool printBlockTerminators = false;
1503 if (op.getNumIterOperands() > 0) {
1504 p << " iter_args(";
1505 auto regionArgs = op.getRegionIterArgs();
1506 auto operands = op.getIterOperands();
1507
1508 llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) {
1509 p << std::get<0>(it) << " = " << std::get<1>(it);
1510 });
1511 p << ") -> (" << op.getResultTypes() << ")";
1512 printBlockTerminators = true;
1513 }
1514
1515 p.printRegion(op.region(),
1516 /*printEntryBlockArgs=*/false, printBlockTerminators);
1517 p.printOptionalAttrDict(op.getAttrs(),
1518 /*elidedAttrs=*/{op.getLowerBoundAttrName(),
1519 op.getUpperBoundAttrName(),
1520 op.getStepAttrName()});
1521 }
1522
1523 /// Fold the constant bounds of a loop.
foldLoopBounds(AffineForOp forOp)1524 static LogicalResult foldLoopBounds(AffineForOp forOp) {
1525 auto foldLowerOrUpperBound = [&forOp](bool lower) {
1526 // Check to see if each of the operands is the result of a constant. If
1527 // so, get the value. If not, ignore it.
1528 SmallVector<Attribute, 8> operandConstants;
1529 auto boundOperands =
1530 lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
1531 for (auto operand : boundOperands) {
1532 Attribute operandCst;
1533 matchPattern(operand, m_Constant(&operandCst));
1534 operandConstants.push_back(operandCst);
1535 }
1536
1537 AffineMap boundMap =
1538 lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
1539 assert(boundMap.getNumResults() >= 1 &&
1540 "bound maps should have at least one result");
1541 SmallVector<Attribute, 4> foldedResults;
1542 if (failed(boundMap.constantFold(operandConstants, foldedResults)))
1543 return failure();
1544
1545 // Compute the max or min as applicable over the results.
1546 assert(!foldedResults.empty() && "bounds should have at least one result");
1547 auto maxOrMin = foldedResults[0].cast<IntegerAttr>().getValue();
1548 for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
1549 auto foldedResult = foldedResults[i].cast<IntegerAttr>().getValue();
1550 maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
1551 : llvm::APIntOps::smin(maxOrMin, foldedResult);
1552 }
1553 lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
1554 : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
1555 return success();
1556 };
1557
1558 // Try to fold the lower bound.
1559 bool folded = false;
1560 if (!forOp.hasConstantLowerBound())
1561 folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true));
1562
1563 // Try to fold the upper bound.
1564 if (!forOp.hasConstantUpperBound())
1565 folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false));
1566 return success(folded);
1567 }
1568
1569 /// Canonicalize the bounds of the given loop.
canonicalizeLoopBounds(AffineForOp forOp)1570 static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
1571 SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
1572 SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
1573
1574 auto lbMap = forOp.getLowerBoundMap();
1575 auto ubMap = forOp.getUpperBoundMap();
1576 auto prevLbMap = lbMap;
1577 auto prevUbMap = ubMap;
1578
1579 canonicalizeMapAndOperands(&lbMap, &lbOperands);
1580 lbMap = removeDuplicateExprs(lbMap);
1581
1582 canonicalizeMapAndOperands(&ubMap, &ubOperands);
1583 ubMap = removeDuplicateExprs(ubMap);
1584
1585 // Any canonicalization change always leads to updated map(s).
1586 if (lbMap == prevLbMap && ubMap == prevUbMap)
1587 return failure();
1588
1589 if (lbMap != prevLbMap)
1590 forOp.setLowerBound(lbOperands, lbMap);
1591 if (ubMap != prevUbMap)
1592 forOp.setUpperBound(ubOperands, ubMap);
1593 return success();
1594 }
1595
1596 namespace {
1597 /// This is a pattern to fold trivially empty loops.
1598 struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
1599 using OpRewritePattern<AffineForOp>::OpRewritePattern;
1600
matchAndRewrite__anone6b2d6171011::AffineForEmptyLoopFolder1601 LogicalResult matchAndRewrite(AffineForOp forOp,
1602 PatternRewriter &rewriter) const override {
1603 // Check that the body only contains a yield.
1604 if (!llvm::hasSingleElement(*forOp.getBody()))
1605 return failure();
1606 rewriter.eraseOp(forOp);
1607 return success();
1608 }
1609 };
1610 } // end anonymous namespace
1611
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1612 void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1613 MLIRContext *context) {
1614 results.insert<AffineForEmptyLoopFolder>(context);
1615 }
1616
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1617 LogicalResult AffineForOp::fold(ArrayRef<Attribute> operands,
1618 SmallVectorImpl<OpFoldResult> &results) {
1619 bool folded = succeeded(foldLoopBounds(*this));
1620 folded |= succeeded(canonicalizeLoopBounds(*this));
1621 return success(folded);
1622 }
1623
getLowerBound()1624 AffineBound AffineForOp::getLowerBound() {
1625 auto lbMap = getLowerBoundMap();
1626 return AffineBound(AffineForOp(*this), 0, lbMap.getNumInputs(), lbMap);
1627 }
1628
getUpperBound()1629 AffineBound AffineForOp::getUpperBound() {
1630 auto lbMap = getLowerBoundMap();
1631 auto ubMap = getUpperBoundMap();
1632 return AffineBound(AffineForOp(*this), lbMap.getNumInputs(),
1633 lbMap.getNumInputs() + ubMap.getNumInputs(), ubMap);
1634 }
1635
setLowerBound(ValueRange lbOperands,AffineMap map)1636 void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) {
1637 assert(lbOperands.size() == map.getNumInputs());
1638 assert(map.getNumResults() >= 1 && "bound map has at least one result");
1639
1640 SmallVector<Value, 4> newOperands(lbOperands.begin(), lbOperands.end());
1641
1642 auto ubOperands = getUpperBoundOperands();
1643 newOperands.append(ubOperands.begin(), ubOperands.end());
1644 auto iterOperands = getIterOperands();
1645 newOperands.append(iterOperands.begin(), iterOperands.end());
1646 (*this)->setOperands(newOperands);
1647
1648 (*this)->setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map));
1649 }
1650
setUpperBound(ValueRange ubOperands,AffineMap map)1651 void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) {
1652 assert(ubOperands.size() == map.getNumInputs());
1653 assert(map.getNumResults() >= 1 && "bound map has at least one result");
1654
1655 SmallVector<Value, 4> newOperands(getLowerBoundOperands());
1656 newOperands.append(ubOperands.begin(), ubOperands.end());
1657 auto iterOperands = getIterOperands();
1658 newOperands.append(iterOperands.begin(), iterOperands.end());
1659 (*this)->setOperands(newOperands);
1660
1661 (*this)->setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map));
1662 }
1663
setLowerBoundMap(AffineMap map)1664 void AffineForOp::setLowerBoundMap(AffineMap map) {
1665 auto lbMap = getLowerBoundMap();
1666 assert(lbMap.getNumDims() == map.getNumDims() &&
1667 lbMap.getNumSymbols() == map.getNumSymbols());
1668 assert(map.getNumResults() >= 1 && "bound map has at least one result");
1669 (void)lbMap;
1670 (*this)->setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map));
1671 }
1672
setUpperBoundMap(AffineMap map)1673 void AffineForOp::setUpperBoundMap(AffineMap map) {
1674 auto ubMap = getUpperBoundMap();
1675 assert(ubMap.getNumDims() == map.getNumDims() &&
1676 ubMap.getNumSymbols() == map.getNumSymbols());
1677 assert(map.getNumResults() >= 1 && "bound map has at least one result");
1678 (void)ubMap;
1679 (*this)->setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map));
1680 }
1681
hasConstantLowerBound()1682 bool AffineForOp::hasConstantLowerBound() {
1683 return getLowerBoundMap().isSingleConstant();
1684 }
1685
hasConstantUpperBound()1686 bool AffineForOp::hasConstantUpperBound() {
1687 return getUpperBoundMap().isSingleConstant();
1688 }
1689
getConstantLowerBound()1690 int64_t AffineForOp::getConstantLowerBound() {
1691 return getLowerBoundMap().getSingleConstantResult();
1692 }
1693
getConstantUpperBound()1694 int64_t AffineForOp::getConstantUpperBound() {
1695 return getUpperBoundMap().getSingleConstantResult();
1696 }
1697
setConstantLowerBound(int64_t value)1698 void AffineForOp::setConstantLowerBound(int64_t value) {
1699 setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
1700 }
1701
setConstantUpperBound(int64_t value)1702 void AffineForOp::setConstantUpperBound(int64_t value) {
1703 setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
1704 }
1705
getLowerBoundOperands()1706 AffineForOp::operand_range AffineForOp::getLowerBoundOperands() {
1707 return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()};
1708 }
1709
getUpperBoundOperands()1710 AffineForOp::operand_range AffineForOp::getUpperBoundOperands() {
1711 return {operand_begin() + getLowerBoundMap().getNumInputs(),
1712 operand_begin() + getLowerBoundMap().getNumInputs() +
1713 getUpperBoundMap().getNumInputs()};
1714 }
1715
matchingBoundOperandList()1716 bool AffineForOp::matchingBoundOperandList() {
1717 auto lbMap = getLowerBoundMap();
1718 auto ubMap = getUpperBoundMap();
1719 if (lbMap.getNumDims() != ubMap.getNumDims() ||
1720 lbMap.getNumSymbols() != ubMap.getNumSymbols())
1721 return false;
1722
1723 unsigned numOperands = lbMap.getNumInputs();
1724 for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
1725 // Compare Value 's.
1726 if (getOperand(i) != getOperand(numOperands + i))
1727 return false;
1728 }
1729 return true;
1730 }
1731
getLoopBody()1732 Region &AffineForOp::getLoopBody() { return region(); }
1733
isDefinedOutsideOfLoop(Value value)1734 bool AffineForOp::isDefinedOutsideOfLoop(Value value) {
1735 return !region().isAncestor(value.getParentRegion());
1736 }
1737
moveOutOfLoop(ArrayRef<Operation * > ops)1738 LogicalResult AffineForOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
1739 for (auto *op : ops)
1740 op->moveBefore(*this);
1741 return success();
1742 }
1743
1744 /// Returns true if the provided value is the induction variable of a
1745 /// AffineForOp.
isForInductionVar(Value val)1746 bool mlir::isForInductionVar(Value val) {
1747 return getForInductionVarOwner(val) != AffineForOp();
1748 }
1749
1750 /// Returns the loop parent of an induction variable. If the provided value is
1751 /// not an induction variable, then return nullptr.
getForInductionVarOwner(Value val)1752 AffineForOp mlir::getForInductionVarOwner(Value val) {
1753 auto ivArg = val.dyn_cast<BlockArgument>();
1754 if (!ivArg || !ivArg.getOwner())
1755 return AffineForOp();
1756 auto *containingInst = ivArg.getOwner()->getParent()->getParentOp();
1757 return dyn_cast<AffineForOp>(containingInst);
1758 }
1759
1760 /// Extracts the induction variables from a list of AffineForOps and returns
1761 /// them.
extractForInductionVars(ArrayRef<AffineForOp> forInsts,SmallVectorImpl<Value> * ivs)1762 void mlir::extractForInductionVars(ArrayRef<AffineForOp> forInsts,
1763 SmallVectorImpl<Value> *ivs) {
1764 ivs->reserve(forInsts.size());
1765 for (auto forInst : forInsts)
1766 ivs->push_back(forInst.getInductionVar());
1767 }
1768
1769 /// Builds an affine loop nest, using "loopCreatorFn" to create individual loop
1770 /// operations.
1771 template <typename BoundListTy, typename LoopCreatorTy>
buildAffineLoopNestImpl(OpBuilder & builder,Location loc,BoundListTy lbs,BoundListTy ubs,ArrayRef<int64_t> steps,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilderFn,LoopCreatorTy && loopCreatorFn)1772 static void buildAffineLoopNestImpl(
1773 OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs,
1774 ArrayRef<int64_t> steps,
1775 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
1776 LoopCreatorTy &&loopCreatorFn) {
1777 assert(lbs.size() == ubs.size() && "Mismatch in number of arguments");
1778 assert(lbs.size() == steps.size() && "Mismatch in number of arguments");
1779
1780 // If there are no loops to be constructed, construct the body anyway.
1781 OpBuilder::InsertionGuard guard(builder);
1782 if (lbs.empty()) {
1783 if (bodyBuilderFn)
1784 bodyBuilderFn(builder, loc, ValueRange());
1785 return;
1786 }
1787
1788 // Create the loops iteratively and store the induction variables.
1789 SmallVector<Value, 4> ivs;
1790 ivs.reserve(lbs.size());
1791 for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
1792 // Callback for creating the loop body, always creates the terminator.
1793 auto loopBody = [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
1794 ValueRange iterArgs) {
1795 ivs.push_back(iv);
1796 // In the innermost loop, call the body builder.
1797 if (i == e - 1 && bodyBuilderFn) {
1798 OpBuilder::InsertionGuard nestedGuard(nestedBuilder);
1799 bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
1800 }
1801 nestedBuilder.create<AffineYieldOp>(nestedLoc);
1802 };
1803
1804 // Delegate actual loop creation to the callback in order to dispatch
1805 // between constant- and variable-bound loops.
1806 auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
1807 builder.setInsertionPointToStart(loop.getBody());
1808 }
1809 }
1810
1811 /// Creates an affine loop from the bounds known to be constants.
1812 static AffineForOp
buildAffineLoopFromConstants(OpBuilder & builder,Location loc,int64_t lb,int64_t ub,int64_t step,AffineForOp::BodyBuilderFn bodyBuilderFn)1813 buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb,
1814 int64_t ub, int64_t step,
1815 AffineForOp::BodyBuilderFn bodyBuilderFn) {
1816 return builder.create<AffineForOp>(loc, lb, ub, step, /*iterArgs=*/llvm::None,
1817 bodyBuilderFn);
1818 }
1819
1820 /// Creates an affine loop from the bounds that may or may not be constants.
1821 static AffineForOp
buildAffineLoopFromValues(OpBuilder & builder,Location loc,Value lb,Value ub,int64_t step,AffineForOp::BodyBuilderFn bodyBuilderFn)1822 buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub,
1823 int64_t step,
1824 AffineForOp::BodyBuilderFn bodyBuilderFn) {
1825 auto lbConst = lb.getDefiningOp<ConstantIndexOp>();
1826 auto ubConst = ub.getDefiningOp<ConstantIndexOp>();
1827 if (lbConst && ubConst)
1828 return buildAffineLoopFromConstants(builder, loc, lbConst.getValue(),
1829 ubConst.getValue(), step,
1830 bodyBuilderFn);
1831 return builder.create<AffineForOp>(loc, lb, builder.getDimIdentityMap(), ub,
1832 builder.getDimIdentityMap(), step,
1833 /*iterArgs=*/llvm::None, bodyBuilderFn);
1834 }
1835
buildAffineLoopNest(OpBuilder & builder,Location loc,ArrayRef<int64_t> lbs,ArrayRef<int64_t> ubs,ArrayRef<int64_t> steps,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilderFn)1836 void mlir::buildAffineLoopNest(
1837 OpBuilder &builder, Location loc, ArrayRef<int64_t> lbs,
1838 ArrayRef<int64_t> ubs, ArrayRef<int64_t> steps,
1839 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1840 buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
1841 buildAffineLoopFromConstants);
1842 }
1843
buildAffineLoopNest(OpBuilder & builder,Location loc,ValueRange lbs,ValueRange ubs,ArrayRef<int64_t> steps,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilderFn)1844 void mlir::buildAffineLoopNest(
1845 OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
1846 ArrayRef<int64_t> steps,
1847 function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1848 buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
1849 buildAffineLoopFromValues);
1850 }
1851
1852 //===----------------------------------------------------------------------===//
1853 // AffineIfOp
1854 //===----------------------------------------------------------------------===//
1855
1856 namespace {
1857 /// Remove else blocks that have nothing other than a zero value yield.
1858 struct SimplifyDeadElse : public OpRewritePattern<AffineIfOp> {
1859 using OpRewritePattern<AffineIfOp>::OpRewritePattern;
1860
matchAndRewrite__anone6b2d6171211::SimplifyDeadElse1861 LogicalResult matchAndRewrite(AffineIfOp ifOp,
1862 PatternRewriter &rewriter) const override {
1863 if (ifOp.elseRegion().empty() ||
1864 !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
1865 return failure();
1866
1867 rewriter.startRootUpdate(ifOp);
1868 rewriter.eraseBlock(ifOp.getElseBlock());
1869 rewriter.finalizeRootUpdate(ifOp);
1870 return success();
1871 }
1872 };
1873 } // end anonymous namespace.
1874
verify(AffineIfOp op)1875 static LogicalResult verify(AffineIfOp op) {
1876 // Verify that we have a condition attribute.
1877 auto conditionAttr =
1878 op->getAttrOfType<IntegerSetAttr>(op.getConditionAttrName());
1879 if (!conditionAttr)
1880 return op.emitOpError(
1881 "requires an integer set attribute named 'condition'");
1882
1883 // Verify that there are enough operands for the condition.
1884 IntegerSet condition = conditionAttr.getValue();
1885 if (op.getNumOperands() != condition.getNumInputs())
1886 return op.emitOpError(
1887 "operand count and condition integer set dimension and "
1888 "symbol count must match");
1889
1890 // Verify that the operands are valid dimension/symbols.
1891 if (failed(verifyDimAndSymbolIdentifiers(op, op.getOperands(),
1892 condition.getNumDims())))
1893 return failure();
1894
1895 return success();
1896 }
1897
parseAffineIfOp(OpAsmParser & parser,OperationState & result)1898 static ParseResult parseAffineIfOp(OpAsmParser &parser,
1899 OperationState &result) {
1900 // Parse the condition attribute set.
1901 IntegerSetAttr conditionAttr;
1902 unsigned numDims;
1903 if (parser.parseAttribute(conditionAttr, AffineIfOp::getConditionAttrName(),
1904 result.attributes) ||
1905 parseDimAndSymbolList(parser, result.operands, numDims))
1906 return failure();
1907
1908 // Verify the condition operands.
1909 auto set = conditionAttr.getValue();
1910 if (set.getNumDims() != numDims)
1911 return parser.emitError(
1912 parser.getNameLoc(),
1913 "dim operand count and integer set dim count must match");
1914 if (numDims + set.getNumSymbols() != result.operands.size())
1915 return parser.emitError(
1916 parser.getNameLoc(),
1917 "symbol operand count and integer set symbol count must match");
1918
1919 if (parser.parseOptionalArrowTypeList(result.types))
1920 return failure();
1921
1922 // Create the regions for 'then' and 'else'. The latter must be created even
1923 // if it remains empty for the validity of the operation.
1924 result.regions.reserve(2);
1925 Region *thenRegion = result.addRegion();
1926 Region *elseRegion = result.addRegion();
1927
1928 // Parse the 'then' region.
1929 if (parser.parseRegion(*thenRegion, {}, {}))
1930 return failure();
1931 AffineIfOp::ensureTerminator(*thenRegion, parser.getBuilder(),
1932 result.location);
1933
1934 // If we find an 'else' keyword then parse the 'else' region.
1935 if (!parser.parseOptionalKeyword("else")) {
1936 if (parser.parseRegion(*elseRegion, {}, {}))
1937 return failure();
1938 AffineIfOp::ensureTerminator(*elseRegion, parser.getBuilder(),
1939 result.location);
1940 }
1941
1942 // Parse the optional attribute list.
1943 if (parser.parseOptionalAttrDict(result.attributes))
1944 return failure();
1945
1946 return success();
1947 }
1948
print(OpAsmPrinter & p,AffineIfOp op)1949 static void print(OpAsmPrinter &p, AffineIfOp op) {
1950 auto conditionAttr =
1951 op->getAttrOfType<IntegerSetAttr>(op.getConditionAttrName());
1952 p << "affine.if " << conditionAttr;
1953 printDimAndSymbolList(op.operand_begin(), op.operand_end(),
1954 conditionAttr.getValue().getNumDims(), p);
1955 p.printOptionalArrowTypeList(op.getResultTypes());
1956 p.printRegion(op.thenRegion(),
1957 /*printEntryBlockArgs=*/false,
1958 /*printBlockTerminators=*/op.getNumResults());
1959
1960 // Print the 'else' regions if it has any blocks.
1961 auto &elseRegion = op.elseRegion();
1962 if (!elseRegion.empty()) {
1963 p << " else";
1964 p.printRegion(elseRegion,
1965 /*printEntryBlockArgs=*/false,
1966 /*printBlockTerminators=*/op.getNumResults());
1967 }
1968
1969 // Print the attribute list.
1970 p.printOptionalAttrDict(op.getAttrs(),
1971 /*elidedAttrs=*/op.getConditionAttrName());
1972 }
1973
getIntegerSet()1974 IntegerSet AffineIfOp::getIntegerSet() {
1975 return (*this)
1976 ->getAttrOfType<IntegerSetAttr>(getConditionAttrName())
1977 .getValue();
1978 }
setIntegerSet(IntegerSet newSet)1979 void AffineIfOp::setIntegerSet(IntegerSet newSet) {
1980 (*this)->setAttr(getConditionAttrName(), IntegerSetAttr::get(newSet));
1981 }
1982
setConditional(IntegerSet set,ValueRange operands)1983 void AffineIfOp::setConditional(IntegerSet set, ValueRange operands) {
1984 setIntegerSet(set);
1985 (*this)->setOperands(operands);
1986 }
1987
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,IntegerSet set,ValueRange args,bool withElseRegion)1988 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
1989 TypeRange resultTypes, IntegerSet set, ValueRange args,
1990 bool withElseRegion) {
1991 assert(resultTypes.empty() || withElseRegion);
1992 result.addTypes(resultTypes);
1993 result.addOperands(args);
1994 result.addAttribute(getConditionAttrName(), IntegerSetAttr::get(set));
1995
1996 Region *thenRegion = result.addRegion();
1997 thenRegion->push_back(new Block());
1998 if (resultTypes.empty())
1999 AffineIfOp::ensureTerminator(*thenRegion, builder, result.location);
2000
2001 Region *elseRegion = result.addRegion();
2002 if (withElseRegion) {
2003 elseRegion->push_back(new Block());
2004 if (resultTypes.empty())
2005 AffineIfOp::ensureTerminator(*elseRegion, builder, result.location);
2006 }
2007 }
2008
build(OpBuilder & builder,OperationState & result,IntegerSet set,ValueRange args,bool withElseRegion)2009 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
2010 IntegerSet set, ValueRange args, bool withElseRegion) {
2011 AffineIfOp::build(builder, result, /*resultTypes=*/{}, set, args,
2012 withElseRegion);
2013 }
2014
2015 /// Canonicalize an affine if op's conditional (integer set + operands).
fold(ArrayRef<Attribute>,SmallVectorImpl<OpFoldResult> &)2016 LogicalResult AffineIfOp::fold(ArrayRef<Attribute>,
2017 SmallVectorImpl<OpFoldResult> &) {
2018 auto set = getIntegerSet();
2019 SmallVector<Value, 4> operands(getOperands());
2020 canonicalizeSetAndOperands(&set, &operands);
2021
2022 // Any canonicalization change always leads to either a reduction in the
2023 // number of operands or a change in the number of symbolic operands
2024 // (promotion of dims to symbols).
2025 if (operands.size() < getIntegerSet().getNumInputs() ||
2026 set.getNumSymbols() > getIntegerSet().getNumSymbols()) {
2027 setConditional(set, operands);
2028 return success();
2029 }
2030
2031 return failure();
2032 }
2033
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2034 void AffineIfOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2035 MLIRContext *context) {
2036 results.insert<SimplifyDeadElse>(context);
2037 }
2038
2039 //===----------------------------------------------------------------------===//
2040 // AffineLoadOp
2041 //===----------------------------------------------------------------------===//
2042
build(OpBuilder & builder,OperationState & result,AffineMap map,ValueRange operands)2043 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2044 AffineMap map, ValueRange operands) {
2045 assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
2046 result.addOperands(operands);
2047 if (map)
2048 result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
2049 auto memrefType = operands[0].getType().cast<MemRefType>();
2050 result.types.push_back(memrefType.getElementType());
2051 }
2052
build(OpBuilder & builder,OperationState & result,Value memref,AffineMap map,ValueRange mapOperands)2053 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2054 Value memref, AffineMap map, ValueRange mapOperands) {
2055 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
2056 result.addOperands(memref);
2057 result.addOperands(mapOperands);
2058 auto memrefType = memref.getType().cast<MemRefType>();
2059 result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
2060 result.types.push_back(memrefType.getElementType());
2061 }
2062
build(OpBuilder & builder,OperationState & result,Value memref,ValueRange indices)2063 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2064 Value memref, ValueRange indices) {
2065 auto memrefType = memref.getType().cast<MemRefType>();
2066 int64_t rank = memrefType.getRank();
2067 // Create identity map for memrefs with at least one dimension or () -> ()
2068 // for zero-dimensional memrefs.
2069 auto map =
2070 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
2071 build(builder, result, memref, map, indices);
2072 }
2073
parseAffineLoadOp(OpAsmParser & parser,OperationState & result)2074 static ParseResult parseAffineLoadOp(OpAsmParser &parser,
2075 OperationState &result) {
2076 auto &builder = parser.getBuilder();
2077 auto indexTy = builder.getIndexType();
2078
2079 MemRefType type;
2080 OpAsmParser::OperandType memrefInfo;
2081 AffineMapAttr mapAttr;
2082 SmallVector<OpAsmParser::OperandType, 1> mapOperands;
2083 return failure(
2084 parser.parseOperand(memrefInfo) ||
2085 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
2086 AffineLoadOp::getMapAttrName(),
2087 result.attributes) ||
2088 parser.parseOptionalAttrDict(result.attributes) ||
2089 parser.parseColonType(type) ||
2090 parser.resolveOperand(memrefInfo, type, result.operands) ||
2091 parser.resolveOperands(mapOperands, indexTy, result.operands) ||
2092 parser.addTypeToList(type.getElementType(), result.types));
2093 }
2094
print(OpAsmPrinter & p,AffineLoadOp op)2095 static void print(OpAsmPrinter &p, AffineLoadOp op) {
2096 p << "affine.load " << op.getMemRef() << '[';
2097 if (AffineMapAttr mapAttr =
2098 op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
2099 p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
2100 p << ']';
2101 p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()});
2102 p << " : " << op.getMemRefType();
2103 }
2104
2105 /// Verify common indexing invariants of affine.load, affine.store,
2106 /// affine.vector_load and affine.vector_store.
2107 static LogicalResult
verifyMemoryOpIndexing(Operation * op,AffineMapAttr mapAttr,Operation::operand_range mapOperands,MemRefType memrefType,unsigned numIndexOperands)2108 verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
2109 Operation::operand_range mapOperands,
2110 MemRefType memrefType, unsigned numIndexOperands) {
2111 if (mapAttr) {
2112 AffineMap map = mapAttr.getValue();
2113 if (map.getNumResults() != memrefType.getRank())
2114 return op->emitOpError("affine map num results must equal memref rank");
2115 if (map.getNumInputs() != numIndexOperands)
2116 return op->emitOpError("expects as many subscripts as affine map inputs");
2117 } else {
2118 if (memrefType.getRank() != numIndexOperands)
2119 return op->emitOpError(
2120 "expects the number of subscripts to be equal to memref rank");
2121 }
2122
2123 Region *scope = getAffineScope(op);
2124 for (auto idx : mapOperands) {
2125 if (!idx.getType().isIndex())
2126 return op->emitOpError("index to load must have 'index' type");
2127 if (!isValidAffineIndexOperand(idx, scope))
2128 return op->emitOpError("index must be a dimension or symbol identifier");
2129 }
2130
2131 return success();
2132 }
2133
verify(AffineLoadOp op)2134 LogicalResult verify(AffineLoadOp op) {
2135 auto memrefType = op.getMemRefType();
2136 if (op.getType() != memrefType.getElementType())
2137 return op.emitOpError("result type must match element type of memref");
2138
2139 if (failed(verifyMemoryOpIndexing(
2140 op.getOperation(),
2141 op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
2142 op.getMapOperands(), memrefType,
2143 /*numIndexOperands=*/op.getNumOperands() - 1)))
2144 return failure();
2145
2146 return success();
2147 }
2148
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2149 void AffineLoadOp::getCanonicalizationPatterns(
2150 OwningRewritePatternList &results, MLIRContext *context) {
2151 results.insert<SimplifyAffineOp<AffineLoadOp>>(context);
2152 }
2153
fold(ArrayRef<Attribute> cstOperands)2154 OpFoldResult AffineLoadOp::fold(ArrayRef<Attribute> cstOperands) {
2155 /// load(memrefcast) -> load
2156 if (succeeded(foldMemRefCast(*this)))
2157 return getResult();
2158 return OpFoldResult();
2159 }
2160
2161 //===----------------------------------------------------------------------===//
2162 // AffineStoreOp
2163 //===----------------------------------------------------------------------===//
2164
build(OpBuilder & builder,OperationState & result,Value valueToStore,Value memref,AffineMap map,ValueRange mapOperands)2165 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
2166 Value valueToStore, Value memref, AffineMap map,
2167 ValueRange mapOperands) {
2168 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
2169 result.addOperands(valueToStore);
2170 result.addOperands(memref);
2171 result.addOperands(mapOperands);
2172 result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
2173 }
2174
2175 // Use identity map.
build(OpBuilder & builder,OperationState & result,Value valueToStore,Value memref,ValueRange indices)2176 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
2177 Value valueToStore, Value memref,
2178 ValueRange indices) {
2179 auto memrefType = memref.getType().cast<MemRefType>();
2180 int64_t rank = memrefType.getRank();
2181 // Create identity map for memrefs with at least one dimension or () -> ()
2182 // for zero-dimensional memrefs.
2183 auto map =
2184 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
2185 build(builder, result, valueToStore, memref, map, indices);
2186 }
2187
parseAffineStoreOp(OpAsmParser & parser,OperationState & result)2188 static ParseResult parseAffineStoreOp(OpAsmParser &parser,
2189 OperationState &result) {
2190 auto indexTy = parser.getBuilder().getIndexType();
2191
2192 MemRefType type;
2193 OpAsmParser::OperandType storeValueInfo;
2194 OpAsmParser::OperandType memrefInfo;
2195 AffineMapAttr mapAttr;
2196 SmallVector<OpAsmParser::OperandType, 1> mapOperands;
2197 return failure(parser.parseOperand(storeValueInfo) || parser.parseComma() ||
2198 parser.parseOperand(memrefInfo) ||
2199 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
2200 AffineStoreOp::getMapAttrName(),
2201 result.attributes) ||
2202 parser.parseOptionalAttrDict(result.attributes) ||
2203 parser.parseColonType(type) ||
2204 parser.resolveOperand(storeValueInfo, type.getElementType(),
2205 result.operands) ||
2206 parser.resolveOperand(memrefInfo, type, result.operands) ||
2207 parser.resolveOperands(mapOperands, indexTy, result.operands));
2208 }
2209
print(OpAsmPrinter & p,AffineStoreOp op)2210 static void print(OpAsmPrinter &p, AffineStoreOp op) {
2211 p << "affine.store " << op.getValueToStore();
2212 p << ", " << op.getMemRef() << '[';
2213 if (AffineMapAttr mapAttr =
2214 op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
2215 p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
2216 p << ']';
2217 p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()});
2218 p << " : " << op.getMemRefType();
2219 }
2220
verify(AffineStoreOp op)2221 LogicalResult verify(AffineStoreOp op) {
2222 // First operand must have same type as memref element type.
2223 auto memrefType = op.getMemRefType();
2224 if (op.getValueToStore().getType() != memrefType.getElementType())
2225 return op.emitOpError(
2226 "first operand must have same type memref element type");
2227
2228 if (failed(verifyMemoryOpIndexing(
2229 op.getOperation(),
2230 op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
2231 op.getMapOperands(), memrefType,
2232 /*numIndexOperands=*/op.getNumOperands() - 2)))
2233 return failure();
2234
2235 return success();
2236 }
2237
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2238 void AffineStoreOp::getCanonicalizationPatterns(
2239 OwningRewritePatternList &results, MLIRContext *context) {
2240 results.insert<SimplifyAffineOp<AffineStoreOp>>(context);
2241 }
2242
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)2243 LogicalResult AffineStoreOp::fold(ArrayRef<Attribute> cstOperands,
2244 SmallVectorImpl<OpFoldResult> &results) {
2245 /// store(memrefcast) -> store
2246 return foldMemRefCast(*this);
2247 }
2248
2249 //===----------------------------------------------------------------------===//
2250 // AffineMinMaxOpBase
2251 //===----------------------------------------------------------------------===//
2252
verifyAffineMinMaxOp(T op)2253 template <typename T> static LogicalResult verifyAffineMinMaxOp(T op) {
2254 // Verify that operand count matches affine map dimension and symbol count.
2255 if (op.getNumOperands() != op.map().getNumDims() + op.map().getNumSymbols())
2256 return op.emitOpError(
2257 "operand count and affine map dimension and symbol count must match");
2258 return success();
2259 }
2260
printAffineMinMaxOp(OpAsmPrinter & p,T op)2261 template <typename T> static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
2262 p << op.getOperationName() << ' ' << op->getAttr(T::getMapAttrName());
2263 auto operands = op.getOperands();
2264 unsigned numDims = op.map().getNumDims();
2265 p << '(' << operands.take_front(numDims) << ')';
2266
2267 if (operands.size() != numDims)
2268 p << '[' << operands.drop_front(numDims) << ']';
2269 p.printOptionalAttrDict(op.getAttrs(),
2270 /*elidedAttrs=*/{T::getMapAttrName()});
2271 }
2272
2273 template <typename T>
parseAffineMinMaxOp(OpAsmParser & parser,OperationState & result)2274 static ParseResult parseAffineMinMaxOp(OpAsmParser &parser,
2275 OperationState &result) {
2276 auto &builder = parser.getBuilder();
2277 auto indexType = builder.getIndexType();
2278 SmallVector<OpAsmParser::OperandType, 8> dim_infos;
2279 SmallVector<OpAsmParser::OperandType, 8> sym_infos;
2280 AffineMapAttr mapAttr;
2281 return failure(
2282 parser.parseAttribute(mapAttr, T::getMapAttrName(), result.attributes) ||
2283 parser.parseOperandList(dim_infos, OpAsmParser::Delimiter::Paren) ||
2284 parser.parseOperandList(sym_infos,
2285 OpAsmParser::Delimiter::OptionalSquare) ||
2286 parser.parseOptionalAttrDict(result.attributes) ||
2287 parser.resolveOperands(dim_infos, indexType, result.operands) ||
2288 parser.resolveOperands(sym_infos, indexType, result.operands) ||
2289 parser.addTypeToList(indexType, result.types));
2290 }
2291
2292 /// Fold an affine min or max operation with the given operands. The operand
2293 /// list may contain nulls, which are interpreted as the operand not being a
2294 /// constant.
2295 template <typename T>
foldMinMaxOp(T op,ArrayRef<Attribute> operands)2296 static OpFoldResult foldMinMaxOp(T op, ArrayRef<Attribute> operands) {
2297 static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
2298 "expected affine min or max op");
2299
2300 // Fold the affine map.
2301 // TODO: Fold more cases:
2302 // min(some_affine, some_affine + constant, ...), etc.
2303 SmallVector<int64_t, 2> results;
2304 auto foldedMap = op.map().partialConstantFold(operands, &results);
2305
2306 // If some of the map results are not constant, try changing the map in-place.
2307 if (results.empty()) {
2308 // If the map is the same, report that folding did not happen.
2309 if (foldedMap == op.map())
2310 return {};
2311 op->setAttr("map", AffineMapAttr::get(foldedMap));
2312 return op.getResult();
2313 }
2314
2315 // Otherwise, completely fold the op into a constant.
2316 auto resultIt = std::is_same<T, AffineMinOp>::value
2317 ? std::min_element(results.begin(), results.end())
2318 : std::max_element(results.begin(), results.end());
2319 if (resultIt == results.end())
2320 return {};
2321 return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt);
2322 }
2323
2324 //===----------------------------------------------------------------------===//
2325 // AffineMinOp
2326 //===----------------------------------------------------------------------===//
2327 //
2328 // %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
2329 //
2330
fold(ArrayRef<Attribute> operands)2331 OpFoldResult AffineMinOp::fold(ArrayRef<Attribute> operands) {
2332 return foldMinMaxOp(*this, operands);
2333 }
2334
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)2335 void AffineMinOp::getCanonicalizationPatterns(
2336 OwningRewritePatternList &patterns, MLIRContext *context) {
2337 patterns.insert<SimplifyAffineOp<AffineMinOp>>(context);
2338 }
2339
2340 //===----------------------------------------------------------------------===//
2341 // AffineMaxOp
2342 //===----------------------------------------------------------------------===//
2343 //
2344 // %0 = affine.max (d0) -> (1000, d0 + 512) (%i0)
2345 //
2346
fold(ArrayRef<Attribute> operands)2347 OpFoldResult AffineMaxOp::fold(ArrayRef<Attribute> operands) {
2348 return foldMinMaxOp(*this, operands);
2349 }
2350
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)2351 void AffineMaxOp::getCanonicalizationPatterns(
2352 OwningRewritePatternList &patterns, MLIRContext *context) {
2353 patterns.insert<SimplifyAffineOp<AffineMaxOp>>(context);
2354 }
2355
2356 //===----------------------------------------------------------------------===//
2357 // AffinePrefetchOp
2358 //===----------------------------------------------------------------------===//
2359
2360 //
2361 // affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32>
2362 //
parseAffinePrefetchOp(OpAsmParser & parser,OperationState & result)2363 static ParseResult parseAffinePrefetchOp(OpAsmParser &parser,
2364 OperationState &result) {
2365 auto &builder = parser.getBuilder();
2366 auto indexTy = builder.getIndexType();
2367
2368 MemRefType type;
2369 OpAsmParser::OperandType memrefInfo;
2370 IntegerAttr hintInfo;
2371 auto i32Type = parser.getBuilder().getIntegerType(32);
2372 StringRef readOrWrite, cacheType;
2373
2374 AffineMapAttr mapAttr;
2375 SmallVector<OpAsmParser::OperandType, 1> mapOperands;
2376 if (parser.parseOperand(memrefInfo) ||
2377 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
2378 AffinePrefetchOp::getMapAttrName(),
2379 result.attributes) ||
2380 parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
2381 parser.parseComma() || parser.parseKeyword("locality") ||
2382 parser.parseLess() ||
2383 parser.parseAttribute(hintInfo, i32Type,
2384 AffinePrefetchOp::getLocalityHintAttrName(),
2385 result.attributes) ||
2386 parser.parseGreater() || parser.parseComma() ||
2387 parser.parseKeyword(&cacheType) ||
2388 parser.parseOptionalAttrDict(result.attributes) ||
2389 parser.parseColonType(type) ||
2390 parser.resolveOperand(memrefInfo, type, result.operands) ||
2391 parser.resolveOperands(mapOperands, indexTy, result.operands))
2392 return failure();
2393
2394 if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
2395 return parser.emitError(parser.getNameLoc(),
2396 "rw specifier has to be 'read' or 'write'");
2397 result.addAttribute(
2398 AffinePrefetchOp::getIsWriteAttrName(),
2399 parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
2400
2401 if (!cacheType.equals("data") && !cacheType.equals("instr"))
2402 return parser.emitError(parser.getNameLoc(),
2403 "cache type has to be 'data' or 'instr'");
2404
2405 result.addAttribute(
2406 AffinePrefetchOp::getIsDataCacheAttrName(),
2407 parser.getBuilder().getBoolAttr(cacheType.equals("data")));
2408
2409 return success();
2410 }
2411
print(OpAsmPrinter & p,AffinePrefetchOp op)2412 static void print(OpAsmPrinter &p, AffinePrefetchOp op) {
2413 p << AffinePrefetchOp::getOperationName() << " " << op.memref() << '[';
2414 AffineMapAttr mapAttr = op->getAttrOfType<AffineMapAttr>(op.getMapAttrName());
2415 if (mapAttr) {
2416 SmallVector<Value, 2> operands(op.getMapOperands());
2417 p.printAffineMapOfSSAIds(mapAttr, operands);
2418 }
2419 p << ']' << ", " << (op.isWrite() ? "write" : "read") << ", "
2420 << "locality<" << op.localityHint() << ">, "
2421 << (op.isDataCache() ? "data" : "instr");
2422 p.printOptionalAttrDict(
2423 op.getAttrs(),
2424 /*elidedAttrs=*/{op.getMapAttrName(), op.getLocalityHintAttrName(),
2425 op.getIsDataCacheAttrName(), op.getIsWriteAttrName()});
2426 p << " : " << op.getMemRefType();
2427 }
2428
verify(AffinePrefetchOp op)2429 static LogicalResult verify(AffinePrefetchOp op) {
2430 auto mapAttr = op->getAttrOfType<AffineMapAttr>(op.getMapAttrName());
2431 if (mapAttr) {
2432 AffineMap map = mapAttr.getValue();
2433 if (map.getNumResults() != op.getMemRefType().getRank())
2434 return op.emitOpError("affine.prefetch affine map num results must equal"
2435 " memref rank");
2436 if (map.getNumInputs() + 1 != op.getNumOperands())
2437 return op.emitOpError("too few operands");
2438 } else {
2439 if (op.getNumOperands() != 1)
2440 return op.emitOpError("too few operands");
2441 }
2442
2443 Region *scope = getAffineScope(op);
2444 for (auto idx : op.getMapOperands()) {
2445 if (!isValidAffineIndexOperand(idx, scope))
2446 return op.emitOpError("index must be a dimension or symbol identifier");
2447 }
2448 return success();
2449 }
2450
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2451 void AffinePrefetchOp::getCanonicalizationPatterns(
2452 OwningRewritePatternList &results, MLIRContext *context) {
2453 // prefetch(memrefcast) -> prefetch
2454 results.insert<SimplifyAffineOp<AffinePrefetchOp>>(context);
2455 }
2456
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)2457 LogicalResult AffinePrefetchOp::fold(ArrayRef<Attribute> cstOperands,
2458 SmallVectorImpl<OpFoldResult> &results) {
2459 /// prefetch(memrefcast) -> prefetch
2460 return foldMemRefCast(*this);
2461 }
2462
2463 //===----------------------------------------------------------------------===//
2464 // AffineParallelOp
2465 //===----------------------------------------------------------------------===//
2466
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,ArrayRef<AtomicRMWKind> reductions,ArrayRef<int64_t> ranges)2467 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
2468 TypeRange resultTypes,
2469 ArrayRef<AtomicRMWKind> reductions,
2470 ArrayRef<int64_t> ranges) {
2471 SmallVector<AffineExpr, 8> lbExprs(ranges.size(),
2472 builder.getAffineConstantExpr(0));
2473 auto lbMap = AffineMap::get(0, 0, lbExprs, builder.getContext());
2474 SmallVector<AffineExpr, 8> ubExprs;
2475 for (int64_t range : ranges)
2476 ubExprs.push_back(builder.getAffineConstantExpr(range));
2477 auto ubMap = AffineMap::get(0, 0, ubExprs, builder.getContext());
2478 build(builder, result, resultTypes, reductions, lbMap, /*lbArgs=*/{}, ubMap,
2479 /*ubArgs=*/{});
2480 }
2481
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,ArrayRef<AtomicRMWKind> reductions,AffineMap lbMap,ValueRange lbArgs,AffineMap ubMap,ValueRange ubArgs)2482 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
2483 TypeRange resultTypes,
2484 ArrayRef<AtomicRMWKind> reductions,
2485 AffineMap lbMap, ValueRange lbArgs,
2486 AffineMap ubMap, ValueRange ubArgs) {
2487 auto numDims = lbMap.getNumResults();
2488 // Verify that the dimensionality of both maps are the same.
2489 assert(numDims == ubMap.getNumResults() &&
2490 "num dims and num results mismatch");
2491 // Make default step sizes of 1.
2492 SmallVector<int64_t, 8> steps(numDims, 1);
2493 build(builder, result, resultTypes, reductions, lbMap, lbArgs, ubMap, ubArgs,
2494 steps);
2495 }
2496
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,ArrayRef<AtomicRMWKind> reductions,AffineMap lbMap,ValueRange lbArgs,AffineMap ubMap,ValueRange ubArgs,ArrayRef<int64_t> steps)2497 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
2498 TypeRange resultTypes,
2499 ArrayRef<AtomicRMWKind> reductions,
2500 AffineMap lbMap, ValueRange lbArgs,
2501 AffineMap ubMap, ValueRange ubArgs,
2502 ArrayRef<int64_t> steps) {
2503 auto numDims = lbMap.getNumResults();
2504 // Verify that the dimensionality of the maps matches the number of steps.
2505 assert(numDims == ubMap.getNumResults() &&
2506 "num dims and num results mismatch");
2507 assert(numDims == steps.size() && "num dims and num steps mismatch");
2508
2509 result.addTypes(resultTypes);
2510 // Convert the reductions to integer attributes.
2511 SmallVector<Attribute, 4> reductionAttrs;
2512 for (AtomicRMWKind reduction : reductions)
2513 reductionAttrs.push_back(
2514 builder.getI64IntegerAttr(static_cast<int64_t>(reduction)));
2515 result.addAttribute(getReductionsAttrName(),
2516 builder.getArrayAttr(reductionAttrs));
2517 result.addAttribute(getLowerBoundsMapAttrName(), AffineMapAttr::get(lbMap));
2518 result.addAttribute(getUpperBoundsMapAttrName(), AffineMapAttr::get(ubMap));
2519 result.addAttribute(getStepsAttrName(), builder.getI64ArrayAttr(steps));
2520 result.addOperands(lbArgs);
2521 result.addOperands(ubArgs);
2522 // Create a region and a block for the body.
2523 auto *bodyRegion = result.addRegion();
2524 auto *body = new Block();
2525 // Add all the block arguments.
2526 for (unsigned i = 0; i < numDims; ++i)
2527 body->addArgument(IndexType::get(builder.getContext()));
2528 bodyRegion->push_back(body);
2529 if (resultTypes.empty())
2530 ensureTerminator(*bodyRegion, builder, result.location);
2531 }
2532
getLoopBody()2533 Region &AffineParallelOp::getLoopBody() { return region(); }
2534
isDefinedOutsideOfLoop(Value value)2535 bool AffineParallelOp::isDefinedOutsideOfLoop(Value value) {
2536 return !region().isAncestor(value.getParentRegion());
2537 }
2538
moveOutOfLoop(ArrayRef<Operation * > ops)2539 LogicalResult AffineParallelOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
2540 for (Operation *op : ops)
2541 op->moveBefore(*this);
2542 return success();
2543 }
2544
getNumDims()2545 unsigned AffineParallelOp::getNumDims() { return steps().size(); }
2546
getLowerBoundsOperands()2547 AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
2548 return getOperands().take_front(lowerBoundsMap().getNumInputs());
2549 }
2550
getUpperBoundsOperands()2551 AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
2552 return getOperands().drop_front(lowerBoundsMap().getNumInputs());
2553 }
2554
getLowerBoundsValueMap()2555 AffineValueMap AffineParallelOp::getLowerBoundsValueMap() {
2556 return AffineValueMap(lowerBoundsMap(), getLowerBoundsOperands());
2557 }
2558
getUpperBoundsValueMap()2559 AffineValueMap AffineParallelOp::getUpperBoundsValueMap() {
2560 return AffineValueMap(upperBoundsMap(), getUpperBoundsOperands());
2561 }
2562
getRangesValueMap()2563 AffineValueMap AffineParallelOp::getRangesValueMap() {
2564 AffineValueMap out;
2565 AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
2566 &out);
2567 return out;
2568 }
2569
getConstantRanges()2570 Optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
2571 // Try to convert all the ranges to constant expressions.
2572 SmallVector<int64_t, 8> out;
2573 AffineValueMap rangesValueMap = getRangesValueMap();
2574 out.reserve(rangesValueMap.getNumResults());
2575 for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) {
2576 auto expr = rangesValueMap.getResult(i);
2577 auto cst = expr.dyn_cast<AffineConstantExpr>();
2578 if (!cst)
2579 return llvm::None;
2580 out.push_back(cst.getValue());
2581 }
2582 return out;
2583 }
2584
getBody()2585 Block *AffineParallelOp::getBody() { return ®ion().front(); }
2586
getBodyBuilder()2587 OpBuilder AffineParallelOp::getBodyBuilder() {
2588 return OpBuilder(getBody(), std::prev(getBody()->end()));
2589 }
2590
setLowerBounds(ValueRange lbOperands,AffineMap map)2591 void AffineParallelOp::setLowerBounds(ValueRange lbOperands, AffineMap map) {
2592 assert(lbOperands.size() == map.getNumInputs() &&
2593 "operands to map must match number of inputs");
2594 assert(map.getNumResults() >= 1 && "bounds map has at least one result");
2595
2596 auto ubOperands = getUpperBoundsOperands();
2597
2598 SmallVector<Value, 4> newOperands(lbOperands);
2599 newOperands.append(ubOperands.begin(), ubOperands.end());
2600 (*this)->setOperands(newOperands);
2601
2602 lowerBoundsMapAttr(AffineMapAttr::get(map));
2603 }
2604
setUpperBounds(ValueRange ubOperands,AffineMap map)2605 void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) {
2606 assert(ubOperands.size() == map.getNumInputs() &&
2607 "operands to map must match number of inputs");
2608 assert(map.getNumResults() >= 1 && "bounds map has at least one result");
2609
2610 SmallVector<Value, 4> newOperands(getLowerBoundsOperands());
2611 newOperands.append(ubOperands.begin(), ubOperands.end());
2612 (*this)->setOperands(newOperands);
2613
2614 upperBoundsMapAttr(AffineMapAttr::get(map));
2615 }
2616
setLowerBoundsMap(AffineMap map)2617 void AffineParallelOp::setLowerBoundsMap(AffineMap map) {
2618 AffineMap lbMap = lowerBoundsMap();
2619 assert(lbMap.getNumDims() == map.getNumDims() &&
2620 lbMap.getNumSymbols() == map.getNumSymbols());
2621 (void)lbMap;
2622 lowerBoundsMapAttr(AffineMapAttr::get(map));
2623 }
2624
setUpperBoundsMap(AffineMap map)2625 void AffineParallelOp::setUpperBoundsMap(AffineMap map) {
2626 AffineMap ubMap = upperBoundsMap();
2627 assert(ubMap.getNumDims() == map.getNumDims() &&
2628 ubMap.getNumSymbols() == map.getNumSymbols());
2629 (void)ubMap;
2630 upperBoundsMapAttr(AffineMapAttr::get(map));
2631 }
2632
getSteps()2633 SmallVector<int64_t, 8> AffineParallelOp::getSteps() {
2634 SmallVector<int64_t, 8> result;
2635 for (Attribute attr : steps()) {
2636 result.push_back(attr.cast<IntegerAttr>().getInt());
2637 }
2638 return result;
2639 }
2640
setSteps(ArrayRef<int64_t> newSteps)2641 void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
2642 stepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
2643 }
2644
verify(AffineParallelOp op)2645 static LogicalResult verify(AffineParallelOp op) {
2646 auto numDims = op.getNumDims();
2647 if (op.lowerBoundsMap().getNumResults() != numDims ||
2648 op.upperBoundsMap().getNumResults() != numDims ||
2649 op.steps().size() != numDims ||
2650 op.getBody()->getNumArguments() != numDims)
2651 return op.emitOpError("region argument count and num results of upper "
2652 "bounds, lower bounds, and steps must all match");
2653
2654 if (op.reductions().size() != op.getNumResults())
2655 return op.emitOpError("a reduction must be specified for each output");
2656
2657 // Verify reduction ops are all valid
2658 for (Attribute attr : op.reductions()) {
2659 auto intAttr = attr.dyn_cast<IntegerAttr>();
2660 if (!intAttr || !symbolizeAtomicRMWKind(intAttr.getInt()))
2661 return op.emitOpError("invalid reduction attribute");
2662 }
2663
2664 // Verify that the bound operands are valid dimension/symbols.
2665 /// Lower bounds.
2666 if (failed(verifyDimAndSymbolIdentifiers(op, op.getLowerBoundsOperands(),
2667 op.lowerBoundsMap().getNumDims())))
2668 return failure();
2669 /// Upper bounds.
2670 if (failed(verifyDimAndSymbolIdentifiers(op, op.getUpperBoundsOperands(),
2671 op.upperBoundsMap().getNumDims())))
2672 return failure();
2673 return success();
2674 }
2675
canonicalize()2676 LogicalResult AffineValueMap::canonicalize() {
2677 SmallVector<Value, 4> newOperands{operands};
2678 auto newMap = getAffineMap();
2679 composeAffineMapAndOperands(&newMap, &newOperands);
2680 if (newMap == getAffineMap() && newOperands == operands)
2681 return failure();
2682 reset(newMap, newOperands);
2683 return success();
2684 }
2685
2686 /// Canonicalize the bounds of the given loop.
canonicalizeLoopBounds(AffineParallelOp op)2687 static LogicalResult canonicalizeLoopBounds(AffineParallelOp op) {
2688 AffineValueMap lb = op.getLowerBoundsValueMap();
2689 bool lbCanonicalized = succeeded(lb.canonicalize());
2690
2691 AffineValueMap ub = op.getUpperBoundsValueMap();
2692 bool ubCanonicalized = succeeded(ub.canonicalize());
2693
2694 // Any canonicalization change always leads to updated map(s).
2695 if (!lbCanonicalized && !ubCanonicalized)
2696 return failure();
2697
2698 if (lbCanonicalized)
2699 op.setLowerBounds(lb.getOperands(), lb.getAffineMap());
2700 if (ubCanonicalized)
2701 op.setUpperBounds(ub.getOperands(), ub.getAffineMap());
2702
2703 return success();
2704 }
2705
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)2706 LogicalResult AffineParallelOp::fold(ArrayRef<Attribute> operands,
2707 SmallVectorImpl<OpFoldResult> &results) {
2708 return canonicalizeLoopBounds(*this);
2709 }
2710
print(OpAsmPrinter & p,AffineParallelOp op)2711 static void print(OpAsmPrinter &p, AffineParallelOp op) {
2712 p << op.getOperationName() << " (" << op.getBody()->getArguments() << ") = (";
2713 p.printAffineMapOfSSAIds(op.lowerBoundsMapAttr(),
2714 op.getLowerBoundsOperands());
2715 p << ") to (";
2716 p.printAffineMapOfSSAIds(op.upperBoundsMapAttr(),
2717 op.getUpperBoundsOperands());
2718 p << ')';
2719 SmallVector<int64_t, 8> steps = op.getSteps();
2720 bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; });
2721 if (!elideSteps) {
2722 p << " step (";
2723 llvm::interleaveComma(steps, p);
2724 p << ')';
2725 }
2726 if (op.getNumResults()) {
2727 p << " reduce (";
2728 llvm::interleaveComma(op.reductions(), p, [&](auto &attr) {
2729 AtomicRMWKind sym =
2730 *symbolizeAtomicRMWKind(attr.template cast<IntegerAttr>().getInt());
2731 p << "\"" << stringifyAtomicRMWKind(sym) << "\"";
2732 });
2733 p << ") -> (" << op.getResultTypes() << ")";
2734 }
2735
2736 p.printRegion(op.region(), /*printEntryBlockArgs=*/false,
2737 /*printBlockTerminators=*/op.getNumResults());
2738 p.printOptionalAttrDict(
2739 op.getAttrs(),
2740 /*elidedAttrs=*/{AffineParallelOp::getReductionsAttrName(),
2741 AffineParallelOp::getLowerBoundsMapAttrName(),
2742 AffineParallelOp::getUpperBoundsMapAttrName(),
2743 AffineParallelOp::getStepsAttrName()});
2744 }
2745
2746 //
2747 // operation ::= `affine.parallel` `(` ssa-ids `)` `=` `(` map-of-ssa-ids `)`
2748 // `to` `(` map-of-ssa-ids `)` steps? region attr-dict?
2749 // steps ::= `steps` `(` integer-literals `)`
2750 //
parseAffineParallelOp(OpAsmParser & parser,OperationState & result)2751 static ParseResult parseAffineParallelOp(OpAsmParser &parser,
2752 OperationState &result) {
2753 auto &builder = parser.getBuilder();
2754 auto indexType = builder.getIndexType();
2755 AffineMapAttr lowerBoundsAttr, upperBoundsAttr;
2756 SmallVector<OpAsmParser::OperandType, 4> ivs;
2757 SmallVector<OpAsmParser::OperandType, 4> lowerBoundsMapOperands;
2758 SmallVector<OpAsmParser::OperandType, 4> upperBoundsMapOperands;
2759 if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
2760 OpAsmParser::Delimiter::Paren) ||
2761 parser.parseEqual() ||
2762 parser.parseAffineMapOfSSAIds(
2763 lowerBoundsMapOperands, lowerBoundsAttr,
2764 AffineParallelOp::getLowerBoundsMapAttrName(), result.attributes,
2765 OpAsmParser::Delimiter::Paren) ||
2766 parser.resolveOperands(lowerBoundsMapOperands, indexType,
2767 result.operands) ||
2768 parser.parseKeyword("to") ||
2769 parser.parseAffineMapOfSSAIds(
2770 upperBoundsMapOperands, upperBoundsAttr,
2771 AffineParallelOp::getUpperBoundsMapAttrName(), result.attributes,
2772 OpAsmParser::Delimiter::Paren) ||
2773 parser.resolveOperands(upperBoundsMapOperands, indexType,
2774 result.operands))
2775 return failure();
2776
2777 AffineMapAttr stepsMapAttr;
2778 NamedAttrList stepsAttrs;
2779 SmallVector<OpAsmParser::OperandType, 4> stepsMapOperands;
2780 if (failed(parser.parseOptionalKeyword("step"))) {
2781 SmallVector<int64_t, 4> steps(ivs.size(), 1);
2782 result.addAttribute(AffineParallelOp::getStepsAttrName(),
2783 builder.getI64ArrayAttr(steps));
2784 } else {
2785 if (parser.parseAffineMapOfSSAIds(stepsMapOperands, stepsMapAttr,
2786 AffineParallelOp::getStepsAttrName(),
2787 stepsAttrs,
2788 OpAsmParser::Delimiter::Paren))
2789 return failure();
2790
2791 // Convert steps from an AffineMap into an I64ArrayAttr.
2792 SmallVector<int64_t, 4> steps;
2793 auto stepsMap = stepsMapAttr.getValue();
2794 for (const auto &result : stepsMap.getResults()) {
2795 auto constExpr = result.dyn_cast<AffineConstantExpr>();
2796 if (!constExpr)
2797 return parser.emitError(parser.getNameLoc(),
2798 "steps must be constant integers");
2799 steps.push_back(constExpr.getValue());
2800 }
2801 result.addAttribute(AffineParallelOp::getStepsAttrName(),
2802 builder.getI64ArrayAttr(steps));
2803 }
2804
2805 // Parse optional clause of the form: `reduce ("addf", "maxf")`, where the
2806 // quoted strings are a member of the enum AtomicRMWKind.
2807 SmallVector<Attribute, 4> reductions;
2808 if (succeeded(parser.parseOptionalKeyword("reduce"))) {
2809 if (parser.parseLParen())
2810 return failure();
2811 do {
2812 // Parse a single quoted string via the attribute parsing, and then
2813 // verify it is a member of the enum and convert to it's integer
2814 // representation.
2815 StringAttr attrVal;
2816 NamedAttrList attrStorage;
2817 auto loc = parser.getCurrentLocation();
2818 if (parser.parseAttribute(attrVal, builder.getNoneType(), "reduce",
2819 attrStorage))
2820 return failure();
2821 llvm::Optional<AtomicRMWKind> reduction =
2822 symbolizeAtomicRMWKind(attrVal.getValue());
2823 if (!reduction)
2824 return parser.emitError(loc, "invalid reduction value: ") << attrVal;
2825 reductions.push_back(builder.getI64IntegerAttr(
2826 static_cast<int64_t>(reduction.getValue())));
2827 // While we keep getting commas, keep parsing.
2828 } while (succeeded(parser.parseOptionalComma()));
2829 if (parser.parseRParen())
2830 return failure();
2831 }
2832 result.addAttribute(AffineParallelOp::getReductionsAttrName(),
2833 builder.getArrayAttr(reductions));
2834
2835 // Parse return types of reductions (if any)
2836 if (parser.parseOptionalArrowTypeList(result.types))
2837 return failure();
2838
2839 // Now parse the body.
2840 Region *body = result.addRegion();
2841 SmallVector<Type, 4> types(ivs.size(), indexType);
2842 if (parser.parseRegion(*body, ivs, types) ||
2843 parser.parseOptionalAttrDict(result.attributes))
2844 return failure();
2845
2846 // Add a terminator if none was parsed.
2847 AffineParallelOp::ensureTerminator(*body, builder, result.location);
2848 return success();
2849 }
2850
2851 //===----------------------------------------------------------------------===//
2852 // AffineYieldOp
2853 //===----------------------------------------------------------------------===//
2854
verify(AffineYieldOp op)2855 static LogicalResult verify(AffineYieldOp op) {
2856 auto *parentOp = op->getParentOp();
2857 auto results = parentOp->getResults();
2858 auto operands = op.getOperands();
2859
2860 if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
2861 return op.emitOpError() << "only terminates affine.if/for/parallel regions";
2862 if (parentOp->getNumResults() != op.getNumOperands())
2863 return op.emitOpError() << "parent of yield must have same number of "
2864 "results as the yield operands";
2865 for (auto it : llvm::zip(results, operands)) {
2866 if (std::get<0>(it).getType() != std::get<1>(it).getType())
2867 return op.emitOpError()
2868 << "types mismatch between yield op and its parent";
2869 }
2870
2871 return success();
2872 }
2873
2874 //===----------------------------------------------------------------------===//
2875 // AffineVectorLoadOp
2876 //===----------------------------------------------------------------------===//
2877
build(OpBuilder & builder,OperationState & result,VectorType resultType,AffineMap map,ValueRange operands)2878 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
2879 VectorType resultType, AffineMap map,
2880 ValueRange operands) {
2881 assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
2882 result.addOperands(operands);
2883 if (map)
2884 result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
2885 result.types.push_back(resultType);
2886 }
2887
build(OpBuilder & builder,OperationState & result,VectorType resultType,Value memref,AffineMap map,ValueRange mapOperands)2888 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
2889 VectorType resultType, Value memref,
2890 AffineMap map, ValueRange mapOperands) {
2891 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
2892 result.addOperands(memref);
2893 result.addOperands(mapOperands);
2894 result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
2895 result.types.push_back(resultType);
2896 }
2897
build(OpBuilder & builder,OperationState & result,VectorType resultType,Value memref,ValueRange indices)2898 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
2899 VectorType resultType, Value memref,
2900 ValueRange indices) {
2901 auto memrefType = memref.getType().cast<MemRefType>();
2902 int64_t rank = memrefType.getRank();
2903 // Create identity map for memrefs with at least one dimension or () -> ()
2904 // for zero-dimensional memrefs.
2905 auto map =
2906 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
2907 build(builder, result, resultType, memref, map, indices);
2908 }
2909
parseAffineVectorLoadOp(OpAsmParser & parser,OperationState & result)2910 static ParseResult parseAffineVectorLoadOp(OpAsmParser &parser,
2911 OperationState &result) {
2912 auto &builder = parser.getBuilder();
2913 auto indexTy = builder.getIndexType();
2914
2915 MemRefType memrefType;
2916 VectorType resultType;
2917 OpAsmParser::OperandType memrefInfo;
2918 AffineMapAttr mapAttr;
2919 SmallVector<OpAsmParser::OperandType, 1> mapOperands;
2920 return failure(
2921 parser.parseOperand(memrefInfo) ||
2922 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
2923 AffineVectorLoadOp::getMapAttrName(),
2924 result.attributes) ||
2925 parser.parseOptionalAttrDict(result.attributes) ||
2926 parser.parseColonType(memrefType) || parser.parseComma() ||
2927 parser.parseType(resultType) ||
2928 parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
2929 parser.resolveOperands(mapOperands, indexTy, result.operands) ||
2930 parser.addTypeToList(resultType, result.types));
2931 }
2932
print(OpAsmPrinter & p,AffineVectorLoadOp op)2933 static void print(OpAsmPrinter &p, AffineVectorLoadOp op) {
2934 p << "affine.vector_load " << op.getMemRef() << '[';
2935 if (AffineMapAttr mapAttr =
2936 op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
2937 p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
2938 p << ']';
2939 p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()});
2940 p << " : " << op.getMemRefType() << ", " << op.getType();
2941 }
2942
2943 /// Verify common invariants of affine.vector_load and affine.vector_store.
verifyVectorMemoryOp(Operation * op,MemRefType memrefType,VectorType vectorType)2944 static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType,
2945 VectorType vectorType) {
2946 // Check that memref and vector element types match.
2947 if (memrefType.getElementType() != vectorType.getElementType())
2948 return op->emitOpError(
2949 "requires memref and vector types of the same elemental type");
2950 return success();
2951 }
2952
verify(AffineVectorLoadOp op)2953 static LogicalResult verify(AffineVectorLoadOp op) {
2954 MemRefType memrefType = op.getMemRefType();
2955 if (failed(verifyMemoryOpIndexing(
2956 op.getOperation(),
2957 op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
2958 op.getMapOperands(), memrefType,
2959 /*numIndexOperands=*/op.getNumOperands() - 1)))
2960 return failure();
2961
2962 if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType,
2963 op.getVectorType())))
2964 return failure();
2965
2966 return success();
2967 }
2968
2969 //===----------------------------------------------------------------------===//
2970 // AffineVectorStoreOp
2971 //===----------------------------------------------------------------------===//
2972
build(OpBuilder & builder,OperationState & result,Value valueToStore,Value memref,AffineMap map,ValueRange mapOperands)2973 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
2974 Value valueToStore, Value memref, AffineMap map,
2975 ValueRange mapOperands) {
2976 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
2977 result.addOperands(valueToStore);
2978 result.addOperands(memref);
2979 result.addOperands(mapOperands);
2980 result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
2981 }
2982
2983 // Use identity map.
build(OpBuilder & builder,OperationState & result,Value valueToStore,Value memref,ValueRange indices)2984 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
2985 Value valueToStore, Value memref,
2986 ValueRange indices) {
2987 auto memrefType = memref.getType().cast<MemRefType>();
2988 int64_t rank = memrefType.getRank();
2989 // Create identity map for memrefs with at least one dimension or () -> ()
2990 // for zero-dimensional memrefs.
2991 auto map =
2992 rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
2993 build(builder, result, valueToStore, memref, map, indices);
2994 }
2995
parseAffineVectorStoreOp(OpAsmParser & parser,OperationState & result)2996 static ParseResult parseAffineVectorStoreOp(OpAsmParser &parser,
2997 OperationState &result) {
2998 auto indexTy = parser.getBuilder().getIndexType();
2999
3000 MemRefType memrefType;
3001 VectorType resultType;
3002 OpAsmParser::OperandType storeValueInfo;
3003 OpAsmParser::OperandType memrefInfo;
3004 AffineMapAttr mapAttr;
3005 SmallVector<OpAsmParser::OperandType, 1> mapOperands;
3006 return failure(
3007 parser.parseOperand(storeValueInfo) || parser.parseComma() ||
3008 parser.parseOperand(memrefInfo) ||
3009 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3010 AffineVectorStoreOp::getMapAttrName(),
3011 result.attributes) ||
3012 parser.parseOptionalAttrDict(result.attributes) ||
3013 parser.parseColonType(memrefType) || parser.parseComma() ||
3014 parser.parseType(resultType) ||
3015 parser.resolveOperand(storeValueInfo, resultType, result.operands) ||
3016 parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
3017 parser.resolveOperands(mapOperands, indexTy, result.operands));
3018 }
3019
print(OpAsmPrinter & p,AffineVectorStoreOp op)3020 static void print(OpAsmPrinter &p, AffineVectorStoreOp op) {
3021 p << "affine.vector_store " << op.getValueToStore();
3022 p << ", " << op.getMemRef() << '[';
3023 if (AffineMapAttr mapAttr =
3024 op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
3025 p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
3026 p << ']';
3027 p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()});
3028 p << " : " << op.getMemRefType() << ", " << op.getValueToStore().getType();
3029 }
3030
verify(AffineVectorStoreOp op)3031 static LogicalResult verify(AffineVectorStoreOp op) {
3032 MemRefType memrefType = op.getMemRefType();
3033 if (failed(verifyMemoryOpIndexing(
3034 op.getOperation(),
3035 op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
3036 op.getMapOperands(), memrefType,
3037 /*numIndexOperands=*/op.getNumOperands() - 2)))
3038 return failure();
3039
3040 if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType,
3041 op.getVectorType())))
3042 return failure();
3043
3044 return success();
3045 }
3046
3047 //===----------------------------------------------------------------------===//
3048 // TableGen'd op method definitions
3049 //===----------------------------------------------------------------------===//
3050
3051 #define GET_OP_CLASSES
3052 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
3053