1 //===-- SPIRVPrepareFunctions.cpp - modify function signatures --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass modifies function signatures containing aggregate arguments
10 // and/or return value. Also it substitutes some llvm intrinsic calls by
11 // function calls, generating these functions as the translator does.
12 //
13 // NOTE: this pass is a module-level one due to the necessity to modify
14 // GVs/functions.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "SPIRV.h"
19 #include "SPIRVTargetMachine.h"
20 #include "SPIRVUtils.h"
21 #include "llvm/CodeGen/IntrinsicLowering.h"
22 #include "llvm/IR/IRBuilder.h"
23 #include "llvm/IR/IntrinsicInst.h"
24 #include "llvm/Transforms/Utils/Cloning.h"
25 #include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
26 
27 using namespace llvm;
28 
29 namespace llvm {
30 void initializeSPIRVPrepareFunctionsPass(PassRegistry &);
31 }
32 
33 namespace {
34 
35 class SPIRVPrepareFunctions : public ModulePass {
36   Function *processFunctionSignature(Function *F);
37 
38 public:
39   static char ID;
40   SPIRVPrepareFunctions() : ModulePass(ID) {
41     initializeSPIRVPrepareFunctionsPass(*PassRegistry::getPassRegistry());
42   }
43 
44   bool runOnModule(Module &M) override;
45 
46   StringRef getPassName() const override { return "SPIRV prepare functions"; }
47 
48   void getAnalysisUsage(AnalysisUsage &AU) const override {
49     ModulePass::getAnalysisUsage(AU);
50   }
51 };
52 
53 } // namespace
54 
55 char SPIRVPrepareFunctions::ID = 0;
56 
57 INITIALIZE_PASS(SPIRVPrepareFunctions, "prepare-functions",
58                 "SPIRV prepare functions", false, false)
59 
60 Function *SPIRVPrepareFunctions::processFunctionSignature(Function *F) {
61   IRBuilder<> B(F->getContext());
62 
63   bool IsRetAggr = F->getReturnType()->isAggregateType();
64   bool HasAggrArg =
65       std::any_of(F->arg_begin(), F->arg_end(), [](Argument &Arg) {
66         return Arg.getType()->isAggregateType();
67       });
68   bool DoClone = IsRetAggr || HasAggrArg;
69   if (!DoClone)
70     return F;
71   SmallVector<std::pair<int, Type *>, 4> ChangedTypes;
72   Type *RetType = IsRetAggr ? B.getInt32Ty() : F->getReturnType();
73   if (IsRetAggr)
74     ChangedTypes.push_back(std::pair<int, Type *>(-1, F->getReturnType()));
75   SmallVector<Type *, 4> ArgTypes;
76   for (const auto &Arg : F->args()) {
77     if (Arg.getType()->isAggregateType()) {
78       ArgTypes.push_back(B.getInt32Ty());
79       ChangedTypes.push_back(
80           std::pair<int, Type *>(Arg.getArgNo(), Arg.getType()));
81     } else
82       ArgTypes.push_back(Arg.getType());
83   }
84   FunctionType *NewFTy =
85       FunctionType::get(RetType, ArgTypes, F->getFunctionType()->isVarArg());
86   Function *NewF =
87       Function::Create(NewFTy, F->getLinkage(), F->getName(), *F->getParent());
88 
89   ValueToValueMapTy VMap;
90   auto NewFArgIt = NewF->arg_begin();
91   for (auto &Arg : F->args()) {
92     StringRef ArgName = Arg.getName();
93     NewFArgIt->setName(ArgName);
94     VMap[&Arg] = &(*NewFArgIt++);
95   }
96   SmallVector<ReturnInst *, 8> Returns;
97 
98   CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly,
99                     Returns);
100   NewF->takeName(F);
101 
102   NamedMDNode *FuncMD =
103       F->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs");
104   SmallVector<Metadata *, 2> MDArgs;
105   MDArgs.push_back(MDString::get(B.getContext(), NewF->getName()));
106   for (auto &ChangedTyP : ChangedTypes)
107     MDArgs.push_back(MDNode::get(
108         B.getContext(),
109         {ConstantAsMetadata::get(B.getInt32(ChangedTyP.first)),
110          ValueAsMetadata::get(Constant::getNullValue(ChangedTyP.second))}));
111   MDNode *ThisFuncMD = MDNode::get(B.getContext(), MDArgs);
112   FuncMD->addOperand(ThisFuncMD);
113 
114   for (auto *U : make_early_inc_range(F->users())) {
115     if (auto *CI = dyn_cast<CallInst>(U))
116       CI->mutateFunctionType(NewF->getFunctionType());
117     U->replaceUsesOfWith(F, NewF);
118   }
119   return NewF;
120 }
121 
122 std::string lowerLLVMIntrinsicName(IntrinsicInst *II) {
123   Function *IntrinsicFunc = II->getCalledFunction();
124   assert(IntrinsicFunc && "Missing function");
125   std::string FuncName = IntrinsicFunc->getName().str();
126   std::replace(FuncName.begin(), FuncName.end(), '.', '_');
127   FuncName = "spirv." + FuncName;
128   return FuncName;
129 }
130 
131 static Function *getOrCreateFunction(Module *M, Type *RetTy,
132                                      ArrayRef<Type *> ArgTypes,
133                                      StringRef Name) {
134   FunctionType *FT = FunctionType::get(RetTy, ArgTypes, false);
135   Function *F = M->getFunction(Name);
136   if (F && F->getFunctionType() == FT)
137     return F;
138   Function *NewF = Function::Create(FT, GlobalValue::ExternalLinkage, Name, M);
139   if (F)
140     NewF->setDSOLocal(F->isDSOLocal());
141   NewF->setCallingConv(CallingConv::SPIR_FUNC);
142   return NewF;
143 }
144 
145 static void lowerIntrinsicToFunction(Module *M, IntrinsicInst *Intrinsic) {
146   // For @llvm.memset.* intrinsic cases with constant value and length arguments
147   // are emulated via "storing" a constant array to the destination. For other
148   // cases we wrap the intrinsic in @spirv.llvm_memset_* function and expand the
149   // intrinsic to a loop via expandMemSetAsLoop().
150   if (auto *MSI = dyn_cast<MemSetInst>(Intrinsic))
151     if (isa<Constant>(MSI->getValue()) && isa<ConstantInt>(MSI->getLength()))
152       return; // It is handled later using OpCopyMemorySized.
153 
154   std::string FuncName = lowerLLVMIntrinsicName(Intrinsic);
155   if (Intrinsic->isVolatile())
156     FuncName += ".volatile";
157   // Redirect @llvm.intrinsic.* call to @spirv.llvm_intrinsic_*
158   Function *F = M->getFunction(FuncName);
159   if (F) {
160     Intrinsic->setCalledFunction(F);
161     return;
162   }
163   // TODO copy arguments attributes: nocapture writeonly.
164   FunctionCallee FC =
165       M->getOrInsertFunction(FuncName, Intrinsic->getFunctionType());
166   auto IntrinsicID = Intrinsic->getIntrinsicID();
167   Intrinsic->setCalledFunction(FC);
168 
169   F = dyn_cast<Function>(FC.getCallee());
170   assert(F && "Callee must be a function");
171 
172   switch (IntrinsicID) {
173   case Intrinsic::memset: {
174     auto *MSI = static_cast<MemSetInst *>(Intrinsic);
175     Argument *Dest = F->getArg(0);
176     Argument *Val = F->getArg(1);
177     Argument *Len = F->getArg(2);
178     Argument *IsVolatile = F->getArg(3);
179     Dest->setName("dest");
180     Val->setName("val");
181     Len->setName("len");
182     IsVolatile->setName("isvolatile");
183     BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", F);
184     IRBuilder<> IRB(EntryBB);
185     auto *MemSet = IRB.CreateMemSet(Dest, Val, Len, MSI->getDestAlign(),
186                                     MSI->isVolatile());
187     IRB.CreateRetVoid();
188     expandMemSetAsLoop(cast<MemSetInst>(MemSet));
189     MemSet->eraseFromParent();
190     break;
191   }
192   case Intrinsic::bswap: {
193     BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", F);
194     IRBuilder<> IRB(EntryBB);
195     auto *BSwap = IRB.CreateIntrinsic(Intrinsic::bswap, Intrinsic->getType(),
196                                       F->getArg(0));
197     IRB.CreateRet(BSwap);
198     IntrinsicLowering IL(M->getDataLayout());
199     IL.LowerIntrinsicCall(BSwap);
200     break;
201   }
202   default:
203     break;
204   }
205   return;
206 }
207 
208 static void lowerFunnelShifts(Module *M, IntrinsicInst *FSHIntrinsic) {
209   // Get a separate function - otherwise, we'd have to rework the CFG of the
210   // current one. Then simply replace the intrinsic uses with a call to the new
211   // function.
212   // Generate LLVM IR for  i* @spirv.llvm_fsh?_i* (i* %a, i* %b, i* %c)
213   FunctionType *FSHFuncTy = FSHIntrinsic->getFunctionType();
214   Type *FSHRetTy = FSHFuncTy->getReturnType();
215   const std::string FuncName = lowerLLVMIntrinsicName(FSHIntrinsic);
216   Function *FSHFunc =
217       getOrCreateFunction(M, FSHRetTy, FSHFuncTy->params(), FuncName);
218 
219   if (!FSHFunc->empty()) {
220     FSHIntrinsic->setCalledFunction(FSHFunc);
221     return;
222   }
223   BasicBlock *RotateBB = BasicBlock::Create(M->getContext(), "rotate", FSHFunc);
224   IRBuilder<> IRB(RotateBB);
225   Type *Ty = FSHFunc->getReturnType();
226   // Build the actual funnel shift rotate logic.
227   // In the comments, "int" is used interchangeably with "vector of int
228   // elements".
229   FixedVectorType *VectorTy = dyn_cast<FixedVectorType>(Ty);
230   Type *IntTy = VectorTy ? VectorTy->getElementType() : Ty;
231   unsigned BitWidth = IntTy->getIntegerBitWidth();
232   ConstantInt *BitWidthConstant = IRB.getInt({BitWidth, BitWidth});
233   Value *BitWidthForInsts =
234       VectorTy
235           ? IRB.CreateVectorSplat(VectorTy->getNumElements(), BitWidthConstant)
236           : BitWidthConstant;
237   Value *RotateModVal =
238       IRB.CreateURem(/*Rotate*/ FSHFunc->getArg(2), BitWidthForInsts);
239   Value *FirstShift = nullptr, *SecShift = nullptr;
240   if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {
241     // Shift the less significant number right, the "rotate" number of bits
242     // will be 0-filled on the left as a result of this regular shift.
243     FirstShift = IRB.CreateLShr(FSHFunc->getArg(1), RotateModVal);
244   } else {
245     // Shift the more significant number left, the "rotate" number of bits
246     // will be 0-filled on the right as a result of this regular shift.
247     FirstShift = IRB.CreateShl(FSHFunc->getArg(0), RotateModVal);
248   }
249   // We want the "rotate" number of the more significant int's LSBs (MSBs) to
250   // occupy the leftmost (rightmost) "0 space" left by the previous operation.
251   // Therefore, subtract the "rotate" number from the integer bitsize...
252   Value *SubRotateVal = IRB.CreateSub(BitWidthForInsts, RotateModVal);
253   if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {
254     // ...and left-shift the more significant int by this number, zero-filling
255     // the LSBs.
256     SecShift = IRB.CreateShl(FSHFunc->getArg(0), SubRotateVal);
257   } else {
258     // ...and right-shift the less significant int by this number, zero-filling
259     // the MSBs.
260     SecShift = IRB.CreateLShr(FSHFunc->getArg(1), SubRotateVal);
261   }
262   // A simple binary addition of the shifted ints yields the final result.
263   IRB.CreateRet(IRB.CreateOr(FirstShift, SecShift));
264 
265   FSHIntrinsic->setCalledFunction(FSHFunc);
266 }
267 
268 static void buildUMulWithOverflowFunc(Module *M, Function *UMulFunc) {
269   // The function body is already created.
270   if (!UMulFunc->empty())
271     return;
272 
273   BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", UMulFunc);
274   IRBuilder<> IRB(EntryBB);
275   // Build the actual unsigned multiplication logic with the overflow
276   // indication. Do unsigned multiplication Mul = A * B. Then check
277   // if unsigned division Div = Mul / A is not equal to B. If so,
278   // then overflow has happened.
279   Value *Mul = IRB.CreateNUWMul(UMulFunc->getArg(0), UMulFunc->getArg(1));
280   Value *Div = IRB.CreateUDiv(Mul, UMulFunc->getArg(0));
281   Value *Overflow = IRB.CreateICmpNE(UMulFunc->getArg(0), Div);
282 
283   // umul.with.overflow intrinsic return a structure, where the first element
284   // is the multiplication result, and the second is an overflow bit.
285   Type *StructTy = UMulFunc->getReturnType();
286   Value *Agg = IRB.CreateInsertValue(PoisonValue::get(StructTy), Mul, {0});
287   Value *Res = IRB.CreateInsertValue(Agg, Overflow, {1});
288   IRB.CreateRet(Res);
289 }
290 
291 static void lowerUMulWithOverflow(Module *M, IntrinsicInst *UMulIntrinsic) {
292   // Get a separate function - otherwise, we'd have to rework the CFG of the
293   // current one. Then simply replace the intrinsic uses with a call to the new
294   // function.
295   FunctionType *UMulFuncTy = UMulIntrinsic->getFunctionType();
296   Type *FSHLRetTy = UMulFuncTy->getReturnType();
297   const std::string FuncName = lowerLLVMIntrinsicName(UMulIntrinsic);
298   Function *UMulFunc =
299       getOrCreateFunction(M, FSHLRetTy, UMulFuncTy->params(), FuncName);
300   buildUMulWithOverflowFunc(M, UMulFunc);
301   UMulIntrinsic->setCalledFunction(UMulFunc);
302 }
303 
304 static void substituteIntrinsicCalls(Module *M, Function *F) {
305   for (BasicBlock &BB : *F) {
306     for (Instruction &I : BB) {
307       auto Call = dyn_cast<CallInst>(&I);
308       if (!Call)
309         continue;
310       Call->setTailCall(false);
311       Function *CF = Call->getCalledFunction();
312       if (!CF || !CF->isIntrinsic())
313         continue;
314       auto *II = cast<IntrinsicInst>(Call);
315       if (II->getIntrinsicID() == Intrinsic::memset ||
316           II->getIntrinsicID() == Intrinsic::bswap)
317         lowerIntrinsicToFunction(M, II);
318       else if (II->getIntrinsicID() == Intrinsic::fshl ||
319                II->getIntrinsicID() == Intrinsic::fshr)
320         lowerFunnelShifts(M, II);
321       else if (II->getIntrinsicID() == Intrinsic::umul_with_overflow)
322         lowerUMulWithOverflow(M, II);
323     }
324   }
325 }
326 
327 bool SPIRVPrepareFunctions::runOnModule(Module &M) {
328   for (Function &F : M)
329     substituteIntrinsicCalls(&M, &F);
330 
331   std::vector<Function *> FuncsWorklist;
332   bool Changed = false;
333   for (auto &F : M)
334     FuncsWorklist.push_back(&F);
335 
336   for (auto *Func : FuncsWorklist) {
337     Function *F = processFunctionSignature(Func);
338 
339     bool CreatedNewF = F != Func;
340 
341     if (Func->isDeclaration()) {
342       Changed |= CreatedNewF;
343       continue;
344     }
345 
346     if (CreatedNewF)
347       Func->eraseFromParent();
348   }
349 
350   return Changed;
351 }
352 
353 ModulePass *llvm::createSPIRVPrepareFunctionsPass() {
354   return new SPIRVPrepareFunctions();
355 }
356