1 //===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "mlir/Dialect/AffineOps/AffineOps.h"
10 #include "mlir/Dialect/StandardOps/Ops.h"
11 #include "mlir/IR/Function.h"
12 #include "mlir/IR/IntegerSet.h"
13 #include "mlir/IR/Matchers.h"
14 #include "mlir/IR/OpImplementation.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/Transforms/InliningUtils.h"
17 #include "mlir/Transforms/SideEffectsInterface.h"
18 #include "llvm/ADT/SetVector.h"
19 #include "llvm/ADT/SmallBitVector.h"
20 #include "llvm/Support/Debug.h"
21
22 using namespace mlir;
23 using llvm::dbgs;
24
25 #define DEBUG_TYPE "affine-analysis"
26
27 //===----------------------------------------------------------------------===//
28 // AffineOpsDialect Interfaces
29 //===----------------------------------------------------------------------===//
30
31 namespace {
32 /// This class defines the interface for handling inlining with affine
33 /// operations.
34 struct AffineInlinerInterface : public DialectInlinerInterface {
35 using DialectInlinerInterface::DialectInlinerInterface;
36
37 //===--------------------------------------------------------------------===//
38 // Analysis Hooks
39 //===--------------------------------------------------------------------===//
40
41 /// Returns true if the given region 'src' can be inlined into the region
42 /// 'dest' that is attached to an operation registered to the current dialect.
isLegalToInline__anon86bd9b9c0111::AffineInlinerInterface43 bool isLegalToInline(Region *dest, Region *src,
44 BlockAndValueMapping &valueMapping) const final {
45 // Conservatively don't allow inlining into affine structures.
46 return false;
47 }
48
49 /// Returns true if the given operation 'op', that is registered to this
50 /// dialect, can be inlined into the given region, false otherwise.
isLegalToInline__anon86bd9b9c0111::AffineInlinerInterface51 bool isLegalToInline(Operation *op, Region *region,
52 BlockAndValueMapping &valueMapping) const final {
53 // Always allow inlining affine operations into the top-level region of a
54 // function. There are some edge cases when inlining *into* affine
55 // structures, but that is handled in the other 'isLegalToInline' hook
56 // above.
57 // TODO: We should be able to inline into other regions than functions.
58 return isa<FuncOp>(region->getParentOp());
59 }
60
61 /// Affine regions should be analyzed recursively.
shouldAnalyzeRecursively__anon86bd9b9c0111::AffineInlinerInterface62 bool shouldAnalyzeRecursively(Operation *op) const final { return true; }
63 };
64
65 // TODO(mlir): Extend for other ops in this dialect.
66 struct AffineSideEffectsInterface : public SideEffectsDialectInterface {
67 using SideEffectsDialectInterface::SideEffectsDialectInterface;
68
isSideEffecting__anon86bd9b9c0111::AffineSideEffectsInterface69 SideEffecting isSideEffecting(Operation *op) const override {
70 if (isa<AffineIfOp>(op)) {
71 return Recursive;
72 }
73 return SideEffectsDialectInterface::isSideEffecting(op);
74 };
75 };
76
77 } // end anonymous namespace
78
79 //===----------------------------------------------------------------------===//
80 // AffineOpsDialect
81 //===----------------------------------------------------------------------===//
82
AffineOpsDialect(MLIRContext * context)83 AffineOpsDialect::AffineOpsDialect(MLIRContext *context)
84 : Dialect(getDialectNamespace(), context) {
85 addOperations<AffineApplyOp, AffineDmaStartOp, AffineDmaWaitOp, AffineLoadOp,
86 AffineStoreOp,
87 #define GET_OP_LIST
88 #include "mlir/Dialect/AffineOps/AffineOps.cpp.inc"
89 >();
90 addInterfaces<AffineInlinerInterface, AffineSideEffectsInterface>();
91 }
92
93 /// Materialize a single constant operation from a given attribute value with
94 /// the desired resultant type.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)95 Operation *AffineOpsDialect::materializeConstant(OpBuilder &builder,
96 Attribute value, Type type,
97 Location loc) {
98 return builder.create<ConstantOp>(loc, type, value);
99 }
100
101 /// A utility function to check if a given region is attached to a function.
isFunctionRegion(Region * region)102 static bool isFunctionRegion(Region *region) {
103 return llvm::isa<FuncOp>(region->getParentOp());
104 }
105
106 /// A utility function to check if a value is defined at the top level of a
107 /// function. A value of index type defined at the top level is always a valid
108 /// symbol.
isTopLevelValue(Value value)109 bool mlir::isTopLevelValue(Value value) {
110 if (auto arg = value.dyn_cast<BlockArgument>())
111 return isFunctionRegion(arg.getOwner()->getParent());
112 return isFunctionRegion(value.getDefiningOp()->getParentRegion());
113 }
114
115 // Value can be used as a dimension id if it is valid as a symbol, or
116 // it is an induction variable, or it is a result of affine apply operation
117 // with dimension id arguments.
isValidDim(Value value)118 bool mlir::isValidDim(Value value) {
119 // The value must be an index type.
120 if (!value.getType().isIndex())
121 return false;
122
123 if (auto *op = value.getDefiningOp()) {
124 // Top level operation or constant operation is ok.
125 if (isFunctionRegion(op->getParentRegion()) || isa<ConstantOp>(op))
126 return true;
127 // Affine apply operation is ok if all of its operands are ok.
128 if (auto applyOp = dyn_cast<AffineApplyOp>(op))
129 return applyOp.isValidDim();
130 // The dim op is okay if its operand memref/tensor is defined at the top
131 // level.
132 if (auto dimOp = dyn_cast<DimOp>(op))
133 return isTopLevelValue(dimOp.getOperand());
134 return false;
135 }
136 // This value has to be a block argument for a FuncOp or an affine.for.
137 auto *parentOp = value.cast<BlockArgument>().getOwner()->getParentOp();
138 return isa<FuncOp>(parentOp) || isa<AffineForOp>(parentOp);
139 }
140
141 /// Returns true if the 'index' dimension of the `memref` defined by
142 /// `memrefDefOp` is a statically shaped one or defined using a valid symbol.
143 template <typename AnyMemRefDefOp>
isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp,unsigned index)144 static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp,
145 unsigned index) {
146 auto memRefType = memrefDefOp.getType();
147 // Statically shaped.
148 if (!ShapedType::isDynamic(memRefType.getDimSize(index)))
149 return true;
150 // Get the position of the dimension among dynamic dimensions;
151 unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
152 return isValidSymbol(
153 *(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos));
154 }
155
156 /// Returns true if the result of the dim op is a valid symbol.
isDimOpValidSymbol(DimOp dimOp)157 static bool isDimOpValidSymbol(DimOp dimOp) {
158 // The dim op is okay if its operand memref/tensor is defined at the top
159 // level.
160 if (isTopLevelValue(dimOp.getOperand()))
161 return true;
162
163 // The dim op is also okay if its operand memref/tensor is a view/subview
164 // whose corresponding size is a valid symbol.
165 unsigned index = dimOp.getIndex();
166 if (auto viewOp = dyn_cast<ViewOp>(dimOp.getOperand().getDefiningOp()))
167 return isMemRefSizeValidSymbol<ViewOp>(viewOp, index);
168 if (auto subViewOp = dyn_cast<SubViewOp>(dimOp.getOperand().getDefiningOp()))
169 return isMemRefSizeValidSymbol<SubViewOp>(subViewOp, index);
170 if (auto allocOp = dyn_cast<AllocOp>(dimOp.getOperand().getDefiningOp()))
171 return isMemRefSizeValidSymbol<AllocOp>(allocOp, index);
172 return false;
173 }
174
175 // Value can be used as a symbol if it is a constant, or it is defined at
176 // the top level, or it is a result of affine apply operation with symbol
177 // arguments, or a result of the dim op on a memref satisfying certain
178 // constraints.
isValidSymbol(Value value)179 bool mlir::isValidSymbol(Value value) {
180 // The value must be an index type.
181 if (!value.getType().isIndex())
182 return false;
183
184 if (auto *op = value.getDefiningOp()) {
185 // Top level operation or constant operation is ok.
186 if (isFunctionRegion(op->getParentRegion()) || isa<ConstantOp>(op))
187 return true;
188 // Affine apply operation is ok if all of its operands are ok.
189 if (auto applyOp = dyn_cast<AffineApplyOp>(op))
190 return applyOp.isValidSymbol();
191 if (auto dimOp = dyn_cast<DimOp>(op)) {
192 return isDimOpValidSymbol(dimOp);
193 }
194 }
195 // Otherwise, check that the value is a top level value.
196 return isTopLevelValue(value);
197 }
198
199 // Returns true if 'value' is a valid index to an affine operation (e.g.
200 // affine.load, affine.store, affine.dma_start, affine.dma_wait).
201 // Returns false otherwise.
isValidAffineIndexOperand(Value value)202 static bool isValidAffineIndexOperand(Value value) {
203 return isValidDim(value) || isValidSymbol(value);
204 }
205
206 /// Utility function to verify that a set of operands are valid dimension and
207 /// symbol identifiers. The operands should be laid out such that the dimension
208 /// operands are before the symbol operands. This function returns failure if
209 /// there was an invalid operand. An operation is provided to emit any necessary
210 /// errors.
211 template <typename OpTy>
212 static LogicalResult
verifyDimAndSymbolIdentifiers(OpTy & op,Operation::operand_range operands,unsigned numDims)213 verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands,
214 unsigned numDims) {
215 unsigned opIt = 0;
216 for (auto operand : operands) {
217 if (opIt++ < numDims) {
218 if (!isValidDim(operand))
219 return op.emitOpError("operand cannot be used as a dimension id");
220 } else if (!isValidSymbol(operand)) {
221 return op.emitOpError("operand cannot be used as a symbol");
222 }
223 }
224 return success();
225 }
226
227 //===----------------------------------------------------------------------===//
228 // AffineApplyOp
229 //===----------------------------------------------------------------------===//
230
build(Builder * builder,OperationState & result,AffineMap map,ValueRange operands)231 void AffineApplyOp::build(Builder *builder, OperationState &result,
232 AffineMap map, ValueRange operands) {
233 result.addOperands(operands);
234 result.types.append(map.getNumResults(), builder->getIndexType());
235 result.addAttribute("map", AffineMapAttr::get(map));
236 }
237
parse(OpAsmParser & parser,OperationState & result)238 ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) {
239 auto &builder = parser.getBuilder();
240 auto indexTy = builder.getIndexType();
241
242 AffineMapAttr mapAttr;
243 unsigned numDims;
244 if (parser.parseAttribute(mapAttr, "map", result.attributes) ||
245 parseDimAndSymbolList(parser, result.operands, numDims) ||
246 parser.parseOptionalAttrDict(result.attributes))
247 return failure();
248 auto map = mapAttr.getValue();
249
250 if (map.getNumDims() != numDims ||
251 numDims + map.getNumSymbols() != result.operands.size()) {
252 return parser.emitError(parser.getNameLoc(),
253 "dimension or symbol index mismatch");
254 }
255
256 result.types.append(map.getNumResults(), indexTy);
257 return success();
258 }
259
print(OpAsmPrinter & p)260 void AffineApplyOp::print(OpAsmPrinter &p) {
261 p << "affine.apply " << getAttr("map");
262 printDimAndSymbolList(operand_begin(), operand_end(),
263 getAffineMap().getNumDims(), p);
264 p.printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{"map"});
265 }
266
verify()267 LogicalResult AffineApplyOp::verify() {
268 // Check that affine map attribute was specified.
269 auto affineMapAttr = getAttrOfType<AffineMapAttr>("map");
270 if (!affineMapAttr)
271 return emitOpError("requires an affine map");
272
273 // Check input and output dimensions match.
274 auto map = affineMapAttr.getValue();
275
276 // Verify that operand count matches affine map dimension and symbol count.
277 if (getNumOperands() != map.getNumDims() + map.getNumSymbols())
278 return emitOpError(
279 "operand count and affine map dimension and symbol count must match");
280
281 // Verify that all operands are of `index` type.
282 for (Type t : getOperandTypes()) {
283 if (!t.isIndex())
284 return emitOpError("operands must be of type 'index'");
285 }
286
287 if (!getResult().getType().isIndex())
288 return emitOpError("result must be of type 'index'");
289
290 // Verify that the map only produces one result.
291 if (map.getNumResults() != 1)
292 return emitOpError("mapping must produce one value");
293
294 return success();
295 }
296
297 // The result of the affine apply operation can be used as a dimension id if all
298 // its operands are valid dimension ids.
isValidDim()299 bool AffineApplyOp::isValidDim() {
300 return llvm::all_of(getOperands(),
301 [](Value op) { return mlir::isValidDim(op); });
302 }
303
304 // The result of the affine apply operation can be used as a symbol if all its
305 // operands are symbols.
isValidSymbol()306 bool AffineApplyOp::isValidSymbol() {
307 return llvm::all_of(getOperands(),
308 [](Value op) { return mlir::isValidSymbol(op); });
309 }
310
fold(ArrayRef<Attribute> operands)311 OpFoldResult AffineApplyOp::fold(ArrayRef<Attribute> operands) {
312 auto map = getAffineMap();
313
314 // Fold dims and symbols to existing values.
315 auto expr = map.getResult(0);
316 if (auto dim = expr.dyn_cast<AffineDimExpr>())
317 return getOperand(dim.getPosition());
318 if (auto sym = expr.dyn_cast<AffineSymbolExpr>())
319 return getOperand(map.getNumDims() + sym.getPosition());
320
321 // Otherwise, default to folding the map.
322 SmallVector<Attribute, 1> result;
323 if (failed(map.constantFold(operands, result)))
324 return {};
325 return result[0];
326 }
327
renumberOneDim(Value v)328 AffineDimExpr AffineApplyNormalizer::renumberOneDim(Value v) {
329 DenseMap<Value, unsigned>::iterator iterPos;
330 bool inserted = false;
331 std::tie(iterPos, inserted) =
332 dimValueToPosition.insert(std::make_pair(v, dimValueToPosition.size()));
333 if (inserted) {
334 reorderedDims.push_back(v);
335 }
336 return getAffineDimExpr(iterPos->second, v.getContext())
337 .cast<AffineDimExpr>();
338 }
339
renumber(const AffineApplyNormalizer & other)340 AffineMap AffineApplyNormalizer::renumber(const AffineApplyNormalizer &other) {
341 SmallVector<AffineExpr, 8> dimRemapping;
342 for (auto v : other.reorderedDims) {
343 auto kvp = other.dimValueToPosition.find(v);
344 if (dimRemapping.size() <= kvp->second)
345 dimRemapping.resize(kvp->second + 1);
346 dimRemapping[kvp->second] = renumberOneDim(kvp->first);
347 }
348 unsigned numSymbols = concatenatedSymbols.size();
349 unsigned numOtherSymbols = other.concatenatedSymbols.size();
350 SmallVector<AffineExpr, 8> symRemapping(numOtherSymbols);
351 for (unsigned idx = 0; idx < numOtherSymbols; ++idx) {
352 symRemapping[idx] =
353 getAffineSymbolExpr(idx + numSymbols, other.affineMap.getContext());
354 }
355 concatenatedSymbols.insert(concatenatedSymbols.end(),
356 other.concatenatedSymbols.begin(),
357 other.concatenatedSymbols.end());
358 auto map = other.affineMap;
359 return map.replaceDimsAndSymbols(dimRemapping, symRemapping,
360 reorderedDims.size(),
361 concatenatedSymbols.size());
362 }
363
364 // Gather the positions of the operands that are produced by an AffineApplyOp.
365 static llvm::SetVector<unsigned>
indicesFromAffineApplyOp(ArrayRef<Value> operands)366 indicesFromAffineApplyOp(ArrayRef<Value> operands) {
367 llvm::SetVector<unsigned> res;
368 for (auto en : llvm::enumerate(operands))
369 if (isa_and_nonnull<AffineApplyOp>(en.value().getDefiningOp()))
370 res.insert(en.index());
371 return res;
372 }
373
374 // Support the special case of a symbol coming from an AffineApplyOp that needs
375 // to be composed into the current AffineApplyOp.
376 // This case is handled by rewriting all such symbols into dims for the purpose
377 // of allowing mathematical AffineMap composition.
378 // Returns an AffineMap where symbols that come from an AffineApplyOp have been
379 // rewritten as dims and are ordered after the original dims.
380 // TODO(andydavis,ntv): This promotion makes AffineMap lose track of which
381 // symbols are represented as dims. This loss is static but can still be
382 // recovered dynamically (with `isValidSymbol`). Still this is annoying for the
383 // semi-affine map case. A dynamic canonicalization of all dims that are valid
384 // symbols (a.k.a `canonicalizePromotedSymbols`) into symbols helps and even
385 // results in better simplifications and foldings. But we should evaluate
386 // whether this behavior is what we really want after using more.
promoteComposedSymbolsAsDims(AffineMap map,ArrayRef<Value> symbols)387 static AffineMap promoteComposedSymbolsAsDims(AffineMap map,
388 ArrayRef<Value> symbols) {
389 if (symbols.empty()) {
390 return map;
391 }
392
393 // Sanity check on symbols.
394 for (auto sym : symbols) {
395 assert(isValidSymbol(sym) && "Expected only valid symbols");
396 (void)sym;
397 }
398
399 // Extract the symbol positions that come from an AffineApplyOp and
400 // needs to be rewritten as dims.
401 auto symPositions = indicesFromAffineApplyOp(symbols);
402 if (symPositions.empty()) {
403 return map;
404 }
405
406 // Create the new map by replacing each symbol at pos by the next new dim.
407 unsigned numDims = map.getNumDims();
408 unsigned numSymbols = map.getNumSymbols();
409 unsigned numNewDims = 0;
410 unsigned numNewSymbols = 0;
411 SmallVector<AffineExpr, 8> symReplacements(numSymbols);
412 for (unsigned i = 0; i < numSymbols; ++i) {
413 symReplacements[i] =
414 symPositions.count(i) > 0
415 ? getAffineDimExpr(numDims + numNewDims++, map.getContext())
416 : getAffineSymbolExpr(numNewSymbols++, map.getContext());
417 }
418 assert(numSymbols >= numNewDims);
419 AffineMap newMap = map.replaceDimsAndSymbols(
420 {}, symReplacements, numDims + numNewDims, numNewSymbols);
421
422 return newMap;
423 }
424
425 /// The AffineNormalizer composes AffineApplyOp recursively. Its purpose is to
426 /// keep a correspondence between the mathematical `map` and the `operands` of
427 /// a given AffineApplyOp. This correspondence is maintained by iterating over
428 /// the operands and forming an `auxiliaryMap` that can be composed
429 /// mathematically with `map`. To keep this correspondence in cases where
430 /// symbols are produced by affine.apply operations, we perform a local rewrite
431 /// of symbols as dims.
432 ///
433 /// Rationale for locally rewriting symbols as dims:
434 /// ================================================
435 /// The mathematical composition of AffineMap must always concatenate symbols
436 /// because it does not have enough information to do otherwise. For example,
437 /// composing `(d0)[s0] -> (d0 + s0)` with itself must produce
438 /// `(d0)[s0, s1] -> (d0 + s0 + s1)`.
439 ///
440 /// The result is only equivalent to `(d0)[s0] -> (d0 + 2 * s0)` when
441 /// applied to the same mlir::Value for both s0 and s1.
442 /// As a consequence mathematical composition of AffineMap always concatenates
443 /// symbols.
444 ///
445 /// When AffineMaps are used in AffineApplyOp however, they may specify
446 /// composition via symbols, which is ambiguous mathematically. This corner case
447 /// is handled by locally rewriting such symbols that come from AffineApplyOp
448 /// into dims and composing through dims.
449 /// TODO(andydavis, ntv): Composition via symbols comes at a significant code
450 /// complexity. Alternatively we should investigate whether we want to
451 /// explicitly disallow symbols coming from affine.apply and instead force the
452 /// user to compose symbols beforehand. The annoyances may be small (i.e. 1 or 2
453 /// extra API calls for such uses, which haven't popped up until now) and the
454 /// benefit potentially big: simpler and more maintainable code for a
455 /// non-trivial, recursive, procedure.
AffineApplyNormalizer(AffineMap map,ArrayRef<Value> operands)456 AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map,
457 ArrayRef<Value> operands)
458 : AffineApplyNormalizer() {
459 static_assert(kMaxAffineApplyDepth > 0, "kMaxAffineApplyDepth must be > 0");
460 assert(map.getNumInputs() == operands.size() &&
461 "number of operands does not match the number of map inputs");
462
463 LLVM_DEBUG(map.print(dbgs() << "\nInput map: "));
464
465 // Promote symbols that come from an AffineApplyOp to dims by rewriting the
466 // map to always refer to:
467 // (dims, symbols coming from AffineApplyOp, other symbols).
468 // The order of operands can remain unchanged.
469 // This is a simplification that relies on 2 ordering properties:
470 // 1. rewritten symbols always appear after the original dims in the map;
471 // 2. operands are traversed in order and either dispatched to:
472 // a. auxiliaryExprs (dims and symbols rewritten as dims);
473 // b. concatenatedSymbols (all other symbols)
474 // This allows operand order to remain unchanged.
475 unsigned numDimsBeforeRewrite = map.getNumDims();
476 map = promoteComposedSymbolsAsDims(map,
477 operands.take_back(map.getNumSymbols()));
478
479 LLVM_DEBUG(map.print(dbgs() << "\nRewritten map: "));
480
481 SmallVector<AffineExpr, 8> auxiliaryExprs;
482 bool furtherCompose = (affineApplyDepth() <= kMaxAffineApplyDepth);
483 // We fully spell out the 2 cases below. In this particular instance a little
484 // code duplication greatly improves readability.
485 // Note that the first branch would disappear if we only supported full
486 // composition (i.e. infinite kMaxAffineApplyDepth).
487 if (!furtherCompose) {
488 // 1. Only dispatch dims or symbols.
489 for (auto en : llvm::enumerate(operands)) {
490 auto t = en.value();
491 assert(t.getType().isIndex());
492 bool isDim = (en.index() < map.getNumDims());
493 if (isDim) {
494 // a. The mathematical composition of AffineMap composes dims.
495 auxiliaryExprs.push_back(renumberOneDim(t));
496 } else {
497 // b. The mathematical composition of AffineMap concatenates symbols.
498 // We do the same for symbol operands.
499 concatenatedSymbols.push_back(t);
500 }
501 }
502 } else {
503 assert(numDimsBeforeRewrite <= operands.size());
504 // 2. Compose AffineApplyOps and dispatch dims or symbols.
505 for (unsigned i = 0, e = operands.size(); i < e; ++i) {
506 auto t = operands[i];
507 auto affineApply = dyn_cast_or_null<AffineApplyOp>(t.getDefiningOp());
508 if (affineApply) {
509 // a. Compose affine.apply operations.
510 LLVM_DEBUG(affineApply.getOperation()->print(
511 dbgs() << "\nCompose AffineApplyOp recursively: "));
512 AffineMap affineApplyMap = affineApply.getAffineMap();
513 SmallVector<Value, 8> affineApplyOperands(
514 affineApply.getOperands().begin(), affineApply.getOperands().end());
515 AffineApplyNormalizer normalizer(affineApplyMap, affineApplyOperands);
516
517 LLVM_DEBUG(normalizer.affineMap.print(
518 dbgs() << "\nRenumber into current normalizer: "));
519
520 auto renumberedMap = renumber(normalizer);
521
522 LLVM_DEBUG(
523 renumberedMap.print(dbgs() << "\nRecursive composition yields: "));
524
525 auxiliaryExprs.push_back(renumberedMap.getResult(0));
526 } else {
527 if (i < numDimsBeforeRewrite) {
528 // b. The mathematical composition of AffineMap composes dims.
529 auxiliaryExprs.push_back(renumberOneDim(t));
530 } else {
531 // c. The mathematical composition of AffineMap concatenates symbols.
532 // We do the same for symbol operands.
533 concatenatedSymbols.push_back(t);
534 }
535 }
536 }
537 }
538
539 // Early exit if `map` is already composed.
540 if (auxiliaryExprs.empty()) {
541 affineMap = map;
542 return;
543 }
544
545 assert(concatenatedSymbols.size() >= map.getNumSymbols() &&
546 "Unexpected number of concatenated symbols");
547 auto numDims = dimValueToPosition.size();
548 auto numSymbols = concatenatedSymbols.size() - map.getNumSymbols();
549 auto auxiliaryMap = AffineMap::get(numDims, numSymbols, auxiliaryExprs);
550
551 LLVM_DEBUG(map.print(dbgs() << "\nCompose map: "));
552 LLVM_DEBUG(auxiliaryMap.print(dbgs() << "\nWith map: "));
553 LLVM_DEBUG(map.compose(auxiliaryMap).print(dbgs() << "\nResult: "));
554
555 // TODO(andydavis,ntv): Disabling simplification results in major speed gains.
556 // Another option is to cache the results as it is expected a lot of redundant
557 // work is performed in practice.
558 affineMap = simplifyAffineMap(map.compose(auxiliaryMap));
559
560 LLVM_DEBUG(affineMap.print(dbgs() << "\nSimplified result: "));
561 LLVM_DEBUG(dbgs() << "\n");
562 }
563
normalize(AffineMap * otherMap,SmallVectorImpl<Value> * otherOperands)564 void AffineApplyNormalizer::normalize(AffineMap *otherMap,
565 SmallVectorImpl<Value> *otherOperands) {
566 AffineApplyNormalizer other(*otherMap, *otherOperands);
567 *otherMap = renumber(other);
568
569 otherOperands->reserve(reorderedDims.size() + concatenatedSymbols.size());
570 otherOperands->assign(reorderedDims.begin(), reorderedDims.end());
571 otherOperands->append(concatenatedSymbols.begin(), concatenatedSymbols.end());
572 }
573
574 /// Implements `map` and `operands` composition and simplification to support
575 /// `makeComposedAffineApply`. This can be called to achieve the same effects
576 /// on `map` and `operands` without creating an AffineApplyOp that needs to be
577 /// immediately deleted.
composeAffineMapAndOperands(AffineMap * map,SmallVectorImpl<Value> * operands)578 static void composeAffineMapAndOperands(AffineMap *map,
579 SmallVectorImpl<Value> *operands) {
580 AffineApplyNormalizer normalizer(*map, *operands);
581 auto normalizedMap = normalizer.getAffineMap();
582 auto normalizedOperands = normalizer.getOperands();
583 canonicalizeMapAndOperands(&normalizedMap, &normalizedOperands);
584 *map = normalizedMap;
585 *operands = normalizedOperands;
586 assert(*map);
587 }
588
fullyComposeAffineMapAndOperands(AffineMap * map,SmallVectorImpl<Value> * operands)589 void mlir::fullyComposeAffineMapAndOperands(AffineMap *map,
590 SmallVectorImpl<Value> *operands) {
591 while (llvm::any_of(*operands, [](Value v) {
592 return isa_and_nonnull<AffineApplyOp>(v.getDefiningOp());
593 })) {
594 composeAffineMapAndOperands(map, operands);
595 }
596 }
597
makeComposedAffineApply(OpBuilder & b,Location loc,AffineMap map,ArrayRef<Value> operands)598 AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc,
599 AffineMap map,
600 ArrayRef<Value> operands) {
601 AffineMap normalizedMap = map;
602 SmallVector<Value, 8> normalizedOperands(operands.begin(), operands.end());
603 composeAffineMapAndOperands(&normalizedMap, &normalizedOperands);
604 assert(normalizedMap);
605 return b.create<AffineApplyOp>(loc, normalizedMap, normalizedOperands);
606 }
607
608 // A symbol may appear as a dim in affine.apply operations. This function
609 // canonicalizes dims that are valid symbols into actual symbols.
610 template <class MapOrSet>
canonicalizePromotedSymbols(MapOrSet * mapOrSet,SmallVectorImpl<Value> * operands)611 static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,
612 SmallVectorImpl<Value> *operands) {
613 if (!mapOrSet || operands->empty())
614 return;
615
616 assert(mapOrSet->getNumInputs() == operands->size() &&
617 "map/set inputs must match number of operands");
618
619 auto *context = mapOrSet->getContext();
620 SmallVector<Value, 8> resultOperands;
621 resultOperands.reserve(operands->size());
622 SmallVector<Value, 8> remappedSymbols;
623 remappedSymbols.reserve(operands->size());
624 unsigned nextDim = 0;
625 unsigned nextSym = 0;
626 unsigned oldNumSyms = mapOrSet->getNumSymbols();
627 SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
628 for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
629 if (i < mapOrSet->getNumDims()) {
630 if (isValidSymbol((*operands)[i])) {
631 // This is a valid symbol that appears as a dim, canonicalize it.
632 dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context);
633 remappedSymbols.push_back((*operands)[i]);
634 } else {
635 dimRemapping[i] = getAffineDimExpr(nextDim++, context);
636 resultOperands.push_back((*operands)[i]);
637 }
638 } else {
639 resultOperands.push_back((*operands)[i]);
640 }
641 }
642
643 resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
644 *operands = resultOperands;
645 *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
646 oldNumSyms + nextSym);
647
648 assert(mapOrSet->getNumInputs() == operands->size() &&
649 "map/set inputs must match number of operands");
650 }
651
652 // Works for either an affine map or an integer set.
653 template <class MapOrSet>
canonicalizeMapOrSetAndOperands(MapOrSet * mapOrSet,SmallVectorImpl<Value> * operands)654 static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
655 SmallVectorImpl<Value> *operands) {
656 static_assert(std::is_same<MapOrSet, AffineMap>::value ||
657 std::is_same<MapOrSet, IntegerSet>::value,
658 "Argument must be either of AffineMap or IntegerSet type");
659
660 if (!mapOrSet || operands->empty())
661 return;
662
663 assert(mapOrSet->getNumInputs() == operands->size() &&
664 "map/set inputs must match number of operands");
665
666 canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
667
668 // Check to see what dims are used.
669 llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
670 llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
671 mapOrSet->walkExprs([&](AffineExpr expr) {
672 if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
673 usedDims[dimExpr.getPosition()] = true;
674 else if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
675 usedSyms[symExpr.getPosition()] = true;
676 });
677
678 auto *context = mapOrSet->getContext();
679
680 SmallVector<Value, 8> resultOperands;
681 resultOperands.reserve(operands->size());
682
683 llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
684 SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
685 unsigned nextDim = 0;
686 for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
687 if (usedDims[i]) {
688 // Remap dim positions for duplicate operands.
689 auto it = seenDims.find((*operands)[i]);
690 if (it == seenDims.end()) {
691 dimRemapping[i] = getAffineDimExpr(nextDim++, context);
692 resultOperands.push_back((*operands)[i]);
693 seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
694 } else {
695 dimRemapping[i] = it->second;
696 }
697 }
698 }
699 llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
700 SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
701 unsigned nextSym = 0;
702 for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
703 if (!usedSyms[i])
704 continue;
705 // Handle constant operands (only needed for symbolic operands since
706 // constant operands in dimensional positions would have already been
707 // promoted to symbolic positions above).
708 IntegerAttr operandCst;
709 if (matchPattern((*operands)[i + mapOrSet->getNumDims()],
710 m_Constant(&operandCst))) {
711 symRemapping[i] =
712 getAffineConstantExpr(operandCst.getValue().getSExtValue(), context);
713 continue;
714 }
715 // Remap symbol positions for duplicate operands.
716 auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
717 if (it == seenSymbols.end()) {
718 symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
719 resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
720 seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
721 symRemapping[i]));
722 } else {
723 symRemapping[i] = it->second;
724 }
725 }
726 *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
727 nextDim, nextSym);
728 *operands = resultOperands;
729 }
730
canonicalizeMapAndOperands(AffineMap * map,SmallVectorImpl<Value> * operands)731 void mlir::canonicalizeMapAndOperands(AffineMap *map,
732 SmallVectorImpl<Value> *operands) {
733 canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
734 }
735
canonicalizeSetAndOperands(IntegerSet * set,SmallVectorImpl<Value> * operands)736 void mlir::canonicalizeSetAndOperands(IntegerSet *set,
737 SmallVectorImpl<Value> *operands) {
738 canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
739 }
740
741 namespace {
742 /// Simplify AffineApply, AffineLoad, and AffineStore operations by composing
743 /// maps that supply results into them.
744 ///
745 template <typename AffineOpTy>
746 struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
747 using OpRewritePattern<AffineOpTy>::OpRewritePattern;
748
749 /// Replace the affine op with another instance of it with the supplied
750 /// map and mapOperands.
751 void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
752 AffineMap map, ArrayRef<Value> mapOperands) const;
753
matchAndRewrite__anon86bd9b9c0611::SimplifyAffineOp754 PatternMatchResult matchAndRewrite(AffineOpTy affineOp,
755 PatternRewriter &rewriter) const override {
756 static_assert(std::is_same<AffineOpTy, AffineLoadOp>::value ||
757 std::is_same<AffineOpTy, AffinePrefetchOp>::value ||
758 std::is_same<AffineOpTy, AffineStoreOp>::value ||
759 std::is_same<AffineOpTy, AffineApplyOp>::value,
760 "affine load/store/apply op expected");
761 auto map = affineOp.getAffineMap();
762 AffineMap oldMap = map;
763 auto oldOperands = affineOp.getMapOperands();
764 SmallVector<Value, 8> resultOperands(oldOperands);
765 composeAffineMapAndOperands(&map, &resultOperands);
766 if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
767 resultOperands.begin()))
768 return this->matchFailure();
769
770 replaceAffineOp(rewriter, affineOp, map, resultOperands);
771 return this->matchSuccess();
772 }
773 };
774
775 // Specialize the template to account for the different build signatures for
776 // affine load, store, and apply ops.
777 template <>
replaceAffineOp(PatternRewriter & rewriter,AffineLoadOp load,AffineMap map,ArrayRef<Value> mapOperands) const778 void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
779 PatternRewriter &rewriter, AffineLoadOp load, AffineMap map,
780 ArrayRef<Value> mapOperands) const {
781 rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.getMemRef(), map,
782 mapOperands);
783 }
784 template <>
replaceAffineOp(PatternRewriter & rewriter,AffinePrefetchOp prefetch,AffineMap map,ArrayRef<Value> mapOperands) const785 void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
786 PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map,
787 ArrayRef<Value> mapOperands) const {
788 rewriter.replaceOpWithNewOp<AffinePrefetchOp>(
789 prefetch, prefetch.memref(), map, mapOperands,
790 prefetch.localityHint().getZExtValue(), prefetch.isWrite(),
791 prefetch.isDataCache());
792 }
793 template <>
replaceAffineOp(PatternRewriter & rewriter,AffineStoreOp store,AffineMap map,ArrayRef<Value> mapOperands) const794 void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
795 PatternRewriter &rewriter, AffineStoreOp store, AffineMap map,
796 ArrayRef<Value> mapOperands) const {
797 rewriter.replaceOpWithNewOp<AffineStoreOp>(
798 store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
799 }
800 template <>
replaceAffineOp(PatternRewriter & rewriter,AffineApplyOp apply,AffineMap map,ArrayRef<Value> mapOperands) const801 void SimplifyAffineOp<AffineApplyOp>::replaceAffineOp(
802 PatternRewriter &rewriter, AffineApplyOp apply, AffineMap map,
803 ArrayRef<Value> mapOperands) const {
804 rewriter.replaceOpWithNewOp<AffineApplyOp>(apply, map, mapOperands);
805 }
806 } // end anonymous namespace.
807
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)808 void AffineApplyOp::getCanonicalizationPatterns(
809 OwningRewritePatternList &results, MLIRContext *context) {
810 results.insert<SimplifyAffineOp<AffineApplyOp>>(context);
811 }
812
813 //===----------------------------------------------------------------------===//
814 // Common canonicalization pattern support logic
815 //===----------------------------------------------------------------------===//
816
817 /// This is a common class used for patterns of the form
818 /// "someop(memrefcast) -> someop". It folds the source of any memref_cast
819 /// into the root operation directly.
foldMemRefCast(Operation * op)820 static LogicalResult foldMemRefCast(Operation *op) {
821 bool folded = false;
822 for (OpOperand &operand : op->getOpOperands()) {
823 auto cast = dyn_cast_or_null<MemRefCastOp>(operand.get().getDefiningOp());
824 if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
825 operand.set(cast.getOperand());
826 folded = true;
827 }
828 }
829 return success(folded);
830 }
831
832 //===----------------------------------------------------------------------===//
833 // AffineDmaStartOp
834 //===----------------------------------------------------------------------===//
835
836 // TODO(b/133776335) Check that map operands are loop IVs or symbols.
build(Builder * builder,OperationState & result,Value srcMemRef,AffineMap srcMap,ValueRange srcIndices,Value destMemRef,AffineMap dstMap,ValueRange destIndices,Value tagMemRef,AffineMap tagMap,ValueRange tagIndices,Value numElements,Value stride,Value elementsPerStride)837 void AffineDmaStartOp::build(Builder *builder, OperationState &result,
838 Value srcMemRef, AffineMap srcMap,
839 ValueRange srcIndices, Value destMemRef,
840 AffineMap dstMap, ValueRange destIndices,
841 Value tagMemRef, AffineMap tagMap,
842 ValueRange tagIndices, Value numElements,
843 Value stride, Value elementsPerStride) {
844 result.addOperands(srcMemRef);
845 result.addAttribute(getSrcMapAttrName(), AffineMapAttr::get(srcMap));
846 result.addOperands(srcIndices);
847 result.addOperands(destMemRef);
848 result.addAttribute(getDstMapAttrName(), AffineMapAttr::get(dstMap));
849 result.addOperands(destIndices);
850 result.addOperands(tagMemRef);
851 result.addAttribute(getTagMapAttrName(), AffineMapAttr::get(tagMap));
852 result.addOperands(tagIndices);
853 result.addOperands(numElements);
854 if (stride) {
855 result.addOperands({stride, elementsPerStride});
856 }
857 }
858
print(OpAsmPrinter & p)859 void AffineDmaStartOp::print(OpAsmPrinter &p) {
860 p << "affine.dma_start " << getSrcMemRef() << '[';
861 p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices());
862 p << "], " << getDstMemRef() << '[';
863 p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices());
864 p << "], " << getTagMemRef() << '[';
865 p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices());
866 p << "], " << getNumElements();
867 if (isStrided()) {
868 p << ", " << getStride();
869 p << ", " << getNumElementsPerStride();
870 }
871 p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
872 << getTagMemRefType();
873 }
874
875 // Parse AffineDmaStartOp.
876 // Ex:
877 // affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size,
878 // %stride, %num_elt_per_stride
879 // : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32>
880 //
parse(OpAsmParser & parser,OperationState & result)881 ParseResult AffineDmaStartOp::parse(OpAsmParser &parser,
882 OperationState &result) {
883 OpAsmParser::OperandType srcMemRefInfo;
884 AffineMapAttr srcMapAttr;
885 SmallVector<OpAsmParser::OperandType, 4> srcMapOperands;
886 OpAsmParser::OperandType dstMemRefInfo;
887 AffineMapAttr dstMapAttr;
888 SmallVector<OpAsmParser::OperandType, 4> dstMapOperands;
889 OpAsmParser::OperandType tagMemRefInfo;
890 AffineMapAttr tagMapAttr;
891 SmallVector<OpAsmParser::OperandType, 4> tagMapOperands;
892 OpAsmParser::OperandType numElementsInfo;
893 SmallVector<OpAsmParser::OperandType, 2> strideInfo;
894
895 SmallVector<Type, 3> types;
896 auto indexType = parser.getBuilder().getIndexType();
897
898 // Parse and resolve the following list of operands:
899 // *) dst memref followed by its affine maps operands (in square brackets).
900 // *) src memref followed by its affine map operands (in square brackets).
901 // *) tag memref followed by its affine map operands (in square brackets).
902 // *) number of elements transferred by DMA operation.
903 if (parser.parseOperand(srcMemRefInfo) ||
904 parser.parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr,
905 getSrcMapAttrName(), result.attributes) ||
906 parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
907 parser.parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr,
908 getDstMapAttrName(), result.attributes) ||
909 parser.parseComma() || parser.parseOperand(tagMemRefInfo) ||
910 parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
911 getTagMapAttrName(), result.attributes) ||
912 parser.parseComma() || parser.parseOperand(numElementsInfo))
913 return failure();
914
915 // Parse optional stride and elements per stride.
916 if (parser.parseTrailingOperandList(strideInfo)) {
917 return failure();
918 }
919 if (!strideInfo.empty() && strideInfo.size() != 2) {
920 return parser.emitError(parser.getNameLoc(),
921 "expected two stride related operands");
922 }
923 bool isStrided = strideInfo.size() == 2;
924
925 if (parser.parseColonTypeList(types))
926 return failure();
927
928 if (types.size() != 3)
929 return parser.emitError(parser.getNameLoc(), "expected three types");
930
931 if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
932 parser.resolveOperands(srcMapOperands, indexType, result.operands) ||
933 parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
934 parser.resolveOperands(dstMapOperands, indexType, result.operands) ||
935 parser.resolveOperand(tagMemRefInfo, types[2], result.operands) ||
936 parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
937 parser.resolveOperand(numElementsInfo, indexType, result.operands))
938 return failure();
939
940 if (isStrided) {
941 if (parser.resolveOperands(strideInfo, indexType, result.operands))
942 return failure();
943 }
944
945 // Check that src/dst/tag operand counts match their map.numInputs.
946 if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
947 dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
948 tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
949 return parser.emitError(parser.getNameLoc(),
950 "memref operand count not equal to map.numInputs");
951 return success();
952 }
953
verify()954 LogicalResult AffineDmaStartOp::verify() {
955 if (!getOperand(getSrcMemRefOperandIndex()).getType().isa<MemRefType>())
956 return emitOpError("expected DMA source to be of memref type");
957 if (!getOperand(getDstMemRefOperandIndex()).getType().isa<MemRefType>())
958 return emitOpError("expected DMA destination to be of memref type");
959 if (!getOperand(getTagMemRefOperandIndex()).getType().isa<MemRefType>())
960 return emitOpError("expected DMA tag to be of memref type");
961
962 // DMAs from different memory spaces supported.
963 if (getSrcMemorySpace() == getDstMemorySpace()) {
964 return emitOpError("DMA should be between different memory spaces");
965 }
966 unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
967 getDstMap().getNumInputs() +
968 getTagMap().getNumInputs();
969 if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
970 getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
971 return emitOpError("incorrect number of operands");
972 }
973
974 for (auto idx : getSrcIndices()) {
975 if (!idx.getType().isIndex())
976 return emitOpError("src index to dma_start must have 'index' type");
977 if (!isValidAffineIndexOperand(idx))
978 return emitOpError("src index must be a dimension or symbol identifier");
979 }
980 for (auto idx : getDstIndices()) {
981 if (!idx.getType().isIndex())
982 return emitOpError("dst index to dma_start must have 'index' type");
983 if (!isValidAffineIndexOperand(idx))
984 return emitOpError("dst index must be a dimension or symbol identifier");
985 }
986 for (auto idx : getTagIndices()) {
987 if (!idx.getType().isIndex())
988 return emitOpError("tag index to dma_start must have 'index' type");
989 if (!isValidAffineIndexOperand(idx))
990 return emitOpError("tag index must be a dimension or symbol identifier");
991 }
992 return success();
993 }
994
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)995 LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
996 SmallVectorImpl<OpFoldResult> &results) {
997 /// dma_start(memrefcast) -> dma_start
998 return foldMemRefCast(*this);
999 }
1000
1001 //===----------------------------------------------------------------------===//
1002 // AffineDmaWaitOp
1003 //===----------------------------------------------------------------------===//
1004
1005 // TODO(b/133776335) Check that map operands are loop IVs or symbols.
build(Builder * builder,OperationState & result,Value tagMemRef,AffineMap tagMap,ValueRange tagIndices,Value numElements)1006 void AffineDmaWaitOp::build(Builder *builder, OperationState &result,
1007 Value tagMemRef, AffineMap tagMap,
1008 ValueRange tagIndices, Value numElements) {
1009 result.addOperands(tagMemRef);
1010 result.addAttribute(getTagMapAttrName(), AffineMapAttr::get(tagMap));
1011 result.addOperands(tagIndices);
1012 result.addOperands(numElements);
1013 }
1014
print(OpAsmPrinter & p)1015 void AffineDmaWaitOp::print(OpAsmPrinter &p) {
1016 p << "affine.dma_wait " << getTagMemRef() << '[';
1017 SmallVector<Value, 2> operands(getTagIndices());
1018 p.printAffineMapOfSSAIds(getTagMapAttr(), operands);
1019 p << "], ";
1020 p.printOperand(getNumElements());
1021 p << " : " << getTagMemRef().getType();
1022 }
1023
1024 // Parse AffineDmaWaitOp.
1025 // Eg:
1026 // affine.dma_wait %tag[%index], %num_elements
1027 // : memref<1 x i32, (d0) -> (d0), 4>
1028 //
parse(OpAsmParser & parser,OperationState & result)1029 ParseResult AffineDmaWaitOp::parse(OpAsmParser &parser,
1030 OperationState &result) {
1031 OpAsmParser::OperandType tagMemRefInfo;
1032 AffineMapAttr tagMapAttr;
1033 SmallVector<OpAsmParser::OperandType, 2> tagMapOperands;
1034 Type type;
1035 auto indexType = parser.getBuilder().getIndexType();
1036 OpAsmParser::OperandType numElementsInfo;
1037
1038 // Parse tag memref, its map operands, and dma size.
1039 if (parser.parseOperand(tagMemRefInfo) ||
1040 parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1041 getTagMapAttrName(), result.attributes) ||
1042 parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1043 parser.parseColonType(type) ||
1044 parser.resolveOperand(tagMemRefInfo, type, result.operands) ||
1045 parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1046 parser.resolveOperand(numElementsInfo, indexType, result.operands))
1047 return failure();
1048
1049 if (!type.isa<MemRefType>())
1050 return parser.emitError(parser.getNameLoc(),
1051 "expected tag to be of memref type");
1052
1053 if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1054 return parser.emitError(parser.getNameLoc(),
1055 "tag memref operand count != to map.numInputs");
1056 return success();
1057 }
1058
verify()1059 LogicalResult AffineDmaWaitOp::verify() {
1060 if (!getOperand(0).getType().isa<MemRefType>())
1061 return emitOpError("expected DMA tag to be of memref type");
1062 for (auto idx : getTagIndices()) {
1063 if (!idx.getType().isIndex())
1064 return emitOpError("index to dma_wait must have 'index' type");
1065 if (!isValidAffineIndexOperand(idx))
1066 return emitOpError("index must be a dimension or symbol identifier");
1067 }
1068 return success();
1069 }
1070
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1071 LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1072 SmallVectorImpl<OpFoldResult> &results) {
1073 /// dma_wait(memrefcast) -> dma_wait
1074 return foldMemRefCast(*this);
1075 }
1076
1077 //===----------------------------------------------------------------------===//
1078 // AffineForOp
1079 //===----------------------------------------------------------------------===//
1080
build(Builder * builder,OperationState & result,ValueRange lbOperands,AffineMap lbMap,ValueRange ubOperands,AffineMap ubMap,int64_t step)1081 void AffineForOp::build(Builder *builder, OperationState &result,
1082 ValueRange lbOperands, AffineMap lbMap,
1083 ValueRange ubOperands, AffineMap ubMap, int64_t step) {
1084 assert(((!lbMap && lbOperands.empty()) ||
1085 lbOperands.size() == lbMap.getNumInputs()) &&
1086 "lower bound operand count does not match the affine map");
1087 assert(((!ubMap && ubOperands.empty()) ||
1088 ubOperands.size() == ubMap.getNumInputs()) &&
1089 "upper bound operand count does not match the affine map");
1090 assert(step > 0 && "step has to be a positive integer constant");
1091
1092 // Add an attribute for the step.
1093 result.addAttribute(getStepAttrName(),
1094 builder->getIntegerAttr(builder->getIndexType(), step));
1095
1096 // Add the lower bound.
1097 result.addAttribute(getLowerBoundAttrName(), AffineMapAttr::get(lbMap));
1098 result.addOperands(lbOperands);
1099
1100 // Add the upper bound.
1101 result.addAttribute(getUpperBoundAttrName(), AffineMapAttr::get(ubMap));
1102 result.addOperands(ubOperands);
1103
1104 // Create a region and a block for the body. The argument of the region is
1105 // the loop induction variable.
1106 Region *bodyRegion = result.addRegion();
1107 Block *body = new Block();
1108 body->addArgument(IndexType::get(builder->getContext()));
1109 bodyRegion->push_back(body);
1110 ensureTerminator(*bodyRegion, *builder, result.location);
1111
1112 // Set the operands list as resizable so that we can freely modify the bounds.
1113 result.setOperandListToResizable();
1114 }
1115
build(Builder * builder,OperationState & result,int64_t lb,int64_t ub,int64_t step)1116 void AffineForOp::build(Builder *builder, OperationState &result, int64_t lb,
1117 int64_t ub, int64_t step) {
1118 auto lbMap = AffineMap::getConstantMap(lb, builder->getContext());
1119 auto ubMap = AffineMap::getConstantMap(ub, builder->getContext());
1120 return build(builder, result, {}, lbMap, {}, ubMap, step);
1121 }
1122
verify(AffineForOp op)1123 static LogicalResult verify(AffineForOp op) {
1124 // Check that the body defines as single block argument for the induction
1125 // variable.
1126 auto *body = op.getBody();
1127 if (body->getNumArguments() != 1 || !body->getArgument(0).getType().isIndex())
1128 return op.emitOpError(
1129 "expected body to have a single index argument for the "
1130 "induction variable");
1131
1132 // Verify that there are enough operands for the bounds.
1133 AffineMap lowerBoundMap = op.getLowerBoundMap(),
1134 upperBoundMap = op.getUpperBoundMap();
1135 if (op.getNumOperands() !=
1136 (lowerBoundMap.getNumInputs() + upperBoundMap.getNumInputs()))
1137 return op.emitOpError(
1138 "operand count must match with affine map dimension and symbol count");
1139
1140 // Verify that the bound operands are valid dimension/symbols.
1141 /// Lower bound.
1142 if (failed(verifyDimAndSymbolIdentifiers(op, op.getLowerBoundOperands(),
1143 op.getLowerBoundMap().getNumDims())))
1144 return failure();
1145 /// Upper bound.
1146 if (failed(verifyDimAndSymbolIdentifiers(op, op.getUpperBoundOperands(),
1147 op.getUpperBoundMap().getNumDims())))
1148 return failure();
1149 return success();
1150 }
1151
1152 /// Parse a for operation loop bounds.
parseBound(bool isLower,OperationState & result,OpAsmParser & p)1153 static ParseResult parseBound(bool isLower, OperationState &result,
1154 OpAsmParser &p) {
1155 // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
1156 // the map has multiple results.
1157 bool failedToParsedMinMax =
1158 failed(p.parseOptionalKeyword(isLower ? "max" : "min"));
1159
1160 auto &builder = p.getBuilder();
1161 auto boundAttrName = isLower ? AffineForOp::getLowerBoundAttrName()
1162 : AffineForOp::getUpperBoundAttrName();
1163
1164 // Parse ssa-id as identity map.
1165 SmallVector<OpAsmParser::OperandType, 1> boundOpInfos;
1166 if (p.parseOperandList(boundOpInfos))
1167 return failure();
1168
1169 if (!boundOpInfos.empty()) {
1170 // Check that only one operand was parsed.
1171 if (boundOpInfos.size() > 1)
1172 return p.emitError(p.getNameLoc(),
1173 "expected only one loop bound operand");
1174
1175 // TODO: improve error message when SSA value is not of index type.
1176 // Currently it is 'use of value ... expects different type than prior uses'
1177 if (p.resolveOperand(boundOpInfos.front(), builder.getIndexType(),
1178 result.operands))
1179 return failure();
1180
1181 // Create an identity map using symbol id. This representation is optimized
1182 // for storage. Analysis passes may expand it into a multi-dimensional map
1183 // if desired.
1184 AffineMap map = builder.getSymbolIdentityMap();
1185 result.addAttribute(boundAttrName, AffineMapAttr::get(map));
1186 return success();
1187 }
1188
1189 // Get the attribute location.
1190 llvm::SMLoc attrLoc = p.getCurrentLocation();
1191
1192 Attribute boundAttr;
1193 if (p.parseAttribute(boundAttr, builder.getIndexType(), boundAttrName,
1194 result.attributes))
1195 return failure();
1196
1197 // Parse full form - affine map followed by dim and symbol list.
1198 if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) {
1199 unsigned currentNumOperands = result.operands.size();
1200 unsigned numDims;
1201 if (parseDimAndSymbolList(p, result.operands, numDims))
1202 return failure();
1203
1204 auto map = affineMapAttr.getValue();
1205 if (map.getNumDims() != numDims)
1206 return p.emitError(
1207 p.getNameLoc(),
1208 "dim operand count and integer set dim count must match");
1209
1210 unsigned numDimAndSymbolOperands =
1211 result.operands.size() - currentNumOperands;
1212 if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
1213 return p.emitError(
1214 p.getNameLoc(),
1215 "symbol operand count and integer set symbol count must match");
1216
1217 // If the map has multiple results, make sure that we parsed the min/max
1218 // prefix.
1219 if (map.getNumResults() > 1 && failedToParsedMinMax) {
1220 if (isLower) {
1221 return p.emitError(attrLoc, "lower loop bound affine map with "
1222 "multiple results requires 'max' prefix");
1223 }
1224 return p.emitError(attrLoc, "upper loop bound affine map with multiple "
1225 "results requires 'min' prefix");
1226 }
1227 return success();
1228 }
1229
1230 // Parse custom assembly form.
1231 if (auto integerAttr = boundAttr.dyn_cast<IntegerAttr>()) {
1232 result.attributes.pop_back();
1233 result.addAttribute(
1234 boundAttrName,
1235 AffineMapAttr::get(builder.getConstantAffineMap(integerAttr.getInt())));
1236 return success();
1237 }
1238
1239 return p.emitError(
1240 p.getNameLoc(),
1241 "expected valid affine map representation for loop bounds");
1242 }
1243
parseAffineForOp(OpAsmParser & parser,OperationState & result)1244 static ParseResult parseAffineForOp(OpAsmParser &parser,
1245 OperationState &result) {
1246 auto &builder = parser.getBuilder();
1247 OpAsmParser::OperandType inductionVariable;
1248 // Parse the induction variable followed by '='.
1249 if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual())
1250 return failure();
1251
1252 // Parse loop bounds.
1253 if (parseBound(/*isLower=*/true, result, parser) ||
1254 parser.parseKeyword("to", " between bounds") ||
1255 parseBound(/*isLower=*/false, result, parser))
1256 return failure();
1257
1258 // Parse the optional loop step, we default to 1 if one is not present.
1259 if (parser.parseOptionalKeyword("step")) {
1260 result.addAttribute(
1261 AffineForOp::getStepAttrName(),
1262 builder.getIntegerAttr(builder.getIndexType(), /*value=*/1));
1263 } else {
1264 llvm::SMLoc stepLoc = parser.getCurrentLocation();
1265 IntegerAttr stepAttr;
1266 if (parser.parseAttribute(stepAttr, builder.getIndexType(),
1267 AffineForOp::getStepAttrName().data(),
1268 result.attributes))
1269 return failure();
1270
1271 if (stepAttr.getValue().getSExtValue() < 0)
1272 return parser.emitError(
1273 stepLoc,
1274 "expected step to be representable as a positive signed integer");
1275 }
1276
1277 // Parse the body region.
1278 Region *body = result.addRegion();
1279 if (parser.parseRegion(*body, inductionVariable, builder.getIndexType()))
1280 return failure();
1281
1282 AffineForOp::ensureTerminator(*body, builder, result.location);
1283
1284 // Parse the optional attribute list.
1285 if (parser.parseOptionalAttrDict(result.attributes))
1286 return failure();
1287
1288 // Set the operands list as resizable so that we can freely modify the bounds.
1289 result.setOperandListToResizable();
1290 return success();
1291 }
1292
printBound(AffineMapAttr boundMap,Operation::operand_range boundOperands,const char * prefix,OpAsmPrinter & p)1293 static void printBound(AffineMapAttr boundMap,
1294 Operation::operand_range boundOperands,
1295 const char *prefix, OpAsmPrinter &p) {
1296 AffineMap map = boundMap.getValue();
1297
1298 // Check if this bound should be printed using custom assembly form.
1299 // The decision to restrict printing custom assembly form to trivial cases
1300 // comes from the will to roundtrip MLIR binary -> text -> binary in a
1301 // lossless way.
1302 // Therefore, custom assembly form parsing and printing is only supported for
1303 // zero-operand constant maps and single symbol operand identity maps.
1304 if (map.getNumResults() == 1) {
1305 AffineExpr expr = map.getResult(0);
1306
1307 // Print constant bound.
1308 if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
1309 if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
1310 p << constExpr.getValue();
1311 return;
1312 }
1313 }
1314
1315 // Print bound that consists of a single SSA symbol if the map is over a
1316 // single symbol.
1317 if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
1318 if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
1319 p.printOperand(*boundOperands.begin());
1320 return;
1321 }
1322 }
1323 } else {
1324 // Map has multiple results. Print 'min' or 'max' prefix.
1325 p << prefix << ' ';
1326 }
1327
1328 // Print the map and its operands.
1329 p << boundMap;
1330 printDimAndSymbolList(boundOperands.begin(), boundOperands.end(),
1331 map.getNumDims(), p);
1332 }
1333
print(OpAsmPrinter & p,AffineForOp op)1334 static void print(OpAsmPrinter &p, AffineForOp op) {
1335 p << "affine.for ";
1336 p.printOperand(op.getBody()->getArgument(0));
1337 p << " = ";
1338 printBound(op.getLowerBoundMapAttr(), op.getLowerBoundOperands(), "max", p);
1339 p << " to ";
1340 printBound(op.getUpperBoundMapAttr(), op.getUpperBoundOperands(), "min", p);
1341
1342 if (op.getStep() != 1)
1343 p << " step " << op.getStep();
1344 p.printRegion(op.region(),
1345 /*printEntryBlockArgs=*/false,
1346 /*printBlockTerminators=*/false);
1347 p.printOptionalAttrDict(op.getAttrs(),
1348 /*elidedAttrs=*/{op.getLowerBoundAttrName(),
1349 op.getUpperBoundAttrName(),
1350 op.getStepAttrName()});
1351 }
1352
1353 /// Fold the constant bounds of a loop.
foldLoopBounds(AffineForOp forOp)1354 static LogicalResult foldLoopBounds(AffineForOp forOp) {
1355 auto foldLowerOrUpperBound = [&forOp](bool lower) {
1356 // Check to see if each of the operands is the result of a constant. If
1357 // so, get the value. If not, ignore it.
1358 SmallVector<Attribute, 8> operandConstants;
1359 auto boundOperands =
1360 lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
1361 for (auto operand : boundOperands) {
1362 Attribute operandCst;
1363 matchPattern(operand, m_Constant(&operandCst));
1364 operandConstants.push_back(operandCst);
1365 }
1366
1367 AffineMap boundMap =
1368 lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
1369 assert(boundMap.getNumResults() >= 1 &&
1370 "bound maps should have at least one result");
1371 SmallVector<Attribute, 4> foldedResults;
1372 if (failed(boundMap.constantFold(operandConstants, foldedResults)))
1373 return failure();
1374
1375 // Compute the max or min as applicable over the results.
1376 assert(!foldedResults.empty() && "bounds should have at least one result");
1377 auto maxOrMin = foldedResults[0].cast<IntegerAttr>().getValue();
1378 for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
1379 auto foldedResult = foldedResults[i].cast<IntegerAttr>().getValue();
1380 maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
1381 : llvm::APIntOps::smin(maxOrMin, foldedResult);
1382 }
1383 lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
1384 : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
1385 return success();
1386 };
1387
1388 // Try to fold the lower bound.
1389 bool folded = false;
1390 if (!forOp.hasConstantLowerBound())
1391 folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true));
1392
1393 // Try to fold the upper bound.
1394 if (!forOp.hasConstantUpperBound())
1395 folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false));
1396 return success(folded);
1397 }
1398
1399 /// Canonicalize the bounds of the given loop.
canonicalizeLoopBounds(AffineForOp forOp)1400 static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
1401 SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
1402 SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
1403
1404 auto lbMap = forOp.getLowerBoundMap();
1405 auto ubMap = forOp.getUpperBoundMap();
1406 auto prevLbMap = lbMap;
1407 auto prevUbMap = ubMap;
1408
1409 canonicalizeMapAndOperands(&lbMap, &lbOperands);
1410 canonicalizeMapAndOperands(&ubMap, &ubOperands);
1411
1412 // Any canonicalization change always leads to updated map(s).
1413 if (lbMap == prevLbMap && ubMap == prevUbMap)
1414 return failure();
1415
1416 if (lbMap != prevLbMap)
1417 forOp.setLowerBound(lbOperands, lbMap);
1418 if (ubMap != prevUbMap)
1419 forOp.setUpperBound(ubOperands, ubMap);
1420 return success();
1421 }
1422
1423 namespace {
1424 /// This is a pattern to fold trivially empty loops.
1425 struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
1426 using OpRewritePattern<AffineForOp>::OpRewritePattern;
1427
matchAndRewrite__anon86bd9b9c0811::AffineForEmptyLoopFolder1428 PatternMatchResult matchAndRewrite(AffineForOp forOp,
1429 PatternRewriter &rewriter) const override {
1430 // Check that the body only contains a terminator.
1431 if (!has_single_element(*forOp.getBody()))
1432 return matchFailure();
1433 rewriter.eraseOp(forOp);
1434 return matchSuccess();
1435 }
1436 };
1437 } // end anonymous namespace
1438
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1439 void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1440 MLIRContext *context) {
1441 results.insert<AffineForEmptyLoopFolder>(context);
1442 }
1443
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1444 LogicalResult AffineForOp::fold(ArrayRef<Attribute> operands,
1445 SmallVectorImpl<OpFoldResult> &results) {
1446 bool folded = succeeded(foldLoopBounds(*this));
1447 folded |= succeeded(canonicalizeLoopBounds(*this));
1448 return success(folded);
1449 }
1450
getLowerBound()1451 AffineBound AffineForOp::getLowerBound() {
1452 auto lbMap = getLowerBoundMap();
1453 return AffineBound(AffineForOp(*this), 0, lbMap.getNumInputs(), lbMap);
1454 }
1455
getUpperBound()1456 AffineBound AffineForOp::getUpperBound() {
1457 auto lbMap = getLowerBoundMap();
1458 auto ubMap = getUpperBoundMap();
1459 return AffineBound(AffineForOp(*this), lbMap.getNumInputs(), getNumOperands(),
1460 ubMap);
1461 }
1462
setLowerBound(ValueRange lbOperands,AffineMap map)1463 void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) {
1464 assert(lbOperands.size() == map.getNumInputs());
1465 assert(map.getNumResults() >= 1 && "bound map has at least one result");
1466
1467 SmallVector<Value, 4> newOperands(lbOperands.begin(), lbOperands.end());
1468
1469 auto ubOperands = getUpperBoundOperands();
1470 newOperands.append(ubOperands.begin(), ubOperands.end());
1471 getOperation()->setOperands(newOperands);
1472
1473 setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map));
1474 }
1475
setUpperBound(ValueRange ubOperands,AffineMap map)1476 void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) {
1477 assert(ubOperands.size() == map.getNumInputs());
1478 assert(map.getNumResults() >= 1 && "bound map has at least one result");
1479
1480 SmallVector<Value, 4> newOperands(getLowerBoundOperands());
1481 newOperands.append(ubOperands.begin(), ubOperands.end());
1482 getOperation()->setOperands(newOperands);
1483
1484 setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map));
1485 }
1486
setLowerBoundMap(AffineMap map)1487 void AffineForOp::setLowerBoundMap(AffineMap map) {
1488 auto lbMap = getLowerBoundMap();
1489 assert(lbMap.getNumDims() == map.getNumDims() &&
1490 lbMap.getNumSymbols() == map.getNumSymbols());
1491 assert(map.getNumResults() >= 1 && "bound map has at least one result");
1492 (void)lbMap;
1493 setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map));
1494 }
1495
setUpperBoundMap(AffineMap map)1496 void AffineForOp::setUpperBoundMap(AffineMap map) {
1497 auto ubMap = getUpperBoundMap();
1498 assert(ubMap.getNumDims() == map.getNumDims() &&
1499 ubMap.getNumSymbols() == map.getNumSymbols());
1500 assert(map.getNumResults() >= 1 && "bound map has at least one result");
1501 (void)ubMap;
1502 setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map));
1503 }
1504
hasConstantLowerBound()1505 bool AffineForOp::hasConstantLowerBound() {
1506 return getLowerBoundMap().isSingleConstant();
1507 }
1508
hasConstantUpperBound()1509 bool AffineForOp::hasConstantUpperBound() {
1510 return getUpperBoundMap().isSingleConstant();
1511 }
1512
getConstantLowerBound()1513 int64_t AffineForOp::getConstantLowerBound() {
1514 return getLowerBoundMap().getSingleConstantResult();
1515 }
1516
getConstantUpperBound()1517 int64_t AffineForOp::getConstantUpperBound() {
1518 return getUpperBoundMap().getSingleConstantResult();
1519 }
1520
setConstantLowerBound(int64_t value)1521 void AffineForOp::setConstantLowerBound(int64_t value) {
1522 setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
1523 }
1524
setConstantUpperBound(int64_t value)1525 void AffineForOp::setConstantUpperBound(int64_t value) {
1526 setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
1527 }
1528
getLowerBoundOperands()1529 AffineForOp::operand_range AffineForOp::getLowerBoundOperands() {
1530 return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()};
1531 }
1532
getUpperBoundOperands()1533 AffineForOp::operand_range AffineForOp::getUpperBoundOperands() {
1534 return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()};
1535 }
1536
matchingBoundOperandList()1537 bool AffineForOp::matchingBoundOperandList() {
1538 auto lbMap = getLowerBoundMap();
1539 auto ubMap = getUpperBoundMap();
1540 if (lbMap.getNumDims() != ubMap.getNumDims() ||
1541 lbMap.getNumSymbols() != ubMap.getNumSymbols())
1542 return false;
1543
1544 unsigned numOperands = lbMap.getNumInputs();
1545 for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
1546 // Compare Value 's.
1547 if (getOperand(i) != getOperand(numOperands + i))
1548 return false;
1549 }
1550 return true;
1551 }
1552
getLoopBody()1553 Region &AffineForOp::getLoopBody() { return region(); }
1554
isDefinedOutsideOfLoop(Value value)1555 bool AffineForOp::isDefinedOutsideOfLoop(Value value) {
1556 return !region().isAncestor(value.getParentRegion());
1557 }
1558
moveOutOfLoop(ArrayRef<Operation * > ops)1559 LogicalResult AffineForOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
1560 for (auto *op : ops)
1561 op->moveBefore(*this);
1562 return success();
1563 }
1564
1565 /// Returns if the provided value is the induction variable of a AffineForOp.
isForInductionVar(Value val)1566 bool mlir::isForInductionVar(Value val) {
1567 return getForInductionVarOwner(val) != AffineForOp();
1568 }
1569
1570 /// Returns the loop parent of an induction variable. If the provided value is
1571 /// not an induction variable, then return nullptr.
getForInductionVarOwner(Value val)1572 AffineForOp mlir::getForInductionVarOwner(Value val) {
1573 auto ivArg = val.dyn_cast<BlockArgument>();
1574 if (!ivArg || !ivArg.getOwner())
1575 return AffineForOp();
1576 auto *containingInst = ivArg.getOwner()->getParent()->getParentOp();
1577 return dyn_cast<AffineForOp>(containingInst);
1578 }
1579
1580 /// Extracts the induction variables from a list of AffineForOps and returns
1581 /// them.
extractForInductionVars(ArrayRef<AffineForOp> forInsts,SmallVectorImpl<Value> * ivs)1582 void mlir::extractForInductionVars(ArrayRef<AffineForOp> forInsts,
1583 SmallVectorImpl<Value> *ivs) {
1584 ivs->reserve(forInsts.size());
1585 for (auto forInst : forInsts)
1586 ivs->push_back(forInst.getInductionVar());
1587 }
1588
1589 //===----------------------------------------------------------------------===//
1590 // AffineIfOp
1591 //===----------------------------------------------------------------------===//
1592
verify(AffineIfOp op)1593 static LogicalResult verify(AffineIfOp op) {
1594 // Verify that we have a condition attribute.
1595 auto conditionAttr =
1596 op.getAttrOfType<IntegerSetAttr>(op.getConditionAttrName());
1597 if (!conditionAttr)
1598 return op.emitOpError(
1599 "requires an integer set attribute named 'condition'");
1600
1601 // Verify that there are enough operands for the condition.
1602 IntegerSet condition = conditionAttr.getValue();
1603 if (op.getNumOperands() != condition.getNumInputs())
1604 return op.emitOpError(
1605 "operand count and condition integer set dimension and "
1606 "symbol count must match");
1607
1608 // Verify that the operands are valid dimension/symbols.
1609 if (failed(verifyDimAndSymbolIdentifiers(
1610 op, op.getOperation()->getNonSuccessorOperands(),
1611 condition.getNumDims())))
1612 return failure();
1613
1614 // Verify that the entry of each child region does not have arguments.
1615 for (auto ®ion : op.getOperation()->getRegions()) {
1616 for (auto &b : region)
1617 if (b.getNumArguments() != 0)
1618 return op.emitOpError(
1619 "requires that child entry blocks have no arguments");
1620 }
1621 return success();
1622 }
1623
parseAffineIfOp(OpAsmParser & parser,OperationState & result)1624 static ParseResult parseAffineIfOp(OpAsmParser &parser,
1625 OperationState &result) {
1626 // Parse the condition attribute set.
1627 IntegerSetAttr conditionAttr;
1628 unsigned numDims;
1629 if (parser.parseAttribute(conditionAttr, AffineIfOp::getConditionAttrName(),
1630 result.attributes) ||
1631 parseDimAndSymbolList(parser, result.operands, numDims))
1632 return failure();
1633
1634 // Verify the condition operands.
1635 auto set = conditionAttr.getValue();
1636 if (set.getNumDims() != numDims)
1637 return parser.emitError(
1638 parser.getNameLoc(),
1639 "dim operand count and integer set dim count must match");
1640 if (numDims + set.getNumSymbols() != result.operands.size())
1641 return parser.emitError(
1642 parser.getNameLoc(),
1643 "symbol operand count and integer set symbol count must match");
1644
1645 // Create the regions for 'then' and 'else'. The latter must be created even
1646 // if it remains empty for the validity of the operation.
1647 result.regions.reserve(2);
1648 Region *thenRegion = result.addRegion();
1649 Region *elseRegion = result.addRegion();
1650
1651 // Parse the 'then' region.
1652 if (parser.parseRegion(*thenRegion, {}, {}))
1653 return failure();
1654 AffineIfOp::ensureTerminator(*thenRegion, parser.getBuilder(),
1655 result.location);
1656
1657 // If we find an 'else' keyword then parse the 'else' region.
1658 if (!parser.parseOptionalKeyword("else")) {
1659 if (parser.parseRegion(*elseRegion, {}, {}))
1660 return failure();
1661 AffineIfOp::ensureTerminator(*elseRegion, parser.getBuilder(),
1662 result.location);
1663 }
1664
1665 // Parse the optional attribute list.
1666 if (parser.parseOptionalAttrDict(result.attributes))
1667 return failure();
1668
1669 return success();
1670 }
1671
print(OpAsmPrinter & p,AffineIfOp op)1672 static void print(OpAsmPrinter &p, AffineIfOp op) {
1673 auto conditionAttr =
1674 op.getAttrOfType<IntegerSetAttr>(op.getConditionAttrName());
1675 p << "affine.if " << conditionAttr;
1676 printDimAndSymbolList(op.operand_begin(), op.operand_end(),
1677 conditionAttr.getValue().getNumDims(), p);
1678 p.printRegion(op.thenRegion(),
1679 /*printEntryBlockArgs=*/false,
1680 /*printBlockTerminators=*/false);
1681
1682 // Print the 'else' regions if it has any blocks.
1683 auto &elseRegion = op.elseRegion();
1684 if (!elseRegion.empty()) {
1685 p << " else";
1686 p.printRegion(elseRegion,
1687 /*printEntryBlockArgs=*/false,
1688 /*printBlockTerminators=*/false);
1689 }
1690
1691 // Print the attribute list.
1692 p.printOptionalAttrDict(op.getAttrs(),
1693 /*elidedAttrs=*/op.getConditionAttrName());
1694 }
1695
getIntegerSet()1696 IntegerSet AffineIfOp::getIntegerSet() {
1697 return getAttrOfType<IntegerSetAttr>(getConditionAttrName()).getValue();
1698 }
setIntegerSet(IntegerSet newSet)1699 void AffineIfOp::setIntegerSet(IntegerSet newSet) {
1700 setAttr(getConditionAttrName(), IntegerSetAttr::get(newSet));
1701 }
1702
setConditional(IntegerSet set,ValueRange operands)1703 void AffineIfOp::setConditional(IntegerSet set, ValueRange operands) {
1704 setIntegerSet(set);
1705 getOperation()->setOperands(operands);
1706 }
1707
build(Builder * builder,OperationState & result,IntegerSet set,ValueRange args,bool withElseRegion)1708 void AffineIfOp::build(Builder *builder, OperationState &result, IntegerSet set,
1709 ValueRange args, bool withElseRegion) {
1710 result.addOperands(args);
1711 result.addAttribute(getConditionAttrName(), IntegerSetAttr::get(set));
1712 Region *thenRegion = result.addRegion();
1713 Region *elseRegion = result.addRegion();
1714 AffineIfOp::ensureTerminator(*thenRegion, *builder, result.location);
1715 if (withElseRegion)
1716 AffineIfOp::ensureTerminator(*elseRegion, *builder, result.location);
1717 }
1718
1719 /// Canonicalize an affine if op's conditional (integer set + operands).
fold(ArrayRef<Attribute>,SmallVectorImpl<OpFoldResult> &)1720 LogicalResult AffineIfOp::fold(ArrayRef<Attribute>,
1721 SmallVectorImpl<OpFoldResult> &) {
1722 auto set = getIntegerSet();
1723 SmallVector<Value, 4> operands(getOperands());
1724 canonicalizeSetAndOperands(&set, &operands);
1725
1726 // Any canonicalization change always leads to either a reduction in the
1727 // number of operands or a change in the number of symbolic operands
1728 // (promotion of dims to symbols).
1729 if (operands.size() < getIntegerSet().getNumInputs() ||
1730 set.getNumSymbols() > getIntegerSet().getNumSymbols()) {
1731 setConditional(set, operands);
1732 return success();
1733 }
1734
1735 return failure();
1736 }
1737
1738 //===----------------------------------------------------------------------===//
1739 // AffineLoadOp
1740 //===----------------------------------------------------------------------===//
1741
build(Builder * builder,OperationState & result,AffineMap map,ValueRange operands)1742 void AffineLoadOp::build(Builder *builder, OperationState &result,
1743 AffineMap map, ValueRange operands) {
1744 assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
1745 result.addOperands(operands);
1746 if (map)
1747 result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
1748 auto memrefType = operands[0].getType().cast<MemRefType>();
1749 result.types.push_back(memrefType.getElementType());
1750 }
1751
build(Builder * builder,OperationState & result,Value memref,AffineMap map,ValueRange mapOperands)1752 void AffineLoadOp::build(Builder *builder, OperationState &result, Value memref,
1753 AffineMap map, ValueRange mapOperands) {
1754 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
1755 result.addOperands(memref);
1756 result.addOperands(mapOperands);
1757 auto memrefType = memref.getType().cast<MemRefType>();
1758 result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
1759 result.types.push_back(memrefType.getElementType());
1760 }
1761
build(Builder * builder,OperationState & result,Value memref,ValueRange indices)1762 void AffineLoadOp::build(Builder *builder, OperationState &result, Value memref,
1763 ValueRange indices) {
1764 auto memrefType = memref.getType().cast<MemRefType>();
1765 auto rank = memrefType.getRank();
1766 // Create identity map for memrefs with at least one dimension or () -> ()
1767 // for zero-dimensional memrefs.
1768 auto map = rank ? builder->getMultiDimIdentityMap(rank)
1769 : builder->getEmptyAffineMap();
1770 build(builder, result, memref, map, indices);
1771 }
1772
parse(OpAsmParser & parser,OperationState & result)1773 ParseResult AffineLoadOp::parse(OpAsmParser &parser, OperationState &result) {
1774 auto &builder = parser.getBuilder();
1775 auto indexTy = builder.getIndexType();
1776
1777 MemRefType type;
1778 OpAsmParser::OperandType memrefInfo;
1779 AffineMapAttr mapAttr;
1780 SmallVector<OpAsmParser::OperandType, 1> mapOperands;
1781 return failure(
1782 parser.parseOperand(memrefInfo) ||
1783 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr, getMapAttrName(),
1784 result.attributes) ||
1785 parser.parseOptionalAttrDict(result.attributes) ||
1786 parser.parseColonType(type) ||
1787 parser.resolveOperand(memrefInfo, type, result.operands) ||
1788 parser.resolveOperands(mapOperands, indexTy, result.operands) ||
1789 parser.addTypeToList(type.getElementType(), result.types));
1790 }
1791
print(OpAsmPrinter & p)1792 void AffineLoadOp::print(OpAsmPrinter &p) {
1793 p << "affine.load " << getMemRef() << '[';
1794 if (AffineMapAttr mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName()))
1795 p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
1796 p << ']';
1797 p.printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{getMapAttrName()});
1798 p << " : " << getMemRefType();
1799 }
1800
verify()1801 LogicalResult AffineLoadOp::verify() {
1802 if (getType() != getMemRefType().getElementType())
1803 return emitOpError("result type must match element type of memref");
1804
1805 auto mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName());
1806 if (mapAttr) {
1807 AffineMap map = getAttrOfType<AffineMapAttr>(getMapAttrName()).getValue();
1808 if (map.getNumResults() != getMemRefType().getRank())
1809 return emitOpError("affine.load affine map num results must equal"
1810 " memref rank");
1811 if (map.getNumInputs() != getNumOperands() - 1)
1812 return emitOpError("expects as many subscripts as affine map inputs");
1813 } else {
1814 if (getMemRefType().getRank() != getNumOperands() - 1)
1815 return emitOpError(
1816 "expects the number of subscripts to be equal to memref rank");
1817 }
1818
1819 for (auto idx : getMapOperands()) {
1820 if (!idx.getType().isIndex())
1821 return emitOpError("index to load must have 'index' type");
1822 if (!isValidAffineIndexOperand(idx))
1823 return emitOpError("index must be a dimension or symbol identifier");
1824 }
1825 return success();
1826 }
1827
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1828 void AffineLoadOp::getCanonicalizationPatterns(
1829 OwningRewritePatternList &results, MLIRContext *context) {
1830 results.insert<SimplifyAffineOp<AffineLoadOp>>(context);
1831 }
1832
fold(ArrayRef<Attribute> cstOperands)1833 OpFoldResult AffineLoadOp::fold(ArrayRef<Attribute> cstOperands) {
1834 /// load(memrefcast) -> load
1835 if (succeeded(foldMemRefCast(*this)))
1836 return getResult();
1837 return OpFoldResult();
1838 }
1839
1840 //===----------------------------------------------------------------------===//
1841 // AffineStoreOp
1842 //===----------------------------------------------------------------------===//
1843
build(Builder * builder,OperationState & result,Value valueToStore,Value memref,AffineMap map,ValueRange mapOperands)1844 void AffineStoreOp::build(Builder *builder, OperationState &result,
1845 Value valueToStore, Value memref, AffineMap map,
1846 ValueRange mapOperands) {
1847 assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
1848 result.addOperands(valueToStore);
1849 result.addOperands(memref);
1850 result.addOperands(mapOperands);
1851 result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
1852 }
1853
1854 // Use identity map.
build(Builder * builder,OperationState & result,Value valueToStore,Value memref,ValueRange indices)1855 void AffineStoreOp::build(Builder *builder, OperationState &result,
1856 Value valueToStore, Value memref,
1857 ValueRange indices) {
1858 auto memrefType = memref.getType().cast<MemRefType>();
1859 auto rank = memrefType.getRank();
1860 // Create identity map for memrefs with at least one dimension or () -> ()
1861 // for zero-dimensional memrefs.
1862 auto map = rank ? builder->getMultiDimIdentityMap(rank)
1863 : builder->getEmptyAffineMap();
1864 build(builder, result, valueToStore, memref, map, indices);
1865 }
1866
parse(OpAsmParser & parser,OperationState & result)1867 ParseResult AffineStoreOp::parse(OpAsmParser &parser, OperationState &result) {
1868 auto indexTy = parser.getBuilder().getIndexType();
1869
1870 MemRefType type;
1871 OpAsmParser::OperandType storeValueInfo;
1872 OpAsmParser::OperandType memrefInfo;
1873 AffineMapAttr mapAttr;
1874 SmallVector<OpAsmParser::OperandType, 1> mapOperands;
1875 return failure(parser.parseOperand(storeValueInfo) || parser.parseComma() ||
1876 parser.parseOperand(memrefInfo) ||
1877 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
1878 getMapAttrName(),
1879 result.attributes) ||
1880 parser.parseOptionalAttrDict(result.attributes) ||
1881 parser.parseColonType(type) ||
1882 parser.resolveOperand(storeValueInfo, type.getElementType(),
1883 result.operands) ||
1884 parser.resolveOperand(memrefInfo, type, result.operands) ||
1885 parser.resolveOperands(mapOperands, indexTy, result.operands));
1886 }
1887
print(OpAsmPrinter & p)1888 void AffineStoreOp::print(OpAsmPrinter &p) {
1889 p << "affine.store " << getValueToStore();
1890 p << ", " << getMemRef() << '[';
1891 if (AffineMapAttr mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName()))
1892 p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
1893 p << ']';
1894 p.printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{getMapAttrName()});
1895 p << " : " << getMemRefType();
1896 }
1897
verify()1898 LogicalResult AffineStoreOp::verify() {
1899 // First operand must have same type as memref element type.
1900 if (getValueToStore().getType() != getMemRefType().getElementType())
1901 return emitOpError("first operand must have same type memref element type");
1902
1903 auto mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName());
1904 if (mapAttr) {
1905 AffineMap map = mapAttr.getValue();
1906 if (map.getNumResults() != getMemRefType().getRank())
1907 return emitOpError("affine.store affine map num results must equal"
1908 " memref rank");
1909 if (map.getNumInputs() != getNumOperands() - 2)
1910 return emitOpError("expects as many subscripts as affine map inputs");
1911 } else {
1912 if (getMemRefType().getRank() != getNumOperands() - 2)
1913 return emitOpError(
1914 "expects the number of subscripts to be equal to memref rank");
1915 }
1916
1917 for (auto idx : getMapOperands()) {
1918 if (!idx.getType().isIndex())
1919 return emitOpError("index to store must have 'index' type");
1920 if (!isValidAffineIndexOperand(idx))
1921 return emitOpError("index must be a dimension or symbol identifier");
1922 }
1923 return success();
1924 }
1925
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1926 void AffineStoreOp::getCanonicalizationPatterns(
1927 OwningRewritePatternList &results, MLIRContext *context) {
1928 results.insert<SimplifyAffineOp<AffineStoreOp>>(context);
1929 }
1930
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1931 LogicalResult AffineStoreOp::fold(ArrayRef<Attribute> cstOperands,
1932 SmallVectorImpl<OpFoldResult> &results) {
1933 /// store(memrefcast) -> store
1934 return foldMemRefCast(*this);
1935 }
1936
1937 //===----------------------------------------------------------------------===//
1938 // AffineMinOp
1939 //===----------------------------------------------------------------------===//
1940 //
1941 // %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
1942 //
1943
parseAffineMinOp(OpAsmParser & parser,OperationState & result)1944 static ParseResult parseAffineMinOp(OpAsmParser &parser,
1945 OperationState &result) {
1946 auto &builder = parser.getBuilder();
1947 auto indexType = builder.getIndexType();
1948 SmallVector<OpAsmParser::OperandType, 8> dim_infos;
1949 SmallVector<OpAsmParser::OperandType, 8> sym_infos;
1950 AffineMapAttr mapAttr;
1951 return failure(
1952 parser.parseAttribute(mapAttr, AffineMinOp::getMapAttrName(),
1953 result.attributes) ||
1954 parser.parseOperandList(dim_infos, OpAsmParser::Delimiter::Paren) ||
1955 parser.parseOperandList(sym_infos,
1956 OpAsmParser::Delimiter::OptionalSquare) ||
1957 parser.parseOptionalAttrDict(result.attributes) ||
1958 parser.resolveOperands(dim_infos, indexType, result.operands) ||
1959 parser.resolveOperands(sym_infos, indexType, result.operands) ||
1960 parser.addTypeToList(indexType, result.types));
1961 }
1962
print(OpAsmPrinter & p,AffineMinOp op)1963 static void print(OpAsmPrinter &p, AffineMinOp op) {
1964 p << op.getOperationName() << ' '
1965 << op.getAttr(AffineMinOp::getMapAttrName());
1966 auto operands = op.getOperands();
1967 unsigned numDims = op.map().getNumDims();
1968 p << '(' << operands.take_front(numDims) << ')';
1969
1970 if (operands.size() != numDims)
1971 p << '[' << operands.drop_front(numDims) << ']';
1972 p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"});
1973 }
1974
verify(AffineMinOp op)1975 static LogicalResult verify(AffineMinOp op) {
1976 // Verify that operand count matches affine map dimension and symbol count.
1977 if (op.getNumOperands() != op.map().getNumDims() + op.map().getNumSymbols())
1978 return op.emitOpError(
1979 "operand count and affine map dimension and symbol count must match");
1980 return success();
1981 }
1982
fold(ArrayRef<Attribute> operands)1983 OpFoldResult AffineMinOp::fold(ArrayRef<Attribute> operands) {
1984 // Fold the affine map.
1985 // TODO(andydavis, ntv) Fold more cases: partial static information,
1986 // min(some_affine, some_affine + constant, ...).
1987 SmallVector<Attribute, 2> results;
1988 if (failed(map().constantFold(operands, results)))
1989 return {};
1990
1991 // Compute and return min of folded map results.
1992 int64_t min = std::numeric_limits<int64_t>::max();
1993 int minIndex = -1;
1994 for (unsigned i = 0, e = results.size(); i < e; ++i) {
1995 auto intAttr = results[i].cast<IntegerAttr>();
1996 if (intAttr.getInt() < min) {
1997 min = intAttr.getInt();
1998 minIndex = i;
1999 }
2000 }
2001 if (minIndex < 0)
2002 return {};
2003 return results[minIndex];
2004 }
2005
2006 //===----------------------------------------------------------------------===//
2007 // AffinePrefetchOp
2008 //===----------------------------------------------------------------------===//
2009
2010 //
2011 // affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32>
2012 //
parseAffinePrefetchOp(OpAsmParser & parser,OperationState & result)2013 static ParseResult parseAffinePrefetchOp(OpAsmParser &parser,
2014 OperationState &result) {
2015 auto &builder = parser.getBuilder();
2016 auto indexTy = builder.getIndexType();
2017
2018 MemRefType type;
2019 OpAsmParser::OperandType memrefInfo;
2020 IntegerAttr hintInfo;
2021 auto i32Type = parser.getBuilder().getIntegerType(32);
2022 StringRef readOrWrite, cacheType;
2023
2024 AffineMapAttr mapAttr;
2025 SmallVector<OpAsmParser::OperandType, 1> mapOperands;
2026 if (parser.parseOperand(memrefInfo) ||
2027 parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
2028 AffinePrefetchOp::getMapAttrName(),
2029 result.attributes) ||
2030 parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
2031 parser.parseComma() || parser.parseKeyword("locality") ||
2032 parser.parseLess() ||
2033 parser.parseAttribute(hintInfo, i32Type,
2034 AffinePrefetchOp::getLocalityHintAttrName(),
2035 result.attributes) ||
2036 parser.parseGreater() || parser.parseComma() ||
2037 parser.parseKeyword(&cacheType) ||
2038 parser.parseOptionalAttrDict(result.attributes) ||
2039 parser.parseColonType(type) ||
2040 parser.resolveOperand(memrefInfo, type, result.operands) ||
2041 parser.resolveOperands(mapOperands, indexTy, result.operands))
2042 return failure();
2043
2044 if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
2045 return parser.emitError(parser.getNameLoc(),
2046 "rw specifier has to be 'read' or 'write'");
2047 result.addAttribute(
2048 AffinePrefetchOp::getIsWriteAttrName(),
2049 parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
2050
2051 if (!cacheType.equals("data") && !cacheType.equals("instr"))
2052 return parser.emitError(parser.getNameLoc(),
2053 "cache type has to be 'data' or 'instr'");
2054
2055 result.addAttribute(
2056 AffinePrefetchOp::getIsDataCacheAttrName(),
2057 parser.getBuilder().getBoolAttr(cacheType.equals("data")));
2058
2059 return success();
2060 }
2061
print(OpAsmPrinter & p,AffinePrefetchOp op)2062 static void print(OpAsmPrinter &p, AffinePrefetchOp op) {
2063 p << AffinePrefetchOp::getOperationName() << " " << op.memref() << '[';
2064 AffineMapAttr mapAttr = op.getAttrOfType<AffineMapAttr>(op.getMapAttrName());
2065 if (mapAttr) {
2066 SmallVector<Value, 2> operands(op.getMapOperands());
2067 p.printAffineMapOfSSAIds(mapAttr, operands);
2068 }
2069 p << ']' << ", " << (op.isWrite() ? "write" : "read") << ", "
2070 << "locality<" << op.localityHint() << ">, "
2071 << (op.isDataCache() ? "data" : "instr");
2072 p.printOptionalAttrDict(
2073 op.getAttrs(),
2074 /*elidedAttrs=*/{op.getMapAttrName(), op.getLocalityHintAttrName(),
2075 op.getIsDataCacheAttrName(), op.getIsWriteAttrName()});
2076 p << " : " << op.getMemRefType();
2077 }
2078
verify(AffinePrefetchOp op)2079 static LogicalResult verify(AffinePrefetchOp op) {
2080 auto mapAttr = op.getAttrOfType<AffineMapAttr>(op.getMapAttrName());
2081 if (mapAttr) {
2082 AffineMap map = mapAttr.getValue();
2083 if (map.getNumResults() != op.getMemRefType().getRank())
2084 return op.emitOpError("affine.prefetch affine map num results must equal"
2085 " memref rank");
2086 if (map.getNumInputs() + 1 != op.getNumOperands())
2087 return op.emitOpError("too few operands");
2088 } else {
2089 if (op.getNumOperands() != 1)
2090 return op.emitOpError("too few operands");
2091 }
2092
2093 for (auto idx : op.getMapOperands()) {
2094 if (!isValidAffineIndexOperand(idx))
2095 return op.emitOpError("index must be a dimension or symbol identifier");
2096 }
2097 return success();
2098 }
2099
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2100 void AffinePrefetchOp::getCanonicalizationPatterns(
2101 OwningRewritePatternList &results, MLIRContext *context) {
2102 // prefetch(memrefcast) -> prefetch
2103 results.insert<SimplifyAffineOp<AffinePrefetchOp>>(context);
2104 }
2105
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)2106 LogicalResult AffinePrefetchOp::fold(ArrayRef<Attribute> cstOperands,
2107 SmallVectorImpl<OpFoldResult> &results) {
2108 /// prefetch(memrefcast) -> prefetch
2109 return foldMemRefCast(*this);
2110 }
2111
2112 //===----------------------------------------------------------------------===//
2113 // TableGen'd op method definitions
2114 //===----------------------------------------------------------------------===//
2115
2116 #define GET_OP_CLASSES
2117 #include "mlir/Dialect/AffineOps/AffineOps.cpp.inc"
2118