1 //===- UniformityAnalysis.cpp ---------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Analysis/UniformityAnalysis.h"
10 #include "llvm/ADT/GenericUniformityImpl.h"
11 #include "llvm/Analysis/CycleAnalysis.h"
12 #include "llvm/Analysis/TargetTransformInfo.h"
13 #include "llvm/IR/Constants.h"
14 #include "llvm/IR/Dominators.h"
15 #include "llvm/IR/InstIterator.h"
16 #include "llvm/IR/Instructions.h"
17 #include "llvm/InitializePasses.h"
18 
19 using namespace llvm;
20 
21 template <>
22 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::hasDivergentDefs(
23     const Instruction &I) const {
24   return isDivergent((const Value *)&I);
25 }
26 
27 template <>
28 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::markDefsDivergent(
29     const Instruction &Instr) {
30   return markDivergent(cast<Value>(&Instr));
31 }
32 
33 template <> void llvm::GenericUniformityAnalysisImpl<SSAContext>::initialize() {
34   for (auto &I : instructions(F)) {
35     if (TTI->isSourceOfDivergence(&I))
36       markDivergent(I);
37     else if (TTI->isAlwaysUniform(&I))
38       addUniformOverride(I);
39   }
40   for (auto &Arg : F.args()) {
41     if (TTI->isSourceOfDivergence(&Arg)) {
42       markDivergent(&Arg);
43     }
44   }
45 }
46 
47 template <>
48 void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
49     const Value *V) {
50   for (const auto *User : V->users()) {
51     if (const auto *UserInstr = dyn_cast<const Instruction>(User)) {
52       markDivergent(*UserInstr);
53     }
54   }
55 }
56 
57 template <>
58 void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
59     const Instruction &Instr) {
60   assert(!isAlwaysUniform(Instr));
61   if (Instr.isTerminator())
62     return;
63   pushUsers(cast<Value>(&Instr));
64 }
65 
66 template <>
67 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::usesValueFromCycle(
68     const Instruction &I, const Cycle &DefCycle) const {
69   assert(!isAlwaysUniform(I));
70   for (const Use &U : I.operands()) {
71     if (auto *I = dyn_cast<Instruction>(&U)) {
72       if (DefCycle.contains(I->getParent()))
73         return true;
74     }
75   }
76   return false;
77 }
78 
79 template <>
80 void llvm::GenericUniformityAnalysisImpl<
81     SSAContext>::propagateTemporalDivergence(const Instruction &I,
82                                              const Cycle &DefCycle) {
83   if (isDivergent(I))
84     return;
85   for (auto *User : I.users()) {
86     auto *UserInstr = cast<Instruction>(User);
87     if (DefCycle.contains(UserInstr->getParent()))
88       continue;
89     markDivergent(*UserInstr);
90   }
91 }
92 
93 template <>
94 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::isDivergentUse(
95     const Use &U) const {
96   const auto *V = U.get();
97   if (isDivergent(V))
98     return true;
99   if (const auto *DefInstr = dyn_cast<Instruction>(V)) {
100     const auto *UseInstr = cast<Instruction>(U.getUser());
101     return isTemporalDivergent(*UseInstr->getParent(), *DefInstr);
102   }
103   return false;
104 }
105 
106 // This ensures explicit instantiation of
107 // GenericUniformityAnalysisImpl::ImplDeleter::operator()
108 template class llvm::GenericUniformityInfo<SSAContext>;
109 template struct llvm::GenericUniformityAnalysisImplDeleter<
110     llvm::GenericUniformityAnalysisImpl<SSAContext>>;
111 
112 //===----------------------------------------------------------------------===//
113 //  UniformityInfoAnalysis and related pass implementations
114 //===----------------------------------------------------------------------===//
115 
116 llvm::UniformityInfo UniformityInfoAnalysis::run(Function &F,
117                                                  FunctionAnalysisManager &FAM) {
118   auto &DT = FAM.getResult<DominatorTreeAnalysis>(F);
119   auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
120   auto &CI = FAM.getResult<CycleAnalysis>(F);
121   UniformityInfo UI{F, DT, CI, &TTI};
122   // Skip computation if we can assume everything is uniform.
123   if (TTI.hasBranchDivergence(&F))
124     UI.compute();
125 
126   return UI;
127 }
128 
129 AnalysisKey UniformityInfoAnalysis::Key;
130 
131 UniformityInfoPrinterPass::UniformityInfoPrinterPass(raw_ostream &OS)
132     : OS(OS) {}
133 
134 PreservedAnalyses UniformityInfoPrinterPass::run(Function &F,
135                                                  FunctionAnalysisManager &AM) {
136   OS << "UniformityInfo for function '" << F.getName() << "':\n";
137   AM.getResult<UniformityInfoAnalysis>(F).print(OS);
138 
139   return PreservedAnalyses::all();
140 }
141 
142 //===----------------------------------------------------------------------===//
143 //  UniformityInfoWrapperPass Implementation
144 //===----------------------------------------------------------------------===//
145 
146 char UniformityInfoWrapperPass::ID = 0;
147 
148 UniformityInfoWrapperPass::UniformityInfoWrapperPass() : FunctionPass(ID) {
149   initializeUniformityInfoWrapperPassPass(*PassRegistry::getPassRegistry());
150 }
151 
152 INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass, "uniformity",
153                       "Uniformity Analysis", true, true)
154 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
155 INITIALIZE_PASS_DEPENDENCY(CycleInfoWrapperPass)
156 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
157 INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniformity",
158                     "Uniformity Analysis", true, true)
159 
160 void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
161   AU.setPreservesAll();
162   AU.addRequired<DominatorTreeWrapperPass>();
163   AU.addRequiredTransitive<CycleInfoWrapperPass>();
164   AU.addRequired<TargetTransformInfoWrapperPass>();
165 }
166 
167 bool UniformityInfoWrapperPass::runOnFunction(Function &F) {
168   auto &cycleInfo = getAnalysis<CycleInfoWrapperPass>().getResult();
169   auto &domTree = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
170   auto &targetTransformInfo =
171       getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
172 
173   m_function = &F;
174   m_uniformityInfo =
175       UniformityInfo{F, domTree, cycleInfo, &targetTransformInfo};
176 
177   // Skip computation if we can assume everything is uniform.
178   if (targetTransformInfo.hasBranchDivergence(m_function))
179     m_uniformityInfo.compute();
180 
181   return false;
182 }
183 
184 void UniformityInfoWrapperPass::print(raw_ostream &OS, const Module *) const {
185   OS << "UniformityInfo for function '" << m_function->getName() << "':\n";
186 }
187 
188 void UniformityInfoWrapperPass::releaseMemory() {
189   m_uniformityInfo = UniformityInfo{};
190   m_function = nullptr;
191 }
192