1 //===- SCFToStandard.cpp - ControlFlow to CFG conversion ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert scf.for, scf.if and loop.terminator
10 // ops into standard CFG ops.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
15 #include "../PassDetail.h"
16 #include "mlir/Dialect/SCF/SCF.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 #include "mlir/IR/BlockAndValueMapping.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/MLIRContext.h"
21 #include "mlir/IR/Module.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Transforms/DialectConversion.h"
24 #include "mlir/Transforms/Passes.h"
25 #include "mlir/Transforms/Utils.h"
26 
27 using namespace mlir;
28 using namespace mlir::scf;
29 
30 namespace {
31 
32 struct SCFToStandardPass : public SCFToStandardBase<SCFToStandardPass> {
33   void runOnOperation() override;
34 };
35 
36 // Create a CFG subgraph for the loop around its body blocks (if the body
37 // contained other loops, they have been already lowered to a flow of blocks).
38 // Maintain the invariants that a CFG subgraph created for any loop has a single
39 // entry and a single exit, and that the entry/exit blocks are respectively
40 // first/last blocks in the parent region.  The original loop operation is
41 // replaced by the initialization operations that set up the initial value of
42 // the loop induction variable (%iv) and computes the loop bounds that are loop-
43 // invariant for affine loops.  The operations following the original scf.for
44 // are split out into a separate continuation (exit) block. A condition block is
45 // created before the continuation block. It checks the exit condition of the
46 // loop and branches either to the continuation block, or to the first block of
47 // the body. The condition block takes as arguments the values of the induction
48 // variable followed by loop-carried values. Since it dominates both the body
49 // blocks and the continuation block, loop-carried values are visible in all of
50 // those blocks. Induction variable modification is appended to the last block
51 // of the body (which is the exit block from the body subgraph thanks to the
52 // invariant we maintain) along with a branch that loops back to the condition
53 // block. Loop-carried values are the loop terminator operands, which are
54 // forwarded to the branch.
55 //
56 //      +---------------------------------+
57 //      |   <code before the ForOp>       |
58 //      |   <definitions of %init...>     |
59 //      |   <compute initial %iv value>   |
60 //      |   br cond(%iv, %init...)        |
61 //      +---------------------------------+
62 //             |
63 //  -------|   |
64 //  |      v   v
65 //  |   +--------------------------------+
66 //  |   | cond(%iv, %init...):           |
67 //  |   |   <compare %iv to upper bound> |
68 //  |   |   cond_br %r, body, end        |
69 //  |   +--------------------------------+
70 //  |          |               |
71 //  |          |               -------------|
72 //  |          v                            |
73 //  |   +--------------------------------+  |
74 //  |   | body-first:                    |  |
75 //  |   |   <%init visible by dominance> |  |
76 //  |   |   <body contents>              |  |
77 //  |   +--------------------------------+  |
78 //  |                   |                   |
79 //  |                  ...                  |
80 //  |                   |                   |
81 //  |   +--------------------------------+  |
82 //  |   | body-last:                     |  |
83 //  |   |   <body contents>              |  |
84 //  |   |   <operands of yield = %yields>|  |
85 //  |   |   %new_iv =<add step to %iv>   |  |
86 //  |   |   br cond(%new_iv, %yields)    |  |
87 //  |   +--------------------------------+  |
88 //  |          |                            |
89 //  |-----------        |--------------------
90 //                      v
91 //      +--------------------------------+
92 //      | end:                           |
93 //      |   <code after the ForOp>       |
94 //      |   <%init visible by dominance> |
95 //      +--------------------------------+
96 //
97 struct ForLowering : public OpRewritePattern<ForOp> {
98   using OpRewritePattern<ForOp>::OpRewritePattern;
99 
100   LogicalResult matchAndRewrite(ForOp forOp,
101                                 PatternRewriter &rewriter) const override;
102 };
103 
104 // Create a CFG subgraph for the scf.if operation (including its "then" and
105 // optional "else" operation blocks).  We maintain the invariants that the
106 // subgraph has a single entry and a single exit point, and that the entry/exit
107 // blocks are respectively the first/last block of the enclosing region. The
108 // operations following the scf.if are split into a continuation (subgraph
109 // exit) block. The condition is lowered to a chain of blocks that implement the
110 // short-circuit scheme. The "scf.if" operation is replaced with a conditional
111 // branch to either the first block of the "then" region, or to the first block
112 // of the "else" region. In these blocks, "scf.yield" is unconditional branches
113 // to the post-dominating block. When the "scf.if" does not return values, the
114 // post-dominating block is the same as the continuation block. When it returns
115 // values, the post-dominating block is a new block with arguments that
116 // correspond to the values returned by the "scf.if" that unconditionally
117 // branches to the continuation block. This allows block arguments to dominate
118 // any uses of the hitherto "scf.if" results that they replaced. (Inserting a
119 // new block allows us to avoid modifying the argument list of an existing
120 // block, which is illegal in a conversion pattern). When the "else" region is
121 // empty, which is only allowed for "scf.if"s that don't return values, the
122 // condition branches directly to the continuation block.
123 //
124 // CFG for a scf.if with else and without results.
125 //
126 //      +--------------------------------+
127 //      | <code before the IfOp>         |
128 //      | cond_br %cond, %then, %else    |
129 //      +--------------------------------+
130 //             |              |
131 //             |              --------------|
132 //             v                            |
133 //      +--------------------------------+  |
134 //      | then:                          |  |
135 //      |   <then contents>              |  |
136 //      |   br continue                  |  |
137 //      +--------------------------------+  |
138 //             |                            |
139 //   |----------               |-------------
140 //   |                         V
141 //   |  +--------------------------------+
142 //   |  | else:                          |
143 //   |  |   <else contents>              |
144 //   |  |   br continue                  |
145 //   |  +--------------------------------+
146 //   |         |
147 //   ------|   |
148 //         v   v
149 //      +--------------------------------+
150 //      | continue:                      |
151 //      |   <code after the IfOp>        |
152 //      +--------------------------------+
153 //
154 // CFG for a scf.if with results.
155 //
156 //      +--------------------------------+
157 //      | <code before the IfOp>         |
158 //      | cond_br %cond, %then, %else    |
159 //      +--------------------------------+
160 //             |              |
161 //             |              --------------|
162 //             v                            |
163 //      +--------------------------------+  |
164 //      | then:                          |  |
165 //      |   <then contents>              |  |
166 //      |   br dom(%args...)             |  |
167 //      +--------------------------------+  |
168 //             |                            |
169 //   |----------               |-------------
170 //   |                         V
171 //   |  +--------------------------------+
172 //   |  | else:                          |
173 //   |  |   <else contents>              |
174 //   |  |   br dom(%args...)             |
175 //   |  +--------------------------------+
176 //   |         |
177 //   ------|   |
178 //         v   v
179 //      +--------------------------------+
180 //      | dom(%args...):                 |
181 //      |   br continue                  |
182 //      +--------------------------------+
183 //             |
184 //             v
185 //      +--------------------------------+
186 //      | continue:                      |
187 //      | <code after the IfOp>          |
188 //      +--------------------------------+
189 //
190 struct IfLowering : public OpRewritePattern<IfOp> {
191   using OpRewritePattern<IfOp>::OpRewritePattern;
192 
193   LogicalResult matchAndRewrite(IfOp ifOp,
194                                 PatternRewriter &rewriter) const override;
195 };
196 
197 struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
198   using OpRewritePattern<mlir::scf::ParallelOp>::OpRewritePattern;
199 
200   LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp,
201                                 PatternRewriter &rewriter) const override;
202 };
203 } // namespace
204 
matchAndRewrite(ForOp forOp,PatternRewriter & rewriter) const205 LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
206                                            PatternRewriter &rewriter) const {
207   Location loc = forOp.getLoc();
208 
209   // Start by splitting the block containing the 'scf.for' into two parts.
210   // The part before will get the init code, the part after will be the end
211   // point.
212   auto *initBlock = rewriter.getInsertionBlock();
213   auto initPosition = rewriter.getInsertionPoint();
214   auto *endBlock = rewriter.splitBlock(initBlock, initPosition);
215 
216   // Use the first block of the loop body as the condition block since it is the
217   // block that has the induction variable and loop-carried values as arguments.
218   // Split out all operations from the first block into a new block. Move all
219   // body blocks from the loop body region to the region containing the loop.
220   auto *conditionBlock = &forOp.region().front();
221   auto *firstBodyBlock =
222       rewriter.splitBlock(conditionBlock, conditionBlock->begin());
223   auto *lastBodyBlock = &forOp.region().back();
224   rewriter.inlineRegionBefore(forOp.region(), endBlock);
225   auto iv = conditionBlock->getArgument(0);
226 
227   // Append the induction variable stepping logic to the last body block and
228   // branch back to the condition block. Loop-carried values are taken from
229   // operands of the loop terminator.
230   Operation *terminator = lastBodyBlock->getTerminator();
231   rewriter.setInsertionPointToEnd(lastBodyBlock);
232   auto step = forOp.step();
233   auto stepped = rewriter.create<AddIOp>(loc, iv, step).getResult();
234   if (!stepped)
235     return failure();
236 
237   SmallVector<Value, 8> loopCarried;
238   loopCarried.push_back(stepped);
239   loopCarried.append(terminator->operand_begin(), terminator->operand_end());
240   rewriter.create<BranchOp>(loc, conditionBlock, loopCarried);
241   rewriter.eraseOp(terminator);
242 
243   // Compute loop bounds before branching to the condition.
244   rewriter.setInsertionPointToEnd(initBlock);
245   Value lowerBound = forOp.lowerBound();
246   Value upperBound = forOp.upperBound();
247   if (!lowerBound || !upperBound)
248     return failure();
249 
250   // The initial values of loop-carried values is obtained from the operands
251   // of the loop operation.
252   SmallVector<Value, 8> destOperands;
253   destOperands.push_back(lowerBound);
254   auto iterOperands = forOp.getIterOperands();
255   destOperands.append(iterOperands.begin(), iterOperands.end());
256   rewriter.create<BranchOp>(loc, conditionBlock, destOperands);
257 
258   // With the body block done, we can fill in the condition block.
259   rewriter.setInsertionPointToEnd(conditionBlock);
260   auto comparison =
261       rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, iv, upperBound);
262 
263   rewriter.create<CondBranchOp>(loc, comparison, firstBodyBlock,
264                                 ArrayRef<Value>(), endBlock, ArrayRef<Value>());
265   // The result of the loop operation is the values of the condition block
266   // arguments except the induction variable on the last iteration.
267   rewriter.replaceOp(forOp, conditionBlock->getArguments().drop_front());
268   return success();
269 }
270 
matchAndRewrite(IfOp ifOp,PatternRewriter & rewriter) const271 LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
272                                           PatternRewriter &rewriter) const {
273   auto loc = ifOp.getLoc();
274 
275   // Start by splitting the block containing the 'scf.if' into two parts.
276   // The part before will contain the condition, the part after will be the
277   // continuation point.
278   auto *condBlock = rewriter.getInsertionBlock();
279   auto opPosition = rewriter.getInsertionPoint();
280   auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
281   Block *continueBlock;
282   if (ifOp.getNumResults() == 0) {
283     continueBlock = remainingOpsBlock;
284   } else {
285     continueBlock =
286         rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes());
287     rewriter.create<BranchOp>(loc, remainingOpsBlock);
288   }
289 
290   // Move blocks from the "then" region to the region containing 'scf.if',
291   // place it before the continuation block, and branch to it.
292   auto &thenRegion = ifOp.thenRegion();
293   auto *thenBlock = &thenRegion.front();
294   Operation *thenTerminator = thenRegion.back().getTerminator();
295   ValueRange thenTerminatorOperands = thenTerminator->getOperands();
296   rewriter.setInsertionPointToEnd(&thenRegion.back());
297   rewriter.create<BranchOp>(loc, continueBlock, thenTerminatorOperands);
298   rewriter.eraseOp(thenTerminator);
299   rewriter.inlineRegionBefore(thenRegion, continueBlock);
300 
301   // Move blocks from the "else" region (if present) to the region containing
302   // 'scf.if', place it before the continuation block and branch to it.  It
303   // will be placed after the "then" regions.
304   auto *elseBlock = continueBlock;
305   auto &elseRegion = ifOp.elseRegion();
306   if (!elseRegion.empty()) {
307     elseBlock = &elseRegion.front();
308     Operation *elseTerminator = elseRegion.back().getTerminator();
309     ValueRange elseTerminatorOperands = elseTerminator->getOperands();
310     rewriter.setInsertionPointToEnd(&elseRegion.back());
311     rewriter.create<BranchOp>(loc, continueBlock, elseTerminatorOperands);
312     rewriter.eraseOp(elseTerminator);
313     rewriter.inlineRegionBefore(elseRegion, continueBlock);
314   }
315 
316   rewriter.setInsertionPointToEnd(condBlock);
317   rewriter.create<CondBranchOp>(loc, ifOp.condition(), thenBlock,
318                                 /*trueArgs=*/ArrayRef<Value>(), elseBlock,
319                                 /*falseArgs=*/ArrayRef<Value>());
320 
321   // Ok, we're done!
322   rewriter.replaceOp(ifOp, continueBlock->getArguments());
323   return success();
324 }
325 
326 LogicalResult
matchAndRewrite(ParallelOp parallelOp,PatternRewriter & rewriter) const327 ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
328                                   PatternRewriter &rewriter) const {
329   Location loc = parallelOp.getLoc();
330   BlockAndValueMapping mapping;
331 
332   // For a parallel loop, we essentially need to create an n-dimensional loop
333   // nest. We do this by translating to scf.for ops and have those lowered in
334   // a further rewrite. If a parallel loop contains reductions (and thus returns
335   // values), forward the initial values for the reductions down the loop
336   // hierarchy and bubble up the results by modifying the "yield" terminator.
337   SmallVector<Value, 4> iterArgs = llvm::to_vector<4>(parallelOp.initVals());
338   bool first = true;
339   SmallVector<Value, 4> loopResults(iterArgs);
340   for (auto loop_operands :
341        llvm::zip(parallelOp.getInductionVars(), parallelOp.lowerBound(),
342                  parallelOp.upperBound(), parallelOp.step())) {
343     Value iv, lower, upper, step;
344     std::tie(iv, lower, upper, step) = loop_operands;
345     ForOp forOp = rewriter.create<ForOp>(loc, lower, upper, step, iterArgs);
346     mapping.map(iv, forOp.getInductionVar());
347     auto iterRange = forOp.getRegionIterArgs();
348     iterArgs.assign(iterRange.begin(), iterRange.end());
349 
350     if (first) {
351       // Store the results of the outermost loop that will be used to replace
352       // the results of the parallel loop when it is fully rewritten.
353       loopResults.assign(forOp.result_begin(), forOp.result_end());
354       first = false;
355     } else if (!forOp.getResults().empty()) {
356       // A loop is constructed with an empty "yield" terminator if there are
357       // no results.
358       rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
359       rewriter.create<YieldOp>(loc, forOp.getResults());
360     }
361 
362     rewriter.setInsertionPointToStart(forOp.getBody());
363   }
364 
365   // Now copy over the contents of the body.
366   SmallVector<Value, 4> yieldOperands;
367   yieldOperands.reserve(parallelOp.getNumResults());
368   for (auto &op : parallelOp.getBody()->without_terminator()) {
369     // Reduction blocks are handled differently.
370     auto reduce = dyn_cast<ReduceOp>(op);
371     if (!reduce) {
372       rewriter.clone(op, mapping);
373       continue;
374     }
375 
376     // Clone the body of the reduction operation into the body of the loop,
377     // using operands of "scf.reduce" and iteration arguments corresponding
378     // to the reduction value to replace arguments of the reduction block.
379     // Collect operands of "scf.reduce.return" to be returned by a final
380     // "scf.yield" instead.
381     Value arg = iterArgs[yieldOperands.size()];
382     Block &reduceBlock = reduce.reductionOperator().front();
383     mapping.map(reduceBlock.getArgument(0), mapping.lookupOrDefault(arg));
384     mapping.map(reduceBlock.getArgument(1),
385                 mapping.lookupOrDefault(reduce.operand()));
386     for (auto &nested : reduceBlock.without_terminator())
387       rewriter.clone(nested, mapping);
388     yieldOperands.push_back(
389         mapping.lookup(reduceBlock.getTerminator()->getOperand(0)));
390   }
391 
392   if (!yieldOperands.empty()) {
393     rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
394     rewriter.create<YieldOp>(loc, yieldOperands);
395   }
396 
397   rewriter.replaceOp(parallelOp, loopResults);
398 
399   return success();
400 }
401 
populateLoopToStdConversionPatterns(OwningRewritePatternList & patterns,MLIRContext * ctx)402 void mlir::populateLoopToStdConversionPatterns(
403     OwningRewritePatternList &patterns, MLIRContext *ctx) {
404   patterns.insert<ForLowering, IfLowering, ParallelLowering>(ctx);
405 }
406 
runOnOperation()407 void SCFToStandardPass::runOnOperation() {
408   OwningRewritePatternList patterns;
409   populateLoopToStdConversionPatterns(patterns, &getContext());
410   // Configure conversion to lower out scf.for, scf.if and scf.parallel.
411   // Anything else is fine.
412   ConversionTarget target(getContext());
413   target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp>();
414   target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
415   if (failed(applyPartialConversion(getOperation(), target, patterns)))
416     signalPassFailure();
417 }
418 
createLowerToCFGPass()419 std::unique_ptr<Pass> mlir::createLowerToCFGPass() {
420   return std::make_unique<SCFToStandardPass>();
421 }
422