1 //===-- SPIRVGlobalRegistry.cpp - SPIR-V Global Registry --------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file contains the implementation of the SPIRVGlobalRegistry class, 10 // which is used to maintain rich type information required for SPIR-V even 11 // after lowering from LLVM IR to GMIR. It can convert an llvm::Type into 12 // an OpTypeXXX instruction, and map it to a virtual register. Also it builds 13 // and supports consistency of constants and global variables. 14 // 15 //===----------------------------------------------------------------------===// 16 17 #include "SPIRVGlobalRegistry.h" 18 #include "SPIRV.h" 19 #include "SPIRVBuiltins.h" 20 #include "SPIRVSubtarget.h" 21 #include "SPIRVTargetMachine.h" 22 #include "SPIRVUtils.h" 23 24 using namespace llvm; 25 SPIRVGlobalRegistry::SPIRVGlobalRegistry(unsigned PointerSize) 26 : PointerSize(PointerSize) {} 27 28 SPIRVType *SPIRVGlobalRegistry::assignIntTypeToVReg(unsigned BitWidth, 29 Register VReg, 30 MachineInstr &I, 31 const SPIRVInstrInfo &TII) { 32 SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(BitWidth, I, TII); 33 assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF); 34 return SpirvType; 35 } 36 37 SPIRVType *SPIRVGlobalRegistry::assignVectTypeToVReg( 38 SPIRVType *BaseType, unsigned NumElements, Register VReg, MachineInstr &I, 39 const SPIRVInstrInfo &TII) { 40 SPIRVType *SpirvType = 41 getOrCreateSPIRVVectorType(BaseType, NumElements, I, TII); 42 assignSPIRVTypeToVReg(SpirvType, VReg, *CurMF); 43 return SpirvType; 44 } 45 46 SPIRVType *SPIRVGlobalRegistry::assignTypeToVReg( 47 const Type *Type, Register VReg, MachineIRBuilder &MIRBuilder, 48 SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) { 49 50 SPIRVType *SpirvType = 51 getOrCreateSPIRVType(Type, MIRBuilder, AccessQual, EmitIR); 52 assignSPIRVTypeToVReg(SpirvType, VReg, MIRBuilder.getMF()); 53 return SpirvType; 54 } 55 56 void SPIRVGlobalRegistry::assignSPIRVTypeToVReg(SPIRVType *SpirvType, 57 Register VReg, 58 MachineFunction &MF) { 59 VRegToTypeMap[&MF][VReg] = SpirvType; 60 } 61 62 static Register createTypeVReg(MachineIRBuilder &MIRBuilder) { 63 auto &MRI = MIRBuilder.getMF().getRegInfo(); 64 auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32)); 65 MRI.setRegClass(Res, &SPIRV::TYPERegClass); 66 return Res; 67 } 68 69 static Register createTypeVReg(MachineRegisterInfo &MRI) { 70 auto Res = MRI.createGenericVirtualRegister(LLT::scalar(32)); 71 MRI.setRegClass(Res, &SPIRV::TYPERegClass); 72 return Res; 73 } 74 75 SPIRVType *SPIRVGlobalRegistry::getOpTypeBool(MachineIRBuilder &MIRBuilder) { 76 return MIRBuilder.buildInstr(SPIRV::OpTypeBool) 77 .addDef(createTypeVReg(MIRBuilder)); 78 } 79 80 SPIRVType *SPIRVGlobalRegistry::getOpTypeInt(uint32_t Width, 81 MachineIRBuilder &MIRBuilder, 82 bool IsSigned) { 83 assert(Width <= 64 && "Unsupported integer width!"); 84 const SPIRVSubtarget &ST = 85 cast<SPIRVSubtarget>(MIRBuilder.getMF().getSubtarget()); 86 if (ST.canUseExtension( 87 SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers)) { 88 MIRBuilder.buildInstr(SPIRV::OpExtension) 89 .addImm(SPIRV::Extension::SPV_INTEL_arbitrary_precision_integers); 90 MIRBuilder.buildInstr(SPIRV::OpCapability) 91 .addImm(SPIRV::Capability::ArbitraryPrecisionIntegersINTEL); 92 } else if (Width <= 8) 93 Width = 8; 94 else if (Width <= 16) 95 Width = 16; 96 else if (Width <= 32) 97 Width = 32; 98 else if (Width <= 64) 99 Width = 64; 100 101 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeInt) 102 .addDef(createTypeVReg(MIRBuilder)) 103 .addImm(Width) 104 .addImm(IsSigned ? 1 : 0); 105 return MIB; 106 } 107 108 SPIRVType *SPIRVGlobalRegistry::getOpTypeFloat(uint32_t Width, 109 MachineIRBuilder &MIRBuilder) { 110 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFloat) 111 .addDef(createTypeVReg(MIRBuilder)) 112 .addImm(Width); 113 return MIB; 114 } 115 116 SPIRVType *SPIRVGlobalRegistry::getOpTypeVoid(MachineIRBuilder &MIRBuilder) { 117 return MIRBuilder.buildInstr(SPIRV::OpTypeVoid) 118 .addDef(createTypeVReg(MIRBuilder)); 119 } 120 121 SPIRVType *SPIRVGlobalRegistry::getOpTypeVector(uint32_t NumElems, 122 SPIRVType *ElemType, 123 MachineIRBuilder &MIRBuilder) { 124 auto EleOpc = ElemType->getOpcode(); 125 assert((EleOpc == SPIRV::OpTypeInt || EleOpc == SPIRV::OpTypeFloat || 126 EleOpc == SPIRV::OpTypeBool) && 127 "Invalid vector element type"); 128 129 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeVector) 130 .addDef(createTypeVReg(MIRBuilder)) 131 .addUse(getSPIRVTypeID(ElemType)) 132 .addImm(NumElems); 133 return MIB; 134 } 135 136 std::tuple<Register, ConstantInt *, bool> 137 SPIRVGlobalRegistry::getOrCreateConstIntReg(uint64_t Val, SPIRVType *SpvType, 138 MachineIRBuilder *MIRBuilder, 139 MachineInstr *I, 140 const SPIRVInstrInfo *TII) { 141 const IntegerType *LLVMIntTy; 142 if (SpvType) 143 LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType)); 144 else 145 LLVMIntTy = IntegerType::getInt32Ty(CurMF->getFunction().getContext()); 146 bool NewInstr = false; 147 // Find a constant in DT or build a new one. 148 ConstantInt *CI = ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val); 149 Register Res = DT.find(CI, CurMF); 150 if (!Res.isValid()) { 151 unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32; 152 LLT LLTy = LLT::scalar(32); 153 Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy); 154 CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass); 155 if (MIRBuilder) 156 assignTypeToVReg(LLVMIntTy, Res, *MIRBuilder); 157 else 158 assignIntTypeToVReg(BitWidth, Res, *I, *TII); 159 DT.add(CI, CurMF, Res); 160 NewInstr = true; 161 } 162 return std::make_tuple(Res, CI, NewInstr); 163 } 164 165 Register SPIRVGlobalRegistry::getOrCreateConstInt(uint64_t Val, MachineInstr &I, 166 SPIRVType *SpvType, 167 const SPIRVInstrInfo &TII) { 168 assert(SpvType); 169 ConstantInt *CI; 170 Register Res; 171 bool New; 172 std::tie(Res, CI, New) = 173 getOrCreateConstIntReg(Val, SpvType, nullptr, &I, &TII); 174 // If we have found Res register which is defined by the passed G_CONSTANT 175 // machine instruction, a new constant instruction should be created. 176 if (!New && (!I.getOperand(0).isReg() || Res != I.getOperand(0).getReg())) 177 return Res; 178 MachineInstrBuilder MIB; 179 MachineBasicBlock &BB = *I.getParent(); 180 if (Val) { 181 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantI)) 182 .addDef(Res) 183 .addUse(getSPIRVTypeID(SpvType)); 184 addNumImm(APInt(getScalarOrVectorBitWidth(SpvType), Val), MIB); 185 } else { 186 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull)) 187 .addDef(Res) 188 .addUse(getSPIRVTypeID(SpvType)); 189 } 190 const auto &ST = CurMF->getSubtarget(); 191 constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(), 192 *ST.getRegisterInfo(), *ST.getRegBankInfo()); 193 return Res; 194 } 195 196 Register SPIRVGlobalRegistry::buildConstantInt(uint64_t Val, 197 MachineIRBuilder &MIRBuilder, 198 SPIRVType *SpvType, 199 bool EmitIR) { 200 auto &MF = MIRBuilder.getMF(); 201 const IntegerType *LLVMIntTy; 202 if (SpvType) 203 LLVMIntTy = cast<IntegerType>(getTypeForSPIRVType(SpvType)); 204 else 205 LLVMIntTy = IntegerType::getInt32Ty(MF.getFunction().getContext()); 206 // Find a constant in DT or build a new one. 207 const auto ConstInt = 208 ConstantInt::get(const_cast<IntegerType *>(LLVMIntTy), Val); 209 Register Res = DT.find(ConstInt, &MF); 210 if (!Res.isValid()) { 211 unsigned BitWidth = SpvType ? getScalarOrVectorBitWidth(SpvType) : 32; 212 LLT LLTy = LLT::scalar(EmitIR ? BitWidth : 32); 213 Res = MF.getRegInfo().createGenericVirtualRegister(LLTy); 214 MF.getRegInfo().setRegClass(Res, &SPIRV::IDRegClass); 215 assignTypeToVReg(LLVMIntTy, Res, MIRBuilder, 216 SPIRV::AccessQualifier::ReadWrite, EmitIR); 217 DT.add(ConstInt, &MIRBuilder.getMF(), Res); 218 if (EmitIR) { 219 MIRBuilder.buildConstant(Res, *ConstInt); 220 } else { 221 MachineInstrBuilder MIB; 222 if (Val) { 223 assert(SpvType); 224 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantI) 225 .addDef(Res) 226 .addUse(getSPIRVTypeID(SpvType)); 227 addNumImm(APInt(BitWidth, Val), MIB); 228 } else { 229 assert(SpvType); 230 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantNull) 231 .addDef(Res) 232 .addUse(getSPIRVTypeID(SpvType)); 233 } 234 const auto &Subtarget = CurMF->getSubtarget(); 235 constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(), 236 *Subtarget.getRegisterInfo(), 237 *Subtarget.getRegBankInfo()); 238 } 239 } 240 return Res; 241 } 242 243 Register SPIRVGlobalRegistry::buildConstantFP(APFloat Val, 244 MachineIRBuilder &MIRBuilder, 245 SPIRVType *SpvType) { 246 auto &MF = MIRBuilder.getMF(); 247 auto &Ctx = MF.getFunction().getContext(); 248 if (!SpvType) { 249 const Type *LLVMFPTy = Type::getFloatTy(Ctx); 250 SpvType = getOrCreateSPIRVType(LLVMFPTy, MIRBuilder); 251 } 252 // Find a constant in DT or build a new one. 253 const auto ConstFP = ConstantFP::get(Ctx, Val); 254 Register Res = DT.find(ConstFP, &MF); 255 if (!Res.isValid()) { 256 Res = MF.getRegInfo().createGenericVirtualRegister(LLT::scalar(32)); 257 MF.getRegInfo().setRegClass(Res, &SPIRV::IDRegClass); 258 assignSPIRVTypeToVReg(SpvType, Res, MF); 259 DT.add(ConstFP, &MF, Res); 260 261 MachineInstrBuilder MIB; 262 MIB = MIRBuilder.buildInstr(SPIRV::OpConstantF) 263 .addDef(Res) 264 .addUse(getSPIRVTypeID(SpvType)); 265 addNumImm(ConstFP->getValueAPF().bitcastToAPInt(), MIB); 266 } 267 268 return Res; 269 } 270 271 Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull( 272 uint64_t Val, MachineInstr &I, SPIRVType *SpvType, 273 const SPIRVInstrInfo &TII, Constant *CA, unsigned BitWidth, 274 unsigned ElemCnt) { 275 // Find a constant vector in DT or build a new one. 276 Register Res = DT.find(CA, CurMF); 277 if (!Res.isValid()) { 278 SPIRVType *SpvBaseType = getOrCreateSPIRVIntegerType(BitWidth, I, TII); 279 // SpvScalConst should be created before SpvVecConst to avoid undefined ID 280 // error on validation. 281 // TODO: can moved below once sorting of types/consts/defs is implemented. 282 Register SpvScalConst; 283 if (Val) 284 SpvScalConst = getOrCreateConstInt(Val, I, SpvBaseType, TII); 285 // TODO: maybe use bitwidth of base type. 286 LLT LLTy = LLT::scalar(32); 287 Register SpvVecConst = 288 CurMF->getRegInfo().createGenericVirtualRegister(LLTy); 289 CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::IDRegClass); 290 assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF); 291 DT.add(CA, CurMF, SpvVecConst); 292 MachineInstrBuilder MIB; 293 MachineBasicBlock &BB = *I.getParent(); 294 if (Val) { 295 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantComposite)) 296 .addDef(SpvVecConst) 297 .addUse(getSPIRVTypeID(SpvType)); 298 for (unsigned i = 0; i < ElemCnt; ++i) 299 MIB.addUse(SpvScalConst); 300 } else { 301 MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpConstantNull)) 302 .addDef(SpvVecConst) 303 .addUse(getSPIRVTypeID(SpvType)); 304 } 305 const auto &Subtarget = CurMF->getSubtarget(); 306 constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(), 307 *Subtarget.getRegisterInfo(), 308 *Subtarget.getRegBankInfo()); 309 return SpvVecConst; 310 } 311 return Res; 312 } 313 314 Register 315 SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val, MachineInstr &I, 316 SPIRVType *SpvType, 317 const SPIRVInstrInfo &TII) { 318 const Type *LLVMTy = getTypeForSPIRVType(SpvType); 319 assert(LLVMTy->isVectorTy()); 320 const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy); 321 Type *LLVMBaseTy = LLVMVecTy->getElementType(); 322 const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val); 323 auto ConstVec = 324 ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt); 325 unsigned BW = getScalarOrVectorBitWidth(SpvType); 326 return getOrCreateIntCompositeOrNull(Val, I, SpvType, TII, ConstVec, BW, 327 SpvType->getOperand(2).getImm()); 328 } 329 330 Register 331 SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val, MachineInstr &I, 332 SPIRVType *SpvType, 333 const SPIRVInstrInfo &TII) { 334 const Type *LLVMTy = getTypeForSPIRVType(SpvType); 335 assert(LLVMTy->isArrayTy()); 336 const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy); 337 Type *LLVMBaseTy = LLVMArrTy->getElementType(); 338 const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val); 339 auto ConstArr = 340 ConstantArray::get(const_cast<ArrayType *>(LLVMArrTy), {ConstInt}); 341 SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg()); 342 unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy); 343 return getOrCreateIntCompositeOrNull(Val, I, SpvType, TII, ConstArr, BW, 344 LLVMArrTy->getNumElements()); 345 } 346 347 Register SPIRVGlobalRegistry::getOrCreateIntCompositeOrNull( 348 uint64_t Val, MachineIRBuilder &MIRBuilder, SPIRVType *SpvType, bool EmitIR, 349 Constant *CA, unsigned BitWidth, unsigned ElemCnt) { 350 Register Res = DT.find(CA, CurMF); 351 if (!Res.isValid()) { 352 Register SpvScalConst; 353 if (Val || EmitIR) { 354 SPIRVType *SpvBaseType = 355 getOrCreateSPIRVIntegerType(BitWidth, MIRBuilder); 356 SpvScalConst = buildConstantInt(Val, MIRBuilder, SpvBaseType, EmitIR); 357 } 358 LLT LLTy = EmitIR ? LLT::fixed_vector(ElemCnt, BitWidth) : LLT::scalar(32); 359 Register SpvVecConst = 360 CurMF->getRegInfo().createGenericVirtualRegister(LLTy); 361 CurMF->getRegInfo().setRegClass(SpvVecConst, &SPIRV::IDRegClass); 362 assignSPIRVTypeToVReg(SpvType, SpvVecConst, *CurMF); 363 DT.add(CA, CurMF, SpvVecConst); 364 if (EmitIR) { 365 MIRBuilder.buildSplatVector(SpvVecConst, SpvScalConst); 366 } else { 367 if (Val) { 368 auto MIB = MIRBuilder.buildInstr(SPIRV::OpConstantComposite) 369 .addDef(SpvVecConst) 370 .addUse(getSPIRVTypeID(SpvType)); 371 for (unsigned i = 0; i < ElemCnt; ++i) 372 MIB.addUse(SpvScalConst); 373 } else { 374 MIRBuilder.buildInstr(SPIRV::OpConstantNull) 375 .addDef(SpvVecConst) 376 .addUse(getSPIRVTypeID(SpvType)); 377 } 378 } 379 return SpvVecConst; 380 } 381 return Res; 382 } 383 384 Register 385 SPIRVGlobalRegistry::getOrCreateConsIntVector(uint64_t Val, 386 MachineIRBuilder &MIRBuilder, 387 SPIRVType *SpvType, bool EmitIR) { 388 const Type *LLVMTy = getTypeForSPIRVType(SpvType); 389 assert(LLVMTy->isVectorTy()); 390 const FixedVectorType *LLVMVecTy = cast<FixedVectorType>(LLVMTy); 391 Type *LLVMBaseTy = LLVMVecTy->getElementType(); 392 const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val); 393 auto ConstVec = 394 ConstantVector::getSplat(LLVMVecTy->getElementCount(), ConstInt); 395 unsigned BW = getScalarOrVectorBitWidth(SpvType); 396 return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR, 397 ConstVec, BW, 398 SpvType->getOperand(2).getImm()); 399 } 400 401 Register 402 SPIRVGlobalRegistry::getOrCreateConsIntArray(uint64_t Val, 403 MachineIRBuilder &MIRBuilder, 404 SPIRVType *SpvType, bool EmitIR) { 405 const Type *LLVMTy = getTypeForSPIRVType(SpvType); 406 assert(LLVMTy->isArrayTy()); 407 const ArrayType *LLVMArrTy = cast<ArrayType>(LLVMTy); 408 Type *LLVMBaseTy = LLVMArrTy->getElementType(); 409 const auto ConstInt = ConstantInt::get(LLVMBaseTy, Val); 410 auto ConstArr = 411 ConstantArray::get(const_cast<ArrayType *>(LLVMArrTy), {ConstInt}); 412 SPIRVType *SpvBaseTy = getSPIRVTypeForVReg(SpvType->getOperand(1).getReg()); 413 unsigned BW = getScalarOrVectorBitWidth(SpvBaseTy); 414 return getOrCreateIntCompositeOrNull(Val, MIRBuilder, SpvType, EmitIR, 415 ConstArr, BW, 416 LLVMArrTy->getNumElements()); 417 } 418 419 Register 420 SPIRVGlobalRegistry::getOrCreateConstNullPtr(MachineIRBuilder &MIRBuilder, 421 SPIRVType *SpvType) { 422 const Type *LLVMTy = getTypeForSPIRVType(SpvType); 423 const PointerType *LLVMPtrTy = cast<PointerType>(LLVMTy); 424 // Find a constant in DT or build a new one. 425 Constant *CP = ConstantPointerNull::get(const_cast<PointerType *>(LLVMPtrTy)); 426 Register Res = DT.find(CP, CurMF); 427 if (!Res.isValid()) { 428 LLT LLTy = LLT::pointer(LLVMPtrTy->getAddressSpace(), PointerSize); 429 Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy); 430 CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass); 431 assignSPIRVTypeToVReg(SpvType, Res, *CurMF); 432 MIRBuilder.buildInstr(SPIRV::OpConstantNull) 433 .addDef(Res) 434 .addUse(getSPIRVTypeID(SpvType)); 435 DT.add(CP, CurMF, Res); 436 } 437 return Res; 438 } 439 440 Register SPIRVGlobalRegistry::buildConstantSampler( 441 Register ResReg, unsigned AddrMode, unsigned Param, unsigned FilerMode, 442 MachineIRBuilder &MIRBuilder, SPIRVType *SpvType) { 443 SPIRVType *SampTy; 444 if (SpvType) 445 SampTy = getOrCreateSPIRVType(getTypeForSPIRVType(SpvType), MIRBuilder); 446 else 447 SampTy = getOrCreateSPIRVTypeByName("opencl.sampler_t", MIRBuilder); 448 449 auto Sampler = 450 ResReg.isValid() 451 ? ResReg 452 : MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass); 453 auto Res = MIRBuilder.buildInstr(SPIRV::OpConstantSampler) 454 .addDef(Sampler) 455 .addUse(getSPIRVTypeID(SampTy)) 456 .addImm(AddrMode) 457 .addImm(Param) 458 .addImm(FilerMode); 459 assert(Res->getOperand(0).isReg()); 460 return Res->getOperand(0).getReg(); 461 } 462 463 Register SPIRVGlobalRegistry::buildGlobalVariable( 464 Register ResVReg, SPIRVType *BaseType, StringRef Name, 465 const GlobalValue *GV, SPIRV::StorageClass::StorageClass Storage, 466 const MachineInstr *Init, bool IsConst, bool HasLinkageTy, 467 SPIRV::LinkageType::LinkageType LinkageType, MachineIRBuilder &MIRBuilder, 468 bool IsInstSelector) { 469 const GlobalVariable *GVar = nullptr; 470 if (GV) 471 GVar = cast<const GlobalVariable>(GV); 472 else { 473 // If GV is not passed explicitly, use the name to find or construct 474 // the global variable. 475 Module *M = MIRBuilder.getMF().getFunction().getParent(); 476 GVar = M->getGlobalVariable(Name); 477 if (GVar == nullptr) { 478 const Type *Ty = getTypeForSPIRVType(BaseType); // TODO: check type. 479 GVar = new GlobalVariable(*M, const_cast<Type *>(Ty), false, 480 GlobalValue::ExternalLinkage, nullptr, 481 Twine(Name)); 482 } 483 GV = GVar; 484 } 485 Register Reg = DT.find(GVar, &MIRBuilder.getMF()); 486 if (Reg.isValid()) { 487 if (Reg != ResVReg) 488 MIRBuilder.buildCopy(ResVReg, Reg); 489 return ResVReg; 490 } 491 492 auto MIB = MIRBuilder.buildInstr(SPIRV::OpVariable) 493 .addDef(ResVReg) 494 .addUse(getSPIRVTypeID(BaseType)) 495 .addImm(static_cast<uint32_t>(Storage)); 496 497 if (Init != 0) { 498 MIB.addUse(Init->getOperand(0).getReg()); 499 } 500 501 // ISel may introduce a new register on this step, so we need to add it to 502 // DT and correct its type avoiding fails on the next stage. 503 if (IsInstSelector) { 504 const auto &Subtarget = CurMF->getSubtarget(); 505 constrainSelectedInstRegOperands(*MIB, *Subtarget.getInstrInfo(), 506 *Subtarget.getRegisterInfo(), 507 *Subtarget.getRegBankInfo()); 508 } 509 Reg = MIB->getOperand(0).getReg(); 510 DT.add(GVar, &MIRBuilder.getMF(), Reg); 511 512 // Set to Reg the same type as ResVReg has. 513 auto MRI = MIRBuilder.getMRI(); 514 assert(MRI->getType(ResVReg).isPointer() && "Pointer type is expected"); 515 if (Reg != ResVReg) { 516 LLT RegLLTy = LLT::pointer(MRI->getType(ResVReg).getAddressSpace(), 32); 517 MRI->setType(Reg, RegLLTy); 518 assignSPIRVTypeToVReg(BaseType, Reg, MIRBuilder.getMF()); 519 } 520 521 // If it's a global variable with name, output OpName for it. 522 if (GVar && GVar->hasName()) 523 buildOpName(Reg, GVar->getName(), MIRBuilder); 524 525 // Output decorations for the GV. 526 // TODO: maybe move to GenerateDecorations pass. 527 if (IsConst) 528 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Constant, {}); 529 530 if (GVar && GVar->getAlign().valueOrOne().value() != 1) { 531 unsigned Alignment = (unsigned)GVar->getAlign().valueOrOne().value(); 532 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::Alignment, {Alignment}); 533 } 534 535 if (HasLinkageTy) 536 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::LinkageAttributes, 537 {static_cast<uint32_t>(LinkageType)}, Name); 538 539 SPIRV::BuiltIn::BuiltIn BuiltInId; 540 if (getSpirvBuiltInIdByName(Name, BuiltInId)) 541 buildOpDecorate(Reg, MIRBuilder, SPIRV::Decoration::BuiltIn, 542 {static_cast<uint32_t>(BuiltInId)}); 543 544 return Reg; 545 } 546 547 SPIRVType *SPIRVGlobalRegistry::getOpTypeArray(uint32_t NumElems, 548 SPIRVType *ElemType, 549 MachineIRBuilder &MIRBuilder, 550 bool EmitIR) { 551 assert((ElemType->getOpcode() != SPIRV::OpTypeVoid) && 552 "Invalid array element type"); 553 Register NumElementsVReg = 554 buildConstantInt(NumElems, MIRBuilder, nullptr, EmitIR); 555 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeArray) 556 .addDef(createTypeVReg(MIRBuilder)) 557 .addUse(getSPIRVTypeID(ElemType)) 558 .addUse(NumElementsVReg); 559 return MIB; 560 } 561 562 SPIRVType *SPIRVGlobalRegistry::getOpTypeOpaque(const StructType *Ty, 563 MachineIRBuilder &MIRBuilder) { 564 assert(Ty->hasName()); 565 const StringRef Name = Ty->hasName() ? Ty->getName() : ""; 566 Register ResVReg = createTypeVReg(MIRBuilder); 567 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeOpaque).addDef(ResVReg); 568 addStringImm(Name, MIB); 569 buildOpName(ResVReg, Name, MIRBuilder); 570 return MIB; 571 } 572 573 SPIRVType *SPIRVGlobalRegistry::getOpTypeStruct(const StructType *Ty, 574 MachineIRBuilder &MIRBuilder, 575 bool EmitIR) { 576 SmallVector<Register, 4> FieldTypes; 577 for (const auto &Elem : Ty->elements()) { 578 SPIRVType *ElemTy = findSPIRVType(Elem, MIRBuilder); 579 assert(ElemTy && ElemTy->getOpcode() != SPIRV::OpTypeVoid && 580 "Invalid struct element type"); 581 FieldTypes.push_back(getSPIRVTypeID(ElemTy)); 582 } 583 Register ResVReg = createTypeVReg(MIRBuilder); 584 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeStruct).addDef(ResVReg); 585 for (const auto &Ty : FieldTypes) 586 MIB.addUse(Ty); 587 if (Ty->hasName()) 588 buildOpName(ResVReg, Ty->getName(), MIRBuilder); 589 if (Ty->isPacked()) 590 buildOpDecorate(ResVReg, MIRBuilder, SPIRV::Decoration::CPacked, {}); 591 return MIB; 592 } 593 594 SPIRVType *SPIRVGlobalRegistry::getOrCreateSpecialType( 595 const Type *Ty, MachineIRBuilder &MIRBuilder, 596 SPIRV::AccessQualifier::AccessQualifier AccQual) { 597 assert(isSpecialOpaqueType(Ty) && "Not a special opaque builtin type"); 598 return SPIRV::lowerBuiltinType(Ty, AccQual, MIRBuilder, this); 599 } 600 601 SPIRVType *SPIRVGlobalRegistry::getOpTypePointer( 602 SPIRV::StorageClass::StorageClass SC, SPIRVType *ElemType, 603 MachineIRBuilder &MIRBuilder, Register Reg) { 604 if (!Reg.isValid()) 605 Reg = createTypeVReg(MIRBuilder); 606 return MIRBuilder.buildInstr(SPIRV::OpTypePointer) 607 .addDef(Reg) 608 .addImm(static_cast<uint32_t>(SC)) 609 .addUse(getSPIRVTypeID(ElemType)); 610 } 611 612 SPIRVType *SPIRVGlobalRegistry::getOpTypeForwardPointer( 613 SPIRV::StorageClass::StorageClass SC, MachineIRBuilder &MIRBuilder) { 614 return MIRBuilder.buildInstr(SPIRV::OpTypeForwardPointer) 615 .addUse(createTypeVReg(MIRBuilder)) 616 .addImm(static_cast<uint32_t>(SC)); 617 } 618 619 SPIRVType *SPIRVGlobalRegistry::getOpTypeFunction( 620 SPIRVType *RetType, const SmallVectorImpl<SPIRVType *> &ArgTypes, 621 MachineIRBuilder &MIRBuilder) { 622 auto MIB = MIRBuilder.buildInstr(SPIRV::OpTypeFunction) 623 .addDef(createTypeVReg(MIRBuilder)) 624 .addUse(getSPIRVTypeID(RetType)); 625 for (const SPIRVType *ArgType : ArgTypes) 626 MIB.addUse(getSPIRVTypeID(ArgType)); 627 return MIB; 628 } 629 630 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeFunctionWithArgs( 631 const Type *Ty, SPIRVType *RetType, 632 const SmallVectorImpl<SPIRVType *> &ArgTypes, 633 MachineIRBuilder &MIRBuilder) { 634 Register Reg = DT.find(Ty, &MIRBuilder.getMF()); 635 if (Reg.isValid()) 636 return getSPIRVTypeForVReg(Reg); 637 SPIRVType *SpirvType = getOpTypeFunction(RetType, ArgTypes, MIRBuilder); 638 DT.add(Ty, CurMF, getSPIRVTypeID(SpirvType)); 639 return finishCreatingSPIRVType(Ty, SpirvType); 640 } 641 642 SPIRVType *SPIRVGlobalRegistry::findSPIRVType( 643 const Type *Ty, MachineIRBuilder &MIRBuilder, 644 SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) { 645 Register Reg = DT.find(Ty, &MIRBuilder.getMF()); 646 if (Reg.isValid()) 647 return getSPIRVTypeForVReg(Reg); 648 if (ForwardPointerTypes.contains(Ty)) 649 return ForwardPointerTypes[Ty]; 650 return restOfCreateSPIRVType(Ty, MIRBuilder, AccQual, EmitIR); 651 } 652 653 Register SPIRVGlobalRegistry::getSPIRVTypeID(const SPIRVType *SpirvType) const { 654 assert(SpirvType && "Attempting to get type id for nullptr type."); 655 if (SpirvType->getOpcode() == SPIRV::OpTypeForwardPointer) 656 return SpirvType->uses().begin()->getReg(); 657 return SpirvType->defs().begin()->getReg(); 658 } 659 660 SPIRVType *SPIRVGlobalRegistry::createSPIRVType( 661 const Type *Ty, MachineIRBuilder &MIRBuilder, 662 SPIRV::AccessQualifier::AccessQualifier AccQual, bool EmitIR) { 663 if (isSpecialOpaqueType(Ty)) 664 return getOrCreateSpecialType(Ty, MIRBuilder, AccQual); 665 auto &TypeToSPIRVTypeMap = DT.getTypes()->getAllUses(); 666 auto t = TypeToSPIRVTypeMap.find(Ty); 667 if (t != TypeToSPIRVTypeMap.end()) { 668 auto tt = t->second.find(&MIRBuilder.getMF()); 669 if (tt != t->second.end()) 670 return getSPIRVTypeForVReg(tt->second); 671 } 672 673 if (auto IType = dyn_cast<IntegerType>(Ty)) { 674 const unsigned Width = IType->getBitWidth(); 675 return Width == 1 ? getOpTypeBool(MIRBuilder) 676 : getOpTypeInt(Width, MIRBuilder, false); 677 } 678 if (Ty->isFloatingPointTy()) 679 return getOpTypeFloat(Ty->getPrimitiveSizeInBits(), MIRBuilder); 680 if (Ty->isVoidTy()) 681 return getOpTypeVoid(MIRBuilder); 682 if (Ty->isVectorTy()) { 683 SPIRVType *El = 684 findSPIRVType(cast<FixedVectorType>(Ty)->getElementType(), MIRBuilder); 685 return getOpTypeVector(cast<FixedVectorType>(Ty)->getNumElements(), El, 686 MIRBuilder); 687 } 688 if (Ty->isArrayTy()) { 689 SPIRVType *El = findSPIRVType(Ty->getArrayElementType(), MIRBuilder); 690 return getOpTypeArray(Ty->getArrayNumElements(), El, MIRBuilder, EmitIR); 691 } 692 if (auto SType = dyn_cast<StructType>(Ty)) { 693 if (SType->isOpaque()) 694 return getOpTypeOpaque(SType, MIRBuilder); 695 return getOpTypeStruct(SType, MIRBuilder, EmitIR); 696 } 697 if (auto FType = dyn_cast<FunctionType>(Ty)) { 698 SPIRVType *RetTy = findSPIRVType(FType->getReturnType(), MIRBuilder); 699 SmallVector<SPIRVType *, 4> ParamTypes; 700 for (const auto &t : FType->params()) { 701 ParamTypes.push_back(findSPIRVType(t, MIRBuilder)); 702 } 703 return getOpTypeFunction(RetTy, ParamTypes, MIRBuilder); 704 } 705 if (auto PType = dyn_cast<PointerType>(Ty)) { 706 SPIRVType *SpvElementType; 707 // At the moment, all opaque pointers correspond to i8 element type. 708 // TODO: change the implementation once opaque pointers are supported 709 // in the SPIR-V specification. 710 SpvElementType = getOrCreateSPIRVIntegerType(8, MIRBuilder); 711 auto SC = addressSpaceToStorageClass(PType->getAddressSpace()); 712 // Null pointer means we have a loop in type definitions, make and 713 // return corresponding OpTypeForwardPointer. 714 if (SpvElementType == nullptr) { 715 if (!ForwardPointerTypes.contains(Ty)) 716 ForwardPointerTypes[PType] = getOpTypeForwardPointer(SC, MIRBuilder); 717 return ForwardPointerTypes[PType]; 718 } 719 Register Reg(0); 720 // If we have forward pointer associated with this type, use its register 721 // operand to create OpTypePointer. 722 if (ForwardPointerTypes.contains(PType)) 723 Reg = getSPIRVTypeID(ForwardPointerTypes[PType]); 724 725 return getOpTypePointer(SC, SpvElementType, MIRBuilder, Reg); 726 } 727 llvm_unreachable("Unable to convert LLVM type to SPIRVType"); 728 } 729 730 SPIRVType *SPIRVGlobalRegistry::restOfCreateSPIRVType( 731 const Type *Ty, MachineIRBuilder &MIRBuilder, 732 SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) { 733 if (TypesInProcessing.count(Ty) && !Ty->isPointerTy()) 734 return nullptr; 735 TypesInProcessing.insert(Ty); 736 SPIRVType *SpirvType = createSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR); 737 TypesInProcessing.erase(Ty); 738 VRegToTypeMap[&MIRBuilder.getMF()][getSPIRVTypeID(SpirvType)] = SpirvType; 739 SPIRVToLLVMType[SpirvType] = Ty; 740 Register Reg = DT.find(Ty, &MIRBuilder.getMF()); 741 // Do not add OpTypeForwardPointer to DT, a corresponding normal pointer type 742 // will be added later. For special types it is already added to DT. 743 if (SpirvType->getOpcode() != SPIRV::OpTypeForwardPointer && !Reg.isValid() && 744 !isSpecialOpaqueType(Ty)) { 745 if (!Ty->isPointerTy()) 746 DT.add(Ty, &MIRBuilder.getMF(), getSPIRVTypeID(SpirvType)); 747 else 748 DT.add(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()), 749 Ty->getPointerAddressSpace(), &MIRBuilder.getMF(), 750 getSPIRVTypeID(SpirvType)); 751 } 752 753 return SpirvType; 754 } 755 756 SPIRVType *SPIRVGlobalRegistry::getSPIRVTypeForVReg(Register VReg) const { 757 auto t = VRegToTypeMap.find(CurMF); 758 if (t != VRegToTypeMap.end()) { 759 auto tt = t->second.find(VReg); 760 if (tt != t->second.end()) 761 return tt->second; 762 } 763 return nullptr; 764 } 765 766 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVType( 767 const Type *Ty, MachineIRBuilder &MIRBuilder, 768 SPIRV::AccessQualifier::AccessQualifier AccessQual, bool EmitIR) { 769 Register Reg; 770 if (!Ty->isPointerTy()) 771 Reg = DT.find(Ty, &MIRBuilder.getMF()); 772 else 773 Reg = 774 DT.find(Type::getInt8Ty(MIRBuilder.getMF().getFunction().getContext()), 775 Ty->getPointerAddressSpace(), &MIRBuilder.getMF()); 776 777 if (Reg.isValid() && !isSpecialOpaqueType(Ty)) 778 return getSPIRVTypeForVReg(Reg); 779 TypesInProcessing.clear(); 780 SPIRVType *STy = restOfCreateSPIRVType(Ty, MIRBuilder, AccessQual, EmitIR); 781 // Create normal pointer types for the corresponding OpTypeForwardPointers. 782 for (auto &CU : ForwardPointerTypes) { 783 const Type *Ty2 = CU.first; 784 SPIRVType *STy2 = CU.second; 785 if ((Reg = DT.find(Ty2, &MIRBuilder.getMF())).isValid()) 786 STy2 = getSPIRVTypeForVReg(Reg); 787 else 788 STy2 = restOfCreateSPIRVType(Ty2, MIRBuilder, AccessQual, EmitIR); 789 if (Ty == Ty2) 790 STy = STy2; 791 } 792 ForwardPointerTypes.clear(); 793 return STy; 794 } 795 796 bool SPIRVGlobalRegistry::isScalarOfType(Register VReg, 797 unsigned TypeOpcode) const { 798 SPIRVType *Type = getSPIRVTypeForVReg(VReg); 799 assert(Type && "isScalarOfType VReg has no type assigned"); 800 return Type->getOpcode() == TypeOpcode; 801 } 802 803 bool SPIRVGlobalRegistry::isScalarOrVectorOfType(Register VReg, 804 unsigned TypeOpcode) const { 805 SPIRVType *Type = getSPIRVTypeForVReg(VReg); 806 assert(Type && "isScalarOrVectorOfType VReg has no type assigned"); 807 if (Type->getOpcode() == TypeOpcode) 808 return true; 809 if (Type->getOpcode() == SPIRV::OpTypeVector) { 810 Register ScalarTypeVReg = Type->getOperand(1).getReg(); 811 SPIRVType *ScalarType = getSPIRVTypeForVReg(ScalarTypeVReg); 812 return ScalarType->getOpcode() == TypeOpcode; 813 } 814 return false; 815 } 816 817 unsigned 818 SPIRVGlobalRegistry::getScalarOrVectorBitWidth(const SPIRVType *Type) const { 819 assert(Type && "Invalid Type pointer"); 820 if (Type->getOpcode() == SPIRV::OpTypeVector) { 821 auto EleTypeReg = Type->getOperand(1).getReg(); 822 Type = getSPIRVTypeForVReg(EleTypeReg); 823 } 824 if (Type->getOpcode() == SPIRV::OpTypeInt || 825 Type->getOpcode() == SPIRV::OpTypeFloat) 826 return Type->getOperand(1).getImm(); 827 if (Type->getOpcode() == SPIRV::OpTypeBool) 828 return 1; 829 llvm_unreachable("Attempting to get bit width of non-integer/float type."); 830 } 831 832 bool SPIRVGlobalRegistry::isScalarOrVectorSigned(const SPIRVType *Type) const { 833 assert(Type && "Invalid Type pointer"); 834 if (Type->getOpcode() == SPIRV::OpTypeVector) { 835 auto EleTypeReg = Type->getOperand(1).getReg(); 836 Type = getSPIRVTypeForVReg(EleTypeReg); 837 } 838 if (Type->getOpcode() == SPIRV::OpTypeInt) 839 return Type->getOperand(2).getImm() != 0; 840 llvm_unreachable("Attempting to get sign of non-integer type."); 841 } 842 843 SPIRV::StorageClass::StorageClass 844 SPIRVGlobalRegistry::getPointerStorageClass(Register VReg) const { 845 SPIRVType *Type = getSPIRVTypeForVReg(VReg); 846 assert(Type && Type->getOpcode() == SPIRV::OpTypePointer && 847 Type->getOperand(1).isImm() && "Pointer type is expected"); 848 return static_cast<SPIRV::StorageClass::StorageClass>( 849 Type->getOperand(1).getImm()); 850 } 851 852 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeImage( 853 MachineIRBuilder &MIRBuilder, SPIRVType *SampledType, SPIRV::Dim::Dim Dim, 854 uint32_t Depth, uint32_t Arrayed, uint32_t Multisampled, uint32_t Sampled, 855 SPIRV::ImageFormat::ImageFormat ImageFormat, 856 SPIRV::AccessQualifier::AccessQualifier AccessQual) { 857 SPIRV::ImageTypeDescriptor TD(SPIRVToLLVMType.lookup(SampledType), Dim, Depth, 858 Arrayed, Multisampled, Sampled, ImageFormat, 859 AccessQual); 860 if (auto *Res = checkSpecialInstr(TD, MIRBuilder)) 861 return Res; 862 Register ResVReg = createTypeVReg(MIRBuilder); 863 DT.add(TD, &MIRBuilder.getMF(), ResVReg); 864 return MIRBuilder.buildInstr(SPIRV::OpTypeImage) 865 .addDef(ResVReg) 866 .addUse(getSPIRVTypeID(SampledType)) 867 .addImm(Dim) 868 .addImm(Depth) // Depth (whether or not it is a Depth image). 869 .addImm(Arrayed) // Arrayed. 870 .addImm(Multisampled) // Multisampled (0 = only single-sample). 871 .addImm(Sampled) // Sampled (0 = usage known at runtime). 872 .addImm(ImageFormat) 873 .addImm(AccessQual); 874 } 875 876 SPIRVType * 877 SPIRVGlobalRegistry::getOrCreateOpTypeSampler(MachineIRBuilder &MIRBuilder) { 878 SPIRV::SamplerTypeDescriptor TD; 879 if (auto *Res = checkSpecialInstr(TD, MIRBuilder)) 880 return Res; 881 Register ResVReg = createTypeVReg(MIRBuilder); 882 DT.add(TD, &MIRBuilder.getMF(), ResVReg); 883 return MIRBuilder.buildInstr(SPIRV::OpTypeSampler).addDef(ResVReg); 884 } 885 886 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypePipe( 887 MachineIRBuilder &MIRBuilder, 888 SPIRV::AccessQualifier::AccessQualifier AccessQual) { 889 SPIRV::PipeTypeDescriptor TD(AccessQual); 890 if (auto *Res = checkSpecialInstr(TD, MIRBuilder)) 891 return Res; 892 Register ResVReg = createTypeVReg(MIRBuilder); 893 DT.add(TD, &MIRBuilder.getMF(), ResVReg); 894 return MIRBuilder.buildInstr(SPIRV::OpTypePipe) 895 .addDef(ResVReg) 896 .addImm(AccessQual); 897 } 898 899 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeDeviceEvent( 900 MachineIRBuilder &MIRBuilder) { 901 SPIRV::DeviceEventTypeDescriptor TD; 902 if (auto *Res = checkSpecialInstr(TD, MIRBuilder)) 903 return Res; 904 Register ResVReg = createTypeVReg(MIRBuilder); 905 DT.add(TD, &MIRBuilder.getMF(), ResVReg); 906 return MIRBuilder.buildInstr(SPIRV::OpTypeDeviceEvent).addDef(ResVReg); 907 } 908 909 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeSampledImage( 910 SPIRVType *ImageType, MachineIRBuilder &MIRBuilder) { 911 SPIRV::SampledImageTypeDescriptor TD( 912 SPIRVToLLVMType.lookup(MIRBuilder.getMF().getRegInfo().getVRegDef( 913 ImageType->getOperand(1).getReg())), 914 ImageType); 915 if (auto *Res = checkSpecialInstr(TD, MIRBuilder)) 916 return Res; 917 Register ResVReg = createTypeVReg(MIRBuilder); 918 DT.add(TD, &MIRBuilder.getMF(), ResVReg); 919 return MIRBuilder.buildInstr(SPIRV::OpTypeSampledImage) 920 .addDef(ResVReg) 921 .addUse(getSPIRVTypeID(ImageType)); 922 } 923 924 SPIRVType *SPIRVGlobalRegistry::getOrCreateOpTypeByOpcode( 925 const Type *Ty, MachineIRBuilder &MIRBuilder, unsigned Opcode) { 926 Register ResVReg = DT.find(Ty, &MIRBuilder.getMF()); 927 if (ResVReg.isValid()) 928 return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(ResVReg); 929 ResVReg = createTypeVReg(MIRBuilder); 930 SPIRVType *SpirvTy = MIRBuilder.buildInstr(Opcode).addDef(ResVReg); 931 DT.add(Ty, &MIRBuilder.getMF(), ResVReg); 932 return SpirvTy; 933 } 934 935 const MachineInstr * 936 SPIRVGlobalRegistry::checkSpecialInstr(const SPIRV::SpecialTypeDescriptor &TD, 937 MachineIRBuilder &MIRBuilder) { 938 Register Reg = DT.find(TD, &MIRBuilder.getMF()); 939 if (Reg.isValid()) 940 return MIRBuilder.getMF().getRegInfo().getUniqueVRegDef(Reg); 941 return nullptr; 942 } 943 944 // TODO: maybe use tablegen to implement this. 945 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVTypeByName( 946 StringRef TypeStr, MachineIRBuilder &MIRBuilder, 947 SPIRV::StorageClass::StorageClass SC, 948 SPIRV::AccessQualifier::AccessQualifier AQ) { 949 unsigned VecElts = 0; 950 auto &Ctx = MIRBuilder.getMF().getFunction().getContext(); 951 952 // Parse strings representing either a SPIR-V or OpenCL builtin type. 953 if (hasBuiltinTypePrefix(TypeStr)) 954 return getOrCreateSPIRVType( 955 SPIRV::parseBuiltinTypeNameToTargetExtType(TypeStr.str(), MIRBuilder), 956 MIRBuilder, AQ); 957 958 // Parse type name in either "typeN" or "type vector[N]" format, where 959 // N is the number of elements of the vector. 960 Type *Ty; 961 962 if (TypeStr.starts_with("atomic_")) 963 TypeStr = TypeStr.substr(strlen("atomic_")); 964 965 if (TypeStr.starts_with("void")) { 966 Ty = Type::getVoidTy(Ctx); 967 TypeStr = TypeStr.substr(strlen("void")); 968 } else if (TypeStr.starts_with("bool")) { 969 Ty = Type::getIntNTy(Ctx, 1); 970 TypeStr = TypeStr.substr(strlen("bool")); 971 } else if (TypeStr.starts_with("char") || TypeStr.starts_with("uchar")) { 972 Ty = Type::getInt8Ty(Ctx); 973 TypeStr = TypeStr.starts_with("char") ? TypeStr.substr(strlen("char")) 974 : TypeStr.substr(strlen("uchar")); 975 } else if (TypeStr.starts_with("short") || TypeStr.starts_with("ushort")) { 976 Ty = Type::getInt16Ty(Ctx); 977 TypeStr = TypeStr.starts_with("short") ? TypeStr.substr(strlen("short")) 978 : TypeStr.substr(strlen("ushort")); 979 } else if (TypeStr.starts_with("int") || TypeStr.starts_with("uint")) { 980 Ty = Type::getInt32Ty(Ctx); 981 TypeStr = TypeStr.starts_with("int") ? TypeStr.substr(strlen("int")) 982 : TypeStr.substr(strlen("uint")); 983 } else if (TypeStr.starts_with("long") || TypeStr.starts_with("ulong")) { 984 Ty = Type::getInt64Ty(Ctx); 985 TypeStr = TypeStr.starts_with("long") ? TypeStr.substr(strlen("long")) 986 : TypeStr.substr(strlen("ulong")); 987 } else if (TypeStr.starts_with("half")) { 988 Ty = Type::getHalfTy(Ctx); 989 TypeStr = TypeStr.substr(strlen("half")); 990 } else if (TypeStr.starts_with("float")) { 991 Ty = Type::getFloatTy(Ctx); 992 TypeStr = TypeStr.substr(strlen("float")); 993 } else if (TypeStr.starts_with("double")) { 994 Ty = Type::getDoubleTy(Ctx); 995 TypeStr = TypeStr.substr(strlen("double")); 996 } else 997 llvm_unreachable("Unable to recognize SPIRV type name."); 998 999 auto SpirvTy = getOrCreateSPIRVType(Ty, MIRBuilder, AQ); 1000 1001 // Handle "type*" or "type* vector[N]". 1002 if (TypeStr.starts_with("*")) { 1003 SpirvTy = getOrCreateSPIRVPointerType(SpirvTy, MIRBuilder, SC); 1004 TypeStr = TypeStr.substr(strlen("*")); 1005 } 1006 1007 // Handle "typeN*" or "type vector[N]*". 1008 bool IsPtrToVec = TypeStr.consume_back("*"); 1009 1010 if (TypeStr.starts_with(" vector[")) { 1011 TypeStr = TypeStr.substr(strlen(" vector[")); 1012 TypeStr = TypeStr.substr(0, TypeStr.find(']')); 1013 } 1014 TypeStr.getAsInteger(10, VecElts); 1015 if (VecElts > 0) 1016 SpirvTy = getOrCreateSPIRVVectorType(SpirvTy, VecElts, MIRBuilder); 1017 1018 if (IsPtrToVec) 1019 SpirvTy = getOrCreateSPIRVPointerType(SpirvTy, MIRBuilder, SC); 1020 1021 return SpirvTy; 1022 } 1023 1024 SPIRVType * 1025 SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType(unsigned BitWidth, 1026 MachineIRBuilder &MIRBuilder) { 1027 return getOrCreateSPIRVType( 1028 IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), BitWidth), 1029 MIRBuilder); 1030 } 1031 1032 SPIRVType *SPIRVGlobalRegistry::finishCreatingSPIRVType(const Type *LLVMTy, 1033 SPIRVType *SpirvType) { 1034 assert(CurMF == SpirvType->getMF()); 1035 VRegToTypeMap[CurMF][getSPIRVTypeID(SpirvType)] = SpirvType; 1036 SPIRVToLLVMType[SpirvType] = LLVMTy; 1037 return SpirvType; 1038 } 1039 1040 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVIntegerType( 1041 unsigned BitWidth, MachineInstr &I, const SPIRVInstrInfo &TII) { 1042 Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), BitWidth); 1043 Register Reg = DT.find(LLVMTy, CurMF); 1044 if (Reg.isValid()) 1045 return getSPIRVTypeForVReg(Reg); 1046 MachineBasicBlock &BB = *I.getParent(); 1047 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeInt)) 1048 .addDef(createTypeVReg(CurMF->getRegInfo())) 1049 .addImm(BitWidth) 1050 .addImm(0); 1051 DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB)); 1052 return finishCreatingSPIRVType(LLVMTy, MIB); 1053 } 1054 1055 SPIRVType * 1056 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineIRBuilder &MIRBuilder) { 1057 return getOrCreateSPIRVType( 1058 IntegerType::get(MIRBuilder.getMF().getFunction().getContext(), 1), 1059 MIRBuilder); 1060 } 1061 1062 SPIRVType * 1063 SPIRVGlobalRegistry::getOrCreateSPIRVBoolType(MachineInstr &I, 1064 const SPIRVInstrInfo &TII) { 1065 Type *LLVMTy = IntegerType::get(CurMF->getFunction().getContext(), 1); 1066 Register Reg = DT.find(LLVMTy, CurMF); 1067 if (Reg.isValid()) 1068 return getSPIRVTypeForVReg(Reg); 1069 MachineBasicBlock &BB = *I.getParent(); 1070 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeBool)) 1071 .addDef(createTypeVReg(CurMF->getRegInfo())); 1072 DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB)); 1073 return finishCreatingSPIRVType(LLVMTy, MIB); 1074 } 1075 1076 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType( 1077 SPIRVType *BaseType, unsigned NumElements, MachineIRBuilder &MIRBuilder) { 1078 return getOrCreateSPIRVType( 1079 FixedVectorType::get(const_cast<Type *>(getTypeForSPIRVType(BaseType)), 1080 NumElements), 1081 MIRBuilder); 1082 } 1083 1084 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVVectorType( 1085 SPIRVType *BaseType, unsigned NumElements, MachineInstr &I, 1086 const SPIRVInstrInfo &TII) { 1087 Type *LLVMTy = FixedVectorType::get( 1088 const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements); 1089 Register Reg = DT.find(LLVMTy, CurMF); 1090 if (Reg.isValid()) 1091 return getSPIRVTypeForVReg(Reg); 1092 MachineBasicBlock &BB = *I.getParent(); 1093 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeVector)) 1094 .addDef(createTypeVReg(CurMF->getRegInfo())) 1095 .addUse(getSPIRVTypeID(BaseType)) 1096 .addImm(NumElements); 1097 DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB)); 1098 return finishCreatingSPIRVType(LLVMTy, MIB); 1099 } 1100 1101 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVArrayType( 1102 SPIRVType *BaseType, unsigned NumElements, MachineInstr &I, 1103 const SPIRVInstrInfo &TII) { 1104 Type *LLVMTy = ArrayType::get( 1105 const_cast<Type *>(getTypeForSPIRVType(BaseType)), NumElements); 1106 Register Reg = DT.find(LLVMTy, CurMF); 1107 if (Reg.isValid()) 1108 return getSPIRVTypeForVReg(Reg); 1109 MachineBasicBlock &BB = *I.getParent(); 1110 SPIRVType *SpirvType = getOrCreateSPIRVIntegerType(32, I, TII); 1111 Register Len = getOrCreateConstInt(NumElements, I, SpirvType, TII); 1112 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypeArray)) 1113 .addDef(createTypeVReg(CurMF->getRegInfo())) 1114 .addUse(getSPIRVTypeID(BaseType)) 1115 .addUse(Len); 1116 DT.add(LLVMTy, CurMF, getSPIRVTypeID(MIB)); 1117 return finishCreatingSPIRVType(LLVMTy, MIB); 1118 } 1119 1120 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType( 1121 SPIRVType *BaseType, MachineIRBuilder &MIRBuilder, 1122 SPIRV::StorageClass::StorageClass SC) { 1123 const Type *PointerElementType = getTypeForSPIRVType(BaseType); 1124 unsigned AddressSpace = storageClassToAddressSpace(SC); 1125 Type *LLVMTy = 1126 PointerType::get(const_cast<Type *>(PointerElementType), AddressSpace); 1127 Register Reg = DT.find(PointerElementType, AddressSpace, CurMF); 1128 if (Reg.isValid()) 1129 return getSPIRVTypeForVReg(Reg); 1130 auto MIB = BuildMI(MIRBuilder.getMBB(), MIRBuilder.getInsertPt(), 1131 MIRBuilder.getDebugLoc(), 1132 MIRBuilder.getTII().get(SPIRV::OpTypePointer)) 1133 .addDef(createTypeVReg(CurMF->getRegInfo())) 1134 .addImm(static_cast<uint32_t>(SC)) 1135 .addUse(getSPIRVTypeID(BaseType)); 1136 DT.add(PointerElementType, AddressSpace, CurMF, getSPIRVTypeID(MIB)); 1137 return finishCreatingSPIRVType(LLVMTy, MIB); 1138 } 1139 1140 SPIRVType *SPIRVGlobalRegistry::getOrCreateSPIRVPointerType( 1141 SPIRVType *BaseType, MachineInstr &I, const SPIRVInstrInfo &TII, 1142 SPIRV::StorageClass::StorageClass SC) { 1143 const Type *PointerElementType = getTypeForSPIRVType(BaseType); 1144 unsigned AddressSpace = storageClassToAddressSpace(SC); 1145 Type *LLVMTy = 1146 PointerType::get(const_cast<Type *>(PointerElementType), AddressSpace); 1147 Register Reg = DT.find(PointerElementType, AddressSpace, CurMF); 1148 if (Reg.isValid()) 1149 return getSPIRVTypeForVReg(Reg); 1150 MachineBasicBlock &BB = *I.getParent(); 1151 auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpTypePointer)) 1152 .addDef(createTypeVReg(CurMF->getRegInfo())) 1153 .addImm(static_cast<uint32_t>(SC)) 1154 .addUse(getSPIRVTypeID(BaseType)); 1155 DT.add(PointerElementType, AddressSpace, CurMF, getSPIRVTypeID(MIB)); 1156 return finishCreatingSPIRVType(LLVMTy, MIB); 1157 } 1158 1159 Register SPIRVGlobalRegistry::getOrCreateUndef(MachineInstr &I, 1160 SPIRVType *SpvType, 1161 const SPIRVInstrInfo &TII) { 1162 assert(SpvType); 1163 const Type *LLVMTy = getTypeForSPIRVType(SpvType); 1164 assert(LLVMTy); 1165 // Find a constant in DT or build a new one. 1166 UndefValue *UV = UndefValue::get(const_cast<Type *>(LLVMTy)); 1167 Register Res = DT.find(UV, CurMF); 1168 if (Res.isValid()) 1169 return Res; 1170 LLT LLTy = LLT::scalar(32); 1171 Res = CurMF->getRegInfo().createGenericVirtualRegister(LLTy); 1172 CurMF->getRegInfo().setRegClass(Res, &SPIRV::IDRegClass); 1173 assignSPIRVTypeToVReg(SpvType, Res, *CurMF); 1174 DT.add(UV, CurMF, Res); 1175 1176 MachineInstrBuilder MIB; 1177 MIB = BuildMI(*I.getParent(), I, I.getDebugLoc(), TII.get(SPIRV::OpUndef)) 1178 .addDef(Res) 1179 .addUse(getSPIRVTypeID(SpvType)); 1180 const auto &ST = CurMF->getSubtarget(); 1181 constrainSelectedInstRegOperands(*MIB, *ST.getInstrInfo(), 1182 *ST.getRegisterInfo(), *ST.getRegBankInfo()); 1183 return Res; 1184 } 1185