1 //===- UniformityAnalysis.cpp ---------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "llvm/Analysis/UniformityAnalysis.h"
10 #include "llvm/ADT/GenericUniformityImpl.h"
11 #include "llvm/Analysis/CycleAnalysis.h"
12 #include "llvm/Analysis/TargetTransformInfo.h"
13 #include "llvm/IR/Constants.h"
14 #include "llvm/IR/Dominators.h"
15 #include "llvm/IR/InstIterator.h"
16 #include "llvm/IR/Instructions.h"
17 #include "llvm/InitializePasses.h"
18
19 using namespace llvm;
20
21 template <>
hasDivergentDefs(const Instruction & I) const22 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::hasDivergentDefs(
23 const Instruction &I) const {
24 return isDivergent((const Value *)&I);
25 }
26
27 template <>
markDefsDivergent(const Instruction & Instr)28 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::markDefsDivergent(
29 const Instruction &Instr) {
30 return markDivergent(cast<Value>(&Instr));
31 }
32
initialize()33 template <> void llvm::GenericUniformityAnalysisImpl<SSAContext>::initialize() {
34 for (auto &I : instructions(F)) {
35 if (TTI->isSourceOfDivergence(&I))
36 markDivergent(I);
37 else if (TTI->isAlwaysUniform(&I))
38 addUniformOverride(I);
39 }
40 for (auto &Arg : F.args()) {
41 if (TTI->isSourceOfDivergence(&Arg)) {
42 markDivergent(&Arg);
43 }
44 }
45 }
46
47 template <>
pushUsers(const Value * V)48 void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
49 const Value *V) {
50 for (const auto *User : V->users()) {
51 if (const auto *UserInstr = dyn_cast<const Instruction>(User)) {
52 markDivergent(*UserInstr);
53 }
54 }
55 }
56
57 template <>
pushUsers(const Instruction & Instr)58 void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
59 const Instruction &Instr) {
60 assert(!isAlwaysUniform(Instr));
61 if (Instr.isTerminator())
62 return;
63 pushUsers(cast<Value>(&Instr));
64 }
65
66 template <>
usesValueFromCycle(const Instruction & I,const Cycle & DefCycle) const67 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::usesValueFromCycle(
68 const Instruction &I, const Cycle &DefCycle) const {
69 assert(!isAlwaysUniform(I));
70 for (const Use &U : I.operands()) {
71 if (auto *I = dyn_cast<Instruction>(&U)) {
72 if (DefCycle.contains(I->getParent()))
73 return true;
74 }
75 }
76 return false;
77 }
78
79 template <>
80 void llvm::GenericUniformityAnalysisImpl<
propagateTemporalDivergence(const Instruction & I,const Cycle & DefCycle)81 SSAContext>::propagateTemporalDivergence(const Instruction &I,
82 const Cycle &DefCycle) {
83 if (isDivergent(I))
84 return;
85 for (auto *User : I.users()) {
86 auto *UserInstr = cast<Instruction>(User);
87 if (DefCycle.contains(UserInstr->getParent()))
88 continue;
89 markDivergent(*UserInstr);
90 }
91 }
92
93 template <>
isDivergentUse(const Use & U) const94 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::isDivergentUse(
95 const Use &U) const {
96 const auto *V = U.get();
97 if (isDivergent(V))
98 return true;
99 if (const auto *DefInstr = dyn_cast<Instruction>(V)) {
100 const auto *UseInstr = cast<Instruction>(U.getUser());
101 return isTemporalDivergent(*UseInstr->getParent(), *DefInstr);
102 }
103 return false;
104 }
105
106 // This ensures explicit instantiation of
107 // GenericUniformityAnalysisImpl::ImplDeleter::operator()
108 template class llvm::GenericUniformityInfo<SSAContext>;
109 template struct llvm::GenericUniformityAnalysisImplDeleter<
110 llvm::GenericUniformityAnalysisImpl<SSAContext>>;
111
112 //===----------------------------------------------------------------------===//
113 // UniformityInfoAnalysis and related pass implementations
114 //===----------------------------------------------------------------------===//
115
run(Function & F,FunctionAnalysisManager & FAM)116 llvm::UniformityInfo UniformityInfoAnalysis::run(Function &F,
117 FunctionAnalysisManager &FAM) {
118 auto &DT = FAM.getResult<DominatorTreeAnalysis>(F);
119 auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
120 auto &CI = FAM.getResult<CycleAnalysis>(F);
121 UniformityInfo UI{DT, CI, &TTI};
122 // Skip computation if we can assume everything is uniform.
123 if (TTI.hasBranchDivergence(&F))
124 UI.compute();
125
126 return UI;
127 }
128
129 AnalysisKey UniformityInfoAnalysis::Key;
130
UniformityInfoPrinterPass(raw_ostream & OS)131 UniformityInfoPrinterPass::UniformityInfoPrinterPass(raw_ostream &OS)
132 : OS(OS) {}
133
run(Function & F,FunctionAnalysisManager & AM)134 PreservedAnalyses UniformityInfoPrinterPass::run(Function &F,
135 FunctionAnalysisManager &AM) {
136 OS << "UniformityInfo for function '" << F.getName() << "':\n";
137 AM.getResult<UniformityInfoAnalysis>(F).print(OS);
138
139 return PreservedAnalyses::all();
140 }
141
142 //===----------------------------------------------------------------------===//
143 // UniformityInfoWrapperPass Implementation
144 //===----------------------------------------------------------------------===//
145
146 char UniformityInfoWrapperPass::ID = 0;
147
UniformityInfoWrapperPass()148 UniformityInfoWrapperPass::UniformityInfoWrapperPass() : FunctionPass(ID) {
149 initializeUniformityInfoWrapperPassPass(*PassRegistry::getPassRegistry());
150 }
151
152 INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass, "uniformity",
153 "Uniformity Analysis", true, true)
INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)154 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
155 INITIALIZE_PASS_DEPENDENCY(CycleInfoWrapperPass)
156 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
157 INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniformity",
158 "Uniformity Analysis", true, true)
159
160 void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
161 AU.setPreservesAll();
162 AU.addRequired<DominatorTreeWrapperPass>();
163 AU.addRequiredTransitive<CycleInfoWrapperPass>();
164 AU.addRequired<TargetTransformInfoWrapperPass>();
165 }
166
runOnFunction(Function & F)167 bool UniformityInfoWrapperPass::runOnFunction(Function &F) {
168 auto &cycleInfo = getAnalysis<CycleInfoWrapperPass>().getResult();
169 auto &domTree = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
170 auto &targetTransformInfo =
171 getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
172
173 m_function = &F;
174 m_uniformityInfo = UniformityInfo{domTree, cycleInfo, &targetTransformInfo};
175
176 // Skip computation if we can assume everything is uniform.
177 if (targetTransformInfo.hasBranchDivergence(m_function))
178 m_uniformityInfo.compute();
179
180 return false;
181 }
182
print(raw_ostream & OS,const Module *) const183 void UniformityInfoWrapperPass::print(raw_ostream &OS, const Module *) const {
184 OS << "UniformityInfo for function '" << m_function->getName() << "':\n";
185 }
186
releaseMemory()187 void UniformityInfoWrapperPass::releaseMemory() {
188 m_uniformityInfo = UniformityInfo{};
189 m_function = nullptr;
190 }
191