1 //===-- SPIRVGlobalRegistry.cpp - SPIR-V Global Registry --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the implementation of the SPIRVGlobalRegistry class,
10 // which is used to maintain rich type information required for SPIR-V even
11 // after lowering from LLVM IR to GMIR. It can convert an llvm::Type into
12 // an OpTypeXXX instruction, and map it to a virtual register. Also it builds
13 // and supports consistency of constants and global variables.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "SPIRVGlobalRegistry.h"
18 #include "SPIRV.h"
19 #include "SPIRVBuiltins.h"
20 #include "SPIRVSubtarget.h"
21 #include "SPIRVTargetMachine.h"
22 #include "SPIRVUtils.h"
23 
24 using namespace llvm;
25 SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize)
26     : PointerSize(PointerSize) {}
27 
28 SPIRVType *SPIRVGlobalRegistry::assignIntTypeToVReg(unsigned BitWidth,
29                                                     Register VReg,
30                                                     MachineInstr &I,
31                                                     const SPIRVInstrInfo &TII) {
32   SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
33   assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);
34   return SpirvType;
35 }
36 
37 SPIRVType *SPIRVGlobalRegistry::assignVectTypeToVReg(
38     SPIRVType *BaseType, unsigned NumElements, Register VReg, MachineInstr &I,
39     const SPIRVInstrInfo &TII) {
40   SPIRVType *SpirvType =
41       getOrCreateSPIRVVectorType(BaseType, NumElements, I, TII);
42   assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);
43   return SpirvType;
44 }
45 
46 SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg(
47     const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder,
48     SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
49 
50   SPIRVType *SpirvType =
51       getOrCreateSPIRVType(Type, MIRBuilder, AccessQual, EmitIR);
52   assignSPIRVTypeToVReg(SpirvType, VReg, MIRBuilder.getMF());
53   return SpirvType;
54 }
55 
56 void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType,
57                                                 Register VReg,
58                                                 MachineFunction &MF) {
59   VRegToTypeMap[&MF][VReg] = SpirvType;
60 }
61 
62 static Register createTypeVReg(MachineIRBuilder &MIRBuilder) {
63   auto &MRI = MIRBuilder.getMF().getRegInfo();
64   auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32));
65   MRI.setRegClass(Res, &SPIRV::TYPERegClass);
66   return Res;
67 }
68 
69 static Register createTypeVReg(MachineRegisterInfo &MRI) {
70   auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32));
71   MRI.setRegClass(Res, &SPIRV::TYPERegClass);
72   return Res;
73 }
74 
75 SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
76   return MIRBuilder.buildInstr(SPIRV::OpTypeBool)
77       .addDef(createTypeVReg(MIRBuilder));
78 }
79 
80 SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width,
81                                              MachineIRBuilder &MIRBuilder,
82                                              bool IsSigned) {
83   assert(Width <= 64 && "Unsupported integer width!");
84   const SPIRVSubtarget &ST =
85       cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget());
86   if (ST.canUseExtension(
87           SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) {
88     MIRBuilder.buildInstr(SPIRV::OpExtension)
89         .addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers);
90     MIRBuilder.buildInstr(SPIRV::OpCapability)
91         .addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL);
92   } else if (Width <= 8)
93     Width = 8;
94   else if (Width <= 16)
95     Width = 16;
96   else if (Width <= 32)
97     Width = 32;
98   else if (Width <= 64)
99     Width = 64;
100 
101   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeInt)
102                  .addDef(createTypeVReg(MIRBuilder))
103                  .addImm(Width)
104                  .addImm(IsSigned ? 1 : 0);
105   return MIB;
106 }
107 
108 SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
109                                                MachineIRBuilder &MIRBuilder) {
110   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
111                  .addDef(createTypeVReg(MIRBuilder))
112                  .addImm(Width);
113   return MIB;
114 }
115 
116 SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) {
117   return MIRBuilder.buildInstr(SPIRV::OpTypeVoid)
118       .addDef(createTypeVReg(MIRBuilder));
119 }
120 
121 SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems,
122                                                 SPIRVType *ElemType,
123                                                 MachineIRBuilder &MIRBuilder) {
124   auto EleOpc = ElemType->getOpcode();
125   assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat ||
126           EleOpc == SPIRV::OpTypeBool) &&
127          "Invalid vector element type");
128 
129   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeVector)
130                  .addDef(createTypeVReg(MIRBuilder))
131                  .addUse(getSPIRVTypeID(ElemType))
132                  .addImm(NumElems);
133   return MIB;
134 }
135 
136 std::tuple<Register, ConstantInt *, bool>
137 SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType,
138                                             MachineIRBuilder *MIRBuilder,
139                                             MachineInstr *I,
140                                             const SPIRVInstrInfo *TII) {
141   const IntegerType *LLVMIntTy;
142   if (SpvType)
143     LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType));
144   else
145     LLVMIntTy = IntegerType::getInt32Ty(CurMF->getFunction().getContext());
146   bool NewInstr = false;
147   // Find a constant in DT or build a new one.
148   ConstantInt *CI = ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val);
149   Register Res = DT.find(CI, CurMF);
150   if (!Res.isValid()) {
151     unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
152     LLT LLTy = LLT::scalar(32);
153     Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
154     CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
155     if (MIRBuilder)
156       assignTypeToVReg(LLVMIntTy, Res, *MIRBuilder);
157     else
158       assignIntTypeToVReg(BitWidth, Res, *I, *TII);
159     DT.add(CI, CurMF, Res);
160     NewInstr = true;
161   }
162   return std::make_tuple(Res, CI, NewInstr);
163 }
164 
165 Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
166                                                   SPIRVType *SpvType,
167                                                   const SPIRVInstrInfo &TII) {
168   assert(SpvType);
169   ConstantInt *CI;
170   Register Res;
171   bool New;
172   std::tie(Res, CI, New) =
173       getOrCreateConstIntReg(Val, SpvType, nullptr, &I, &TII);
174   // If we have found Res register which is defined by the passed G_CONSTANT
175   // machine instruction, a new constant instruction should be created.
176   if (!New && (!I.getOperand(0).isReg() || Res != I.getOperand(0).getReg()))
177     return Res;
178   MachineInstrBuilder MIB;
179   MachineBasicBlock &BB = *I.getParent();
180   if (Val) {
181     MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantI))
182               .addDef(Res)
183               .addUse(getSPIRVTypeID(SpvType));
184     addNumImm(APInt(getScalarOrVectorBitWidth(SpvType), Val), MIB);
185   } else {
186     MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
187               .addDef(Res)
188               .addUse(getSPIRVTypeID(SpvType));
189   }
190   const auto &ST = CurMF->getSubtarget();
191   constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
192                                    *ST.getRegisterInfo(), *ST.getRegBankInfo());
193   return Res;
194 }
195 
196 Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
197                                                MachineIRBuilder &MIRBuilder,
198                                                SPIRVType *SpvType,
199                                                bool EmitIR) {
200   auto &MF = MIRBuilder.getMF();
201   const IntegerType *LLVMIntTy;
202   if (SpvType)
203     LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType));
204   else
205     LLVMIntTy = IntegerType::getInt32Ty(MF.getFunction().getContext());
206   // Find a constant in DT or build a new one.
207   const auto ConstInt =
208       ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val);
209   Register Res = DT.find(ConstInt, &MF);
210   if (!Res.isValid()) {
211     unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
212     LLT LLTy = LLT::scalar(EmitIR ? BitWidth : 32);
213     Res = MF.getRegInfo().createGenericVirtualRegister(LLTy);
214     MF.getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
215     assignTypeToVReg(LLVMIntTy, Res, MIRBuilder,
216                      SPIRV::AccessQualifier::ReadWrite, EmitIR);
217     DT.add(ConstInt, &MIRBuilder.getMF(), Res);
218     if (EmitIR) {
219       MIRBuilder.buildConstant(Res, *ConstInt);
220     } else {
221       MachineInstrBuilder MIB;
222       if (Val) {
223         assert(SpvType);
224         MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
225                   .addDef(Res)
226                   .addUse(getSPIRVTypeID(SpvType));
227         addNumImm(APInt(BitWidth, Val), MIB);
228       } else {
229         assert(SpvType);
230         MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
231                   .addDef(Res)
232                   .addUse(getSPIRVTypeID(SpvType));
233       }
234       const auto &Subtarget = CurMF->getSubtarget();
235       constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
236                                        *Subtarget.getRegisterInfo(),
237                                        *Subtarget.getRegBankInfo());
238     }
239   }
240   return Res;
241 }
242 
243 Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,
244                                               MachineIRBuilder &MIRBuilder,
245                                               SPIRVType *SpvType) {
246   auto &MF = MIRBuilder.getMF();
247   auto &Ctx = MF.getFunction().getContext();
248   if (!SpvType) {
249     const Type *LLVMFPTy = Type::getFloatTy(Ctx);
250     SpvType = getOrCreateSPIRVType(LLVMFPTy, MIRBuilder);
251   }
252   // Find a constant in DT or build a new one.
253   const auto ConstFP = ConstantFP::get(Ctx, Val);
254   Register Res = DT.find(ConstFP, &MF);
255   if (!Res.isValid()) {
256     Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(32));
257     MF.getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
258     assignSPIRVTypeToVReg(SpvType, Res, MF);
259     DT.add(ConstFP, &MF, Res);
260 
261     MachineInstrBuilder MIB;
262     MIB = MIRBuilder.buildInstr(SPIRV::OpConstantF)
263               .addDef(Res)
264               .addUse(getSPIRVTypeID(SpvType));
265     addNumImm(ConstFP->getValueAPF().bitcastToAPInt(), MIB);
266   }
267 
268   return Res;
269 }
270 
271 Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
272     uint64_t Val, MachineInstr &I, SPIRVType *SpvType,
273     const SPIRVInstrInfo &TII, Constant *CA, unsigned BitWidth,
274     unsigned ElemCnt) {
275   // Find a constant vector in DT or build a new one.
276   Register Res = DT.find(CA, CurMF);
277   if (!Res.isValid()) {
278     SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
279     // SpvScalConst should be created before SpvVecConst to avoid undefined ID
280     // error on validation.
281     // TODO: can moved below once sorting of types/consts/defs is implemented.
282     Register SpvScalConst;
283     if (Val)
284       SpvScalConst = getOrCreateConstInt(Val, I, SpvBaseType, TII);
285     // TODO: maybe use bitwidth of base type.
286     LLT LLTy = LLT::scalar(32);
287     Register SpvVecConst =
288         CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
289     CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::IDRegClass);
290     assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF);
291     DT.add(CA, CurMF, SpvVecConst);
292     MachineInstrBuilder MIB;
293     MachineBasicBlock &BB = *I.getParent();
294     if (Val) {
295       MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantComposite))
296                 .addDef(SpvVecConst)
297                 .addUse(getSPIRVTypeID(SpvType));
298       for (unsigned i = 0; i < ElemCnt; ++i)
299         MIB.addUse(SpvScalConst);
300     } else {
301       MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
302                 .addDef(SpvVecConst)
303                 .addUse(getSPIRVTypeID(SpvType));
304     }
305     const auto &Subtarget = CurMF->getSubtarget();
306     constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
307                                      *Subtarget.getRegisterInfo(),
308                                      *Subtarget.getRegBankInfo());
309     return SpvVecConst;
310   }
311   return Res;
312 }
313 
314 Register
315 SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val, MachineInstr &I,
316                                               SPIRVType *SpvType,
317                                               const SPIRVInstrInfo &TII) {
318   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
319   assert(LLVMTy->isVectorTy());
320   const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
321   Type *LLVMBaseTy = LLVMVecTy->getElementType();
322   const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
323   auto ConstVec =
324       ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt);
325   unsigned BW = getScalarOrVectorBitWidth(SpvType);
326   return getOrCreateIntCompositeOrNull(Val, I, SpvType, TII, ConstVec, BW,
327                                        SpvType->getOperand(2).getImm());
328 }
329 
330 Register
331 SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val, MachineInstr &I,
332                                              SPIRVType *SpvType,
333                                              const SPIRVInstrInfo &TII) {
334   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
335   assert(LLVMTy->isArrayTy());
336   const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy);
337   Type *LLVMBaseTy = LLVMArrTy->getElementType();
338   const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
339   auto ConstArr =
340       ConstantArray::get(const_cast<ArrayType *>(LLVMArrTy), {ConstInt});
341   SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
342   unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy);
343   return getOrCreateIntCompositeOrNull(Val, I, SpvType, TII, ConstArr, BW,
344                                        LLVMArrTy->getNumElements());
345 }
346 
347 Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
348     uint64_t Val, MachineIRBuilder &MIRBuilder, SPIRVType *SpvType, bool EmitIR,
349     Constant *CA, unsigned BitWidth, unsigned ElemCnt) {
350   Register Res = DT.find(CA, CurMF);
351   if (!Res.isValid()) {
352     Register SpvScalConst;
353     if (Val || EmitIR) {
354       SPIRVType *SpvBaseType =
355           getOrCreateSPIRVIntegerType(BitWidth, MIRBuilder);
356       SpvScalConst = buildConstantInt(Val, MIRBuilder, SpvBaseType, EmitIR);
357     }
358     LLT LLTy = EmitIR ? LLT::fixed_vector(ElemCnt, BitWidth) : LLT::scalar(32);
359     Register SpvVecConst =
360         CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
361     CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::IDRegClass);
362     assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF);
363     DT.add(CA, CurMF, SpvVecConst);
364     if (EmitIR) {
365       MIRBuilder.buildSplatVector(SpvVecConst, SpvScalConst);
366     } else {
367       if (Val) {
368         auto MIB = MIRBuilder.buildInstr(SPIRV::OpConstantComposite)
369                        .addDef(SpvVecConst)
370                        .addUse(getSPIRVTypeID(SpvType));
371         for (unsigned i = 0; i < ElemCnt; ++i)
372           MIB.addUse(SpvScalConst);
373       } else {
374         MIRBuilder.buildInstr(SPIRV::OpConstantNull)
375             .addDef(SpvVecConst)
376             .addUse(getSPIRVTypeID(SpvType));
377       }
378     }
379     return SpvVecConst;
380   }
381   return Res;
382 }
383 
384 Register
385 SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val,
386                                               MachineIRBuilder &MIRBuilder,
387                                               SPIRVType *SpvType, bool EmitIR) {
388   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
389   assert(LLVMTy->isVectorTy());
390   const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
391   Type *LLVMBaseTy = LLVMVecTy->getElementType();
392   const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
393   auto ConstVec =
394       ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt);
395   unsigned BW = getScalarOrVectorBitWidth(SpvType);
396   return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR,
397                                        ConstVec, BW,
398                                        SpvType->getOperand(2).getImm());
399 }
400 
401 Register
402 SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val,
403                                              MachineIRBuilder &MIRBuilder,
404                                              SPIRVType *SpvType, bool EmitIR) {
405   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
406   assert(LLVMTy->isArrayTy());
407   const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy);
408   Type *LLVMBaseTy = LLVMArrTy->getElementType();
409   const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
410   auto ConstArr =
411       ConstantArray::get(const_cast<ArrayType *>(LLVMArrTy), {ConstInt});
412   SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
413   unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy);
414   return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR,
415                                        ConstArr, BW,
416                                        LLVMArrTy->getNumElements());
417 }
418 
419 Register
420 SPIRVGlobalRegistry::getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder,
421                                              SPIRVType *SpvType) {
422   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
423   const PointerType *LLVMPtrTy = cast<PointerType>(LLVMTy);
424   // Find a constant in DT or build a new one.
425   Constant *CP = ConstantPointerNull::get(const_cast<PointerType *>(LLVMPtrTy));
426   Register Res = DT.find(CP, CurMF);
427   if (!Res.isValid()) {
428     LLT LLTy = LLT::pointer(LLVMPtrTy->getAddressSpace(), PointerSize);
429     Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
430     CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
431     assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
432     MIRBuilder.buildInstr(SPIRV::OpConstantNull)
433         .addDef(Res)
434         .addUse(getSPIRVTypeID(SpvType));
435     DT.add(CP, CurMF, Res);
436   }
437   return Res;
438 }
439 
440 Register SPIRVGlobalRegistry::buildConstantSampler(
441     Register ResReg, unsigned AddrMode, unsigned Param, unsigned FilerMode,
442     MachineIRBuilder &MIRBuilder, SPIRVType *SpvType) {
443   SPIRVType *SampTy;
444   if (SpvType)
445     SampTy = getOrCreateSPIRVType(getTypeForSPIRVType(SpvType), MIRBuilder);
446   else
447     SampTy = getOrCreateSPIRVTypeByName("opencl.sampler_t", MIRBuilder);
448 
449   auto Sampler =
450       ResReg.isValid()
451           ? ResReg
452           : MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
453   auto Res = MIRBuilder.buildInstr(SPIRV::OpConstantSampler)
454                  .addDef(Sampler)
455                  .addUse(getSPIRVTypeID(SampTy))
456                  .addImm(AddrMode)
457                  .addImm(Param)
458                  .addImm(FilerMode);
459   assert(Res->getOperand(0).isReg());
460   return Res->getOperand(0).getReg();
461 }
462 
463 Register SPIRVGlobalRegistry::buildGlobalVariable(
464     Register ResVReg, SPIRVType *BaseType, StringRef Name,
465     const GlobalValue *GV, SPIRV::StorageClass::StorageClass Storage,
466     const MachineInstr *Init, bool IsConst, bool HasLinkageTy,
467     SPIRV::LinkageType::LinkageType LinkageType, MachineIRBuilder &MIRBuilder,
468     bool IsInstSelector) {
469   const GlobalVariable *GVar = nullptr;
470   if (GV)
471     GVar = cast<const GlobalVariable>(GV);
472   else {
473     // If GV is not passed explicitly, use the name to find or construct
474     // the global variable.
475     Module *M = MIRBuilder.getMF().getFunction().getParent();
476     GVar = M->getGlobalVariable(Name);
477     if (GVar == nullptr) {
478       const Type *Ty = getTypeForSPIRVType(BaseType); // TODO: check type.
479       GVar = new GlobalVariable(*M, const_cast<Type *>(Ty), false,
480                                 GlobalValue::ExternalLinkage, nullptr,
481                                 Twine(Name));
482     }
483     GV = GVar;
484   }
485   Register Reg = DT.find(GVar, &MIRBuilder.getMF());
486   if (Reg.isValid()) {
487     if (Reg != ResVReg)
488       MIRBuilder.buildCopy(ResVReg, Reg);
489     return ResVReg;
490   }
491 
492   auto MIB = MIRBuilder.buildInstr(SPIRV::OpVariable)
493                  .addDef(ResVReg)
494                  .addUse(getSPIRVTypeID(BaseType))
495                  .addImm(static_cast<uint32_t>(Storage));
496 
497   if (Init != 0) {
498     MIB.addUse(Init->getOperand(0).getReg());
499   }
500 
501   // ISel may introduce a new register on this step, so we need to add it to
502   // DT and correct its type avoiding fails on the next stage.
503   if (IsInstSelector) {
504     const auto &Subtarget = CurMF->getSubtarget();
505     constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
506                                      *Subtarget.getRegisterInfo(),
507                                      *Subtarget.getRegBankInfo());
508   }
509   Reg = MIB->getOperand(0).getReg();
510   DT.add(GVar, &MIRBuilder.getMF(), Reg);
511 
512   // Set to Reg the same type as ResVReg has.
513   auto MRI = MIRBuilder.getMRI();
514   assert(MRI->getType(ResVReg).isPointer() && "Pointer type is expected");
515   if (Reg != ResVReg) {
516     LLT RegLLTy = LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), 32);
517     MRI->setType(Reg, RegLLTy);
518     assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF());
519   }
520 
521   // If it's a global variable with name, output OpName for it.
522   if (GVar && GVar->hasName())
523     buildOpName(Reg, GVar->getName(), MIRBuilder);
524 
525   // Output decorations for the GV.
526   // TODO: maybe move to GenerateDecorations pass.
527   if (IsConst)
528     buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Constant, {});
529 
530   if (GVar && GVar->getAlign().valueOrOne().value() != 1) {
531     unsigned Alignment = (unsigned)GVar->getAlign().valueOrOne().value();
532     buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Alignment, {Alignment});
533   }
534 
535   if (HasLinkageTy)
536     buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::LinkageAttributes,
537                     {static_cast<uint32_t>(LinkageType)}, Name);
538 
539   SPIRV::BuiltIn::BuiltIn BuiltInId;
540   if (getSpirvBuiltInIdByName(Name, BuiltInId))
541     buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::BuiltIn,
542                     {static_cast<uint32_t>(BuiltInId)});
543 
544   return Reg;
545 }
546 
547 SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems,
548                                                SPIRVType *ElemType,
549                                                MachineIRBuilder &MIRBuilder,
550                                                bool EmitIR) {
551   assert((ElemType->getOpcode() != SPIRV::OpTypeVoid) &&
552          "Invalid array element type");
553   Register NumElementsVReg =
554       buildConstantInt(NumElems, MIRBuilder, nullptr, EmitIR);
555   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeArray)
556                  .addDef(createTypeVReg(MIRBuilder))
557                  .addUse(getSPIRVTypeID(ElemType))
558                  .addUse(NumElementsVReg);
559   return MIB;
560 }
561 
562 SPIRVType *SPIRVGlobalRegistry::getOpTypeOpaque(const StructType *Ty,
563                                                 MachineIRBuilder &MIRBuilder) {
564   assert(Ty->hasName());
565   const StringRef Name = Ty->hasName() ? Ty->getName() : "";
566   Register ResVReg = createTypeVReg(MIRBuilder);
567   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeOpaque).addDef(ResVReg);
568   addStringImm(Name, MIB);
569   buildOpName(ResVReg, Name, MIRBuilder);
570   return MIB;
571 }
572 
573 SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty,
574                                                 MachineIRBuilder &MIRBuilder,
575                                                 bool EmitIR) {
576   SmallVector<Register, 4> FieldTypes;
577   for (const auto &Elem : Ty->elements()) {
578     SPIRVType *ElemTy = findSPIRVType(Elem, MIRBuilder);
579     assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid &&
580            "Invalid struct element type");
581     FieldTypes.push_back(getSPIRVTypeID(ElemTy));
582   }
583   Register ResVReg = createTypeVReg(MIRBuilder);
584   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeStruct).addDef(ResVReg);
585   for (const auto &Ty : FieldTypes)
586     MIB.addUse(Ty);
587   if (Ty->hasName())
588     buildOpName(ResVReg, Ty->getName(), MIRBuilder);
589   if (Ty->isPacked())
590     buildOpDecorate(ResVReg, MIRBuilder, SPIRV::Decoration::CPacked, {});
591   return MIB;
592 }
593 
594 SPIRVType *SPIRVGlobalRegistry::getOrCreateSpecialType(
595     const Type *Ty, MachineIRBuilder &MIRBuilder,
596     SPIRV::AccessQualifier::AccessQualifier AccQual) {
597   assert(isSpecialOpaqueType(Ty) && "Not a special opaque builtin type");
598   return SPIRV::lowerBuiltinType(Ty, AccQual, MIRBuilder, this);
599 }
600 
601 SPIRVType *SPIRVGlobalRegistry::getOpTypePointer(
602     SPIRV::StorageClass::StorageClass SC, SPIRVType *ElemType,
603     MachineIRBuilder &MIRBuilder, Register Reg) {
604   if (!Reg.isValid())
605     Reg = createTypeVReg(MIRBuilder);
606   return MIRBuilder.buildInstr(SPIRV::OpTypePointer)
607       .addDef(Reg)
608       .addImm(static_cast<uint32_t>(SC))
609       .addUse(getSPIRVTypeID(ElemType));
610 }
611 
612 SPIRVType *SPIRVGlobalRegistry::getOpTypeForwardPointer(
613     SPIRV::StorageClass::StorageClass SC, MachineIRBuilder &MIRBuilder) {
614   return MIRBuilder.buildInstr(SPIRV::OpTypeForwardPointer)
615       .addUse(createTypeVReg(MIRBuilder))
616       .addImm(static_cast<uint32_t>(SC));
617 }
618 
619 SPIRVType *SPIRVGlobalRegistry::getOpTypeFunction(
620     SPIRVType *RetType, const SmallVectorImpl<SPIRVType *> &ArgTypes,
621     MachineIRBuilder &MIRBuilder) {
622   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFunction)
623                  .addDef(createTypeVReg(MIRBuilder))
624                  .addUse(getSPIRVTypeID(RetType));
625   for (const SPIRVType *ArgType : ArgTypes)
626     MIB.addUse(getSPIRVTypeID(ArgType));
627   return MIB;
628 }
629 
630 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeFunctionWithArgs(
631     const Type *Ty, SPIRVType *RetType,
632     const SmallVectorImpl<SPIRVType *> &ArgTypes,
633     MachineIRBuilder &MIRBuilder) {
634   Register Reg = DT.find(Ty, &MIRBuilder.getMF());
635   if (Reg.isValid())
636     return getSPIRVTypeForVReg(Reg);
637   SPIRVType *SpirvType = getOpTypeFunction(RetType, ArgTypes, MIRBuilder);
638   DT.add(Ty, CurMF, getSPIRVTypeID(SpirvType));
639   return finishCreatingSPIRVType(Ty, SpirvType);
640 }
641 
642 SPIRVType *SPIRVGlobalRegistry::findSPIRVType(
643     const Type *Ty, MachineIRBuilder &MIRBuilder,
644     SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
645   Register Reg = DT.find(Ty, &MIRBuilder.getMF());
646   if (Reg.isValid())
647     return getSPIRVTypeForVReg(Reg);
648   if (ForwardPointerTypes.contains(Ty))
649     return ForwardPointerTypes[Ty];
650   return restOfCreateSPIRVType(Ty, MIRBuilder, AccQual, EmitIR);
651 }
652 
653 Register SPIRVGlobalRegistry::getSPIRVTypeID(const SPIRVType *SpirvType) const {
654   assert(SpirvType && "Attempting to get type id for nullptr type.");
655   if (SpirvType->getOpcode() == SPIRV::OpTypeForwardPointer)
656     return SpirvType->uses().begin()->getReg();
657   return SpirvType->defs().begin()->getReg();
658 }
659 
660 SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
661     const Type *Ty, MachineIRBuilder &MIRBuilder,
662     SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
663   if (isSpecialOpaqueType(Ty))
664     return getOrCreateSpecialType(Ty, MIRBuilder, AccQual);
665   auto &TypeToSPIRVTypeMap = DT.getTypes()->getAllUses();
666   auto t = TypeToSPIRVTypeMap.find(Ty);
667   if (t != TypeToSPIRVTypeMap.end()) {
668     auto tt = t->second.find(&MIRBuilder.getMF());
669     if (tt != t->second.end())
670       return getSPIRVTypeForVReg(tt->second);
671   }
672 
673   if (auto IType = dyn_cast<IntegerType>(Ty)) {
674     const unsigned Width = IType->getBitWidth();
675     return Width == 1 ? getOpTypeBool(MIRBuilder)
676                       : getOpTypeInt(Width, MIRBuilder, false);
677   }
678   if (Ty->isFloatingPointTy())
679     return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
680   if (Ty->isVoidTy())
681     return getOpTypeVoid(MIRBuilder);
682   if (Ty->isVectorTy()) {
683     SPIRVType *El =
684         findSPIRVType(cast<FixedVectorType>(Ty)->getElementType(), MIRBuilder);
685     return getOpTypeVector(cast<FixedVectorType>(Ty)->getNumElements(), El,
686                            MIRBuilder);
687   }
688   if (Ty->isArrayTy()) {
689     SPIRVType *El = findSPIRVType(Ty->getArrayElementType(), MIRBuilder);
690     return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder, EmitIR);
691   }
692   if (auto SType = dyn_cast<StructType>(Ty)) {
693     if (SType->isOpaque())
694       return getOpTypeOpaque(SType, MIRBuilder);
695     return getOpTypeStruct(SType, MIRBuilder, EmitIR);
696   }
697   if (auto FType = dyn_cast<FunctionType>(Ty)) {
698     SPIRVType *RetTy = findSPIRVType(FType->getReturnType(), MIRBuilder);
699     SmallVector<SPIRVType *, 4> ParamTypes;
700     for (const auto &t : FType->params()) {
701       ParamTypes.push_back(findSPIRVType(t, MIRBuilder));
702     }
703     return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder);
704   }
705   if (auto PType = dyn_cast<PointerType>(Ty)) {
706     SPIRVType *SpvElementType;
707     // At the moment, all opaque pointers correspond to i8 element type.
708     // TODO: change the implementation once opaque pointers are supported
709     // in the SPIR-V specification.
710     SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);
711     auto SC = addressSpaceToStorageClass(PType->getAddressSpace());
712     // Null pointer means we have a loop in type definitions, make and
713     // return corresponding OpTypeForwardPointer.
714     if (SpvElementType == nullptr) {
715       if (!ForwardPointerTypes.contains(Ty))
716         ForwardPointerTypes[PType] = getOpTypeForwardPointer(SC, MIRBuilder);
717       return ForwardPointerTypes[PType];
718     }
719     Register Reg(0);
720     // If we have forward pointer associated with this type, use its register
721     // operand to create OpTypePointer.
722     if (ForwardPointerTypes.contains(PType))
723       Reg = getSPIRVTypeID(ForwardPointerTypes[PType]);
724 
725     return getOpTypePointer(SC, SpvElementType, MIRBuilder, Reg);
726   }
727   llvm_unreachable("Unable to convert LLVM type to SPIRVType");
728 }
729 
730 SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
731     const Type *Ty, MachineIRBuilder &MIRBuilder,
732     SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
733   if (TypesInProcessing.count(Ty) && !Ty->isPointerTy())
734     return nullptr;
735   TypesInProcessing.insert(Ty);
736   SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
737   TypesInProcessing.erase(Ty);
738   VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;
739   SPIRVToLLVMType[SpirvType] = Ty;
740   Register Reg = DT.find(Ty, &MIRBuilder.getMF());
741   // Do not add OpTypeForwardPointer to DT, a corresponding normal pointer type
742   // will be added later. For special types it is already added to DT.
743   if (SpirvType->getOpcode() != SPIRV::OpTypeForwardPointer && !Reg.isValid() &&
744       !isSpecialOpaqueType(Ty)) {
745     if (!Ty->isPointerTy())
746       DT.add(Ty, &MIRBuilder.getMF(), getSPIRVTypeID(SpirvType));
747     else
748       DT.add(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()),
749              Ty->getPointerAddressSpace(), &MIRBuilder.getMF(),
750              getSPIRVTypeID(SpirvType));
751   }
752 
753   return SpirvType;
754 }
755 
756 SPIRVType *SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg) const {
757   auto t = VRegToTypeMap.find(CurMF);
758   if (t != VRegToTypeMap.end()) {
759     auto tt = t->second.find(VReg);
760     if (tt != t->second.end())
761       return tt->second;
762   }
763   return nullptr;
764 }
765 
766 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(
767     const Type *Ty, MachineIRBuilder &MIRBuilder,
768     SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
769   Register Reg;
770   if (!Ty->isPointerTy())
771     Reg = DT.find(Ty, &MIRBuilder.getMF());
772   else
773     Reg =
774         DT.find(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()),
775                 Ty->getPointerAddressSpace(), &MIRBuilder.getMF());
776 
777   if (Reg.isValid() && !isSpecialOpaqueType(Ty))
778     return getSPIRVTypeForVReg(Reg);
779   TypesInProcessing.clear();
780   SPIRVType *STy = restOfCreateSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
781   // Create normal pointer types for the corresponding OpTypeForwardPointers.
782   for (auto &CU : ForwardPointerTypes) {
783     const Type *Ty2 = CU.first;
784     SPIRVType *STy2 = CU.second;
785     if ((Reg = DT.find(Ty2, &MIRBuilder.getMF())).isValid())
786       STy2 = getSPIRVTypeForVReg(Reg);
787     else
788       STy2 = restOfCreateSPIRVType(Ty2, MIRBuilder, AccessQual, EmitIR);
789     if (Ty == Ty2)
790       STy = STy2;
791   }
792   ForwardPointerTypes.clear();
793   return STy;
794 }
795 
796 bool SPIRVGlobalRegistry::isScalarOfType(Register VReg,
797                                          unsigned TypeOpcode) const {
798   SPIRVType *Type = getSPIRVTypeForVReg(VReg);
799   assert(Type && "isScalarOfType VReg has no type assigned");
800   return Type->getOpcode() == TypeOpcode;
801 }
802 
803 bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg,
804                                                  unsigned TypeOpcode) const {
805   SPIRVType *Type = getSPIRVTypeForVReg(VReg);
806   assert(Type && "isScalarOrVectorOfType VReg has no type assigned");
807   if (Type->getOpcode() == TypeOpcode)
808     return true;
809   if (Type->getOpcode() == SPIRV::OpTypeVector) {
810     Register ScalarTypeVReg = Type->getOperand(1).getReg();
811     SPIRVType *ScalarType = getSPIRVTypeForVReg(ScalarTypeVReg);
812     return ScalarType->getOpcode() == TypeOpcode;
813   }
814   return false;
815 }
816 
817 unsigned
818 SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const {
819   assert(Type && "Invalid Type pointer");
820   if (Type->getOpcode() == SPIRV::OpTypeVector) {
821     auto EleTypeReg = Type->getOperand(1).getReg();
822     Type = getSPIRVTypeForVReg(EleTypeReg);
823   }
824   if (Type->getOpcode() == SPIRV::OpTypeInt ||
825       Type->getOpcode() == SPIRV::OpTypeFloat)
826     return Type->getOperand(1).getImm();
827   if (Type->getOpcode() == SPIRV::OpTypeBool)
828     return 1;
829   llvm_unreachable("Attempting to get bit width of non-integer/float type.");
830 }
831 
832 bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const {
833   assert(Type && "Invalid Type pointer");
834   if (Type->getOpcode() == SPIRV::OpTypeVector) {
835     auto EleTypeReg = Type->getOperand(1).getReg();
836     Type = getSPIRVTypeForVReg(EleTypeReg);
837   }
838   if (Type->getOpcode() == SPIRV::OpTypeInt)
839     return Type->getOperand(2).getImm() != 0;
840   llvm_unreachable("Attempting to get sign of non-integer type.");
841 }
842 
843 SPIRV::StorageClass::StorageClass
844 SPIRVGlobalRegistry::getPointerStorageClass(Register VReg) const {
845   SPIRVType *Type = getSPIRVTypeForVReg(VReg);
846   assert(Type && Type->getOpcode() == SPIRV::OpTypePointer &&
847          Type->getOperand(1).isImm() && "Pointer type is expected");
848   return static_cast<SPIRV::StorageClass::StorageClass>(
849       Type->getOperand(1).getImm());
850 }
851 
852 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeImage(
853     MachineIRBuilder &MIRBuilder, SPIRVType *SampledType, SPIRV::Dim::Dim Dim,
854     uint32_t Depth, uint32_t Arrayed, uint32_t Multisampled, uint32_t Sampled,
855     SPIRV::ImageFormat::ImageFormat ImageFormat,
856     SPIRV::AccessQualifier::AccessQualifier AccessQual) {
857   SPIRV::ImageTypeDescriptor TD(SPIRVToLLVMType.lookup(SampledType), Dim, Depth,
858                                 Arrayed, Multisampled, Sampled, ImageFormat,
859                                 AccessQual);
860   if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
861     return Res;
862   Register ResVReg = createTypeVReg(MIRBuilder);
863   DT.add(TD, &MIRBuilder.getMF(), ResVReg);
864   return MIRBuilder.buildInstr(SPIRV::OpTypeImage)
865       .addDef(ResVReg)
866       .addUse(getSPIRVTypeID(SampledType))
867       .addImm(Dim)
868       .addImm(Depth)        // Depth (whether or not it is a Depth image).
869       .addImm(Arrayed)      // Arrayed.
870       .addImm(Multisampled) // Multisampled (0 = only single-sample).
871       .addImm(Sampled)      // Sampled (0 = usage known at runtime).
872       .addImm(ImageFormat)
873       .addImm(AccessQual);
874 }
875 
876 SPIRVType *
877 SPIRVGlobalRegistry::getOrCreateOpTypeSampler(MachineIRBuilder &MIRBuilder) {
878   SPIRV::SamplerTypeDescriptor TD;
879   if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
880     return Res;
881   Register ResVReg = createTypeVReg(MIRBuilder);
882   DT.add(TD, &MIRBuilder.getMF(), ResVReg);
883   return MIRBuilder.buildInstr(SPIRV::OpTypeSampler).addDef(ResVReg);
884 }
885 
886 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypePipe(
887     MachineIRBuilder &MIRBuilder,
888     SPIRV::AccessQualifier::AccessQualifier AccessQual) {
889   SPIRV::PipeTypeDescriptor TD(AccessQual);
890   if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
891     return Res;
892   Register ResVReg = createTypeVReg(MIRBuilder);
893   DT.add(TD, &MIRBuilder.getMF(), ResVReg);
894   return MIRBuilder.buildInstr(SPIRV::OpTypePipe)
895       .addDef(ResVReg)
896       .addImm(AccessQual);
897 }
898 
899 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeDeviceEvent(
900     MachineIRBuilder &MIRBuilder) {
901   SPIRV::DeviceEventTypeDescriptor TD;
902   if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
903     return Res;
904   Register ResVReg = createTypeVReg(MIRBuilder);
905   DT.add(TD, &MIRBuilder.getMF(), ResVReg);
906   return MIRBuilder.buildInstr(SPIRV::OpTypeDeviceEvent).addDef(ResVReg);
907 }
908 
909 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeSampledImage(
910     SPIRVType *ImageType, MachineIRBuilder &MIRBuilder) {
911   SPIRV::SampledImageTypeDescriptor TD(
912       SPIRVToLLVMType.lookup(MIRBuilder.getMF().getRegInfo().getVRegDef(
913           ImageType->getOperand(1).getReg())),
914       ImageType);
915   if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
916     return Res;
917   Register ResVReg = createTypeVReg(MIRBuilder);
918   DT.add(TD, &MIRBuilder.getMF(), ResVReg);
919   return MIRBuilder.buildInstr(SPIRV::OpTypeSampledImage)
920       .addDef(ResVReg)
921       .addUse(getSPIRVTypeID(ImageType));
922 }
923 
924 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeByOpcode(
925     const Type *Ty, MachineIRBuilder &MIRBuilder, unsigned Opcode) {
926   Register ResVReg = DT.find(Ty, &MIRBuilder.getMF());
927   if (ResVReg.isValid())
928     return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(ResVReg);
929   ResVReg = createTypeVReg(MIRBuilder);
930   SPIRVType *SpirvTy = MIRBuilder.buildInstr(Opcode).addDef(ResVReg);
931   DT.add(Ty, &MIRBuilder.getMF(), ResVReg);
932   return SpirvTy;
933 }
934 
935 const MachineInstr *
936 SPIRVGlobalRegistry::checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD,
937                                        MachineIRBuilder &MIRBuilder) {
938   Register Reg = DT.find(TD, &MIRBuilder.getMF());
939   if (Reg.isValid())
940     return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(Reg);
941   return nullptr;
942 }
943 
944 // TODO: maybe use tablegen to implement this.
945 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(
946     StringRef TypeStr, MachineIRBuilder &MIRBuilder,
947     SPIRV::StorageClass::StorageClass SC,
948     SPIRV::AccessQualifier::AccessQualifier AQ) {
949   unsigned VecElts = 0;
950   auto &Ctx = MIRBuilder.getMF().getFunction().getContext();
951 
952   // Parse strings representing either a SPIR-V or OpenCL builtin type.
953   if (hasBuiltinTypePrefix(TypeStr))
954     return getOrCreateSPIRVType(
955         SPIRV::parseBuiltinTypeNameToTargetExtType(TypeStr.str(), MIRBuilder),
956         MIRBuilder, AQ);
957 
958   // Parse type name in either "typeN" or "type vector[N]" format, where
959   // N is the number of elements of the vector.
960   Type *Ty;
961 
962   if (TypeStr.starts_with("atomic_"))
963     TypeStr = TypeStr.substr(strlen("atomic_"));
964 
965   if (TypeStr.starts_with("void")) {
966     Ty = Type::getVoidTy(Ctx);
967     TypeStr = TypeStr.substr(strlen("void"));
968   } else if (TypeStr.starts_with("bool")) {
969     Ty = Type::getIntNTy(Ctx, 1);
970     TypeStr = TypeStr.substr(strlen("bool"));
971   } else if (TypeStr.starts_with("char") || TypeStr.starts_with("uchar")) {
972     Ty = Type::getInt8Ty(Ctx);
973     TypeStr = TypeStr.starts_with("char") ? TypeStr.substr(strlen("char"))
974                                           : TypeStr.substr(strlen("uchar"));
975   } else if (TypeStr.starts_with("short") || TypeStr.starts_with("ushort")) {
976     Ty = Type::getInt16Ty(Ctx);
977     TypeStr = TypeStr.starts_with("short") ? TypeStr.substr(strlen("short"))
978                                            : TypeStr.substr(strlen("ushort"));
979   } else if (TypeStr.starts_with("int") || TypeStr.starts_with("uint")) {
980     Ty = Type::getInt32Ty(Ctx);
981     TypeStr = TypeStr.starts_with("int") ? TypeStr.substr(strlen("int"))
982                                          : TypeStr.substr(strlen("uint"));
983   } else if (TypeStr.starts_with("long") || TypeStr.starts_with("ulong")) {
984     Ty = Type::getInt64Ty(Ctx);
985     TypeStr = TypeStr.starts_with("long") ? TypeStr.substr(strlen("long"))
986                                           : TypeStr.substr(strlen("ulong"));
987   } else if (TypeStr.starts_with("half")) {
988     Ty = Type::getHalfTy(Ctx);
989     TypeStr = TypeStr.substr(strlen("half"));
990   } else if (TypeStr.starts_with("float")) {
991     Ty = Type::getFloatTy(Ctx);
992     TypeStr = TypeStr.substr(strlen("float"));
993   } else if (TypeStr.starts_with("double")) {
994     Ty = Type::getDoubleTy(Ctx);
995     TypeStr = TypeStr.substr(strlen("double"));
996   } else
997     llvm_unreachable("Unable to recognize SPIRV type name.");
998 
999   auto SpirvTy = getOrCreateSPIRVType(Ty, MIRBuilder, AQ);
1000 
1001   // Handle "type*" or  "type* vector[N]".
1002   if (TypeStr.starts_with("*")) {
1003     SpirvTy = getOrCreateSPIRVPointerType(SpirvTy, MIRBuilder, SC);
1004     TypeStr = TypeStr.substr(strlen("*"));
1005   }
1006 
1007   // Handle "typeN*" or  "type vector[N]*".
1008   bool IsPtrToVec = TypeStr.consume_back("*");
1009 
1010   if (TypeStr.starts_with(" vector[")) {
1011     TypeStr = TypeStr.substr(strlen(" vector["));
1012     TypeStr = TypeStr.substr(0, TypeStr.find(']'));
1013   }
1014   TypeStr.getAsInteger(10, VecElts);
1015   if (VecElts > 0)
1016     SpirvTy = getOrCreateSPIRVVectorType(SpirvTy, VecElts, MIRBuilder);
1017 
1018   if (IsPtrToVec)
1019     SpirvTy = getOrCreateSPIRVPointerType(SpirvTy, MIRBuilder, SC);
1020 
1021   return SpirvTy;
1022 }
1023 
1024 SPIRVType *
1025 SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth,
1026                                                  MachineIRBuilder &MIRBuilder) {
1027   return getOrCreateSPIRVType(
1028       IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), BitWidth),
1029       MIRBuilder);
1030 }
1031 
1032 SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy,
1033                                                         SPIRVType *SpirvType) {
1034   assert(CurMF == SpirvType->getMF());
1035   VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType;
1036   SPIRVToLLVMType[SpirvType] = LLVMTy;
1037   return SpirvType;
1038 }
1039 
1040 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
1041     unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
1042   Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth);
1043   Register Reg = DT.find(LLVMTy, CurMF);
1044   if (Reg.isValid())
1045     return getSPIRVTypeForVReg(Reg);
1046   MachineBasicBlock &BB = *I.getParent();
1047   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeInt))
1048                  .addDef(createTypeVReg(CurMF->getRegInfo()))
1049                  .addImm(BitWidth)
1050                  .addImm(0);
1051   DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB));
1052   return finishCreatingSPIRVType(LLVMTy, MIB);
1053 }
1054 
1055 SPIRVType *
1056 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder) {
1057   return getOrCreateSPIRVType(
1058       IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), 1),
1059       MIRBuilder);
1060 }
1061 
1062 SPIRVType *
1063 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineInstr &I,
1064                                               const SPIRVInstrInfo &TII) {
1065   Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), 1);
1066   Register Reg = DT.find(LLVMTy, CurMF);
1067   if (Reg.isValid())
1068     return getSPIRVTypeForVReg(Reg);
1069   MachineBasicBlock &BB = *I.getParent();
1070   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeBool))
1071                  .addDef(createTypeVReg(CurMF->getRegInfo()));
1072   DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB));
1073   return finishCreatingSPIRVType(LLVMTy, MIB);
1074 }
1075 
1076 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
1077     SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder) {
1078   return getOrCreateSPIRVType(
1079       FixedVectorType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)),
1080                            NumElements),
1081       MIRBuilder);
1082 }
1083 
1084 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
1085     SPIRVType *BaseType, unsigned NumElements, MachineInstr &I,
1086     const SPIRVInstrInfo &TII) {
1087   Type *LLVMTy = FixedVectorType::get(
1088       const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements);
1089   Register Reg = DT.find(LLVMTy, CurMF);
1090   if (Reg.isValid())
1091     return getSPIRVTypeForVReg(Reg);
1092   MachineBasicBlock &BB = *I.getParent();
1093   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeVector))
1094                  .addDef(createTypeVReg(CurMF->getRegInfo()))
1095                  .addUse(getSPIRVTypeID(BaseType))
1096                  .addImm(NumElements);
1097   DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB));
1098   return finishCreatingSPIRVType(LLVMTy, MIB);
1099 }
1100 
1101 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVArrayType(
1102     SPIRVType *BaseType, unsigned NumElements, MachineInstr &I,
1103     const SPIRVInstrInfo &TII) {
1104   Type *LLVMTy = ArrayType::get(
1105       const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements);
1106   Register Reg = DT.find(LLVMTy, CurMF);
1107   if (Reg.isValid())
1108     return getSPIRVTypeForVReg(Reg);
1109   MachineBasicBlock &BB = *I.getParent();
1110   SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(32, I, TII);
1111   Register Len = getOrCreateConstInt(NumElements, I, SpirvType, TII);
1112   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeArray))
1113                  .addDef(createTypeVReg(CurMF->getRegInfo()))
1114                  .addUse(getSPIRVTypeID(BaseType))
1115                  .addUse(Len);
1116   DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB));
1117   return finishCreatingSPIRVType(LLVMTy, MIB);
1118 }
1119 
1120 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
1121     SPIRVType *BaseType, MachineIRBuilder &MIRBuilder,
1122     SPIRV::StorageClass::StorageClass SC) {
1123   const Type *PointerElementType = getTypeForSPIRVType(BaseType);
1124   unsigned AddressSpace = storageClassToAddressSpace(SC);
1125   Type *LLVMTy =
1126       PointerType::get(const_cast<Type *>(PointerElementType), AddressSpace);
1127   Register Reg = DT.find(PointerElementType, AddressSpace, CurMF);
1128   if (Reg.isValid())
1129     return getSPIRVTypeForVReg(Reg);
1130   auto MIB = BuildMI(MIRBuilder.getMBB(), MIRBuilder.getInsertPt(),
1131                      MIRBuilder.getDebugLoc(),
1132                      MIRBuilder.getTII().get(SPIRV::OpTypePointer))
1133                  .addDef(createTypeVReg(CurMF->getRegInfo()))
1134                  .addImm(static_cast<uint32_t>(SC))
1135                  .addUse(getSPIRVTypeID(BaseType));
1136   DT.add(PointerElementType, AddressSpace, CurMF, getSPIRVTypeID(MIB));
1137   return finishCreatingSPIRVType(LLVMTy, MIB);
1138 }
1139 
1140 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
1141     SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &TII,
1142     SPIRV::StorageClass::StorageClass SC) {
1143   const Type *PointerElementType = getTypeForSPIRVType(BaseType);
1144   unsigned AddressSpace = storageClassToAddressSpace(SC);
1145   Type *LLVMTy =
1146       PointerType::get(const_cast<Type *>(PointerElementType), AddressSpace);
1147   Register Reg = DT.find(PointerElementType, AddressSpace, CurMF);
1148   if (Reg.isValid())
1149     return getSPIRVTypeForVReg(Reg);
1150   MachineBasicBlock &BB = *I.getParent();
1151   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypePointer))
1152                  .addDef(createTypeVReg(CurMF->getRegInfo()))
1153                  .addImm(static_cast<uint32_t>(SC))
1154                  .addUse(getSPIRVTypeID(BaseType));
1155   DT.add(PointerElementType, AddressSpace, CurMF, getSPIRVTypeID(MIB));
1156   return finishCreatingSPIRVType(LLVMTy, MIB);
1157 }
1158 
1159 Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I,
1160                                                SPIRVType *SpvType,
1161                                                const SPIRVInstrInfo &TII) {
1162   assert(SpvType);
1163   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
1164   assert(LLVMTy);
1165   // Find a constant in DT or build a new one.
1166   UndefValue *UV = UndefValue::get(const_cast<Type *>(LLVMTy));
1167   Register Res = DT.find(UV, CurMF);
1168   if (Res.isValid())
1169     return Res;
1170   LLT LLTy = LLT::scalar(32);
1171   Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
1172   CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
1173   assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
1174   DT.add(UV, CurMF, Res);
1175 
1176   MachineInstrBuilder MIB;
1177   MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpUndef))
1178             .addDef(Res)
1179             .addUse(getSPIRVTypeID(SpvType));
1180   const auto &ST = CurMF->getSubtarget();
1181   constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
1182                                    *ST.getRegisterInfo(), *ST.getRegBankInfo());
1183   return Res;
1184 }
1185