1 //===--- SPIRVCallLowering.cpp - Call lowering ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the lowering of LLVM calls to machine code calls for
10 // GlobalISel.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "SPIRVCallLowering.h"
15 #include "MCTargetDesc/SPIRVBaseInfo.h"
16 #include "SPIRV.h"
17 #include "SPIRVBuiltins.h"
18 #include "SPIRVGlobalRegistry.h"
19 #include "SPIRVISelLowering.h"
20 #include "SPIRVRegisterInfo.h"
21 #include "SPIRVSubtarget.h"
22 #include "SPIRVUtils.h"
23 #include "llvm/CodeGen/FunctionLoweringInfo.h"
24 #include "llvm/Support/ModRef.h"
25 
26 using namespace llvm;
27 
28 SPIRVCallLowering::SPIRVCallLowering(const SPIRVTargetLowering &TLI,
29                                      SPIRVGlobalRegistry *GR)
30     : CallLowering(&TLI), GR(GR) {}
31 
32 bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
33                                     const Value *Val, ArrayRef<Register> VRegs,
34                                     FunctionLoweringInfo &FLI,
35                                     Register SwiftErrorVReg) const {
36   // Currently all return types should use a single register.
37   // TODO: handle the case of multiple registers.
38   if (VRegs.size() > 1)
39     return false;
40   if (Val) {
41     const auto &STI = MIRBuilder.getMF().getSubtarget();
42     return MIRBuilder.buildInstr(SPIRV::OpReturnValue)
43         .addUse(VRegs[0])
44         .constrainAllUses(MIRBuilder.getTII(), *STI.getRegisterInfo(),
45                           *STI.getRegBankInfo());
46   }
47   MIRBuilder.buildInstr(SPIRV::OpReturn);
48   return true;
49 }
50 
51 // Based on the LLVM function attributes, get a SPIR-V FunctionControl.
52 static uint32_t getFunctionControl(const Function &F) {
53   MemoryEffects MemEffects = F.getMemoryEffects();
54 
55   uint32_t FuncControl = static_cast<uint32_t>(SPIRV::FunctionControl::None);
56 
57   if (F.hasFnAttribute(Attribute::AttrKind::NoInline))
58     FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::DontInline);
59   else if (F.hasFnAttribute(Attribute::AttrKind::AlwaysInline))
60     FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Inline);
61 
62   if (MemEffects.doesNotAccessMemory())
63     FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Pure);
64   else if (MemEffects.onlyReadsMemory())
65     FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Const);
66 
67   return FuncControl;
68 }
69 
70 static ConstantInt *getConstInt(MDNode *MD, unsigned NumOp) {
71   if (MD->getNumOperands() > NumOp) {
72     auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(NumOp));
73     if (CMeta)
74       return dyn_cast<ConstantInt>(CMeta->getValue());
75   }
76   return nullptr;
77 }
78 
79 // This code restores function args/retvalue types for composite cases
80 // because the final types should still be aggregate whereas they're i32
81 // during the translation to cope with aggregate flattening etc.
82 static FunctionType *getOriginalFunctionType(const Function &F) {
83   auto *NamedMD = F.getParent()->getNamedMetadata("spv.cloned_funcs");
84   if (NamedMD == nullptr)
85     return F.getFunctionType();
86 
87   Type *RetTy = F.getFunctionType()->getReturnType();
88   SmallVector<Type *, 4> ArgTypes;
89   for (auto &Arg : F.args())
90     ArgTypes.push_back(Arg.getType());
91 
92   auto ThisFuncMDIt =
93       std::find_if(NamedMD->op_begin(), NamedMD->op_end(), [&F](MDNode *N) {
94         return isa<MDString>(N->getOperand(0)) &&
95                cast<MDString>(N->getOperand(0))->getString() == F.getName();
96       });
97   // TODO: probably one function can have numerous type mutations,
98   // so we should support this.
99   if (ThisFuncMDIt != NamedMD->op_end()) {
100     auto *ThisFuncMD = *ThisFuncMDIt;
101     MDNode *MD = dyn_cast<MDNode>(ThisFuncMD->getOperand(1));
102     assert(MD && "MDNode operand is expected");
103     ConstantInt *Const = getConstInt(MD, 0);
104     if (Const) {
105       auto *CMeta = dyn_cast<ConstantAsMetadata>(MD->getOperand(1));
106       assert(CMeta && "ConstantAsMetadata operand is expected");
107       assert(Const->getSExtValue() >= -1);
108       // Currently -1 indicates return value, greater values mean
109       // argument numbers.
110       if (Const->getSExtValue() == -1)
111         RetTy = CMeta->getType();
112       else
113         ArgTypes[Const->getSExtValue()] = CMeta->getType();
114     }
115   }
116 
117   return FunctionType::get(RetTy, ArgTypes, F.isVarArg());
118 }
119 
120 static MDString *getKernelArgAttribute(const Function &KernelFunction,
121                                        unsigned ArgIdx,
122                                        const StringRef AttributeName) {
123   assert(KernelFunction.getCallingConv() == CallingConv::SPIR_KERNEL &&
124          "Kernel attributes are attached/belong only to kernel functions");
125 
126   // Lookup the argument attribute in metadata attached to the kernel function.
127   MDNode *Node = KernelFunction.getMetadata(AttributeName);
128   if (Node && ArgIdx < Node->getNumOperands())
129     return cast<MDString>(Node->getOperand(ArgIdx));
130 
131   // Sometimes metadata containing kernel attributes is not attached to the
132   // function, but can be found in the named module-level metadata instead.
133   // For example:
134   //   !opencl.kernels = !{!0}
135   //   !0 = !{void ()* @someKernelFunction, !1, ...}
136   //   !1 = !{!"kernel_arg_addr_space", ...}
137   // In this case the actual index of searched argument attribute is ArgIdx + 1,
138   // since the first metadata node operand is occupied by attribute name
139   // ("kernel_arg_addr_space" in the example above).
140   unsigned MDArgIdx = ArgIdx + 1;
141   NamedMDNode *OpenCLKernelsMD =
142       KernelFunction.getParent()->getNamedMetadata("opencl.kernels");
143   if (!OpenCLKernelsMD || OpenCLKernelsMD->getNumOperands() == 0)
144     return nullptr;
145 
146   // KernelToMDNodeList contains kernel function declarations followed by
147   // corresponding MDNodes for each attribute. Search only MDNodes "belonging"
148   // to the currently lowered kernel function.
149   MDNode *KernelToMDNodeList = OpenCLKernelsMD->getOperand(0);
150   bool FoundLoweredKernelFunction = false;
151   for (const MDOperand &Operand : KernelToMDNodeList->operands()) {
152     ValueAsMetadata *MaybeValue = dyn_cast<ValueAsMetadata>(Operand);
153     if (MaybeValue && dyn_cast<Function>(MaybeValue->getValue())->getName() ==
154                           KernelFunction.getName()) {
155       FoundLoweredKernelFunction = true;
156       continue;
157     }
158     if (MaybeValue && FoundLoweredKernelFunction)
159       return nullptr;
160 
161     MDNode *MaybeNode = dyn_cast<MDNode>(Operand);
162     if (FoundLoweredKernelFunction && MaybeNode &&
163         cast<MDString>(MaybeNode->getOperand(0))->getString() ==
164             AttributeName &&
165         MDArgIdx < MaybeNode->getNumOperands())
166       return cast<MDString>(MaybeNode->getOperand(MDArgIdx));
167   }
168   return nullptr;
169 }
170 
171 static SPIRV::AccessQualifier::AccessQualifier
172 getArgAccessQual(const Function &F, unsigned ArgIdx) {
173   if (F.getCallingConv() != CallingConv::SPIR_KERNEL)
174     return SPIRV::AccessQualifier::ReadWrite;
175 
176   MDString *ArgAttribute =
177       getKernelArgAttribute(F, ArgIdx, "kernel_arg_access_qual");
178   if (!ArgAttribute)
179     return SPIRV::AccessQualifier::ReadWrite;
180 
181   if (ArgAttribute->getString().compare("read_only") == 0)
182     return SPIRV::AccessQualifier::ReadOnly;
183   if (ArgAttribute->getString().compare("write_only") == 0)
184     return SPIRV::AccessQualifier::WriteOnly;
185   return SPIRV::AccessQualifier::ReadWrite;
186 }
187 
188 static std::vector<SPIRV::Decoration::Decoration>
189 getKernelArgTypeQual(const Function &KernelFunction, unsigned ArgIdx) {
190   MDString *ArgAttribute =
191       getKernelArgAttribute(KernelFunction, ArgIdx, "kernel_arg_type_qual");
192   if (ArgAttribute && ArgAttribute->getString().compare("volatile") == 0)
193     return {SPIRV::Decoration::Volatile};
194   return {};
195 }
196 
197 static SPIRVType *getArgSPIRVType(const Function &F, unsigned ArgIdx,
198                                   SPIRVGlobalRegistry *GR,
199                                   MachineIRBuilder &MIRBuilder) {
200   // Read argument's access qualifier from metadata or default.
201   SPIRV::AccessQualifier::AccessQualifier ArgAccessQual =
202       getArgAccessQual(F, ArgIdx);
203 
204   Type *OriginalArgType = getOriginalFunctionType(F)->getParamType(ArgIdx);
205 
206   // In case of non-kernel SPIR-V function or already TargetExtType, use the
207   // original IR type.
208   if (F.getCallingConv() != CallingConv::SPIR_KERNEL ||
209       isSpecialOpaqueType(OriginalArgType))
210     return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
211 
212   MDString *MDKernelArgType =
213       getKernelArgAttribute(F, ArgIdx, "kernel_arg_type");
214   if (!MDKernelArgType || (!MDKernelArgType->getString().ends_with("*") &&
215                            !MDKernelArgType->getString().ends_with("_t")))
216     return GR->getOrCreateSPIRVType(OriginalArgType, MIRBuilder, ArgAccessQual);
217 
218   if (MDKernelArgType->getString().ends_with("*"))
219     return GR->getOrCreateSPIRVTypeByName(
220         MDKernelArgType->getString(), MIRBuilder,
221         addressSpaceToStorageClass(OriginalArgType->getPointerAddressSpace()));
222 
223   if (MDKernelArgType->getString().ends_with("_t"))
224     return GR->getOrCreateSPIRVTypeByName(
225         "opencl." + MDKernelArgType->getString().str(), MIRBuilder,
226         SPIRV::StorageClass::Function, ArgAccessQual);
227 
228   llvm_unreachable("Unable to recognize argument type name.");
229 }
230 
231 static bool isEntryPoint(const Function &F) {
232   // OpenCL handling: any function with the SPIR_KERNEL
233   // calling convention will be a potential entry point.
234   if (F.getCallingConv() == CallingConv::SPIR_KERNEL)
235     return true;
236 
237   // HLSL handling: special attribute are emitted from the
238   // front-end.
239   if (F.getFnAttribute("hlsl.shader").isValid())
240     return true;
241 
242   return false;
243 }
244 
245 static SPIRV::ExecutionModel::ExecutionModel
246 getExecutionModel(const SPIRVSubtarget &STI, const Function &F) {
247   if (STI.isOpenCLEnv())
248     return SPIRV::ExecutionModel::Kernel;
249 
250   auto attribute = F.getFnAttribute("hlsl.shader");
251   if (!attribute.isValid()) {
252     report_fatal_error(
253         "This entry point lacks mandatory hlsl.shader attribute.");
254   }
255 
256   const auto value = attribute.getValueAsString();
257   if (value == "compute")
258     return SPIRV::ExecutionModel::GLCompute;
259 
260   report_fatal_error("This HLSL entry point is not supported by this backend.");
261 }
262 
263 bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
264                                              const Function &F,
265                                              ArrayRef<ArrayRef<Register>> VRegs,
266                                              FunctionLoweringInfo &FLI) const {
267   assert(GR && "Must initialize the SPIRV type registry before lowering args.");
268   GR->setCurrentFunc(MIRBuilder.getMF());
269 
270   // Assign types and names to all args, and store their types for later.
271   FunctionType *FTy = getOriginalFunctionType(F);
272   SmallVector<SPIRVType *, 4> ArgTypeVRegs;
273   if (VRegs.size() > 0) {
274     unsigned i = 0;
275     for (const auto &Arg : F.args()) {
276       // Currently formal args should use single registers.
277       // TODO: handle the case of multiple registers.
278       if (VRegs[i].size() > 1)
279         return false;
280       auto *SpirvTy = getArgSPIRVType(F, i, GR, MIRBuilder);
281       GR->assignSPIRVTypeToVReg(SpirvTy, VRegs[i][0], MIRBuilder.getMF());
282       ArgTypeVRegs.push_back(SpirvTy);
283 
284       if (Arg.hasName())
285         buildOpName(VRegs[i][0], Arg.getName(), MIRBuilder);
286       if (Arg.getType()->isPointerTy()) {
287         auto DerefBytes = static_cast<unsigned>(Arg.getDereferenceableBytes());
288         if (DerefBytes != 0)
289           buildOpDecorate(VRegs[i][0], MIRBuilder,
290                           SPIRV::Decoration::MaxByteOffset, {DerefBytes});
291       }
292       if (Arg.hasAttribute(Attribute::Alignment)) {
293         auto Alignment = static_cast<unsigned>(
294             Arg.getAttribute(Attribute::Alignment).getValueAsInt());
295         buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::Alignment,
296                         {Alignment});
297       }
298       if (Arg.hasAttribute(Attribute::ReadOnly)) {
299         auto Attr =
300             static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoWrite);
301         buildOpDecorate(VRegs[i][0], MIRBuilder,
302                         SPIRV::Decoration::FuncParamAttr, {Attr});
303       }
304       if (Arg.hasAttribute(Attribute::ZExt)) {
305         auto Attr =
306             static_cast<unsigned>(SPIRV::FunctionParameterAttribute::Zext);
307         buildOpDecorate(VRegs[i][0], MIRBuilder,
308                         SPIRV::Decoration::FuncParamAttr, {Attr});
309       }
310       if (Arg.hasAttribute(Attribute::NoAlias)) {
311         auto Attr =
312             static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoAlias);
313         buildOpDecorate(VRegs[i][0], MIRBuilder,
314                         SPIRV::Decoration::FuncParamAttr, {Attr});
315       }
316 
317       if (F.getCallingConv() == CallingConv::SPIR_KERNEL) {
318         std::vector<SPIRV::Decoration::Decoration> ArgTypeQualDecs =
319             getKernelArgTypeQual(F, i);
320         for (SPIRV::Decoration::Decoration Decoration : ArgTypeQualDecs)
321           buildOpDecorate(VRegs[i][0], MIRBuilder, Decoration, {});
322       }
323 
324       MDNode *Node = F.getMetadata("spirv.ParameterDecorations");
325       if (Node && i < Node->getNumOperands() &&
326           isa<MDNode>(Node->getOperand(i))) {
327         MDNode *MD = cast<MDNode>(Node->getOperand(i));
328         for (const MDOperand &MDOp : MD->operands()) {
329           MDNode *MD2 = dyn_cast<MDNode>(MDOp);
330           assert(MD2 && "Metadata operand is expected");
331           ConstantInt *Const = getConstInt(MD2, 0);
332           assert(Const && "MDOperand should be ConstantInt");
333           auto Dec =
334               static_cast<SPIRV::Decoration::Decoration>(Const->getZExtValue());
335           std::vector<uint32_t> DecVec;
336           for (unsigned j = 1; j < MD2->getNumOperands(); j++) {
337             ConstantInt *Const = getConstInt(MD2, j);
338             assert(Const && "MDOperand should be ConstantInt");
339             DecVec.push_back(static_cast<uint32_t>(Const->getZExtValue()));
340           }
341           buildOpDecorate(VRegs[i][0], MIRBuilder, Dec, DecVec);
342         }
343       }
344       ++i;
345     }
346   }
347 
348   // Generate a SPIR-V type for the function.
349   auto MRI = MIRBuilder.getMRI();
350   Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
351   MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass);
352   if (F.isDeclaration())
353     GR->add(&F, &MIRBuilder.getMF(), FuncVReg);
354   SPIRVType *RetTy = GR->getOrCreateSPIRVType(FTy->getReturnType(), MIRBuilder);
355   SPIRVType *FuncTy = GR->getOrCreateOpTypeFunctionWithArgs(
356       FTy, RetTy, ArgTypeVRegs, MIRBuilder);
357 
358   // Build the OpTypeFunction declaring it.
359   uint32_t FuncControl = getFunctionControl(F);
360 
361   MIRBuilder.buildInstr(SPIRV::OpFunction)
362       .addDef(FuncVReg)
363       .addUse(GR->getSPIRVTypeID(RetTy))
364       .addImm(FuncControl)
365       .addUse(GR->getSPIRVTypeID(FuncTy));
366 
367   // Add OpFunctionParameters.
368   int i = 0;
369   for (const auto &Arg : F.args()) {
370     assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs");
371     MRI->setRegClass(VRegs[i][0], &SPIRV::IDRegClass);
372     MIRBuilder.buildInstr(SPIRV::OpFunctionParameter)
373         .addDef(VRegs[i][0])
374         .addUse(GR->getSPIRVTypeID(ArgTypeVRegs[i]));
375     if (F.isDeclaration())
376       GR->add(&Arg, &MIRBuilder.getMF(), VRegs[i][0]);
377     i++;
378   }
379   // Name the function.
380   if (F.hasName())
381     buildOpName(FuncVReg, F.getName(), MIRBuilder);
382 
383   // Handle entry points and function linkage.
384   if (isEntryPoint(F)) {
385     const auto &STI = MIRBuilder.getMF().getSubtarget<SPIRVSubtarget>();
386     auto executionModel = getExecutionModel(STI, F);
387     auto MIB = MIRBuilder.buildInstr(SPIRV::OpEntryPoint)
388                    .addImm(static_cast<uint32_t>(executionModel))
389                    .addUse(FuncVReg);
390     addStringImm(F.getName(), MIB);
391   } else if (F.getLinkage() == GlobalValue::LinkageTypes::ExternalLinkage ||
392              F.getLinkage() == GlobalValue::LinkOnceODRLinkage) {
393     auto LnkTy = F.isDeclaration() ? SPIRV::LinkageType::Import
394                                    : SPIRV::LinkageType::Export;
395     buildOpDecorate(FuncVReg, MIRBuilder, SPIRV::Decoration::LinkageAttributes,
396                     {static_cast<uint32_t>(LnkTy)}, F.getGlobalIdentifier());
397   }
398 
399   return true;
400 }
401 
402 bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
403                                   CallLoweringInfo &Info) const {
404   // Currently call returns should have single vregs.
405   // TODO: handle the case of multiple registers.
406   if (Info.OrigRet.Regs.size() > 1)
407     return false;
408   MachineFunction &MF = MIRBuilder.getMF();
409   GR->setCurrentFunc(MF);
410   FunctionType *FTy = nullptr;
411   const Function *CF = nullptr;
412 
413   // Emit a regular OpFunctionCall. If it's an externally declared function,
414   // be sure to emit its type and function declaration here. It will be hoisted
415   // globally later.
416   if (Info.Callee.isGlobal()) {
417     CF = dyn_cast_or_null<const Function>(Info.Callee.getGlobal());
418     // TODO: support constexpr casts and indirect calls.
419     if (CF == nullptr)
420       return false;
421     FTy = getOriginalFunctionType(*CF);
422   }
423 
424   MachineRegisterInfo *MRI = MIRBuilder.getMRI();
425   Register ResVReg =
426       Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0];
427   std::string FuncName = Info.Callee.getGlobal()->getName().str();
428   std::string DemangledName = getOclOrSpirvBuiltinDemangledName(FuncName);
429   const auto *ST = static_cast<const SPIRVSubtarget *>(&MF.getSubtarget());
430   // TODO: check that it's OCL builtin, then apply OpenCL_std.
431   if (!DemangledName.empty() && CF && CF->isDeclaration() &&
432       ST->canUseExtInstSet(SPIRV::InstructionSet::OpenCL_std)) {
433     const Type *OrigRetTy = Info.OrigRet.Ty;
434     if (FTy)
435       OrigRetTy = FTy->getReturnType();
436     SmallVector<Register, 8> ArgVRegs;
437     for (auto Arg : Info.OrigArgs) {
438       assert(Arg.Regs.size() == 1 && "Call arg has multiple VRegs");
439       ArgVRegs.push_back(Arg.Regs[0]);
440       SPIRVType *SPIRVTy = GR->getOrCreateSPIRVType(Arg.Ty, MIRBuilder);
441       if (!GR->getSPIRVTypeForVReg(Arg.Regs[0]))
442         GR->assignSPIRVTypeToVReg(SPIRVTy, Arg.Regs[0], MIRBuilder.getMF());
443     }
444     if (auto Res = SPIRV::lowerBuiltin(
445             DemangledName, SPIRV::InstructionSet::OpenCL_std, MIRBuilder,
446             ResVReg, OrigRetTy, ArgVRegs, GR))
447       return *Res;
448   }
449   if (CF && CF->isDeclaration() &&
450       !GR->find(CF, &MIRBuilder.getMF()).isValid()) {
451     // Emit the type info and forward function declaration to the first MBB
452     // to ensure VReg definition dependencies are valid across all MBBs.
453     MachineIRBuilder FirstBlockBuilder;
454     FirstBlockBuilder.setMF(MF);
455     FirstBlockBuilder.setMBB(*MF.getBlockNumbered(0));
456 
457     SmallVector<ArrayRef<Register>, 8> VRegArgs;
458     SmallVector<SmallVector<Register, 1>, 8> ToInsert;
459     for (const Argument &Arg : CF->args()) {
460       if (MIRBuilder.getDataLayout().getTypeStoreSize(Arg.getType()).isZero())
461         continue; // Don't handle zero sized types.
462       Register Reg = MRI->createGenericVirtualRegister(LLT::scalar(32));
463       MRI->setRegClass(Reg, &SPIRV::IDRegClass);
464       ToInsert.push_back({Reg});
465       VRegArgs.push_back(ToInsert.back());
466     }
467     // TODO: Reuse FunctionLoweringInfo
468     FunctionLoweringInfo FuncInfo;
469     lowerFormalArguments(FirstBlockBuilder, *CF, VRegArgs, FuncInfo);
470   }
471 
472   // Make sure there's a valid return reg, even for functions returning void.
473   if (!ResVReg.isValid())
474     ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
475   SPIRVType *RetType =
476       GR->assignTypeToVReg(FTy->getReturnType(), ResVReg, MIRBuilder);
477 
478   // Emit the OpFunctionCall and its args.
479   auto MIB = MIRBuilder.buildInstr(SPIRV::OpFunctionCall)
480                  .addDef(ResVReg)
481                  .addUse(GR->getSPIRVTypeID(RetType))
482                  .add(Info.Callee);
483 
484   for (const auto &Arg : Info.OrigArgs) {
485     // Currently call args should have single vregs.
486     if (Arg.Regs.size() > 1)
487       return false;
488     MIB.addUse(Arg.Regs[0]);
489   }
490   return MIB.constrainAllUses(MIRBuilder.getTII(), *ST->getRegisterInfo(),
491                               *ST->getRegBankInfo());
492 }
493