1 //===-- SPIRVGlobalRegistry.cpp - SPIR-V Global Registry --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the implementation of the SPIRVGlobalRegistry class,
10 // which is used to maintain rich type information required for SPIR-V even
11 // after lowering from LLVM IR to GMIR. It can convert an llvm::Type into
12 // an OpTypeXXX instruction, and map it to a virtual register. Also it builds
13 // and supports consistency of constants and global variables.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "SPIRVGlobalRegistry.h"
18 #include "SPIRV.h"
19 #include "SPIRVSubtarget.h"
20 #include "SPIRVTargetMachine.h"
21 #include "SPIRVUtils.h"
22 
23 using namespace llvm;
24 SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize)
25     : PointerSize(PointerSize) {}
26 
27 SPIRVType *SPIRVGlobalRegistry::assignIntTypeToVReg(unsigned BitWidth,
28                                                     Register VReg,
29                                                     MachineInstr &I,
30                                                     const SPIRVInstrInfo &TII) {
31   SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
32   assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);
33   return SpirvType;
34 }
35 
36 SPIRVType *SPIRVGlobalRegistry::assignVectTypeToVReg(
37     SPIRVType *BaseType, unsigned NumElements, Register VReg, MachineInstr &I,
38     const SPIRVInstrInfo &TII) {
39   SPIRVType *SpirvType =
40       getOrCreateSPIRVVectorType(BaseType, NumElements, I, TII);
41   assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF);
42   return SpirvType;
43 }
44 
45 SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg(
46     const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder,
47     SPIRV::AccessQualifier AccessQual, bool EmitIR) {
48 
49   SPIRVType *SpirvType =
50       getOrCreateSPIRVType(Type, MIRBuilder, AccessQual, EmitIR);
51   assignSPIRVTypeToVReg(SpirvType, VReg, MIRBuilder.getMF());
52   return SpirvType;
53 }
54 
55 void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType,
56                                                 Register VReg,
57                                                 MachineFunction &MF) {
58   VRegToTypeMap[&MF][VReg] = SpirvType;
59 }
60 
61 static Register createTypeVReg(MachineIRBuilder &MIRBuilder) {
62   auto &MRI = MIRBuilder.getMF().getRegInfo();
63   auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32));
64   MRI.setRegClass(Res, &SPIRV::TYPERegClass);
65   return Res;
66 }
67 
68 static Register createTypeVReg(MachineRegisterInfo &MRI) {
69   auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32));
70   MRI.setRegClass(Res, &SPIRV::TYPERegClass);
71   return Res;
72 }
73 
74 SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
75   return MIRBuilder.buildInstr(SPIRV::OpTypeBool)
76       .addDef(createTypeVReg(MIRBuilder));
77 }
78 
79 SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width,
80                                              MachineIRBuilder &MIRBuilder,
81                                              bool IsSigned) {
82   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeInt)
83                  .addDef(createTypeVReg(MIRBuilder))
84                  .addImm(Width)
85                  .addImm(IsSigned ? 1 : 0);
86   return MIB;
87 }
88 
89 SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
90                                                MachineIRBuilder &MIRBuilder) {
91   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
92                  .addDef(createTypeVReg(MIRBuilder))
93                  .addImm(Width);
94   return MIB;
95 }
96 
97 SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) {
98   return MIRBuilder.buildInstr(SPIRV::OpTypeVoid)
99       .addDef(createTypeVReg(MIRBuilder));
100 }
101 
102 SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems,
103                                                 SPIRVType *ElemType,
104                                                 MachineIRBuilder &MIRBuilder) {
105   auto EleOpc = ElemType->getOpcode();
106   assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat ||
107           EleOpc == SPIRV::OpTypeBool) &&
108          "Invalid vector element type");
109 
110   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeVector)
111                  .addDef(createTypeVReg(MIRBuilder))
112                  .addUse(getSPIRVTypeID(ElemType))
113                  .addImm(NumElems);
114   return MIB;
115 }
116 
117 std::tuple<Register, ConstantInt *, bool>
118 SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType,
119                                             MachineIRBuilder *MIRBuilder,
120                                             MachineInstr *I,
121                                             const SPIRVInstrInfo *TII) {
122   const IntegerType *LLVMIntTy;
123   if (SpvType)
124     LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType));
125   else
126     LLVMIntTy = IntegerType::getInt32Ty(CurMF->getFunction().getContext());
127   bool NewInstr = false;
128   // Find a constant in DT or build a new one.
129   ConstantInt *CI = ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val);
130   Register Res = DT.find(CI, CurMF);
131   if (!Res.isValid()) {
132     unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
133     LLT LLTy = LLT::scalar(32);
134     Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
135     if (MIRBuilder)
136       assignTypeToVReg(LLVMIntTy, Res, *MIRBuilder);
137     else
138       assignIntTypeToVReg(BitWidth, Res, *I, *TII);
139     DT.add(CI, CurMF, Res);
140     NewInstr = true;
141   }
142   return std::make_tuple(Res, CI, NewInstr);
143 }
144 
145 Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I,
146                                                   SPIRVType *SpvType,
147                                                   const SPIRVInstrInfo &TII) {
148   assert(SpvType);
149   ConstantInt *CI;
150   Register Res;
151   bool New;
152   std::tie(Res, CI, New) =
153       getOrCreateConstIntReg(Val, SpvType, nullptr, &I, &TII);
154   // If we have found Res register which is defined by the passed G_CONSTANT
155   // machine instruction, a new constant instruction should be created.
156   if (!New && (!I.getOperand(0).isReg() || Res != I.getOperand(0).getReg()))
157     return Res;
158   MachineInstrBuilder MIB;
159   MachineBasicBlock &BB = *I.getParent();
160   if (Val) {
161     MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantI))
162               .addDef(Res)
163               .addUse(getSPIRVTypeID(SpvType));
164     addNumImm(APInt(getScalarOrVectorBitWidth(SpvType), Val), MIB);
165   } else {
166     MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
167               .addDef(Res)
168               .addUse(getSPIRVTypeID(SpvType));
169   }
170   const auto &ST = CurMF->getSubtarget();
171   constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
172                                    *ST.getRegisterInfo(), *ST.getRegBankInfo());
173   return Res;
174 }
175 
176 Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
177                                                MachineIRBuilder &MIRBuilder,
178                                                SPIRVType *SpvType,
179                                                bool EmitIR) {
180   auto &MF = MIRBuilder.getMF();
181   const IntegerType *LLVMIntTy;
182   if (SpvType)
183     LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType));
184   else
185     LLVMIntTy = IntegerType::getInt32Ty(MF.getFunction().getContext());
186   // Find a constant in DT or build a new one.
187   const auto ConstInt =
188       ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val);
189   Register Res = DT.find(ConstInt, &MF);
190   if (!Res.isValid()) {
191     unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
192     LLT LLTy = LLT::scalar(EmitIR ? BitWidth : 32);
193     Res = MF.getRegInfo().createGenericVirtualRegister(LLTy);
194     assignTypeToVReg(LLVMIntTy, Res, MIRBuilder,
195                      SPIRV::AccessQualifier::ReadWrite, EmitIR);
196     DT.add(ConstInt, &MIRBuilder.getMF(), Res);
197     if (EmitIR) {
198       MIRBuilder.buildConstant(Res, *ConstInt);
199     } else {
200       MachineInstrBuilder MIB;
201       if (Val) {
202         assert(SpvType);
203         MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI)
204                   .addDef(Res)
205                   .addUse(getSPIRVTypeID(SpvType));
206         addNumImm(APInt(BitWidth, Val), MIB);
207       } else {
208         assert(SpvType);
209         MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull)
210                   .addDef(Res)
211                   .addUse(getSPIRVTypeID(SpvType));
212       }
213       const auto &Subtarget = CurMF->getSubtarget();
214       constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
215                                        *Subtarget.getRegisterInfo(),
216                                        *Subtarget.getRegBankInfo());
217     }
218   }
219   return Res;
220 }
221 
222 Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,
223                                               MachineIRBuilder &MIRBuilder,
224                                               SPIRVType *SpvType) {
225   auto &MF = MIRBuilder.getMF();
226   const Type *LLVMFPTy;
227   if (SpvType) {
228     LLVMFPTy = getTypeForSPIRVType(SpvType);
229     assert(LLVMFPTy->isFloatingPointTy());
230   } else {
231     LLVMFPTy = IntegerType::getFloatTy(MF.getFunction().getContext());
232   }
233   // Find a constant in DT or build a new one.
234   const auto ConstFP = ConstantFP::get(LLVMFPTy->getContext(), Val);
235   Register Res = DT.find(ConstFP, &MF);
236   if (!Res.isValid()) {
237     unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
238     Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
239     assignTypeToVReg(LLVMFPTy, Res, MIRBuilder);
240     DT.add(ConstFP, &MF, Res);
241     MIRBuilder.buildFConstant(Res, *ConstFP);
242   }
243   return Res;
244 }
245 
246 Register
247 SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val, MachineInstr &I,
248                                               SPIRVType *SpvType,
249                                               const SPIRVInstrInfo &TII) {
250   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
251   assert(LLVMTy->isVectorTy());
252   const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy);
253   Type *LLVMBaseTy = LLVMVecTy->getElementType();
254   // Find a constant vector in DT or build a new one.
255   const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val);
256   auto ConstVec =
257       ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt);
258   Register Res = DT.find(ConstVec, CurMF);
259   if (!Res.isValid()) {
260     unsigned BitWidth = getScalarOrVectorBitWidth(SpvType);
261     SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII);
262     // SpvScalConst should be created before SpvVecConst to avoid undefined ID
263     // error on validation.
264     // TODO: can moved below once sorting of types/consts/defs is implemented.
265     Register SpvScalConst;
266     if (Val)
267       SpvScalConst = getOrCreateConstInt(Val, I, SpvBaseType, TII);
268     // TODO: maybe use bitwidth of base type.
269     LLT LLTy = LLT::scalar(32);
270     Register SpvVecConst =
271         CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
272     const unsigned ElemCnt = SpvType->getOperand(2).getImm();
273     assignVectTypeToVReg(SpvBaseType, ElemCnt, SpvVecConst, I, TII);
274     DT.add(ConstVec, CurMF, SpvVecConst);
275     MachineInstrBuilder MIB;
276     MachineBasicBlock &BB = *I.getParent();
277     if (Val) {
278       MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantComposite))
279                 .addDef(SpvVecConst)
280                 .addUse(getSPIRVTypeID(SpvType));
281       for (unsigned i = 0; i < ElemCnt; ++i)
282         MIB.addUse(SpvScalConst);
283     } else {
284       MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull))
285                 .addDef(SpvVecConst)
286                 .addUse(getSPIRVTypeID(SpvType));
287     }
288     const auto &Subtarget = CurMF->getSubtarget();
289     constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
290                                      *Subtarget.getRegisterInfo(),
291                                      *Subtarget.getRegBankInfo());
292     return SpvVecConst;
293   }
294   return Res;
295 }
296 
297 Register SPIRVGlobalRegistry::buildGlobalVariable(
298     Register ResVReg, SPIRVType *BaseType, StringRef Name,
299     const GlobalValue *GV, SPIRV::StorageClass Storage,
300     const MachineInstr *Init, bool IsConst, bool HasLinkageTy,
301     SPIRV::LinkageType LinkageType, MachineIRBuilder &MIRBuilder,
302     bool IsInstSelector) {
303   const GlobalVariable *GVar = nullptr;
304   if (GV)
305     GVar = cast<const GlobalVariable>(GV);
306   else {
307     // If GV is not passed explicitly, use the name to find or construct
308     // the global variable.
309     Module *M = MIRBuilder.getMF().getFunction().getParent();
310     GVar = M->getGlobalVariable(Name);
311     if (GVar == nullptr) {
312       const Type *Ty = getTypeForSPIRVType(BaseType); // TODO: check type.
313       GVar = new GlobalVariable(*M, const_cast<Type *>(Ty), false,
314                                 GlobalValue::ExternalLinkage, nullptr,
315                                 Twine(Name));
316     }
317     GV = GVar;
318   }
319   Register Reg = DT.find(GVar, &MIRBuilder.getMF());
320   if (Reg.isValid()) {
321     if (Reg != ResVReg)
322       MIRBuilder.buildCopy(ResVReg, Reg);
323     return ResVReg;
324   }
325 
326   auto MIB = MIRBuilder.buildInstr(SPIRV::OpVariable)
327                  .addDef(ResVReg)
328                  .addUse(getSPIRVTypeID(BaseType))
329                  .addImm(static_cast<uint32_t>(Storage));
330 
331   if (Init != 0) {
332     MIB.addUse(Init->getOperand(0).getReg());
333   }
334 
335   // ISel may introduce a new register on this step, so we need to add it to
336   // DT and correct its type avoiding fails on the next stage.
337   if (IsInstSelector) {
338     const auto &Subtarget = CurMF->getSubtarget();
339     constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
340                                      *Subtarget.getRegisterInfo(),
341                                      *Subtarget.getRegBankInfo());
342   }
343   Reg = MIB->getOperand(0).getReg();
344   DT.add(GVar, &MIRBuilder.getMF(), Reg);
345 
346   // Set to Reg the same type as ResVReg has.
347   auto MRI = MIRBuilder.getMRI();
348   assert(MRI->getType(ResVReg).isPointer() && "Pointer type is expected");
349   if (Reg != ResVReg) {
350     LLT RegLLTy = LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), 32);
351     MRI->setType(Reg, RegLLTy);
352     assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF());
353   }
354 
355   // If it's a global variable with name, output OpName for it.
356   if (GVar && GVar->hasName())
357     buildOpName(Reg, GVar->getName(), MIRBuilder);
358 
359   // Output decorations for the GV.
360   // TODO: maybe move to GenerateDecorations pass.
361   if (IsConst)
362     buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Constant, {});
363 
364   if (GVar && GVar->getAlign().valueOrOne().value() != 1)
365     buildOpDecorate(
366         Reg, MIRBuilder, SPIRV::Decoration::Alignment,
367         {static_cast<uint32_t>(GVar->getAlign().valueOrOne().value())});
368 
369   if (HasLinkageTy)
370     buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::LinkageAttributes,
371                     {static_cast<uint32_t>(LinkageType)}, Name);
372   return Reg;
373 }
374 
375 SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems,
376                                                SPIRVType *ElemType,
377                                                MachineIRBuilder &MIRBuilder,
378                                                bool EmitIR) {
379   assert((ElemType->getOpcode() != SPIRV::OpTypeVoid) &&
380          "Invalid array element type");
381   Register NumElementsVReg =
382       buildConstantInt(NumElems, MIRBuilder, nullptr, EmitIR);
383   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeArray)
384                  .addDef(createTypeVReg(MIRBuilder))
385                  .addUse(getSPIRVTypeID(ElemType))
386                  .addUse(NumElementsVReg);
387   return MIB;
388 }
389 
390 SPIRVType *SPIRVGlobalRegistry::getOpTypeOpaque(const StructType *Ty,
391                                                 MachineIRBuilder &MIRBuilder) {
392   assert(Ty->hasName());
393   const StringRef Name = Ty->hasName() ? Ty->getName() : "";
394   Register ResVReg = createTypeVReg(MIRBuilder);
395   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeOpaque).addDef(ResVReg);
396   addStringImm(Name, MIB);
397   buildOpName(ResVReg, Name, MIRBuilder);
398   return MIB;
399 }
400 
401 SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty,
402                                                 MachineIRBuilder &MIRBuilder,
403                                                 bool EmitIR) {
404   SmallVector<Register, 4> FieldTypes;
405   for (const auto &Elem : Ty->elements()) {
406     SPIRVType *ElemTy = findSPIRVType(Elem, MIRBuilder);
407     assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid &&
408            "Invalid struct element type");
409     FieldTypes.push_back(getSPIRVTypeID(ElemTy));
410   }
411   Register ResVReg = createTypeVReg(MIRBuilder);
412   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeStruct).addDef(ResVReg);
413   for (const auto &Ty : FieldTypes)
414     MIB.addUse(Ty);
415   if (Ty->hasName())
416     buildOpName(ResVReg, Ty->getName(), MIRBuilder);
417   if (Ty->isPacked())
418     buildOpDecorate(ResVReg, MIRBuilder, SPIRV::Decoration::CPacked, {});
419   return MIB;
420 }
421 
422 static bool isOpenCLBuiltinType(const StructType *SType) {
423   return SType->isOpaque() && SType->hasName() &&
424          SType->getName().startswith("opencl.");
425 }
426 
427 static bool isSPIRVBuiltinType(const StructType *SType) {
428   return SType->isOpaque() && SType->hasName() &&
429          SType->getName().startswith("spirv.");
430 }
431 
432 static bool isSpecialType(const Type *Ty) {
433   if (auto PType = dyn_cast<PointerType>(Ty)) {
434     if (!PType->isOpaque())
435       Ty = PType->getNonOpaquePointerElementType();
436   }
437   if (auto SType = dyn_cast<StructType>(Ty))
438     return isOpenCLBuiltinType(SType) || isSPIRVBuiltinType(SType);
439   return false;
440 }
441 
442 SPIRVType *SPIRVGlobalRegistry::getOpTypePointer(SPIRV::StorageClass SC,
443                                                  SPIRVType *ElemType,
444                                                  MachineIRBuilder &MIRBuilder,
445                                                  Register Reg) {
446   if (!Reg.isValid())
447     Reg = createTypeVReg(MIRBuilder);
448   return MIRBuilder.buildInstr(SPIRV::OpTypePointer)
449       .addDef(Reg)
450       .addImm(static_cast<uint32_t>(SC))
451       .addUse(getSPIRVTypeID(ElemType));
452 }
453 
454 SPIRVType *
455 SPIRVGlobalRegistry::getOpTypeForwardPointer(SPIRV::StorageClass SC,
456                                              MachineIRBuilder &MIRBuilder) {
457   return MIRBuilder.buildInstr(SPIRV::OpTypeForwardPointer)
458       .addUse(createTypeVReg(MIRBuilder))
459       .addImm(static_cast<uint32_t>(SC));
460 }
461 
462 SPIRVType *SPIRVGlobalRegistry::getOpTypeFunction(
463     SPIRVType *RetType, const SmallVectorImpl<SPIRVType *> &ArgTypes,
464     MachineIRBuilder &MIRBuilder) {
465   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFunction)
466                  .addDef(createTypeVReg(MIRBuilder))
467                  .addUse(getSPIRVTypeID(RetType));
468   for (const SPIRVType *ArgType : ArgTypes)
469     MIB.addUse(getSPIRVTypeID(ArgType));
470   return MIB;
471 }
472 
473 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeFunctionWithArgs(
474     const Type *Ty, SPIRVType *RetType,
475     const SmallVectorImpl<SPIRVType *> &ArgTypes,
476     MachineIRBuilder &MIRBuilder) {
477   Register Reg = DT.find(Ty, &MIRBuilder.getMF());
478   if (Reg.isValid())
479     return getSPIRVTypeForVReg(Reg);
480   SPIRVType *SpirvType = getOpTypeFunction(RetType, ArgTypes, MIRBuilder);
481   return finishCreatingSPIRVType(Ty, SpirvType);
482 }
483 
484 SPIRVType *SPIRVGlobalRegistry::findSPIRVType(const Type *Ty,
485                                               MachineIRBuilder &MIRBuilder,
486                                               SPIRV::AccessQualifier AccQual,
487                                               bool EmitIR) {
488   Register Reg = DT.find(Ty, &MIRBuilder.getMF());
489   if (Reg.isValid())
490     return getSPIRVTypeForVReg(Reg);
491   if (ForwardPointerTypes.find(Ty) != ForwardPointerTypes.end())
492     return ForwardPointerTypes[Ty];
493   return restOfCreateSPIRVType(Ty, MIRBuilder, AccQual, EmitIR);
494 }
495 
496 Register SPIRVGlobalRegistry::getSPIRVTypeID(const SPIRVType *SpirvType) const {
497   assert(SpirvType && "Attempting to get type id for nullptr type.");
498   if (SpirvType->getOpcode() == SPIRV::OpTypeForwardPointer)
499     return SpirvType->uses().begin()->getReg();
500   return SpirvType->defs().begin()->getReg();
501 }
502 
503 SPIRVType *SPIRVGlobalRegistry::createSPIRVType(const Type *Ty,
504                                                 MachineIRBuilder &MIRBuilder,
505                                                 SPIRV::AccessQualifier AccQual,
506                                                 bool EmitIR) {
507   assert(!isSpecialType(Ty));
508   auto &TypeToSPIRVTypeMap = DT.getTypes()->getAllUses();
509   auto t = TypeToSPIRVTypeMap.find(Ty);
510   if (t != TypeToSPIRVTypeMap.end()) {
511     auto tt = t->second.find(&MIRBuilder.getMF());
512     if (tt != t->second.end())
513       return getSPIRVTypeForVReg(tt->second);
514   }
515 
516   if (auto IType = dyn_cast<IntegerType>(Ty)) {
517     const unsigned Width = IType->getBitWidth();
518     return Width == 1 ? getOpTypeBool(MIRBuilder)
519                       : getOpTypeInt(Width, MIRBuilder, false);
520   }
521   if (Ty->isFloatingPointTy())
522     return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
523   if (Ty->isVoidTy())
524     return getOpTypeVoid(MIRBuilder);
525   if (Ty->isVectorTy()) {
526     SPIRVType *El =
527         findSPIRVType(cast<FixedVectorType>(Ty)->getElementType(), MIRBuilder);
528     return getOpTypeVector(cast<FixedVectorType>(Ty)->getNumElements(), El,
529                            MIRBuilder);
530   }
531   if (Ty->isArrayTy()) {
532     SPIRVType *El = findSPIRVType(Ty->getArrayElementType(), MIRBuilder);
533     return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder, EmitIR);
534   }
535   if (auto SType = dyn_cast<StructType>(Ty)) {
536     if (SType->isOpaque())
537       return getOpTypeOpaque(SType, MIRBuilder);
538     return getOpTypeStruct(SType, MIRBuilder, EmitIR);
539   }
540   if (auto FType = dyn_cast<FunctionType>(Ty)) {
541     SPIRVType *RetTy = findSPIRVType(FType->getReturnType(), MIRBuilder);
542     SmallVector<SPIRVType *, 4> ParamTypes;
543     for (const auto &t : FType->params()) {
544       ParamTypes.push_back(findSPIRVType(t, MIRBuilder));
545     }
546     return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder);
547   }
548   if (auto PType = dyn_cast<PointerType>(Ty)) {
549     SPIRVType *SpvElementType;
550     // At the moment, all opaque pointers correspond to i8 element type.
551     // TODO: change the implementation once opaque pointers are supported
552     // in the SPIR-V specification.
553     if (PType->isOpaque())
554       SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);
555     else
556       SpvElementType =
557           findSPIRVType(PType->getNonOpaquePointerElementType(), MIRBuilder,
558                         SPIRV::AccessQualifier::ReadWrite, EmitIR);
559     auto SC = addressSpaceToStorageClass(PType->getAddressSpace());
560     // Null pointer means we have a loop in type definitions, make and
561     // return corresponding OpTypeForwardPointer.
562     if (SpvElementType == nullptr) {
563       if (ForwardPointerTypes.find(Ty) == ForwardPointerTypes.end())
564         ForwardPointerTypes[PType] = getOpTypeForwardPointer(SC, MIRBuilder);
565       return ForwardPointerTypes[PType];
566     }
567     Register Reg(0);
568     // If we have forward pointer associated with this type, use its register
569     // operand to create OpTypePointer.
570     if (ForwardPointerTypes.find(PType) != ForwardPointerTypes.end())
571       Reg = getSPIRVTypeID(ForwardPointerTypes[PType]);
572 
573     return getOpTypePointer(SC, SpvElementType, MIRBuilder, Reg);
574   }
575   llvm_unreachable("Unable to convert LLVM type to SPIRVType");
576 }
577 
578 SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(
579     const Type *Ty, MachineIRBuilder &MIRBuilder,
580     SPIRV::AccessQualifier AccessQual, bool EmitIR) {
581   if (TypesInProcessing.count(Ty) && !Ty->isPointerTy())
582     return nullptr;
583   TypesInProcessing.insert(Ty);
584   SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
585   TypesInProcessing.erase(Ty);
586   VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType;
587   SPIRVToLLVMType[SpirvType] = Ty;
588   Register Reg = DT.find(Ty, &MIRBuilder.getMF());
589   // Do not add OpTypeForwardPointer to DT, a corresponding normal pointer type
590   // will be added later. For special types it is already added to DT.
591   if (SpirvType->getOpcode() != SPIRV::OpTypeForwardPointer && !Reg.isValid() &&
592       !isSpecialType(Ty))
593     DT.add(Ty, &MIRBuilder.getMF(), getSPIRVTypeID(SpirvType));
594 
595   return SpirvType;
596 }
597 
598 SPIRVType *SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg) const {
599   auto t = VRegToTypeMap.find(CurMF);
600   if (t != VRegToTypeMap.end()) {
601     auto tt = t->second.find(VReg);
602     if (tt != t->second.end())
603       return tt->second;
604   }
605   return nullptr;
606 }
607 
608 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(
609     const Type *Ty, MachineIRBuilder &MIRBuilder,
610     SPIRV::AccessQualifier AccessQual, bool EmitIR) {
611   Register Reg = DT.find(Ty, &MIRBuilder.getMF());
612   if (Reg.isValid())
613     return getSPIRVTypeForVReg(Reg);
614   TypesInProcessing.clear();
615   SPIRVType *STy = restOfCreateSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR);
616   // Create normal pointer types for the corresponding OpTypeForwardPointers.
617   for (auto &CU : ForwardPointerTypes) {
618     const Type *Ty2 = CU.first;
619     SPIRVType *STy2 = CU.second;
620     if ((Reg = DT.find(Ty2, &MIRBuilder.getMF())).isValid())
621       STy2 = getSPIRVTypeForVReg(Reg);
622     else
623       STy2 = restOfCreateSPIRVType(Ty2, MIRBuilder, AccessQual, EmitIR);
624     if (Ty == Ty2)
625       STy = STy2;
626   }
627   ForwardPointerTypes.clear();
628   return STy;
629 }
630 
631 bool SPIRVGlobalRegistry::isScalarOfType(Register VReg,
632                                          unsigned TypeOpcode) const {
633   SPIRVType *Type = getSPIRVTypeForVReg(VReg);
634   assert(Type && "isScalarOfType VReg has no type assigned");
635   return Type->getOpcode() == TypeOpcode;
636 }
637 
638 bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg,
639                                                  unsigned TypeOpcode) const {
640   SPIRVType *Type = getSPIRVTypeForVReg(VReg);
641   assert(Type && "isScalarOrVectorOfType VReg has no type assigned");
642   if (Type->getOpcode() == TypeOpcode)
643     return true;
644   if (Type->getOpcode() == SPIRV::OpTypeVector) {
645     Register ScalarTypeVReg = Type->getOperand(1).getReg();
646     SPIRVType *ScalarType = getSPIRVTypeForVReg(ScalarTypeVReg);
647     return ScalarType->getOpcode() == TypeOpcode;
648   }
649   return false;
650 }
651 
652 unsigned
653 SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const {
654   assert(Type && "Invalid Type pointer");
655   if (Type->getOpcode() == SPIRV::OpTypeVector) {
656     auto EleTypeReg = Type->getOperand(1).getReg();
657     Type = getSPIRVTypeForVReg(EleTypeReg);
658   }
659   if (Type->getOpcode() == SPIRV::OpTypeInt ||
660       Type->getOpcode() == SPIRV::OpTypeFloat)
661     return Type->getOperand(1).getImm();
662   if (Type->getOpcode() == SPIRV::OpTypeBool)
663     return 1;
664   llvm_unreachable("Attempting to get bit width of non-integer/float type.");
665 }
666 
667 bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const {
668   assert(Type && "Invalid Type pointer");
669   if (Type->getOpcode() == SPIRV::OpTypeVector) {
670     auto EleTypeReg = Type->getOperand(1).getReg();
671     Type = getSPIRVTypeForVReg(EleTypeReg);
672   }
673   if (Type->getOpcode() == SPIRV::OpTypeInt)
674     return Type->getOperand(2).getImm() != 0;
675   llvm_unreachable("Attempting to get sign of non-integer type.");
676 }
677 
678 SPIRV::StorageClass
679 SPIRVGlobalRegistry::getPointerStorageClass(Register VReg) const {
680   SPIRVType *Type = getSPIRVTypeForVReg(VReg);
681   assert(Type && Type->getOpcode() == SPIRV::OpTypePointer &&
682          Type->getOperand(1).isImm() && "Pointer type is expected");
683   return static_cast<SPIRV::StorageClass>(Type->getOperand(1).getImm());
684 }
685 
686 SPIRVType *
687 SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth,
688                                                  MachineIRBuilder &MIRBuilder) {
689   return getOrCreateSPIRVType(
690       IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), BitWidth),
691       MIRBuilder);
692 }
693 
694 SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy,
695                                                         SPIRVType *SpirvType) {
696   assert(CurMF == SpirvType->getMF());
697   VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType;
698   SPIRVToLLVMType[SpirvType] = LLVMTy;
699   DT.add(LLVMTy, CurMF, getSPIRVTypeID(SpirvType));
700   return SpirvType;
701 }
702 
703 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
704     unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
705   Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth);
706   Register Reg = DT.find(LLVMTy, CurMF);
707   if (Reg.isValid())
708     return getSPIRVTypeForVReg(Reg);
709   MachineBasicBlock &BB = *I.getParent();
710   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeInt))
711                  .addDef(createTypeVReg(CurMF->getRegInfo()))
712                  .addImm(BitWidth)
713                  .addImm(0);
714   return finishCreatingSPIRVType(LLVMTy, MIB);
715 }
716 
717 SPIRVType *
718 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder) {
719   return getOrCreateSPIRVType(
720       IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), 1),
721       MIRBuilder);
722 }
723 
724 SPIRVType *
725 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineInstr &I,
726                                               const SPIRVInstrInfo &TII) {
727   Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), 1);
728   Register Reg = DT.find(LLVMTy, CurMF);
729   if (Reg.isValid())
730     return getSPIRVTypeForVReg(Reg);
731   MachineBasicBlock &BB = *I.getParent();
732   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeBool))
733                  .addDef(createTypeVReg(CurMF->getRegInfo()));
734   return finishCreatingSPIRVType(LLVMTy, MIB);
735 }
736 
737 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
738     SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder) {
739   return getOrCreateSPIRVType(
740       FixedVectorType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)),
741                            NumElements),
742       MIRBuilder);
743 }
744 
745 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
746     SPIRVType *BaseType, unsigned NumElements, MachineInstr &I,
747     const SPIRVInstrInfo &TII) {
748   Type *LLVMTy = FixedVectorType::get(
749       const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements);
750   Register Reg = DT.find(LLVMTy, CurMF);
751   if (Reg.isValid())
752     return getSPIRVTypeForVReg(Reg);
753   MachineBasicBlock &BB = *I.getParent();
754   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeVector))
755                  .addDef(createTypeVReg(CurMF->getRegInfo()))
756                  .addUse(getSPIRVTypeID(BaseType))
757                  .addImm(NumElements);
758   return finishCreatingSPIRVType(LLVMTy, MIB);
759 }
760 
761 SPIRVType *
762 SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(SPIRVType *BaseType,
763                                                  MachineIRBuilder &MIRBuilder,
764                                                  SPIRV::StorageClass SClass) {
765   return getOrCreateSPIRVType(
766       PointerType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)),
767                        storageClassToAddressSpace(SClass)),
768       MIRBuilder);
769 }
770 
771 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
772     SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &TII,
773     SPIRV::StorageClass SC) {
774   Type *LLVMTy =
775       PointerType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)),
776                        storageClassToAddressSpace(SC));
777   Register Reg = DT.find(LLVMTy, CurMF);
778   if (Reg.isValid())
779     return getSPIRVTypeForVReg(Reg);
780   MachineBasicBlock &BB = *I.getParent();
781   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypePointer))
782                  .addDef(createTypeVReg(CurMF->getRegInfo()))
783                  .addImm(static_cast<uint32_t>(SC))
784                  .addUse(getSPIRVTypeID(BaseType));
785   return finishCreatingSPIRVType(LLVMTy, MIB);
786 }
787 
788 Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I,
789                                                SPIRVType *SpvType,
790                                                const SPIRVInstrInfo &TII) {
791   assert(SpvType);
792   const Type *LLVMTy = getTypeForSPIRVType(SpvType);
793   assert(LLVMTy);
794   // Find a constant in DT or build a new one.
795   UndefValue *UV = UndefValue::get(const_cast<Type *>(LLVMTy));
796   Register Res = DT.find(UV, CurMF);
797   if (Res.isValid())
798     return Res;
799   LLT LLTy = LLT::scalar(32);
800   Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy);
801   assignSPIRVTypeToVReg(SpvType, Res, *CurMF);
802   DT.add(UV, CurMF, Res);
803 
804   MachineInstrBuilder MIB;
805   MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpUndef))
806             .addDef(Res)
807             .addUse(getSPIRVTypeID(SpvType));
808   const auto &ST = CurMF->getSubtarget();
809   constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(),
810                                    *ST.getRegisterInfo(), *ST.getRegBankInfo());
811   return Res;
812 }
813