1 //===- MachineUniformityAnalysis.cpp --------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8
9 #include "llvm/CodeGen/MachineUniformityAnalysis.h"
10 #include "llvm/ADT/GenericUniformityImpl.h"
11 #include "llvm/CodeGen/MachineCycleAnalysis.h"
12 #include "llvm/CodeGen/MachineDominators.h"
13 #include "llvm/CodeGen/MachineRegisterInfo.h"
14 #include "llvm/CodeGen/MachineSSAContext.h"
15 #include "llvm/CodeGen/TargetInstrInfo.h"
16 #include "llvm/InitializePasses.h"
17
18 using namespace llvm;
19
20 template <>
hasDivergentDefs(const MachineInstr & I) const21 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::hasDivergentDefs(
22 const MachineInstr &I) const {
23 for (auto &op : I.all_defs()) {
24 if (isDivergent(op.getReg()))
25 return true;
26 }
27 return false;
28 }
29
30 template <>
markDefsDivergent(const MachineInstr & Instr)31 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::markDefsDivergent(
32 const MachineInstr &Instr) {
33 bool insertedDivergent = false;
34 const auto &MRI = F.getRegInfo();
35 const auto &RBI = *F.getSubtarget().getRegBankInfo();
36 const auto &TRI = *MRI.getTargetRegisterInfo();
37 for (auto &op : Instr.all_defs()) {
38 if (!op.getReg().isVirtual())
39 continue;
40 assert(!op.getSubReg());
41 if (TRI.isUniformReg(MRI, RBI, op.getReg()))
42 continue;
43 insertedDivergent |= markDivergent(op.getReg());
44 }
45 return insertedDivergent;
46 }
47
48 template <>
initialize()49 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::initialize() {
50 const auto &InstrInfo = *F.getSubtarget().getInstrInfo();
51
52 for (const MachineBasicBlock &block : F) {
53 for (const MachineInstr &instr : block) {
54 auto uniformity = InstrInfo.getInstructionUniformity(instr);
55 if (uniformity == InstructionUniformity::AlwaysUniform) {
56 addUniformOverride(instr);
57 continue;
58 }
59
60 if (uniformity == InstructionUniformity::NeverUniform) {
61 markDivergent(instr);
62 }
63 }
64 }
65 }
66
67 template <>
pushUsers(Register Reg)68 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers(
69 Register Reg) {
70 assert(isDivergent(Reg));
71 const auto &RegInfo = F.getRegInfo();
72 for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) {
73 markDivergent(UserInstr);
74 }
75 }
76
77 template <>
pushUsers(const MachineInstr & Instr)78 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::pushUsers(
79 const MachineInstr &Instr) {
80 assert(!isAlwaysUniform(Instr));
81 if (Instr.isTerminator())
82 return;
83 for (const MachineOperand &op : Instr.all_defs()) {
84 auto Reg = op.getReg();
85 if (isDivergent(Reg))
86 pushUsers(Reg);
87 }
88 }
89
90 template <>
usesValueFromCycle(const MachineInstr & I,const MachineCycle & DefCycle) const91 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::usesValueFromCycle(
92 const MachineInstr &I, const MachineCycle &DefCycle) const {
93 assert(!isAlwaysUniform(I));
94 for (auto &Op : I.operands()) {
95 if (!Op.isReg() || !Op.readsReg())
96 continue;
97 auto Reg = Op.getReg();
98
99 // FIXME: Physical registers need to be properly checked instead of always
100 // returning true
101 if (Reg.isPhysical())
102 return true;
103
104 auto *Def = F.getRegInfo().getVRegDef(Reg);
105 if (DefCycle.contains(Def->getParent()))
106 return true;
107 }
108 return false;
109 }
110
111 template <>
112 void llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::
propagateTemporalDivergence(const MachineInstr & I,const MachineCycle & DefCycle)113 propagateTemporalDivergence(const MachineInstr &I,
114 const MachineCycle &DefCycle) {
115 const auto &RegInfo = F.getRegInfo();
116 for (auto &Op : I.all_defs()) {
117 if (!Op.getReg().isVirtual())
118 continue;
119 auto Reg = Op.getReg();
120 if (isDivergent(Reg))
121 continue;
122 for (MachineInstr &UserInstr : RegInfo.use_instructions(Reg)) {
123 if (DefCycle.contains(UserInstr.getParent()))
124 continue;
125 markDivergent(UserInstr);
126 }
127 }
128 }
129
130 template <>
isDivergentUse(const MachineOperand & U) const131 bool llvm::GenericUniformityAnalysisImpl<MachineSSAContext>::isDivergentUse(
132 const MachineOperand &U) const {
133 if (!U.isReg())
134 return false;
135
136 auto Reg = U.getReg();
137 if (isDivergent(Reg))
138 return true;
139
140 const auto &RegInfo = F.getRegInfo();
141 auto *Def = RegInfo.getOneDef(Reg);
142 if (!Def)
143 return true;
144
145 auto *DefInstr = Def->getParent();
146 auto *UseInstr = U.getParent();
147 return isTemporalDivergent(*UseInstr->getParent(), *DefInstr);
148 }
149
150 // This ensures explicit instantiation of
151 // GenericUniformityAnalysisImpl::ImplDeleter::operator()
152 template class llvm::GenericUniformityInfo<MachineSSAContext>;
153 template struct llvm::GenericUniformityAnalysisImplDeleter<
154 llvm::GenericUniformityAnalysisImpl<MachineSSAContext>>;
155
computeMachineUniformityInfo(MachineFunction & F,const MachineCycleInfo & cycleInfo,const MachineDomTree & domTree,bool HasBranchDivergence)156 MachineUniformityInfo llvm::computeMachineUniformityInfo(
157 MachineFunction &F, const MachineCycleInfo &cycleInfo,
158 const MachineDomTree &domTree, bool HasBranchDivergence) {
159 assert(F.getRegInfo().isSSA() && "Expected to be run on SSA form!");
160 MachineUniformityInfo UI(domTree, cycleInfo);
161 if (HasBranchDivergence)
162 UI.compute();
163 return UI;
164 }
165
166 namespace {
167
168 /// Legacy analysis pass which computes a \ref MachineUniformityInfo.
169 class MachineUniformityAnalysisPass : public MachineFunctionPass {
170 MachineUniformityInfo UI;
171
172 public:
173 static char ID;
174
175 MachineUniformityAnalysisPass();
176
getUniformityInfo()177 MachineUniformityInfo &getUniformityInfo() { return UI; }
getUniformityInfo() const178 const MachineUniformityInfo &getUniformityInfo() const { return UI; }
179
180 bool runOnMachineFunction(MachineFunction &F) override;
181 void getAnalysisUsage(AnalysisUsage &AU) const override;
182 void print(raw_ostream &OS, const Module *M = nullptr) const override;
183
184 // TODO: verify analysis
185 };
186
187 class MachineUniformityInfoPrinterPass : public MachineFunctionPass {
188 public:
189 static char ID;
190
191 MachineUniformityInfoPrinterPass();
192
193 bool runOnMachineFunction(MachineFunction &F) override;
194 void getAnalysisUsage(AnalysisUsage &AU) const override;
195 };
196
197 } // namespace
198
199 char MachineUniformityAnalysisPass::ID = 0;
200
MachineUniformityAnalysisPass()201 MachineUniformityAnalysisPass::MachineUniformityAnalysisPass()
202 : MachineFunctionPass(ID) {
203 initializeMachineUniformityAnalysisPassPass(*PassRegistry::getPassRegistry());
204 }
205
206 INITIALIZE_PASS_BEGIN(MachineUniformityAnalysisPass, "machine-uniformity",
207 "Machine Uniformity Info Analysis", true, true)
INITIALIZE_PASS_DEPENDENCY(MachineCycleInfoWrapperPass)208 INITIALIZE_PASS_DEPENDENCY(MachineCycleInfoWrapperPass)
209 INITIALIZE_PASS_DEPENDENCY(MachineDominatorTree)
210 INITIALIZE_PASS_END(MachineUniformityAnalysisPass, "machine-uniformity",
211 "Machine Uniformity Info Analysis", true, true)
212
213 void MachineUniformityAnalysisPass::getAnalysisUsage(AnalysisUsage &AU) const {
214 AU.setPreservesAll();
215 AU.addRequired<MachineCycleInfoWrapperPass>();
216 AU.addRequired<MachineDominatorTree>();
217 MachineFunctionPass::getAnalysisUsage(AU);
218 }
219
runOnMachineFunction(MachineFunction & MF)220 bool MachineUniformityAnalysisPass::runOnMachineFunction(MachineFunction &MF) {
221 auto &DomTree = getAnalysis<MachineDominatorTree>().getBase();
222 auto &CI = getAnalysis<MachineCycleInfoWrapperPass>().getCycleInfo();
223 // FIXME: Query TTI::hasBranchDivergence. -run-pass seems to end up with a
224 // default NoTTI
225 UI = computeMachineUniformityInfo(MF, CI, DomTree, true);
226 return false;
227 }
228
print(raw_ostream & OS,const Module *) const229 void MachineUniformityAnalysisPass::print(raw_ostream &OS,
230 const Module *) const {
231 OS << "MachineUniformityInfo for function: " << UI.getFunction().getName()
232 << "\n";
233 UI.print(OS);
234 }
235
236 char MachineUniformityInfoPrinterPass::ID = 0;
237
MachineUniformityInfoPrinterPass()238 MachineUniformityInfoPrinterPass::MachineUniformityInfoPrinterPass()
239 : MachineFunctionPass(ID) {
240 initializeMachineUniformityInfoPrinterPassPass(
241 *PassRegistry::getPassRegistry());
242 }
243
244 INITIALIZE_PASS_BEGIN(MachineUniformityInfoPrinterPass,
245 "print-machine-uniformity",
246 "Print Machine Uniformity Info Analysis", true, true)
INITIALIZE_PASS_DEPENDENCY(MachineUniformityAnalysisPass)247 INITIALIZE_PASS_DEPENDENCY(MachineUniformityAnalysisPass)
248 INITIALIZE_PASS_END(MachineUniformityInfoPrinterPass,
249 "print-machine-uniformity",
250 "Print Machine Uniformity Info Analysis", true, true)
251
252 void MachineUniformityInfoPrinterPass::getAnalysisUsage(
253 AnalysisUsage &AU) const {
254 AU.setPreservesAll();
255 AU.addRequired<MachineUniformityAnalysisPass>();
256 MachineFunctionPass::getAnalysisUsage(AU);
257 }
258
runOnMachineFunction(MachineFunction & F)259 bool MachineUniformityInfoPrinterPass::runOnMachineFunction(
260 MachineFunction &F) {
261 auto &UI = getAnalysis<MachineUniformityAnalysisPass>();
262 UI.print(errs());
263 return false;
264 }
265