1 //===- UnifyFunctionExitNodes.cpp - Make all functions have a single exit -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass is used to ensure that functions have at most one return and one
10 // unreachable instruction in them.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/Transforms/Utils/UnifyFunctionExitNodes.h"
15 #include "llvm/IR/BasicBlock.h"
16 #include "llvm/IR/Function.h"
17 #include "llvm/IR/Instructions.h"
18 #include "llvm/IR/Type.h"
19 #include "llvm/InitializePasses.h"
20 #include "llvm/Transforms/Utils.h"
21 using namespace llvm;
22 
23 char UnifyFunctionExitNodesLegacyPass::ID = 0;
24 
25 UnifyFunctionExitNodesLegacyPass::UnifyFunctionExitNodesLegacyPass()
26     : FunctionPass(ID) {
27   initializeUnifyFunctionExitNodesLegacyPassPass(
28       *PassRegistry::getPassRegistry());
29 }
30 
31 INITIALIZE_PASS(UnifyFunctionExitNodesLegacyPass, "mergereturn",
32                 "Unify function exit nodes", false, false)
33 
34 Pass *llvm::createUnifyFunctionExitNodesPass() {
35   return new UnifyFunctionExitNodesLegacyPass();
36 }
37 
38 void UnifyFunctionExitNodesLegacyPass::getAnalysisUsage(
39     AnalysisUsage &AU) const {
40   // We preserve the non-critical-edgeness property
41   AU.addPreservedID(BreakCriticalEdgesID);
42   // This is a cluster of orthogonal Transforms
43   AU.addPreservedID(LowerSwitchID);
44 }
45 
46 namespace {
47 
48 bool unifyUnreachableBlocks(Function &F) {
49   std::vector<BasicBlock *> UnreachableBlocks;
50 
51   for (BasicBlock &I : F)
52     if (isa<UnreachableInst>(I.getTerminator()))
53       UnreachableBlocks.push_back(&I);
54 
55   if (UnreachableBlocks.size() <= 1)
56     return false;
57 
58   BasicBlock *UnreachableBlock =
59       BasicBlock::Create(F.getContext(), "UnifiedUnreachableBlock", &F);
60   new UnreachableInst(F.getContext(), UnreachableBlock);
61 
62   for (BasicBlock *BB : UnreachableBlocks) {
63     BB->back().eraseFromParent(); // Remove the unreachable inst.
64     BranchInst::Create(UnreachableBlock, BB);
65   }
66 
67   return true;
68 }
69 
70 bool unifyReturnBlocks(Function &F) {
71   std::vector<BasicBlock *> ReturningBlocks;
72 
73   for (BasicBlock &I : F)
74     if (isa<ReturnInst>(I.getTerminator()))
75       ReturningBlocks.push_back(&I);
76 
77   if (ReturningBlocks.size() <= 1)
78     return false;
79 
80   // Insert a new basic block into the function, add PHI nodes (if the function
81   // returns values), and convert all of the return instructions into
82   // unconditional branches.
83   BasicBlock *NewRetBlock = BasicBlock::Create(F.getContext(),
84                                                "UnifiedReturnBlock", &F);
85 
86   PHINode *PN = nullptr;
87   if (F.getReturnType()->isVoidTy()) {
88     ReturnInst::Create(F.getContext(), nullptr, NewRetBlock);
89   } else {
90     // If the function doesn't return void... add a PHI node to the block...
91     PN = PHINode::Create(F.getReturnType(), ReturningBlocks.size(),
92                          "UnifiedRetVal");
93     PN->insertInto(NewRetBlock, NewRetBlock->end());
94     ReturnInst::Create(F.getContext(), PN, NewRetBlock);
95   }
96 
97   // Loop over all of the blocks, replacing the return instruction with an
98   // unconditional branch.
99   for (BasicBlock *BB : ReturningBlocks) {
100     // Add an incoming element to the PHI node for every return instruction that
101     // is merging into this new block...
102     if (PN)
103       PN->addIncoming(BB->getTerminator()->getOperand(0), BB);
104 
105     BB->back().eraseFromParent(); // Remove the return insn
106     BranchInst::Create(NewRetBlock, BB);
107   }
108 
109   return true;
110 }
111 } // namespace
112 
113 // Unify all exit nodes of the CFG by creating a new BasicBlock, and converting
114 // all returns to unconditional branches to this new basic block. Also, unify
115 // all unreachable blocks.
116 bool UnifyFunctionExitNodesLegacyPass::runOnFunction(Function &F) {
117   bool Changed = false;
118   Changed |= unifyUnreachableBlocks(F);
119   Changed |= unifyReturnBlocks(F);
120   return Changed;
121 }
122 
123 PreservedAnalyses UnifyFunctionExitNodesPass::run(Function &F,
124                                                   FunctionAnalysisManager &AM) {
125   bool Changed = false;
126   Changed |= unifyUnreachableBlocks(F);
127   Changed |= unifyReturnBlocks(F);
128   return Changed ? PreservedAnalyses() : PreservedAnalyses::all();
129 }
130