1 //===- TLSVariableHoist.cpp -------- Remove Redundant TLS Loads ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass identifies/eliminate Redundant TLS Loads if related option is set.
10 // The example: Please refer to the comment at the head of TLSVariableHoist.h.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "llvm/ADT/SmallVector.h"
15 #include "llvm/IR/BasicBlock.h"
16 #include "llvm/IR/Dominators.h"
17 #include "llvm/IR/Function.h"
18 #include "llvm/IR/InstrTypes.h"
19 #include "llvm/IR/Instruction.h"
20 #include "llvm/IR/Instructions.h"
21 #include "llvm/IR/IntrinsicInst.h"
22 #include "llvm/IR/Module.h"
23 #include "llvm/IR/Value.h"
24 #include "llvm/InitializePasses.h"
25 #include "llvm/Pass.h"
26 #include "llvm/Support/Casting.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/raw_ostream.h"
29 #include "llvm/Transforms/Scalar.h"
30 #include "llvm/Transforms/Scalar/TLSVariableHoist.h"
31 #include <algorithm>
32 #include <cassert>
33 #include <cstdint>
34 #include <iterator>
35 #include <utility>
36
37 using namespace llvm;
38 using namespace tlshoist;
39
40 #define DEBUG_TYPE "tlshoist"
41
42 static cl::opt<bool> TLSLoadHoist(
43 "tls-load-hoist", cl::init(false), cl::Hidden,
44 cl::desc("hoist the TLS loads in PIC model to eliminate redundant "
45 "TLS address calculation."));
46
47 namespace {
48
49 /// The TLS Variable hoist pass.
50 class TLSVariableHoistLegacyPass : public FunctionPass {
51 public:
52 static char ID; // Pass identification, replacement for typeid
53
TLSVariableHoistLegacyPass()54 TLSVariableHoistLegacyPass() : FunctionPass(ID) {
55 initializeTLSVariableHoistLegacyPassPass(*PassRegistry::getPassRegistry());
56 }
57
58 bool runOnFunction(Function &Fn) override;
59
getPassName() const60 StringRef getPassName() const override { return "TLS Variable Hoist"; }
61
getAnalysisUsage(AnalysisUsage & AU) const62 void getAnalysisUsage(AnalysisUsage &AU) const override {
63 AU.setPreservesCFG();
64 AU.addRequired<DominatorTreeWrapperPass>();
65 AU.addRequired<LoopInfoWrapperPass>();
66 }
67
68 private:
69 TLSVariableHoistPass Impl;
70 };
71
72 } // end anonymous namespace
73
74 char TLSVariableHoistLegacyPass::ID = 0;
75
76 INITIALIZE_PASS_BEGIN(TLSVariableHoistLegacyPass, "tlshoist",
77 "TLS Variable Hoist", false, false)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)78 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
79 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
80 INITIALIZE_PASS_END(TLSVariableHoistLegacyPass, "tlshoist",
81 "TLS Variable Hoist", false, false)
82
83 FunctionPass *llvm::createTLSVariableHoistPass() {
84 return new TLSVariableHoistLegacyPass();
85 }
86
87 /// Perform the TLS Variable Hoist optimization for the given function.
runOnFunction(Function & Fn)88 bool TLSVariableHoistLegacyPass::runOnFunction(Function &Fn) {
89 if (skipFunction(Fn))
90 return false;
91
92 LLVM_DEBUG(dbgs() << "********** Begin TLS Variable Hoist **********\n");
93 LLVM_DEBUG(dbgs() << "********** Function: " << Fn.getName() << '\n');
94
95 bool MadeChange =
96 Impl.runImpl(Fn, getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
97 getAnalysis<LoopInfoWrapperPass>().getLoopInfo());
98
99 if (MadeChange) {
100 LLVM_DEBUG(dbgs() << "********** Function after TLS Variable Hoist: "
101 << Fn.getName() << '\n');
102 LLVM_DEBUG(dbgs() << Fn);
103 }
104 LLVM_DEBUG(dbgs() << "********** End TLS Variable Hoist **********\n");
105
106 return MadeChange;
107 }
108
collectTLSCandidate(Instruction * Inst)109 void TLSVariableHoistPass::collectTLSCandidate(Instruction *Inst) {
110 // Skip all cast instructions. They are visited indirectly later on.
111 if (Inst->isCast())
112 return;
113
114 // Scan all operands.
115 for (unsigned Idx = 0, E = Inst->getNumOperands(); Idx != E; ++Idx) {
116 auto *GV = dyn_cast<GlobalVariable>(Inst->getOperand(Idx));
117 if (!GV || !GV->isThreadLocal())
118 continue;
119
120 // Add Candidate to TLSCandMap (GV --> Candidate).
121 TLSCandMap[GV].addUser(Inst, Idx);
122 }
123 }
124
collectTLSCandidates(Function & Fn)125 void TLSVariableHoistPass::collectTLSCandidates(Function &Fn) {
126 // First, quickly check if there is TLS Variable.
127 Module *M = Fn.getParent();
128
129 bool HasTLS = llvm::any_of(
130 M->globals(), [](GlobalVariable &GV) { return GV.isThreadLocal(); });
131
132 // If non, directly return.
133 if (!HasTLS)
134 return;
135
136 TLSCandMap.clear();
137
138 // Then, collect TLS Variable info.
139 for (BasicBlock &BB : Fn) {
140 // Ignore unreachable basic blocks.
141 if (!DT->isReachableFromEntry(&BB))
142 continue;
143
144 for (Instruction &Inst : BB)
145 collectTLSCandidate(&Inst);
146 }
147 }
148
oneUseOutsideLoop(tlshoist::TLSCandidate & Cand,LoopInfo * LI)149 static bool oneUseOutsideLoop(tlshoist::TLSCandidate &Cand, LoopInfo *LI) {
150 if (Cand.Users.size() != 1)
151 return false;
152
153 BasicBlock *BB = Cand.Users[0].Inst->getParent();
154 if (LI->getLoopFor(BB))
155 return false;
156
157 return true;
158 }
159
getNearestLoopDomInst(BasicBlock * BB,Loop * L)160 Instruction *TLSVariableHoistPass::getNearestLoopDomInst(BasicBlock *BB,
161 Loop *L) {
162 assert(L && "Unexcepted Loop status!");
163
164 // Get the outermost loop.
165 while (Loop *Parent = L->getParentLoop())
166 L = Parent;
167
168 BasicBlock *PreHeader = L->getLoopPreheader();
169
170 // There is unique predecessor outside the loop.
171 if (PreHeader)
172 return PreHeader->getTerminator();
173
174 BasicBlock *Header = L->getHeader();
175 BasicBlock *Dom = Header;
176 for (BasicBlock *PredBB : predecessors(Header))
177 Dom = DT->findNearestCommonDominator(Dom, PredBB);
178
179 assert(Dom && "Not find dominator BB!");
180 Instruction *Term = Dom->getTerminator();
181
182 return Term;
183 }
184
getDomInst(Instruction * I1,Instruction * I2)185 Instruction *TLSVariableHoistPass::getDomInst(Instruction *I1,
186 Instruction *I2) {
187 if (!I1)
188 return I2;
189 return DT->findNearestCommonDominator(I1, I2);
190 }
191
findInsertPos(Function & Fn,GlobalVariable * GV,BasicBlock * & PosBB)192 BasicBlock::iterator TLSVariableHoistPass::findInsertPos(Function &Fn,
193 GlobalVariable *GV,
194 BasicBlock *&PosBB) {
195 tlshoist::TLSCandidate &Cand = TLSCandMap[GV];
196
197 // We should hoist the TLS use out of loop, so choose its nearest instruction
198 // which dominate the loop and the outside loops (if exist).
199 Instruction *LastPos = nullptr;
200 for (auto &User : Cand.Users) {
201 BasicBlock *BB = User.Inst->getParent();
202 Instruction *Pos = User.Inst;
203 if (Loop *L = LI->getLoopFor(BB)) {
204 Pos = getNearestLoopDomInst(BB, L);
205 assert(Pos && "Not find insert position out of loop!");
206 }
207 Pos = getDomInst(LastPos, Pos);
208 LastPos = Pos;
209 }
210
211 assert(LastPos && "Unexpected insert position!");
212 BasicBlock *Parent = LastPos->getParent();
213 PosBB = Parent;
214 return LastPos->getIterator();
215 }
216
217 // Generate a bitcast (no type change) to replace the uses of TLS Candidate.
genBitCastInst(Function & Fn,GlobalVariable * GV)218 Instruction *TLSVariableHoistPass::genBitCastInst(Function &Fn,
219 GlobalVariable *GV) {
220 BasicBlock *PosBB = &Fn.getEntryBlock();
221 BasicBlock::iterator Iter = findInsertPos(Fn, GV, PosBB);
222 Type *Ty = GV->getType();
223 auto *CastInst = new BitCastInst(GV, Ty, "tls_bitcast");
224 CastInst->insertInto(PosBB, Iter);
225 return CastInst;
226 }
227
tryReplaceTLSCandidate(Function & Fn,GlobalVariable * GV)228 bool TLSVariableHoistPass::tryReplaceTLSCandidate(Function &Fn,
229 GlobalVariable *GV) {
230
231 tlshoist::TLSCandidate &Cand = TLSCandMap[GV];
232
233 // If only used 1 time and not in loops, we no need to replace it.
234 if (oneUseOutsideLoop(Cand, LI))
235 return false;
236
237 // Generate a bitcast (no type change)
238 auto *CastInst = genBitCastInst(Fn, GV);
239
240 // to replace the uses of TLS Candidate
241 for (auto &User : Cand.Users)
242 User.Inst->setOperand(User.OpndIdx, CastInst);
243
244 return true;
245 }
246
tryReplaceTLSCandidates(Function & Fn)247 bool TLSVariableHoistPass::tryReplaceTLSCandidates(Function &Fn) {
248 if (TLSCandMap.empty())
249 return false;
250
251 bool Replaced = false;
252 for (auto &GV2Cand : TLSCandMap) {
253 GlobalVariable *GV = GV2Cand.first;
254 Replaced |= tryReplaceTLSCandidate(Fn, GV);
255 }
256
257 return Replaced;
258 }
259
260 /// Optimize expensive TLS variables in the given function.
runImpl(Function & Fn,DominatorTree & DT,LoopInfo & LI)261 bool TLSVariableHoistPass::runImpl(Function &Fn, DominatorTree &DT,
262 LoopInfo &LI) {
263 if (Fn.hasOptNone())
264 return false;
265
266 if (!TLSLoadHoist && !Fn.getAttributes().hasFnAttr("tls-load-hoist"))
267 return false;
268
269 this->LI = &LI;
270 this->DT = &DT;
271 assert(this->LI && this->DT && "Unexcepted requirement!");
272
273 // Collect all TLS variable candidates.
274 collectTLSCandidates(Fn);
275
276 bool MadeChange = tryReplaceTLSCandidates(Fn);
277
278 return MadeChange;
279 }
280
run(Function & F,FunctionAnalysisManager & AM)281 PreservedAnalyses TLSVariableHoistPass::run(Function &F,
282 FunctionAnalysisManager &AM) {
283
284 auto &LI = AM.getResult<LoopAnalysis>(F);
285 auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
286
287 if (!runImpl(F, DT, LI))
288 return PreservedAnalyses::all();
289
290 PreservedAnalyses PA;
291 PA.preserveSet<CFGAnalyses>();
292 return PA;
293 }
294