1 //===-- SPIRVGlobalRegistry.cpp - SPIR-V Global Registry --------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains the implementation of the SPIRVGlobalRegistry class, 10 // which is used to maintain rich type information required for SPIR-V even 11 // after lowering from LLVM IR to GMIR. It can convert an llvm::Type into 12 // an OpTypeXXX instruction, and map it to a virtual register. Also it builds 13 // and supports consistency of constants and global variables. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "SPIRVGlobalRegistry.h" 18 #include "SPIRV.h" 19 #include "SPIRVSubtarget.h" 20 #include "SPIRVTargetMachine.h" 21 #include "SPIRVUtils.h" 22 23 using namespace llvm; 24 SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize) 25 : PointerSize(PointerSize) {} 26 27 SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg( 28 const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder, 29 SPIRV::AccessQualifier AccessQual, bool EmitIR) { 30 31 SPIRVType *SpirvType = 32 getOrCreateSPIRVType(Type, MIRBuilder, AccessQual, EmitIR); 33 assignSPIRVTypeToVReg(SpirvType, VReg, MIRBuilder.getMF()); 34 return SpirvType; 35 } 36 37 void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType, 38 Register VReg, 39 MachineFunction &MF) { 40 VRegToTypeMap[&MF][VReg] = SpirvType; 41 } 42 43 static Register createTypeVReg(MachineIRBuilder &MIRBuilder) { 44 auto &MRI = MIRBuilder.getMF().getRegInfo(); 45 auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32)); 46 MRI.setRegClass(Res, &SPIRV::TYPERegClass); 47 return Res; 48 } 49 50 static Register createTypeVReg(MachineRegisterInfo &MRI) { 51 auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32)); 52 MRI.setRegClass(Res, &SPIRV::TYPERegClass); 53 return Res; 54 } 55 56 SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) { 57 return MIRBuilder.buildInstr(SPIRV::OpTypeBool) 58 .addDef(createTypeVReg(MIRBuilder)); 59 } 60 61 SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width, 62 MachineIRBuilder &MIRBuilder, 63 bool IsSigned) { 64 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeInt) 65 .addDef(createTypeVReg(MIRBuilder)) 66 .addImm(Width) 67 .addImm(IsSigned ? 1 : 0); 68 return MIB; 69 } 70 71 SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width, 72 MachineIRBuilder &MIRBuilder) { 73 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFloat) 74 .addDef(createTypeVReg(MIRBuilder)) 75 .addImm(Width); 76 return MIB; 77 } 78 79 SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) { 80 return MIRBuilder.buildInstr(SPIRV::OpTypeVoid) 81 .addDef(createTypeVReg(MIRBuilder)); 82 } 83 84 SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems, 85 SPIRVType *ElemType, 86 MachineIRBuilder &MIRBuilder) { 87 auto EleOpc = ElemType->getOpcode(); 88 assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat || 89 EleOpc == SPIRV::OpTypeBool) && 90 "Invalid vector element type"); 91 92 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeVector) 93 .addDef(createTypeVReg(MIRBuilder)) 94 .addUse(getSPIRVTypeID(ElemType)) 95 .addImm(NumElems); 96 return MIB; 97 } 98 99 Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val, 100 MachineIRBuilder &MIRBuilder, 101 SPIRVType *SpvType, 102 bool EmitIR) { 103 auto &MF = MIRBuilder.getMF(); 104 const IntegerType *LLVMIntTy; 105 if (SpvType) 106 LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType)); 107 else 108 LLVMIntTy = IntegerType::getInt32Ty(MF.getFunction().getContext()); 109 // Find a constant in DT or build a new one. 110 const auto ConstInt = 111 ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val); 112 Register Res = DT.find(ConstInt, &MF); 113 if (!Res.isValid()) { 114 unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32; 115 Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth)); 116 assignTypeToVReg(LLVMIntTy, Res, MIRBuilder); 117 if (EmitIR) 118 MIRBuilder.buildConstant(Res, *ConstInt); 119 else 120 MIRBuilder.buildInstr(SPIRV::OpConstantI) 121 .addDef(Res) 122 .addImm(ConstInt->getSExtValue()); 123 } 124 return Res; 125 } 126 127 Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val, 128 MachineIRBuilder &MIRBuilder, 129 SPIRVType *SpvType) { 130 auto &MF = MIRBuilder.getMF(); 131 const Type *LLVMFPTy; 132 if (SpvType) { 133 LLVMFPTy = getTypeForSPIRVType(SpvType); 134 assert(LLVMFPTy->isFloatingPointTy()); 135 } else { 136 LLVMFPTy = IntegerType::getFloatTy(MF.getFunction().getContext()); 137 } 138 // Find a constant in DT or build a new one. 139 const auto ConstFP = ConstantFP::get(LLVMFPTy->getContext(), Val); 140 Register Res = DT.find(ConstFP, &MF); 141 if (!Res.isValid()) { 142 unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32; 143 Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(BitWidth)); 144 assignTypeToVReg(LLVMFPTy, Res, MIRBuilder); 145 MIRBuilder.buildFConstant(Res, *ConstFP); 146 } 147 return Res; 148 } 149 150 Register SPIRVGlobalRegistry::buildGlobalVariable( 151 Register ResVReg, SPIRVType *BaseType, StringRef Name, 152 const GlobalValue *GV, SPIRV::StorageClass Storage, 153 const MachineInstr *Init, bool IsConst, bool HasLinkageTy, 154 SPIRV::LinkageType LinkageType, MachineIRBuilder &MIRBuilder, 155 bool IsInstSelector) { 156 const GlobalVariable *GVar = nullptr; 157 if (GV) 158 GVar = cast<const GlobalVariable>(GV); 159 else { 160 // If GV is not passed explicitly, use the name to find or construct 161 // the global variable. 162 Module *M = MIRBuilder.getMF().getFunction().getParent(); 163 GVar = M->getGlobalVariable(Name); 164 if (GVar == nullptr) { 165 const Type *Ty = getTypeForSPIRVType(BaseType); // TODO: check type. 166 GVar = new GlobalVariable(*M, const_cast<Type *>(Ty), false, 167 GlobalValue::ExternalLinkage, nullptr, 168 Twine(Name)); 169 } 170 GV = GVar; 171 } 172 Register Reg; 173 auto MIB = MIRBuilder.buildInstr(SPIRV::OpVariable) 174 .addDef(ResVReg) 175 .addUse(getSPIRVTypeID(BaseType)) 176 .addImm(static_cast<uint32_t>(Storage)); 177 178 if (Init != 0) { 179 MIB.addUse(Init->getOperand(0).getReg()); 180 } 181 182 // ISel may introduce a new register on this step, so we need to add it to 183 // DT and correct its type avoiding fails on the next stage. 184 if (IsInstSelector) { 185 const auto &Subtarget = CurMF->getSubtarget(); 186 constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(), 187 *Subtarget.getRegisterInfo(), 188 *Subtarget.getRegBankInfo()); 189 } 190 Reg = MIB->getOperand(0).getReg(); 191 DT.add(GVar, &MIRBuilder.getMF(), Reg); 192 193 // Set to Reg the same type as ResVReg has. 194 auto MRI = MIRBuilder.getMRI(); 195 assert(MRI->getType(ResVReg).isPointer() && "Pointer type is expected"); 196 if (Reg != ResVReg) { 197 LLT RegLLTy = LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), 32); 198 MRI->setType(Reg, RegLLTy); 199 assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF()); 200 } 201 202 // If it's a global variable with name, output OpName for it. 203 if (GVar && GVar->hasName()) 204 buildOpName(Reg, GVar->getName(), MIRBuilder); 205 206 // Output decorations for the GV. 207 // TODO: maybe move to GenerateDecorations pass. 208 if (IsConst) 209 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Constant, {}); 210 211 if (GVar && GVar->getAlign().valueOrOne().value() != 1) 212 buildOpDecorate( 213 Reg, MIRBuilder, SPIRV::Decoration::Alignment, 214 {static_cast<uint32_t>(GVar->getAlign().valueOrOne().value())}); 215 216 if (HasLinkageTy) 217 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::LinkageAttributes, 218 {static_cast<uint32_t>(LinkageType)}, Name); 219 return Reg; 220 } 221 222 SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems, 223 SPIRVType *ElemType, 224 MachineIRBuilder &MIRBuilder, 225 bool EmitIR) { 226 assert((ElemType->getOpcode() != SPIRV::OpTypeVoid) && 227 "Invalid array element type"); 228 Register NumElementsVReg = 229 buildConstantInt(NumElems, MIRBuilder, nullptr, EmitIR); 230 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeArray) 231 .addDef(createTypeVReg(MIRBuilder)) 232 .addUse(getSPIRVTypeID(ElemType)) 233 .addUse(NumElementsVReg); 234 return MIB; 235 } 236 237 SPIRVType *SPIRVGlobalRegistry::getOpTypePointer(SPIRV::StorageClass SC, 238 SPIRVType *ElemType, 239 MachineIRBuilder &MIRBuilder) { 240 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypePointer) 241 .addDef(createTypeVReg(MIRBuilder)) 242 .addImm(static_cast<uint32_t>(SC)) 243 .addUse(getSPIRVTypeID(ElemType)); 244 return MIB; 245 } 246 247 SPIRVType *SPIRVGlobalRegistry::getOpTypeFunction( 248 SPIRVType *RetType, const SmallVectorImpl<SPIRVType *> &ArgTypes, 249 MachineIRBuilder &MIRBuilder) { 250 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFunction) 251 .addDef(createTypeVReg(MIRBuilder)) 252 .addUse(getSPIRVTypeID(RetType)); 253 for (const SPIRVType *ArgType : ArgTypes) 254 MIB.addUse(getSPIRVTypeID(ArgType)); 255 return MIB; 256 } 257 258 SPIRVType *SPIRVGlobalRegistry::createSPIRVType(const Type *Ty, 259 MachineIRBuilder &MIRBuilder, 260 SPIRV::AccessQualifier AccQual, 261 bool EmitIR) { 262 if (auto IType = dyn_cast<IntegerType>(Ty)) { 263 const unsigned Width = IType->getBitWidth(); 264 return Width == 1 ? getOpTypeBool(MIRBuilder) 265 : getOpTypeInt(Width, MIRBuilder, false); 266 } 267 if (Ty->isFloatingPointTy()) 268 return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder); 269 if (Ty->isVoidTy()) 270 return getOpTypeVoid(MIRBuilder); 271 if (Ty->isVectorTy()) { 272 auto El = getOrCreateSPIRVType(cast<FixedVectorType>(Ty)->getElementType(), 273 MIRBuilder); 274 return getOpTypeVector(cast<FixedVectorType>(Ty)->getNumElements(), El, 275 MIRBuilder); 276 } 277 if (Ty->isArrayTy()) { 278 auto *El = getOrCreateSPIRVType(Ty->getArrayElementType(), MIRBuilder); 279 return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder, EmitIR); 280 } 281 assert(!isa<StructType>(Ty) && "Unsupported StructType"); 282 if (auto FType = dyn_cast<FunctionType>(Ty)) { 283 SPIRVType *RetTy = getOrCreateSPIRVType(FType->getReturnType(), MIRBuilder); 284 SmallVector<SPIRVType *, 4> ParamTypes; 285 for (const auto &t : FType->params()) { 286 ParamTypes.push_back(getOrCreateSPIRVType(t, MIRBuilder)); 287 } 288 return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder); 289 } 290 if (auto PType = dyn_cast<PointerType>(Ty)) { 291 SPIRVType *SpvElementType; 292 // At the moment, all opaque pointers correspond to i8 element type. 293 // TODO: change the implementation once opaque pointers are supported 294 // in the SPIR-V specification. 295 if (PType->isOpaque()) { 296 SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder); 297 } else { 298 Type *ElemType = PType->getNonOpaquePointerElementType(); 299 // TODO: support OpenCL and SPIRV builtins like image2d_t that are passed 300 // as pointers, but should be treated as custom types like OpTypeImage. 301 assert(!isa<StructType>(ElemType) && "Unsupported StructType pointer"); 302 303 // Otherwise, treat it as a regular pointer type. 304 SpvElementType = getOrCreateSPIRVType( 305 ElemType, MIRBuilder, SPIRV::AccessQualifier::ReadWrite, EmitIR); 306 } 307 auto SC = addressSpaceToStorageClass(PType->getAddressSpace()); 308 return getOpTypePointer(SC, SpvElementType, MIRBuilder); 309 } 310 llvm_unreachable("Unable to convert LLVM type to SPIRVType"); 311 } 312 313 SPIRVType *SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg) const { 314 auto t = VRegToTypeMap.find(CurMF); 315 if (t != VRegToTypeMap.end()) { 316 auto tt = t->second.find(VReg); 317 if (tt != t->second.end()) 318 return tt->second; 319 } 320 return nullptr; 321 } 322 323 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType( 324 const Type *Type, MachineIRBuilder &MIRBuilder, 325 SPIRV::AccessQualifier AccessQual, bool EmitIR) { 326 Register Reg = DT.find(Type, &MIRBuilder.getMF()); 327 if (Reg.isValid()) 328 return getSPIRVTypeForVReg(Reg); 329 SPIRVType *SpirvType = createSPIRVType(Type, MIRBuilder, AccessQual, EmitIR); 330 return restOfCreateSPIRVType(Type, SpirvType); 331 } 332 333 bool SPIRVGlobalRegistry::isScalarOfType(Register VReg, 334 unsigned TypeOpcode) const { 335 SPIRVType *Type = getSPIRVTypeForVReg(VReg); 336 assert(Type && "isScalarOfType VReg has no type assigned"); 337 return Type->getOpcode() == TypeOpcode; 338 } 339 340 bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg, 341 unsigned TypeOpcode) const { 342 SPIRVType *Type = getSPIRVTypeForVReg(VReg); 343 assert(Type && "isScalarOrVectorOfType VReg has no type assigned"); 344 if (Type->getOpcode() == TypeOpcode) 345 return true; 346 if (Type->getOpcode() == SPIRV::OpTypeVector) { 347 Register ScalarTypeVReg = Type->getOperand(1).getReg(); 348 SPIRVType *ScalarType = getSPIRVTypeForVReg(ScalarTypeVReg); 349 return ScalarType->getOpcode() == TypeOpcode; 350 } 351 return false; 352 } 353 354 unsigned 355 SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const { 356 assert(Type && "Invalid Type pointer"); 357 if (Type->getOpcode() == SPIRV::OpTypeVector) { 358 auto EleTypeReg = Type->getOperand(1).getReg(); 359 Type = getSPIRVTypeForVReg(EleTypeReg); 360 } 361 if (Type->getOpcode() == SPIRV::OpTypeInt || 362 Type->getOpcode() == SPIRV::OpTypeFloat) 363 return Type->getOperand(1).getImm(); 364 if (Type->getOpcode() == SPIRV::OpTypeBool) 365 return 1; 366 llvm_unreachable("Attempting to get bit width of non-integer/float type."); 367 } 368 369 bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const { 370 assert(Type && "Invalid Type pointer"); 371 if (Type->getOpcode() == SPIRV::OpTypeVector) { 372 auto EleTypeReg = Type->getOperand(1).getReg(); 373 Type = getSPIRVTypeForVReg(EleTypeReg); 374 } 375 if (Type->getOpcode() == SPIRV::OpTypeInt) 376 return Type->getOperand(2).getImm() != 0; 377 llvm_unreachable("Attempting to get sign of non-integer type."); 378 } 379 380 SPIRV::StorageClass 381 SPIRVGlobalRegistry::getPointerStorageClass(Register VReg) const { 382 SPIRVType *Type = getSPIRVTypeForVReg(VReg); 383 assert(Type && Type->getOpcode() == SPIRV::OpTypePointer && 384 Type->getOperand(1).isImm() && "Pointer type is expected"); 385 return static_cast<SPIRV::StorageClass>(Type->getOperand(1).getImm()); 386 } 387 388 SPIRVType * 389 SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth, 390 MachineIRBuilder &MIRBuilder) { 391 return getOrCreateSPIRVType( 392 IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), BitWidth), 393 MIRBuilder); 394 } 395 396 SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType(const Type *LLVMTy, 397 SPIRVType *SpirvType) { 398 assert(CurMF == SpirvType->getMF()); 399 VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType; 400 SPIRVToLLVMType[SpirvType] = LLVMTy; 401 DT.add(LLVMTy, CurMF, getSPIRVTypeID(SpirvType)); 402 return SpirvType; 403 } 404 405 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType( 406 unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) { 407 Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth); 408 Register Reg = DT.find(LLVMTy, CurMF); 409 if (Reg.isValid()) 410 return getSPIRVTypeForVReg(Reg); 411 MachineBasicBlock &BB = *I.getParent(); 412 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeInt)) 413 .addDef(createTypeVReg(CurMF->getRegInfo())) 414 .addImm(BitWidth) 415 .addImm(0); 416 return restOfCreateSPIRVType(LLVMTy, MIB); 417 } 418 419 SPIRVType * 420 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder) { 421 return getOrCreateSPIRVType( 422 IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), 1), 423 MIRBuilder); 424 } 425 426 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType( 427 SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder) { 428 return getOrCreateSPIRVType( 429 FixedVectorType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)), 430 NumElements), 431 MIRBuilder); 432 } 433 434 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType( 435 SPIRVType *BaseType, unsigned NumElements, MachineInstr &I, 436 const SPIRVInstrInfo &TII) { 437 Type *LLVMTy = FixedVectorType::get( 438 const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements); 439 MachineBasicBlock &BB = *I.getParent(); 440 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeVector)) 441 .addDef(createTypeVReg(CurMF->getRegInfo())) 442 .addUse(getSPIRVTypeID(BaseType)) 443 .addImm(NumElements); 444 return restOfCreateSPIRVType(LLVMTy, MIB); 445 } 446 447 SPIRVType * 448 SPIRVGlobalRegistry::getOrCreateSPIRVPointerType(SPIRVType *BaseType, 449 MachineIRBuilder &MIRBuilder, 450 SPIRV::StorageClass SClass) { 451 return getOrCreateSPIRVType( 452 PointerType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)), 453 storageClassToAddressSpace(SClass)), 454 MIRBuilder); 455 } 456 457 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType( 458 SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &TII, 459 SPIRV::StorageClass SC) { 460 Type *LLVMTy = 461 PointerType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)), 462 storageClassToAddressSpace(SC)); 463 MachineBasicBlock &BB = *I.getParent(); 464 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypePointer)) 465 .addDef(createTypeVReg(CurMF->getRegInfo())) 466 .addImm(static_cast<uint32_t>(SC)) 467 .addUse(getSPIRVTypeID(BaseType)); 468 return restOfCreateSPIRVType(LLVMTy, MIB); 469 } 470