1 //===-- SPIRVGlobalRegistry.cpp - SPIR-V Global Registry --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the implementation of the SPIRVGlobalRegistry class,
10 // which is used to maintain rich type information required for SPIR-V even
11 // after lowering from LLVM IR to GMIR. It can convert an llvm::Type into
12 // an OpTypeXXX instruction, and map it to a virtual register. Also it builds
13 // and supports consistency of constants and global variables.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #include "SPIRVGlobalRegistry.h"
18 #include "SPIRV.h"
19 #include "SPIRVBuiltins.h"
20 #include "SPIRVSubtarget.h"
21 #include "SPIRVTargetMachine.h"
22 #include "SPIRVUtils.h"
23
24 using namespace llvm;
SPIRVGlobalRegistry(unsigned PointerSize)25 SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize)
26 : PointerSize(PointerSize) {}
27
assignIntTypeToVReg(unsigned BitWidth,Register VReg,MachineInstr & I,const SPIRVInstrInfo & TII)28 SPIRVType *SPIRVGlobalRegistry::assignIntTypeToVReg(unsigned BitWidth,
29 Register VReg,
30 MachineInstr &I,
31 const SPIRVInstrInfo &TII) {
32 SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
33 assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);
34 return SpirvType;
35 }
36
assignVectTypeToVReg(SPIRVType * BaseType,unsigned NumElements,Register VReg,MachineInstr & I,const SPIRVInstrInfo & TII)37 SPIRVType *SPIRVGlobalRegistry::assignVectTypeToVReg(
38 SPIRVType *BaseType, unsigned NumElements, Register VReg, MachineInstr &I,
39 const SPIRVInstrInfo &TII) {
40 SPIRVType *SpirvType =
41 getOrCreateSPIRVVectorType(BaseType, NumElements, I, TII);
42 assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);
43 return SpirvType;
44 }
45
assignTypeToVReg(const Type * Type,Register VReg,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccessQual,bool EmitIR)46 SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg(
47 const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder,
48 SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
49
50 SPIRVType *SpirvType =
51 getOrCreateSPIRVType(Type, MIRBuilder, AccessQual, EmitIR);
52 assignSPIRVTypeToVReg(SpirvType, VReg, MIRBuilder.getMF());
53 return SpirvType;
54 }
55
assignSPIRVTypeToVReg(SPIRVType * SpirvType,Register VReg,MachineFunction & MF)56 void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType,
57 Register VReg,
58 MachineFunction &MF) {
59 VRegToTypeMap[&MF][VReg] = SpirvType;
60 }
61
createTypeVReg(MachineIRBuilder & MIRBuilder)62 static Register createTypeVReg(MachineIRBuilder &MIRBuilder) {
63 auto &MRI = MIRBuilder.getMF().getRegInfo();
64 auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32));
65 MRI.setRegClass(Res, &SPIRV::TYPERegClass);
66 return Res;
67 }
68
createTypeVReg(MachineRegisterInfo & MRI)69 static Register createTypeVReg(MachineRegisterInfo &MRI) {
70 auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32));
71 MRI.setRegClass(Res, &SPIRV::TYPERegClass);
72 return Res;
73 }
74
getOpTypeBool(MachineIRBuilder & MIRBuilder)75 SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
76 return MIRBuilder.buildInstr(SPIRV::OpTypeBool)
77 .addDef(createTypeVReg(MIRBuilder));
78 }
79
getOpTypeInt(uint32_t Width,MachineIRBuilder & MIRBuilder,bool IsSigned)80 SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width,
81 MachineIRBuilder &MIRBuilder,
82 bool IsSigned) {
83 assert(Width <= 64 && "Unsupported integer width!");
84 const SPIRVSubtarget &ST =
85 cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
86 if (ST.canUseExtension(
87 SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
88 MIRBuilder.buildInstr(SPIRV::OpExtension)
89 .addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
90 MIRBuilder.buildInstr(SPIRV::OpCapability)
91 .addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);
92 } else if (Width <= 8)
93 Width = 8;
94 else if (Width <= 16)
95 Width = 16;
96 else if (Width <= 32)
97 Width = 32;
98 else if (Width <= 64)
99 Width = 64;
100
101 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeInt)
102 .addDef(createTypeVReg(MIRBuilder))
103 .addImm(Width)
104 .addImm(IsSigned ? 1 : 0);
105 return MIB;
106 }
107
getOpTypeFloat(uint32_t Width,MachineIRBuilder & MIRBuilder)108 SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
109 MachineIRBuilder &MIRBuilder) {
110 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
111 .addDef(createTypeVReg(MIRBuilder))
112 .addImm(Width);
113 return MIB;
114 }
115
getOpTypeVoid(MachineIRBuilder & MIRBuilder)116 SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) {
117 return MIRBuilder.buildInstr(SPIRV::OpTypeVoid)
118 .addDef(createTypeVReg(MIRBuilder));
119 }
120
getOpTypeVector(uint32_t NumElems,SPIRVType * ElemType,MachineIRBuilder & MIRBuilder)121 SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems,
122 SPIRVType *ElemType,
123 MachineIRBuilder &MIRBuilder) {
124 auto EleOpc = ElemType->getOpcode();
125 assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat ||
126 EleOpc == SPIRV::OpTypeBool) &&
127 "Invalid vector element type");
128
129 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeVector)
130 .addDef(createTypeVReg(MIRBuilder))
131 .addUse(getSPIRVTypeID(ElemType))
132 .addImm(NumElems);
133 return MIB;
134 }
135
136 std::tuple<Register, ConstantInt *, bool>
getOrCreateConstIntReg(uint64_t Val,SPIRVType * SpvType,MachineIRBuilder * MIRBuilder,MachineInstr * I,const SPIRVInstrInfo * TII)137 SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType,
138 MachineIRBuilder *MIRBuilder,
139 MachineInstr *I,
140 const SPIRVInstrInfo *TII) {
141 const IntegerType *LLVMIntTy;
142 if (SpvType)
143 LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType));
144 else
145 LLVMIntTy = IntegerType::getInt32Ty(CurMF->getFunction().getContext());
146 bool NewInstr = false;
147 // Find a constant in DT or build a new one.
148 ConstantInt *CI = ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val);
149 Register Res = DT.find(CI, CurMF);
150 if (!Res.isValid()) {
151 unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
152 LLT LLTy = LLT::scalar(32);
153 Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
154 CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
155 if (MIRBuilder)
156 assignTypeToVReg(LLVMIntTy, Res, *MIRBuilder);
157 else
158 assignIntTypeToVReg(BitWidth, Res, *I, *TII);
159 DT.add(CI, CurMF, Res);
160 NewInstr = true;
161 }
162 return std::make_tuple(Res, CI, NewInstr);
163 }
164
getOrCreateConstInt(uint64_t Val,MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII)165 Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
166 SPIRVType *SpvType,
167 const SPIRVInstrInfo &TII) {
168 assert(SpvType);
169 ConstantInt *CI;
170 Register Res;
171 bool New;
172 std::tie(Res, CI, New) =
173 getOrCreateConstIntReg(Val, SpvType, nullptr, &I, &TII);
174 // If we have found Res register which is defined by the passed G_CONSTANT
175 // machine instruction, a new constant instruction should be created.
176 if (!New && (!I.getOperand(0).isReg() || Res != I.getOperand(0).getReg()))
177 return Res;
178 MachineInstrBuilder MIB;
179 MachineBasicBlock &BB = *I.getParent();
180 if (Val) {
181 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantI))
182 .addDef(Res)
183 .addUse(getSPIRVTypeID(SpvType));
184 addNumImm(APInt(getScalarOrVectorBitWidth(SpvType), Val), MIB);
185 } else {
186 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
187 .addDef(Res)
188 .addUse(getSPIRVTypeID(SpvType));
189 }
190 const auto &ST = CurMF->getSubtarget();
191 constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
192 *ST.getRegisterInfo(), *ST.getRegBankInfo());
193 return Res;
194 }
195
buildConstantInt(uint64_t Val,MachineIRBuilder & MIRBuilder,SPIRVType * SpvType,bool EmitIR)196 Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
197 MachineIRBuilder &MIRBuilder,
198 SPIRVType *SpvType,
199 bool EmitIR) {
200 auto &MF = MIRBuilder.getMF();
201 const IntegerType *LLVMIntTy;
202 if (SpvType)
203 LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType));
204 else
205 LLVMIntTy = IntegerType::getInt32Ty(MF.getFunction().getContext());
206 // Find a constant in DT or build a new one.
207 const auto ConstInt =
208 ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val);
209 Register Res = DT.find(ConstInt, &MF);
210 if (!Res.isValid()) {
211 unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
212 LLT LLTy = LLT::scalar(EmitIR ? BitWidth : 32);
213 Res = MF.getRegInfo().createGenericVirtualRegister(LLTy);
214 MF.getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
215 assignTypeToVReg(LLVMIntTy, Res, MIRBuilder,
216 SPIRV::AccessQualifier::ReadWrite, EmitIR);
217 DT.add(ConstInt, &MIRBuilder.getMF(), Res);
218 if (EmitIR) {
219 MIRBuilder.buildConstant(Res, *ConstInt);
220 } else {
221 MachineInstrBuilder MIB;
222 if (Val) {
223 assert(SpvType);
224 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
225 .addDef(Res)
226 .addUse(getSPIRVTypeID(SpvType));
227 addNumImm(APInt(BitWidth, Val), MIB);
228 } else {
229 assert(SpvType);
230 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
231 .addDef(Res)
232 .addUse(getSPIRVTypeID(SpvType));
233 }
234 const auto &Subtarget = CurMF->getSubtarget();
235 constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
236 *Subtarget.getRegisterInfo(),
237 *Subtarget.getRegBankInfo());
238 }
239 }
240 return Res;
241 }
242
buildConstantFP(APFloat Val,MachineIRBuilder & MIRBuilder,SPIRVType * SpvType)243 Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,
244 MachineIRBuilder &MIRBuilder,
245 SPIRVType *SpvType) {
246 auto &MF = MIRBuilder.getMF();
247 auto &Ctx = MF.getFunction().getContext();
248 if (!SpvType) {
249 const Type *LLVMFPTy = Type::getFloatTy(Ctx);
250 SpvType = getOrCreateSPIRVType(LLVMFPTy, MIRBuilder);
251 }
252 // Find a constant in DT or build a new one.
253 const auto ConstFP = ConstantFP::get(Ctx, Val);
254 Register Res = DT.find(ConstFP, &MF);
255 if (!Res.isValid()) {
256 Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(32));
257 MF.getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
258 assignSPIRVTypeToVReg(SpvType, Res, MF);
259 DT.add(ConstFP, &MF, Res);
260
261 MachineInstrBuilder MIB;
262 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantF)
263 .addDef(Res)
264 .addUse(getSPIRVTypeID(SpvType));
265 addNumImm(ConstFP->getValueAPF().bitcastToAPInt(), MIB);
266 }
267
268 return Res;
269 }
270
getOrCreateIntCompositeOrNull(uint64_t Val,MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII,Constant * CA,unsigned BitWidth,unsigned ElemCnt)271 Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
272 uint64_t Val, MachineInstr &I, SPIRVType *SpvType,
273 const SPIRVInstrInfo &TII, Constant *CA, unsigned BitWidth,
274 unsigned ElemCnt) {
275 // Find a constant vector in DT or build a new one.
276 Register Res = DT.find(CA, CurMF);
277 if (!Res.isValid()) {
278 SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
279 // SpvScalConst should be created before SpvVecConst to avoid undefined ID
280 // error on validation.
281 // TODO: can moved below once sorting of types/consts/defs is implemented.
282 Register SpvScalConst;
283 if (Val)
284 SpvScalConst = getOrCreateConstInt(Val, I, SpvBaseType, TII);
285 // TODO: maybe use bitwidth of base type.
286 LLT LLTy = LLT::scalar(32);
287 Register SpvVecConst =
288 CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
289 CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::IDRegClass);
290 assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF);
291 DT.add(CA, CurMF, SpvVecConst);
292 MachineInstrBuilder MIB;
293 MachineBasicBlock &BB = *I.getParent();
294 if (Val) {
295 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantComposite))
296 .addDef(SpvVecConst)
297 .addUse(getSPIRVTypeID(SpvType));
298 for (unsigned i = 0; i < ElemCnt; ++i)
299 MIB.addUse(SpvScalConst);
300 } else {
301 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
302 .addDef(SpvVecConst)
303 .addUse(getSPIRVTypeID(SpvType));
304 }
305 const auto &Subtarget = CurMF->getSubtarget();
306 constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
307 *Subtarget.getRegisterInfo(),
308 *Subtarget.getRegBankInfo());
309 return SpvVecConst;
310 }
311 return Res;
312 }
313
314 Register
getOrCreateConsIntVector(uint64_t Val,MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII)315 SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val, MachineInstr &I,
316 SPIRVType *SpvType,
317 const SPIRVInstrInfo &TII) {
318 const Type *LLVMTy = getTypeForSPIRVType(SpvType);
319 assert(LLVMTy->isVectorTy());
320 const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
321 Type *LLVMBaseTy = LLVMVecTy->getElementType();
322 const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
323 auto ConstVec =
324 ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt);
325 unsigned BW = getScalarOrVectorBitWidth(SpvType);
326 return getOrCreateIntCompositeOrNull(Val, I, SpvType, TII, ConstVec, BW,
327 SpvType->getOperand(2).getImm());
328 }
329
330 Register
getOrCreateConsIntArray(uint64_t Val,MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII)331 SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val, MachineInstr &I,
332 SPIRVType *SpvType,
333 const SPIRVInstrInfo &TII) {
334 const Type *LLVMTy = getTypeForSPIRVType(SpvType);
335 assert(LLVMTy->isArrayTy());
336 const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy);
337 Type *LLVMBaseTy = LLVMArrTy->getElementType();
338 const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
339 auto ConstArr =
340 ConstantArray::get(const_cast<ArrayType *>(LLVMArrTy), {ConstInt});
341 SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
342 unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy);
343 return getOrCreateIntCompositeOrNull(Val, I, SpvType, TII, ConstArr, BW,
344 LLVMArrTy->getNumElements());
345 }
346
getOrCreateIntCompositeOrNull(uint64_t Val,MachineIRBuilder & MIRBuilder,SPIRVType * SpvType,bool EmitIR,Constant * CA,unsigned BitWidth,unsigned ElemCnt)347 Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
348 uint64_t Val, MachineIRBuilder &MIRBuilder, SPIRVType *SpvType, bool EmitIR,
349 Constant *CA, unsigned BitWidth, unsigned ElemCnt) {
350 Register Res = DT.find(CA, CurMF);
351 if (!Res.isValid()) {
352 Register SpvScalConst;
353 if (Val || EmitIR) {
354 SPIRVType *SpvBaseType =
355 getOrCreateSPIRVIntegerType(BitWidth, MIRBuilder);
356 SpvScalConst = buildConstantInt(Val, MIRBuilder, SpvBaseType, EmitIR);
357 }
358 LLT LLTy = EmitIR ? LLT::fixed_vector(ElemCnt, BitWidth) : LLT::scalar(32);
359 Register SpvVecConst =
360 CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
361 CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::IDRegClass);
362 assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF);
363 DT.add(CA, CurMF, SpvVecConst);
364 if (EmitIR) {
365 MIRBuilder.buildSplatVector(SpvVecConst, SpvScalConst);
366 } else {
367 if (Val) {
368 auto MIB = MIRBuilder.buildInstr(SPIRV::OpConstantComposite)
369 .addDef(SpvVecConst)
370 .addUse(getSPIRVTypeID(SpvType));
371 for (unsigned i = 0; i < ElemCnt; ++i)
372 MIB.addUse(SpvScalConst);
373 } else {
374 MIRBuilder.buildInstr(SPIRV::OpConstantNull)
375 .addDef(SpvVecConst)
376 .addUse(getSPIRVTypeID(SpvType));
377 }
378 }
379 return SpvVecConst;
380 }
381 return Res;
382 }
383
384 Register
getOrCreateConsIntVector(uint64_t Val,MachineIRBuilder & MIRBuilder,SPIRVType * SpvType,bool EmitIR)385 SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val,
386 MachineIRBuilder &MIRBuilder,
387 SPIRVType *SpvType, bool EmitIR) {
388 const Type *LLVMTy = getTypeForSPIRVType(SpvType);
389 assert(LLVMTy->isVectorTy());
390 const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
391 Type *LLVMBaseTy = LLVMVecTy->getElementType();
392 const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
393 auto ConstVec =
394 ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt);
395 unsigned BW = getScalarOrVectorBitWidth(SpvType);
396 return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR,
397 ConstVec, BW,
398 SpvType->getOperand(2).getImm());
399 }
400
401 Register
getOrCreateConsIntArray(uint64_t Val,MachineIRBuilder & MIRBuilder,SPIRVType * SpvType,bool EmitIR)402 SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val,
403 MachineIRBuilder &MIRBuilder,
404 SPIRVType *SpvType, bool EmitIR) {
405 const Type *LLVMTy = getTypeForSPIRVType(SpvType);
406 assert(LLVMTy->isArrayTy());
407 const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy);
408 Type *LLVMBaseTy = LLVMArrTy->getElementType();
409 const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
410 auto ConstArr =
411 ConstantArray::get(const_cast<ArrayType *>(LLVMArrTy), {ConstInt});
412 SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
413 unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy);
414 return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR,
415 ConstArr, BW,
416 LLVMArrTy->getNumElements());
417 }
418
419 Register
getOrCreateConstNullPtr(MachineIRBuilder & MIRBuilder,SPIRVType * SpvType)420 SPIRVGlobalRegistry::getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder,
421 SPIRVType *SpvType) {
422 const Type *LLVMTy = getTypeForSPIRVType(SpvType);
423 const PointerType *LLVMPtrTy = cast<PointerType>(LLVMTy);
424 // Find a constant in DT or build a new one.
425 Constant *CP = ConstantPointerNull::get(const_cast<PointerType *>(LLVMPtrTy));
426 Register Res = DT.find(CP, CurMF);
427 if (!Res.isValid()) {
428 LLT LLTy = LLT::pointer(LLVMPtrTy->getAddressSpace(), PointerSize);
429 Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
430 CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
431 assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
432 MIRBuilder.buildInstr(SPIRV::OpConstantNull)
433 .addDef(Res)
434 .addUse(getSPIRVTypeID(SpvType));
435 DT.add(CP, CurMF, Res);
436 }
437 return Res;
438 }
439
buildConstantSampler(Register ResReg,unsigned AddrMode,unsigned Param,unsigned FilerMode,MachineIRBuilder & MIRBuilder,SPIRVType * SpvType)440 Register SPIRVGlobalRegistry::buildConstantSampler(
441 Register ResReg, unsigned AddrMode, unsigned Param, unsigned FilerMode,
442 MachineIRBuilder &MIRBuilder, SPIRVType *SpvType) {
443 SPIRVType *SampTy;
444 if (SpvType)
445 SampTy = getOrCreateSPIRVType(getTypeForSPIRVType(SpvType), MIRBuilder);
446 else
447 SampTy = getOrCreateSPIRVTypeByName("opencl.sampler_t", MIRBuilder);
448
449 auto Sampler =
450 ResReg.isValid()
451 ? ResReg
452 : MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
453 auto Res = MIRBuilder.buildInstr(SPIRV::OpConstantSampler)
454 .addDef(Sampler)
455 .addUse(getSPIRVTypeID(SampTy))
456 .addImm(AddrMode)
457 .addImm(Param)
458 .addImm(FilerMode);
459 assert(Res->getOperand(0).isReg());
460 return Res->getOperand(0).getReg();
461 }
462
buildGlobalVariable(Register ResVReg,SPIRVType * BaseType,StringRef Name,const GlobalValue * GV,SPIRV::StorageClass::StorageClass Storage,const MachineInstr * Init,bool IsConst,bool HasLinkageTy,SPIRV::LinkageType::LinkageType LinkageType,MachineIRBuilder & MIRBuilder,bool IsInstSelector)463 Register SPIRVGlobalRegistry::buildGlobalVariable(
464 Register ResVReg, SPIRVType *BaseType, StringRef Name,
465 const GlobalValue *GV, SPIRV::StorageClass::StorageClass Storage,
466 const MachineInstr *Init, bool IsConst, bool HasLinkageTy,
467 SPIRV::LinkageType::LinkageType LinkageType, MachineIRBuilder &MIRBuilder,
468 bool IsInstSelector) {
469 const GlobalVariable *GVar = nullptr;
470 if (GV)
471 GVar = cast<const GlobalVariable>(GV);
472 else {
473 // If GV is not passed explicitly, use the name to find or construct
474 // the global variable.
475 Module *M = MIRBuilder.getMF().getFunction().getParent();
476 GVar = M->getGlobalVariable(Name);
477 if (GVar == nullptr) {
478 const Type *Ty = getTypeForSPIRVType(BaseType); // TODO: check type.
479 GVar = new GlobalVariable(*M, const_cast<Type *>(Ty), false,
480 GlobalValue::ExternalLinkage, nullptr,
481 Twine(Name));
482 }
483 GV = GVar;
484 }
485 Register Reg = DT.find(GVar, &MIRBuilder.getMF());
486 if (Reg.isValid()) {
487 if (Reg != ResVReg)
488 MIRBuilder.buildCopy(ResVReg, Reg);
489 return ResVReg;
490 }
491
492 auto MIB = MIRBuilder.buildInstr(SPIRV::OpVariable)
493 .addDef(ResVReg)
494 .addUse(getSPIRVTypeID(BaseType))
495 .addImm(static_cast<uint32_t>(Storage));
496
497 if (Init != 0) {
498 MIB.addUse(Init->getOperand(0).getReg());
499 }
500
501 // ISel may introduce a new register on this step, so we need to add it to
502 // DT and correct its type avoiding fails on the next stage.
503 if (IsInstSelector) {
504 const auto &Subtarget = CurMF->getSubtarget();
505 constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
506 *Subtarget.getRegisterInfo(),
507 *Subtarget.getRegBankInfo());
508 }
509 Reg = MIB->getOperand(0).getReg();
510 DT.add(GVar, &MIRBuilder.getMF(), Reg);
511
512 // Set to Reg the same type as ResVReg has.
513 auto MRI = MIRBuilder.getMRI();
514 assert(MRI->getType(ResVReg).isPointer() && "Pointer type is expected");
515 if (Reg != ResVReg) {
516 LLT RegLLTy = LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), 32);
517 MRI->setType(Reg, RegLLTy);
518 assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF());
519 }
520
521 // If it's a global variable with name, output OpName for it.
522 if (GVar && GVar->hasName())
523 buildOpName(Reg, GVar->getName(), MIRBuilder);
524
525 // Output decorations for the GV.
526 // TODO: maybe move to GenerateDecorations pass.
527 if (IsConst)
528 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Constant, {});
529
530 if (GVar && GVar->getAlign().valueOrOne().value() != 1) {
531 unsigned Alignment = (unsigned)GVar->getAlign().valueOrOne().value();
532 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Alignment, {Alignment});
533 }
534
535 if (HasLinkageTy)
536 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::LinkageAttributes,
537 {static_cast<uint32_t>(LinkageType)}, Name);
538
539 SPIRV::BuiltIn::BuiltIn BuiltInId;
540 if (getSpirvBuiltInIdByName(Name, BuiltInId))
541 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::BuiltIn,
542 {static_cast<uint32_t>(BuiltInId)});
543
544 return Reg;
545 }
546
getOpTypeArray(uint32_t NumElems,SPIRVType * ElemType,MachineIRBuilder & MIRBuilder,bool EmitIR)547 SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems,
548 SPIRVType *ElemType,
549 MachineIRBuilder &MIRBuilder,
550 bool EmitIR) {
551 assert((ElemType->getOpcode() != SPIRV::OpTypeVoid) &&
552 "Invalid array element type");
553 Register NumElementsVReg =
554 buildConstantInt(NumElems, MIRBuilder, nullptr, EmitIR);
555 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeArray)
556 .addDef(createTypeVReg(MIRBuilder))
557 .addUse(getSPIRVTypeID(ElemType))
558 .addUse(NumElementsVReg);
559 return MIB;
560 }
561
getOpTypeOpaque(const StructType * Ty,MachineIRBuilder & MIRBuilder)562 SPIRVType *SPIRVGlobalRegistry::getOpTypeOpaque(const StructType *Ty,
563 MachineIRBuilder &MIRBuilder) {
564 assert(Ty->hasName());
565 const StringRef Name = Ty->hasName() ? Ty->getName() : "";
566 Register ResVReg = createTypeVReg(MIRBuilder);
567 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeOpaque).addDef(ResVReg);
568 addStringImm(Name, MIB);
569 buildOpName(ResVReg, Name, MIRBuilder);
570 return MIB;
571 }
572
getOpTypeStruct(const StructType * Ty,MachineIRBuilder & MIRBuilder,bool EmitIR)573 SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty,
574 MachineIRBuilder &MIRBuilder,
575 bool EmitIR) {
576 SmallVector<Register, 4> FieldTypes;
577 for (const auto &Elem : Ty->elements()) {
578 SPIRVType *ElemTy = findSPIRVType(Elem, MIRBuilder);
579 assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid &&
580 "Invalid struct element type");
581 FieldTypes.push_back(getSPIRVTypeID(ElemTy));
582 }
583 Register ResVReg = createTypeVReg(MIRBuilder);
584 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeStruct).addDef(ResVReg);
585 for (const auto &Ty : FieldTypes)
586 MIB.addUse(Ty);
587 if (Ty->hasName())
588 buildOpName(ResVReg, Ty->getName(), MIRBuilder);
589 if (Ty->isPacked())
590 buildOpDecorate(ResVReg, MIRBuilder, SPIRV::Decoration::CPacked, {});
591 return MIB;
592 }
593
getOrCreateSpecialType(const Type * Ty,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccQual)594 SPIRVType *SPIRVGlobalRegistry::getOrCreateSpecialType(
595 const Type *Ty, MachineIRBuilder &MIRBuilder,
596 SPIRV::AccessQualifier::AccessQualifier AccQual) {
597 assert(isSpecialOpaqueType(Ty) && "Not a special opaque builtin type");
598 return SPIRV::lowerBuiltinType(Ty, AccQual, MIRBuilder, this);
599 }
600
getOpTypePointer(SPIRV::StorageClass::StorageClass SC,SPIRVType * ElemType,MachineIRBuilder & MIRBuilder,Register Reg)601 SPIRVType *SPIRVGlobalRegistry::getOpTypePointer(
602 SPIRV::StorageClass::StorageClass SC, SPIRVType *ElemType,
603 MachineIRBuilder &MIRBuilder, Register Reg) {
604 if (!Reg.isValid())
605 Reg = createTypeVReg(MIRBuilder);
606 return MIRBuilder.buildInstr(SPIRV::OpTypePointer)
607 .addDef(Reg)
608 .addImm(static_cast<uint32_t>(SC))
609 .addUse(getSPIRVTypeID(ElemType));
610 }
611
getOpTypeForwardPointer(SPIRV::StorageClass::StorageClass SC,MachineIRBuilder & MIRBuilder)612 SPIRVType *SPIRVGlobalRegistry::getOpTypeForwardPointer(
613 SPIRV::StorageClass::StorageClass SC, MachineIRBuilder &MIRBuilder) {
614 return MIRBuilder.buildInstr(SPIRV::OpTypeForwardPointer)
615 .addUse(createTypeVReg(MIRBuilder))
616 .addImm(static_cast<uint32_t>(SC));
617 }
618
getOpTypeFunction(SPIRVType * RetType,const SmallVectorImpl<SPIRVType * > & ArgTypes,MachineIRBuilder & MIRBuilder)619 SPIRVType *SPIRVGlobalRegistry::getOpTypeFunction(
620 SPIRVType *RetType, const SmallVectorImpl<SPIRVType *> &ArgTypes,
621 MachineIRBuilder &MIRBuilder) {
622 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFunction)
623 .addDef(createTypeVReg(MIRBuilder))
624 .addUse(getSPIRVTypeID(RetType));
625 for (const SPIRVType *ArgType : ArgTypes)
626 MIB.addUse(getSPIRVTypeID(ArgType));
627 return MIB;
628 }
629
getOrCreateOpTypeFunctionWithArgs(const Type * Ty,SPIRVType * RetType,const SmallVectorImpl<SPIRVType * > & ArgTypes,MachineIRBuilder & MIRBuilder)630 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeFunctionWithArgs(
631 const Type *Ty, SPIRVType *RetType,
632 const SmallVectorImpl<SPIRVType *> &ArgTypes,
633 MachineIRBuilder &MIRBuilder) {
634 Register Reg = DT.find(Ty, &MIRBuilder.getMF());
635 if (Reg.isValid())
636 return getSPIRVTypeForVReg(Reg);
637 SPIRVType *SpirvType = getOpTypeFunction(RetType, ArgTypes, MIRBuilder);
638 DT.add(Ty, CurMF, getSPIRVTypeID(SpirvType));
639 return finishCreatingSPIRVType(Ty, SpirvType);
640 }
641
findSPIRVType(const Type * Ty,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccQual,bool EmitIR)642 SPIRVType *SPIRVGlobalRegistry::findSPIRVType(
643 const Type *Ty, MachineIRBuilder &MIRBuilder,
644 SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
645 Register Reg = DT.find(Ty, &MIRBuilder.getMF());
646 if (Reg.isValid())
647 return getSPIRVTypeForVReg(Reg);
648 if (ForwardPointerTypes.contains(Ty))
649 return ForwardPointerTypes[Ty];
650 return restOfCreateSPIRVType(Ty, MIRBuilder, AccQual, EmitIR);
651 }
652
getSPIRVTypeID(const SPIRVType * SpirvType) const653 Register SPIRVGlobalRegistry::getSPIRVTypeID(const SPIRVType *SpirvType) const {
654 assert(SpirvType && "Attempting to get type id for nullptr type.");
655 if (SpirvType->getOpcode() == SPIRV::OpTypeForwardPointer)
656 return SpirvType->uses().begin()->getReg();
657 return SpirvType->defs().begin()->getReg();
658 }
659
createSPIRVType(const Type * Ty,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccQual,bool EmitIR)660 SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
661 const Type *Ty, MachineIRBuilder &MIRBuilder,
662 SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
663 if (isSpecialOpaqueType(Ty))
664 return getOrCreateSpecialType(Ty, MIRBuilder, AccQual);
665 auto &TypeToSPIRVTypeMap = DT.getTypes()->getAllUses();
666 auto t = TypeToSPIRVTypeMap.find(Ty);
667 if (t != TypeToSPIRVTypeMap.end()) {
668 auto tt = t->second.find(&MIRBuilder.getMF());
669 if (tt != t->second.end())
670 return getSPIRVTypeForVReg(tt->second);
671 }
672
673 if (auto IType = dyn_cast<IntegerType>(Ty)) {
674 const unsigned Width = IType->getBitWidth();
675 return Width == 1 ? getOpTypeBool(MIRBuilder)
676 : getOpTypeInt(Width, MIRBuilder, false);
677 }
678 if (Ty->isFloatingPointTy())
679 return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
680 if (Ty->isVoidTy())
681 return getOpTypeVoid(MIRBuilder);
682 if (Ty->isVectorTy()) {
683 SPIRVType *El =
684 findSPIRVType(cast<FixedVectorType>(Ty)->getElementType(), MIRBuilder);
685 return getOpTypeVector(cast<FixedVectorType>(Ty)->getNumElements(), El,
686 MIRBuilder);
687 }
688 if (Ty->isArrayTy()) {
689 SPIRVType *El = findSPIRVType(Ty->getArrayElementType(), MIRBuilder);
690 return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder, EmitIR);
691 }
692 if (auto SType = dyn_cast<StructType>(Ty)) {
693 if (SType->isOpaque())
694 return getOpTypeOpaque(SType, MIRBuilder);
695 return getOpTypeStruct(SType, MIRBuilder, EmitIR);
696 }
697 if (auto FType = dyn_cast<FunctionType>(Ty)) {
698 SPIRVType *RetTy = findSPIRVType(FType->getReturnType(), MIRBuilder);
699 SmallVector<SPIRVType *, 4> ParamTypes;
700 for (const auto &t : FType->params()) {
701 ParamTypes.push_back(findSPIRVType(t, MIRBuilder));
702 }
703 return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder);
704 }
705 if (auto PType = dyn_cast<PointerType>(Ty)) {
706 SPIRVType *SpvElementType;
707 // At the moment, all opaque pointers correspond to i8 element type.
708 // TODO: change the implementation once opaque pointers are supported
709 // in the SPIR-V specification.
710 SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);
711 auto SC = addressSpaceToStorageClass(PType->getAddressSpace());
712 // Null pointer means we have a loop in type definitions, make and
713 // return corresponding OpTypeForwardPointer.
714 if (SpvElementType == nullptr) {
715 if (!ForwardPointerTypes.contains(Ty))
716 ForwardPointerTypes[PType] = getOpTypeForwardPointer(SC, MIRBuilder);
717 return ForwardPointerTypes[PType];
718 }
719 Register Reg(0);
720 // If we have forward pointer associated with this type, use its register
721 // operand to create OpTypePointer.
722 if (ForwardPointerTypes.contains(PType))
723 Reg = getSPIRVTypeID(ForwardPointerTypes[PType]);
724
725 return getOpTypePointer(SC, SpvElementType, MIRBuilder, Reg);
726 }
727 llvm_unreachable("Unable to convert LLVM type to SPIRVType");
728 }
729
restOfCreateSPIRVType(const Type * Ty,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccessQual,bool EmitIR)730 SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
731 const Type *Ty, MachineIRBuilder &MIRBuilder,
732 SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
733 if (TypesInProcessing.count(Ty) && !Ty->isPointerTy())
734 return nullptr;
735 TypesInProcessing.insert(Ty);
736 SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
737 TypesInProcessing.erase(Ty);
738 VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;
739 SPIRVToLLVMType[SpirvType] = Ty;
740 Register Reg = DT.find(Ty, &MIRBuilder.getMF());
741 // Do not add OpTypeForwardPointer to DT, a corresponding normal pointer type
742 // will be added later. For special types it is already added to DT.
743 if (SpirvType->getOpcode() != SPIRV::OpTypeForwardPointer && !Reg.isValid() &&
744 !isSpecialOpaqueType(Ty)) {
745 if (!Ty->isPointerTy())
746 DT.add(Ty, &MIRBuilder.getMF(), getSPIRVTypeID(SpirvType));
747 else
748 DT.add(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()),
749 Ty->getPointerAddressSpace(), &MIRBuilder.getMF(),
750 getSPIRVTypeID(SpirvType));
751 }
752
753 return SpirvType;
754 }
755
getSPIRVTypeForVReg(Register VReg) const756 SPIRVType *SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg) const {
757 auto t = VRegToTypeMap.find(CurMF);
758 if (t != VRegToTypeMap.end()) {
759 auto tt = t->second.find(VReg);
760 if (tt != t->second.end())
761 return tt->second;
762 }
763 return nullptr;
764 }
765
getOrCreateSPIRVType(const Type * Ty,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccessQual,bool EmitIR)766 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(
767 const Type *Ty, MachineIRBuilder &MIRBuilder,
768 SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
769 Register Reg;
770 if (!Ty->isPointerTy())
771 Reg = DT.find(Ty, &MIRBuilder.getMF());
772 else
773 Reg =
774 DT.find(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()),
775 Ty->getPointerAddressSpace(), &MIRBuilder.getMF());
776
777 if (Reg.isValid() && !isSpecialOpaqueType(Ty))
778 return getSPIRVTypeForVReg(Reg);
779 TypesInProcessing.clear();
780 SPIRVType *STy = restOfCreateSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
781 // Create normal pointer types for the corresponding OpTypeForwardPointers.
782 for (auto &CU : ForwardPointerTypes) {
783 const Type *Ty2 = CU.first;
784 SPIRVType *STy2 = CU.second;
785 if ((Reg = DT.find(Ty2, &MIRBuilder.getMF())).isValid())
786 STy2 = getSPIRVTypeForVReg(Reg);
787 else
788 STy2 = restOfCreateSPIRVType(Ty2, MIRBuilder, AccessQual, EmitIR);
789 if (Ty == Ty2)
790 STy = STy2;
791 }
792 ForwardPointerTypes.clear();
793 return STy;
794 }
795
isScalarOfType(Register VReg,unsigned TypeOpcode) const796 bool SPIRVGlobalRegistry::isScalarOfType(Register VReg,
797 unsigned TypeOpcode) const {
798 SPIRVType *Type = getSPIRVTypeForVReg(VReg);
799 assert(Type && "isScalarOfType VReg has no type assigned");
800 return Type->getOpcode() == TypeOpcode;
801 }
802
isScalarOrVectorOfType(Register VReg,unsigned TypeOpcode) const803 bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg,
804 unsigned TypeOpcode) const {
805 SPIRVType *Type = getSPIRVTypeForVReg(VReg);
806 assert(Type && "isScalarOrVectorOfType VReg has no type assigned");
807 if (Type->getOpcode() == TypeOpcode)
808 return true;
809 if (Type->getOpcode() == SPIRV::OpTypeVector) {
810 Register ScalarTypeVReg = Type->getOperand(1).getReg();
811 SPIRVType *ScalarType = getSPIRVTypeForVReg(ScalarTypeVReg);
812 return ScalarType->getOpcode() == TypeOpcode;
813 }
814 return false;
815 }
816
817 unsigned
getScalarOrVectorBitWidth(const SPIRVType * Type) const818 SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const {
819 assert(Type && "Invalid Type pointer");
820 if (Type->getOpcode() == SPIRV::OpTypeVector) {
821 auto EleTypeReg = Type->getOperand(1).getReg();
822 Type = getSPIRVTypeForVReg(EleTypeReg);
823 }
824 if (Type->getOpcode() == SPIRV::OpTypeInt ||
825 Type->getOpcode() == SPIRV::OpTypeFloat)
826 return Type->getOperand(1).getImm();
827 if (Type->getOpcode() == SPIRV::OpTypeBool)
828 return 1;
829 llvm_unreachable("Attempting to get bit width of non-integer/float type.");
830 }
831
isScalarOrVectorSigned(const SPIRVType * Type) const832 bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const {
833 assert(Type && "Invalid Type pointer");
834 if (Type->getOpcode() == SPIRV::OpTypeVector) {
835 auto EleTypeReg = Type->getOperand(1).getReg();
836 Type = getSPIRVTypeForVReg(EleTypeReg);
837 }
838 if (Type->getOpcode() == SPIRV::OpTypeInt)
839 return Type->getOperand(2).getImm() != 0;
840 llvm_unreachable("Attempting to get sign of non-integer type.");
841 }
842
843 SPIRV::StorageClass::StorageClass
getPointerStorageClass(Register VReg) const844 SPIRVGlobalRegistry::getPointerStorageClass(Register VReg) const {
845 SPIRVType *Type = getSPIRVTypeForVReg(VReg);
846 assert(Type && Type->getOpcode() == SPIRV::OpTypePointer &&
847 Type->getOperand(1).isImm() && "Pointer type is expected");
848 return static_cast<SPIRV::StorageClass::StorageClass>(
849 Type->getOperand(1).getImm());
850 }
851
getOrCreateOpTypeImage(MachineIRBuilder & MIRBuilder,SPIRVType * SampledType,SPIRV::Dim::Dim Dim,uint32_t Depth,uint32_t Arrayed,uint32_t Multisampled,uint32_t Sampled,SPIRV::ImageFormat::ImageFormat ImageFormat,SPIRV::AccessQualifier::AccessQualifier AccessQual)852 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeImage(
853 MachineIRBuilder &MIRBuilder, SPIRVType *SampledType, SPIRV::Dim::Dim Dim,
854 uint32_t Depth, uint32_t Arrayed, uint32_t Multisampled, uint32_t Sampled,
855 SPIRV::ImageFormat::ImageFormat ImageFormat,
856 SPIRV::AccessQualifier::AccessQualifier AccessQual) {
857 SPIRV::ImageTypeDescriptor TD(SPIRVToLLVMType.lookup(SampledType), Dim, Depth,
858 Arrayed, Multisampled, Sampled, ImageFormat,
859 AccessQual);
860 if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
861 return Res;
862 Register ResVReg = createTypeVReg(MIRBuilder);
863 DT.add(TD, &MIRBuilder.getMF(), ResVReg);
864 return MIRBuilder.buildInstr(SPIRV::OpTypeImage)
865 .addDef(ResVReg)
866 .addUse(getSPIRVTypeID(SampledType))
867 .addImm(Dim)
868 .addImm(Depth) // Depth (whether or not it is a Depth image).
869 .addImm(Arrayed) // Arrayed.
870 .addImm(Multisampled) // Multisampled (0 = only single-sample).
871 .addImm(Sampled) // Sampled (0 = usage known at runtime).
872 .addImm(ImageFormat)
873 .addImm(AccessQual);
874 }
875
876 SPIRVType *
getOrCreateOpTypeSampler(MachineIRBuilder & MIRBuilder)877 SPIRVGlobalRegistry::getOrCreateOpTypeSampler(MachineIRBuilder &MIRBuilder) {
878 SPIRV::SamplerTypeDescriptor TD;
879 if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
880 return Res;
881 Register ResVReg = createTypeVReg(MIRBuilder);
882 DT.add(TD, &MIRBuilder.getMF(), ResVReg);
883 return MIRBuilder.buildInstr(SPIRV::OpTypeSampler).addDef(ResVReg);
884 }
885
getOrCreateOpTypePipe(MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccessQual)886 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypePipe(
887 MachineIRBuilder &MIRBuilder,
888 SPIRV::AccessQualifier::AccessQualifier AccessQual) {
889 SPIRV::PipeTypeDescriptor TD(AccessQual);
890 if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
891 return Res;
892 Register ResVReg = createTypeVReg(MIRBuilder);
893 DT.add(TD, &MIRBuilder.getMF(), ResVReg);
894 return MIRBuilder.buildInstr(SPIRV::OpTypePipe)
895 .addDef(ResVReg)
896 .addImm(AccessQual);
897 }
898
getOrCreateOpTypeDeviceEvent(MachineIRBuilder & MIRBuilder)899 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeDeviceEvent(
900 MachineIRBuilder &MIRBuilder) {
901 SPIRV::DeviceEventTypeDescriptor TD;
902 if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
903 return Res;
904 Register ResVReg = createTypeVReg(MIRBuilder);
905 DT.add(TD, &MIRBuilder.getMF(), ResVReg);
906 return MIRBuilder.buildInstr(SPIRV::OpTypeDeviceEvent).addDef(ResVReg);
907 }
908
getOrCreateOpTypeSampledImage(SPIRVType * ImageType,MachineIRBuilder & MIRBuilder)909 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeSampledImage(
910 SPIRVType *ImageType, MachineIRBuilder &MIRBuilder) {
911 SPIRV::SampledImageTypeDescriptor TD(
912 SPIRVToLLVMType.lookup(MIRBuilder.getMF().getRegInfo().getVRegDef(
913 ImageType->getOperand(1).getReg())),
914 ImageType);
915 if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
916 return Res;
917 Register ResVReg = createTypeVReg(MIRBuilder);
918 DT.add(TD, &MIRBuilder.getMF(), ResVReg);
919 return MIRBuilder.buildInstr(SPIRV::OpTypeSampledImage)
920 .addDef(ResVReg)
921 .addUse(getSPIRVTypeID(ImageType));
922 }
923
getOrCreateOpTypeByOpcode(const Type * Ty,MachineIRBuilder & MIRBuilder,unsigned Opcode)924 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeByOpcode(
925 const Type *Ty, MachineIRBuilder &MIRBuilder, unsigned Opcode) {
926 Register ResVReg = DT.find(Ty, &MIRBuilder.getMF());
927 if (ResVReg.isValid())
928 return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(ResVReg);
929 ResVReg = createTypeVReg(MIRBuilder);
930 SPIRVType *SpirvTy = MIRBuilder.buildInstr(Opcode).addDef(ResVReg);
931 DT.add(Ty, &MIRBuilder.getMF(), ResVReg);
932 return SpirvTy;
933 }
934
935 const MachineInstr *
checkSpecialInstr(const SPIRV::SpecialTypeDescriptor & TD,MachineIRBuilder & MIRBuilder)936 SPIRVGlobalRegistry::checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD,
937 MachineIRBuilder &MIRBuilder) {
938 Register Reg = DT.find(TD, &MIRBuilder.getMF());
939 if (Reg.isValid())
940 return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(Reg);
941 return nullptr;
942 }
943
944 // TODO: maybe use tablegen to implement this.
getOrCreateSPIRVTypeByName(StringRef TypeStr,MachineIRBuilder & MIRBuilder,SPIRV::StorageClass::StorageClass SC,SPIRV::AccessQualifier::AccessQualifier AQ)945 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(
946 StringRef TypeStr, MachineIRBuilder &MIRBuilder,
947 SPIRV::StorageClass::StorageClass SC,
948 SPIRV::AccessQualifier::AccessQualifier AQ) {
949 unsigned VecElts = 0;
950 auto &Ctx = MIRBuilder.getMF().getFunction().getContext();
951
952 // Parse strings representing either a SPIR-V or OpenCL builtin type.
953 if (hasBuiltinTypePrefix(TypeStr))
954 return getOrCreateSPIRVType(
955 SPIRV::parseBuiltinTypeNameToTargetExtType(TypeStr.str(), MIRBuilder),
956 MIRBuilder, AQ);
957
958 // Parse type name in either "typeN" or "type vector[N]" format, where
959 // N is the number of elements of the vector.
960 Type *Ty;
961
962 TypeStr.consume_front("atomic_");
963
964 if (TypeStr.starts_with("void")) {
965 Ty = Type::getVoidTy(Ctx);
966 TypeStr = TypeStr.substr(strlen("void"));
967 } else if (TypeStr.starts_with("bool")) {
968 Ty = Type::getIntNTy(Ctx, 1);
969 TypeStr = TypeStr.substr(strlen("bool"));
970 } else if (TypeStr.starts_with("char") || TypeStr.starts_with("uchar")) {
971 Ty = Type::getInt8Ty(Ctx);
972 TypeStr = TypeStr.starts_with("char") ? TypeStr.substr(strlen("char"))
973 : TypeStr.substr(strlen("uchar"));
974 } else if (TypeStr.starts_with("short") || TypeStr.starts_with("ushort")) {
975 Ty = Type::getInt16Ty(Ctx);
976 TypeStr = TypeStr.starts_with("short") ? TypeStr.substr(strlen("short"))
977 : TypeStr.substr(strlen("ushort"));
978 } else if (TypeStr.starts_with("int") || TypeStr.starts_with("uint")) {
979 Ty = Type::getInt32Ty(Ctx);
980 TypeStr = TypeStr.starts_with("int") ? TypeStr.substr(strlen("int"))
981 : TypeStr.substr(strlen("uint"));
982 } else if (TypeStr.starts_with("long") || TypeStr.starts_with("ulong")) {
983 Ty = Type::getInt64Ty(Ctx);
984 TypeStr = TypeStr.starts_with("long") ? TypeStr.substr(strlen("long"))
985 : TypeStr.substr(strlen("ulong"));
986 } else if (TypeStr.starts_with("half")) {
987 Ty = Type::getHalfTy(Ctx);
988 TypeStr = TypeStr.substr(strlen("half"));
989 } else if (TypeStr.starts_with("float")) {
990 Ty = Type::getFloatTy(Ctx);
991 TypeStr = TypeStr.substr(strlen("float"));
992 } else if (TypeStr.starts_with("double")) {
993 Ty = Type::getDoubleTy(Ctx);
994 TypeStr = TypeStr.substr(strlen("double"));
995 } else
996 llvm_unreachable("Unable to recognize SPIRV type name.");
997
998 auto SpirvTy = getOrCreateSPIRVType(Ty, MIRBuilder, AQ);
999
1000 // Handle "type*" or "type* vector[N]".
1001 if (TypeStr.starts_with("*")) {
1002 SpirvTy = getOrCreateSPIRVPointerType(SpirvTy, MIRBuilder, SC);
1003 TypeStr = TypeStr.substr(strlen("*"));
1004 }
1005
1006 // Handle "typeN*" or "type vector[N]*".
1007 bool IsPtrToVec = TypeStr.consume_back("*");
1008
1009 if (TypeStr.consume_front(" vector[")) {
1010 TypeStr = TypeStr.substr(0, TypeStr.find(']'));
1011 }
1012 TypeStr.getAsInteger(10, VecElts);
1013 if (VecElts > 0)
1014 SpirvTy = getOrCreateSPIRVVectorType(SpirvTy, VecElts, MIRBuilder);
1015
1016 if (IsPtrToVec)
1017 SpirvTy = getOrCreateSPIRVPointerType(SpirvTy, MIRBuilder, SC);
1018
1019 return SpirvTy;
1020 }
1021
1022 SPIRVType *
getOrCreateSPIRVIntegerType(unsigned BitWidth,MachineIRBuilder & MIRBuilder)1023 SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth,
1024 MachineIRBuilder &MIRBuilder) {
1025 return getOrCreateSPIRVType(
1026 IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), BitWidth),
1027 MIRBuilder);
1028 }
1029
finishCreatingSPIRVType(const Type * LLVMTy,SPIRVType * SpirvType)1030 SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy,
1031 SPIRVType *SpirvType) {
1032 assert(CurMF == SpirvType->getMF());
1033 VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType;
1034 SPIRVToLLVMType[SpirvType] = LLVMTy;
1035 return SpirvType;
1036 }
1037
getOrCreateSPIRVIntegerType(unsigned BitWidth,MachineInstr & I,const SPIRVInstrInfo & TII)1038 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
1039 unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
1040 Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth);
1041 Register Reg = DT.find(LLVMTy, CurMF);
1042 if (Reg.isValid())
1043 return getSPIRVTypeForVReg(Reg);
1044 MachineBasicBlock &BB = *I.getParent();
1045 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeInt))
1046 .addDef(createTypeVReg(CurMF->getRegInfo()))
1047 .addImm(BitWidth)
1048 .addImm(0);
1049 DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB));
1050 return finishCreatingSPIRVType(LLVMTy, MIB);
1051 }
1052
1053 SPIRVType *
getOrCreateSPIRVBoolType(MachineIRBuilder & MIRBuilder)1054 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder) {
1055 return getOrCreateSPIRVType(
1056 IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), 1),
1057 MIRBuilder);
1058 }
1059
1060 SPIRVType *
getOrCreateSPIRVBoolType(MachineInstr & I,const SPIRVInstrInfo & TII)1061 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineInstr &I,
1062 const SPIRVInstrInfo &TII) {
1063 Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), 1);
1064 Register Reg = DT.find(LLVMTy, CurMF);
1065 if (Reg.isValid())
1066 return getSPIRVTypeForVReg(Reg);
1067 MachineBasicBlock &BB = *I.getParent();
1068 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeBool))
1069 .addDef(createTypeVReg(CurMF->getRegInfo()));
1070 DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB));
1071 return finishCreatingSPIRVType(LLVMTy, MIB);
1072 }
1073
getOrCreateSPIRVVectorType(SPIRVType * BaseType,unsigned NumElements,MachineIRBuilder & MIRBuilder)1074 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
1075 SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder) {
1076 return getOrCreateSPIRVType(
1077 FixedVectorType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)),
1078 NumElements),
1079 MIRBuilder);
1080 }
1081
getOrCreateSPIRVVectorType(SPIRVType * BaseType,unsigned NumElements,MachineInstr & I,const SPIRVInstrInfo & TII)1082 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
1083 SPIRVType *BaseType, unsigned NumElements, MachineInstr &I,
1084 const SPIRVInstrInfo &TII) {
1085 Type *LLVMTy = FixedVectorType::get(
1086 const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements);
1087 Register Reg = DT.find(LLVMTy, CurMF);
1088 if (Reg.isValid())
1089 return getSPIRVTypeForVReg(Reg);
1090 MachineBasicBlock &BB = *I.getParent();
1091 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeVector))
1092 .addDef(createTypeVReg(CurMF->getRegInfo()))
1093 .addUse(getSPIRVTypeID(BaseType))
1094 .addImm(NumElements);
1095 DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB));
1096 return finishCreatingSPIRVType(LLVMTy, MIB);
1097 }
1098
getOrCreateSPIRVArrayType(SPIRVType * BaseType,unsigned NumElements,MachineInstr & I,const SPIRVInstrInfo & TII)1099 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVArrayType(
1100 SPIRVType *BaseType, unsigned NumElements, MachineInstr &I,
1101 const SPIRVInstrInfo &TII) {
1102 Type *LLVMTy = ArrayType::get(
1103 const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements);
1104 Register Reg = DT.find(LLVMTy, CurMF);
1105 if (Reg.isValid())
1106 return getSPIRVTypeForVReg(Reg);
1107 MachineBasicBlock &BB = *I.getParent();
1108 SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(32, I, TII);
1109 Register Len = getOrCreateConstInt(NumElements, I, SpirvType, TII);
1110 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeArray))
1111 .addDef(createTypeVReg(CurMF->getRegInfo()))
1112 .addUse(getSPIRVTypeID(BaseType))
1113 .addUse(Len);
1114 DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB));
1115 return finishCreatingSPIRVType(LLVMTy, MIB);
1116 }
1117
getOrCreateSPIRVPointerType(SPIRVType * BaseType,MachineIRBuilder & MIRBuilder,SPIRV::StorageClass::StorageClass SC)1118 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
1119 SPIRVType *BaseType, MachineIRBuilder &MIRBuilder,
1120 SPIRV::StorageClass::StorageClass SC) {
1121 const Type *PointerElementType = getTypeForSPIRVType(BaseType);
1122 unsigned AddressSpace = storageClassToAddressSpace(SC);
1123 Type *LLVMTy =
1124 PointerType::get(const_cast<Type *>(PointerElementType), AddressSpace);
1125 Register Reg = DT.find(PointerElementType, AddressSpace, CurMF);
1126 if (Reg.isValid())
1127 return getSPIRVTypeForVReg(Reg);
1128 auto MIB = BuildMI(MIRBuilder.getMBB(), MIRBuilder.getInsertPt(),
1129 MIRBuilder.getDebugLoc(),
1130 MIRBuilder.getTII().get(SPIRV::OpTypePointer))
1131 .addDef(createTypeVReg(CurMF->getRegInfo()))
1132 .addImm(static_cast<uint32_t>(SC))
1133 .addUse(getSPIRVTypeID(BaseType));
1134 DT.add(PointerElementType, AddressSpace, CurMF, getSPIRVTypeID(MIB));
1135 return finishCreatingSPIRVType(LLVMTy, MIB);
1136 }
1137
getOrCreateSPIRVPointerType(SPIRVType * BaseType,MachineInstr & I,const SPIRVInstrInfo & TII,SPIRV::StorageClass::StorageClass SC)1138 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
1139 SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &TII,
1140 SPIRV::StorageClass::StorageClass SC) {
1141 const Type *PointerElementType = getTypeForSPIRVType(BaseType);
1142 unsigned AddressSpace = storageClassToAddressSpace(SC);
1143 Type *LLVMTy =
1144 PointerType::get(const_cast<Type *>(PointerElementType), AddressSpace);
1145 Register Reg = DT.find(PointerElementType, AddressSpace, CurMF);
1146 if (Reg.isValid())
1147 return getSPIRVTypeForVReg(Reg);
1148 MachineBasicBlock &BB = *I.getParent();
1149 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypePointer))
1150 .addDef(createTypeVReg(CurMF->getRegInfo()))
1151 .addImm(static_cast<uint32_t>(SC))
1152 .addUse(getSPIRVTypeID(BaseType));
1153 DT.add(PointerElementType, AddressSpace, CurMF, getSPIRVTypeID(MIB));
1154 return finishCreatingSPIRVType(LLVMTy, MIB);
1155 }
1156
getOrCreateUndef(MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII)1157 Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I,
1158 SPIRVType *SpvType,
1159 const SPIRVInstrInfo &TII) {
1160 assert(SpvType);
1161 const Type *LLVMTy = getTypeForSPIRVType(SpvType);
1162 assert(LLVMTy);
1163 // Find a constant in DT or build a new one.
1164 UndefValue *UV = UndefValue::get(const_cast<Type *>(LLVMTy));
1165 Register Res = DT.find(UV, CurMF);
1166 if (Res.isValid())
1167 return Res;
1168 LLT LLTy = LLT::scalar(32);
1169 Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
1170 CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
1171 assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
1172 DT.add(UV, CurMF, Res);
1173
1174 MachineInstrBuilder MIB;
1175 MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpUndef))
1176 .addDef(Res)
1177 .addUse(getSPIRVTypeID(SpvType));
1178 const auto &ST = CurMF->getSubtarget();
1179 constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
1180 *ST.getRegisterInfo(), *ST.getRegBankInfo());
1181 return Res;
1182 }
1183