1 //===- TLSVariableHoist.cpp -------- Remove Redundant TLS Loads ---------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass identifies/eliminate Redundant TLS Loads if related option is set.
10 // The example: Please refer to the comment at the head of TLSVariableHoist.h.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/ADT/SmallVector.h"
15 #include "llvm/IR/BasicBlock.h"
16 #include "llvm/IR/Dominators.h"
17 #include "llvm/IR/Function.h"
18 #include "llvm/IR/InstrTypes.h"
19 #include "llvm/IR/Instruction.h"
20 #include "llvm/IR/Instructions.h"
21 #include "llvm/IR/IntrinsicInst.h"
22 #include "llvm/IR/Module.h"
23 #include "llvm/IR/Value.h"
24 #include "llvm/InitializePasses.h"
25 #include "llvm/Pass.h"
26 #include "llvm/Support/Casting.h"
27 #include "llvm/Support/Debug.h"
28 #include "llvm/Support/raw_ostream.h"
29 #include "llvm/Transforms/Scalar.h"
30 #include "llvm/Transforms/Scalar/TLSVariableHoist.h"
31 #include <algorithm>
32 #include <cassert>
33 #include <cstdint>
34 #include <iterator>
35 #include <tuple>
36 #include <utility>
37 
38 using namespace llvm;
39 using namespace tlshoist;
40 
41 #define DEBUG_TYPE "tlshoist"
42 
43 static cl::opt<bool> TLSLoadHoist(
44     "tls-load-hoist", cl::init(false), cl::Hidden,
45     cl::desc("hoist the TLS loads in PIC model to eliminate redundant "
46              "TLS address calculation."));
47 
48 namespace {
49 
50 /// The TLS Variable hoist pass.
51 class TLSVariableHoistLegacyPass : public FunctionPass {
52 public:
53   static char ID; // Pass identification, replacement for typeid
54 
55   TLSVariableHoistLegacyPass() : FunctionPass(ID) {
56     initializeTLSVariableHoistLegacyPassPass(*PassRegistry::getPassRegistry());
57   }
58 
59   bool runOnFunction(Function &Fn) override;
60 
61   StringRef getPassName() const override { return "TLS Variable Hoist"; }
62 
63   void getAnalysisUsage(AnalysisUsage &AU) const override {
64     AU.setPreservesCFG();
65     AU.addRequired<DominatorTreeWrapperPass>();
66     AU.addRequired<LoopInfoWrapperPass>();
67   }
68 
69 private:
70   TLSVariableHoistPass Impl;
71 };
72 
73 } // end anonymous namespace
74 
75 char TLSVariableHoistLegacyPass::ID = 0;
76 
77 INITIALIZE_PASS_BEGIN(TLSVariableHoistLegacyPass, "tlshoist",
78                       "TLS Variable Hoist", false, false)
79 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
80 INITIALIZE_PASS_DEPENDENCY(LoopInfoWrapperPass)
81 INITIALIZE_PASS_END(TLSVariableHoistLegacyPass, "tlshoist",
82                     "TLS Variable Hoist", false, false)
83 
84 FunctionPass *llvm::createTLSVariableHoistPass() {
85   return new TLSVariableHoistLegacyPass();
86 }
87 
88 /// Perform the TLS Variable Hoist optimization for the given function.
89 bool TLSVariableHoistLegacyPass::runOnFunction(Function &Fn) {
90   if (skipFunction(Fn))
91     return false;
92 
93   LLVM_DEBUG(dbgs() << "********** Begin TLS Variable Hoist **********\n");
94   LLVM_DEBUG(dbgs() << "********** Function: " << Fn.getName() << '\n');
95 
96   bool MadeChange =
97       Impl.runImpl(Fn, getAnalysis<DominatorTreeWrapperPass>().getDomTree(),
98                    getAnalysis<LoopInfoWrapperPass>().getLoopInfo());
99 
100   if (MadeChange) {
101     LLVM_DEBUG(dbgs() << "********** Function after TLS Variable Hoist: "
102                       << Fn.getName() << '\n');
103     LLVM_DEBUG(dbgs() << Fn);
104   }
105   LLVM_DEBUG(dbgs() << "********** End TLS Variable Hoist **********\n");
106 
107   return MadeChange;
108 }
109 
110 void TLSVariableHoistPass::collectTLSCandidate(Instruction *Inst) {
111   // Skip all cast instructions. They are visited indirectly later on.
112   if (Inst->isCast())
113     return;
114 
115   // Scan all operands.
116   for (unsigned Idx = 0, E = Inst->getNumOperands(); Idx != E; ++Idx) {
117     auto *GV = dyn_cast<GlobalVariable>(Inst->getOperand(Idx));
118     if (!GV || !GV->isThreadLocal())
119       continue;
120 
121     // Add Candidate to TLSCandMap (GV --> Candidate).
122     TLSCandMap[GV].addUser(Inst, Idx);
123   }
124 }
125 
126 void TLSVariableHoistPass::collectTLSCandidates(Function &Fn) {
127   // First, quickly check if there is TLS Variable.
128   Module *M = Fn.getParent();
129 
130   bool HasTLS = llvm::any_of(
131       M->globals(), [](GlobalVariable &GV) { return GV.isThreadLocal(); });
132 
133   // If non, directly return.
134   if (!HasTLS)
135     return;
136 
137   TLSCandMap.clear();
138 
139   // Then, collect TLS Variable info.
140   for (BasicBlock &BB : Fn) {
141     // Ignore unreachable basic blocks.
142     if (!DT->isReachableFromEntry(&BB))
143       continue;
144 
145     for (Instruction &Inst : BB)
146       collectTLSCandidate(&Inst);
147   }
148 }
149 
150 static bool oneUseOutsideLoop(tlshoist::TLSCandidate &Cand, LoopInfo *LI) {
151   if (Cand.Users.size() != 1)
152     return false;
153 
154   BasicBlock *BB = Cand.Users[0].Inst->getParent();
155   if (LI->getLoopFor(BB))
156     return false;
157 
158   return true;
159 }
160 
161 Instruction *TLSVariableHoistPass::getNearestLoopDomInst(BasicBlock *BB,
162                                                          Loop *L) {
163   assert(L && "Unexcepted Loop status!");
164 
165   // Get the outermost loop.
166   while (Loop *Parent = L->getParentLoop())
167     L = Parent;
168 
169   BasicBlock *PreHeader = L->getLoopPreheader();
170 
171   // There is unique predecessor outside the loop.
172   if (PreHeader)
173     return PreHeader->getTerminator();
174 
175   BasicBlock *Header = L->getHeader();
176   BasicBlock *Dom = Header;
177   for (BasicBlock *PredBB : predecessors(Header))
178     Dom = DT->findNearestCommonDominator(Dom, PredBB);
179 
180   assert(Dom && "Not find dominator BB!");
181   Instruction *Term = Dom->getTerminator();
182 
183   return Term;
184 }
185 
186 Instruction *TLSVariableHoistPass::getDomInst(Instruction *I1,
187                                               Instruction *I2) {
188   if (!I1)
189     return I2;
190   return DT->findNearestCommonDominator(I1, I2);
191 }
192 
193 BasicBlock::iterator TLSVariableHoistPass::findInsertPos(Function &Fn,
194                                                          GlobalVariable *GV,
195                                                          BasicBlock *&PosBB) {
196   tlshoist::TLSCandidate &Cand = TLSCandMap[GV];
197 
198   // We should hoist the TLS use out of loop, so choose its nearest instruction
199   // which dominate the loop and the outside loops (if exist).
200   Instruction *LastPos = nullptr;
201   for (auto &User : Cand.Users) {
202     BasicBlock *BB = User.Inst->getParent();
203     Instruction *Pos = User.Inst;
204     if (Loop *L = LI->getLoopFor(BB)) {
205       Pos = getNearestLoopDomInst(BB, L);
206       assert(Pos && "Not find insert position out of loop!");
207     }
208     Pos = getDomInst(LastPos, Pos);
209     LastPos = Pos;
210   }
211 
212   assert(LastPos && "Unexpected insert position!");
213   BasicBlock *Parent = LastPos->getParent();
214   PosBB = Parent;
215   return LastPos->getIterator();
216 }
217 
218 // Generate a bitcast (no type change) to replace the uses of TLS Candidate.
219 Instruction *TLSVariableHoistPass::genBitCastInst(Function &Fn,
220                                                   GlobalVariable *GV) {
221   BasicBlock *PosBB = &Fn.getEntryBlock();
222   BasicBlock::iterator Iter = findInsertPos(Fn, GV, PosBB);
223   Type *Ty = GV->getType();
224   auto *CastInst = new BitCastInst(GV, Ty, "tls_bitcast");
225   CastInst->insertInto(PosBB, Iter);
226   return CastInst;
227 }
228 
229 bool TLSVariableHoistPass::tryReplaceTLSCandidate(Function &Fn,
230                                                   GlobalVariable *GV) {
231 
232   tlshoist::TLSCandidate &Cand = TLSCandMap[GV];
233 
234   // If only used 1 time and not in loops, we no need to replace it.
235   if (oneUseOutsideLoop(Cand, LI))
236     return false;
237 
238   // Generate a bitcast (no type change)
239   auto *CastInst = genBitCastInst(Fn, GV);
240 
241   // to replace the uses of TLS Candidate
242   for (auto &User : Cand.Users)
243     User.Inst->setOperand(User.OpndIdx, CastInst);
244 
245   return true;
246 }
247 
248 bool TLSVariableHoistPass::tryReplaceTLSCandidates(Function &Fn) {
249   if (TLSCandMap.empty())
250     return false;
251 
252   bool Replaced = false;
253   for (auto &GV2Cand : TLSCandMap) {
254     GlobalVariable *GV = GV2Cand.first;
255     Replaced |= tryReplaceTLSCandidate(Fn, GV);
256   }
257 
258   return Replaced;
259 }
260 
261 /// Optimize expensive TLS variables in the given function.
262 bool TLSVariableHoistPass::runImpl(Function &Fn, DominatorTree &DT,
263                                    LoopInfo &LI) {
264   if (Fn.hasOptNone())
265     return false;
266 
267   if (!TLSLoadHoist && !Fn.getAttributes().hasFnAttr("tls-load-hoist"))
268     return false;
269 
270   this->LI = &LI;
271   this->DT = &DT;
272   assert(this->LI && this->DT && "Unexcepted requirement!");
273 
274   // Collect all TLS variable candidates.
275   collectTLSCandidates(Fn);
276 
277   bool MadeChange = tryReplaceTLSCandidates(Fn);
278 
279   return MadeChange;
280 }
281 
282 PreservedAnalyses TLSVariableHoistPass::run(Function &F,
283                                             FunctionAnalysisManager &AM) {
284 
285   auto &LI = AM.getResult<LoopAnalysis>(F);
286   auto &DT = AM.getResult<DominatorTreeAnalysis>(F);
287 
288   if (!runImpl(F, DT, LI))
289     return PreservedAnalyses::all();
290 
291   PreservedAnalyses PA;
292   PA.preserveSet<CFGAnalyses>();
293   return PA;
294 }
295