1 //===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/IR/BlockAndValueMapping.h"
13 #include "mlir/IR/BuiltinOps.h"
14 #include "mlir/IR/IntegerSet.h"
15 #include "mlir/IR/Matchers.h"
16 #include "mlir/IR/OpImplementation.h"
17 #include "mlir/IR/PatternMatch.h"
18 #include "mlir/Transforms/InliningUtils.h"
19 #include "llvm/ADT/SetVector.h"
20 #include "llvm/ADT/SmallBitVector.h"
21 #include "llvm/ADT/TypeSwitch.h"
22 #include "llvm/Support/Debug.h"
23 
24 using namespace mlir;
25 using llvm::dbgs;
26 
27 #define DEBUG_TYPE "affine-analysis"
28 
29 /// A utility function to check if a value is defined at the top level of
30 /// `region` or is an argument of `region`. A value of index type defined at the
31 /// top level of a `AffineScope` region is always a valid symbol for all
32 /// uses in that region.
isTopLevelValue(Value value,Region * region)33 static bool isTopLevelValue(Value value, Region *region) {
34   if (auto arg = value.dyn_cast<BlockArgument>())
35     return arg.getParentRegion() == region;
36   return value.getDefiningOp()->getParentRegion() == region;
37 }
38 
39 /// Checks if `value` known to be a legal affine dimension or symbol in `src`
40 /// region remains legal if the operation that uses it is inlined into `dest`
41 /// with the given value mapping. `legalityCheck` is either `isValidDim` or
42 /// `isValidSymbol`, depending on the value being required to remain a valid
43 /// dimension or symbol.
44 static bool
remainsLegalAfterInline(Value value,Region * src,Region * dest,const BlockAndValueMapping & mapping,function_ref<bool (Value,Region *)> legalityCheck)45 remainsLegalAfterInline(Value value, Region *src, Region *dest,
46                         const BlockAndValueMapping &mapping,
47                         function_ref<bool(Value, Region *)> legalityCheck) {
48   // If the value is a valid dimension for any other reason than being
49   // a top-level value, it will remain valid: constants get inlined
50   // with the function, transitive affine applies also get inlined and
51   // will be checked themselves, etc.
52   if (!isTopLevelValue(value, src))
53     return true;
54 
55   // If it's a top-level value because it's a block operand, i.e. a
56   // function argument, check whether the value replacing it after
57   // inlining is a valid dimension in the new region.
58   if (value.isa<BlockArgument>())
59     return legalityCheck(mapping.lookup(value), dest);
60 
61   // If it's a top-level value beacuse it's defined in the region,
62   // it can only be inlined if the defining op is a constant or a
63   // `dim`, which can appear anywhere and be valid, since the defining
64   // op won't be top-level anymore after inlining.
65   Attribute operandCst;
66   return matchPattern(value.getDefiningOp(), m_Constant(&operandCst)) ||
67          value.getDefiningOp<DimOp>();
68 }
69 
70 /// Checks if all values known to be legal affine dimensions or symbols in `src`
71 /// remain so if their respective users are inlined into `dest`.
72 static bool
remainsLegalAfterInline(ValueRange values,Region * src,Region * dest,const BlockAndValueMapping & mapping,function_ref<bool (Value,Region *)> legalityCheck)73 remainsLegalAfterInline(ValueRange values, Region *src, Region *dest,
74                         const BlockAndValueMapping &mapping,
75                         function_ref<bool(Value, Region *)> legalityCheck) {
76   return llvm::all_of(values, [&](Value v) {
77     return remainsLegalAfterInline(v, src, dest, mapping, legalityCheck);
78   });
79 }
80 
81 /// Checks if an affine read or write operation remains legal after inlining
82 /// from `src` to `dest`.
83 template <typename OpTy>
remainsLegalAfterInline(OpTy op,Region * src,Region * dest,const BlockAndValueMapping & mapping)84 static bool remainsLegalAfterInline(OpTy op, Region *src, Region *dest,
85                                     const BlockAndValueMapping &mapping) {
86   static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
87                                 AffineWriteOpInterface>::value,
88                 "only ops with affine read/write interface are supported");
89 
90   AffineMap map = op.getAffineMap();
91   ValueRange dimOperands = op.getMapOperands().take_front(map.getNumDims());
92   ValueRange symbolOperands =
93       op.getMapOperands().take_back(map.getNumSymbols());
94   if (!remainsLegalAfterInline(
95           dimOperands, src, dest, mapping,
96           static_cast<bool (*)(Value, Region *)>(isValidDim)))
97     return false;
98   if (!remainsLegalAfterInline(
99           symbolOperands, src, dest, mapping,
100           static_cast<bool (*)(Value, Region *)>(isValidSymbol)))
101     return false;
102   return true;
103 }
104 
105 /// Checks if an affine apply operation remains legal after inlining from `src`
106 /// to `dest`.
107 template <>
remainsLegalAfterInline(AffineApplyOp op,Region * src,Region * dest,const BlockAndValueMapping & mapping)108 bool remainsLegalAfterInline(AffineApplyOp op, Region *src, Region *dest,
109                              const BlockAndValueMapping &mapping) {
110   // If it's a valid dimension, we need to check that it remains so.
111   if (isValidDim(op.getResult(), src))
112     return remainsLegalAfterInline(
113         op.getMapOperands(), src, dest, mapping,
114         static_cast<bool (*)(Value, Region *)>(isValidDim));
115 
116   // Otherwise it must be a valid symbol, check that it remains so.
117   return remainsLegalAfterInline(
118       op.getMapOperands(), src, dest, mapping,
119       static_cast<bool (*)(Value, Region *)>(isValidSymbol));
120 }
121 
122 //===----------------------------------------------------------------------===//
123 // AffineDialect Interfaces
124 //===----------------------------------------------------------------------===//
125 
126 namespace {
127 /// This class defines the interface for handling inlining with affine
128 /// operations.
129 struct AffineInlinerInterface : public DialectInlinerInterface {
130   using DialectInlinerInterface::DialectInlinerInterface;
131 
132   //===--------------------------------------------------------------------===//
133   // Analysis Hooks
134   //===--------------------------------------------------------------------===//
135 
136   /// Returns true if the given region 'src' can be inlined into the region
137   /// 'dest' that is attached to an operation registered to the current dialect.
138   /// 'wouldBeCloned' is set if the region is cloned into its new location
139   /// rather than moved, indicating there may be other users.
isLegalToInline__anone6b2d6170211::AffineInlinerInterface140   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
141                        BlockAndValueMapping &valueMapping) const final {
142     // We can inline into affine loops and conditionals if this doesn't break
143     // affine value categorization rules.
144     Operation *destOp = dest->getParentOp();
145     if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
146       return false;
147 
148     // Multi-block regions cannot be inlined into affine constructs, all of
149     // which require single-block regions.
150     if (!llvm::hasSingleElement(*src))
151       return false;
152 
153     // Side-effecting operations that the affine dialect cannot understand
154     // should not be inlined.
155     Block &srcBlock = src->front();
156     for (Operation &op : srcBlock) {
157       // Ops with no side effects are fine,
158       if (auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
159         if (iface.hasNoEffect())
160           continue;
161       }
162 
163       // Assuming the inlined region is valid, we only need to check if the
164       // inlining would change it.
165       bool remainsValid =
166           llvm::TypeSwitch<Operation *, bool>(&op)
167               .Case<AffineApplyOp, AffineReadOpInterface,
168                     AffineWriteOpInterface>([&](auto op) {
169                 return remainsLegalAfterInline(op, src, dest, valueMapping);
170               })
171               .Default([](Operation *) {
172                 // Conservatively disallow inlining ops we cannot reason about.
173                 return false;
174               });
175 
176       if (!remainsValid)
177         return false;
178     }
179 
180     return true;
181   }
182 
183   /// Returns true if the given operation 'op', that is registered to this
184   /// dialect, can be inlined into the given region, false otherwise.
isLegalToInline__anone6b2d6170211::AffineInlinerInterface185   bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
186                        BlockAndValueMapping &valueMapping) const final {
187     // Always allow inlining affine operations into a region that is marked as
188     // affine scope, or into affine loops and conditionals. There are some edge
189     // cases when inlining *into* affine structures, but that is handled in the
190     // other 'isLegalToInline' hook above.
191     Operation *parentOp = region->getParentOp();
192     return parentOp->hasTrait<OpTrait::AffineScope>() ||
193            isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
194   }
195 
196   /// Affine regions should be analyzed recursively.
shouldAnalyzeRecursively__anone6b2d6170211::AffineInlinerInterface197   bool shouldAnalyzeRecursively(Operation *op) const final { return true; }
198 };
199 } // end anonymous namespace
200 
201 //===----------------------------------------------------------------------===//
202 // AffineDialect
203 //===----------------------------------------------------------------------===//
204 
initialize()205 void AffineDialect::initialize() {
206   addOperations<AffineDmaStartOp, AffineDmaWaitOp,
207 #define GET_OP_LIST
208 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
209                 >();
210   addInterfaces<AffineInlinerInterface>();
211 }
212 
213 /// Materialize a single constant operation from a given attribute value with
214 /// the desired resultant type.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)215 Operation *AffineDialect::materializeConstant(OpBuilder &builder,
216                                               Attribute value, Type type,
217                                               Location loc) {
218   return builder.create<ConstantOp>(loc, type, value);
219 }
220 
221 /// A utility function to check if a value is defined at the top level of an
222 /// op with trait `AffineScope`. If the value is defined in an unlinked region,
223 /// conservatively assume it is not top-level. A value of index type defined at
224 /// the top level is always a valid symbol.
isTopLevelValue(Value value)225 bool mlir::isTopLevelValue(Value value) {
226   if (auto arg = value.dyn_cast<BlockArgument>()) {
227     // The block owning the argument may be unlinked, e.g. when the surrounding
228     // region has not yet been attached to an Op, at which point the parent Op
229     // is null.
230     Operation *parentOp = arg.getOwner()->getParentOp();
231     return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
232   }
233   // The defining Op may live in an unlinked block so its parent Op may be null.
234   Operation *parentOp = value.getDefiningOp()->getParentOp();
235   return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
236 }
237 
238 /// Returns the closest region enclosing `op` that is held by an operation with
239 /// trait `AffineScope`; `nullptr` if there is no such region.
240 //  TODO: getAffineScope should be publicly exposed for affine passes/utilities.
getAffineScope(Operation * op)241 static Region *getAffineScope(Operation *op) {
242   auto *curOp = op;
243   while (auto *parentOp = curOp->getParentOp()) {
244     if (parentOp->hasTrait<OpTrait::AffineScope>())
245       return curOp->getParentRegion();
246     curOp = parentOp;
247   }
248   return nullptr;
249 }
250 
251 // A Value can be used as a dimension id iff it meets one of the following
252 // conditions:
253 // *) It is valid as a symbol.
254 // *) It is an induction variable.
255 // *) It is the result of affine apply operation with dimension id arguments.
isValidDim(Value value)256 bool mlir::isValidDim(Value value) {
257   // The value must be an index type.
258   if (!value.getType().isIndex())
259     return false;
260 
261   if (auto *defOp = value.getDefiningOp())
262     return isValidDim(value, getAffineScope(defOp));
263 
264   // This value has to be a block argument for an op that has the
265   // `AffineScope` trait or for an affine.for or affine.parallel.
266   auto *parentOp = value.cast<BlockArgument>().getOwner()->getParentOp();
267   return parentOp && (parentOp->hasTrait<OpTrait::AffineScope>() ||
268                       isa<AffineForOp, AffineParallelOp>(parentOp));
269 }
270 
271 // Value can be used as a dimension id iff it meets one of the following
272 // conditions:
273 // *) It is valid as a symbol.
274 // *) It is an induction variable.
275 // *) It is the result of an affine apply operation with dimension id operands.
isValidDim(Value value,Region * region)276 bool mlir::isValidDim(Value value, Region *region) {
277   // The value must be an index type.
278   if (!value.getType().isIndex())
279     return false;
280 
281   // All valid symbols are okay.
282   if (isValidSymbol(value, region))
283     return true;
284 
285   auto *op = value.getDefiningOp();
286   if (!op) {
287     // This value has to be a block argument for an affine.for or an
288     // affine.parallel.
289     auto *parentOp = value.cast<BlockArgument>().getOwner()->getParentOp();
290     return isa<AffineForOp, AffineParallelOp>(parentOp);
291   }
292 
293   // Affine apply operation is ok if all of its operands are ok.
294   if (auto applyOp = dyn_cast<AffineApplyOp>(op))
295     return applyOp.isValidDim(region);
296   // The dim op is okay if its operand memref/tensor is defined at the top
297   // level.
298   if (auto dimOp = dyn_cast<DimOp>(op))
299     return isTopLevelValue(dimOp.memrefOrTensor());
300   return false;
301 }
302 
303 /// Returns true if the 'index' dimension of the `memref` defined by
304 /// `memrefDefOp` is a statically  shaped one or defined using a valid symbol
305 /// for `region`.
306 template <typename AnyMemRefDefOp>
isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp,unsigned index,Region * region)307 static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index,
308                                     Region *region) {
309   auto memRefType = memrefDefOp.getType();
310   // Statically shaped.
311   if (!memRefType.isDynamicDim(index))
312     return true;
313   // Get the position of the dimension among dynamic dimensions;
314   unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
315   return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
316                        region);
317 }
318 
319 /// Returns true if the result of the dim op is a valid symbol for `region`.
isDimOpValidSymbol(DimOp dimOp,Region * region)320 static bool isDimOpValidSymbol(DimOp dimOp, Region *region) {
321   // The dim op is okay if its operand memref/tensor is defined at the top
322   // level.
323   if (isTopLevelValue(dimOp.memrefOrTensor()))
324     return true;
325 
326   // Conservatively handle remaining BlockArguments as non-valid symbols.
327   // E.g. scf.for iterArgs.
328   if (dimOp.memrefOrTensor().isa<BlockArgument>())
329     return false;
330 
331   // The dim op is also okay if its operand memref/tensor is a view/subview
332   // whose corresponding size is a valid symbol.
333   Optional<int64_t> index = dimOp.getConstantIndex();
334   assert(index.hasValue() &&
335          "expect only `dim` operations with a constant index");
336   int64_t i = index.getValue();
337   return TypeSwitch<Operation *, bool>(dimOp.memrefOrTensor().getDefiningOp())
338       .Case<ViewOp, SubViewOp, AllocOp>(
339           [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); })
340       .Default([](Operation *) { return false; });
341 }
342 
343 // A value can be used as a symbol (at all its use sites) iff it meets one of
344 // the following conditions:
345 // *) It is a constant.
346 // *) Its defining op or block arg appearance is immediately enclosed by an op
347 //    with `AffineScope` trait.
348 // *) It is the result of an affine.apply operation with symbol operands.
349 // *) It is a result of the dim op on a memref whose corresponding size is a
350 //    valid symbol.
isValidSymbol(Value value)351 bool mlir::isValidSymbol(Value value) {
352   // The value must be an index type.
353   if (!value.getType().isIndex())
354     return false;
355 
356   // Check that the value is a top level value.
357   if (isTopLevelValue(value))
358     return true;
359 
360   if (auto *defOp = value.getDefiningOp())
361     return isValidSymbol(value, getAffineScope(defOp));
362 
363   return false;
364 }
365 
366 /// A value can be used as a symbol for `region` iff it meets onf of the the
367 /// following conditions:
368 /// *) It is a constant.
369 /// *) It is the result of an affine apply operation with symbol arguments.
370 /// *) It is a result of the dim op on a memref whose corresponding size is
371 ///    a valid symbol.
372 /// *) It is defined at the top level of 'region' or is its argument.
373 /// *) It dominates `region`'s parent op.
374 /// If `region` is null, conservatively assume the symbol definition scope does
375 /// not exist and only accept the values that would be symbols regardless of
376 /// the surrounding region structure, i.e. the first three cases above.
isValidSymbol(Value value,Region * region)377 bool mlir::isValidSymbol(Value value, Region *region) {
378   // The value must be an index type.
379   if (!value.getType().isIndex())
380     return false;
381 
382   // A top-level value is a valid symbol.
383   if (region && ::isTopLevelValue(value, region))
384     return true;
385 
386   auto *defOp = value.getDefiningOp();
387   if (!defOp) {
388     // A block argument that is not a top-level value is a valid symbol if it
389     // dominates region's parent op.
390     if (region && !region->getParentOp()->isKnownIsolatedFromAbove())
391       if (auto *parentOpRegion = region->getParentOp()->getParentRegion())
392         return isValidSymbol(value, parentOpRegion);
393     return false;
394   }
395 
396   // Constant operation is ok.
397   Attribute operandCst;
398   if (matchPattern(defOp, m_Constant(&operandCst)))
399     return true;
400 
401   // Affine apply operation is ok if all of its operands are ok.
402   if (auto applyOp = dyn_cast<AffineApplyOp>(defOp))
403     return applyOp.isValidSymbol(region);
404 
405   // Dim op results could be valid symbols at any level.
406   if (auto dimOp = dyn_cast<DimOp>(defOp))
407     return isDimOpValidSymbol(dimOp, region);
408 
409   // Check for values dominating `region`'s parent op.
410   if (region && !region->getParentOp()->isKnownIsolatedFromAbove())
411     if (auto *parentRegion = region->getParentOp()->getParentRegion())
412       return isValidSymbol(value, parentRegion);
413 
414   return false;
415 }
416 
417 // Returns true if 'value' is a valid index to an affine operation (e.g.
418 // affine.load, affine.store, affine.dma_start, affine.dma_wait) where
419 // `region` provides the polyhedral symbol scope. Returns false otherwise.
isValidAffineIndexOperand(Value value,Region * region)420 static bool isValidAffineIndexOperand(Value value, Region *region) {
421   return isValidDim(value, region) || isValidSymbol(value, region);
422 }
423 
424 /// Prints dimension and symbol list.
printDimAndSymbolList(Operation::operand_iterator begin,Operation::operand_iterator end,unsigned numDims,OpAsmPrinter & printer)425 static void printDimAndSymbolList(Operation::operand_iterator begin,
426                                   Operation::operand_iterator end,
427                                   unsigned numDims, OpAsmPrinter &printer) {
428   OperandRange operands(begin, end);
429   printer << '(' << operands.take_front(numDims) << ')';
430   if (operands.size() > numDims)
431     printer << '[' << operands.drop_front(numDims) << ']';
432 }
433 
434 /// Parses dimension and symbol list and returns true if parsing failed.
parseDimAndSymbolList(OpAsmParser & parser,SmallVectorImpl<Value> & operands,unsigned & numDims)435 ParseResult mlir::parseDimAndSymbolList(OpAsmParser &parser,
436                                         SmallVectorImpl<Value> &operands,
437                                         unsigned &numDims) {
438   SmallVector<OpAsmParser::OperandType, 8> opInfos;
439   if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren))
440     return failure();
441   // Store number of dimensions for validation by caller.
442   numDims = opInfos.size();
443 
444   // Parse the optional symbol operands.
445   auto indexTy = parser.getBuilder().getIndexType();
446   return failure(parser.parseOperandList(
447                      opInfos, OpAsmParser::Delimiter::OptionalSquare) ||
448                  parser.resolveOperands(opInfos, indexTy, operands));
449 }
450 
451 /// Utility function to verify that a set of operands are valid dimension and
452 /// symbol identifiers. The operands should be laid out such that the dimension
453 /// operands are before the symbol operands. This function returns failure if
454 /// there was an invalid operand. An operation is provided to emit any necessary
455 /// errors.
456 template <typename OpTy>
457 static LogicalResult
verifyDimAndSymbolIdentifiers(OpTy & op,Operation::operand_range operands,unsigned numDims)458 verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands,
459                               unsigned numDims) {
460   unsigned opIt = 0;
461   for (auto operand : operands) {
462     if (opIt++ < numDims) {
463       if (!isValidDim(operand, getAffineScope(op)))
464         return op.emitOpError("operand cannot be used as a dimension id");
465     } else if (!isValidSymbol(operand, getAffineScope(op))) {
466       return op.emitOpError("operand cannot be used as a symbol");
467     }
468   }
469   return success();
470 }
471 
472 //===----------------------------------------------------------------------===//
473 // AffineApplyOp
474 //===----------------------------------------------------------------------===//
475 
getAffineValueMap()476 AffineValueMap AffineApplyOp::getAffineValueMap() {
477   return AffineValueMap(getAffineMap(), getOperands(), getResult());
478 }
479 
parseAffineApplyOp(OpAsmParser & parser,OperationState & result)480 static ParseResult parseAffineApplyOp(OpAsmParser &parser,
481                                       OperationState &result) {
482   auto &builder = parser.getBuilder();
483   auto indexTy = builder.getIndexType();
484 
485   AffineMapAttr mapAttr;
486   unsigned numDims;
487   if (parser.parseAttribute(mapAttr, "map", result.attributes) ||
488       parseDimAndSymbolList(parser, result.operands, numDims) ||
489       parser.parseOptionalAttrDict(result.attributes))
490     return failure();
491   auto map = mapAttr.getValue();
492 
493   if (map.getNumDims() != numDims ||
494       numDims + map.getNumSymbols() != result.operands.size()) {
495     return parser.emitError(parser.getNameLoc(),
496                             "dimension or symbol index mismatch");
497   }
498 
499   result.types.append(map.getNumResults(), indexTy);
500   return success();
501 }
502 
print(OpAsmPrinter & p,AffineApplyOp op)503 static void print(OpAsmPrinter &p, AffineApplyOp op) {
504   p << AffineApplyOp::getOperationName() << " " << op.mapAttr();
505   printDimAndSymbolList(op.operand_begin(), op.operand_end(),
506                         op.getAffineMap().getNumDims(), p);
507   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"map"});
508 }
509 
verify(AffineApplyOp op)510 static LogicalResult verify(AffineApplyOp op) {
511   // Check input and output dimensions match.
512   auto map = op.map();
513 
514   // Verify that operand count matches affine map dimension and symbol count.
515   if (op.getNumOperands() != map.getNumDims() + map.getNumSymbols())
516     return op.emitOpError(
517         "operand count and affine map dimension and symbol count must match");
518 
519   // Verify that the map only produces one result.
520   if (map.getNumResults() != 1)
521     return op.emitOpError("mapping must produce one value");
522 
523   return success();
524 }
525 
526 // The result of the affine apply operation can be used as a dimension id if all
527 // its operands are valid dimension ids.
isValidDim()528 bool AffineApplyOp::isValidDim() {
529   return llvm::all_of(getOperands(),
530                       [](Value op) { return mlir::isValidDim(op); });
531 }
532 
533 // The result of the affine apply operation can be used as a dimension id if all
534 // its operands are valid dimension ids with the parent operation of `region`
535 // defining the polyhedral scope for symbols.
isValidDim(Region * region)536 bool AffineApplyOp::isValidDim(Region *region) {
537   return llvm::all_of(getOperands(),
538                       [&](Value op) { return ::isValidDim(op, region); });
539 }
540 
541 // The result of the affine apply operation can be used as a symbol if all its
542 // operands are symbols.
isValidSymbol()543 bool AffineApplyOp::isValidSymbol() {
544   return llvm::all_of(getOperands(),
545                       [](Value op) { return mlir::isValidSymbol(op); });
546 }
547 
548 // The result of the affine apply operation can be used as a symbol in `region`
549 // if all its operands are symbols in `region`.
isValidSymbol(Region * region)550 bool AffineApplyOp::isValidSymbol(Region *region) {
551   return llvm::all_of(getOperands(), [&](Value operand) {
552     return mlir::isValidSymbol(operand, region);
553   });
554 }
555 
fold(ArrayRef<Attribute> operands)556 OpFoldResult AffineApplyOp::fold(ArrayRef<Attribute> operands) {
557   auto map = getAffineMap();
558 
559   // Fold dims and symbols to existing values.
560   auto expr = map.getResult(0);
561   if (auto dim = expr.dyn_cast<AffineDimExpr>())
562     return getOperand(dim.getPosition());
563   if (auto sym = expr.dyn_cast<AffineSymbolExpr>())
564     return getOperand(map.getNumDims() + sym.getPosition());
565 
566   // Otherwise, default to folding the map.
567   SmallVector<Attribute, 1> result;
568   if (failed(map.constantFold(operands, result)))
569     return {};
570   return result[0];
571 }
572 
573 /// Replace all occurrences of AffineExpr at position `pos` in `map` by the
574 /// defining AffineApplyOp expression and operands.
575 /// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
576 /// When `dimOrSymbolPosition >= dims.size()`,
577 /// AffineSymbolExpr@[pos - dims.size()] is replaced.
578 /// Mutate `map`,`dims` and `syms` in place as follows:
579 ///   1. `dims` and `syms` are only appended to.
580 ///   2. `map` dim and symbols are gradually shifted to higer positions.
581 ///   3. Old `dim` and `sym` entries are replaced by nullptr
582 /// This avoids the need for any bookkeeping.
replaceDimOrSym(AffineMap * map,unsigned dimOrSymbolPosition,SmallVectorImpl<Value> & dims,SmallVectorImpl<Value> & syms)583 static LogicalResult replaceDimOrSym(AffineMap *map,
584                                      unsigned dimOrSymbolPosition,
585                                      SmallVectorImpl<Value> &dims,
586                                      SmallVectorImpl<Value> &syms) {
587   bool isDimReplacement = (dimOrSymbolPosition < dims.size());
588   unsigned pos = isDimReplacement ? dimOrSymbolPosition
589                                   : dimOrSymbolPosition - dims.size();
590   Value &v = isDimReplacement ? dims[pos] : syms[pos];
591   if (!v)
592     return failure();
593 
594   auto affineApply = v.getDefiningOp<AffineApplyOp>();
595   if (!affineApply)
596     return failure();
597 
598   // At this point we will perform a replacement of `v`, set the entry in `dim`
599   // or `sym` to nullptr immediately.
600   v = nullptr;
601 
602   // Compute the map, dims and symbols coming from the AffineApplyOp.
603   AffineMap composeMap = affineApply.getAffineMap();
604   assert(composeMap.getNumResults() == 1 && "affine.apply with >1 results");
605   AffineExpr composeExpr =
606       composeMap.shiftDims(dims.size()).shiftSymbols(syms.size()).getResult(0);
607   ValueRange composeDims =
608       affineApply.getMapOperands().take_front(composeMap.getNumDims());
609   ValueRange composeSyms =
610       affineApply.getMapOperands().take_back(composeMap.getNumSymbols());
611 
612   // Perform the replacement and append the dims and symbols where relevant.
613   MLIRContext *ctx = map->getContext();
614   AffineExpr toReplace = isDimReplacement ? getAffineDimExpr(pos, ctx)
615                                           : getAffineSymbolExpr(pos, ctx);
616   *map = map->replace(toReplace, composeExpr, dims.size(), syms.size());
617   dims.append(composeDims.begin(), composeDims.end());
618   syms.append(composeSyms.begin(), composeSyms.end());
619 
620   return success();
621 }
622 
623 /// Iterate over `operands` and fold away all those produced by an AffineApplyOp
624 /// iteratively. Perform canonicalization of map and operands as well as
625 /// AffineMap simplification. `map` and `operands` are mutated in place.
composeAffineMapAndOperands(AffineMap * map,SmallVectorImpl<Value> * operands)626 static void composeAffineMapAndOperands(AffineMap *map,
627                                         SmallVectorImpl<Value> *operands) {
628   if (map->getNumResults() == 0) {
629     canonicalizeMapAndOperands(map, operands);
630     *map = simplifyAffineMap(*map);
631     return;
632   }
633 
634   MLIRContext *ctx = map->getContext();
635   SmallVector<Value, 4> dims(operands->begin(),
636                              operands->begin() + map->getNumDims());
637   SmallVector<Value, 4> syms(operands->begin() + map->getNumDims(),
638                              operands->end());
639 
640   // Iterate over dims and symbols coming from AffineApplyOp and replace until
641   // exhaustion. This iteratively mutates `map`, `dims` and `syms`. Both `dims`
642   // and `syms` can only increase by construction.
643   // The implementation uses a `while` loop to support the case of symbols
644   // that may be constructed from dims ;this may be overkill.
645   while (true) {
646     bool changed = false;
647     for (unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
648       if ((changed |= succeeded(replaceDimOrSym(map, pos, dims, syms))))
649         break;
650     if (!changed)
651       break;
652   }
653 
654   // Clear operands so we can fill them anew.
655   operands->clear();
656 
657   // At this point we may have introduced null operands, prune them out before
658   // canonicalizing map and operands.
659   unsigned nDims = 0, nSyms = 0;
660   SmallVector<AffineExpr, 4> dimReplacements, symReplacements;
661   dimReplacements.reserve(dims.size());
662   symReplacements.reserve(syms.size());
663   for (auto *container : {&dims, &syms}) {
664     bool isDim = (container == &dims);
665     auto &repls = isDim ? dimReplacements : symReplacements;
666     for (auto en : llvm::enumerate(*container)) {
667       Value v = en.value();
668       if (!v) {
669         assert(isDim ? !map->isFunctionOfDim(en.index())
670                      : !map->isFunctionOfSymbol(en.index()) &&
671                            "map is function of unexpected expr@pos");
672         repls.push_back(getAffineConstantExpr(0, ctx));
673         continue;
674       }
675       repls.push_back(isDim ? getAffineDimExpr(nDims++, ctx)
676                             : getAffineSymbolExpr(nSyms++, ctx));
677       operands->push_back(v);
678     }
679   }
680   *map = map->replaceDimsAndSymbols(dimReplacements, symReplacements, nDims,
681                                     nSyms);
682 
683   // Canonicalize and simplify before returning.
684   canonicalizeMapAndOperands(map, operands);
685   *map = simplifyAffineMap(*map);
686 }
687 
fullyComposeAffineMapAndOperands(AffineMap * map,SmallVectorImpl<Value> * operands)688 void mlir::fullyComposeAffineMapAndOperands(AffineMap *map,
689                                             SmallVectorImpl<Value> *operands) {
690   while (llvm::any_of(*operands, [](Value v) {
691     return isa_and_nonnull<AffineApplyOp>(v.getDefiningOp());
692   })) {
693     composeAffineMapAndOperands(map, operands);
694   }
695 }
696 
makeComposedAffineApply(OpBuilder & b,Location loc,AffineMap map,ArrayRef<Value> operands)697 AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc,
698                                             AffineMap map,
699                                             ArrayRef<Value> operands) {
700   AffineMap normalizedMap = map;
701   SmallVector<Value, 8> normalizedOperands(operands.begin(), operands.end());
702   composeAffineMapAndOperands(&normalizedMap, &normalizedOperands);
703   assert(normalizedMap);
704   return b.create<AffineApplyOp>(loc, normalizedMap, normalizedOperands);
705 }
706 
707 // A symbol may appear as a dim in affine.apply operations. This function
708 // canonicalizes dims that are valid symbols into actual symbols.
709 template <class MapOrSet>
canonicalizePromotedSymbols(MapOrSet * mapOrSet,SmallVectorImpl<Value> * operands)710 static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,
711                                         SmallVectorImpl<Value> *operands) {
712   if (!mapOrSet || operands->empty())
713     return;
714 
715   assert(mapOrSet->getNumInputs() == operands->size() &&
716          "map/set inputs must match number of operands");
717 
718   auto *context = mapOrSet->getContext();
719   SmallVector<Value, 8> resultOperands;
720   resultOperands.reserve(operands->size());
721   SmallVector<Value, 8> remappedSymbols;
722   remappedSymbols.reserve(operands->size());
723   unsigned nextDim = 0;
724   unsigned nextSym = 0;
725   unsigned oldNumSyms = mapOrSet->getNumSymbols();
726   SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
727   for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
728     if (i < mapOrSet->getNumDims()) {
729       if (isValidSymbol((*operands)[i])) {
730         // This is a valid symbol that appears as a dim, canonicalize it.
731         dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context);
732         remappedSymbols.push_back((*operands)[i]);
733       } else {
734         dimRemapping[i] = getAffineDimExpr(nextDim++, context);
735         resultOperands.push_back((*operands)[i]);
736       }
737     } else {
738       resultOperands.push_back((*operands)[i]);
739     }
740   }
741 
742   resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
743   *operands = resultOperands;
744   *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
745                                               oldNumSyms + nextSym);
746 
747   assert(mapOrSet->getNumInputs() == operands->size() &&
748          "map/set inputs must match number of operands");
749 }
750 
751 // Works for either an affine map or an integer set.
752 template <class MapOrSet>
canonicalizeMapOrSetAndOperands(MapOrSet * mapOrSet,SmallVectorImpl<Value> * operands)753 static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
754                                             SmallVectorImpl<Value> *operands) {
755   static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
756                 "Argument must be either of AffineMap or IntegerSet type");
757 
758   if (!mapOrSet || operands->empty())
759     return;
760 
761   assert(mapOrSet->getNumInputs() == operands->size() &&
762          "map/set inputs must match number of operands");
763 
764   canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
765 
766   // Check to see what dims are used.
767   llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
768   llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
769   mapOrSet->walkExprs([&](AffineExpr expr) {
770     if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
771       usedDims[dimExpr.getPosition()] = true;
772     else if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
773       usedSyms[symExpr.getPosition()] = true;
774   });
775 
776   auto *context = mapOrSet->getContext();
777 
778   SmallVector<Value, 8> resultOperands;
779   resultOperands.reserve(operands->size());
780 
781   llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
782   SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
783   unsigned nextDim = 0;
784   for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
785     if (usedDims[i]) {
786       // Remap dim positions for duplicate operands.
787       auto it = seenDims.find((*operands)[i]);
788       if (it == seenDims.end()) {
789         dimRemapping[i] = getAffineDimExpr(nextDim++, context);
790         resultOperands.push_back((*operands)[i]);
791         seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
792       } else {
793         dimRemapping[i] = it->second;
794       }
795     }
796   }
797   llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
798   SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
799   unsigned nextSym = 0;
800   for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
801     if (!usedSyms[i])
802       continue;
803     // Handle constant operands (only needed for symbolic operands since
804     // constant operands in dimensional positions would have already been
805     // promoted to symbolic positions above).
806     IntegerAttr operandCst;
807     if (matchPattern((*operands)[i + mapOrSet->getNumDims()],
808                      m_Constant(&operandCst))) {
809       symRemapping[i] =
810           getAffineConstantExpr(operandCst.getValue().getSExtValue(), context);
811       continue;
812     }
813     // Remap symbol positions for duplicate operands.
814     auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
815     if (it == seenSymbols.end()) {
816       symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
817       resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
818       seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
819                                         symRemapping[i]));
820     } else {
821       symRemapping[i] = it->second;
822     }
823   }
824   *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
825                                               nextDim, nextSym);
826   *operands = resultOperands;
827 }
828 
canonicalizeMapAndOperands(AffineMap * map,SmallVectorImpl<Value> * operands)829 void mlir::canonicalizeMapAndOperands(AffineMap *map,
830                                       SmallVectorImpl<Value> *operands) {
831   canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
832 }
833 
canonicalizeSetAndOperands(IntegerSet * set,SmallVectorImpl<Value> * operands)834 void mlir::canonicalizeSetAndOperands(IntegerSet *set,
835                                       SmallVectorImpl<Value> *operands) {
836   canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
837 }
838 
839 namespace {
840 /// Simplify AffineApply, AffineLoad, and AffineStore operations by composing
841 /// maps that supply results into them.
842 ///
843 template <typename AffineOpTy>
844 struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
845   using OpRewritePattern<AffineOpTy>::OpRewritePattern;
846 
847   /// Replace the affine op with another instance of it with the supplied
848   /// map and mapOperands.
849   void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
850                        AffineMap map, ArrayRef<Value> mapOperands) const;
851 
matchAndRewrite__anone6b2d6170d11::SimplifyAffineOp852   LogicalResult matchAndRewrite(AffineOpTy affineOp,
853                                 PatternRewriter &rewriter) const override {
854     static_assert(llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
855                                   AffineStoreOp, AffineApplyOp, AffineMinOp,
856                                   AffineMaxOp>::value,
857                   "affine load/store/apply/prefetch/min/max op expected");
858     auto map = affineOp.getAffineMap();
859     AffineMap oldMap = map;
860     auto oldOperands = affineOp.getMapOperands();
861     SmallVector<Value, 8> resultOperands(oldOperands);
862     composeAffineMapAndOperands(&map, &resultOperands);
863     if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
864                                     resultOperands.begin()))
865       return failure();
866 
867     replaceAffineOp(rewriter, affineOp, map, resultOperands);
868     return success();
869   }
870 };
871 
872 // Specialize the template to account for the different build signatures for
873 // affine load, store, and apply ops.
874 template <>
replaceAffineOp(PatternRewriter & rewriter,AffineLoadOp load,AffineMap map,ArrayRef<Value> mapOperands) const875 void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
876     PatternRewriter &rewriter, AffineLoadOp load, AffineMap map,
877     ArrayRef<Value> mapOperands) const {
878   rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.getMemRef(), map,
879                                             mapOperands);
880 }
881 template <>
replaceAffineOp(PatternRewriter & rewriter,AffinePrefetchOp prefetch,AffineMap map,ArrayRef<Value> mapOperands) const882 void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
883     PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map,
884     ArrayRef<Value> mapOperands) const {
885   rewriter.replaceOpWithNewOp<AffinePrefetchOp>(
886       prefetch, prefetch.memref(), map, mapOperands, prefetch.localityHint(),
887       prefetch.isWrite(), prefetch.isDataCache());
888 }
889 template <>
replaceAffineOp(PatternRewriter & rewriter,AffineStoreOp store,AffineMap map,ArrayRef<Value> mapOperands) const890 void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
891     PatternRewriter &rewriter, AffineStoreOp store, AffineMap map,
892     ArrayRef<Value> mapOperands) const {
893   rewriter.replaceOpWithNewOp<AffineStoreOp>(
894       store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
895 }
896 
897 // Generic version for ops that don't have extra operands.
898 template <typename AffineOpTy>
replaceAffineOp(PatternRewriter & rewriter,AffineOpTy op,AffineMap map,ArrayRef<Value> mapOperands) const899 void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
900     PatternRewriter &rewriter, AffineOpTy op, AffineMap map,
901     ArrayRef<Value> mapOperands) const {
902   rewriter.replaceOpWithNewOp<AffineOpTy>(op, map, mapOperands);
903 }
904 } // end anonymous namespace.
905 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)906 void AffineApplyOp::getCanonicalizationPatterns(
907     OwningRewritePatternList &results, MLIRContext *context) {
908   results.insert<SimplifyAffineOp<AffineApplyOp>>(context);
909 }
910 
911 //===----------------------------------------------------------------------===//
912 // Common canonicalization pattern support logic
913 //===----------------------------------------------------------------------===//
914 
915 /// This is a common class used for patterns of the form
916 /// "someop(memrefcast) -> someop".  It folds the source of any memref_cast
917 /// into the root operation directly.
foldMemRefCast(Operation * op)918 static LogicalResult foldMemRefCast(Operation *op) {
919   bool folded = false;
920   for (OpOperand &operand : op->getOpOperands()) {
921     auto cast = operand.get().getDefiningOp<MemRefCastOp>();
922     if (cast && !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
923       operand.set(cast.getOperand());
924       folded = true;
925     }
926   }
927   return success(folded);
928 }
929 
930 //===----------------------------------------------------------------------===//
931 // AffineDmaStartOp
932 //===----------------------------------------------------------------------===//
933 
934 // TODO: Check that map operands are loop IVs or symbols.
build(OpBuilder & builder,OperationState & result,Value srcMemRef,AffineMap srcMap,ValueRange srcIndices,Value destMemRef,AffineMap dstMap,ValueRange destIndices,Value tagMemRef,AffineMap tagMap,ValueRange tagIndices,Value numElements,Value stride,Value elementsPerStride)935 void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result,
936                              Value srcMemRef, AffineMap srcMap,
937                              ValueRange srcIndices, Value destMemRef,
938                              AffineMap dstMap, ValueRange destIndices,
939                              Value tagMemRef, AffineMap tagMap,
940                              ValueRange tagIndices, Value numElements,
941                              Value stride, Value elementsPerStride) {
942   result.addOperands(srcMemRef);
943   result.addAttribute(getSrcMapAttrName(), AffineMapAttr::get(srcMap));
944   result.addOperands(srcIndices);
945   result.addOperands(destMemRef);
946   result.addAttribute(getDstMapAttrName(), AffineMapAttr::get(dstMap));
947   result.addOperands(destIndices);
948   result.addOperands(tagMemRef);
949   result.addAttribute(getTagMapAttrName(), AffineMapAttr::get(tagMap));
950   result.addOperands(tagIndices);
951   result.addOperands(numElements);
952   if (stride) {
953     result.addOperands({stride, elementsPerStride});
954   }
955 }
956 
print(OpAsmPrinter & p)957 void AffineDmaStartOp::print(OpAsmPrinter &p) {
958   p << "affine.dma_start " << getSrcMemRef() << '[';
959   p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices());
960   p << "], " << getDstMemRef() << '[';
961   p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices());
962   p << "], " << getTagMemRef() << '[';
963   p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices());
964   p << "], " << getNumElements();
965   if (isStrided()) {
966     p << ", " << getStride();
967     p << ", " << getNumElementsPerStride();
968   }
969   p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
970     << getTagMemRefType();
971 }
972 
973 // Parse AffineDmaStartOp.
974 // Ex:
975 //   affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size,
976 //     %stride, %num_elt_per_stride
977 //       : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32>
978 //
parse(OpAsmParser & parser,OperationState & result)979 ParseResult AffineDmaStartOp::parse(OpAsmParser &parser,
980                                     OperationState &result) {
981   OpAsmParser::OperandType srcMemRefInfo;
982   AffineMapAttr srcMapAttr;
983   SmallVector<OpAsmParser::OperandType, 4> srcMapOperands;
984   OpAsmParser::OperandType dstMemRefInfo;
985   AffineMapAttr dstMapAttr;
986   SmallVector<OpAsmParser::OperandType, 4> dstMapOperands;
987   OpAsmParser::OperandType tagMemRefInfo;
988   AffineMapAttr tagMapAttr;
989   SmallVector<OpAsmParser::OperandType, 4> tagMapOperands;
990   OpAsmParser::OperandType numElementsInfo;
991   SmallVector<OpAsmParser::OperandType, 2> strideInfo;
992 
993   SmallVector<Type, 3> types;
994   auto indexType = parser.getBuilder().getIndexType();
995 
996   // Parse and resolve the following list of operands:
997   // *) dst memref followed by its affine maps operands (in square brackets).
998   // *) src memref followed by its affine map operands (in square brackets).
999   // *) tag memref followed by its affine map operands (in square brackets).
1000   // *) number of elements transferred by DMA operation.
1001   if (parser.parseOperand(srcMemRefInfo) ||
1002       parser.parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr,
1003                                     getSrcMapAttrName(), result.attributes) ||
1004       parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1005       parser.parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr,
1006                                     getDstMapAttrName(), result.attributes) ||
1007       parser.parseComma() || parser.parseOperand(tagMemRefInfo) ||
1008       parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1009                                     getTagMapAttrName(), result.attributes) ||
1010       parser.parseComma() || parser.parseOperand(numElementsInfo))
1011     return failure();
1012 
1013   // Parse optional stride and elements per stride.
1014   if (parser.parseTrailingOperandList(strideInfo)) {
1015     return failure();
1016   }
1017   if (!strideInfo.empty() && strideInfo.size() != 2) {
1018     return parser.emitError(parser.getNameLoc(),
1019                             "expected two stride related operands");
1020   }
1021   bool isStrided = strideInfo.size() == 2;
1022 
1023   if (parser.parseColonTypeList(types))
1024     return failure();
1025 
1026   if (types.size() != 3)
1027     return parser.emitError(parser.getNameLoc(), "expected three types");
1028 
1029   if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1030       parser.resolveOperands(srcMapOperands, indexType, result.operands) ||
1031       parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1032       parser.resolveOperands(dstMapOperands, indexType, result.operands) ||
1033       parser.resolveOperand(tagMemRefInfo, types[2], result.operands) ||
1034       parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1035       parser.resolveOperand(numElementsInfo, indexType, result.operands))
1036     return failure();
1037 
1038   if (isStrided) {
1039     if (parser.resolveOperands(strideInfo, indexType, result.operands))
1040       return failure();
1041   }
1042 
1043   // Check that src/dst/tag operand counts match their map.numInputs.
1044   if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
1045       dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
1046       tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1047     return parser.emitError(parser.getNameLoc(),
1048                             "memref operand count not equal to map.numInputs");
1049   return success();
1050 }
1051 
verify()1052 LogicalResult AffineDmaStartOp::verify() {
1053   if (!getOperand(getSrcMemRefOperandIndex()).getType().isa<MemRefType>())
1054     return emitOpError("expected DMA source to be of memref type");
1055   if (!getOperand(getDstMemRefOperandIndex()).getType().isa<MemRefType>())
1056     return emitOpError("expected DMA destination to be of memref type");
1057   if (!getOperand(getTagMemRefOperandIndex()).getType().isa<MemRefType>())
1058     return emitOpError("expected DMA tag to be of memref type");
1059 
1060   // DMAs from different memory spaces supported.
1061   if (getSrcMemorySpace() == getDstMemorySpace()) {
1062     return emitOpError("DMA should be between different memory spaces");
1063   }
1064   unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
1065                               getDstMap().getNumInputs() +
1066                               getTagMap().getNumInputs();
1067   if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
1068       getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
1069     return emitOpError("incorrect number of operands");
1070   }
1071 
1072   Region *scope = getAffineScope(*this);
1073   for (auto idx : getSrcIndices()) {
1074     if (!idx.getType().isIndex())
1075       return emitOpError("src index to dma_start must have 'index' type");
1076     if (!isValidAffineIndexOperand(idx, scope))
1077       return emitOpError("src index must be a dimension or symbol identifier");
1078   }
1079   for (auto idx : getDstIndices()) {
1080     if (!idx.getType().isIndex())
1081       return emitOpError("dst index to dma_start must have 'index' type");
1082     if (!isValidAffineIndexOperand(idx, scope))
1083       return emitOpError("dst index must be a dimension or symbol identifier");
1084   }
1085   for (auto idx : getTagIndices()) {
1086     if (!idx.getType().isIndex())
1087       return emitOpError("tag index to dma_start must have 'index' type");
1088     if (!isValidAffineIndexOperand(idx, scope))
1089       return emitOpError("tag index must be a dimension or symbol identifier");
1090   }
1091   return success();
1092 }
1093 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1094 LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1095                                      SmallVectorImpl<OpFoldResult> &results) {
1096   /// dma_start(memrefcast) -> dma_start
1097   return foldMemRefCast(*this);
1098 }
1099 
1100 //===----------------------------------------------------------------------===//
1101 // AffineDmaWaitOp
1102 //===----------------------------------------------------------------------===//
1103 
1104 // TODO: Check that map operands are loop IVs or symbols.
build(OpBuilder & builder,OperationState & result,Value tagMemRef,AffineMap tagMap,ValueRange tagIndices,Value numElements)1105 void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result,
1106                             Value tagMemRef, AffineMap tagMap,
1107                             ValueRange tagIndices, Value numElements) {
1108   result.addOperands(tagMemRef);
1109   result.addAttribute(getTagMapAttrName(), AffineMapAttr::get(tagMap));
1110   result.addOperands(tagIndices);
1111   result.addOperands(numElements);
1112 }
1113 
print(OpAsmPrinter & p)1114 void AffineDmaWaitOp::print(OpAsmPrinter &p) {
1115   p << "affine.dma_wait " << getTagMemRef() << '[';
1116   SmallVector<Value, 2> operands(getTagIndices());
1117   p.printAffineMapOfSSAIds(getTagMapAttr(), operands);
1118   p << "], ";
1119   p.printOperand(getNumElements());
1120   p << " : " << getTagMemRef().getType();
1121 }
1122 
1123 // Parse AffineDmaWaitOp.
1124 // Eg:
1125 //   affine.dma_wait %tag[%index], %num_elements
1126 //     : memref<1 x i32, (d0) -> (d0), 4>
1127 //
parse(OpAsmParser & parser,OperationState & result)1128 ParseResult AffineDmaWaitOp::parse(OpAsmParser &parser,
1129                                    OperationState &result) {
1130   OpAsmParser::OperandType tagMemRefInfo;
1131   AffineMapAttr tagMapAttr;
1132   SmallVector<OpAsmParser::OperandType, 2> tagMapOperands;
1133   Type type;
1134   auto indexType = parser.getBuilder().getIndexType();
1135   OpAsmParser::OperandType numElementsInfo;
1136 
1137   // Parse tag memref, its map operands, and dma size.
1138   if (parser.parseOperand(tagMemRefInfo) ||
1139       parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1140                                     getTagMapAttrName(), result.attributes) ||
1141       parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1142       parser.parseColonType(type) ||
1143       parser.resolveOperand(tagMemRefInfo, type, result.operands) ||
1144       parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1145       parser.resolveOperand(numElementsInfo, indexType, result.operands))
1146     return failure();
1147 
1148   if (!type.isa<MemRefType>())
1149     return parser.emitError(parser.getNameLoc(),
1150                             "expected tag to be of memref type");
1151 
1152   if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1153     return parser.emitError(parser.getNameLoc(),
1154                             "tag memref operand count != to map.numInputs");
1155   return success();
1156 }
1157 
verify()1158 LogicalResult AffineDmaWaitOp::verify() {
1159   if (!getOperand(0).getType().isa<MemRefType>())
1160     return emitOpError("expected DMA tag to be of memref type");
1161   Region *scope = getAffineScope(*this);
1162   for (auto idx : getTagIndices()) {
1163     if (!idx.getType().isIndex())
1164       return emitOpError("index to dma_wait must have 'index' type");
1165     if (!isValidAffineIndexOperand(idx, scope))
1166       return emitOpError("index must be a dimension or symbol identifier");
1167   }
1168   return success();
1169 }
1170 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1171 LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1172                                     SmallVectorImpl<OpFoldResult> &results) {
1173   /// dma_wait(memrefcast) -> dma_wait
1174   return foldMemRefCast(*this);
1175 }
1176 
1177 //===----------------------------------------------------------------------===//
1178 // AffineForOp
1179 //===----------------------------------------------------------------------===//
1180 
1181 /// 'bodyBuilder' is used to build the body of affine.for. If iterArgs and
1182 /// bodyBuilder are empty/null, we include default terminator op.
build(OpBuilder & builder,OperationState & result,ValueRange lbOperands,AffineMap lbMap,ValueRange ubOperands,AffineMap ubMap,int64_t step,ValueRange iterArgs,BodyBuilderFn bodyBuilder)1183 void AffineForOp::build(OpBuilder &builder, OperationState &result,
1184                         ValueRange lbOperands, AffineMap lbMap,
1185                         ValueRange ubOperands, AffineMap ubMap, int64_t step,
1186                         ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
1187   assert(((!lbMap && lbOperands.empty()) ||
1188           lbOperands.size() == lbMap.getNumInputs()) &&
1189          "lower bound operand count does not match the affine map");
1190   assert(((!ubMap && ubOperands.empty()) ||
1191           ubOperands.size() == ubMap.getNumInputs()) &&
1192          "upper bound operand count does not match the affine map");
1193   assert(step > 0 && "step has to be a positive integer constant");
1194 
1195   for (Value val : iterArgs)
1196     result.addTypes(val.getType());
1197 
1198   // Add an attribute for the step.
1199   result.addAttribute(getStepAttrName(),
1200                       builder.getIntegerAttr(builder.getIndexType(), step));
1201 
1202   // Add the lower bound.
1203   result.addAttribute(getLowerBoundAttrName(), AffineMapAttr::get(lbMap));
1204   result.addOperands(lbOperands);
1205 
1206   // Add the upper bound.
1207   result.addAttribute(getUpperBoundAttrName(), AffineMapAttr::get(ubMap));
1208   result.addOperands(ubOperands);
1209 
1210   result.addOperands(iterArgs);
1211   // Create a region and a block for the body.  The argument of the region is
1212   // the loop induction variable.
1213   Region *bodyRegion = result.addRegion();
1214   bodyRegion->push_back(new Block);
1215   Block &bodyBlock = bodyRegion->front();
1216   Value inductionVar = bodyBlock.addArgument(builder.getIndexType());
1217   for (Value val : iterArgs)
1218     bodyBlock.addArgument(val.getType());
1219 
1220   // Create the default terminator if the builder is not provided and if the
1221   // iteration arguments are not provided. Otherwise, leave this to the caller
1222   // because we don't know which values to return from the loop.
1223   if (iterArgs.empty() && !bodyBuilder) {
1224     ensureTerminator(*bodyRegion, builder, result.location);
1225   } else if (bodyBuilder) {
1226     OpBuilder::InsertionGuard guard(builder);
1227     builder.setInsertionPointToStart(&bodyBlock);
1228     bodyBuilder(builder, result.location, inductionVar,
1229                 bodyBlock.getArguments().drop_front());
1230   }
1231 }
1232 
build(OpBuilder & builder,OperationState & result,int64_t lb,int64_t ub,int64_t step,ValueRange iterArgs,BodyBuilderFn bodyBuilder)1233 void AffineForOp::build(OpBuilder &builder, OperationState &result, int64_t lb,
1234                         int64_t ub, int64_t step, ValueRange iterArgs,
1235                         BodyBuilderFn bodyBuilder) {
1236   auto lbMap = AffineMap::getConstantMap(lb, builder.getContext());
1237   auto ubMap = AffineMap::getConstantMap(ub, builder.getContext());
1238   return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,
1239                bodyBuilder);
1240 }
1241 
verify(AffineForOp op)1242 static LogicalResult verify(AffineForOp op) {
1243   // Check that the body defines as single block argument for the induction
1244   // variable.
1245   auto *body = op.getBody();
1246   if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
1247     return op.emitOpError(
1248         "expected body to have a single index argument for the "
1249         "induction variable");
1250 
1251   // Verify that the bound operands are valid dimension/symbols.
1252   /// Lower bound.
1253   if (op.getLowerBoundMap().getNumInputs() > 0)
1254     if (failed(
1255             verifyDimAndSymbolIdentifiers(op, op.getLowerBoundOperands(),
1256                                           op.getLowerBoundMap().getNumDims())))
1257       return failure();
1258   /// Upper bound.
1259   if (op.getUpperBoundMap().getNumInputs() > 0)
1260     if (failed(
1261             verifyDimAndSymbolIdentifiers(op, op.getUpperBoundOperands(),
1262                                           op.getUpperBoundMap().getNumDims())))
1263       return failure();
1264 
1265   unsigned opNumResults = op.getNumResults();
1266   if (opNumResults == 0)
1267     return success();
1268 
1269   // If ForOp defines values, check that the number and types of the defined
1270   // values match ForOp initial iter operands and backedge basic block
1271   // arguments.
1272   if (op.getNumIterOperands() != opNumResults)
1273     return op.emitOpError(
1274         "mismatch between the number of loop-carried values and results");
1275   if (op.getNumRegionIterArgs() != opNumResults)
1276     return op.emitOpError(
1277         "mismatch between the number of basic block args and results");
1278 
1279   return success();
1280 }
1281 
1282 /// Parse a for operation loop bounds.
parseBound(bool isLower,OperationState & result,OpAsmParser & p)1283 static ParseResult parseBound(bool isLower, OperationState &result,
1284                               OpAsmParser &p) {
1285   // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
1286   // the map has multiple results.
1287   bool failedToParsedMinMax =
1288       failed(p.parseOptionalKeyword(isLower ? "max" : "min"));
1289 
1290   auto &builder = p.getBuilder();
1291   auto boundAttrName = isLower ? AffineForOp::getLowerBoundAttrName()
1292                                : AffineForOp::getUpperBoundAttrName();
1293 
1294   // Parse ssa-id as identity map.
1295   SmallVector<OpAsmParser::OperandType, 1> boundOpInfos;
1296   if (p.parseOperandList(boundOpInfos))
1297     return failure();
1298 
1299   if (!boundOpInfos.empty()) {
1300     // Check that only one operand was parsed.
1301     if (boundOpInfos.size() > 1)
1302       return p.emitError(p.getNameLoc(),
1303                          "expected only one loop bound operand");
1304 
1305     // TODO: improve error message when SSA value is not of index type.
1306     // Currently it is 'use of value ... expects different type than prior uses'
1307     if (p.resolveOperand(boundOpInfos.front(), builder.getIndexType(),
1308                          result.operands))
1309       return failure();
1310 
1311     // Create an identity map using symbol id. This representation is optimized
1312     // for storage. Analysis passes may expand it into a multi-dimensional map
1313     // if desired.
1314     AffineMap map = builder.getSymbolIdentityMap();
1315     result.addAttribute(boundAttrName, AffineMapAttr::get(map));
1316     return success();
1317   }
1318 
1319   // Get the attribute location.
1320   llvm::SMLoc attrLoc = p.getCurrentLocation();
1321 
1322   Attribute boundAttr;
1323   if (p.parseAttribute(boundAttr, builder.getIndexType(), boundAttrName,
1324                        result.attributes))
1325     return failure();
1326 
1327   // Parse full form - affine map followed by dim and symbol list.
1328   if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) {
1329     unsigned currentNumOperands = result.operands.size();
1330     unsigned numDims;
1331     if (parseDimAndSymbolList(p, result.operands, numDims))
1332       return failure();
1333 
1334     auto map = affineMapAttr.getValue();
1335     if (map.getNumDims() != numDims)
1336       return p.emitError(
1337           p.getNameLoc(),
1338           "dim operand count and affine map dim count must match");
1339 
1340     unsigned numDimAndSymbolOperands =
1341         result.operands.size() - currentNumOperands;
1342     if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
1343       return p.emitError(
1344           p.getNameLoc(),
1345           "symbol operand count and affine map symbol count must match");
1346 
1347     // If the map has multiple results, make sure that we parsed the min/max
1348     // prefix.
1349     if (map.getNumResults() > 1 && failedToParsedMinMax) {
1350       if (isLower) {
1351         return p.emitError(attrLoc, "lower loop bound affine map with "
1352                                     "multiple results requires 'max' prefix");
1353       }
1354       return p.emitError(attrLoc, "upper loop bound affine map with multiple "
1355                                   "results requires 'min' prefix");
1356     }
1357     return success();
1358   }
1359 
1360   // Parse custom assembly form.
1361   if (auto integerAttr = boundAttr.dyn_cast<IntegerAttr>()) {
1362     result.attributes.pop_back();
1363     result.addAttribute(
1364         boundAttrName,
1365         AffineMapAttr::get(builder.getConstantAffineMap(integerAttr.getInt())));
1366     return success();
1367   }
1368 
1369   return p.emitError(
1370       p.getNameLoc(),
1371       "expected valid affine map representation for loop bounds");
1372 }
1373 
parseAffineForOp(OpAsmParser & parser,OperationState & result)1374 static ParseResult parseAffineForOp(OpAsmParser &parser,
1375                                     OperationState &result) {
1376   auto &builder = parser.getBuilder();
1377   OpAsmParser::OperandType inductionVariable;
1378   // Parse the induction variable followed by '='.
1379   if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual())
1380     return failure();
1381 
1382   // Parse loop bounds.
1383   if (parseBound(/*isLower=*/true, result, parser) ||
1384       parser.parseKeyword("to", " between bounds") ||
1385       parseBound(/*isLower=*/false, result, parser))
1386     return failure();
1387 
1388   // Parse the optional loop step, we default to 1 if one is not present.
1389   if (parser.parseOptionalKeyword("step")) {
1390     result.addAttribute(
1391         AffineForOp::getStepAttrName(),
1392         builder.getIntegerAttr(builder.getIndexType(), /*value=*/1));
1393   } else {
1394     llvm::SMLoc stepLoc = parser.getCurrentLocation();
1395     IntegerAttr stepAttr;
1396     if (parser.parseAttribute(stepAttr, builder.getIndexType(),
1397                               AffineForOp::getStepAttrName().data(),
1398                               result.attributes))
1399       return failure();
1400 
1401     if (stepAttr.getValue().getSExtValue() < 0)
1402       return parser.emitError(
1403           stepLoc,
1404           "expected step to be representable as a positive signed integer");
1405   }
1406 
1407   // Parse the optional initial iteration arguments.
1408   SmallVector<OpAsmParser::OperandType, 4> regionArgs, operands;
1409   SmallVector<Type, 4> argTypes;
1410   regionArgs.push_back(inductionVariable);
1411 
1412   if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
1413     // Parse assignment list and results type list.
1414     if (parser.parseAssignmentList(regionArgs, operands) ||
1415         parser.parseArrowTypeList(result.types))
1416       return failure();
1417     // Resolve input operands.
1418     for (auto operandType : llvm::zip(operands, result.types))
1419       if (parser.resolveOperand(std::get<0>(operandType),
1420                                 std::get<1>(operandType), result.operands))
1421         return failure();
1422   }
1423   // Induction variable.
1424   Type indexType = builder.getIndexType();
1425   argTypes.push_back(indexType);
1426   // Loop carried variables.
1427   argTypes.append(result.types.begin(), result.types.end());
1428   // Parse the body region.
1429   Region *body = result.addRegion();
1430   if (regionArgs.size() != argTypes.size())
1431     return parser.emitError(
1432         parser.getNameLoc(),
1433         "mismatch between the number of loop-carried values and results");
1434   if (parser.parseRegion(*body, regionArgs, argTypes))
1435     return failure();
1436 
1437   AffineForOp::ensureTerminator(*body, builder, result.location);
1438 
1439   // Parse the optional attribute list.
1440   return parser.parseOptionalAttrDict(result.attributes);
1441 }
1442 
printBound(AffineMapAttr boundMap,Operation::operand_range boundOperands,const char * prefix,OpAsmPrinter & p)1443 static void printBound(AffineMapAttr boundMap,
1444                        Operation::operand_range boundOperands,
1445                        const char *prefix, OpAsmPrinter &p) {
1446   AffineMap map = boundMap.getValue();
1447 
1448   // Check if this bound should be printed using custom assembly form.
1449   // The decision to restrict printing custom assembly form to trivial cases
1450   // comes from the will to roundtrip MLIR binary -> text -> binary in a
1451   // lossless way.
1452   // Therefore, custom assembly form parsing and printing is only supported for
1453   // zero-operand constant maps and single symbol operand identity maps.
1454   if (map.getNumResults() == 1) {
1455     AffineExpr expr = map.getResult(0);
1456 
1457     // Print constant bound.
1458     if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
1459       if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
1460         p << constExpr.getValue();
1461         return;
1462       }
1463     }
1464 
1465     // Print bound that consists of a single SSA symbol if the map is over a
1466     // single symbol.
1467     if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
1468       if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
1469         p.printOperand(*boundOperands.begin());
1470         return;
1471       }
1472     }
1473   } else {
1474     // Map has multiple results. Print 'min' or 'max' prefix.
1475     p << prefix << ' ';
1476   }
1477 
1478   // Print the map and its operands.
1479   p << boundMap;
1480   printDimAndSymbolList(boundOperands.begin(), boundOperands.end(),
1481                         map.getNumDims(), p);
1482 }
1483 
getNumIterOperands()1484 unsigned AffineForOp::getNumIterOperands() {
1485   AffineMap lbMap = getLowerBoundMapAttr().getValue();
1486   AffineMap ubMap = getUpperBoundMapAttr().getValue();
1487 
1488   return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs();
1489 }
1490 
print(OpAsmPrinter & p,AffineForOp op)1491 static void print(OpAsmPrinter &p, AffineForOp op) {
1492   p << op.getOperationName() << ' ';
1493   p.printOperand(op.getBody()->getArgument(0));
1494   p << " = ";
1495   printBound(op.getLowerBoundMapAttr(), op.getLowerBoundOperands(), "max", p);
1496   p << " to ";
1497   printBound(op.getUpperBoundMapAttr(), op.getUpperBoundOperands(), "min", p);
1498 
1499   if (op.getStep() != 1)
1500     p << " step " << op.getStep();
1501 
1502   bool printBlockTerminators = false;
1503   if (op.getNumIterOperands() > 0) {
1504     p << " iter_args(";
1505     auto regionArgs = op.getRegionIterArgs();
1506     auto operands = op.getIterOperands();
1507 
1508     llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) {
1509       p << std::get<0>(it) << " = " << std::get<1>(it);
1510     });
1511     p << ") -> (" << op.getResultTypes() << ")";
1512     printBlockTerminators = true;
1513   }
1514 
1515   p.printRegion(op.region(),
1516                 /*printEntryBlockArgs=*/false, printBlockTerminators);
1517   p.printOptionalAttrDict(op.getAttrs(),
1518                           /*elidedAttrs=*/{op.getLowerBoundAttrName(),
1519                                            op.getUpperBoundAttrName(),
1520                                            op.getStepAttrName()});
1521 }
1522 
1523 /// Fold the constant bounds of a loop.
foldLoopBounds(AffineForOp forOp)1524 static LogicalResult foldLoopBounds(AffineForOp forOp) {
1525   auto foldLowerOrUpperBound = [&forOp](bool lower) {
1526     // Check to see if each of the operands is the result of a constant.  If
1527     // so, get the value.  If not, ignore it.
1528     SmallVector<Attribute, 8> operandConstants;
1529     auto boundOperands =
1530         lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
1531     for (auto operand : boundOperands) {
1532       Attribute operandCst;
1533       matchPattern(operand, m_Constant(&operandCst));
1534       operandConstants.push_back(operandCst);
1535     }
1536 
1537     AffineMap boundMap =
1538         lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
1539     assert(boundMap.getNumResults() >= 1 &&
1540            "bound maps should have at least one result");
1541     SmallVector<Attribute, 4> foldedResults;
1542     if (failed(boundMap.constantFold(operandConstants, foldedResults)))
1543       return failure();
1544 
1545     // Compute the max or min as applicable over the results.
1546     assert(!foldedResults.empty() && "bounds should have at least one result");
1547     auto maxOrMin = foldedResults[0].cast<IntegerAttr>().getValue();
1548     for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
1549       auto foldedResult = foldedResults[i].cast<IntegerAttr>().getValue();
1550       maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
1551                        : llvm::APIntOps::smin(maxOrMin, foldedResult);
1552     }
1553     lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
1554           : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
1555     return success();
1556   };
1557 
1558   // Try to fold the lower bound.
1559   bool folded = false;
1560   if (!forOp.hasConstantLowerBound())
1561     folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true));
1562 
1563   // Try to fold the upper bound.
1564   if (!forOp.hasConstantUpperBound())
1565     folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false));
1566   return success(folded);
1567 }
1568 
1569 /// Canonicalize the bounds of the given loop.
canonicalizeLoopBounds(AffineForOp forOp)1570 static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
1571   SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
1572   SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
1573 
1574   auto lbMap = forOp.getLowerBoundMap();
1575   auto ubMap = forOp.getUpperBoundMap();
1576   auto prevLbMap = lbMap;
1577   auto prevUbMap = ubMap;
1578 
1579   canonicalizeMapAndOperands(&lbMap, &lbOperands);
1580   lbMap = removeDuplicateExprs(lbMap);
1581 
1582   canonicalizeMapAndOperands(&ubMap, &ubOperands);
1583   ubMap = removeDuplicateExprs(ubMap);
1584 
1585   // Any canonicalization change always leads to updated map(s).
1586   if (lbMap == prevLbMap && ubMap == prevUbMap)
1587     return failure();
1588 
1589   if (lbMap != prevLbMap)
1590     forOp.setLowerBound(lbOperands, lbMap);
1591   if (ubMap != prevUbMap)
1592     forOp.setUpperBound(ubOperands, ubMap);
1593   return success();
1594 }
1595 
1596 namespace {
1597 /// This is a pattern to fold trivially empty loops.
1598 struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
1599   using OpRewritePattern<AffineForOp>::OpRewritePattern;
1600 
matchAndRewrite__anone6b2d6171011::AffineForEmptyLoopFolder1601   LogicalResult matchAndRewrite(AffineForOp forOp,
1602                                 PatternRewriter &rewriter) const override {
1603     // Check that the body only contains a yield.
1604     if (!llvm::hasSingleElement(*forOp.getBody()))
1605       return failure();
1606     rewriter.eraseOp(forOp);
1607     return success();
1608   }
1609 };
1610 } // end anonymous namespace
1611 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)1612 void AffineForOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
1613                                               MLIRContext *context) {
1614   results.insert<AffineForEmptyLoopFolder>(context);
1615 }
1616 
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1617 LogicalResult AffineForOp::fold(ArrayRef<Attribute> operands,
1618                                 SmallVectorImpl<OpFoldResult> &results) {
1619   bool folded = succeeded(foldLoopBounds(*this));
1620   folded |= succeeded(canonicalizeLoopBounds(*this));
1621   return success(folded);
1622 }
1623 
getLowerBound()1624 AffineBound AffineForOp::getLowerBound() {
1625   auto lbMap = getLowerBoundMap();
1626   return AffineBound(AffineForOp(*this), 0, lbMap.getNumInputs(), lbMap);
1627 }
1628 
getUpperBound()1629 AffineBound AffineForOp::getUpperBound() {
1630   auto lbMap = getLowerBoundMap();
1631   auto ubMap = getUpperBoundMap();
1632   return AffineBound(AffineForOp(*this), lbMap.getNumInputs(),
1633                      lbMap.getNumInputs() + ubMap.getNumInputs(), ubMap);
1634 }
1635 
setLowerBound(ValueRange lbOperands,AffineMap map)1636 void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) {
1637   assert(lbOperands.size() == map.getNumInputs());
1638   assert(map.getNumResults() >= 1 && "bound map has at least one result");
1639 
1640   SmallVector<Value, 4> newOperands(lbOperands.begin(), lbOperands.end());
1641 
1642   auto ubOperands = getUpperBoundOperands();
1643   newOperands.append(ubOperands.begin(), ubOperands.end());
1644   auto iterOperands = getIterOperands();
1645   newOperands.append(iterOperands.begin(), iterOperands.end());
1646   (*this)->setOperands(newOperands);
1647 
1648   (*this)->setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map));
1649 }
1650 
setUpperBound(ValueRange ubOperands,AffineMap map)1651 void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) {
1652   assert(ubOperands.size() == map.getNumInputs());
1653   assert(map.getNumResults() >= 1 && "bound map has at least one result");
1654 
1655   SmallVector<Value, 4> newOperands(getLowerBoundOperands());
1656   newOperands.append(ubOperands.begin(), ubOperands.end());
1657   auto iterOperands = getIterOperands();
1658   newOperands.append(iterOperands.begin(), iterOperands.end());
1659   (*this)->setOperands(newOperands);
1660 
1661   (*this)->setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map));
1662 }
1663 
setLowerBoundMap(AffineMap map)1664 void AffineForOp::setLowerBoundMap(AffineMap map) {
1665   auto lbMap = getLowerBoundMap();
1666   assert(lbMap.getNumDims() == map.getNumDims() &&
1667          lbMap.getNumSymbols() == map.getNumSymbols());
1668   assert(map.getNumResults() >= 1 && "bound map has at least one result");
1669   (void)lbMap;
1670   (*this)->setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map));
1671 }
1672 
setUpperBoundMap(AffineMap map)1673 void AffineForOp::setUpperBoundMap(AffineMap map) {
1674   auto ubMap = getUpperBoundMap();
1675   assert(ubMap.getNumDims() == map.getNumDims() &&
1676          ubMap.getNumSymbols() == map.getNumSymbols());
1677   assert(map.getNumResults() >= 1 && "bound map has at least one result");
1678   (void)ubMap;
1679   (*this)->setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map));
1680 }
1681 
hasConstantLowerBound()1682 bool AffineForOp::hasConstantLowerBound() {
1683   return getLowerBoundMap().isSingleConstant();
1684 }
1685 
hasConstantUpperBound()1686 bool AffineForOp::hasConstantUpperBound() {
1687   return getUpperBoundMap().isSingleConstant();
1688 }
1689 
getConstantLowerBound()1690 int64_t AffineForOp::getConstantLowerBound() {
1691   return getLowerBoundMap().getSingleConstantResult();
1692 }
1693 
getConstantUpperBound()1694 int64_t AffineForOp::getConstantUpperBound() {
1695   return getUpperBoundMap().getSingleConstantResult();
1696 }
1697 
setConstantLowerBound(int64_t value)1698 void AffineForOp::setConstantLowerBound(int64_t value) {
1699   setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
1700 }
1701 
setConstantUpperBound(int64_t value)1702 void AffineForOp::setConstantUpperBound(int64_t value) {
1703   setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
1704 }
1705 
getLowerBoundOperands()1706 AffineForOp::operand_range AffineForOp::getLowerBoundOperands() {
1707   return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()};
1708 }
1709 
getUpperBoundOperands()1710 AffineForOp::operand_range AffineForOp::getUpperBoundOperands() {
1711   return {operand_begin() + getLowerBoundMap().getNumInputs(),
1712           operand_begin() + getLowerBoundMap().getNumInputs() +
1713               getUpperBoundMap().getNumInputs()};
1714 }
1715 
matchingBoundOperandList()1716 bool AffineForOp::matchingBoundOperandList() {
1717   auto lbMap = getLowerBoundMap();
1718   auto ubMap = getUpperBoundMap();
1719   if (lbMap.getNumDims() != ubMap.getNumDims() ||
1720       lbMap.getNumSymbols() != ubMap.getNumSymbols())
1721     return false;
1722 
1723   unsigned numOperands = lbMap.getNumInputs();
1724   for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
1725     // Compare Value 's.
1726     if (getOperand(i) != getOperand(numOperands + i))
1727       return false;
1728   }
1729   return true;
1730 }
1731 
getLoopBody()1732 Region &AffineForOp::getLoopBody() { return region(); }
1733 
isDefinedOutsideOfLoop(Value value)1734 bool AffineForOp::isDefinedOutsideOfLoop(Value value) {
1735   return !region().isAncestor(value.getParentRegion());
1736 }
1737 
moveOutOfLoop(ArrayRef<Operation * > ops)1738 LogicalResult AffineForOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
1739   for (auto *op : ops)
1740     op->moveBefore(*this);
1741   return success();
1742 }
1743 
1744 /// Returns true if the provided value is the induction variable of a
1745 /// AffineForOp.
isForInductionVar(Value val)1746 bool mlir::isForInductionVar(Value val) {
1747   return getForInductionVarOwner(val) != AffineForOp();
1748 }
1749 
1750 /// Returns the loop parent of an induction variable. If the provided value is
1751 /// not an induction variable, then return nullptr.
getForInductionVarOwner(Value val)1752 AffineForOp mlir::getForInductionVarOwner(Value val) {
1753   auto ivArg = val.dyn_cast<BlockArgument>();
1754   if (!ivArg || !ivArg.getOwner())
1755     return AffineForOp();
1756   auto *containingInst = ivArg.getOwner()->getParent()->getParentOp();
1757   return dyn_cast<AffineForOp>(containingInst);
1758 }
1759 
1760 /// Extracts the induction variables from a list of AffineForOps and returns
1761 /// them.
extractForInductionVars(ArrayRef<AffineForOp> forInsts,SmallVectorImpl<Value> * ivs)1762 void mlir::extractForInductionVars(ArrayRef<AffineForOp> forInsts,
1763                                    SmallVectorImpl<Value> *ivs) {
1764   ivs->reserve(forInsts.size());
1765   for (auto forInst : forInsts)
1766     ivs->push_back(forInst.getInductionVar());
1767 }
1768 
1769 /// Builds an affine loop nest, using "loopCreatorFn" to create individual loop
1770 /// operations.
1771 template <typename BoundListTy, typename LoopCreatorTy>
buildAffineLoopNestImpl(OpBuilder & builder,Location loc,BoundListTy lbs,BoundListTy ubs,ArrayRef<int64_t> steps,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilderFn,LoopCreatorTy && loopCreatorFn)1772 static void buildAffineLoopNestImpl(
1773     OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs,
1774     ArrayRef<int64_t> steps,
1775     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
1776     LoopCreatorTy &&loopCreatorFn) {
1777   assert(lbs.size() == ubs.size() && "Mismatch in number of arguments");
1778   assert(lbs.size() == steps.size() && "Mismatch in number of arguments");
1779 
1780   // If there are no loops to be constructed, construct the body anyway.
1781   OpBuilder::InsertionGuard guard(builder);
1782   if (lbs.empty()) {
1783     if (bodyBuilderFn)
1784       bodyBuilderFn(builder, loc, ValueRange());
1785     return;
1786   }
1787 
1788   // Create the loops iteratively and store the induction variables.
1789   SmallVector<Value, 4> ivs;
1790   ivs.reserve(lbs.size());
1791   for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
1792     // Callback for creating the loop body, always creates the terminator.
1793     auto loopBody = [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
1794                         ValueRange iterArgs) {
1795       ivs.push_back(iv);
1796       // In the innermost loop, call the body builder.
1797       if (i == e - 1 && bodyBuilderFn) {
1798         OpBuilder::InsertionGuard nestedGuard(nestedBuilder);
1799         bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
1800       }
1801       nestedBuilder.create<AffineYieldOp>(nestedLoc);
1802     };
1803 
1804     // Delegate actual loop creation to the callback in order to dispatch
1805     // between constant- and variable-bound loops.
1806     auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
1807     builder.setInsertionPointToStart(loop.getBody());
1808   }
1809 }
1810 
1811 /// Creates an affine loop from the bounds known to be constants.
1812 static AffineForOp
buildAffineLoopFromConstants(OpBuilder & builder,Location loc,int64_t lb,int64_t ub,int64_t step,AffineForOp::BodyBuilderFn bodyBuilderFn)1813 buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb,
1814                              int64_t ub, int64_t step,
1815                              AffineForOp::BodyBuilderFn bodyBuilderFn) {
1816   return builder.create<AffineForOp>(loc, lb, ub, step, /*iterArgs=*/llvm::None,
1817                                      bodyBuilderFn);
1818 }
1819 
1820 /// Creates an affine loop from the bounds that may or may not be constants.
1821 static AffineForOp
buildAffineLoopFromValues(OpBuilder & builder,Location loc,Value lb,Value ub,int64_t step,AffineForOp::BodyBuilderFn bodyBuilderFn)1822 buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub,
1823                           int64_t step,
1824                           AffineForOp::BodyBuilderFn bodyBuilderFn) {
1825   auto lbConst = lb.getDefiningOp<ConstantIndexOp>();
1826   auto ubConst = ub.getDefiningOp<ConstantIndexOp>();
1827   if (lbConst && ubConst)
1828     return buildAffineLoopFromConstants(builder, loc, lbConst.getValue(),
1829                                         ubConst.getValue(), step,
1830                                         bodyBuilderFn);
1831   return builder.create<AffineForOp>(loc, lb, builder.getDimIdentityMap(), ub,
1832                                      builder.getDimIdentityMap(), step,
1833                                      /*iterArgs=*/llvm::None, bodyBuilderFn);
1834 }
1835 
buildAffineLoopNest(OpBuilder & builder,Location loc,ArrayRef<int64_t> lbs,ArrayRef<int64_t> ubs,ArrayRef<int64_t> steps,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilderFn)1836 void mlir::buildAffineLoopNest(
1837     OpBuilder &builder, Location loc, ArrayRef<int64_t> lbs,
1838     ArrayRef<int64_t> ubs, ArrayRef<int64_t> steps,
1839     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1840   buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
1841                           buildAffineLoopFromConstants);
1842 }
1843 
buildAffineLoopNest(OpBuilder & builder,Location loc,ValueRange lbs,ValueRange ubs,ArrayRef<int64_t> steps,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilderFn)1844 void mlir::buildAffineLoopNest(
1845     OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
1846     ArrayRef<int64_t> steps,
1847     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1848   buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
1849                           buildAffineLoopFromValues);
1850 }
1851 
1852 //===----------------------------------------------------------------------===//
1853 // AffineIfOp
1854 //===----------------------------------------------------------------------===//
1855 
1856 namespace {
1857 /// Remove else blocks that have nothing other than a zero value yield.
1858 struct SimplifyDeadElse : public OpRewritePattern<AffineIfOp> {
1859   using OpRewritePattern<AffineIfOp>::OpRewritePattern;
1860 
matchAndRewrite__anone6b2d6171211::SimplifyDeadElse1861   LogicalResult matchAndRewrite(AffineIfOp ifOp,
1862                                 PatternRewriter &rewriter) const override {
1863     if (ifOp.elseRegion().empty() ||
1864         !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
1865       return failure();
1866 
1867     rewriter.startRootUpdate(ifOp);
1868     rewriter.eraseBlock(ifOp.getElseBlock());
1869     rewriter.finalizeRootUpdate(ifOp);
1870     return success();
1871   }
1872 };
1873 } // end anonymous namespace.
1874 
verify(AffineIfOp op)1875 static LogicalResult verify(AffineIfOp op) {
1876   // Verify that we have a condition attribute.
1877   auto conditionAttr =
1878       op->getAttrOfType<IntegerSetAttr>(op.getConditionAttrName());
1879   if (!conditionAttr)
1880     return op.emitOpError(
1881         "requires an integer set attribute named 'condition'");
1882 
1883   // Verify that there are enough operands for the condition.
1884   IntegerSet condition = conditionAttr.getValue();
1885   if (op.getNumOperands() != condition.getNumInputs())
1886     return op.emitOpError(
1887         "operand count and condition integer set dimension and "
1888         "symbol count must match");
1889 
1890   // Verify that the operands are valid dimension/symbols.
1891   if (failed(verifyDimAndSymbolIdentifiers(op, op.getOperands(),
1892                                            condition.getNumDims())))
1893     return failure();
1894 
1895   return success();
1896 }
1897 
parseAffineIfOp(OpAsmParser & parser,OperationState & result)1898 static ParseResult parseAffineIfOp(OpAsmParser &parser,
1899                                    OperationState &result) {
1900   // Parse the condition attribute set.
1901   IntegerSetAttr conditionAttr;
1902   unsigned numDims;
1903   if (parser.parseAttribute(conditionAttr, AffineIfOp::getConditionAttrName(),
1904                             result.attributes) ||
1905       parseDimAndSymbolList(parser, result.operands, numDims))
1906     return failure();
1907 
1908   // Verify the condition operands.
1909   auto set = conditionAttr.getValue();
1910   if (set.getNumDims() != numDims)
1911     return parser.emitError(
1912         parser.getNameLoc(),
1913         "dim operand count and integer set dim count must match");
1914   if (numDims + set.getNumSymbols() != result.operands.size())
1915     return parser.emitError(
1916         parser.getNameLoc(),
1917         "symbol operand count and integer set symbol count must match");
1918 
1919   if (parser.parseOptionalArrowTypeList(result.types))
1920     return failure();
1921 
1922   // Create the regions for 'then' and 'else'.  The latter must be created even
1923   // if it remains empty for the validity of the operation.
1924   result.regions.reserve(2);
1925   Region *thenRegion = result.addRegion();
1926   Region *elseRegion = result.addRegion();
1927 
1928   // Parse the 'then' region.
1929   if (parser.parseRegion(*thenRegion, {}, {}))
1930     return failure();
1931   AffineIfOp::ensureTerminator(*thenRegion, parser.getBuilder(),
1932                                result.location);
1933 
1934   // If we find an 'else' keyword then parse the 'else' region.
1935   if (!parser.parseOptionalKeyword("else")) {
1936     if (parser.parseRegion(*elseRegion, {}, {}))
1937       return failure();
1938     AffineIfOp::ensureTerminator(*elseRegion, parser.getBuilder(),
1939                                  result.location);
1940   }
1941 
1942   // Parse the optional attribute list.
1943   if (parser.parseOptionalAttrDict(result.attributes))
1944     return failure();
1945 
1946   return success();
1947 }
1948 
print(OpAsmPrinter & p,AffineIfOp op)1949 static void print(OpAsmPrinter &p, AffineIfOp op) {
1950   auto conditionAttr =
1951       op->getAttrOfType<IntegerSetAttr>(op.getConditionAttrName());
1952   p << "affine.if " << conditionAttr;
1953   printDimAndSymbolList(op.operand_begin(), op.operand_end(),
1954                         conditionAttr.getValue().getNumDims(), p);
1955   p.printOptionalArrowTypeList(op.getResultTypes());
1956   p.printRegion(op.thenRegion(),
1957                 /*printEntryBlockArgs=*/false,
1958                 /*printBlockTerminators=*/op.getNumResults());
1959 
1960   // Print the 'else' regions if it has any blocks.
1961   auto &elseRegion = op.elseRegion();
1962   if (!elseRegion.empty()) {
1963     p << " else";
1964     p.printRegion(elseRegion,
1965                   /*printEntryBlockArgs=*/false,
1966                   /*printBlockTerminators=*/op.getNumResults());
1967   }
1968 
1969   // Print the attribute list.
1970   p.printOptionalAttrDict(op.getAttrs(),
1971                           /*elidedAttrs=*/op.getConditionAttrName());
1972 }
1973 
getIntegerSet()1974 IntegerSet AffineIfOp::getIntegerSet() {
1975   return (*this)
1976       ->getAttrOfType<IntegerSetAttr>(getConditionAttrName())
1977       .getValue();
1978 }
setIntegerSet(IntegerSet newSet)1979 void AffineIfOp::setIntegerSet(IntegerSet newSet) {
1980   (*this)->setAttr(getConditionAttrName(), IntegerSetAttr::get(newSet));
1981 }
1982 
setConditional(IntegerSet set,ValueRange operands)1983 void AffineIfOp::setConditional(IntegerSet set, ValueRange operands) {
1984   setIntegerSet(set);
1985   (*this)->setOperands(operands);
1986 }
1987 
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,IntegerSet set,ValueRange args,bool withElseRegion)1988 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
1989                        TypeRange resultTypes, IntegerSet set, ValueRange args,
1990                        bool withElseRegion) {
1991   assert(resultTypes.empty() || withElseRegion);
1992   result.addTypes(resultTypes);
1993   result.addOperands(args);
1994   result.addAttribute(getConditionAttrName(), IntegerSetAttr::get(set));
1995 
1996   Region *thenRegion = result.addRegion();
1997   thenRegion->push_back(new Block());
1998   if (resultTypes.empty())
1999     AffineIfOp::ensureTerminator(*thenRegion, builder, result.location);
2000 
2001   Region *elseRegion = result.addRegion();
2002   if (withElseRegion) {
2003     elseRegion->push_back(new Block());
2004     if (resultTypes.empty())
2005       AffineIfOp::ensureTerminator(*elseRegion, builder, result.location);
2006   }
2007 }
2008 
build(OpBuilder & builder,OperationState & result,IntegerSet set,ValueRange args,bool withElseRegion)2009 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
2010                        IntegerSet set, ValueRange args, bool withElseRegion) {
2011   AffineIfOp::build(builder, result, /*resultTypes=*/{}, set, args,
2012                     withElseRegion);
2013 }
2014 
2015 /// Canonicalize an affine if op's conditional (integer set + operands).
fold(ArrayRef<Attribute>,SmallVectorImpl<OpFoldResult> &)2016 LogicalResult AffineIfOp::fold(ArrayRef<Attribute>,
2017                                SmallVectorImpl<OpFoldResult> &) {
2018   auto set = getIntegerSet();
2019   SmallVector<Value, 4> operands(getOperands());
2020   canonicalizeSetAndOperands(&set, &operands);
2021 
2022   // Any canonicalization change always leads to either a reduction in the
2023   // number of operands or a change in the number of symbolic operands
2024   // (promotion of dims to symbols).
2025   if (operands.size() < getIntegerSet().getNumInputs() ||
2026       set.getNumSymbols() > getIntegerSet().getNumSymbols()) {
2027     setConditional(set, operands);
2028     return success();
2029   }
2030 
2031   return failure();
2032 }
2033 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2034 void AffineIfOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2035                                              MLIRContext *context) {
2036   results.insert<SimplifyDeadElse>(context);
2037 }
2038 
2039 //===----------------------------------------------------------------------===//
2040 // AffineLoadOp
2041 //===----------------------------------------------------------------------===//
2042 
build(OpBuilder & builder,OperationState & result,AffineMap map,ValueRange operands)2043 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2044                          AffineMap map, ValueRange operands) {
2045   assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
2046   result.addOperands(operands);
2047   if (map)
2048     result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
2049   auto memrefType = operands[0].getType().cast<MemRefType>();
2050   result.types.push_back(memrefType.getElementType());
2051 }
2052 
build(OpBuilder & builder,OperationState & result,Value memref,AffineMap map,ValueRange mapOperands)2053 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2054                          Value memref, AffineMap map, ValueRange mapOperands) {
2055   assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
2056   result.addOperands(memref);
2057   result.addOperands(mapOperands);
2058   auto memrefType = memref.getType().cast<MemRefType>();
2059   result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
2060   result.types.push_back(memrefType.getElementType());
2061 }
2062 
build(OpBuilder & builder,OperationState & result,Value memref,ValueRange indices)2063 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2064                          Value memref, ValueRange indices) {
2065   auto memrefType = memref.getType().cast<MemRefType>();
2066   int64_t rank = memrefType.getRank();
2067   // Create identity map for memrefs with at least one dimension or () -> ()
2068   // for zero-dimensional memrefs.
2069   auto map =
2070       rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
2071   build(builder, result, memref, map, indices);
2072 }
2073 
parseAffineLoadOp(OpAsmParser & parser,OperationState & result)2074 static ParseResult parseAffineLoadOp(OpAsmParser &parser,
2075                                      OperationState &result) {
2076   auto &builder = parser.getBuilder();
2077   auto indexTy = builder.getIndexType();
2078 
2079   MemRefType type;
2080   OpAsmParser::OperandType memrefInfo;
2081   AffineMapAttr mapAttr;
2082   SmallVector<OpAsmParser::OperandType, 1> mapOperands;
2083   return failure(
2084       parser.parseOperand(memrefInfo) ||
2085       parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
2086                                     AffineLoadOp::getMapAttrName(),
2087                                     result.attributes) ||
2088       parser.parseOptionalAttrDict(result.attributes) ||
2089       parser.parseColonType(type) ||
2090       parser.resolveOperand(memrefInfo, type, result.operands) ||
2091       parser.resolveOperands(mapOperands, indexTy, result.operands) ||
2092       parser.addTypeToList(type.getElementType(), result.types));
2093 }
2094 
print(OpAsmPrinter & p,AffineLoadOp op)2095 static void print(OpAsmPrinter &p, AffineLoadOp op) {
2096   p << "affine.load " << op.getMemRef() << '[';
2097   if (AffineMapAttr mapAttr =
2098           op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
2099     p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
2100   p << ']';
2101   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()});
2102   p << " : " << op.getMemRefType();
2103 }
2104 
2105 /// Verify common indexing invariants of affine.load, affine.store,
2106 /// affine.vector_load and affine.vector_store.
2107 static LogicalResult
verifyMemoryOpIndexing(Operation * op,AffineMapAttr mapAttr,Operation::operand_range mapOperands,MemRefType memrefType,unsigned numIndexOperands)2108 verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
2109                        Operation::operand_range mapOperands,
2110                        MemRefType memrefType, unsigned numIndexOperands) {
2111   if (mapAttr) {
2112     AffineMap map = mapAttr.getValue();
2113     if (map.getNumResults() != memrefType.getRank())
2114       return op->emitOpError("affine map num results must equal memref rank");
2115     if (map.getNumInputs() != numIndexOperands)
2116       return op->emitOpError("expects as many subscripts as affine map inputs");
2117   } else {
2118     if (memrefType.getRank() != numIndexOperands)
2119       return op->emitOpError(
2120           "expects the number of subscripts to be equal to memref rank");
2121   }
2122 
2123   Region *scope = getAffineScope(op);
2124   for (auto idx : mapOperands) {
2125     if (!idx.getType().isIndex())
2126       return op->emitOpError("index to load must have 'index' type");
2127     if (!isValidAffineIndexOperand(idx, scope))
2128       return op->emitOpError("index must be a dimension or symbol identifier");
2129   }
2130 
2131   return success();
2132 }
2133 
verify(AffineLoadOp op)2134 LogicalResult verify(AffineLoadOp op) {
2135   auto memrefType = op.getMemRefType();
2136   if (op.getType() != memrefType.getElementType())
2137     return op.emitOpError("result type must match element type of memref");
2138 
2139   if (failed(verifyMemoryOpIndexing(
2140           op.getOperation(),
2141           op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
2142           op.getMapOperands(), memrefType,
2143           /*numIndexOperands=*/op.getNumOperands() - 1)))
2144     return failure();
2145 
2146   return success();
2147 }
2148 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2149 void AffineLoadOp::getCanonicalizationPatterns(
2150     OwningRewritePatternList &results, MLIRContext *context) {
2151   results.insert<SimplifyAffineOp<AffineLoadOp>>(context);
2152 }
2153 
fold(ArrayRef<Attribute> cstOperands)2154 OpFoldResult AffineLoadOp::fold(ArrayRef<Attribute> cstOperands) {
2155   /// load(memrefcast) -> load
2156   if (succeeded(foldMemRefCast(*this)))
2157     return getResult();
2158   return OpFoldResult();
2159 }
2160 
2161 //===----------------------------------------------------------------------===//
2162 // AffineStoreOp
2163 //===----------------------------------------------------------------------===//
2164 
build(OpBuilder & builder,OperationState & result,Value valueToStore,Value memref,AffineMap map,ValueRange mapOperands)2165 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
2166                           Value valueToStore, Value memref, AffineMap map,
2167                           ValueRange mapOperands) {
2168   assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
2169   result.addOperands(valueToStore);
2170   result.addOperands(memref);
2171   result.addOperands(mapOperands);
2172   result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
2173 }
2174 
2175 // Use identity map.
build(OpBuilder & builder,OperationState & result,Value valueToStore,Value memref,ValueRange indices)2176 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
2177                           Value valueToStore, Value memref,
2178                           ValueRange indices) {
2179   auto memrefType = memref.getType().cast<MemRefType>();
2180   int64_t rank = memrefType.getRank();
2181   // Create identity map for memrefs with at least one dimension or () -> ()
2182   // for zero-dimensional memrefs.
2183   auto map =
2184       rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
2185   build(builder, result, valueToStore, memref, map, indices);
2186 }
2187 
parseAffineStoreOp(OpAsmParser & parser,OperationState & result)2188 static ParseResult parseAffineStoreOp(OpAsmParser &parser,
2189                                       OperationState &result) {
2190   auto indexTy = parser.getBuilder().getIndexType();
2191 
2192   MemRefType type;
2193   OpAsmParser::OperandType storeValueInfo;
2194   OpAsmParser::OperandType memrefInfo;
2195   AffineMapAttr mapAttr;
2196   SmallVector<OpAsmParser::OperandType, 1> mapOperands;
2197   return failure(parser.parseOperand(storeValueInfo) || parser.parseComma() ||
2198                  parser.parseOperand(memrefInfo) ||
2199                  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
2200                                                AffineStoreOp::getMapAttrName(),
2201                                                result.attributes) ||
2202                  parser.parseOptionalAttrDict(result.attributes) ||
2203                  parser.parseColonType(type) ||
2204                  parser.resolveOperand(storeValueInfo, type.getElementType(),
2205                                        result.operands) ||
2206                  parser.resolveOperand(memrefInfo, type, result.operands) ||
2207                  parser.resolveOperands(mapOperands, indexTy, result.operands));
2208 }
2209 
print(OpAsmPrinter & p,AffineStoreOp op)2210 static void print(OpAsmPrinter &p, AffineStoreOp op) {
2211   p << "affine.store " << op.getValueToStore();
2212   p << ", " << op.getMemRef() << '[';
2213   if (AffineMapAttr mapAttr =
2214           op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
2215     p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
2216   p << ']';
2217   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()});
2218   p << " : " << op.getMemRefType();
2219 }
2220 
verify(AffineStoreOp op)2221 LogicalResult verify(AffineStoreOp op) {
2222   // First operand must have same type as memref element type.
2223   auto memrefType = op.getMemRefType();
2224   if (op.getValueToStore().getType() != memrefType.getElementType())
2225     return op.emitOpError(
2226         "first operand must have same type memref element type");
2227 
2228   if (failed(verifyMemoryOpIndexing(
2229           op.getOperation(),
2230           op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
2231           op.getMapOperands(), memrefType,
2232           /*numIndexOperands=*/op.getNumOperands() - 2)))
2233     return failure();
2234 
2235   return success();
2236 }
2237 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2238 void AffineStoreOp::getCanonicalizationPatterns(
2239     OwningRewritePatternList &results, MLIRContext *context) {
2240   results.insert<SimplifyAffineOp<AffineStoreOp>>(context);
2241 }
2242 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)2243 LogicalResult AffineStoreOp::fold(ArrayRef<Attribute> cstOperands,
2244                                   SmallVectorImpl<OpFoldResult> &results) {
2245   /// store(memrefcast) -> store
2246   return foldMemRefCast(*this);
2247 }
2248 
2249 //===----------------------------------------------------------------------===//
2250 // AffineMinMaxOpBase
2251 //===----------------------------------------------------------------------===//
2252 
verifyAffineMinMaxOp(T op)2253 template <typename T> static LogicalResult verifyAffineMinMaxOp(T op) {
2254   // Verify that operand count matches affine map dimension and symbol count.
2255   if (op.getNumOperands() != op.map().getNumDims() + op.map().getNumSymbols())
2256     return op.emitOpError(
2257         "operand count and affine map dimension and symbol count must match");
2258   return success();
2259 }
2260 
printAffineMinMaxOp(OpAsmPrinter & p,T op)2261 template <typename T> static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
2262   p << op.getOperationName() << ' ' << op->getAttr(T::getMapAttrName());
2263   auto operands = op.getOperands();
2264   unsigned numDims = op.map().getNumDims();
2265   p << '(' << operands.take_front(numDims) << ')';
2266 
2267   if (operands.size() != numDims)
2268     p << '[' << operands.drop_front(numDims) << ']';
2269   p.printOptionalAttrDict(op.getAttrs(),
2270                           /*elidedAttrs=*/{T::getMapAttrName()});
2271 }
2272 
2273 template <typename T>
parseAffineMinMaxOp(OpAsmParser & parser,OperationState & result)2274 static ParseResult parseAffineMinMaxOp(OpAsmParser &parser,
2275                                        OperationState &result) {
2276   auto &builder = parser.getBuilder();
2277   auto indexType = builder.getIndexType();
2278   SmallVector<OpAsmParser::OperandType, 8> dim_infos;
2279   SmallVector<OpAsmParser::OperandType, 8> sym_infos;
2280   AffineMapAttr mapAttr;
2281   return failure(
2282       parser.parseAttribute(mapAttr, T::getMapAttrName(), result.attributes) ||
2283       parser.parseOperandList(dim_infos, OpAsmParser::Delimiter::Paren) ||
2284       parser.parseOperandList(sym_infos,
2285                               OpAsmParser::Delimiter::OptionalSquare) ||
2286       parser.parseOptionalAttrDict(result.attributes) ||
2287       parser.resolveOperands(dim_infos, indexType, result.operands) ||
2288       parser.resolveOperands(sym_infos, indexType, result.operands) ||
2289       parser.addTypeToList(indexType, result.types));
2290 }
2291 
2292 /// Fold an affine min or max operation with the given operands. The operand
2293 /// list may contain nulls, which are interpreted as the operand not being a
2294 /// constant.
2295 template <typename T>
foldMinMaxOp(T op,ArrayRef<Attribute> operands)2296 static OpFoldResult foldMinMaxOp(T op, ArrayRef<Attribute> operands) {
2297   static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
2298                 "expected affine min or max op");
2299 
2300   // Fold the affine map.
2301   // TODO: Fold more cases:
2302   // min(some_affine, some_affine + constant, ...), etc.
2303   SmallVector<int64_t, 2> results;
2304   auto foldedMap = op.map().partialConstantFold(operands, &results);
2305 
2306   // If some of the map results are not constant, try changing the map in-place.
2307   if (results.empty()) {
2308     // If the map is the same, report that folding did not happen.
2309     if (foldedMap == op.map())
2310       return {};
2311     op->setAttr("map", AffineMapAttr::get(foldedMap));
2312     return op.getResult();
2313   }
2314 
2315   // Otherwise, completely fold the op into a constant.
2316   auto resultIt = std::is_same<T, AffineMinOp>::value
2317                       ? std::min_element(results.begin(), results.end())
2318                       : std::max_element(results.begin(), results.end());
2319   if (resultIt == results.end())
2320     return {};
2321   return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt);
2322 }
2323 
2324 //===----------------------------------------------------------------------===//
2325 // AffineMinOp
2326 //===----------------------------------------------------------------------===//
2327 //
2328 //   %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
2329 //
2330 
fold(ArrayRef<Attribute> operands)2331 OpFoldResult AffineMinOp::fold(ArrayRef<Attribute> operands) {
2332   return foldMinMaxOp(*this, operands);
2333 }
2334 
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)2335 void AffineMinOp::getCanonicalizationPatterns(
2336     OwningRewritePatternList &patterns, MLIRContext *context) {
2337   patterns.insert<SimplifyAffineOp<AffineMinOp>>(context);
2338 }
2339 
2340 //===----------------------------------------------------------------------===//
2341 // AffineMaxOp
2342 //===----------------------------------------------------------------------===//
2343 //
2344 //   %0 = affine.max (d0) -> (1000, d0 + 512) (%i0)
2345 //
2346 
fold(ArrayRef<Attribute> operands)2347 OpFoldResult AffineMaxOp::fold(ArrayRef<Attribute> operands) {
2348   return foldMinMaxOp(*this, operands);
2349 }
2350 
getCanonicalizationPatterns(OwningRewritePatternList & patterns,MLIRContext * context)2351 void AffineMaxOp::getCanonicalizationPatterns(
2352     OwningRewritePatternList &patterns, MLIRContext *context) {
2353   patterns.insert<SimplifyAffineOp<AffineMaxOp>>(context);
2354 }
2355 
2356 //===----------------------------------------------------------------------===//
2357 // AffinePrefetchOp
2358 //===----------------------------------------------------------------------===//
2359 
2360 //
2361 // affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32>
2362 //
parseAffinePrefetchOp(OpAsmParser & parser,OperationState & result)2363 static ParseResult parseAffinePrefetchOp(OpAsmParser &parser,
2364                                          OperationState &result) {
2365   auto &builder = parser.getBuilder();
2366   auto indexTy = builder.getIndexType();
2367 
2368   MemRefType type;
2369   OpAsmParser::OperandType memrefInfo;
2370   IntegerAttr hintInfo;
2371   auto i32Type = parser.getBuilder().getIntegerType(32);
2372   StringRef readOrWrite, cacheType;
2373 
2374   AffineMapAttr mapAttr;
2375   SmallVector<OpAsmParser::OperandType, 1> mapOperands;
2376   if (parser.parseOperand(memrefInfo) ||
2377       parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
2378                                     AffinePrefetchOp::getMapAttrName(),
2379                                     result.attributes) ||
2380       parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
2381       parser.parseComma() || parser.parseKeyword("locality") ||
2382       parser.parseLess() ||
2383       parser.parseAttribute(hintInfo, i32Type,
2384                             AffinePrefetchOp::getLocalityHintAttrName(),
2385                             result.attributes) ||
2386       parser.parseGreater() || parser.parseComma() ||
2387       parser.parseKeyword(&cacheType) ||
2388       parser.parseOptionalAttrDict(result.attributes) ||
2389       parser.parseColonType(type) ||
2390       parser.resolveOperand(memrefInfo, type, result.operands) ||
2391       parser.resolveOperands(mapOperands, indexTy, result.operands))
2392     return failure();
2393 
2394   if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
2395     return parser.emitError(parser.getNameLoc(),
2396                             "rw specifier has to be 'read' or 'write'");
2397   result.addAttribute(
2398       AffinePrefetchOp::getIsWriteAttrName(),
2399       parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
2400 
2401   if (!cacheType.equals("data") && !cacheType.equals("instr"))
2402     return parser.emitError(parser.getNameLoc(),
2403                             "cache type has to be 'data' or 'instr'");
2404 
2405   result.addAttribute(
2406       AffinePrefetchOp::getIsDataCacheAttrName(),
2407       parser.getBuilder().getBoolAttr(cacheType.equals("data")));
2408 
2409   return success();
2410 }
2411 
print(OpAsmPrinter & p,AffinePrefetchOp op)2412 static void print(OpAsmPrinter &p, AffinePrefetchOp op) {
2413   p << AffinePrefetchOp::getOperationName() << " " << op.memref() << '[';
2414   AffineMapAttr mapAttr = op->getAttrOfType<AffineMapAttr>(op.getMapAttrName());
2415   if (mapAttr) {
2416     SmallVector<Value, 2> operands(op.getMapOperands());
2417     p.printAffineMapOfSSAIds(mapAttr, operands);
2418   }
2419   p << ']' << ", " << (op.isWrite() ? "write" : "read") << ", "
2420     << "locality<" << op.localityHint() << ">, "
2421     << (op.isDataCache() ? "data" : "instr");
2422   p.printOptionalAttrDict(
2423       op.getAttrs(),
2424       /*elidedAttrs=*/{op.getMapAttrName(), op.getLocalityHintAttrName(),
2425                        op.getIsDataCacheAttrName(), op.getIsWriteAttrName()});
2426   p << " : " << op.getMemRefType();
2427 }
2428 
verify(AffinePrefetchOp op)2429 static LogicalResult verify(AffinePrefetchOp op) {
2430   auto mapAttr = op->getAttrOfType<AffineMapAttr>(op.getMapAttrName());
2431   if (mapAttr) {
2432     AffineMap map = mapAttr.getValue();
2433     if (map.getNumResults() != op.getMemRefType().getRank())
2434       return op.emitOpError("affine.prefetch affine map num results must equal"
2435                             " memref rank");
2436     if (map.getNumInputs() + 1 != op.getNumOperands())
2437       return op.emitOpError("too few operands");
2438   } else {
2439     if (op.getNumOperands() != 1)
2440       return op.emitOpError("too few operands");
2441   }
2442 
2443   Region *scope = getAffineScope(op);
2444   for (auto idx : op.getMapOperands()) {
2445     if (!isValidAffineIndexOperand(idx, scope))
2446       return op.emitOpError("index must be a dimension or symbol identifier");
2447   }
2448   return success();
2449 }
2450 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2451 void AffinePrefetchOp::getCanonicalizationPatterns(
2452     OwningRewritePatternList &results, MLIRContext *context) {
2453   // prefetch(memrefcast) -> prefetch
2454   results.insert<SimplifyAffineOp<AffinePrefetchOp>>(context);
2455 }
2456 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)2457 LogicalResult AffinePrefetchOp::fold(ArrayRef<Attribute> cstOperands,
2458                                      SmallVectorImpl<OpFoldResult> &results) {
2459   /// prefetch(memrefcast) -> prefetch
2460   return foldMemRefCast(*this);
2461 }
2462 
2463 //===----------------------------------------------------------------------===//
2464 // AffineParallelOp
2465 //===----------------------------------------------------------------------===//
2466 
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,ArrayRef<AtomicRMWKind> reductions,ArrayRef<int64_t> ranges)2467 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
2468                              TypeRange resultTypes,
2469                              ArrayRef<AtomicRMWKind> reductions,
2470                              ArrayRef<int64_t> ranges) {
2471   SmallVector<AffineExpr, 8> lbExprs(ranges.size(),
2472                                      builder.getAffineConstantExpr(0));
2473   auto lbMap = AffineMap::get(0, 0, lbExprs, builder.getContext());
2474   SmallVector<AffineExpr, 8> ubExprs;
2475   for (int64_t range : ranges)
2476     ubExprs.push_back(builder.getAffineConstantExpr(range));
2477   auto ubMap = AffineMap::get(0, 0, ubExprs, builder.getContext());
2478   build(builder, result, resultTypes, reductions, lbMap, /*lbArgs=*/{}, ubMap,
2479         /*ubArgs=*/{});
2480 }
2481 
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,ArrayRef<AtomicRMWKind> reductions,AffineMap lbMap,ValueRange lbArgs,AffineMap ubMap,ValueRange ubArgs)2482 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
2483                              TypeRange resultTypes,
2484                              ArrayRef<AtomicRMWKind> reductions,
2485                              AffineMap lbMap, ValueRange lbArgs,
2486                              AffineMap ubMap, ValueRange ubArgs) {
2487   auto numDims = lbMap.getNumResults();
2488   // Verify that the dimensionality of both maps are the same.
2489   assert(numDims == ubMap.getNumResults() &&
2490          "num dims and num results mismatch");
2491   // Make default step sizes of 1.
2492   SmallVector<int64_t, 8> steps(numDims, 1);
2493   build(builder, result, resultTypes, reductions, lbMap, lbArgs, ubMap, ubArgs,
2494         steps);
2495 }
2496 
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,ArrayRef<AtomicRMWKind> reductions,AffineMap lbMap,ValueRange lbArgs,AffineMap ubMap,ValueRange ubArgs,ArrayRef<int64_t> steps)2497 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
2498                              TypeRange resultTypes,
2499                              ArrayRef<AtomicRMWKind> reductions,
2500                              AffineMap lbMap, ValueRange lbArgs,
2501                              AffineMap ubMap, ValueRange ubArgs,
2502                              ArrayRef<int64_t> steps) {
2503   auto numDims = lbMap.getNumResults();
2504   // Verify that the dimensionality of the maps matches the number of steps.
2505   assert(numDims == ubMap.getNumResults() &&
2506          "num dims and num results mismatch");
2507   assert(numDims == steps.size() && "num dims and num steps mismatch");
2508 
2509   result.addTypes(resultTypes);
2510   // Convert the reductions to integer attributes.
2511   SmallVector<Attribute, 4> reductionAttrs;
2512   for (AtomicRMWKind reduction : reductions)
2513     reductionAttrs.push_back(
2514         builder.getI64IntegerAttr(static_cast<int64_t>(reduction)));
2515   result.addAttribute(getReductionsAttrName(),
2516                       builder.getArrayAttr(reductionAttrs));
2517   result.addAttribute(getLowerBoundsMapAttrName(), AffineMapAttr::get(lbMap));
2518   result.addAttribute(getUpperBoundsMapAttrName(), AffineMapAttr::get(ubMap));
2519   result.addAttribute(getStepsAttrName(), builder.getI64ArrayAttr(steps));
2520   result.addOperands(lbArgs);
2521   result.addOperands(ubArgs);
2522   // Create a region and a block for the body.
2523   auto *bodyRegion = result.addRegion();
2524   auto *body = new Block();
2525   // Add all the block arguments.
2526   for (unsigned i = 0; i < numDims; ++i)
2527     body->addArgument(IndexType::get(builder.getContext()));
2528   bodyRegion->push_back(body);
2529   if (resultTypes.empty())
2530     ensureTerminator(*bodyRegion, builder, result.location);
2531 }
2532 
getLoopBody()2533 Region &AffineParallelOp::getLoopBody() { return region(); }
2534 
isDefinedOutsideOfLoop(Value value)2535 bool AffineParallelOp::isDefinedOutsideOfLoop(Value value) {
2536   return !region().isAncestor(value.getParentRegion());
2537 }
2538 
moveOutOfLoop(ArrayRef<Operation * > ops)2539 LogicalResult AffineParallelOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
2540   for (Operation *op : ops)
2541     op->moveBefore(*this);
2542   return success();
2543 }
2544 
getNumDims()2545 unsigned AffineParallelOp::getNumDims() { return steps().size(); }
2546 
getLowerBoundsOperands()2547 AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
2548   return getOperands().take_front(lowerBoundsMap().getNumInputs());
2549 }
2550 
getUpperBoundsOperands()2551 AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
2552   return getOperands().drop_front(lowerBoundsMap().getNumInputs());
2553 }
2554 
getLowerBoundsValueMap()2555 AffineValueMap AffineParallelOp::getLowerBoundsValueMap() {
2556   return AffineValueMap(lowerBoundsMap(), getLowerBoundsOperands());
2557 }
2558 
getUpperBoundsValueMap()2559 AffineValueMap AffineParallelOp::getUpperBoundsValueMap() {
2560   return AffineValueMap(upperBoundsMap(), getUpperBoundsOperands());
2561 }
2562 
getRangesValueMap()2563 AffineValueMap AffineParallelOp::getRangesValueMap() {
2564   AffineValueMap out;
2565   AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
2566                              &out);
2567   return out;
2568 }
2569 
getConstantRanges()2570 Optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
2571   // Try to convert all the ranges to constant expressions.
2572   SmallVector<int64_t, 8> out;
2573   AffineValueMap rangesValueMap = getRangesValueMap();
2574   out.reserve(rangesValueMap.getNumResults());
2575   for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) {
2576     auto expr = rangesValueMap.getResult(i);
2577     auto cst = expr.dyn_cast<AffineConstantExpr>();
2578     if (!cst)
2579       return llvm::None;
2580     out.push_back(cst.getValue());
2581   }
2582   return out;
2583 }
2584 
getBody()2585 Block *AffineParallelOp::getBody() { return &region().front(); }
2586 
getBodyBuilder()2587 OpBuilder AffineParallelOp::getBodyBuilder() {
2588   return OpBuilder(getBody(), std::prev(getBody()->end()));
2589 }
2590 
setLowerBounds(ValueRange lbOperands,AffineMap map)2591 void AffineParallelOp::setLowerBounds(ValueRange lbOperands, AffineMap map) {
2592   assert(lbOperands.size() == map.getNumInputs() &&
2593          "operands to map must match number of inputs");
2594   assert(map.getNumResults() >= 1 && "bounds map has at least one result");
2595 
2596   auto ubOperands = getUpperBoundsOperands();
2597 
2598   SmallVector<Value, 4> newOperands(lbOperands);
2599   newOperands.append(ubOperands.begin(), ubOperands.end());
2600   (*this)->setOperands(newOperands);
2601 
2602   lowerBoundsMapAttr(AffineMapAttr::get(map));
2603 }
2604 
setUpperBounds(ValueRange ubOperands,AffineMap map)2605 void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) {
2606   assert(ubOperands.size() == map.getNumInputs() &&
2607          "operands to map must match number of inputs");
2608   assert(map.getNumResults() >= 1 && "bounds map has at least one result");
2609 
2610   SmallVector<Value, 4> newOperands(getLowerBoundsOperands());
2611   newOperands.append(ubOperands.begin(), ubOperands.end());
2612   (*this)->setOperands(newOperands);
2613 
2614   upperBoundsMapAttr(AffineMapAttr::get(map));
2615 }
2616 
setLowerBoundsMap(AffineMap map)2617 void AffineParallelOp::setLowerBoundsMap(AffineMap map) {
2618   AffineMap lbMap = lowerBoundsMap();
2619   assert(lbMap.getNumDims() == map.getNumDims() &&
2620          lbMap.getNumSymbols() == map.getNumSymbols());
2621   (void)lbMap;
2622   lowerBoundsMapAttr(AffineMapAttr::get(map));
2623 }
2624 
setUpperBoundsMap(AffineMap map)2625 void AffineParallelOp::setUpperBoundsMap(AffineMap map) {
2626   AffineMap ubMap = upperBoundsMap();
2627   assert(ubMap.getNumDims() == map.getNumDims() &&
2628          ubMap.getNumSymbols() == map.getNumSymbols());
2629   (void)ubMap;
2630   upperBoundsMapAttr(AffineMapAttr::get(map));
2631 }
2632 
getSteps()2633 SmallVector<int64_t, 8> AffineParallelOp::getSteps() {
2634   SmallVector<int64_t, 8> result;
2635   for (Attribute attr : steps()) {
2636     result.push_back(attr.cast<IntegerAttr>().getInt());
2637   }
2638   return result;
2639 }
2640 
setSteps(ArrayRef<int64_t> newSteps)2641 void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
2642   stepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
2643 }
2644 
verify(AffineParallelOp op)2645 static LogicalResult verify(AffineParallelOp op) {
2646   auto numDims = op.getNumDims();
2647   if (op.lowerBoundsMap().getNumResults() != numDims ||
2648       op.upperBoundsMap().getNumResults() != numDims ||
2649       op.steps().size() != numDims ||
2650       op.getBody()->getNumArguments() != numDims)
2651     return op.emitOpError("region argument count and num results of upper "
2652                           "bounds, lower bounds, and steps must all match");
2653 
2654   if (op.reductions().size() != op.getNumResults())
2655     return op.emitOpError("a reduction must be specified for each output");
2656 
2657   // Verify reduction  ops are all valid
2658   for (Attribute attr : op.reductions()) {
2659     auto intAttr = attr.dyn_cast<IntegerAttr>();
2660     if (!intAttr || !symbolizeAtomicRMWKind(intAttr.getInt()))
2661       return op.emitOpError("invalid reduction attribute");
2662   }
2663 
2664   // Verify that the bound operands are valid dimension/symbols.
2665   /// Lower bounds.
2666   if (failed(verifyDimAndSymbolIdentifiers(op, op.getLowerBoundsOperands(),
2667                                            op.lowerBoundsMap().getNumDims())))
2668     return failure();
2669   /// Upper bounds.
2670   if (failed(verifyDimAndSymbolIdentifiers(op, op.getUpperBoundsOperands(),
2671                                            op.upperBoundsMap().getNumDims())))
2672     return failure();
2673   return success();
2674 }
2675 
canonicalize()2676 LogicalResult AffineValueMap::canonicalize() {
2677   SmallVector<Value, 4> newOperands{operands};
2678   auto newMap = getAffineMap();
2679   composeAffineMapAndOperands(&newMap, &newOperands);
2680   if (newMap == getAffineMap() && newOperands == operands)
2681     return failure();
2682   reset(newMap, newOperands);
2683   return success();
2684 }
2685 
2686 /// Canonicalize the bounds of the given loop.
canonicalizeLoopBounds(AffineParallelOp op)2687 static LogicalResult canonicalizeLoopBounds(AffineParallelOp op) {
2688   AffineValueMap lb = op.getLowerBoundsValueMap();
2689   bool lbCanonicalized = succeeded(lb.canonicalize());
2690 
2691   AffineValueMap ub = op.getUpperBoundsValueMap();
2692   bool ubCanonicalized = succeeded(ub.canonicalize());
2693 
2694   // Any canonicalization change always leads to updated map(s).
2695   if (!lbCanonicalized && !ubCanonicalized)
2696     return failure();
2697 
2698   if (lbCanonicalized)
2699     op.setLowerBounds(lb.getOperands(), lb.getAffineMap());
2700   if (ubCanonicalized)
2701     op.setUpperBounds(ub.getOperands(), ub.getAffineMap());
2702 
2703   return success();
2704 }
2705 
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)2706 LogicalResult AffineParallelOp::fold(ArrayRef<Attribute> operands,
2707                                      SmallVectorImpl<OpFoldResult> &results) {
2708   return canonicalizeLoopBounds(*this);
2709 }
2710 
print(OpAsmPrinter & p,AffineParallelOp op)2711 static void print(OpAsmPrinter &p, AffineParallelOp op) {
2712   p << op.getOperationName() << " (" << op.getBody()->getArguments() << ") = (";
2713   p.printAffineMapOfSSAIds(op.lowerBoundsMapAttr(),
2714                            op.getLowerBoundsOperands());
2715   p << ") to (";
2716   p.printAffineMapOfSSAIds(op.upperBoundsMapAttr(),
2717                            op.getUpperBoundsOperands());
2718   p << ')';
2719   SmallVector<int64_t, 8> steps = op.getSteps();
2720   bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; });
2721   if (!elideSteps) {
2722     p << " step (";
2723     llvm::interleaveComma(steps, p);
2724     p << ')';
2725   }
2726   if (op.getNumResults()) {
2727     p << " reduce (";
2728     llvm::interleaveComma(op.reductions(), p, [&](auto &attr) {
2729       AtomicRMWKind sym =
2730           *symbolizeAtomicRMWKind(attr.template cast<IntegerAttr>().getInt());
2731       p << "\"" << stringifyAtomicRMWKind(sym) << "\"";
2732     });
2733     p << ") -> (" << op.getResultTypes() << ")";
2734   }
2735 
2736   p.printRegion(op.region(), /*printEntryBlockArgs=*/false,
2737                 /*printBlockTerminators=*/op.getNumResults());
2738   p.printOptionalAttrDict(
2739       op.getAttrs(),
2740       /*elidedAttrs=*/{AffineParallelOp::getReductionsAttrName(),
2741                        AffineParallelOp::getLowerBoundsMapAttrName(),
2742                        AffineParallelOp::getUpperBoundsMapAttrName(),
2743                        AffineParallelOp::getStepsAttrName()});
2744 }
2745 
2746 //
2747 // operation ::= `affine.parallel` `(` ssa-ids `)` `=` `(` map-of-ssa-ids `)`
2748 //               `to` `(` map-of-ssa-ids `)` steps? region attr-dict?
2749 // steps     ::= `steps` `(` integer-literals `)`
2750 //
parseAffineParallelOp(OpAsmParser & parser,OperationState & result)2751 static ParseResult parseAffineParallelOp(OpAsmParser &parser,
2752                                          OperationState &result) {
2753   auto &builder = parser.getBuilder();
2754   auto indexType = builder.getIndexType();
2755   AffineMapAttr lowerBoundsAttr, upperBoundsAttr;
2756   SmallVector<OpAsmParser::OperandType, 4> ivs;
2757   SmallVector<OpAsmParser::OperandType, 4> lowerBoundsMapOperands;
2758   SmallVector<OpAsmParser::OperandType, 4> upperBoundsMapOperands;
2759   if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
2760                                      OpAsmParser::Delimiter::Paren) ||
2761       parser.parseEqual() ||
2762       parser.parseAffineMapOfSSAIds(
2763           lowerBoundsMapOperands, lowerBoundsAttr,
2764           AffineParallelOp::getLowerBoundsMapAttrName(), result.attributes,
2765           OpAsmParser::Delimiter::Paren) ||
2766       parser.resolveOperands(lowerBoundsMapOperands, indexType,
2767                              result.operands) ||
2768       parser.parseKeyword("to") ||
2769       parser.parseAffineMapOfSSAIds(
2770           upperBoundsMapOperands, upperBoundsAttr,
2771           AffineParallelOp::getUpperBoundsMapAttrName(), result.attributes,
2772           OpAsmParser::Delimiter::Paren) ||
2773       parser.resolveOperands(upperBoundsMapOperands, indexType,
2774                              result.operands))
2775     return failure();
2776 
2777   AffineMapAttr stepsMapAttr;
2778   NamedAttrList stepsAttrs;
2779   SmallVector<OpAsmParser::OperandType, 4> stepsMapOperands;
2780   if (failed(parser.parseOptionalKeyword("step"))) {
2781     SmallVector<int64_t, 4> steps(ivs.size(), 1);
2782     result.addAttribute(AffineParallelOp::getStepsAttrName(),
2783                         builder.getI64ArrayAttr(steps));
2784   } else {
2785     if (parser.parseAffineMapOfSSAIds(stepsMapOperands, stepsMapAttr,
2786                                       AffineParallelOp::getStepsAttrName(),
2787                                       stepsAttrs,
2788                                       OpAsmParser::Delimiter::Paren))
2789       return failure();
2790 
2791     // Convert steps from an AffineMap into an I64ArrayAttr.
2792     SmallVector<int64_t, 4> steps;
2793     auto stepsMap = stepsMapAttr.getValue();
2794     for (const auto &result : stepsMap.getResults()) {
2795       auto constExpr = result.dyn_cast<AffineConstantExpr>();
2796       if (!constExpr)
2797         return parser.emitError(parser.getNameLoc(),
2798                                 "steps must be constant integers");
2799       steps.push_back(constExpr.getValue());
2800     }
2801     result.addAttribute(AffineParallelOp::getStepsAttrName(),
2802                         builder.getI64ArrayAttr(steps));
2803   }
2804 
2805   // Parse optional clause of the form: `reduce ("addf", "maxf")`, where the
2806   // quoted strings are a member of the enum AtomicRMWKind.
2807   SmallVector<Attribute, 4> reductions;
2808   if (succeeded(parser.parseOptionalKeyword("reduce"))) {
2809     if (parser.parseLParen())
2810       return failure();
2811     do {
2812       // Parse a single quoted string via the attribute parsing, and then
2813       // verify it is a member of the enum and convert to it's integer
2814       // representation.
2815       StringAttr attrVal;
2816       NamedAttrList attrStorage;
2817       auto loc = parser.getCurrentLocation();
2818       if (parser.parseAttribute(attrVal, builder.getNoneType(), "reduce",
2819                                 attrStorage))
2820         return failure();
2821       llvm::Optional<AtomicRMWKind> reduction =
2822           symbolizeAtomicRMWKind(attrVal.getValue());
2823       if (!reduction)
2824         return parser.emitError(loc, "invalid reduction value: ") << attrVal;
2825       reductions.push_back(builder.getI64IntegerAttr(
2826           static_cast<int64_t>(reduction.getValue())));
2827       // While we keep getting commas, keep parsing.
2828     } while (succeeded(parser.parseOptionalComma()));
2829     if (parser.parseRParen())
2830       return failure();
2831   }
2832   result.addAttribute(AffineParallelOp::getReductionsAttrName(),
2833                       builder.getArrayAttr(reductions));
2834 
2835   // Parse return types of reductions (if any)
2836   if (parser.parseOptionalArrowTypeList(result.types))
2837     return failure();
2838 
2839   // Now parse the body.
2840   Region *body = result.addRegion();
2841   SmallVector<Type, 4> types(ivs.size(), indexType);
2842   if (parser.parseRegion(*body, ivs, types) ||
2843       parser.parseOptionalAttrDict(result.attributes))
2844     return failure();
2845 
2846   // Add a terminator if none was parsed.
2847   AffineParallelOp::ensureTerminator(*body, builder, result.location);
2848   return success();
2849 }
2850 
2851 //===----------------------------------------------------------------------===//
2852 // AffineYieldOp
2853 //===----------------------------------------------------------------------===//
2854 
verify(AffineYieldOp op)2855 static LogicalResult verify(AffineYieldOp op) {
2856   auto *parentOp = op->getParentOp();
2857   auto results = parentOp->getResults();
2858   auto operands = op.getOperands();
2859 
2860   if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
2861     return op.emitOpError() << "only terminates affine.if/for/parallel regions";
2862   if (parentOp->getNumResults() != op.getNumOperands())
2863     return op.emitOpError() << "parent of yield must have same number of "
2864                                "results as the yield operands";
2865   for (auto it : llvm::zip(results, operands)) {
2866     if (std::get<0>(it).getType() != std::get<1>(it).getType())
2867       return op.emitOpError()
2868              << "types mismatch between yield op and its parent";
2869   }
2870 
2871   return success();
2872 }
2873 
2874 //===----------------------------------------------------------------------===//
2875 // AffineVectorLoadOp
2876 //===----------------------------------------------------------------------===//
2877 
build(OpBuilder & builder,OperationState & result,VectorType resultType,AffineMap map,ValueRange operands)2878 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
2879                                VectorType resultType, AffineMap map,
2880                                ValueRange operands) {
2881   assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
2882   result.addOperands(operands);
2883   if (map)
2884     result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
2885   result.types.push_back(resultType);
2886 }
2887 
build(OpBuilder & builder,OperationState & result,VectorType resultType,Value memref,AffineMap map,ValueRange mapOperands)2888 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
2889                                VectorType resultType, Value memref,
2890                                AffineMap map, ValueRange mapOperands) {
2891   assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
2892   result.addOperands(memref);
2893   result.addOperands(mapOperands);
2894   result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
2895   result.types.push_back(resultType);
2896 }
2897 
build(OpBuilder & builder,OperationState & result,VectorType resultType,Value memref,ValueRange indices)2898 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
2899                                VectorType resultType, Value memref,
2900                                ValueRange indices) {
2901   auto memrefType = memref.getType().cast<MemRefType>();
2902   int64_t rank = memrefType.getRank();
2903   // Create identity map for memrefs with at least one dimension or () -> ()
2904   // for zero-dimensional memrefs.
2905   auto map =
2906       rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
2907   build(builder, result, resultType, memref, map, indices);
2908 }
2909 
parseAffineVectorLoadOp(OpAsmParser & parser,OperationState & result)2910 static ParseResult parseAffineVectorLoadOp(OpAsmParser &parser,
2911                                            OperationState &result) {
2912   auto &builder = parser.getBuilder();
2913   auto indexTy = builder.getIndexType();
2914 
2915   MemRefType memrefType;
2916   VectorType resultType;
2917   OpAsmParser::OperandType memrefInfo;
2918   AffineMapAttr mapAttr;
2919   SmallVector<OpAsmParser::OperandType, 1> mapOperands;
2920   return failure(
2921       parser.parseOperand(memrefInfo) ||
2922       parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
2923                                     AffineVectorLoadOp::getMapAttrName(),
2924                                     result.attributes) ||
2925       parser.parseOptionalAttrDict(result.attributes) ||
2926       parser.parseColonType(memrefType) || parser.parseComma() ||
2927       parser.parseType(resultType) ||
2928       parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
2929       parser.resolveOperands(mapOperands, indexTy, result.operands) ||
2930       parser.addTypeToList(resultType, result.types));
2931 }
2932 
print(OpAsmPrinter & p,AffineVectorLoadOp op)2933 static void print(OpAsmPrinter &p, AffineVectorLoadOp op) {
2934   p << "affine.vector_load " << op.getMemRef() << '[';
2935   if (AffineMapAttr mapAttr =
2936           op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
2937     p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
2938   p << ']';
2939   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()});
2940   p << " : " << op.getMemRefType() << ", " << op.getType();
2941 }
2942 
2943 /// Verify common invariants of affine.vector_load and affine.vector_store.
verifyVectorMemoryOp(Operation * op,MemRefType memrefType,VectorType vectorType)2944 static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType,
2945                                           VectorType vectorType) {
2946   // Check that memref and vector element types match.
2947   if (memrefType.getElementType() != vectorType.getElementType())
2948     return op->emitOpError(
2949         "requires memref and vector types of the same elemental type");
2950   return success();
2951 }
2952 
verify(AffineVectorLoadOp op)2953 static LogicalResult verify(AffineVectorLoadOp op) {
2954   MemRefType memrefType = op.getMemRefType();
2955   if (failed(verifyMemoryOpIndexing(
2956           op.getOperation(),
2957           op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
2958           op.getMapOperands(), memrefType,
2959           /*numIndexOperands=*/op.getNumOperands() - 1)))
2960     return failure();
2961 
2962   if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType,
2963                                   op.getVectorType())))
2964     return failure();
2965 
2966   return success();
2967 }
2968 
2969 //===----------------------------------------------------------------------===//
2970 // AffineVectorStoreOp
2971 //===----------------------------------------------------------------------===//
2972 
build(OpBuilder & builder,OperationState & result,Value valueToStore,Value memref,AffineMap map,ValueRange mapOperands)2973 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
2974                                 Value valueToStore, Value memref, AffineMap map,
2975                                 ValueRange mapOperands) {
2976   assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
2977   result.addOperands(valueToStore);
2978   result.addOperands(memref);
2979   result.addOperands(mapOperands);
2980   result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
2981 }
2982 
2983 // Use identity map.
build(OpBuilder & builder,OperationState & result,Value valueToStore,Value memref,ValueRange indices)2984 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
2985                                 Value valueToStore, Value memref,
2986                                 ValueRange indices) {
2987   auto memrefType = memref.getType().cast<MemRefType>();
2988   int64_t rank = memrefType.getRank();
2989   // Create identity map for memrefs with at least one dimension or () -> ()
2990   // for zero-dimensional memrefs.
2991   auto map =
2992       rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
2993   build(builder, result, valueToStore, memref, map, indices);
2994 }
2995 
parseAffineVectorStoreOp(OpAsmParser & parser,OperationState & result)2996 static ParseResult parseAffineVectorStoreOp(OpAsmParser &parser,
2997                                             OperationState &result) {
2998   auto indexTy = parser.getBuilder().getIndexType();
2999 
3000   MemRefType memrefType;
3001   VectorType resultType;
3002   OpAsmParser::OperandType storeValueInfo;
3003   OpAsmParser::OperandType memrefInfo;
3004   AffineMapAttr mapAttr;
3005   SmallVector<OpAsmParser::OperandType, 1> mapOperands;
3006   return failure(
3007       parser.parseOperand(storeValueInfo) || parser.parseComma() ||
3008       parser.parseOperand(memrefInfo) ||
3009       parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3010                                     AffineVectorStoreOp::getMapAttrName(),
3011                                     result.attributes) ||
3012       parser.parseOptionalAttrDict(result.attributes) ||
3013       parser.parseColonType(memrefType) || parser.parseComma() ||
3014       parser.parseType(resultType) ||
3015       parser.resolveOperand(storeValueInfo, resultType, result.operands) ||
3016       parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
3017       parser.resolveOperands(mapOperands, indexTy, result.operands));
3018 }
3019 
print(OpAsmPrinter & p,AffineVectorStoreOp op)3020 static void print(OpAsmPrinter &p, AffineVectorStoreOp op) {
3021   p << "affine.vector_store " << op.getValueToStore();
3022   p << ", " << op.getMemRef() << '[';
3023   if (AffineMapAttr mapAttr =
3024           op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
3025     p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
3026   p << ']';
3027   p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{op.getMapAttrName()});
3028   p << " : " << op.getMemRefType() << ", " << op.getValueToStore().getType();
3029 }
3030 
verify(AffineVectorStoreOp op)3031 static LogicalResult verify(AffineVectorStoreOp op) {
3032   MemRefType memrefType = op.getMemRefType();
3033   if (failed(verifyMemoryOpIndexing(
3034           op.getOperation(),
3035           op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
3036           op.getMapOperands(), memrefType,
3037           /*numIndexOperands=*/op.getNumOperands() - 2)))
3038     return failure();
3039 
3040   if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType,
3041                                   op.getVectorType())))
3042     return failure();
3043 
3044   return success();
3045 }
3046 
3047 //===----------------------------------------------------------------------===//
3048 // TableGen'd op method definitions
3049 //===----------------------------------------------------------------------===//
3050 
3051 #define GET_OP_CLASSES
3052 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
3053