1 //===- SPIRVModuleAnalysis.cpp - analysis of global instrs & regs - C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The analysis collects instructions that should be output at the module level
10 // and performs the global register numbering.
11 //
12 // The results of this analysis are used in AsmPrinter to rename registers
13 // globally and to output required instructions at the module level.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "SPIRVModuleAnalysis.h"
18 #include "SPIRV.h"
19 #include "SPIRVGlobalRegistry.h"
20 #include "SPIRVSubtarget.h"
21 #include "SPIRVTargetMachine.h"
22 #include "SPIRVUtils.h"
23 #include "TargetInfo/SPIRVTargetInfo.h"
24 #include "llvm/CodeGen/MachineModuleInfo.h"
25 #include "llvm/CodeGen/TargetPassConfig.h"
26 
27 using namespace llvm;
28 
29 #define DEBUG_TYPE "spirv-module-analysis"
30 
31 static cl::opt<bool>
32     SPVDumpDeps("spv-dump-deps",
33                 cl::desc("Dump MIR with SPIR-V dependencies info"),
34                 cl::Optional, cl::init(false));
35 
36 char llvm::SPIRVModuleAnalysis::ID = 0;
37 
38 namespace llvm {
39 void initializeSPIRVModuleAnalysisPass(PassRegistry &);
40 } // namespace llvm
41 
42 INITIALIZE_PASS(SPIRVModuleAnalysis, DEBUG_TYPE, "SPIRV module analysis", true,
43                 true)
44 
45 // Retrieve an unsigned from an MDNode with a list of them as operands.
46 static unsigned getMetadataUInt(MDNode *MdNode, unsigned OpIndex,
47                                 unsigned DefaultVal = 0) {
48   if (MdNode && OpIndex < MdNode->getNumOperands()) {
49     const auto &Op = MdNode->getOperand(OpIndex);
50     return mdconst::extract<ConstantInt>(Op)->getZExtValue();
51   }
52   return DefaultVal;
53 }
54 
55 void SPIRVModuleAnalysis::setBaseInfo(const Module &M) {
56   MAI.MaxID = 0;
57   for (int i = 0; i < SPIRV::NUM_MODULE_SECTIONS; i++)
58     MAI.MS[i].clear();
59   MAI.RegisterAliasTable.clear();
60   MAI.InstrsToDelete.clear();
61   MAI.FuncNameMap.clear();
62   MAI.GlobalVarList.clear();
63 
64   // TODO: determine memory model and source language from the configuratoin.
65   MAI.Mem = SPIRV::MemoryModel::OpenCL;
66   MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_C;
67   unsigned PtrSize = ST->getPointerSize();
68   MAI.Addr = PtrSize == 32   ? SPIRV::AddressingModel::Physical32
69              : PtrSize == 64 ? SPIRV::AddressingModel::Physical64
70                              : SPIRV::AddressingModel::Logical;
71   // Get the OpenCL version number from metadata.
72   // TODO: support other source languages.
73   MAI.SrcLangVersion = 0;
74   if (auto VerNode = M.getNamedMetadata("opencl.ocl.version")) {
75     // Construct version literal according to OpenCL 2.2 environment spec.
76     auto VersionMD = VerNode->getOperand(0);
77     unsigned MajorNum = getMetadataUInt(VersionMD, 0, 2);
78     unsigned MinorNum = getMetadataUInt(VersionMD, 1);
79     unsigned RevNum = getMetadataUInt(VersionMD, 2);
80     MAI.SrcLangVersion = 0 | (MajorNum << 16) | (MinorNum << 8) | RevNum;
81   }
82 }
83 
84 // True if there is an instruction in the MS list with all the same operands as
85 // the given instruction has (after the given starting index).
86 // TODO: maybe it needs to check Opcodes too.
87 static bool findSameInstrInMS(const MachineInstr &A,
88                               SPIRV::ModuleSectionType MSType,
89                               SPIRV::ModuleAnalysisInfo &MAI,
90                               bool UpdateRegAliases,
91                               unsigned StartOpIndex = 0) {
92   for (const auto *B : MAI.MS[MSType]) {
93     const unsigned NumAOps = A.getNumOperands();
94     if (NumAOps == B->getNumOperands() && A.getNumDefs() == B->getNumDefs()) {
95       bool AllOpsMatch = true;
96       for (unsigned i = StartOpIndex; i < NumAOps && AllOpsMatch; ++i) {
97         if (A.getOperand(i).isReg() && B->getOperand(i).isReg()) {
98           Register RegA = A.getOperand(i).getReg();
99           Register RegB = B->getOperand(i).getReg();
100           AllOpsMatch = MAI.getRegisterAlias(A.getMF(), RegA) ==
101                         MAI.getRegisterAlias(B->getMF(), RegB);
102         } else {
103           AllOpsMatch = A.getOperand(i).isIdenticalTo(B->getOperand(i));
104         }
105       }
106       if (AllOpsMatch) {
107         if (UpdateRegAliases) {
108           assert(A.getOperand(0).isReg() && B->getOperand(0).isReg());
109           Register LocalReg = A.getOperand(0).getReg();
110           Register GlobalReg =
111               MAI.getRegisterAlias(B->getMF(), B->getOperand(0).getReg());
112           MAI.setRegisterAlias(A.getMF(), LocalReg, GlobalReg);
113         }
114         return true;
115       }
116     }
117   }
118   return false;
119 }
120 
121 // Collect MI which defines the register in the given machine function.
122 static void collectDefInstr(Register Reg, const MachineFunction *MF,
123                             SPIRV::ModuleAnalysisInfo *MAI,
124                             SPIRV::ModuleSectionType MSType,
125                             bool DoInsert = true) {
126   assert(MAI->hasRegisterAlias(MF, Reg) && "Cannot find register alias");
127   MachineInstr *MI = MF->getRegInfo().getUniqueVRegDef(Reg);
128   assert(MI && "There should be an instruction that defines the register");
129   MAI->setSkipEmission(MI);
130   if (DoInsert)
131     MAI->MS[MSType].push_back(MI);
132 }
133 
134 void SPIRVModuleAnalysis::collectGlobalEntities(
135     const std::vector<SPIRV::DTSortableEntry *> &DepsGraph,
136     SPIRV::ModuleSectionType MSType,
137     std::function<bool(const SPIRV::DTSortableEntry *)> Pred,
138     bool UsePreOrder) {
139   DenseSet<const SPIRV::DTSortableEntry *> Visited;
140   for (const auto *E : DepsGraph) {
141     std::function<void(const SPIRV::DTSortableEntry *)> RecHoistUtil;
142     // NOTE: here we prefer recursive approach over iterative because
143     // we don't expect depchains long enough to cause SO.
144     RecHoistUtil = [MSType, UsePreOrder, &Visited, &Pred,
145                     &RecHoistUtil](const SPIRV::DTSortableEntry *E) {
146       if (Visited.count(E) || !Pred(E))
147         return;
148       Visited.insert(E);
149 
150       // Traversing deps graph in post-order allows us to get rid of
151       // register aliases preprocessing.
152       // But pre-order is required for correct processing of function
153       // declaration and arguments processing.
154       if (!UsePreOrder)
155         for (auto *S : E->getDeps())
156           RecHoistUtil(S);
157 
158       Register GlobalReg = Register::index2VirtReg(MAI.getNextID());
159       bool IsFirst = true;
160       for (auto &U : *E) {
161         const MachineFunction *MF = U.first;
162         Register Reg = U.second;
163         MAI.setRegisterAlias(MF, Reg, GlobalReg);
164         if (!MF->getRegInfo().getUniqueVRegDef(Reg))
165           continue;
166         collectDefInstr(Reg, MF, &MAI, MSType, IsFirst);
167         IsFirst = false;
168         if (E->getIsGV())
169           MAI.GlobalVarList.push_back(MF->getRegInfo().getUniqueVRegDef(Reg));
170       }
171 
172       if (UsePreOrder)
173         for (auto *S : E->getDeps())
174           RecHoistUtil(S);
175     };
176     RecHoistUtil(E);
177   }
178 }
179 
180 // The function initializes global register alias table for types, consts,
181 // global vars and func decls and collects these instruction for output
182 // at module level. Also it collects explicit OpExtension/OpCapability
183 // instructions.
184 void SPIRVModuleAnalysis::processDefInstrs(const Module &M) {
185   std::vector<SPIRV::DTSortableEntry *> DepsGraph;
186 
187   GR->buildDepsGraph(DepsGraph, SPVDumpDeps ? MMI : nullptr);
188 
189   collectGlobalEntities(
190       DepsGraph, SPIRV::MB_TypeConstVars,
191       [](const SPIRV::DTSortableEntry *E) { return !E->getIsFunc(); }, false);
192 
193   collectGlobalEntities(
194       DepsGraph, SPIRV::MB_ExtFuncDecls,
195       [](const SPIRV::DTSortableEntry *E) { return E->getIsFunc(); }, true);
196 }
197 
198 // Look for IDs declared with Import linkage, and map the imported name string
199 // to the register defining that variable (which will usually be the result of
200 // an OpFunction). This lets us call externally imported functions using
201 // the correct ID registers.
202 void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI,
203                                            const Function &F) {
204   if (MI.getOpcode() == SPIRV::OpDecorate) {
205     // If it's got Import linkage.
206     auto Dec = MI.getOperand(1).getImm();
207     if (Dec == static_cast<unsigned>(SPIRV::Decoration::LinkageAttributes)) {
208       auto Lnk = MI.getOperand(MI.getNumOperands() - 1).getImm();
209       if (Lnk == static_cast<unsigned>(SPIRV::LinkageType::Import)) {
210         // Map imported function name to function ID register.
211         std::string Name = getStringImm(MI, 2);
212         Register Target = MI.getOperand(0).getReg();
213         // TODO: check defs from different MFs.
214         MAI.FuncNameMap[Name] = MAI.getRegisterAlias(MI.getMF(), Target);
215       }
216     }
217   } else if (MI.getOpcode() == SPIRV::OpFunction) {
218     // Record all internal OpFunction declarations.
219     Register Reg = MI.defs().begin()->getReg();
220     Register GlobalReg = MAI.getRegisterAlias(MI.getMF(), Reg);
221     assert(GlobalReg.isValid());
222     // TODO: check that it does not conflict with existing entries.
223     MAI.FuncNameMap[F.getGlobalIdentifier()] = GlobalReg;
224   }
225 }
226 
227 // Collect the given instruction in the specified MS. We assume global register
228 // numbering has already occurred by this point. We can directly compare reg
229 // arguments when detecting duplicates.
230 static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI,
231                               SPIRV::ModuleSectionType MSType) {
232   MAI.setSkipEmission(&MI);
233   if (findSameInstrInMS(MI, MSType, MAI, false))
234     return; // Found a duplicate, so don't add it.
235   // No duplicates, so add it.
236   MAI.MS[MSType].push_back(&MI);
237 }
238 
239 // Some global instructions make reference to function-local ID regs, so cannot
240 // be correctly collected until these registers are globally numbered.
241 void SPIRVModuleAnalysis::processOtherInstrs(const Module &M) {
242   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
243     if ((*F).isDeclaration())
244       continue;
245     MachineFunction *MF = MMI->getMachineFunction(*F);
246     assert(MF);
247     for (MachineBasicBlock &MBB : *MF)
248       for (MachineInstr &MI : MBB) {
249         if (MAI.getSkipEmission(&MI))
250           continue;
251         const unsigned OpCode = MI.getOpcode();
252         if (OpCode == SPIRV::OpName || OpCode == SPIRV::OpMemberName) {
253           collectOtherInstr(MI, MAI, SPIRV::MB_DebugNames);
254         } else if (OpCode == SPIRV::OpEntryPoint) {
255           collectOtherInstr(MI, MAI, SPIRV::MB_EntryPoints);
256         } else if (TII->isDecorationInstr(MI)) {
257           collectOtherInstr(MI, MAI, SPIRV::MB_Annotations);
258           collectFuncNames(MI, *F);
259         } else if (OpCode == SPIRV::OpFunction) {
260           collectFuncNames(MI, *F);
261         }
262       }
263   }
264 }
265 
266 // Number registers in all functions globally from 0 onwards and store
267 // the result in global register alias table.
268 void SPIRVModuleAnalysis::numberRegistersGlobally(const Module &M) {
269   for (auto F = M.begin(), E = M.end(); F != E; ++F) {
270     if ((*F).isDeclaration())
271       continue;
272     MachineFunction *MF = MMI->getMachineFunction(*F);
273     assert(MF);
274     for (MachineBasicBlock &MBB : *MF) {
275       for (MachineInstr &MI : MBB) {
276         for (MachineOperand &Op : MI.operands()) {
277           if (!Op.isReg())
278             continue;
279           Register Reg = Op.getReg();
280           if (MAI.hasRegisterAlias(MF, Reg))
281             continue;
282           Register NewReg = Register::index2VirtReg(MAI.getNextID());
283           MAI.setRegisterAlias(MF, Reg, NewReg);
284         }
285       }
286     }
287   }
288 }
289 
290 struct SPIRV::ModuleAnalysisInfo SPIRVModuleAnalysis::MAI;
291 
292 void SPIRVModuleAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {
293   AU.addRequired<TargetPassConfig>();
294   AU.addRequired<MachineModuleInfoWrapperPass>();
295 }
296 
297 bool SPIRVModuleAnalysis::runOnModule(Module &M) {
298   SPIRVTargetMachine &TM =
299       getAnalysis<TargetPassConfig>().getTM<SPIRVTargetMachine>();
300   ST = TM.getSubtargetImpl();
301   GR = ST->getSPIRVGlobalRegistry();
302   TII = ST->getInstrInfo();
303 
304   MMI = &getAnalysis<MachineModuleInfoWrapperPass>().getMMI();
305 
306   setBaseInfo(M);
307 
308   // TODO: Process type/const/global var/func decl instructions, number their
309   // destination registers from 0 to N, collect Extensions and Capabilities.
310   processDefInstrs(M);
311 
312   // Number rest of registers from N+1 onwards.
313   numberRegistersGlobally(M);
314 
315   // Collect OpName, OpEntryPoint, OpDecorate etc, process other instructions.
316   processOtherInstrs(M);
317 
318   return false;
319 }
320