1 //===- SPIRVModuleAnalysis.cpp - analysis of global instrs & regs - C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // The analysis collects instructions that should be output at the module level
10 // and performs the global register numbering.
11 //
12 // The results of this analysis are used in AsmPrinter to rename registers
13 // globally and to output required instructions at the module level.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #include "SPIRVModuleAnalysis.h"
18 #include "SPIRV.h"
19 #include "SPIRVSubtarget.h"
20 #include "SPIRVTargetMachine.h"
21 #include "SPIRVUtils.h"
22 #include "TargetInfo/SPIRVTargetInfo.h"
23 #include "llvm/ADT/STLExtras.h"
24 #include "llvm/CodeGen/MachineModuleInfo.h"
25 #include "llvm/CodeGen/TargetPassConfig.h"
26
27 using namespace llvm;
28
29 #define DEBUG_TYPE "spirv-module-analysis"
30
31 static cl::opt<bool>
32 SPVDumpDeps("spv-dump-deps",
33 cl::desc("Dump MIR with SPIR-V dependencies info"),
34 cl::Optional, cl::init(false));
35
36 char llvm::SPIRVModuleAnalysis::ID = 0;
37
38 namespace llvm {
39 void initializeSPIRVModuleAnalysisPass(PassRegistry &);
40 } // namespace llvm
41
42 INITIALIZE_PASS(SPIRVModuleAnalysis, DEBUG_TYPE, "SPIRV module analysis", true,
43 true)
44
45 // Retrieve an unsigned from an MDNode with a list of them as operands.
getMetadataUInt(MDNode * MdNode,unsigned OpIndex,unsigned DefaultVal=0)46 static unsigned getMetadataUInt(MDNode *MdNode, unsigned OpIndex,
47 unsigned DefaultVal = 0) {
48 if (MdNode && OpIndex < MdNode->getNumOperands()) {
49 const auto &Op = MdNode->getOperand(OpIndex);
50 return mdconst::extract<ConstantInt>(Op)->getZExtValue();
51 }
52 return DefaultVal;
53 }
54
55 static SPIRV::Requirements
getSymbolicOperandRequirements(SPIRV::OperandCategory::OperandCategory Category,unsigned i,const SPIRVSubtarget & ST,SPIRV::RequirementHandler & Reqs)56 getSymbolicOperandRequirements(SPIRV::OperandCategory::OperandCategory Category,
57 unsigned i, const SPIRVSubtarget &ST,
58 SPIRV::RequirementHandler &Reqs) {
59 unsigned ReqMinVer = getSymbolicOperandMinVersion(Category, i);
60 unsigned ReqMaxVer = getSymbolicOperandMaxVersion(Category, i);
61 unsigned TargetVer = ST.getSPIRVVersion();
62 bool MinVerOK = !ReqMinVer || !TargetVer || TargetVer >= ReqMinVer;
63 bool MaxVerOK = !ReqMaxVer || !TargetVer || TargetVer <= ReqMaxVer;
64 CapabilityList ReqCaps = getSymbolicOperandCapabilities(Category, i);
65 ExtensionList ReqExts = getSymbolicOperandExtensions(Category, i);
66 if (ReqCaps.empty()) {
67 if (ReqExts.empty()) {
68 if (MinVerOK && MaxVerOK)
69 return {true, {}, {}, ReqMinVer, ReqMaxVer};
70 return {false, {}, {}, 0, 0};
71 }
72 } else if (MinVerOK && MaxVerOK) {
73 for (auto Cap : ReqCaps) { // Only need 1 of the capabilities to work.
74 if (Reqs.isCapabilityAvailable(Cap))
75 return {true, {Cap}, {}, ReqMinVer, ReqMaxVer};
76 }
77 }
78 // If there are no capabilities, or we can't satisfy the version or
79 // capability requirements, use the list of extensions (if the subtarget
80 // can handle them all).
81 if (llvm::all_of(ReqExts, [&ST](const SPIRV::Extension::Extension &Ext) {
82 return ST.canUseExtension(Ext);
83 })) {
84 return {true, {}, ReqExts, 0, 0}; // TODO: add versions to extensions.
85 }
86 return {false, {}, {}, 0, 0};
87 }
88
setBaseInfo(const Module & M)89 void SPIRVModuleAnalysis::setBaseInfo(const Module &M) {
90 MAI.MaxID = 0;
91 for (int i = 0; i < SPIRV::NUM_MODULE_SECTIONS; i++)
92 MAI.MS[i].clear();
93 MAI.RegisterAliasTable.clear();
94 MAI.InstrsToDelete.clear();
95 MAI.FuncMap.clear();
96 MAI.GlobalVarList.clear();
97 MAI.ExtInstSetMap.clear();
98 MAI.Reqs.clear();
99 MAI.Reqs.initAvailableCapabilities(*ST);
100
101 // TODO: determine memory model and source language from the configuratoin.
102 if (auto MemModel = M.getNamedMetadata("spirv.MemoryModel")) {
103 auto MemMD = MemModel->getOperand(0);
104 MAI.Addr = static_cast<SPIRV::AddressingModel::AddressingModel>(
105 getMetadataUInt(MemMD, 0));
106 MAI.Mem =
107 static_cast<SPIRV::MemoryModel::MemoryModel>(getMetadataUInt(MemMD, 1));
108 } else {
109 MAI.Mem = SPIRV::MemoryModel::OpenCL;
110 unsigned PtrSize = ST->getPointerSize();
111 MAI.Addr = PtrSize == 32 ? SPIRV::AddressingModel::Physical32
112 : PtrSize == 64 ? SPIRV::AddressingModel::Physical64
113 : SPIRV::AddressingModel::Logical;
114 }
115 // Get the OpenCL version number from metadata.
116 // TODO: support other source languages.
117 if (auto VerNode = M.getNamedMetadata("opencl.ocl.version")) {
118 MAI.SrcLang = SPIRV::SourceLanguage::OpenCL_C;
119 // Construct version literal in accordance with SPIRV-LLVM-Translator.
120 // TODO: support multiple OCL version metadata.
121 assert(VerNode->getNumOperands() > 0 && "Invalid SPIR");
122 auto VersionMD = VerNode->getOperand(0);
123 unsigned MajorNum = getMetadataUInt(VersionMD, 0, 2);
124 unsigned MinorNum = getMetadataUInt(VersionMD, 1);
125 unsigned RevNum = getMetadataUInt(VersionMD, 2);
126 MAI.SrcLangVersion = (MajorNum * 100 + MinorNum) * 1000 + RevNum;
127 } else {
128 MAI.SrcLang = SPIRV::SourceLanguage::Unknown;
129 MAI.SrcLangVersion = 0;
130 }
131
132 if (auto ExtNode = M.getNamedMetadata("opencl.used.extensions")) {
133 for (unsigned I = 0, E = ExtNode->getNumOperands(); I != E; ++I) {
134 MDNode *MD = ExtNode->getOperand(I);
135 if (!MD || MD->getNumOperands() == 0)
136 continue;
137 for (unsigned J = 0, N = MD->getNumOperands(); J != N; ++J)
138 MAI.SrcExt.insert(cast<MDString>(MD->getOperand(J))->getString());
139 }
140 }
141
142 // Update required capabilities for this memory model, addressing model and
143 // source language.
144 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand,
145 MAI.Mem, *ST);
146 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::SourceLanguageOperand,
147 MAI.SrcLang, *ST);
148 MAI.Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand,
149 MAI.Addr, *ST);
150
151 // TODO: check if it's required by default.
152 MAI.ExtInstSetMap[static_cast<unsigned>(SPIRV::InstructionSet::OpenCL_std)] =
153 Register::index2VirtReg(MAI.getNextID());
154 }
155
156 // Collect MI which defines the register in the given machine function.
collectDefInstr(Register Reg,const MachineFunction * MF,SPIRV::ModuleAnalysisInfo * MAI,SPIRV::ModuleSectionType MSType,bool DoInsert=true)157 static void collectDefInstr(Register Reg, const MachineFunction *MF,
158 SPIRV::ModuleAnalysisInfo *MAI,
159 SPIRV::ModuleSectionType MSType,
160 bool DoInsert = true) {
161 assert(MAI->hasRegisterAlias(MF, Reg) && "Cannot find register alias");
162 MachineInstr *MI = MF->getRegInfo().getUniqueVRegDef(Reg);
163 assert(MI && "There should be an instruction that defines the register");
164 MAI->setSkipEmission(MI);
165 if (DoInsert)
166 MAI->MS[MSType].push_back(MI);
167 }
168
collectGlobalEntities(const std::vector<SPIRV::DTSortableEntry * > & DepsGraph,SPIRV::ModuleSectionType MSType,std::function<bool (const SPIRV::DTSortableEntry *)> Pred,bool UsePreOrder=false)169 void SPIRVModuleAnalysis::collectGlobalEntities(
170 const std::vector<SPIRV::DTSortableEntry *> &DepsGraph,
171 SPIRV::ModuleSectionType MSType,
172 std::function<bool(const SPIRV::DTSortableEntry *)> Pred,
173 bool UsePreOrder = false) {
174 DenseSet<const SPIRV::DTSortableEntry *> Visited;
175 for (const auto *E : DepsGraph) {
176 std::function<void(const SPIRV::DTSortableEntry *)> RecHoistUtil;
177 // NOTE: here we prefer recursive approach over iterative because
178 // we don't expect depchains long enough to cause SO.
179 RecHoistUtil = [MSType, UsePreOrder, &Visited, &Pred,
180 &RecHoistUtil](const SPIRV::DTSortableEntry *E) {
181 if (Visited.count(E) || !Pred(E))
182 return;
183 Visited.insert(E);
184
185 // Traversing deps graph in post-order allows us to get rid of
186 // register aliases preprocessing.
187 // But pre-order is required for correct processing of function
188 // declaration and arguments processing.
189 if (!UsePreOrder)
190 for (auto *S : E->getDeps())
191 RecHoistUtil(S);
192
193 Register GlobalReg = Register::index2VirtReg(MAI.getNextID());
194 bool IsFirst = true;
195 for (auto &U : *E) {
196 const MachineFunction *MF = U.first;
197 Register Reg = U.second;
198 MAI.setRegisterAlias(MF, Reg, GlobalReg);
199 if (!MF->getRegInfo().getUniqueVRegDef(Reg))
200 continue;
201 collectDefInstr(Reg, MF, &MAI, MSType, IsFirst);
202 IsFirst = false;
203 if (E->getIsGV())
204 MAI.GlobalVarList.push_back(MF->getRegInfo().getUniqueVRegDef(Reg));
205 }
206
207 if (UsePreOrder)
208 for (auto *S : E->getDeps())
209 RecHoistUtil(S);
210 };
211 RecHoistUtil(E);
212 }
213 }
214
215 // The function initializes global register alias table for types, consts,
216 // global vars and func decls and collects these instruction for output
217 // at module level. Also it collects explicit OpExtension/OpCapability
218 // instructions.
processDefInstrs(const Module & M)219 void SPIRVModuleAnalysis::processDefInstrs(const Module &M) {
220 std::vector<SPIRV::DTSortableEntry *> DepsGraph;
221
222 GR->buildDepsGraph(DepsGraph, SPVDumpDeps ? MMI : nullptr);
223
224 collectGlobalEntities(
225 DepsGraph, SPIRV::MB_TypeConstVars,
226 [](const SPIRV::DTSortableEntry *E) { return !E->getIsFunc(); });
227
228 for (auto F = M.begin(), E = M.end(); F != E; ++F) {
229 MachineFunction *MF = MMI->getMachineFunction(*F);
230 if (!MF)
231 continue;
232 // Iterate through and collect OpExtension/OpCapability instructions.
233 for (MachineBasicBlock &MBB : *MF) {
234 for (MachineInstr &MI : MBB) {
235 if (MI.getOpcode() == SPIRV::OpExtension) {
236 // Here, OpExtension just has a single enum operand, not a string.
237 auto Ext = SPIRV::Extension::Extension(MI.getOperand(0).getImm());
238 MAI.Reqs.addExtension(Ext);
239 MAI.setSkipEmission(&MI);
240 } else if (MI.getOpcode() == SPIRV::OpCapability) {
241 auto Cap = SPIRV::Capability::Capability(MI.getOperand(0).getImm());
242 MAI.Reqs.addCapability(Cap);
243 MAI.setSkipEmission(&MI);
244 }
245 }
246 }
247 }
248
249 collectGlobalEntities(
250 DepsGraph, SPIRV::MB_ExtFuncDecls,
251 [](const SPIRV::DTSortableEntry *E) { return E->getIsFunc(); }, true);
252 }
253
254 // True if there is an instruction in the MS list with all the same operands as
255 // the given instruction has (after the given starting index).
256 // TODO: maybe it needs to check Opcodes too.
findSameInstrInMS(const MachineInstr & A,SPIRV::ModuleSectionType MSType,SPIRV::ModuleAnalysisInfo & MAI,unsigned StartOpIndex=0)257 static bool findSameInstrInMS(const MachineInstr &A,
258 SPIRV::ModuleSectionType MSType,
259 SPIRV::ModuleAnalysisInfo &MAI,
260 unsigned StartOpIndex = 0) {
261 for (const auto *B : MAI.MS[MSType]) {
262 const unsigned NumAOps = A.getNumOperands();
263 if (NumAOps != B->getNumOperands() || A.getNumDefs() != B->getNumDefs())
264 continue;
265 bool AllOpsMatch = true;
266 for (unsigned i = StartOpIndex; i < NumAOps && AllOpsMatch; ++i) {
267 if (A.getOperand(i).isReg() && B->getOperand(i).isReg()) {
268 Register RegA = A.getOperand(i).getReg();
269 Register RegB = B->getOperand(i).getReg();
270 AllOpsMatch = MAI.getRegisterAlias(A.getMF(), RegA) ==
271 MAI.getRegisterAlias(B->getMF(), RegB);
272 } else {
273 AllOpsMatch = A.getOperand(i).isIdenticalTo(B->getOperand(i));
274 }
275 }
276 if (AllOpsMatch)
277 return true;
278 }
279 return false;
280 }
281
282 // Look for IDs declared with Import linkage, and map the corresponding function
283 // to the register defining that variable (which will usually be the result of
284 // an OpFunction). This lets us call externally imported functions using
285 // the correct ID registers.
collectFuncNames(MachineInstr & MI,const Function * F)286 void SPIRVModuleAnalysis::collectFuncNames(MachineInstr &MI,
287 const Function *F) {
288 if (MI.getOpcode() == SPIRV::OpDecorate) {
289 // If it's got Import linkage.
290 auto Dec = MI.getOperand(1).getImm();
291 if (Dec == static_cast<unsigned>(SPIRV::Decoration::LinkageAttributes)) {
292 auto Lnk = MI.getOperand(MI.getNumOperands() - 1).getImm();
293 if (Lnk == static_cast<unsigned>(SPIRV::LinkageType::Import)) {
294 // Map imported function name to function ID register.
295 const Function *ImportedFunc =
296 F->getParent()->getFunction(getStringImm(MI, 2));
297 Register Target = MI.getOperand(0).getReg();
298 MAI.FuncMap[ImportedFunc] = MAI.getRegisterAlias(MI.getMF(), Target);
299 }
300 }
301 } else if (MI.getOpcode() == SPIRV::OpFunction) {
302 // Record all internal OpFunction declarations.
303 Register Reg = MI.defs().begin()->getReg();
304 Register GlobalReg = MAI.getRegisterAlias(MI.getMF(), Reg);
305 assert(GlobalReg.isValid());
306 MAI.FuncMap[F] = GlobalReg;
307 }
308 }
309
310 // Collect the given instruction in the specified MS. We assume global register
311 // numbering has already occurred by this point. We can directly compare reg
312 // arguments when detecting duplicates.
collectOtherInstr(MachineInstr & MI,SPIRV::ModuleAnalysisInfo & MAI,SPIRV::ModuleSectionType MSType,bool Append=true)313 static void collectOtherInstr(MachineInstr &MI, SPIRV::ModuleAnalysisInfo &MAI,
314 SPIRV::ModuleSectionType MSType,
315 bool Append = true) {
316 MAI.setSkipEmission(&MI);
317 if (findSameInstrInMS(MI, MSType, MAI))
318 return; // Found a duplicate, so don't add it.
319 // No duplicates, so add it.
320 if (Append)
321 MAI.MS[MSType].push_back(&MI);
322 else
323 MAI.MS[MSType].insert(MAI.MS[MSType].begin(), &MI);
324 }
325
326 // Some global instructions make reference to function-local ID regs, so cannot
327 // be correctly collected until these registers are globally numbered.
processOtherInstrs(const Module & M)328 void SPIRVModuleAnalysis::processOtherInstrs(const Module &M) {
329 for (auto F = M.begin(), E = M.end(); F != E; ++F) {
330 if ((*F).isDeclaration())
331 continue;
332 MachineFunction *MF = MMI->getMachineFunction(*F);
333 assert(MF);
334 for (MachineBasicBlock &MBB : *MF)
335 for (MachineInstr &MI : MBB) {
336 if (MAI.getSkipEmission(&MI))
337 continue;
338 const unsigned OpCode = MI.getOpcode();
339 if (OpCode == SPIRV::OpName || OpCode == SPIRV::OpMemberName) {
340 collectOtherInstr(MI, MAI, SPIRV::MB_DebugNames);
341 } else if (OpCode == SPIRV::OpEntryPoint) {
342 collectOtherInstr(MI, MAI, SPIRV::MB_EntryPoints);
343 } else if (TII->isDecorationInstr(MI)) {
344 collectOtherInstr(MI, MAI, SPIRV::MB_Annotations);
345 collectFuncNames(MI, &*F);
346 } else if (TII->isConstantInstr(MI)) {
347 // Now OpSpecConstant*s are not in DT,
348 // but they need to be collected anyway.
349 collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars);
350 } else if (OpCode == SPIRV::OpFunction) {
351 collectFuncNames(MI, &*F);
352 } else if (OpCode == SPIRV::OpTypeForwardPointer) {
353 collectOtherInstr(MI, MAI, SPIRV::MB_TypeConstVars, false);
354 }
355 }
356 }
357 }
358
359 // Number registers in all functions globally from 0 onwards and store
360 // the result in global register alias table. Some registers are already
361 // numbered in collectGlobalEntities.
numberRegistersGlobally(const Module & M)362 void SPIRVModuleAnalysis::numberRegistersGlobally(const Module &M) {
363 for (auto F = M.begin(), E = M.end(); F != E; ++F) {
364 if ((*F).isDeclaration())
365 continue;
366 MachineFunction *MF = MMI->getMachineFunction(*F);
367 assert(MF);
368 for (MachineBasicBlock &MBB : *MF) {
369 for (MachineInstr &MI : MBB) {
370 for (MachineOperand &Op : MI.operands()) {
371 if (!Op.isReg())
372 continue;
373 Register Reg = Op.getReg();
374 if (MAI.hasRegisterAlias(MF, Reg))
375 continue;
376 Register NewReg = Register::index2VirtReg(MAI.getNextID());
377 MAI.setRegisterAlias(MF, Reg, NewReg);
378 }
379 if (MI.getOpcode() != SPIRV::OpExtInst)
380 continue;
381 auto Set = MI.getOperand(2).getImm();
382 if (MAI.ExtInstSetMap.find(Set) == MAI.ExtInstSetMap.end())
383 MAI.ExtInstSetMap[Set] = Register::index2VirtReg(MAI.getNextID());
384 }
385 }
386 }
387 }
388
389 // Find OpIEqual and OpBranchConditional instructions originating from
390 // OpSwitches, mark them skipped for emission. Also mark MBB skipped if it
391 // contains only these instructions.
processSwitches(const Module & M,SPIRV::ModuleAnalysisInfo & MAI,MachineModuleInfo * MMI)392 static void processSwitches(const Module &M, SPIRV::ModuleAnalysisInfo &MAI,
393 MachineModuleInfo *MMI) {
394 DenseSet<Register> SwitchRegs;
395 for (auto F = M.begin(), E = M.end(); F != E; ++F) {
396 MachineFunction *MF = MMI->getMachineFunction(*F);
397 if (!MF)
398 continue;
399 for (MachineBasicBlock &MBB : *MF)
400 for (MachineInstr &MI : MBB) {
401 if (MAI.getSkipEmission(&MI))
402 continue;
403 if (MI.getOpcode() == SPIRV::OpSwitch) {
404 assert(MI.getOperand(0).isReg());
405 SwitchRegs.insert(MI.getOperand(0).getReg());
406 }
407 if (MI.getOpcode() == SPIRV::OpISubS &&
408 SwitchRegs.contains(MI.getOperand(2).getReg())) {
409 SwitchRegs.insert(MI.getOperand(0).getReg());
410 MAI.setSkipEmission(&MI);
411 }
412 if ((MI.getOpcode() != SPIRV::OpIEqual &&
413 MI.getOpcode() != SPIRV::OpULessThanEqual) ||
414 !MI.getOperand(2).isReg() ||
415 !SwitchRegs.contains(MI.getOperand(2).getReg()))
416 continue;
417 Register CmpReg = MI.getOperand(0).getReg();
418 MachineInstr *CBr = MI.getNextNode();
419 assert(CBr && CBr->getOpcode() == SPIRV::OpBranchConditional &&
420 CBr->getOperand(0).isReg() &&
421 CBr->getOperand(0).getReg() == CmpReg);
422 MAI.setSkipEmission(&MI);
423 MAI.setSkipEmission(CBr);
424 if (&MBB.front() == &MI && &MBB.back() == CBr)
425 MAI.MBBsToSkip.insert(&MBB);
426 }
427 }
428 }
429
430 // RequirementHandler implementations.
getAndAddRequirements(SPIRV::OperandCategory::OperandCategory Category,uint32_t i,const SPIRVSubtarget & ST)431 void SPIRV::RequirementHandler::getAndAddRequirements(
432 SPIRV::OperandCategory::OperandCategory Category, uint32_t i,
433 const SPIRVSubtarget &ST) {
434 addRequirements(getSymbolicOperandRequirements(Category, i, ST, *this));
435 }
436
pruneCapabilities(const CapabilityList & ToPrune)437 void SPIRV::RequirementHandler::pruneCapabilities(
438 const CapabilityList &ToPrune) {
439 for (const auto &Cap : ToPrune) {
440 AllCaps.insert(Cap);
441 auto FoundIndex = std::find(MinimalCaps.begin(), MinimalCaps.end(), Cap);
442 if (FoundIndex != MinimalCaps.end())
443 MinimalCaps.erase(FoundIndex);
444 CapabilityList ImplicitDecls =
445 getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap);
446 pruneCapabilities(ImplicitDecls);
447 }
448 }
449
addCapabilities(const CapabilityList & ToAdd)450 void SPIRV::RequirementHandler::addCapabilities(const CapabilityList &ToAdd) {
451 for (const auto &Cap : ToAdd) {
452 bool IsNewlyInserted = AllCaps.insert(Cap).second;
453 if (!IsNewlyInserted) // Don't re-add if it's already been declared.
454 continue;
455 CapabilityList ImplicitDecls =
456 getSymbolicOperandCapabilities(OperandCategory::CapabilityOperand, Cap);
457 pruneCapabilities(ImplicitDecls);
458 MinimalCaps.push_back(Cap);
459 }
460 }
461
addRequirements(const SPIRV::Requirements & Req)462 void SPIRV::RequirementHandler::addRequirements(
463 const SPIRV::Requirements &Req) {
464 if (!Req.IsSatisfiable)
465 report_fatal_error("Adding SPIR-V requirements this target can't satisfy.");
466
467 if (Req.Cap.has_value())
468 addCapabilities({Req.Cap.value()});
469
470 addExtensions(Req.Exts);
471
472 if (Req.MinVer) {
473 if (MaxVersion && Req.MinVer > MaxVersion) {
474 LLVM_DEBUG(dbgs() << "Conflicting version requirements: >= " << Req.MinVer
475 << " and <= " << MaxVersion << "\n");
476 report_fatal_error("Adding SPIR-V requirements that can't be satisfied.");
477 }
478
479 if (MinVersion == 0 || Req.MinVer > MinVersion)
480 MinVersion = Req.MinVer;
481 }
482
483 if (Req.MaxVer) {
484 if (MinVersion && Req.MaxVer < MinVersion) {
485 LLVM_DEBUG(dbgs() << "Conflicting version requirements: <= " << Req.MaxVer
486 << " and >= " << MinVersion << "\n");
487 report_fatal_error("Adding SPIR-V requirements that can't be satisfied.");
488 }
489
490 if (MaxVersion == 0 || Req.MaxVer < MaxVersion)
491 MaxVersion = Req.MaxVer;
492 }
493 }
494
checkSatisfiable(const SPIRVSubtarget & ST) const495 void SPIRV::RequirementHandler::checkSatisfiable(
496 const SPIRVSubtarget &ST) const {
497 // Report as many errors as possible before aborting the compilation.
498 bool IsSatisfiable = true;
499 auto TargetVer = ST.getSPIRVVersion();
500
501 if (MaxVersion && TargetVer && MaxVersion < TargetVer) {
502 LLVM_DEBUG(
503 dbgs() << "Target SPIR-V version too high for required features\n"
504 << "Required max version: " << MaxVersion << " target version "
505 << TargetVer << "\n");
506 IsSatisfiable = false;
507 }
508
509 if (MinVersion && TargetVer && MinVersion > TargetVer) {
510 LLVM_DEBUG(dbgs() << "Target SPIR-V version too low for required features\n"
511 << "Required min version: " << MinVersion
512 << " target version " << TargetVer << "\n");
513 IsSatisfiable = false;
514 }
515
516 if (MinVersion && MaxVersion && MinVersion > MaxVersion) {
517 LLVM_DEBUG(
518 dbgs()
519 << "Version is too low for some features and too high for others.\n"
520 << "Required SPIR-V min version: " << MinVersion
521 << " required SPIR-V max version " << MaxVersion << "\n");
522 IsSatisfiable = false;
523 }
524
525 for (auto Cap : MinimalCaps) {
526 if (AvailableCaps.contains(Cap))
527 continue;
528 LLVM_DEBUG(dbgs() << "Capability not supported: "
529 << getSymbolicOperandMnemonic(
530 OperandCategory::CapabilityOperand, Cap)
531 << "\n");
532 IsSatisfiable = false;
533 }
534
535 for (auto Ext : AllExtensions) {
536 if (ST.canUseExtension(Ext))
537 continue;
538 LLVM_DEBUG(dbgs() << "Extension not suported: "
539 << getSymbolicOperandMnemonic(
540 OperandCategory::ExtensionOperand, Ext)
541 << "\n");
542 IsSatisfiable = false;
543 }
544
545 if (!IsSatisfiable)
546 report_fatal_error("Unable to meet SPIR-V requirements for this target.");
547 }
548
549 // Add the given capabilities and all their implicitly defined capabilities too.
addAvailableCaps(const CapabilityList & ToAdd)550 void SPIRV::RequirementHandler::addAvailableCaps(const CapabilityList &ToAdd) {
551 for (const auto Cap : ToAdd)
552 if (AvailableCaps.insert(Cap).second)
553 addAvailableCaps(getSymbolicOperandCapabilities(
554 SPIRV::OperandCategory::CapabilityOperand, Cap));
555 }
556
557 namespace llvm {
558 namespace SPIRV {
initAvailableCapabilities(const SPIRVSubtarget & ST)559 void RequirementHandler::initAvailableCapabilities(const SPIRVSubtarget &ST) {
560 // TODO: Implemented for other targets other then OpenCL.
561 if (!ST.isOpenCLEnv())
562 return;
563 // Add the min requirements for different OpenCL and SPIR-V versions.
564 addAvailableCaps({Capability::Addresses, Capability::Float16Buffer,
565 Capability::Int16, Capability::Int8, Capability::Kernel,
566 Capability::Linkage, Capability::Vector16,
567 Capability::Groups, Capability::GenericPointer,
568 Capability::Shader});
569 if (ST.hasOpenCLFullProfile())
570 addAvailableCaps({Capability::Int64, Capability::Int64Atomics});
571 if (ST.hasOpenCLImageSupport()) {
572 addAvailableCaps({Capability::ImageBasic, Capability::LiteralSampler,
573 Capability::Image1D, Capability::SampledBuffer,
574 Capability::ImageBuffer});
575 if (ST.isAtLeastOpenCLVer(20))
576 addAvailableCaps({Capability::ImageReadWrite});
577 }
578 if (ST.isAtLeastSPIRVVer(11) && ST.isAtLeastOpenCLVer(22))
579 addAvailableCaps({Capability::SubgroupDispatch, Capability::PipeStorage});
580 if (ST.isAtLeastSPIRVVer(13))
581 addAvailableCaps({Capability::GroupNonUniform,
582 Capability::GroupNonUniformVote,
583 Capability::GroupNonUniformArithmetic,
584 Capability::GroupNonUniformBallot,
585 Capability::GroupNonUniformClustered,
586 Capability::GroupNonUniformShuffle,
587 Capability::GroupNonUniformShuffleRelative});
588 if (ST.isAtLeastSPIRVVer(14))
589 addAvailableCaps({Capability::DenormPreserve, Capability::DenormFlushToZero,
590 Capability::SignedZeroInfNanPreserve,
591 Capability::RoundingModeRTE,
592 Capability::RoundingModeRTZ});
593 // TODO: verify if this needs some checks.
594 addAvailableCaps({Capability::Float16, Capability::Float64});
595
596 // TODO: add OpenCL extensions.
597 }
598 } // namespace SPIRV
599 } // namespace llvm
600
601 // Add the required capabilities from a decoration instruction (including
602 // BuiltIns).
addOpDecorateReqs(const MachineInstr & MI,unsigned DecIndex,SPIRV::RequirementHandler & Reqs,const SPIRVSubtarget & ST)603 static void addOpDecorateReqs(const MachineInstr &MI, unsigned DecIndex,
604 SPIRV::RequirementHandler &Reqs,
605 const SPIRVSubtarget &ST) {
606 int64_t DecOp = MI.getOperand(DecIndex).getImm();
607 auto Dec = static_cast<SPIRV::Decoration::Decoration>(DecOp);
608 Reqs.addRequirements(getSymbolicOperandRequirements(
609 SPIRV::OperandCategory::DecorationOperand, Dec, ST, Reqs));
610
611 if (Dec == SPIRV::Decoration::BuiltIn) {
612 int64_t BuiltInOp = MI.getOperand(DecIndex + 1).getImm();
613 auto BuiltIn = static_cast<SPIRV::BuiltIn::BuiltIn>(BuiltInOp);
614 Reqs.addRequirements(getSymbolicOperandRequirements(
615 SPIRV::OperandCategory::BuiltInOperand, BuiltIn, ST, Reqs));
616 }
617 }
618
619 // Add requirements for image handling.
addOpTypeImageReqs(const MachineInstr & MI,SPIRV::RequirementHandler & Reqs,const SPIRVSubtarget & ST)620 static void addOpTypeImageReqs(const MachineInstr &MI,
621 SPIRV::RequirementHandler &Reqs,
622 const SPIRVSubtarget &ST) {
623 assert(MI.getNumOperands() >= 8 && "Insufficient operands for OpTypeImage");
624 // The operand indices used here are based on the OpTypeImage layout, which
625 // the MachineInstr follows as well.
626 int64_t ImgFormatOp = MI.getOperand(7).getImm();
627 auto ImgFormat = static_cast<SPIRV::ImageFormat::ImageFormat>(ImgFormatOp);
628 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ImageFormatOperand,
629 ImgFormat, ST);
630
631 bool IsArrayed = MI.getOperand(4).getImm() == 1;
632 bool IsMultisampled = MI.getOperand(5).getImm() == 1;
633 bool NoSampler = MI.getOperand(6).getImm() == 2;
634 // Add dimension requirements.
635 assert(MI.getOperand(2).isImm());
636 switch (MI.getOperand(2).getImm()) {
637 case SPIRV::Dim::DIM_1D:
638 Reqs.addRequirements(NoSampler ? SPIRV::Capability::Image1D
639 : SPIRV::Capability::Sampled1D);
640 break;
641 case SPIRV::Dim::DIM_2D:
642 if (IsMultisampled && NoSampler)
643 Reqs.addRequirements(SPIRV::Capability::ImageMSArray);
644 break;
645 case SPIRV::Dim::DIM_Cube:
646 Reqs.addRequirements(SPIRV::Capability::Shader);
647 if (IsArrayed)
648 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageCubeArray
649 : SPIRV::Capability::SampledCubeArray);
650 break;
651 case SPIRV::Dim::DIM_Rect:
652 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageRect
653 : SPIRV::Capability::SampledRect);
654 break;
655 case SPIRV::Dim::DIM_Buffer:
656 Reqs.addRequirements(NoSampler ? SPIRV::Capability::ImageBuffer
657 : SPIRV::Capability::SampledBuffer);
658 break;
659 case SPIRV::Dim::DIM_SubpassData:
660 Reqs.addRequirements(SPIRV::Capability::InputAttachment);
661 break;
662 }
663
664 // Has optional access qualifier.
665 // TODO: check if it's OpenCL's kernel.
666 if (MI.getNumOperands() > 8 &&
667 MI.getOperand(8).getImm() == SPIRV::AccessQualifier::ReadWrite)
668 Reqs.addRequirements(SPIRV::Capability::ImageReadWrite);
669 else
670 Reqs.addRequirements(SPIRV::Capability::ImageBasic);
671 }
672
addInstrRequirements(const MachineInstr & MI,SPIRV::RequirementHandler & Reqs,const SPIRVSubtarget & ST)673 void addInstrRequirements(const MachineInstr &MI,
674 SPIRV::RequirementHandler &Reqs,
675 const SPIRVSubtarget &ST) {
676 switch (MI.getOpcode()) {
677 case SPIRV::OpMemoryModel: {
678 int64_t Addr = MI.getOperand(0).getImm();
679 Reqs.getAndAddRequirements(SPIRV::OperandCategory::AddressingModelOperand,
680 Addr, ST);
681 int64_t Mem = MI.getOperand(1).getImm();
682 Reqs.getAndAddRequirements(SPIRV::OperandCategory::MemoryModelOperand, Mem,
683 ST);
684 break;
685 }
686 case SPIRV::OpEntryPoint: {
687 int64_t Exe = MI.getOperand(0).getImm();
688 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModelOperand,
689 Exe, ST);
690 break;
691 }
692 case SPIRV::OpExecutionMode:
693 case SPIRV::OpExecutionModeId: {
694 int64_t Exe = MI.getOperand(1).getImm();
695 Reqs.getAndAddRequirements(SPIRV::OperandCategory::ExecutionModeOperand,
696 Exe, ST);
697 break;
698 }
699 case SPIRV::OpTypeMatrix:
700 Reqs.addCapability(SPIRV::Capability::Matrix);
701 break;
702 case SPIRV::OpTypeInt: {
703 unsigned BitWidth = MI.getOperand(1).getImm();
704 if (BitWidth == 64)
705 Reqs.addCapability(SPIRV::Capability::Int64);
706 else if (BitWidth == 16)
707 Reqs.addCapability(SPIRV::Capability::Int16);
708 else if (BitWidth == 8)
709 Reqs.addCapability(SPIRV::Capability::Int8);
710 break;
711 }
712 case SPIRV::OpTypeFloat: {
713 unsigned BitWidth = MI.getOperand(1).getImm();
714 if (BitWidth == 64)
715 Reqs.addCapability(SPIRV::Capability::Float64);
716 else if (BitWidth == 16)
717 Reqs.addCapability(SPIRV::Capability::Float16);
718 break;
719 }
720 case SPIRV::OpTypeVector: {
721 unsigned NumComponents = MI.getOperand(2).getImm();
722 if (NumComponents == 8 || NumComponents == 16)
723 Reqs.addCapability(SPIRV::Capability::Vector16);
724 break;
725 }
726 case SPIRV::OpTypePointer: {
727 auto SC = MI.getOperand(1).getImm();
728 Reqs.getAndAddRequirements(SPIRV::OperandCategory::StorageClassOperand, SC,
729 ST);
730 // If it's a type of pointer to float16, add Float16Buffer capability.
731 assert(MI.getOperand(2).isReg());
732 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
733 SPIRVType *TypeDef = MRI.getVRegDef(MI.getOperand(2).getReg());
734 if (TypeDef->getOpcode() == SPIRV::OpTypeFloat &&
735 TypeDef->getOperand(1).getImm() == 16)
736 Reqs.addCapability(SPIRV::Capability::Float16Buffer);
737 break;
738 }
739 case SPIRV::OpBitReverse:
740 case SPIRV::OpTypeRuntimeArray:
741 Reqs.addCapability(SPIRV::Capability::Shader);
742 break;
743 case SPIRV::OpTypeOpaque:
744 case SPIRV::OpTypeEvent:
745 Reqs.addCapability(SPIRV::Capability::Kernel);
746 break;
747 case SPIRV::OpTypePipe:
748 case SPIRV::OpTypeReserveId:
749 Reqs.addCapability(SPIRV::Capability::Pipes);
750 break;
751 case SPIRV::OpTypeDeviceEvent:
752 case SPIRV::OpTypeQueue:
753 case SPIRV::OpBuildNDRange:
754 Reqs.addCapability(SPIRV::Capability::DeviceEnqueue);
755 break;
756 case SPIRV::OpDecorate:
757 case SPIRV::OpDecorateId:
758 case SPIRV::OpDecorateString:
759 addOpDecorateReqs(MI, 1, Reqs, ST);
760 break;
761 case SPIRV::OpMemberDecorate:
762 case SPIRV::OpMemberDecorateString:
763 addOpDecorateReqs(MI, 2, Reqs, ST);
764 break;
765 case SPIRV::OpInBoundsPtrAccessChain:
766 Reqs.addCapability(SPIRV::Capability::Addresses);
767 break;
768 case SPIRV::OpConstantSampler:
769 Reqs.addCapability(SPIRV::Capability::LiteralSampler);
770 break;
771 case SPIRV::OpTypeImage:
772 addOpTypeImageReqs(MI, Reqs, ST);
773 break;
774 case SPIRV::OpTypeSampler:
775 Reqs.addCapability(SPIRV::Capability::ImageBasic);
776 break;
777 case SPIRV::OpTypeForwardPointer:
778 // TODO: check if it's OpenCL's kernel.
779 Reqs.addCapability(SPIRV::Capability::Addresses);
780 break;
781 case SPIRV::OpAtomicFlagTestAndSet:
782 case SPIRV::OpAtomicLoad:
783 case SPIRV::OpAtomicStore:
784 case SPIRV::OpAtomicExchange:
785 case SPIRV::OpAtomicCompareExchange:
786 case SPIRV::OpAtomicIIncrement:
787 case SPIRV::OpAtomicIDecrement:
788 case SPIRV::OpAtomicIAdd:
789 case SPIRV::OpAtomicISub:
790 case SPIRV::OpAtomicUMin:
791 case SPIRV::OpAtomicUMax:
792 case SPIRV::OpAtomicSMin:
793 case SPIRV::OpAtomicSMax:
794 case SPIRV::OpAtomicAnd:
795 case SPIRV::OpAtomicOr:
796 case SPIRV::OpAtomicXor: {
797 const MachineRegisterInfo &MRI = MI.getMF()->getRegInfo();
798 const MachineInstr *InstrPtr = &MI;
799 if (MI.getOpcode() == SPIRV::OpAtomicStore) {
800 assert(MI.getOperand(3).isReg());
801 InstrPtr = MRI.getVRegDef(MI.getOperand(3).getReg());
802 assert(InstrPtr && "Unexpected type instruction for OpAtomicStore");
803 }
804 assert(InstrPtr->getOperand(1).isReg() && "Unexpected operand in atomic");
805 Register TypeReg = InstrPtr->getOperand(1).getReg();
806 SPIRVType *TypeDef = MRI.getVRegDef(TypeReg);
807 if (TypeDef->getOpcode() == SPIRV::OpTypeInt) {
808 unsigned BitWidth = TypeDef->getOperand(1).getImm();
809 if (BitWidth == 64)
810 Reqs.addCapability(SPIRV::Capability::Int64Atomics);
811 }
812 break;
813 }
814 case SPIRV::OpGroupNonUniformIAdd:
815 case SPIRV::OpGroupNonUniformFAdd:
816 case SPIRV::OpGroupNonUniformIMul:
817 case SPIRV::OpGroupNonUniformFMul:
818 case SPIRV::OpGroupNonUniformSMin:
819 case SPIRV::OpGroupNonUniformUMin:
820 case SPIRV::OpGroupNonUniformFMin:
821 case SPIRV::OpGroupNonUniformSMax:
822 case SPIRV::OpGroupNonUniformUMax:
823 case SPIRV::OpGroupNonUniformFMax:
824 case SPIRV::OpGroupNonUniformBitwiseAnd:
825 case SPIRV::OpGroupNonUniformBitwiseOr:
826 case SPIRV::OpGroupNonUniformBitwiseXor:
827 case SPIRV::OpGroupNonUniformLogicalAnd:
828 case SPIRV::OpGroupNonUniformLogicalOr:
829 case SPIRV::OpGroupNonUniformLogicalXor: {
830 assert(MI.getOperand(3).isImm());
831 int64_t GroupOp = MI.getOperand(3).getImm();
832 switch (GroupOp) {
833 case SPIRV::GroupOperation::Reduce:
834 case SPIRV::GroupOperation::InclusiveScan:
835 case SPIRV::GroupOperation::ExclusiveScan:
836 Reqs.addCapability(SPIRV::Capability::Kernel);
837 Reqs.addCapability(SPIRV::Capability::GroupNonUniformArithmetic);
838 Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot);
839 break;
840 case SPIRV::GroupOperation::ClusteredReduce:
841 Reqs.addCapability(SPIRV::Capability::GroupNonUniformClustered);
842 break;
843 case SPIRV::GroupOperation::PartitionedReduceNV:
844 case SPIRV::GroupOperation::PartitionedInclusiveScanNV:
845 case SPIRV::GroupOperation::PartitionedExclusiveScanNV:
846 Reqs.addCapability(SPIRV::Capability::GroupNonUniformPartitionedNV);
847 break;
848 }
849 break;
850 }
851 case SPIRV::OpGroupNonUniformShuffle:
852 case SPIRV::OpGroupNonUniformShuffleXor:
853 Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffle);
854 break;
855 case SPIRV::OpGroupNonUniformShuffleUp:
856 case SPIRV::OpGroupNonUniformShuffleDown:
857 Reqs.addCapability(SPIRV::Capability::GroupNonUniformShuffleRelative);
858 break;
859 case SPIRV::OpGroupAll:
860 case SPIRV::OpGroupAny:
861 case SPIRV::OpGroupBroadcast:
862 case SPIRV::OpGroupIAdd:
863 case SPIRV::OpGroupFAdd:
864 case SPIRV::OpGroupFMin:
865 case SPIRV::OpGroupUMin:
866 case SPIRV::OpGroupSMin:
867 case SPIRV::OpGroupFMax:
868 case SPIRV::OpGroupUMax:
869 case SPIRV::OpGroupSMax:
870 Reqs.addCapability(SPIRV::Capability::Groups);
871 break;
872 case SPIRV::OpGroupNonUniformElect:
873 Reqs.addCapability(SPIRV::Capability::GroupNonUniform);
874 break;
875 case SPIRV::OpGroupNonUniformAll:
876 case SPIRV::OpGroupNonUniformAny:
877 case SPIRV::OpGroupNonUniformAllEqual:
878 Reqs.addCapability(SPIRV::Capability::GroupNonUniformVote);
879 break;
880 case SPIRV::OpGroupNonUniformBroadcast:
881 case SPIRV::OpGroupNonUniformBroadcastFirst:
882 case SPIRV::OpGroupNonUniformBallot:
883 case SPIRV::OpGroupNonUniformInverseBallot:
884 case SPIRV::OpGroupNonUniformBallotBitExtract:
885 case SPIRV::OpGroupNonUniformBallotBitCount:
886 case SPIRV::OpGroupNonUniformBallotFindLSB:
887 case SPIRV::OpGroupNonUniformBallotFindMSB:
888 Reqs.addCapability(SPIRV::Capability::GroupNonUniformBallot);
889 break;
890 default:
891 break;
892 }
893 }
894
collectReqs(const Module & M,SPIRV::ModuleAnalysisInfo & MAI,MachineModuleInfo * MMI,const SPIRVSubtarget & ST)895 static void collectReqs(const Module &M, SPIRV::ModuleAnalysisInfo &MAI,
896 MachineModuleInfo *MMI, const SPIRVSubtarget &ST) {
897 // Collect requirements for existing instructions.
898 for (auto F = M.begin(), E = M.end(); F != E; ++F) {
899 MachineFunction *MF = MMI->getMachineFunction(*F);
900 if (!MF)
901 continue;
902 for (const MachineBasicBlock &MBB : *MF)
903 for (const MachineInstr &MI : MBB)
904 addInstrRequirements(MI, MAI.Reqs, ST);
905 }
906 // Collect requirements for OpExecutionMode instructions.
907 auto Node = M.getNamedMetadata("spirv.ExecutionMode");
908 if (Node) {
909 for (unsigned i = 0; i < Node->getNumOperands(); i++) {
910 MDNode *MDN = cast<MDNode>(Node->getOperand(i));
911 const MDOperand &MDOp = MDN->getOperand(1);
912 if (auto *CMeta = dyn_cast<ConstantAsMetadata>(MDOp)) {
913 Constant *C = CMeta->getValue();
914 if (ConstantInt *Const = dyn_cast<ConstantInt>(C)) {
915 auto EM = Const->getZExtValue();
916 MAI.Reqs.getAndAddRequirements(
917 SPIRV::OperandCategory::ExecutionModeOperand, EM, ST);
918 }
919 }
920 }
921 }
922 for (auto FI = M.begin(), E = M.end(); FI != E; ++FI) {
923 const Function &F = *FI;
924 if (F.isDeclaration())
925 continue;
926 if (F.getMetadata("reqd_work_group_size"))
927 MAI.Reqs.getAndAddRequirements(
928 SPIRV::OperandCategory::ExecutionModeOperand,
929 SPIRV::ExecutionMode::LocalSize, ST);
930 if (F.getMetadata("work_group_size_hint"))
931 MAI.Reqs.getAndAddRequirements(
932 SPIRV::OperandCategory::ExecutionModeOperand,
933 SPIRV::ExecutionMode::LocalSizeHint, ST);
934 if (F.getMetadata("intel_reqd_sub_group_size"))
935 MAI.Reqs.getAndAddRequirements(
936 SPIRV::OperandCategory::ExecutionModeOperand,
937 SPIRV::ExecutionMode::SubgroupSize, ST);
938 if (F.getMetadata("vec_type_hint"))
939 MAI.Reqs.getAndAddRequirements(
940 SPIRV::OperandCategory::ExecutionModeOperand,
941 SPIRV::ExecutionMode::VecTypeHint, ST);
942 }
943 }
944
getFastMathFlags(const MachineInstr & I)945 static unsigned getFastMathFlags(const MachineInstr &I) {
946 unsigned Flags = SPIRV::FPFastMathMode::None;
947 if (I.getFlag(MachineInstr::MIFlag::FmNoNans))
948 Flags |= SPIRV::FPFastMathMode::NotNaN;
949 if (I.getFlag(MachineInstr::MIFlag::FmNoInfs))
950 Flags |= SPIRV::FPFastMathMode::NotInf;
951 if (I.getFlag(MachineInstr::MIFlag::FmNsz))
952 Flags |= SPIRV::FPFastMathMode::NSZ;
953 if (I.getFlag(MachineInstr::MIFlag::FmArcp))
954 Flags |= SPIRV::FPFastMathMode::AllowRecip;
955 if (I.getFlag(MachineInstr::MIFlag::FmReassoc))
956 Flags |= SPIRV::FPFastMathMode::Fast;
957 return Flags;
958 }
959
handleMIFlagDecoration(MachineInstr & I,const SPIRVSubtarget & ST,const SPIRVInstrInfo & TII,SPIRV::RequirementHandler & Reqs)960 static void handleMIFlagDecoration(MachineInstr &I, const SPIRVSubtarget &ST,
961 const SPIRVInstrInfo &TII,
962 SPIRV::RequirementHandler &Reqs) {
963 if (I.getFlag(MachineInstr::MIFlag::NoSWrap) && TII.canUseNSW(I) &&
964 getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand,
965 SPIRV::Decoration::NoSignedWrap, ST, Reqs)
966 .IsSatisfiable) {
967 buildOpDecorate(I.getOperand(0).getReg(), I, TII,
968 SPIRV::Decoration::NoSignedWrap, {});
969 }
970 if (I.getFlag(MachineInstr::MIFlag::NoUWrap) && TII.canUseNUW(I) &&
971 getSymbolicOperandRequirements(SPIRV::OperandCategory::DecorationOperand,
972 SPIRV::Decoration::NoUnsignedWrap, ST,
973 Reqs)
974 .IsSatisfiable) {
975 buildOpDecorate(I.getOperand(0).getReg(), I, TII,
976 SPIRV::Decoration::NoUnsignedWrap, {});
977 }
978 if (!TII.canUseFastMathFlags(I))
979 return;
980 unsigned FMFlags = getFastMathFlags(I);
981 if (FMFlags == SPIRV::FPFastMathMode::None)
982 return;
983 Register DstReg = I.getOperand(0).getReg();
984 buildOpDecorate(DstReg, I, TII, SPIRV::Decoration::FPFastMathMode, {FMFlags});
985 }
986
987 // Walk all functions and add decorations related to MI flags.
addDecorations(const Module & M,const SPIRVInstrInfo & TII,MachineModuleInfo * MMI,const SPIRVSubtarget & ST,SPIRV::ModuleAnalysisInfo & MAI)988 static void addDecorations(const Module &M, const SPIRVInstrInfo &TII,
989 MachineModuleInfo *MMI, const SPIRVSubtarget &ST,
990 SPIRV::ModuleAnalysisInfo &MAI) {
991 for (auto F = M.begin(), E = M.end(); F != E; ++F) {
992 MachineFunction *MF = MMI->getMachineFunction(*F);
993 if (!MF)
994 continue;
995 for (auto &MBB : *MF)
996 for (auto &MI : MBB)
997 handleMIFlagDecoration(MI, ST, TII, MAI.Reqs);
998 }
999 }
1000
1001 struct SPIRV::ModuleAnalysisInfo SPIRVModuleAnalysis::MAI;
1002
getAnalysisUsage(AnalysisUsage & AU) const1003 void SPIRVModuleAnalysis::getAnalysisUsage(AnalysisUsage &AU) const {
1004 AU.addRequired<TargetPassConfig>();
1005 AU.addRequired<MachineModuleInfoWrapperPass>();
1006 }
1007
runOnModule(Module & M)1008 bool SPIRVModuleAnalysis::runOnModule(Module &M) {
1009 SPIRVTargetMachine &TM =
1010 getAnalysis<TargetPassConfig>().getTM<SPIRVTargetMachine>();
1011 ST = TM.getSubtargetImpl();
1012 GR = ST->getSPIRVGlobalRegistry();
1013 TII = ST->getInstrInfo();
1014
1015 MMI = &getAnalysis<MachineModuleInfoWrapperPass>().getMMI();
1016
1017 setBaseInfo(M);
1018
1019 addDecorations(M, *TII, MMI, *ST, MAI);
1020
1021 collectReqs(M, MAI, MMI, *ST);
1022
1023 processSwitches(M, MAI, MMI);
1024
1025 // Process type/const/global var/func decl instructions, number their
1026 // destination registers from 0 to N, collect Extensions and Capabilities.
1027 processDefInstrs(M);
1028
1029 // Number rest of registers from N+1 onwards.
1030 numberRegistersGlobally(M);
1031
1032 // Collect OpName, OpEntryPoint, OpDecorate etc, process other instructions.
1033 processOtherInstrs(M);
1034
1035 // If there are no entry points, we need the Linkage capability.
1036 if (MAI.MS[SPIRV::MB_EntryPoints].empty())
1037 MAI.Reqs.addCapability(SPIRV::Capability::Linkage);
1038
1039 return false;
1040 }
1041