1 //===- TLSVariableHoist.cpp -------- Remove Redundant TLS Loads ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass identifies/eliminate Redundant TLS Loads if related option is set.
10 // The example: Please refer to the comment at the head of TLSVariableHoist.h.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/ADT/SmallVector.h"
15 #include "llvm/IR/BasicBlock.h"
16 #include "llvm/IR/Dominators.h"
17 #include "llvm/IR/Function.h"
18 #include "llvm/IR/InstrTypes.h"
19 #include "llvm/IR/Instruction.h"
20 #include "llvm/IR/Instructions.h"
21 #include "llvm/IR/IntrinsicInst.h"
22 #include "llvm/IR/Module.h"
23 #include "llvm/IR/Value.h"
24 #include "llvm/InitializePasses.h"
25 #include "llvm/Pass.h"
26 #include "llvm/Support/Casting.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/raw_ostream.h"
29 #include "llvm/Transforms/Scalar.h"
30 #include "llvm/Transforms/Scalar/TLSVariableHoist.h"
31 #include <algorithm>
32 #include <cassert>
33 #include <cstdint>
34 #include <iterator>
35 #include <utility>
36 
37 using namespace llvm;
38 using namespace tlshoist;
39 
40 #define DEBUG_TYPE "tlshoist"
41 
42 static cl::opt<bool> TLSLoadHoist(
43     "tls-load-hoist", cl::init(false), cl::Hidden,
44     cl::desc("hoist the TLS loads in PIC model to eliminate redundant "
45              "TLS address calculation."));
46 
47 namespace {
48 
49 /// The TLS Variable hoist pass.
50 class TLSVariableHoistLegacyPass : public FunctionPass {
51 public:
52   static char ID; // Pass identification, replacement for typeid
53 
TLSVariableHoistLegacyPass()54   TLSVariableHoistLegacyPass() : FunctionPass(ID) {
55     initializeTLSVariableHoistLegacyPassPass(*PassRegistry::getPassRegistry());
56   }
57 
58   bool runOnFunction(Function &Fn) override;
59 
getPassName() const60   StringRef getPassName() const override { return "TLS Variable Hoist"; }
61 
getAnalysisUsage(AnalysisUsage & AU) const62   void getAnalysisUsage(AnalysisUsage &AU) const override {
63     AU.setPreservesCFG();
64     AU.addRequired<DominatorTreeWrapperPass>();
65     AU.addRequired<LoopInfoWrapperPass>();
66   }
67 
68 private:
69   TLSVariableHoistPass Impl;
70 };
71 
72 } // end anonymous namespace
73 
74 char TLSVariableHoistLegacyPass::ID = 0;
75 
76 INITIALIZE_PASS_BEGIN(TLSVariableHoistLegacyPass, "tlshoist",
77                       "TLS Variable Hoist", false, false)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)78 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
79 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
80 INITIALIZE_PASS_END(TLSVariableHoistLegacyPass, "tlshoist",
81                     "TLS Variable Hoist", false, false)
82 
83 FunctionPass *llvm::createTLSVariableHoistPass() {
84   return new TLSVariableHoistLegacyPass();
85 }
86 
87 /// Perform the TLS Variable Hoist optimization for the given function.
runOnFunction(Function & Fn)88 bool TLSVariableHoistLegacyPass::runOnFunction(Function &Fn) {
89   if (skipFunction(Fn))
90     return false;
91 
92   LLVM_DEBUG(dbgs() << "********** Begin TLS Variable Hoist **********\n");
93   LLVM_DEBUG(dbgs() << "********** Function: " << Fn.getName() << '\n');
94 
95   bool MadeChange =
96       Impl.runImpl(Fn, getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
97                    getAnalysis<LoopInfoWrapperPass>().getLoopInfo());
98 
99   if (MadeChange) {
100     LLVM_DEBUG(dbgs() << "********** Function after TLS Variable Hoist: "
101                       << Fn.getName() << '\n');
102     LLVM_DEBUG(dbgs() << Fn);
103   }
104   LLVM_DEBUG(dbgs() << "********** End TLS Variable Hoist **********\n");
105 
106   return MadeChange;
107 }
108 
collectTLSCandidate(Instruction * Inst)109 void TLSVariableHoistPass::collectTLSCandidate(Instruction *Inst) {
110   // Skip all cast instructions. They are visited indirectly later on.
111   if (Inst->isCast())
112     return;
113 
114   // Scan all operands.
115   for (unsigned Idx = 0, E = Inst->getNumOperands(); Idx != E; ++Idx) {
116     auto *GV = dyn_cast<GlobalVariable>(Inst->getOperand(Idx));
117     if (!GV || !GV->isThreadLocal())
118       continue;
119 
120     // Add Candidate to TLSCandMap (GV --> Candidate).
121     TLSCandMap[GV].addUser(Inst, Idx);
122   }
123 }
124 
collectTLSCandidates(Function & Fn)125 void TLSVariableHoistPass::collectTLSCandidates(Function &Fn) {
126   // First, quickly check if there is TLS Variable.
127   Module *M = Fn.getParent();
128 
129   bool HasTLS = llvm::any_of(
130       M->globals(), [](GlobalVariable &GV) { return GV.isThreadLocal(); });
131 
132   // If non, directly return.
133   if (!HasTLS)
134     return;
135 
136   TLSCandMap.clear();
137 
138   // Then, collect TLS Variable info.
139   for (BasicBlock &BB : Fn) {
140     // Ignore unreachable basic blocks.
141     if (!DT->isReachableFromEntry(&BB))
142       continue;
143 
144     for (Instruction &Inst : BB)
145       collectTLSCandidate(&Inst);
146   }
147 }
148 
oneUseOutsideLoop(tlshoist::TLSCandidate & Cand,LoopInfo * LI)149 static bool oneUseOutsideLoop(tlshoist::TLSCandidate &Cand, LoopInfo *LI) {
150   if (Cand.Users.size() != 1)
151     return false;
152 
153   BasicBlock *BB = Cand.Users[0].Inst->getParent();
154   if (LI->getLoopFor(BB))
155     return false;
156 
157   return true;
158 }
159 
getNearestLoopDomInst(BasicBlock * BB,Loop * L)160 Instruction *TLSVariableHoistPass::getNearestLoopDomInst(BasicBlock *BB,
161                                                          Loop *L) {
162   assert(L && "Unexcepted Loop status!");
163 
164   // Get the outermost loop.
165   while (Loop *Parent = L->getParentLoop())
166     L = Parent;
167 
168   BasicBlock *PreHeader = L->getLoopPreheader();
169 
170   // There is unique predecessor outside the loop.
171   if (PreHeader)
172     return PreHeader->getTerminator();
173 
174   BasicBlock *Header = L->getHeader();
175   BasicBlock *Dom = Header;
176   for (BasicBlock *PredBB : predecessors(Header))
177     Dom = DT->findNearestCommonDominator(Dom, PredBB);
178 
179   assert(Dom && "Not find dominator BB!");
180   Instruction *Term = Dom->getTerminator();
181 
182   return Term;
183 }
184 
getDomInst(Instruction * I1,Instruction * I2)185 Instruction *TLSVariableHoistPass::getDomInst(Instruction *I1,
186                                               Instruction *I2) {
187   if (!I1)
188     return I2;
189   return DT->findNearestCommonDominator(I1, I2);
190 }
191 
findInsertPos(Function & Fn,GlobalVariable * GV,BasicBlock * & PosBB)192 BasicBlock::iterator TLSVariableHoistPass::findInsertPos(Function &Fn,
193                                                          GlobalVariable *GV,
194                                                          BasicBlock *&PosBB) {
195   tlshoist::TLSCandidate &Cand = TLSCandMap[GV];
196 
197   // We should hoist the TLS use out of loop, so choose its nearest instruction
198   // which dominate the loop and the outside loops (if exist).
199   Instruction *LastPos = nullptr;
200   for (auto &User : Cand.Users) {
201     BasicBlock *BB = User.Inst->getParent();
202     Instruction *Pos = User.Inst;
203     if (Loop *L = LI->getLoopFor(BB)) {
204       Pos = getNearestLoopDomInst(BB, L);
205       assert(Pos && "Not find insert position out of loop!");
206     }
207     Pos = getDomInst(LastPos, Pos);
208     LastPos = Pos;
209   }
210 
211   assert(LastPos && "Unexpected insert position!");
212   BasicBlock *Parent = LastPos->getParent();
213   PosBB = Parent;
214   return LastPos->getIterator();
215 }
216 
217 // Generate a bitcast (no type change) to replace the uses of TLS Candidate.
genBitCastInst(Function & Fn,GlobalVariable * GV)218 Instruction *TLSVariableHoistPass::genBitCastInst(Function &Fn,
219                                                   GlobalVariable *GV) {
220   BasicBlock *PosBB = &Fn.getEntryBlock();
221   BasicBlock::iterator Iter = findInsertPos(Fn, GV, PosBB);
222   Type *Ty = GV->getType();
223   auto *CastInst = new BitCastInst(GV, Ty, "tls_bitcast");
224   CastInst->insertInto(PosBB, Iter);
225   return CastInst;
226 }
227 
tryReplaceTLSCandidate(Function & Fn,GlobalVariable * GV)228 bool TLSVariableHoistPass::tryReplaceTLSCandidate(Function &Fn,
229                                                   GlobalVariable *GV) {
230 
231   tlshoist::TLSCandidate &Cand = TLSCandMap[GV];
232 
233   // If only used 1 time and not in loops, we no need to replace it.
234   if (oneUseOutsideLoop(Cand, LI))
235     return false;
236 
237   // Generate a bitcast (no type change)
238   auto *CastInst = genBitCastInst(Fn, GV);
239 
240   // to replace the uses of TLS Candidate
241   for (auto &User : Cand.Users)
242     User.Inst->setOperand(User.OpndIdx, CastInst);
243 
244   return true;
245 }
246 
tryReplaceTLSCandidates(Function & Fn)247 bool TLSVariableHoistPass::tryReplaceTLSCandidates(Function &Fn) {
248   if (TLSCandMap.empty())
249     return false;
250 
251   bool Replaced = false;
252   for (auto &GV2Cand : TLSCandMap) {
253     GlobalVariable *GV = GV2Cand.first;
254     Replaced |= tryReplaceTLSCandidate(Fn, GV);
255   }
256 
257   return Replaced;
258 }
259 
260 /// Optimize expensive TLS variables in the given function.
runImpl(Function & Fn,DominatorTree & DT,LoopInfo & LI)261 bool TLSVariableHoistPass::runImpl(Function &Fn, DominatorTree &DT,
262                                    LoopInfo &LI) {
263   if (Fn.hasOptNone())
264     return false;
265 
266   if (!TLSLoadHoist && !Fn.getAttributes().hasFnAttr("tls-load-hoist"))
267     return false;
268 
269   this->LI = &LI;
270   this->DT = &DT;
271   assert(this->LI && this->DT && "Unexcepted requirement!");
272 
273   // Collect all TLS variable candidates.
274   collectTLSCandidates(Fn);
275 
276   bool MadeChange = tryReplaceTLSCandidates(Fn);
277 
278   return MadeChange;
279 }
280 
run(Function & F,FunctionAnalysisManager & AM)281 PreservedAnalyses TLSVariableHoistPass::run(Function &F,
282                                             FunctionAnalysisManager &AM) {
283 
284   auto &LI = AM.getResult<LoopAnalysis>(F);
285   auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
286 
287   if (!runImpl(F, DT, LI))
288     return PreservedAnalyses::all();
289 
290   PreservedAnalyses PA;
291   PA.preserveSet<CFGAnalyses>();
292   return PA;
293 }
294