1 //===--- SPIRVCallLowering.cpp - Call lowering ------------------*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements the lowering of LLVM calls to machine code calls for
10 // GlobalISel.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "SPIRVCallLowering.h"
15 #include "MCTargetDesc/SPIRVBaseInfo.h"
16 #include "SPIRV.h"
17 #include "SPIRVGlobalRegistry.h"
18 #include "SPIRVISelLowering.h"
19 #include "SPIRVRegisterInfo.h"
20 #include "SPIRVSubtarget.h"
21 #include "SPIRVUtils.h"
22 #include "llvm/CodeGen/FunctionLoweringInfo.h"
23 
24 using namespace llvm;
25 
26 SPIRVCallLowering::SPIRVCallLowering(const SPIRVTargetLowering &TLI,
27                                      const SPIRVSubtarget &ST,
28                                      SPIRVGlobalRegistry *GR)
29     : CallLowering(&TLI), ST(ST), GR(GR) {}
30 
31 bool SPIRVCallLowering::lowerReturn(MachineIRBuilder &MIRBuilder,
32                                     const Value *Val, ArrayRef<Register> VRegs,
33                                     FunctionLoweringInfo &FLI,
34                                     Register SwiftErrorVReg) const {
35   // Currently all return types should use a single register.
36   // TODO: handle the case of multiple registers.
37   if (VRegs.size() > 1)
38     return false;
39   if (Val)
40     return MIRBuilder.buildInstr(SPIRV::OpReturnValue)
41         .addUse(VRegs[0])
42         .constrainAllUses(MIRBuilder.getTII(), *ST.getRegisterInfo(),
43                           *ST.getRegBankInfo());
44   MIRBuilder.buildInstr(SPIRV::OpReturn);
45   return true;
46 }
47 
48 // Based on the LLVM function attributes, get a SPIR-V FunctionControl.
49 static uint32_t getFunctionControl(const Function &F) {
50   uint32_t FuncControl = static_cast<uint32_t>(SPIRV::FunctionControl::None);
51   if (F.hasFnAttribute(Attribute::AttrKind::AlwaysInline)) {
52     FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Inline);
53   }
54   if (F.hasFnAttribute(Attribute::AttrKind::ReadNone)) {
55     FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Pure);
56   }
57   if (F.hasFnAttribute(Attribute::AttrKind::ReadOnly)) {
58     FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::Const);
59   }
60   if (F.hasFnAttribute(Attribute::AttrKind::NoInline)) {
61     FuncControl |= static_cast<uint32_t>(SPIRV::FunctionControl::DontInline);
62   }
63   return FuncControl;
64 }
65 
66 bool SPIRVCallLowering::lowerFormalArguments(MachineIRBuilder &MIRBuilder,
67                                              const Function &F,
68                                              ArrayRef<ArrayRef<Register>> VRegs,
69                                              FunctionLoweringInfo &FLI) const {
70   assert(GR && "Must initialize the SPIRV type registry before lowering args.");
71   GR->setCurrentFunc(MIRBuilder.getMF());
72 
73   // Assign types and names to all args, and store their types for later.
74   SmallVector<Register, 4> ArgTypeVRegs;
75   if (VRegs.size() > 0) {
76     unsigned i = 0;
77     for (const auto &Arg : F.args()) {
78       // Currently formal args should use single registers.
79       // TODO: handle the case of multiple registers.
80       if (VRegs[i].size() > 1)
81         return false;
82       auto *SpirvTy =
83           GR->assignTypeToVReg(Arg.getType(), VRegs[i][0], MIRBuilder);
84       ArgTypeVRegs.push_back(GR->getSPIRVTypeID(SpirvTy));
85 
86       if (Arg.hasName())
87         buildOpName(VRegs[i][0], Arg.getName(), MIRBuilder);
88       if (Arg.getType()->isPointerTy()) {
89         auto DerefBytes = static_cast<unsigned>(Arg.getDereferenceableBytes());
90         if (DerefBytes != 0)
91           buildOpDecorate(VRegs[i][0], MIRBuilder,
92                           SPIRV::Decoration::MaxByteOffset, {DerefBytes});
93       }
94       if (Arg.hasAttribute(Attribute::Alignment)) {
95         buildOpDecorate(VRegs[i][0], MIRBuilder, SPIRV::Decoration::Alignment,
96                         {static_cast<unsigned>(Arg.getParamAlignment())});
97       }
98       if (Arg.hasAttribute(Attribute::ReadOnly)) {
99         auto Attr =
100             static_cast<unsigned>(SPIRV::FunctionParameterAttribute::NoWrite);
101         buildOpDecorate(VRegs[i][0], MIRBuilder,
102                         SPIRV::Decoration::FuncParamAttr, {Attr});
103       }
104       if (Arg.hasAttribute(Attribute::ZExt)) {
105         auto Attr =
106             static_cast<unsigned>(SPIRV::FunctionParameterAttribute::Zext);
107         buildOpDecorate(VRegs[i][0], MIRBuilder,
108                         SPIRV::Decoration::FuncParamAttr, {Attr});
109       }
110       ++i;
111     }
112   }
113 
114   // Generate a SPIR-V type for the function.
115   auto MRI = MIRBuilder.getMRI();
116   Register FuncVReg = MRI->createGenericVirtualRegister(LLT::scalar(32));
117   MRI->setRegClass(FuncVReg, &SPIRV::IDRegClass);
118   if (F.isDeclaration())
119     GR->add(&F, &MIRBuilder.getMF(), FuncVReg);
120 
121   auto *FTy = F.getFunctionType();
122   auto FuncTy = GR->assignTypeToVReg(FTy, FuncVReg, MIRBuilder);
123 
124   // Build the OpTypeFunction declaring it.
125   Register ReturnTypeID = FuncTy->getOperand(1).getReg();
126   uint32_t FuncControl = getFunctionControl(F);
127 
128   MIRBuilder.buildInstr(SPIRV::OpFunction)
129       .addDef(FuncVReg)
130       .addUse(ReturnTypeID)
131       .addImm(FuncControl)
132       .addUse(GR->getSPIRVTypeID(FuncTy));
133 
134   // Add OpFunctionParameters.
135   const unsigned NumArgs = ArgTypeVRegs.size();
136   for (unsigned i = 0; i < NumArgs; ++i) {
137     assert(VRegs[i].size() == 1 && "Formal arg has multiple vregs");
138     MRI->setRegClass(VRegs[i][0], &SPIRV::IDRegClass);
139     MIRBuilder.buildInstr(SPIRV::OpFunctionParameter)
140         .addDef(VRegs[i][0])
141         .addUse(ArgTypeVRegs[i]);
142     if (F.isDeclaration())
143       GR->add(F.getArg(i), &MIRBuilder.getMF(), VRegs[i][0]);
144   }
145   // Name the function.
146   if (F.hasName())
147     buildOpName(FuncVReg, F.getName(), MIRBuilder);
148 
149   // Handle entry points and function linkage.
150   if (F.getCallingConv() == CallingConv::SPIR_KERNEL) {
151     auto MIB = MIRBuilder.buildInstr(SPIRV::OpEntryPoint)
152                    .addImm(static_cast<uint32_t>(SPIRV::ExecutionModel::Kernel))
153                    .addUse(FuncVReg);
154     addStringImm(F.getName(), MIB);
155   } else if (F.getLinkage() == GlobalValue::LinkageTypes::ExternalLinkage ||
156              F.getLinkage() == GlobalValue::LinkOnceODRLinkage) {
157     auto LnkTy = F.isDeclaration() ? SPIRV::LinkageType::Import
158                                    : SPIRV::LinkageType::Export;
159     buildOpDecorate(FuncVReg, MIRBuilder, SPIRV::Decoration::LinkageAttributes,
160                     {static_cast<uint32_t>(LnkTy)}, F.getGlobalIdentifier());
161   }
162 
163   return true;
164 }
165 
166 bool SPIRVCallLowering::lowerCall(MachineIRBuilder &MIRBuilder,
167                                   CallLoweringInfo &Info) const {
168   // Currently call returns should have single vregs.
169   // TODO: handle the case of multiple registers.
170   if (Info.OrigRet.Regs.size() > 1)
171     return false;
172 
173   GR->setCurrentFunc(MIRBuilder.getMF());
174   Register ResVReg =
175       Info.OrigRet.Regs.empty() ? Register(0) : Info.OrigRet.Regs[0];
176   // Emit a regular OpFunctionCall. If it's an externally declared function,
177   // be sure to emit its type and function declaration here. It will be
178   // hoisted globally later.
179   if (Info.Callee.isGlobal()) {
180     auto *CF = dyn_cast_or_null<const Function>(Info.Callee.getGlobal());
181     // TODO: support constexpr casts and indirect calls.
182     if (CF == nullptr)
183       return false;
184     if (CF->isDeclaration()) {
185       // Emit the type info and forward function declaration to the first MBB
186       // to ensure VReg definition dependencies are valid across all MBBs.
187       MachineBasicBlock::iterator OldII = MIRBuilder.getInsertPt();
188       MachineBasicBlock &OldBB = MIRBuilder.getMBB();
189       MachineBasicBlock &FirstBB = *MIRBuilder.getMF().getBlockNumbered(0);
190       MIRBuilder.setInsertPt(FirstBB, FirstBB.instr_end());
191 
192       SmallVector<ArrayRef<Register>, 8> VRegArgs;
193       SmallVector<SmallVector<Register, 1>, 8> ToInsert;
194       for (const Argument &Arg : CF->args()) {
195         if (MIRBuilder.getDataLayout().getTypeStoreSize(Arg.getType()).isZero())
196           continue; // Don't handle zero sized types.
197         ToInsert.push_back({MIRBuilder.getMRI()->createGenericVirtualRegister(
198             LLT::scalar(32))});
199         VRegArgs.push_back(ToInsert.back());
200       }
201       // TODO: Reuse FunctionLoweringInfo.
202       FunctionLoweringInfo FuncInfo;
203       lowerFormalArguments(MIRBuilder, *CF, VRegArgs, FuncInfo);
204       MIRBuilder.setInsertPt(OldBB, OldII);
205     }
206   }
207 
208   // Make sure there's a valid return reg, even for functions returning void.
209   if (!ResVReg.isValid()) {
210     ResVReg = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::IDRegClass);
211   }
212   SPIRVType *RetType =
213       GR->assignTypeToVReg(Info.OrigRet.Ty, ResVReg, MIRBuilder);
214 
215   // Emit the OpFunctionCall and its args.
216   auto MIB = MIRBuilder.buildInstr(SPIRV::OpFunctionCall)
217                  .addDef(ResVReg)
218                  .addUse(GR->getSPIRVTypeID(RetType))
219                  .add(Info.Callee);
220 
221   for (const auto &Arg : Info.OrigArgs) {
222     // Currently call args should have single vregs.
223     if (Arg.Regs.size() > 1)
224       return false;
225     MIB.addUse(Arg.Regs[0]);
226   }
227   return MIB.constrainAllUses(MIRBuilder.getTII(), *ST.getRegisterInfo(),
228                               *ST.getRegBankInfo());
229 }
230