1 //===- SCFToStandard.cpp - ControlFlow to CFG conversion ------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements a pass to convert scf.for, scf.if and loop.terminator
10 // ops into standard CFG ops.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "mlir/Conversion/SCFToStandard/SCFToStandard.h"
15 #include "../PassDetail.h"
16 #include "mlir/Dialect/SCF/SCF.h"
17 #include "mlir/Dialect/StandardOps/IR/Ops.h"
18 #include "mlir/IR/BlockAndValueMapping.h"
19 #include "mlir/IR/Builders.h"
20 #include "mlir/IR/BuiltinOps.h"
21 #include "mlir/IR/MLIRContext.h"
22 #include "mlir/IR/PatternMatch.h"
23 #include "mlir/Transforms/DialectConversion.h"
24 #include "mlir/Transforms/Passes.h"
25 #include "mlir/Transforms/Utils.h"
26
27 using namespace mlir;
28 using namespace mlir::scf;
29
30 namespace {
31
32 struct SCFToStandardPass : public SCFToStandardBase<SCFToStandardPass> {
33 void runOnOperation() override;
34 };
35
36 // Create a CFG subgraph for the loop around its body blocks (if the body
37 // contained other loops, they have been already lowered to a flow of blocks).
38 // Maintain the invariants that a CFG subgraph created for any loop has a single
39 // entry and a single exit, and that the entry/exit blocks are respectively
40 // first/last blocks in the parent region. The original loop operation is
41 // replaced by the initialization operations that set up the initial value of
42 // the loop induction variable (%iv) and computes the loop bounds that are loop-
43 // invariant for affine loops. The operations following the original scf.for
44 // are split out into a separate continuation (exit) block. A condition block is
45 // created before the continuation block. It checks the exit condition of the
46 // loop and branches either to the continuation block, or to the first block of
47 // the body. The condition block takes as arguments the values of the induction
48 // variable followed by loop-carried values. Since it dominates both the body
49 // blocks and the continuation block, loop-carried values are visible in all of
50 // those blocks. Induction variable modification is appended to the last block
51 // of the body (which is the exit block from the body subgraph thanks to the
52 // invariant we maintain) along with a branch that loops back to the condition
53 // block. Loop-carried values are the loop terminator operands, which are
54 // forwarded to the branch.
55 //
56 // +---------------------------------+
57 // | <code before the ForOp> |
58 // | <definitions of %init...> |
59 // | <compute initial %iv value> |
60 // | br cond(%iv, %init...) |
61 // +---------------------------------+
62 // |
63 // -------| |
timespec_to_usec(const struct timespec * ts)64 // | v v
65 // | +--------------------------------+
66 // | | cond(%iv, %init...): |
67 // | | <compare %iv to upper bound> |
68 // | | cond_br %r, body, end |
69 // | +--------------------------------+
70 // | | |
71 // | | -------------|
72 // | v |
73 // | +--------------------------------+ |
74 // | | body-first: | |
75 // | | <%init visible by dominance> | |
76 // | | <body contents> | |
77 // | +--------------------------------+ |
78 // | | |
79 // | ... |
80 // | | |
81 // | +--------------------------------+ |
82 // | | body-last: | |
83 // | | <body contents> | |
84 // | | <operands of yield = %yields>| |
85 // | | %new_iv =<add step to %iv> | |
86 // | | br cond(%new_iv, %yields) | |
87 // | +--------------------------------+ |
88 // | | |
89 // |----------- |--------------------
90 // v
91 // +--------------------------------+
92 // | end: |
93 // | <code after the ForOp> |
94 // | <%init visible by dominance> |
95 // +--------------------------------+
96 //
97 struct ForLowering : public OpRewritePattern<ForOp> {
98 using OpRewritePattern<ForOp>::OpRewritePattern;
99
100 LogicalResult matchAndRewrite(ForOp forOp,
101 PatternRewriter &rewriter) const override;
102 };
103
104 // Create a CFG subgraph for the scf.if operation (including its "then" and
105 // optional "else" operation blocks). We maintain the invariants that the
106 // subgraph has a single entry and a single exit point, and that the entry/exit
107 // blocks are respectively the first/last block of the enclosing region. The
108 // operations following the scf.if are split into a continuation (subgraph
109 // exit) block. The condition is lowered to a chain of blocks that implement the
110 // short-circuit scheme. The "scf.if" operation is replaced with a conditional
111 // branch to either the first block of the "then" region, or to the first block
112 // of the "else" region. In these blocks, "scf.yield" is unconditional branches
113 // to the post-dominating block. When the "scf.if" does not return values, the
114 // post-dominating block is the same as the continuation block. When it returns
get_hash_value(const char * log_file_name,my_off_t log_file_pos)115 // values, the post-dominating block is a new block with arguments that
116 // correspond to the values returned by the "scf.if" that unconditionally
117 // branches to the continuation block. This allows block arguments to dominate
118 // any uses of the hitherto "scf.if" results that they replaced. (Inserting a
119 // new block allows us to avoid modifying the argument list of an existing
120 // block, which is illegal in a conversion pattern). When the "else" region is
121 // empty, which is only allowed for "scf.if"s that don't return values, the
122 // condition branches directly to the continuation block.
123 //
124 // CFG for a scf.if with else and without results.
125 //
126 // +--------------------------------+
127 // | <code before the IfOp> |
128 // | cond_br %cond, %then, %else |
129 // +--------------------------------+
130 // | |
131 // | --------------|
132 // v |
133 // +--------------------------------+ |
134 // | then: | |
135 // | <then contents> | |
136 // | br continue | |
137 // +--------------------------------+ |
138 // | |
139 // |---------- |-------------
140 // | V
141 // | +--------------------------------+
142 // | | else: |
143 // | | <else contents> |
144 // | | br continue |
145 // | +--------------------------------+
146 // | |
147 // ------| |
148 // v v
149 // +--------------------------------+
150 // | continue: |
151 // | <code after the IfOp> |
152 // +--------------------------------+
153 //
154 // CFG for a scf.if with results.
155 //
156 // +--------------------------------+
157 // | <code before the IfOp> |
158 // | cond_br %cond, %then, %else |
159 // +--------------------------------+
160 // | |
161 // | --------------|
162 // v |
163 // +--------------------------------+ |
164 // | then: | |
165 // | <then contents> | |
166 // | br dom(%args...) | |
167 // +--------------------------------+ |
168 // | |
169 // |---------- |-------------
170 // | V
171 // | +--------------------------------+
172 // | | else: |
173 // | | <else contents> |
174 // | | br dom(%args...) |
175 // | +--------------------------------+
176 // | |
177 // ------| |
178 // v v
179 // +--------------------------------+
180 // | dom(%args...): |
181 // | br continue |
182 // +--------------------------------+
183 // |
184 // v
185 // +--------------------------------+
186 // | continue: |
187 // | <code after the IfOp> |
188 // +--------------------------------+
189 //
190 struct IfLowering : public OpRewritePattern<IfOp> {
191 using OpRewritePattern<IfOp>::OpRewritePattern;
192
193 LogicalResult matchAndRewrite(IfOp ifOp,
194 PatternRewriter &rewriter) const override;
195 };
196
197 struct ExecuteRegionLowering : public OpRewritePattern<ExecuteRegionOp> {
198 using OpRewritePattern<ExecuteRegionOp>::OpRewritePattern;
199
200 LogicalResult matchAndRewrite(ExecuteRegionOp op,
201 PatternRewriter &rewriter) const override;
202 };
203
204 struct ParallelLowering : public OpRewritePattern<mlir::scf::ParallelOp> {
205 using OpRewritePattern<mlir::scf::ParallelOp>::OpRewritePattern;
206
207 LogicalResult matchAndRewrite(mlir::scf::ParallelOp parallelOp,
is_tranx_end_pos(const char * log_file_name,my_off_t log_file_pos)208 PatternRewriter &rewriter) const override;
209 };
210
211 /// Create a CFG subgraph for this loop construct. The regions of the loop need
212 /// not be a single block anymore (for example, if other SCF constructs that
213 /// they contain have been already converted to CFG), but need to be single-exit
214 /// from the last block of each region. The operations following the original
215 /// WhileOp are split into a new continuation block. Both regions of the WhileOp
216 /// are inlined, and their terminators are rewritten to organize the control
217 /// flow implementing the loop as follows.
218 ///
219 /// +---------------------------------+
220 /// | <code before the WhileOp> |
221 /// | br ^before(%operands...) |
222 /// +---------------------------------+
223 /// |
224 /// -------| |
225 /// | v v
226 /// | +--------------------------------+
227 /// | | ^before(%bargs...): |
228 /// | | %vals... = <some payload> |
229 /// | +--------------------------------+
230 /// | |
clear_active_tranx_nodes(const char * log_file_name,my_off_t log_file_pos)231 /// | ...
232 /// | |
233 /// | +--------------------------------+
234 /// | | ^before-last:
235 /// | | %cond = <compute condition> |
236 /// | | cond_br %cond, |
237 /// | | ^after(%vals...), ^cont |
238 /// | +--------------------------------+
239 /// | | |
240 /// | | -------------|
241 /// | v |
242 /// | +--------------------------------+ |
243 /// | | ^after(%aargs...): | |
244 /// | | <body contents> | |
245 /// | +--------------------------------+ |
246 /// | | |
247 /// | ... |
248 /// | | |
249 /// | +--------------------------------+ |
250 /// | | ^after-last: | |
251 /// | | %yields... = <some payload> | |
252 /// | | br ^before(%yields...) | |
253 /// | +--------------------------------+ |
254 /// | | |
255 /// |----------- |--------------------
256 /// v
257 /// +--------------------------------+
258 /// | ^cont: |
259 /// | <code after the WhileOp> |
260 /// | <%vals from 'before' region |
261 /// | visible by dominance> |
262 /// +--------------------------------+
263 ///
264 /// Values are communicated between ex-regions (the groups of blocks that used
265 /// to form a region before inlining) through block arguments of their
266 /// entry blocks, which are visible in all other dominated blocks. Similarly,
267 /// the results of the WhileOp are defined in the 'before' region, which is
268 /// required to have a single existing block, and are therefore accessible in
269 /// the continuation block due to dominance.
270 struct WhileLowering : public OpRewritePattern<WhileOp> {
271 using OpRewritePattern<WhileOp>::OpRewritePattern;
272
273 LogicalResult matchAndRewrite(WhileOp whileOp,
274 PatternRewriter &rewriter) const override;
275 };
276
277 /// Optimized version of the above for the case of the "after" region merely
278 /// forwarding its arguments back to the "before" region (i.e., a "do-while"
279 /// loop). This avoid inlining the "after" region completely and branches back
280 /// to the "before" entry instead.
281 struct DoWhileLowering : public OpRewritePattern<WhileOp> {
282 using OpRewritePattern<WhileOp>::OpRewritePattern;
283
284 LogicalResult matchAndRewrite(WhileOp whileOp,
285 PatternRewriter &rewriter) const override;
286 };
287 } // namespace
288
289 LogicalResult ForLowering::matchAndRewrite(ForOp forOp,
290 PatternRewriter &rewriter) const {
291 Location loc = forOp.getLoc();
292
293 // Start by splitting the block containing the 'scf.for' into two parts.
294 // The part before will get the init code, the part after will be the end
295 // point.
296 auto *initBlock = rewriter.getInsertionBlock();
297 auto initPosition = rewriter.getInsertionPoint();
298 auto *endBlock = rewriter.splitBlock(initBlock, initPosition);
299
300 // Use the first block of the loop body as the condition block since it is the
301 // block that has the induction variable and loop-carried values as arguments.
302 // Split out all operations from the first block into a new block. Move all
303 // body blocks from the loop body region to the region containing the loop.
304 auto *conditionBlock = &forOp.region().front();
305 auto *firstBodyBlock =
306 rewriter.splitBlock(conditionBlock, conditionBlock->begin());
307 auto *lastBodyBlock = &forOp.region().back();
308 rewriter.inlineRegionBefore(forOp.region(), endBlock);
309 auto iv = conditionBlock->getArgument(0);
310
311 // Append the induction variable stepping logic to the last body block and
312 // branch back to the condition block. Loop-carried values are taken from
313 // operands of the loop terminator.
314 Operation *terminator = lastBodyBlock->getTerminator();
315 rewriter.setInsertionPointToEnd(lastBodyBlock);
316 auto step = forOp.step();
317 auto stepped = rewriter.create<AddIOp>(loc, iv, step).getResult();
318 if (!stepped)
319 return failure();
320
321 SmallVector<Value, 8> loopCarried;
322 loopCarried.push_back(stepped);
323 loopCarried.append(terminator->operand_begin(), terminator->operand_end());
324 rewriter.create<BranchOp>(loc, conditionBlock, loopCarried);
325 rewriter.eraseOp(terminator);
326
327 // Compute loop bounds before branching to the condition.
328 rewriter.setInsertionPointToEnd(initBlock);
329 Value lowerBound = forOp.lowerBound();
330 Value upperBound = forOp.upperBound();
331 if (!lowerBound || !upperBound)
332 return failure();
333
334 // The initial values of loop-carried values is obtained from the operands
335 // of the loop operation.
336 SmallVector<Value, 8> destOperands;
Repl_semi_sync_master()337 destOperands.push_back(lowerBound);
338 auto iterOperands = forOp.getIterOperands();
339 destOperands.append(iterOperands.begin(), iterOperands.end());
340 rewriter.create<BranchOp>(loc, conditionBlock, destOperands);
341
342 // With the body block done, we can fill in the condition block.
343 rewriter.setInsertionPointToEnd(conditionBlock);
344 auto comparison =
345 rewriter.create<CmpIOp>(loc, CmpIPredicate::slt, iv, upperBound);
346
347 rewriter.create<CondBranchOp>(loc, comparison, firstBodyBlock,
348 ArrayRef<Value>(), endBlock, ArrayRef<Value>());
349 // The result of the loop operation is the values of the condition block
350 // arguments except the induction variable on the last iteration.
351 rewriter.replaceOp(forOp, conditionBlock->getArguments().drop_front());
352 return success();
init_object()353 }
354
355 LogicalResult IfLowering::matchAndRewrite(IfOp ifOp,
356 PatternRewriter &rewriter) const {
357 auto loc = ifOp.getLoc();
358
359 // Start by splitting the block containing the 'scf.if' into two parts.
360 // The part before will contain the condition, the part after will be the
361 // continuation point.
362 auto *condBlock = rewriter.getInsertionBlock();
363 auto opPosition = rewriter.getInsertionPoint();
364 auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
365 Block *continueBlock;
366 if (ifOp.getNumResults() == 0) {
367 continueBlock = remainingOpsBlock;
368 } else {
369 continueBlock =
370 rewriter.createBlock(remainingOpsBlock, ifOp.getResultTypes());
371 rewriter.create<BranchOp>(loc, remainingOpsBlock);
372 }
373
374 // Move blocks from the "then" region to the region containing 'scf.if',
375 // place it before the continuation block, and branch to it.
376 auto &thenRegion = ifOp.thenRegion();
377 auto *thenBlock = &thenRegion.front();
378 Operation *thenTerminator = thenRegion.back().getTerminator();
379 ValueRange thenTerminatorOperands = thenTerminator->getOperands();
380 rewriter.setInsertionPointToEnd(&thenRegion.back());
381 rewriter.create<BranchOp>(loc, continueBlock, thenTerminatorOperands);
382 rewriter.eraseOp(thenTerminator);
383 rewriter.inlineRegionBefore(thenRegion, continueBlock);
384
385 // Move blocks from the "else" region (if present) to the region containing
386 // 'scf.if', place it before the continuation block and branch to it. It
387 // will be placed after the "then" regions.
388 auto *elseBlock = continueBlock;
389 auto &elseRegion = ifOp.elseRegion();
390 if (!elseRegion.empty()) {
391 elseBlock = &elseRegion.front();
392 Operation *elseTerminator = elseRegion.back().getTerminator();
393 ValueRange elseTerminatorOperands = elseTerminator->getOperands();
394 rewriter.setInsertionPointToEnd(&elseRegion.back());
395 rewriter.create<BranchOp>(loc, continueBlock, elseTerminatorOperands);
396 rewriter.eraseOp(elseTerminator);
397 rewriter.inlineRegionBefore(elseRegion, continueBlock);
398 }
399
400 rewriter.setInsertionPointToEnd(condBlock);
401 rewriter.create<CondBranchOp>(loc, ifOp.condition(), thenBlock,
402 /*trueArgs=*/ArrayRef<Value>(), elseBlock,
403 /*falseArgs=*/ArrayRef<Value>());
404
405 // Ok, we're done!
406 rewriter.replaceOp(ifOp, continueBlock->getArguments());
407 return success();
408 }
409
410 LogicalResult
411 ExecuteRegionLowering::matchAndRewrite(ExecuteRegionOp op,
412 PatternRewriter &rewriter) const {
413 auto loc = op.getLoc();
414
415 auto *condBlock = rewriter.getInsertionBlock();
416 auto opPosition = rewriter.getInsertionPoint();
417 auto *remainingOpsBlock = rewriter.splitBlock(condBlock, opPosition);
418
419 auto ®ion = op.region();
420 rewriter.setInsertionPointToEnd(condBlock);
421 rewriter.create<BranchOp>(loc, ®ion.front());
422
423 for (Block &block : region) {
424 if (auto terminator = dyn_cast<scf::YieldOp>(block.getTerminator())) {
425 ValueRange terminatorOperands = terminator->getOperands();
disable_master()426 rewriter.setInsertionPointToEnd(&block);
427 rewriter.create<BranchOp>(loc, remainingOpsBlock, terminatorOperands);
428 rewriter.eraseOp(terminator);
429 }
430 }
431
432 rewriter.inlineRegionBefore(region, remainingOpsBlock);
433
434 SmallVector<Value> vals;
435 for (auto arg : remainingOpsBlock->addArguments(op->getResultTypes())) {
436 vals.push_back(arg);
437 }
438 rewriter.replaceOp(op, vals);
439 return success();
440 }
441
442 LogicalResult
443 ParallelLowering::matchAndRewrite(ParallelOp parallelOp,
444 PatternRewriter &rewriter) const {
445 Location loc = parallelOp.getLoc();
446
447 // For a parallel loop, we essentially need to create an n-dimensional loop
448 // nest. We do this by translating to scf.for ops and have those lowered in
449 // a further rewrite. If a parallel loop contains reductions (and thus returns
450 // values), forward the initial values for the reductions down the loop
451 // hierarchy and bubble up the results by modifying the "yield" terminator.
452 SmallVector<Value, 4> iterArgs = llvm::to_vector<4>(parallelOp.initVals());
cleanup()453 SmallVector<Value, 4> ivs;
454 ivs.reserve(parallelOp.getNumLoops());
455 bool first = true;
456 SmallVector<Value, 4> loopResults(iterArgs);
457 for (auto loop_operands :
458 llvm::zip(parallelOp.getInductionVars(), parallelOp.lowerBound(),
459 parallelOp.upperBound(), parallelOp.step())) {
460 Value iv, lower, upper, step;
461 std::tie(iv, lower, upper, step) = loop_operands;
462 ForOp forOp = rewriter.create<ForOp>(loc, lower, upper, step, iterArgs);
463 ivs.push_back(forOp.getInductionVar());
464 auto iterRange = forOp.getRegionIterArgs();
465 iterArgs.assign(iterRange.begin(), iterRange.end());
lock()466
467 if (first) {
468 // Store the results of the outermost loop that will be used to replace
469 // the results of the parallel loop when it is fully rewritten.
470 loopResults.assign(forOp.result_begin(), forOp.result_end());
471 first = false;
472 } else if (!forOp.getResults().empty()) {
473 // A loop is constructed with an empty "yield" terminator if there are
474 // no results.
475 rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
cond_broadcast()476 rewriter.create<scf::YieldOp>(loc, forOp.getResults());
477 }
478
479 rewriter.setInsertionPointToStart(forOp.getBody());
480 }
cond_timewait(struct timespec * wait_time)481
482 // First, merge reduction blocks into the main region.
483 SmallVector<Value, 4> yieldOperands;
484 yieldOperands.reserve(parallelOp.getNumResults());
485 for (auto &op : *parallelOp.getBody()) {
486 auto reduce = dyn_cast<ReduceOp>(op);
487 if (!reduce)
488 continue;
489
490 Block &reduceBlock = reduce.reductionOperator().front();
491 Value arg = iterArgs[yieldOperands.size()];
492 yieldOperands.push_back(reduceBlock.getTerminator()->getOperand(0));
add_slave()493 rewriter.eraseOp(reduceBlock.getTerminator());
494 rewriter.mergeBlockBefore(&reduceBlock, &op, {arg, reduce.operand()});
495 rewriter.eraseOp(reduce);
496 }
497
498 // Then merge the loop body without the terminator.
499 rewriter.eraseOp(parallelOp.getBody()->getTerminator());
remove_slave()500 Block *newBody = rewriter.getInsertionBlock();
501 if (newBody->empty())
502 rewriter.mergeBlocks(parallelOp.getBody(), newBody, ivs);
503 else
504 rewriter.mergeBlockBefore(parallelOp.getBody(), newBody->getTerminator(),
505 ivs);
506
507 // Finally, create the terminator if required (for loops with no results, it
508 // has been already created in loop construction).
509 if (!yieldOperands.empty()) {
510 rewriter.setInsertionPointToEnd(rewriter.getInsertionBlock());
511 rewriter.create<scf::YieldOp>(loc, yieldOperands);
512 }
513
514 rewriter.replaceOp(parallelOp, loopResults);
515
516 return success();
517 }
518
report_reply_packet(uint32 server_id,const uchar * packet,ulong packet_len)519 LogicalResult WhileLowering::matchAndRewrite(WhileOp whileOp,
520 PatternRewriter &rewriter) const {
521 OpBuilder::InsertionGuard guard(rewriter);
522 Location loc = whileOp.getLoc();
523
524 // Split the current block before the WhileOp to create the inlining point.
525 Block *currentBlock = rewriter.getInsertionBlock();
526 Block *continuation =
527 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
528
529 // Inline both regions.
530 Block *after = &whileOp.after().front();
531 Block *afterLast = &whileOp.after().back();
532 Block *before = &whileOp.before().front();
533 Block *beforeLast = &whileOp.before().back();
534 rewriter.inlineRegionBefore(whileOp.after(), continuation);
535 rewriter.inlineRegionBefore(whileOp.before(), after);
536
537 // Branch to the "before" region.
538 rewriter.setInsertionPointToEnd(currentBlock);
539 rewriter.create<BranchOp>(loc, before, whileOp.inits());
540
541 // Replace terminators with branches. Assuming bodies are SESE, which holds
542 // given only the patterns from this file, we only need to look at the last
543 // block. This should be reconsidered if we allow break/continue in SCF.
544 rewriter.setInsertionPointToEnd(beforeLast);
545 auto condOp = cast<ConditionOp>(beforeLast->getTerminator());
546 rewriter.replaceOpWithNewOp<CondBranchOp>(condOp, condOp.condition(), after,
547 condOp.args(), continuation,
548 ValueRange());
549
550 rewriter.setInsertionPointToEnd(afterLast);
551 auto yieldOp = cast<scf::YieldOp>(afterLast->getTerminator());
552 rewriter.replaceOpWithNewOp<BranchOp>(yieldOp, before, yieldOp.results());
553
554 // Replace the op with values "yielded" from the "before" region, which are
555 // visible by dominance.
556 rewriter.replaceOp(whileOp, condOp.args());
557
558 return success();
559 }
560
561 LogicalResult
562 DoWhileLowering::matchAndRewrite(WhileOp whileOp,
563 PatternRewriter &rewriter) const {
564 if (!llvm::hasSingleElement(whileOp.after()))
565 return rewriter.notifyMatchFailure(whileOp,
566 "do-while simplification applicable to "
report_reply_binlog(uint32 server_id,const char * log_file_name,my_off_t log_file_pos)567 "single-block 'after' region only");
568
569 Block &afterBlock = whileOp.after().front();
570 if (!llvm::hasSingleElement(afterBlock))
571 return rewriter.notifyMatchFailure(whileOp,
572 "do-while simplification applicable "
573 "only if 'after' region has no payload");
574
575 auto yield = dyn_cast<scf::YieldOp>(&afterBlock.front());
576 if (!yield || yield.results() != afterBlock.getArguments())
577 return rewriter.notifyMatchFailure(whileOp,
578 "do-while simplification applicable "
579 "only to forwarding 'after' regions");
580
581 // Split the current block before the WhileOp to create the inlining point.
582 OpBuilder::InsertionGuard guard(rewriter);
583 Block *currentBlock = rewriter.getInsertionBlock();
584 Block *continuation =
585 rewriter.splitBlock(currentBlock, rewriter.getInsertionPoint());
586
587 // Only the "before" region should be inlined.
588 Block *before = &whileOp.before().front();
589 Block *beforeLast = &whileOp.before().back();
590 rewriter.inlineRegionBefore(whileOp.before(), continuation);
591
592 // Branch to the "before" region.
593 rewriter.setInsertionPointToEnd(currentBlock);
594 rewriter.create<BranchOp>(whileOp.getLoc(), before, whileOp.inits());
595
596 // Loop around the "before" region based on condition.
597 rewriter.setInsertionPointToEnd(beforeLast);
598 auto condOp = cast<ConditionOp>(beforeLast->getTerminator());
599 rewriter.replaceOpWithNewOp<CondBranchOp>(condOp, condOp.condition(), before,
600 condOp.args(), continuation,
601 ValueRange());
602
603 // Replace the op with values "yielded" from the "before" region, which are
604 // visible by dominance.
605 rewriter.replaceOp(whileOp, condOp.args());
606
607 return success();
608 }
609
610 void mlir::populateLoopToStdConversionPatterns(RewritePatternSet &patterns) {
611 patterns.add<ForLowering, IfLowering, ParallelLowering, WhileLowering,
612 ExecuteRegionLowering>(patterns.getContext());
613 patterns.add<DoWhileLowering>(patterns.getContext(), /*benefit=*/2);
614 }
615
616 void SCFToStandardPass::runOnOperation() {
617 RewritePatternSet patterns(&getContext());
618 populateLoopToStdConversionPatterns(patterns);
619 // Configure conversion to lower out scf.for, scf.if, scf.parallel and
620 // scf.while. Anything else is fine.
621 ConversionTarget target(getContext());
622 target.addIllegalOp<scf::ForOp, scf::IfOp, scf::ParallelOp, scf::WhileOp,
623 scf::ExecuteRegionOp>();
624 target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
625 if (failed(
626 applyPartialConversion(getOperation(), target, std::move(patterns))))
627 signalPassFailure();
628 }
629
630 std::unique_ptr<Pass> mlir::createLowerToCFGPass() {
631 return std::make_unique<SCFToStandardPass>();
632 }
633