1 //===-- SPIRVPrepareFunctions.cpp - modify function signatures --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass modifies function signatures containing aggregate arguments
10 // and/or return value before IRTranslator. Information about the original
11 // signatures is stored in metadata. It is used during call lowering to
12 // restore correct SPIR-V types of function arguments and return values.
13 // This pass also substitutes some llvm intrinsic calls with calls to newly
14 // generated functions (as the Khronos LLVM/SPIR-V Translator does).
15 //
16 // NOTE: this pass is a module-level one due to the necessity to modify
17 // GVs/functions.
18 //
19 //===----------------------------------------------------------------------===//
20
21 #include "SPIRV.h"
22 #include "SPIRVSubtarget.h"
23 #include "SPIRVTargetMachine.h"
24 #include "SPIRVUtils.h"
25 #include "llvm/CodeGen/IntrinsicLowering.h"
26 #include "llvm/IR/IRBuilder.h"
27 #include "llvm/IR/IntrinsicInst.h"
28 #include "llvm/IR/Intrinsics.h"
29 #include "llvm/IR/IntrinsicsSPIRV.h"
30 #include "llvm/Transforms/Utils/Cloning.h"
31 #include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
32
33 using namespace llvm;
34
35 namespace llvm {
36 void initializeSPIRVPrepareFunctionsPass(PassRegistry &);
37 }
38
39 namespace {
40
41 class SPIRVPrepareFunctions : public ModulePass {
42 const SPIRVTargetMachine &TM;
43 bool substituteIntrinsicCalls(Function *F);
44 Function *removeAggregateTypesFromSignature(Function *F);
45
46 public:
47 static char ID;
SPIRVPrepareFunctions(const SPIRVTargetMachine & TM)48 SPIRVPrepareFunctions(const SPIRVTargetMachine &TM) : ModulePass(ID), TM(TM) {
49 initializeSPIRVPrepareFunctionsPass(*PassRegistry::getPassRegistry());
50 }
51
52 bool runOnModule(Module &M) override;
53
getPassName() const54 StringRef getPassName() const override { return "SPIRV prepare functions"; }
55
getAnalysisUsage(AnalysisUsage & AU) const56 void getAnalysisUsage(AnalysisUsage &AU) const override {
57 ModulePass::getAnalysisUsage(AU);
58 }
59 };
60
61 } // namespace
62
63 char SPIRVPrepareFunctions::ID = 0;
64
65 INITIALIZE_PASS(SPIRVPrepareFunctions, "prepare-functions",
66 "SPIRV prepare functions", false, false)
67
lowerLLVMIntrinsicName(IntrinsicInst * II)68 std::string lowerLLVMIntrinsicName(IntrinsicInst *II) {
69 Function *IntrinsicFunc = II->getCalledFunction();
70 assert(IntrinsicFunc && "Missing function");
71 std::string FuncName = IntrinsicFunc->getName().str();
72 std::replace(FuncName.begin(), FuncName.end(), '.', '_');
73 FuncName = "spirv." + FuncName;
74 return FuncName;
75 }
76
getOrCreateFunction(Module * M,Type * RetTy,ArrayRef<Type * > ArgTypes,StringRef Name)77 static Function *getOrCreateFunction(Module *M, Type *RetTy,
78 ArrayRef<Type *> ArgTypes,
79 StringRef Name) {
80 FunctionType *FT = FunctionType::get(RetTy, ArgTypes, false);
81 Function *F = M->getFunction(Name);
82 if (F && F->getFunctionType() == FT)
83 return F;
84 Function *NewF = Function::Create(FT, GlobalValue::ExternalLinkage, Name, M);
85 if (F)
86 NewF->setDSOLocal(F->isDSOLocal());
87 NewF->setCallingConv(CallingConv::SPIR_FUNC);
88 return NewF;
89 }
90
lowerIntrinsicToFunction(IntrinsicInst * Intrinsic)91 static bool lowerIntrinsicToFunction(IntrinsicInst *Intrinsic) {
92 // For @llvm.memset.* intrinsic cases with constant value and length arguments
93 // are emulated via "storing" a constant array to the destination. For other
94 // cases we wrap the intrinsic in @spirv.llvm_memset_* function and expand the
95 // intrinsic to a loop via expandMemSetAsLoop().
96 if (auto *MSI = dyn_cast<MemSetInst>(Intrinsic))
97 if (isa<Constant>(MSI->getValue()) && isa<ConstantInt>(MSI->getLength()))
98 return false; // It is handled later using OpCopyMemorySized.
99
100 Module *M = Intrinsic->getModule();
101 std::string FuncName = lowerLLVMIntrinsicName(Intrinsic);
102 if (Intrinsic->isVolatile())
103 FuncName += ".volatile";
104 // Redirect @llvm.intrinsic.* call to @spirv.llvm_intrinsic_*
105 Function *F = M->getFunction(FuncName);
106 if (F) {
107 Intrinsic->setCalledFunction(F);
108 return true;
109 }
110 // TODO copy arguments attributes: nocapture writeonly.
111 FunctionCallee FC =
112 M->getOrInsertFunction(FuncName, Intrinsic->getFunctionType());
113 auto IntrinsicID = Intrinsic->getIntrinsicID();
114 Intrinsic->setCalledFunction(FC);
115
116 F = dyn_cast<Function>(FC.getCallee());
117 assert(F && "Callee must be a function");
118
119 switch (IntrinsicID) {
120 case Intrinsic::memset: {
121 auto *MSI = static_cast<MemSetInst *>(Intrinsic);
122 Argument *Dest = F->getArg(0);
123 Argument *Val = F->getArg(1);
124 Argument *Len = F->getArg(2);
125 Argument *IsVolatile = F->getArg(3);
126 Dest->setName("dest");
127 Val->setName("val");
128 Len->setName("len");
129 IsVolatile->setName("isvolatile");
130 BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", F);
131 IRBuilder<> IRB(EntryBB);
132 auto *MemSet = IRB.CreateMemSet(Dest, Val, Len, MSI->getDestAlign(),
133 MSI->isVolatile());
134 IRB.CreateRetVoid();
135 expandMemSetAsLoop(cast<MemSetInst>(MemSet));
136 MemSet->eraseFromParent();
137 break;
138 }
139 case Intrinsic::bswap: {
140 BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", F);
141 IRBuilder<> IRB(EntryBB);
142 auto *BSwap = IRB.CreateIntrinsic(Intrinsic::bswap, Intrinsic->getType(),
143 F->getArg(0));
144 IRB.CreateRet(BSwap);
145 IntrinsicLowering IL(M->getDataLayout());
146 IL.LowerIntrinsicCall(BSwap);
147 break;
148 }
149 default:
150 break;
151 }
152 return true;
153 }
154
lowerFunnelShifts(IntrinsicInst * FSHIntrinsic)155 static void lowerFunnelShifts(IntrinsicInst *FSHIntrinsic) {
156 // Get a separate function - otherwise, we'd have to rework the CFG of the
157 // current one. Then simply replace the intrinsic uses with a call to the new
158 // function.
159 // Generate LLVM IR for i* @spirv.llvm_fsh?_i* (i* %a, i* %b, i* %c)
160 Module *M = FSHIntrinsic->getModule();
161 FunctionType *FSHFuncTy = FSHIntrinsic->getFunctionType();
162 Type *FSHRetTy = FSHFuncTy->getReturnType();
163 const std::string FuncName = lowerLLVMIntrinsicName(FSHIntrinsic);
164 Function *FSHFunc =
165 getOrCreateFunction(M, FSHRetTy, FSHFuncTy->params(), FuncName);
166
167 if (!FSHFunc->empty()) {
168 FSHIntrinsic->setCalledFunction(FSHFunc);
169 return;
170 }
171 BasicBlock *RotateBB = BasicBlock::Create(M->getContext(), "rotate", FSHFunc);
172 IRBuilder<> IRB(RotateBB);
173 Type *Ty = FSHFunc->getReturnType();
174 // Build the actual funnel shift rotate logic.
175 // In the comments, "int" is used interchangeably with "vector of int
176 // elements".
177 FixedVectorType *VectorTy = dyn_cast<FixedVectorType>(Ty);
178 Type *IntTy = VectorTy ? VectorTy->getElementType() : Ty;
179 unsigned BitWidth = IntTy->getIntegerBitWidth();
180 ConstantInt *BitWidthConstant = IRB.getInt({BitWidth, BitWidth});
181 Value *BitWidthForInsts =
182 VectorTy
183 ? IRB.CreateVectorSplat(VectorTy->getNumElements(), BitWidthConstant)
184 : BitWidthConstant;
185 Value *RotateModVal =
186 IRB.CreateURem(/*Rotate*/ FSHFunc->getArg(2), BitWidthForInsts);
187 Value *FirstShift = nullptr, *SecShift = nullptr;
188 if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {
189 // Shift the less significant number right, the "rotate" number of bits
190 // will be 0-filled on the left as a result of this regular shift.
191 FirstShift = IRB.CreateLShr(FSHFunc->getArg(1), RotateModVal);
192 } else {
193 // Shift the more significant number left, the "rotate" number of bits
194 // will be 0-filled on the right as a result of this regular shift.
195 FirstShift = IRB.CreateShl(FSHFunc->getArg(0), RotateModVal);
196 }
197 // We want the "rotate" number of the more significant int's LSBs (MSBs) to
198 // occupy the leftmost (rightmost) "0 space" left by the previous operation.
199 // Therefore, subtract the "rotate" number from the integer bitsize...
200 Value *SubRotateVal = IRB.CreateSub(BitWidthForInsts, RotateModVal);
201 if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {
202 // ...and left-shift the more significant int by this number, zero-filling
203 // the LSBs.
204 SecShift = IRB.CreateShl(FSHFunc->getArg(0), SubRotateVal);
205 } else {
206 // ...and right-shift the less significant int by this number, zero-filling
207 // the MSBs.
208 SecShift = IRB.CreateLShr(FSHFunc->getArg(1), SubRotateVal);
209 }
210 // A simple binary addition of the shifted ints yields the final result.
211 IRB.CreateRet(IRB.CreateOr(FirstShift, SecShift));
212
213 FSHIntrinsic->setCalledFunction(FSHFunc);
214 }
215
buildUMulWithOverflowFunc(Function * UMulFunc)216 static void buildUMulWithOverflowFunc(Function *UMulFunc) {
217 // The function body is already created.
218 if (!UMulFunc->empty())
219 return;
220
221 BasicBlock *EntryBB = BasicBlock::Create(UMulFunc->getParent()->getContext(),
222 "entry", UMulFunc);
223 IRBuilder<> IRB(EntryBB);
224 // Build the actual unsigned multiplication logic with the overflow
225 // indication. Do unsigned multiplication Mul = A * B. Then check
226 // if unsigned division Div = Mul / A is not equal to B. If so,
227 // then overflow has happened.
228 Value *Mul = IRB.CreateNUWMul(UMulFunc->getArg(0), UMulFunc->getArg(1));
229 Value *Div = IRB.CreateUDiv(Mul, UMulFunc->getArg(0));
230 Value *Overflow = IRB.CreateICmpNE(UMulFunc->getArg(0), Div);
231
232 // umul.with.overflow intrinsic return a structure, where the first element
233 // is the multiplication result, and the second is an overflow bit.
234 Type *StructTy = UMulFunc->getReturnType();
235 Value *Agg = IRB.CreateInsertValue(PoisonValue::get(StructTy), Mul, {0});
236 Value *Res = IRB.CreateInsertValue(Agg, Overflow, {1});
237 IRB.CreateRet(Res);
238 }
239
lowerExpectAssume(IntrinsicInst * II)240 static void lowerExpectAssume(IntrinsicInst *II) {
241 // If we cannot use the SPV_KHR_expect_assume extension, then we need to
242 // ignore the intrinsic and move on. It should be removed later on by LLVM.
243 // Otherwise we should lower the intrinsic to the corresponding SPIR-V
244 // instruction.
245 // For @llvm.assume we have OpAssumeTrueKHR.
246 // For @llvm.expect we have OpExpectKHR.
247 //
248 // We need to lower this into a builtin and then the builtin into a SPIR-V
249 // instruction.
250 if (II->getIntrinsicID() == Intrinsic::assume) {
251 Function *F = Intrinsic::getDeclaration(
252 II->getModule(), Intrinsic::SPVIntrinsics::spv_assume);
253 II->setCalledFunction(F);
254 } else if (II->getIntrinsicID() == Intrinsic::expect) {
255 Function *F = Intrinsic::getDeclaration(
256 II->getModule(), Intrinsic::SPVIntrinsics::spv_expect,
257 {II->getOperand(0)->getType()});
258 II->setCalledFunction(F);
259 } else {
260 llvm_unreachable("Unknown intrinsic");
261 }
262
263 return;
264 }
265
lowerUMulWithOverflow(IntrinsicInst * UMulIntrinsic)266 static void lowerUMulWithOverflow(IntrinsicInst *UMulIntrinsic) {
267 // Get a separate function - otherwise, we'd have to rework the CFG of the
268 // current one. Then simply replace the intrinsic uses with a call to the new
269 // function.
270 Module *M = UMulIntrinsic->getModule();
271 FunctionType *UMulFuncTy = UMulIntrinsic->getFunctionType();
272 Type *FSHLRetTy = UMulFuncTy->getReturnType();
273 const std::string FuncName = lowerLLVMIntrinsicName(UMulIntrinsic);
274 Function *UMulFunc =
275 getOrCreateFunction(M, FSHLRetTy, UMulFuncTy->params(), FuncName);
276 buildUMulWithOverflowFunc(UMulFunc);
277 UMulIntrinsic->setCalledFunction(UMulFunc);
278 }
279
280 // Substitutes calls to LLVM intrinsics with either calls to SPIR-V intrinsics
281 // or calls to proper generated functions. Returns True if F was modified.
substituteIntrinsicCalls(Function * F)282 bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) {
283 bool Changed = false;
284 for (BasicBlock &BB : *F) {
285 for (Instruction &I : BB) {
286 auto Call = dyn_cast<CallInst>(&I);
287 if (!Call)
288 continue;
289 Function *CF = Call->getCalledFunction();
290 if (!CF || !CF->isIntrinsic())
291 continue;
292 auto *II = cast<IntrinsicInst>(Call);
293 if (II->getIntrinsicID() == Intrinsic::memset ||
294 II->getIntrinsicID() == Intrinsic::bswap)
295 Changed |= lowerIntrinsicToFunction(II);
296 else if (II->getIntrinsicID() == Intrinsic::fshl ||
297 II->getIntrinsicID() == Intrinsic::fshr) {
298 lowerFunnelShifts(II);
299 Changed = true;
300 } else if (II->getIntrinsicID() == Intrinsic::umul_with_overflow) {
301 lowerUMulWithOverflow(II);
302 Changed = true;
303 } else if (II->getIntrinsicID() == Intrinsic::assume ||
304 II->getIntrinsicID() == Intrinsic::expect) {
305 const SPIRVSubtarget &STI = TM.getSubtarget<SPIRVSubtarget>(*F);
306 if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume))
307 lowerExpectAssume(II);
308 Changed = true;
309 }
310 }
311 }
312 return Changed;
313 }
314
315 // Returns F if aggregate argument/return types are not present or cloned F
316 // function with the types replaced by i32 types. The change in types is
317 // noted in 'spv.cloned_funcs' metadata for later restoration.
318 Function *
removeAggregateTypesFromSignature(Function * F)319 SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) {
320 IRBuilder<> B(F->getContext());
321
322 bool IsRetAggr = F->getReturnType()->isAggregateType();
323 bool HasAggrArg =
324 std::any_of(F->arg_begin(), F->arg_end(), [](Argument &Arg) {
325 return Arg.getType()->isAggregateType();
326 });
327 bool DoClone = IsRetAggr || HasAggrArg;
328 if (!DoClone)
329 return F;
330 SmallVector<std::pair<int, Type *>, 4> ChangedTypes;
331 Type *RetType = IsRetAggr ? B.getInt32Ty() : F->getReturnType();
332 if (IsRetAggr)
333 ChangedTypes.push_back(std::pair<int, Type *>(-1, F->getReturnType()));
334 SmallVector<Type *, 4> ArgTypes;
335 for (const auto &Arg : F->args()) {
336 if (Arg.getType()->isAggregateType()) {
337 ArgTypes.push_back(B.getInt32Ty());
338 ChangedTypes.push_back(
339 std::pair<int, Type *>(Arg.getArgNo(), Arg.getType()));
340 } else
341 ArgTypes.push_back(Arg.getType());
342 }
343 FunctionType *NewFTy =
344 FunctionType::get(RetType, ArgTypes, F->getFunctionType()->isVarArg());
345 Function *NewF =
346 Function::Create(NewFTy, F->getLinkage(), F->getName(), *F->getParent());
347
348 ValueToValueMapTy VMap;
349 auto NewFArgIt = NewF->arg_begin();
350 for (auto &Arg : F->args()) {
351 StringRef ArgName = Arg.getName();
352 NewFArgIt->setName(ArgName);
353 VMap[&Arg] = &(*NewFArgIt++);
354 }
355 SmallVector<ReturnInst *, 8> Returns;
356
357 CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly,
358 Returns);
359 NewF->takeName(F);
360
361 NamedMDNode *FuncMD =
362 F->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs");
363 SmallVector<Metadata *, 2> MDArgs;
364 MDArgs.push_back(MDString::get(B.getContext(), NewF->getName()));
365 for (auto &ChangedTyP : ChangedTypes)
366 MDArgs.push_back(MDNode::get(
367 B.getContext(),
368 {ConstantAsMetadata::get(B.getInt32(ChangedTyP.first)),
369 ValueAsMetadata::get(Constant::getNullValue(ChangedTyP.second))}));
370 MDNode *ThisFuncMD = MDNode::get(B.getContext(), MDArgs);
371 FuncMD->addOperand(ThisFuncMD);
372
373 for (auto *U : make_early_inc_range(F->users())) {
374 if (auto *CI = dyn_cast<CallInst>(U))
375 CI->mutateFunctionType(NewF->getFunctionType());
376 U->replaceUsesOfWith(F, NewF);
377 }
378 return NewF;
379 }
380
runOnModule(Module & M)381 bool SPIRVPrepareFunctions::runOnModule(Module &M) {
382 bool Changed = false;
383 for (Function &F : M)
384 Changed |= substituteIntrinsicCalls(&F);
385
386 std::vector<Function *> FuncsWorklist;
387 for (auto &F : M)
388 FuncsWorklist.push_back(&F);
389
390 for (auto *F : FuncsWorklist) {
391 Function *NewF = removeAggregateTypesFromSignature(F);
392
393 if (NewF != F) {
394 F->eraseFromParent();
395 Changed = true;
396 }
397 }
398 return Changed;
399 }
400
401 ModulePass *
createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine & TM)402 llvm::createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM) {
403 return new SPIRVPrepareFunctions(TM);
404 }
405