1 //===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/IR/Function.h"
13 #include "mlir/IR/IntegerSet.h"
14 #include "mlir/IR/Matchers.h"
15 #include "mlir/IR/OpImplementation.h"
16 #include "mlir/IR/PatternMatch.h"
17 #include "mlir/Transforms/InliningUtils.h"
18 #include "llvm/ADT/SetVector.h"
19 #include "llvm/ADT/SmallBitVector.h"
20 #include "llvm/ADT/TypeSwitch.h"
21 #include "llvm/Support/Debug.h"
22 
23 using namespace mlir;
24 using llvm::dbgs;
25 
26 #define DEBUG_TYPE "affine-analysis"
27 
28 //===----------------------------------------------------------------------===//
29 // AffineDialect Interfaces
30 //===----------------------------------------------------------------------===//
31 
32 namespace {
33 /// This class defines the interface for handling inlining with affine
34 /// operations.
35 struct AffineInlinerInterface : public DialectInlinerInterface {
36   using DialectInlinerInterface::DialectInlinerInterface;
37 
38   //===--------------------------------------------------------------------===//
39   // Analysis Hooks
40   //===--------------------------------------------------------------------===//
41 
42   /// Returns true if the given region 'src' can be inlined into the region
43   /// 'dest' that is attached to an operation registered to the current dialect.
isLegalToInline__anon07ec94590111::AffineInlinerInterface44   bool isLegalToInline(Region *dest, Region *src,
45                        BlockAndValueMapping &valueMapping) const final {
46     // Conservatively don't allow inlining into affine structures.
47     return false;
48   }
49 
50   /// Returns true if the given operation 'op', that is registered to this
51   /// dialect, can be inlined into the given region, false otherwise.
isLegalToInline__anon07ec94590111::AffineInlinerInterface52   bool isLegalToInline(Operation *op, Region *region,
53                        BlockAndValueMapping &valueMapping) const final {
54     // Always allow inlining affine operations into the top-level region of a
55     // function. There are some edge cases when inlining *into* affine
56     // structures, but that is handled in the other 'isLegalToInline' hook
57     // above.
58     // TODO: We should be able to inline into other regions than functions.
59     return isa<FuncOp>(region->getParentOp());
60   }
61 
62   /// Affine regions should be analyzed recursively.
shouldAnalyzeRecursively__anon07ec94590111::AffineInlinerInterface63   bool shouldAnalyzeRecursively(Operation *op) const final { return true; }
64 };
65 } // end anonymous namespace
66 
67 //===----------------------------------------------------------------------===//
68 // AffineDialect
69 //===----------------------------------------------------------------------===//
70 
AffineDialect(MLIRContext * context)71 AffineDialect::AffineDialect(MLIRContext *context)
72     : Dialect(getDialectNamespace(), context) {
73   addOperations<AffineDmaStartOp, AffineDmaWaitOp,
74 #define GET_OP_LIST
75 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
76                 >();
77   addInterfaces<AffineInlinerInterface>();
78 }
79 
80 /// Materialize a single constant operation from a given attribute value with
81 /// the desired resultant type.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)82 Operation *AffineDialect::materializeConstant(OpBuilder &builder,
83                                               Attribute value, Type type,
84                                               Location loc) {
85   return builder.create<ConstantOp>(loc, type, value);
86 }
87 
88 /// A utility function to check if a value is defined at the top level of an
89 /// op with trait `AffineScope`. If the value is defined in an unlinked region,
90 /// conservatively assume it is not top-level. A value of index type defined at
91 /// the top level is always a valid symbol.
isTopLevelValue(Value value)92 bool mlir::isTopLevelValue(Value value) {
93   if (auto arg = value.dyn_cast<BlockArgument>()) {
94     // The block owning the argument may be unlinked, e.g. when the surrounding
95     // region has not yet been attached to an Op, at which point the parent Op
96     // is null.
97     Operation *parentOp = arg.getOwner()->getParentOp();
98     return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
99   }
100   // The defining Op may live in an unlinked block so its parent Op may be null.
101   Operation *parentOp = value.getDefiningOp()->getParentOp();
102   return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
103 }
104 
105 /// A utility function to check if a value is defined at the top level of
106 /// `region` or is an argument of `region`. A value of index type defined at the
107 /// top level of a `AffineScope` region is always a valid symbol for all
108 /// uses in that region.
isTopLevelValue(Value value,Region * region)109 static bool isTopLevelValue(Value value, Region *region) {
110   if (auto arg = value.dyn_cast<BlockArgument>())
111     return arg.getParentRegion() == region;
112   return value.getDefiningOp()->getParentRegion() == region;
113 }
114 
115 /// Returns the closest region enclosing `op` that is held by an operation with
116 /// trait `AffineScope`; `nullptr` if there is no such region.
117 //  TODO: getAffineScope should be publicly exposed for affine passes/utilities.
getAffineScope(Operation * op)118 static Region *getAffineScope(Operation *op) {
119   auto *curOp = op;
120   while (auto *parentOp = curOp->getParentOp()) {
121     if (parentOp->hasTrait<OpTrait::AffineScope>())
122       return curOp->getParentRegion();
123     curOp = parentOp;
124   }
125   return nullptr;
126 }
127 
128 // A Value can be used as a dimension id iff it meets one of the following
129 // conditions:
130 // *) It is valid as a symbol.
131 // *) It is an induction variable.
132 // *) It is the result of affine apply operation with dimension id arguments.
isValidDim(Value value)133 bool mlir::isValidDim(Value value) {
134   // The value must be an index type.
135   if (!value.getType().isIndex())
136     return false;
137 
138   if (auto *defOp = value.getDefiningOp())
139     return isValidDim(value, getAffineScope(defOp));
140 
141   // This value has to be a block argument for an op that has the
142   // `AffineScope` trait or for an affine.for or affine.parallel.
143   auto *parentOp = value.cast<BlockArgument>().getOwner()->getParentOp();
144   return parentOp && (parentOp->hasTrait<OpTrait::AffineScope>() ||
145                       isa<AffineForOp, AffineParallelOp>(parentOp));
146 }
147 
148 // Value can be used as a dimension id iff it meets one of the following
149 // conditions:
150 // *) It is valid as a symbol.
151 // *) It is an induction variable.
152 // *) It is the result of an affine apply operation with dimension id operands.
isValidDim(Value value,Region * region)153 bool mlir::isValidDim(Value value, Region *region) {
154   // The value must be an index type.
155   if (!value.getType().isIndex())
156     return false;
157 
158   // All valid symbols are okay.
159   if (isValidSymbol(value, region))
160     return true;
161 
162   auto *op = value.getDefiningOp();
163   if (!op) {
164     // This value has to be a block argument for an affine.for or an
165     // affine.parallel.
166     auto *parentOp = value.cast<BlockArgument>().getOwner()->getParentOp();
167     return isa<AffineForOp, AffineParallelOp>(parentOp);
168   }
169 
170   // Affine apply operation is ok if all of its operands are ok.
171   if (auto applyOp = dyn_cast<AffineApplyOp>(op))
172     return applyOp.isValidDim(region);
173   // The dim op is okay if its operand memref/tensor is defined at the top
174   // level.
175   if (auto dimOp = dyn_cast<DimOp>(op))
176     return isTopLevelValue(dimOp.memrefOrTensor());
177   return false;
178 }
179 
180 /// Returns true if the 'index' dimension of the `memref` defined by
181 /// `memrefDefOp` is a statically  shaped one or defined using a valid symbol
182 /// for `region`.
183 template <typename AnyMemRefDefOp>
isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp,unsigned index,Region * region)184 static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index,
185                                     Region *region) {
186   auto memRefType = memrefDefOp.getType();
187   // Statically shaped.
188   if (!memRefType.isDynamicDim(index))
189     return true;
190   // Get the position of the dimension among dynamic dimensions;
191   unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
192   return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
193                        region);
194 }
195 
196 /// Returns true if the result of the dim op is a valid symbol for `region`.
isDimOpValidSymbol(DimOp dimOp,Region * region)197 static bool isDimOpValidSymbol(DimOp dimOp, Region *region) {
198   // The dim op is okay if its operand memref/tensor is defined at the top
199   // level.
200   if (isTopLevelValue(dimOp.memrefOrTensor()))
201     return true;
202 
203   // The dim op is also okay if its operand memref/tensor is a view/subview
204   // whose corresponding size is a valid symbol.
205   Optional<int64_t> index = dimOp.getConstantIndex();
206   assert(index.hasValue() &&
207          "expect only `dim` operations with a constant index");
208   int64_t i = index.getValue();
209   return TypeSwitch<Operation *, bool>(dimOp.memrefOrTensor().getDefiningOp())
210       .Case<ViewOp, SubViewOp, AllocOp>(
211           [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); })
212       .Default([](Operation *) { return false; });
213 }
214 
215 // A value can be used as a symbol (at all its use sites) iff it meets one of
216 // the following conditions:
217 // *) It is a constant.
218 // *) Its defining op or block arg appearance is immediately enclosed by an op
219 //    with `AffineScope` trait.
220 // *) It is the result of an affine.apply operation with symbol operands.
221 // *) It is a result of the dim op on a memref whose corresponding size is a
222 //    valid symbol.
isValidSymbol(Value value)223 bool mlir::isValidSymbol(Value value) {
224   // The value must be an index type.
225   if (!value.getType().isIndex())
226     return false;
227 
228   // Check that the value is a top level value.
229   if (isTopLevelValue(value))
230     return true;
231 
232   if (auto *defOp = value.getDefiningOp())
233     return isValidSymbol(value, getAffineScope(defOp));
234 
235   return false;
236 }
237 
238 /// A value can be used as a symbol for `region` iff it meets onf of the the
239 /// following conditions:
240 /// *) It is a constant.
241 /// *) It is the result of an affine apply operation with symbol arguments.
242 /// *) It is a result of the dim op on a memref whose corresponding size is
243 ///    a valid symbol.
244 /// *) It is defined at the top level of 'region' or is its argument.
245 /// *) It dominates `region`'s parent op.
246 /// If `region` is null, conservatively assume the symbol definition scope does
247 /// not exist and only accept the values that would be symbols regardless of
248 /// the surrounding region structure, i.e. the first three cases above.
isValidSymbol(Value value,Region * region)249 bool mlir::isValidSymbol(Value value, Region *region) {
250   // The value must be an index type.
251   if (!value.getType().isIndex())
252     return false;
253 
254   // A top-level value is a valid symbol.
255   if (region && ::isTopLevelValue(value, region))
256     return true;
257 
258   auto *defOp = value.getDefiningOp();
259   if (!defOp) {
260     // A block argument that is not a top-level value is a valid symbol if it
261     // dominates region's parent op.
262     if (region && !region->getParentOp()->isKnownIsolatedFromAbove())
263       if (auto *parentOpRegion = region->getParentOp()->getParentRegion())
264         return isValidSymbol(value, parentOpRegion);
265     return false;
266   }
267 
268   // Constant operation is ok.
269   Attribute operandCst;
270   if (matchPattern(defOp, m_Constant(&operandCst)))
271     return true;
272 
273   // Affine apply operation is ok if all of its operands are ok.
274   if (auto applyOp = dyn_cast<AffineApplyOp>(defOp))
275     return applyOp.isValidSymbol(region);
276 
277   // Dim op results could be valid symbols at any level.
278   if (auto dimOp = dyn_cast<DimOp>(defOp))
279     return isDimOpValidSymbol(dimOp, region);
280 
281   // Check for values dominating `region`'s parent op.
282   if (region && !region->getParentOp()->isKnownIsolatedFromAbove())
283     if (auto *parentRegion = region->getParentOp()->getParentRegion())
284       return isValidSymbol(value, parentRegion);
285 
286   return false;
287 }
288 
289 // Returns true if 'value' is a valid index to an affine operation (e.g.
290 // affine.load, affine.store, affine.dma_start, affine.dma_wait) where
291 // `region` provides the polyhedral symbol scope. Returns false otherwise.
isValidAffineIndexOperand(Value value,Region * region)292 static bool isValidAffineIndexOperand(Value value, Region *region) {
293   return isValidDim(value, region) || isValidSymbol(value, region);
294 }
295 
296 /// Utility function to verify that a set of operands are valid dimension and
297 /// symbol identifiers. The operands should be laid out such that the dimension
298 /// operands are before the symbol operands. This function returns failure if
299 /// there was an invalid operand. An operation is provided to emit any necessary
300 /// errors.
301 template <typename OpTy>
302 static LogicalResult
verifyDimAndSymbolIdentifiers(OpTy & op,Operation::operand_range operands,unsigned numDims)303 verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands,
304                               unsigned numDims) {
305   unsigned opIt = 0;
306   for (auto operand : operands) {
307     if (opIt++ < numDims) {
308       if (!isValidDim(operand, getAffineScope(op)))
309         return op.emitOpError("operand cannot be used as a dimension id");
310     } else if (!isValidSymbol(operand, getAffineScope(op))) {
311       return op.emitOpError("operand cannot be used as a symbol");
312     }
313   }
314   return success();
315 }
316 
317 //===----------------------------------------------------------------------===//
318 // AffineApplyOp
319 //===----------------------------------------------------------------------===//
320 
getAffineValueMap()321 AffineValueMap AffineApplyOp::getAffineValueMap() {
322   return AffineValueMap(getAffineMap(), getOperands(), getResult());
323 }
324 
parseAffineApplyOp(OpAsmParser & parser,OperationState & result)325 static ParseResult parseAffineApplyOp(OpAsmParser &parser,
326                                       OperationState &result) {
327   auto &builder = parser.getBuilder();
328   auto indexTy = builder.getIndexType();
329 
330   AffineMapAttr mapAttr;
331   unsigned numDims;
332   if (parser.parseAttribute(mapAttr, "map", result.attributes) ||
333       parseDimAndSymbolList(parser, result.operands, numDims) ||
334       parser.parseOptionalAttrDict(result.attributes))
335     return failure();
336   auto map = mapAttr.getValue();
337 
338   if (map.getNumDims() != numDims ||
339       numDims + map.getNumSymbols() != result.operands.size()) {
340     return parser.emitError(parser.getNameLoc(),
341                             "dimension or symbol index mismatch");
342   }
343 
344   result.types.append(map.getNumResults(), indexTy);
345   return success();
346 }
347 
print(OpAsmPrinter & p,AffineApplyOp op)348 static void print(OpAsmPrinter &p, AffineApplyOp op) {
349   p << AffineApplyOp::getOperationName() << " " << op.mapAttr();
350   printDimAndSymbolList(op.operand_begin(), op.operand_end(),
351                         op.getAffineMap().getNumDims(), p);
352   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"});
353 }
354 
verify(AffineApplyOp op)355 static LogicalResult verify(AffineApplyOp op) {
356   // Check input and output dimensions match.
357   auto map = op.map();
358 
359   // Verify that operand count matches affine map dimension and symbol count.
360   if (op.getNumOperands() != map.getNumDims() + map.getNumSymbols())
361     return op.emitOpError(
362         "operand count and affine map dimension and symbol count must match");
363 
364   // Verify that the map only produces one result.
365   if (map.getNumResults() != 1)
366     return op.emitOpError("mapping must produce one value");
367 
368   return success();
369 }
370 
371 // The result of the affine apply operation can be used as a dimension id if all
372 // its operands are valid dimension ids.
isValidDim()373 bool AffineApplyOp::isValidDim() {
374   return llvm::all_of(getOperands(),
375                       [](Value op) { return mlir::isValidDim(op); });
376 }
377 
378 // The result of the affine apply operation can be used as a dimension id if all
379 // its operands are valid dimension ids with the parent operation of `region`
380 // defining the polyhedral scope for symbols.
isValidDim(Region * region)381 bool AffineApplyOp::isValidDim(Region *region) {
382   return llvm::all_of(getOperands(),
383                       [&](Value op) { return ::isValidDim(op, region); });
384 }
385 
386 // The result of the affine apply operation can be used as a symbol if all its
387 // operands are symbols.
isValidSymbol()388 bool AffineApplyOp::isValidSymbol() {
389   return llvm::all_of(getOperands(),
390                       [](Value op) { return mlir::isValidSymbol(op); });
391 }
392 
393 // The result of the affine apply operation can be used as a symbol in `region`
394 // if all its operands are symbols in `region`.
isValidSymbol(Region * region)395 bool AffineApplyOp::isValidSymbol(Region *region) {
396   return llvm::all_of(getOperands(), [&](Value operand) {
397     return mlir::isValidSymbol(operand, region);
398   });
399 }
400 
fold(ArrayRef<Attribute> operands)401 OpFoldResult AffineApplyOp::fold(ArrayRef<Attribute> operands) {
402   auto map = getAffineMap();
403 
404   // Fold dims and symbols to existing values.
405   auto expr = map.getResult(0);
406   if (auto dim = expr.dyn_cast<AffineDimExpr>())
407     return getOperand(dim.getPosition());
408   if (auto sym = expr.dyn_cast<AffineSymbolExpr>())
409     return getOperand(map.getNumDims() + sym.getPosition());
410 
411   // Otherwise, default to folding the map.
412   SmallVector<Attribute, 1> result;
413   if (failed(map.constantFold(operands, result)))
414     return {};
415   return result[0];
416 }
417 
renumberOneDim(Value v)418 AffineDimExpr AffineApplyNormalizer::renumberOneDim(Value v) {
419   DenseMap<Value, unsigned>::iterator iterPos;
420   bool inserted = false;
421   std::tie(iterPos, inserted) =
422       dimValueToPosition.insert(std::make_pair(v, dimValueToPosition.size()));
423   if (inserted) {
424     reorderedDims.push_back(v);
425   }
426   return getAffineDimExpr(iterPos->second, v.getContext())
427       .cast<AffineDimExpr>();
428 }
429 
renumber(const AffineApplyNormalizer & other)430 AffineMap AffineApplyNormalizer::renumber(const AffineApplyNormalizer &other) {
431   SmallVector<AffineExpr, 8> dimRemapping;
432   for (auto v : other.reorderedDims) {
433     auto kvp = other.dimValueToPosition.find(v);
434     if (dimRemapping.size() <= kvp->second)
435       dimRemapping.resize(kvp->second + 1);
436     dimRemapping[kvp->second] = renumberOneDim(kvp->first);
437   }
438   unsigned numSymbols = concatenatedSymbols.size();
439   unsigned numOtherSymbols = other.concatenatedSymbols.size();
440   SmallVector<AffineExpr, 8> symRemapping(numOtherSymbols);
441   for (unsigned idx = 0; idx < numOtherSymbols; ++idx) {
442     symRemapping[idx] =
443         getAffineSymbolExpr(idx + numSymbols, other.affineMap.getContext());
444   }
445   concatenatedSymbols.insert(concatenatedSymbols.end(),
446                              other.concatenatedSymbols.begin(),
447                              other.concatenatedSymbols.end());
448   auto map = other.affineMap;
449   return map.replaceDimsAndSymbols(dimRemapping, symRemapping,
450                                    reorderedDims.size(),
451                                    concatenatedSymbols.size());
452 }
453 
454 // Gather the positions of the operands that are produced by an AffineApplyOp.
455 static llvm::SetVector<unsigned>
indicesFromAffineApplyOp(ArrayRef<Value> operands)456 indicesFromAffineApplyOp(ArrayRef<Value> operands) {
457   llvm::SetVector<unsigned> res;
458   for (auto en : llvm::enumerate(operands))
459     if (isa_and_nonnull<AffineApplyOp>(en.value().getDefiningOp()))
460       res.insert(en.index());
461   return res;
462 }
463 
464 // Support the special case of a symbol coming from an AffineApplyOp that needs
465 // to be composed into the current AffineApplyOp.
466 // This case is handled by rewriting all such symbols into dims for the purpose
467 // of allowing mathematical AffineMap composition.
468 // Returns an AffineMap where symbols that come from an AffineApplyOp have been
469 // rewritten as dims and are ordered after the original dims.
470 // TODO: This promotion makes AffineMap lose track of which
471 // symbols are represented as dims. This loss is static but can still be
472 // recovered dynamically (with `isValidSymbol`). Still this is annoying for the
473 // semi-affine map case. A dynamic canonicalization of all dims that are valid
474 // symbols (a.k.a `canonicalizePromotedSymbols`) into symbols helps and even
475 // results in better simplifications and foldings. But we should evaluate
476 // whether this behavior is what we really want after using more.
promoteComposedSymbolsAsDims(AffineMap map,ArrayRef<Value> symbols)477 static AffineMap promoteComposedSymbolsAsDims(AffineMap map,
478                                               ArrayRef<Value> symbols) {
479   if (symbols.empty()) {
480     return map;
481   }
482 
483   // Sanity check on symbols.
484   for (auto sym : symbols) {
485     assert(isValidSymbol(sym) && "Expected only valid symbols");
486     (void)sym;
487   }
488 
489   // Extract the symbol positions that come from an AffineApplyOp and
490   // needs to be rewritten as dims.
491   auto symPositions = indicesFromAffineApplyOp(symbols);
492   if (symPositions.empty()) {
493     return map;
494   }
495 
496   // Create the new map by replacing each symbol at pos by the next new dim.
497   unsigned numDims = map.getNumDims();
498   unsigned numSymbols = map.getNumSymbols();
499   unsigned numNewDims = 0;
500   unsigned numNewSymbols = 0;
501   SmallVector<AffineExpr, 8> symReplacements(numSymbols);
502   for (unsigned i = 0; i < numSymbols; ++i) {
503     symReplacements[i] =
504         symPositions.count(i) > 0
505             ? getAffineDimExpr(numDims + numNewDims++, map.getContext())
506             : getAffineSymbolExpr(numNewSymbols++, map.getContext());
507   }
508   assert(numSymbols >= numNewDims);
509   AffineMap newMap = map.replaceDimsAndSymbols(
510       {}, symReplacements, numDims + numNewDims, numNewSymbols);
511 
512   return newMap;
513 }
514 
515 /// The AffineNormalizer composes AffineApplyOp recursively. Its purpose is to
516 /// keep a correspondence between the mathematical `map` and the `operands` of
517 /// a given AffineApplyOp. This correspondence is maintained by iterating over
518 /// the operands and forming an `auxiliaryMap` that can be composed
519 /// mathematically with `map`. To keep this correspondence in cases where
520 /// symbols are produced by affine.apply operations, we perform a local rewrite
521 /// of symbols as dims.
522 ///
523 /// Rationale for locally rewriting symbols as dims:
524 /// ================================================
525 /// The mathematical composition of AffineMap must always concatenate symbols
526 /// because it does not have enough information to do otherwise. For example,
527 /// composing `(d0)[s0] -> (d0 + s0)` with itself must produce
528 /// `(d0)[s0, s1] -> (d0 + s0 + s1)`.
529 ///
530 /// The result is only equivalent to `(d0)[s0] -> (d0 + 2 * s0)` when
531 /// applied to the same mlir::Value for both s0 and s1.
532 /// As a consequence mathematical composition of AffineMap always concatenates
533 /// symbols.
534 ///
535 /// When AffineMaps are used in AffineApplyOp however, they may specify
536 /// composition via symbols, which is ambiguous mathematically. This corner case
537 /// is handled by locally rewriting such symbols that come from AffineApplyOp
538 /// into dims and composing through dims.
539 /// TODO: Composition via symbols comes at a significant code
540 /// complexity. Alternatively we should investigate whether we want to
541 /// explicitly disallow symbols coming from affine.apply and instead force the
542 /// user to compose symbols beforehand. The annoyances may be small (i.e. 1 or 2
543 /// extra API calls for such uses, which haven't popped up until now) and the
544 /// benefit potentially big: simpler and more maintainable code for a
545 /// non-trivial, recursive, procedure.
AffineApplyNormalizer(AffineMap map,ArrayRef<Value> operands)546 AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map,
547                                              ArrayRef<Value> operands)
548     : AffineApplyNormalizer() {
549   static_assert(kMaxAffineApplyDepth > 0, "kMaxAffineApplyDepth must be > 0");
550   assert(map.getNumInputs() == operands.size() &&
551          "number of operands does not match the number of map inputs");
552 
553   LLVM_DEBUG(map.print(dbgs() << "\nInput map: "));
554 
555   // Promote symbols that come from an AffineApplyOp to dims by rewriting the
556   // map to always refer to:
557   //   (dims, symbols coming from AffineApplyOp, other symbols).
558   // The order of operands can remain unchanged.
559   // This is a simplification that relies on 2 ordering properties:
560   //   1. rewritten symbols always appear after the original dims in the map;
561   //   2. operands are traversed in order and either dispatched to:
562   //      a. auxiliaryExprs (dims and symbols rewritten as dims);
563   //      b. concatenatedSymbols (all other symbols)
564   // This allows operand order to remain unchanged.
565   unsigned numDimsBeforeRewrite = map.getNumDims();
566   map = promoteComposedSymbolsAsDims(map,
567                                      operands.take_back(map.getNumSymbols()));
568 
569   LLVM_DEBUG(map.print(dbgs() << "\nRewritten map: "));
570 
571   SmallVector<AffineExpr, 8> auxiliaryExprs;
572   bool furtherCompose = (affineApplyDepth() <= kMaxAffineApplyDepth);
573   // We fully spell out the 2 cases below. In this particular instance a little
574   // code duplication greatly improves readability.
575   // Note that the first branch would disappear if we only supported full
576   // composition (i.e. infinite kMaxAffineApplyDepth).
577   if (!furtherCompose) {
578     // 1. Only dispatch dims or symbols.
579     for (auto en : llvm::enumerate(operands)) {
580       auto t = en.value();
581       assert(t.getType().isIndex());
582       bool isDim = (en.index() < map.getNumDims());
583       if (isDim) {
584         // a. The mathematical composition of AffineMap composes dims.
585         auxiliaryExprs.push_back(renumberOneDim(t));
586       } else {
587         // b. The mathematical composition of AffineMap concatenates symbols.
588         //    We do the same for symbol operands.
589         concatenatedSymbols.push_back(t);
590       }
591     }
592   } else {
593     assert(numDimsBeforeRewrite <= operands.size());
594     // 2. Compose AffineApplyOps and dispatch dims or symbols.
595     for (unsigned i = 0, e = operands.size(); i < e; ++i) {
596       auto t = operands[i];
597       auto affineApply = t.getDefiningOp<AffineApplyOp>();
598       if (affineApply) {
599         // a. Compose affine.apply operations.
600         LLVM_DEBUG(affineApply.getOperation()->print(
601             dbgs() << "\nCompose AffineApplyOp recursively: "));
602         AffineMap affineApplyMap = affineApply.getAffineMap();
603         SmallVector<Value, 8> affineApplyOperands(
604             affineApply.getOperands().begin(), affineApply.getOperands().end());
605         AffineApplyNormalizer normalizer(affineApplyMap, affineApplyOperands);
606 
607         LLVM_DEBUG(normalizer.affineMap.print(
608             dbgs() << "\nRenumber into current normalizer: "));
609 
610         auto renumberedMap = renumber(normalizer);
611 
612         LLVM_DEBUG(
613             renumberedMap.print(dbgs() << "\nRecursive composition yields: "));
614 
615         auxiliaryExprs.push_back(renumberedMap.getResult(0));
616       } else {
617         if (i < numDimsBeforeRewrite) {
618           // b. The mathematical composition of AffineMap composes dims.
619           auxiliaryExprs.push_back(renumberOneDim(t));
620         } else {
621           // c. The mathematical composition of AffineMap concatenates symbols.
622           //    Note that the map composition will put symbols already present
623           //    in the map before any symbols coming from the auxiliary map, so
624           //    we insert them before any symbols that are due to renumbering,
625           //    and after the proper symbols we have seen already.
626           concatenatedSymbols.insert(
627               std::next(concatenatedSymbols.begin(), numProperSymbols++), t);
628         }
629       }
630     }
631   }
632 
633   // Early exit if `map` is already composed.
634   if (auxiliaryExprs.empty()) {
635     affineMap = map;
636     return;
637   }
638 
639   assert(concatenatedSymbols.size() >= map.getNumSymbols() &&
640          "Unexpected number of concatenated symbols");
641   auto numDims = dimValueToPosition.size();
642   auto numSymbols = concatenatedSymbols.size() - map.getNumSymbols();
643   auto auxiliaryMap =
644       AffineMap::get(numDims, numSymbols, auxiliaryExprs, map.getContext());
645 
646   LLVM_DEBUG(map.print(dbgs() << "\nCompose map: "));
647   LLVM_DEBUG(auxiliaryMap.print(dbgs() << "\nWith map: "));
648   LLVM_DEBUG(map.compose(auxiliaryMap).print(dbgs() << "\nResult: "));
649 
650   // TODO: Disabling simplification results in major speed gains.
651   // Another option is to cache the results as it is expected a lot of redundant
652   // work is performed in practice.
653   affineMap = simplifyAffineMap(map.compose(auxiliaryMap));
654 
655   LLVM_DEBUG(affineMap.print(dbgs() << "\nSimplified result: "));
656   LLVM_DEBUG(dbgs() << "\n");
657 }
658 
normalize(AffineMap * otherMap,SmallVectorImpl<Value> * otherOperands)659 void AffineApplyNormalizer::normalize(AffineMap *otherMap,
660                                       SmallVectorImpl<Value> *otherOperands) {
661   AffineApplyNormalizer other(*otherMap, *otherOperands);
662   *otherMap = renumber(other);
663 
664   otherOperands->reserve(reorderedDims.size() + concatenatedSymbols.size());
665   otherOperands->assign(reorderedDims.begin(), reorderedDims.end());
666   otherOperands->append(concatenatedSymbols.begin(), concatenatedSymbols.end());
667 }
668 
669 /// Implements `map` and `operands` composition and simplification to support
670 /// `makeComposedAffineApply`. This can be called to achieve the same effects
671 /// on `map` and `operands` without creating an AffineApplyOp that needs to be
672 /// immediately deleted.
composeAffineMapAndOperands(AffineMap * map,SmallVectorImpl<Value> * operands)673 static void composeAffineMapAndOperands(AffineMap *map,
674                                         SmallVectorImpl<Value> *operands) {
675   AffineApplyNormalizer normalizer(*map, *operands);
676   auto normalizedMap = normalizer.getAffineMap();
677   auto normalizedOperands = normalizer.getOperands();
678   canonicalizeMapAndOperands(&normalizedMap, &normalizedOperands);
679   *map = normalizedMap;
680   *operands = normalizedOperands;
681   assert(*map);
682 }
683 
fullyComposeAffineMapAndOperands(AffineMap * map,SmallVectorImpl<Value> * operands)684 void mlir::fullyComposeAffineMapAndOperands(AffineMap *map,
685                                             SmallVectorImpl<Value> *operands) {
686   while (llvm::any_of(*operands, [](Value v) {
687     return isa_and_nonnull<AffineApplyOp>(v.getDefiningOp());
688   })) {
689     composeAffineMapAndOperands(map, operands);
690   }
691 }
692 
makeComposedAffineApply(OpBuilder & b,Location loc,AffineMap map,ArrayRef<Value> operands)693 AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc,
694                                             AffineMap map,
695                                             ArrayRef<Value> operands) {
696   AffineMap normalizedMap = map;
697   SmallVector<Value, 8> normalizedOperands(operands.begin(), operands.end());
698   composeAffineMapAndOperands(&normalizedMap, &normalizedOperands);
699   assert(normalizedMap);
700   return b.create<AffineApplyOp>(loc, normalizedMap, normalizedOperands);
701 }
702 
703 // A symbol may appear as a dim in affine.apply operations. This function
704 // canonicalizes dims that are valid symbols into actual symbols.
705 template <class MapOrSet>
canonicalizePromotedSymbols(MapOrSet * mapOrSet,SmallVectorImpl<Value> * operands)706 static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,
707                                         SmallVectorImpl<Value> *operands) {
708   if (!mapOrSet || operands->empty())
709     return;
710 
711   assert(mapOrSet->getNumInputs() == operands->size() &&
712          "map/set inputs must match number of operands");
713 
714   auto *context = mapOrSet->getContext();
715   SmallVector<Value, 8> resultOperands;
716   resultOperands.reserve(operands->size());
717   SmallVector<Value, 8> remappedSymbols;
718   remappedSymbols.reserve(operands->size());
719   unsigned nextDim = 0;
720   unsigned nextSym = 0;
721   unsigned oldNumSyms = mapOrSet->getNumSymbols();
722   SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
723   for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
724     if (i < mapOrSet->getNumDims()) {
725       if (isValidSymbol((*operands)[i])) {
726         // This is a valid symbol that appears as a dim, canonicalize it.
727         dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context);
728         remappedSymbols.push_back((*operands)[i]);
729       } else {
730         dimRemapping[i] = getAffineDimExpr(nextDim++, context);
731         resultOperands.push_back((*operands)[i]);
732       }
733     } else {
734       resultOperands.push_back((*operands)[i]);
735     }
736   }
737 
738   resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
739   *operands = resultOperands;
740   *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
741                                               oldNumSyms + nextSym);
742 
743   assert(mapOrSet->getNumInputs() == operands->size() &&
744          "map/set inputs must match number of operands");
745 }
746 
747 // Works for either an affine map or an integer set.
748 template <class MapOrSet>
canonicalizeMapOrSetAndOperands(MapOrSet * mapOrSet,SmallVectorImpl<Value> * operands)749 static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
750                                             SmallVectorImpl<Value> *operands) {
751   static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
752                 "Argument must be either of AffineMap or IntegerSet type");
753 
754   if (!mapOrSet || operands->empty())
755     return;
756 
757   assert(mapOrSet->getNumInputs() == operands->size() &&
758          "map/set inputs must match number of operands");
759 
760   canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
761 
762   // Check to see what dims are used.
763   llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
764   llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
765   mapOrSet->walkExprs([&](AffineExpr expr) {
766     if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
767       usedDims[dimExpr.getPosition()] = true;
768     else if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
769       usedSyms[symExpr.getPosition()] = true;
770   });
771 
772   auto *context = mapOrSet->getContext();
773 
774   SmallVector<Value, 8> resultOperands;
775   resultOperands.reserve(operands->size());
776 
777   llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
778   SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
779   unsigned nextDim = 0;
780   for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
781     if (usedDims[i]) {
782       // Remap dim positions for duplicate operands.
783       auto it = seenDims.find((*operands)[i]);
784       if (it == seenDims.end()) {
785         dimRemapping[i] = getAffineDimExpr(nextDim++, context);
786         resultOperands.push_back((*operands)[i]);
787         seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
788       } else {
789         dimRemapping[i] = it->second;
790       }
791     }
792   }
793   llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
794   SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
795   unsigned nextSym = 0;
796   for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
797     if (!usedSyms[i])
798       continue;
799     // Handle constant operands (only needed for symbolic operands since
800     // constant operands in dimensional positions would have already been
801     // promoted to symbolic positions above).
802     IntegerAttr operandCst;
803     if (matchPattern((*operands)[i + mapOrSet->getNumDims()],
804                      m_Constant(&operandCst))) {
805       symRemapping[i] =
806           getAffineConstantExpr(operandCst.getValue().getSExtValue(), context);
807       continue;
808     }
809     // Remap symbol positions for duplicate operands.
810     auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
811     if (it == seenSymbols.end()) {
812       symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
813       resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
814       seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
815                                         symRemapping[i]));
816     } else {
817       symRemapping[i] = it->second;
818     }
819   }
820   *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
821                                               nextDim, nextSym);
822   *operands = resultOperands;
823 }
824 
canonicalizeMapAndOperands(AffineMap * map,SmallVectorImpl<Value> * operands)825 void mlir::canonicalizeMapAndOperands(AffineMap *map,
826                                       SmallVectorImpl<Value> *operands) {
827   canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
828 }
829 
canonicalizeSetAndOperands(IntegerSet * set,SmallVectorImpl<Value> * operands)830 void mlir::canonicalizeSetAndOperands(IntegerSet *set,
831                                       SmallVectorImpl<Value> *operands) {
832   canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
833 }
834 
835 namespace {
836 /// Simplify AffineApply, AffineLoad, and AffineStore operations by composing
837 /// maps that supply results into them.
838 ///
839 template <typename AffineOpTy>
840 struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
841   using OpRewritePattern<AffineOpTy>::OpRewritePattern;
842 
843   /// Replace the affine op with another instance of it with the supplied
844   /// map and mapOperands.
845   void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
846                        AffineMap map, ArrayRef<Value> mapOperands) const;
847 
matchAndRewrite__anon07ec94590a11::SimplifyAffineOp848   LogicalResult matchAndRewrite(AffineOpTy affineOp,
849                                 PatternRewriter &rewriter) const override {
850     static_assert(llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
851                                   AffineStoreOp, AffineApplyOp, AffineMinOp,
852                                   AffineMaxOp>::value,
853                   "affine load/store/apply/prefetch/min/max op expected");
854     auto map = affineOp.getAffineMap();
855     AffineMap oldMap = map;
856     auto oldOperands = affineOp.getMapOperands();
857     SmallVector<Value, 8> resultOperands(oldOperands);
858     composeAffineMapAndOperands(&map, &resultOperands);
859     if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
860                                     resultOperands.begin()))
861       return failure();
862 
863     replaceAffineOp(rewriter, affineOp, map, resultOperands);
864     return success();
865   }
866 };
867 
868 // Specialize the template to account for the different build signatures for
869 // affine load, store, and apply ops.
870 template <>
replaceAffineOp(PatternRewriter & rewriter,AffineLoadOp load,AffineMap map,ArrayRef<Value> mapOperands) const871 void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
872     PatternRewriter &rewriter, AffineLoadOp load, AffineMap map,
873     ArrayRef<Value> mapOperands) const {
874   rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.getMemRef(), map,
875                                             mapOperands);
876 }
877 template <>
replaceAffineOp(PatternRewriter & rewriter,AffinePrefetchOp prefetch,AffineMap map,ArrayRef<Value> mapOperands) const878 void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
879     PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map,
880     ArrayRef<Value> mapOperands) const {
881   rewriter.replaceOpWithNewOp<AffinePrefetchOp>(
882       prefetch, prefetch.memref(), map, mapOperands,
883       prefetch.localityHint().getZExtValue(), prefetch.isWrite(),
884       prefetch.isDataCache());
885 }
886 template <>
replaceAffineOp(PatternRewriter & rewriter,AffineStoreOp store,AffineMap map,ArrayRef<Value> mapOperands) const887 void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
888     PatternRewriter &rewriter, AffineStoreOp store, AffineMap map,
889     ArrayRef<Value> mapOperands) const {
890   rewriter.replaceOpWithNewOp<AffineStoreOp>(
891       store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
892 }
893 
894 // Generic version for ops that don't have extra operands.
895 template <typename AffineOpTy>
replaceAffineOp(PatternRewriter & rewriter,AffineOpTy op,AffineMap map,ArrayRef<Value> mapOperands) const896 void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
897     PatternRewriter &rewriter, AffineOpTy op, AffineMap map,
898     ArrayRef<Value> mapOperands) const {
899   rewriter.replaceOpWithNewOp<AffineOpTy>(op, map, mapOperands);
900 }
901 } // end anonymous namespace.
902 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)903 void AffineApplyOp::getCanonicalizationPatterns(
904     OwningRewritePatternList &results, MLIRContext *context) {
905   results.insert<SimplifyAffineOp<AffineApplyOp>>(context);
906 }
907 
908 //===----------------------------------------------------------------------===//
909 // Common canonicalization pattern support logic
910 //===----------------------------------------------------------------------===//
911 
912 /// This is a common class used for patterns of the form
913 /// "someop(memrefcast) -> someop".  It folds the source of any memref_cast
914 /// into the root operation directly.
foldMemRefCast(Operation * op)915 static LogicalResult foldMemRefCast(Operation *op) {
916   bool folded = false;
917   for (OpOperand &operand : op->getOpOperands()) {
918     auto cast = operand.get().getDefiningOp<MemRefCastOp>();
919     if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
920       operand.set(cast.getOperand());
921       folded = true;
922     }
923   }
924   return success(folded);
925 }
926 
927 //===----------------------------------------------------------------------===//
928 // AffineDmaStartOp
929 //===----------------------------------------------------------------------===//
930 
931 // TODO: Check that map operands are loop IVs or symbols.
build(OpBuilder & builder,OperationState & result,Value srcMemRef,AffineMap srcMap,ValueRange srcIndices,Value destMemRef,AffineMap dstMap,ValueRange destIndices,Value tagMemRef,AffineMap tagMap,ValueRange tagIndices,Value numElements,Value stride,Value elementsPerStride)932 void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result,
933                              Value srcMemRef, AffineMap srcMap,
934                              ValueRange srcIndices, Value destMemRef,
935                              AffineMap dstMap, ValueRange destIndices,
936                              Value tagMemRef, AffineMap tagMap,
937                              ValueRange tagIndices, Value numElements,
938                              Value stride, Value elementsPerStride) {
939   result.addOperands(srcMemRef);
940   result.addAttribute(getSrcMapAttrName(), AffineMapAttr::get(srcMap));
941   result.addOperands(srcIndices);
942   result.addOperands(destMemRef);
943   result.addAttribute(getDstMapAttrName(), AffineMapAttr::get(dstMap));
944   result.addOperands(destIndices);
945   result.addOperands(tagMemRef);
946   result.addAttribute(getTagMapAttrName(), AffineMapAttr::get(tagMap));
947   result.addOperands(tagIndices);
948   result.addOperands(numElements);
949   if (stride) {
950     result.addOperands({stride, elementsPerStride});
951   }
952 }
953 
print(OpAsmPrinter & p)954 void AffineDmaStartOp::print(OpAsmPrinter &p) {
955   p << "affine.dma_start " << getSrcMemRef() << '[';
956   p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices());
957   p << "], " << getDstMemRef() << '[';
958   p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices());
959   p << "], " << getTagMemRef() << '[';
960   p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices());
961   p << "], " << getNumElements();
962   if (isStrided()) {
963     p << ", " << getStride();
964     p << ", " << getNumElementsPerStride();
965   }
966   p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
967     << getTagMemRefType();
968 }
969 
970 // Parse AffineDmaStartOp.
971 // Ex:
972 //   affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size,
973 //     %stride, %num_elt_per_stride
974 //       : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32>
975 //
parse(OpAsmParser & parser,OperationState & result)976 ParseResult AffineDmaStartOp::parse(OpAsmParser &parser,
977                                     OperationState &result) {
978   OpAsmParser::OperandType srcMemRefInfo;
979   AffineMapAttr srcMapAttr;
980   SmallVector<OpAsmParser::OperandType, 4> srcMapOperands;
981   OpAsmParser::OperandType dstMemRefInfo;
982   AffineMapAttr dstMapAttr;
983   SmallVector<OpAsmParser::OperandType, 4> dstMapOperands;
984   OpAsmParser::OperandType tagMemRefInfo;
985   AffineMapAttr tagMapAttr;
986   SmallVector<OpAsmParser::OperandType, 4> tagMapOperands;
987   OpAsmParser::OperandType numElementsInfo;
988   SmallVector<OpAsmParser::OperandType, 2> strideInfo;
989 
990   SmallVector<Type, 3> types;
991   auto indexType = parser.getBuilder().getIndexType();
992 
993   // Parse and resolve the following list of operands:
994   // *) dst memref followed by its affine maps operands (in square brackets).
995   // *) src memref followed by its affine map operands (in square brackets).
996   // *) tag memref followed by its affine map operands (in square brackets).
997   // *) number of elements transferred by DMA operation.
998   if (parser.parseOperand(srcMemRefInfo) ||
999       parser.parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr,
1000                                     getSrcMapAttrName(), result.attributes) ||
1001       parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1002       parser.parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr,
1003                                     getDstMapAttrName(), result.attributes) ||
1004       parser.parseComma() || parser.parseOperand(tagMemRefInfo) ||
1005       parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1006                                     getTagMapAttrName(), result.attributes) ||
1007       parser.parseComma() || parser.parseOperand(numElementsInfo))
1008     return failure();
1009 
1010   // Parse optional stride and elements per stride.
1011   if (parser.parseTrailingOperandList(strideInfo)) {
1012     return failure();
1013   }
1014   if (!strideInfo.empty() && strideInfo.size() != 2) {
1015     return parser.emitError(parser.getNameLoc(),
1016                             "expected two stride related operands");
1017   }
1018   bool isStrided = strideInfo.size() == 2;
1019 
1020   if (parser.parseColonTypeList(types))
1021     return failure();
1022 
1023   if (types.size() != 3)
1024     return parser.emitError(parser.getNameLoc(), "expected three types");
1025 
1026   if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1027       parser.resolveOperands(srcMapOperands, indexType, result.operands) ||
1028       parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1029       parser.resolveOperands(dstMapOperands, indexType, result.operands) ||
1030       parser.resolveOperand(tagMemRefInfo, types[2], result.operands) ||
1031       parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1032       parser.resolveOperand(numElementsInfo, indexType, result.operands))
1033     return failure();
1034 
1035   if (isStrided) {
1036     if (parser.resolveOperands(strideInfo, indexType, result.operands))
1037       return failure();
1038   }
1039 
1040   // Check that src/dst/tag operand counts match their map.numInputs.
1041   if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
1042       dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
1043       tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1044     return parser.emitError(parser.getNameLoc(),
1045                             "memref operand count not equal to map.numInputs");
1046   return success();
1047 }
1048 
verify()1049 LogicalResult AffineDmaStartOp::verify() {
1050   if (!getOperand(getSrcMemRefOperandIndex()).getType().isa<MemRefType>())
1051     return emitOpError("expected DMA source to be of memref type");
1052   if (!getOperand(getDstMemRefOperandIndex()).getType().isa<MemRefType>())
1053     return emitOpError("expected DMA destination to be of memref type");
1054   if (!getOperand(getTagMemRefOperandIndex()).getType().isa<MemRefType>())
1055     return emitOpError("expected DMA tag to be of memref type");
1056 
1057   // DMAs from different memory spaces supported.
1058   if (getSrcMemorySpace() == getDstMemorySpace()) {
1059     return emitOpError("DMA should be between different memory spaces");
1060   }
1061   unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
1062                               getDstMap().getNumInputs() +
1063                               getTagMap().getNumInputs();
1064   if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
1065       getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
1066     return emitOpError("incorrect number of operands");
1067   }
1068 
1069   Region *scope = getAffineScope(*this);
1070   for (auto idx : getSrcIndices()) {
1071     if (!idx.getType().isIndex())
1072       return emitOpError("src index to dma_start must have 'index' type");
1073     if (!isValidAffineIndexOperand(idx, scope))
1074       return emitOpError("src index must be a dimension or symbol identifier");
1075   }
1076   for (auto idx : getDstIndices()) {
1077     if (!idx.getType().isIndex())
1078       return emitOpError("dst index to dma_start must have 'index' type");
1079     if (!isValidAffineIndexOperand(idx, scope))
1080       return emitOpError("dst index must be a dimension or symbol identifier");
1081   }
1082   for (auto idx : getTagIndices()) {
1083     if (!idx.getType().isIndex())
1084       return emitOpError("tag index to dma_start must have 'index' type");
1085     if (!isValidAffineIndexOperand(idx, scope))
1086       return emitOpError("tag index must be a dimension or symbol identifier");
1087   }
1088   return success();
1089 }
1090 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1091 LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1092                                      SmallVectorImpl<OpFoldResult> &results) {
1093   /// dma_start(memrefcast) -> dma_start
1094   return foldMemRefCast(*this);
1095 }
1096 
1097 //===----------------------------------------------------------------------===//
1098 // AffineDmaWaitOp
1099 //===----------------------------------------------------------------------===//
1100 
1101 // TODO: Check that map operands are loop IVs or symbols.
build(OpBuilder & builder,OperationState & result,Value tagMemRef,AffineMap tagMap,ValueRange tagIndices,Value numElements)1102 void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result,
1103                             Value tagMemRef, AffineMap tagMap,
1104                             ValueRange tagIndices, Value numElements) {
1105   result.addOperands(tagMemRef);
1106   result.addAttribute(getTagMapAttrName(), AffineMapAttr::get(tagMap));
1107   result.addOperands(tagIndices);
1108   result.addOperands(numElements);
1109 }
1110 
print(OpAsmPrinter & p)1111 void AffineDmaWaitOp::print(OpAsmPrinter &p) {
1112   p << "affine.dma_wait " << getTagMemRef() << '[';
1113   SmallVector<Value, 2> operands(getTagIndices());
1114   p.printAffineMapOfSSAIds(getTagMapAttr(), operands);
1115   p << "], ";
1116   p.printOperand(getNumElements());
1117   p << " : " << getTagMemRef().getType();
1118 }
1119 
1120 // Parse AffineDmaWaitOp.
1121 // Eg:
1122 //   affine.dma_wait %tag[%index], %num_elements
1123 //     : memref<1 x i32, (d0) -> (d0), 4>
1124 //
parse(OpAsmParser & parser,OperationState & result)1125 ParseResult AffineDmaWaitOp::parse(OpAsmParser &parser,
1126                                    OperationState &result) {
1127   OpAsmParser::OperandType tagMemRefInfo;
1128   AffineMapAttr tagMapAttr;
1129   SmallVector<OpAsmParser::OperandType, 2> tagMapOperands;
1130   Type type;
1131   auto indexType = parser.getBuilder().getIndexType();
1132   OpAsmParser::OperandType numElementsInfo;
1133 
1134   // Parse tag memref, its map operands, and dma size.
1135   if (parser.parseOperand(tagMemRefInfo) ||
1136       parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1137                                     getTagMapAttrName(), result.attributes) ||
1138       parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1139       parser.parseColonType(type) ||
1140       parser.resolveOperand(tagMemRefInfo, type, result.operands) ||
1141       parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1142       parser.resolveOperand(numElementsInfo, indexType, result.operands))
1143     return failure();
1144 
1145   if (!type.isa<MemRefType>())
1146     return parser.emitError(parser.getNameLoc(),
1147                             "expected tag to be of memref type");
1148 
1149   if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1150     return parser.emitError(parser.getNameLoc(),
1151                             "tag memref operand count != to map.numInputs");
1152   return success();
1153 }
1154 
verify()1155 LogicalResult AffineDmaWaitOp::verify() {
1156   if (!getOperand(0).getType().isa<MemRefType>())
1157     return emitOpError("expected DMA tag to be of memref type");
1158   Region *scope = getAffineScope(*this);
1159   for (auto idx : getTagIndices()) {
1160     if (!idx.getType().isIndex())
1161       return emitOpError("index to dma_wait must have 'index' type");
1162     if (!isValidAffineIndexOperand(idx, scope))
1163       return emitOpError("index must be a dimension or symbol identifier");
1164   }
1165   return success();
1166 }
1167 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1168 LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1169                                     SmallVectorImpl<OpFoldResult> &results) {
1170   /// dma_wait(memrefcast) -> dma_wait
1171   return foldMemRefCast(*this);
1172 }
1173 
1174 //===----------------------------------------------------------------------===//
1175 // AffineForOp
1176 //===----------------------------------------------------------------------===//
1177 
build(OpBuilder & builder,OperationState & result,ValueRange lbOperands,AffineMap lbMap,ValueRange ubOperands,AffineMap ubMap,int64_t step,function_ref<void (OpBuilder &,Location,Value)> bodyBuilder)1178 void AffineForOp::build(
1179     OpBuilder &builder, OperationState &result, ValueRange lbOperands,
1180     AffineMap lbMap, ValueRange ubOperands, AffineMap ubMap, int64_t step,
1181     function_ref<void(OpBuilder &, Location, Value)> bodyBuilder) {
1182   assert(((!lbMap && lbOperands.empty()) ||
1183           lbOperands.size() == lbMap.getNumInputs()) &&
1184          "lower bound operand count does not match the affine map");
1185   assert(((!ubMap && ubOperands.empty()) ||
1186           ubOperands.size() == ubMap.getNumInputs()) &&
1187          "upper bound operand count does not match the affine map");
1188   assert(step > 0 && "step has to be a positive integer constant");
1189 
1190   // Add an attribute for the step.
1191   result.addAttribute(getStepAttrName(),
1192                       builder.getIntegerAttr(builder.getIndexType(), step));
1193 
1194   // Add the lower bound.
1195   result.addAttribute(getLowerBoundAttrName(), AffineMapAttr::get(lbMap));
1196   result.addOperands(lbOperands);
1197 
1198   // Add the upper bound.
1199   result.addAttribute(getUpperBoundAttrName(), AffineMapAttr::get(ubMap));
1200   result.addOperands(ubOperands);
1201 
1202   // Create a region and a block for the body.  The argument of the region is
1203   // the loop induction variable.
1204   Region *bodyRegion = result.addRegion();
1205   Block *body = new Block;
1206   Value inductionVar = body->addArgument(IndexType::get(builder.getContext()));
1207   bodyRegion->push_back(body);
1208   if (bodyBuilder) {
1209     OpBuilder::InsertionGuard guard(builder);
1210     builder.setInsertionPointToStart(body);
1211     bodyBuilder(builder, result.location, inductionVar);
1212   } else {
1213     ensureTerminator(*bodyRegion, builder, result.location);
1214   }
1215 }
1216 
build(OpBuilder & builder,OperationState & result,int64_t lb,int64_t ub,int64_t step,function_ref<void (OpBuilder &,Location,Value)> bodyBuilder)1217 void AffineForOp::build(
1218     OpBuilder &builder, OperationState &result, int64_t lb, int64_t ub,
1219     int64_t step,
1220     function_ref<void(OpBuilder &, Location, Value)> bodyBuilder) {
1221   auto lbMap = AffineMap::getConstantMap(lb, builder.getContext());
1222   auto ubMap = AffineMap::getConstantMap(ub, builder.getContext());
1223   return build(builder, result, {}, lbMap, {}, ubMap, step, bodyBuilder);
1224 }
1225 
verify(AffineForOp op)1226 static LogicalResult verify(AffineForOp op) {
1227   // Check that the body defines as single block argument for the induction
1228   // variable.
1229   auto *body = op.getBody();
1230   if (body->getNumArguments() != 1 || !body->getArgument(0).getType().isIndex())
1231     return op.emitOpError(
1232         "expected body to have a single index argument for the "
1233         "induction variable");
1234 
1235   // Verify that there are enough operands for the bounds.
1236   AffineMap lowerBoundMap = op.getLowerBoundMap(),
1237             upperBoundMap = op.getUpperBoundMap();
1238   if (op.getNumOperands() !=
1239       (lowerBoundMap.getNumInputs() + upperBoundMap.getNumInputs()))
1240     return op.emitOpError(
1241         "operand count must match with affine map dimension and symbol count");
1242 
1243   // Verify that the bound operands are valid dimension/symbols.
1244   /// Lower bound.
1245   if (failed(verifyDimAndSymbolIdentifiers(op, op.getLowerBoundOperands(),
1246                                            op.getLowerBoundMap().getNumDims())))
1247     return failure();
1248   /// Upper bound.
1249   if (failed(verifyDimAndSymbolIdentifiers(op, op.getUpperBoundOperands(),
1250                                            op.getUpperBoundMap().getNumDims())))
1251     return failure();
1252   return success();
1253 }
1254 
1255 /// Parse a for operation loop bounds.
parseBound(bool isLower,OperationState & result,OpAsmParser & p)1256 static ParseResult parseBound(bool isLower, OperationState &result,
1257                               OpAsmParser &p) {
1258   // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
1259   // the map has multiple results.
1260   bool failedToParsedMinMax =
1261       failed(p.parseOptionalKeyword(isLower ? "max" : "min"));
1262 
1263   auto &builder = p.getBuilder();
1264   auto boundAttrName = isLower ? AffineForOp::getLowerBoundAttrName()
1265                                : AffineForOp::getUpperBoundAttrName();
1266 
1267   // Parse ssa-id as identity map.
1268   SmallVector<OpAsmParser::OperandType, 1> boundOpInfos;
1269   if (p.parseOperandList(boundOpInfos))
1270     return failure();
1271 
1272   if (!boundOpInfos.empty()) {
1273     // Check that only one operand was parsed.
1274     if (boundOpInfos.size() > 1)
1275       return p.emitError(p.getNameLoc(),
1276                          "expected only one loop bound operand");
1277 
1278     // TODO: improve error message when SSA value is not of index type.
1279     // Currently it is 'use of value ... expects different type than prior uses'
1280     if (p.resolveOperand(boundOpInfos.front(), builder.getIndexType(),
1281                          result.operands))
1282       return failure();
1283 
1284     // Create an identity map using symbol id. This representation is optimized
1285     // for storage. Analysis passes may expand it into a multi-dimensional map
1286     // if desired.
1287     AffineMap map = builder.getSymbolIdentityMap();
1288     result.addAttribute(boundAttrName, AffineMapAttr::get(map));
1289     return success();
1290   }
1291 
1292   // Get the attribute location.
1293   llvm::SMLoc attrLoc = p.getCurrentLocation();
1294 
1295   Attribute boundAttr;
1296   if (p.parseAttribute(boundAttr, builder.getIndexType(), boundAttrName,
1297                        result.attributes))
1298     return failure();
1299 
1300   // Parse full form - affine map followed by dim and symbol list.
1301   if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) {
1302     unsigned currentNumOperands = result.operands.size();
1303     unsigned numDims;
1304     if (parseDimAndSymbolList(p, result.operands, numDims))
1305       return failure();
1306 
1307     auto map = affineMapAttr.getValue();
1308     if (map.getNumDims() != numDims)
1309       return p.emitError(
1310           p.getNameLoc(),
1311           "dim operand count and affine map dim count must match");
1312 
1313     unsigned numDimAndSymbolOperands =
1314         result.operands.size() - currentNumOperands;
1315     if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
1316       return p.emitError(
1317           p.getNameLoc(),
1318           "symbol operand count and affine map symbol count must match");
1319 
1320     // If the map has multiple results, make sure that we parsed the min/max
1321     // prefix.
1322     if (map.getNumResults() > 1 && failedToParsedMinMax) {
1323       if (isLower) {
1324         return p.emitError(attrLoc, "lower loop bound affine map with "
1325                                     "multiple results requires 'max' prefix");
1326       }
1327       return p.emitError(attrLoc, "upper loop bound affine map with multiple "
1328                                   "results requires 'min' prefix");
1329     }
1330     return success();
1331   }
1332 
1333   // Parse custom assembly form.
1334   if (auto integerAttr = boundAttr.dyn_cast<IntegerAttr>()) {
1335     result.attributes.pop_back();
1336     result.addAttribute(
1337         boundAttrName,
1338         AffineMapAttr::get(builder.getConstantAffineMap(integerAttr.getInt())));
1339     return success();
1340   }
1341 
1342   return p.emitError(
1343       p.getNameLoc(),
1344       "expected valid affine map representation for loop bounds");
1345 }
1346 
parseAffineForOp(OpAsmParser & parser,OperationState & result)1347 static ParseResult parseAffineForOp(OpAsmParser &parser,
1348                                     OperationState &result) {
1349   auto &builder = parser.getBuilder();
1350   OpAsmParser::OperandType inductionVariable;
1351   // Parse the induction variable followed by '='.
1352   if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual())
1353     return failure();
1354 
1355   // Parse loop bounds.
1356   if (parseBound(/*isLower=*/true, result, parser) ||
1357       parser.parseKeyword("to", " between bounds") ||
1358       parseBound(/*isLower=*/false, result, parser))
1359     return failure();
1360 
1361   // Parse the optional loop step, we default to 1 if one is not present.
1362   if (parser.parseOptionalKeyword("step")) {
1363     result.addAttribute(
1364         AffineForOp::getStepAttrName(),
1365         builder.getIntegerAttr(builder.getIndexType(), /*value=*/1));
1366   } else {
1367     llvm::SMLoc stepLoc = parser.getCurrentLocation();
1368     IntegerAttr stepAttr;
1369     if (parser.parseAttribute(stepAttr, builder.getIndexType(),
1370                               AffineForOp::getStepAttrName().data(),
1371                               result.attributes))
1372       return failure();
1373 
1374     if (stepAttr.getValue().getSExtValue() < 0)
1375       return parser.emitError(
1376           stepLoc,
1377           "expected step to be representable as a positive signed integer");
1378   }
1379 
1380   // Parse the body region.
1381   Region *body = result.addRegion();
1382   if (parser.parseRegion(*body, inductionVariable, builder.getIndexType()))
1383     return failure();
1384 
1385   AffineForOp::ensureTerminator(*body, builder, result.location);
1386 
1387   // Parse the optional attribute list.
1388   return parser.parseOptionalAttrDict(result.attributes);
1389 }
1390 
printBound(AffineMapAttr boundMap,Operation::operand_range boundOperands,const char * prefix,OpAsmPrinter & p)1391 static void printBound(AffineMapAttr boundMap,
1392                        Operation::operand_range boundOperands,
1393                        const char *prefix, OpAsmPrinter &p) {
1394   AffineMap map = boundMap.getValue();
1395 
1396   // Check if this bound should be printed using custom assembly form.
1397   // The decision to restrict printing custom assembly form to trivial cases
1398   // comes from the will to roundtrip MLIR binary -> text -> binary in a
1399   // lossless way.
1400   // Therefore, custom assembly form parsing and printing is only supported for
1401   // zero-operand constant maps and single symbol operand identity maps.
1402   if (map.getNumResults() == 1) {
1403     AffineExpr expr = map.getResult(0);
1404 
1405     // Print constant bound.
1406     if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
1407       if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
1408         p << constExpr.getValue();
1409         return;
1410       }
1411     }
1412 
1413     // Print bound that consists of a single SSA symbol if the map is over a
1414     // single symbol.
1415     if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
1416       if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
1417         p.printOperand(*boundOperands.begin());
1418         return;
1419       }
1420     }
1421   } else {
1422     // Map has multiple results. Print 'min' or 'max' prefix.
1423     p << prefix << ' ';
1424   }
1425 
1426   // Print the map and its operands.
1427   p << boundMap;
1428   printDimAndSymbolList(boundOperands.begin(), boundOperands.end(),
1429                         map.getNumDims(), p);
1430 }
1431 
print(OpAsmPrinter & p,AffineForOp op)1432 static void print(OpAsmPrinter &p, AffineForOp op) {
1433   p << op.getOperationName() << ' ';
1434   p.printOperand(op.getBody()->getArgument(0));
1435   p << " = ";
1436   printBound(op.getLowerBoundMapAttr(), op.getLowerBoundOperands(), "max", p);
1437   p << " to ";
1438   printBound(op.getUpperBoundMapAttr(), op.getUpperBoundOperands(), "min", p);
1439 
1440   if (op.getStep() != 1)
1441     p << " step " << op.getStep();
1442   p.printRegion(op.region(),
1443                 /*printEntryBlockArgs=*/false,
1444                 /*printBlockTerminators=*/false);
1445   p.printOptionalAttrDict(op.getAttrs(),
1446                           /*elidedAttrs=*/{op.getLowerBoundAttrName(),
1447                                            op.getUpperBoundAttrName(),
1448                                            op.getStepAttrName()});
1449 }
1450 
1451 /// Fold the constant bounds of a loop.
foldLoopBounds(AffineForOp forOp)1452 static LogicalResult foldLoopBounds(AffineForOp forOp) {
1453   auto foldLowerOrUpperBound = [&forOp](bool lower) {
1454     // Check to see if each of the operands is the result of a constant.  If
1455     // so, get the value.  If not, ignore it.
1456     SmallVector<Attribute, 8> operandConstants;
1457     auto boundOperands =
1458         lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
1459     for (auto operand : boundOperands) {
1460       Attribute operandCst;
1461       matchPattern(operand, m_Constant(&operandCst));
1462       operandConstants.push_back(operandCst);
1463     }
1464 
1465     AffineMap boundMap =
1466         lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
1467     assert(boundMap.getNumResults() >= 1 &&
1468            "bound maps should have at least one result");
1469     SmallVector<Attribute, 4> foldedResults;
1470     if (failed(boundMap.constantFold(operandConstants, foldedResults)))
1471       return failure();
1472 
1473     // Compute the max or min as applicable over the results.
1474     assert(!foldedResults.empty() && "bounds should have at least one result");
1475     auto maxOrMin = foldedResults[0].cast<IntegerAttr>().getValue();
1476     for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
1477       auto foldedResult = foldedResults[i].cast<IntegerAttr>().getValue();
1478       maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
1479                        : llvm::APIntOps::smin(maxOrMin, foldedResult);
1480     }
1481     lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
1482           : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
1483     return success();
1484   };
1485 
1486   // Try to fold the lower bound.
1487   bool folded = false;
1488   if (!forOp.hasConstantLowerBound())
1489     folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true));
1490 
1491   // Try to fold the upper bound.
1492   if (!forOp.hasConstantUpperBound())
1493     folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false));
1494   return success(folded);
1495 }
1496 
1497 /// Canonicalize the bounds of the given loop.
canonicalizeLoopBounds(AffineForOp forOp)1498 static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
1499   SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
1500   SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
1501 
1502   auto lbMap = forOp.getLowerBoundMap();
1503   auto ubMap = forOp.getUpperBoundMap();
1504   auto prevLbMap = lbMap;
1505   auto prevUbMap = ubMap;
1506 
1507   canonicalizeMapAndOperands(&lbMap, &lbOperands);
1508   lbMap = removeDuplicateExprs(lbMap);
1509 
1510   canonicalizeMapAndOperands(&ubMap, &ubOperands);
1511   ubMap = removeDuplicateExprs(ubMap);
1512 
1513   // Any canonicalization change always leads to updated map(s).
1514   if (lbMap == prevLbMap && ubMap == prevUbMap)
1515     return failure();
1516 
1517   if (lbMap != prevLbMap)
1518     forOp.setLowerBound(lbOperands, lbMap);
1519   if (ubMap != prevUbMap)
1520     forOp.setUpperBound(ubOperands, ubMap);
1521   return success();
1522 }
1523 
1524 namespace {
1525 /// This is a pattern to fold trivially empty loops.
1526 struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
1527   using OpRewritePattern<AffineForOp>::OpRewritePattern;
1528 
matchAndRewrite__anon07ec94590c11::AffineForEmptyLoopFolder1529   LogicalResult matchAndRewrite(AffineForOp forOp,
1530                                 PatternRewriter &rewriter) const override {
1531     // Check that the body only contains a yield.
1532     if (!llvm::hasSingleElement(*forOp.getBody()))
1533       return failure();
1534     rewriter.eraseOp(forOp);
1535     return success();
1536   }
1537 };
1538 } // end anonymous namespace
1539 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1540 void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1541                                               MLIRContext *context) {
1542   results.insert<AffineForEmptyLoopFolder>(context);
1543 }
1544 
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1545 LogicalResult AffineForOp::fold(ArrayRef<Attribute> operands,
1546                                 SmallVectorImpl<OpFoldResult> &results) {
1547   bool folded = succeeded(foldLoopBounds(*this));
1548   folded |= succeeded(canonicalizeLoopBounds(*this));
1549   return success(folded);
1550 }
1551 
getLowerBound()1552 AffineBound AffineForOp::getLowerBound() {
1553   auto lbMap = getLowerBoundMap();
1554   return AffineBound(AffineForOp(*this), 0, lbMap.getNumInputs(), lbMap);
1555 }
1556 
getUpperBound()1557 AffineBound AffineForOp::getUpperBound() {
1558   auto lbMap = getLowerBoundMap();
1559   auto ubMap = getUpperBoundMap();
1560   return AffineBound(AffineForOp(*this), lbMap.getNumInputs(), getNumOperands(),
1561                      ubMap);
1562 }
1563 
setLowerBound(ValueRange lbOperands,AffineMap map)1564 void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) {
1565   assert(lbOperands.size() == map.getNumInputs());
1566   assert(map.getNumResults() >= 1 && "bound map has at least one result");
1567 
1568   SmallVector<Value, 4> newOperands(lbOperands.begin(), lbOperands.end());
1569 
1570   auto ubOperands = getUpperBoundOperands();
1571   newOperands.append(ubOperands.begin(), ubOperands.end());
1572   getOperation()->setOperands(newOperands);
1573 
1574   setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map));
1575 }
1576 
setUpperBound(ValueRange ubOperands,AffineMap map)1577 void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) {
1578   assert(ubOperands.size() == map.getNumInputs());
1579   assert(map.getNumResults() >= 1 && "bound map has at least one result");
1580 
1581   SmallVector<Value, 4> newOperands(getLowerBoundOperands());
1582   newOperands.append(ubOperands.begin(), ubOperands.end());
1583   getOperation()->setOperands(newOperands);
1584 
1585   setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map));
1586 }
1587 
setLowerBoundMap(AffineMap map)1588 void AffineForOp::setLowerBoundMap(AffineMap map) {
1589   auto lbMap = getLowerBoundMap();
1590   assert(lbMap.getNumDims() == map.getNumDims() &&
1591          lbMap.getNumSymbols() == map.getNumSymbols());
1592   assert(map.getNumResults() >= 1 && "bound map has at least one result");
1593   (void)lbMap;
1594   setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map));
1595 }
1596 
setUpperBoundMap(AffineMap map)1597 void AffineForOp::setUpperBoundMap(AffineMap map) {
1598   auto ubMap = getUpperBoundMap();
1599   assert(ubMap.getNumDims() == map.getNumDims() &&
1600          ubMap.getNumSymbols() == map.getNumSymbols());
1601   assert(map.getNumResults() >= 1 && "bound map has at least one result");
1602   (void)ubMap;
1603   setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map));
1604 }
1605 
hasConstantLowerBound()1606 bool AffineForOp::hasConstantLowerBound() {
1607   return getLowerBoundMap().isSingleConstant();
1608 }
1609 
hasConstantUpperBound()1610 bool AffineForOp::hasConstantUpperBound() {
1611   return getUpperBoundMap().isSingleConstant();
1612 }
1613 
getConstantLowerBound()1614 int64_t AffineForOp::getConstantLowerBound() {
1615   return getLowerBoundMap().getSingleConstantResult();
1616 }
1617 
getConstantUpperBound()1618 int64_t AffineForOp::getConstantUpperBound() {
1619   return getUpperBoundMap().getSingleConstantResult();
1620 }
1621 
setConstantLowerBound(int64_t value)1622 void AffineForOp::setConstantLowerBound(int64_t value) {
1623   setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
1624 }
1625 
setConstantUpperBound(int64_t value)1626 void AffineForOp::setConstantUpperBound(int64_t value) {
1627   setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
1628 }
1629 
getLowerBoundOperands()1630 AffineForOp::operand_range AffineForOp::getLowerBoundOperands() {
1631   return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()};
1632 }
1633 
getUpperBoundOperands()1634 AffineForOp::operand_range AffineForOp::getUpperBoundOperands() {
1635   return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()};
1636 }
1637 
matchingBoundOperandList()1638 bool AffineForOp::matchingBoundOperandList() {
1639   auto lbMap = getLowerBoundMap();
1640   auto ubMap = getUpperBoundMap();
1641   if (lbMap.getNumDims() != ubMap.getNumDims() ||
1642       lbMap.getNumSymbols() != ubMap.getNumSymbols())
1643     return false;
1644 
1645   unsigned numOperands = lbMap.getNumInputs();
1646   for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
1647     // Compare Value 's.
1648     if (getOperand(i) != getOperand(numOperands + i))
1649       return false;
1650   }
1651   return true;
1652 }
1653 
getLoopBody()1654 Region &AffineForOp::getLoopBody() { return region(); }
1655 
isDefinedOutsideOfLoop(Value value)1656 bool AffineForOp::isDefinedOutsideOfLoop(Value value) {
1657   return !region().isAncestor(value.getParentRegion());
1658 }
1659 
moveOutOfLoop(ArrayRef<Operation * > ops)1660 LogicalResult AffineForOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
1661   for (auto *op : ops)
1662     op->moveBefore(*this);
1663   return success();
1664 }
1665 
1666 /// Returns if the provided value is the induction variable of a AffineForOp.
isForInductionVar(Value val)1667 bool mlir::isForInductionVar(Value val) {
1668   return getForInductionVarOwner(val) != AffineForOp();
1669 }
1670 
1671 /// Returns the loop parent of an induction variable. If the provided value is
1672 /// not an induction variable, then return nullptr.
getForInductionVarOwner(Value val)1673 AffineForOp mlir::getForInductionVarOwner(Value val) {
1674   auto ivArg = val.dyn_cast<BlockArgument>();
1675   if (!ivArg || !ivArg.getOwner())
1676     return AffineForOp();
1677   auto *containingInst = ivArg.getOwner()->getParent()->getParentOp();
1678   return dyn_cast<AffineForOp>(containingInst);
1679 }
1680 
1681 /// Extracts the induction variables from a list of AffineForOps and returns
1682 /// them.
extractForInductionVars(ArrayRef<AffineForOp> forInsts,SmallVectorImpl<Value> * ivs)1683 void mlir::extractForInductionVars(ArrayRef<AffineForOp> forInsts,
1684                                    SmallVectorImpl<Value> *ivs) {
1685   ivs->reserve(forInsts.size());
1686   for (auto forInst : forInsts)
1687     ivs->push_back(forInst.getInductionVar());
1688 }
1689 
1690 /// Builds an affine loop nest, using "loopCreatorFn" to create individual loop
1691 /// operations.
1692 template <typename BoundListTy, typename LoopCreatorTy>
buildAffineLoopNestImpl(OpBuilder & builder,Location loc,BoundListTy lbs,BoundListTy ubs,ArrayRef<int64_t> steps,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilderFn,LoopCreatorTy && loopCreatorFn)1693 static void buildAffineLoopNestImpl(
1694     OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs,
1695     ArrayRef<int64_t> steps,
1696     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
1697     LoopCreatorTy &&loopCreatorFn) {
1698   assert(lbs.size() == ubs.size() && "Mismatch in number of arguments");
1699   assert(lbs.size() == steps.size() && "Mismatch in number of arguments");
1700 
1701   // If there are no loops to be constructed, construct the body anyway.
1702   OpBuilder::InsertionGuard guard(builder);
1703   if (lbs.empty()) {
1704     if (bodyBuilderFn)
1705       bodyBuilderFn(builder, loc, ValueRange());
1706     return;
1707   }
1708 
1709   // Create the loops iteratively and store the induction variables.
1710   SmallVector<Value, 4> ivs;
1711   ivs.reserve(lbs.size());
1712   for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
1713     // Callback for creating the loop body, always creates the terminator.
1714     auto loopBody = [&](OpBuilder &nestedBuilder, Location nestedLoc,
1715                         Value iv) {
1716       ivs.push_back(iv);
1717       // In the innermost loop, call the body builder.
1718       if (i == e - 1 && bodyBuilderFn) {
1719         OpBuilder::InsertionGuard nestedGuard(nestedBuilder);
1720         bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
1721       }
1722       nestedBuilder.create<AffineYieldOp>(nestedLoc);
1723     };
1724 
1725     // Delegate actual loop creation to the callback in order to dispatch
1726     // between constant- and variable-bound loops.
1727     auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
1728     builder.setInsertionPointToStart(loop.getBody());
1729   }
1730 }
1731 
1732 /// Creates an affine loop from the bounds known to be constants.
buildAffineLoopFromConstants(OpBuilder & builder,Location loc,int64_t lb,int64_t ub,int64_t step,function_ref<void (OpBuilder &,Location,Value)> bodyBuilderFn)1733 static AffineForOp buildAffineLoopFromConstants(
1734     OpBuilder &builder, Location loc, int64_t lb, int64_t ub, int64_t step,
1735     function_ref<void(OpBuilder &, Location, Value)> bodyBuilderFn) {
1736   return builder.create<AffineForOp>(loc, lb, ub, step, bodyBuilderFn);
1737 }
1738 
1739 /// Creates an affine loop from the bounds that may or may not be constants.
buildAffineLoopFromValues(OpBuilder & builder,Location loc,Value lb,Value ub,int64_t step,function_ref<void (OpBuilder &,Location,Value)> bodyBuilderFn)1740 static AffineForOp buildAffineLoopFromValues(
1741     OpBuilder &builder, Location loc, Value lb, Value ub, int64_t step,
1742     function_ref<void(OpBuilder &, Location, Value)> bodyBuilderFn) {
1743   auto lbConst = lb.getDefiningOp<ConstantIndexOp>();
1744   auto ubConst = ub.getDefiningOp<ConstantIndexOp>();
1745   if (lbConst && ubConst)
1746     return buildAffineLoopFromConstants(builder, loc, lbConst.getValue(),
1747                                         ubConst.getValue(), step,
1748                                         bodyBuilderFn);
1749   return builder.create<AffineForOp>(loc, lb, builder.getDimIdentityMap(), ub,
1750                                      builder.getDimIdentityMap(), step,
1751                                      bodyBuilderFn);
1752 }
1753 
buildAffineLoopNest(OpBuilder & builder,Location loc,ArrayRef<int64_t> lbs,ArrayRef<int64_t> ubs,ArrayRef<int64_t> steps,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilderFn)1754 void mlir::buildAffineLoopNest(
1755     OpBuilder &builder, Location loc, ArrayRef<int64_t> lbs,
1756     ArrayRef<int64_t> ubs, ArrayRef<int64_t> steps,
1757     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1758   buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
1759                           buildAffineLoopFromConstants);
1760 }
1761 
buildAffineLoopNest(OpBuilder & builder,Location loc,ValueRange lbs,ValueRange ubs,ArrayRef<int64_t> steps,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilderFn)1762 void mlir::buildAffineLoopNest(
1763     OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
1764     ArrayRef<int64_t> steps,
1765     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1766   buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
1767                           buildAffineLoopFromValues);
1768 }
1769 
1770 //===----------------------------------------------------------------------===//
1771 // AffineIfOp
1772 //===----------------------------------------------------------------------===//
1773 
1774 namespace {
1775 /// Remove else blocks that have nothing other than a zero value yield.
1776 struct SimplifyDeadElse : public OpRewritePattern<AffineIfOp> {
1777   using OpRewritePattern<AffineIfOp>::OpRewritePattern;
1778 
matchAndRewrite__anon07ec94590e11::SimplifyDeadElse1779   LogicalResult matchAndRewrite(AffineIfOp ifOp,
1780                                 PatternRewriter &rewriter) const override {
1781     if (ifOp.elseRegion().empty() ||
1782         !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
1783       return failure();
1784 
1785     rewriter.startRootUpdate(ifOp);
1786     rewriter.eraseBlock(ifOp.getElseBlock());
1787     rewriter.finalizeRootUpdate(ifOp);
1788     return success();
1789   }
1790 };
1791 } // end anonymous namespace.
1792 
verify(AffineIfOp op)1793 static LogicalResult verify(AffineIfOp op) {
1794   // Verify that we have a condition attribute.
1795   auto conditionAttr =
1796       op.getAttrOfType<IntegerSetAttr>(op.getConditionAttrName());
1797   if (!conditionAttr)
1798     return op.emitOpError(
1799         "requires an integer set attribute named 'condition'");
1800 
1801   // Verify that there are enough operands for the condition.
1802   IntegerSet condition = conditionAttr.getValue();
1803   if (op.getNumOperands() != condition.getNumInputs())
1804     return op.emitOpError(
1805         "operand count and condition integer set dimension and "
1806         "symbol count must match");
1807 
1808   // Verify that the operands are valid dimension/symbols.
1809   if (failed(verifyDimAndSymbolIdentifiers(op, op.getOperands(),
1810                                            condition.getNumDims())))
1811     return failure();
1812 
1813   return success();
1814 }
1815 
parseAffineIfOp(OpAsmParser & parser,OperationState & result)1816 static ParseResult parseAffineIfOp(OpAsmParser &parser,
1817                                    OperationState &result) {
1818   // Parse the condition attribute set.
1819   IntegerSetAttr conditionAttr;
1820   unsigned numDims;
1821   if (parser.parseAttribute(conditionAttr, AffineIfOp::getConditionAttrName(),
1822                             result.attributes) ||
1823       parseDimAndSymbolList(parser, result.operands, numDims))
1824     return failure();
1825 
1826   // Verify the condition operands.
1827   auto set = conditionAttr.getValue();
1828   if (set.getNumDims() != numDims)
1829     return parser.emitError(
1830         parser.getNameLoc(),
1831         "dim operand count and integer set dim count must match");
1832   if (numDims + set.getNumSymbols() != result.operands.size())
1833     return parser.emitError(
1834         parser.getNameLoc(),
1835         "symbol operand count and integer set symbol count must match");
1836 
1837   if (parser.parseOptionalArrowTypeList(result.types))
1838     return failure();
1839 
1840   // Create the regions for 'then' and 'else'.  The latter must be created even
1841   // if it remains empty for the validity of the operation.
1842   result.regions.reserve(2);
1843   Region *thenRegion = result.addRegion();
1844   Region *elseRegion = result.addRegion();
1845 
1846   // Parse the 'then' region.
1847   if (parser.parseRegion(*thenRegion, {}, {}))
1848     return failure();
1849   AffineIfOp::ensureTerminator(*thenRegion, parser.getBuilder(),
1850                                result.location);
1851 
1852   // If we find an 'else' keyword then parse the 'else' region.
1853   if (!parser.parseOptionalKeyword("else")) {
1854     if (parser.parseRegion(*elseRegion, {}, {}))
1855       return failure();
1856     AffineIfOp::ensureTerminator(*elseRegion, parser.getBuilder(),
1857                                  result.location);
1858   }
1859 
1860   // Parse the optional attribute list.
1861   if (parser.parseOptionalAttrDict(result.attributes))
1862     return failure();
1863 
1864   return success();
1865 }
1866 
print(OpAsmPrinter & p,AffineIfOp op)1867 static void print(OpAsmPrinter &p, AffineIfOp op) {
1868   auto conditionAttr =
1869       op.getAttrOfType<IntegerSetAttr>(op.getConditionAttrName());
1870   p << "affine.if " << conditionAttr;
1871   printDimAndSymbolList(op.operand_begin(), op.operand_end(),
1872                         conditionAttr.getValue().getNumDims(), p);
1873   p.printOptionalArrowTypeList(op.getResultTypes());
1874   p.printRegion(op.thenRegion(),
1875                 /*printEntryBlockArgs=*/false,
1876                 /*printBlockTerminators=*/op.getNumResults());
1877 
1878   // Print the 'else' regions if it has any blocks.
1879   auto &elseRegion = op.elseRegion();
1880   if (!elseRegion.empty()) {
1881     p << " else";
1882     p.printRegion(elseRegion,
1883                   /*printEntryBlockArgs=*/false,
1884                   /*printBlockTerminators=*/op.getNumResults());
1885   }
1886 
1887   // Print the attribute list.
1888   p.printOptionalAttrDict(op.getAttrs(),
1889                           /*elidedAttrs=*/op.getConditionAttrName());
1890 }
1891 
getIntegerSet()1892 IntegerSet AffineIfOp::getIntegerSet() {
1893   return getAttrOfType<IntegerSetAttr>(getConditionAttrName()).getValue();
1894 }
setIntegerSet(IntegerSet newSet)1895 void AffineIfOp::setIntegerSet(IntegerSet newSet) {
1896   setAttr(getConditionAttrName(), IntegerSetAttr::get(newSet));
1897 }
1898 
setConditional(IntegerSet set,ValueRange operands)1899 void AffineIfOp::setConditional(IntegerSet set, ValueRange operands) {
1900   setIntegerSet(set);
1901   getOperation()->setOperands(operands);
1902 }
1903 
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,IntegerSet set,ValueRange args,bool withElseRegion)1904 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
1905                        TypeRange resultTypes, IntegerSet set, ValueRange args,
1906                        bool withElseRegion) {
1907   assert(resultTypes.empty() || withElseRegion);
1908   result.addTypes(resultTypes);
1909   result.addOperands(args);
1910   result.addAttribute(getConditionAttrName(), IntegerSetAttr::get(set));
1911 
1912   Region *thenRegion = result.addRegion();
1913   thenRegion->push_back(new Block());
1914   if (resultTypes.empty())
1915     AffineIfOp::ensureTerminator(*thenRegion, builder, result.location);
1916 
1917   Region *elseRegion = result.addRegion();
1918   if (withElseRegion) {
1919     elseRegion->push_back(new Block());
1920     if (resultTypes.empty())
1921       AffineIfOp::ensureTerminator(*elseRegion, builder, result.location);
1922   }
1923 }
1924 
build(OpBuilder & builder,OperationState & result,IntegerSet set,ValueRange args,bool withElseRegion)1925 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
1926                        IntegerSet set, ValueRange args, bool withElseRegion) {
1927   AffineIfOp::build(builder, result, /*resultTypes=*/{}, set, args,
1928                     withElseRegion);
1929 }
1930 
1931 /// Canonicalize an affine if op's conditional (integer set + operands).
fold(ArrayRef<Attribute>,SmallVectorImpl<OpFoldResult> &)1932 LogicalResult AffineIfOp::fold(ArrayRef<Attribute>,
1933                                SmallVectorImpl<OpFoldResult> &) {
1934   auto set = getIntegerSet();
1935   SmallVector<Value, 4> operands(getOperands());
1936   canonicalizeSetAndOperands(&set, &operands);
1937 
1938   // Any canonicalization change always leads to either a reduction in the
1939   // number of operands or a change in the number of symbolic operands
1940   // (promotion of dims to symbols).
1941   if (operands.size() < getIntegerSet().getNumInputs() ||
1942       set.getNumSymbols() > getIntegerSet().getNumSymbols()) {
1943     setConditional(set, operands);
1944     return success();
1945   }
1946 
1947   return failure();
1948 }
1949 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1950 void AffineIfOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1951                                              MLIRContext *context) {
1952   results.insert<SimplifyDeadElse>(context);
1953 }
1954 
1955 //===----------------------------------------------------------------------===//
1956 // AffineLoadOp
1957 //===----------------------------------------------------------------------===//
1958 
build(OpBuilder & builder,OperationState & result,AffineMap map,ValueRange operands)1959 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
1960                          AffineMap map, ValueRange operands) {
1961   assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
1962   result.addOperands(operands);
1963   if (map)
1964     result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
1965   auto memrefType = operands[0].getType().cast<MemRefType>();
1966   result.types.push_back(memrefType.getElementType());
1967 }
1968 
build(OpBuilder & builder,OperationState & result,Value memref,AffineMap map,ValueRange mapOperands)1969 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
1970                          Value memref, AffineMap map, ValueRange mapOperands) {
1971   assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
1972   result.addOperands(memref);
1973   result.addOperands(mapOperands);
1974   auto memrefType = memref.getType().cast<MemRefType>();
1975   result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
1976   result.types.push_back(memrefType.getElementType());
1977 }
1978 
build(OpBuilder & builder,OperationState & result,Value memref,ValueRange indices)1979 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
1980                          Value memref, ValueRange indices) {
1981   auto memrefType = memref.getType().cast<MemRefType>();
1982   auto rank = memrefType.getRank();
1983   // Create identity map for memrefs with at least one dimension or () -> ()
1984   // for zero-dimensional memrefs.
1985   auto map =
1986       rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
1987   build(builder, result, memref, map, indices);
1988 }
1989 
parseAffineLoadOp(OpAsmParser & parser,OperationState & result)1990 static ParseResult parseAffineLoadOp(OpAsmParser &parser,
1991                                      OperationState &result) {
1992   auto &builder = parser.getBuilder();
1993   auto indexTy = builder.getIndexType();
1994 
1995   MemRefType type;
1996   OpAsmParser::OperandType memrefInfo;
1997   AffineMapAttr mapAttr;
1998   SmallVector<OpAsmParser::OperandType, 1> mapOperands;
1999   return failure(
2000       parser.parseOperand(memrefInfo) ||
2001       parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
2002                                     AffineLoadOp::getMapAttrName(),
2003                                     result.attributes) ||
2004       parser.parseOptionalAttrDict(result.attributes) ||
2005       parser.parseColonType(type) ||
2006       parser.resolveOperand(memrefInfo, type, result.operands) ||
2007       parser.resolveOperands(mapOperands, indexTy, result.operands) ||
2008       parser.addTypeToList(type.getElementType(), result.types));
2009 }
2010 
print(OpAsmPrinter & p,AffineLoadOp op)2011 static void print(OpAsmPrinter &p, AffineLoadOp op) {
2012   p << "affine.load " << op.getMemRef() << '[';
2013   if (AffineMapAttr mapAttr =
2014           op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
2015     p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
2016   p << ']';
2017   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()});
2018   p << " : " << op.getMemRefType();
2019 }
2020 
2021 /// Verify common indexing invariants of affine.load, affine.store,
2022 /// affine.vector_load and affine.vector_store.
2023 static LogicalResult
verifyMemoryOpIndexing(Operation * op,AffineMapAttr mapAttr,Operation::operand_range mapOperands,MemRefType memrefType,unsigned numIndexOperands)2024 verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
2025                        Operation::operand_range mapOperands,
2026                        MemRefType memrefType, unsigned numIndexOperands) {
2027   if (mapAttr) {
2028     AffineMap map = mapAttr.getValue();
2029     if (map.getNumResults() != memrefType.getRank())
2030       return op->emitOpError("affine map num results must equal memref rank");
2031     if (map.getNumInputs() != numIndexOperands)
2032       return op->emitOpError("expects as many subscripts as affine map inputs");
2033   } else {
2034     if (memrefType.getRank() != numIndexOperands)
2035       return op->emitOpError(
2036           "expects the number of subscripts to be equal to memref rank");
2037   }
2038 
2039   Region *scope = getAffineScope(op);
2040   for (auto idx : mapOperands) {
2041     if (!idx.getType().isIndex())
2042       return op->emitOpError("index to load must have 'index' type");
2043     if (!isValidAffineIndexOperand(idx, scope))
2044       return op->emitOpError("index must be a dimension or symbol identifier");
2045   }
2046 
2047   return success();
2048 }
2049 
verify(AffineLoadOp op)2050 LogicalResult verify(AffineLoadOp op) {
2051   auto memrefType = op.getMemRefType();
2052   if (op.getType() != memrefType.getElementType())
2053     return op.emitOpError("result type must match element type of memref");
2054 
2055   if (failed(verifyMemoryOpIndexing(
2056           op.getOperation(),
2057           op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
2058           op.getMapOperands(), memrefType,
2059           /*numIndexOperands=*/op.getNumOperands() - 1)))
2060     return failure();
2061 
2062   return success();
2063 }
2064 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2065 void AffineLoadOp::getCanonicalizationPatterns(
2066     OwningRewritePatternList &results, MLIRContext *context) {
2067   results.insert<SimplifyAffineOp<AffineLoadOp>>(context);
2068 }
2069 
fold(ArrayRef<Attribute> cstOperands)2070 OpFoldResult AffineLoadOp::fold(ArrayRef<Attribute> cstOperands) {
2071   /// load(memrefcast) -> load
2072   if (succeeded(foldMemRefCast(*this)))
2073     return getResult();
2074   return OpFoldResult();
2075 }
2076 
2077 //===----------------------------------------------------------------------===//
2078 // AffineStoreOp
2079 //===----------------------------------------------------------------------===//
2080 
build(OpBuilder & builder,OperationState & result,Value valueToStore,Value memref,AffineMap map,ValueRange mapOperands)2081 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
2082                           Value valueToStore, Value memref, AffineMap map,
2083                           ValueRange mapOperands) {
2084   assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
2085   result.addOperands(valueToStore);
2086   result.addOperands(memref);
2087   result.addOperands(mapOperands);
2088   result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
2089 }
2090 
2091 // Use identity map.
build(OpBuilder & builder,OperationState & result,Value valueToStore,Value memref,ValueRange indices)2092 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
2093                           Value valueToStore, Value memref,
2094                           ValueRange indices) {
2095   auto memrefType = memref.getType().cast<MemRefType>();
2096   auto rank = memrefType.getRank();
2097   // Create identity map for memrefs with at least one dimension or () -> ()
2098   // for zero-dimensional memrefs.
2099   auto map =
2100       rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
2101   build(builder, result, valueToStore, memref, map, indices);
2102 }
2103 
parseAffineStoreOp(OpAsmParser & parser,OperationState & result)2104 static ParseResult parseAffineStoreOp(OpAsmParser &parser,
2105                                       OperationState &result) {
2106   auto indexTy = parser.getBuilder().getIndexType();
2107 
2108   MemRefType type;
2109   OpAsmParser::OperandType storeValueInfo;
2110   OpAsmParser::OperandType memrefInfo;
2111   AffineMapAttr mapAttr;
2112   SmallVector<OpAsmParser::OperandType, 1> mapOperands;
2113   return failure(parser.parseOperand(storeValueInfo) || parser.parseComma() ||
2114                  parser.parseOperand(memrefInfo) ||
2115                  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
2116                                                AffineStoreOp::getMapAttrName(),
2117                                                result.attributes) ||
2118                  parser.parseOptionalAttrDict(result.attributes) ||
2119                  parser.parseColonType(type) ||
2120                  parser.resolveOperand(storeValueInfo, type.getElementType(),
2121                                        result.operands) ||
2122                  parser.resolveOperand(memrefInfo, type, result.operands) ||
2123                  parser.resolveOperands(mapOperands, indexTy, result.operands));
2124 }
2125 
print(OpAsmPrinter & p,AffineStoreOp op)2126 static void print(OpAsmPrinter &p, AffineStoreOp op) {
2127   p << "affine.store " << op.getValueToStore();
2128   p << ", " << op.getMemRef() << '[';
2129   if (AffineMapAttr mapAttr =
2130           op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
2131     p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
2132   p << ']';
2133   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()});
2134   p << " : " << op.getMemRefType();
2135 }
2136 
verify(AffineStoreOp op)2137 LogicalResult verify(AffineStoreOp op) {
2138   // First operand must have same type as memref element type.
2139   auto memrefType = op.getMemRefType();
2140   if (op.getValueToStore().getType() != memrefType.getElementType())
2141     return op.emitOpError(
2142         "first operand must have same type memref element type");
2143 
2144   if (failed(verifyMemoryOpIndexing(
2145           op.getOperation(),
2146           op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
2147           op.getMapOperands(), memrefType,
2148           /*numIndexOperands=*/op.getNumOperands() - 2)))
2149     return failure();
2150 
2151   return success();
2152 }
2153 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2154 void AffineStoreOp::getCanonicalizationPatterns(
2155     OwningRewritePatternList &results, MLIRContext *context) {
2156   results.insert<SimplifyAffineOp<AffineStoreOp>>(context);
2157 }
2158 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)2159 LogicalResult AffineStoreOp::fold(ArrayRef<Attribute> cstOperands,
2160                                   SmallVectorImpl<OpFoldResult> &results) {
2161   /// store(memrefcast) -> store
2162   return foldMemRefCast(*this);
2163 }
2164 
2165 //===----------------------------------------------------------------------===//
2166 // AffineMinMaxOpBase
2167 //===----------------------------------------------------------------------===//
2168 
2169 template <typename T>
verifyAffineMinMaxOp(T op)2170 static LogicalResult verifyAffineMinMaxOp(T op) {
2171   // Verify that operand count matches affine map dimension and symbol count.
2172   if (op.getNumOperands() != op.map().getNumDims() + op.map().getNumSymbols())
2173     return op.emitOpError(
2174         "operand count and affine map dimension and symbol count must match");
2175   return success();
2176 }
2177 
2178 template <typename T>
printAffineMinMaxOp(OpAsmPrinter & p,T op)2179 static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
2180   p << op.getOperationName() << ' ' << op.getAttr(T::getMapAttrName());
2181   auto operands = op.getOperands();
2182   unsigned numDims = op.map().getNumDims();
2183   p << '(' << operands.take_front(numDims) << ')';
2184 
2185   if (operands.size() != numDims)
2186     p << '[' << operands.drop_front(numDims) << ']';
2187   p.printOptionalAttrDict(op.getAttrs(),
2188                           /*elidedAttrs=*/{T::getMapAttrName()});
2189 }
2190 
2191 template <typename T>
parseAffineMinMaxOp(OpAsmParser & parser,OperationState & result)2192 static ParseResult parseAffineMinMaxOp(OpAsmParser &parser,
2193                                        OperationState &result) {
2194   auto &builder = parser.getBuilder();
2195   auto indexType = builder.getIndexType();
2196   SmallVector<OpAsmParser::OperandType, 8> dim_infos;
2197   SmallVector<OpAsmParser::OperandType, 8> sym_infos;
2198   AffineMapAttr mapAttr;
2199   return failure(
2200       parser.parseAttribute(mapAttr, T::getMapAttrName(), result.attributes) ||
2201       parser.parseOperandList(dim_infos, OpAsmParser::Delimiter::Paren) ||
2202       parser.parseOperandList(sym_infos,
2203                               OpAsmParser::Delimiter::OptionalSquare) ||
2204       parser.parseOptionalAttrDict(result.attributes) ||
2205       parser.resolveOperands(dim_infos, indexType, result.operands) ||
2206       parser.resolveOperands(sym_infos, indexType, result.operands) ||
2207       parser.addTypeToList(indexType, result.types));
2208 }
2209 
2210 /// Fold an affine min or max operation with the given operands. The operand
2211 /// list may contain nulls, which are interpreted as the operand not being a
2212 /// constant.
2213 template <typename T>
foldMinMaxOp(T op,ArrayRef<Attribute> operands)2214 static OpFoldResult foldMinMaxOp(T op, ArrayRef<Attribute> operands) {
2215   static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
2216                 "expected affine min or max op");
2217 
2218   // Fold the affine map.
2219   // TODO: Fold more cases:
2220   // min(some_affine, some_affine + constant, ...), etc.
2221   SmallVector<int64_t, 2> results;
2222   auto foldedMap = op.map().partialConstantFold(operands, &results);
2223 
2224   // If some of the map results are not constant, try changing the map in-place.
2225   if (results.empty()) {
2226     // If the map is the same, report that folding did not happen.
2227     if (foldedMap == op.map())
2228       return {};
2229     op.setAttr("map", AffineMapAttr::get(foldedMap));
2230     return op.getResult();
2231   }
2232 
2233   // Otherwise, completely fold the op into a constant.
2234   auto resultIt = std::is_same<T, AffineMinOp>::value
2235                       ? std::min_element(results.begin(), results.end())
2236                       : std::max_element(results.begin(), results.end());
2237   if (resultIt == results.end())
2238     return {};
2239   return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt);
2240 }
2241 
2242 //===----------------------------------------------------------------------===//
2243 // AffineMinOp
2244 //===----------------------------------------------------------------------===//
2245 //
2246 //   %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
2247 //
2248 
fold(ArrayRef<Attribute> operands)2249 OpFoldResult AffineMinOp::fold(ArrayRef<Attribute> operands) {
2250   return foldMinMaxOp(*this, operands);
2251 }
2252 
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)2253 void AffineMinOp::getCanonicalizationPatterns(
2254     OwningRewritePatternList &patterns, MLIRContext *context) {
2255   patterns.insert<SimplifyAffineOp<AffineMinOp>>(context);
2256 }
2257 
2258 //===----------------------------------------------------------------------===//
2259 // AffineMaxOp
2260 //===----------------------------------------------------------------------===//
2261 //
2262 //   %0 = affine.max (d0) -> (1000, d0 + 512) (%i0)
2263 //
2264 
fold(ArrayRef<Attribute> operands)2265 OpFoldResult AffineMaxOp::fold(ArrayRef<Attribute> operands) {
2266   return foldMinMaxOp(*this, operands);
2267 }
2268 
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)2269 void AffineMaxOp::getCanonicalizationPatterns(
2270     OwningRewritePatternList &patterns, MLIRContext *context) {
2271   patterns.insert<SimplifyAffineOp<AffineMaxOp>>(context);
2272 }
2273 
2274 //===----------------------------------------------------------------------===//
2275 // AffinePrefetchOp
2276 //===----------------------------------------------------------------------===//
2277 
2278 //
2279 // affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32>
2280 //
parseAffinePrefetchOp(OpAsmParser & parser,OperationState & result)2281 static ParseResult parseAffinePrefetchOp(OpAsmParser &parser,
2282                                          OperationState &result) {
2283   auto &builder = parser.getBuilder();
2284   auto indexTy = builder.getIndexType();
2285 
2286   MemRefType type;
2287   OpAsmParser::OperandType memrefInfo;
2288   IntegerAttr hintInfo;
2289   auto i32Type = parser.getBuilder().getIntegerType(32);
2290   StringRef readOrWrite, cacheType;
2291 
2292   AffineMapAttr mapAttr;
2293   SmallVector<OpAsmParser::OperandType, 1> mapOperands;
2294   if (parser.parseOperand(memrefInfo) ||
2295       parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
2296                                     AffinePrefetchOp::getMapAttrName(),
2297                                     result.attributes) ||
2298       parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
2299       parser.parseComma() || parser.parseKeyword("locality") ||
2300       parser.parseLess() ||
2301       parser.parseAttribute(hintInfo, i32Type,
2302                             AffinePrefetchOp::getLocalityHintAttrName(),
2303                             result.attributes) ||
2304       parser.parseGreater() || parser.parseComma() ||
2305       parser.parseKeyword(&cacheType) ||
2306       parser.parseOptionalAttrDict(result.attributes) ||
2307       parser.parseColonType(type) ||
2308       parser.resolveOperand(memrefInfo, type, result.operands) ||
2309       parser.resolveOperands(mapOperands, indexTy, result.operands))
2310     return failure();
2311 
2312   if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
2313     return parser.emitError(parser.getNameLoc(),
2314                             "rw specifier has to be 'read' or 'write'");
2315   result.addAttribute(
2316       AffinePrefetchOp::getIsWriteAttrName(),
2317       parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
2318 
2319   if (!cacheType.equals("data") && !cacheType.equals("instr"))
2320     return parser.emitError(parser.getNameLoc(),
2321                             "cache type has to be 'data' or 'instr'");
2322 
2323   result.addAttribute(
2324       AffinePrefetchOp::getIsDataCacheAttrName(),
2325       parser.getBuilder().getBoolAttr(cacheType.equals("data")));
2326 
2327   return success();
2328 }
2329 
print(OpAsmPrinter & p,AffinePrefetchOp op)2330 static void print(OpAsmPrinter &p, AffinePrefetchOp op) {
2331   p << AffinePrefetchOp::getOperationName() << " " << op.memref() << '[';
2332   AffineMapAttr mapAttr = op.getAttrOfType<AffineMapAttr>(op.getMapAttrName());
2333   if (mapAttr) {
2334     SmallVector<Value, 2> operands(op.getMapOperands());
2335     p.printAffineMapOfSSAIds(mapAttr, operands);
2336   }
2337   p << ']' << ", " << (op.isWrite() ? "write" : "read") << ", "
2338     << "locality<" << op.localityHint() << ">, "
2339     << (op.isDataCache() ? "data" : "instr");
2340   p.printOptionalAttrDict(
2341       op.getAttrs(),
2342       /*elidedAttrs=*/{op.getMapAttrName(), op.getLocalityHintAttrName(),
2343                        op.getIsDataCacheAttrName(), op.getIsWriteAttrName()});
2344   p << " : " << op.getMemRefType();
2345 }
2346 
verify(AffinePrefetchOp op)2347 static LogicalResult verify(AffinePrefetchOp op) {
2348   auto mapAttr = op.getAttrOfType<AffineMapAttr>(op.getMapAttrName());
2349   if (mapAttr) {
2350     AffineMap map = mapAttr.getValue();
2351     if (map.getNumResults() != op.getMemRefType().getRank())
2352       return op.emitOpError("affine.prefetch affine map num results must equal"
2353                             " memref rank");
2354     if (map.getNumInputs() + 1 != op.getNumOperands())
2355       return op.emitOpError("too few operands");
2356   } else {
2357     if (op.getNumOperands() != 1)
2358       return op.emitOpError("too few operands");
2359   }
2360 
2361   Region *scope = getAffineScope(op);
2362   for (auto idx : op.getMapOperands()) {
2363     if (!isValidAffineIndexOperand(idx, scope))
2364       return op.emitOpError("index must be a dimension or symbol identifier");
2365   }
2366   return success();
2367 }
2368 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2369 void AffinePrefetchOp::getCanonicalizationPatterns(
2370     OwningRewritePatternList &results, MLIRContext *context) {
2371   // prefetch(memrefcast) -> prefetch
2372   results.insert<SimplifyAffineOp<AffinePrefetchOp>>(context);
2373 }
2374 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)2375 LogicalResult AffinePrefetchOp::fold(ArrayRef<Attribute> cstOperands,
2376                                      SmallVectorImpl<OpFoldResult> &results) {
2377   /// prefetch(memrefcast) -> prefetch
2378   return foldMemRefCast(*this);
2379 }
2380 
2381 //===----------------------------------------------------------------------===//
2382 // AffineParallelOp
2383 //===----------------------------------------------------------------------===//
2384 
build(OpBuilder & builder,OperationState & result,ArrayRef<Type> resultTypes,ArrayRef<AtomicRMWKind> reductions,ArrayRef<int64_t> ranges)2385 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
2386                              ArrayRef<Type> resultTypes,
2387                              ArrayRef<AtomicRMWKind> reductions,
2388                              ArrayRef<int64_t> ranges) {
2389   SmallVector<AffineExpr, 8> lbExprs(ranges.size(),
2390                                      builder.getAffineConstantExpr(0));
2391   auto lbMap = AffineMap::get(0, 0, lbExprs, builder.getContext());
2392   SmallVector<AffineExpr, 8> ubExprs;
2393   for (int64_t range : ranges)
2394     ubExprs.push_back(builder.getAffineConstantExpr(range));
2395   auto ubMap = AffineMap::get(0, 0, ubExprs, builder.getContext());
2396   build(builder, result, resultTypes, reductions, lbMap, /*lbArgs=*/{}, ubMap,
2397         /*ubArgs=*/{});
2398 }
2399 
build(OpBuilder & builder,OperationState & result,ArrayRef<Type> resultTypes,ArrayRef<AtomicRMWKind> reductions,AffineMap lbMap,ValueRange lbArgs,AffineMap ubMap,ValueRange ubArgs)2400 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
2401                              ArrayRef<Type> resultTypes,
2402                              ArrayRef<AtomicRMWKind> reductions,
2403                              AffineMap lbMap, ValueRange lbArgs,
2404                              AffineMap ubMap, ValueRange ubArgs) {
2405   auto numDims = lbMap.getNumResults();
2406   // Verify that the dimensionality of both maps are the same.
2407   assert(numDims == ubMap.getNumResults() &&
2408          "num dims and num results mismatch");
2409   // Make default step sizes of 1.
2410   SmallVector<int64_t, 8> steps(numDims, 1);
2411   build(builder, result, resultTypes, reductions, lbMap, lbArgs, ubMap, ubArgs,
2412         steps);
2413 }
2414 
build(OpBuilder & builder,OperationState & result,ArrayRef<Type> resultTypes,ArrayRef<AtomicRMWKind> reductions,AffineMap lbMap,ValueRange lbArgs,AffineMap ubMap,ValueRange ubArgs,ArrayRef<int64_t> steps)2415 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
2416                              ArrayRef<Type> resultTypes,
2417                              ArrayRef<AtomicRMWKind> reductions,
2418                              AffineMap lbMap, ValueRange lbArgs,
2419                              AffineMap ubMap, ValueRange ubArgs,
2420                              ArrayRef<int64_t> steps) {
2421   auto numDims = lbMap.getNumResults();
2422   // Verify that the dimensionality of the maps matches the number of steps.
2423   assert(numDims == ubMap.getNumResults() &&
2424          "num dims and num results mismatch");
2425   assert(numDims == steps.size() && "num dims and num steps mismatch");
2426 
2427   result.addTypes(resultTypes);
2428   // Convert the reductions to integer attributes.
2429   SmallVector<Attribute, 4> reductionAttrs;
2430   for (AtomicRMWKind reduction : reductions)
2431     reductionAttrs.push_back(
2432         builder.getI64IntegerAttr(static_cast<int64_t>(reduction)));
2433   result.addAttribute(getReductionsAttrName(),
2434                       builder.getArrayAttr(reductionAttrs));
2435   result.addAttribute(getLowerBoundsMapAttrName(), AffineMapAttr::get(lbMap));
2436   result.addAttribute(getUpperBoundsMapAttrName(), AffineMapAttr::get(ubMap));
2437   result.addAttribute(getStepsAttrName(), builder.getI64ArrayAttr(steps));
2438   result.addOperands(lbArgs);
2439   result.addOperands(ubArgs);
2440   // Create a region and a block for the body.
2441   auto bodyRegion = result.addRegion();
2442   auto body = new Block();
2443   // Add all the block arguments.
2444   for (unsigned i = 0; i < numDims; ++i)
2445     body->addArgument(IndexType::get(builder.getContext()));
2446   bodyRegion->push_back(body);
2447   if (resultTypes.empty())
2448     ensureTerminator(*bodyRegion, builder, result.location);
2449 }
2450 
getLoopBody()2451 Region &AffineParallelOp::getLoopBody() { return region(); }
2452 
isDefinedOutsideOfLoop(Value value)2453 bool AffineParallelOp::isDefinedOutsideOfLoop(Value value) {
2454   return !region().isAncestor(value.getParentRegion());
2455 }
2456 
moveOutOfLoop(ArrayRef<Operation * > ops)2457 LogicalResult AffineParallelOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
2458   for (Operation *op : ops)
2459     op->moveBefore(*this);
2460   return success();
2461 }
2462 
getNumDims()2463 unsigned AffineParallelOp::getNumDims() { return steps().size(); }
2464 
getLowerBoundsOperands()2465 AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
2466   return getOperands().take_front(lowerBoundsMap().getNumInputs());
2467 }
2468 
getUpperBoundsOperands()2469 AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
2470   return getOperands().drop_front(lowerBoundsMap().getNumInputs());
2471 }
2472 
getLowerBoundsValueMap()2473 AffineValueMap AffineParallelOp::getLowerBoundsValueMap() {
2474   return AffineValueMap(lowerBoundsMap(), getLowerBoundsOperands());
2475 }
2476 
getUpperBoundsValueMap()2477 AffineValueMap AffineParallelOp::getUpperBoundsValueMap() {
2478   return AffineValueMap(upperBoundsMap(), getUpperBoundsOperands());
2479 }
2480 
getRangesValueMap()2481 AffineValueMap AffineParallelOp::getRangesValueMap() {
2482   AffineValueMap out;
2483   AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
2484                              &out);
2485   return out;
2486 }
2487 
getConstantRanges()2488 Optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
2489   // Try to convert all the ranges to constant expressions.
2490   SmallVector<int64_t, 8> out;
2491   AffineValueMap rangesValueMap = getRangesValueMap();
2492   out.reserve(rangesValueMap.getNumResults());
2493   for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) {
2494     auto expr = rangesValueMap.getResult(i);
2495     auto cst = expr.dyn_cast<AffineConstantExpr>();
2496     if (!cst)
2497       return llvm::None;
2498     out.push_back(cst.getValue());
2499   }
2500   return out;
2501 }
2502 
getBody()2503 Block *AffineParallelOp::getBody() { return &region().front(); }
2504 
getBodyBuilder()2505 OpBuilder AffineParallelOp::getBodyBuilder() {
2506   return OpBuilder(getBody(), std::prev(getBody()->end()));
2507 }
2508 
setSteps(ArrayRef<int64_t> newSteps)2509 void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
2510   assert(newSteps.size() == getNumDims() && "steps & num dims mismatch");
2511   setAttr(getStepsAttrName(), getBodyBuilder().getI64ArrayAttr(newSteps));
2512 }
2513 
verify(AffineParallelOp op)2514 static LogicalResult verify(AffineParallelOp op) {
2515   auto numDims = op.getNumDims();
2516   if (op.lowerBoundsMap().getNumResults() != numDims ||
2517       op.upperBoundsMap().getNumResults() != numDims ||
2518       op.steps().size() != numDims ||
2519       op.getBody()->getNumArguments() != numDims)
2520     return op.emitOpError("region argument count and num results of upper "
2521                           "bounds, lower bounds, and steps must all match");
2522 
2523   if (op.reductions().size() != op.getNumResults())
2524     return op.emitOpError("a reduction must be specified for each output");
2525 
2526   // Verify reduction  ops are all valid
2527   for (Attribute attr : op.reductions()) {
2528     auto intAttr = attr.dyn_cast<IntegerAttr>();
2529     if (!intAttr || !symbolizeAtomicRMWKind(intAttr.getInt()))
2530       return op.emitOpError("invalid reduction attribute");
2531   }
2532 
2533   // Verify that the bound operands are valid dimension/symbols.
2534   /// Lower bounds.
2535   if (failed(verifyDimAndSymbolIdentifiers(op, op.getLowerBoundsOperands(),
2536                                            op.lowerBoundsMap().getNumDims())))
2537     return failure();
2538   /// Upper bounds.
2539   if (failed(verifyDimAndSymbolIdentifiers(op, op.getUpperBoundsOperands(),
2540                                            op.upperBoundsMap().getNumDims())))
2541     return failure();
2542   return success();
2543 }
2544 
print(OpAsmPrinter & p,AffineParallelOp op)2545 static void print(OpAsmPrinter &p, AffineParallelOp op) {
2546   p << op.getOperationName() << " (" << op.getBody()->getArguments() << ") = (";
2547   p.printAffineMapOfSSAIds(op.lowerBoundsMapAttr(),
2548                            op.getLowerBoundsOperands());
2549   p << ") to (";
2550   p.printAffineMapOfSSAIds(op.upperBoundsMapAttr(),
2551                            op.getUpperBoundsOperands());
2552   p << ')';
2553   SmallVector<int64_t, 4> steps;
2554   bool elideSteps = true;
2555   for (auto attr : op.steps()) {
2556     auto step = attr.cast<IntegerAttr>().getInt();
2557     elideSteps &= (step == 1);
2558     steps.push_back(step);
2559   }
2560   if (!elideSteps) {
2561     p << " step (";
2562     llvm::interleaveComma(steps, p);
2563     p << ')';
2564   }
2565   if (op.getNumResults()) {
2566     p << " reduce (";
2567     llvm::interleaveComma(op.reductions(), p, [&](auto &attr) {
2568       AtomicRMWKind sym =
2569           *symbolizeAtomicRMWKind(attr.template cast<IntegerAttr>().getInt());
2570       p << "\"" << stringifyAtomicRMWKind(sym) << "\"";
2571     });
2572     p << ") -> (" << op.getResultTypes() << ")";
2573   }
2574 
2575   p.printRegion(op.region(), /*printEntryBlockArgs=*/false,
2576                 /*printBlockTerminators=*/op.getNumResults());
2577   p.printOptionalAttrDict(
2578       op.getAttrs(),
2579       /*elidedAttrs=*/{AffineParallelOp::getReductionsAttrName(),
2580                        AffineParallelOp::getLowerBoundsMapAttrName(),
2581                        AffineParallelOp::getUpperBoundsMapAttrName(),
2582                        AffineParallelOp::getStepsAttrName()});
2583 }
2584 
2585 //
2586 // operation ::= `affine.parallel` `(` ssa-ids `)` `=` `(` map-of-ssa-ids `)`
2587 //               `to` `(` map-of-ssa-ids `)` steps? region attr-dict?
2588 // steps     ::= `steps` `(` integer-literals `)`
2589 //
parseAffineParallelOp(OpAsmParser & parser,OperationState & result)2590 static ParseResult parseAffineParallelOp(OpAsmParser &parser,
2591                                          OperationState &result) {
2592   auto &builder = parser.getBuilder();
2593   auto indexType = builder.getIndexType();
2594   AffineMapAttr lowerBoundsAttr, upperBoundsAttr;
2595   SmallVector<OpAsmParser::OperandType, 4> ivs;
2596   SmallVector<OpAsmParser::OperandType, 4> lowerBoundsMapOperands;
2597   SmallVector<OpAsmParser::OperandType, 4> upperBoundsMapOperands;
2598   if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
2599                                      OpAsmParser::Delimiter::Paren) ||
2600       parser.parseEqual() ||
2601       parser.parseAffineMapOfSSAIds(
2602           lowerBoundsMapOperands, lowerBoundsAttr,
2603           AffineParallelOp::getLowerBoundsMapAttrName(), result.attributes,
2604           OpAsmParser::Delimiter::Paren) ||
2605       parser.resolveOperands(lowerBoundsMapOperands, indexType,
2606                              result.operands) ||
2607       parser.parseKeyword("to") ||
2608       parser.parseAffineMapOfSSAIds(
2609           upperBoundsMapOperands, upperBoundsAttr,
2610           AffineParallelOp::getUpperBoundsMapAttrName(), result.attributes,
2611           OpAsmParser::Delimiter::Paren) ||
2612       parser.resolveOperands(upperBoundsMapOperands, indexType,
2613                              result.operands))
2614     return failure();
2615 
2616   AffineMapAttr stepsMapAttr;
2617   NamedAttrList stepsAttrs;
2618   SmallVector<OpAsmParser::OperandType, 4> stepsMapOperands;
2619   if (failed(parser.parseOptionalKeyword("step"))) {
2620     SmallVector<int64_t, 4> steps(ivs.size(), 1);
2621     result.addAttribute(AffineParallelOp::getStepsAttrName(),
2622                         builder.getI64ArrayAttr(steps));
2623   } else {
2624     if (parser.parseAffineMapOfSSAIds(stepsMapOperands, stepsMapAttr,
2625                                       AffineParallelOp::getStepsAttrName(),
2626                                       stepsAttrs,
2627                                       OpAsmParser::Delimiter::Paren))
2628       return failure();
2629 
2630     // Convert steps from an AffineMap into an I64ArrayAttr.
2631     SmallVector<int64_t, 4> steps;
2632     auto stepsMap = stepsMapAttr.getValue();
2633     for (const auto &result : stepsMap.getResults()) {
2634       auto constExpr = result.dyn_cast<AffineConstantExpr>();
2635       if (!constExpr)
2636         return parser.emitError(parser.getNameLoc(),
2637                                 "steps must be constant integers");
2638       steps.push_back(constExpr.getValue());
2639     }
2640     result.addAttribute(AffineParallelOp::getStepsAttrName(),
2641                         builder.getI64ArrayAttr(steps));
2642   }
2643 
2644   // Parse optional clause of the form: `reduce ("addf", "maxf")`, where the
2645   // quoted strings a member of the enum AtomicRMWKind.
2646   SmallVector<Attribute, 4> reductions;
2647   if (succeeded(parser.parseOptionalKeyword("reduce"))) {
2648     if (parser.parseLParen())
2649       return failure();
2650     do {
2651       // Parse a single quoted string via the attribute parsing, and then
2652       // verify it is a member of the enum and convert to it's integer
2653       // representation.
2654       StringAttr attrVal;
2655       NamedAttrList attrStorage;
2656       auto loc = parser.getCurrentLocation();
2657       if (parser.parseAttribute(attrVal, builder.getNoneType(), "reduce",
2658                                 attrStorage))
2659         return failure();
2660       llvm::Optional<AtomicRMWKind> reduction =
2661           symbolizeAtomicRMWKind(attrVal.getValue());
2662       if (!reduction)
2663         return parser.emitError(loc, "invalid reduction value: ") << attrVal;
2664       reductions.push_back(builder.getI64IntegerAttr(
2665           static_cast<int64_t>(reduction.getValue())));
2666       // While we keep getting commas, keep parsing.
2667     } while (succeeded(parser.parseOptionalComma()));
2668     if (parser.parseRParen())
2669       return failure();
2670   }
2671   result.addAttribute(AffineParallelOp::getReductionsAttrName(),
2672                       builder.getArrayAttr(reductions));
2673 
2674   // Parse return types of reductions (if any)
2675   if (parser.parseOptionalArrowTypeList(result.types))
2676     return failure();
2677 
2678   // Now parse the body.
2679   Region *body = result.addRegion();
2680   SmallVector<Type, 4> types(ivs.size(), indexType);
2681   if (parser.parseRegion(*body, ivs, types) ||
2682       parser.parseOptionalAttrDict(result.attributes))
2683     return failure();
2684 
2685   // Add a terminator if none was parsed.
2686   AffineParallelOp::ensureTerminator(*body, builder, result.location);
2687   return success();
2688 }
2689 
2690 //===----------------------------------------------------------------------===//
2691 // AffineYieldOp
2692 //===----------------------------------------------------------------------===//
2693 
verify(AffineYieldOp op)2694 static LogicalResult verify(AffineYieldOp op) {
2695   auto parentOp = op.getParentOp();
2696   auto results = parentOp->getResults();
2697   auto operands = op.getOperands();
2698 
2699   if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
2700     return op.emitOpError()
2701            << "affine.terminate only terminates If, For or Parallel regions";
2702   if (parentOp->getNumResults() != op.getNumOperands())
2703     return op.emitOpError() << "parent of yield must have same number of "
2704                                "results as the yield operands";
2705   for (auto it : llvm::zip(results, operands)) {
2706     if (std::get<0>(it).getType() != std::get<1>(it).getType())
2707       return op.emitOpError()
2708              << "types mismatch between yield op and its parent";
2709   }
2710 
2711   return success();
2712 }
2713 
2714 //===----------------------------------------------------------------------===//
2715 // AffineVectorLoadOp
2716 //===----------------------------------------------------------------------===//
2717 
parseAffineVectorLoadOp(OpAsmParser & parser,OperationState & result)2718 static ParseResult parseAffineVectorLoadOp(OpAsmParser &parser,
2719                                            OperationState &result) {
2720   auto &builder = parser.getBuilder();
2721   auto indexTy = builder.getIndexType();
2722 
2723   MemRefType memrefType;
2724   VectorType resultType;
2725   OpAsmParser::OperandType memrefInfo;
2726   AffineMapAttr mapAttr;
2727   SmallVector<OpAsmParser::OperandType, 1> mapOperands;
2728   return failure(
2729       parser.parseOperand(memrefInfo) ||
2730       parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
2731                                     AffineVectorLoadOp::getMapAttrName(),
2732                                     result.attributes) ||
2733       parser.parseOptionalAttrDict(result.attributes) ||
2734       parser.parseColonType(memrefType) || parser.parseComma() ||
2735       parser.parseType(resultType) ||
2736       parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
2737       parser.resolveOperands(mapOperands, indexTy, result.operands) ||
2738       parser.addTypeToList(resultType, result.types));
2739 }
2740 
print(OpAsmPrinter & p,AffineVectorLoadOp op)2741 static void print(OpAsmPrinter &p, AffineVectorLoadOp op) {
2742   p << "affine.vector_load " << op.getMemRef() << '[';
2743   if (AffineMapAttr mapAttr =
2744           op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
2745     p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
2746   p << ']';
2747   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()});
2748   p << " : " << op.getMemRefType() << ", " << op.getType();
2749 }
2750 
2751 /// Verify common invariants of affine.vector_load and affine.vector_store.
verifyVectorMemoryOp(Operation * op,MemRefType memrefType,VectorType vectorType)2752 static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType,
2753                                           VectorType vectorType) {
2754   // Check that memref and vector element types match.
2755   if (memrefType.getElementType() != vectorType.getElementType())
2756     return op->emitOpError(
2757         "requires memref and vector types of the same elemental type");
2758   return success();
2759 }
2760 
verify(AffineVectorLoadOp op)2761 static LogicalResult verify(AffineVectorLoadOp op) {
2762   MemRefType memrefType = op.getMemRefType();
2763   if (failed(verifyMemoryOpIndexing(
2764           op.getOperation(),
2765           op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
2766           op.getMapOperands(), memrefType,
2767           /*numIndexOperands=*/op.getNumOperands() - 1)))
2768     return failure();
2769 
2770   if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType,
2771                                   op.getVectorType())))
2772     return failure();
2773 
2774   return success();
2775 }
2776 
2777 //===----------------------------------------------------------------------===//
2778 // AffineVectorStoreOp
2779 //===----------------------------------------------------------------------===//
2780 
parseAffineVectorStoreOp(OpAsmParser & parser,OperationState & result)2781 static ParseResult parseAffineVectorStoreOp(OpAsmParser &parser,
2782                                             OperationState &result) {
2783   auto indexTy = parser.getBuilder().getIndexType();
2784 
2785   MemRefType memrefType;
2786   VectorType resultType;
2787   OpAsmParser::OperandType storeValueInfo;
2788   OpAsmParser::OperandType memrefInfo;
2789   AffineMapAttr mapAttr;
2790   SmallVector<OpAsmParser::OperandType, 1> mapOperands;
2791   return failure(
2792       parser.parseOperand(storeValueInfo) || parser.parseComma() ||
2793       parser.parseOperand(memrefInfo) ||
2794       parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
2795                                     AffineVectorStoreOp::getMapAttrName(),
2796                                     result.attributes) ||
2797       parser.parseOptionalAttrDict(result.attributes) ||
2798       parser.parseColonType(memrefType) || parser.parseComma() ||
2799       parser.parseType(resultType) ||
2800       parser.resolveOperand(storeValueInfo, resultType, result.operands) ||
2801       parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
2802       parser.resolveOperands(mapOperands, indexTy, result.operands));
2803 }
2804 
print(OpAsmPrinter & p,AffineVectorStoreOp op)2805 static void print(OpAsmPrinter &p, AffineVectorStoreOp op) {
2806   p << "affine.vector_store " << op.getValueToStore();
2807   p << ", " << op.getMemRef() << '[';
2808   if (AffineMapAttr mapAttr =
2809           op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
2810     p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
2811   p << ']';
2812   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()});
2813   p << " : " << op.getMemRefType() << ", " << op.getValueToStore().getType();
2814 }
2815 
verify(AffineVectorStoreOp op)2816 static LogicalResult verify(AffineVectorStoreOp op) {
2817   MemRefType memrefType = op.getMemRefType();
2818   if (failed(verifyMemoryOpIndexing(
2819           op.getOperation(),
2820           op.getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
2821           op.getMapOperands(), memrefType,
2822           /*numIndexOperands=*/op.getNumOperands() - 2)))
2823     return failure();
2824 
2825   if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType,
2826                                   op.getVectorType())))
2827     return failure();
2828 
2829   return success();
2830 }
2831 
2832 //===----------------------------------------------------------------------===//
2833 // TableGen'd op method definitions
2834 //===----------------------------------------------------------------------===//
2835 
2836 #define GET_OP_CLASSES
2837 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
2838