1 //===-- SPIRVGlobalRegistry.cpp - SPIR-V Global Registry --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the implementation of the SPIRVGlobalRegistry class,
10 // which is used to maintain rich type information required for SPIR-V even
11 // after lowering from LLVM IR to GMIR. It can convert an llvm::Type into
12 // an OpTypeXXX instruction, and map it to a virtual register. Also it builds
13 // and supports consistency of constants and global variables.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "SPIRVGlobalRegistry.h"
18 #include "SPIRV.h"
19 #include "SPIRVBuiltins.h"
20 #include "SPIRVSubtarget.h"
21 #include "SPIRVTargetMachine.h"
22 #include "SPIRVUtils.h"
23 
24 using namespace llvm;
SPIRVGlobalRegistry(unsigned PointerSize)25 SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize)
26     : PointerSize(PointerSize) {}
27 
assignIntTypeToVReg(unsigned BitWidth,Register VReg,MachineInstr & I,const SPIRVInstrInfo & TII)28 SPIRVType *SPIRVGlobalRegistry::assignIntTypeToVReg(unsigned BitWidth,
29                                                     Register VReg,
30                                                     MachineInstr &I,
31                                                     const SPIRVInstrInfo &TII) {
32   SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
33   assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);
34   return SpirvType;
35 }
36 
assignVectTypeToVReg(SPIRVType * BaseType,unsigned NumElements,Register VReg,MachineInstr & I,const SPIRVInstrInfo & TII)37 SPIRVType *SPIRVGlobalRegistry::assignVectTypeToVReg(
38     SPIRVType *BaseType, unsigned NumElements, Register VReg, MachineInstr &I,
39     const SPIRVInstrInfo &TII) {
40   SPIRVType *SpirvType =
41       getOrCreateSPIRVVectorType(BaseType, NumElements, I, TII);
42   assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);
43   return SpirvType;
44 }
45 
assignTypeToVReg(const Type * Type,Register VReg,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccessQual,bool EmitIR)46 SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg(
47     const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder,
48     SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
49 
50   SPIRVType *SpirvType =
51       getOrCreateSPIRVType(Type, MIRBuilder, AccessQual, EmitIR);
52   assignSPIRVTypeToVReg(SpirvType, VReg, MIRBuilder.getMF());
53   return SpirvType;
54 }
55 
assignSPIRVTypeToVReg(SPIRVType * SpirvType,Register VReg,MachineFunction & MF)56 void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType,
57                                                 Register VReg,
58                                                 MachineFunction &MF) {
59   VRegToTypeMap[&MF][VReg] = SpirvType;
60 }
61 
createTypeVReg(MachineIRBuilder & MIRBuilder)62 static Register createTypeVReg(MachineIRBuilder &MIRBuilder) {
63   auto &MRI = MIRBuilder.getMF().getRegInfo();
64   auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32));
65   MRI.setRegClass(Res, &SPIRV::TYPERegClass);
66   return Res;
67 }
68 
createTypeVReg(MachineRegisterInfo & MRI)69 static Register createTypeVReg(MachineRegisterInfo &MRI) {
70   auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32));
71   MRI.setRegClass(Res, &SPIRV::TYPERegClass);
72   return Res;
73 }
74 
getOpTypeBool(MachineIRBuilder & MIRBuilder)75 SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
76   return MIRBuilder.buildInstr(SPIRV::OpTypeBool)
77       .addDef(createTypeVReg(MIRBuilder));
78 }
79 
getOpTypeInt(uint32_t Width,MachineIRBuilder & MIRBuilder,bool IsSigned)80 SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width,
81                                              MachineIRBuilder &MIRBuilder,
82                                              bool IsSigned) {
83   assert(Width <= 64 && "Unsupported integer width!");
84   const SPIRVSubtarget &ST =
85       cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
86   if (ST.canUseExtension(
87           SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
88     MIRBuilder.buildInstr(SPIRV::OpExtension)
89         .addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
90     MIRBuilder.buildInstr(SPIRV::OpCapability)
91         .addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);
92   } else if (Width <= 8)
93     Width = 8;
94   else if (Width <= 16)
95     Width = 16;
96   else if (Width <= 32)
97     Width = 32;
98   else if (Width <= 64)
99     Width = 64;
100 
101   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeInt)
102                  .addDef(createTypeVReg(MIRBuilder))
103                  .addImm(Width)
104                  .addImm(IsSigned ? 1 : 0);
105   return MIB;
106 }
107 
getOpTypeFloat(uint32_t Width,MachineIRBuilder & MIRBuilder)108 SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
109                                                MachineIRBuilder &MIRBuilder) {
110   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
111                  .addDef(createTypeVReg(MIRBuilder))
112                  .addImm(Width);
113   return MIB;
114 }
115 
getOpTypeVoid(MachineIRBuilder & MIRBuilder)116 SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) {
117   return MIRBuilder.buildInstr(SPIRV::OpTypeVoid)
118       .addDef(createTypeVReg(MIRBuilder));
119 }
120 
getOpTypeVector(uint32_t NumElems,SPIRVType * ElemType,MachineIRBuilder & MIRBuilder)121 SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems,
122                                                 SPIRVType *ElemType,
123                                                 MachineIRBuilder &MIRBuilder) {
124   auto EleOpc = ElemType->getOpcode();
125   assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat ||
126           EleOpc == SPIRV::OpTypeBool) &&
127          "Invalid vector element type");
128 
129   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeVector)
130                  .addDef(createTypeVReg(MIRBuilder))
131                  .addUse(getSPIRVTypeID(ElemType))
132                  .addImm(NumElems);
133   return MIB;
134 }
135 
136 std::tuple<Register, ConstantInt *, bool>
getOrCreateConstIntReg(uint64_t Val,SPIRVType * SpvType,MachineIRBuilder * MIRBuilder,MachineInstr * I,const SPIRVInstrInfo * TII)137 SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType,
138                                             MachineIRBuilder *MIRBuilder,
139                                             MachineInstr *I,
140                                             const SPIRVInstrInfo *TII) {
141   const IntegerType *LLVMIntTy;
142   if (SpvType)
143     LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType));
144   else
145     LLVMIntTy = IntegerType::getInt32Ty(CurMF->getFunction().getContext());
146   bool NewInstr = false;
147   // Find a constant in DT or build a new one.
148   ConstantInt *CI = ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val);
149   Register Res = DT.find(CI, CurMF);
150   if (!Res.isValid()) {
151     unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
152     LLT LLTy = LLT::scalar(32);
153     Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
154     CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
155     if (MIRBuilder)
156       assignTypeToVReg(LLVMIntTy, Res, *MIRBuilder);
157     else
158       assignIntTypeToVReg(BitWidth, Res, *I, *TII);
159     DT.add(CI, CurMF, Res);
160     NewInstr = true;
161   }
162   return std::make_tuple(Res, CI, NewInstr);
163 }
164 
getOrCreateConstInt(uint64_t Val,MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII)165 Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
166                                                   SPIRVType *SpvType,
167                                                   const SPIRVInstrInfo &TII) {
168   assert(SpvType);
169   ConstantInt *CI;
170   Register Res;
171   bool New;
172   std::tie(Res, CI, New) =
173       getOrCreateConstIntReg(Val, SpvType, nullptr, &I, &TII);
174   // If we have found Res register which is defined by the passed G_CONSTANT
175   // machine instruction, a new constant instruction should be created.
176   if (!New && (!I.getOperand(0).isReg() || Res != I.getOperand(0).getReg()))
177     return Res;
178   MachineInstrBuilder MIB;
179   MachineBasicBlock &BB = *I.getParent();
180   if (Val) {
181     MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantI))
182               .addDef(Res)
183               .addUse(getSPIRVTypeID(SpvType));
184     addNumImm(APInt(getScalarOrVectorBitWidth(SpvType), Val), MIB);
185   } else {
186     MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
187               .addDef(Res)
188               .addUse(getSPIRVTypeID(SpvType));
189   }
190   const auto &ST = CurMF->getSubtarget();
191   constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
192                                    *ST.getRegisterInfo(), *ST.getRegBankInfo());
193   return Res;
194 }
195 
buildConstantInt(uint64_t Val,MachineIRBuilder & MIRBuilder,SPIRVType * SpvType,bool EmitIR)196 Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
197                                                MachineIRBuilder &MIRBuilder,
198                                                SPIRVType *SpvType,
199                                                bool EmitIR) {
200   auto &MF = MIRBuilder.getMF();
201   const IntegerType *LLVMIntTy;
202   if (SpvType)
203     LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType));
204   else
205     LLVMIntTy = IntegerType::getInt32Ty(MF.getFunction().getContext());
206   // Find a constant in DT or build a new one.
207   const auto ConstInt =
208       ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val);
209   Register Res = DT.find(ConstInt, &MF);
210   if (!Res.isValid()) {
211     unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
212     LLT LLTy = LLT::scalar(EmitIR ? BitWidth : 32);
213     Res = MF.getRegInfo().createGenericVirtualRegister(LLTy);
214     MF.getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
215     assignTypeToVReg(LLVMIntTy, Res, MIRBuilder,
216                      SPIRV::AccessQualifier::ReadWrite, EmitIR);
217     DT.add(ConstInt, &MIRBuilder.getMF(), Res);
218     if (EmitIR) {
219       MIRBuilder.buildConstant(Res, *ConstInt);
220     } else {
221       MachineInstrBuilder MIB;
222       if (Val) {
223         assert(SpvType);
224         MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
225                   .addDef(Res)
226                   .addUse(getSPIRVTypeID(SpvType));
227         addNumImm(APInt(BitWidth, Val), MIB);
228       } else {
229         assert(SpvType);
230         MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
231                   .addDef(Res)
232                   .addUse(getSPIRVTypeID(SpvType));
233       }
234       const auto &Subtarget = CurMF->getSubtarget();
235       constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
236                                        *Subtarget.getRegisterInfo(),
237                                        *Subtarget.getRegBankInfo());
238     }
239   }
240   return Res;
241 }
242 
buildConstantFP(APFloat Val,MachineIRBuilder & MIRBuilder,SPIRVType * SpvType)243 Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,
244                                               MachineIRBuilder &MIRBuilder,
245                                               SPIRVType *SpvType) {
246   auto &MF = MIRBuilder.getMF();
247   auto &Ctx = MF.getFunction().getContext();
248   if (!SpvType) {
249     const Type *LLVMFPTy = Type::getFloatTy(Ctx);
250     SpvType = getOrCreateSPIRVType(LLVMFPTy, MIRBuilder);
251   }
252   // Find a constant in DT or build a new one.
253   const auto ConstFP = ConstantFP::get(Ctx, Val);
254   Register Res = DT.find(ConstFP, &MF);
255   if (!Res.isValid()) {
256     Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(32));
257     MF.getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
258     assignSPIRVTypeToVReg(SpvType, Res, MF);
259     DT.add(ConstFP, &MF, Res);
260 
261     MachineInstrBuilder MIB;
262     MIB = MIRBuilder.buildInstr(SPIRV::OpConstantF)
263               .addDef(Res)
264               .addUse(getSPIRVTypeID(SpvType));
265     addNumImm(ConstFP->getValueAPF().bitcastToAPInt(), MIB);
266   }
267 
268   return Res;
269 }
270 
getOrCreateIntCompositeOrNull(uint64_t Val,MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII,Constant * CA,unsigned BitWidth,unsigned ElemCnt)271 Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
272     uint64_t Val, MachineInstr &I, SPIRVType *SpvType,
273     const SPIRVInstrInfo &TII, Constant *CA, unsigned BitWidth,
274     unsigned ElemCnt) {
275   // Find a constant vector in DT or build a new one.
276   Register Res = DT.find(CA, CurMF);
277   if (!Res.isValid()) {
278     SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
279     // SpvScalConst should be created before SpvVecConst to avoid undefined ID
280     // error on validation.
281     // TODO: can moved below once sorting of types/consts/defs is implemented.
282     Register SpvScalConst;
283     if (Val)
284       SpvScalConst = getOrCreateConstInt(Val, I, SpvBaseType, TII);
285     // TODO: maybe use bitwidth of base type.
286     LLT LLTy = LLT::scalar(32);
287     Register SpvVecConst =
288         CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
289     CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::IDRegClass);
290     assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF);
291     DT.add(CA, CurMF, SpvVecConst);
292     MachineInstrBuilder MIB;
293     MachineBasicBlock &BB = *I.getParent();
294     if (Val) {
295       MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantComposite))
296                 .addDef(SpvVecConst)
297                 .addUse(getSPIRVTypeID(SpvType));
298       for (unsigned i = 0; i < ElemCnt; ++i)
299         MIB.addUse(SpvScalConst);
300     } else {
301       MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
302                 .addDef(SpvVecConst)
303                 .addUse(getSPIRVTypeID(SpvType));
304     }
305     const auto &Subtarget = CurMF->getSubtarget();
306     constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
307                                      *Subtarget.getRegisterInfo(),
308                                      *Subtarget.getRegBankInfo());
309     return SpvVecConst;
310   }
311   return Res;
312 }
313 
314 Register
getOrCreateConsIntVector(uint64_t Val,MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII)315 SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val, MachineInstr &I,
316                                               SPIRVType *SpvType,
317                                               const SPIRVInstrInfo &TII) {
318   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
319   assert(LLVMTy->isVectorTy());
320   const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
321   Type *LLVMBaseTy = LLVMVecTy->getElementType();
322   const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
323   auto ConstVec =
324       ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt);
325   unsigned BW = getScalarOrVectorBitWidth(SpvType);
326   return getOrCreateIntCompositeOrNull(Val, I, SpvType, TII, ConstVec, BW,
327                                        SpvType->getOperand(2).getImm());
328 }
329 
330 Register
getOrCreateConsIntArray(uint64_t Val,MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII)331 SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val, MachineInstr &I,
332                                              SPIRVType *SpvType,
333                                              const SPIRVInstrInfo &TII) {
334   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
335   assert(LLVMTy->isArrayTy());
336   const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy);
337   Type *LLVMBaseTy = LLVMArrTy->getElementType();
338   const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
339   auto ConstArr =
340       ConstantArray::get(const_cast<ArrayType *>(LLVMArrTy), {ConstInt});
341   SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
342   unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy);
343   return getOrCreateIntCompositeOrNull(Val, I, SpvType, TII, ConstArr, BW,
344                                        LLVMArrTy->getNumElements());
345 }
346 
getOrCreateIntCompositeOrNull(uint64_t Val,MachineIRBuilder & MIRBuilder,SPIRVType * SpvType,bool EmitIR,Constant * CA,unsigned BitWidth,unsigned ElemCnt)347 Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
348     uint64_t Val, MachineIRBuilder &MIRBuilder, SPIRVType *SpvType, bool EmitIR,
349     Constant *CA, unsigned BitWidth, unsigned ElemCnt) {
350   Register Res = DT.find(CA, CurMF);
351   if (!Res.isValid()) {
352     Register SpvScalConst;
353     if (Val || EmitIR) {
354       SPIRVType *SpvBaseType =
355           getOrCreateSPIRVIntegerType(BitWidth, MIRBuilder);
356       SpvScalConst = buildConstantInt(Val, MIRBuilder, SpvBaseType, EmitIR);
357     }
358     LLT LLTy = EmitIR ? LLT::fixed_vector(ElemCnt, BitWidth) : LLT::scalar(32);
359     Register SpvVecConst =
360         CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
361     CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::IDRegClass);
362     assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF);
363     DT.add(CA, CurMF, SpvVecConst);
364     if (EmitIR) {
365       MIRBuilder.buildSplatVector(SpvVecConst, SpvScalConst);
366     } else {
367       if (Val) {
368         auto MIB = MIRBuilder.buildInstr(SPIRV::OpConstantComposite)
369                        .addDef(SpvVecConst)
370                        .addUse(getSPIRVTypeID(SpvType));
371         for (unsigned i = 0; i < ElemCnt; ++i)
372           MIB.addUse(SpvScalConst);
373       } else {
374         MIRBuilder.buildInstr(SPIRV::OpConstantNull)
375             .addDef(SpvVecConst)
376             .addUse(getSPIRVTypeID(SpvType));
377       }
378     }
379     return SpvVecConst;
380   }
381   return Res;
382 }
383 
384 Register
getOrCreateConsIntVector(uint64_t Val,MachineIRBuilder & MIRBuilder,SPIRVType * SpvType,bool EmitIR)385 SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val,
386                                               MachineIRBuilder &MIRBuilder,
387                                               SPIRVType *SpvType, bool EmitIR) {
388   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
389   assert(LLVMTy->isVectorTy());
390   const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
391   Type *LLVMBaseTy = LLVMVecTy->getElementType();
392   const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
393   auto ConstVec =
394       ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt);
395   unsigned BW = getScalarOrVectorBitWidth(SpvType);
396   return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR,
397                                        ConstVec, BW,
398                                        SpvType->getOperand(2).getImm());
399 }
400 
401 Register
getOrCreateConsIntArray(uint64_t Val,MachineIRBuilder & MIRBuilder,SPIRVType * SpvType,bool EmitIR)402 SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val,
403                                              MachineIRBuilder &MIRBuilder,
404                                              SPIRVType *SpvType, bool EmitIR) {
405   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
406   assert(LLVMTy->isArrayTy());
407   const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy);
408   Type *LLVMBaseTy = LLVMArrTy->getElementType();
409   const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
410   auto ConstArr =
411       ConstantArray::get(const_cast<ArrayType *>(LLVMArrTy), {ConstInt});
412   SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
413   unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy);
414   return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR,
415                                        ConstArr, BW,
416                                        LLVMArrTy->getNumElements());
417 }
418 
419 Register
getOrCreateConstNullPtr(MachineIRBuilder & MIRBuilder,SPIRVType * SpvType)420 SPIRVGlobalRegistry::getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder,
421                                              SPIRVType *SpvType) {
422   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
423   const PointerType *LLVMPtrTy = cast<PointerType>(LLVMTy);
424   // Find a constant in DT or build a new one.
425   Constant *CP = ConstantPointerNull::get(const_cast<PointerType *>(LLVMPtrTy));
426   Register Res = DT.find(CP, CurMF);
427   if (!Res.isValid()) {
428     LLT LLTy = LLT::pointer(LLVMPtrTy->getAddressSpace(), PointerSize);
429     Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
430     CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
431     assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
432     MIRBuilder.buildInstr(SPIRV::OpConstantNull)
433         .addDef(Res)
434         .addUse(getSPIRVTypeID(SpvType));
435     DT.add(CP, CurMF, Res);
436   }
437   return Res;
438 }
439 
buildConstantSampler(Register ResReg,unsigned AddrMode,unsigned Param,unsigned FilerMode,MachineIRBuilder & MIRBuilder,SPIRVType * SpvType)440 Register SPIRVGlobalRegistry::buildConstantSampler(
441     Register ResReg, unsigned AddrMode, unsigned Param, unsigned FilerMode,
442     MachineIRBuilder &MIRBuilder, SPIRVType *SpvType) {
443   SPIRVType *SampTy;
444   if (SpvType)
445     SampTy = getOrCreateSPIRVType(getTypeForSPIRVType(SpvType), MIRBuilder);
446   else
447     SampTy = getOrCreateSPIRVTypeByName("opencl.sampler_t", MIRBuilder);
448 
449   auto Sampler =
450       ResReg.isValid()
451           ? ResReg
452           : MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
453   auto Res = MIRBuilder.buildInstr(SPIRV::OpConstantSampler)
454                  .addDef(Sampler)
455                  .addUse(getSPIRVTypeID(SampTy))
456                  .addImm(AddrMode)
457                  .addImm(Param)
458                  .addImm(FilerMode);
459   assert(Res->getOperand(0).isReg());
460   return Res->getOperand(0).getReg();
461 }
462 
buildGlobalVariable(Register ResVReg,SPIRVType * BaseType,StringRef Name,const GlobalValue * GV,SPIRV::StorageClass::StorageClass Storage,const MachineInstr * Init,bool IsConst,bool HasLinkageTy,SPIRV::LinkageType::LinkageType LinkageType,MachineIRBuilder & MIRBuilder,bool IsInstSelector)463 Register SPIRVGlobalRegistry::buildGlobalVariable(
464     Register ResVReg, SPIRVType *BaseType, StringRef Name,
465     const GlobalValue *GV, SPIRV::StorageClass::StorageClass Storage,
466     const MachineInstr *Init, bool IsConst, bool HasLinkageTy,
467     SPIRV::LinkageType::LinkageType LinkageType, MachineIRBuilder &MIRBuilder,
468     bool IsInstSelector) {
469   const GlobalVariable *GVar = nullptr;
470   if (GV)
471     GVar = cast<const GlobalVariable>(GV);
472   else {
473     // If GV is not passed explicitly, use the name to find or construct
474     // the global variable.
475     Module *M = MIRBuilder.getMF().getFunction().getParent();
476     GVar = M->getGlobalVariable(Name);
477     if (GVar == nullptr) {
478       const Type *Ty = getTypeForSPIRVType(BaseType); // TODO: check type.
479       GVar = new GlobalVariable(*M, const_cast<Type *>(Ty), false,
480                                 GlobalValue::ExternalLinkage, nullptr,
481                                 Twine(Name));
482     }
483     GV = GVar;
484   }
485   Register Reg = DT.find(GVar, &MIRBuilder.getMF());
486   if (Reg.isValid()) {
487     if (Reg != ResVReg)
488       MIRBuilder.buildCopy(ResVReg, Reg);
489     return ResVReg;
490   }
491 
492   auto MIB = MIRBuilder.buildInstr(SPIRV::OpVariable)
493                  .addDef(ResVReg)
494                  .addUse(getSPIRVTypeID(BaseType))
495                  .addImm(static_cast<uint32_t>(Storage));
496 
497   if (Init != 0) {
498     MIB.addUse(Init->getOperand(0).getReg());
499   }
500 
501   // ISel may introduce a new register on this step, so we need to add it to
502   // DT and correct its type avoiding fails on the next stage.
503   if (IsInstSelector) {
504     const auto &Subtarget = CurMF->getSubtarget();
505     constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
506                                      *Subtarget.getRegisterInfo(),
507                                      *Subtarget.getRegBankInfo());
508   }
509   Reg = MIB->getOperand(0).getReg();
510   DT.add(GVar, &MIRBuilder.getMF(), Reg);
511 
512   // Set to Reg the same type as ResVReg has.
513   auto MRI = MIRBuilder.getMRI();
514   assert(MRI->getType(ResVReg).isPointer() && "Pointer type is expected");
515   if (Reg != ResVReg) {
516     LLT RegLLTy = LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), 32);
517     MRI->setType(Reg, RegLLTy);
518     assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF());
519   }
520 
521   // If it's a global variable with name, output OpName for it.
522   if (GVar && GVar->hasName())
523     buildOpName(Reg, GVar->getName(), MIRBuilder);
524 
525   // Output decorations for the GV.
526   // TODO: maybe move to GenerateDecorations pass.
527   if (IsConst)
528     buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Constant, {});
529 
530   if (GVar && GVar->getAlign().valueOrOne().value() != 1) {
531     unsigned Alignment = (unsigned)GVar->getAlign().valueOrOne().value();
532     buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Alignment, {Alignment});
533   }
534 
535   if (HasLinkageTy)
536     buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::LinkageAttributes,
537                     {static_cast<uint32_t>(LinkageType)}, Name);
538 
539   SPIRV::BuiltIn::BuiltIn BuiltInId;
540   if (getSpirvBuiltInIdByName(Name, BuiltInId))
541     buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::BuiltIn,
542                     {static_cast<uint32_t>(BuiltInId)});
543 
544   return Reg;
545 }
546 
getOpTypeArray(uint32_t NumElems,SPIRVType * ElemType,MachineIRBuilder & MIRBuilder,bool EmitIR)547 SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems,
548                                                SPIRVType *ElemType,
549                                                MachineIRBuilder &MIRBuilder,
550                                                bool EmitIR) {
551   assert((ElemType->getOpcode() != SPIRV::OpTypeVoid) &&
552          "Invalid array element type");
553   Register NumElementsVReg =
554       buildConstantInt(NumElems, MIRBuilder, nullptr, EmitIR);
555   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeArray)
556                  .addDef(createTypeVReg(MIRBuilder))
557                  .addUse(getSPIRVTypeID(ElemType))
558                  .addUse(NumElementsVReg);
559   return MIB;
560 }
561 
getOpTypeOpaque(const StructType * Ty,MachineIRBuilder & MIRBuilder)562 SPIRVType *SPIRVGlobalRegistry::getOpTypeOpaque(const StructType *Ty,
563                                                 MachineIRBuilder &MIRBuilder) {
564   assert(Ty->hasName());
565   const StringRef Name = Ty->hasName() ? Ty->getName() : "";
566   Register ResVReg = createTypeVReg(MIRBuilder);
567   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeOpaque).addDef(ResVReg);
568   addStringImm(Name, MIB);
569   buildOpName(ResVReg, Name, MIRBuilder);
570   return MIB;
571 }
572 
getOpTypeStruct(const StructType * Ty,MachineIRBuilder & MIRBuilder,bool EmitIR)573 SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty,
574                                                 MachineIRBuilder &MIRBuilder,
575                                                 bool EmitIR) {
576   SmallVector<Register, 4> FieldTypes;
577   for (const auto &Elem : Ty->elements()) {
578     SPIRVType *ElemTy = findSPIRVType(Elem, MIRBuilder);
579     assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid &&
580            "Invalid struct element type");
581     FieldTypes.push_back(getSPIRVTypeID(ElemTy));
582   }
583   Register ResVReg = createTypeVReg(MIRBuilder);
584   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeStruct).addDef(ResVReg);
585   for (const auto &Ty : FieldTypes)
586     MIB.addUse(Ty);
587   if (Ty->hasName())
588     buildOpName(ResVReg, Ty->getName(), MIRBuilder);
589   if (Ty->isPacked())
590     buildOpDecorate(ResVReg, MIRBuilder, SPIRV::Decoration::CPacked, {});
591   return MIB;
592 }
593 
getOrCreateSpecialType(const Type * Ty,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccQual)594 SPIRVType *SPIRVGlobalRegistry::getOrCreateSpecialType(
595     const Type *Ty, MachineIRBuilder &MIRBuilder,
596     SPIRV::AccessQualifier::AccessQualifier AccQual) {
597   assert(isSpecialOpaqueType(Ty) && "Not a special opaque builtin type");
598   return SPIRV::lowerBuiltinType(Ty, AccQual, MIRBuilder, this);
599 }
600 
getOpTypePointer(SPIRV::StorageClass::StorageClass SC,SPIRVType * ElemType,MachineIRBuilder & MIRBuilder,Register Reg)601 SPIRVType *SPIRVGlobalRegistry::getOpTypePointer(
602     SPIRV::StorageClass::StorageClass SC, SPIRVType *ElemType,
603     MachineIRBuilder &MIRBuilder, Register Reg) {
604   if (!Reg.isValid())
605     Reg = createTypeVReg(MIRBuilder);
606   return MIRBuilder.buildInstr(SPIRV::OpTypePointer)
607       .addDef(Reg)
608       .addImm(static_cast<uint32_t>(SC))
609       .addUse(getSPIRVTypeID(ElemType));
610 }
611 
getOpTypeForwardPointer(SPIRV::StorageClass::StorageClass SC,MachineIRBuilder & MIRBuilder)612 SPIRVType *SPIRVGlobalRegistry::getOpTypeForwardPointer(
613     SPIRV::StorageClass::StorageClass SC, MachineIRBuilder &MIRBuilder) {
614   return MIRBuilder.buildInstr(SPIRV::OpTypeForwardPointer)
615       .addUse(createTypeVReg(MIRBuilder))
616       .addImm(static_cast<uint32_t>(SC));
617 }
618 
getOpTypeFunction(SPIRVType * RetType,const SmallVectorImpl<SPIRVType * > & ArgTypes,MachineIRBuilder & MIRBuilder)619 SPIRVType *SPIRVGlobalRegistry::getOpTypeFunction(
620     SPIRVType *RetType, const SmallVectorImpl<SPIRVType *> &ArgTypes,
621     MachineIRBuilder &MIRBuilder) {
622   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFunction)
623                  .addDef(createTypeVReg(MIRBuilder))
624                  .addUse(getSPIRVTypeID(RetType));
625   for (const SPIRVType *ArgType : ArgTypes)
626     MIB.addUse(getSPIRVTypeID(ArgType));
627   return MIB;
628 }
629 
getOrCreateOpTypeFunctionWithArgs(const Type * Ty,SPIRVType * RetType,const SmallVectorImpl<SPIRVType * > & ArgTypes,MachineIRBuilder & MIRBuilder)630 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeFunctionWithArgs(
631     const Type *Ty, SPIRVType *RetType,
632     const SmallVectorImpl<SPIRVType *> &ArgTypes,
633     MachineIRBuilder &MIRBuilder) {
634   Register Reg = DT.find(Ty, &MIRBuilder.getMF());
635   if (Reg.isValid())
636     return getSPIRVTypeForVReg(Reg);
637   SPIRVType *SpirvType = getOpTypeFunction(RetType, ArgTypes, MIRBuilder);
638   DT.add(Ty, CurMF, getSPIRVTypeID(SpirvType));
639   return finishCreatingSPIRVType(Ty, SpirvType);
640 }
641 
findSPIRVType(const Type * Ty,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccQual,bool EmitIR)642 SPIRVType *SPIRVGlobalRegistry::findSPIRVType(
643     const Type *Ty, MachineIRBuilder &MIRBuilder,
644     SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
645   Register Reg = DT.find(Ty, &MIRBuilder.getMF());
646   if (Reg.isValid())
647     return getSPIRVTypeForVReg(Reg);
648   if (ForwardPointerTypes.contains(Ty))
649     return ForwardPointerTypes[Ty];
650   return restOfCreateSPIRVType(Ty, MIRBuilder, AccQual, EmitIR);
651 }
652 
getSPIRVTypeID(const SPIRVType * SpirvType) const653 Register SPIRVGlobalRegistry::getSPIRVTypeID(const SPIRVType *SpirvType) const {
654   assert(SpirvType && "Attempting to get type id for nullptr type.");
655   if (SpirvType->getOpcode() == SPIRV::OpTypeForwardPointer)
656     return SpirvType->uses().begin()->getReg();
657   return SpirvType->defs().begin()->getReg();
658 }
659 
createSPIRVType(const Type * Ty,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccQual,bool EmitIR)660 SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
661     const Type *Ty, MachineIRBuilder &MIRBuilder,
662     SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
663   if (isSpecialOpaqueType(Ty))
664     return getOrCreateSpecialType(Ty, MIRBuilder, AccQual);
665   auto &TypeToSPIRVTypeMap = DT.getTypes()->getAllUses();
666   auto t = TypeToSPIRVTypeMap.find(Ty);
667   if (t != TypeToSPIRVTypeMap.end()) {
668     auto tt = t->second.find(&MIRBuilder.getMF());
669     if (tt != t->second.end())
670       return getSPIRVTypeForVReg(tt->second);
671   }
672 
673   if (auto IType = dyn_cast<IntegerType>(Ty)) {
674     const unsigned Width = IType->getBitWidth();
675     return Width == 1 ? getOpTypeBool(MIRBuilder)
676                       : getOpTypeInt(Width, MIRBuilder, false);
677   }
678   if (Ty->isFloatingPointTy())
679     return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
680   if (Ty->isVoidTy())
681     return getOpTypeVoid(MIRBuilder);
682   if (Ty->isVectorTy()) {
683     SPIRVType *El =
684         findSPIRVType(cast<FixedVectorType>(Ty)->getElementType(), MIRBuilder);
685     return getOpTypeVector(cast<FixedVectorType>(Ty)->getNumElements(), El,
686                            MIRBuilder);
687   }
688   if (Ty->isArrayTy()) {
689     SPIRVType *El = findSPIRVType(Ty->getArrayElementType(), MIRBuilder);
690     return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder, EmitIR);
691   }
692   if (auto SType = dyn_cast<StructType>(Ty)) {
693     if (SType->isOpaque())
694       return getOpTypeOpaque(SType, MIRBuilder);
695     return getOpTypeStruct(SType, MIRBuilder, EmitIR);
696   }
697   if (auto FType = dyn_cast<FunctionType>(Ty)) {
698     SPIRVType *RetTy = findSPIRVType(FType->getReturnType(), MIRBuilder);
699     SmallVector<SPIRVType *, 4> ParamTypes;
700     for (const auto &t : FType->params()) {
701       ParamTypes.push_back(findSPIRVType(t, MIRBuilder));
702     }
703     return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder);
704   }
705   if (auto PType = dyn_cast<PointerType>(Ty)) {
706     SPIRVType *SpvElementType;
707     // At the moment, all opaque pointers correspond to i8 element type.
708     // TODO: change the implementation once opaque pointers are supported
709     // in the SPIR-V specification.
710     SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);
711     auto SC = addressSpaceToStorageClass(PType->getAddressSpace());
712     // Null pointer means we have a loop in type definitions, make and
713     // return corresponding OpTypeForwardPointer.
714     if (SpvElementType == nullptr) {
715       if (!ForwardPointerTypes.contains(Ty))
716         ForwardPointerTypes[PType] = getOpTypeForwardPointer(SC, MIRBuilder);
717       return ForwardPointerTypes[PType];
718     }
719     Register Reg(0);
720     // If we have forward pointer associated with this type, use its register
721     // operand to create OpTypePointer.
722     if (ForwardPointerTypes.contains(PType))
723       Reg = getSPIRVTypeID(ForwardPointerTypes[PType]);
724 
725     return getOpTypePointer(SC, SpvElementType, MIRBuilder, Reg);
726   }
727   llvm_unreachable("Unable to convert LLVM type to SPIRVType");
728 }
729 
restOfCreateSPIRVType(const Type * Ty,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccessQual,bool EmitIR)730 SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
731     const Type *Ty, MachineIRBuilder &MIRBuilder,
732     SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
733   if (TypesInProcessing.count(Ty) && !Ty->isPointerTy())
734     return nullptr;
735   TypesInProcessing.insert(Ty);
736   SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
737   TypesInProcessing.erase(Ty);
738   VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;
739   SPIRVToLLVMType[SpirvType] = Ty;
740   Register Reg = DT.find(Ty, &MIRBuilder.getMF());
741   // Do not add OpTypeForwardPointer to DT, a corresponding normal pointer type
742   // will be added later. For special types it is already added to DT.
743   if (SpirvType->getOpcode() != SPIRV::OpTypeForwardPointer && !Reg.isValid() &&
744       !isSpecialOpaqueType(Ty)) {
745     if (!Ty->isPointerTy())
746       DT.add(Ty, &MIRBuilder.getMF(), getSPIRVTypeID(SpirvType));
747     else
748       DT.add(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()),
749              Ty->getPointerAddressSpace(), &MIRBuilder.getMF(),
750              getSPIRVTypeID(SpirvType));
751   }
752 
753   return SpirvType;
754 }
755 
getSPIRVTypeForVReg(Register VReg) const756 SPIRVType *SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg) const {
757   auto t = VRegToTypeMap.find(CurMF);
758   if (t != VRegToTypeMap.end()) {
759     auto tt = t->second.find(VReg);
760     if (tt != t->second.end())
761       return tt->second;
762   }
763   return nullptr;
764 }
765 
getOrCreateSPIRVType(const Type * Ty,MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccessQual,bool EmitIR)766 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(
767     const Type *Ty, MachineIRBuilder &MIRBuilder,
768     SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
769   Register Reg;
770   if (!Ty->isPointerTy())
771     Reg = DT.find(Ty, &MIRBuilder.getMF());
772   else
773     Reg =
774         DT.find(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()),
775                 Ty->getPointerAddressSpace(), &MIRBuilder.getMF());
776 
777   if (Reg.isValid() && !isSpecialOpaqueType(Ty))
778     return getSPIRVTypeForVReg(Reg);
779   TypesInProcessing.clear();
780   SPIRVType *STy = restOfCreateSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
781   // Create normal pointer types for the corresponding OpTypeForwardPointers.
782   for (auto &CU : ForwardPointerTypes) {
783     const Type *Ty2 = CU.first;
784     SPIRVType *STy2 = CU.second;
785     if ((Reg = DT.find(Ty2, &MIRBuilder.getMF())).isValid())
786       STy2 = getSPIRVTypeForVReg(Reg);
787     else
788       STy2 = restOfCreateSPIRVType(Ty2, MIRBuilder, AccessQual, EmitIR);
789     if (Ty == Ty2)
790       STy = STy2;
791   }
792   ForwardPointerTypes.clear();
793   return STy;
794 }
795 
isScalarOfType(Register VReg,unsigned TypeOpcode) const796 bool SPIRVGlobalRegistry::isScalarOfType(Register VReg,
797                                          unsigned TypeOpcode) const {
798   SPIRVType *Type = getSPIRVTypeForVReg(VReg);
799   assert(Type && "isScalarOfType VReg has no type assigned");
800   return Type->getOpcode() == TypeOpcode;
801 }
802 
isScalarOrVectorOfType(Register VReg,unsigned TypeOpcode) const803 bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg,
804                                                  unsigned TypeOpcode) const {
805   SPIRVType *Type = getSPIRVTypeForVReg(VReg);
806   assert(Type && "isScalarOrVectorOfType VReg has no type assigned");
807   if (Type->getOpcode() == TypeOpcode)
808     return true;
809   if (Type->getOpcode() == SPIRV::OpTypeVector) {
810     Register ScalarTypeVReg = Type->getOperand(1).getReg();
811     SPIRVType *ScalarType = getSPIRVTypeForVReg(ScalarTypeVReg);
812     return ScalarType->getOpcode() == TypeOpcode;
813   }
814   return false;
815 }
816 
817 unsigned
getScalarOrVectorBitWidth(const SPIRVType * Type) const818 SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const {
819   assert(Type && "Invalid Type pointer");
820   if (Type->getOpcode() == SPIRV::OpTypeVector) {
821     auto EleTypeReg = Type->getOperand(1).getReg();
822     Type = getSPIRVTypeForVReg(EleTypeReg);
823   }
824   if (Type->getOpcode() == SPIRV::OpTypeInt ||
825       Type->getOpcode() == SPIRV::OpTypeFloat)
826     return Type->getOperand(1).getImm();
827   if (Type->getOpcode() == SPIRV::OpTypeBool)
828     return 1;
829   llvm_unreachable("Attempting to get bit width of non-integer/float type.");
830 }
831 
isScalarOrVectorSigned(const SPIRVType * Type) const832 bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const {
833   assert(Type && "Invalid Type pointer");
834   if (Type->getOpcode() == SPIRV::OpTypeVector) {
835     auto EleTypeReg = Type->getOperand(1).getReg();
836     Type = getSPIRVTypeForVReg(EleTypeReg);
837   }
838   if (Type->getOpcode() == SPIRV::OpTypeInt)
839     return Type->getOperand(2).getImm() != 0;
840   llvm_unreachable("Attempting to get sign of non-integer type.");
841 }
842 
843 SPIRV::StorageClass::StorageClass
getPointerStorageClass(Register VReg) const844 SPIRVGlobalRegistry::getPointerStorageClass(Register VReg) const {
845   SPIRVType *Type = getSPIRVTypeForVReg(VReg);
846   assert(Type && Type->getOpcode() == SPIRV::OpTypePointer &&
847          Type->getOperand(1).isImm() && "Pointer type is expected");
848   return static_cast<SPIRV::StorageClass::StorageClass>(
849       Type->getOperand(1).getImm());
850 }
851 
getOrCreateOpTypeImage(MachineIRBuilder & MIRBuilder,SPIRVType * SampledType,SPIRV::Dim::Dim Dim,uint32_t Depth,uint32_t Arrayed,uint32_t Multisampled,uint32_t Sampled,SPIRV::ImageFormat::ImageFormat ImageFormat,SPIRV::AccessQualifier::AccessQualifier AccessQual)852 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeImage(
853     MachineIRBuilder &MIRBuilder, SPIRVType *SampledType, SPIRV::Dim::Dim Dim,
854     uint32_t Depth, uint32_t Arrayed, uint32_t Multisampled, uint32_t Sampled,
855     SPIRV::ImageFormat::ImageFormat ImageFormat,
856     SPIRV::AccessQualifier::AccessQualifier AccessQual) {
857   SPIRV::ImageTypeDescriptor TD(SPIRVToLLVMType.lookup(SampledType), Dim, Depth,
858                                 Arrayed, Multisampled, Sampled, ImageFormat,
859                                 AccessQual);
860   if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
861     return Res;
862   Register ResVReg = createTypeVReg(MIRBuilder);
863   DT.add(TD, &MIRBuilder.getMF(), ResVReg);
864   return MIRBuilder.buildInstr(SPIRV::OpTypeImage)
865       .addDef(ResVReg)
866       .addUse(getSPIRVTypeID(SampledType))
867       .addImm(Dim)
868       .addImm(Depth)        // Depth (whether or not it is a Depth image).
869       .addImm(Arrayed)      // Arrayed.
870       .addImm(Multisampled) // Multisampled (0 = only single-sample).
871       .addImm(Sampled)      // Sampled (0 = usage known at runtime).
872       .addImm(ImageFormat)
873       .addImm(AccessQual);
874 }
875 
876 SPIRVType *
getOrCreateOpTypeSampler(MachineIRBuilder & MIRBuilder)877 SPIRVGlobalRegistry::getOrCreateOpTypeSampler(MachineIRBuilder &MIRBuilder) {
878   SPIRV::SamplerTypeDescriptor TD;
879   if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
880     return Res;
881   Register ResVReg = createTypeVReg(MIRBuilder);
882   DT.add(TD, &MIRBuilder.getMF(), ResVReg);
883   return MIRBuilder.buildInstr(SPIRV::OpTypeSampler).addDef(ResVReg);
884 }
885 
getOrCreateOpTypePipe(MachineIRBuilder & MIRBuilder,SPIRV::AccessQualifier::AccessQualifier AccessQual)886 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypePipe(
887     MachineIRBuilder &MIRBuilder,
888     SPIRV::AccessQualifier::AccessQualifier AccessQual) {
889   SPIRV::PipeTypeDescriptor TD(AccessQual);
890   if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
891     return Res;
892   Register ResVReg = createTypeVReg(MIRBuilder);
893   DT.add(TD, &MIRBuilder.getMF(), ResVReg);
894   return MIRBuilder.buildInstr(SPIRV::OpTypePipe)
895       .addDef(ResVReg)
896       .addImm(AccessQual);
897 }
898 
getOrCreateOpTypeDeviceEvent(MachineIRBuilder & MIRBuilder)899 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeDeviceEvent(
900     MachineIRBuilder &MIRBuilder) {
901   SPIRV::DeviceEventTypeDescriptor TD;
902   if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
903     return Res;
904   Register ResVReg = createTypeVReg(MIRBuilder);
905   DT.add(TD, &MIRBuilder.getMF(), ResVReg);
906   return MIRBuilder.buildInstr(SPIRV::OpTypeDeviceEvent).addDef(ResVReg);
907 }
908 
getOrCreateOpTypeSampledImage(SPIRVType * ImageType,MachineIRBuilder & MIRBuilder)909 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeSampledImage(
910     SPIRVType *ImageType, MachineIRBuilder &MIRBuilder) {
911   SPIRV::SampledImageTypeDescriptor TD(
912       SPIRVToLLVMType.lookup(MIRBuilder.getMF().getRegInfo().getVRegDef(
913           ImageType->getOperand(1).getReg())),
914       ImageType);
915   if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
916     return Res;
917   Register ResVReg = createTypeVReg(MIRBuilder);
918   DT.add(TD, &MIRBuilder.getMF(), ResVReg);
919   return MIRBuilder.buildInstr(SPIRV::OpTypeSampledImage)
920       .addDef(ResVReg)
921       .addUse(getSPIRVTypeID(ImageType));
922 }
923 
getOrCreateOpTypeByOpcode(const Type * Ty,MachineIRBuilder & MIRBuilder,unsigned Opcode)924 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeByOpcode(
925     const Type *Ty, MachineIRBuilder &MIRBuilder, unsigned Opcode) {
926   Register ResVReg = DT.find(Ty, &MIRBuilder.getMF());
927   if (ResVReg.isValid())
928     return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(ResVReg);
929   ResVReg = createTypeVReg(MIRBuilder);
930   SPIRVType *SpirvTy = MIRBuilder.buildInstr(Opcode).addDef(ResVReg);
931   DT.add(Ty, &MIRBuilder.getMF(), ResVReg);
932   return SpirvTy;
933 }
934 
935 const MachineInstr *
checkSpecialInstr(const SPIRV::SpecialTypeDescriptor & TD,MachineIRBuilder & MIRBuilder)936 SPIRVGlobalRegistry::checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD,
937                                        MachineIRBuilder &MIRBuilder) {
938   Register Reg = DT.find(TD, &MIRBuilder.getMF());
939   if (Reg.isValid())
940     return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(Reg);
941   return nullptr;
942 }
943 
944 // TODO: maybe use tablegen to implement this.
getOrCreateSPIRVTypeByName(StringRef TypeStr,MachineIRBuilder & MIRBuilder,SPIRV::StorageClass::StorageClass SC,SPIRV::AccessQualifier::AccessQualifier AQ)945 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(
946     StringRef TypeStr, MachineIRBuilder &MIRBuilder,
947     SPIRV::StorageClass::StorageClass SC,
948     SPIRV::AccessQualifier::AccessQualifier AQ) {
949   unsigned VecElts = 0;
950   auto &Ctx = MIRBuilder.getMF().getFunction().getContext();
951 
952   // Parse strings representing either a SPIR-V or OpenCL builtin type.
953   if (hasBuiltinTypePrefix(TypeStr))
954     return getOrCreateSPIRVType(
955         SPIRV::parseBuiltinTypeNameToTargetExtType(TypeStr.str(), MIRBuilder),
956         MIRBuilder, AQ);
957 
958   // Parse type name in either "typeN" or "type vector[N]" format, where
959   // N is the number of elements of the vector.
960   Type *Ty;
961 
962   TypeStr.consume_front("atomic_");
963 
964   if (TypeStr.starts_with("void")) {
965     Ty = Type::getVoidTy(Ctx);
966     TypeStr = TypeStr.substr(strlen("void"));
967   } else if (TypeStr.starts_with("bool")) {
968     Ty = Type::getIntNTy(Ctx, 1);
969     TypeStr = TypeStr.substr(strlen("bool"));
970   } else if (TypeStr.starts_with("char") || TypeStr.starts_with("uchar")) {
971     Ty = Type::getInt8Ty(Ctx);
972     TypeStr = TypeStr.starts_with("char") ? TypeStr.substr(strlen("char"))
973                                           : TypeStr.substr(strlen("uchar"));
974   } else if (TypeStr.starts_with("short") || TypeStr.starts_with("ushort")) {
975     Ty = Type::getInt16Ty(Ctx);
976     TypeStr = TypeStr.starts_with("short") ? TypeStr.substr(strlen("short"))
977                                            : TypeStr.substr(strlen("ushort"));
978   } else if (TypeStr.starts_with("int") || TypeStr.starts_with("uint")) {
979     Ty = Type::getInt32Ty(Ctx);
980     TypeStr = TypeStr.starts_with("int") ? TypeStr.substr(strlen("int"))
981                                          : TypeStr.substr(strlen("uint"));
982   } else if (TypeStr.starts_with("long") || TypeStr.starts_with("ulong")) {
983     Ty = Type::getInt64Ty(Ctx);
984     TypeStr = TypeStr.starts_with("long") ? TypeStr.substr(strlen("long"))
985                                           : TypeStr.substr(strlen("ulong"));
986   } else if (TypeStr.starts_with("half")) {
987     Ty = Type::getHalfTy(Ctx);
988     TypeStr = TypeStr.substr(strlen("half"));
989   } else if (TypeStr.starts_with("float")) {
990     Ty = Type::getFloatTy(Ctx);
991     TypeStr = TypeStr.substr(strlen("float"));
992   } else if (TypeStr.starts_with("double")) {
993     Ty = Type::getDoubleTy(Ctx);
994     TypeStr = TypeStr.substr(strlen("double"));
995   } else
996     llvm_unreachable("Unable to recognize SPIRV type name.");
997 
998   auto SpirvTy = getOrCreateSPIRVType(Ty, MIRBuilder, AQ);
999 
1000   // Handle "type*" or  "type* vector[N]".
1001   if (TypeStr.starts_with("*")) {
1002     SpirvTy = getOrCreateSPIRVPointerType(SpirvTy, MIRBuilder, SC);
1003     TypeStr = TypeStr.substr(strlen("*"));
1004   }
1005 
1006   // Handle "typeN*" or  "type vector[N]*".
1007   bool IsPtrToVec = TypeStr.consume_back("*");
1008 
1009   if (TypeStr.consume_front(" vector[")) {
1010     TypeStr = TypeStr.substr(0, TypeStr.find(']'));
1011   }
1012   TypeStr.getAsInteger(10, VecElts);
1013   if (VecElts > 0)
1014     SpirvTy = getOrCreateSPIRVVectorType(SpirvTy, VecElts, MIRBuilder);
1015 
1016   if (IsPtrToVec)
1017     SpirvTy = getOrCreateSPIRVPointerType(SpirvTy, MIRBuilder, SC);
1018 
1019   return SpirvTy;
1020 }
1021 
1022 SPIRVType *
getOrCreateSPIRVIntegerType(unsigned BitWidth,MachineIRBuilder & MIRBuilder)1023 SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth,
1024                                                  MachineIRBuilder &MIRBuilder) {
1025   return getOrCreateSPIRVType(
1026       IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), BitWidth),
1027       MIRBuilder);
1028 }
1029 
finishCreatingSPIRVType(const Type * LLVMTy,SPIRVType * SpirvType)1030 SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy,
1031                                                         SPIRVType *SpirvType) {
1032   assert(CurMF == SpirvType->getMF());
1033   VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType;
1034   SPIRVToLLVMType[SpirvType] = LLVMTy;
1035   return SpirvType;
1036 }
1037 
getOrCreateSPIRVIntegerType(unsigned BitWidth,MachineInstr & I,const SPIRVInstrInfo & TII)1038 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
1039     unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
1040   Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth);
1041   Register Reg = DT.find(LLVMTy, CurMF);
1042   if (Reg.isValid())
1043     return getSPIRVTypeForVReg(Reg);
1044   MachineBasicBlock &BB = *I.getParent();
1045   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeInt))
1046                  .addDef(createTypeVReg(CurMF->getRegInfo()))
1047                  .addImm(BitWidth)
1048                  .addImm(0);
1049   DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB));
1050   return finishCreatingSPIRVType(LLVMTy, MIB);
1051 }
1052 
1053 SPIRVType *
getOrCreateSPIRVBoolType(MachineIRBuilder & MIRBuilder)1054 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder) {
1055   return getOrCreateSPIRVType(
1056       IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), 1),
1057       MIRBuilder);
1058 }
1059 
1060 SPIRVType *
getOrCreateSPIRVBoolType(MachineInstr & I,const SPIRVInstrInfo & TII)1061 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineInstr &I,
1062                                               const SPIRVInstrInfo &TII) {
1063   Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), 1);
1064   Register Reg = DT.find(LLVMTy, CurMF);
1065   if (Reg.isValid())
1066     return getSPIRVTypeForVReg(Reg);
1067   MachineBasicBlock &BB = *I.getParent();
1068   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeBool))
1069                  .addDef(createTypeVReg(CurMF->getRegInfo()));
1070   DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB));
1071   return finishCreatingSPIRVType(LLVMTy, MIB);
1072 }
1073 
getOrCreateSPIRVVectorType(SPIRVType * BaseType,unsigned NumElements,MachineIRBuilder & MIRBuilder)1074 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
1075     SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder) {
1076   return getOrCreateSPIRVType(
1077       FixedVectorType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)),
1078                            NumElements),
1079       MIRBuilder);
1080 }
1081 
getOrCreateSPIRVVectorType(SPIRVType * BaseType,unsigned NumElements,MachineInstr & I,const SPIRVInstrInfo & TII)1082 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
1083     SPIRVType *BaseType, unsigned NumElements, MachineInstr &I,
1084     const SPIRVInstrInfo &TII) {
1085   Type *LLVMTy = FixedVectorType::get(
1086       const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements);
1087   Register Reg = DT.find(LLVMTy, CurMF);
1088   if (Reg.isValid())
1089     return getSPIRVTypeForVReg(Reg);
1090   MachineBasicBlock &BB = *I.getParent();
1091   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeVector))
1092                  .addDef(createTypeVReg(CurMF->getRegInfo()))
1093                  .addUse(getSPIRVTypeID(BaseType))
1094                  .addImm(NumElements);
1095   DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB));
1096   return finishCreatingSPIRVType(LLVMTy, MIB);
1097 }
1098 
getOrCreateSPIRVArrayType(SPIRVType * BaseType,unsigned NumElements,MachineInstr & I,const SPIRVInstrInfo & TII)1099 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVArrayType(
1100     SPIRVType *BaseType, unsigned NumElements, MachineInstr &I,
1101     const SPIRVInstrInfo &TII) {
1102   Type *LLVMTy = ArrayType::get(
1103       const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements);
1104   Register Reg = DT.find(LLVMTy, CurMF);
1105   if (Reg.isValid())
1106     return getSPIRVTypeForVReg(Reg);
1107   MachineBasicBlock &BB = *I.getParent();
1108   SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(32, I, TII);
1109   Register Len = getOrCreateConstInt(NumElements, I, SpirvType, TII);
1110   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeArray))
1111                  .addDef(createTypeVReg(CurMF->getRegInfo()))
1112                  .addUse(getSPIRVTypeID(BaseType))
1113                  .addUse(Len);
1114   DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB));
1115   return finishCreatingSPIRVType(LLVMTy, MIB);
1116 }
1117 
getOrCreateSPIRVPointerType(SPIRVType * BaseType,MachineIRBuilder & MIRBuilder,SPIRV::StorageClass::StorageClass SC)1118 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
1119     SPIRVType *BaseType, MachineIRBuilder &MIRBuilder,
1120     SPIRV::StorageClass::StorageClass SC) {
1121   const Type *PointerElementType = getTypeForSPIRVType(BaseType);
1122   unsigned AddressSpace = storageClassToAddressSpace(SC);
1123   Type *LLVMTy =
1124       PointerType::get(const_cast<Type *>(PointerElementType), AddressSpace);
1125   Register Reg = DT.find(PointerElementType, AddressSpace, CurMF);
1126   if (Reg.isValid())
1127     return getSPIRVTypeForVReg(Reg);
1128   auto MIB = BuildMI(MIRBuilder.getMBB(), MIRBuilder.getInsertPt(),
1129                      MIRBuilder.getDebugLoc(),
1130                      MIRBuilder.getTII().get(SPIRV::OpTypePointer))
1131                  .addDef(createTypeVReg(CurMF->getRegInfo()))
1132                  .addImm(static_cast<uint32_t>(SC))
1133                  .addUse(getSPIRVTypeID(BaseType));
1134   DT.add(PointerElementType, AddressSpace, CurMF, getSPIRVTypeID(MIB));
1135   return finishCreatingSPIRVType(LLVMTy, MIB);
1136 }
1137 
getOrCreateSPIRVPointerType(SPIRVType * BaseType,MachineInstr & I,const SPIRVInstrInfo & TII,SPIRV::StorageClass::StorageClass SC)1138 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
1139     SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &TII,
1140     SPIRV::StorageClass::StorageClass SC) {
1141   const Type *PointerElementType = getTypeForSPIRVType(BaseType);
1142   unsigned AddressSpace = storageClassToAddressSpace(SC);
1143   Type *LLVMTy =
1144       PointerType::get(const_cast<Type *>(PointerElementType), AddressSpace);
1145   Register Reg = DT.find(PointerElementType, AddressSpace, CurMF);
1146   if (Reg.isValid())
1147     return getSPIRVTypeForVReg(Reg);
1148   MachineBasicBlock &BB = *I.getParent();
1149   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypePointer))
1150                  .addDef(createTypeVReg(CurMF->getRegInfo()))
1151                  .addImm(static_cast<uint32_t>(SC))
1152                  .addUse(getSPIRVTypeID(BaseType));
1153   DT.add(PointerElementType, AddressSpace, CurMF, getSPIRVTypeID(MIB));
1154   return finishCreatingSPIRVType(LLVMTy, MIB);
1155 }
1156 
getOrCreateUndef(MachineInstr & I,SPIRVType * SpvType,const SPIRVInstrInfo & TII)1157 Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I,
1158                                                SPIRVType *SpvType,
1159                                                const SPIRVInstrInfo &TII) {
1160   assert(SpvType);
1161   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
1162   assert(LLVMTy);
1163   // Find a constant in DT or build a new one.
1164   UndefValue *UV = UndefValue::get(const_cast<Type *>(LLVMTy));
1165   Register Res = DT.find(UV, CurMF);
1166   if (Res.isValid())
1167     return Res;
1168   LLT LLTy = LLT::scalar(32);
1169   Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
1170   CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
1171   assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
1172   DT.add(UV, CurMF, Res);
1173 
1174   MachineInstrBuilder MIB;
1175   MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpUndef))
1176             .addDef(Res)
1177             .addUse(getSPIRVTypeID(SpvType));
1178   const auto &ST = CurMF->getSubtarget();
1179   constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
1180                                    *ST.getRegisterInfo(), *ST.getRegBankInfo());
1181   return Res;
1182 }
1183