1 //===-- SPIRVGlobalRegistry.cpp - SPIR-V Global Registry --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the implementation of the SPIRVGlobalRegistry class,
10 // which is used to maintain rich type information required for SPIR-V even
11 // after lowering from LLVM IR to GMIR. It can convert an llvm::Type into
12 // an OpTypeXXX instruction, and map it to a virtual register. Also it builds
13 // and supports consistency of constants and global variables.
14 //
15 //===----------------------------------------------------------------------===//
16
17 #include "SPIRVGlobalRegistry.h"
18 #include "SPIRV.h"
19 #include "SPIRVBuiltins.h"
20 #include "SPIRVSubtarget.h"
21 #include "SPIRVTargetMachine.h"
22 #include "SPIRVUtils.h"
23
24 using namespace llvm;
SPIRVGlobalRegistry(unsigned PointerSize)25 SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize)
26 : PointerSize(PointerSize) {}
27
assignIntTypeToVReg(unsigned BitWidth,Register VReg,MachineInstr & I,const SPIRVInstrInfo & TII)28 SPIRVType *SPIRVGlobalRegistry::assignIntTypeToVReg(unsigned BitWidth,
29 Register VReg,
30 MachineInstr &I,
31 const SPIRVInstrInfo &TII) {
32 SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
33 assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);
34 return SpirvType;
35 }
36
assignVectTypeToVReg(SPIRVType * BaseType,unsigned NumElements,Register VReg,MachineInstr & I,const SPIRVInstrInfo & TII)37 SPIRVType *SPIRVGlobalRegistry::assignVectTypeToVReg(
38 SPIRVType *BaseType, unsigned NumElements, Register VReg, MachineInstr &I,
39 const SPIRVInstrInfo &TII) {
40 SPIRVType *SpirvType =
41 getOrCreateSPIRVVectorType(BaseType, NumElements, I, TII);
42 assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);
43 return SpirvType;
44 }
45
assignTypeToVReg(const Type * Type,Register VReg,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccessQual,bool EmitIR)46 SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg(
47 const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder,
48 SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
49
50 SPIRVType *SpirvType =
51 getOrCreateSPIRVType(Type, MIRBuilder, AccessQual, EmitIR);
52 assignSPIRVTypeToVReg(SpirvType, VReg, MIRBuilder.getMF());
53 return SpirvType;
54 }
55
assignSPIRVTypeToVReg(SPIRVType * SpirvType,Register VReg,MachineFunction & MF)56 void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType,
57 Register VReg,
58 MachineFunction &MF) {
59 VRegToTypeMap[&MF][VReg] = SpirvType;
60 }
61
createTypeVReg(MachineIRBuilder & MIRBuilder)62 static Register createTypeVReg(MachineIRBuilder &MIRBuilder) {
63 auto &MRI = MIRBuilder.getMF().getRegInfo();
64 auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32));
65 MRI.setRegClass(Res, &SPIRV::TYPERegClass);
66 return Res;
67 }
68
createTypeVReg(MachineRegisterInfo & MRI)69 static Register createTypeVReg(MachineRegisterInfo &MRI) {
70 auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32));
71 MRI.setRegClass(Res, &SPIRV::TYPERegClass);
72 return Res;
73 }
74
getOpTypeBool(MachineIRBuilder & MIRBuilder)75 SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
76 return MIRBuilder.buildInstr(SPIRV::OpTypeBool)
77 .addDef(createTypeVReg(MIRBuilder));
78 }
79
getOpTypeInt(uint32_t Width,MachineIRBuilder & MIRBuilder,bool IsSigned)80 SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width,
81 MachineIRBuilder &MIRBuilder,
82 bool IsSigned) {
83 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeInt)
84 .addDef(createTypeVReg(MIRBuilder))
85 .addImm(Width)
86 .addImm(IsSigned ? 1 : 0);
87 return MIB;
88 }
89
getOpTypeFloat(uint32_t Width,MachineIRBuilder & MIRBuilder)90 SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
91 MachineIRBuilder &MIRBuilder) {
92 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
93 .addDef(createTypeVReg(MIRBuilder))
94 .addImm(Width);
95 return MIB;
96 }
97
getOpTypeVoid(MachineIRBuilder & MIRBuilder)98 SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) {
99 return MIRBuilder.buildInstr(SPIRV::OpTypeVoid)
100 .addDef(createTypeVReg(MIRBuilder));
101 }
102
getOpTypeVector(uint32_t NumElems,SPIRVType * ElemType,MachineIRBuilder & MIRBuilder)103 SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems,
104 SPIRVType *ElemType,
105 MachineIRBuilder &MIRBuilder) {
106 auto EleOpc = ElemType->getOpcode();
107 assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat ||
108 EleOpc == SPIRV::OpTypeBool) &&
109 "Invalid vector element type");
110
111 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeVector)
112 .addDef(createTypeVReg(MIRBuilder))
113 .addUse(getSPIRVTypeID(ElemType))
114 .addImm(NumElems);
115 return MIB;
116 }
117
118 std::tuple<Register, ConstantInt *, bool>
getOrCreateConstIntReg(uint64_t Val,SPIRVType * SpvType,MachineIRBuilder * MIRBuilder,MachineInstr * I,const SPIRVInstrInfo * TII)119 SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType,
120 MachineIRBuilder *MIRBuilder,
121 MachineInstr *I,
122 const SPIRVInstrInfo *TII) {
123 const IntegerType *LLVMIntTy;
124 if (SpvType)
125 LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType));
126 else
127 LLVMIntTy = IntegerType::getInt32Ty(CurMF->getFunction().getContext());
128 bool NewInstr = false;
129 // Find a constant in DT or build a new one.
130 ConstantInt *CI = ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val);
131 Register Res = DT.find(CI, CurMF);
132 if (!Res.isValid()) {
133 unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
134 LLT LLTy = LLT::scalar(32);
135 Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
136 if (MIRBuilder)
137 assignTypeToVReg(LLVMIntTy, Res, *MIRBuilder);
138 else
139 assignIntTypeToVReg(BitWidth, Res, *I, *TII);
140 DT.add(CI, CurMF, Res);
141 NewInstr = true;
142 }
143 return std::make_tuple(Res, CI, NewInstr);
144 }
145
getOrCreateConstInt(uint64_t Val,MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII)146 Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
147 SPIRVType *SpvType,
148 const SPIRVInstrInfo &TII) {
149 assert(SpvType);
150 ConstantInt *CI;
151 Register Res;
152 bool New;
153 std::tie(Res, CI, New) =
154 getOrCreateConstIntReg(Val, SpvType, nullptr, &I, &TII);
155 // If we have found Res register which is defined by the passed G_CONSTANT
156 // machine instruction, a new constant instruction should be created.
157 if (!New && (!I.getOperand(0).isReg() || Res != I.getOperand(0).getReg()))
158 return Res;
159 MachineInstrBuilder MIB;
160 MachineBasicBlock &BB = *I.getParent();
161 if (Val) {
162 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantI))
163 .addDef(Res)
164 .addUse(getSPIRVTypeID(SpvType));
165 addNumImm(APInt(getScalarOrVectorBitWidth(SpvType), Val), MIB);
166 } else {
167 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
168 .addDef(Res)
169 .addUse(getSPIRVTypeID(SpvType));
170 }
171 const auto &ST = CurMF->getSubtarget();
172 constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
173 *ST.getRegisterInfo(), *ST.getRegBankInfo());
174 return Res;
175 }
176
buildConstantInt(uint64_t Val,MachineIRBuilder & MIRBuilder,SPIRVType * SpvType,bool EmitIR)177 Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
178 MachineIRBuilder &MIRBuilder,
179 SPIRVType *SpvType,
180 bool EmitIR) {
181 auto &MF = MIRBuilder.getMF();
182 const IntegerType *LLVMIntTy;
183 if (SpvType)
184 LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType));
185 else
186 LLVMIntTy = IntegerType::getInt32Ty(MF.getFunction().getContext());
187 // Find a constant in DT or build a new one.
188 const auto ConstInt =
189 ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val);
190 Register Res = DT.find(ConstInt, &MF);
191 if (!Res.isValid()) {
192 unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
193 LLT LLTy = LLT::scalar(EmitIR ? BitWidth : 32);
194 Res = MF.getRegInfo().createGenericVirtualRegister(LLTy);
195 assignTypeToVReg(LLVMIntTy, Res, MIRBuilder,
196 SPIRV::AccessQualifier::ReadWrite, EmitIR);
197 DT.add(ConstInt, &MIRBuilder.getMF(), Res);
198 if (EmitIR) {
199 MIRBuilder.buildConstant(Res, *ConstInt);
200 } else {
201 MachineInstrBuilder MIB;
202 if (Val) {
203 assert(SpvType);
204 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
205 .addDef(Res)
206 .addUse(getSPIRVTypeID(SpvType));
207 addNumImm(APInt(BitWidth, Val), MIB);
208 } else {
209 assert(SpvType);
210 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
211 .addDef(Res)
212 .addUse(getSPIRVTypeID(SpvType));
213 }
214 const auto &Subtarget = CurMF->getSubtarget();
215 constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
216 *Subtarget.getRegisterInfo(),
217 *Subtarget.getRegBankInfo());
218 }
219 }
220 return Res;
221 }
222
buildConstantFP(APFloat Val,MachineIRBuilder & MIRBuilder,SPIRVType * SpvType)223 Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,
224 MachineIRBuilder &MIRBuilder,
225 SPIRVType *SpvType) {
226 auto &MF = MIRBuilder.getMF();
227 const Type *LLVMFPTy;
228 if (SpvType) {
229 LLVMFPTy = getTypeForSPIRVType(SpvType);
230 assert(LLVMFPTy->isFloatingPointTy());
231 } else {
232 LLVMFPTy = IntegerType::getFloatTy(MF.getFunction().getContext());
233 }
234 // Find a constant in DT or build a new one.
235 const auto ConstFP = ConstantFP::get(LLVMFPTy->getContext(), Val);
236 Register Res = DT.find(ConstFP, &MF);
237 if (!Res.isValid()) {
238 unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
239 Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
240 assignTypeToVReg(LLVMFPTy, Res, MIRBuilder);
241 DT.add(ConstFP, &MF, Res);
242 MIRBuilder.buildFConstant(Res, *ConstFP);
243 }
244 return Res;
245 }
246
getOrCreateIntCompositeOrNull(uint64_t Val,MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII,Constant * CA,unsigned BitWidth,unsigned ElemCnt)247 Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
248 uint64_t Val, MachineInstr &I, SPIRVType *SpvType,
249 const SPIRVInstrInfo &TII, Constant *CA, unsigned BitWidth,
250 unsigned ElemCnt) {
251 // Find a constant vector in DT or build a new one.
252 Register Res = DT.find(CA, CurMF);
253 if (!Res.isValid()) {
254 SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
255 // SpvScalConst should be created before SpvVecConst to avoid undefined ID
256 // error on validation.
257 // TODO: can moved below once sorting of types/consts/defs is implemented.
258 Register SpvScalConst;
259 if (Val)
260 SpvScalConst = getOrCreateConstInt(Val, I, SpvBaseType, TII);
261 // TODO: maybe use bitwidth of base type.
262 LLT LLTy = LLT::scalar(32);
263 Register SpvVecConst =
264 CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
265 assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF);
266 DT.add(CA, CurMF, SpvVecConst);
267 MachineInstrBuilder MIB;
268 MachineBasicBlock &BB = *I.getParent();
269 if (Val) {
270 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantComposite))
271 .addDef(SpvVecConst)
272 .addUse(getSPIRVTypeID(SpvType));
273 for (unsigned i = 0; i < ElemCnt; ++i)
274 MIB.addUse(SpvScalConst);
275 } else {
276 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
277 .addDef(SpvVecConst)
278 .addUse(getSPIRVTypeID(SpvType));
279 }
280 const auto &Subtarget = CurMF->getSubtarget();
281 constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
282 *Subtarget.getRegisterInfo(),
283 *Subtarget.getRegBankInfo());
284 return SpvVecConst;
285 }
286 return Res;
287 }
288
289 Register
getOrCreateConsIntVector(uint64_t Val,MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII)290 SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val, MachineInstr &I,
291 SPIRVType *SpvType,
292 const SPIRVInstrInfo &TII) {
293 const Type *LLVMTy = getTypeForSPIRVType(SpvType);
294 assert(LLVMTy->isVectorTy());
295 const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
296 Type *LLVMBaseTy = LLVMVecTy->getElementType();
297 const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
298 auto ConstVec =
299 ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt);
300 unsigned BW = getScalarOrVectorBitWidth(SpvType);
301 return getOrCreateIntCompositeOrNull(Val, I, SpvType, TII, ConstVec, BW,
302 SpvType->getOperand(2).getImm());
303 }
304
305 Register
getOrCreateConsIntArray(uint64_t Val,MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII)306 SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val, MachineInstr &I,
307 SPIRVType *SpvType,
308 const SPIRVInstrInfo &TII) {
309 const Type *LLVMTy = getTypeForSPIRVType(SpvType);
310 assert(LLVMTy->isArrayTy());
311 const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy);
312 Type *LLVMBaseTy = LLVMArrTy->getElementType();
313 const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
314 auto ConstArr =
315 ConstantArray::get(const_cast<ArrayType *>(LLVMArrTy), {ConstInt});
316 SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
317 unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy);
318 return getOrCreateIntCompositeOrNull(Val, I, SpvType, TII, ConstArr, BW,
319 LLVMArrTy->getNumElements());
320 }
321
getOrCreateIntCompositeOrNull(uint64_t Val,MachineIRBuilder & MIRBuilder,SPIRVType * SpvType,bool EmitIR,Constant * CA,unsigned BitWidth,unsigned ElemCnt)322 Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
323 uint64_t Val, MachineIRBuilder &MIRBuilder, SPIRVType *SpvType, bool EmitIR,
324 Constant *CA, unsigned BitWidth, unsigned ElemCnt) {
325 Register Res = DT.find(CA, CurMF);
326 if (!Res.isValid()) {
327 Register SpvScalConst;
328 if (Val || EmitIR) {
329 SPIRVType *SpvBaseType =
330 getOrCreateSPIRVIntegerType(BitWidth, MIRBuilder);
331 SpvScalConst = buildConstantInt(Val, MIRBuilder, SpvBaseType, EmitIR);
332 }
333 LLT LLTy = EmitIR ? LLT::fixed_vector(ElemCnt, BitWidth) : LLT::scalar(32);
334 Register SpvVecConst =
335 CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
336 assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF);
337 DT.add(CA, CurMF, SpvVecConst);
338 if (EmitIR) {
339 MIRBuilder.buildSplatVector(SpvVecConst, SpvScalConst);
340 } else {
341 if (Val) {
342 auto MIB = MIRBuilder.buildInstr(SPIRV::OpConstantComposite)
343 .addDef(SpvVecConst)
344 .addUse(getSPIRVTypeID(SpvType));
345 for (unsigned i = 0; i < ElemCnt; ++i)
346 MIB.addUse(SpvScalConst);
347 } else {
348 MIRBuilder.buildInstr(SPIRV::OpConstantNull)
349 .addDef(SpvVecConst)
350 .addUse(getSPIRVTypeID(SpvType));
351 }
352 }
353 return SpvVecConst;
354 }
355 return Res;
356 }
357
358 Register
getOrCreateConsIntVector(uint64_t Val,MachineIRBuilder & MIRBuilder,SPIRVType * SpvType,bool EmitIR)359 SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val,
360 MachineIRBuilder &MIRBuilder,
361 SPIRVType *SpvType, bool EmitIR) {
362 const Type *LLVMTy = getTypeForSPIRVType(SpvType);
363 assert(LLVMTy->isVectorTy());
364 const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
365 Type *LLVMBaseTy = LLVMVecTy->getElementType();
366 const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
367 auto ConstVec =
368 ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt);
369 unsigned BW = getScalarOrVectorBitWidth(SpvType);
370 return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR,
371 ConstVec, BW,
372 SpvType->getOperand(2).getImm());
373 }
374
375 Register
getOrCreateConsIntArray(uint64_t Val,MachineIRBuilder & MIRBuilder,SPIRVType * SpvType,bool EmitIR)376 SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val,
377 MachineIRBuilder &MIRBuilder,
378 SPIRVType *SpvType, bool EmitIR) {
379 const Type *LLVMTy = getTypeForSPIRVType(SpvType);
380 assert(LLVMTy->isArrayTy());
381 const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy);
382 Type *LLVMBaseTy = LLVMArrTy->getElementType();
383 const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
384 auto ConstArr =
385 ConstantArray::get(const_cast<ArrayType *>(LLVMArrTy), {ConstInt});
386 SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
387 unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy);
388 return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR,
389 ConstArr, BW,
390 LLVMArrTy->getNumElements());
391 }
392
393 Register
getOrCreateConstNullPtr(MachineIRBuilder & MIRBuilder,SPIRVType * SpvType)394 SPIRVGlobalRegistry::getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder,
395 SPIRVType *SpvType) {
396 const Type *LLVMTy = getTypeForSPIRVType(SpvType);
397 const PointerType *LLVMPtrTy = cast<PointerType>(LLVMTy);
398 // Find a constant in DT or build a new one.
399 Constant *CP = ConstantPointerNull::get(const_cast<PointerType *>(LLVMPtrTy));
400 Register Res = DT.find(CP, CurMF);
401 if (!Res.isValid()) {
402 LLT LLTy = LLT::pointer(LLVMPtrTy->getAddressSpace(), PointerSize);
403 Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
404 assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
405 MIRBuilder.buildInstr(SPIRV::OpConstantNull)
406 .addDef(Res)
407 .addUse(getSPIRVTypeID(SpvType));
408 DT.add(CP, CurMF, Res);
409 }
410 return Res;
411 }
412
buildConstantSampler(Register ResReg,unsigned AddrMode,unsigned Param,unsigned FilerMode,MachineIRBuilder & MIRBuilder,SPIRVType * SpvType)413 Register SPIRVGlobalRegistry::buildConstantSampler(
414 Register ResReg, unsigned AddrMode, unsigned Param, unsigned FilerMode,
415 MachineIRBuilder &MIRBuilder, SPIRVType *SpvType) {
416 SPIRVType *SampTy;
417 if (SpvType)
418 SampTy = getOrCreateSPIRVType(getTypeForSPIRVType(SpvType), MIRBuilder);
419 else
420 SampTy = getOrCreateSPIRVTypeByName("opencl.sampler_t", MIRBuilder);
421
422 auto Sampler =
423 ResReg.isValid()
424 ? ResReg
425 : MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
426 auto Res = MIRBuilder.buildInstr(SPIRV::OpConstantSampler)
427 .addDef(Sampler)
428 .addUse(getSPIRVTypeID(SampTy))
429 .addImm(AddrMode)
430 .addImm(Param)
431 .addImm(FilerMode);
432 assert(Res->getOperand(0).isReg());
433 return Res->getOperand(0).getReg();
434 }
435
buildGlobalVariable(Register ResVReg,SPIRVType * BaseType,StringRef Name,const GlobalValue * GV,SPIRV::StorageClass::StorageClass Storage,const MachineInstr * Init,bool IsConst,bool HasLinkageTy,SPIRV::LinkageType::LinkageType LinkageType,MachineIRBuilder & MIRBuilder,bool IsInstSelector)436 Register SPIRVGlobalRegistry::buildGlobalVariable(
437 Register ResVReg, SPIRVType *BaseType, StringRef Name,
438 const GlobalValue *GV, SPIRV::StorageClass::StorageClass Storage,
439 const MachineInstr *Init, bool IsConst, bool HasLinkageTy,
440 SPIRV::LinkageType::LinkageType LinkageType, MachineIRBuilder &MIRBuilder,
441 bool IsInstSelector) {
442 const GlobalVariable *GVar = nullptr;
443 if (GV)
444 GVar = cast<const GlobalVariable>(GV);
445 else {
446 // If GV is not passed explicitly, use the name to find or construct
447 // the global variable.
448 Module *M = MIRBuilder.getMF().getFunction().getParent();
449 GVar = M->getGlobalVariable(Name);
450 if (GVar == nullptr) {
451 const Type *Ty = getTypeForSPIRVType(BaseType); // TODO: check type.
452 GVar = new GlobalVariable(*M, const_cast<Type *>(Ty), false,
453 GlobalValue::ExternalLinkage, nullptr,
454 Twine(Name));
455 }
456 GV = GVar;
457 }
458 Register Reg = DT.find(GVar, &MIRBuilder.getMF());
459 if (Reg.isValid()) {
460 if (Reg != ResVReg)
461 MIRBuilder.buildCopy(ResVReg, Reg);
462 return ResVReg;
463 }
464
465 auto MIB = MIRBuilder.buildInstr(SPIRV::OpVariable)
466 .addDef(ResVReg)
467 .addUse(getSPIRVTypeID(BaseType))
468 .addImm(static_cast<uint32_t>(Storage));
469
470 if (Init != 0) {
471 MIB.addUse(Init->getOperand(0).getReg());
472 }
473
474 // ISel may introduce a new register on this step, so we need to add it to
475 // DT and correct its type avoiding fails on the next stage.
476 if (IsInstSelector) {
477 const auto &Subtarget = CurMF->getSubtarget();
478 constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
479 *Subtarget.getRegisterInfo(),
480 *Subtarget.getRegBankInfo());
481 }
482 Reg = MIB->getOperand(0).getReg();
483 DT.add(GVar, &MIRBuilder.getMF(), Reg);
484
485 // Set to Reg the same type as ResVReg has.
486 auto MRI = MIRBuilder.getMRI();
487 assert(MRI->getType(ResVReg).isPointer() && "Pointer type is expected");
488 if (Reg != ResVReg) {
489 LLT RegLLTy = LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), 32);
490 MRI->setType(Reg, RegLLTy);
491 assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF());
492 }
493
494 // If it's a global variable with name, output OpName for it.
495 if (GVar && GVar->hasName())
496 buildOpName(Reg, GVar->getName(), MIRBuilder);
497
498 // Output decorations for the GV.
499 // TODO: maybe move to GenerateDecorations pass.
500 if (IsConst)
501 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Constant, {});
502
503 if (GVar && GVar->getAlign().valueOrOne().value() != 1) {
504 unsigned Alignment = (unsigned)GVar->getAlign().valueOrOne().value();
505 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Alignment, {Alignment});
506 }
507
508 if (HasLinkageTy)
509 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::LinkageAttributes,
510 {static_cast<uint32_t>(LinkageType)}, Name);
511
512 SPIRV::BuiltIn::BuiltIn BuiltInId;
513 if (getSpirvBuiltInIdByName(Name, BuiltInId))
514 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::BuiltIn,
515 {static_cast<uint32_t>(BuiltInId)});
516
517 return Reg;
518 }
519
getOpTypeArray(uint32_t NumElems,SPIRVType * ElemType,MachineIRBuilder & MIRBuilder,bool EmitIR)520 SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems,
521 SPIRVType *ElemType,
522 MachineIRBuilder &MIRBuilder,
523 bool EmitIR) {
524 assert((ElemType->getOpcode() != SPIRV::OpTypeVoid) &&
525 "Invalid array element type");
526 Register NumElementsVReg =
527 buildConstantInt(NumElems, MIRBuilder, nullptr, EmitIR);
528 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeArray)
529 .addDef(createTypeVReg(MIRBuilder))
530 .addUse(getSPIRVTypeID(ElemType))
531 .addUse(NumElementsVReg);
532 return MIB;
533 }
534
getOpTypeOpaque(const StructType * Ty,MachineIRBuilder & MIRBuilder)535 SPIRVType *SPIRVGlobalRegistry::getOpTypeOpaque(const StructType *Ty,
536 MachineIRBuilder &MIRBuilder) {
537 assert(Ty->hasName());
538 const StringRef Name = Ty->hasName() ? Ty->getName() : "";
539 Register ResVReg = createTypeVReg(MIRBuilder);
540 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeOpaque).addDef(ResVReg);
541 addStringImm(Name, MIB);
542 buildOpName(ResVReg, Name, MIRBuilder);
543 return MIB;
544 }
545
getOpTypeStruct(const StructType * Ty,MachineIRBuilder & MIRBuilder,bool EmitIR)546 SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty,
547 MachineIRBuilder &MIRBuilder,
548 bool EmitIR) {
549 SmallVector<Register, 4> FieldTypes;
550 for (const auto &Elem : Ty->elements()) {
551 SPIRVType *ElemTy = findSPIRVType(Elem, MIRBuilder);
552 assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid &&
553 "Invalid struct element type");
554 FieldTypes.push_back(getSPIRVTypeID(ElemTy));
555 }
556 Register ResVReg = createTypeVReg(MIRBuilder);
557 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeStruct).addDef(ResVReg);
558 for (const auto &Ty : FieldTypes)
559 MIB.addUse(Ty);
560 if (Ty->hasName())
561 buildOpName(ResVReg, Ty->getName(), MIRBuilder);
562 if (Ty->isPacked())
563 buildOpDecorate(ResVReg, MIRBuilder, SPIRV::Decoration::CPacked, {});
564 return MIB;
565 }
566
getOrCreateSpecialType(const Type * Ty,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccQual)567 SPIRVType *SPIRVGlobalRegistry::getOrCreateSpecialType(
568 const Type *Ty, MachineIRBuilder &MIRBuilder,
569 SPIRV::AccessQualifier::AccessQualifier AccQual) {
570 // Some OpenCL and SPIRV builtins like image2d_t are passed in as
571 // pointers, but should be treated as custom types like OpTypeImage.
572 if (auto PType = dyn_cast<PointerType>(Ty)) {
573 assert(!PType->isOpaque());
574 Ty = PType->getNonOpaquePointerElementType();
575 }
576 auto SType = cast<StructType>(Ty);
577 assert(isSpecialOpaqueType(SType) && "Not a special opaque builtin type");
578 return SPIRV::lowerBuiltinType(SType, AccQual, MIRBuilder, this);
579 }
580
getOpTypePointer(SPIRV::StorageClass::StorageClass SC,SPIRVType * ElemType,MachineIRBuilder & MIRBuilder,Register Reg)581 SPIRVType *SPIRVGlobalRegistry::getOpTypePointer(
582 SPIRV::StorageClass::StorageClass SC, SPIRVType *ElemType,
583 MachineIRBuilder &MIRBuilder, Register Reg) {
584 if (!Reg.isValid())
585 Reg = createTypeVReg(MIRBuilder);
586 return MIRBuilder.buildInstr(SPIRV::OpTypePointer)
587 .addDef(Reg)
588 .addImm(static_cast<uint32_t>(SC))
589 .addUse(getSPIRVTypeID(ElemType));
590 }
591
getOpTypeForwardPointer(SPIRV::StorageClass::StorageClass SC,MachineIRBuilder & MIRBuilder)592 SPIRVType *SPIRVGlobalRegistry::getOpTypeForwardPointer(
593 SPIRV::StorageClass::StorageClass SC, MachineIRBuilder &MIRBuilder) {
594 return MIRBuilder.buildInstr(SPIRV::OpTypeForwardPointer)
595 .addUse(createTypeVReg(MIRBuilder))
596 .addImm(static_cast<uint32_t>(SC));
597 }
598
getOpTypeFunction(SPIRVType * RetType,const SmallVectorImpl<SPIRVType * > & ArgTypes,MachineIRBuilder & MIRBuilder)599 SPIRVType *SPIRVGlobalRegistry::getOpTypeFunction(
600 SPIRVType *RetType, const SmallVectorImpl<SPIRVType *> &ArgTypes,
601 MachineIRBuilder &MIRBuilder) {
602 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFunction)
603 .addDef(createTypeVReg(MIRBuilder))
604 .addUse(getSPIRVTypeID(RetType));
605 for (const SPIRVType *ArgType : ArgTypes)
606 MIB.addUse(getSPIRVTypeID(ArgType));
607 return MIB;
608 }
609
getOrCreateOpTypeFunctionWithArgs(const Type * Ty,SPIRVType * RetType,const SmallVectorImpl<SPIRVType * > & ArgTypes,MachineIRBuilder & MIRBuilder)610 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeFunctionWithArgs(
611 const Type *Ty, SPIRVType *RetType,
612 const SmallVectorImpl<SPIRVType *> &ArgTypes,
613 MachineIRBuilder &MIRBuilder) {
614 Register Reg = DT.find(Ty, &MIRBuilder.getMF());
615 if (Reg.isValid())
616 return getSPIRVTypeForVReg(Reg);
617 SPIRVType *SpirvType = getOpTypeFunction(RetType, ArgTypes, MIRBuilder);
618 return finishCreatingSPIRVType(Ty, SpirvType);
619 }
620
findSPIRVType(const Type * Ty,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccQual,bool EmitIR)621 SPIRVType *SPIRVGlobalRegistry::findSPIRVType(
622 const Type *Ty, MachineIRBuilder &MIRBuilder,
623 SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
624 Register Reg = DT.find(Ty, &MIRBuilder.getMF());
625 if (Reg.isValid())
626 return getSPIRVTypeForVReg(Reg);
627 if (ForwardPointerTypes.find(Ty) != ForwardPointerTypes.end())
628 return ForwardPointerTypes[Ty];
629 return restOfCreateSPIRVType(Ty, MIRBuilder, AccQual, EmitIR);
630 }
631
getSPIRVTypeID(const SPIRVType * SpirvType) const632 Register SPIRVGlobalRegistry::getSPIRVTypeID(const SPIRVType *SpirvType) const {
633 assert(SpirvType && "Attempting to get type id for nullptr type.");
634 if (SpirvType->getOpcode() == SPIRV::OpTypeForwardPointer)
635 return SpirvType->uses().begin()->getReg();
636 return SpirvType->defs().begin()->getReg();
637 }
638
createSPIRVType(const Type * Ty,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccQual,bool EmitIR)639 SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
640 const Type *Ty, MachineIRBuilder &MIRBuilder,
641 SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
642 if (isSpecialOpaqueType(Ty))
643 return getOrCreateSpecialType(Ty, MIRBuilder, AccQual);
644 auto &TypeToSPIRVTypeMap = DT.getTypes()->getAllUses();
645 auto t = TypeToSPIRVTypeMap.find(Ty);
646 if (t != TypeToSPIRVTypeMap.end()) {
647 auto tt = t->second.find(&MIRBuilder.getMF());
648 if (tt != t->second.end())
649 return getSPIRVTypeForVReg(tt->second);
650 }
651
652 if (auto IType = dyn_cast<IntegerType>(Ty)) {
653 const unsigned Width = IType->getBitWidth();
654 return Width == 1 ? getOpTypeBool(MIRBuilder)
655 : getOpTypeInt(Width, MIRBuilder, false);
656 }
657 if (Ty->isFloatingPointTy())
658 return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
659 if (Ty->isVoidTy())
660 return getOpTypeVoid(MIRBuilder);
661 if (Ty->isVectorTy()) {
662 SPIRVType *El =
663 findSPIRVType(cast<FixedVectorType>(Ty)->getElementType(), MIRBuilder);
664 return getOpTypeVector(cast<FixedVectorType>(Ty)->getNumElements(), El,
665 MIRBuilder);
666 }
667 if (Ty->isArrayTy()) {
668 SPIRVType *El = findSPIRVType(Ty->getArrayElementType(), MIRBuilder);
669 return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder, EmitIR);
670 }
671 if (auto SType = dyn_cast<StructType>(Ty)) {
672 if (SType->isOpaque())
673 return getOpTypeOpaque(SType, MIRBuilder);
674 return getOpTypeStruct(SType, MIRBuilder, EmitIR);
675 }
676 if (auto FType = dyn_cast<FunctionType>(Ty)) {
677 SPIRVType *RetTy = findSPIRVType(FType->getReturnType(), MIRBuilder);
678 SmallVector<SPIRVType *, 4> ParamTypes;
679 for (const auto &t : FType->params()) {
680 ParamTypes.push_back(findSPIRVType(t, MIRBuilder));
681 }
682 return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder);
683 }
684 if (auto PType = dyn_cast<PointerType>(Ty)) {
685 SPIRVType *SpvElementType;
686 // At the moment, all opaque pointers correspond to i8 element type.
687 // TODO: change the implementation once opaque pointers are supported
688 // in the SPIR-V specification.
689 if (PType->isOpaque())
690 SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);
691 else
692 SpvElementType =
693 findSPIRVType(PType->getNonOpaquePointerElementType(), MIRBuilder,
694 SPIRV::AccessQualifier::ReadWrite, EmitIR);
695 auto SC = addressSpaceToStorageClass(PType->getAddressSpace());
696 // Null pointer means we have a loop in type definitions, make and
697 // return corresponding OpTypeForwardPointer.
698 if (SpvElementType == nullptr) {
699 if (ForwardPointerTypes.find(Ty) == ForwardPointerTypes.end())
700 ForwardPointerTypes[PType] = getOpTypeForwardPointer(SC, MIRBuilder);
701 return ForwardPointerTypes[PType];
702 }
703 Register Reg(0);
704 // If we have forward pointer associated with this type, use its register
705 // operand to create OpTypePointer.
706 if (ForwardPointerTypes.find(PType) != ForwardPointerTypes.end())
707 Reg = getSPIRVTypeID(ForwardPointerTypes[PType]);
708
709 return getOpTypePointer(SC, SpvElementType, MIRBuilder, Reg);
710 }
711 llvm_unreachable("Unable to convert LLVM type to SPIRVType");
712 }
713
restOfCreateSPIRVType(const Type * Ty,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccessQual,bool EmitIR)714 SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
715 const Type *Ty, MachineIRBuilder &MIRBuilder,
716 SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
717 if (TypesInProcessing.count(Ty) && !Ty->isPointerTy())
718 return nullptr;
719 TypesInProcessing.insert(Ty);
720 SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
721 TypesInProcessing.erase(Ty);
722 VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;
723 SPIRVToLLVMType[SpirvType] = Ty;
724 Register Reg = DT.find(Ty, &MIRBuilder.getMF());
725 // Do not add OpTypeForwardPointer to DT, a corresponding normal pointer type
726 // will be added later. For special types it is already added to DT.
727 if (SpirvType->getOpcode() != SPIRV::OpTypeForwardPointer && !Reg.isValid() &&
728 !isSpecialOpaqueType(Ty))
729 DT.add(Ty, &MIRBuilder.getMF(), getSPIRVTypeID(SpirvType));
730
731 return SpirvType;
732 }
733
getSPIRVTypeForVReg(Register VReg) const734 SPIRVType *SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg) const {
735 auto t = VRegToTypeMap.find(CurMF);
736 if (t != VRegToTypeMap.end()) {
737 auto tt = t->second.find(VReg);
738 if (tt != t->second.end())
739 return tt->second;
740 }
741 return nullptr;
742 }
743
getOrCreateSPIRVType(const Type * Ty,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccessQual,bool EmitIR)744 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(
745 const Type *Ty, MachineIRBuilder &MIRBuilder,
746 SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
747 Register Reg = DT.find(Ty, &MIRBuilder.getMF());
748 if (Reg.isValid() && !isSpecialOpaqueType(Ty))
749 return getSPIRVTypeForVReg(Reg);
750 TypesInProcessing.clear();
751 SPIRVType *STy = restOfCreateSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
752 // Create normal pointer types for the corresponding OpTypeForwardPointers.
753 for (auto &CU : ForwardPointerTypes) {
754 const Type *Ty2 = CU.first;
755 SPIRVType *STy2 = CU.second;
756 if ((Reg = DT.find(Ty2, &MIRBuilder.getMF())).isValid())
757 STy2 = getSPIRVTypeForVReg(Reg);
758 else
759 STy2 = restOfCreateSPIRVType(Ty2, MIRBuilder, AccessQual, EmitIR);
760 if (Ty == Ty2)
761 STy = STy2;
762 }
763 ForwardPointerTypes.clear();
764 return STy;
765 }
766
isScalarOfType(Register VReg,unsigned TypeOpcode) const767 bool SPIRVGlobalRegistry::isScalarOfType(Register VReg,
768 unsigned TypeOpcode) const {
769 SPIRVType *Type = getSPIRVTypeForVReg(VReg);
770 assert(Type && "isScalarOfType VReg has no type assigned");
771 return Type->getOpcode() == TypeOpcode;
772 }
773
isScalarOrVectorOfType(Register VReg,unsigned TypeOpcode) const774 bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg,
775 unsigned TypeOpcode) const {
776 SPIRVType *Type = getSPIRVTypeForVReg(VReg);
777 assert(Type && "isScalarOrVectorOfType VReg has no type assigned");
778 if (Type->getOpcode() == TypeOpcode)
779 return true;
780 if (Type->getOpcode() == SPIRV::OpTypeVector) {
781 Register ScalarTypeVReg = Type->getOperand(1).getReg();
782 SPIRVType *ScalarType = getSPIRVTypeForVReg(ScalarTypeVReg);
783 return ScalarType->getOpcode() == TypeOpcode;
784 }
785 return false;
786 }
787
788 unsigned
getScalarOrVectorBitWidth(const SPIRVType * Type) const789 SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const {
790 assert(Type && "Invalid Type pointer");
791 if (Type->getOpcode() == SPIRV::OpTypeVector) {
792 auto EleTypeReg = Type->getOperand(1).getReg();
793 Type = getSPIRVTypeForVReg(EleTypeReg);
794 }
795 if (Type->getOpcode() == SPIRV::OpTypeInt ||
796 Type->getOpcode() == SPIRV::OpTypeFloat)
797 return Type->getOperand(1).getImm();
798 if (Type->getOpcode() == SPIRV::OpTypeBool)
799 return 1;
800 llvm_unreachable("Attempting to get bit width of non-integer/float type.");
801 }
802
isScalarOrVectorSigned(const SPIRVType * Type) const803 bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const {
804 assert(Type && "Invalid Type pointer");
805 if (Type->getOpcode() == SPIRV::OpTypeVector) {
806 auto EleTypeReg = Type->getOperand(1).getReg();
807 Type = getSPIRVTypeForVReg(EleTypeReg);
808 }
809 if (Type->getOpcode() == SPIRV::OpTypeInt)
810 return Type->getOperand(2).getImm() != 0;
811 llvm_unreachable("Attempting to get sign of non-integer type.");
812 }
813
814 SPIRV::StorageClass::StorageClass
getPointerStorageClass(Register VReg) const815 SPIRVGlobalRegistry::getPointerStorageClass(Register VReg) const {
816 SPIRVType *Type = getSPIRVTypeForVReg(VReg);
817 assert(Type && Type->getOpcode() == SPIRV::OpTypePointer &&
818 Type->getOperand(1).isImm() && "Pointer type is expected");
819 return static_cast<SPIRV::StorageClass::StorageClass>(
820 Type->getOperand(1).getImm());
821 }
822
getOrCreateOpTypeImage(MachineIRBuilder & MIRBuilder,SPIRVType * SampledType,SPIRV::Dim::Dim Dim,uint32_t Depth,uint32_t Arrayed,uint32_t Multisampled,uint32_t Sampled,SPIRV::ImageFormat::ImageFormat ImageFormat,SPIRV::AccessQualifier::AccessQualifier AccessQual)823 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeImage(
824 MachineIRBuilder &MIRBuilder, SPIRVType *SampledType, SPIRV::Dim::Dim Dim,
825 uint32_t Depth, uint32_t Arrayed, uint32_t Multisampled, uint32_t Sampled,
826 SPIRV::ImageFormat::ImageFormat ImageFormat,
827 SPIRV::AccessQualifier::AccessQualifier AccessQual) {
828 SPIRV::ImageTypeDescriptor TD(SPIRVToLLVMType.lookup(SampledType), Dim, Depth,
829 Arrayed, Multisampled, Sampled, ImageFormat,
830 AccessQual);
831 if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
832 return Res;
833 Register ResVReg = createTypeVReg(MIRBuilder);
834 DT.add(TD, &MIRBuilder.getMF(), ResVReg);
835 return MIRBuilder.buildInstr(SPIRV::OpTypeImage)
836 .addDef(ResVReg)
837 .addUse(getSPIRVTypeID(SampledType))
838 .addImm(Dim)
839 .addImm(Depth) // Depth (whether or not it is a Depth image).
840 .addImm(Arrayed) // Arrayed.
841 .addImm(Multisampled) // Multisampled (0 = only single-sample).
842 .addImm(Sampled) // Sampled (0 = usage known at runtime).
843 .addImm(ImageFormat)
844 .addImm(AccessQual);
845 }
846
847 SPIRVType *
getOrCreateOpTypeSampler(MachineIRBuilder & MIRBuilder)848 SPIRVGlobalRegistry::getOrCreateOpTypeSampler(MachineIRBuilder &MIRBuilder) {
849 SPIRV::SamplerTypeDescriptor TD;
850 if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
851 return Res;
852 Register ResVReg = createTypeVReg(MIRBuilder);
853 DT.add(TD, &MIRBuilder.getMF(), ResVReg);
854 return MIRBuilder.buildInstr(SPIRV::OpTypeSampler).addDef(ResVReg);
855 }
856
getOrCreateOpTypePipe(MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccessQual)857 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypePipe(
858 MachineIRBuilder &MIRBuilder,
859 SPIRV::AccessQualifier::AccessQualifier AccessQual) {
860 SPIRV::PipeTypeDescriptor TD(AccessQual);
861 if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
862 return Res;
863 Register ResVReg = createTypeVReg(MIRBuilder);
864 DT.add(TD, &MIRBuilder.getMF(), ResVReg);
865 return MIRBuilder.buildInstr(SPIRV::OpTypePipe)
866 .addDef(ResVReg)
867 .addImm(AccessQual);
868 }
869
getOrCreateOpTypeDeviceEvent(MachineIRBuilder & MIRBuilder)870 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeDeviceEvent(
871 MachineIRBuilder &MIRBuilder) {
872 SPIRV::DeviceEventTypeDescriptor TD;
873 if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
874 return Res;
875 Register ResVReg = createTypeVReg(MIRBuilder);
876 DT.add(TD, &MIRBuilder.getMF(), ResVReg);
877 return MIRBuilder.buildInstr(SPIRV::OpTypeDeviceEvent).addDef(ResVReg);
878 }
879
getOrCreateOpTypeSampledImage(SPIRVType * ImageType,MachineIRBuilder & MIRBuilder)880 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeSampledImage(
881 SPIRVType *ImageType, MachineIRBuilder &MIRBuilder) {
882 SPIRV::SampledImageTypeDescriptor TD(
883 SPIRVToLLVMType.lookup(MIRBuilder.getMF().getRegInfo().getVRegDef(
884 ImageType->getOperand(1).getReg())),
885 ImageType);
886 if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
887 return Res;
888 Register ResVReg = createTypeVReg(MIRBuilder);
889 DT.add(TD, &MIRBuilder.getMF(), ResVReg);
890 return MIRBuilder.buildInstr(SPIRV::OpTypeSampledImage)
891 .addDef(ResVReg)
892 .addUse(getSPIRVTypeID(ImageType));
893 }
894
getOrCreateOpTypeByOpcode(const Type * Ty,MachineIRBuilder & MIRBuilder,unsigned Opcode)895 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeByOpcode(
896 const Type *Ty, MachineIRBuilder &MIRBuilder, unsigned Opcode) {
897 Register ResVReg = DT.find(Ty, &MIRBuilder.getMF());
898 if (ResVReg.isValid())
899 return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(ResVReg);
900 ResVReg = createTypeVReg(MIRBuilder);
901 DT.add(Ty, &MIRBuilder.getMF(), ResVReg);
902 return MIRBuilder.buildInstr(Opcode).addDef(ResVReg);
903 }
904
905 const MachineInstr *
checkSpecialInstr(const SPIRV::SpecialTypeDescriptor & TD,MachineIRBuilder & MIRBuilder)906 SPIRVGlobalRegistry::checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD,
907 MachineIRBuilder &MIRBuilder) {
908 Register Reg = DT.find(TD, &MIRBuilder.getMF());
909 if (Reg.isValid())
910 return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(Reg);
911 return nullptr;
912 }
913
914 // TODO: maybe use tablegen to implement this.
915 SPIRVType *
getOrCreateSPIRVTypeByName(StringRef TypeStr,MachineIRBuilder & MIRBuilder)916 SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(StringRef TypeStr,
917 MachineIRBuilder &MIRBuilder) {
918 unsigned VecElts = 0;
919 auto &Ctx = MIRBuilder.getMF().getFunction().getContext();
920
921 // Parse type name in either "typeN" or "type vector[N]" format, where
922 // N is the number of elements of the vector.
923 Type *Type;
924 if (TypeStr.startswith("void")) {
925 Type = Type::getVoidTy(Ctx);
926 TypeStr = TypeStr.substr(strlen("void"));
927 } else if (TypeStr.startswith("int") || TypeStr.startswith("uint")) {
928 Type = Type::getInt32Ty(Ctx);
929 TypeStr = TypeStr.startswith("int") ? TypeStr.substr(strlen("int"))
930 : TypeStr.substr(strlen("uint"));
931 } else if (TypeStr.startswith("float")) {
932 Type = Type::getFloatTy(Ctx);
933 TypeStr = TypeStr.substr(strlen("float"));
934 } else if (TypeStr.startswith("half")) {
935 Type = Type::getHalfTy(Ctx);
936 TypeStr = TypeStr.substr(strlen("half"));
937 } else if (TypeStr.startswith("opencl.sampler_t")) {
938 Type = StructType::create(Ctx, "opencl.sampler_t");
939 } else
940 llvm_unreachable("Unable to recognize SPIRV type name.");
941 if (TypeStr.startswith(" vector[")) {
942 TypeStr = TypeStr.substr(strlen(" vector["));
943 TypeStr = TypeStr.substr(0, TypeStr.find(']'));
944 }
945 TypeStr.getAsInteger(10, VecElts);
946 auto SpirvTy = getOrCreateSPIRVType(Type, MIRBuilder);
947 if (VecElts > 0)
948 SpirvTy = getOrCreateSPIRVVectorType(SpirvTy, VecElts, MIRBuilder);
949 return SpirvTy;
950 }
951
952 SPIRVType *
getOrCreateSPIRVIntegerType(unsigned BitWidth,MachineIRBuilder & MIRBuilder)953 SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth,
954 MachineIRBuilder &MIRBuilder) {
955 return getOrCreateSPIRVType(
956 IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), BitWidth),
957 MIRBuilder);
958 }
959
finishCreatingSPIRVType(const Type * LLVMTy,SPIRVType * SpirvType)960 SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy,
961 SPIRVType *SpirvType) {
962 assert(CurMF == SpirvType->getMF());
963 VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType;
964 SPIRVToLLVMType[SpirvType] = LLVMTy;
965 DT.add(LLVMTy, CurMF, getSPIRVTypeID(SpirvType));
966 return SpirvType;
967 }
968
getOrCreateSPIRVIntegerType(unsigned BitWidth,MachineInstr & I,const SPIRVInstrInfo & TII)969 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
970 unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
971 Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth);
972 Register Reg = DT.find(LLVMTy, CurMF);
973 if (Reg.isValid())
974 return getSPIRVTypeForVReg(Reg);
975 MachineBasicBlock &BB = *I.getParent();
976 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeInt))
977 .addDef(createTypeVReg(CurMF->getRegInfo()))
978 .addImm(BitWidth)
979 .addImm(0);
980 return finishCreatingSPIRVType(LLVMTy, MIB);
981 }
982
983 SPIRVType *
getOrCreateSPIRVBoolType(MachineIRBuilder & MIRBuilder)984 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder) {
985 return getOrCreateSPIRVType(
986 IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), 1),
987 MIRBuilder);
988 }
989
990 SPIRVType *
getOrCreateSPIRVBoolType(MachineInstr & I,const SPIRVInstrInfo & TII)991 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineInstr &I,
992 const SPIRVInstrInfo &TII) {
993 Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), 1);
994 Register Reg = DT.find(LLVMTy, CurMF);
995 if (Reg.isValid())
996 return getSPIRVTypeForVReg(Reg);
997 MachineBasicBlock &BB = *I.getParent();
998 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeBool))
999 .addDef(createTypeVReg(CurMF->getRegInfo()));
1000 return finishCreatingSPIRVType(LLVMTy, MIB);
1001 }
1002
getOrCreateSPIRVVectorType(SPIRVType * BaseType,unsigned NumElements,MachineIRBuilder & MIRBuilder)1003 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
1004 SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder) {
1005 return getOrCreateSPIRVType(
1006 FixedVectorType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)),
1007 NumElements),
1008 MIRBuilder);
1009 }
1010
getOrCreateSPIRVVectorType(SPIRVType * BaseType,unsigned NumElements,MachineInstr & I,const SPIRVInstrInfo & TII)1011 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
1012 SPIRVType *BaseType, unsigned NumElements, MachineInstr &I,
1013 const SPIRVInstrInfo &TII) {
1014 Type *LLVMTy = FixedVectorType::get(
1015 const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements);
1016 Register Reg = DT.find(LLVMTy, CurMF);
1017 if (Reg.isValid())
1018 return getSPIRVTypeForVReg(Reg);
1019 MachineBasicBlock &BB = *I.getParent();
1020 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeVector))
1021 .addDef(createTypeVReg(CurMF->getRegInfo()))
1022 .addUse(getSPIRVTypeID(BaseType))
1023 .addImm(NumElements);
1024 return finishCreatingSPIRVType(LLVMTy, MIB);
1025 }
1026
getOrCreateSPIRVArrayType(SPIRVType * BaseType,unsigned NumElements,MachineInstr & I,const SPIRVInstrInfo & TII)1027 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVArrayType(
1028 SPIRVType *BaseType, unsigned NumElements, MachineInstr &I,
1029 const SPIRVInstrInfo &TII) {
1030 Type *LLVMTy = ArrayType::get(
1031 const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements);
1032 Register Reg = DT.find(LLVMTy, CurMF);
1033 if (Reg.isValid())
1034 return getSPIRVTypeForVReg(Reg);
1035 MachineBasicBlock &BB = *I.getParent();
1036 SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(32, I, TII);
1037 Register Len = getOrCreateConstInt(NumElements, I, SpirvType, TII);
1038 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeArray))
1039 .addDef(createTypeVReg(CurMF->getRegInfo()))
1040 .addUse(getSPIRVTypeID(BaseType))
1041 .addUse(Len);
1042 return finishCreatingSPIRVType(LLVMTy, MIB);
1043 }
1044
getOrCreateSPIRVPointerType(SPIRVType * BaseType,MachineIRBuilder & MIRBuilder,SPIRV::StorageClass::StorageClass SClass)1045 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
1046 SPIRVType *BaseType, MachineIRBuilder &MIRBuilder,
1047 SPIRV::StorageClass::StorageClass SClass) {
1048 return getOrCreateSPIRVType(
1049 PointerType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)),
1050 storageClassToAddressSpace(SClass)),
1051 MIRBuilder);
1052 }
1053
getOrCreateSPIRVPointerType(SPIRVType * BaseType,MachineInstr & I,const SPIRVInstrInfo & TII,SPIRV::StorageClass::StorageClass SC)1054 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
1055 SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &TII,
1056 SPIRV::StorageClass::StorageClass SC) {
1057 Type *LLVMTy =
1058 PointerType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)),
1059 storageClassToAddressSpace(SC));
1060 Register Reg = DT.find(LLVMTy, CurMF);
1061 if (Reg.isValid())
1062 return getSPIRVTypeForVReg(Reg);
1063 MachineBasicBlock &BB = *I.getParent();
1064 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypePointer))
1065 .addDef(createTypeVReg(CurMF->getRegInfo()))
1066 .addImm(static_cast<uint32_t>(SC))
1067 .addUse(getSPIRVTypeID(BaseType));
1068 return finishCreatingSPIRVType(LLVMTy, MIB);
1069 }
1070
getOrCreateUndef(MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII)1071 Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I,
1072 SPIRVType *SpvType,
1073 const SPIRVInstrInfo &TII) {
1074 assert(SpvType);
1075 const Type *LLVMTy = getTypeForSPIRVType(SpvType);
1076 assert(LLVMTy);
1077 // Find a constant in DT or build a new one.
1078 UndefValue *UV = UndefValue::get(const_cast<Type *>(LLVMTy));
1079 Register Res = DT.find(UV, CurMF);
1080 if (Res.isValid())
1081 return Res;
1082 LLT LLTy = LLT::scalar(32);
1083 Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
1084 assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
1085 DT.add(UV, CurMF, Res);
1086
1087 MachineInstrBuilder MIB;
1088 MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpUndef))
1089 .addDef(Res)
1090 .addUse(getSPIRVTypeID(SpvType));
1091 const auto &ST = CurMF->getSubtarget();
1092 constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
1093 *ST.getRegisterInfo(), *ST.getRegBankInfo());
1094 return Res;
1095 }
1096