1 //===-- SPIRVGlobalRegistry.cpp - SPIR-V Global Registry --------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file contains the implementation of the SPIRVGlobalRegistry class,
10 // which is used to maintain rich type information required for SPIR-V even
11 // after lowering from LLVM IR to GMIR. It can convert an llvm::Type into
12 // an OpTypeXXX instruction, and map it to a virtual register. Also it builds
13 // and supports consistency of constants and global variables.
14 //
15 //===----------------------------------------------------------------------===//
16 
17 #include "SPIRVGlobalRegistry.h"
18 #include "SPIRV.h"
19 #include "SPIRVSubtarget.h"
20 #include "SPIRVTargetMachine.h"
21 #include "SPIRVUtils.h"
22 
23 using namespace llvm;
24 SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize)
25     : PointerSize(PointerSize) {}
26 
27 SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg(
28     const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder,
29     SPIRV::AccessQualifier AccessQual, bool EmitIR) {
30 
31   SPIRVType *SpirvType =
32       getOrCreateSPIRVType(Type, MIRBuilder, AccessQual, EmitIR);
33   assignSPIRVTypeToVReg(SpirvType, VReg, MIRBuilder.getMF());
34   return SpirvType;
35 }
36 
37 void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType,
38                                                 Register VReg,
39                                                 MachineFunction &MF) {
40   VRegToTypeMap[&MF][VReg] = SpirvType;
41 }
42 
43 static Register createTypeVReg(MachineIRBuilder &MIRBuilder) {
44   auto &MRI = MIRBuilder.getMF().getRegInfo();
45   auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32));
46   MRI.setRegClass(Res, &SPIRV::TYPERegClass);
47   return Res;
48 }
49 
50 static Register createTypeVReg(MachineRegisterInfo &MRI) {
51   auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32));
52   MRI.setRegClass(Res, &SPIRV::TYPERegClass);
53   return Res;
54 }
55 
56 SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) {
57   return MIRBuilder.buildInstr(SPIRV::OpTypeBool)
58       .addDef(createTypeVReg(MIRBuilder));
59 }
60 
61 SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width,
62                                              MachineIRBuilder &MIRBuilder,
63                                              bool IsSigned) {
64   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeInt)
65                  .addDef(createTypeVReg(MIRBuilder))
66                  .addImm(Width)
67                  .addImm(IsSigned ? 1 : 0);
68   return MIB;
69 }
70 
71 SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width,
72                                                MachineIRBuilder &MIRBuilder) {
73   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFloat)
74                  .addDef(createTypeVReg(MIRBuilder))
75                  .addImm(Width);
76   return MIB;
77 }
78 
79 SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) {
80   return MIRBuilder.buildInstr(SPIRV::OpTypeVoid)
81       .addDef(createTypeVReg(MIRBuilder));
82 }
83 
84 SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems,
85                                                 SPIRVType *ElemType,
86                                                 MachineIRBuilder &MIRBuilder) {
87   auto EleOpc = ElemType->getOpcode();
88   assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat ||
89           EleOpc == SPIRV::OpTypeBool) &&
90          "Invalid vector element type");
91 
92   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeVector)
93                  .addDef(createTypeVReg(MIRBuilder))
94                  .addUse(getSPIRVTypeID(ElemType))
95                  .addImm(NumElems);
96   return MIB;
97 }
98 
99 Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val,
100                                                MachineIRBuilder &MIRBuilder,
101                                                SPIRVType *SpvType,
102                                                bool EmitIR) {
103   auto &MF = MIRBuilder.getMF();
104   const IntegerType *LLVMIntTy;
105   if (SpvType)
106     LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType));
107   else
108     LLVMIntTy = IntegerType::getInt32Ty(MF.getFunction().getContext());
109   // Find a constant in DT or build a new one.
110   const auto ConstInt =
111       ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val);
112   Register Res = DT.find(ConstInt, &MF);
113   if (!Res.isValid()) {
114     unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
115     Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
116     assignTypeToVReg(LLVMIntTy, Res, MIRBuilder);
117     if (EmitIR)
118       MIRBuilder.buildConstant(Res, *ConstInt);
119     else
120       MIRBuilder.buildInstr(SPIRV::OpConstantI)
121           .addDef(Res)
122           .addImm(ConstInt->getSExtValue());
123   }
124   return Res;
125 }
126 
127 Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val,
128                                               MachineIRBuilder &MIRBuilder,
129                                               SPIRVType *SpvType) {
130   auto &MF = MIRBuilder.getMF();
131   const Type *LLVMFPTy;
132   if (SpvType) {
133     LLVMFPTy = getTypeForSPIRVType(SpvType);
134     assert(LLVMFPTy->isFloatingPointTy());
135   } else {
136     LLVMFPTy = IntegerType::getFloatTy(MF.getFunction().getContext());
137   }
138   // Find a constant in DT or build a new one.
139   const auto ConstFP = ConstantFP::get(LLVMFPTy->getContext(), Val);
140   Register Res = DT.find(ConstFP, &MF);
141   if (!Res.isValid()) {
142     unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32;
143     Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth));
144     assignTypeToVReg(LLVMFPTy, Res, MIRBuilder);
145     MIRBuilder.buildFConstant(Res, *ConstFP);
146   }
147   return Res;
148 }
149 
150 Register SPIRVGlobalRegistry::buildGlobalVariable(
151     Register ResVReg, SPIRVType *BaseType, StringRef Name,
152     const GlobalValue *GV, SPIRV::StorageClass Storage,
153     const MachineInstr *Init, bool IsConst, bool HasLinkageTy,
154     SPIRV::LinkageType LinkageType, MachineIRBuilder &MIRBuilder,
155     bool IsInstSelector) {
156   const GlobalVariable *GVar = nullptr;
157   if (GV)
158     GVar = cast<const GlobalVariable>(GV);
159   else {
160     // If GV is not passed explicitly, use the name to find or construct
161     // the global variable.
162     Module *M = MIRBuilder.getMF().getFunction().getParent();
163     GVar = M->getGlobalVariable(Name);
164     if (GVar == nullptr) {
165       const Type *Ty = getTypeForSPIRVType(BaseType); // TODO: check type.
166       GVar = new GlobalVariable(*M, const_cast<Type *>(Ty), false,
167                                 GlobalValue::ExternalLinkage, nullptr,
168                                 Twine(Name));
169     }
170     GV = GVar;
171   }
172   Register Reg;
173   auto MIB = MIRBuilder.buildInstr(SPIRV::OpVariable)
174                  .addDef(ResVReg)
175                  .addUse(getSPIRVTypeID(BaseType))
176                  .addImm(static_cast<uint32_t>(Storage));
177 
178   if (Init != 0) {
179     MIB.addUse(Init->getOperand(0).getReg());
180   }
181 
182   // ISel may introduce a new register on this step, so we need to add it to
183   // DT and correct its type avoiding fails on the next stage.
184   if (IsInstSelector) {
185     const auto &Subtarget = CurMF->getSubtarget();
186     constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(),
187                                      *Subtarget.getRegisterInfo(),
188                                      *Subtarget.getRegBankInfo());
189   }
190   Reg = MIB->getOperand(0).getReg();
191   DT.add(GVar, &MIRBuilder.getMF(), Reg);
192 
193   // Set to Reg the same type as ResVReg has.
194   auto MRI = MIRBuilder.getMRI();
195   assert(MRI->getType(ResVReg).isPointer() && "Pointer type is expected");
196   if (Reg != ResVReg) {
197     LLT RegLLTy = LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), 32);
198     MRI->setType(Reg, RegLLTy);
199     assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF());
200   }
201 
202   // If it's a global variable with name, output OpName for it.
203   if (GVar && GVar->hasName())
204     buildOpName(Reg, GVar->getName(), MIRBuilder);
205 
206   // Output decorations for the GV.
207   // TODO: maybe move to GenerateDecorations pass.
208   if (IsConst)
209     buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Constant, {});
210 
211   if (GVar && GVar->getAlign().valueOrOne().value() != 1)
212     buildOpDecorate(
213         Reg, MIRBuilder, SPIRV::Decoration::Alignment,
214         {static_cast<uint32_t>(GVar->getAlign().valueOrOne().value())});
215 
216   if (HasLinkageTy)
217     buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::LinkageAttributes,
218                     {static_cast<uint32_t>(LinkageType)}, Name);
219   return Reg;
220 }
221 
222 SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems,
223                                                SPIRVType *ElemType,
224                                                MachineIRBuilder &MIRBuilder,
225                                                bool EmitIR) {
226   assert((ElemType->getOpcode() != SPIRV::OpTypeVoid) &&
227          "Invalid array element type");
228   Register NumElementsVReg =
229       buildConstantInt(NumElems, MIRBuilder, nullptr, EmitIR);
230   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeArray)
231                  .addDef(createTypeVReg(MIRBuilder))
232                  .addUse(getSPIRVTypeID(ElemType))
233                  .addUse(NumElementsVReg);
234   return MIB;
235 }
236 
237 SPIRVType *SPIRVGlobalRegistry::getOpTypePointer(SPIRV::StorageClass SC,
238                                                  SPIRVType *ElemType,
239                                                  MachineIRBuilder &MIRBuilder) {
240   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypePointer)
241                  .addDef(createTypeVReg(MIRBuilder))
242                  .addImm(static_cast<uint32_t>(SC))
243                  .addUse(getSPIRVTypeID(ElemType));
244   return MIB;
245 }
246 
247 SPIRVType *SPIRVGlobalRegistry::getOpTypeFunction(
248     SPIRVType *RetType, const SmallVectorImpl<SPIRVType *> &ArgTypes,
249     MachineIRBuilder &MIRBuilder) {
250   auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFunction)
251                  .addDef(createTypeVReg(MIRBuilder))
252                  .addUse(getSPIRVTypeID(RetType));
253   for (const SPIRVType *ArgType : ArgTypes)
254     MIB.addUse(getSPIRVTypeID(ArgType));
255   return MIB;
256 }
257 
258 SPIRVType *SPIRVGlobalRegistry::createSPIRVType(const Type *Ty,
259                                                 MachineIRBuilder &MIRBuilder,
260                                                 SPIRV::AccessQualifier AccQual,
261                                                 bool EmitIR) {
262   if (auto IType = dyn_cast<IntegerType>(Ty)) {
263     const unsigned Width = IType->getBitWidth();
264     return Width == 1 ? getOpTypeBool(MIRBuilder)
265                       : getOpTypeInt(Width, MIRBuilder, false);
266   }
267   if (Ty->isFloatingPointTy())
268     return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder);
269   if (Ty->isVoidTy())
270     return getOpTypeVoid(MIRBuilder);
271   if (Ty->isVectorTy()) {
272     auto El = getOrCreateSPIRVType(cast<FixedVectorType>(Ty)->getElementType(),
273                                    MIRBuilder);
274     return getOpTypeVector(cast<FixedVectorType>(Ty)->getNumElements(), El,
275                            MIRBuilder);
276   }
277   if (Ty->isArrayTy()) {
278     auto *El = getOrCreateSPIRVType(Ty->getArrayElementType(), MIRBuilder);
279     return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder, EmitIR);
280   }
281   assert(!isa<StructType>(Ty) && "Unsupported StructType");
282   if (auto FType = dyn_cast<FunctionType>(Ty)) {
283     SPIRVType *RetTy = getOrCreateSPIRVType(FType->getReturnType(), MIRBuilder);
284     SmallVector<SPIRVType *, 4> ParamTypes;
285     for (const auto &t : FType->params()) {
286       ParamTypes.push_back(getOrCreateSPIRVType(t, MIRBuilder));
287     }
288     return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder);
289   }
290   if (auto PType = dyn_cast<PointerType>(Ty)) {
291     SPIRVType *SpvElementType;
292     // At the moment, all opaque pointers correspond to i8 element type.
293     // TODO: change the implementation once opaque pointers are supported
294     // in the SPIR-V specification.
295     if (PType->isOpaque()) {
296       SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder);
297     } else {
298       Type *ElemType = PType->getNonOpaquePointerElementType();
299       // TODO: support OpenCL and SPIRV builtins like image2d_t that are passed
300       // as pointers, but should be treated as custom types like OpTypeImage.
301       assert(!isa<StructType>(ElemType) && "Unsupported StructType pointer");
302 
303       // Otherwise, treat it as a regular pointer type.
304       SpvElementType = getOrCreateSPIRVType(
305           ElemType, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, EmitIR);
306     }
307     auto SC = addressSpaceToStorageClass(PType->getAddressSpace());
308     return getOpTypePointer(SC, SpvElementType, MIRBuilder);
309   }
310   llvm_unreachable("Unable to convert LLVM type to SPIRVType");
311 }
312 
313 SPIRVType *SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg) const {
314   auto t = VRegToTypeMap.find(CurMF);
315   if (t != VRegToTypeMap.end()) {
316     auto tt = t->second.find(VReg);
317     if (tt != t->second.end())
318       return tt->second;
319   }
320   return nullptr;
321 }
322 
323 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType(
324     const Type *Type, MachineIRBuilder &MIRBuilder,
325     SPIRV::AccessQualifier AccessQual, bool EmitIR) {
326   Register Reg = DT.find(Type, &MIRBuilder.getMF());
327   if (Reg.isValid())
328     return getSPIRVTypeForVReg(Reg);
329   SPIRVType *SpirvType = createSPIRVType(Type, MIRBuilder, AccessQual, EmitIR);
330   return restOfCreateSPIRVType(Type, SpirvType);
331 }
332 
333 bool SPIRVGlobalRegistry::isScalarOfType(Register VReg,
334                                          unsigned TypeOpcode) const {
335   SPIRVType *Type = getSPIRVTypeForVReg(VReg);
336   assert(Type && "isScalarOfType VReg has no type assigned");
337   return Type->getOpcode() == TypeOpcode;
338 }
339 
340 bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg,
341                                                  unsigned TypeOpcode) const {
342   SPIRVType *Type = getSPIRVTypeForVReg(VReg);
343   assert(Type && "isScalarOrVectorOfType VReg has no type assigned");
344   if (Type->getOpcode() == TypeOpcode)
345     return true;
346   if (Type->getOpcode() == SPIRV::OpTypeVector) {
347     Register ScalarTypeVReg = Type->getOperand(1).getReg();
348     SPIRVType *ScalarType = getSPIRVTypeForVReg(ScalarTypeVReg);
349     return ScalarType->getOpcode() == TypeOpcode;
350   }
351   return false;
352 }
353 
354 unsigned
355 SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const {
356   assert(Type && "Invalid Type pointer");
357   if (Type->getOpcode() == SPIRV::OpTypeVector) {
358     auto EleTypeReg = Type->getOperand(1).getReg();
359     Type = getSPIRVTypeForVReg(EleTypeReg);
360   }
361   if (Type->getOpcode() == SPIRV::OpTypeInt ||
362       Type->getOpcode() == SPIRV::OpTypeFloat)
363     return Type->getOperand(1).getImm();
364   if (Type->getOpcode() == SPIRV::OpTypeBool)
365     return 1;
366   llvm_unreachable("Attempting to get bit width of non-integer/float type.");
367 }
368 
369 bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const {
370   assert(Type && "Invalid Type pointer");
371   if (Type->getOpcode() == SPIRV::OpTypeVector) {
372     auto EleTypeReg = Type->getOperand(1).getReg();
373     Type = getSPIRVTypeForVReg(EleTypeReg);
374   }
375   if (Type->getOpcode() == SPIRV::OpTypeInt)
376     return Type->getOperand(2).getImm() != 0;
377   llvm_unreachable("Attempting to get sign of non-integer type.");
378 }
379 
380 SPIRV::StorageClass
381 SPIRVGlobalRegistry::getPointerStorageClass(Register VReg) const {
382   SPIRVType *Type = getSPIRVTypeForVReg(VReg);
383   assert(Type && Type->getOpcode() == SPIRV::OpTypePointer &&
384          Type->getOperand(1).isImm() && "Pointer type is expected");
385   return static_cast<SPIRV::StorageClass>(Type->getOperand(1).getImm());
386 }
387 
388 SPIRVType *
389 SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth,
390                                                  MachineIRBuilder &MIRBuilder) {
391   return getOrCreateSPIRVType(
392       IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), BitWidth),
393       MIRBuilder);
394 }
395 
396 SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(const Type *LLVMTy,
397                                                       SPIRVType *SpirvType) {
398   assert(CurMF == SpirvType->getMF());
399   VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType;
400   SPIRVToLLVMType[SpirvType] = LLVMTy;
401   DT.add(LLVMTy, CurMF, getSPIRVTypeID(SpirvType));
402   return SpirvType;
403 }
404 
405 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(
406     unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) {
407   Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth);
408   Register Reg = DT.find(LLVMTy, CurMF);
409   if (Reg.isValid())
410     return getSPIRVTypeForVReg(Reg);
411   MachineBasicBlock &BB = *I.getParent();
412   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeInt))
413                  .addDef(createTypeVReg(CurMF->getRegInfo()))
414                  .addImm(BitWidth)
415                  .addImm(0);
416   return restOfCreateSPIRVType(LLVMTy, MIB);
417 }
418 
419 SPIRVType *
420 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder) {
421   return getOrCreateSPIRVType(
422       IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), 1),
423       MIRBuilder);
424 }
425 
426 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
427     SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder) {
428   return getOrCreateSPIRVType(
429       FixedVectorType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)),
430                            NumElements),
431       MIRBuilder);
432 }
433 
434 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType(
435     SPIRVType *BaseType, unsigned NumElements, MachineInstr &I,
436     const SPIRVInstrInfo &TII) {
437   Type *LLVMTy = FixedVectorType::get(
438       const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements);
439   MachineBasicBlock &BB = *I.getParent();
440   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeVector))
441                  .addDef(createTypeVReg(CurMF->getRegInfo()))
442                  .addUse(getSPIRVTypeID(BaseType))
443                  .addImm(NumElements);
444   return restOfCreateSPIRVType(LLVMTy, MIB);
445 }
446 
447 SPIRVType *
448 SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(SPIRVType *BaseType,
449                                                  MachineIRBuilder &MIRBuilder,
450                                                  SPIRV::StorageClass SClass) {
451   return getOrCreateSPIRVType(
452       PointerType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)),
453                        storageClassToAddressSpace(SClass)),
454       MIRBuilder);
455 }
456 
457 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(
458     SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &TII,
459     SPIRV::StorageClass SC) {
460   Type *LLVMTy =
461       PointerType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)),
462                        storageClassToAddressSpace(SC));
463   MachineBasicBlock &BB = *I.getParent();
464   auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypePointer))
465                  .addDef(createTypeVReg(CurMF->getRegInfo()))
466                  .addImm(static_cast<uint32_t>(SC))
467                  .addUse(getSPIRVTypeID(BaseType));
468   return restOfCreateSPIRVType(LLVMTy, MIB);
469 }
470