1 //===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
2 //
3 // Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/AffineOps/AffineOps.h"
10 #include "mlir/Dialect/StandardOps/Ops.h"
11 #include "mlir/IR/Function.h"
12 #include "mlir/IR/IntegerSet.h"
13 #include "mlir/IR/Matchers.h"
14 #include "mlir/IR/OpImplementation.h"
15 #include "mlir/IR/PatternMatch.h"
16 #include "mlir/Transforms/InliningUtils.h"
17 #include "mlir/Transforms/SideEffectsInterface.h"
18 #include "llvm/ADT/SetVector.h"
19 #include "llvm/ADT/SmallBitVector.h"
20 #include "llvm/Support/Debug.h"
21 
22 using namespace mlir;
23 using llvm::dbgs;
24 
25 #define DEBUG_TYPE "affine-analysis"
26 
27 //===----------------------------------------------------------------------===//
28 // AffineOpsDialect Interfaces
29 //===----------------------------------------------------------------------===//
30 
31 namespace {
32 /// This class defines the interface for handling inlining with affine
33 /// operations.
34 struct AffineInlinerInterface : public DialectInlinerInterface {
35   using DialectInlinerInterface::DialectInlinerInterface;
36 
37   //===--------------------------------------------------------------------===//
38   // Analysis Hooks
39   //===--------------------------------------------------------------------===//
40 
41   /// Returns true if the given region 'src' can be inlined into the region
42   /// 'dest' that is attached to an operation registered to the current dialect.
isLegalToInline__anon86bd9b9c0111::AffineInlinerInterface43   bool isLegalToInline(Region *dest, Region *src,
44                        BlockAndValueMapping &valueMapping) const final {
45     // Conservatively don't allow inlining into affine structures.
46     return false;
47   }
48 
49   /// Returns true if the given operation 'op', that is registered to this
50   /// dialect, can be inlined into the given region, false otherwise.
isLegalToInline__anon86bd9b9c0111::AffineInlinerInterface51   bool isLegalToInline(Operation *op, Region *region,
52                        BlockAndValueMapping &valueMapping) const final {
53     // Always allow inlining affine operations into the top-level region of a
54     // function. There are some edge cases when inlining *into* affine
55     // structures, but that is handled in the other 'isLegalToInline' hook
56     // above.
57     // TODO: We should be able to inline into other regions than functions.
58     return isa<FuncOp>(region->getParentOp());
59   }
60 
61   /// Affine regions should be analyzed recursively.
shouldAnalyzeRecursively__anon86bd9b9c0111::AffineInlinerInterface62   bool shouldAnalyzeRecursively(Operation *op) const final { return true; }
63 };
64 
65 // TODO(mlir): Extend for other ops in this dialect.
66 struct AffineSideEffectsInterface : public SideEffectsDialectInterface {
67   using SideEffectsDialectInterface::SideEffectsDialectInterface;
68 
isSideEffecting__anon86bd9b9c0111::AffineSideEffectsInterface69   SideEffecting isSideEffecting(Operation *op) const override {
70     if (isa<AffineIfOp>(op)) {
71       return Recursive;
72     }
73     return SideEffectsDialectInterface::isSideEffecting(op);
74   };
75 };
76 
77 } // end anonymous namespace
78 
79 //===----------------------------------------------------------------------===//
80 // AffineOpsDialect
81 //===----------------------------------------------------------------------===//
82 
AffineOpsDialect(MLIRContext * context)83 AffineOpsDialect::AffineOpsDialect(MLIRContext *context)
84     : Dialect(getDialectNamespace(), context) {
85   addOperations<AffineApplyOp, AffineDmaStartOp, AffineDmaWaitOp, AffineLoadOp,
86                 AffineStoreOp,
87 #define GET_OP_LIST
88 #include "mlir/Dialect/AffineOps/AffineOps.cpp.inc"
89                 >();
90   addInterfaces<AffineInlinerInterface, AffineSideEffectsInterface>();
91 }
92 
93 /// Materialize a single constant operation from a given attribute value with
94 /// the desired resultant type.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)95 Operation *AffineOpsDialect::materializeConstant(OpBuilder &builder,
96                                                  Attribute value, Type type,
97                                                  Location loc) {
98   return builder.create<ConstantOp>(loc, type, value);
99 }
100 
101 /// A utility function to check if a given region is attached to a function.
isFunctionRegion(Region * region)102 static bool isFunctionRegion(Region *region) {
103   return llvm::isa<FuncOp>(region->getParentOp());
104 }
105 
106 /// A utility function to check if a value is defined at the top level of a
107 /// function. A value of index type defined at the top level is always a valid
108 /// symbol.
isTopLevelValue(Value value)109 bool mlir::isTopLevelValue(Value value) {
110   if (auto arg = value.dyn_cast<BlockArgument>())
111     return isFunctionRegion(arg.getOwner()->getParent());
112   return isFunctionRegion(value.getDefiningOp()->getParentRegion());
113 }
114 
115 // Value can be used as a dimension id if it is valid as a symbol, or
116 // it is an induction variable, or it is a result of affine apply operation
117 // with dimension id arguments.
isValidDim(Value value)118 bool mlir::isValidDim(Value value) {
119   // The value must be an index type.
120   if (!value.getType().isIndex())
121     return false;
122 
123   if (auto *op = value.getDefiningOp()) {
124     // Top level operation or constant operation is ok.
125     if (isFunctionRegion(op->getParentRegion()) || isa<ConstantOp>(op))
126       return true;
127     // Affine apply operation is ok if all of its operands are ok.
128     if (auto applyOp = dyn_cast<AffineApplyOp>(op))
129       return applyOp.isValidDim();
130     // The dim op is okay if its operand memref/tensor is defined at the top
131     // level.
132     if (auto dimOp = dyn_cast<DimOp>(op))
133       return isTopLevelValue(dimOp.getOperand());
134     return false;
135   }
136   // This value has to be a block argument for a FuncOp or an affine.for.
137   auto *parentOp = value.cast<BlockArgument>().getOwner()->getParentOp();
138   return isa<FuncOp>(parentOp) || isa<AffineForOp>(parentOp);
139 }
140 
141 /// Returns true if the 'index' dimension of the `memref` defined by
142 /// `memrefDefOp` is a statically  shaped one or defined using a valid symbol.
143 template <typename AnyMemRefDefOp>
isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp,unsigned index)144 static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp,
145                                     unsigned index) {
146   auto memRefType = memrefDefOp.getType();
147   // Statically shaped.
148   if (!ShapedType::isDynamic(memRefType.getDimSize(index)))
149     return true;
150   // Get the position of the dimension among dynamic dimensions;
151   unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
152   return isValidSymbol(
153       *(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos));
154 }
155 
156 /// Returns true if the result of the dim op is a valid symbol.
isDimOpValidSymbol(DimOp dimOp)157 static bool isDimOpValidSymbol(DimOp dimOp) {
158   // The dim op is okay if its operand memref/tensor is defined at the top
159   // level.
160   if (isTopLevelValue(dimOp.getOperand()))
161     return true;
162 
163   // The dim op is also okay if its operand memref/tensor is a view/subview
164   // whose corresponding size is a valid symbol.
165   unsigned index = dimOp.getIndex();
166   if (auto viewOp = dyn_cast<ViewOp>(dimOp.getOperand().getDefiningOp()))
167     return isMemRefSizeValidSymbol<ViewOp>(viewOp, index);
168   if (auto subViewOp = dyn_cast<SubViewOp>(dimOp.getOperand().getDefiningOp()))
169     return isMemRefSizeValidSymbol<SubViewOp>(subViewOp, index);
170   if (auto allocOp = dyn_cast<AllocOp>(dimOp.getOperand().getDefiningOp()))
171     return isMemRefSizeValidSymbol<AllocOp>(allocOp, index);
172   return false;
173 }
174 
175 // Value can be used as a symbol if it is a constant, or it is defined at
176 // the top level, or it is a result of affine apply operation with symbol
177 // arguments, or a result of the dim op on a memref satisfying certain
178 // constraints.
isValidSymbol(Value value)179 bool mlir::isValidSymbol(Value value) {
180   // The value must be an index type.
181   if (!value.getType().isIndex())
182     return false;
183 
184   if (auto *op = value.getDefiningOp()) {
185     // Top level operation or constant operation is ok.
186     if (isFunctionRegion(op->getParentRegion()) || isa<ConstantOp>(op))
187       return true;
188     // Affine apply operation is ok if all of its operands are ok.
189     if (auto applyOp = dyn_cast<AffineApplyOp>(op))
190       return applyOp.isValidSymbol();
191     if (auto dimOp = dyn_cast<DimOp>(op)) {
192       return isDimOpValidSymbol(dimOp);
193     }
194   }
195   // Otherwise, check that the value is a top level value.
196   return isTopLevelValue(value);
197 }
198 
199 // Returns true if 'value' is a valid index to an affine operation (e.g.
200 // affine.load, affine.store, affine.dma_start, affine.dma_wait).
201 // Returns false otherwise.
isValidAffineIndexOperand(Value value)202 static bool isValidAffineIndexOperand(Value value) {
203   return isValidDim(value) || isValidSymbol(value);
204 }
205 
206 /// Utility function to verify that a set of operands are valid dimension and
207 /// symbol identifiers. The operands should be laid out such that the dimension
208 /// operands are before the symbol operands. This function returns failure if
209 /// there was an invalid operand. An operation is provided to emit any necessary
210 /// errors.
211 template <typename OpTy>
212 static LogicalResult
verifyDimAndSymbolIdentifiers(OpTy & op,Operation::operand_range operands,unsigned numDims)213 verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands,
214                               unsigned numDims) {
215   unsigned opIt = 0;
216   for (auto operand : operands) {
217     if (opIt++ < numDims) {
218       if (!isValidDim(operand))
219         return op.emitOpError("operand cannot be used as a dimension id");
220     } else if (!isValidSymbol(operand)) {
221       return op.emitOpError("operand cannot be used as a symbol");
222     }
223   }
224   return success();
225 }
226 
227 //===----------------------------------------------------------------------===//
228 // AffineApplyOp
229 //===----------------------------------------------------------------------===//
230 
build(Builder * builder,OperationState & result,AffineMap map,ValueRange operands)231 void AffineApplyOp::build(Builder *builder, OperationState &result,
232                           AffineMap map, ValueRange operands) {
233   result.addOperands(operands);
234   result.types.append(map.getNumResults(), builder->getIndexType());
235   result.addAttribute("map", AffineMapAttr::get(map));
236 }
237 
parse(OpAsmParser & parser,OperationState & result)238 ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) {
239   auto &builder = parser.getBuilder();
240   auto indexTy = builder.getIndexType();
241 
242   AffineMapAttr mapAttr;
243   unsigned numDims;
244   if (parser.parseAttribute(mapAttr, "map", result.attributes) ||
245       parseDimAndSymbolList(parser, result.operands, numDims) ||
246       parser.parseOptionalAttrDict(result.attributes))
247     return failure();
248   auto map = mapAttr.getValue();
249 
250   if (map.getNumDims() != numDims ||
251       numDims + map.getNumSymbols() != result.operands.size()) {
252     return parser.emitError(parser.getNameLoc(),
253                             "dimension or symbol index mismatch");
254   }
255 
256   result.types.append(map.getNumResults(), indexTy);
257   return success();
258 }
259 
print(OpAsmPrinter & p)260 void AffineApplyOp::print(OpAsmPrinter &p) {
261   p << "affine.apply " << getAttr("map");
262   printDimAndSymbolList(operand_begin(), operand_end(),
263                         getAffineMap().getNumDims(), p);
264   p.printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{"map"});
265 }
266 
verify()267 LogicalResult AffineApplyOp::verify() {
268   // Check that affine map attribute was specified.
269   auto affineMapAttr = getAttrOfType<AffineMapAttr>("map");
270   if (!affineMapAttr)
271     return emitOpError("requires an affine map");
272 
273   // Check input and output dimensions match.
274   auto map = affineMapAttr.getValue();
275 
276   // Verify that operand count matches affine map dimension and symbol count.
277   if (getNumOperands() != map.getNumDims() + map.getNumSymbols())
278     return emitOpError(
279         "operand count and affine map dimension and symbol count must match");
280 
281   // Verify that all operands are of `index` type.
282   for (Type t : getOperandTypes()) {
283     if (!t.isIndex())
284       return emitOpError("operands must be of type 'index'");
285   }
286 
287   if (!getResult().getType().isIndex())
288     return emitOpError("result must be of type 'index'");
289 
290   // Verify that the map only produces one result.
291   if (map.getNumResults() != 1)
292     return emitOpError("mapping must produce one value");
293 
294   return success();
295 }
296 
297 // The result of the affine apply operation can be used as a dimension id if all
298 // its operands are valid dimension ids.
isValidDim()299 bool AffineApplyOp::isValidDim() {
300   return llvm::all_of(getOperands(),
301                       [](Value op) { return mlir::isValidDim(op); });
302 }
303 
304 // The result of the affine apply operation can be used as a symbol if all its
305 // operands are symbols.
isValidSymbol()306 bool AffineApplyOp::isValidSymbol() {
307   return llvm::all_of(getOperands(),
308                       [](Value op) { return mlir::isValidSymbol(op); });
309 }
310 
fold(ArrayRef<Attribute> operands)311 OpFoldResult AffineApplyOp::fold(ArrayRef<Attribute> operands) {
312   auto map = getAffineMap();
313 
314   // Fold dims and symbols to existing values.
315   auto expr = map.getResult(0);
316   if (auto dim = expr.dyn_cast<AffineDimExpr>())
317     return getOperand(dim.getPosition());
318   if (auto sym = expr.dyn_cast<AffineSymbolExpr>())
319     return getOperand(map.getNumDims() + sym.getPosition());
320 
321   // Otherwise, default to folding the map.
322   SmallVector<Attribute, 1> result;
323   if (failed(map.constantFold(operands, result)))
324     return {};
325   return result[0];
326 }
327 
renumberOneDim(Value v)328 AffineDimExpr AffineApplyNormalizer::renumberOneDim(Value v) {
329   DenseMap<Value, unsigned>::iterator iterPos;
330   bool inserted = false;
331   std::tie(iterPos, inserted) =
332       dimValueToPosition.insert(std::make_pair(v, dimValueToPosition.size()));
333   if (inserted) {
334     reorderedDims.push_back(v);
335   }
336   return getAffineDimExpr(iterPos->second, v.getContext())
337       .cast<AffineDimExpr>();
338 }
339 
renumber(const AffineApplyNormalizer & other)340 AffineMap AffineApplyNormalizer::renumber(const AffineApplyNormalizer &other) {
341   SmallVector<AffineExpr, 8> dimRemapping;
342   for (auto v : other.reorderedDims) {
343     auto kvp = other.dimValueToPosition.find(v);
344     if (dimRemapping.size() <= kvp->second)
345       dimRemapping.resize(kvp->second + 1);
346     dimRemapping[kvp->second] = renumberOneDim(kvp->first);
347   }
348   unsigned numSymbols = concatenatedSymbols.size();
349   unsigned numOtherSymbols = other.concatenatedSymbols.size();
350   SmallVector<AffineExpr, 8> symRemapping(numOtherSymbols);
351   for (unsigned idx = 0; idx < numOtherSymbols; ++idx) {
352     symRemapping[idx] =
353         getAffineSymbolExpr(idx + numSymbols, other.affineMap.getContext());
354   }
355   concatenatedSymbols.insert(concatenatedSymbols.end(),
356                              other.concatenatedSymbols.begin(),
357                              other.concatenatedSymbols.end());
358   auto map = other.affineMap;
359   return map.replaceDimsAndSymbols(dimRemapping, symRemapping,
360                                    reorderedDims.size(),
361                                    concatenatedSymbols.size());
362 }
363 
364 // Gather the positions of the operands that are produced by an AffineApplyOp.
365 static llvm::SetVector<unsigned>
indicesFromAffineApplyOp(ArrayRef<Value> operands)366 indicesFromAffineApplyOp(ArrayRef<Value> operands) {
367   llvm::SetVector<unsigned> res;
368   for (auto en : llvm::enumerate(operands))
369     if (isa_and_nonnull<AffineApplyOp>(en.value().getDefiningOp()))
370       res.insert(en.index());
371   return res;
372 }
373 
374 // Support the special case of a symbol coming from an AffineApplyOp that needs
375 // to be composed into the current AffineApplyOp.
376 // This case is handled by rewriting all such symbols into dims for the purpose
377 // of allowing mathematical AffineMap composition.
378 // Returns an AffineMap where symbols that come from an AffineApplyOp have been
379 // rewritten as dims and are ordered after the original dims.
380 // TODO(andydavis,ntv): This promotion makes AffineMap lose track of which
381 // symbols are represented as dims. This loss is static but can still be
382 // recovered dynamically (with `isValidSymbol`). Still this is annoying for the
383 // semi-affine map case. A dynamic canonicalization of all dims that are valid
384 // symbols (a.k.a `canonicalizePromotedSymbols`) into symbols helps and even
385 // results in better simplifications and foldings. But we should evaluate
386 // whether this behavior is what we really want after using more.
promoteComposedSymbolsAsDims(AffineMap map,ArrayRef<Value> symbols)387 static AffineMap promoteComposedSymbolsAsDims(AffineMap map,
388                                               ArrayRef<Value> symbols) {
389   if (symbols.empty()) {
390     return map;
391   }
392 
393   // Sanity check on symbols.
394   for (auto sym : symbols) {
395     assert(isValidSymbol(sym) && "Expected only valid symbols");
396     (void)sym;
397   }
398 
399   // Extract the symbol positions that come from an AffineApplyOp and
400   // needs to be rewritten as dims.
401   auto symPositions = indicesFromAffineApplyOp(symbols);
402   if (symPositions.empty()) {
403     return map;
404   }
405 
406   // Create the new map by replacing each symbol at pos by the next new dim.
407   unsigned numDims = map.getNumDims();
408   unsigned numSymbols = map.getNumSymbols();
409   unsigned numNewDims = 0;
410   unsigned numNewSymbols = 0;
411   SmallVector<AffineExpr, 8> symReplacements(numSymbols);
412   for (unsigned i = 0; i < numSymbols; ++i) {
413     symReplacements[i] =
414         symPositions.count(i) > 0
415             ? getAffineDimExpr(numDims + numNewDims++, map.getContext())
416             : getAffineSymbolExpr(numNewSymbols++, map.getContext());
417   }
418   assert(numSymbols >= numNewDims);
419   AffineMap newMap = map.replaceDimsAndSymbols(
420       {}, symReplacements, numDims + numNewDims, numNewSymbols);
421 
422   return newMap;
423 }
424 
425 /// The AffineNormalizer composes AffineApplyOp recursively. Its purpose is to
426 /// keep a correspondence between the mathematical `map` and the `operands` of
427 /// a given AffineApplyOp. This correspondence is maintained by iterating over
428 /// the operands and forming an `auxiliaryMap` that can be composed
429 /// mathematically with `map`. To keep this correspondence in cases where
430 /// symbols are produced by affine.apply operations, we perform a local rewrite
431 /// of symbols as dims.
432 ///
433 /// Rationale for locally rewriting symbols as dims:
434 /// ================================================
435 /// The mathematical composition of AffineMap must always concatenate symbols
436 /// because it does not have enough information to do otherwise. For example,
437 /// composing `(d0)[s0] -> (d0 + s0)` with itself must produce
438 /// `(d0)[s0, s1] -> (d0 + s0 + s1)`.
439 ///
440 /// The result is only equivalent to `(d0)[s0] -> (d0 + 2 * s0)` when
441 /// applied to the same mlir::Value for both s0 and s1.
442 /// As a consequence mathematical composition of AffineMap always concatenates
443 /// symbols.
444 ///
445 /// When AffineMaps are used in AffineApplyOp however, they may specify
446 /// composition via symbols, which is ambiguous mathematically. This corner case
447 /// is handled by locally rewriting such symbols that come from AffineApplyOp
448 /// into dims and composing through dims.
449 /// TODO(andydavis, ntv): Composition via symbols comes at a significant code
450 /// complexity. Alternatively we should investigate whether we want to
451 /// explicitly disallow symbols coming from affine.apply and instead force the
452 /// user to compose symbols beforehand. The annoyances may be small (i.e. 1 or 2
453 /// extra API calls for such uses, which haven't popped up until now) and the
454 /// benefit potentially big: simpler and more maintainable code for a
455 /// non-trivial, recursive, procedure.
AffineApplyNormalizer(AffineMap map,ArrayRef<Value> operands)456 AffineApplyNormalizer::AffineApplyNormalizer(AffineMap map,
457                                              ArrayRef<Value> operands)
458     : AffineApplyNormalizer() {
459   static_assert(kMaxAffineApplyDepth > 0, "kMaxAffineApplyDepth must be > 0");
460   assert(map.getNumInputs() == operands.size() &&
461          "number of operands does not match the number of map inputs");
462 
463   LLVM_DEBUG(map.print(dbgs() << "\nInput map: "));
464 
465   // Promote symbols that come from an AffineApplyOp to dims by rewriting the
466   // map to always refer to:
467   //   (dims, symbols coming from AffineApplyOp, other symbols).
468   // The order of operands can remain unchanged.
469   // This is a simplification that relies on 2 ordering properties:
470   //   1. rewritten symbols always appear after the original dims in the map;
471   //   2. operands are traversed in order and either dispatched to:
472   //      a. auxiliaryExprs (dims and symbols rewritten as dims);
473   //      b. concatenatedSymbols (all other symbols)
474   // This allows operand order to remain unchanged.
475   unsigned numDimsBeforeRewrite = map.getNumDims();
476   map = promoteComposedSymbolsAsDims(map,
477                                      operands.take_back(map.getNumSymbols()));
478 
479   LLVM_DEBUG(map.print(dbgs() << "\nRewritten map: "));
480 
481   SmallVector<AffineExpr, 8> auxiliaryExprs;
482   bool furtherCompose = (affineApplyDepth() <= kMaxAffineApplyDepth);
483   // We fully spell out the 2 cases below. In this particular instance a little
484   // code duplication greatly improves readability.
485   // Note that the first branch would disappear if we only supported full
486   // composition (i.e. infinite kMaxAffineApplyDepth).
487   if (!furtherCompose) {
488     // 1. Only dispatch dims or symbols.
489     for (auto en : llvm::enumerate(operands)) {
490       auto t = en.value();
491       assert(t.getType().isIndex());
492       bool isDim = (en.index() < map.getNumDims());
493       if (isDim) {
494         // a. The mathematical composition of AffineMap composes dims.
495         auxiliaryExprs.push_back(renumberOneDim(t));
496       } else {
497         // b. The mathematical composition of AffineMap concatenates symbols.
498         //    We do the same for symbol operands.
499         concatenatedSymbols.push_back(t);
500       }
501     }
502   } else {
503     assert(numDimsBeforeRewrite <= operands.size());
504     // 2. Compose AffineApplyOps and dispatch dims or symbols.
505     for (unsigned i = 0, e = operands.size(); i < e; ++i) {
506       auto t = operands[i];
507       auto affineApply = dyn_cast_or_null<AffineApplyOp>(t.getDefiningOp());
508       if (affineApply) {
509         // a. Compose affine.apply operations.
510         LLVM_DEBUG(affineApply.getOperation()->print(
511             dbgs() << "\nCompose AffineApplyOp recursively: "));
512         AffineMap affineApplyMap = affineApply.getAffineMap();
513         SmallVector<Value, 8> affineApplyOperands(
514             affineApply.getOperands().begin(), affineApply.getOperands().end());
515         AffineApplyNormalizer normalizer(affineApplyMap, affineApplyOperands);
516 
517         LLVM_DEBUG(normalizer.affineMap.print(
518             dbgs() << "\nRenumber into current normalizer: "));
519 
520         auto renumberedMap = renumber(normalizer);
521 
522         LLVM_DEBUG(
523             renumberedMap.print(dbgs() << "\nRecursive composition yields: "));
524 
525         auxiliaryExprs.push_back(renumberedMap.getResult(0));
526       } else {
527         if (i < numDimsBeforeRewrite) {
528           // b. The mathematical composition of AffineMap composes dims.
529           auxiliaryExprs.push_back(renumberOneDim(t));
530         } else {
531           // c. The mathematical composition of AffineMap concatenates symbols.
532           //    We do the same for symbol operands.
533           concatenatedSymbols.push_back(t);
534         }
535       }
536     }
537   }
538 
539   // Early exit if `map` is already composed.
540   if (auxiliaryExprs.empty()) {
541     affineMap = map;
542     return;
543   }
544 
545   assert(concatenatedSymbols.size() >= map.getNumSymbols() &&
546          "Unexpected number of concatenated symbols");
547   auto numDims = dimValueToPosition.size();
548   auto numSymbols = concatenatedSymbols.size() - map.getNumSymbols();
549   auto auxiliaryMap = AffineMap::get(numDims, numSymbols, auxiliaryExprs);
550 
551   LLVM_DEBUG(map.print(dbgs() << "\nCompose map: "));
552   LLVM_DEBUG(auxiliaryMap.print(dbgs() << "\nWith map: "));
553   LLVM_DEBUG(map.compose(auxiliaryMap).print(dbgs() << "\nResult: "));
554 
555   // TODO(andydavis,ntv): Disabling simplification results in major speed gains.
556   // Another option is to cache the results as it is expected a lot of redundant
557   // work is performed in practice.
558   affineMap = simplifyAffineMap(map.compose(auxiliaryMap));
559 
560   LLVM_DEBUG(affineMap.print(dbgs() << "\nSimplified result: "));
561   LLVM_DEBUG(dbgs() << "\n");
562 }
563 
normalize(AffineMap * otherMap,SmallVectorImpl<Value> * otherOperands)564 void AffineApplyNormalizer::normalize(AffineMap *otherMap,
565                                       SmallVectorImpl<Value> *otherOperands) {
566   AffineApplyNormalizer other(*otherMap, *otherOperands);
567   *otherMap = renumber(other);
568 
569   otherOperands->reserve(reorderedDims.size() + concatenatedSymbols.size());
570   otherOperands->assign(reorderedDims.begin(), reorderedDims.end());
571   otherOperands->append(concatenatedSymbols.begin(), concatenatedSymbols.end());
572 }
573 
574 /// Implements `map` and `operands` composition and simplification to support
575 /// `makeComposedAffineApply`. This can be called to achieve the same effects
576 /// on `map` and `operands` without creating an AffineApplyOp that needs to be
577 /// immediately deleted.
composeAffineMapAndOperands(AffineMap * map,SmallVectorImpl<Value> * operands)578 static void composeAffineMapAndOperands(AffineMap *map,
579                                         SmallVectorImpl<Value> *operands) {
580   AffineApplyNormalizer normalizer(*map, *operands);
581   auto normalizedMap = normalizer.getAffineMap();
582   auto normalizedOperands = normalizer.getOperands();
583   canonicalizeMapAndOperands(&normalizedMap, &normalizedOperands);
584   *map = normalizedMap;
585   *operands = normalizedOperands;
586   assert(*map);
587 }
588 
fullyComposeAffineMapAndOperands(AffineMap * map,SmallVectorImpl<Value> * operands)589 void mlir::fullyComposeAffineMapAndOperands(AffineMap *map,
590                                             SmallVectorImpl<Value> *operands) {
591   while (llvm::any_of(*operands, [](Value v) {
592     return isa_and_nonnull<AffineApplyOp>(v.getDefiningOp());
593   })) {
594     composeAffineMapAndOperands(map, operands);
595   }
596 }
597 
makeComposedAffineApply(OpBuilder & b,Location loc,AffineMap map,ArrayRef<Value> operands)598 AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc,
599                                             AffineMap map,
600                                             ArrayRef<Value> operands) {
601   AffineMap normalizedMap = map;
602   SmallVector<Value, 8> normalizedOperands(operands.begin(), operands.end());
603   composeAffineMapAndOperands(&normalizedMap, &normalizedOperands);
604   assert(normalizedMap);
605   return b.create<AffineApplyOp>(loc, normalizedMap, normalizedOperands);
606 }
607 
608 // A symbol may appear as a dim in affine.apply operations. This function
609 // canonicalizes dims that are valid symbols into actual symbols.
610 template <class MapOrSet>
canonicalizePromotedSymbols(MapOrSet * mapOrSet,SmallVectorImpl<Value> * operands)611 static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,
612                                         SmallVectorImpl<Value> *operands) {
613   if (!mapOrSet || operands->empty())
614     return;
615 
616   assert(mapOrSet->getNumInputs() == operands->size() &&
617          "map/set inputs must match number of operands");
618 
619   auto *context = mapOrSet->getContext();
620   SmallVector<Value, 8> resultOperands;
621   resultOperands.reserve(operands->size());
622   SmallVector<Value, 8> remappedSymbols;
623   remappedSymbols.reserve(operands->size());
624   unsigned nextDim = 0;
625   unsigned nextSym = 0;
626   unsigned oldNumSyms = mapOrSet->getNumSymbols();
627   SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
628   for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
629     if (i < mapOrSet->getNumDims()) {
630       if (isValidSymbol((*operands)[i])) {
631         // This is a valid symbol that appears as a dim, canonicalize it.
632         dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context);
633         remappedSymbols.push_back((*operands)[i]);
634       } else {
635         dimRemapping[i] = getAffineDimExpr(nextDim++, context);
636         resultOperands.push_back((*operands)[i]);
637       }
638     } else {
639       resultOperands.push_back((*operands)[i]);
640     }
641   }
642 
643   resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
644   *operands = resultOperands;
645   *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
646                                               oldNumSyms + nextSym);
647 
648   assert(mapOrSet->getNumInputs() == operands->size() &&
649          "map/set inputs must match number of operands");
650 }
651 
652 // Works for either an affine map or an integer set.
653 template <class MapOrSet>
canonicalizeMapOrSetAndOperands(MapOrSet * mapOrSet,SmallVectorImpl<Value> * operands)654 static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
655                                             SmallVectorImpl<Value> *operands) {
656   static_assert(std::is_same<MapOrSet, AffineMap>::value ||
657                     std::is_same<MapOrSet, IntegerSet>::value,
658                 "Argument must be either of AffineMap or IntegerSet type");
659 
660   if (!mapOrSet || operands->empty())
661     return;
662 
663   assert(mapOrSet->getNumInputs() == operands->size() &&
664          "map/set inputs must match number of operands");
665 
666   canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
667 
668   // Check to see what dims are used.
669   llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
670   llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
671   mapOrSet->walkExprs([&](AffineExpr expr) {
672     if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
673       usedDims[dimExpr.getPosition()] = true;
674     else if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
675       usedSyms[symExpr.getPosition()] = true;
676   });
677 
678   auto *context = mapOrSet->getContext();
679 
680   SmallVector<Value, 8> resultOperands;
681   resultOperands.reserve(operands->size());
682 
683   llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
684   SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
685   unsigned nextDim = 0;
686   for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
687     if (usedDims[i]) {
688       // Remap dim positions for duplicate operands.
689       auto it = seenDims.find((*operands)[i]);
690       if (it == seenDims.end()) {
691         dimRemapping[i] = getAffineDimExpr(nextDim++, context);
692         resultOperands.push_back((*operands)[i]);
693         seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
694       } else {
695         dimRemapping[i] = it->second;
696       }
697     }
698   }
699   llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
700   SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
701   unsigned nextSym = 0;
702   for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
703     if (!usedSyms[i])
704       continue;
705     // Handle constant operands (only needed for symbolic operands since
706     // constant operands in dimensional positions would have already been
707     // promoted to symbolic positions above).
708     IntegerAttr operandCst;
709     if (matchPattern((*operands)[i + mapOrSet->getNumDims()],
710                      m_Constant(&operandCst))) {
711       symRemapping[i] =
712           getAffineConstantExpr(operandCst.getValue().getSExtValue(), context);
713       continue;
714     }
715     // Remap symbol positions for duplicate operands.
716     auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
717     if (it == seenSymbols.end()) {
718       symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
719       resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
720       seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
721                                         symRemapping[i]));
722     } else {
723       symRemapping[i] = it->second;
724     }
725   }
726   *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
727                                               nextDim, nextSym);
728   *operands = resultOperands;
729 }
730 
canonicalizeMapAndOperands(AffineMap * map,SmallVectorImpl<Value> * operands)731 void mlir::canonicalizeMapAndOperands(AffineMap *map,
732                                       SmallVectorImpl<Value> *operands) {
733   canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
734 }
735 
canonicalizeSetAndOperands(IntegerSet * set,SmallVectorImpl<Value> * operands)736 void mlir::canonicalizeSetAndOperands(IntegerSet *set,
737                                       SmallVectorImpl<Value> *operands) {
738   canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
739 }
740 
741 namespace {
742 /// Simplify AffineApply, AffineLoad, and AffineStore operations by composing
743 /// maps that supply results into them.
744 ///
745 template <typename AffineOpTy>
746 struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
747   using OpRewritePattern<AffineOpTy>::OpRewritePattern;
748 
749   /// Replace the affine op with another instance of it with the supplied
750   /// map and mapOperands.
751   void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
752                        AffineMap map, ArrayRef<Value> mapOperands) const;
753 
matchAndRewrite__anon86bd9b9c0611::SimplifyAffineOp754   PatternMatchResult matchAndRewrite(AffineOpTy affineOp,
755                                      PatternRewriter &rewriter) const override {
756     static_assert(std::is_same<AffineOpTy, AffineLoadOp>::value ||
757                       std::is_same<AffineOpTy, AffinePrefetchOp>::value ||
758                       std::is_same<AffineOpTy, AffineStoreOp>::value ||
759                       std::is_same<AffineOpTy, AffineApplyOp>::value,
760                   "affine load/store/apply op expected");
761     auto map = affineOp.getAffineMap();
762     AffineMap oldMap = map;
763     auto oldOperands = affineOp.getMapOperands();
764     SmallVector<Value, 8> resultOperands(oldOperands);
765     composeAffineMapAndOperands(&map, &resultOperands);
766     if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
767                                     resultOperands.begin()))
768       return this->matchFailure();
769 
770     replaceAffineOp(rewriter, affineOp, map, resultOperands);
771     return this->matchSuccess();
772   }
773 };
774 
775 // Specialize the template to account for the different build signatures for
776 // affine load, store, and apply ops.
777 template <>
replaceAffineOp(PatternRewriter & rewriter,AffineLoadOp load,AffineMap map,ArrayRef<Value> mapOperands) const778 void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
779     PatternRewriter &rewriter, AffineLoadOp load, AffineMap map,
780     ArrayRef<Value> mapOperands) const {
781   rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.getMemRef(), map,
782                                             mapOperands);
783 }
784 template <>
replaceAffineOp(PatternRewriter & rewriter,AffinePrefetchOp prefetch,AffineMap map,ArrayRef<Value> mapOperands) const785 void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
786     PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map,
787     ArrayRef<Value> mapOperands) const {
788   rewriter.replaceOpWithNewOp<AffinePrefetchOp>(
789       prefetch, prefetch.memref(), map, mapOperands,
790       prefetch.localityHint().getZExtValue(), prefetch.isWrite(),
791       prefetch.isDataCache());
792 }
793 template <>
replaceAffineOp(PatternRewriter & rewriter,AffineStoreOp store,AffineMap map,ArrayRef<Value> mapOperands) const794 void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
795     PatternRewriter &rewriter, AffineStoreOp store, AffineMap map,
796     ArrayRef<Value> mapOperands) const {
797   rewriter.replaceOpWithNewOp<AffineStoreOp>(
798       store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
799 }
800 template <>
replaceAffineOp(PatternRewriter & rewriter,AffineApplyOp apply,AffineMap map,ArrayRef<Value> mapOperands) const801 void SimplifyAffineOp<AffineApplyOp>::replaceAffineOp(
802     PatternRewriter &rewriter, AffineApplyOp apply, AffineMap map,
803     ArrayRef<Value> mapOperands) const {
804   rewriter.replaceOpWithNewOp<AffineApplyOp>(apply, map, mapOperands);
805 }
806 } // end anonymous namespace.
807 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)808 void AffineApplyOp::getCanonicalizationPatterns(
809     OwningRewritePatternList &results, MLIRContext *context) {
810   results.insert<SimplifyAffineOp<AffineApplyOp>>(context);
811 }
812 
813 //===----------------------------------------------------------------------===//
814 // Common canonicalization pattern support logic
815 //===----------------------------------------------------------------------===//
816 
817 /// This is a common class used for patterns of the form
818 /// "someop(memrefcast) -> someop".  It folds the source of any memref_cast
819 /// into the root operation directly.
foldMemRefCast(Operation * op)820 static LogicalResult foldMemRefCast(Operation *op) {
821   bool folded = false;
822   for (OpOperand &operand : op->getOpOperands()) {
823     auto cast = dyn_cast_or_null<MemRefCastOp>(operand.get().getDefiningOp());
824     if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
825       operand.set(cast.getOperand());
826       folded = true;
827     }
828   }
829   return success(folded);
830 }
831 
832 //===----------------------------------------------------------------------===//
833 // AffineDmaStartOp
834 //===----------------------------------------------------------------------===//
835 
836 // TODO(b/133776335) Check that map operands are loop IVs or symbols.
build(Builder * builder,OperationState & result,Value srcMemRef,AffineMap srcMap,ValueRange srcIndices,Value destMemRef,AffineMap dstMap,ValueRange destIndices,Value tagMemRef,AffineMap tagMap,ValueRange tagIndices,Value numElements,Value stride,Value elementsPerStride)837 void AffineDmaStartOp::build(Builder *builder, OperationState &result,
838                              Value srcMemRef, AffineMap srcMap,
839                              ValueRange srcIndices, Value destMemRef,
840                              AffineMap dstMap, ValueRange destIndices,
841                              Value tagMemRef, AffineMap tagMap,
842                              ValueRange tagIndices, Value numElements,
843                              Value stride, Value elementsPerStride) {
844   result.addOperands(srcMemRef);
845   result.addAttribute(getSrcMapAttrName(), AffineMapAttr::get(srcMap));
846   result.addOperands(srcIndices);
847   result.addOperands(destMemRef);
848   result.addAttribute(getDstMapAttrName(), AffineMapAttr::get(dstMap));
849   result.addOperands(destIndices);
850   result.addOperands(tagMemRef);
851   result.addAttribute(getTagMapAttrName(), AffineMapAttr::get(tagMap));
852   result.addOperands(tagIndices);
853   result.addOperands(numElements);
854   if (stride) {
855     result.addOperands({stride, elementsPerStride});
856   }
857 }
858 
print(OpAsmPrinter & p)859 void AffineDmaStartOp::print(OpAsmPrinter &p) {
860   p << "affine.dma_start " << getSrcMemRef() << '[';
861   p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices());
862   p << "], " << getDstMemRef() << '[';
863   p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices());
864   p << "], " << getTagMemRef() << '[';
865   p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices());
866   p << "], " << getNumElements();
867   if (isStrided()) {
868     p << ", " << getStride();
869     p << ", " << getNumElementsPerStride();
870   }
871   p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
872     << getTagMemRefType();
873 }
874 
875 // Parse AffineDmaStartOp.
876 // Ex:
877 //   affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size,
878 //     %stride, %num_elt_per_stride
879 //       : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32>
880 //
parse(OpAsmParser & parser,OperationState & result)881 ParseResult AffineDmaStartOp::parse(OpAsmParser &parser,
882                                     OperationState &result) {
883   OpAsmParser::OperandType srcMemRefInfo;
884   AffineMapAttr srcMapAttr;
885   SmallVector<OpAsmParser::OperandType, 4> srcMapOperands;
886   OpAsmParser::OperandType dstMemRefInfo;
887   AffineMapAttr dstMapAttr;
888   SmallVector<OpAsmParser::OperandType, 4> dstMapOperands;
889   OpAsmParser::OperandType tagMemRefInfo;
890   AffineMapAttr tagMapAttr;
891   SmallVector<OpAsmParser::OperandType, 4> tagMapOperands;
892   OpAsmParser::OperandType numElementsInfo;
893   SmallVector<OpAsmParser::OperandType, 2> strideInfo;
894 
895   SmallVector<Type, 3> types;
896   auto indexType = parser.getBuilder().getIndexType();
897 
898   // Parse and resolve the following list of operands:
899   // *) dst memref followed by its affine maps operands (in square brackets).
900   // *) src memref followed by its affine map operands (in square brackets).
901   // *) tag memref followed by its affine map operands (in square brackets).
902   // *) number of elements transferred by DMA operation.
903   if (parser.parseOperand(srcMemRefInfo) ||
904       parser.parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr,
905                                     getSrcMapAttrName(), result.attributes) ||
906       parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
907       parser.parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr,
908                                     getDstMapAttrName(), result.attributes) ||
909       parser.parseComma() || parser.parseOperand(tagMemRefInfo) ||
910       parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
911                                     getTagMapAttrName(), result.attributes) ||
912       parser.parseComma() || parser.parseOperand(numElementsInfo))
913     return failure();
914 
915   // Parse optional stride and elements per stride.
916   if (parser.parseTrailingOperandList(strideInfo)) {
917     return failure();
918   }
919   if (!strideInfo.empty() && strideInfo.size() != 2) {
920     return parser.emitError(parser.getNameLoc(),
921                             "expected two stride related operands");
922   }
923   bool isStrided = strideInfo.size() == 2;
924 
925   if (parser.parseColonTypeList(types))
926     return failure();
927 
928   if (types.size() != 3)
929     return parser.emitError(parser.getNameLoc(), "expected three types");
930 
931   if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
932       parser.resolveOperands(srcMapOperands, indexType, result.operands) ||
933       parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
934       parser.resolveOperands(dstMapOperands, indexType, result.operands) ||
935       parser.resolveOperand(tagMemRefInfo, types[2], result.operands) ||
936       parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
937       parser.resolveOperand(numElementsInfo, indexType, result.operands))
938     return failure();
939 
940   if (isStrided) {
941     if (parser.resolveOperands(strideInfo, indexType, result.operands))
942       return failure();
943   }
944 
945   // Check that src/dst/tag operand counts match their map.numInputs.
946   if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
947       dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
948       tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
949     return parser.emitError(parser.getNameLoc(),
950                             "memref operand count not equal to map.numInputs");
951   return success();
952 }
953 
verify()954 LogicalResult AffineDmaStartOp::verify() {
955   if (!getOperand(getSrcMemRefOperandIndex()).getType().isa<MemRefType>())
956     return emitOpError("expected DMA source to be of memref type");
957   if (!getOperand(getDstMemRefOperandIndex()).getType().isa<MemRefType>())
958     return emitOpError("expected DMA destination to be of memref type");
959   if (!getOperand(getTagMemRefOperandIndex()).getType().isa<MemRefType>())
960     return emitOpError("expected DMA tag to be of memref type");
961 
962   // DMAs from different memory spaces supported.
963   if (getSrcMemorySpace() == getDstMemorySpace()) {
964     return emitOpError("DMA should be between different memory spaces");
965   }
966   unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
967                               getDstMap().getNumInputs() +
968                               getTagMap().getNumInputs();
969   if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
970       getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
971     return emitOpError("incorrect number of operands");
972   }
973 
974   for (auto idx : getSrcIndices()) {
975     if (!idx.getType().isIndex())
976       return emitOpError("src index to dma_start must have 'index' type");
977     if (!isValidAffineIndexOperand(idx))
978       return emitOpError("src index must be a dimension or symbol identifier");
979   }
980   for (auto idx : getDstIndices()) {
981     if (!idx.getType().isIndex())
982       return emitOpError("dst index to dma_start must have 'index' type");
983     if (!isValidAffineIndexOperand(idx))
984       return emitOpError("dst index must be a dimension or symbol identifier");
985   }
986   for (auto idx : getTagIndices()) {
987     if (!idx.getType().isIndex())
988       return emitOpError("tag index to dma_start must have 'index' type");
989     if (!isValidAffineIndexOperand(idx))
990       return emitOpError("tag index must be a dimension or symbol identifier");
991   }
992   return success();
993 }
994 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)995 LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
996                                      SmallVectorImpl<OpFoldResult> &results) {
997   /// dma_start(memrefcast) -> dma_start
998   return foldMemRefCast(*this);
999 }
1000 
1001 //===----------------------------------------------------------------------===//
1002 // AffineDmaWaitOp
1003 //===----------------------------------------------------------------------===//
1004 
1005 // TODO(b/133776335) Check that map operands are loop IVs or symbols.
build(Builder * builder,OperationState & result,Value tagMemRef,AffineMap tagMap,ValueRange tagIndices,Value numElements)1006 void AffineDmaWaitOp::build(Builder *builder, OperationState &result,
1007                             Value tagMemRef, AffineMap tagMap,
1008                             ValueRange tagIndices, Value numElements) {
1009   result.addOperands(tagMemRef);
1010   result.addAttribute(getTagMapAttrName(), AffineMapAttr::get(tagMap));
1011   result.addOperands(tagIndices);
1012   result.addOperands(numElements);
1013 }
1014 
print(OpAsmPrinter & p)1015 void AffineDmaWaitOp::print(OpAsmPrinter &p) {
1016   p << "affine.dma_wait " << getTagMemRef() << '[';
1017   SmallVector<Value, 2> operands(getTagIndices());
1018   p.printAffineMapOfSSAIds(getTagMapAttr(), operands);
1019   p << "], ";
1020   p.printOperand(getNumElements());
1021   p << " : " << getTagMemRef().getType();
1022 }
1023 
1024 // Parse AffineDmaWaitOp.
1025 // Eg:
1026 //   affine.dma_wait %tag[%index], %num_elements
1027 //     : memref<1 x i32, (d0) -> (d0), 4>
1028 //
parse(OpAsmParser & parser,OperationState & result)1029 ParseResult AffineDmaWaitOp::parse(OpAsmParser &parser,
1030                                    OperationState &result) {
1031   OpAsmParser::OperandType tagMemRefInfo;
1032   AffineMapAttr tagMapAttr;
1033   SmallVector<OpAsmParser::OperandType, 2> tagMapOperands;
1034   Type type;
1035   auto indexType = parser.getBuilder().getIndexType();
1036   OpAsmParser::OperandType numElementsInfo;
1037 
1038   // Parse tag memref, its map operands, and dma size.
1039   if (parser.parseOperand(tagMemRefInfo) ||
1040       parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1041                                     getTagMapAttrName(), result.attributes) ||
1042       parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1043       parser.parseColonType(type) ||
1044       parser.resolveOperand(tagMemRefInfo, type, result.operands) ||
1045       parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1046       parser.resolveOperand(numElementsInfo, indexType, result.operands))
1047     return failure();
1048 
1049   if (!type.isa<MemRefType>())
1050     return parser.emitError(parser.getNameLoc(),
1051                             "expected tag to be of memref type");
1052 
1053   if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1054     return parser.emitError(parser.getNameLoc(),
1055                             "tag memref operand count != to map.numInputs");
1056   return success();
1057 }
1058 
verify()1059 LogicalResult AffineDmaWaitOp::verify() {
1060   if (!getOperand(0).getType().isa<MemRefType>())
1061     return emitOpError("expected DMA tag to be of memref type");
1062   for (auto idx : getTagIndices()) {
1063     if (!idx.getType().isIndex())
1064       return emitOpError("index to dma_wait must have 'index' type");
1065     if (!isValidAffineIndexOperand(idx))
1066       return emitOpError("index must be a dimension or symbol identifier");
1067   }
1068   return success();
1069 }
1070 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1071 LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1072                                     SmallVectorImpl<OpFoldResult> &results) {
1073   /// dma_wait(memrefcast) -> dma_wait
1074   return foldMemRefCast(*this);
1075 }
1076 
1077 //===----------------------------------------------------------------------===//
1078 // AffineForOp
1079 //===----------------------------------------------------------------------===//
1080 
build(Builder * builder,OperationState & result,ValueRange lbOperands,AffineMap lbMap,ValueRange ubOperands,AffineMap ubMap,int64_t step)1081 void AffineForOp::build(Builder *builder, OperationState &result,
1082                         ValueRange lbOperands, AffineMap lbMap,
1083                         ValueRange ubOperands, AffineMap ubMap, int64_t step) {
1084   assert(((!lbMap && lbOperands.empty()) ||
1085           lbOperands.size() == lbMap.getNumInputs()) &&
1086          "lower bound operand count does not match the affine map");
1087   assert(((!ubMap && ubOperands.empty()) ||
1088           ubOperands.size() == ubMap.getNumInputs()) &&
1089          "upper bound operand count does not match the affine map");
1090   assert(step > 0 && "step has to be a positive integer constant");
1091 
1092   // Add an attribute for the step.
1093   result.addAttribute(getStepAttrName(),
1094                       builder->getIntegerAttr(builder->getIndexType(), step));
1095 
1096   // Add the lower bound.
1097   result.addAttribute(getLowerBoundAttrName(), AffineMapAttr::get(lbMap));
1098   result.addOperands(lbOperands);
1099 
1100   // Add the upper bound.
1101   result.addAttribute(getUpperBoundAttrName(), AffineMapAttr::get(ubMap));
1102   result.addOperands(ubOperands);
1103 
1104   // Create a region and a block for the body.  The argument of the region is
1105   // the loop induction variable.
1106   Region *bodyRegion = result.addRegion();
1107   Block *body = new Block();
1108   body->addArgument(IndexType::get(builder->getContext()));
1109   bodyRegion->push_back(body);
1110   ensureTerminator(*bodyRegion, *builder, result.location);
1111 
1112   // Set the operands list as resizable so that we can freely modify the bounds.
1113   result.setOperandListToResizable();
1114 }
1115 
build(Builder * builder,OperationState & result,int64_t lb,int64_t ub,int64_t step)1116 void AffineForOp::build(Builder *builder, OperationState &result, int64_t lb,
1117                         int64_t ub, int64_t step) {
1118   auto lbMap = AffineMap::getConstantMap(lb, builder->getContext());
1119   auto ubMap = AffineMap::getConstantMap(ub, builder->getContext());
1120   return build(builder, result, {}, lbMap, {}, ubMap, step);
1121 }
1122 
verify(AffineForOp op)1123 static LogicalResult verify(AffineForOp op) {
1124   // Check that the body defines as single block argument for the induction
1125   // variable.
1126   auto *body = op.getBody();
1127   if (body->getNumArguments() != 1 || !body->getArgument(0).getType().isIndex())
1128     return op.emitOpError(
1129         "expected body to have a single index argument for the "
1130         "induction variable");
1131 
1132   // Verify that there are enough operands for the bounds.
1133   AffineMap lowerBoundMap = op.getLowerBoundMap(),
1134             upperBoundMap = op.getUpperBoundMap();
1135   if (op.getNumOperands() !=
1136       (lowerBoundMap.getNumInputs() + upperBoundMap.getNumInputs()))
1137     return op.emitOpError(
1138         "operand count must match with affine map dimension and symbol count");
1139 
1140   // Verify that the bound operands are valid dimension/symbols.
1141   /// Lower bound.
1142   if (failed(verifyDimAndSymbolIdentifiers(op, op.getLowerBoundOperands(),
1143                                            op.getLowerBoundMap().getNumDims())))
1144     return failure();
1145   /// Upper bound.
1146   if (failed(verifyDimAndSymbolIdentifiers(op, op.getUpperBoundOperands(),
1147                                            op.getUpperBoundMap().getNumDims())))
1148     return failure();
1149   return success();
1150 }
1151 
1152 /// Parse a for operation loop bounds.
parseBound(bool isLower,OperationState & result,OpAsmParser & p)1153 static ParseResult parseBound(bool isLower, OperationState &result,
1154                               OpAsmParser &p) {
1155   // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
1156   // the map has multiple results.
1157   bool failedToParsedMinMax =
1158       failed(p.parseOptionalKeyword(isLower ? "max" : "min"));
1159 
1160   auto &builder = p.getBuilder();
1161   auto boundAttrName = isLower ? AffineForOp::getLowerBoundAttrName()
1162                                : AffineForOp::getUpperBoundAttrName();
1163 
1164   // Parse ssa-id as identity map.
1165   SmallVector<OpAsmParser::OperandType, 1> boundOpInfos;
1166   if (p.parseOperandList(boundOpInfos))
1167     return failure();
1168 
1169   if (!boundOpInfos.empty()) {
1170     // Check that only one operand was parsed.
1171     if (boundOpInfos.size() > 1)
1172       return p.emitError(p.getNameLoc(),
1173                          "expected only one loop bound operand");
1174 
1175     // TODO: improve error message when SSA value is not of index type.
1176     // Currently it is 'use of value ... expects different type than prior uses'
1177     if (p.resolveOperand(boundOpInfos.front(), builder.getIndexType(),
1178                          result.operands))
1179       return failure();
1180 
1181     // Create an identity map using symbol id. This representation is optimized
1182     // for storage. Analysis passes may expand it into a multi-dimensional map
1183     // if desired.
1184     AffineMap map = builder.getSymbolIdentityMap();
1185     result.addAttribute(boundAttrName, AffineMapAttr::get(map));
1186     return success();
1187   }
1188 
1189   // Get the attribute location.
1190   llvm::SMLoc attrLoc = p.getCurrentLocation();
1191 
1192   Attribute boundAttr;
1193   if (p.parseAttribute(boundAttr, builder.getIndexType(), boundAttrName,
1194                        result.attributes))
1195     return failure();
1196 
1197   // Parse full form - affine map followed by dim and symbol list.
1198   if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) {
1199     unsigned currentNumOperands = result.operands.size();
1200     unsigned numDims;
1201     if (parseDimAndSymbolList(p, result.operands, numDims))
1202       return failure();
1203 
1204     auto map = affineMapAttr.getValue();
1205     if (map.getNumDims() != numDims)
1206       return p.emitError(
1207           p.getNameLoc(),
1208           "dim operand count and integer set dim count must match");
1209 
1210     unsigned numDimAndSymbolOperands =
1211         result.operands.size() - currentNumOperands;
1212     if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
1213       return p.emitError(
1214           p.getNameLoc(),
1215           "symbol operand count and integer set symbol count must match");
1216 
1217     // If the map has multiple results, make sure that we parsed the min/max
1218     // prefix.
1219     if (map.getNumResults() > 1 && failedToParsedMinMax) {
1220       if (isLower) {
1221         return p.emitError(attrLoc, "lower loop bound affine map with "
1222                                     "multiple results requires 'max' prefix");
1223       }
1224       return p.emitError(attrLoc, "upper loop bound affine map with multiple "
1225                                   "results requires 'min' prefix");
1226     }
1227     return success();
1228   }
1229 
1230   // Parse custom assembly form.
1231   if (auto integerAttr = boundAttr.dyn_cast<IntegerAttr>()) {
1232     result.attributes.pop_back();
1233     result.addAttribute(
1234         boundAttrName,
1235         AffineMapAttr::get(builder.getConstantAffineMap(integerAttr.getInt())));
1236     return success();
1237   }
1238 
1239   return p.emitError(
1240       p.getNameLoc(),
1241       "expected valid affine map representation for loop bounds");
1242 }
1243 
parseAffineForOp(OpAsmParser & parser,OperationState & result)1244 static ParseResult parseAffineForOp(OpAsmParser &parser,
1245                                     OperationState &result) {
1246   auto &builder = parser.getBuilder();
1247   OpAsmParser::OperandType inductionVariable;
1248   // Parse the induction variable followed by '='.
1249   if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual())
1250     return failure();
1251 
1252   // Parse loop bounds.
1253   if (parseBound(/*isLower=*/true, result, parser) ||
1254       parser.parseKeyword("to", " between bounds") ||
1255       parseBound(/*isLower=*/false, result, parser))
1256     return failure();
1257 
1258   // Parse the optional loop step, we default to 1 if one is not present.
1259   if (parser.parseOptionalKeyword("step")) {
1260     result.addAttribute(
1261         AffineForOp::getStepAttrName(),
1262         builder.getIntegerAttr(builder.getIndexType(), /*value=*/1));
1263   } else {
1264     llvm::SMLoc stepLoc = parser.getCurrentLocation();
1265     IntegerAttr stepAttr;
1266     if (parser.parseAttribute(stepAttr, builder.getIndexType(),
1267                               AffineForOp::getStepAttrName().data(),
1268                               result.attributes))
1269       return failure();
1270 
1271     if (stepAttr.getValue().getSExtValue() < 0)
1272       return parser.emitError(
1273           stepLoc,
1274           "expected step to be representable as a positive signed integer");
1275   }
1276 
1277   // Parse the body region.
1278   Region *body = result.addRegion();
1279   if (parser.parseRegion(*body, inductionVariable, builder.getIndexType()))
1280     return failure();
1281 
1282   AffineForOp::ensureTerminator(*body, builder, result.location);
1283 
1284   // Parse the optional attribute list.
1285   if (parser.parseOptionalAttrDict(result.attributes))
1286     return failure();
1287 
1288   // Set the operands list as resizable so that we can freely modify the bounds.
1289   result.setOperandListToResizable();
1290   return success();
1291 }
1292 
printBound(AffineMapAttr boundMap,Operation::operand_range boundOperands,const char * prefix,OpAsmPrinter & p)1293 static void printBound(AffineMapAttr boundMap,
1294                        Operation::operand_range boundOperands,
1295                        const char *prefix, OpAsmPrinter &p) {
1296   AffineMap map = boundMap.getValue();
1297 
1298   // Check if this bound should be printed using custom assembly form.
1299   // The decision to restrict printing custom assembly form to trivial cases
1300   // comes from the will to roundtrip MLIR binary -> text -> binary in a
1301   // lossless way.
1302   // Therefore, custom assembly form parsing and printing is only supported for
1303   // zero-operand constant maps and single symbol operand identity maps.
1304   if (map.getNumResults() == 1) {
1305     AffineExpr expr = map.getResult(0);
1306 
1307     // Print constant bound.
1308     if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
1309       if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
1310         p << constExpr.getValue();
1311         return;
1312       }
1313     }
1314 
1315     // Print bound that consists of a single SSA symbol if the map is over a
1316     // single symbol.
1317     if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
1318       if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
1319         p.printOperand(*boundOperands.begin());
1320         return;
1321       }
1322     }
1323   } else {
1324     // Map has multiple results. Print 'min' or 'max' prefix.
1325     p << prefix << ' ';
1326   }
1327 
1328   // Print the map and its operands.
1329   p << boundMap;
1330   printDimAndSymbolList(boundOperands.begin(), boundOperands.end(),
1331                         map.getNumDims(), p);
1332 }
1333 
print(OpAsmPrinter & p,AffineForOp op)1334 static void print(OpAsmPrinter &p, AffineForOp op) {
1335   p << "affine.for ";
1336   p.printOperand(op.getBody()->getArgument(0));
1337   p << " = ";
1338   printBound(op.getLowerBoundMapAttr(), op.getLowerBoundOperands(), "max", p);
1339   p << " to ";
1340   printBound(op.getUpperBoundMapAttr(), op.getUpperBoundOperands(), "min", p);
1341 
1342   if (op.getStep() != 1)
1343     p << " step " << op.getStep();
1344   p.printRegion(op.region(),
1345                 /*printEntryBlockArgs=*/false,
1346                 /*printBlockTerminators=*/false);
1347   p.printOptionalAttrDict(op.getAttrs(),
1348                           /*elidedAttrs=*/{op.getLowerBoundAttrName(),
1349                                            op.getUpperBoundAttrName(),
1350                                            op.getStepAttrName()});
1351 }
1352 
1353 /// Fold the constant bounds of a loop.
foldLoopBounds(AffineForOp forOp)1354 static LogicalResult foldLoopBounds(AffineForOp forOp) {
1355   auto foldLowerOrUpperBound = [&forOp](bool lower) {
1356     // Check to see if each of the operands is the result of a constant.  If
1357     // so, get the value.  If not, ignore it.
1358     SmallVector<Attribute, 8> operandConstants;
1359     auto boundOperands =
1360         lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
1361     for (auto operand : boundOperands) {
1362       Attribute operandCst;
1363       matchPattern(operand, m_Constant(&operandCst));
1364       operandConstants.push_back(operandCst);
1365     }
1366 
1367     AffineMap boundMap =
1368         lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
1369     assert(boundMap.getNumResults() >= 1 &&
1370            "bound maps should have at least one result");
1371     SmallVector<Attribute, 4> foldedResults;
1372     if (failed(boundMap.constantFold(operandConstants, foldedResults)))
1373       return failure();
1374 
1375     // Compute the max or min as applicable over the results.
1376     assert(!foldedResults.empty() && "bounds should have at least one result");
1377     auto maxOrMin = foldedResults[0].cast<IntegerAttr>().getValue();
1378     for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
1379       auto foldedResult = foldedResults[i].cast<IntegerAttr>().getValue();
1380       maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
1381                        : llvm::APIntOps::smin(maxOrMin, foldedResult);
1382     }
1383     lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
1384           : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
1385     return success();
1386   };
1387 
1388   // Try to fold the lower bound.
1389   bool folded = false;
1390   if (!forOp.hasConstantLowerBound())
1391     folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true));
1392 
1393   // Try to fold the upper bound.
1394   if (!forOp.hasConstantUpperBound())
1395     folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false));
1396   return success(folded);
1397 }
1398 
1399 /// Canonicalize the bounds of the given loop.
canonicalizeLoopBounds(AffineForOp forOp)1400 static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
1401   SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
1402   SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
1403 
1404   auto lbMap = forOp.getLowerBoundMap();
1405   auto ubMap = forOp.getUpperBoundMap();
1406   auto prevLbMap = lbMap;
1407   auto prevUbMap = ubMap;
1408 
1409   canonicalizeMapAndOperands(&lbMap, &lbOperands);
1410   canonicalizeMapAndOperands(&ubMap, &ubOperands);
1411 
1412   // Any canonicalization change always leads to updated map(s).
1413   if (lbMap == prevLbMap && ubMap == prevUbMap)
1414     return failure();
1415 
1416   if (lbMap != prevLbMap)
1417     forOp.setLowerBound(lbOperands, lbMap);
1418   if (ubMap != prevUbMap)
1419     forOp.setUpperBound(ubOperands, ubMap);
1420   return success();
1421 }
1422 
1423 namespace {
1424 /// This is a pattern to fold trivially empty loops.
1425 struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
1426   using OpRewritePattern<AffineForOp>::OpRewritePattern;
1427 
matchAndRewrite__anon86bd9b9c0811::AffineForEmptyLoopFolder1428   PatternMatchResult matchAndRewrite(AffineForOp forOp,
1429                                      PatternRewriter &rewriter) const override {
1430     // Check that the body only contains a terminator.
1431     if (!has_single_element(*forOp.getBody()))
1432       return matchFailure();
1433     rewriter.eraseOp(forOp);
1434     return matchSuccess();
1435   }
1436 };
1437 } // end anonymous namespace
1438 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1439 void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1440                                               MLIRContext *context) {
1441   results.insert<AffineForEmptyLoopFolder>(context);
1442 }
1443 
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1444 LogicalResult AffineForOp::fold(ArrayRef<Attribute> operands,
1445                                 SmallVectorImpl<OpFoldResult> &results) {
1446   bool folded = succeeded(foldLoopBounds(*this));
1447   folded |= succeeded(canonicalizeLoopBounds(*this));
1448   return success(folded);
1449 }
1450 
getLowerBound()1451 AffineBound AffineForOp::getLowerBound() {
1452   auto lbMap = getLowerBoundMap();
1453   return AffineBound(AffineForOp(*this), 0, lbMap.getNumInputs(), lbMap);
1454 }
1455 
getUpperBound()1456 AffineBound AffineForOp::getUpperBound() {
1457   auto lbMap = getLowerBoundMap();
1458   auto ubMap = getUpperBoundMap();
1459   return AffineBound(AffineForOp(*this), lbMap.getNumInputs(), getNumOperands(),
1460                      ubMap);
1461 }
1462 
setLowerBound(ValueRange lbOperands,AffineMap map)1463 void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) {
1464   assert(lbOperands.size() == map.getNumInputs());
1465   assert(map.getNumResults() >= 1 && "bound map has at least one result");
1466 
1467   SmallVector<Value, 4> newOperands(lbOperands.begin(), lbOperands.end());
1468 
1469   auto ubOperands = getUpperBoundOperands();
1470   newOperands.append(ubOperands.begin(), ubOperands.end());
1471   getOperation()->setOperands(newOperands);
1472 
1473   setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map));
1474 }
1475 
setUpperBound(ValueRange ubOperands,AffineMap map)1476 void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) {
1477   assert(ubOperands.size() == map.getNumInputs());
1478   assert(map.getNumResults() >= 1 && "bound map has at least one result");
1479 
1480   SmallVector<Value, 4> newOperands(getLowerBoundOperands());
1481   newOperands.append(ubOperands.begin(), ubOperands.end());
1482   getOperation()->setOperands(newOperands);
1483 
1484   setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map));
1485 }
1486 
setLowerBoundMap(AffineMap map)1487 void AffineForOp::setLowerBoundMap(AffineMap map) {
1488   auto lbMap = getLowerBoundMap();
1489   assert(lbMap.getNumDims() == map.getNumDims() &&
1490          lbMap.getNumSymbols() == map.getNumSymbols());
1491   assert(map.getNumResults() >= 1 && "bound map has at least one result");
1492   (void)lbMap;
1493   setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map));
1494 }
1495 
setUpperBoundMap(AffineMap map)1496 void AffineForOp::setUpperBoundMap(AffineMap map) {
1497   auto ubMap = getUpperBoundMap();
1498   assert(ubMap.getNumDims() == map.getNumDims() &&
1499          ubMap.getNumSymbols() == map.getNumSymbols());
1500   assert(map.getNumResults() >= 1 && "bound map has at least one result");
1501   (void)ubMap;
1502   setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map));
1503 }
1504 
hasConstantLowerBound()1505 bool AffineForOp::hasConstantLowerBound() {
1506   return getLowerBoundMap().isSingleConstant();
1507 }
1508 
hasConstantUpperBound()1509 bool AffineForOp::hasConstantUpperBound() {
1510   return getUpperBoundMap().isSingleConstant();
1511 }
1512 
getConstantLowerBound()1513 int64_t AffineForOp::getConstantLowerBound() {
1514   return getLowerBoundMap().getSingleConstantResult();
1515 }
1516 
getConstantUpperBound()1517 int64_t AffineForOp::getConstantUpperBound() {
1518   return getUpperBoundMap().getSingleConstantResult();
1519 }
1520 
setConstantLowerBound(int64_t value)1521 void AffineForOp::setConstantLowerBound(int64_t value) {
1522   setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
1523 }
1524 
setConstantUpperBound(int64_t value)1525 void AffineForOp::setConstantUpperBound(int64_t value) {
1526   setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
1527 }
1528 
getLowerBoundOperands()1529 AffineForOp::operand_range AffineForOp::getLowerBoundOperands() {
1530   return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()};
1531 }
1532 
getUpperBoundOperands()1533 AffineForOp::operand_range AffineForOp::getUpperBoundOperands() {
1534   return {operand_begin() + getLowerBoundMap().getNumInputs(), operand_end()};
1535 }
1536 
matchingBoundOperandList()1537 bool AffineForOp::matchingBoundOperandList() {
1538   auto lbMap = getLowerBoundMap();
1539   auto ubMap = getUpperBoundMap();
1540   if (lbMap.getNumDims() != ubMap.getNumDims() ||
1541       lbMap.getNumSymbols() != ubMap.getNumSymbols())
1542     return false;
1543 
1544   unsigned numOperands = lbMap.getNumInputs();
1545   for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
1546     // Compare Value 's.
1547     if (getOperand(i) != getOperand(numOperands + i))
1548       return false;
1549   }
1550   return true;
1551 }
1552 
getLoopBody()1553 Region &AffineForOp::getLoopBody() { return region(); }
1554 
isDefinedOutsideOfLoop(Value value)1555 bool AffineForOp::isDefinedOutsideOfLoop(Value value) {
1556   return !region().isAncestor(value.getParentRegion());
1557 }
1558 
moveOutOfLoop(ArrayRef<Operation * > ops)1559 LogicalResult AffineForOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
1560   for (auto *op : ops)
1561     op->moveBefore(*this);
1562   return success();
1563 }
1564 
1565 /// Returns if the provided value is the induction variable of a AffineForOp.
isForInductionVar(Value val)1566 bool mlir::isForInductionVar(Value val) {
1567   return getForInductionVarOwner(val) != AffineForOp();
1568 }
1569 
1570 /// Returns the loop parent of an induction variable. If the provided value is
1571 /// not an induction variable, then return nullptr.
getForInductionVarOwner(Value val)1572 AffineForOp mlir::getForInductionVarOwner(Value val) {
1573   auto ivArg = val.dyn_cast<BlockArgument>();
1574   if (!ivArg || !ivArg.getOwner())
1575     return AffineForOp();
1576   auto *containingInst = ivArg.getOwner()->getParent()->getParentOp();
1577   return dyn_cast<AffineForOp>(containingInst);
1578 }
1579 
1580 /// Extracts the induction variables from a list of AffineForOps and returns
1581 /// them.
extractForInductionVars(ArrayRef<AffineForOp> forInsts,SmallVectorImpl<Value> * ivs)1582 void mlir::extractForInductionVars(ArrayRef<AffineForOp> forInsts,
1583                                    SmallVectorImpl<Value> *ivs) {
1584   ivs->reserve(forInsts.size());
1585   for (auto forInst : forInsts)
1586     ivs->push_back(forInst.getInductionVar());
1587 }
1588 
1589 //===----------------------------------------------------------------------===//
1590 // AffineIfOp
1591 //===----------------------------------------------------------------------===//
1592 
verify(AffineIfOp op)1593 static LogicalResult verify(AffineIfOp op) {
1594   // Verify that we have a condition attribute.
1595   auto conditionAttr =
1596       op.getAttrOfType<IntegerSetAttr>(op.getConditionAttrName());
1597   if (!conditionAttr)
1598     return op.emitOpError(
1599         "requires an integer set attribute named 'condition'");
1600 
1601   // Verify that there are enough operands for the condition.
1602   IntegerSet condition = conditionAttr.getValue();
1603   if (op.getNumOperands() != condition.getNumInputs())
1604     return op.emitOpError(
1605         "operand count and condition integer set dimension and "
1606         "symbol count must match");
1607 
1608   // Verify that the operands are valid dimension/symbols.
1609   if (failed(verifyDimAndSymbolIdentifiers(
1610           op, op.getOperation()->getNonSuccessorOperands(),
1611           condition.getNumDims())))
1612     return failure();
1613 
1614   // Verify that the entry of each child region does not have arguments.
1615   for (auto &region : op.getOperation()->getRegions()) {
1616     for (auto &b : region)
1617       if (b.getNumArguments() != 0)
1618         return op.emitOpError(
1619             "requires that child entry blocks have no arguments");
1620   }
1621   return success();
1622 }
1623 
parseAffineIfOp(OpAsmParser & parser,OperationState & result)1624 static ParseResult parseAffineIfOp(OpAsmParser &parser,
1625                                    OperationState &result) {
1626   // Parse the condition attribute set.
1627   IntegerSetAttr conditionAttr;
1628   unsigned numDims;
1629   if (parser.parseAttribute(conditionAttr, AffineIfOp::getConditionAttrName(),
1630                             result.attributes) ||
1631       parseDimAndSymbolList(parser, result.operands, numDims))
1632     return failure();
1633 
1634   // Verify the condition operands.
1635   auto set = conditionAttr.getValue();
1636   if (set.getNumDims() != numDims)
1637     return parser.emitError(
1638         parser.getNameLoc(),
1639         "dim operand count and integer set dim count must match");
1640   if (numDims + set.getNumSymbols() != result.operands.size())
1641     return parser.emitError(
1642         parser.getNameLoc(),
1643         "symbol operand count and integer set symbol count must match");
1644 
1645   // Create the regions for 'then' and 'else'.  The latter must be created even
1646   // if it remains empty for the validity of the operation.
1647   result.regions.reserve(2);
1648   Region *thenRegion = result.addRegion();
1649   Region *elseRegion = result.addRegion();
1650 
1651   // Parse the 'then' region.
1652   if (parser.parseRegion(*thenRegion, {}, {}))
1653     return failure();
1654   AffineIfOp::ensureTerminator(*thenRegion, parser.getBuilder(),
1655                                result.location);
1656 
1657   // If we find an 'else' keyword then parse the 'else' region.
1658   if (!parser.parseOptionalKeyword("else")) {
1659     if (parser.parseRegion(*elseRegion, {}, {}))
1660       return failure();
1661     AffineIfOp::ensureTerminator(*elseRegion, parser.getBuilder(),
1662                                  result.location);
1663   }
1664 
1665   // Parse the optional attribute list.
1666   if (parser.parseOptionalAttrDict(result.attributes))
1667     return failure();
1668 
1669   return success();
1670 }
1671 
print(OpAsmPrinter & p,AffineIfOp op)1672 static void print(OpAsmPrinter &p, AffineIfOp op) {
1673   auto conditionAttr =
1674       op.getAttrOfType<IntegerSetAttr>(op.getConditionAttrName());
1675   p << "affine.if " << conditionAttr;
1676   printDimAndSymbolList(op.operand_begin(), op.operand_end(),
1677                         conditionAttr.getValue().getNumDims(), p);
1678   p.printRegion(op.thenRegion(),
1679                 /*printEntryBlockArgs=*/false,
1680                 /*printBlockTerminators=*/false);
1681 
1682   // Print the 'else' regions if it has any blocks.
1683   auto &elseRegion = op.elseRegion();
1684   if (!elseRegion.empty()) {
1685     p << " else";
1686     p.printRegion(elseRegion,
1687                   /*printEntryBlockArgs=*/false,
1688                   /*printBlockTerminators=*/false);
1689   }
1690 
1691   // Print the attribute list.
1692   p.printOptionalAttrDict(op.getAttrs(),
1693                           /*elidedAttrs=*/op.getConditionAttrName());
1694 }
1695 
getIntegerSet()1696 IntegerSet AffineIfOp::getIntegerSet() {
1697   return getAttrOfType<IntegerSetAttr>(getConditionAttrName()).getValue();
1698 }
setIntegerSet(IntegerSet newSet)1699 void AffineIfOp::setIntegerSet(IntegerSet newSet) {
1700   setAttr(getConditionAttrName(), IntegerSetAttr::get(newSet));
1701 }
1702 
setConditional(IntegerSet set,ValueRange operands)1703 void AffineIfOp::setConditional(IntegerSet set, ValueRange operands) {
1704   setIntegerSet(set);
1705   getOperation()->setOperands(operands);
1706 }
1707 
build(Builder * builder,OperationState & result,IntegerSet set,ValueRange args,bool withElseRegion)1708 void AffineIfOp::build(Builder *builder, OperationState &result, IntegerSet set,
1709                        ValueRange args, bool withElseRegion) {
1710   result.addOperands(args);
1711   result.addAttribute(getConditionAttrName(), IntegerSetAttr::get(set));
1712   Region *thenRegion = result.addRegion();
1713   Region *elseRegion = result.addRegion();
1714   AffineIfOp::ensureTerminator(*thenRegion, *builder, result.location);
1715   if (withElseRegion)
1716     AffineIfOp::ensureTerminator(*elseRegion, *builder, result.location);
1717 }
1718 
1719 /// Canonicalize an affine if op's conditional (integer set + operands).
fold(ArrayRef<Attribute>,SmallVectorImpl<OpFoldResult> &)1720 LogicalResult AffineIfOp::fold(ArrayRef<Attribute>,
1721                                SmallVectorImpl<OpFoldResult> &) {
1722   auto set = getIntegerSet();
1723   SmallVector<Value, 4> operands(getOperands());
1724   canonicalizeSetAndOperands(&set, &operands);
1725 
1726   // Any canonicalization change always leads to either a reduction in the
1727   // number of operands or a change in the number of symbolic operands
1728   // (promotion of dims to symbols).
1729   if (operands.size() < getIntegerSet().getNumInputs() ||
1730       set.getNumSymbols() > getIntegerSet().getNumSymbols()) {
1731     setConditional(set, operands);
1732     return success();
1733   }
1734 
1735   return failure();
1736 }
1737 
1738 //===----------------------------------------------------------------------===//
1739 // AffineLoadOp
1740 //===----------------------------------------------------------------------===//
1741 
build(Builder * builder,OperationState & result,AffineMap map,ValueRange operands)1742 void AffineLoadOp::build(Builder *builder, OperationState &result,
1743                          AffineMap map, ValueRange operands) {
1744   assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
1745   result.addOperands(operands);
1746   if (map)
1747     result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
1748   auto memrefType = operands[0].getType().cast<MemRefType>();
1749   result.types.push_back(memrefType.getElementType());
1750 }
1751 
build(Builder * builder,OperationState & result,Value memref,AffineMap map,ValueRange mapOperands)1752 void AffineLoadOp::build(Builder *builder, OperationState &result, Value memref,
1753                          AffineMap map, ValueRange mapOperands) {
1754   assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
1755   result.addOperands(memref);
1756   result.addOperands(mapOperands);
1757   auto memrefType = memref.getType().cast<MemRefType>();
1758   result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
1759   result.types.push_back(memrefType.getElementType());
1760 }
1761 
build(Builder * builder,OperationState & result,Value memref,ValueRange indices)1762 void AffineLoadOp::build(Builder *builder, OperationState &result, Value memref,
1763                          ValueRange indices) {
1764   auto memrefType = memref.getType().cast<MemRefType>();
1765   auto rank = memrefType.getRank();
1766   // Create identity map for memrefs with at least one dimension or () -> ()
1767   // for zero-dimensional memrefs.
1768   auto map = rank ? builder->getMultiDimIdentityMap(rank)
1769                   : builder->getEmptyAffineMap();
1770   build(builder, result, memref, map, indices);
1771 }
1772 
parse(OpAsmParser & parser,OperationState & result)1773 ParseResult AffineLoadOp::parse(OpAsmParser &parser, OperationState &result) {
1774   auto &builder = parser.getBuilder();
1775   auto indexTy = builder.getIndexType();
1776 
1777   MemRefType type;
1778   OpAsmParser::OperandType memrefInfo;
1779   AffineMapAttr mapAttr;
1780   SmallVector<OpAsmParser::OperandType, 1> mapOperands;
1781   return failure(
1782       parser.parseOperand(memrefInfo) ||
1783       parser.parseAffineMapOfSSAIds(mapOperands, mapAttr, getMapAttrName(),
1784                                     result.attributes) ||
1785       parser.parseOptionalAttrDict(result.attributes) ||
1786       parser.parseColonType(type) ||
1787       parser.resolveOperand(memrefInfo, type, result.operands) ||
1788       parser.resolveOperands(mapOperands, indexTy, result.operands) ||
1789       parser.addTypeToList(type.getElementType(), result.types));
1790 }
1791 
print(OpAsmPrinter & p)1792 void AffineLoadOp::print(OpAsmPrinter &p) {
1793   p << "affine.load " << getMemRef() << '[';
1794   if (AffineMapAttr mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName()))
1795     p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
1796   p << ']';
1797   p.printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{getMapAttrName()});
1798   p << " : " << getMemRefType();
1799 }
1800 
verify()1801 LogicalResult AffineLoadOp::verify() {
1802   if (getType() != getMemRefType().getElementType())
1803     return emitOpError("result type must match element type of memref");
1804 
1805   auto mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName());
1806   if (mapAttr) {
1807     AffineMap map = getAttrOfType<AffineMapAttr>(getMapAttrName()).getValue();
1808     if (map.getNumResults() != getMemRefType().getRank())
1809       return emitOpError("affine.load affine map num results must equal"
1810                          " memref rank");
1811     if (map.getNumInputs() != getNumOperands() - 1)
1812       return emitOpError("expects as many subscripts as affine map inputs");
1813   } else {
1814     if (getMemRefType().getRank() != getNumOperands() - 1)
1815       return emitOpError(
1816           "expects the number of subscripts to be equal to memref rank");
1817   }
1818 
1819   for (auto idx : getMapOperands()) {
1820     if (!idx.getType().isIndex())
1821       return emitOpError("index to load must have 'index' type");
1822     if (!isValidAffineIndexOperand(idx))
1823       return emitOpError("index must be a dimension or symbol identifier");
1824   }
1825   return success();
1826 }
1827 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1828 void AffineLoadOp::getCanonicalizationPatterns(
1829     OwningRewritePatternList &results, MLIRContext *context) {
1830   results.insert<SimplifyAffineOp<AffineLoadOp>>(context);
1831 }
1832 
fold(ArrayRef<Attribute> cstOperands)1833 OpFoldResult AffineLoadOp::fold(ArrayRef<Attribute> cstOperands) {
1834   /// load(memrefcast) -> load
1835   if (succeeded(foldMemRefCast(*this)))
1836     return getResult();
1837   return OpFoldResult();
1838 }
1839 
1840 //===----------------------------------------------------------------------===//
1841 // AffineStoreOp
1842 //===----------------------------------------------------------------------===//
1843 
build(Builder * builder,OperationState & result,Value valueToStore,Value memref,AffineMap map,ValueRange mapOperands)1844 void AffineStoreOp::build(Builder *builder, OperationState &result,
1845                           Value valueToStore, Value memref, AffineMap map,
1846                           ValueRange mapOperands) {
1847   assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
1848   result.addOperands(valueToStore);
1849   result.addOperands(memref);
1850   result.addOperands(mapOperands);
1851   result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
1852 }
1853 
1854 // Use identity map.
build(Builder * builder,OperationState & result,Value valueToStore,Value memref,ValueRange indices)1855 void AffineStoreOp::build(Builder *builder, OperationState &result,
1856                           Value valueToStore, Value memref,
1857                           ValueRange indices) {
1858   auto memrefType = memref.getType().cast<MemRefType>();
1859   auto rank = memrefType.getRank();
1860   // Create identity map for memrefs with at least one dimension or () -> ()
1861   // for zero-dimensional memrefs.
1862   auto map = rank ? builder->getMultiDimIdentityMap(rank)
1863                   : builder->getEmptyAffineMap();
1864   build(builder, result, valueToStore, memref, map, indices);
1865 }
1866 
parse(OpAsmParser & parser,OperationState & result)1867 ParseResult AffineStoreOp::parse(OpAsmParser &parser, OperationState &result) {
1868   auto indexTy = parser.getBuilder().getIndexType();
1869 
1870   MemRefType type;
1871   OpAsmParser::OperandType storeValueInfo;
1872   OpAsmParser::OperandType memrefInfo;
1873   AffineMapAttr mapAttr;
1874   SmallVector<OpAsmParser::OperandType, 1> mapOperands;
1875   return failure(parser.parseOperand(storeValueInfo) || parser.parseComma() ||
1876                  parser.parseOperand(memrefInfo) ||
1877                  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
1878                                                getMapAttrName(),
1879                                                result.attributes) ||
1880                  parser.parseOptionalAttrDict(result.attributes) ||
1881                  parser.parseColonType(type) ||
1882                  parser.resolveOperand(storeValueInfo, type.getElementType(),
1883                                        result.operands) ||
1884                  parser.resolveOperand(memrefInfo, type, result.operands) ||
1885                  parser.resolveOperands(mapOperands, indexTy, result.operands));
1886 }
1887 
print(OpAsmPrinter & p)1888 void AffineStoreOp::print(OpAsmPrinter &p) {
1889   p << "affine.store " << getValueToStore();
1890   p << ", " << getMemRef() << '[';
1891   if (AffineMapAttr mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName()))
1892     p.printAffineMapOfSSAIds(mapAttr, getMapOperands());
1893   p << ']';
1894   p.printOptionalAttrDict(getAttrs(), /*elidedAttrs=*/{getMapAttrName()});
1895   p << " : " << getMemRefType();
1896 }
1897 
verify()1898 LogicalResult AffineStoreOp::verify() {
1899   // First operand must have same type as memref element type.
1900   if (getValueToStore().getType() != getMemRefType().getElementType())
1901     return emitOpError("first operand must have same type memref element type");
1902 
1903   auto mapAttr = getAttrOfType<AffineMapAttr>(getMapAttrName());
1904   if (mapAttr) {
1905     AffineMap map = mapAttr.getValue();
1906     if (map.getNumResults() != getMemRefType().getRank())
1907       return emitOpError("affine.store affine map num results must equal"
1908                          " memref rank");
1909     if (map.getNumInputs() != getNumOperands() - 2)
1910       return emitOpError("expects as many subscripts as affine map inputs");
1911   } else {
1912     if (getMemRefType().getRank() != getNumOperands() - 2)
1913       return emitOpError(
1914           "expects the number of subscripts to be equal to memref rank");
1915   }
1916 
1917   for (auto idx : getMapOperands()) {
1918     if (!idx.getType().isIndex())
1919       return emitOpError("index to store must have 'index' type");
1920     if (!isValidAffineIndexOperand(idx))
1921       return emitOpError("index must be a dimension or symbol identifier");
1922   }
1923   return success();
1924 }
1925 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1926 void AffineStoreOp::getCanonicalizationPatterns(
1927     OwningRewritePatternList &results, MLIRContext *context) {
1928   results.insert<SimplifyAffineOp<AffineStoreOp>>(context);
1929 }
1930 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1931 LogicalResult AffineStoreOp::fold(ArrayRef<Attribute> cstOperands,
1932                                   SmallVectorImpl<OpFoldResult> &results) {
1933   /// store(memrefcast) -> store
1934   return foldMemRefCast(*this);
1935 }
1936 
1937 //===----------------------------------------------------------------------===//
1938 // AffineMinOp
1939 //===----------------------------------------------------------------------===//
1940 //
1941 //   %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
1942 //
1943 
parseAffineMinOp(OpAsmParser & parser,OperationState & result)1944 static ParseResult parseAffineMinOp(OpAsmParser &parser,
1945                                     OperationState &result) {
1946   auto &builder = parser.getBuilder();
1947   auto indexType = builder.getIndexType();
1948   SmallVector<OpAsmParser::OperandType, 8> dim_infos;
1949   SmallVector<OpAsmParser::OperandType, 8> sym_infos;
1950   AffineMapAttr mapAttr;
1951   return failure(
1952       parser.parseAttribute(mapAttr, AffineMinOp::getMapAttrName(),
1953                             result.attributes) ||
1954       parser.parseOperandList(dim_infos, OpAsmParser::Delimiter::Paren) ||
1955       parser.parseOperandList(sym_infos,
1956                               OpAsmParser::Delimiter::OptionalSquare) ||
1957       parser.parseOptionalAttrDict(result.attributes) ||
1958       parser.resolveOperands(dim_infos, indexType, result.operands) ||
1959       parser.resolveOperands(sym_infos, indexType, result.operands) ||
1960       parser.addTypeToList(indexType, result.types));
1961 }
1962 
print(OpAsmPrinter & p,AffineMinOp op)1963 static void print(OpAsmPrinter &p, AffineMinOp op) {
1964   p << op.getOperationName() << ' '
1965     << op.getAttr(AffineMinOp::getMapAttrName());
1966   auto operands = op.getOperands();
1967   unsigned numDims = op.map().getNumDims();
1968   p << '(' << operands.take_front(numDims) << ')';
1969 
1970   if (operands.size() != numDims)
1971     p << '[' << operands.drop_front(numDims) << ']';
1972   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"});
1973 }
1974 
verify(AffineMinOp op)1975 static LogicalResult verify(AffineMinOp op) {
1976   // Verify that operand count matches affine map dimension and symbol count.
1977   if (op.getNumOperands() != op.map().getNumDims() + op.map().getNumSymbols())
1978     return op.emitOpError(
1979         "operand count and affine map dimension and symbol count must match");
1980   return success();
1981 }
1982 
fold(ArrayRef<Attribute> operands)1983 OpFoldResult AffineMinOp::fold(ArrayRef<Attribute> operands) {
1984   // Fold the affine map.
1985   // TODO(andydavis, ntv) Fold more cases: partial static information,
1986   // min(some_affine, some_affine + constant, ...).
1987   SmallVector<Attribute, 2> results;
1988   if (failed(map().constantFold(operands, results)))
1989     return {};
1990 
1991   // Compute and return min of folded map results.
1992   int64_t min = std::numeric_limits<int64_t>::max();
1993   int minIndex = -1;
1994   for (unsigned i = 0, e = results.size(); i < e; ++i) {
1995     auto intAttr = results[i].cast<IntegerAttr>();
1996     if (intAttr.getInt() < min) {
1997       min = intAttr.getInt();
1998       minIndex = i;
1999     }
2000   }
2001   if (minIndex < 0)
2002     return {};
2003   return results[minIndex];
2004 }
2005 
2006 //===----------------------------------------------------------------------===//
2007 // AffinePrefetchOp
2008 //===----------------------------------------------------------------------===//
2009 
2010 //
2011 // affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32>
2012 //
parseAffinePrefetchOp(OpAsmParser & parser,OperationState & result)2013 static ParseResult parseAffinePrefetchOp(OpAsmParser &parser,
2014                                          OperationState &result) {
2015   auto &builder = parser.getBuilder();
2016   auto indexTy = builder.getIndexType();
2017 
2018   MemRefType type;
2019   OpAsmParser::OperandType memrefInfo;
2020   IntegerAttr hintInfo;
2021   auto i32Type = parser.getBuilder().getIntegerType(32);
2022   StringRef readOrWrite, cacheType;
2023 
2024   AffineMapAttr mapAttr;
2025   SmallVector<OpAsmParser::OperandType, 1> mapOperands;
2026   if (parser.parseOperand(memrefInfo) ||
2027       parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
2028                                     AffinePrefetchOp::getMapAttrName(),
2029                                     result.attributes) ||
2030       parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
2031       parser.parseComma() || parser.parseKeyword("locality") ||
2032       parser.parseLess() ||
2033       parser.parseAttribute(hintInfo, i32Type,
2034                             AffinePrefetchOp::getLocalityHintAttrName(),
2035                             result.attributes) ||
2036       parser.parseGreater() || parser.parseComma() ||
2037       parser.parseKeyword(&cacheType) ||
2038       parser.parseOptionalAttrDict(result.attributes) ||
2039       parser.parseColonType(type) ||
2040       parser.resolveOperand(memrefInfo, type, result.operands) ||
2041       parser.resolveOperands(mapOperands, indexTy, result.operands))
2042     return failure();
2043 
2044   if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
2045     return parser.emitError(parser.getNameLoc(),
2046                             "rw specifier has to be 'read' or 'write'");
2047   result.addAttribute(
2048       AffinePrefetchOp::getIsWriteAttrName(),
2049       parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
2050 
2051   if (!cacheType.equals("data") && !cacheType.equals("instr"))
2052     return parser.emitError(parser.getNameLoc(),
2053                             "cache type has to be 'data' or 'instr'");
2054 
2055   result.addAttribute(
2056       AffinePrefetchOp::getIsDataCacheAttrName(),
2057       parser.getBuilder().getBoolAttr(cacheType.equals("data")));
2058 
2059   return success();
2060 }
2061 
print(OpAsmPrinter & p,AffinePrefetchOp op)2062 static void print(OpAsmPrinter &p, AffinePrefetchOp op) {
2063   p << AffinePrefetchOp::getOperationName() << " " << op.memref() << '[';
2064   AffineMapAttr mapAttr = op.getAttrOfType<AffineMapAttr>(op.getMapAttrName());
2065   if (mapAttr) {
2066     SmallVector<Value, 2> operands(op.getMapOperands());
2067     p.printAffineMapOfSSAIds(mapAttr, operands);
2068   }
2069   p << ']' << ", " << (op.isWrite() ? "write" : "read") << ", "
2070     << "locality<" << op.localityHint() << ">, "
2071     << (op.isDataCache() ? "data" : "instr");
2072   p.printOptionalAttrDict(
2073       op.getAttrs(),
2074       /*elidedAttrs=*/{op.getMapAttrName(), op.getLocalityHintAttrName(),
2075                        op.getIsDataCacheAttrName(), op.getIsWriteAttrName()});
2076   p << " : " << op.getMemRefType();
2077 }
2078 
verify(AffinePrefetchOp op)2079 static LogicalResult verify(AffinePrefetchOp op) {
2080   auto mapAttr = op.getAttrOfType<AffineMapAttr>(op.getMapAttrName());
2081   if (mapAttr) {
2082     AffineMap map = mapAttr.getValue();
2083     if (map.getNumResults() != op.getMemRefType().getRank())
2084       return op.emitOpError("affine.prefetch affine map num results must equal"
2085                             " memref rank");
2086     if (map.getNumInputs() + 1 != op.getNumOperands())
2087       return op.emitOpError("too few operands");
2088   } else {
2089     if (op.getNumOperands() != 1)
2090       return op.emitOpError("too few operands");
2091   }
2092 
2093   for (auto idx : op.getMapOperands()) {
2094     if (!isValidAffineIndexOperand(idx))
2095       return op.emitOpError("index must be a dimension or symbol identifier");
2096   }
2097   return success();
2098 }
2099 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2100 void AffinePrefetchOp::getCanonicalizationPatterns(
2101     OwningRewritePatternList &results, MLIRContext *context) {
2102   // prefetch(memrefcast) -> prefetch
2103   results.insert<SimplifyAffineOp<AffinePrefetchOp>>(context);
2104 }
2105 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)2106 LogicalResult AffinePrefetchOp::fold(ArrayRef<Attribute> cstOperands,
2107                                      SmallVectorImpl<OpFoldResult> &results) {
2108   /// prefetch(memrefcast) -> prefetch
2109   return foldMemRefCast(*this);
2110 }
2111 
2112 //===----------------------------------------------------------------------===//
2113 // TableGen'd op method definitions
2114 //===----------------------------------------------------------------------===//
2115 
2116 #define GET_OP_CLASSES
2117 #include "mlir/Dialect/AffineOps/AffineOps.cpp.inc"
2118