1 //===--- SPIRVCallLowering.cpp - Call lowering ------------------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 // This file implements the lowering of LLVM calls to machine code calls for 10 // GlobalISel. 11 // 12 //===----------------------------------------------------------------------===// 13 14 #include "SPIRVCallLowering.h" 15 #include "MCTargetDesc/SPIRVBaseInfo.h" 16 #include "SPIRV.h" 17 #include "SPIRVGlobalRegistry.h" 18 #include "SPIRVISelLowering.h" 19 #include "SPIRVRegisterInfo.h" 20 #include "SPIRVSubtarget.h" 21 #include "SPIRVUtils.h" 22 #include "llvm/CodeGen/FunctionLoweringInfo.h" 23 24 using namespace llvm; 25 26 SPIRVCallLowering::SPIRVCallLowering(const SPIRVTargetLowering &TLI, 27 const SPIRVSubtarget &ST, 28 SPIRVGlobalRegistry *GR) 29 : CallLowering(&TLI), ST(ST), GR(GR) {} 30 31 bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder, 32 const Value *Val, ArrayRef<Register> VRegs, 33 FunctionLoweringInfo &FLI, 34 Register SwiftErrorVReg) const { 35 // Currently all return types should use a single register. 36 // TODO: handle the case of multiple registers. 37 if (VRegs.size() > 1) 38 return false; 39 if (Val) 40 return MIRBuilder.buildInstr(SPIRV::OpReturnValue) 41 .addUse(VRegs[0]) 42 .constrainAllUses(MIRBuilder.getTII(), *ST.getRegisterInfo(), 43 *ST.getRegBankInfo()); 44 MIRBuilder.buildInstr(SPIRV::OpReturn); 45 return true; 46 } 47 48 // Based on the LLVM function attributes, get a SPIR-V FunctionControl. 49 static uint32_t getFunctionControl(const Function &F) { 50 uint32_t FuncControl = static_cast<uint32_t>(SPIRV::FunctionControl::None); 51 if (F.hasFnAttribute(Attribute::AttrKind::AlwaysInline)) { 52 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Inline); 53 } 54 if (F.hasFnAttribute(Attribute::AttrKind::ReadNone)) { 55 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Pure); 56 } 57 if (F.hasFnAttribute(Attribute::AttrKind::ReadOnly)) { 58 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Const); 59 } 60 if (F.hasFnAttribute(Attribute::AttrKind::NoInline)) { 61 FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::DontInline); 62 } 63 return FuncControl; 64 } 65 66 bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder, 67 const Function &F, 68 ArrayRef<ArrayRef<Register>> VRegs, 69 FunctionLoweringInfo &FLI) const { 70 assert(GR && "Must initialize the SPIRV type registry before lowering args."); 71 GR->setCurrentFunc(MIRBuilder.getMF()); 72 73 // Assign types and names to all args, and store their types for later. 74 SmallVector<Register, 4> ArgTypeVRegs; 75 if (VRegs.size() > 0) { 76 unsigned i = 0; 77 for (const auto &Arg : F.args()) { 78 // Currently formal args should use single registers. 79 // TODO: handle the case of multiple registers. 80 if (VRegs[i].size() > 1) 81 return false; 82 auto *SpirvTy = 83 GR->assignTypeToVReg(Arg.getType(), VRegs[i][0], MIRBuilder); 84 ArgTypeVRegs.push_back(GR->getSPIRVTypeID(SpirvTy)); 85 86 if (Arg.hasName()) 87 buildOpName(VRegs[i][0], Arg.getName(), MIRBuilder); 88 if (Arg.getType()->isPointerTy()) { 89 auto DerefBytes = static_cast<unsigned>(Arg.getDereferenceableBytes()); 90 if (DerefBytes != 0) 91 buildOpDecorate(VRegs[i][0], MIRBuilder, 92 SPIRV::Decoration::MaxByteOffset, {DerefBytes}); 93 } 94 if (Arg.hasAttribute(Attribute::Alignment)) { 95 buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::Alignment, 96 {static_cast<unsigned>(Arg.getParamAlignment())}); 97 } 98 if (Arg.hasAttribute(Attribute::ReadOnly)) { 99 auto Attr = 100 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoWrite); 101 buildOpDecorate(VRegs[i][0], MIRBuilder, 102 SPIRV::Decoration::FuncParamAttr, {Attr}); 103 } 104 if (Arg.hasAttribute(Attribute::ZExt)) { 105 auto Attr = 106 static_cast<unsigned>(SPIRV::FunctionParameterAttribute::Zext); 107 buildOpDecorate(VRegs[i][0], MIRBuilder, 108 SPIRV::Decoration::FuncParamAttr, {Attr}); 109 } 110 ++i; 111 } 112 } 113 114 // Generate a SPIR-V type for the function. 115 auto MRI = MIRBuilder.getMRI(); 116 Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32)); 117 MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass); 118 if (F.isDeclaration()) 119 GR->add(&F, &MIRBuilder.getMF(), FuncVReg); 120 121 auto *FTy = F.getFunctionType(); 122 auto FuncTy = GR->assignTypeToVReg(FTy, FuncVReg, MIRBuilder); 123 124 // Build the OpTypeFunction declaring it. 125 Register ReturnTypeID = FuncTy->getOperand(1).getReg(); 126 uint32_t FuncControl = getFunctionControl(F); 127 128 MIRBuilder.buildInstr(SPIRV::OpFunction) 129 .addDef(FuncVReg) 130 .addUse(ReturnTypeID) 131 .addImm(FuncControl) 132 .addUse(GR->getSPIRVTypeID(FuncTy)); 133 134 // Add OpFunctionParameters. 135 const unsigned NumArgs = ArgTypeVRegs.size(); 136 for (unsigned i = 0; i < NumArgs; ++i) { 137 assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs"); 138 MRI->setRegClass(VRegs[i][0], &SPIRV::IDRegClass); 139 MIRBuilder.buildInstr(SPIRV::OpFunctionParameter) 140 .addDef(VRegs[i][0]) 141 .addUse(ArgTypeVRegs[i]); 142 if (F.isDeclaration()) 143 GR->add(F.getArg(i), &MIRBuilder.getMF(), VRegs[i][0]); 144 } 145 // Name the function. 146 if (F.hasName()) 147 buildOpName(FuncVReg, F.getName(), MIRBuilder); 148 149 // Handle entry points and function linkage. 150 if (F.getCallingConv() == CallingConv::SPIR_KERNEL) { 151 auto MIB = MIRBuilder.buildInstr(SPIRV::OpEntryPoint) 152 .addImm(static_cast<uint32_t>(SPIRV::ExecutionModel::Kernel)) 153 .addUse(FuncVReg); 154 addStringImm(F.getName(), MIB); 155 } else if (F.getLinkage() == GlobalValue::LinkageTypes::ExternalLinkage || 156 F.getLinkage() == GlobalValue::LinkOnceODRLinkage) { 157 auto LnkTy = F.isDeclaration() ? SPIRV::LinkageType::Import 158 : SPIRV::LinkageType::Export; 159 buildOpDecorate(FuncVReg, MIRBuilder, SPIRV::Decoration::LinkageAttributes, 160 {static_cast<uint32_t>(LnkTy)}, F.getGlobalIdentifier()); 161 } 162 163 return true; 164 } 165 166 bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder, 167 CallLoweringInfo &Info) const { 168 // Currently call returns should have single vregs. 169 // TODO: handle the case of multiple registers. 170 if (Info.OrigRet.Regs.size() > 1) 171 return false; 172 173 GR->setCurrentFunc(MIRBuilder.getMF()); 174 Register ResVReg = 175 Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0]; 176 // Emit a regular OpFunctionCall. If it's an externally declared function, 177 // be sure to emit its type and function declaration here. It will be 178 // hoisted globally later. 179 if (Info.Callee.isGlobal()) { 180 auto *CF = dyn_cast_or_null<const Function>(Info.Callee.getGlobal()); 181 // TODO: support constexpr casts and indirect calls. 182 if (CF == nullptr) 183 return false; 184 if (CF->isDeclaration()) { 185 // Emit the type info and forward function declaration to the first MBB 186 // to ensure VReg definition dependencies are valid across all MBBs. 187 MachineBasicBlock::iterator OldII = MIRBuilder.getInsertPt(); 188 MachineBasicBlock &OldBB = MIRBuilder.getMBB(); 189 MachineBasicBlock &FirstBB = *MIRBuilder.getMF().getBlockNumbered(0); 190 MIRBuilder.setInsertPt(FirstBB, FirstBB.instr_end()); 191 192 SmallVector<ArrayRef<Register>, 8> VRegArgs; 193 SmallVector<SmallVector<Register, 1>, 8> ToInsert; 194 for (const Argument &Arg : CF->args()) { 195 if (MIRBuilder.getDataLayout().getTypeStoreSize(Arg.getType()).isZero()) 196 continue; // Don't handle zero sized types. 197 ToInsert.push_back({MIRBuilder.getMRI()->createGenericVirtualRegister( 198 LLT::scalar(32))}); 199 VRegArgs.push_back(ToInsert.back()); 200 } 201 // TODO: Reuse FunctionLoweringInfo. 202 FunctionLoweringInfo FuncInfo; 203 lowerFormalArguments(MIRBuilder, *CF, VRegArgs, FuncInfo); 204 MIRBuilder.setInsertPt(OldBB, OldII); 205 } 206 } 207 208 // Make sure there's a valid return reg, even for functions returning void. 209 if (!ResVReg.isValid()) { 210 ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass); 211 } 212 SPIRVType *RetType = 213 GR->assignTypeToVReg(Info.OrigRet.Ty, ResVReg, MIRBuilder); 214 215 // Emit the OpFunctionCall and its args. 216 auto MIB = MIRBuilder.buildInstr(SPIRV::OpFunctionCall) 217 .addDef(ResVReg) 218 .addUse(GR->getSPIRVTypeID(RetType)) 219 .add(Info.Callee); 220 221 for (const auto &Arg : Info.OrigArgs) { 222 // Currently call args should have single vregs. 223 if (Arg.Regs.size() > 1) 224 return false; 225 MIB.addUse(Arg.Regs[0]); 226 } 227 return MIB.constrainAllUses(MIRBuilder.getTII(), *ST.getRegisterInfo(), 228 *ST.getRegBankInfo()); 229 } 230