1 //===- TLSVariableHoist.cpp -------- Remove Redundant TLS Loads ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass identifies/eliminate Redundant TLS Loads if related option is set.
10 // The example: Please refer to the comment at the head of TLSVariableHoist.h.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/ADT/SmallVector.h"
15 #include "llvm/IR/BasicBlock.h"
16 #include "llvm/IR/Dominators.h"
17 #include "llvm/IR/Function.h"
18 #include "llvm/IR/InstrTypes.h"
19 #include "llvm/IR/Instruction.h"
20 #include "llvm/IR/Instructions.h"
21 #include "llvm/IR/IntrinsicInst.h"
22 #include "llvm/IR/Module.h"
23 #include "llvm/IR/Value.h"
24 #include "llvm/InitializePasses.h"
25 #include "llvm/Pass.h"
26 #include "llvm/Support/Casting.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/raw_ostream.h"
29 #include "llvm/Transforms/Scalar.h"
30 #include "llvm/Transforms/Scalar/TLSVariableHoist.h"
31 #include <algorithm>
32 #include <cassert>
33 #include <cstdint>
34 #include <iterator>
35 #include <tuple>
36 #include <utility>
37 
38 using namespace llvm;
39 using namespace tlshoist;
40 
41 #define DEBUG_TYPE "tlshoist"
42 
43 static cl::opt<bool> TLSLoadHoist(
44     "tls-load-hoist", cl::init(false), cl::Hidden,
45     cl::desc("hoist the TLS loads in PIC model to eliminate redundant "
46              "TLS address calculation."));
47 
48 namespace {
49 
50 /// The TLS Variable hoist pass.
51 class TLSVariableHoistLegacyPass : public FunctionPass {
52 public:
53   static char ID; // Pass identification, replacement for typeid
54 
55   TLSVariableHoistLegacyPass() : FunctionPass(ID) {
56     initializeTLSVariableHoistLegacyPassPass(*PassRegistry::getPassRegistry());
57   }
58 
59   bool runOnFunction(Function &Fn) override;
60 
61   StringRef getPassName() const override { return "TLS Variable Hoist"; }
62 
63   void getAnalysisUsage(AnalysisUsage &AU) const override {
64     AU.setPreservesCFG();
65     AU.addRequired<DominatorTreeWrapperPass>();
66     AU.addRequired<LoopInfoWrapperPass>();
67   }
68 
69 private:
70   TLSVariableHoistPass Impl;
71 };
72 
73 } // end anonymous namespace
74 
75 char TLSVariableHoistLegacyPass::ID = 0;
76 
77 INITIALIZE_PASS_BEGIN(TLSVariableHoistLegacyPass, "tlshoist",
78                       "TLS Variable Hoist", false, false)
79 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
80 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
81 INITIALIZE_PASS_END(TLSVariableHoistLegacyPass, "tlshoist",
82                     "TLS Variable Hoist", false, false)
83 
84 FunctionPass *llvm::createTLSVariableHoistPass() {
85   return new TLSVariableHoistLegacyPass();
86 }
87 
88 /// Perform the TLS Variable Hoist optimization for the given function.
89 bool TLSVariableHoistLegacyPass::runOnFunction(Function &Fn) {
90   if (skipFunction(Fn))
91     return false;
92 
93   LLVM_DEBUG(dbgs() << "********** Begin TLS Variable Hoist **********\n");
94   LLVM_DEBUG(dbgs() << "********** Function: " << Fn.getName() << '\n');
95 
96   bool MadeChange =
97       Impl.runImpl(Fn, getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
98                    getAnalysis<LoopInfoWrapperPass>().getLoopInfo());
99 
100   if (MadeChange) {
101     LLVM_DEBUG(dbgs() << "********** Function after TLS Variable Hoist: "
102                       << Fn.getName() << '\n');
103     LLVM_DEBUG(dbgs() << Fn);
104   }
105   LLVM_DEBUG(dbgs() << "********** End TLS Variable Hoist **********\n");
106 
107   return MadeChange;
108 }
109 
110 void TLSVariableHoistPass::collectTLSCandidate(Instruction *Inst) {
111   // Skip all cast instructions. They are visited indirectly later on.
112   if (Inst->isCast())
113     return;
114 
115   // Scan all operands.
116   for (unsigned Idx = 0, E = Inst->getNumOperands(); Idx != E; ++Idx) {
117     auto *GV = dyn_cast<GlobalVariable>(Inst->getOperand(Idx));
118     if (!GV || !GV->isThreadLocal())
119       continue;
120 
121     // Add Candidate to TLSCandMap (GV --> Candidate).
122     TLSCandMap[GV].addUser(Inst, Idx);
123   }
124 }
125 
126 void TLSVariableHoistPass::collectTLSCandidates(Function &Fn) {
127   // First, quickly check if there is TLS Variable.
128   Module *M = Fn.getParent();
129 
130   bool HasTLS = llvm::any_of(
131       M->globals(), [](GlobalVariable &GV) { return GV.isThreadLocal(); });
132 
133   // If non, directly return.
134   if (!HasTLS)
135     return;
136 
137   TLSCandMap.clear();
138 
139   // Then, collect TLS Variable info.
140   for (BasicBlock &BB : Fn) {
141     // Ignore unreachable basic blocks.
142     if (!DT->isReachableFromEntry(&BB))
143       continue;
144 
145     for (Instruction &Inst : BB)
146       collectTLSCandidate(&Inst);
147   }
148 }
149 
150 static bool oneUseOutsideLoop(tlshoist::TLSCandidate &Cand, LoopInfo *LI) {
151   if (Cand.Users.size() != 1)
152     return false;
153 
154   BasicBlock *BB = Cand.Users[0].Inst->getParent();
155   if (LI->getLoopFor(BB))
156     return false;
157 
158   return true;
159 }
160 
161 Instruction *TLSVariableHoistPass::getNearestLoopDomInst(BasicBlock *BB,
162                                                          Loop *L) {
163   assert(L && "Unexcepted Loop status!");
164 
165   // Get the outermost loop.
166   while (Loop *Parent = L->getParentLoop())
167     L = Parent;
168 
169   BasicBlock *PreHeader = L->getLoopPreheader();
170 
171   // There is unique predecessor outside the loop.
172   if (PreHeader)
173     return PreHeader->getTerminator();
174 
175   BasicBlock *Header = L->getHeader();
176   BasicBlock *Dom = Header;
177   for (BasicBlock *PredBB : predecessors(Header))
178     Dom = DT->findNearestCommonDominator(Dom, PredBB);
179 
180   assert(Dom && "Not find dominator BB!");
181   Instruction *Term = Dom->getTerminator();
182 
183   return Term;
184 }
185 
186 Instruction *TLSVariableHoistPass::getDomInst(Instruction *I1,
187                                               Instruction *I2) {
188   if (!I1)
189     return I2;
190   if (DT->dominates(I1, I2))
191     return I1;
192   if (DT->dominates(I2, I1))
193     return I2;
194 
195   // If there is no dominance relation, use common dominator.
196   BasicBlock *DomBB =
197       DT->findNearestCommonDominator(I1->getParent(), I2->getParent());
198 
199   Instruction *Dom = DomBB->getTerminator();
200   assert(Dom && "Common dominator not found!");
201 
202   return Dom;
203 }
204 
205 BasicBlock::iterator TLSVariableHoistPass::findInsertPos(Function &Fn,
206                                                          GlobalVariable *GV,
207                                                          BasicBlock *&PosBB) {
208   tlshoist::TLSCandidate &Cand = TLSCandMap[GV];
209 
210   // We should hoist the TLS use out of loop, so choose its nearest instruction
211   // which dominate the loop and the outside loops (if exist).
212   Instruction *LastPos = nullptr;
213   for (auto &User : Cand.Users) {
214     BasicBlock *BB = User.Inst->getParent();
215     Instruction *Pos = User.Inst;
216     if (Loop *L = LI->getLoopFor(BB)) {
217       Pos = getNearestLoopDomInst(BB, L);
218       assert(Pos && "Not find insert position out of loop!");
219     }
220     Pos = getDomInst(LastPos, Pos);
221     LastPos = Pos;
222   }
223 
224   assert(LastPos && "Unexpected insert position!");
225   BasicBlock *Parent = LastPos->getParent();
226   PosBB = Parent;
227   return LastPos->getIterator();
228 }
229 
230 // Generate a bitcast (no type change) to replace the uses of TLS Candidate.
231 Instruction *TLSVariableHoistPass::genBitCastInst(Function &Fn,
232                                                   GlobalVariable *GV) {
233   BasicBlock *PosBB = &Fn.getEntryBlock();
234   BasicBlock::iterator Iter = findInsertPos(Fn, GV, PosBB);
235   Type *Ty = GV->getType();
236   auto *CastInst = new BitCastInst(GV, Ty, "tls_bitcast");
237   PosBB->getInstList().insert(Iter, CastInst);
238   return CastInst;
239 }
240 
241 bool TLSVariableHoistPass::tryReplaceTLSCandidate(Function &Fn,
242                                                   GlobalVariable *GV) {
243 
244   tlshoist::TLSCandidate &Cand = TLSCandMap[GV];
245 
246   // If only used 1 time and not in loops, we no need to replace it.
247   if (oneUseOutsideLoop(Cand, LI))
248     return false;
249 
250   // Generate a bitcast (no type change)
251   auto *CastInst = genBitCastInst(Fn, GV);
252 
253   // to replace the uses of TLS Candidate
254   for (auto &User : Cand.Users)
255     User.Inst->setOperand(User.OpndIdx, CastInst);
256 
257   return true;
258 }
259 
260 bool TLSVariableHoistPass::tryReplaceTLSCandidates(Function &Fn) {
261   if (TLSCandMap.empty())
262     return false;
263 
264   bool Replaced = false;
265   for (auto &GV2Cand : TLSCandMap) {
266     GlobalVariable *GV = GV2Cand.first;
267     Replaced |= tryReplaceTLSCandidate(Fn, GV);
268   }
269 
270   return Replaced;
271 }
272 
273 /// Optimize expensive TLS variables in the given function.
274 bool TLSVariableHoistPass::runImpl(Function &Fn, DominatorTree &DT,
275                                    LoopInfo &LI) {
276   if (Fn.hasOptNone())
277     return false;
278 
279   if (!TLSLoadHoist && !Fn.getAttributes().hasFnAttr("tls-load-hoist"))
280     return false;
281 
282   this->LI = &LI;
283   this->DT = &DT;
284   assert(this->LI && this->DT && "Unexcepted requirement!");
285 
286   // Collect all TLS variable candidates.
287   collectTLSCandidates(Fn);
288 
289   bool MadeChange = tryReplaceTLSCandidates(Fn);
290 
291   return MadeChange;
292 }
293 
294 PreservedAnalyses TLSVariableHoistPass::run(Function &F,
295                                             FunctionAnalysisManager &AM) {
296 
297   auto &LI = AM.getResult<LoopAnalysis>(F);
298   auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
299 
300   if (!runImpl(F, DT, LI))
301     return PreservedAnalyses::all();
302 
303   PreservedAnalyses PA;
304   PA.preserveSet<CFGAnalyses>();
305   return PA;
306 }
307