1 //===- AffineOps.cpp - MLIR Affine Operations -----------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/Affine/IR/AffineOps.h"
10 #include "mlir/Dialect/Affine/IR/AffineValueMap.h"
11 #include "mlir/Dialect/MemRef/IR/MemRef.h"
12 #include "mlir/Dialect/StandardOps/IR/Ops.h"
13 #include "mlir/IR/BlockAndValueMapping.h"
14 #include "mlir/IR/BuiltinOps.h"
15 #include "mlir/IR/IntegerSet.h"
16 #include "mlir/IR/Matchers.h"
17 #include "mlir/IR/OpImplementation.h"
18 #include "mlir/IR/PatternMatch.h"
19 #include "mlir/Transforms/InliningUtils.h"
20 #include "llvm/ADT/SetVector.h"
21 #include "llvm/ADT/SmallBitVector.h"
22 #include "llvm/ADT/TypeSwitch.h"
23 #include "llvm/Support/Debug.h"
24 
25 using namespace mlir;
26 
27 #define DEBUG_TYPE "affine-analysis"
28 
29 #include "mlir/Dialect/Affine/IR/AffineOpsDialect.cpp.inc"
30 
31 /// A utility function to check if a value is defined at the top level of
32 /// `region` or is an argument of `region`. A value of index type defined at the
33 /// top level of a `AffineScope` region is always a valid symbol for all
34 /// uses in that region.
isTopLevelValue(Value value,Region * region)35 static bool isTopLevelValue(Value value, Region *region) {
36   if (auto arg = value.dyn_cast<BlockArgument>())
37     return arg.getParentRegion() == region;
38   return value.getDefiningOp()->getParentRegion() == region;
39 }
40 
41 /// Checks if `value` known to be a legal affine dimension or symbol in `src`
42 /// region remains legal if the operation that uses it is inlined into `dest`
43 /// with the given value mapping. `legalityCheck` is either `isValidDim` or
44 /// `isValidSymbol`, depending on the value being required to remain a valid
45 /// dimension or symbol.
46 static bool
remainsLegalAfterInline(Value value,Region * src,Region * dest,const BlockAndValueMapping & mapping,function_ref<bool (Value,Region *)> legalityCheck)47 remainsLegalAfterInline(Value value, Region *src, Region *dest,
48                         const BlockAndValueMapping &mapping,
49                         function_ref<bool(Value, Region *)> legalityCheck) {
50   // If the value is a valid dimension for any other reason than being
51   // a top-level value, it will remain valid: constants get inlined
52   // with the function, transitive affine applies also get inlined and
53   // will be checked themselves, etc.
54   if (!isTopLevelValue(value, src))
55     return true;
56 
57   // If it's a top-level value because it's a block operand, i.e. a
58   // function argument, check whether the value replacing it after
59   // inlining is a valid dimension in the new region.
60   if (value.isa<BlockArgument>())
61     return legalityCheck(mapping.lookup(value), dest);
62 
63   // If it's a top-level value because it's defined in the region,
64   // it can only be inlined if the defining op is a constant or a
65   // `dim`, which can appear anywhere and be valid, since the defining
66   // op won't be top-level anymore after inlining.
67   Attribute operandCst;
68   return matchPattern(value.getDefiningOp(), m_Constant(&operandCst)) ||
69          value.getDefiningOp<memref::DimOp>() ||
70          value.getDefiningOp<tensor::DimOp>();
71 }
72 
73 /// Checks if all values known to be legal affine dimensions or symbols in `src`
74 /// remain so if their respective users are inlined into `dest`.
75 static bool
remainsLegalAfterInline(ValueRange values,Region * src,Region * dest,const BlockAndValueMapping & mapping,function_ref<bool (Value,Region *)> legalityCheck)76 remainsLegalAfterInline(ValueRange values, Region *src, Region *dest,
77                         const BlockAndValueMapping &mapping,
78                         function_ref<bool(Value, Region *)> legalityCheck) {
79   return llvm::all_of(values, [&](Value v) {
80     return remainsLegalAfterInline(v, src, dest, mapping, legalityCheck);
81   });
82 }
83 
84 /// Checks if an affine read or write operation remains legal after inlining
85 /// from `src` to `dest`.
86 template <typename OpTy>
remainsLegalAfterInline(OpTy op,Region * src,Region * dest,const BlockAndValueMapping & mapping)87 static bool remainsLegalAfterInline(OpTy op, Region *src, Region *dest,
88                                     const BlockAndValueMapping &mapping) {
89   static_assert(llvm::is_one_of<OpTy, AffineReadOpInterface,
90                                 AffineWriteOpInterface>::value,
91                 "only ops with affine read/write interface are supported");
92 
93   AffineMap map = op.getAffineMap();
94   ValueRange dimOperands = op.getMapOperands().take_front(map.getNumDims());
95   ValueRange symbolOperands =
96       op.getMapOperands().take_back(map.getNumSymbols());
97   if (!remainsLegalAfterInline(
98           dimOperands, src, dest, mapping,
99           static_cast<bool (*)(Value, Region *)>(isValidDim)))
100     return false;
101   if (!remainsLegalAfterInline(
102           symbolOperands, src, dest, mapping,
103           static_cast<bool (*)(Value, Region *)>(isValidSymbol)))
104     return false;
105   return true;
106 }
107 
108 /// Checks if an affine apply operation remains legal after inlining from `src`
109 /// to `dest`.
110 template <>
remainsLegalAfterInline(AffineApplyOp op,Region * src,Region * dest,const BlockAndValueMapping & mapping)111 bool remainsLegalAfterInline(AffineApplyOp op, Region *src, Region *dest,
112                              const BlockAndValueMapping &mapping) {
113   // If it's a valid dimension, we need to check that it remains so.
114   if (isValidDim(op.getResult(), src))
115     return remainsLegalAfterInline(
116         op.getMapOperands(), src, dest, mapping,
117         static_cast<bool (*)(Value, Region *)>(isValidDim));
118 
119   // Otherwise it must be a valid symbol, check that it remains so.
120   return remainsLegalAfterInline(
121       op.getMapOperands(), src, dest, mapping,
122       static_cast<bool (*)(Value, Region *)>(isValidSymbol));
123 }
124 
125 //===----------------------------------------------------------------------===//
126 // AffineDialect Interfaces
127 //===----------------------------------------------------------------------===//
128 
129 namespace {
130 /// This class defines the interface for handling inlining with affine
131 /// operations.
132 struct AffineInlinerInterface : public DialectInlinerInterface {
133   using DialectInlinerInterface::DialectInlinerInterface;
134 
135   //===--------------------------------------------------------------------===//
136   // Analysis Hooks
137   //===--------------------------------------------------------------------===//
138 
139   /// Returns true if the given region 'src' can be inlined into the region
140   /// 'dest' that is attached to an operation registered to the current dialect.
141   /// 'wouldBeCloned' is set if the region is cloned into its new location
142   /// rather than moved, indicating there may be other users.
isLegalToInline__anon4d71c4590211::AffineInlinerInterface143   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
144                        BlockAndValueMapping &valueMapping) const final {
145     // We can inline into affine loops and conditionals if this doesn't break
146     // affine value categorization rules.
147     Operation *destOp = dest->getParentOp();
148     if (!isa<AffineParallelOp, AffineForOp, AffineIfOp>(destOp))
149       return false;
150 
151     // Multi-block regions cannot be inlined into affine constructs, all of
152     // which require single-block regions.
153     if (!llvm::hasSingleElement(*src))
154       return false;
155 
156     // Side-effecting operations that the affine dialect cannot understand
157     // should not be inlined.
158     Block &srcBlock = src->front();
159     for (Operation &op : srcBlock) {
160       // Ops with no side effects are fine,
161       if (auto iface = dyn_cast<MemoryEffectOpInterface>(op)) {
162         if (iface.hasNoEffect())
163           continue;
164       }
165 
166       // Assuming the inlined region is valid, we only need to check if the
167       // inlining would change it.
168       bool remainsValid =
169           llvm::TypeSwitch<Operation *, bool>(&op)
170               .Case<AffineApplyOp, AffineReadOpInterface,
171                     AffineWriteOpInterface>([&](auto op) {
172                 return remainsLegalAfterInline(op, src, dest, valueMapping);
173               })
174               .Default([](Operation *) {
175                 // Conservatively disallow inlining ops we cannot reason about.
176                 return false;
177               });
178 
179       if (!remainsValid)
180         return false;
181     }
182 
183     return true;
184   }
185 
186   /// Returns true if the given operation 'op', that is registered to this
187   /// dialect, can be inlined into the given region, false otherwise.
isLegalToInline__anon4d71c4590211::AffineInlinerInterface188   bool isLegalToInline(Operation *op, Region *region, bool wouldBeCloned,
189                        BlockAndValueMapping &valueMapping) const final {
190     // Always allow inlining affine operations into a region that is marked as
191     // affine scope, or into affine loops and conditionals. There are some edge
192     // cases when inlining *into* affine structures, but that is handled in the
193     // other 'isLegalToInline' hook above.
194     Operation *parentOp = region->getParentOp();
195     return parentOp->hasTrait<OpTrait::AffineScope>() ||
196            isa<AffineForOp, AffineParallelOp, AffineIfOp>(parentOp);
197   }
198 
199   /// Affine regions should be analyzed recursively.
shouldAnalyzeRecursively__anon4d71c4590211::AffineInlinerInterface200   bool shouldAnalyzeRecursively(Operation *op) const final { return true; }
201 };
202 } // end anonymous namespace
203 
204 //===----------------------------------------------------------------------===//
205 // AffineDialect
206 //===----------------------------------------------------------------------===//
207 
initialize()208 void AffineDialect::initialize() {
209   addOperations<AffineDmaStartOp, AffineDmaWaitOp,
210 #define GET_OP_LIST
211 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
212                 >();
213   addInterfaces<AffineInlinerInterface>();
214 }
215 
216 /// Materialize a single constant operation from a given attribute value with
217 /// the desired resultant type.
materializeConstant(OpBuilder & builder,Attribute value,Type type,Location loc)218 Operation *AffineDialect::materializeConstant(OpBuilder &builder,
219                                               Attribute value, Type type,
220                                               Location loc) {
221   return builder.create<ConstantOp>(loc, type, value);
222 }
223 
224 /// A utility function to check if a value is defined at the top level of an
225 /// op with trait `AffineScope`. If the value is defined in an unlinked region,
226 /// conservatively assume it is not top-level. A value of index type defined at
227 /// the top level is always a valid symbol.
isTopLevelValue(Value value)228 bool mlir::isTopLevelValue(Value value) {
229   if (auto arg = value.dyn_cast<BlockArgument>()) {
230     // The block owning the argument may be unlinked, e.g. when the surrounding
231     // region has not yet been attached to an Op, at which point the parent Op
232     // is null.
233     Operation *parentOp = arg.getOwner()->getParentOp();
234     return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
235   }
236   // The defining Op may live in an unlinked block so its parent Op may be null.
237   Operation *parentOp = value.getDefiningOp()->getParentOp();
238   return parentOp && parentOp->hasTrait<OpTrait::AffineScope>();
239 }
240 
241 /// Returns the closest region enclosing `op` that is held by an operation with
242 /// trait `AffineScope`; `nullptr` if there is no such region.
243 //  TODO: getAffineScope should be publicly exposed for affine passes/utilities.
getAffineScope(Operation * op)244 static Region *getAffineScope(Operation *op) {
245   auto *curOp = op;
246   while (auto *parentOp = curOp->getParentOp()) {
247     if (parentOp->hasTrait<OpTrait::AffineScope>())
248       return curOp->getParentRegion();
249     curOp = parentOp;
250   }
251   return nullptr;
252 }
253 
254 // A Value can be used as a dimension id iff it meets one of the following
255 // conditions:
256 // *) It is valid as a symbol.
257 // *) It is an induction variable.
258 // *) It is the result of affine apply operation with dimension id arguments.
isValidDim(Value value)259 bool mlir::isValidDim(Value value) {
260   // The value must be an index type.
261   if (!value.getType().isIndex())
262     return false;
263 
264   if (auto *defOp = value.getDefiningOp())
265     return isValidDim(value, getAffineScope(defOp));
266 
267   // This value has to be a block argument for an op that has the
268   // `AffineScope` trait or for an affine.for or affine.parallel.
269   auto *parentOp = value.cast<BlockArgument>().getOwner()->getParentOp();
270   return parentOp && (parentOp->hasTrait<OpTrait::AffineScope>() ||
271                       isa<AffineForOp, AffineParallelOp>(parentOp));
272 }
273 
274 // Value can be used as a dimension id iff it meets one of the following
275 // conditions:
276 // *) It is valid as a symbol.
277 // *) It is an induction variable.
278 // *) It is the result of an affine apply operation with dimension id operands.
isValidDim(Value value,Region * region)279 bool mlir::isValidDim(Value value, Region *region) {
280   // The value must be an index type.
281   if (!value.getType().isIndex())
282     return false;
283 
284   // All valid symbols are okay.
285   if (isValidSymbol(value, region))
286     return true;
287 
288   auto *op = value.getDefiningOp();
289   if (!op) {
290     // This value has to be a block argument for an affine.for or an
291     // affine.parallel.
292     auto *parentOp = value.cast<BlockArgument>().getOwner()->getParentOp();
293     return isa<AffineForOp, AffineParallelOp>(parentOp);
294   }
295 
296   // Affine apply operation is ok if all of its operands are ok.
297   if (auto applyOp = dyn_cast<AffineApplyOp>(op))
298     return applyOp.isValidDim(region);
299   // The dim op is okay if its operand memref/tensor is defined at the top
300   // level.
301   if (auto dimOp = dyn_cast<memref::DimOp>(op))
302     return isTopLevelValue(dimOp.source());
303   if (auto dimOp = dyn_cast<tensor::DimOp>(op))
304     return isTopLevelValue(dimOp.source());
305   return false;
306 }
307 
308 /// Returns true if the 'index' dimension of the `memref` defined by
309 /// `memrefDefOp` is a statically  shaped one or defined using a valid symbol
310 /// for `region`.
311 template <typename AnyMemRefDefOp>
isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp,unsigned index,Region * region)312 static bool isMemRefSizeValidSymbol(AnyMemRefDefOp memrefDefOp, unsigned index,
313                                     Region *region) {
314   auto memRefType = memrefDefOp.getType();
315   // Statically shaped.
316   if (!memRefType.isDynamicDim(index))
317     return true;
318   // Get the position of the dimension among dynamic dimensions;
319   unsigned dynamicDimPos = memRefType.getDynamicDimIndex(index);
320   return isValidSymbol(*(memrefDefOp.getDynamicSizes().begin() + dynamicDimPos),
321                        region);
322 }
323 
324 /// Returns true if the result of the dim op is a valid symbol for `region`.
325 template <typename OpTy>
isDimOpValidSymbol(OpTy dimOp,Region * region)326 static bool isDimOpValidSymbol(OpTy dimOp, Region *region) {
327   // The dim op is okay if its source is defined at the top level.
328   if (isTopLevelValue(dimOp.source()))
329     return true;
330 
331   // Conservatively handle remaining BlockArguments as non-valid symbols.
332   // E.g. scf.for iterArgs.
333   if (dimOp.source().template isa<BlockArgument>())
334     return false;
335 
336   // The dim op is also okay if its operand memref is a view/subview whose
337   // corresponding size is a valid symbol.
338   Optional<int64_t> index = dimOp.getConstantIndex();
339   assert(index.hasValue() &&
340          "expect only `dim` operations with a constant index");
341   int64_t i = index.getValue();
342   return TypeSwitch<Operation *, bool>(dimOp.source().getDefiningOp())
343       .Case<memref::ViewOp, memref::SubViewOp, memref::AllocOp>(
344           [&](auto op) { return isMemRefSizeValidSymbol(op, i, region); })
345       .Default([](Operation *) { return false; });
346 }
347 
348 // A value can be used as a symbol (at all its use sites) iff it meets one of
349 // the following conditions:
350 // *) It is a constant.
351 // *) Its defining op or block arg appearance is immediately enclosed by an op
352 //    with `AffineScope` trait.
353 // *) It is the result of an affine.apply operation with symbol operands.
354 // *) It is a result of the dim op on a memref whose corresponding size is a
355 //    valid symbol.
isValidSymbol(Value value)356 bool mlir::isValidSymbol(Value value) {
357   // The value must be an index type.
358   if (!value.getType().isIndex())
359     return false;
360 
361   // Check that the value is a top level value.
362   if (isTopLevelValue(value))
363     return true;
364 
365   if (auto *defOp = value.getDefiningOp())
366     return isValidSymbol(value, getAffineScope(defOp));
367 
368   return false;
369 }
370 
371 /// A value can be used as a symbol for `region` iff it meets one of the
372 /// following conditions:
373 /// *) It is a constant.
374 /// *) It is the result of an affine apply operation with symbol arguments.
375 /// *) It is a result of the dim op on a memref whose corresponding size is
376 ///    a valid symbol.
377 /// *) It is defined at the top level of 'region' or is its argument.
378 /// *) It dominates `region`'s parent op.
379 /// If `region` is null, conservatively assume the symbol definition scope does
380 /// not exist and only accept the values that would be symbols regardless of
381 /// the surrounding region structure, i.e. the first three cases above.
isValidSymbol(Value value,Region * region)382 bool mlir::isValidSymbol(Value value, Region *region) {
383   // The value must be an index type.
384   if (!value.getType().isIndex())
385     return false;
386 
387   // A top-level value is a valid symbol.
388   if (region && ::isTopLevelValue(value, region))
389     return true;
390 
391   auto *defOp = value.getDefiningOp();
392   if (!defOp) {
393     // A block argument that is not a top-level value is a valid symbol if it
394     // dominates region's parent op.
395     Operation *regionOp = region ? region->getParentOp() : nullptr;
396     if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
397       if (auto *parentOpRegion = region->getParentOp()->getParentRegion())
398         return isValidSymbol(value, parentOpRegion);
399     return false;
400   }
401 
402   // Constant operation is ok.
403   Attribute operandCst;
404   if (matchPattern(defOp, m_Constant(&operandCst)))
405     return true;
406 
407   // Affine apply operation is ok if all of its operands are ok.
408   if (auto applyOp = dyn_cast<AffineApplyOp>(defOp))
409     return applyOp.isValidSymbol(region);
410 
411   // Dim op results could be valid symbols at any level.
412   if (auto dimOp = dyn_cast<memref::DimOp>(defOp))
413     return isDimOpValidSymbol(dimOp, region);
414   if (auto dimOp = dyn_cast<tensor::DimOp>(defOp))
415     return isDimOpValidSymbol(dimOp, region);
416 
417   // Check for values dominating `region`'s parent op.
418   Operation *regionOp = region ? region->getParentOp() : nullptr;
419   if (regionOp && !regionOp->hasTrait<OpTrait::IsIsolatedFromAbove>())
420     if (auto *parentRegion = region->getParentOp()->getParentRegion())
421       return isValidSymbol(value, parentRegion);
422 
423   return false;
424 }
425 
426 // Returns true if 'value' is a valid index to an affine operation (e.g.
427 // affine.load, affine.store, affine.dma_start, affine.dma_wait) where
428 // `region` provides the polyhedral symbol scope. Returns false otherwise.
isValidAffineIndexOperand(Value value,Region * region)429 static bool isValidAffineIndexOperand(Value value, Region *region) {
430   return isValidDim(value, region) || isValidSymbol(value, region);
431 }
432 
433 /// Prints dimension and symbol list.
printDimAndSymbolList(Operation::operand_iterator begin,Operation::operand_iterator end,unsigned numDims,OpAsmPrinter & printer)434 static void printDimAndSymbolList(Operation::operand_iterator begin,
435                                   Operation::operand_iterator end,
436                                   unsigned numDims, OpAsmPrinter &printer) {
437   OperandRange operands(begin, end);
438   printer << '(' << operands.take_front(numDims) << ')';
439   if (operands.size() > numDims)
440     printer << '[' << operands.drop_front(numDims) << ']';
441 }
442 
443 /// Parses dimension and symbol list and returns true if parsing failed.
parseDimAndSymbolList(OpAsmParser & parser,SmallVectorImpl<Value> & operands,unsigned & numDims)444 ParseResult mlir::parseDimAndSymbolList(OpAsmParser &parser,
445                                         SmallVectorImpl<Value> &operands,
446                                         unsigned &numDims) {
447   SmallVector<OpAsmParser::OperandType, 8> opInfos;
448   if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren))
449     return failure();
450   // Store number of dimensions for validation by caller.
451   numDims = opInfos.size();
452 
453   // Parse the optional symbol operands.
454   auto indexTy = parser.getBuilder().getIndexType();
455   return failure(parser.parseOperandList(
456                      opInfos, OpAsmParser::Delimiter::OptionalSquare) ||
457                  parser.resolveOperands(opInfos, indexTy, operands));
458 }
459 
460 /// Utility function to verify that a set of operands are valid dimension and
461 /// symbol identifiers. The operands should be laid out such that the dimension
462 /// operands are before the symbol operands. This function returns failure if
463 /// there was an invalid operand. An operation is provided to emit any necessary
464 /// errors.
465 template <typename OpTy>
466 static LogicalResult
verifyDimAndSymbolIdentifiers(OpTy & op,Operation::operand_range operands,unsigned numDims)467 verifyDimAndSymbolIdentifiers(OpTy &op, Operation::operand_range operands,
468                               unsigned numDims) {
469   unsigned opIt = 0;
470   for (auto operand : operands) {
471     if (opIt++ < numDims) {
472       if (!isValidDim(operand, getAffineScope(op)))
473         return op.emitOpError("operand cannot be used as a dimension id");
474     } else if (!isValidSymbol(operand, getAffineScope(op))) {
475       return op.emitOpError("operand cannot be used as a symbol");
476     }
477   }
478   return success();
479 }
480 
481 //===----------------------------------------------------------------------===//
482 // AffineApplyOp
483 //===----------------------------------------------------------------------===//
484 
getAffineValueMap()485 AffineValueMap AffineApplyOp::getAffineValueMap() {
486   return AffineValueMap(getAffineMap(), getOperands(), getResult());
487 }
488 
parseAffineApplyOp(OpAsmParser & parser,OperationState & result)489 static ParseResult parseAffineApplyOp(OpAsmParser &parser,
490                                       OperationState &result) {
491   auto &builder = parser.getBuilder();
492   auto indexTy = builder.getIndexType();
493 
494   AffineMapAttr mapAttr;
495   unsigned numDims;
496   if (parser.parseAttribute(mapAttr, "map", result.attributes) ||
497       parseDimAndSymbolList(parser, result.operands, numDims) ||
498       parser.parseOptionalAttrDict(result.attributes))
499     return failure();
500   auto map = mapAttr.getValue();
501 
502   if (map.getNumDims() != numDims ||
503       numDims + map.getNumSymbols() != result.operands.size()) {
504     return parser.emitError(parser.getNameLoc(),
505                             "dimension or symbol index mismatch");
506   }
507 
508   result.types.append(map.getNumResults(), indexTy);
509   return success();
510 }
511 
print(OpAsmPrinter & p,AffineApplyOp op)512 static void print(OpAsmPrinter &p, AffineApplyOp op) {
513   p << AffineApplyOp::getOperationName() << " " << op.mapAttr();
514   printDimAndSymbolList(op.operand_begin(), op.operand_end(),
515                         op.getAffineMap().getNumDims(), p);
516   p.printOptionalAttrDict(op->getAttrs(), /*elidedAttrs=*/{"map"});
517 }
518 
verify(AffineApplyOp op)519 static LogicalResult verify(AffineApplyOp op) {
520   // Check input and output dimensions match.
521   auto map = op.map();
522 
523   // Verify that operand count matches affine map dimension and symbol count.
524   if (op.getNumOperands() != map.getNumDims() + map.getNumSymbols())
525     return op.emitOpError(
526         "operand count and affine map dimension and symbol count must match");
527 
528   // Verify that the map only produces one result.
529   if (map.getNumResults() != 1)
530     return op.emitOpError("mapping must produce one value");
531 
532   return success();
533 }
534 
535 // The result of the affine apply operation can be used as a dimension id if all
536 // its operands are valid dimension ids.
isValidDim()537 bool AffineApplyOp::isValidDim() {
538   return llvm::all_of(getOperands(),
539                       [](Value op) { return mlir::isValidDim(op); });
540 }
541 
542 // The result of the affine apply operation can be used as a dimension id if all
543 // its operands are valid dimension ids with the parent operation of `region`
544 // defining the polyhedral scope for symbols.
isValidDim(Region * region)545 bool AffineApplyOp::isValidDim(Region *region) {
546   return llvm::all_of(getOperands(),
547                       [&](Value op) { return ::isValidDim(op, region); });
548 }
549 
550 // The result of the affine apply operation can be used as a symbol if all its
551 // operands are symbols.
isValidSymbol()552 bool AffineApplyOp::isValidSymbol() {
553   return llvm::all_of(getOperands(),
554                       [](Value op) { return mlir::isValidSymbol(op); });
555 }
556 
557 // The result of the affine apply operation can be used as a symbol in `region`
558 // if all its operands are symbols in `region`.
isValidSymbol(Region * region)559 bool AffineApplyOp::isValidSymbol(Region *region) {
560   return llvm::all_of(getOperands(), [&](Value operand) {
561     return mlir::isValidSymbol(operand, region);
562   });
563 }
564 
fold(ArrayRef<Attribute> operands)565 OpFoldResult AffineApplyOp::fold(ArrayRef<Attribute> operands) {
566   auto map = getAffineMap();
567 
568   // Fold dims and symbols to existing values.
569   auto expr = map.getResult(0);
570   if (auto dim = expr.dyn_cast<AffineDimExpr>())
571     return getOperand(dim.getPosition());
572   if (auto sym = expr.dyn_cast<AffineSymbolExpr>())
573     return getOperand(map.getNumDims() + sym.getPosition());
574 
575   // Otherwise, default to folding the map.
576   SmallVector<Attribute, 1> result;
577   if (failed(map.constantFold(operands, result)))
578     return {};
579   return result[0];
580 }
581 
582 /// Replace all occurrences of AffineExpr at position `pos` in `map` by the
583 /// defining AffineApplyOp expression and operands.
584 /// When `dimOrSymbolPosition < dims.size()`, AffineDimExpr@[pos] is replaced.
585 /// When `dimOrSymbolPosition >= dims.size()`,
586 /// AffineSymbolExpr@[pos - dims.size()] is replaced.
587 /// Mutate `map`,`dims` and `syms` in place as follows:
588 ///   1. `dims` and `syms` are only appended to.
589 ///   2. `map` dim and symbols are gradually shifted to higer positions.
590 ///   3. Old `dim` and `sym` entries are replaced by nullptr
591 /// This avoids the need for any bookkeeping.
replaceDimOrSym(AffineMap * map,unsigned dimOrSymbolPosition,SmallVectorImpl<Value> & dims,SmallVectorImpl<Value> & syms)592 static LogicalResult replaceDimOrSym(AffineMap *map,
593                                      unsigned dimOrSymbolPosition,
594                                      SmallVectorImpl<Value> &dims,
595                                      SmallVectorImpl<Value> &syms) {
596   bool isDimReplacement = (dimOrSymbolPosition < dims.size());
597   unsigned pos = isDimReplacement ? dimOrSymbolPosition
598                                   : dimOrSymbolPosition - dims.size();
599   Value &v = isDimReplacement ? dims[pos] : syms[pos];
600   if (!v)
601     return failure();
602 
603   auto affineApply = v.getDefiningOp<AffineApplyOp>();
604   if (!affineApply)
605     return failure();
606 
607   // At this point we will perform a replacement of `v`, set the entry in `dim`
608   // or `sym` to nullptr immediately.
609   v = nullptr;
610 
611   // Compute the map, dims and symbols coming from the AffineApplyOp.
612   AffineMap composeMap = affineApply.getAffineMap();
613   assert(composeMap.getNumResults() == 1 && "affine.apply with >1 results");
614   AffineExpr composeExpr =
615       composeMap.shiftDims(dims.size()).shiftSymbols(syms.size()).getResult(0);
616   ValueRange composeDims =
617       affineApply.getMapOperands().take_front(composeMap.getNumDims());
618   ValueRange composeSyms =
619       affineApply.getMapOperands().take_back(composeMap.getNumSymbols());
620 
621   // Perform the replacement and append the dims and symbols where relevant.
622   MLIRContext *ctx = map->getContext();
623   AffineExpr toReplace = isDimReplacement ? getAffineDimExpr(pos, ctx)
624                                           : getAffineSymbolExpr(pos, ctx);
625   *map = map->replace(toReplace, composeExpr, dims.size(), syms.size());
626   dims.append(composeDims.begin(), composeDims.end());
627   syms.append(composeSyms.begin(), composeSyms.end());
628 
629   return success();
630 }
631 
632 /// Iterate over `operands` and fold away all those produced by an AffineApplyOp
633 /// iteratively. Perform canonicalization of map and operands as well as
634 /// AffineMap simplification. `map` and `operands` are mutated in place.
composeAffineMapAndOperands(AffineMap * map,SmallVectorImpl<Value> * operands)635 static void composeAffineMapAndOperands(AffineMap *map,
636                                         SmallVectorImpl<Value> *operands) {
637   if (map->getNumResults() == 0) {
638     canonicalizeMapAndOperands(map, operands);
639     *map = simplifyAffineMap(*map);
640     return;
641   }
642 
643   MLIRContext *ctx = map->getContext();
644   SmallVector<Value, 4> dims(operands->begin(),
645                              operands->begin() + map->getNumDims());
646   SmallVector<Value, 4> syms(operands->begin() + map->getNumDims(),
647                              operands->end());
648 
649   // Iterate over dims and symbols coming from AffineApplyOp and replace until
650   // exhaustion. This iteratively mutates `map`, `dims` and `syms`. Both `dims`
651   // and `syms` can only increase by construction.
652   // The implementation uses a `while` loop to support the case of symbols
653   // that may be constructed from dims ;this may be overkill.
654   while (true) {
655     bool changed = false;
656     for (unsigned pos = 0; pos != dims.size() + syms.size(); ++pos)
657       if ((changed |= succeeded(replaceDimOrSym(map, pos, dims, syms))))
658         break;
659     if (!changed)
660       break;
661   }
662 
663   // Clear operands so we can fill them anew.
664   operands->clear();
665 
666   // At this point we may have introduced null operands, prune them out before
667   // canonicalizing map and operands.
668   unsigned nDims = 0, nSyms = 0;
669   SmallVector<AffineExpr, 4> dimReplacements, symReplacements;
670   dimReplacements.reserve(dims.size());
671   symReplacements.reserve(syms.size());
672   for (auto *container : {&dims, &syms}) {
673     bool isDim = (container == &dims);
674     auto &repls = isDim ? dimReplacements : symReplacements;
675     for (auto en : llvm::enumerate(*container)) {
676       Value v = en.value();
677       if (!v) {
678         assert(isDim ? !map->isFunctionOfDim(en.index())
679                      : !map->isFunctionOfSymbol(en.index()) &&
680                            "map is function of unexpected expr@pos");
681         repls.push_back(getAffineConstantExpr(0, ctx));
682         continue;
683       }
684       repls.push_back(isDim ? getAffineDimExpr(nDims++, ctx)
685                             : getAffineSymbolExpr(nSyms++, ctx));
686       operands->push_back(v);
687     }
688   }
689   *map = map->replaceDimsAndSymbols(dimReplacements, symReplacements, nDims,
690                                     nSyms);
691 
692   // Canonicalize and simplify before returning.
693   canonicalizeMapAndOperands(map, operands);
694   *map = simplifyAffineMap(*map);
695 }
696 
fullyComposeAffineMapAndOperands(AffineMap * map,SmallVectorImpl<Value> * operands)697 void mlir::fullyComposeAffineMapAndOperands(AffineMap *map,
698                                             SmallVectorImpl<Value> *operands) {
699   while (llvm::any_of(*operands, [](Value v) {
700     return isa_and_nonnull<AffineApplyOp>(v.getDefiningOp());
701   })) {
702     composeAffineMapAndOperands(map, operands);
703   }
704 }
705 
makeComposedAffineApply(OpBuilder & b,Location loc,AffineMap map,ValueRange operands)706 AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc,
707                                             AffineMap map,
708                                             ValueRange operands) {
709   AffineMap normalizedMap = map;
710   SmallVector<Value, 8> normalizedOperands(operands.begin(), operands.end());
711   composeAffineMapAndOperands(&normalizedMap, &normalizedOperands);
712   assert(normalizedMap);
713   return b.create<AffineApplyOp>(loc, normalizedMap, normalizedOperands);
714 }
715 
makeComposedAffineApply(OpBuilder & b,Location loc,AffineExpr e,ValueRange values)716 AffineApplyOp mlir::makeComposedAffineApply(OpBuilder &b, Location loc,
717                                             AffineExpr e, ValueRange values) {
718   return makeComposedAffineApply(
719       b, loc, AffineMap::inferFromExprList(ArrayRef<AffineExpr>{e}).front(),
720       values);
721 }
722 
723 // A symbol may appear as a dim in affine.apply operations. This function
724 // canonicalizes dims that are valid symbols into actual symbols.
725 template <class MapOrSet>
canonicalizePromotedSymbols(MapOrSet * mapOrSet,SmallVectorImpl<Value> * operands)726 static void canonicalizePromotedSymbols(MapOrSet *mapOrSet,
727                                         SmallVectorImpl<Value> *operands) {
728   if (!mapOrSet || operands->empty())
729     return;
730 
731   assert(mapOrSet->getNumInputs() == operands->size() &&
732          "map/set inputs must match number of operands");
733 
734   auto *context = mapOrSet->getContext();
735   SmallVector<Value, 8> resultOperands;
736   resultOperands.reserve(operands->size());
737   SmallVector<Value, 8> remappedSymbols;
738   remappedSymbols.reserve(operands->size());
739   unsigned nextDim = 0;
740   unsigned nextSym = 0;
741   unsigned oldNumSyms = mapOrSet->getNumSymbols();
742   SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
743   for (unsigned i = 0, e = mapOrSet->getNumInputs(); i != e; ++i) {
744     if (i < mapOrSet->getNumDims()) {
745       if (isValidSymbol((*operands)[i])) {
746         // This is a valid symbol that appears as a dim, canonicalize it.
747         dimRemapping[i] = getAffineSymbolExpr(oldNumSyms + nextSym++, context);
748         remappedSymbols.push_back((*operands)[i]);
749       } else {
750         dimRemapping[i] = getAffineDimExpr(nextDim++, context);
751         resultOperands.push_back((*operands)[i]);
752       }
753     } else {
754       resultOperands.push_back((*operands)[i]);
755     }
756   }
757 
758   resultOperands.append(remappedSymbols.begin(), remappedSymbols.end());
759   *operands = resultOperands;
760   *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, {}, nextDim,
761                                               oldNumSyms + nextSym);
762 
763   assert(mapOrSet->getNumInputs() == operands->size() &&
764          "map/set inputs must match number of operands");
765 }
766 
767 // Works for either an affine map or an integer set.
768 template <class MapOrSet>
canonicalizeMapOrSetAndOperands(MapOrSet * mapOrSet,SmallVectorImpl<Value> * operands)769 static void canonicalizeMapOrSetAndOperands(MapOrSet *mapOrSet,
770                                             SmallVectorImpl<Value> *operands) {
771   static_assert(llvm::is_one_of<MapOrSet, AffineMap, IntegerSet>::value,
772                 "Argument must be either of AffineMap or IntegerSet type");
773 
774   if (!mapOrSet || operands->empty())
775     return;
776 
777   assert(mapOrSet->getNumInputs() == operands->size() &&
778          "map/set inputs must match number of operands");
779 
780   canonicalizePromotedSymbols<MapOrSet>(mapOrSet, operands);
781 
782   // Check to see what dims are used.
783   llvm::SmallBitVector usedDims(mapOrSet->getNumDims());
784   llvm::SmallBitVector usedSyms(mapOrSet->getNumSymbols());
785   mapOrSet->walkExprs([&](AffineExpr expr) {
786     if (auto dimExpr = expr.dyn_cast<AffineDimExpr>())
787       usedDims[dimExpr.getPosition()] = true;
788     else if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>())
789       usedSyms[symExpr.getPosition()] = true;
790   });
791 
792   auto *context = mapOrSet->getContext();
793 
794   SmallVector<Value, 8> resultOperands;
795   resultOperands.reserve(operands->size());
796 
797   llvm::SmallDenseMap<Value, AffineExpr, 8> seenDims;
798   SmallVector<AffineExpr, 8> dimRemapping(mapOrSet->getNumDims());
799   unsigned nextDim = 0;
800   for (unsigned i = 0, e = mapOrSet->getNumDims(); i != e; ++i) {
801     if (usedDims[i]) {
802       // Remap dim positions for duplicate operands.
803       auto it = seenDims.find((*operands)[i]);
804       if (it == seenDims.end()) {
805         dimRemapping[i] = getAffineDimExpr(nextDim++, context);
806         resultOperands.push_back((*operands)[i]);
807         seenDims.insert(std::make_pair((*operands)[i], dimRemapping[i]));
808       } else {
809         dimRemapping[i] = it->second;
810       }
811     }
812   }
813   llvm::SmallDenseMap<Value, AffineExpr, 8> seenSymbols;
814   SmallVector<AffineExpr, 8> symRemapping(mapOrSet->getNumSymbols());
815   unsigned nextSym = 0;
816   for (unsigned i = 0, e = mapOrSet->getNumSymbols(); i != e; ++i) {
817     if (!usedSyms[i])
818       continue;
819     // Handle constant operands (only needed for symbolic operands since
820     // constant operands in dimensional positions would have already been
821     // promoted to symbolic positions above).
822     IntegerAttr operandCst;
823     if (matchPattern((*operands)[i + mapOrSet->getNumDims()],
824                      m_Constant(&operandCst))) {
825       symRemapping[i] =
826           getAffineConstantExpr(operandCst.getValue().getSExtValue(), context);
827       continue;
828     }
829     // Remap symbol positions for duplicate operands.
830     auto it = seenSymbols.find((*operands)[i + mapOrSet->getNumDims()]);
831     if (it == seenSymbols.end()) {
832       symRemapping[i] = getAffineSymbolExpr(nextSym++, context);
833       resultOperands.push_back((*operands)[i + mapOrSet->getNumDims()]);
834       seenSymbols.insert(std::make_pair((*operands)[i + mapOrSet->getNumDims()],
835                                         symRemapping[i]));
836     } else {
837       symRemapping[i] = it->second;
838     }
839   }
840   *mapOrSet = mapOrSet->replaceDimsAndSymbols(dimRemapping, symRemapping,
841                                               nextDim, nextSym);
842   *operands = resultOperands;
843 }
844 
canonicalizeMapAndOperands(AffineMap * map,SmallVectorImpl<Value> * operands)845 void mlir::canonicalizeMapAndOperands(AffineMap *map,
846                                       SmallVectorImpl<Value> *operands) {
847   canonicalizeMapOrSetAndOperands<AffineMap>(map, operands);
848 }
849 
canonicalizeSetAndOperands(IntegerSet * set,SmallVectorImpl<Value> * operands)850 void mlir::canonicalizeSetAndOperands(IntegerSet *set,
851                                       SmallVectorImpl<Value> *operands) {
852   canonicalizeMapOrSetAndOperands<IntegerSet>(set, operands);
853 }
854 
855 namespace {
856 /// Simplify AffineApply, AffineLoad, and AffineStore operations by composing
857 /// maps that supply results into them.
858 ///
859 template <typename AffineOpTy>
860 struct SimplifyAffineOp : public OpRewritePattern<AffineOpTy> {
861   using OpRewritePattern<AffineOpTy>::OpRewritePattern;
862 
863   /// Replace the affine op with another instance of it with the supplied
864   /// map and mapOperands.
865   void replaceAffineOp(PatternRewriter &rewriter, AffineOpTy affineOp,
866                        AffineMap map, ArrayRef<Value> mapOperands) const;
867 
matchAndRewrite__anon4d71c4590d11::SimplifyAffineOp868   LogicalResult matchAndRewrite(AffineOpTy affineOp,
869                                 PatternRewriter &rewriter) const override {
870     static_assert(
871         llvm::is_one_of<AffineOpTy, AffineLoadOp, AffinePrefetchOp,
872                         AffineStoreOp, AffineApplyOp, AffineMinOp, AffineMaxOp,
873                         AffineVectorStoreOp, AffineVectorLoadOp>::value,
874         "affine load/store/vectorstore/vectorload/apply/prefetch/min/max op "
875         "expected");
876     auto map = affineOp.getAffineMap();
877     AffineMap oldMap = map;
878     auto oldOperands = affineOp.getMapOperands();
879     SmallVector<Value, 8> resultOperands(oldOperands);
880     composeAffineMapAndOperands(&map, &resultOperands);
881     canonicalizeMapAndOperands(&map, &resultOperands);
882     if (map == oldMap && std::equal(oldOperands.begin(), oldOperands.end(),
883                                     resultOperands.begin()))
884       return failure();
885 
886     replaceAffineOp(rewriter, affineOp, map, resultOperands);
887     return success();
888   }
889 };
890 
891 // Specialize the template to account for the different build signatures for
892 // affine load, store, and apply ops.
893 template <>
replaceAffineOp(PatternRewriter & rewriter,AffineLoadOp load,AffineMap map,ArrayRef<Value> mapOperands) const894 void SimplifyAffineOp<AffineLoadOp>::replaceAffineOp(
895     PatternRewriter &rewriter, AffineLoadOp load, AffineMap map,
896     ArrayRef<Value> mapOperands) const {
897   rewriter.replaceOpWithNewOp<AffineLoadOp>(load, load.getMemRef(), map,
898                                             mapOperands);
899 }
900 template <>
replaceAffineOp(PatternRewriter & rewriter,AffinePrefetchOp prefetch,AffineMap map,ArrayRef<Value> mapOperands) const901 void SimplifyAffineOp<AffinePrefetchOp>::replaceAffineOp(
902     PatternRewriter &rewriter, AffinePrefetchOp prefetch, AffineMap map,
903     ArrayRef<Value> mapOperands) const {
904   rewriter.replaceOpWithNewOp<AffinePrefetchOp>(
905       prefetch, prefetch.memref(), map, mapOperands, prefetch.localityHint(),
906       prefetch.isWrite(), prefetch.isDataCache());
907 }
908 template <>
replaceAffineOp(PatternRewriter & rewriter,AffineStoreOp store,AffineMap map,ArrayRef<Value> mapOperands) const909 void SimplifyAffineOp<AffineStoreOp>::replaceAffineOp(
910     PatternRewriter &rewriter, AffineStoreOp store, AffineMap map,
911     ArrayRef<Value> mapOperands) const {
912   rewriter.replaceOpWithNewOp<AffineStoreOp>(
913       store, store.getValueToStore(), store.getMemRef(), map, mapOperands);
914 }
915 template <>
replaceAffineOp(PatternRewriter & rewriter,AffineVectorLoadOp vectorload,AffineMap map,ArrayRef<Value> mapOperands) const916 void SimplifyAffineOp<AffineVectorLoadOp>::replaceAffineOp(
917     PatternRewriter &rewriter, AffineVectorLoadOp vectorload, AffineMap map,
918     ArrayRef<Value> mapOperands) const {
919   rewriter.replaceOpWithNewOp<AffineVectorLoadOp>(
920       vectorload, vectorload.getVectorType(), vectorload.getMemRef(), map,
921       mapOperands);
922 }
923 template <>
replaceAffineOp(PatternRewriter & rewriter,AffineVectorStoreOp vectorstore,AffineMap map,ArrayRef<Value> mapOperands) const924 void SimplifyAffineOp<AffineVectorStoreOp>::replaceAffineOp(
925     PatternRewriter &rewriter, AffineVectorStoreOp vectorstore, AffineMap map,
926     ArrayRef<Value> mapOperands) const {
927   rewriter.replaceOpWithNewOp<AffineVectorStoreOp>(
928       vectorstore, vectorstore.getValueToStore(), vectorstore.getMemRef(), map,
929       mapOperands);
930 }
931 
932 // Generic version for ops that don't have extra operands.
933 template <typename AffineOpTy>
replaceAffineOp(PatternRewriter & rewriter,AffineOpTy op,AffineMap map,ArrayRef<Value> mapOperands) const934 void SimplifyAffineOp<AffineOpTy>::replaceAffineOp(
935     PatternRewriter &rewriter, AffineOpTy op, AffineMap map,
936     ArrayRef<Value> mapOperands) const {
937   rewriter.replaceOpWithNewOp<AffineOpTy>(op, map, mapOperands);
938 }
939 } // end anonymous namespace.
940 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)941 void AffineApplyOp::getCanonicalizationPatterns(RewritePatternSet &results,
942                                                 MLIRContext *context) {
943   results.add<SimplifyAffineOp<AffineApplyOp>>(context);
944 }
945 
946 //===----------------------------------------------------------------------===//
947 // Common canonicalization pattern support logic
948 //===----------------------------------------------------------------------===//
949 
950 /// This is a common class used for patterns of the form
951 /// "someop(memrefcast) -> someop".  It folds the source of any memref.cast
952 /// into the root operation directly.
foldMemRefCast(Operation * op,Value ignore=nullptr)953 static LogicalResult foldMemRefCast(Operation *op, Value ignore = nullptr) {
954   bool folded = false;
955   for (OpOperand &operand : op->getOpOperands()) {
956     auto cast = operand.get().getDefiningOp<memref::CastOp>();
957     if (cast && operand.get() != ignore &&
958         !cast.getOperand().getType().isa<UnrankedMemRefType>()) {
959       operand.set(cast.getOperand());
960       folded = true;
961     }
962   }
963   return success(folded);
964 }
965 
966 //===----------------------------------------------------------------------===//
967 // AffineDmaStartOp
968 //===----------------------------------------------------------------------===//
969 
970 // TODO: Check that map operands are loop IVs or symbols.
build(OpBuilder & builder,OperationState & result,Value srcMemRef,AffineMap srcMap,ValueRange srcIndices,Value destMemRef,AffineMap dstMap,ValueRange destIndices,Value tagMemRef,AffineMap tagMap,ValueRange tagIndices,Value numElements,Value stride,Value elementsPerStride)971 void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result,
972                              Value srcMemRef, AffineMap srcMap,
973                              ValueRange srcIndices, Value destMemRef,
974                              AffineMap dstMap, ValueRange destIndices,
975                              Value tagMemRef, AffineMap tagMap,
976                              ValueRange tagIndices, Value numElements,
977                              Value stride, Value elementsPerStride) {
978   result.addOperands(srcMemRef);
979   result.addAttribute(getSrcMapAttrName(), AffineMapAttr::get(srcMap));
980   result.addOperands(srcIndices);
981   result.addOperands(destMemRef);
982   result.addAttribute(getDstMapAttrName(), AffineMapAttr::get(dstMap));
983   result.addOperands(destIndices);
984   result.addOperands(tagMemRef);
985   result.addAttribute(getTagMapAttrName(), AffineMapAttr::get(tagMap));
986   result.addOperands(tagIndices);
987   result.addOperands(numElements);
988   if (stride) {
989     result.addOperands({stride, elementsPerStride});
990   }
991 }
992 
print(OpAsmPrinter & p)993 void AffineDmaStartOp::print(OpAsmPrinter &p) {
994   p << "affine.dma_start " << getSrcMemRef() << '[';
995   p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices());
996   p << "], " << getDstMemRef() << '[';
997   p.printAffineMapOfSSAIds(getDstMapAttr(), getDstIndices());
998   p << "], " << getTagMemRef() << '[';
999   p.printAffineMapOfSSAIds(getTagMapAttr(), getTagIndices());
1000   p << "], " << getNumElements();
1001   if (isStrided()) {
1002     p << ", " << getStride();
1003     p << ", " << getNumElementsPerStride();
1004   }
1005   p << " : " << getSrcMemRefType() << ", " << getDstMemRefType() << ", "
1006     << getTagMemRefType();
1007 }
1008 
1009 // Parse AffineDmaStartOp.
1010 // Ex:
1011 //   affine.dma_start %src[%i, %j], %dst[%k, %l], %tag[%index], %size,
1012 //     %stride, %num_elt_per_stride
1013 //       : memref<3076 x f32, 0>, memref<1024 x f32, 2>, memref<1 x i32>
1014 //
parse(OpAsmParser & parser,OperationState & result)1015 ParseResult AffineDmaStartOp::parse(OpAsmParser &parser,
1016                                     OperationState &result) {
1017   OpAsmParser::OperandType srcMemRefInfo;
1018   AffineMapAttr srcMapAttr;
1019   SmallVector<OpAsmParser::OperandType, 4> srcMapOperands;
1020   OpAsmParser::OperandType dstMemRefInfo;
1021   AffineMapAttr dstMapAttr;
1022   SmallVector<OpAsmParser::OperandType, 4> dstMapOperands;
1023   OpAsmParser::OperandType tagMemRefInfo;
1024   AffineMapAttr tagMapAttr;
1025   SmallVector<OpAsmParser::OperandType, 4> tagMapOperands;
1026   OpAsmParser::OperandType numElementsInfo;
1027   SmallVector<OpAsmParser::OperandType, 2> strideInfo;
1028 
1029   SmallVector<Type, 3> types;
1030   auto indexType = parser.getBuilder().getIndexType();
1031 
1032   // Parse and resolve the following list of operands:
1033   // *) dst memref followed by its affine maps operands (in square brackets).
1034   // *) src memref followed by its affine map operands (in square brackets).
1035   // *) tag memref followed by its affine map operands (in square brackets).
1036   // *) number of elements transferred by DMA operation.
1037   if (parser.parseOperand(srcMemRefInfo) ||
1038       parser.parseAffineMapOfSSAIds(srcMapOperands, srcMapAttr,
1039                                     getSrcMapAttrName(), result.attributes) ||
1040       parser.parseComma() || parser.parseOperand(dstMemRefInfo) ||
1041       parser.parseAffineMapOfSSAIds(dstMapOperands, dstMapAttr,
1042                                     getDstMapAttrName(), result.attributes) ||
1043       parser.parseComma() || parser.parseOperand(tagMemRefInfo) ||
1044       parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1045                                     getTagMapAttrName(), result.attributes) ||
1046       parser.parseComma() || parser.parseOperand(numElementsInfo))
1047     return failure();
1048 
1049   // Parse optional stride and elements per stride.
1050   if (parser.parseTrailingOperandList(strideInfo)) {
1051     return failure();
1052   }
1053   if (!strideInfo.empty() && strideInfo.size() != 2) {
1054     return parser.emitError(parser.getNameLoc(),
1055                             "expected two stride related operands");
1056   }
1057   bool isStrided = strideInfo.size() == 2;
1058 
1059   if (parser.parseColonTypeList(types))
1060     return failure();
1061 
1062   if (types.size() != 3)
1063     return parser.emitError(parser.getNameLoc(), "expected three types");
1064 
1065   if (parser.resolveOperand(srcMemRefInfo, types[0], result.operands) ||
1066       parser.resolveOperands(srcMapOperands, indexType, result.operands) ||
1067       parser.resolveOperand(dstMemRefInfo, types[1], result.operands) ||
1068       parser.resolveOperands(dstMapOperands, indexType, result.operands) ||
1069       parser.resolveOperand(tagMemRefInfo, types[2], result.operands) ||
1070       parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1071       parser.resolveOperand(numElementsInfo, indexType, result.operands))
1072     return failure();
1073 
1074   if (isStrided) {
1075     if (parser.resolveOperands(strideInfo, indexType, result.operands))
1076       return failure();
1077   }
1078 
1079   // Check that src/dst/tag operand counts match their map.numInputs.
1080   if (srcMapOperands.size() != srcMapAttr.getValue().getNumInputs() ||
1081       dstMapOperands.size() != dstMapAttr.getValue().getNumInputs() ||
1082       tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1083     return parser.emitError(parser.getNameLoc(),
1084                             "memref operand count not equal to map.numInputs");
1085   return success();
1086 }
1087 
verify()1088 LogicalResult AffineDmaStartOp::verify() {
1089   if (!getOperand(getSrcMemRefOperandIndex()).getType().isa<MemRefType>())
1090     return emitOpError("expected DMA source to be of memref type");
1091   if (!getOperand(getDstMemRefOperandIndex()).getType().isa<MemRefType>())
1092     return emitOpError("expected DMA destination to be of memref type");
1093   if (!getOperand(getTagMemRefOperandIndex()).getType().isa<MemRefType>())
1094     return emitOpError("expected DMA tag to be of memref type");
1095 
1096   unsigned numInputsAllMaps = getSrcMap().getNumInputs() +
1097                               getDstMap().getNumInputs() +
1098                               getTagMap().getNumInputs();
1099   if (getNumOperands() != numInputsAllMaps + 3 + 1 &&
1100       getNumOperands() != numInputsAllMaps + 3 + 1 + 2) {
1101     return emitOpError("incorrect number of operands");
1102   }
1103 
1104   Region *scope = getAffineScope(*this);
1105   for (auto idx : getSrcIndices()) {
1106     if (!idx.getType().isIndex())
1107       return emitOpError("src index to dma_start must have 'index' type");
1108     if (!isValidAffineIndexOperand(idx, scope))
1109       return emitOpError("src index must be a dimension or symbol identifier");
1110   }
1111   for (auto idx : getDstIndices()) {
1112     if (!idx.getType().isIndex())
1113       return emitOpError("dst index to dma_start must have 'index' type");
1114     if (!isValidAffineIndexOperand(idx, scope))
1115       return emitOpError("dst index must be a dimension or symbol identifier");
1116   }
1117   for (auto idx : getTagIndices()) {
1118     if (!idx.getType().isIndex())
1119       return emitOpError("tag index to dma_start must have 'index' type");
1120     if (!isValidAffineIndexOperand(idx, scope))
1121       return emitOpError("tag index must be a dimension or symbol identifier");
1122   }
1123   return success();
1124 }
1125 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1126 LogicalResult AffineDmaStartOp::fold(ArrayRef<Attribute> cstOperands,
1127                                      SmallVectorImpl<OpFoldResult> &results) {
1128   /// dma_start(memrefcast) -> dma_start
1129   return foldMemRefCast(*this);
1130 }
1131 
1132 //===----------------------------------------------------------------------===//
1133 // AffineDmaWaitOp
1134 //===----------------------------------------------------------------------===//
1135 
1136 // TODO: Check that map operands are loop IVs or symbols.
build(OpBuilder & builder,OperationState & result,Value tagMemRef,AffineMap tagMap,ValueRange tagIndices,Value numElements)1137 void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result,
1138                             Value tagMemRef, AffineMap tagMap,
1139                             ValueRange tagIndices, Value numElements) {
1140   result.addOperands(tagMemRef);
1141   result.addAttribute(getTagMapAttrName(), AffineMapAttr::get(tagMap));
1142   result.addOperands(tagIndices);
1143   result.addOperands(numElements);
1144 }
1145 
print(OpAsmPrinter & p)1146 void AffineDmaWaitOp::print(OpAsmPrinter &p) {
1147   p << "affine.dma_wait " << getTagMemRef() << '[';
1148   SmallVector<Value, 2> operands(getTagIndices());
1149   p.printAffineMapOfSSAIds(getTagMapAttr(), operands);
1150   p << "], ";
1151   p.printOperand(getNumElements());
1152   p << " : " << getTagMemRef().getType();
1153 }
1154 
1155 // Parse AffineDmaWaitOp.
1156 // Eg:
1157 //   affine.dma_wait %tag[%index], %num_elements
1158 //     : memref<1 x i32, (d0) -> (d0), 4>
1159 //
parse(OpAsmParser & parser,OperationState & result)1160 ParseResult AffineDmaWaitOp::parse(OpAsmParser &parser,
1161                                    OperationState &result) {
1162   OpAsmParser::OperandType tagMemRefInfo;
1163   AffineMapAttr tagMapAttr;
1164   SmallVector<OpAsmParser::OperandType, 2> tagMapOperands;
1165   Type type;
1166   auto indexType = parser.getBuilder().getIndexType();
1167   OpAsmParser::OperandType numElementsInfo;
1168 
1169   // Parse tag memref, its map operands, and dma size.
1170   if (parser.parseOperand(tagMemRefInfo) ||
1171       parser.parseAffineMapOfSSAIds(tagMapOperands, tagMapAttr,
1172                                     getTagMapAttrName(), result.attributes) ||
1173       parser.parseComma() || parser.parseOperand(numElementsInfo) ||
1174       parser.parseColonType(type) ||
1175       parser.resolveOperand(tagMemRefInfo, type, result.operands) ||
1176       parser.resolveOperands(tagMapOperands, indexType, result.operands) ||
1177       parser.resolveOperand(numElementsInfo, indexType, result.operands))
1178     return failure();
1179 
1180   if (!type.isa<MemRefType>())
1181     return parser.emitError(parser.getNameLoc(),
1182                             "expected tag to be of memref type");
1183 
1184   if (tagMapOperands.size() != tagMapAttr.getValue().getNumInputs())
1185     return parser.emitError(parser.getNameLoc(),
1186                             "tag memref operand count != to map.numInputs");
1187   return success();
1188 }
1189 
verify()1190 LogicalResult AffineDmaWaitOp::verify() {
1191   if (!getOperand(0).getType().isa<MemRefType>())
1192     return emitOpError("expected DMA tag to be of memref type");
1193   Region *scope = getAffineScope(*this);
1194   for (auto idx : getTagIndices()) {
1195     if (!idx.getType().isIndex())
1196       return emitOpError("index to dma_wait must have 'index' type");
1197     if (!isValidAffineIndexOperand(idx, scope))
1198       return emitOpError("index must be a dimension or symbol identifier");
1199   }
1200   return success();
1201 }
1202 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)1203 LogicalResult AffineDmaWaitOp::fold(ArrayRef<Attribute> cstOperands,
1204                                     SmallVectorImpl<OpFoldResult> &results) {
1205   /// dma_wait(memrefcast) -> dma_wait
1206   return foldMemRefCast(*this);
1207 }
1208 
1209 //===----------------------------------------------------------------------===//
1210 // AffineForOp
1211 //===----------------------------------------------------------------------===//
1212 
1213 /// 'bodyBuilder' is used to build the body of affine.for. If iterArgs and
1214 /// bodyBuilder are empty/null, we include default terminator op.
build(OpBuilder & builder,OperationState & result,ValueRange lbOperands,AffineMap lbMap,ValueRange ubOperands,AffineMap ubMap,int64_t step,ValueRange iterArgs,BodyBuilderFn bodyBuilder)1215 void AffineForOp::build(OpBuilder &builder, OperationState &result,
1216                         ValueRange lbOperands, AffineMap lbMap,
1217                         ValueRange ubOperands, AffineMap ubMap, int64_t step,
1218                         ValueRange iterArgs, BodyBuilderFn bodyBuilder) {
1219   assert(((!lbMap && lbOperands.empty()) ||
1220           lbOperands.size() == lbMap.getNumInputs()) &&
1221          "lower bound operand count does not match the affine map");
1222   assert(((!ubMap && ubOperands.empty()) ||
1223           ubOperands.size() == ubMap.getNumInputs()) &&
1224          "upper bound operand count does not match the affine map");
1225   assert(step > 0 && "step has to be a positive integer constant");
1226 
1227   for (Value val : iterArgs)
1228     result.addTypes(val.getType());
1229 
1230   // Add an attribute for the step.
1231   result.addAttribute(getStepAttrName(),
1232                       builder.getIntegerAttr(builder.getIndexType(), step));
1233 
1234   // Add the lower bound.
1235   result.addAttribute(getLowerBoundAttrName(), AffineMapAttr::get(lbMap));
1236   result.addOperands(lbOperands);
1237 
1238   // Add the upper bound.
1239   result.addAttribute(getUpperBoundAttrName(), AffineMapAttr::get(ubMap));
1240   result.addOperands(ubOperands);
1241 
1242   result.addOperands(iterArgs);
1243   // Create a region and a block for the body.  The argument of the region is
1244   // the loop induction variable.
1245   Region *bodyRegion = result.addRegion();
1246   bodyRegion->push_back(new Block);
1247   Block &bodyBlock = bodyRegion->front();
1248   Value inductionVar = bodyBlock.addArgument(builder.getIndexType());
1249   for (Value val : iterArgs)
1250     bodyBlock.addArgument(val.getType());
1251 
1252   // Create the default terminator if the builder is not provided and if the
1253   // iteration arguments are not provided. Otherwise, leave this to the caller
1254   // because we don't know which values to return from the loop.
1255   if (iterArgs.empty() && !bodyBuilder) {
1256     ensureTerminator(*bodyRegion, builder, result.location);
1257   } else if (bodyBuilder) {
1258     OpBuilder::InsertionGuard guard(builder);
1259     builder.setInsertionPointToStart(&bodyBlock);
1260     bodyBuilder(builder, result.location, inductionVar,
1261                 bodyBlock.getArguments().drop_front());
1262   }
1263 }
1264 
build(OpBuilder & builder,OperationState & result,int64_t lb,int64_t ub,int64_t step,ValueRange iterArgs,BodyBuilderFn bodyBuilder)1265 void AffineForOp::build(OpBuilder &builder, OperationState &result, int64_t lb,
1266                         int64_t ub, int64_t step, ValueRange iterArgs,
1267                         BodyBuilderFn bodyBuilder) {
1268   auto lbMap = AffineMap::getConstantMap(lb, builder.getContext());
1269   auto ubMap = AffineMap::getConstantMap(ub, builder.getContext());
1270   return build(builder, result, {}, lbMap, {}, ubMap, step, iterArgs,
1271                bodyBuilder);
1272 }
1273 
verify(AffineForOp op)1274 static LogicalResult verify(AffineForOp op) {
1275   // Check that the body defines as single block argument for the induction
1276   // variable.
1277   auto *body = op.getBody();
1278   if (body->getNumArguments() == 0 || !body->getArgument(0).getType().isIndex())
1279     return op.emitOpError(
1280         "expected body to have a single index argument for the "
1281         "induction variable");
1282 
1283   // Verify that the bound operands are valid dimension/symbols.
1284   /// Lower bound.
1285   if (op.getLowerBoundMap().getNumInputs() > 0)
1286     if (failed(
1287             verifyDimAndSymbolIdentifiers(op, op.getLowerBoundOperands(),
1288                                           op.getLowerBoundMap().getNumDims())))
1289       return failure();
1290   /// Upper bound.
1291   if (op.getUpperBoundMap().getNumInputs() > 0)
1292     if (failed(
1293             verifyDimAndSymbolIdentifiers(op, op.getUpperBoundOperands(),
1294                                           op.getUpperBoundMap().getNumDims())))
1295       return failure();
1296 
1297   unsigned opNumResults = op.getNumResults();
1298   if (opNumResults == 0)
1299     return success();
1300 
1301   // If ForOp defines values, check that the number and types of the defined
1302   // values match ForOp initial iter operands and backedge basic block
1303   // arguments.
1304   if (op.getNumIterOperands() != opNumResults)
1305     return op.emitOpError(
1306         "mismatch between the number of loop-carried values and results");
1307   if (op.getNumRegionIterArgs() != opNumResults)
1308     return op.emitOpError(
1309         "mismatch between the number of basic block args and results");
1310 
1311   return success();
1312 }
1313 
1314 /// Parse a for operation loop bounds.
parseBound(bool isLower,OperationState & result,OpAsmParser & p)1315 static ParseResult parseBound(bool isLower, OperationState &result,
1316                               OpAsmParser &p) {
1317   // 'min' / 'max' prefixes are generally syntactic sugar, but are required if
1318   // the map has multiple results.
1319   bool failedToParsedMinMax =
1320       failed(p.parseOptionalKeyword(isLower ? "max" : "min"));
1321 
1322   auto &builder = p.getBuilder();
1323   auto boundAttrName = isLower ? AffineForOp::getLowerBoundAttrName()
1324                                : AffineForOp::getUpperBoundAttrName();
1325 
1326   // Parse ssa-id as identity map.
1327   SmallVector<OpAsmParser::OperandType, 1> boundOpInfos;
1328   if (p.parseOperandList(boundOpInfos))
1329     return failure();
1330 
1331   if (!boundOpInfos.empty()) {
1332     // Check that only one operand was parsed.
1333     if (boundOpInfos.size() > 1)
1334       return p.emitError(p.getNameLoc(),
1335                          "expected only one loop bound operand");
1336 
1337     // TODO: improve error message when SSA value is not of index type.
1338     // Currently it is 'use of value ... expects different type than prior uses'
1339     if (p.resolveOperand(boundOpInfos.front(), builder.getIndexType(),
1340                          result.operands))
1341       return failure();
1342 
1343     // Create an identity map using symbol id. This representation is optimized
1344     // for storage. Analysis passes may expand it into a multi-dimensional map
1345     // if desired.
1346     AffineMap map = builder.getSymbolIdentityMap();
1347     result.addAttribute(boundAttrName, AffineMapAttr::get(map));
1348     return success();
1349   }
1350 
1351   // Get the attribute location.
1352   llvm::SMLoc attrLoc = p.getCurrentLocation();
1353 
1354   Attribute boundAttr;
1355   if (p.parseAttribute(boundAttr, builder.getIndexType(), boundAttrName,
1356                        result.attributes))
1357     return failure();
1358 
1359   // Parse full form - affine map followed by dim and symbol list.
1360   if (auto affineMapAttr = boundAttr.dyn_cast<AffineMapAttr>()) {
1361     unsigned currentNumOperands = result.operands.size();
1362     unsigned numDims;
1363     if (parseDimAndSymbolList(p, result.operands, numDims))
1364       return failure();
1365 
1366     auto map = affineMapAttr.getValue();
1367     if (map.getNumDims() != numDims)
1368       return p.emitError(
1369           p.getNameLoc(),
1370           "dim operand count and affine map dim count must match");
1371 
1372     unsigned numDimAndSymbolOperands =
1373         result.operands.size() - currentNumOperands;
1374     if (numDims + map.getNumSymbols() != numDimAndSymbolOperands)
1375       return p.emitError(
1376           p.getNameLoc(),
1377           "symbol operand count and affine map symbol count must match");
1378 
1379     // If the map has multiple results, make sure that we parsed the min/max
1380     // prefix.
1381     if (map.getNumResults() > 1 && failedToParsedMinMax) {
1382       if (isLower) {
1383         return p.emitError(attrLoc, "lower loop bound affine map with "
1384                                     "multiple results requires 'max' prefix");
1385       }
1386       return p.emitError(attrLoc, "upper loop bound affine map with multiple "
1387                                   "results requires 'min' prefix");
1388     }
1389     return success();
1390   }
1391 
1392   // Parse custom assembly form.
1393   if (auto integerAttr = boundAttr.dyn_cast<IntegerAttr>()) {
1394     result.attributes.pop_back();
1395     result.addAttribute(
1396         boundAttrName,
1397         AffineMapAttr::get(builder.getConstantAffineMap(integerAttr.getInt())));
1398     return success();
1399   }
1400 
1401   return p.emitError(
1402       p.getNameLoc(),
1403       "expected valid affine map representation for loop bounds");
1404 }
1405 
parseAffineForOp(OpAsmParser & parser,OperationState & result)1406 static ParseResult parseAffineForOp(OpAsmParser &parser,
1407                                     OperationState &result) {
1408   auto &builder = parser.getBuilder();
1409   OpAsmParser::OperandType inductionVariable;
1410   // Parse the induction variable followed by '='.
1411   if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual())
1412     return failure();
1413 
1414   // Parse loop bounds.
1415   if (parseBound(/*isLower=*/true, result, parser) ||
1416       parser.parseKeyword("to", " between bounds") ||
1417       parseBound(/*isLower=*/false, result, parser))
1418     return failure();
1419 
1420   // Parse the optional loop step, we default to 1 if one is not present.
1421   if (parser.parseOptionalKeyword("step")) {
1422     result.addAttribute(
1423         AffineForOp::getStepAttrName(),
1424         builder.getIntegerAttr(builder.getIndexType(), /*value=*/1));
1425   } else {
1426     llvm::SMLoc stepLoc = parser.getCurrentLocation();
1427     IntegerAttr stepAttr;
1428     if (parser.parseAttribute(stepAttr, builder.getIndexType(),
1429                               AffineForOp::getStepAttrName().data(),
1430                               result.attributes))
1431       return failure();
1432 
1433     if (stepAttr.getValue().getSExtValue() < 0)
1434       return parser.emitError(
1435           stepLoc,
1436           "expected step to be representable as a positive signed integer");
1437   }
1438 
1439   // Parse the optional initial iteration arguments.
1440   SmallVector<OpAsmParser::OperandType, 4> regionArgs, operands;
1441   SmallVector<Type, 4> argTypes;
1442   regionArgs.push_back(inductionVariable);
1443 
1444   if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
1445     // Parse assignment list and results type list.
1446     if (parser.parseAssignmentList(regionArgs, operands) ||
1447         parser.parseArrowTypeList(result.types))
1448       return failure();
1449     // Resolve input operands.
1450     for (auto operandType : llvm::zip(operands, result.types))
1451       if (parser.resolveOperand(std::get<0>(operandType),
1452                                 std::get<1>(operandType), result.operands))
1453         return failure();
1454   }
1455   // Induction variable.
1456   Type indexType = builder.getIndexType();
1457   argTypes.push_back(indexType);
1458   // Loop carried variables.
1459   argTypes.append(result.types.begin(), result.types.end());
1460   // Parse the body region.
1461   Region *body = result.addRegion();
1462   if (regionArgs.size() != argTypes.size())
1463     return parser.emitError(
1464         parser.getNameLoc(),
1465         "mismatch between the number of loop-carried values and results");
1466   if (parser.parseRegion(*body, regionArgs, argTypes))
1467     return failure();
1468 
1469   AffineForOp::ensureTerminator(*body, builder, result.location);
1470 
1471   // Parse the optional attribute list.
1472   return parser.parseOptionalAttrDict(result.attributes);
1473 }
1474 
printBound(AffineMapAttr boundMap,Operation::operand_range boundOperands,const char * prefix,OpAsmPrinter & p)1475 static void printBound(AffineMapAttr boundMap,
1476                        Operation::operand_range boundOperands,
1477                        const char *prefix, OpAsmPrinter &p) {
1478   AffineMap map = boundMap.getValue();
1479 
1480   // Check if this bound should be printed using custom assembly form.
1481   // The decision to restrict printing custom assembly form to trivial cases
1482   // comes from the will to roundtrip MLIR binary -> text -> binary in a
1483   // lossless way.
1484   // Therefore, custom assembly form parsing and printing is only supported for
1485   // zero-operand constant maps and single symbol operand identity maps.
1486   if (map.getNumResults() == 1) {
1487     AffineExpr expr = map.getResult(0);
1488 
1489     // Print constant bound.
1490     if (map.getNumDims() == 0 && map.getNumSymbols() == 0) {
1491       if (auto constExpr = expr.dyn_cast<AffineConstantExpr>()) {
1492         p << constExpr.getValue();
1493         return;
1494       }
1495     }
1496 
1497     // Print bound that consists of a single SSA symbol if the map is over a
1498     // single symbol.
1499     if (map.getNumDims() == 0 && map.getNumSymbols() == 1) {
1500       if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
1501         p.printOperand(*boundOperands.begin());
1502         return;
1503       }
1504     }
1505   } else {
1506     // Map has multiple results. Print 'min' or 'max' prefix.
1507     p << prefix << ' ';
1508   }
1509 
1510   // Print the map and its operands.
1511   p << boundMap;
1512   printDimAndSymbolList(boundOperands.begin(), boundOperands.end(),
1513                         map.getNumDims(), p);
1514 }
1515 
getNumIterOperands()1516 unsigned AffineForOp::getNumIterOperands() {
1517   AffineMap lbMap = getLowerBoundMapAttr().getValue();
1518   AffineMap ubMap = getUpperBoundMapAttr().getValue();
1519 
1520   return getNumOperands() - lbMap.getNumInputs() - ubMap.getNumInputs();
1521 }
1522 
print(OpAsmPrinter & p,AffineForOp op)1523 static void print(OpAsmPrinter &p, AffineForOp op) {
1524   p << op.getOperationName() << ' ';
1525   p.printOperand(op.getBody()->getArgument(0));
1526   p << " = ";
1527   printBound(op.getLowerBoundMapAttr(), op.getLowerBoundOperands(), "max", p);
1528   p << " to ";
1529   printBound(op.getUpperBoundMapAttr(), op.getUpperBoundOperands(), "min", p);
1530 
1531   if (op.getStep() != 1)
1532     p << " step " << op.getStep();
1533 
1534   bool printBlockTerminators = false;
1535   if (op.getNumIterOperands() > 0) {
1536     p << " iter_args(";
1537     auto regionArgs = op.getRegionIterArgs();
1538     auto operands = op.getIterOperands();
1539 
1540     llvm::interleaveComma(llvm::zip(regionArgs, operands), p, [&](auto it) {
1541       p << std::get<0>(it) << " = " << std::get<1>(it);
1542     });
1543     p << ") -> (" << op.getResultTypes() << ")";
1544     printBlockTerminators = true;
1545   }
1546 
1547   p.printRegion(op.region(),
1548                 /*printEntryBlockArgs=*/false, printBlockTerminators);
1549   p.printOptionalAttrDict(op->getAttrs(),
1550                           /*elidedAttrs=*/{op.getLowerBoundAttrName(),
1551                                            op.getUpperBoundAttrName(),
1552                                            op.getStepAttrName()});
1553 }
1554 
1555 /// Fold the constant bounds of a loop.
foldLoopBounds(AffineForOp forOp)1556 static LogicalResult foldLoopBounds(AffineForOp forOp) {
1557   auto foldLowerOrUpperBound = [&forOp](bool lower) {
1558     // Check to see if each of the operands is the result of a constant.  If
1559     // so, get the value.  If not, ignore it.
1560     SmallVector<Attribute, 8> operandConstants;
1561     auto boundOperands =
1562         lower ? forOp.getLowerBoundOperands() : forOp.getUpperBoundOperands();
1563     for (auto operand : boundOperands) {
1564       Attribute operandCst;
1565       matchPattern(operand, m_Constant(&operandCst));
1566       operandConstants.push_back(operandCst);
1567     }
1568 
1569     AffineMap boundMap =
1570         lower ? forOp.getLowerBoundMap() : forOp.getUpperBoundMap();
1571     assert(boundMap.getNumResults() >= 1 &&
1572            "bound maps should have at least one result");
1573     SmallVector<Attribute, 4> foldedResults;
1574     if (failed(boundMap.constantFold(operandConstants, foldedResults)))
1575       return failure();
1576 
1577     // Compute the max or min as applicable over the results.
1578     assert(!foldedResults.empty() && "bounds should have at least one result");
1579     auto maxOrMin = foldedResults[0].cast<IntegerAttr>().getValue();
1580     for (unsigned i = 1, e = foldedResults.size(); i < e; i++) {
1581       auto foldedResult = foldedResults[i].cast<IntegerAttr>().getValue();
1582       maxOrMin = lower ? llvm::APIntOps::smax(maxOrMin, foldedResult)
1583                        : llvm::APIntOps::smin(maxOrMin, foldedResult);
1584     }
1585     lower ? forOp.setConstantLowerBound(maxOrMin.getSExtValue())
1586           : forOp.setConstantUpperBound(maxOrMin.getSExtValue());
1587     return success();
1588   };
1589 
1590   // Try to fold the lower bound.
1591   bool folded = false;
1592   if (!forOp.hasConstantLowerBound())
1593     folded |= succeeded(foldLowerOrUpperBound(/*lower=*/true));
1594 
1595   // Try to fold the upper bound.
1596   if (!forOp.hasConstantUpperBound())
1597     folded |= succeeded(foldLowerOrUpperBound(/*lower=*/false));
1598   return success(folded);
1599 }
1600 
1601 /// Canonicalize the bounds of the given loop.
canonicalizeLoopBounds(AffineForOp forOp)1602 static LogicalResult canonicalizeLoopBounds(AffineForOp forOp) {
1603   SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
1604   SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
1605 
1606   auto lbMap = forOp.getLowerBoundMap();
1607   auto ubMap = forOp.getUpperBoundMap();
1608   auto prevLbMap = lbMap;
1609   auto prevUbMap = ubMap;
1610 
1611   composeAffineMapAndOperands(&lbMap, &lbOperands);
1612   canonicalizeMapAndOperands(&lbMap, &lbOperands);
1613   lbMap = removeDuplicateExprs(lbMap);
1614 
1615   composeAffineMapAndOperands(&ubMap, &ubOperands);
1616   canonicalizeMapAndOperands(&ubMap, &ubOperands);
1617   ubMap = removeDuplicateExprs(ubMap);
1618 
1619   // Any canonicalization change always leads to updated map(s).
1620   if (lbMap == prevLbMap && ubMap == prevUbMap)
1621     return failure();
1622 
1623   if (lbMap != prevLbMap)
1624     forOp.setLowerBound(lbOperands, lbMap);
1625   if (ubMap != prevUbMap)
1626     forOp.setUpperBound(ubOperands, ubMap);
1627   return success();
1628 }
1629 
1630 namespace {
1631 /// This is a pattern to fold trivially empty loop bodies.
1632 /// TODO: This should be moved into the folding hook.
1633 struct AffineForEmptyLoopFolder : public OpRewritePattern<AffineForOp> {
1634   using OpRewritePattern<AffineForOp>::OpRewritePattern;
1635 
matchAndRewrite__anon4d71c4591011::AffineForEmptyLoopFolder1636   LogicalResult matchAndRewrite(AffineForOp forOp,
1637                                 PatternRewriter &rewriter) const override {
1638     // Check that the body only contains a yield.
1639     if (!llvm::hasSingleElement(*forOp.getBody()))
1640       return failure();
1641     // The initial values of the iteration arguments would be the op's results.
1642     rewriter.replaceOp(forOp, forOp.getIterOperands());
1643     return success();
1644   }
1645 };
1646 } // end anonymous namespace
1647 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1648 void AffineForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1649                                               MLIRContext *context) {
1650   results.add<AffineForEmptyLoopFolder>(context);
1651 }
1652 
1653 /// Returns true if the affine.for has zero iterations in trivial cases.
hasTrivialZeroTripCount(AffineForOp op)1654 static bool hasTrivialZeroTripCount(AffineForOp op) {
1655   if (!op.hasConstantBounds())
1656     return false;
1657   int64_t lb = op.getConstantLowerBound();
1658   int64_t ub = op.getConstantUpperBound();
1659   return ub - lb <= 0;
1660 }
1661 
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)1662 LogicalResult AffineForOp::fold(ArrayRef<Attribute> operands,
1663                                 SmallVectorImpl<OpFoldResult> &results) {
1664   bool folded = succeeded(foldLoopBounds(*this));
1665   folded |= succeeded(canonicalizeLoopBounds(*this));
1666   if (hasTrivialZeroTripCount(*this)) {
1667     // The initial values of the loop-carried variables (iter_args) are the
1668     // results of the op.
1669     results.assign(getIterOperands().begin(), getIterOperands().end());
1670     folded = true;
1671   }
1672   return success(folded);
1673 }
1674 
getLowerBound()1675 AffineBound AffineForOp::getLowerBound() {
1676   auto lbMap = getLowerBoundMap();
1677   return AffineBound(AffineForOp(*this), 0, lbMap.getNumInputs(), lbMap);
1678 }
1679 
getUpperBound()1680 AffineBound AffineForOp::getUpperBound() {
1681   auto lbMap = getLowerBoundMap();
1682   auto ubMap = getUpperBoundMap();
1683   return AffineBound(AffineForOp(*this), lbMap.getNumInputs(),
1684                      lbMap.getNumInputs() + ubMap.getNumInputs(), ubMap);
1685 }
1686 
setLowerBound(ValueRange lbOperands,AffineMap map)1687 void AffineForOp::setLowerBound(ValueRange lbOperands, AffineMap map) {
1688   assert(lbOperands.size() == map.getNumInputs());
1689   assert(map.getNumResults() >= 1 && "bound map has at least one result");
1690 
1691   SmallVector<Value, 4> newOperands(lbOperands.begin(), lbOperands.end());
1692 
1693   auto ubOperands = getUpperBoundOperands();
1694   newOperands.append(ubOperands.begin(), ubOperands.end());
1695   auto iterOperands = getIterOperands();
1696   newOperands.append(iterOperands.begin(), iterOperands.end());
1697   (*this)->setOperands(newOperands);
1698 
1699   (*this)->setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map));
1700 }
1701 
setUpperBound(ValueRange ubOperands,AffineMap map)1702 void AffineForOp::setUpperBound(ValueRange ubOperands, AffineMap map) {
1703   assert(ubOperands.size() == map.getNumInputs());
1704   assert(map.getNumResults() >= 1 && "bound map has at least one result");
1705 
1706   SmallVector<Value, 4> newOperands(getLowerBoundOperands());
1707   newOperands.append(ubOperands.begin(), ubOperands.end());
1708   auto iterOperands = getIterOperands();
1709   newOperands.append(iterOperands.begin(), iterOperands.end());
1710   (*this)->setOperands(newOperands);
1711 
1712   (*this)->setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map));
1713 }
1714 
setLowerBoundMap(AffineMap map)1715 void AffineForOp::setLowerBoundMap(AffineMap map) {
1716   auto lbMap = getLowerBoundMap();
1717   assert(lbMap.getNumDims() == map.getNumDims() &&
1718          lbMap.getNumSymbols() == map.getNumSymbols());
1719   assert(map.getNumResults() >= 1 && "bound map has at least one result");
1720   (void)lbMap;
1721   (*this)->setAttr(getLowerBoundAttrName(), AffineMapAttr::get(map));
1722 }
1723 
setUpperBoundMap(AffineMap map)1724 void AffineForOp::setUpperBoundMap(AffineMap map) {
1725   auto ubMap = getUpperBoundMap();
1726   assert(ubMap.getNumDims() == map.getNumDims() &&
1727          ubMap.getNumSymbols() == map.getNumSymbols());
1728   assert(map.getNumResults() >= 1 && "bound map has at least one result");
1729   (void)ubMap;
1730   (*this)->setAttr(getUpperBoundAttrName(), AffineMapAttr::get(map));
1731 }
1732 
hasConstantLowerBound()1733 bool AffineForOp::hasConstantLowerBound() {
1734   return getLowerBoundMap().isSingleConstant();
1735 }
1736 
hasConstantUpperBound()1737 bool AffineForOp::hasConstantUpperBound() {
1738   return getUpperBoundMap().isSingleConstant();
1739 }
1740 
getConstantLowerBound()1741 int64_t AffineForOp::getConstantLowerBound() {
1742   return getLowerBoundMap().getSingleConstantResult();
1743 }
1744 
getConstantUpperBound()1745 int64_t AffineForOp::getConstantUpperBound() {
1746   return getUpperBoundMap().getSingleConstantResult();
1747 }
1748 
setConstantLowerBound(int64_t value)1749 void AffineForOp::setConstantLowerBound(int64_t value) {
1750   setLowerBound({}, AffineMap::getConstantMap(value, getContext()));
1751 }
1752 
setConstantUpperBound(int64_t value)1753 void AffineForOp::setConstantUpperBound(int64_t value) {
1754   setUpperBound({}, AffineMap::getConstantMap(value, getContext()));
1755 }
1756 
getLowerBoundOperands()1757 AffineForOp::operand_range AffineForOp::getLowerBoundOperands() {
1758   return {operand_begin(), operand_begin() + getLowerBoundMap().getNumInputs()};
1759 }
1760 
getUpperBoundOperands()1761 AffineForOp::operand_range AffineForOp::getUpperBoundOperands() {
1762   return {operand_begin() + getLowerBoundMap().getNumInputs(),
1763           operand_begin() + getLowerBoundMap().getNumInputs() +
1764               getUpperBoundMap().getNumInputs()};
1765 }
1766 
matchingBoundOperandList()1767 bool AffineForOp::matchingBoundOperandList() {
1768   auto lbMap = getLowerBoundMap();
1769   auto ubMap = getUpperBoundMap();
1770   if (lbMap.getNumDims() != ubMap.getNumDims() ||
1771       lbMap.getNumSymbols() != ubMap.getNumSymbols())
1772     return false;
1773 
1774   unsigned numOperands = lbMap.getNumInputs();
1775   for (unsigned i = 0, e = lbMap.getNumInputs(); i < e; i++) {
1776     // Compare Value 's.
1777     if (getOperand(i) != getOperand(numOperands + i))
1778       return false;
1779   }
1780   return true;
1781 }
1782 
getLoopBody()1783 Region &AffineForOp::getLoopBody() { return region(); }
1784 
isDefinedOutsideOfLoop(Value value)1785 bool AffineForOp::isDefinedOutsideOfLoop(Value value) {
1786   return !region().isAncestor(value.getParentRegion());
1787 }
1788 
moveOutOfLoop(ArrayRef<Operation * > ops)1789 LogicalResult AffineForOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
1790   for (auto *op : ops)
1791     op->moveBefore(*this);
1792   return success();
1793 }
1794 
1795 /// Returns true if the provided value is the induction variable of a
1796 /// AffineForOp.
isForInductionVar(Value val)1797 bool mlir::isForInductionVar(Value val) {
1798   return getForInductionVarOwner(val) != AffineForOp();
1799 }
1800 
1801 /// Returns the loop parent of an induction variable. If the provided value is
1802 /// not an induction variable, then return nullptr.
getForInductionVarOwner(Value val)1803 AffineForOp mlir::getForInductionVarOwner(Value val) {
1804   auto ivArg = val.dyn_cast<BlockArgument>();
1805   if (!ivArg || !ivArg.getOwner())
1806     return AffineForOp();
1807   auto *containingInst = ivArg.getOwner()->getParent()->getParentOp();
1808   return dyn_cast<AffineForOp>(containingInst);
1809 }
1810 
1811 /// Extracts the induction variables from a list of AffineForOps and returns
1812 /// them.
extractForInductionVars(ArrayRef<AffineForOp> forInsts,SmallVectorImpl<Value> * ivs)1813 void mlir::extractForInductionVars(ArrayRef<AffineForOp> forInsts,
1814                                    SmallVectorImpl<Value> *ivs) {
1815   ivs->reserve(forInsts.size());
1816   for (auto forInst : forInsts)
1817     ivs->push_back(forInst.getInductionVar());
1818 }
1819 
1820 /// Builds an affine loop nest, using "loopCreatorFn" to create individual loop
1821 /// operations.
1822 template <typename BoundListTy, typename LoopCreatorTy>
buildAffineLoopNestImpl(OpBuilder & builder,Location loc,BoundListTy lbs,BoundListTy ubs,ArrayRef<int64_t> steps,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilderFn,LoopCreatorTy && loopCreatorFn)1823 static void buildAffineLoopNestImpl(
1824     OpBuilder &builder, Location loc, BoundListTy lbs, BoundListTy ubs,
1825     ArrayRef<int64_t> steps,
1826     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn,
1827     LoopCreatorTy &&loopCreatorFn) {
1828   assert(lbs.size() == ubs.size() && "Mismatch in number of arguments");
1829   assert(lbs.size() == steps.size() && "Mismatch in number of arguments");
1830 
1831   // If there are no loops to be constructed, construct the body anyway.
1832   OpBuilder::InsertionGuard guard(builder);
1833   if (lbs.empty()) {
1834     if (bodyBuilderFn)
1835       bodyBuilderFn(builder, loc, ValueRange());
1836     return;
1837   }
1838 
1839   // Create the loops iteratively and store the induction variables.
1840   SmallVector<Value, 4> ivs;
1841   ivs.reserve(lbs.size());
1842   for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
1843     // Callback for creating the loop body, always creates the terminator.
1844     auto loopBody = [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
1845                         ValueRange iterArgs) {
1846       ivs.push_back(iv);
1847       // In the innermost loop, call the body builder.
1848       if (i == e - 1 && bodyBuilderFn) {
1849         OpBuilder::InsertionGuard nestedGuard(nestedBuilder);
1850         bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
1851       }
1852       nestedBuilder.create<AffineYieldOp>(nestedLoc);
1853     };
1854 
1855     // Delegate actual loop creation to the callback in order to dispatch
1856     // between constant- and variable-bound loops.
1857     auto loop = loopCreatorFn(builder, loc, lbs[i], ubs[i], steps[i], loopBody);
1858     builder.setInsertionPointToStart(loop.getBody());
1859   }
1860 }
1861 
1862 /// Creates an affine loop from the bounds known to be constants.
1863 static AffineForOp
buildAffineLoopFromConstants(OpBuilder & builder,Location loc,int64_t lb,int64_t ub,int64_t step,AffineForOp::BodyBuilderFn bodyBuilderFn)1864 buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb,
1865                              int64_t ub, int64_t step,
1866                              AffineForOp::BodyBuilderFn bodyBuilderFn) {
1867   return builder.create<AffineForOp>(loc, lb, ub, step, /*iterArgs=*/llvm::None,
1868                                      bodyBuilderFn);
1869 }
1870 
1871 /// Creates an affine loop from the bounds that may or may not be constants.
1872 static AffineForOp
buildAffineLoopFromValues(OpBuilder & builder,Location loc,Value lb,Value ub,int64_t step,AffineForOp::BodyBuilderFn bodyBuilderFn)1873 buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub,
1874                           int64_t step,
1875                           AffineForOp::BodyBuilderFn bodyBuilderFn) {
1876   auto lbConst = lb.getDefiningOp<ConstantIndexOp>();
1877   auto ubConst = ub.getDefiningOp<ConstantIndexOp>();
1878   if (lbConst && ubConst)
1879     return buildAffineLoopFromConstants(builder, loc, lbConst.getValue(),
1880                                         ubConst.getValue(), step,
1881                                         bodyBuilderFn);
1882   return builder.create<AffineForOp>(loc, lb, builder.getDimIdentityMap(), ub,
1883                                      builder.getDimIdentityMap(), step,
1884                                      /*iterArgs=*/llvm::None, bodyBuilderFn);
1885 }
1886 
buildAffineLoopNest(OpBuilder & builder,Location loc,ArrayRef<int64_t> lbs,ArrayRef<int64_t> ubs,ArrayRef<int64_t> steps,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilderFn)1887 void mlir::buildAffineLoopNest(
1888     OpBuilder &builder, Location loc, ArrayRef<int64_t> lbs,
1889     ArrayRef<int64_t> ubs, ArrayRef<int64_t> steps,
1890     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1891   buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
1892                           buildAffineLoopFromConstants);
1893 }
1894 
buildAffineLoopNest(OpBuilder & builder,Location loc,ValueRange lbs,ValueRange ubs,ArrayRef<int64_t> steps,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilderFn)1895 void mlir::buildAffineLoopNest(
1896     OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
1897     ArrayRef<int64_t> steps,
1898     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1899   buildAffineLoopNestImpl(builder, loc, lbs, ubs, steps, bodyBuilderFn,
1900                           buildAffineLoopFromValues);
1901 }
1902 
replaceForOpWithNewYields(OpBuilder & b,AffineForOp loop,ValueRange newIterOperands,ValueRange newYieldedValues,ValueRange newIterArgs,bool replaceLoopResults)1903 AffineForOp mlir::replaceForOpWithNewYields(OpBuilder &b, AffineForOp loop,
1904                                             ValueRange newIterOperands,
1905                                             ValueRange newYieldedValues,
1906                                             ValueRange newIterArgs,
1907                                             bool replaceLoopResults) {
1908   assert(newIterOperands.size() == newYieldedValues.size() &&
1909          "newIterOperands must be of the same size as newYieldedValues");
1910   // Create a new loop before the existing one, with the extra operands.
1911   OpBuilder::InsertionGuard g(b);
1912   b.setInsertionPoint(loop);
1913   auto operands = llvm::to_vector<4>(loop.getIterOperands());
1914   operands.append(newIterOperands.begin(), newIterOperands.end());
1915   SmallVector<Value, 4> lbOperands(loop.getLowerBoundOperands());
1916   SmallVector<Value, 4> ubOperands(loop.getUpperBoundOperands());
1917   SmallVector<Value, 4> steps(loop.getStep());
1918   auto lbMap = loop.getLowerBoundMap();
1919   auto ubMap = loop.getUpperBoundMap();
1920   AffineForOp newLoop =
1921       b.create<AffineForOp>(loop.getLoc(), lbOperands, lbMap, ubOperands, ubMap,
1922                             loop.getStep(), operands);
1923   // Take the body of the original parent loop.
1924   newLoop.getLoopBody().takeBody(loop.getLoopBody());
1925   for (Value val : newIterArgs)
1926     newLoop.getLoopBody().addArgument(val.getType());
1927 
1928   // Update yield operation with new values to be added.
1929   if (!newYieldedValues.empty()) {
1930     auto yield = cast<AffineYieldOp>(newLoop.getBody()->getTerminator());
1931     b.setInsertionPoint(yield);
1932     auto yieldOperands = llvm::to_vector<4>(yield.getOperands());
1933     yieldOperands.append(newYieldedValues.begin(), newYieldedValues.end());
1934     b.create<AffineYieldOp>(yield.getLoc(), yieldOperands);
1935     yield.erase();
1936   }
1937   if (replaceLoopResults) {
1938     for (auto it : llvm::zip(loop.getResults(), newLoop.getResults().take_front(
1939                                                     loop.getNumResults()))) {
1940       std::get<0>(it).replaceAllUsesWith(std::get<1>(it));
1941     }
1942   }
1943   return newLoop;
1944 }
1945 
1946 //===----------------------------------------------------------------------===//
1947 // AffineIfOp
1948 //===----------------------------------------------------------------------===//
1949 
1950 namespace {
1951 /// Remove else blocks that have nothing other than a zero value yield.
1952 struct SimplifyDeadElse : public OpRewritePattern<AffineIfOp> {
1953   using OpRewritePattern<AffineIfOp>::OpRewritePattern;
1954 
matchAndRewrite__anon4d71c4591211::SimplifyDeadElse1955   LogicalResult matchAndRewrite(AffineIfOp ifOp,
1956                                 PatternRewriter &rewriter) const override {
1957     if (ifOp.elseRegion().empty() ||
1958         !llvm::hasSingleElement(*ifOp.getElseBlock()) || ifOp.getNumResults())
1959       return failure();
1960 
1961     rewriter.startRootUpdate(ifOp);
1962     rewriter.eraseBlock(ifOp.getElseBlock());
1963     rewriter.finalizeRootUpdate(ifOp);
1964     return success();
1965   }
1966 };
1967 
1968 /// Removes affine.if cond if the condition is always true or false in certain
1969 /// trivial cases. Promotes the then/else block in the parent operation block.
1970 struct AlwaysTrueOrFalseIf : public OpRewritePattern<AffineIfOp> {
1971   using OpRewritePattern<AffineIfOp>::OpRewritePattern;
1972 
matchAndRewrite__anon4d71c4591211::AlwaysTrueOrFalseIf1973   LogicalResult matchAndRewrite(AffineIfOp op,
1974                                 PatternRewriter &rewriter) const override {
1975 
1976     auto isTriviallyFalse = [](IntegerSet iSet) {
1977       return iSet.isEmptyIntegerSet();
1978     };
1979 
1980     auto isTriviallyTrue = [](IntegerSet iSet) {
1981       return (iSet.getNumEqualities() == 1 && iSet.getNumInequalities() == 0 &&
1982               iSet.getConstraint(0) == 0);
1983     };
1984 
1985     IntegerSet affineIfConditions = op.getIntegerSet();
1986     Block *blockToMove;
1987     if (isTriviallyFalse(affineIfConditions)) {
1988       // The absence, or equivalently, the emptiness of the else region need not
1989       // be checked when affine.if is returning results because if an affine.if
1990       // operation is returning results, it always has a non-empty else region.
1991       if (op.getNumResults() == 0 && !op.hasElse()) {
1992         // If the else region is absent, or equivalently, empty, remove the
1993         // affine.if operation (which is not returning any results).
1994         rewriter.eraseOp(op);
1995         return success();
1996       }
1997       blockToMove = op.getElseBlock();
1998     } else if (isTriviallyTrue(affineIfConditions)) {
1999       blockToMove = op.getThenBlock();
2000     } else {
2001       return failure();
2002     }
2003     Operation *blockToMoveTerminator = blockToMove->getTerminator();
2004     // Promote the "blockToMove" block to the parent operation block between the
2005     // prologue and epilogue of "op".
2006     rewriter.mergeBlockBefore(blockToMove, op);
2007     // Replace the "op" operation with the operands of the
2008     // "blockToMoveTerminator" operation. Note that "blockToMoveTerminator" is
2009     // the affine.yield operation present in the "blockToMove" block. It has no
2010     // operands when affine.if is not returning results and therefore, in that
2011     // case, replaceOp just erases "op". When affine.if is not returning
2012     // results, the affine.yield operation can be omitted. It gets inserted
2013     // implicitly.
2014     rewriter.replaceOp(op, blockToMoveTerminator->getOperands());
2015     // Erase the "blockToMoveTerminator" operation since it is now in the parent
2016     // operation block, which already has its own terminator.
2017     rewriter.eraseOp(blockToMoveTerminator);
2018     return success();
2019   }
2020 };
2021 } // end anonymous namespace.
2022 
verify(AffineIfOp op)2023 static LogicalResult verify(AffineIfOp op) {
2024   // Verify that we have a condition attribute.
2025   auto conditionAttr =
2026       op->getAttrOfType<IntegerSetAttr>(op.getConditionAttrName());
2027   if (!conditionAttr)
2028     return op.emitOpError(
2029         "requires an integer set attribute named 'condition'");
2030 
2031   // Verify that there are enough operands for the condition.
2032   IntegerSet condition = conditionAttr.getValue();
2033   if (op.getNumOperands() != condition.getNumInputs())
2034     return op.emitOpError(
2035         "operand count and condition integer set dimension and "
2036         "symbol count must match");
2037 
2038   // Verify that the operands are valid dimension/symbols.
2039   if (failed(verifyDimAndSymbolIdentifiers(op, op.getOperands(),
2040                                            condition.getNumDims())))
2041     return failure();
2042 
2043   return success();
2044 }
2045 
parseAffineIfOp(OpAsmParser & parser,OperationState & result)2046 static ParseResult parseAffineIfOp(OpAsmParser &parser,
2047                                    OperationState &result) {
2048   // Parse the condition attribute set.
2049   IntegerSetAttr conditionAttr;
2050   unsigned numDims;
2051   if (parser.parseAttribute(conditionAttr, AffineIfOp::getConditionAttrName(),
2052                             result.attributes) ||
2053       parseDimAndSymbolList(parser, result.operands, numDims))
2054     return failure();
2055 
2056   // Verify the condition operands.
2057   auto set = conditionAttr.getValue();
2058   if (set.getNumDims() != numDims)
2059     return parser.emitError(
2060         parser.getNameLoc(),
2061         "dim operand count and integer set dim count must match");
2062   if (numDims + set.getNumSymbols() != result.operands.size())
2063     return parser.emitError(
2064         parser.getNameLoc(),
2065         "symbol operand count and integer set symbol count must match");
2066 
2067   if (parser.parseOptionalArrowTypeList(result.types))
2068     return failure();
2069 
2070   // Create the regions for 'then' and 'else'.  The latter must be created even
2071   // if it remains empty for the validity of the operation.
2072   result.regions.reserve(2);
2073   Region *thenRegion = result.addRegion();
2074   Region *elseRegion = result.addRegion();
2075 
2076   // Parse the 'then' region.
2077   if (parser.parseRegion(*thenRegion, {}, {}))
2078     return failure();
2079   AffineIfOp::ensureTerminator(*thenRegion, parser.getBuilder(),
2080                                result.location);
2081 
2082   // If we find an 'else' keyword then parse the 'else' region.
2083   if (!parser.parseOptionalKeyword("else")) {
2084     if (parser.parseRegion(*elseRegion, {}, {}))
2085       return failure();
2086     AffineIfOp::ensureTerminator(*elseRegion, parser.getBuilder(),
2087                                  result.location);
2088   }
2089 
2090   // Parse the optional attribute list.
2091   if (parser.parseOptionalAttrDict(result.attributes))
2092     return failure();
2093 
2094   return success();
2095 }
2096 
print(OpAsmPrinter & p,AffineIfOp op)2097 static void print(OpAsmPrinter &p, AffineIfOp op) {
2098   auto conditionAttr =
2099       op->getAttrOfType<IntegerSetAttr>(op.getConditionAttrName());
2100   p << "affine.if " << conditionAttr;
2101   printDimAndSymbolList(op.operand_begin(), op.operand_end(),
2102                         conditionAttr.getValue().getNumDims(), p);
2103   p.printOptionalArrowTypeList(op.getResultTypes());
2104   p.printRegion(op.thenRegion(),
2105                 /*printEntryBlockArgs=*/false,
2106                 /*printBlockTerminators=*/op.getNumResults());
2107 
2108   // Print the 'else' regions if it has any blocks.
2109   auto &elseRegion = op.elseRegion();
2110   if (!elseRegion.empty()) {
2111     p << " else";
2112     p.printRegion(elseRegion,
2113                   /*printEntryBlockArgs=*/false,
2114                   /*printBlockTerminators=*/op.getNumResults());
2115   }
2116 
2117   // Print the attribute list.
2118   p.printOptionalAttrDict(op->getAttrs(),
2119                           /*elidedAttrs=*/op.getConditionAttrName());
2120 }
2121 
getIntegerSet()2122 IntegerSet AffineIfOp::getIntegerSet() {
2123   return (*this)
2124       ->getAttrOfType<IntegerSetAttr>(getConditionAttrName())
2125       .getValue();
2126 }
2127 
setIntegerSet(IntegerSet newSet)2128 void AffineIfOp::setIntegerSet(IntegerSet newSet) {
2129   (*this)->setAttr(getConditionAttrName(), IntegerSetAttr::get(newSet));
2130 }
2131 
setConditional(IntegerSet set,ValueRange operands)2132 void AffineIfOp::setConditional(IntegerSet set, ValueRange operands) {
2133   setIntegerSet(set);
2134   (*this)->setOperands(operands);
2135 }
2136 
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,IntegerSet set,ValueRange args,bool withElseRegion)2137 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
2138                        TypeRange resultTypes, IntegerSet set, ValueRange args,
2139                        bool withElseRegion) {
2140   assert(resultTypes.empty() || withElseRegion);
2141   result.addTypes(resultTypes);
2142   result.addOperands(args);
2143   result.addAttribute(getConditionAttrName(), IntegerSetAttr::get(set));
2144 
2145   Region *thenRegion = result.addRegion();
2146   thenRegion->push_back(new Block());
2147   if (resultTypes.empty())
2148     AffineIfOp::ensureTerminator(*thenRegion, builder, result.location);
2149 
2150   Region *elseRegion = result.addRegion();
2151   if (withElseRegion) {
2152     elseRegion->push_back(new Block());
2153     if (resultTypes.empty())
2154       AffineIfOp::ensureTerminator(*elseRegion, builder, result.location);
2155   }
2156 }
2157 
build(OpBuilder & builder,OperationState & result,IntegerSet set,ValueRange args,bool withElseRegion)2158 void AffineIfOp::build(OpBuilder &builder, OperationState &result,
2159                        IntegerSet set, ValueRange args, bool withElseRegion) {
2160   AffineIfOp::build(builder, result, /*resultTypes=*/{}, set, args,
2161                     withElseRegion);
2162 }
2163 
2164 /// Canonicalize an affine if op's conditional (integer set + operands).
fold(ArrayRef<Attribute>,SmallVectorImpl<OpFoldResult> &)2165 LogicalResult AffineIfOp::fold(ArrayRef<Attribute>,
2166                                SmallVectorImpl<OpFoldResult> &) {
2167   auto set = getIntegerSet();
2168   SmallVector<Value, 4> operands(getOperands());
2169   canonicalizeSetAndOperands(&set, &operands);
2170 
2171   // Any canonicalization change always leads to either a reduction in the
2172   // number of operands or a change in the number of symbolic operands
2173   // (promotion of dims to symbols).
2174   if (operands.size() < getIntegerSet().getNumInputs() ||
2175       set.getNumSymbols() > getIntegerSet().getNumSymbols()) {
2176     setConditional(set, operands);
2177     return success();
2178   }
2179 
2180   return failure();
2181 }
2182 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2183 void AffineIfOp::getCanonicalizationPatterns(RewritePatternSet &results,
2184                                              MLIRContext *context) {
2185   results.add<SimplifyDeadElse, AlwaysTrueOrFalseIf>(context);
2186 }
2187 
2188 //===----------------------------------------------------------------------===//
2189 // AffineLoadOp
2190 //===----------------------------------------------------------------------===//
2191 
build(OpBuilder & builder,OperationState & result,AffineMap map,ValueRange operands)2192 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2193                          AffineMap map, ValueRange operands) {
2194   assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
2195   result.addOperands(operands);
2196   if (map)
2197     result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
2198   auto memrefType = operands[0].getType().cast<MemRefType>();
2199   result.types.push_back(memrefType.getElementType());
2200 }
2201 
build(OpBuilder & builder,OperationState & result,Value memref,AffineMap map,ValueRange mapOperands)2202 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2203                          Value memref, AffineMap map, ValueRange mapOperands) {
2204   assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
2205   result.addOperands(memref);
2206   result.addOperands(mapOperands);
2207   auto memrefType = memref.getType().cast<MemRefType>();
2208   result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
2209   result.types.push_back(memrefType.getElementType());
2210 }
2211 
build(OpBuilder & builder,OperationState & result,Value memref,ValueRange indices)2212 void AffineLoadOp::build(OpBuilder &builder, OperationState &result,
2213                          Value memref, ValueRange indices) {
2214   auto memrefType = memref.getType().cast<MemRefType>();
2215   int64_t rank = memrefType.getRank();
2216   // Create identity map for memrefs with at least one dimension or () -> ()
2217   // for zero-dimensional memrefs.
2218   auto map =
2219       rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
2220   build(builder, result, memref, map, indices);
2221 }
2222 
parseAffineLoadOp(OpAsmParser & parser,OperationState & result)2223 static ParseResult parseAffineLoadOp(OpAsmParser &parser,
2224                                      OperationState &result) {
2225   auto &builder = parser.getBuilder();
2226   auto indexTy = builder.getIndexType();
2227 
2228   MemRefType type;
2229   OpAsmParser::OperandType memrefInfo;
2230   AffineMapAttr mapAttr;
2231   SmallVector<OpAsmParser::OperandType, 1> mapOperands;
2232   return failure(
2233       parser.parseOperand(memrefInfo) ||
2234       parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
2235                                     AffineLoadOp::getMapAttrName(),
2236                                     result.attributes) ||
2237       parser.parseOptionalAttrDict(result.attributes) ||
2238       parser.parseColonType(type) ||
2239       parser.resolveOperand(memrefInfo, type, result.operands) ||
2240       parser.resolveOperands(mapOperands, indexTy, result.operands) ||
2241       parser.addTypeToList(type.getElementType(), result.types));
2242 }
2243 
print(OpAsmPrinter & p,AffineLoadOp op)2244 static void print(OpAsmPrinter &p, AffineLoadOp op) {
2245   p << "affine.load " << op.getMemRef() << '[';
2246   if (AffineMapAttr mapAttr =
2247           op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
2248     p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
2249   p << ']';
2250   p.printOptionalAttrDict(op->getAttrs(),
2251                           /*elidedAttrs=*/{op.getMapAttrName()});
2252   p << " : " << op.getMemRefType();
2253 }
2254 
2255 /// Verify common indexing invariants of affine.load, affine.store,
2256 /// affine.vector_load and affine.vector_store.
2257 static LogicalResult
verifyMemoryOpIndexing(Operation * op,AffineMapAttr mapAttr,Operation::operand_range mapOperands,MemRefType memrefType,unsigned numIndexOperands)2258 verifyMemoryOpIndexing(Operation *op, AffineMapAttr mapAttr,
2259                        Operation::operand_range mapOperands,
2260                        MemRefType memrefType, unsigned numIndexOperands) {
2261   if (mapAttr) {
2262     AffineMap map = mapAttr.getValue();
2263     if (map.getNumResults() != memrefType.getRank())
2264       return op->emitOpError("affine map num results must equal memref rank");
2265     if (map.getNumInputs() != numIndexOperands)
2266       return op->emitOpError("expects as many subscripts as affine map inputs");
2267   } else {
2268     if (memrefType.getRank() != numIndexOperands)
2269       return op->emitOpError(
2270           "expects the number of subscripts to be equal to memref rank");
2271   }
2272 
2273   Region *scope = getAffineScope(op);
2274   for (auto idx : mapOperands) {
2275     if (!idx.getType().isIndex())
2276       return op->emitOpError("index to load must have 'index' type");
2277     if (!isValidAffineIndexOperand(idx, scope))
2278       return op->emitOpError("index must be a dimension or symbol identifier");
2279   }
2280 
2281   return success();
2282 }
2283 
verify(AffineLoadOp op)2284 LogicalResult verify(AffineLoadOp op) {
2285   auto memrefType = op.getMemRefType();
2286   if (op.getType() != memrefType.getElementType())
2287     return op.emitOpError("result type must match element type of memref");
2288 
2289   if (failed(verifyMemoryOpIndexing(
2290           op.getOperation(),
2291           op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
2292           op.getMapOperands(), memrefType,
2293           /*numIndexOperands=*/op.getNumOperands() - 1)))
2294     return failure();
2295 
2296   return success();
2297 }
2298 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2299 void AffineLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
2300                                                MLIRContext *context) {
2301   results.add<SimplifyAffineOp<AffineLoadOp>>(context);
2302 }
2303 
fold(ArrayRef<Attribute> cstOperands)2304 OpFoldResult AffineLoadOp::fold(ArrayRef<Attribute> cstOperands) {
2305   /// load(memrefcast) -> load
2306   if (succeeded(foldMemRefCast(*this)))
2307     return getResult();
2308   return OpFoldResult();
2309 }
2310 
2311 //===----------------------------------------------------------------------===//
2312 // AffineStoreOp
2313 //===----------------------------------------------------------------------===//
2314 
build(OpBuilder & builder,OperationState & result,Value valueToStore,Value memref,AffineMap map,ValueRange mapOperands)2315 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
2316                           Value valueToStore, Value memref, AffineMap map,
2317                           ValueRange mapOperands) {
2318   assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
2319   result.addOperands(valueToStore);
2320   result.addOperands(memref);
2321   result.addOperands(mapOperands);
2322   result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
2323 }
2324 
2325 // Use identity map.
build(OpBuilder & builder,OperationState & result,Value valueToStore,Value memref,ValueRange indices)2326 void AffineStoreOp::build(OpBuilder &builder, OperationState &result,
2327                           Value valueToStore, Value memref,
2328                           ValueRange indices) {
2329   auto memrefType = memref.getType().cast<MemRefType>();
2330   int64_t rank = memrefType.getRank();
2331   // Create identity map for memrefs with at least one dimension or () -> ()
2332   // for zero-dimensional memrefs.
2333   auto map =
2334       rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
2335   build(builder, result, valueToStore, memref, map, indices);
2336 }
2337 
parseAffineStoreOp(OpAsmParser & parser,OperationState & result)2338 static ParseResult parseAffineStoreOp(OpAsmParser &parser,
2339                                       OperationState &result) {
2340   auto indexTy = parser.getBuilder().getIndexType();
2341 
2342   MemRefType type;
2343   OpAsmParser::OperandType storeValueInfo;
2344   OpAsmParser::OperandType memrefInfo;
2345   AffineMapAttr mapAttr;
2346   SmallVector<OpAsmParser::OperandType, 1> mapOperands;
2347   return failure(parser.parseOperand(storeValueInfo) || parser.parseComma() ||
2348                  parser.parseOperand(memrefInfo) ||
2349                  parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
2350                                                AffineStoreOp::getMapAttrName(),
2351                                                result.attributes) ||
2352                  parser.parseOptionalAttrDict(result.attributes) ||
2353                  parser.parseColonType(type) ||
2354                  parser.resolveOperand(storeValueInfo, type.getElementType(),
2355                                        result.operands) ||
2356                  parser.resolveOperand(memrefInfo, type, result.operands) ||
2357                  parser.resolveOperands(mapOperands, indexTy, result.operands));
2358 }
2359 
print(OpAsmPrinter & p,AffineStoreOp op)2360 static void print(OpAsmPrinter &p, AffineStoreOp op) {
2361   p << "affine.store " << op.getValueToStore();
2362   p << ", " << op.getMemRef() << '[';
2363   if (AffineMapAttr mapAttr =
2364           op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
2365     p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
2366   p << ']';
2367   p.printOptionalAttrDict(op->getAttrs(),
2368                           /*elidedAttrs=*/{op.getMapAttrName()});
2369   p << " : " << op.getMemRefType();
2370 }
2371 
verify(AffineStoreOp op)2372 LogicalResult verify(AffineStoreOp op) {
2373   // First operand must have same type as memref element type.
2374   auto memrefType = op.getMemRefType();
2375   if (op.getValueToStore().getType() != memrefType.getElementType())
2376     return op.emitOpError(
2377         "first operand must have same type memref element type");
2378 
2379   if (failed(verifyMemoryOpIndexing(
2380           op.getOperation(),
2381           op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
2382           op.getMapOperands(), memrefType,
2383           /*numIndexOperands=*/op.getNumOperands() - 2)))
2384     return failure();
2385 
2386   return success();
2387 }
2388 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2389 void AffineStoreOp::getCanonicalizationPatterns(RewritePatternSet &results,
2390                                                 MLIRContext *context) {
2391   results.add<SimplifyAffineOp<AffineStoreOp>>(context);
2392 }
2393 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)2394 LogicalResult AffineStoreOp::fold(ArrayRef<Attribute> cstOperands,
2395                                   SmallVectorImpl<OpFoldResult> &results) {
2396   /// store(memrefcast) -> store
2397   return foldMemRefCast(*this, getValueToStore());
2398 }
2399 
2400 //===----------------------------------------------------------------------===//
2401 // AffineMinMaxOpBase
2402 //===----------------------------------------------------------------------===//
2403 
2404 template <typename T>
verifyAffineMinMaxOp(T op)2405 static LogicalResult verifyAffineMinMaxOp(T op) {
2406   // Verify that operand count matches affine map dimension and symbol count.
2407   if (op.getNumOperands() != op.map().getNumDims() + op.map().getNumSymbols())
2408     return op.emitOpError(
2409         "operand count and affine map dimension and symbol count must match");
2410   return success();
2411 }
2412 
2413 template <typename T>
printAffineMinMaxOp(OpAsmPrinter & p,T op)2414 static void printAffineMinMaxOp(OpAsmPrinter &p, T op) {
2415   p << op.getOperationName() << ' ' << op->getAttr(T::getMapAttrName());
2416   auto operands = op.getOperands();
2417   unsigned numDims = op.map().getNumDims();
2418   p << '(' << operands.take_front(numDims) << ')';
2419 
2420   if (operands.size() != numDims)
2421     p << '[' << operands.drop_front(numDims) << ']';
2422   p.printOptionalAttrDict(op->getAttrs(),
2423                           /*elidedAttrs=*/{T::getMapAttrName()});
2424 }
2425 
2426 template <typename T>
parseAffineMinMaxOp(OpAsmParser & parser,OperationState & result)2427 static ParseResult parseAffineMinMaxOp(OpAsmParser &parser,
2428                                        OperationState &result) {
2429   auto &builder = parser.getBuilder();
2430   auto indexType = builder.getIndexType();
2431   SmallVector<OpAsmParser::OperandType, 8> dimInfos;
2432   SmallVector<OpAsmParser::OperandType, 8> symInfos;
2433   AffineMapAttr mapAttr;
2434   return failure(
2435       parser.parseAttribute(mapAttr, T::getMapAttrName(), result.attributes) ||
2436       parser.parseOperandList(dimInfos, OpAsmParser::Delimiter::Paren) ||
2437       parser.parseOperandList(symInfos,
2438                               OpAsmParser::Delimiter::OptionalSquare) ||
2439       parser.parseOptionalAttrDict(result.attributes) ||
2440       parser.resolveOperands(dimInfos, indexType, result.operands) ||
2441       parser.resolveOperands(symInfos, indexType, result.operands) ||
2442       parser.addTypeToList(indexType, result.types));
2443 }
2444 
2445 /// Fold an affine min or max operation with the given operands. The operand
2446 /// list may contain nulls, which are interpreted as the operand not being a
2447 /// constant.
2448 template <typename T>
foldMinMaxOp(T op,ArrayRef<Attribute> operands)2449 static OpFoldResult foldMinMaxOp(T op, ArrayRef<Attribute> operands) {
2450   static_assert(llvm::is_one_of<T, AffineMinOp, AffineMaxOp>::value,
2451                 "expected affine min or max op");
2452 
2453   // Fold the affine map.
2454   // TODO: Fold more cases:
2455   // min(some_affine, some_affine + constant, ...), etc.
2456   SmallVector<int64_t, 2> results;
2457   auto foldedMap = op.map().partialConstantFold(operands, &results);
2458 
2459   // If some of the map results are not constant, try changing the map in-place.
2460   if (results.empty()) {
2461     // If the map is the same, report that folding did not happen.
2462     if (foldedMap == op.map())
2463       return {};
2464     op->setAttr("map", AffineMapAttr::get(foldedMap));
2465     return op.getResult();
2466   }
2467 
2468   // Otherwise, completely fold the op into a constant.
2469   auto resultIt = std::is_same<T, AffineMinOp>::value
2470                       ? std::min_element(results.begin(), results.end())
2471                       : std::max_element(results.begin(), results.end());
2472   if (resultIt == results.end())
2473     return {};
2474   return IntegerAttr::get(IndexType::get(op.getContext()), *resultIt);
2475 }
2476 
2477 /// Remove duplicated expressions in affine min/max ops.
2478 template <typename T>
2479 struct DeduplicateAffineMinMaxExpressions : public OpRewritePattern<T> {
2480   using OpRewritePattern<T>::OpRewritePattern;
2481 
matchAndRewriteDeduplicateAffineMinMaxExpressions2482   LogicalResult matchAndRewrite(T affineOp,
2483                                 PatternRewriter &rewriter) const override {
2484     AffineMap oldMap = affineOp.getAffineMap();
2485 
2486     SmallVector<AffineExpr, 4> newExprs;
2487     for (AffineExpr expr : oldMap.getResults()) {
2488       // This is a linear scan over newExprs, but it should be fine given that
2489       // we typically just have a few expressions per op.
2490       if (!llvm::is_contained(newExprs, expr))
2491         newExprs.push_back(expr);
2492     }
2493 
2494     if (newExprs.size() == oldMap.getNumResults())
2495       return failure();
2496 
2497     auto newMap = AffineMap::get(oldMap.getNumDims(), oldMap.getNumSymbols(),
2498                                  newExprs, rewriter.getContext());
2499     rewriter.replaceOpWithNewOp<T>(affineOp, newMap, affineOp.getMapOperands());
2500 
2501     return success();
2502   }
2503 };
2504 
2505 /// Merge an affine min/max op to its consumers if its consumer is also an
2506 /// affine min/max op.
2507 ///
2508 /// This pattern requires the producer affine min/max op is bound to a
2509 /// dimension/symbol that is used as a standalone expression in the consumer
2510 /// affine op's map.
2511 ///
2512 /// For example, a pattern like the following:
2513 ///
2514 ///   %0 = affine.min affine_map<()[s0] -> (s0 + 16, s0 * 8)> ()[%sym1]
2515 ///   %1 = affine.min affine_map<(d0)[s0] -> (s0 + 4, d0)> (%0)[%sym2]
2516 ///
2517 /// Can be turned into:
2518 ///
2519 ///   %1 = affine.min affine_map<
2520 ///          ()[s0, s1] -> (s0 + 4, s1 + 16, s1 * 8)> ()[%sym2, %sym1]
2521 template <typename T>
2522 struct MergeAffineMinMaxOp : public OpRewritePattern<T> {
2523   using OpRewritePattern<T>::OpRewritePattern;
2524 
matchAndRewriteMergeAffineMinMaxOp2525   LogicalResult matchAndRewrite(T affineOp,
2526                                 PatternRewriter &rewriter) const override {
2527     AffineMap oldMap = affineOp.getAffineMap();
2528     ValueRange dimOperands =
2529         affineOp.getMapOperands().take_front(oldMap.getNumDims());
2530     ValueRange symOperands =
2531         affineOp.getMapOperands().take_back(oldMap.getNumSymbols());
2532 
2533     auto newDimOperands = llvm::to_vector<8>(dimOperands);
2534     auto newSymOperands = llvm::to_vector<8>(symOperands);
2535     SmallVector<AffineExpr, 4> newExprs;
2536     SmallVector<T, 4> producerOps;
2537 
2538     // Go over each expression to see whether it's a single dimension/symbol
2539     // with the corresponding operand which is the result of another affine
2540     // min/max op. If So it can be merged into this affine op.
2541     for (AffineExpr expr : oldMap.getResults()) {
2542       if (auto symExpr = expr.dyn_cast<AffineSymbolExpr>()) {
2543         Value symValue = symOperands[symExpr.getPosition()];
2544         if (auto producerOp = symValue.getDefiningOp<T>()) {
2545           producerOps.push_back(producerOp);
2546           continue;
2547         }
2548       } else if (auto dimExpr = expr.dyn_cast<AffineDimExpr>()) {
2549         Value dimValue = dimOperands[dimExpr.getPosition()];
2550         if (auto producerOp = dimValue.getDefiningOp<T>()) {
2551           producerOps.push_back(producerOp);
2552           continue;
2553         }
2554       }
2555       // For the above cases we will remove the expression by merging the
2556       // producer affine min/max's affine expressions. Otherwise we need to
2557       // keep the existing expression.
2558       newExprs.push_back(expr);
2559     }
2560 
2561     if (producerOps.empty())
2562       return failure();
2563 
2564     unsigned numUsedDims = oldMap.getNumDims();
2565     unsigned numUsedSyms = oldMap.getNumSymbols();
2566 
2567     // Now go over all producer affine ops and merge their expressions.
2568     for (T producerOp : producerOps) {
2569       AffineMap producerMap = producerOp.getAffineMap();
2570       unsigned numProducerDims = producerMap.getNumDims();
2571       unsigned numProducerSyms = producerMap.getNumSymbols();
2572 
2573       // Collect all dimension/symbol values.
2574       ValueRange dimValues =
2575           producerOp.getMapOperands().take_front(numProducerDims);
2576       ValueRange symValues =
2577           producerOp.getMapOperands().take_back(numProducerSyms);
2578       newDimOperands.append(dimValues.begin(), dimValues.end());
2579       newSymOperands.append(symValues.begin(), symValues.end());
2580 
2581       // For expressions we need to shift to avoid overlap.
2582       for (AffineExpr expr : producerMap.getResults()) {
2583         newExprs.push_back(expr.shiftDims(numProducerDims, numUsedDims)
2584                                .shiftSymbols(numProducerSyms, numUsedSyms));
2585       }
2586 
2587       numUsedDims += numProducerDims;
2588       numUsedSyms += numProducerSyms;
2589     }
2590 
2591     auto newMap = AffineMap::get(numUsedDims, numUsedSyms, newExprs,
2592                                  rewriter.getContext());
2593     auto newOperands =
2594         llvm::to_vector<8>(llvm::concat<Value>(newDimOperands, newSymOperands));
2595     rewriter.replaceOpWithNewOp<T>(affineOp, newMap, newOperands);
2596 
2597     return success();
2598   }
2599 };
2600 
2601 template <typename T>
2602 struct CanonicalizeSingleResultAffineMinMaxOp : public OpRewritePattern<T> {
2603   using OpRewritePattern<T>::OpRewritePattern;
2604 
matchAndRewriteCanonicalizeSingleResultAffineMinMaxOp2605   LogicalResult matchAndRewrite(T affineOp,
2606                                 PatternRewriter &rewriter) const override {
2607     if (affineOp.map().getNumResults() != 1)
2608       return failure();
2609     rewriter.replaceOpWithNewOp<AffineApplyOp>(affineOp, affineOp.map(),
2610                                                affineOp.getOperands());
2611     return success();
2612   }
2613 };
2614 
2615 //===----------------------------------------------------------------------===//
2616 // AffineMinOp
2617 //===----------------------------------------------------------------------===//
2618 //
2619 //   %0 = affine.min (d0) -> (1000, d0 + 512) (%i0)
2620 //
2621 
fold(ArrayRef<Attribute> operands)2622 OpFoldResult AffineMinOp::fold(ArrayRef<Attribute> operands) {
2623   return foldMinMaxOp(*this, operands);
2624 }
2625 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)2626 void AffineMinOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2627                                               MLIRContext *context) {
2628   patterns.add<CanonicalizeSingleResultAffineMinMaxOp<AffineMinOp>,
2629                DeduplicateAffineMinMaxExpressions<AffineMinOp>,
2630                MergeAffineMinMaxOp<AffineMinOp>, SimplifyAffineOp<AffineMinOp>>(
2631       context);
2632 }
2633 
2634 //===----------------------------------------------------------------------===//
2635 // AffineMaxOp
2636 //===----------------------------------------------------------------------===//
2637 //
2638 //   %0 = affine.max (d0) -> (1000, d0 + 512) (%i0)
2639 //
2640 
fold(ArrayRef<Attribute> operands)2641 OpFoldResult AffineMaxOp::fold(ArrayRef<Attribute> operands) {
2642   return foldMinMaxOp(*this, operands);
2643 }
2644 
getCanonicalizationPatterns(RewritePatternSet & patterns,MLIRContext * context)2645 void AffineMaxOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
2646                                               MLIRContext *context) {
2647   patterns.add<CanonicalizeSingleResultAffineMinMaxOp<AffineMaxOp>,
2648                DeduplicateAffineMinMaxExpressions<AffineMaxOp>,
2649                MergeAffineMinMaxOp<AffineMaxOp>, SimplifyAffineOp<AffineMaxOp>>(
2650       context);
2651 }
2652 
2653 //===----------------------------------------------------------------------===//
2654 // AffinePrefetchOp
2655 //===----------------------------------------------------------------------===//
2656 
2657 //
2658 // affine.prefetch %0[%i, %j + 5], read, locality<3>, data : memref<400x400xi32>
2659 //
parseAffinePrefetchOp(OpAsmParser & parser,OperationState & result)2660 static ParseResult parseAffinePrefetchOp(OpAsmParser &parser,
2661                                          OperationState &result) {
2662   auto &builder = parser.getBuilder();
2663   auto indexTy = builder.getIndexType();
2664 
2665   MemRefType type;
2666   OpAsmParser::OperandType memrefInfo;
2667   IntegerAttr hintInfo;
2668   auto i32Type = parser.getBuilder().getIntegerType(32);
2669   StringRef readOrWrite, cacheType;
2670 
2671   AffineMapAttr mapAttr;
2672   SmallVector<OpAsmParser::OperandType, 1> mapOperands;
2673   if (parser.parseOperand(memrefInfo) ||
2674       parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
2675                                     AffinePrefetchOp::getMapAttrName(),
2676                                     result.attributes) ||
2677       parser.parseComma() || parser.parseKeyword(&readOrWrite) ||
2678       parser.parseComma() || parser.parseKeyword("locality") ||
2679       parser.parseLess() ||
2680       parser.parseAttribute(hintInfo, i32Type,
2681                             AffinePrefetchOp::getLocalityHintAttrName(),
2682                             result.attributes) ||
2683       parser.parseGreater() || parser.parseComma() ||
2684       parser.parseKeyword(&cacheType) ||
2685       parser.parseOptionalAttrDict(result.attributes) ||
2686       parser.parseColonType(type) ||
2687       parser.resolveOperand(memrefInfo, type, result.operands) ||
2688       parser.resolveOperands(mapOperands, indexTy, result.operands))
2689     return failure();
2690 
2691   if (!readOrWrite.equals("read") && !readOrWrite.equals("write"))
2692     return parser.emitError(parser.getNameLoc(),
2693                             "rw specifier has to be 'read' or 'write'");
2694   result.addAttribute(
2695       AffinePrefetchOp::getIsWriteAttrName(),
2696       parser.getBuilder().getBoolAttr(readOrWrite.equals("write")));
2697 
2698   if (!cacheType.equals("data") && !cacheType.equals("instr"))
2699     return parser.emitError(parser.getNameLoc(),
2700                             "cache type has to be 'data' or 'instr'");
2701 
2702   result.addAttribute(
2703       AffinePrefetchOp::getIsDataCacheAttrName(),
2704       parser.getBuilder().getBoolAttr(cacheType.equals("data")));
2705 
2706   return success();
2707 }
2708 
print(OpAsmPrinter & p,AffinePrefetchOp op)2709 static void print(OpAsmPrinter &p, AffinePrefetchOp op) {
2710   p << AffinePrefetchOp::getOperationName() << " " << op.memref() << '[';
2711   AffineMapAttr mapAttr = op->getAttrOfType<AffineMapAttr>(op.getMapAttrName());
2712   if (mapAttr) {
2713     SmallVector<Value, 2> operands(op.getMapOperands());
2714     p.printAffineMapOfSSAIds(mapAttr, operands);
2715   }
2716   p << ']' << ", " << (op.isWrite() ? "write" : "read") << ", "
2717     << "locality<" << op.localityHint() << ">, "
2718     << (op.isDataCache() ? "data" : "instr");
2719   p.printOptionalAttrDict(
2720       op->getAttrs(),
2721       /*elidedAttrs=*/{op.getMapAttrName(), op.getLocalityHintAttrName(),
2722                        op.getIsDataCacheAttrName(), op.getIsWriteAttrName()});
2723   p << " : " << op.getMemRefType();
2724 }
2725 
verify(AffinePrefetchOp op)2726 static LogicalResult verify(AffinePrefetchOp op) {
2727   auto mapAttr = op->getAttrOfType<AffineMapAttr>(op.getMapAttrName());
2728   if (mapAttr) {
2729     AffineMap map = mapAttr.getValue();
2730     if (map.getNumResults() != op.getMemRefType().getRank())
2731       return op.emitOpError("affine.prefetch affine map num results must equal"
2732                             " memref rank");
2733     if (map.getNumInputs() + 1 != op.getNumOperands())
2734       return op.emitOpError("too few operands");
2735   } else {
2736     if (op.getNumOperands() != 1)
2737       return op.emitOpError("too few operands");
2738   }
2739 
2740   Region *scope = getAffineScope(op);
2741   for (auto idx : op.getMapOperands()) {
2742     if (!isValidAffineIndexOperand(idx, scope))
2743       return op.emitOpError("index must be a dimension or symbol identifier");
2744   }
2745   return success();
2746 }
2747 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)2748 void AffinePrefetchOp::getCanonicalizationPatterns(RewritePatternSet &results,
2749                                                    MLIRContext *context) {
2750   // prefetch(memrefcast) -> prefetch
2751   results.add<SimplifyAffineOp<AffinePrefetchOp>>(context);
2752 }
2753 
fold(ArrayRef<Attribute> cstOperands,SmallVectorImpl<OpFoldResult> & results)2754 LogicalResult AffinePrefetchOp::fold(ArrayRef<Attribute> cstOperands,
2755                                      SmallVectorImpl<OpFoldResult> &results) {
2756   /// prefetch(memrefcast) -> prefetch
2757   return foldMemRefCast(*this);
2758 }
2759 
2760 //===----------------------------------------------------------------------===//
2761 // AffineParallelOp
2762 //===----------------------------------------------------------------------===//
2763 
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,ArrayRef<AtomicRMWKind> reductions,ArrayRef<int64_t> ranges)2764 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
2765                              TypeRange resultTypes,
2766                              ArrayRef<AtomicRMWKind> reductions,
2767                              ArrayRef<int64_t> ranges) {
2768   SmallVector<AffineMap> lbs(ranges.size(), builder.getConstantAffineMap(0));
2769   auto ubs = llvm::to_vector<4>(llvm::map_range(ranges, [&](int64_t value) {
2770     return builder.getConstantAffineMap(value);
2771   }));
2772   SmallVector<int64_t> steps(ranges.size(), 1);
2773   build(builder, result, resultTypes, reductions, lbs, /*lbArgs=*/{}, ubs,
2774         /*ubArgs=*/{}, steps);
2775 }
2776 
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,ArrayRef<AtomicRMWKind> reductions,ArrayRef<AffineMap> lbMaps,ValueRange lbArgs,ArrayRef<AffineMap> ubMaps,ValueRange ubArgs,ArrayRef<int64_t> steps)2777 void AffineParallelOp::build(OpBuilder &builder, OperationState &result,
2778                              TypeRange resultTypes,
2779                              ArrayRef<AtomicRMWKind> reductions,
2780                              ArrayRef<AffineMap> lbMaps, ValueRange lbArgs,
2781                              ArrayRef<AffineMap> ubMaps, ValueRange ubArgs,
2782                              ArrayRef<int64_t> steps) {
2783   assert(llvm::all_of(lbMaps,
2784                       [lbMaps](AffineMap m) {
2785                         return m.getNumDims() == lbMaps[0].getNumDims() &&
2786                                m.getNumSymbols() == lbMaps[0].getNumSymbols();
2787                       }) &&
2788          "expected all lower bounds maps to have the same number of dimensions "
2789          "and symbols");
2790   assert(llvm::all_of(ubMaps,
2791                       [ubMaps](AffineMap m) {
2792                         return m.getNumDims() == ubMaps[0].getNumDims() &&
2793                                m.getNumSymbols() == ubMaps[0].getNumSymbols();
2794                       }) &&
2795          "expected all upper bounds maps to have the same number of dimensions "
2796          "and symbols");
2797   assert((lbMaps.empty() || lbMaps[0].getNumInputs() == lbArgs.size()) &&
2798          "expected lower bound maps to have as many inputs as lower bound "
2799          "operands");
2800   assert((ubMaps.empty() || ubMaps[0].getNumInputs() == ubArgs.size()) &&
2801          "expected upper bound maps to have as many inputs as upper bound "
2802          "operands");
2803 
2804   result.addTypes(resultTypes);
2805 
2806   // Convert the reductions to integer attributes.
2807   SmallVector<Attribute, 4> reductionAttrs;
2808   for (AtomicRMWKind reduction : reductions)
2809     reductionAttrs.push_back(
2810         builder.getI64IntegerAttr(static_cast<int64_t>(reduction)));
2811   result.addAttribute(getReductionsAttrName(),
2812                       builder.getArrayAttr(reductionAttrs));
2813 
2814   // Concatenates maps defined in the same input space (same dimensions and
2815   // symbols), assumes there is at least one map.
2816   auto concatMapsSameInput = [&builder](ArrayRef<AffineMap> maps,
2817                                         SmallVectorImpl<int32_t> &groups) {
2818     if (maps.empty())
2819       return AffineMap::get(builder.getContext());
2820     SmallVector<AffineExpr> exprs;
2821     groups.reserve(groups.size() + maps.size());
2822     exprs.reserve(maps.size());
2823     for (AffineMap m : maps) {
2824       llvm::append_range(exprs, m.getResults());
2825       groups.push_back(m.getNumResults());
2826     }
2827     return AffineMap::get(maps[0].getNumDims(), maps[0].getNumSymbols(), exprs,
2828                           maps[0].getContext());
2829   };
2830 
2831   // Set up the bounds.
2832   SmallVector<int32_t> lbGroups, ubGroups;
2833   AffineMap lbMap = concatMapsSameInput(lbMaps, lbGroups);
2834   AffineMap ubMap = concatMapsSameInput(ubMaps, ubGroups);
2835   result.addAttribute(getLowerBoundsMapAttrName(), AffineMapAttr::get(lbMap));
2836   result.addAttribute(getLowerBoundsGroupsAttrName(),
2837                       builder.getI32TensorAttr(lbGroups));
2838   result.addAttribute(getUpperBoundsMapAttrName(), AffineMapAttr::get(ubMap));
2839   result.addAttribute(getUpperBoundsGroupsAttrName(),
2840                       builder.getI32TensorAttr(ubGroups));
2841   result.addAttribute(getStepsAttrName(), builder.getI64ArrayAttr(steps));
2842   result.addOperands(lbArgs);
2843   result.addOperands(ubArgs);
2844 
2845   // Create a region and a block for the body.
2846   auto *bodyRegion = result.addRegion();
2847   auto *body = new Block();
2848   // Add all the block arguments.
2849   for (unsigned i = 0, e = steps.size(); i < e; ++i)
2850     body->addArgument(IndexType::get(builder.getContext()));
2851   bodyRegion->push_back(body);
2852   if (resultTypes.empty())
2853     ensureTerminator(*bodyRegion, builder, result.location);
2854 }
2855 
getLoopBody()2856 Region &AffineParallelOp::getLoopBody() { return region(); }
2857 
isDefinedOutsideOfLoop(Value value)2858 bool AffineParallelOp::isDefinedOutsideOfLoop(Value value) {
2859   return !region().isAncestor(value.getParentRegion());
2860 }
2861 
moveOutOfLoop(ArrayRef<Operation * > ops)2862 LogicalResult AffineParallelOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
2863   for (Operation *op : ops)
2864     op->moveBefore(*this);
2865   return success();
2866 }
2867 
getNumDims()2868 unsigned AffineParallelOp::getNumDims() { return steps().size(); }
2869 
getLowerBoundsOperands()2870 AffineParallelOp::operand_range AffineParallelOp::getLowerBoundsOperands() {
2871   return getOperands().take_front(lowerBoundsMap().getNumInputs());
2872 }
2873 
getUpperBoundsOperands()2874 AffineParallelOp::operand_range AffineParallelOp::getUpperBoundsOperands() {
2875   return getOperands().drop_front(lowerBoundsMap().getNumInputs());
2876 }
2877 
getLowerBoundMap(unsigned pos)2878 AffineMap AffineParallelOp::getLowerBoundMap(unsigned pos) {
2879   unsigned start = 0;
2880   for (unsigned i = 0; i < pos; ++i)
2881     start += lowerBoundsGroups().getValue<int32_t>(i);
2882   return lowerBoundsMap().getSliceMap(
2883       start, lowerBoundsGroups().getValue<int32_t>(pos));
2884 }
2885 
getUpperBoundMap(unsigned pos)2886 AffineMap AffineParallelOp::getUpperBoundMap(unsigned pos) {
2887   unsigned start = 0;
2888   for (unsigned i = 0; i < pos; ++i)
2889     start += upperBoundsGroups().getValue<int32_t>(i);
2890   return upperBoundsMap().getSliceMap(
2891       start, upperBoundsGroups().getValue<int32_t>(pos));
2892 }
2893 
getLowerBoundsValueMap()2894 AffineValueMap AffineParallelOp::getLowerBoundsValueMap() {
2895   return AffineValueMap(lowerBoundsMap(), getLowerBoundsOperands());
2896 }
2897 
getUpperBoundsValueMap()2898 AffineValueMap AffineParallelOp::getUpperBoundsValueMap() {
2899   return AffineValueMap(upperBoundsMap(), getUpperBoundsOperands());
2900 }
2901 
getConstantRanges()2902 Optional<SmallVector<int64_t, 8>> AffineParallelOp::getConstantRanges() {
2903   if (hasMinMaxBounds())
2904     return llvm::None;
2905 
2906   // Try to convert all the ranges to constant expressions.
2907   SmallVector<int64_t, 8> out;
2908   AffineValueMap rangesValueMap;
2909   AffineValueMap::difference(getUpperBoundsValueMap(), getLowerBoundsValueMap(),
2910                              &rangesValueMap);
2911   out.reserve(rangesValueMap.getNumResults());
2912   for (unsigned i = 0, e = rangesValueMap.getNumResults(); i < e; ++i) {
2913     auto expr = rangesValueMap.getResult(i);
2914     auto cst = expr.dyn_cast<AffineConstantExpr>();
2915     if (!cst)
2916       return llvm::None;
2917     out.push_back(cst.getValue());
2918   }
2919   return out;
2920 }
2921 
getBody()2922 Block *AffineParallelOp::getBody() { return &region().front(); }
2923 
getBodyBuilder()2924 OpBuilder AffineParallelOp::getBodyBuilder() {
2925   return OpBuilder(getBody(), std::prev(getBody()->end()));
2926 }
2927 
setLowerBounds(ValueRange lbOperands,AffineMap map)2928 void AffineParallelOp::setLowerBounds(ValueRange lbOperands, AffineMap map) {
2929   assert(lbOperands.size() == map.getNumInputs() &&
2930          "operands to map must match number of inputs");
2931 
2932   auto ubOperands = getUpperBoundsOperands();
2933 
2934   SmallVector<Value, 4> newOperands(lbOperands);
2935   newOperands.append(ubOperands.begin(), ubOperands.end());
2936   (*this)->setOperands(newOperands);
2937 
2938   lowerBoundsMapAttr(AffineMapAttr::get(map));
2939 }
2940 
setUpperBounds(ValueRange ubOperands,AffineMap map)2941 void AffineParallelOp::setUpperBounds(ValueRange ubOperands, AffineMap map) {
2942   assert(ubOperands.size() == map.getNumInputs() &&
2943          "operands to map must match number of inputs");
2944 
2945   SmallVector<Value, 4> newOperands(getLowerBoundsOperands());
2946   newOperands.append(ubOperands.begin(), ubOperands.end());
2947   (*this)->setOperands(newOperands);
2948 
2949   upperBoundsMapAttr(AffineMapAttr::get(map));
2950 }
2951 
setLowerBoundsMap(AffineMap map)2952 void AffineParallelOp::setLowerBoundsMap(AffineMap map) {
2953   AffineMap lbMap = lowerBoundsMap();
2954   assert(lbMap.getNumDims() == map.getNumDims() &&
2955          lbMap.getNumSymbols() == map.getNumSymbols());
2956   (void)lbMap;
2957   lowerBoundsMapAttr(AffineMapAttr::get(map));
2958 }
2959 
setUpperBoundsMap(AffineMap map)2960 void AffineParallelOp::setUpperBoundsMap(AffineMap map) {
2961   AffineMap ubMap = upperBoundsMap();
2962   assert(ubMap.getNumDims() == map.getNumDims() &&
2963          ubMap.getNumSymbols() == map.getNumSymbols());
2964   (void)ubMap;
2965   upperBoundsMapAttr(AffineMapAttr::get(map));
2966 }
2967 
getSteps()2968 SmallVector<int64_t, 8> AffineParallelOp::getSteps() {
2969   SmallVector<int64_t, 8> result;
2970   for (Attribute attr : steps()) {
2971     result.push_back(attr.cast<IntegerAttr>().getInt());
2972   }
2973   return result;
2974 }
2975 
setSteps(ArrayRef<int64_t> newSteps)2976 void AffineParallelOp::setSteps(ArrayRef<int64_t> newSteps) {
2977   stepsAttr(getBodyBuilder().getI64ArrayAttr(newSteps));
2978 }
2979 
verify(AffineParallelOp op)2980 static LogicalResult verify(AffineParallelOp op) {
2981   auto numDims = op.getNumDims();
2982   if (op.lowerBoundsGroups().getNumElements() != numDims ||
2983       op.upperBoundsGroups().getNumElements() != numDims ||
2984       op.steps().size() != numDims ||
2985       op.getBody()->getNumArguments() != numDims) {
2986     return op.emitOpError()
2987            << "the number of region arguments ("
2988            << op.getBody()->getNumArguments()
2989            << ") and the number of map groups for lower ("
2990            << op.lowerBoundsGroups().getNumElements() << ") and upper bound ("
2991            << op.upperBoundsGroups().getNumElements()
2992            << "), and the number of steps (" << op.steps().size()
2993            << ") must all match";
2994   }
2995 
2996   unsigned expectedNumLBResults = 0;
2997   for (APInt v : op.lowerBoundsGroups())
2998     expectedNumLBResults += v.getZExtValue();
2999   if (expectedNumLBResults != op.lowerBoundsMap().getNumResults())
3000     return op.emitOpError() << "expected lower bounds map to have "
3001                             << expectedNumLBResults << " results";
3002   unsigned expectedNumUBResults = 0;
3003   for (APInt v : op.upperBoundsGroups())
3004     expectedNumUBResults += v.getZExtValue();
3005   if (expectedNumUBResults != op.upperBoundsMap().getNumResults())
3006     return op.emitOpError() << "expected upper bounds map to have "
3007                             << expectedNumUBResults << " results";
3008 
3009   if (op.reductions().size() != op.getNumResults())
3010     return op.emitOpError("a reduction must be specified for each output");
3011 
3012   // Verify reduction  ops are all valid
3013   for (Attribute attr : op.reductions()) {
3014     auto intAttr = attr.dyn_cast<IntegerAttr>();
3015     if (!intAttr || !symbolizeAtomicRMWKind(intAttr.getInt()))
3016       return op.emitOpError("invalid reduction attribute");
3017   }
3018 
3019   // Verify that the bound operands are valid dimension/symbols.
3020   /// Lower bounds.
3021   if (failed(verifyDimAndSymbolIdentifiers(op, op.getLowerBoundsOperands(),
3022                                            op.lowerBoundsMap().getNumDims())))
3023     return failure();
3024   /// Upper bounds.
3025   if (failed(verifyDimAndSymbolIdentifiers(op, op.getUpperBoundsOperands(),
3026                                            op.upperBoundsMap().getNumDims())))
3027     return failure();
3028   return success();
3029 }
3030 
canonicalize()3031 LogicalResult AffineValueMap::canonicalize() {
3032   SmallVector<Value, 4> newOperands{operands};
3033   auto newMap = getAffineMap();
3034   composeAffineMapAndOperands(&newMap, &newOperands);
3035   if (newMap == getAffineMap() && newOperands == operands)
3036     return failure();
3037   reset(newMap, newOperands);
3038   return success();
3039 }
3040 
3041 /// Canonicalize the bounds of the given loop.
canonicalizeLoopBounds(AffineParallelOp op)3042 static LogicalResult canonicalizeLoopBounds(AffineParallelOp op) {
3043   AffineValueMap lb = op.getLowerBoundsValueMap();
3044   bool lbCanonicalized = succeeded(lb.canonicalize());
3045 
3046   AffineValueMap ub = op.getUpperBoundsValueMap();
3047   bool ubCanonicalized = succeeded(ub.canonicalize());
3048 
3049   // Any canonicalization change always leads to updated map(s).
3050   if (!lbCanonicalized && !ubCanonicalized)
3051     return failure();
3052 
3053   if (lbCanonicalized)
3054     op.setLowerBounds(lb.getOperands(), lb.getAffineMap());
3055   if (ubCanonicalized)
3056     op.setUpperBounds(ub.getOperands(), ub.getAffineMap());
3057 
3058   return success();
3059 }
3060 
fold(ArrayRef<Attribute> operands,SmallVectorImpl<OpFoldResult> & results)3061 LogicalResult AffineParallelOp::fold(ArrayRef<Attribute> operands,
3062                                      SmallVectorImpl<OpFoldResult> &results) {
3063   return canonicalizeLoopBounds(*this);
3064 }
3065 
3066 /// Prints a lower(upper) bound of an affine parallel loop with max(min)
3067 /// conditions in it. `mapAttr` is a flat list of affine expressions and `group`
3068 /// identifies which of the those expressions form max/min groups. `operands`
3069 /// are the SSA values of dimensions and symbols and `keyword` is either "min"
3070 /// or "max".
printMinMaxBound(OpAsmPrinter & p,AffineMapAttr mapAttr,DenseIntElementsAttr group,ValueRange operands,StringRef keyword)3071 static void printMinMaxBound(OpAsmPrinter &p, AffineMapAttr mapAttr,
3072                              DenseIntElementsAttr group, ValueRange operands,
3073                              StringRef keyword) {
3074   AffineMap map = mapAttr.getValue();
3075   unsigned numDims = map.getNumDims();
3076   ValueRange dimOperands = operands.take_front(numDims);
3077   ValueRange symOperands = operands.drop_front(numDims);
3078   unsigned start = 0;
3079   for (llvm::APInt groupSize : group) {
3080     if (start != 0)
3081       p << ", ";
3082 
3083     unsigned size = groupSize.getZExtValue();
3084     if (size == 1) {
3085       p.printAffineExprOfSSAIds(map.getResult(start), dimOperands, symOperands);
3086       ++start;
3087     } else {
3088       p << keyword << '(';
3089       AffineMap submap = map.getSliceMap(start, size);
3090       p.printAffineMapOfSSAIds(AffineMapAttr::get(submap), operands);
3091       p << ')';
3092       start += size;
3093     }
3094   }
3095 }
3096 
print(OpAsmPrinter & p,AffineParallelOp op)3097 static void print(OpAsmPrinter &p, AffineParallelOp op) {
3098   p << op.getOperationName() << " (" << op.getBody()->getArguments() << ") = (";
3099   printMinMaxBound(p, op.lowerBoundsMapAttr(), op.lowerBoundsGroupsAttr(),
3100                    op.getLowerBoundsOperands(), "max");
3101   p << ") to (";
3102   printMinMaxBound(p, op.upperBoundsMapAttr(), op.upperBoundsGroupsAttr(),
3103                    op.getUpperBoundsOperands(), "min");
3104   p << ')';
3105   SmallVector<int64_t, 8> steps = op.getSteps();
3106   bool elideSteps = llvm::all_of(steps, [](int64_t step) { return step == 1; });
3107   if (!elideSteps) {
3108     p << " step (";
3109     llvm::interleaveComma(steps, p);
3110     p << ')';
3111   }
3112   if (op.getNumResults()) {
3113     p << " reduce (";
3114     llvm::interleaveComma(op.reductions(), p, [&](auto &attr) {
3115       AtomicRMWKind sym =
3116           *symbolizeAtomicRMWKind(attr.template cast<IntegerAttr>().getInt());
3117       p << "\"" << stringifyAtomicRMWKind(sym) << "\"";
3118     });
3119     p << ") -> (" << op.getResultTypes() << ")";
3120   }
3121 
3122   p.printRegion(op.region(), /*printEntryBlockArgs=*/false,
3123                 /*printBlockTerminators=*/op.getNumResults());
3124   p.printOptionalAttrDict(
3125       op->getAttrs(),
3126       /*elidedAttrs=*/{AffineParallelOp::getReductionsAttrName(),
3127                        AffineParallelOp::getLowerBoundsMapAttrName(),
3128                        AffineParallelOp::getLowerBoundsGroupsAttrName(),
3129                        AffineParallelOp::getUpperBoundsMapAttrName(),
3130                        AffineParallelOp::getUpperBoundsGroupsAttrName(),
3131                        AffineParallelOp::getStepsAttrName()});
3132 }
3133 
3134 /// Given a list of lists of parsed operands, populates `uniqueOperands` with
3135 /// unique operands. Also populates `replacements with affine expressions of
3136 /// `kind` that can be used to update affine maps previously accepting a
3137 /// `operands` to accept `uniqueOperands` instead.
deduplicateAndResolveOperands(OpAsmParser & parser,ArrayRef<SmallVector<OpAsmParser::OperandType>> operands,SmallVectorImpl<Value> & uniqueOperands,SmallVectorImpl<AffineExpr> & replacements,AffineExprKind kind)3138 static void deduplicateAndResolveOperands(
3139     OpAsmParser &parser,
3140     ArrayRef<SmallVector<OpAsmParser::OperandType>> operands,
3141     SmallVectorImpl<Value> &uniqueOperands,
3142     SmallVectorImpl<AffineExpr> &replacements, AffineExprKind kind) {
3143   assert((kind == AffineExprKind::DimId || kind == AffineExprKind::SymbolId) &&
3144          "expected operands to be dim or symbol expression");
3145 
3146   Type indexType = parser.getBuilder().getIndexType();
3147   for (const auto &list : operands) {
3148     SmallVector<Value> valueOperands;
3149     parser.resolveOperands(list, indexType, valueOperands);
3150     for (Value operand : valueOperands) {
3151       unsigned pos = std::distance(uniqueOperands.begin(),
3152                                    llvm::find(uniqueOperands, operand));
3153       if (pos == uniqueOperands.size())
3154         uniqueOperands.push_back(operand);
3155       replacements.push_back(
3156           kind == AffineExprKind::DimId
3157               ? getAffineDimExpr(pos, parser.getBuilder().getContext())
3158               : getAffineSymbolExpr(pos, parser.getBuilder().getContext()));
3159     }
3160   }
3161 }
3162 
3163 namespace {
3164 enum class MinMaxKind { Min, Max };
3165 } // namespace
3166 
3167 /// Parses an affine map that can contain a min/max for groups of its results,
3168 /// e.g., max(expr-1, expr-2), expr-3, max(expr-4, expr-5, expr-6). Populates
3169 /// `result` attributes with the map (flat list of expressions) and the grouping
3170 /// (list of integers that specify how many expressions to put into each
3171 /// min/max) attributes. Deduplicates repeated operands.
3172 ///
3173 /// parallel-bound       ::= `(` parallel-group-list `)`
3174 /// parallel-group-list  ::= parallel-group (`,` parallel-group-list)?
3175 /// parallel-group       ::= simple-group | min-max-group
3176 /// simple-group         ::= expr-of-ssa-ids
3177 /// min-max-group        ::= ( `min` | `max` ) `(` expr-of-ssa-ids-list `)`
3178 /// expr-of-ssa-ids-list ::= expr-of-ssa-ids (`,` expr-of-ssa-id-list)?
3179 ///
3180 /// Examples:
3181 ///   (%0, min(%1 + %2, %3), %4, min(%5 floordiv 32, %6))
3182 ///   (%0, max(%1 - 2 * %2))
parseAffineMapWithMinMax(OpAsmParser & parser,OperationState & result,MinMaxKind kind)3183 static ParseResult parseAffineMapWithMinMax(OpAsmParser &parser,
3184                                             OperationState &result,
3185                                             MinMaxKind kind) {
3186   constexpr llvm::StringLiteral tmpAttrName = "__pseudo_bound_map";
3187 
3188   StringRef mapName = kind == MinMaxKind::Min
3189                           ? AffineParallelOp::getUpperBoundsMapAttrName()
3190                           : AffineParallelOp::getLowerBoundsMapAttrName();
3191   StringRef groupsName = kind == MinMaxKind::Min
3192                              ? AffineParallelOp::getUpperBoundsGroupsAttrName()
3193                              : AffineParallelOp::getLowerBoundsGroupsAttrName();
3194 
3195   if (failed(parser.parseLParen()))
3196     return failure();
3197 
3198   if (succeeded(parser.parseOptionalRParen())) {
3199     result.addAttribute(
3200         mapName, AffineMapAttr::get(parser.getBuilder().getEmptyAffineMap()));
3201     result.addAttribute(groupsName, parser.getBuilder().getI32TensorAttr({}));
3202     return success();
3203   }
3204 
3205   SmallVector<AffineExpr> flatExprs;
3206   SmallVector<SmallVector<OpAsmParser::OperandType>> flatDimOperands;
3207   SmallVector<SmallVector<OpAsmParser::OperandType>> flatSymOperands;
3208   SmallVector<int32_t> numMapsPerGroup;
3209   SmallVector<OpAsmParser::OperandType> mapOperands;
3210   do {
3211     if (succeeded(parser.parseOptionalKeyword(
3212             kind == MinMaxKind::Min ? "min" : "max"))) {
3213       mapOperands.clear();
3214       AffineMapAttr map;
3215       if (failed(parser.parseAffineMapOfSSAIds(mapOperands, map, tmpAttrName,
3216                                                result.attributes,
3217                                                OpAsmParser::Delimiter::Paren)))
3218         return failure();
3219       result.attributes.erase(tmpAttrName);
3220       llvm::append_range(flatExprs, map.getValue().getResults());
3221       auto operandsRef = llvm::makeArrayRef(mapOperands);
3222       auto dimsRef = operandsRef.take_front(map.getValue().getNumDims());
3223       SmallVector<OpAsmParser::OperandType> dims(dimsRef.begin(),
3224                                                  dimsRef.end());
3225       auto symsRef = operandsRef.drop_front(map.getValue().getNumDims());
3226       SmallVector<OpAsmParser::OperandType> syms(symsRef.begin(),
3227                                                  symsRef.end());
3228       flatDimOperands.append(map.getValue().getNumResults(), dims);
3229       flatSymOperands.append(map.getValue().getNumResults(), syms);
3230       numMapsPerGroup.push_back(map.getValue().getNumResults());
3231     } else {
3232       if (failed(parser.parseAffineExprOfSSAIds(flatDimOperands.emplace_back(),
3233                                                 flatSymOperands.emplace_back(),
3234                                                 flatExprs.emplace_back())))
3235         return failure();
3236       numMapsPerGroup.push_back(1);
3237     }
3238   } while (succeeded(parser.parseOptionalComma()));
3239 
3240   if (failed(parser.parseRParen()))
3241     return failure();
3242 
3243   unsigned totalNumDims = 0;
3244   unsigned totalNumSyms = 0;
3245   for (unsigned i = 0, e = flatExprs.size(); i < e; ++i) {
3246     unsigned numDims = flatDimOperands[i].size();
3247     unsigned numSyms = flatSymOperands[i].size();
3248     flatExprs[i] = flatExprs[i]
3249                        .shiftDims(numDims, totalNumDims)
3250                        .shiftSymbols(numSyms, totalNumSyms);
3251     totalNumDims += numDims;
3252     totalNumSyms += numSyms;
3253   }
3254 
3255   // Deduplicate map operands.
3256   SmallVector<Value> dimOperands, symOperands;
3257   SmallVector<AffineExpr> dimRplacements, symRepacements;
3258   deduplicateAndResolveOperands(parser, flatDimOperands, dimOperands,
3259                                 dimRplacements, AffineExprKind::DimId);
3260   deduplicateAndResolveOperands(parser, flatSymOperands, symOperands,
3261                                 symRepacements, AffineExprKind::SymbolId);
3262 
3263   result.operands.append(dimOperands.begin(), dimOperands.end());
3264   result.operands.append(symOperands.begin(), symOperands.end());
3265 
3266   Builder &builder = parser.getBuilder();
3267   auto flatMap = AffineMap::get(totalNumDims, totalNumSyms, flatExprs,
3268                                 parser.getBuilder().getContext());
3269   flatMap = flatMap.replaceDimsAndSymbols(
3270       dimRplacements, symRepacements, dimOperands.size(), symOperands.size());
3271 
3272   result.addAttribute(mapName, AffineMapAttr::get(flatMap));
3273   result.addAttribute(groupsName, builder.getI32TensorAttr(numMapsPerGroup));
3274   return success();
3275 }
3276 
3277 //
3278 // operation ::= `affine.parallel` `(` ssa-ids `)` `=` parallel-bound
3279 //               `to` parallel-bound steps? region attr-dict?
3280 // steps     ::= `steps` `(` integer-literals `)`
3281 //
parseAffineParallelOp(OpAsmParser & parser,OperationState & result)3282 static ParseResult parseAffineParallelOp(OpAsmParser &parser,
3283                                          OperationState &result) {
3284   auto &builder = parser.getBuilder();
3285   auto indexType = builder.getIndexType();
3286   SmallVector<OpAsmParser::OperandType, 4> ivs;
3287   if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
3288                                      OpAsmParser::Delimiter::Paren) ||
3289       parser.parseEqual() ||
3290       parseAffineMapWithMinMax(parser, result, MinMaxKind::Max) ||
3291       parser.parseKeyword("to") ||
3292       parseAffineMapWithMinMax(parser, result, MinMaxKind::Min))
3293     return failure();
3294 
3295   AffineMapAttr stepsMapAttr;
3296   NamedAttrList stepsAttrs;
3297   SmallVector<OpAsmParser::OperandType, 4> stepsMapOperands;
3298   if (failed(parser.parseOptionalKeyword("step"))) {
3299     SmallVector<int64_t, 4> steps(ivs.size(), 1);
3300     result.addAttribute(AffineParallelOp::getStepsAttrName(),
3301                         builder.getI64ArrayAttr(steps));
3302   } else {
3303     if (parser.parseAffineMapOfSSAIds(stepsMapOperands, stepsMapAttr,
3304                                       AffineParallelOp::getStepsAttrName(),
3305                                       stepsAttrs,
3306                                       OpAsmParser::Delimiter::Paren))
3307       return failure();
3308 
3309     // Convert steps from an AffineMap into an I64ArrayAttr.
3310     SmallVector<int64_t, 4> steps;
3311     auto stepsMap = stepsMapAttr.getValue();
3312     for (const auto &result : stepsMap.getResults()) {
3313       auto constExpr = result.dyn_cast<AffineConstantExpr>();
3314       if (!constExpr)
3315         return parser.emitError(parser.getNameLoc(),
3316                                 "steps must be constant integers");
3317       steps.push_back(constExpr.getValue());
3318     }
3319     result.addAttribute(AffineParallelOp::getStepsAttrName(),
3320                         builder.getI64ArrayAttr(steps));
3321   }
3322 
3323   // Parse optional clause of the form: `reduce ("addf", "maxf")`, where the
3324   // quoted strings are a member of the enum AtomicRMWKind.
3325   SmallVector<Attribute, 4> reductions;
3326   if (succeeded(parser.parseOptionalKeyword("reduce"))) {
3327     if (parser.parseLParen())
3328       return failure();
3329     do {
3330       // Parse a single quoted string via the attribute parsing, and then
3331       // verify it is a member of the enum and convert to it's integer
3332       // representation.
3333       StringAttr attrVal;
3334       NamedAttrList attrStorage;
3335       auto loc = parser.getCurrentLocation();
3336       if (parser.parseAttribute(attrVal, builder.getNoneType(), "reduce",
3337                                 attrStorage))
3338         return failure();
3339       llvm::Optional<AtomicRMWKind> reduction =
3340           symbolizeAtomicRMWKind(attrVal.getValue());
3341       if (!reduction)
3342         return parser.emitError(loc, "invalid reduction value: ") << attrVal;
3343       reductions.push_back(builder.getI64IntegerAttr(
3344           static_cast<int64_t>(reduction.getValue())));
3345       // While we keep getting commas, keep parsing.
3346     } while (succeeded(parser.parseOptionalComma()));
3347     if (parser.parseRParen())
3348       return failure();
3349   }
3350   result.addAttribute(AffineParallelOp::getReductionsAttrName(),
3351                       builder.getArrayAttr(reductions));
3352 
3353   // Parse return types of reductions (if any)
3354   if (parser.parseOptionalArrowTypeList(result.types))
3355     return failure();
3356 
3357   // Now parse the body.
3358   Region *body = result.addRegion();
3359   SmallVector<Type, 4> types(ivs.size(), indexType);
3360   if (parser.parseRegion(*body, ivs, types) ||
3361       parser.parseOptionalAttrDict(result.attributes))
3362     return failure();
3363 
3364   // Add a terminator if none was parsed.
3365   AffineParallelOp::ensureTerminator(*body, builder, result.location);
3366   return success();
3367 }
3368 
3369 //===----------------------------------------------------------------------===//
3370 // AffineYieldOp
3371 //===----------------------------------------------------------------------===//
3372 
verify(AffineYieldOp op)3373 static LogicalResult verify(AffineYieldOp op) {
3374   auto *parentOp = op->getParentOp();
3375   auto results = parentOp->getResults();
3376   auto operands = op.getOperands();
3377 
3378   if (!isa<AffineParallelOp, AffineIfOp, AffineForOp>(parentOp))
3379     return op.emitOpError() << "only terminates affine.if/for/parallel regions";
3380   if (parentOp->getNumResults() != op.getNumOperands())
3381     return op.emitOpError() << "parent of yield must have same number of "
3382                                "results as the yield operands";
3383   for (auto it : llvm::zip(results, operands)) {
3384     if (std::get<0>(it).getType() != std::get<1>(it).getType())
3385       return op.emitOpError()
3386              << "types mismatch between yield op and its parent";
3387   }
3388 
3389   return success();
3390 }
3391 
3392 //===----------------------------------------------------------------------===//
3393 // AffineVectorLoadOp
3394 //===----------------------------------------------------------------------===//
3395 
build(OpBuilder & builder,OperationState & result,VectorType resultType,AffineMap map,ValueRange operands)3396 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
3397                                VectorType resultType, AffineMap map,
3398                                ValueRange operands) {
3399   assert(operands.size() == 1 + map.getNumInputs() && "inconsistent operands");
3400   result.addOperands(operands);
3401   if (map)
3402     result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
3403   result.types.push_back(resultType);
3404 }
3405 
build(OpBuilder & builder,OperationState & result,VectorType resultType,Value memref,AffineMap map,ValueRange mapOperands)3406 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
3407                                VectorType resultType, Value memref,
3408                                AffineMap map, ValueRange mapOperands) {
3409   assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3410   result.addOperands(memref);
3411   result.addOperands(mapOperands);
3412   result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
3413   result.types.push_back(resultType);
3414 }
3415 
build(OpBuilder & builder,OperationState & result,VectorType resultType,Value memref,ValueRange indices)3416 void AffineVectorLoadOp::build(OpBuilder &builder, OperationState &result,
3417                                VectorType resultType, Value memref,
3418                                ValueRange indices) {
3419   auto memrefType = memref.getType().cast<MemRefType>();
3420   int64_t rank = memrefType.getRank();
3421   // Create identity map for memrefs with at least one dimension or () -> ()
3422   // for zero-dimensional memrefs.
3423   auto map =
3424       rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3425   build(builder, result, resultType, memref, map, indices);
3426 }
3427 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3428 void AffineVectorLoadOp::getCanonicalizationPatterns(RewritePatternSet &results,
3429                                                      MLIRContext *context) {
3430   results.add<SimplifyAffineOp<AffineVectorLoadOp>>(context);
3431 }
3432 
parseAffineVectorLoadOp(OpAsmParser & parser,OperationState & result)3433 static ParseResult parseAffineVectorLoadOp(OpAsmParser &parser,
3434                                            OperationState &result) {
3435   auto &builder = parser.getBuilder();
3436   auto indexTy = builder.getIndexType();
3437 
3438   MemRefType memrefType;
3439   VectorType resultType;
3440   OpAsmParser::OperandType memrefInfo;
3441   AffineMapAttr mapAttr;
3442   SmallVector<OpAsmParser::OperandType, 1> mapOperands;
3443   return failure(
3444       parser.parseOperand(memrefInfo) ||
3445       parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3446                                     AffineVectorLoadOp::getMapAttrName(),
3447                                     result.attributes) ||
3448       parser.parseOptionalAttrDict(result.attributes) ||
3449       parser.parseColonType(memrefType) || parser.parseComma() ||
3450       parser.parseType(resultType) ||
3451       parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
3452       parser.resolveOperands(mapOperands, indexTy, result.operands) ||
3453       parser.addTypeToList(resultType, result.types));
3454 }
3455 
print(OpAsmPrinter & p,AffineVectorLoadOp op)3456 static void print(OpAsmPrinter &p, AffineVectorLoadOp op) {
3457   p << "affine.vector_load " << op.getMemRef() << '[';
3458   if (AffineMapAttr mapAttr =
3459           op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
3460     p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
3461   p << ']';
3462   p.printOptionalAttrDict(op->getAttrs(),
3463                           /*elidedAttrs=*/{op.getMapAttrName()});
3464   p << " : " << op.getMemRefType() << ", " << op.getType();
3465 }
3466 
3467 /// Verify common invariants of affine.vector_load and affine.vector_store.
verifyVectorMemoryOp(Operation * op,MemRefType memrefType,VectorType vectorType)3468 static LogicalResult verifyVectorMemoryOp(Operation *op, MemRefType memrefType,
3469                                           VectorType vectorType) {
3470   // Check that memref and vector element types match.
3471   if (memrefType.getElementType() != vectorType.getElementType())
3472     return op->emitOpError(
3473         "requires memref and vector types of the same elemental type");
3474   return success();
3475 }
3476 
verify(AffineVectorLoadOp op)3477 static LogicalResult verify(AffineVectorLoadOp op) {
3478   MemRefType memrefType = op.getMemRefType();
3479   if (failed(verifyMemoryOpIndexing(
3480           op.getOperation(),
3481           op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
3482           op.getMapOperands(), memrefType,
3483           /*numIndexOperands=*/op.getNumOperands() - 1)))
3484     return failure();
3485 
3486   if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType,
3487                                   op.getVectorType())))
3488     return failure();
3489 
3490   return success();
3491 }
3492 
3493 //===----------------------------------------------------------------------===//
3494 // AffineVectorStoreOp
3495 //===----------------------------------------------------------------------===//
3496 
build(OpBuilder & builder,OperationState & result,Value valueToStore,Value memref,AffineMap map,ValueRange mapOperands)3497 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
3498                                 Value valueToStore, Value memref, AffineMap map,
3499                                 ValueRange mapOperands) {
3500   assert(map.getNumInputs() == mapOperands.size() && "inconsistent index info");
3501   result.addOperands(valueToStore);
3502   result.addOperands(memref);
3503   result.addOperands(mapOperands);
3504   result.addAttribute(getMapAttrName(), AffineMapAttr::get(map));
3505 }
3506 
3507 // Use identity map.
build(OpBuilder & builder,OperationState & result,Value valueToStore,Value memref,ValueRange indices)3508 void AffineVectorStoreOp::build(OpBuilder &builder, OperationState &result,
3509                                 Value valueToStore, Value memref,
3510                                 ValueRange indices) {
3511   auto memrefType = memref.getType().cast<MemRefType>();
3512   int64_t rank = memrefType.getRank();
3513   // Create identity map for memrefs with at least one dimension or () -> ()
3514   // for zero-dimensional memrefs.
3515   auto map =
3516       rank ? builder.getMultiDimIdentityMap(rank) : builder.getEmptyAffineMap();
3517   build(builder, result, valueToStore, memref, map, indices);
3518 }
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)3519 void AffineVectorStoreOp::getCanonicalizationPatterns(
3520     RewritePatternSet &results, MLIRContext *context) {
3521   results.add<SimplifyAffineOp<AffineVectorStoreOp>>(context);
3522 }
3523 
parseAffineVectorStoreOp(OpAsmParser & parser,OperationState & result)3524 static ParseResult parseAffineVectorStoreOp(OpAsmParser &parser,
3525                                             OperationState &result) {
3526   auto indexTy = parser.getBuilder().getIndexType();
3527 
3528   MemRefType memrefType;
3529   VectorType resultType;
3530   OpAsmParser::OperandType storeValueInfo;
3531   OpAsmParser::OperandType memrefInfo;
3532   AffineMapAttr mapAttr;
3533   SmallVector<OpAsmParser::OperandType, 1> mapOperands;
3534   return failure(
3535       parser.parseOperand(storeValueInfo) || parser.parseComma() ||
3536       parser.parseOperand(memrefInfo) ||
3537       parser.parseAffineMapOfSSAIds(mapOperands, mapAttr,
3538                                     AffineVectorStoreOp::getMapAttrName(),
3539                                     result.attributes) ||
3540       parser.parseOptionalAttrDict(result.attributes) ||
3541       parser.parseColonType(memrefType) || parser.parseComma() ||
3542       parser.parseType(resultType) ||
3543       parser.resolveOperand(storeValueInfo, resultType, result.operands) ||
3544       parser.resolveOperand(memrefInfo, memrefType, result.operands) ||
3545       parser.resolveOperands(mapOperands, indexTy, result.operands));
3546 }
3547 
print(OpAsmPrinter & p,AffineVectorStoreOp op)3548 static void print(OpAsmPrinter &p, AffineVectorStoreOp op) {
3549   p << "affine.vector_store " << op.getValueToStore();
3550   p << ", " << op.getMemRef() << '[';
3551   if (AffineMapAttr mapAttr =
3552           op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()))
3553     p.printAffineMapOfSSAIds(mapAttr, op.getMapOperands());
3554   p << ']';
3555   p.printOptionalAttrDict(op->getAttrs(),
3556                           /*elidedAttrs=*/{op.getMapAttrName()});
3557   p << " : " << op.getMemRefType() << ", " << op.getValueToStore().getType();
3558 }
3559 
verify(AffineVectorStoreOp op)3560 static LogicalResult verify(AffineVectorStoreOp op) {
3561   MemRefType memrefType = op.getMemRefType();
3562   if (failed(verifyMemoryOpIndexing(
3563           op.getOperation(),
3564           op->getAttrOfType<AffineMapAttr>(op.getMapAttrName()),
3565           op.getMapOperands(), memrefType,
3566           /*numIndexOperands=*/op.getNumOperands() - 2)))
3567     return failure();
3568 
3569   if (failed(verifyVectorMemoryOp(op.getOperation(), memrefType,
3570                                   op.getVectorType())))
3571     return failure();
3572 
3573   return success();
3574 }
3575 
3576 //===----------------------------------------------------------------------===//
3577 // TableGen'd op method definitions
3578 //===----------------------------------------------------------------------===//
3579 
3580 #define GET_OP_CLASSES
3581 #include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
3582