1 //===-- SPIRVGlobalRegistry.cpp - SPIR-V Global Registry --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the implementation of the SPIRVGlobalRegistry class,
10 // which is used to maintain rich type information required for SPIR-V even
11 // after lowering from LLVM IR to GMIR. It can convert an llvm::Type into
12 // an OpTypeXXX instruction, and map it to a virtual register. Also it builds
13 // and supports consistency of constants and global variables.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "SPIRVGlobalRegistry.h"
18 #include "SPIRV.h"
19 #include "SPIRVBuiltins.h"
20 #include "SPIRVSubtarget.h"
21 #include "SPIRVTargetMachine.h"
22 #include "SPIRVUtils.h"
23 
24 using namespace llvm;
25 SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize)
26     : PointerSize(PointerSize) {}
27 
28 SPIRVType *SPIRVGlobalRegistry::assignIntTypeToVReg(unsigned BitWidth,
29                                                     Register VReg,
30                                                     MachineInstr &I,
31                                                     const SPIRVInstrInfo &TII) {
32   SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
33   assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);
34   return SpirvType;
35 }
36 
37 SPIRVType *SPIRVGlobalRegistry::assignVectTypeToVReg(
38     SPIRVType *BaseType, unsigned NumElements, Register VReg, MachineInstr &I,
39     const SPIRVInstrInfo &TII) {
40   SPIRVType *SpirvType =
41       getOrCreateSPIRVVectorType(BaseType, NumElements, I, TII);
42   assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);
43   return SpirvType;
44 }
45 
46 SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg(
47     const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder,
48     SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
49 
50   SPIRVType *SpirvType =
51       getOrCreateSPIRVType(Type, MIRBuilder, AccessQual, EmitIR);
52   assignSPIRVTypeToVReg(SpirvType, VReg, MIRBuilder.getMF());
53   return SpirvType;
54 }
55 
56 void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType,
57                                                 Register VReg,
58                                                 MachineFunction &MF) {
59   VRegToTypeMap[&MF][VReg] = SpirvType;
60 }
61 
62 static Register createTypeVReg(MachineIRBuilder &MIRBuilder) {
63   auto &MRI = MIRBuilder.getMF().getRegInfo();
64   auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32));
65   MRI.setRegClass(Res, &SPIRV::TYPERegClass);
66   return Res;
67 }
68 
69 static Register createTypeVReg(MachineRegisterInfo &MRI) {
70   auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32));
71   MRI.setRegClass(Res, &SPIRV::TYPERegClass);
72   return Res;
73 }
74 
75 SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
76   return MIRBuilder.buildInstr(SPIRV::OpTypeBool)
77       .addDef(createTypeVReg(MIRBuilder));
78 }
79 
80 SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width,
81                                              MachineIRBuilder &MIRBuilder,
82                                              bool IsSigned) {
83   assert(Width <= 64 && "Unsupported integer width!");
84   if (Width <= 8)
85     Width = 8;
86   else if (Width <= 16)
87     Width = 16;
88   else if (Width <= 32)
89     Width = 32;
90   else if (Width <= 64)
91     Width = 64;
92 
93   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeInt)
94                  .addDef(createTypeVReg(MIRBuilder))
95                  .addImm(Width)
96                  .addImm(IsSigned ? 1 : 0);
97   return MIB;
98 }
99 
100 SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
101                                                MachineIRBuilder &MIRBuilder) {
102   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
103                  .addDef(createTypeVReg(MIRBuilder))
104                  .addImm(Width);
105   return MIB;
106 }
107 
108 SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) {
109   return MIRBuilder.buildInstr(SPIRV::OpTypeVoid)
110       .addDef(createTypeVReg(MIRBuilder));
111 }
112 
113 SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems,
114                                                 SPIRVType *ElemType,
115                                                 MachineIRBuilder &MIRBuilder) {
116   auto EleOpc = ElemType->getOpcode();
117   assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat ||
118           EleOpc == SPIRV::OpTypeBool) &&
119          "Invalid vector element type");
120 
121   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeVector)
122                  .addDef(createTypeVReg(MIRBuilder))
123                  .addUse(getSPIRVTypeID(ElemType))
124                  .addImm(NumElems);
125   return MIB;
126 }
127 
128 std::tuple<Register, ConstantInt *, bool>
129 SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType,
130                                             MachineIRBuilder *MIRBuilder,
131                                             MachineInstr *I,
132                                             const SPIRVInstrInfo *TII) {
133   const IntegerType *LLVMIntTy;
134   if (SpvType)
135     LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType));
136   else
137     LLVMIntTy = IntegerType::getInt32Ty(CurMF->getFunction().getContext());
138   bool NewInstr = false;
139   // Find a constant in DT or build a new one.
140   ConstantInt *CI = ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val);
141   Register Res = DT.find(CI, CurMF);
142   if (!Res.isValid()) {
143     unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
144     LLT LLTy = LLT::scalar(32);
145     Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
146     CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
147     if (MIRBuilder)
148       assignTypeToVReg(LLVMIntTy, Res, *MIRBuilder);
149     else
150       assignIntTypeToVReg(BitWidth, Res, *I, *TII);
151     DT.add(CI, CurMF, Res);
152     NewInstr = true;
153   }
154   return std::make_tuple(Res, CI, NewInstr);
155 }
156 
157 Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
158                                                   SPIRVType *SpvType,
159                                                   const SPIRVInstrInfo &TII) {
160   assert(SpvType);
161   ConstantInt *CI;
162   Register Res;
163   bool New;
164   std::tie(Res, CI, New) =
165       getOrCreateConstIntReg(Val, SpvType, nullptr, &I, &TII);
166   // If we have found Res register which is defined by the passed G_CONSTANT
167   // machine instruction, a new constant instruction should be created.
168   if (!New && (!I.getOperand(0).isReg() || Res != I.getOperand(0).getReg()))
169     return Res;
170   MachineInstrBuilder MIB;
171   MachineBasicBlock &BB = *I.getParent();
172   if (Val) {
173     MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantI))
174               .addDef(Res)
175               .addUse(getSPIRVTypeID(SpvType));
176     addNumImm(APInt(getScalarOrVectorBitWidth(SpvType), Val), MIB);
177   } else {
178     MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
179               .addDef(Res)
180               .addUse(getSPIRVTypeID(SpvType));
181   }
182   const auto &ST = CurMF->getSubtarget();
183   constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
184                                    *ST.getRegisterInfo(), *ST.getRegBankInfo());
185   return Res;
186 }
187 
188 Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
189                                                MachineIRBuilder &MIRBuilder,
190                                                SPIRVType *SpvType,
191                                                bool EmitIR) {
192   auto &MF = MIRBuilder.getMF();
193   const IntegerType *LLVMIntTy;
194   if (SpvType)
195     LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType));
196   else
197     LLVMIntTy = IntegerType::getInt32Ty(MF.getFunction().getContext());
198   // Find a constant in DT or build a new one.
199   const auto ConstInt =
200       ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val);
201   Register Res = DT.find(ConstInt, &MF);
202   if (!Res.isValid()) {
203     unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
204     LLT LLTy = LLT::scalar(EmitIR ? BitWidth : 32);
205     Res = MF.getRegInfo().createGenericVirtualRegister(LLTy);
206     MF.getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
207     assignTypeToVReg(LLVMIntTy, Res, MIRBuilder,
208                      SPIRV::AccessQualifier::ReadWrite, EmitIR);
209     DT.add(ConstInt, &MIRBuilder.getMF(), Res);
210     if (EmitIR) {
211       MIRBuilder.buildConstant(Res, *ConstInt);
212     } else {
213       MachineInstrBuilder MIB;
214       if (Val) {
215         assert(SpvType);
216         MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
217                   .addDef(Res)
218                   .addUse(getSPIRVTypeID(SpvType));
219         addNumImm(APInt(BitWidth, Val), MIB);
220       } else {
221         assert(SpvType);
222         MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
223                   .addDef(Res)
224                   .addUse(getSPIRVTypeID(SpvType));
225       }
226       const auto &Subtarget = CurMF->getSubtarget();
227       constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
228                                        *Subtarget.getRegisterInfo(),
229                                        *Subtarget.getRegBankInfo());
230     }
231   }
232   return Res;
233 }
234 
235 Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,
236                                               MachineIRBuilder &MIRBuilder,
237                                               SPIRVType *SpvType) {
238   auto &MF = MIRBuilder.getMF();
239   const Type *LLVMFPTy;
240   if (SpvType) {
241     LLVMFPTy = getTypeForSPIRVType(SpvType);
242     assert(LLVMFPTy->isFloatingPointTy());
243   } else {
244     LLVMFPTy = IntegerType::getFloatTy(MF.getFunction().getContext());
245   }
246   // Find a constant in DT or build a new one.
247   const auto ConstFP = ConstantFP::get(LLVMFPTy->getContext(), Val);
248   Register Res = DT.find(ConstFP, &MF);
249   if (!Res.isValid()) {
250     unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
251     Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
252     MF.getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
253     assignTypeToVReg(LLVMFPTy, Res, MIRBuilder);
254     DT.add(ConstFP, &MF, Res);
255     MIRBuilder.buildFConstant(Res, *ConstFP);
256   }
257   return Res;
258 }
259 
260 Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
261     uint64_t Val, MachineInstr &I, SPIRVType *SpvType,
262     const SPIRVInstrInfo &TII, Constant *CA, unsigned BitWidth,
263     unsigned ElemCnt) {
264   // Find a constant vector in DT or build a new one.
265   Register Res = DT.find(CA, CurMF);
266   if (!Res.isValid()) {
267     SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
268     // SpvScalConst should be created before SpvVecConst to avoid undefined ID
269     // error on validation.
270     // TODO: can moved below once sorting of types/consts/defs is implemented.
271     Register SpvScalConst;
272     if (Val)
273       SpvScalConst = getOrCreateConstInt(Val, I, SpvBaseType, TII);
274     // TODO: maybe use bitwidth of base type.
275     LLT LLTy = LLT::scalar(32);
276     Register SpvVecConst =
277         CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
278     CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::IDRegClass);
279     assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF);
280     DT.add(CA, CurMF, SpvVecConst);
281     MachineInstrBuilder MIB;
282     MachineBasicBlock &BB = *I.getParent();
283     if (Val) {
284       MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantComposite))
285                 .addDef(SpvVecConst)
286                 .addUse(getSPIRVTypeID(SpvType));
287       for (unsigned i = 0; i < ElemCnt; ++i)
288         MIB.addUse(SpvScalConst);
289     } else {
290       MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
291                 .addDef(SpvVecConst)
292                 .addUse(getSPIRVTypeID(SpvType));
293     }
294     const auto &Subtarget = CurMF->getSubtarget();
295     constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
296                                      *Subtarget.getRegisterInfo(),
297                                      *Subtarget.getRegBankInfo());
298     return SpvVecConst;
299   }
300   return Res;
301 }
302 
303 Register
304 SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val, MachineInstr &I,
305                                               SPIRVType *SpvType,
306                                               const SPIRVInstrInfo &TII) {
307   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
308   assert(LLVMTy->isVectorTy());
309   const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
310   Type *LLVMBaseTy = LLVMVecTy->getElementType();
311   const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
312   auto ConstVec =
313       ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt);
314   unsigned BW = getScalarOrVectorBitWidth(SpvType);
315   return getOrCreateIntCompositeOrNull(Val, I, SpvType, TII, ConstVec, BW,
316                                        SpvType->getOperand(2).getImm());
317 }
318 
319 Register
320 SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val, MachineInstr &I,
321                                              SPIRVType *SpvType,
322                                              const SPIRVInstrInfo &TII) {
323   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
324   assert(LLVMTy->isArrayTy());
325   const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy);
326   Type *LLVMBaseTy = LLVMArrTy->getElementType();
327   const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
328   auto ConstArr =
329       ConstantArray::get(const_cast<ArrayType *>(LLVMArrTy), {ConstInt});
330   SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
331   unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy);
332   return getOrCreateIntCompositeOrNull(Val, I, SpvType, TII, ConstArr, BW,
333                                        LLVMArrTy->getNumElements());
334 }
335 
336 Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull(
337     uint64_t Val, MachineIRBuilder &MIRBuilder, SPIRVType *SpvType, bool EmitIR,
338     Constant *CA, unsigned BitWidth, unsigned ElemCnt) {
339   Register Res = DT.find(CA, CurMF);
340   if (!Res.isValid()) {
341     Register SpvScalConst;
342     if (Val || EmitIR) {
343       SPIRVType *SpvBaseType =
344           getOrCreateSPIRVIntegerType(BitWidth, MIRBuilder);
345       SpvScalConst = buildConstantInt(Val, MIRBuilder, SpvBaseType, EmitIR);
346     }
347     LLT LLTy = EmitIR ? LLT::fixed_vector(ElemCnt, BitWidth) : LLT::scalar(32);
348     Register SpvVecConst =
349         CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
350     CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::IDRegClass);
351     assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF);
352     DT.add(CA, CurMF, SpvVecConst);
353     if (EmitIR) {
354       MIRBuilder.buildSplatVector(SpvVecConst, SpvScalConst);
355     } else {
356       if (Val) {
357         auto MIB = MIRBuilder.buildInstr(SPIRV::OpConstantComposite)
358                        .addDef(SpvVecConst)
359                        .addUse(getSPIRVTypeID(SpvType));
360         for (unsigned i = 0; i < ElemCnt; ++i)
361           MIB.addUse(SpvScalConst);
362       } else {
363         MIRBuilder.buildInstr(SPIRV::OpConstantNull)
364             .addDef(SpvVecConst)
365             .addUse(getSPIRVTypeID(SpvType));
366       }
367     }
368     return SpvVecConst;
369   }
370   return Res;
371 }
372 
373 Register
374 SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val,
375                                               MachineIRBuilder &MIRBuilder,
376                                               SPIRVType *SpvType, bool EmitIR) {
377   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
378   assert(LLVMTy->isVectorTy());
379   const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
380   Type *LLVMBaseTy = LLVMVecTy->getElementType();
381   const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
382   auto ConstVec =
383       ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt);
384   unsigned BW = getScalarOrVectorBitWidth(SpvType);
385   return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR,
386                                        ConstVec, BW,
387                                        SpvType->getOperand(2).getImm());
388 }
389 
390 Register
391 SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val,
392                                              MachineIRBuilder &MIRBuilder,
393                                              SPIRVType *SpvType, bool EmitIR) {
394   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
395   assert(LLVMTy->isArrayTy());
396   const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy);
397   Type *LLVMBaseTy = LLVMArrTy->getElementType();
398   const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
399   auto ConstArr =
400       ConstantArray::get(const_cast<ArrayType *>(LLVMArrTy), {ConstInt});
401   SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg());
402   unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy);
403   return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR,
404                                        ConstArr, BW,
405                                        LLVMArrTy->getNumElements());
406 }
407 
408 Register
409 SPIRVGlobalRegistry::getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder,
410                                              SPIRVType *SpvType) {
411   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
412   const PointerType *LLVMPtrTy = cast<PointerType>(LLVMTy);
413   // Find a constant in DT or build a new one.
414   Constant *CP = ConstantPointerNull::get(const_cast<PointerType *>(LLVMPtrTy));
415   Register Res = DT.find(CP, CurMF);
416   if (!Res.isValid()) {
417     LLT LLTy = LLT::pointer(LLVMPtrTy->getAddressSpace(), PointerSize);
418     Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
419     CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
420     assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
421     MIRBuilder.buildInstr(SPIRV::OpConstantNull)
422         .addDef(Res)
423         .addUse(getSPIRVTypeID(SpvType));
424     DT.add(CP, CurMF, Res);
425   }
426   return Res;
427 }
428 
429 Register SPIRVGlobalRegistry::buildConstantSampler(
430     Register ResReg, unsigned AddrMode, unsigned Param, unsigned FilerMode,
431     MachineIRBuilder &MIRBuilder, SPIRVType *SpvType) {
432   SPIRVType *SampTy;
433   if (SpvType)
434     SampTy = getOrCreateSPIRVType(getTypeForSPIRVType(SpvType), MIRBuilder);
435   else
436     SampTy = getOrCreateSPIRVTypeByName("opencl.sampler_t", MIRBuilder);
437 
438   auto Sampler =
439       ResReg.isValid()
440           ? ResReg
441           : MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
442   auto Res = MIRBuilder.buildInstr(SPIRV::OpConstantSampler)
443                  .addDef(Sampler)
444                  .addUse(getSPIRVTypeID(SampTy))
445                  .addImm(AddrMode)
446                  .addImm(Param)
447                  .addImm(FilerMode);
448   assert(Res->getOperand(0).isReg());
449   return Res->getOperand(0).getReg();
450 }
451 
452 Register SPIRVGlobalRegistry::buildGlobalVariable(
453     Register ResVReg, SPIRVType *BaseType, StringRef Name,
454     const GlobalValue *GV, SPIRV::StorageClass::StorageClass Storage,
455     const MachineInstr *Init, bool IsConst, bool HasLinkageTy,
456     SPIRV::LinkageType::LinkageType LinkageType, MachineIRBuilder &MIRBuilder,
457     bool IsInstSelector) {
458   const GlobalVariable *GVar = nullptr;
459   if (GV)
460     GVar = cast<const GlobalVariable>(GV);
461   else {
462     // If GV is not passed explicitly, use the name to find or construct
463     // the global variable.
464     Module *M = MIRBuilder.getMF().getFunction().getParent();
465     GVar = M->getGlobalVariable(Name);
466     if (GVar == nullptr) {
467       const Type *Ty = getTypeForSPIRVType(BaseType); // TODO: check type.
468       GVar = new GlobalVariable(*M, const_cast<Type *>(Ty), false,
469                                 GlobalValue::ExternalLinkage, nullptr,
470                                 Twine(Name));
471     }
472     GV = GVar;
473   }
474   Register Reg = DT.find(GVar, &MIRBuilder.getMF());
475   if (Reg.isValid()) {
476     if (Reg != ResVReg)
477       MIRBuilder.buildCopy(ResVReg, Reg);
478     return ResVReg;
479   }
480 
481   auto MIB = MIRBuilder.buildInstr(SPIRV::OpVariable)
482                  .addDef(ResVReg)
483                  .addUse(getSPIRVTypeID(BaseType))
484                  .addImm(static_cast<uint32_t>(Storage));
485 
486   if (Init != 0) {
487     MIB.addUse(Init->getOperand(0).getReg());
488   }
489 
490   // ISel may introduce a new register on this step, so we need to add it to
491   // DT and correct its type avoiding fails on the next stage.
492   if (IsInstSelector) {
493     const auto &Subtarget = CurMF->getSubtarget();
494     constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
495                                      *Subtarget.getRegisterInfo(),
496                                      *Subtarget.getRegBankInfo());
497   }
498   Reg = MIB->getOperand(0).getReg();
499   DT.add(GVar, &MIRBuilder.getMF(), Reg);
500 
501   // Set to Reg the same type as ResVReg has.
502   auto MRI = MIRBuilder.getMRI();
503   assert(MRI->getType(ResVReg).isPointer() && "Pointer type is expected");
504   if (Reg != ResVReg) {
505     LLT RegLLTy = LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), 32);
506     MRI->setType(Reg, RegLLTy);
507     assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF());
508   }
509 
510   // If it's a global variable with name, output OpName for it.
511   if (GVar && GVar->hasName())
512     buildOpName(Reg, GVar->getName(), MIRBuilder);
513 
514   // Output decorations for the GV.
515   // TODO: maybe move to GenerateDecorations pass.
516   if (IsConst)
517     buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Constant, {});
518 
519   if (GVar && GVar->getAlign().valueOrOne().value() != 1) {
520     unsigned Alignment = (unsigned)GVar->getAlign().valueOrOne().value();
521     buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Alignment, {Alignment});
522   }
523 
524   if (HasLinkageTy)
525     buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::LinkageAttributes,
526                     {static_cast<uint32_t>(LinkageType)}, Name);
527 
528   SPIRV::BuiltIn::BuiltIn BuiltInId;
529   if (getSpirvBuiltInIdByName(Name, BuiltInId))
530     buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::BuiltIn,
531                     {static_cast<uint32_t>(BuiltInId)});
532 
533   return Reg;
534 }
535 
536 SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems,
537                                                SPIRVType *ElemType,
538                                                MachineIRBuilder &MIRBuilder,
539                                                bool EmitIR) {
540   assert((ElemType->getOpcode() != SPIRV::OpTypeVoid) &&
541          "Invalid array element type");
542   Register NumElementsVReg =
543       buildConstantInt(NumElems, MIRBuilder, nullptr, EmitIR);
544   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeArray)
545                  .addDef(createTypeVReg(MIRBuilder))
546                  .addUse(getSPIRVTypeID(ElemType))
547                  .addUse(NumElementsVReg);
548   return MIB;
549 }
550 
551 SPIRVType *SPIRVGlobalRegistry::getOpTypeOpaque(const StructType *Ty,
552                                                 MachineIRBuilder &MIRBuilder) {
553   assert(Ty->hasName());
554   const StringRef Name = Ty->hasName() ? Ty->getName() : "";
555   Register ResVReg = createTypeVReg(MIRBuilder);
556   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeOpaque).addDef(ResVReg);
557   addStringImm(Name, MIB);
558   buildOpName(ResVReg, Name, MIRBuilder);
559   return MIB;
560 }
561 
562 SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty,
563                                                 MachineIRBuilder &MIRBuilder,
564                                                 bool EmitIR) {
565   SmallVector<Register, 4> FieldTypes;
566   for (const auto &Elem : Ty->elements()) {
567     SPIRVType *ElemTy = findSPIRVType(Elem, MIRBuilder);
568     assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid &&
569            "Invalid struct element type");
570     FieldTypes.push_back(getSPIRVTypeID(ElemTy));
571   }
572   Register ResVReg = createTypeVReg(MIRBuilder);
573   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeStruct).addDef(ResVReg);
574   for (const auto &Ty : FieldTypes)
575     MIB.addUse(Ty);
576   if (Ty->hasName())
577     buildOpName(ResVReg, Ty->getName(), MIRBuilder);
578   if (Ty->isPacked())
579     buildOpDecorate(ResVReg, MIRBuilder, SPIRV::Decoration::CPacked, {});
580   return MIB;
581 }
582 
583 SPIRVType *SPIRVGlobalRegistry::getOrCreateSpecialType(
584     const Type *Ty, MachineIRBuilder &MIRBuilder,
585     SPIRV::AccessQualifier::AccessQualifier AccQual) {
586   // Some OpenCL and SPIRV builtins like image2d_t are passed in as
587   // pointers, but should be treated as custom types like OpTypeImage.
588   if (auto PType = dyn_cast<PointerType>(Ty)) {
589     assert(!PType->isOpaque());
590     Ty = PType->getNonOpaquePointerElementType();
591   }
592   assert(isSpecialOpaqueType(Ty) && "Not a special opaque builtin type");
593   return SPIRV::lowerBuiltinType(Ty, AccQual, MIRBuilder, this);
594 }
595 
596 SPIRVType *SPIRVGlobalRegistry::getOpTypePointer(
597     SPIRV::StorageClass::StorageClass SC, SPIRVType *ElemType,
598     MachineIRBuilder &MIRBuilder, Register Reg) {
599   if (!Reg.isValid())
600     Reg = createTypeVReg(MIRBuilder);
601   return MIRBuilder.buildInstr(SPIRV::OpTypePointer)
602       .addDef(Reg)
603       .addImm(static_cast<uint32_t>(SC))
604       .addUse(getSPIRVTypeID(ElemType));
605 }
606 
607 SPIRVType *SPIRVGlobalRegistry::getOpTypeForwardPointer(
608     SPIRV::StorageClass::StorageClass SC, MachineIRBuilder &MIRBuilder) {
609   return MIRBuilder.buildInstr(SPIRV::OpTypeForwardPointer)
610       .addUse(createTypeVReg(MIRBuilder))
611       .addImm(static_cast<uint32_t>(SC));
612 }
613 
614 SPIRVType *SPIRVGlobalRegistry::getOpTypeFunction(
615     SPIRVType *RetType, const SmallVectorImpl<SPIRVType *> &ArgTypes,
616     MachineIRBuilder &MIRBuilder) {
617   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFunction)
618                  .addDef(createTypeVReg(MIRBuilder))
619                  .addUse(getSPIRVTypeID(RetType));
620   for (const SPIRVType *ArgType : ArgTypes)
621     MIB.addUse(getSPIRVTypeID(ArgType));
622   return MIB;
623 }
624 
625 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeFunctionWithArgs(
626     const Type *Ty, SPIRVType *RetType,
627     const SmallVectorImpl<SPIRVType *> &ArgTypes,
628     MachineIRBuilder &MIRBuilder) {
629   Register Reg = DT.find(Ty, &MIRBuilder.getMF());
630   if (Reg.isValid())
631     return getSPIRVTypeForVReg(Reg);
632   SPIRVType *SpirvType = getOpTypeFunction(RetType, ArgTypes, MIRBuilder);
633   return finishCreatingSPIRVType(Ty, SpirvType);
634 }
635 
636 SPIRVType *SPIRVGlobalRegistry::findSPIRVType(
637     const Type *Ty, MachineIRBuilder &MIRBuilder,
638     SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
639   Register Reg = DT.find(Ty, &MIRBuilder.getMF());
640   if (Reg.isValid())
641     return getSPIRVTypeForVReg(Reg);
642   if (ForwardPointerTypes.find(Ty) != ForwardPointerTypes.end())
643     return ForwardPointerTypes[Ty];
644   return restOfCreateSPIRVType(Ty, MIRBuilder, AccQual, EmitIR);
645 }
646 
647 Register SPIRVGlobalRegistry::getSPIRVTypeID(const SPIRVType *SpirvType) const {
648   assert(SpirvType && "Attempting to get type id for nullptr type.");
649   if (SpirvType->getOpcode() == SPIRV::OpTypeForwardPointer)
650     return SpirvType->uses().begin()->getReg();
651   return SpirvType->defs().begin()->getReg();
652 }
653 
654 SPIRVType *SPIRVGlobalRegistry::createSPIRVType(
655     const Type *Ty, MachineIRBuilder &MIRBuilder,
656     SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) {
657   if (isSpecialOpaqueType(Ty))
658     return getOrCreateSpecialType(Ty, MIRBuilder, AccQual);
659   auto &TypeToSPIRVTypeMap = DT.getTypes()->getAllUses();
660   auto t = TypeToSPIRVTypeMap.find(Ty);
661   if (t != TypeToSPIRVTypeMap.end()) {
662     auto tt = t->second.find(&MIRBuilder.getMF());
663     if (tt != t->second.end())
664       return getSPIRVTypeForVReg(tt->second);
665   }
666 
667   if (auto IType = dyn_cast<IntegerType>(Ty)) {
668     const unsigned Width = IType->getBitWidth();
669     return Width == 1 ? getOpTypeBool(MIRBuilder)
670                       : getOpTypeInt(Width, MIRBuilder, false);
671   }
672   if (Ty->isFloatingPointTy())
673     return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
674   if (Ty->isVoidTy())
675     return getOpTypeVoid(MIRBuilder);
676   if (Ty->isVectorTy()) {
677     SPIRVType *El =
678         findSPIRVType(cast<FixedVectorType>(Ty)->getElementType(), MIRBuilder);
679     return getOpTypeVector(cast<FixedVectorType>(Ty)->getNumElements(), El,
680                            MIRBuilder);
681   }
682   if (Ty->isArrayTy()) {
683     SPIRVType *El = findSPIRVType(Ty->getArrayElementType(), MIRBuilder);
684     return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder, EmitIR);
685   }
686   if (auto SType = dyn_cast<StructType>(Ty)) {
687     if (SType->isOpaque())
688       return getOpTypeOpaque(SType, MIRBuilder);
689     return getOpTypeStruct(SType, MIRBuilder, EmitIR);
690   }
691   if (auto FType = dyn_cast<FunctionType>(Ty)) {
692     SPIRVType *RetTy = findSPIRVType(FType->getReturnType(), MIRBuilder);
693     SmallVector<SPIRVType *, 4> ParamTypes;
694     for (const auto &t : FType->params()) {
695       ParamTypes.push_back(findSPIRVType(t, MIRBuilder));
696     }
697     return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder);
698   }
699   if (auto PType = dyn_cast<PointerType>(Ty)) {
700     SPIRVType *SpvElementType;
701     // At the moment, all opaque pointers correspond to i8 element type.
702     // TODO: change the implementation once opaque pointers are supported
703     // in the SPIR-V specification.
704     if (PType->isOpaque())
705       SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);
706     else
707       SpvElementType =
708           findSPIRVType(PType->getNonOpaquePointerElementType(), MIRBuilder,
709                         SPIRV::AccessQualifier::ReadWrite, EmitIR);
710     auto SC = addressSpaceToStorageClass(PType->getAddressSpace());
711     // Null pointer means we have a loop in type definitions, make and
712     // return corresponding OpTypeForwardPointer.
713     if (SpvElementType == nullptr) {
714       if (ForwardPointerTypes.find(Ty) == ForwardPointerTypes.end())
715         ForwardPointerTypes[PType] = getOpTypeForwardPointer(SC, MIRBuilder);
716       return ForwardPointerTypes[PType];
717     }
718     Register Reg(0);
719     // If we have forward pointer associated with this type, use its register
720     // operand to create OpTypePointer.
721     if (ForwardPointerTypes.find(PType) != ForwardPointerTypes.end())
722       Reg = getSPIRVTypeID(ForwardPointerTypes[PType]);
723 
724     return getOpTypePointer(SC, SpvElementType, MIRBuilder, Reg);
725   }
726   llvm_unreachable("Unable to convert LLVM type to SPIRVType");
727 }
728 
729 SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
730     const Type *Ty, MachineIRBuilder &MIRBuilder,
731     SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
732   if (TypesInProcessing.count(Ty) && !Ty->isPointerTy())
733     return nullptr;
734   TypesInProcessing.insert(Ty);
735   SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
736   TypesInProcessing.erase(Ty);
737   VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;
738   SPIRVToLLVMType[SpirvType] = Ty;
739   Register Reg = DT.find(Ty, &MIRBuilder.getMF());
740   // Do not add OpTypeForwardPointer to DT, a corresponding normal pointer type
741   // will be added later. For special types it is already added to DT.
742   if (SpirvType->getOpcode() != SPIRV::OpTypeForwardPointer && !Reg.isValid() &&
743       !isSpecialOpaqueType(Ty))
744     DT.add(Ty, &MIRBuilder.getMF(), getSPIRVTypeID(SpirvType));
745 
746   return SpirvType;
747 }
748 
749 SPIRVType *SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg) const {
750   auto t = VRegToTypeMap.find(CurMF);
751   if (t != VRegToTypeMap.end()) {
752     auto tt = t->second.find(VReg);
753     if (tt != t->second.end())
754       return tt->second;
755   }
756   return nullptr;
757 }
758 
759 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(
760     const Type *Ty, MachineIRBuilder &MIRBuilder,
761     SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) {
762   Register Reg = DT.find(Ty, &MIRBuilder.getMF());
763   if (Reg.isValid() && !isSpecialOpaqueType(Ty))
764     return getSPIRVTypeForVReg(Reg);
765   TypesInProcessing.clear();
766   SPIRVType *STy = restOfCreateSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
767   // Create normal pointer types for the corresponding OpTypeForwardPointers.
768   for (auto &CU : ForwardPointerTypes) {
769     const Type *Ty2 = CU.first;
770     SPIRVType *STy2 = CU.second;
771     if ((Reg = DT.find(Ty2, &MIRBuilder.getMF())).isValid())
772       STy2 = getSPIRVTypeForVReg(Reg);
773     else
774       STy2 = restOfCreateSPIRVType(Ty2, MIRBuilder, AccessQual, EmitIR);
775     if (Ty == Ty2)
776       STy = STy2;
777   }
778   ForwardPointerTypes.clear();
779   return STy;
780 }
781 
782 bool SPIRVGlobalRegistry::isScalarOfType(Register VReg,
783                                          unsigned TypeOpcode) const {
784   SPIRVType *Type = getSPIRVTypeForVReg(VReg);
785   assert(Type && "isScalarOfType VReg has no type assigned");
786   return Type->getOpcode() == TypeOpcode;
787 }
788 
789 bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg,
790                                                  unsigned TypeOpcode) const {
791   SPIRVType *Type = getSPIRVTypeForVReg(VReg);
792   assert(Type && "isScalarOrVectorOfType VReg has no type assigned");
793   if (Type->getOpcode() == TypeOpcode)
794     return true;
795   if (Type->getOpcode() == SPIRV::OpTypeVector) {
796     Register ScalarTypeVReg = Type->getOperand(1).getReg();
797     SPIRVType *ScalarType = getSPIRVTypeForVReg(ScalarTypeVReg);
798     return ScalarType->getOpcode() == TypeOpcode;
799   }
800   return false;
801 }
802 
803 unsigned
804 SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const {
805   assert(Type && "Invalid Type pointer");
806   if (Type->getOpcode() == SPIRV::OpTypeVector) {
807     auto EleTypeReg = Type->getOperand(1).getReg();
808     Type = getSPIRVTypeForVReg(EleTypeReg);
809   }
810   if (Type->getOpcode() == SPIRV::OpTypeInt ||
811       Type->getOpcode() == SPIRV::OpTypeFloat)
812     return Type->getOperand(1).getImm();
813   if (Type->getOpcode() == SPIRV::OpTypeBool)
814     return 1;
815   llvm_unreachable("Attempting to get bit width of non-integer/float type.");
816 }
817 
818 bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const {
819   assert(Type && "Invalid Type pointer");
820   if (Type->getOpcode() == SPIRV::OpTypeVector) {
821     auto EleTypeReg = Type->getOperand(1).getReg();
822     Type = getSPIRVTypeForVReg(EleTypeReg);
823   }
824   if (Type->getOpcode() == SPIRV::OpTypeInt)
825     return Type->getOperand(2).getImm() != 0;
826   llvm_unreachable("Attempting to get sign of non-integer type.");
827 }
828 
829 SPIRV::StorageClass::StorageClass
830 SPIRVGlobalRegistry::getPointerStorageClass(Register VReg) const {
831   SPIRVType *Type = getSPIRVTypeForVReg(VReg);
832   assert(Type && Type->getOpcode() == SPIRV::OpTypePointer &&
833          Type->getOperand(1).isImm() && "Pointer type is expected");
834   return static_cast<SPIRV::StorageClass::StorageClass>(
835       Type->getOperand(1).getImm());
836 }
837 
838 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeImage(
839     MachineIRBuilder &MIRBuilder, SPIRVType *SampledType, SPIRV::Dim::Dim Dim,
840     uint32_t Depth, uint32_t Arrayed, uint32_t Multisampled, uint32_t Sampled,
841     SPIRV::ImageFormat::ImageFormat ImageFormat,
842     SPIRV::AccessQualifier::AccessQualifier AccessQual) {
843   SPIRV::ImageTypeDescriptor TD(SPIRVToLLVMType.lookup(SampledType), Dim, Depth,
844                                 Arrayed, Multisampled, Sampled, ImageFormat,
845                                 AccessQual);
846   if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
847     return Res;
848   Register ResVReg = createTypeVReg(MIRBuilder);
849   DT.add(TD, &MIRBuilder.getMF(), ResVReg);
850   return MIRBuilder.buildInstr(SPIRV::OpTypeImage)
851       .addDef(ResVReg)
852       .addUse(getSPIRVTypeID(SampledType))
853       .addImm(Dim)
854       .addImm(Depth)        // Depth (whether or not it is a Depth image).
855       .addImm(Arrayed)      // Arrayed.
856       .addImm(Multisampled) // Multisampled (0 = only single-sample).
857       .addImm(Sampled)      // Sampled (0 = usage known at runtime).
858       .addImm(ImageFormat)
859       .addImm(AccessQual);
860 }
861 
862 SPIRVType *
863 SPIRVGlobalRegistry::getOrCreateOpTypeSampler(MachineIRBuilder &MIRBuilder) {
864   SPIRV::SamplerTypeDescriptor TD;
865   if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
866     return Res;
867   Register ResVReg = createTypeVReg(MIRBuilder);
868   DT.add(TD, &MIRBuilder.getMF(), ResVReg);
869   return MIRBuilder.buildInstr(SPIRV::OpTypeSampler).addDef(ResVReg);
870 }
871 
872 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypePipe(
873     MachineIRBuilder &MIRBuilder,
874     SPIRV::AccessQualifier::AccessQualifier AccessQual) {
875   SPIRV::PipeTypeDescriptor TD(AccessQual);
876   if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
877     return Res;
878   Register ResVReg = createTypeVReg(MIRBuilder);
879   DT.add(TD, &MIRBuilder.getMF(), ResVReg);
880   return MIRBuilder.buildInstr(SPIRV::OpTypePipe)
881       .addDef(ResVReg)
882       .addImm(AccessQual);
883 }
884 
885 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeDeviceEvent(
886     MachineIRBuilder &MIRBuilder) {
887   SPIRV::DeviceEventTypeDescriptor TD;
888   if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
889     return Res;
890   Register ResVReg = createTypeVReg(MIRBuilder);
891   DT.add(TD, &MIRBuilder.getMF(), ResVReg);
892   return MIRBuilder.buildInstr(SPIRV::OpTypeDeviceEvent).addDef(ResVReg);
893 }
894 
895 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeSampledImage(
896     SPIRVType *ImageType, MachineIRBuilder &MIRBuilder) {
897   SPIRV::SampledImageTypeDescriptor TD(
898       SPIRVToLLVMType.lookup(MIRBuilder.getMF().getRegInfo().getVRegDef(
899           ImageType->getOperand(1).getReg())),
900       ImageType);
901   if (auto *Res = checkSpecialInstr(TD, MIRBuilder))
902     return Res;
903   Register ResVReg = createTypeVReg(MIRBuilder);
904   DT.add(TD, &MIRBuilder.getMF(), ResVReg);
905   return MIRBuilder.buildInstr(SPIRV::OpTypeSampledImage)
906       .addDef(ResVReg)
907       .addUse(getSPIRVTypeID(ImageType));
908 }
909 
910 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeByOpcode(
911     const Type *Ty, MachineIRBuilder &MIRBuilder, unsigned Opcode) {
912   Register ResVReg = DT.find(Ty, &MIRBuilder.getMF());
913   if (ResVReg.isValid())
914     return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(ResVReg);
915   ResVReg = createTypeVReg(MIRBuilder);
916   DT.add(Ty, &MIRBuilder.getMF(), ResVReg);
917   return MIRBuilder.buildInstr(Opcode).addDef(ResVReg);
918 }
919 
920 const MachineInstr *
921 SPIRVGlobalRegistry::checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD,
922                                        MachineIRBuilder &MIRBuilder) {
923   Register Reg = DT.find(TD, &MIRBuilder.getMF());
924   if (Reg.isValid())
925     return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(Reg);
926   return nullptr;
927 }
928 
929 // TODO: maybe use tablegen to implement this.
930 SPIRVType *
931 SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName(StringRef TypeStr,
932                                                 MachineIRBuilder &MIRBuilder) {
933   unsigned VecElts = 0;
934   auto &Ctx = MIRBuilder.getMF().getFunction().getContext();
935 
936   // Parse type name in either "typeN" or "type vector[N]" format, where
937   // N is the number of elements of the vector.
938   Type *Type;
939   if (TypeStr.startswith("void")) {
940     Type = Type::getVoidTy(Ctx);
941     TypeStr = TypeStr.substr(strlen("void"));
942   } else if (TypeStr.startswith("int") || TypeStr.startswith("uint")) {
943     Type = Type::getInt32Ty(Ctx);
944     TypeStr = TypeStr.startswith("int") ? TypeStr.substr(strlen("int"))
945                                         : TypeStr.substr(strlen("uint"));
946   } else if (TypeStr.startswith("float")) {
947     Type = Type::getFloatTy(Ctx);
948     TypeStr = TypeStr.substr(strlen("float"));
949   } else if (TypeStr.startswith("half")) {
950     Type = Type::getHalfTy(Ctx);
951     TypeStr = TypeStr.substr(strlen("half"));
952   } else if (TypeStr.startswith("opencl.sampler_t")) {
953     Type = StructType::create(Ctx, "opencl.sampler_t");
954   } else
955     llvm_unreachable("Unable to recognize SPIRV type name.");
956   if (TypeStr.startswith(" vector[")) {
957     TypeStr = TypeStr.substr(strlen(" vector["));
958     TypeStr = TypeStr.substr(0, TypeStr.find(']'));
959   }
960   TypeStr.getAsInteger(10, VecElts);
961   auto SpirvTy = getOrCreateSPIRVType(Type, MIRBuilder);
962   if (VecElts > 0)
963     SpirvTy = getOrCreateSPIRVVectorType(SpirvTy, VecElts, MIRBuilder);
964   return SpirvTy;
965 }
966 
967 SPIRVType *
968 SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth,
969                                                  MachineIRBuilder &MIRBuilder) {
970   return getOrCreateSPIRVType(
971       IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), BitWidth),
972       MIRBuilder);
973 }
974 
975 SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy,
976                                                         SPIRVType *SpirvType) {
977   assert(CurMF == SpirvType->getMF());
978   VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType;
979   SPIRVToLLVMType[SpirvType] = LLVMTy;
980   DT.add(LLVMTy, CurMF, getSPIRVTypeID(SpirvType));
981   return SpirvType;
982 }
983 
984 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
985     unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
986   Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth);
987   Register Reg = DT.find(LLVMTy, CurMF);
988   if (Reg.isValid())
989     return getSPIRVTypeForVReg(Reg);
990   MachineBasicBlock &BB = *I.getParent();
991   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeInt))
992                  .addDef(createTypeVReg(CurMF->getRegInfo()))
993                  .addImm(BitWidth)
994                  .addImm(0);
995   return finishCreatingSPIRVType(LLVMTy, MIB);
996 }
997 
998 SPIRVType *
999 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder) {
1000   return getOrCreateSPIRVType(
1001       IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), 1),
1002       MIRBuilder);
1003 }
1004 
1005 SPIRVType *
1006 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineInstr &I,
1007                                               const SPIRVInstrInfo &TII) {
1008   Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), 1);
1009   Register Reg = DT.find(LLVMTy, CurMF);
1010   if (Reg.isValid())
1011     return getSPIRVTypeForVReg(Reg);
1012   MachineBasicBlock &BB = *I.getParent();
1013   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeBool))
1014                  .addDef(createTypeVReg(CurMF->getRegInfo()));
1015   return finishCreatingSPIRVType(LLVMTy, MIB);
1016 }
1017 
1018 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
1019     SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder) {
1020   return getOrCreateSPIRVType(
1021       FixedVectorType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)),
1022                            NumElements),
1023       MIRBuilder);
1024 }
1025 
1026 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
1027     SPIRVType *BaseType, unsigned NumElements, MachineInstr &I,
1028     const SPIRVInstrInfo &TII) {
1029   Type *LLVMTy = FixedVectorType::get(
1030       const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements);
1031   Register Reg = DT.find(LLVMTy, CurMF);
1032   if (Reg.isValid())
1033     return getSPIRVTypeForVReg(Reg);
1034   MachineBasicBlock &BB = *I.getParent();
1035   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeVector))
1036                  .addDef(createTypeVReg(CurMF->getRegInfo()))
1037                  .addUse(getSPIRVTypeID(BaseType))
1038                  .addImm(NumElements);
1039   return finishCreatingSPIRVType(LLVMTy, MIB);
1040 }
1041 
1042 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVArrayType(
1043     SPIRVType *BaseType, unsigned NumElements, MachineInstr &I,
1044     const SPIRVInstrInfo &TII) {
1045   Type *LLVMTy = ArrayType::get(
1046       const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements);
1047   Register Reg = DT.find(LLVMTy, CurMF);
1048   if (Reg.isValid())
1049     return getSPIRVTypeForVReg(Reg);
1050   MachineBasicBlock &BB = *I.getParent();
1051   SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(32, I, TII);
1052   Register Len = getOrCreateConstInt(NumElements, I, SpirvType, TII);
1053   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeArray))
1054                  .addDef(createTypeVReg(CurMF->getRegInfo()))
1055                  .addUse(getSPIRVTypeID(BaseType))
1056                  .addUse(Len);
1057   return finishCreatingSPIRVType(LLVMTy, MIB);
1058 }
1059 
1060 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
1061     SPIRVType *BaseType, MachineIRBuilder &MIRBuilder,
1062     SPIRV::StorageClass::StorageClass SClass) {
1063   return getOrCreateSPIRVType(
1064       PointerType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)),
1065                        storageClassToAddressSpace(SClass)),
1066       MIRBuilder);
1067 }
1068 
1069 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
1070     SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &TII,
1071     SPIRV::StorageClass::StorageClass SC) {
1072   Type *LLVMTy =
1073       PointerType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)),
1074                        storageClassToAddressSpace(SC));
1075   Register Reg = DT.find(LLVMTy, CurMF);
1076   if (Reg.isValid())
1077     return getSPIRVTypeForVReg(Reg);
1078   MachineBasicBlock &BB = *I.getParent();
1079   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypePointer))
1080                  .addDef(createTypeVReg(CurMF->getRegInfo()))
1081                  .addImm(static_cast<uint32_t>(SC))
1082                  .addUse(getSPIRVTypeID(BaseType));
1083   return finishCreatingSPIRVType(LLVMTy, MIB);
1084 }
1085 
1086 Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I,
1087                                                SPIRVType *SpvType,
1088                                                const SPIRVInstrInfo &TII) {
1089   assert(SpvType);
1090   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
1091   assert(LLVMTy);
1092   // Find a constant in DT or build a new one.
1093   UndefValue *UV = UndefValue::get(const_cast<Type *>(LLVMTy));
1094   Register Res = DT.find(UV, CurMF);
1095   if (Res.isValid())
1096     return Res;
1097   LLT LLTy = LLT::scalar(32);
1098   Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
1099   CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass);
1100   assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
1101   DT.add(UV, CurMF, Res);
1102 
1103   MachineInstrBuilder MIB;
1104   MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpUndef))
1105             .addDef(Res)
1106             .addUse(getSPIRVTypeID(SpvType));
1107   const auto &ST = CurMF->getSubtarget();
1108   constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
1109                                    *ST.getRegisterInfo(), *ST.getRegBankInfo());
1110   return Res;
1111 }
1112