1 //===- ConvergenceUtils.cpp -----------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #include "llvm/Analysis/UniformityAnalysis.h"
10 #include "llvm/ADT/GenericUniformityImpl.h"
11 #include "llvm/Analysis/CycleAnalysis.h"
12 #include "llvm/Analysis/TargetTransformInfo.h"
13 #include "llvm/IR/Constants.h"
14 #include "llvm/IR/Dominators.h"
15 #include "llvm/IR/InstIterator.h"
16 #include "llvm/IR/Instructions.h"
17 #include "llvm/InitializePasses.h"
18 
19 using namespace llvm;
20 
21 template <>
22 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::hasDivergentDefs(
23     const Instruction &I) const {
24   return isDivergent((const Value *)&I);
25 }
26 
27 template <>
28 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::markDefsDivergent(
29     const Instruction &Instr, bool AllDefsDivergent) {
30   return markDivergent(&Instr);
31 }
32 
33 template <> void llvm::GenericUniformityAnalysisImpl<SSAContext>::initialize() {
34   for (auto &I : instructions(F)) {
35     if (TTI->isSourceOfDivergence(&I)) {
36       assert(!I.isTerminator());
37       markDivergent(I);
38     } else if (TTI->isAlwaysUniform(&I)) {
39       addUniformOverride(I);
40     }
41   }
42   for (auto &Arg : F.args()) {
43     if (TTI->isSourceOfDivergence(&Arg)) {
44       markDivergent(&Arg);
45     }
46   }
47 }
48 
49 template <>
50 void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
51     const Value *V) {
52   for (const auto *User : V->users()) {
53     const auto *UserInstr = dyn_cast<const Instruction>(User);
54     if (!UserInstr)
55       continue;
56     if (isAlwaysUniform(*UserInstr))
57       continue;
58     if (markDivergent(*UserInstr)) {
59       Worklist.push_back(UserInstr);
60     }
61   }
62 }
63 
64 template <>
65 void llvm::GenericUniformityAnalysisImpl<SSAContext>::pushUsers(
66     const Instruction &Instr) {
67   assert(!isAlwaysUniform(Instr));
68   if (Instr.isTerminator())
69     return;
70   pushUsers(cast<Value>(&Instr));
71 }
72 
73 template <>
74 bool llvm::GenericUniformityAnalysisImpl<SSAContext>::usesValueFromCycle(
75     const Instruction &I, const Cycle &DefCycle) const {
76   if (isAlwaysUniform(I))
77     return false;
78   for (const Use &U : I.operands()) {
79     if (auto *I = dyn_cast<Instruction>(&U)) {
80       if (DefCycle.contains(I->getParent()))
81         return true;
82     }
83   }
84   return false;
85 }
86 
87 // This ensures explicit instantiation of
88 // GenericUniformityAnalysisImpl::ImplDeleter::operator()
89 template class llvm::GenericUniformityInfo<SSAContext>;
90 template struct llvm::GenericUniformityAnalysisImplDeleter<
91     llvm::GenericUniformityAnalysisImpl<SSAContext>>;
92 
93 //===----------------------------------------------------------------------===//
94 //  UniformityInfoAnalysis and related pass implementations
95 //===----------------------------------------------------------------------===//
96 
97 llvm::UniformityInfo UniformityInfoAnalysis::run(Function &F,
98                                                  FunctionAnalysisManager &FAM) {
99   auto &DT = FAM.getResult<DominatorTreeAnalysis>(F);
100   auto &TTI = FAM.getResult<TargetIRAnalysis>(F);
101   auto &CI = FAM.getResult<CycleAnalysis>(F);
102   return UniformityInfo{F, DT, CI, &TTI};
103 }
104 
105 AnalysisKey UniformityInfoAnalysis::Key;
106 
107 UniformityInfoPrinterPass::UniformityInfoPrinterPass(raw_ostream &OS)
108     : OS(OS) {}
109 
110 PreservedAnalyses UniformityInfoPrinterPass::run(Function &F,
111                                                  FunctionAnalysisManager &AM) {
112   OS << "UniformityInfo for function '" << F.getName() << "':\n";
113   AM.getResult<UniformityInfoAnalysis>(F).print(OS);
114 
115   return PreservedAnalyses::all();
116 }
117 
118 //===----------------------------------------------------------------------===//
119 //  UniformityInfoWrapperPass Implementation
120 //===----------------------------------------------------------------------===//
121 
122 char UniformityInfoWrapperPass::ID = 0;
123 
124 UniformityInfoWrapperPass::UniformityInfoWrapperPass() : FunctionPass(ID) {
125   initializeUniformityInfoWrapperPassPass(*PassRegistry::getPassRegistry());
126 }
127 
128 INITIALIZE_PASS_BEGIN(UniformityInfoWrapperPass, "uniforminfo",
129                       "Uniform Info Analysis", true, true)
130 INITIALIZE_PASS_DEPENDENCY(DominatorTreeWrapperPass)
131 INITIALIZE_PASS_DEPENDENCY(TargetTransformInfoWrapperPass)
132 INITIALIZE_PASS_END(UniformityInfoWrapperPass, "uniforminfo",
133                     "Uniform Info Analysis", true, true)
134 
135 void UniformityInfoWrapperPass::getAnalysisUsage(AnalysisUsage &AU) const {
136   AU.setPreservesAll();
137   AU.addRequired<DominatorTreeWrapperPass>();
138   AU.addRequired<CycleInfoWrapperPass>();
139   AU.addRequired<TargetTransformInfoWrapperPass>();
140 }
141 
142 bool UniformityInfoWrapperPass::runOnFunction(Function &F) {
143   auto &cycleInfo = getAnalysis<CycleInfoWrapperPass>().getResult();
144   auto &domTree = getAnalysis<DominatorTreeWrapperPass>().getDomTree();
145   auto &targetTransformInfo =
146       getAnalysis<TargetTransformInfoWrapperPass>().getTTI(F);
147 
148   m_function = &F;
149   m_uniformityInfo =
150       UniformityInfo{F, domTree, cycleInfo, &targetTransformInfo};
151   return false;
152 }
153 
154 void UniformityInfoWrapperPass::print(raw_ostream &OS, const Module *) const {
155   OS << "UniformityInfo for function '" << m_function->getName() << "':\n";
156 }
157 
158 void UniformityInfoWrapperPass::releaseMemory() {
159   m_uniformityInfo = UniformityInfo{};
160   m_function = nullptr;
161 }
162