1 //===- SCF.cpp - Structured Control Flow Operations -----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "mlir/Dialect/SCF/SCF.h"
10 #include "mlir/Dialect/MemRef/IR/MemRef.h"
11 #include "mlir/Dialect/StandardOps/IR/Ops.h"
12 #include "mlir/Dialect/Tensor/IR/Tensor.h"
13 #include "mlir/IR/BlockAndValueMapping.h"
14 #include "mlir/IR/PatternMatch.h"
15 #include "mlir/Support/MathExtras.h"
16 #include "mlir/Transforms/InliningUtils.h"
17 
18 using namespace mlir;
19 using namespace mlir::scf;
20 
21 #include "mlir/Dialect/SCF/SCFOpsDialect.cpp.inc"
22 
23 //===----------------------------------------------------------------------===//
24 // SCFDialect Dialect Interfaces
25 //===----------------------------------------------------------------------===//
26 
27 namespace {
28 struct SCFInlinerInterface : public DialectInlinerInterface {
29   using DialectInlinerInterface::DialectInlinerInterface;
30   // We don't have any special restrictions on what can be inlined into
31   // destination regions (e.g. while/conditional bodies). Always allow it.
isLegalToInline__anon56feca5b0111::SCFInlinerInterface32   bool isLegalToInline(Region *dest, Region *src, bool wouldBeCloned,
33                        BlockAndValueMapping &valueMapping) const final {
34     return true;
35   }
36   // Operations in scf dialect are always legal to inline since they are
37   // pure.
isLegalToInline__anon56feca5b0111::SCFInlinerInterface38   bool isLegalToInline(Operation *, Region *, bool,
39                        BlockAndValueMapping &) const final {
40     return true;
41   }
42   // Handle the given inlined terminator by replacing it with a new operation
43   // as necessary. Required when the region has only one block.
handleTerminator__anon56feca5b0111::SCFInlinerInterface44   void handleTerminator(Operation *op,
45                         ArrayRef<Value> valuesToRepl) const final {
46     auto retValOp = dyn_cast<scf::YieldOp>(op);
47     if (!retValOp)
48       return;
49 
50     for (auto retValue : llvm::zip(valuesToRepl, retValOp.getOperands())) {
51       std::get<0>(retValue).replaceAllUsesWith(std::get<1>(retValue));
52     }
53   }
54 };
55 } // end anonymous namespace
56 
57 //===----------------------------------------------------------------------===//
58 // SCFDialect
59 //===----------------------------------------------------------------------===//
60 
initialize()61 void SCFDialect::initialize() {
62   addOperations<
63 #define GET_OP_LIST
64 #include "mlir/Dialect/SCF/SCFOps.cpp.inc"
65       >();
66   addInterfaces<SCFInlinerInterface>();
67 }
68 
69 /// Default callback for IfOp builders. Inserts a yield without arguments.
buildTerminatedBody(OpBuilder & builder,Location loc)70 void mlir::scf::buildTerminatedBody(OpBuilder &builder, Location loc) {
71   builder.create<scf::YieldOp>(loc);
72 }
73 
74 //===----------------------------------------------------------------------===//
75 // ExecuteRegionOp
76 //===----------------------------------------------------------------------===//
77 
78 /// Replaces the given op with the contents of the given single-block region,
79 /// using the operands of the block terminator to replace operation results.
replaceOpWithRegion(PatternRewriter & rewriter,Operation * op,Region & region,ValueRange blockArgs={})80 static void replaceOpWithRegion(PatternRewriter &rewriter, Operation *op,
81                                 Region &region, ValueRange blockArgs = {}) {
82   assert(llvm::hasSingleElement(region) && "expected single-region block");
83   Block *block = &region.front();
84   Operation *terminator = block->getTerminator();
85   ValueRange results = terminator->getOperands();
86   rewriter.mergeBlockBefore(block, op, blockArgs);
87   rewriter.replaceOp(op, results);
88   rewriter.eraseOp(terminator);
89 }
90 
91 ///
92 /// (ssa-id `=`)? `execute_region` `->` function-result-type `{`
93 ///    block+
94 /// `}`
95 ///
96 /// Example:
97 ///   scf.execute_region -> i32 {
98 ///     %idx = load %rI[%i] : memref<128xi32>
99 ///     return %idx : i32
100 ///   }
101 ///
parseExecuteRegionOp(OpAsmParser & parser,OperationState & result)102 static ParseResult parseExecuteRegionOp(OpAsmParser &parser,
103                                         OperationState &result) {
104   if (parser.parseOptionalArrowTypeList(result.types))
105     return failure();
106 
107   // Introduce the body region and parse it.
108   Region *body = result.addRegion();
109   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}) ||
110       parser.parseOptionalAttrDict(result.attributes))
111     return failure();
112 
113   return success();
114 }
115 
print(OpAsmPrinter & p,ExecuteRegionOp op)116 static void print(OpAsmPrinter &p, ExecuteRegionOp op) {
117   p << ExecuteRegionOp::getOperationName();
118   if (op.getNumResults() > 0)
119     p << " -> " << op.getResultTypes();
120 
121   p.printRegion(op.region(),
122                 /*printEntryBlockArgs=*/false,
123                 /*printBlockTerminators=*/true);
124 
125   p.printOptionalAttrDict(op->getAttrs());
126 }
127 
verify(ExecuteRegionOp op)128 static LogicalResult verify(ExecuteRegionOp op) {
129   if (op.region().empty())
130     return op.emitOpError("region needs to have at least one block");
131   if (op.region().front().getNumArguments() > 0)
132     return op.emitOpError("region cannot have any arguments");
133   return success();
134 }
135 
136 // Inline an ExecuteRegionOp if it only contains one block.
137 //     "test.foo"() : () -> ()
138 //      %v = scf.execute_region -> i64 {
139 //        %x = "test.val"() : () -> i64
140 //        scf.yield %x : i64
141 //      }
142 //      "test.bar"(%v) : (i64) -> ()
143 //
144 //  becomes
145 //
146 //     "test.foo"() : () -> ()
147 //     %x = "test.val"() : () -> i64
148 //     "test.bar"(%x) : (i64) -> ()
149 //
150 struct SingleBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
151   using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
152 
matchAndRewriteSingleBlockExecuteInliner153   LogicalResult matchAndRewrite(ExecuteRegionOp op,
154                                 PatternRewriter &rewriter) const override {
155     if (!llvm::hasSingleElement(op.region()))
156       return failure();
157     replaceOpWithRegion(rewriter, op, op.region());
158     return success();
159   }
160 };
161 
162 // Inline an ExecuteRegionOp if its parent can contain multiple blocks.
163 // TODO generalize the conditions for operations which can be inlined into.
164 // func @func_execute_region_elim() {
165 //     "test.foo"() : () -> ()
166 //     %v = scf.execute_region -> i64 {
167 //       %c = "test.cmp"() : () -> i1
168 //       cond_br %c, ^bb2, ^bb3
169 //     ^bb2:
170 //       %x = "test.val1"() : () -> i64
171 //       br ^bb4(%x : i64)
172 //     ^bb3:
173 //       %y = "test.val2"() : () -> i64
174 //       br ^bb4(%y : i64)
175 //     ^bb4(%z : i64):
176 //       scf.yield %z : i64
177 //     }
178 //     "test.bar"(%v) : (i64) -> ()
179 //   return
180 // }
181 //
182 //  becomes
183 //
184 // func @func_execute_region_elim() {
185 //    "test.foo"() : () -> ()
186 //    %c = "test.cmp"() : () -> i1
187 //    cond_br %c, ^bb1, ^bb2
188 //  ^bb1:  // pred: ^bb0
189 //    %x = "test.val1"() : () -> i64
190 //    br ^bb3(%x : i64)
191 //  ^bb2:  // pred: ^bb0
192 //    %y = "test.val2"() : () -> i64
193 //    br ^bb3(%y : i64)
194 //  ^bb3(%z: i64):  // 2 preds: ^bb1, ^bb2
195 //    "test.bar"(%z) : (i64) -> ()
196 //    return
197 //  }
198 //
199 struct MultiBlockExecuteInliner : public OpRewritePattern<ExecuteRegionOp> {
200   using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
201 
matchAndRewriteMultiBlockExecuteInliner202   LogicalResult matchAndRewrite(ExecuteRegionOp op,
203                                 PatternRewriter &rewriter) const override {
204     if (!isa<FuncOp, ExecuteRegionOp>(op->getParentOp()))
205       return failure();
206 
207     Block *prevBlock = op->getBlock();
208     Block *postBlock = rewriter.splitBlock(prevBlock, op->getIterator());
209     rewriter.setInsertionPointToEnd(prevBlock);
210 
211     rewriter.create<BranchOp>(op.getLoc(), &op.region().front());
212 
213     for (Block &blk : op.region()) {
214       if (YieldOp yieldOp = dyn_cast<YieldOp>(blk.getTerminator())) {
215         rewriter.setInsertionPoint(yieldOp);
216         rewriter.create<BranchOp>(yieldOp.getLoc(), postBlock,
217                                   yieldOp.results());
218         rewriter.eraseOp(yieldOp);
219       }
220     }
221 
222     rewriter.inlineRegionBefore(op.region(), postBlock);
223     SmallVector<Value> blockArgs;
224 
225     for (auto res : op.getResults())
226       blockArgs.push_back(postBlock->addArgument(res.getType()));
227 
228     rewriter.replaceOp(op, blockArgs);
229     return success();
230   }
231 };
232 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)233 void ExecuteRegionOp::getCanonicalizationPatterns(RewritePatternSet &results,
234                                                   MLIRContext *context) {
235   results.add<SingleBlockExecuteInliner, MultiBlockExecuteInliner>(context);
236 }
237 
238 //===----------------------------------------------------------------------===//
239 // ConditionOp
240 //===----------------------------------------------------------------------===//
241 
242 MutableOperandRange
getMutableSuccessorOperands(Optional<unsigned> index)243 ConditionOp::getMutableSuccessorOperands(Optional<unsigned> index) {
244   // Pass all operands except the condition to the successor region.
245   return argsMutable();
246 }
247 
248 //===----------------------------------------------------------------------===//
249 // ForOp
250 //===----------------------------------------------------------------------===//
251 
build(OpBuilder & builder,OperationState & result,Value lb,Value ub,Value step,ValueRange iterArgs,BodyBuilderFn bodyBuilder)252 void ForOp::build(OpBuilder &builder, OperationState &result, Value lb,
253                   Value ub, Value step, ValueRange iterArgs,
254                   BodyBuilderFn bodyBuilder) {
255   result.addOperands({lb, ub, step});
256   result.addOperands(iterArgs);
257   for (Value v : iterArgs)
258     result.addTypes(v.getType());
259   Region *bodyRegion = result.addRegion();
260   bodyRegion->push_back(new Block);
261   Block &bodyBlock = bodyRegion->front();
262   bodyBlock.addArgument(builder.getIndexType());
263   for (Value v : iterArgs)
264     bodyBlock.addArgument(v.getType());
265 
266   // Create the default terminator if the builder is not provided and if the
267   // iteration arguments are not provided. Otherwise, leave this to the caller
268   // because we don't know which values to return from the loop.
269   if (iterArgs.empty() && !bodyBuilder) {
270     ForOp::ensureTerminator(*bodyRegion, builder, result.location);
271   } else if (bodyBuilder) {
272     OpBuilder::InsertionGuard guard(builder);
273     builder.setInsertionPointToStart(&bodyBlock);
274     bodyBuilder(builder, result.location, bodyBlock.getArgument(0),
275                 bodyBlock.getArguments().drop_front());
276   }
277 }
278 
verify(ForOp op)279 static LogicalResult verify(ForOp op) {
280   if (auto cst = op.step().getDefiningOp<ConstantIndexOp>())
281     if (cst.getValue() <= 0)
282       return op.emitOpError("constant step operand must be positive");
283 
284   // Check that the body defines as single block argument for the induction
285   // variable.
286   auto *body = op.getBody();
287   if (!body->getArgument(0).getType().isIndex())
288     return op.emitOpError(
289         "expected body first argument to be an index argument for "
290         "the induction variable");
291 
292   auto opNumResults = op.getNumResults();
293   if (opNumResults == 0)
294     return success();
295   // If ForOp defines values, check that the number and types of
296   // the defined values match ForOp initial iter operands and backedge
297   // basic block arguments.
298   if (op.getNumIterOperands() != opNumResults)
299     return op.emitOpError(
300         "mismatch in number of loop-carried values and defined values");
301   if (op.getNumRegionIterArgs() != opNumResults)
302     return op.emitOpError(
303         "mismatch in number of basic block args and defined values");
304   auto iterOperands = op.getIterOperands();
305   auto iterArgs = op.getRegionIterArgs();
306   auto opResults = op.getResults();
307   unsigned i = 0;
308   for (auto e : llvm::zip(iterOperands, iterArgs, opResults)) {
309     if (std::get<0>(e).getType() != std::get<2>(e).getType())
310       return op.emitOpError() << "types mismatch between " << i
311                               << "th iter operand and defined value";
312     if (std::get<1>(e).getType() != std::get<2>(e).getType())
313       return op.emitOpError() << "types mismatch between " << i
314                               << "th iter region arg and defined value";
315 
316     i++;
317   }
318 
319   return RegionBranchOpInterface::verifyTypes(op);
320 }
321 
322 /// Prints the initialization list in the form of
323 ///   <prefix>(%inner = %outer, %inner2 = %outer2, <...>)
324 /// where 'inner' values are assumed to be region arguments and 'outer' values
325 /// are regular SSA values.
printInitializationList(OpAsmPrinter & p,Block::BlockArgListType blocksArgs,ValueRange initializers,StringRef prefix="")326 static void printInitializationList(OpAsmPrinter &p,
327                                     Block::BlockArgListType blocksArgs,
328                                     ValueRange initializers,
329                                     StringRef prefix = "") {
330   assert(blocksArgs.size() == initializers.size() &&
331          "expected same length of arguments and initializers");
332   if (initializers.empty())
333     return;
334 
335   p << prefix << '(';
336   llvm::interleaveComma(llvm::zip(blocksArgs, initializers), p, [&](auto it) {
337     p << std::get<0>(it) << " = " << std::get<1>(it);
338   });
339   p << ")";
340 }
341 
print(OpAsmPrinter & p,ForOp op)342 static void print(OpAsmPrinter &p, ForOp op) {
343   p << op.getOperationName() << " " << op.getInductionVar() << " = "
344     << op.lowerBound() << " to " << op.upperBound() << " step " << op.step();
345 
346   printInitializationList(p, op.getRegionIterArgs(), op.getIterOperands(),
347                           " iter_args");
348   if (!op.getIterOperands().empty())
349     p << " -> (" << op.getIterOperands().getTypes() << ')';
350   p.printRegion(op.region(),
351                 /*printEntryBlockArgs=*/false,
352                 /*printBlockTerminators=*/op.hasIterOperands());
353   p.printOptionalAttrDict(op->getAttrs());
354 }
355 
parseForOp(OpAsmParser & parser,OperationState & result)356 static ParseResult parseForOp(OpAsmParser &parser, OperationState &result) {
357   auto &builder = parser.getBuilder();
358   OpAsmParser::OperandType inductionVariable, lb, ub, step;
359   // Parse the induction variable followed by '='.
360   if (parser.parseRegionArgument(inductionVariable) || parser.parseEqual())
361     return failure();
362 
363   // Parse loop bounds.
364   Type indexType = builder.getIndexType();
365   if (parser.parseOperand(lb) ||
366       parser.resolveOperand(lb, indexType, result.operands) ||
367       parser.parseKeyword("to") || parser.parseOperand(ub) ||
368       parser.resolveOperand(ub, indexType, result.operands) ||
369       parser.parseKeyword("step") || parser.parseOperand(step) ||
370       parser.resolveOperand(step, indexType, result.operands))
371     return failure();
372 
373   // Parse the optional initial iteration arguments.
374   SmallVector<OpAsmParser::OperandType, 4> regionArgs, operands;
375   SmallVector<Type, 4> argTypes;
376   regionArgs.push_back(inductionVariable);
377 
378   if (succeeded(parser.parseOptionalKeyword("iter_args"))) {
379     // Parse assignment list and results type list.
380     if (parser.parseAssignmentList(regionArgs, operands) ||
381         parser.parseArrowTypeList(result.types))
382       return failure();
383     // Resolve input operands.
384     for (auto operandType : llvm::zip(operands, result.types))
385       if (parser.resolveOperand(std::get<0>(operandType),
386                                 std::get<1>(operandType), result.operands))
387         return failure();
388   }
389   // Induction variable.
390   argTypes.push_back(indexType);
391   // Loop carried variables
392   argTypes.append(result.types.begin(), result.types.end());
393   // Parse the body region.
394   Region *body = result.addRegion();
395   if (regionArgs.size() != argTypes.size())
396     return parser.emitError(
397         parser.getNameLoc(),
398         "mismatch in number of loop-carried values and defined values");
399 
400   if (parser.parseRegion(*body, regionArgs, argTypes))
401     return failure();
402 
403   ForOp::ensureTerminator(*body, builder, result.location);
404 
405   // Parse the optional attribute list.
406   if (parser.parseOptionalAttrDict(result.attributes))
407     return failure();
408 
409   return success();
410 }
411 
getLoopBody()412 Region &ForOp::getLoopBody() { return region(); }
413 
isDefinedOutsideOfLoop(Value value)414 bool ForOp::isDefinedOutsideOfLoop(Value value) {
415   return !region().isAncestor(value.getParentRegion());
416 }
417 
moveOutOfLoop(ArrayRef<Operation * > ops)418 LogicalResult ForOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
419   for (auto *op : ops)
420     op->moveBefore(*this);
421   return success();
422 }
423 
getForInductionVarOwner(Value val)424 ForOp mlir::scf::getForInductionVarOwner(Value val) {
425   auto ivArg = val.dyn_cast<BlockArgument>();
426   if (!ivArg)
427     return ForOp();
428   assert(ivArg.getOwner() && "unlinked block argument");
429   auto *containingOp = ivArg.getOwner()->getParentOp();
430   return dyn_cast_or_null<ForOp>(containingOp);
431 }
432 
433 /// Return operands used when entering the region at 'index'. These operands
434 /// correspond to the loop iterator operands, i.e., those excluding the
435 /// induction variable. LoopOp only has one region, so 0 is the only valid value
436 /// for `index`.
getSuccessorEntryOperands(unsigned index)437 OperandRange ForOp::getSuccessorEntryOperands(unsigned index) {
438   assert(index == 0 && "invalid region index");
439 
440   // The initial operands map to the loop arguments after the induction
441   // variable.
442   return initArgs();
443 }
444 
445 /// Given the region at `index`, or the parent operation if `index` is None,
446 /// return the successor regions. These are the regions that may be selected
447 /// during the flow of control. `operands` is a set of optional attributes that
448 /// correspond to a constant value for each operand, or null if that operand is
449 /// not a constant.
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)450 void ForOp::getSuccessorRegions(Optional<unsigned> index,
451                                 ArrayRef<Attribute> operands,
452                                 SmallVectorImpl<RegionSuccessor> &regions) {
453   // If the predecessor is the ForOp, branch into the body using the iterator
454   // arguments.
455   if (!index.hasValue()) {
456     regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
457     return;
458   }
459 
460   // Otherwise, the loop may branch back to itself or the parent operation.
461   assert(index.getValue() == 0 && "expected loop region");
462   regions.push_back(RegionSuccessor(&getLoopBody(), getRegionIterArgs()));
463   regions.push_back(RegionSuccessor(getResults()));
464 }
465 
getNumRegionInvocations(ArrayRef<Attribute> operands,SmallVectorImpl<int64_t> & countPerRegion)466 void ForOp::getNumRegionInvocations(ArrayRef<Attribute> operands,
467                                     SmallVectorImpl<int64_t> &countPerRegion) {
468   assert(countPerRegion.empty());
469   countPerRegion.resize(1);
470 
471   auto lb = operands[0].dyn_cast_or_null<IntegerAttr>();
472   auto ub = operands[1].dyn_cast_or_null<IntegerAttr>();
473   auto step = operands[2].dyn_cast_or_null<IntegerAttr>();
474 
475   // Loop bounds are not known statically.
476   if (!lb || !ub || !step || step.getValue().getSExtValue() == 0) {
477     countPerRegion[0] = -1;
478     return;
479   }
480 
481   countPerRegion[0] =
482       ceilDiv(ub.getValue().getSExtValue() - lb.getValue().getSExtValue(),
483               step.getValue().getSExtValue());
484 }
485 
buildLoopNest(OpBuilder & builder,Location loc,ValueRange lbs,ValueRange ubs,ValueRange steps,ValueRange iterArgs,function_ref<ValueVector (OpBuilder &,Location,ValueRange,ValueRange)> bodyBuilder)486 LoopNest mlir::scf::buildLoopNest(
487     OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
488     ValueRange steps, ValueRange iterArgs,
489     function_ref<ValueVector(OpBuilder &, Location, ValueRange, ValueRange)>
490         bodyBuilder) {
491   assert(lbs.size() == ubs.size() &&
492          "expected the same number of lower and upper bounds");
493   assert(lbs.size() == steps.size() &&
494          "expected the same number of lower bounds and steps");
495 
496   // If there are no bounds, call the body-building function and return early.
497   if (lbs.empty()) {
498     ValueVector results =
499         bodyBuilder ? bodyBuilder(builder, loc, ValueRange(), iterArgs)
500                     : ValueVector();
501     assert(results.size() == iterArgs.size() &&
502            "loop nest body must return as many values as loop has iteration "
503            "arguments");
504     return LoopNest();
505   }
506 
507   // First, create the loop structure iteratively using the body-builder
508   // callback of `ForOp::build`. Do not create `YieldOp`s yet.
509   OpBuilder::InsertionGuard guard(builder);
510   SmallVector<scf::ForOp, 4> loops;
511   SmallVector<Value, 4> ivs;
512   loops.reserve(lbs.size());
513   ivs.reserve(lbs.size());
514   ValueRange currentIterArgs = iterArgs;
515   Location currentLoc = loc;
516   for (unsigned i = 0, e = lbs.size(); i < e; ++i) {
517     auto loop = builder.create<scf::ForOp>(
518         currentLoc, lbs[i], ubs[i], steps[i], currentIterArgs,
519         [&](OpBuilder &nestedBuilder, Location nestedLoc, Value iv,
520             ValueRange args) {
521           ivs.push_back(iv);
522           // It is safe to store ValueRange args because it points to block
523           // arguments of a loop operation that we also own.
524           currentIterArgs = args;
525           currentLoc = nestedLoc;
526         });
527     // Set the builder to point to the body of the newly created loop. We don't
528     // do this in the callback because the builder is reset when the callback
529     // returns.
530     builder.setInsertionPointToStart(loop.getBody());
531     loops.push_back(loop);
532   }
533 
534   // For all loops but the innermost, yield the results of the nested loop.
535   for (unsigned i = 0, e = loops.size() - 1; i < e; ++i) {
536     builder.setInsertionPointToEnd(loops[i].getBody());
537     builder.create<scf::YieldOp>(loc, loops[i + 1].getResults());
538   }
539 
540   // In the body of the innermost loop, call the body building function if any
541   // and yield its results.
542   builder.setInsertionPointToStart(loops.back().getBody());
543   ValueVector results = bodyBuilder
544                             ? bodyBuilder(builder, currentLoc, ivs,
545                                           loops.back().getRegionIterArgs())
546                             : ValueVector();
547   assert(results.size() == iterArgs.size() &&
548          "loop nest body must return as many values as loop has iteration "
549          "arguments");
550   builder.setInsertionPointToEnd(loops.back().getBody());
551   builder.create<scf::YieldOp>(loc, results);
552 
553   // Return the loops.
554   LoopNest res;
555   res.loops.assign(loops.begin(), loops.end());
556   return res;
557 }
558 
buildLoopNest(OpBuilder & builder,Location loc,ValueRange lbs,ValueRange ubs,ValueRange steps,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilder)559 LoopNest mlir::scf::buildLoopNest(
560     OpBuilder &builder, Location loc, ValueRange lbs, ValueRange ubs,
561     ValueRange steps,
562     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilder) {
563   // Delegate to the main function by wrapping the body builder.
564   return buildLoopNest(builder, loc, lbs, ubs, steps, llvm::None,
565                        [&bodyBuilder](OpBuilder &nestedBuilder,
566                                       Location nestedLoc, ValueRange ivs,
567                                       ValueRange) -> ValueVector {
568                          if (bodyBuilder)
569                            bodyBuilder(nestedBuilder, nestedLoc, ivs);
570                          return {};
571                        });
572 }
573 
574 namespace {
575 // Fold away ForOp iter arguments when:
576 // 1) The op yields the iter arguments.
577 // 2) The iter arguments have no use and the corresponding outer region
578 // iterators (inputs) are yielded.
579 // 3) The iter arguments have no use and the corresponding (operation) results
580 // have no use.
581 //
582 // These arguments must be defined outside of
583 // the ForOp region and can just be forwarded after simplifying the op inits,
584 // yields and returns.
585 //
586 // The implementation uses `mergeBlockBefore` to steal the content of the
587 // original ForOp and avoid cloning.
588 struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
589   using OpRewritePattern<scf::ForOp>::OpRewritePattern;
590 
matchAndRewrite__anon56feca5b0511::ForOpIterArgsFolder591   LogicalResult matchAndRewrite(scf::ForOp forOp,
592                                 PatternRewriter &rewriter) const final {
593     bool canonicalize = false;
594     Block &block = forOp.region().front();
595     auto yieldOp = cast<scf::YieldOp>(block.getTerminator());
596 
597     // An internal flat vector of block transfer
598     // arguments `newBlockTransferArgs` keeps the 1-1 mapping of original to
599     // transformed block argument mappings. This plays the role of a
600     // BlockAndValueMapping for the particular use case of calling into
601     // `mergeBlockBefore`.
602     SmallVector<bool, 4> keepMask;
603     keepMask.reserve(yieldOp.getNumOperands());
604     SmallVector<Value, 4> newBlockTransferArgs, newIterArgs, newYieldValues,
605         newResultValues;
606     newBlockTransferArgs.reserve(1 + forOp.getNumIterOperands());
607     newBlockTransferArgs.push_back(Value()); // iv placeholder with null value
608     newIterArgs.reserve(forOp.getNumIterOperands());
609     newYieldValues.reserve(yieldOp.getNumOperands());
610     newResultValues.reserve(forOp.getNumResults());
611     for (auto it : llvm::zip(forOp.getIterOperands(),   // iter from outside
612                              forOp.getRegionIterArgs(), // iter inside region
613                              forOp.getResults(),        // op results
614                              yieldOp.getOperands()      // iter yield
615                              )) {
616       // Forwarded is `true` when:
617       // 1) The region `iter` argument is yielded.
618       // 2) The region `iter` argument has no use, and the corresponding iter
619       // operand (input) is yielded.
620       // 3) The region `iter` argument has no use, and the corresponding op
621       // result has no use.
622       bool forwarded = ((std::get<1>(it) == std::get<3>(it)) ||
623                         (std::get<1>(it).use_empty() &&
624                          (std::get<0>(it) == std::get<3>(it) ||
625                           std::get<2>(it).use_empty())));
626       keepMask.push_back(!forwarded);
627       canonicalize |= forwarded;
628       if (forwarded) {
629         newBlockTransferArgs.push_back(std::get<0>(it));
630         newResultValues.push_back(std::get<0>(it));
631         continue;
632       }
633       newIterArgs.push_back(std::get<0>(it));
634       newYieldValues.push_back(std::get<3>(it));
635       newBlockTransferArgs.push_back(Value()); // placeholder with null value
636       newResultValues.push_back(Value());      // placeholder with null value
637     }
638 
639     if (!canonicalize)
640       return failure();
641 
642     scf::ForOp newForOp = rewriter.create<scf::ForOp>(
643         forOp.getLoc(), forOp.lowerBound(), forOp.upperBound(), forOp.step(),
644         newIterArgs);
645     Block &newBlock = newForOp.region().front();
646 
647     // Replace the null placeholders with newly constructed values.
648     newBlockTransferArgs[0] = newBlock.getArgument(0); // iv
649     for (unsigned idx = 0, collapsedIdx = 0, e = newResultValues.size();
650          idx != e; ++idx) {
651       Value &blockTransferArg = newBlockTransferArgs[1 + idx];
652       Value &newResultVal = newResultValues[idx];
653       assert((blockTransferArg && newResultVal) ||
654              (!blockTransferArg && !newResultVal));
655       if (!blockTransferArg) {
656         blockTransferArg = newForOp.getRegionIterArgs()[collapsedIdx];
657         newResultVal = newForOp.getResult(collapsedIdx++);
658       }
659     }
660 
661     Block &oldBlock = forOp.region().front();
662     assert(oldBlock.getNumArguments() == newBlockTransferArgs.size() &&
663            "unexpected argument size mismatch");
664 
665     // No results case: the scf::ForOp builder already created a zero
666     // result terminator. Merge before this terminator and just get rid of the
667     // original terminator that has been merged in.
668     if (newIterArgs.empty()) {
669       auto newYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
670       rewriter.mergeBlockBefore(&oldBlock, newYieldOp, newBlockTransferArgs);
671       rewriter.eraseOp(newBlock.getTerminator()->getPrevNode());
672       rewriter.replaceOp(forOp, newResultValues);
673       return success();
674     }
675 
676     // No terminator case: merge and rewrite the merged terminator.
677     auto cloneFilteredTerminator = [&](scf::YieldOp mergedTerminator) {
678       OpBuilder::InsertionGuard g(rewriter);
679       rewriter.setInsertionPoint(mergedTerminator);
680       SmallVector<Value, 4> filteredOperands;
681       filteredOperands.reserve(newResultValues.size());
682       for (unsigned idx = 0, e = keepMask.size(); idx < e; ++idx)
683         if (keepMask[idx])
684           filteredOperands.push_back(mergedTerminator.getOperand(idx));
685       rewriter.create<scf::YieldOp>(mergedTerminator.getLoc(),
686                                     filteredOperands);
687     };
688 
689     rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
690     auto mergedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
691     cloneFilteredTerminator(mergedYieldOp);
692     rewriter.eraseOp(mergedYieldOp);
693     rewriter.replaceOp(forOp, newResultValues);
694     return success();
695   }
696 };
697 
698 /// Rewriting pattern that erases loops that are known not to iterate and
699 /// replaces single-iteration loops with their bodies.
700 struct SimplifyTrivialLoops : public OpRewritePattern<ForOp> {
701   using OpRewritePattern<ForOp>::OpRewritePattern;
702 
matchAndRewrite__anon56feca5b0511::SimplifyTrivialLoops703   LogicalResult matchAndRewrite(ForOp op,
704                                 PatternRewriter &rewriter) const override {
705     // If the upper bound is the same as the lower bound, the loop does not
706     // iterate, just remove it.
707     if (op.lowerBound() == op.upperBound()) {
708       rewriter.replaceOp(op, op.getIterOperands());
709       return success();
710     }
711 
712     auto lb = op.lowerBound().getDefiningOp<ConstantOp>();
713     auto ub = op.upperBound().getDefiningOp<ConstantOp>();
714     if (!lb || !ub)
715       return failure();
716 
717     // If the loop is known to have 0 iterations, remove it.
718     llvm::APInt lbValue = lb.getValue().cast<IntegerAttr>().getValue();
719     llvm::APInt ubValue = ub.getValue().cast<IntegerAttr>().getValue();
720     if (lbValue.sge(ubValue)) {
721       rewriter.replaceOp(op, op.getIterOperands());
722       return success();
723     }
724 
725     auto step = op.step().getDefiningOp<ConstantOp>();
726     if (!step)
727       return failure();
728 
729     // If the loop is known to have 1 iteration, inline its body and remove the
730     // loop.
731     llvm::APInt stepValue = step.getValue().cast<IntegerAttr>().getValue();
732     if ((lbValue + stepValue).sge(ubValue)) {
733       SmallVector<Value, 4> blockArgs;
734       blockArgs.reserve(op.getNumIterOperands() + 1);
735       blockArgs.push_back(op.lowerBound());
736       llvm::append_range(blockArgs, op.getIterOperands());
737       replaceOpWithRegion(rewriter, op, op.getLoopBody(), blockArgs);
738       return success();
739     }
740 
741     return failure();
742   }
743 };
744 
745 /// Perform a replacement of one iter OpOperand of an scf.for to the
746 /// `replacement` value which is expected to be the source of a tensor.cast.
747 /// tensor.cast ops are inserted inside the block to account for the type cast.
replaceTensorCastForOpIterArg(PatternRewriter & rewriter,OpOperand & operand,Value replacement)748 static ForOp replaceTensorCastForOpIterArg(PatternRewriter &rewriter,
749                                            OpOperand &operand,
750                                            Value replacement) {
751   Type oldType = operand.get().getType(), newType = replacement.getType();
752   assert(oldType.isa<RankedTensorType>() && newType.isa<RankedTensorType>() &&
753          "expected ranked tensor types");
754 
755   // 1. Create new iter operands, exactly 1 is replaced.
756   ForOp forOp = cast<ForOp>(operand.getOwner());
757   assert(operand.getOperandNumber() >= forOp.getNumControlOperands() &&
758          "expected an iter OpOperand");
759   if (operand.get().getType() == replacement.getType())
760     return forOp;
761   SmallVector<Value> newIterOperands;
762   for (OpOperand &opOperand : forOp.getIterOpOperands()) {
763     if (opOperand.getOperandNumber() == operand.getOperandNumber()) {
764       newIterOperands.push_back(replacement);
765       continue;
766     }
767     newIterOperands.push_back(opOperand.get());
768   }
769 
770   // 2. Create the new forOp shell.
771   scf::ForOp newForOp = rewriter.create<scf::ForOp>(
772       forOp.getLoc(), forOp.lowerBound(), forOp.upperBound(), forOp.step(),
773       newIterOperands);
774   Block &newBlock = newForOp.region().front();
775   SmallVector<Value, 4> newBlockTransferArgs(newBlock.getArguments().begin(),
776                                              newBlock.getArguments().end());
777 
778   // 3. Inject an incoming cast op at the beginning of the block for the bbArg
779   // corresponding to the `replacement` value.
780   OpBuilder::InsertionGuard g(rewriter);
781   rewriter.setInsertionPoint(&newBlock, newBlock.begin());
782   BlockArgument newRegionIterArg = newForOp.getRegionIterArgForOpOperand(
783       newForOp->getOpOperand(operand.getOperandNumber()));
784   Value castIn = rewriter.create<tensor::CastOp>(newForOp.getLoc(), oldType,
785                                                  newRegionIterArg);
786   newBlockTransferArgs[newRegionIterArg.getArgNumber()] = castIn;
787 
788   // 4. Steal the old block ops, mapping to the newBlockTransferArgs.
789   Block &oldBlock = forOp.region().front();
790   rewriter.mergeBlocks(&oldBlock, &newBlock, newBlockTransferArgs);
791 
792   // 5. Inject an outgoing cast op at the end of the block and yield it instead.
793   auto clonedYieldOp = cast<scf::YieldOp>(newBlock.getTerminator());
794   rewriter.setInsertionPoint(clonedYieldOp);
795   unsigned yieldIdx =
796       newRegionIterArg.getArgNumber() - forOp.getNumInductionVars();
797   Value castOut = rewriter.create<tensor::CastOp>(
798       newForOp.getLoc(), newType, clonedYieldOp.getOperand(yieldIdx));
799   SmallVector<Value> newYieldOperands = clonedYieldOp.getOperands();
800   newYieldOperands[yieldIdx] = castOut;
801   rewriter.create<scf::YieldOp>(newForOp.getLoc(), newYieldOperands);
802   rewriter.eraseOp(clonedYieldOp);
803 
804   // 6. Inject an outgoing cast op after the forOp.
805   rewriter.setInsertionPointAfter(newForOp);
806   SmallVector<Value> newResults = newForOp.getResults();
807   newResults[yieldIdx] = rewriter.create<tensor::CastOp>(
808       newForOp.getLoc(), oldType, newResults[yieldIdx]);
809 
810   return newForOp;
811 }
812 
813 /// Fold scf.for iter_arg/result pairs that go through incoming/ougoing
814 /// a tensor.cast op pair so as to pull the tensor.cast inside the scf.for:
815 ///
816 /// ```
817 ///   %0 = tensor.cast %t0 : tensor<32x1024xf32> to tensor<?x?xf32>
818 ///   %1 = scf.for %i = %c0 to %c1024 step %c32 iter_args(%iter_t0 = %0)
819 ///      -> (tensor<?x?xf32>) {
820 ///     %2 = call @do(%iter_t0) : (tensor<?x?xf32>) -> tensor<?x?xf32>
821 ///     scf.yield %2 : tensor<?x?xf32>
822 ///   }
823 ///   %2 = tensor.cast %1 : tensor<?x?xf32> to tensor<32x1024xf32>
824 ///   use_of(%2)
825 /// ```
826 ///
827 /// folds into:
828 ///
829 /// ```
830 ///   %0 = scf.for %arg2 = %c0 to %c1024 step %c32 iter_args(%arg3 = %arg0)
831 ///       -> (tensor<32x1024xf32>) {
832 ///     %2 = tensor.cast %arg3 : tensor<32x1024xf32> to tensor<?x?xf32>
833 ///     %3 = call @do(%2) : (tensor<?x?xf32>) -> tensor<?x?xf32>
834 ///     %4 = tensor.cast %3 : tensor<?x?xf32> to tensor<32x1024xf32>
835 ///     scf.yield %4 : tensor<32x1024xf32>
836 ///   }
837 ///   use_of(%0)
838 /// ```
839 struct ForOpTensorCastFolder : public OpRewritePattern<ForOp> {
840   using OpRewritePattern<ForOp>::OpRewritePattern;
841 
matchAndRewrite__anon56feca5b0511::ForOpTensorCastFolder842   LogicalResult matchAndRewrite(ForOp op,
843                                 PatternRewriter &rewriter) const override {
844     for (auto it : llvm::zip(op.getIterOpOperands(), op.getResults())) {
845       OpOperand &iterOpOperand = std::get<0>(it);
846       auto incomingCast = iterOpOperand.get().getDefiningOp<tensor::CastOp>();
847       if (!incomingCast)
848         continue;
849       if (!std::get<1>(it).hasOneUse())
850         continue;
851       auto outgoingCastOp =
852           dyn_cast<tensor::CastOp>(*std::get<1>(it).user_begin());
853       if (!outgoingCastOp)
854         continue;
855 
856       // Must be a tensor.cast op pair with matching types.
857       if (outgoingCastOp.getResult().getType() !=
858           incomingCast.source().getType())
859         continue;
860 
861       // Create a new ForOp with that iter operand replaced.
862       auto newForOp = replaceTensorCastForOpIterArg(rewriter, iterOpOperand,
863                                                     incomingCast.source());
864 
865       // Insert outgoing cast and use it to replace the corresponding result.
866       rewriter.setInsertionPointAfter(newForOp);
867       SmallVector<Value> replacements = newForOp.getResults();
868       unsigned returnIdx =
869           iterOpOperand.getOperandNumber() - op.getNumControlOperands();
870       replacements[returnIdx] = rewriter.create<tensor::CastOp>(
871           op.getLoc(), incomingCast.dest().getType(), replacements[returnIdx]);
872       rewriter.replaceOp(op, replacements);
873       return success();
874     }
875     return failure();
876   }
877 };
878 
879 /// Canonicalize the iter_args of an scf::ForOp that involve a tensor_load and
880 /// for which only the last loop iteration is actually visible outside of the
881 /// loop. The canonicalization looks for a pattern such as:
882 /// ```
883 ///    %t0 = ... : tensor_type
884 ///    %0 = scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) {
885 ///      ...
886 ///      // %m is either buffer_cast(%bb00) or defined above the loop
887 ///      %m... : memref_type
888 ///      ... // uses of %m with potential inplace updates
889 ///      %new_tensor = tensor_load %m : memref_type
890 ///      ...
891 ///      scf.yield %new_tensor : tensor_type
892 ///    }
893 /// ```
894 ///
895 /// `%bb0` may have either 0 or 1 use. If it has 1 use it must be exactly a
896 /// `%m = buffer_cast %bb0` op that feeds into the yielded `tensor_load`
897 /// op.
898 ///
899 /// If no aliasing write to the memref `%m`, from which `%new_tensor`is loaded,
900 /// occurs between tensor_load and yield then the value %0 visible outside of
901 /// the loop is the last `tensor_load` produced in the loop.
902 ///
903 /// For now, we approximate the absence of aliasing by only supporting the case
904 /// when the tensor_load is the operation immediately preceding the yield.
905 ///
906 /// The canonicalization rewrites the pattern as:
907 /// ```
908 ///    // %m is either a buffer_cast or defined above
909 ///    %m... : memref_type
910 ///    scf.for ... iter_args(%bb0 : %t0) -> (tensor_type) {
911 ///      ... // uses of %m with potential inplace updates
912 ///      scf.yield %bb0: tensor_type
913 ///    }
914 ///    %0 = tensor_load %m : memref_type
915 /// ```
916 ///
917 /// A later bbArg canonicalization will further rewrite as:
918 /// ```
919 ///    // %m is either a buffer_cast or defined above
920 ///    %m... : memref_type
921 ///    scf.for ... { // no iter_args
922 ///      ... // uses of %m with potential inplace updates
923 ///    }
924 ///    %0 = tensor_load %m : memref_type
925 /// ```
926 struct LastTensorLoadCanonicalization : public OpRewritePattern<ForOp> {
927   using OpRewritePattern<ForOp>::OpRewritePattern;
928 
matchAndRewrite__anon56feca5b0511::LastTensorLoadCanonicalization929   LogicalResult matchAndRewrite(ForOp forOp,
930                                 PatternRewriter &rewriter) const override {
931     assert(std::next(forOp.region().begin()) == forOp.region().end() &&
932            "unexpected multiple blocks");
933 
934     Location loc = forOp.getLoc();
935     DenseMap<Value, Value> replacements;
936     for (BlockArgument bbArg : forOp.getRegionIterArgs()) {
937       unsigned idx = bbArg.getArgNumber() - /*numIv=*/1;
938       auto yieldOp = cast<scf::YieldOp>(forOp.region().front().getTerminator());
939       Value yieldVal = yieldOp->getOperand(idx);
940       auto tensorLoadOp = yieldVal.getDefiningOp<memref::TensorLoadOp>();
941       bool isTensor = bbArg.getType().isa<TensorType>();
942 
943       memref::BufferCastOp bufferCastOp;
944       // Either bbArg has no use or it has a single buffer_cast use.
945       if (bbArg.hasOneUse())
946         bufferCastOp =
947             dyn_cast<memref::BufferCastOp>(*bbArg.getUsers().begin());
948       if (!isTensor || !tensorLoadOp || (!bbArg.use_empty() && !bufferCastOp))
949         continue;
950       // If bufferCastOp is present, it must feed into the `tensorLoadOp`.
951       if (bufferCastOp && tensorLoadOp.memref() != bufferCastOp)
952         continue;
953       // TODO: Any aliasing write of tensorLoadOp.memref() nested under `forOp`
954       // must be before `tensorLoadOp` in the block so that the lastWrite
955       // property is not subject to additional side-effects.
956       // For now, we only support the case when tensorLoadOp appears immediately
957       // before the terminator.
958       if (tensorLoadOp->getNextNode() != yieldOp)
959         continue;
960 
961       // Clone the optional bufferCastOp before forOp.
962       if (bufferCastOp) {
963         rewriter.setInsertionPoint(forOp);
964         rewriter.replaceOpWithNewOp<memref::BufferCastOp>(
965             bufferCastOp, bufferCastOp.memref().getType(),
966             bufferCastOp.tensor());
967       }
968 
969       // Clone the tensorLoad after forOp.
970       rewriter.setInsertionPointAfter(forOp);
971       Value newTensorLoad =
972           rewriter.create<memref::TensorLoadOp>(loc, tensorLoadOp.memref());
973       Value forOpResult = forOp.getResult(bbArg.getArgNumber() - /*iv=*/1);
974       replacements.insert(std::make_pair(forOpResult, newTensorLoad));
975 
976       // Make the terminator just yield the bbArg, the old tensorLoadOp + the
977       // old bbArg (that is now directly yielded) will canonicalize away.
978       rewriter.startRootUpdate(yieldOp);
979       yieldOp.setOperand(idx, bbArg);
980       rewriter.finalizeRootUpdate(yieldOp);
981     }
982     if (replacements.empty())
983       return failure();
984 
985     // We want to replace a subset of the results of `forOp`. rewriter.replaceOp
986     // replaces the whole op and erase it unconditionally. This is wrong for
987     // `forOp` as it generally contains ops with side effects.
988     // Instead, use `rewriter.replaceOpWithIf`.
989     SmallVector<Value> newResults;
990     newResults.reserve(forOp.getNumResults());
991     for (Value v : forOp.getResults()) {
992       auto it = replacements.find(v);
993       newResults.push_back((it != replacements.end()) ? it->second : v);
994     }
995     unsigned idx = 0;
996     rewriter.replaceOpWithIf(forOp, newResults, [&](OpOperand &op) {
997       return op.get() != newResults[idx++];
998     });
999     return success();
1000   }
1001 };
1002 } // namespace
1003 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1004 void ForOp::getCanonicalizationPatterns(RewritePatternSet &results,
1005                                         MLIRContext *context) {
1006   results.add<ForOpIterArgsFolder, SimplifyTrivialLoops,
1007               LastTensorLoadCanonicalization, ForOpTensorCastFolder>(context);
1008 }
1009 
1010 //===----------------------------------------------------------------------===//
1011 // IfOp
1012 //===----------------------------------------------------------------------===//
1013 
build(OpBuilder & builder,OperationState & result,Value cond,bool withElseRegion)1014 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
1015                  bool withElseRegion) {
1016   build(builder, result, /*resultTypes=*/llvm::None, cond, withElseRegion);
1017 }
1018 
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,Value cond,bool withElseRegion)1019 void IfOp::build(OpBuilder &builder, OperationState &result,
1020                  TypeRange resultTypes, Value cond, bool withElseRegion) {
1021   auto addTerminator = [&](OpBuilder &nested, Location loc) {
1022     if (resultTypes.empty())
1023       IfOp::ensureTerminator(*nested.getInsertionBlock()->getParent(), nested,
1024                              loc);
1025   };
1026 
1027   build(builder, result, resultTypes, cond, addTerminator,
1028         withElseRegion ? addTerminator
1029                        : function_ref<void(OpBuilder &, Location)>());
1030 }
1031 
build(OpBuilder & builder,OperationState & result,TypeRange resultTypes,Value cond,function_ref<void (OpBuilder &,Location)> thenBuilder,function_ref<void (OpBuilder &,Location)> elseBuilder)1032 void IfOp::build(OpBuilder &builder, OperationState &result,
1033                  TypeRange resultTypes, Value cond,
1034                  function_ref<void(OpBuilder &, Location)> thenBuilder,
1035                  function_ref<void(OpBuilder &, Location)> elseBuilder) {
1036   assert(thenBuilder && "the builder callback for 'then' must be present");
1037 
1038   result.addOperands(cond);
1039   result.addTypes(resultTypes);
1040 
1041   OpBuilder::InsertionGuard guard(builder);
1042   Region *thenRegion = result.addRegion();
1043   builder.createBlock(thenRegion);
1044   thenBuilder(builder, result.location);
1045 
1046   Region *elseRegion = result.addRegion();
1047   if (!elseBuilder)
1048     return;
1049 
1050   builder.createBlock(elseRegion);
1051   elseBuilder(builder, result.location);
1052 }
1053 
build(OpBuilder & builder,OperationState & result,Value cond,function_ref<void (OpBuilder &,Location)> thenBuilder,function_ref<void (OpBuilder &,Location)> elseBuilder)1054 void IfOp::build(OpBuilder &builder, OperationState &result, Value cond,
1055                  function_ref<void(OpBuilder &, Location)> thenBuilder,
1056                  function_ref<void(OpBuilder &, Location)> elseBuilder) {
1057   build(builder, result, TypeRange(), cond, thenBuilder, elseBuilder);
1058 }
1059 
verify(IfOp op)1060 static LogicalResult verify(IfOp op) {
1061   if (op.getNumResults() != 0 && op.elseRegion().empty())
1062     return op.emitOpError("must have an else block if defining values");
1063 
1064   return RegionBranchOpInterface::verifyTypes(op);
1065 }
1066 
parseIfOp(OpAsmParser & parser,OperationState & result)1067 static ParseResult parseIfOp(OpAsmParser &parser, OperationState &result) {
1068   // Create the regions for 'then'.
1069   result.regions.reserve(2);
1070   Region *thenRegion = result.addRegion();
1071   Region *elseRegion = result.addRegion();
1072 
1073   auto &builder = parser.getBuilder();
1074   OpAsmParser::OperandType cond;
1075   Type i1Type = builder.getIntegerType(1);
1076   if (parser.parseOperand(cond) ||
1077       parser.resolveOperand(cond, i1Type, result.operands))
1078     return failure();
1079   // Parse optional results type list.
1080   if (parser.parseOptionalArrowTypeList(result.types))
1081     return failure();
1082   // Parse the 'then' region.
1083   if (parser.parseRegion(*thenRegion, /*arguments=*/{}, /*argTypes=*/{}))
1084     return failure();
1085   IfOp::ensureTerminator(*thenRegion, parser.getBuilder(), result.location);
1086 
1087   // If we find an 'else' keyword then parse the 'else' region.
1088   if (!parser.parseOptionalKeyword("else")) {
1089     if (parser.parseRegion(*elseRegion, /*arguments=*/{}, /*argTypes=*/{}))
1090       return failure();
1091     IfOp::ensureTerminator(*elseRegion, parser.getBuilder(), result.location);
1092   }
1093 
1094   // Parse the optional attribute list.
1095   if (parser.parseOptionalAttrDict(result.attributes))
1096     return failure();
1097   return success();
1098 }
1099 
print(OpAsmPrinter & p,IfOp op)1100 static void print(OpAsmPrinter &p, IfOp op) {
1101   bool printBlockTerminators = false;
1102 
1103   p << IfOp::getOperationName() << " " << op.condition();
1104   if (!op.results().empty()) {
1105     p << " -> (" << op.getResultTypes() << ")";
1106     // Print yield explicitly if the op defines values.
1107     printBlockTerminators = true;
1108   }
1109   p.printRegion(op.thenRegion(),
1110                 /*printEntryBlockArgs=*/false,
1111                 /*printBlockTerminators=*/printBlockTerminators);
1112 
1113   // Print the 'else' regions if it exists and has a block.
1114   auto &elseRegion = op.elseRegion();
1115   if (!elseRegion.empty()) {
1116     p << " else";
1117     p.printRegion(elseRegion,
1118                   /*printEntryBlockArgs=*/false,
1119                   /*printBlockTerminators=*/printBlockTerminators);
1120   }
1121 
1122   p.printOptionalAttrDict(op->getAttrs());
1123 }
1124 
1125 /// Given the region at `index`, or the parent operation if `index` is None,
1126 /// return the successor regions. These are the regions that may be selected
1127 /// during the flow of control. `operands` is a set of optional attributes that
1128 /// correspond to a constant value for each operand, or null if that operand is
1129 /// not a constant.
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)1130 void IfOp::getSuccessorRegions(Optional<unsigned> index,
1131                                ArrayRef<Attribute> operands,
1132                                SmallVectorImpl<RegionSuccessor> &regions) {
1133   // The `then` and the `else` region branch back to the parent operation.
1134   if (index.hasValue()) {
1135     regions.push_back(RegionSuccessor(getResults()));
1136     return;
1137   }
1138 
1139   // Don't consider the else region if it is empty.
1140   Region *elseRegion = &this->elseRegion();
1141   if (elseRegion->empty())
1142     elseRegion = nullptr;
1143 
1144   // Otherwise, the successor is dependent on the condition.
1145   bool condition;
1146   if (auto condAttr = operands.front().dyn_cast_or_null<IntegerAttr>()) {
1147     condition = condAttr.getValue().isOneValue();
1148   } else {
1149     // If the condition isn't constant, both regions may be executed.
1150     regions.push_back(RegionSuccessor(&thenRegion()));
1151     // If the else region does not exist, it is not a viable successor.
1152     if (elseRegion)
1153       regions.push_back(RegionSuccessor(elseRegion));
1154     return;
1155   }
1156 
1157   // Add the successor regions using the condition.
1158   regions.push_back(RegionSuccessor(condition ? &thenRegion() : elseRegion));
1159 }
1160 
1161 namespace {
1162 // Pattern to remove unused IfOp results.
1163 struct RemoveUnusedResults : public OpRewritePattern<IfOp> {
1164   using OpRewritePattern<IfOp>::OpRewritePattern;
1165 
transferBody__anon56feca5b0911::RemoveUnusedResults1166   void transferBody(Block *source, Block *dest, ArrayRef<OpResult> usedResults,
1167                     PatternRewriter &rewriter) const {
1168     // Move all operations to the destination block.
1169     rewriter.mergeBlocks(source, dest);
1170     // Replace the yield op by one that returns only the used values.
1171     auto yieldOp = cast<scf::YieldOp>(dest->getTerminator());
1172     SmallVector<Value, 4> usedOperands;
1173     llvm::transform(usedResults, std::back_inserter(usedOperands),
1174                     [&](OpResult result) {
1175                       return yieldOp.getOperand(result.getResultNumber());
1176                     });
1177     rewriter.updateRootInPlace(yieldOp,
1178                                [&]() { yieldOp->setOperands(usedOperands); });
1179   }
1180 
matchAndRewrite__anon56feca5b0911::RemoveUnusedResults1181   LogicalResult matchAndRewrite(IfOp op,
1182                                 PatternRewriter &rewriter) const override {
1183     // Compute the list of used results.
1184     SmallVector<OpResult, 4> usedResults;
1185     llvm::copy_if(op.getResults(), std::back_inserter(usedResults),
1186                   [](OpResult result) { return !result.use_empty(); });
1187 
1188     // Replace the operation if only a subset of its results have uses.
1189     if (usedResults.size() == op.getNumResults())
1190       return failure();
1191 
1192     // Compute the result types of the replacement operation.
1193     SmallVector<Type, 4> newTypes;
1194     llvm::transform(usedResults, std::back_inserter(newTypes),
1195                     [](OpResult result) { return result.getType(); });
1196 
1197     // Create a replacement operation with empty then and else regions.
1198     auto emptyBuilder = [](OpBuilder &, Location) {};
1199     auto newOp = rewriter.create<IfOp>(op.getLoc(), newTypes, op.condition(),
1200                                        emptyBuilder, emptyBuilder);
1201 
1202     // Move the bodies and replace the terminators (note there is a then and
1203     // an else region since the operation returns results).
1204     transferBody(op.getBody(0), newOp.getBody(0), usedResults, rewriter);
1205     transferBody(op.getBody(1), newOp.getBody(1), usedResults, rewriter);
1206 
1207     // Replace the operation by the new one.
1208     SmallVector<Value, 4> repResults(op.getNumResults());
1209     for (auto en : llvm::enumerate(usedResults))
1210       repResults[en.value().getResultNumber()] = newOp.getResult(en.index());
1211     rewriter.replaceOp(op, repResults);
1212     return success();
1213   }
1214 };
1215 
1216 struct RemoveStaticCondition : public OpRewritePattern<IfOp> {
1217   using OpRewritePattern<IfOp>::OpRewritePattern;
1218 
matchAndRewrite__anon56feca5b0911::RemoveStaticCondition1219   LogicalResult matchAndRewrite(IfOp op,
1220                                 PatternRewriter &rewriter) const override {
1221     auto constant = op.condition().getDefiningOp<ConstantOp>();
1222     if (!constant)
1223       return failure();
1224 
1225     if (constant.getValue().cast<BoolAttr>().getValue())
1226       replaceOpWithRegion(rewriter, op, op.thenRegion());
1227     else if (!op.elseRegion().empty())
1228       replaceOpWithRegion(rewriter, op, op.elseRegion());
1229     else
1230       rewriter.eraseOp(op);
1231 
1232     return success();
1233   }
1234 };
1235 
1236 struct ConvertTrivialIfToSelect : public OpRewritePattern<IfOp> {
1237   using OpRewritePattern<IfOp>::OpRewritePattern;
1238 
matchAndRewrite__anon56feca5b0911::ConvertTrivialIfToSelect1239   LogicalResult matchAndRewrite(IfOp op,
1240                                 PatternRewriter &rewriter) const override {
1241     if (op->getNumResults() == 0)
1242       return failure();
1243 
1244     if (!llvm::hasSingleElement(op.thenRegion().front()) ||
1245         !llvm::hasSingleElement(op.elseRegion().front()))
1246       return failure();
1247 
1248     auto cond = op.condition();
1249     auto thenYieldArgs =
1250         cast<scf::YieldOp>(op.thenRegion().front().getTerminator())
1251             .getOperands();
1252     auto elseYieldArgs =
1253         cast<scf::YieldOp>(op.elseRegion().front().getTerminator())
1254             .getOperands();
1255     SmallVector<Value> results(op->getNumResults());
1256     assert(thenYieldArgs.size() == results.size());
1257     assert(elseYieldArgs.size() == results.size());
1258     for (auto it : llvm::enumerate(llvm::zip(thenYieldArgs, elseYieldArgs))) {
1259       Value trueVal = std::get<0>(it.value());
1260       Value falseVal = std::get<1>(it.value());
1261       if (trueVal == falseVal)
1262         results[it.index()] = trueVal;
1263       else
1264         results[it.index()] =
1265             rewriter.create<SelectOp>(op.getLoc(), cond, trueVal, falseVal);
1266     }
1267 
1268     rewriter.replaceOp(op, results);
1269     return success();
1270   }
1271 };
1272 
1273 /// Allow the true region of an if to assume the condition is true
1274 /// and vice versa. For example:
1275 ///
1276 ///   scf.if %cmp {
1277 ///      print(%cmp)
1278 ///   }
1279 ///
1280 ///  becomes
1281 ///
1282 ///   scf.if %cmp {
1283 ///      print(true)
1284 ///   }
1285 ///
1286 struct ConditionPropagation : public OpRewritePattern<IfOp> {
1287   using OpRewritePattern<IfOp>::OpRewritePattern;
1288 
matchAndRewrite__anon56feca5b0911::ConditionPropagation1289   LogicalResult matchAndRewrite(IfOp op,
1290                                 PatternRewriter &rewriter) const override {
1291     // Early exit if the condition is constant since replacing a constant
1292     // in the body with another constant isn't a simplification.
1293     if (op.condition().getDefiningOp<ConstantOp>())
1294       return failure();
1295 
1296     bool changed = false;
1297     mlir::Type i1Ty = rewriter.getI1Type();
1298 
1299     // These variables serve to prevent creating duplicate constants
1300     // and hold constant true or false values.
1301     Value constantTrue = nullptr;
1302     Value constantFalse = nullptr;
1303 
1304     for (OpOperand &use :
1305          llvm::make_early_inc_range(op.condition().getUses())) {
1306       if (op.thenRegion().isAncestor(use.getOwner()->getParentRegion())) {
1307         changed = true;
1308 
1309         if (!constantTrue)
1310           constantTrue = rewriter.create<mlir::ConstantOp>(
1311               op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1));
1312 
1313         rewriter.updateRootInPlace(use.getOwner(),
1314                                    [&]() { use.set(constantTrue); });
1315       } else if (op.elseRegion().isAncestor(
1316                      use.getOwner()->getParentRegion())) {
1317         changed = true;
1318 
1319         if (!constantFalse)
1320           constantFalse = rewriter.create<mlir::ConstantOp>(
1321               op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 0));
1322 
1323         rewriter.updateRootInPlace(use.getOwner(),
1324                                    [&]() { use.set(constantFalse); });
1325       }
1326     }
1327 
1328     return success(changed);
1329   }
1330 };
1331 
1332 /// Remove any statements from an if that are equivalent to the condition
1333 /// or its negation. For example:
1334 ///
1335 ///    %res:2 = scf.if %cmp {
1336 ///       yield something(), true
1337 ///    } else {
1338 ///       yield something2(), false
1339 ///    }
1340 ///    print(%res#1)
1341 ///
1342 ///  becomes
1343 ///    %res = scf.if %cmp {
1344 ///       yield something()
1345 ///    } else {
1346 ///       yield something2()
1347 ///    }
1348 ///    print(%cmp)
1349 ///
1350 /// Additionally if both branches yield the same value, replace all uses
1351 /// of the result with the yielded value.
1352 ///
1353 ///    %res:2 = scf.if %cmp {
1354 ///       yield something(), %arg1
1355 ///    } else {
1356 ///       yield something2(), %arg1
1357 ///    }
1358 ///    print(%res#1)
1359 ///
1360 ///  becomes
1361 ///    %res = scf.if %cmp {
1362 ///       yield something()
1363 ///    } else {
1364 ///       yield something2()
1365 ///    }
1366 ///    print(%arg1)
1367 ///
1368 struct ReplaceIfYieldWithConditionOrValue : public OpRewritePattern<IfOp> {
1369   using OpRewritePattern<IfOp>::OpRewritePattern;
1370 
matchAndRewrite__anon56feca5b0911::ReplaceIfYieldWithConditionOrValue1371   LogicalResult matchAndRewrite(IfOp op,
1372                                 PatternRewriter &rewriter) const override {
1373     // Early exit if there are no results that could be replaced.
1374     if (op.getNumResults() == 0)
1375       return failure();
1376 
1377     auto trueYield = cast<scf::YieldOp>(op.thenRegion().back().getTerminator());
1378     auto falseYield =
1379         cast<scf::YieldOp>(op.elseRegion().back().getTerminator());
1380 
1381     rewriter.setInsertionPoint(op->getBlock(),
1382                                op.getOperation()->getIterator());
1383     bool changed = false;
1384     Type i1Ty = rewriter.getI1Type();
1385     for (auto tup :
1386          llvm::zip(trueYield.results(), falseYield.results(), op.results())) {
1387       Value trueResult, falseResult, opResult;
1388       std::tie(trueResult, falseResult, opResult) = tup;
1389 
1390       if (trueResult == falseResult) {
1391         if (!opResult.use_empty()) {
1392           opResult.replaceAllUsesWith(trueResult);
1393           changed = true;
1394         }
1395         continue;
1396       }
1397 
1398       auto trueYield = trueResult.getDefiningOp<ConstantOp>();
1399       if (!trueYield)
1400         continue;
1401 
1402       if (!trueYield.getType().isInteger(1))
1403         continue;
1404 
1405       auto falseYield = falseResult.getDefiningOp<ConstantOp>();
1406       if (!falseYield)
1407         continue;
1408 
1409       bool trueVal = trueYield.getValue().cast<BoolAttr>().getValue();
1410       bool falseVal = falseYield.getValue().cast<BoolAttr>().getValue();
1411       if (!trueVal && falseVal) {
1412         if (!opResult.use_empty()) {
1413           Value notCond = rewriter.create<XOrOp>(
1414               op.getLoc(), op.condition(),
1415               rewriter.create<mlir::ConstantOp>(
1416                   op.getLoc(), i1Ty, rewriter.getIntegerAttr(i1Ty, 1)));
1417           opResult.replaceAllUsesWith(notCond);
1418           changed = true;
1419         }
1420       }
1421       if (trueVal && !falseVal) {
1422         if (!opResult.use_empty()) {
1423           opResult.replaceAllUsesWith(op.condition());
1424           changed = true;
1425         }
1426       }
1427     }
1428     return success(changed);
1429   }
1430 };
1431 
1432 /// Merge any consecutive scf.if's with the same condition.
1433 ///
1434 ///    scf.if %cond {
1435 ///       firstCodeTrue();...
1436 ///    } else {
1437 ///       firstCodeFalse();...
1438 ///    }
1439 ///    %res = scf.if %cond {
1440 ///       secondCodeTrue();...
1441 ///    } else {
1442 ///       secondCodeFalse();...
1443 ///    }
1444 ///
1445 ///  becomes
1446 ///    %res = scf.if %cmp {
1447 ///       firstCodeTrue();...
1448 ///       secondCodeTrue();...
1449 ///    } else {
1450 ///       firstCodeFalse();...
1451 ///       secondCodeFalse();...
1452 ///    }
1453 struct CombineIfs : public OpRewritePattern<IfOp> {
1454   using OpRewritePattern<IfOp>::OpRewritePattern;
1455 
matchAndRewrite__anon56feca5b0911::CombineIfs1456   LogicalResult matchAndRewrite(IfOp nextIf,
1457                                 PatternRewriter &rewriter) const override {
1458     Block *parent = nextIf->getBlock();
1459     if (nextIf == &parent->front())
1460       return failure();
1461 
1462     auto prevIf = dyn_cast<IfOp>(nextIf->getPrevNode());
1463     if (!prevIf)
1464       return failure();
1465 
1466     if (nextIf.condition() != prevIf.condition())
1467       return failure();
1468 
1469     // Don't permit merging if a result of the first if is used
1470     // within the second.
1471     if (llvm::any_of(prevIf->getUsers(),
1472                      [&](Operation *user) { return nextIf->isAncestor(user); }))
1473       return failure();
1474 
1475     SmallVector<Type> mergedTypes(prevIf.getResultTypes());
1476     llvm::append_range(mergedTypes, nextIf.getResultTypes());
1477 
1478     IfOp combinedIf = rewriter.create<IfOp>(
1479         nextIf.getLoc(), mergedTypes, nextIf.condition(), /*hasElse=*/false);
1480     rewriter.eraseBlock(&combinedIf.thenRegion().back());
1481 
1482     YieldOp thenYield = prevIf.thenYield();
1483     YieldOp thenYield2 = nextIf.thenYield();
1484 
1485     combinedIf.thenRegion().getBlocks().splice(
1486         combinedIf.thenRegion().getBlocks().begin(),
1487         prevIf.thenRegion().getBlocks());
1488 
1489     rewriter.mergeBlocks(nextIf.thenBlock(), combinedIf.thenBlock());
1490     rewriter.setInsertionPointToEnd(combinedIf.thenBlock());
1491 
1492     SmallVector<Value> mergedYields(thenYield.getOperands());
1493     llvm::append_range(mergedYields, thenYield2.getOperands());
1494     rewriter.create<YieldOp>(thenYield2.getLoc(), mergedYields);
1495     rewriter.eraseOp(thenYield);
1496     rewriter.eraseOp(thenYield2);
1497 
1498     combinedIf.elseRegion().getBlocks().splice(
1499         combinedIf.elseRegion().getBlocks().begin(),
1500         prevIf.elseRegion().getBlocks());
1501 
1502     if (!nextIf.elseRegion().empty()) {
1503       if (combinedIf.elseRegion().empty()) {
1504         combinedIf.elseRegion().getBlocks().splice(
1505             combinedIf.elseRegion().getBlocks().begin(),
1506             nextIf.elseRegion().getBlocks());
1507       } else {
1508         YieldOp elseYield = combinedIf.elseYield();
1509         YieldOp elseYield2 = nextIf.elseYield();
1510         rewriter.mergeBlocks(nextIf.elseBlock(), combinedIf.elseBlock());
1511 
1512         rewriter.setInsertionPointToEnd(combinedIf.elseBlock());
1513 
1514         SmallVector<Value> mergedElseYields(elseYield.getOperands());
1515         llvm::append_range(mergedElseYields, elseYield2.getOperands());
1516 
1517         rewriter.create<YieldOp>(elseYield2.getLoc(), mergedElseYields);
1518         rewriter.eraseOp(elseYield);
1519         rewriter.eraseOp(elseYield2);
1520       }
1521     }
1522 
1523     SmallVector<Value> prevValues;
1524     SmallVector<Value> nextValues;
1525     for (auto pair : llvm::enumerate(combinedIf.getResults())) {
1526       if (pair.index() < prevIf.getNumResults())
1527         prevValues.push_back(pair.value());
1528       else
1529         nextValues.push_back(pair.value());
1530     }
1531     rewriter.replaceOp(prevIf, prevValues);
1532     rewriter.replaceOp(nextIf, nextValues);
1533     return success();
1534   }
1535 };
1536 
1537 /// Pattern to remove an empty else branch.
1538 struct RemoveEmptyElseBranch : public OpRewritePattern<IfOp> {
1539   using OpRewritePattern<IfOp>::OpRewritePattern;
1540 
matchAndRewrite__anon56feca5b0911::RemoveEmptyElseBranch1541   LogicalResult matchAndRewrite(IfOp ifOp,
1542                                 PatternRewriter &rewriter) const override {
1543     // Cannot remove else region when there are operation results.
1544     if (ifOp.getNumResults())
1545       return failure();
1546     Block *elseBlock = ifOp.elseBlock();
1547     if (!elseBlock || !llvm::hasSingleElement(*elseBlock))
1548       return failure();
1549     auto newIfOp = rewriter.cloneWithoutRegions(ifOp);
1550     rewriter.inlineRegionBefore(ifOp.thenRegion(), newIfOp.thenRegion(),
1551                                 newIfOp.thenRegion().begin());
1552     rewriter.eraseOp(ifOp);
1553     return success();
1554   }
1555 };
1556 
1557 } // namespace
1558 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1559 void IfOp::getCanonicalizationPatterns(RewritePatternSet &results,
1560                                        MLIRContext *context) {
1561   results
1562       .add<RemoveUnusedResults, RemoveStaticCondition, ConvertTrivialIfToSelect,
1563            ConditionPropagation, ReplaceIfYieldWithConditionOrValue, CombineIfs,
1564            RemoveEmptyElseBranch>(context);
1565 }
1566 
thenBlock()1567 Block *IfOp::thenBlock() { return &thenRegion().back(); }
thenYield()1568 YieldOp IfOp::thenYield() { return cast<YieldOp>(&thenBlock()->back()); }
elseBlock()1569 Block *IfOp::elseBlock() {
1570   Region &r = elseRegion();
1571   if (r.empty())
1572     return nullptr;
1573   return &r.back();
1574 }
elseYield()1575 YieldOp IfOp::elseYield() { return cast<YieldOp>(&elseBlock()->back()); }
1576 
1577 //===----------------------------------------------------------------------===//
1578 // ParallelOp
1579 //===----------------------------------------------------------------------===//
1580 
build(OpBuilder & builder,OperationState & result,ValueRange lowerBounds,ValueRange upperBounds,ValueRange steps,ValueRange initVals,function_ref<void (OpBuilder &,Location,ValueRange,ValueRange)> bodyBuilderFn)1581 void ParallelOp::build(
1582     OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
1583     ValueRange upperBounds, ValueRange steps, ValueRange initVals,
1584     function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)>
1585         bodyBuilderFn) {
1586   result.addOperands(lowerBounds);
1587   result.addOperands(upperBounds);
1588   result.addOperands(steps);
1589   result.addOperands(initVals);
1590   result.addAttribute(
1591       ParallelOp::getOperandSegmentSizeAttr(),
1592       builder.getI32VectorAttr({static_cast<int32_t>(lowerBounds.size()),
1593                                 static_cast<int32_t>(upperBounds.size()),
1594                                 static_cast<int32_t>(steps.size()),
1595                                 static_cast<int32_t>(initVals.size())}));
1596   result.addTypes(initVals.getTypes());
1597 
1598   OpBuilder::InsertionGuard guard(builder);
1599   unsigned numIVs = steps.size();
1600   SmallVector<Type, 8> argTypes(numIVs, builder.getIndexType());
1601   Region *bodyRegion = result.addRegion();
1602   Block *bodyBlock = builder.createBlock(bodyRegion, {}, argTypes);
1603 
1604   if (bodyBuilderFn) {
1605     builder.setInsertionPointToStart(bodyBlock);
1606     bodyBuilderFn(builder, result.location,
1607                   bodyBlock->getArguments().take_front(numIVs),
1608                   bodyBlock->getArguments().drop_front(numIVs));
1609   }
1610   ParallelOp::ensureTerminator(*bodyRegion, builder, result.location);
1611 }
1612 
build(OpBuilder & builder,OperationState & result,ValueRange lowerBounds,ValueRange upperBounds,ValueRange steps,function_ref<void (OpBuilder &,Location,ValueRange)> bodyBuilderFn)1613 void ParallelOp::build(
1614     OpBuilder &builder, OperationState &result, ValueRange lowerBounds,
1615     ValueRange upperBounds, ValueRange steps,
1616     function_ref<void(OpBuilder &, Location, ValueRange)> bodyBuilderFn) {
1617   // Only pass a non-null wrapper if bodyBuilderFn is non-null itself. Make sure
1618   // we don't capture a reference to a temporary by constructing the lambda at
1619   // function level.
1620   auto wrappedBuilderFn = [&bodyBuilderFn](OpBuilder &nestedBuilder,
1621                                            Location nestedLoc, ValueRange ivs,
1622                                            ValueRange) {
1623     bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
1624   };
1625   function_ref<void(OpBuilder &, Location, ValueRange, ValueRange)> wrapper;
1626   if (bodyBuilderFn)
1627     wrapper = wrappedBuilderFn;
1628 
1629   build(builder, result, lowerBounds, upperBounds, steps, ValueRange(),
1630         wrapper);
1631 }
1632 
verify(ParallelOp op)1633 static LogicalResult verify(ParallelOp op) {
1634   // Check that there is at least one value in lowerBound, upperBound and step.
1635   // It is sufficient to test only step, because it is ensured already that the
1636   // number of elements in lowerBound, upperBound and step are the same.
1637   Operation::operand_range stepValues = op.step();
1638   if (stepValues.empty())
1639     return op.emitOpError(
1640         "needs at least one tuple element for lowerBound, upperBound and step");
1641 
1642   // Check whether all constant step values are positive.
1643   for (Value stepValue : stepValues)
1644     if (auto cst = stepValue.getDefiningOp<ConstantIndexOp>())
1645       if (cst.getValue() <= 0)
1646         return op.emitOpError("constant step operand must be positive");
1647 
1648   // Check that the body defines the same number of block arguments as the
1649   // number of tuple elements in step.
1650   Block *body = op.getBody();
1651   if (body->getNumArguments() != stepValues.size())
1652     return op.emitOpError()
1653            << "expects the same number of induction variables: "
1654            << body->getNumArguments()
1655            << " as bound and step values: " << stepValues.size();
1656   for (auto arg : body->getArguments())
1657     if (!arg.getType().isIndex())
1658       return op.emitOpError(
1659           "expects arguments for the induction variable to be of index type");
1660 
1661   // Check that the yield has no results
1662   Operation *yield = body->getTerminator();
1663   if (yield->getNumOperands() != 0)
1664     return yield->emitOpError() << "not allowed to have operands inside '"
1665                                 << ParallelOp::getOperationName() << "'";
1666 
1667   // Check that the number of results is the same as the number of ReduceOps.
1668   SmallVector<ReduceOp, 4> reductions(body->getOps<ReduceOp>());
1669   auto resultsSize = op.results().size();
1670   auto reductionsSize = reductions.size();
1671   auto initValsSize = op.initVals().size();
1672   if (resultsSize != reductionsSize)
1673     return op.emitOpError()
1674            << "expects number of results: " << resultsSize
1675            << " to be the same as number of reductions: " << reductionsSize;
1676   if (resultsSize != initValsSize)
1677     return op.emitOpError()
1678            << "expects number of results: " << resultsSize
1679            << " to be the same as number of initial values: " << initValsSize;
1680 
1681   // Check that the types of the results and reductions are the same.
1682   for (auto resultAndReduce : llvm::zip(op.results(), reductions)) {
1683     auto resultType = std::get<0>(resultAndReduce).getType();
1684     auto reduceOp = std::get<1>(resultAndReduce);
1685     auto reduceType = reduceOp.operand().getType();
1686     if (resultType != reduceType)
1687       return reduceOp.emitOpError()
1688              << "expects type of reduce: " << reduceType
1689              << " to be the same as result type: " << resultType;
1690   }
1691   return success();
1692 }
1693 
parseParallelOp(OpAsmParser & parser,OperationState & result)1694 static ParseResult parseParallelOp(OpAsmParser &parser,
1695                                    OperationState &result) {
1696   auto &builder = parser.getBuilder();
1697   // Parse an opening `(` followed by induction variables followed by `)`
1698   SmallVector<OpAsmParser::OperandType, 4> ivs;
1699   if (parser.parseRegionArgumentList(ivs, /*requiredOperandCount=*/-1,
1700                                      OpAsmParser::Delimiter::Paren))
1701     return failure();
1702 
1703   // Parse loop bounds.
1704   SmallVector<OpAsmParser::OperandType, 4> lower;
1705   if (parser.parseEqual() ||
1706       parser.parseOperandList(lower, ivs.size(),
1707                               OpAsmParser::Delimiter::Paren) ||
1708       parser.resolveOperands(lower, builder.getIndexType(), result.operands))
1709     return failure();
1710 
1711   SmallVector<OpAsmParser::OperandType, 4> upper;
1712   if (parser.parseKeyword("to") ||
1713       parser.parseOperandList(upper, ivs.size(),
1714                               OpAsmParser::Delimiter::Paren) ||
1715       parser.resolveOperands(upper, builder.getIndexType(), result.operands))
1716     return failure();
1717 
1718   // Parse step values.
1719   SmallVector<OpAsmParser::OperandType, 4> steps;
1720   if (parser.parseKeyword("step") ||
1721       parser.parseOperandList(steps, ivs.size(),
1722                               OpAsmParser::Delimiter::Paren) ||
1723       parser.resolveOperands(steps, builder.getIndexType(), result.operands))
1724     return failure();
1725 
1726   // Parse init values.
1727   SmallVector<OpAsmParser::OperandType, 4> initVals;
1728   if (succeeded(parser.parseOptionalKeyword("init"))) {
1729     if (parser.parseOperandList(initVals, /*requiredOperandCount=*/-1,
1730                                 OpAsmParser::Delimiter::Paren))
1731       return failure();
1732   }
1733 
1734   // Parse optional results in case there is a reduce.
1735   if (parser.parseOptionalArrowTypeList(result.types))
1736     return failure();
1737 
1738   // Now parse the body.
1739   Region *body = result.addRegion();
1740   SmallVector<Type, 4> types(ivs.size(), builder.getIndexType());
1741   if (parser.parseRegion(*body, ivs, types))
1742     return failure();
1743 
1744   // Set `operand_segment_sizes` attribute.
1745   result.addAttribute(
1746       ParallelOp::getOperandSegmentSizeAttr(),
1747       builder.getI32VectorAttr({static_cast<int32_t>(lower.size()),
1748                                 static_cast<int32_t>(upper.size()),
1749                                 static_cast<int32_t>(steps.size()),
1750                                 static_cast<int32_t>(initVals.size())}));
1751 
1752   // Parse attributes.
1753   if (parser.parseOptionalAttrDict(result.attributes))
1754     return failure();
1755 
1756   if (!initVals.empty())
1757     parser.resolveOperands(initVals, result.types, parser.getNameLoc(),
1758                            result.operands);
1759   // Add a terminator if none was parsed.
1760   ForOp::ensureTerminator(*body, builder, result.location);
1761 
1762   return success();
1763 }
1764 
print(OpAsmPrinter & p,ParallelOp op)1765 static void print(OpAsmPrinter &p, ParallelOp op) {
1766   p << op.getOperationName() << " (" << op.getBody()->getArguments() << ") = ("
1767     << op.lowerBound() << ") to (" << op.upperBound() << ") step (" << op.step()
1768     << ")";
1769   if (!op.initVals().empty())
1770     p << " init (" << op.initVals() << ")";
1771   p.printOptionalArrowTypeList(op.getResultTypes());
1772   p.printRegion(op.region(), /*printEntryBlockArgs=*/false);
1773   p.printOptionalAttrDict(
1774       op->getAttrs(), /*elidedAttrs=*/ParallelOp::getOperandSegmentSizeAttr());
1775 }
1776 
getLoopBody()1777 Region &ParallelOp::getLoopBody() { return region(); }
1778 
isDefinedOutsideOfLoop(Value value)1779 bool ParallelOp::isDefinedOutsideOfLoop(Value value) {
1780   return !region().isAncestor(value.getParentRegion());
1781 }
1782 
moveOutOfLoop(ArrayRef<Operation * > ops)1783 LogicalResult ParallelOp::moveOutOfLoop(ArrayRef<Operation *> ops) {
1784   for (auto *op : ops)
1785     op->moveBefore(*this);
1786   return success();
1787 }
1788 
getParallelForInductionVarOwner(Value val)1789 ParallelOp mlir::scf::getParallelForInductionVarOwner(Value val) {
1790   auto ivArg = val.dyn_cast<BlockArgument>();
1791   if (!ivArg)
1792     return ParallelOp();
1793   assert(ivArg.getOwner() && "unlinked block argument");
1794   auto *containingOp = ivArg.getOwner()->getParentOp();
1795   return dyn_cast<ParallelOp>(containingOp);
1796 }
1797 
1798 namespace {
1799 // Collapse loop dimensions that perform a single iteration.
1800 struct CollapseSingleIterationLoops : public OpRewritePattern<ParallelOp> {
1801   using OpRewritePattern<ParallelOp>::OpRewritePattern;
1802 
matchAndRewrite__anon56feca5b1311::CollapseSingleIterationLoops1803   LogicalResult matchAndRewrite(ParallelOp op,
1804                                 PatternRewriter &rewriter) const override {
1805     BlockAndValueMapping mapping;
1806     // Compute new loop bounds that omit all single-iteration loop dimensions.
1807     SmallVector<Value, 2> newLowerBounds;
1808     SmallVector<Value, 2> newUpperBounds;
1809     SmallVector<Value, 2> newSteps;
1810     newLowerBounds.reserve(op.lowerBound().size());
1811     newUpperBounds.reserve(op.upperBound().size());
1812     newSteps.reserve(op.step().size());
1813     for (auto dim : llvm::zip(op.lowerBound(), op.upperBound(), op.step(),
1814                               op.getInductionVars())) {
1815       Value lowerBound, upperBound, step, iv;
1816       std::tie(lowerBound, upperBound, step, iv) = dim;
1817       // Collect the statically known loop bounds.
1818       auto lowerBoundConstant =
1819           dyn_cast_or_null<ConstantIndexOp>(lowerBound.getDefiningOp());
1820       auto upperBoundConstant =
1821           dyn_cast_or_null<ConstantIndexOp>(upperBound.getDefiningOp());
1822       auto stepConstant =
1823           dyn_cast_or_null<ConstantIndexOp>(step.getDefiningOp());
1824       // Replace the loop induction variable by the lower bound if the loop
1825       // performs a single iteration. Otherwise, copy the loop bounds.
1826       if (lowerBoundConstant && upperBoundConstant && stepConstant &&
1827           (upperBoundConstant.getValue() - lowerBoundConstant.getValue()) > 0 &&
1828           (upperBoundConstant.getValue() - lowerBoundConstant.getValue()) <=
1829               stepConstant.getValue()) {
1830         mapping.map(iv, lowerBound);
1831       } else {
1832         newLowerBounds.push_back(lowerBound);
1833         newUpperBounds.push_back(upperBound);
1834         newSteps.push_back(step);
1835       }
1836     }
1837     // Exit if none of the loop dimensions perform a single iteration.
1838     if (newLowerBounds.size() == op.lowerBound().size())
1839       return failure();
1840 
1841     if (newLowerBounds.empty()) {
1842       // All of the loop dimensions perform a single iteration. Inline
1843       // loop body and nested ReduceOp's
1844       SmallVector<Value> results;
1845       results.reserve(op.initVals().size());
1846       for (auto &bodyOp : op.getLoopBody().front().without_terminator()) {
1847         auto reduce = dyn_cast<ReduceOp>(bodyOp);
1848         if (!reduce) {
1849           rewriter.clone(bodyOp, mapping);
1850           continue;
1851         }
1852         Block &reduceBlock = reduce.reductionOperator().front();
1853         auto initValIndex = results.size();
1854         mapping.map(reduceBlock.getArgument(0), op.initVals()[initValIndex]);
1855         mapping.map(reduceBlock.getArgument(1),
1856                     mapping.lookupOrDefault(reduce.operand()));
1857         for (auto &reduceBodyOp : reduceBlock.without_terminator())
1858           rewriter.clone(reduceBodyOp, mapping);
1859 
1860         auto result = mapping.lookupOrDefault(
1861             cast<ReduceReturnOp>(reduceBlock.getTerminator()).result());
1862         results.push_back(result);
1863       }
1864       rewriter.replaceOp(op, results);
1865       return success();
1866     }
1867     // Replace the parallel loop by lower-dimensional parallel loop.
1868     auto newOp =
1869         rewriter.create<ParallelOp>(op.getLoc(), newLowerBounds, newUpperBounds,
1870                                     newSteps, op.initVals(), nullptr);
1871     // Clone the loop body and remap the block arguments of the collapsed loops
1872     // (inlining does not support a cancellable block argument mapping).
1873     rewriter.cloneRegionBefore(op.region(), newOp.region(),
1874                                newOp.region().begin(), mapping);
1875     rewriter.replaceOp(op, newOp.getResults());
1876     return success();
1877   }
1878 };
1879 
1880 /// Removes parallel loops in which at least one lower/upper bound pair consists
1881 /// of the same values - such loops have an empty iteration domain.
1882 struct RemoveEmptyParallelLoops : public OpRewritePattern<ParallelOp> {
1883   using OpRewritePattern<ParallelOp>::OpRewritePattern;
1884 
matchAndRewrite__anon56feca5b1311::RemoveEmptyParallelLoops1885   LogicalResult matchAndRewrite(ParallelOp op,
1886                                 PatternRewriter &rewriter) const override {
1887     for (auto dim : llvm::zip(op.lowerBound(), op.upperBound())) {
1888       if (std::get<0>(dim) == std::get<1>(dim)) {
1889         rewriter.replaceOp(op, op.initVals());
1890         return success();
1891       }
1892     }
1893     return failure();
1894   }
1895 };
1896 
1897 struct MergeNestedParallelLoops : public OpRewritePattern<ParallelOp> {
1898   using OpRewritePattern<ParallelOp>::OpRewritePattern;
1899 
matchAndRewrite__anon56feca5b1311::MergeNestedParallelLoops1900   LogicalResult matchAndRewrite(ParallelOp op,
1901                                 PatternRewriter &rewriter) const override {
1902     Block &outerBody = op.getLoopBody().front();
1903     if (!llvm::hasSingleElement(outerBody.without_terminator()))
1904       return failure();
1905 
1906     auto innerOp = dyn_cast<ParallelOp>(outerBody.front());
1907     if (!innerOp)
1908       return failure();
1909 
1910     auto hasVal = [](const auto &range, Value val) {
1911       return llvm::find(range, val) != range.end();
1912     };
1913 
1914     for (auto val : outerBody.getArguments())
1915       if (hasVal(innerOp.lowerBound(), val) ||
1916           hasVal(innerOp.upperBound(), val) || hasVal(innerOp.step(), val))
1917         return failure();
1918 
1919     // Reductions are not supported yet.
1920     if (!op.initVals().empty() || !innerOp.initVals().empty())
1921       return failure();
1922 
1923     auto bodyBuilder = [&](OpBuilder &builder, Location /*loc*/,
1924                            ValueRange iterVals, ValueRange) {
1925       Block &innerBody = innerOp.getLoopBody().front();
1926       assert(iterVals.size() ==
1927              (outerBody.getNumArguments() + innerBody.getNumArguments()));
1928       BlockAndValueMapping mapping;
1929       mapping.map(outerBody.getArguments(),
1930                   iterVals.take_front(outerBody.getNumArguments()));
1931       mapping.map(innerBody.getArguments(),
1932                   iterVals.take_back(innerBody.getNumArguments()));
1933       for (Operation &op : innerBody.without_terminator())
1934         builder.clone(op, mapping);
1935     };
1936 
1937     auto concatValues = [](const auto &first, const auto &second) {
1938       SmallVector<Value> ret;
1939       ret.reserve(first.size() + second.size());
1940       ret.assign(first.begin(), first.end());
1941       ret.append(second.begin(), second.end());
1942       return ret;
1943     };
1944 
1945     auto newLowerBounds = concatValues(op.lowerBound(), innerOp.lowerBound());
1946     auto newUpperBounds = concatValues(op.upperBound(), innerOp.upperBound());
1947     auto newSteps = concatValues(op.step(), innerOp.step());
1948 
1949     rewriter.replaceOpWithNewOp<ParallelOp>(op, newLowerBounds, newUpperBounds,
1950                                             newSteps, llvm::None, bodyBuilder);
1951     return success();
1952   }
1953 };
1954 
1955 } // namespace
1956 
getCanonicalizationPatterns(RewritePatternSet & results,MLIRContext * context)1957 void ParallelOp::getCanonicalizationPatterns(RewritePatternSet &results,
1958                                              MLIRContext *context) {
1959   results.add<CollapseSingleIterationLoops, RemoveEmptyParallelLoops,
1960               MergeNestedParallelLoops>(context);
1961 }
1962 
1963 //===----------------------------------------------------------------------===//
1964 // ReduceOp
1965 //===----------------------------------------------------------------------===//
1966 
build(OpBuilder & builder,OperationState & result,Value operand,function_ref<void (OpBuilder &,Location,Value,Value)> bodyBuilderFn)1967 void ReduceOp::build(
1968     OpBuilder &builder, OperationState &result, Value operand,
1969     function_ref<void(OpBuilder &, Location, Value, Value)> bodyBuilderFn) {
1970   auto type = operand.getType();
1971   result.addOperands(operand);
1972 
1973   OpBuilder::InsertionGuard guard(builder);
1974   Region *bodyRegion = result.addRegion();
1975   Block *body = builder.createBlock(bodyRegion, {}, ArrayRef<Type>{type, type});
1976   if (bodyBuilderFn)
1977     bodyBuilderFn(builder, result.location, body->getArgument(0),
1978                   body->getArgument(1));
1979 }
1980 
verify(ReduceOp op)1981 static LogicalResult verify(ReduceOp op) {
1982   // The region of a ReduceOp has two arguments of the same type as its operand.
1983   auto type = op.operand().getType();
1984   Block &block = op.reductionOperator().front();
1985   if (block.empty())
1986     return op.emitOpError("the block inside reduce should not be empty");
1987   if (block.getNumArguments() != 2 ||
1988       llvm::any_of(block.getArguments(), [&](const BlockArgument &arg) {
1989         return arg.getType() != type;
1990       }))
1991     return op.emitOpError()
1992            << "expects two arguments to reduce block of type " << type;
1993 
1994   // Check that the block is terminated by a ReduceReturnOp.
1995   if (!isa<ReduceReturnOp>(block.getTerminator()))
1996     return op.emitOpError("the block inside reduce should be terminated with a "
1997                           "'scf.reduce.return' op");
1998 
1999   return success();
2000 }
2001 
parseReduceOp(OpAsmParser & parser,OperationState & result)2002 static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) {
2003   // Parse an opening `(` followed by the reduced value followed by `)`
2004   OpAsmParser::OperandType operand;
2005   if (parser.parseLParen() || parser.parseOperand(operand) ||
2006       parser.parseRParen())
2007     return failure();
2008 
2009   Type resultType;
2010   // Parse the type of the operand (and also what reduce computes on).
2011   if (parser.parseColonType(resultType) ||
2012       parser.resolveOperand(operand, resultType, result.operands))
2013     return failure();
2014 
2015   // Now parse the body.
2016   Region *body = result.addRegion();
2017   if (parser.parseRegion(*body, /*arguments=*/{}, /*argTypes=*/{}))
2018     return failure();
2019 
2020   return success();
2021 }
2022 
print(OpAsmPrinter & p,ReduceOp op)2023 static void print(OpAsmPrinter &p, ReduceOp op) {
2024   p << op.getOperationName() << "(" << op.operand() << ") ";
2025   p << " : " << op.operand().getType();
2026   p.printRegion(op.reductionOperator());
2027 }
2028 
2029 //===----------------------------------------------------------------------===//
2030 // ReduceReturnOp
2031 //===----------------------------------------------------------------------===//
2032 
verify(ReduceReturnOp op)2033 static LogicalResult verify(ReduceReturnOp op) {
2034   // The type of the return value should be the same type as the type of the
2035   // operand of the enclosing ReduceOp.
2036   auto reduceOp = cast<ReduceOp>(op->getParentOp());
2037   Type reduceType = reduceOp.operand().getType();
2038   if (reduceType != op.result().getType())
2039     return op.emitOpError() << "needs to have type " << reduceType
2040                             << " (the type of the enclosing ReduceOp)";
2041   return success();
2042 }
2043 
2044 //===----------------------------------------------------------------------===//
2045 // WhileOp
2046 //===----------------------------------------------------------------------===//
2047 
getSuccessorEntryOperands(unsigned index)2048 OperandRange WhileOp::getSuccessorEntryOperands(unsigned index) {
2049   assert(index == 0 &&
2050          "WhileOp is expected to branch only to the first region");
2051 
2052   return inits();
2053 }
2054 
getConditionOp()2055 ConditionOp WhileOp::getConditionOp() {
2056   return cast<ConditionOp>(before().front().getTerminator());
2057 }
2058 
getAfterArguments()2059 Block::BlockArgListType WhileOp::getAfterArguments() {
2060   return after().front().getArguments();
2061 }
2062 
getSuccessorRegions(Optional<unsigned> index,ArrayRef<Attribute> operands,SmallVectorImpl<RegionSuccessor> & regions)2063 void WhileOp::getSuccessorRegions(Optional<unsigned> index,
2064                                   ArrayRef<Attribute> operands,
2065                                   SmallVectorImpl<RegionSuccessor> &regions) {
2066   (void)operands;
2067 
2068   if (!index.hasValue()) {
2069     regions.emplace_back(&before(), before().getArguments());
2070     return;
2071   }
2072 
2073   assert(*index < 2 && "there are only two regions in a WhileOp");
2074   if (*index == 0) {
2075     regions.emplace_back(&after(), after().getArguments());
2076     regions.emplace_back(getResults());
2077     return;
2078   }
2079 
2080   regions.emplace_back(&before(), before().getArguments());
2081 }
2082 
2083 /// Parses a `while` op.
2084 ///
2085 /// op ::= `scf.while` assignments `:` function-type region `do` region
2086 ///         `attributes` attribute-dict
2087 /// initializer ::= /* empty */ | `(` assignment-list `)`
2088 /// assignment-list ::= assignment | assignment `,` assignment-list
2089 /// assignment ::= ssa-value `=` ssa-value
parseWhileOp(OpAsmParser & parser,OperationState & result)2090 static ParseResult parseWhileOp(OpAsmParser &parser, OperationState &result) {
2091   SmallVector<OpAsmParser::OperandType, 4> regionArgs, operands;
2092   Region *before = result.addRegion();
2093   Region *after = result.addRegion();
2094 
2095   OptionalParseResult listResult =
2096       parser.parseOptionalAssignmentList(regionArgs, operands);
2097   if (listResult.hasValue() && failed(listResult.getValue()))
2098     return failure();
2099 
2100   FunctionType functionType;
2101   llvm::SMLoc typeLoc = parser.getCurrentLocation();
2102   if (failed(parser.parseColonType(functionType)))
2103     return failure();
2104 
2105   result.addTypes(functionType.getResults());
2106 
2107   if (functionType.getNumInputs() != operands.size()) {
2108     return parser.emitError(typeLoc)
2109            << "expected as many input types as operands "
2110            << "(expected " << operands.size() << " got "
2111            << functionType.getNumInputs() << ")";
2112   }
2113 
2114   // Resolve input operands.
2115   if (failed(parser.resolveOperands(operands, functionType.getInputs(),
2116                                     parser.getCurrentLocation(),
2117                                     result.operands)))
2118     return failure();
2119 
2120   return failure(
2121       parser.parseRegion(*before, regionArgs, functionType.getInputs()) ||
2122       parser.parseKeyword("do") || parser.parseRegion(*after) ||
2123       parser.parseOptionalAttrDictWithKeyword(result.attributes));
2124 }
2125 
2126 /// Prints a `while` op.
print(OpAsmPrinter & p,scf::WhileOp op)2127 static void print(OpAsmPrinter &p, scf::WhileOp op) {
2128   p << op.getOperationName();
2129   printInitializationList(p, op.before().front().getArguments(), op.inits(),
2130                           " ");
2131   p << " : ";
2132   p.printFunctionalType(op.inits().getTypes(), op.results().getTypes());
2133   p.printRegion(op.before(), /*printEntryBlockArgs=*/false);
2134   p << " do";
2135   p.printRegion(op.after());
2136   p.printOptionalAttrDictWithKeyword(op->getAttrs());
2137 }
2138 
2139 /// Verifies that two ranges of types match, i.e. have the same number of
2140 /// entries and that types are pairwise equals. Reports errors on the given
2141 /// operation in case of mismatch.
2142 template <typename OpTy>
verifyTypeRangesMatch(OpTy op,TypeRange left,TypeRange right,StringRef message)2143 static LogicalResult verifyTypeRangesMatch(OpTy op, TypeRange left,
2144                                            TypeRange right, StringRef message) {
2145   if (left.size() != right.size())
2146     return op.emitOpError("expects the same number of ") << message;
2147 
2148   for (unsigned i = 0, e = left.size(); i < e; ++i) {
2149     if (left[i] != right[i]) {
2150       InFlightDiagnostic diag = op.emitOpError("expects the same types for ")
2151                                 << message;
2152       diag.attachNote() << "for argument " << i << ", found " << left[i]
2153                         << " and " << right[i];
2154       return diag;
2155     }
2156   }
2157 
2158   return success();
2159 }
2160 
2161 /// Verifies that the first block of the given `region` is terminated by a
2162 /// YieldOp. Reports errors on the given operation if it is not the case.
2163 template <typename TerminatorTy>
verifyAndGetTerminator(scf::WhileOp op,Region & region,StringRef errorMessage)2164 static TerminatorTy verifyAndGetTerminator(scf::WhileOp op, Region &region,
2165                                            StringRef errorMessage) {
2166   Operation *terminatorOperation = region.front().getTerminator();
2167   if (auto yield = dyn_cast_or_null<TerminatorTy>(terminatorOperation))
2168     return yield;
2169 
2170   auto diag = op.emitOpError(errorMessage);
2171   if (terminatorOperation)
2172     diag.attachNote(terminatorOperation->getLoc()) << "terminator here";
2173   return nullptr;
2174 }
2175 
verify(scf::WhileOp op)2176 static LogicalResult verify(scf::WhileOp op) {
2177   if (failed(RegionBranchOpInterface::verifyTypes(op)))
2178     return failure();
2179 
2180   auto beforeTerminator = verifyAndGetTerminator<scf::ConditionOp>(
2181       op, op.before(),
2182       "expects the 'before' region to terminate with 'scf.condition'");
2183   if (!beforeTerminator)
2184     return failure();
2185 
2186   auto afterTerminator = verifyAndGetTerminator<scf::YieldOp>(
2187       op, op.after(),
2188       "expects the 'after' region to terminate with 'scf.yield'");
2189   return success(afterTerminator != nullptr);
2190 }
2191 
2192 namespace {
2193 /// Replace uses of the condition within the do block with true, since otherwise
2194 /// the block would not be evaluated.
2195 ///
2196 /// scf.while (..) : (i1, ...) -> ... {
2197 ///  %condition = call @evaluate_condition() : () -> i1
2198 ///  scf.condition(%condition) %condition : i1, ...
2199 /// } do {
2200 /// ^bb0(%arg0: i1, ...):
2201 ///    use(%arg0)
2202 ///    ...
2203 ///
2204 /// becomes
2205 /// scf.while (..) : (i1, ...) -> ... {
2206 ///  %condition = call @evaluate_condition() : () -> i1
2207 ///  scf.condition(%condition) %condition : i1, ...
2208 /// } do {
2209 /// ^bb0(%arg0: i1, ...):
2210 ///    use(%true)
2211 ///    ...
2212 struct WhileConditionTruth : public OpRewritePattern<WhileOp> {
2213   using OpRewritePattern<WhileOp>::OpRewritePattern;
2214 
matchAndRewrite__anon56feca5b1811::WhileConditionTruth2215   LogicalResult matchAndRewrite(WhileOp op,
2216                                 PatternRewriter &rewriter) const override {
2217     auto term = op.getConditionOp();
2218 
2219     // These variables serve to prevent creating duplicate constants
2220     // and hold constant true or false values.
2221     Value constantTrue = nullptr;
2222 
2223     bool replaced = false;
2224     for (auto yieldedAndBlockArgs :
2225          llvm::zip(term.args(), op.getAfterArguments())) {
2226       if (std::get<0>(yieldedAndBlockArgs) == term.condition()) {
2227         if (!std::get<1>(yieldedAndBlockArgs).use_empty()) {
2228           if (!constantTrue)
2229             constantTrue = rewriter.create<mlir::ConstantOp>(
2230                 op.getLoc(), term.condition().getType(),
2231                 rewriter.getBoolAttr(true));
2232 
2233           std::get<1>(yieldedAndBlockArgs).replaceAllUsesWith(constantTrue);
2234           replaced = true;
2235         }
2236       }
2237     }
2238     return success(replaced);
2239   }
2240 };
2241 } // namespace
2242 
getCanonicalizationPatterns(OwningRewritePatternList & results,MLIRContext * context)2243 void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
2244                                           MLIRContext *context) {
2245   results.insert<WhileConditionTruth>(context);
2246 }
2247 
2248 //===----------------------------------------------------------------------===//
2249 // TableGen'd op method definitions
2250 //===----------------------------------------------------------------------===//
2251 
2252 #define GET_OP_CLASSES
2253 #include "mlir/Dialect/SCF/SCFOps.cpp.inc"
2254