1 //===-- SPIRVPrepareFunctions.cpp - modify function signatures --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This pass modifies function signatures containing aggregate arguments
10 // and/or return value. Also it substitutes some llvm intrinsic calls by
11 // function calls, generating these functions as the translator does.
12 //
13 // NOTE: this pass is a module-level one due to the necessity to modify
14 // GVs/functions.
15 //
16 //===----------------------------------------------------------------------===//
17 
18 #include "SPIRV.h"
19 #include "SPIRVTargetMachine.h"
20 #include "SPIRVUtils.h"
21 #include "llvm/IR/IRBuilder.h"
22 #include "llvm/IR/IntrinsicInst.h"
23 #include "llvm/Transforms/Utils/Cloning.h"
24 #include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
25 
26 using namespace llvm;
27 
28 namespace llvm {
29 void initializeSPIRVPrepareFunctionsPass(PassRegistry &);
30 }
31 
32 namespace {
33 
34 class SPIRVPrepareFunctions : public ModulePass {
35   Function *processFunctionSignature(Function *F);
36 
37 public:
38   static char ID;
39   SPIRVPrepareFunctions() : ModulePass(ID) {
40     initializeSPIRVPrepareFunctionsPass(*PassRegistry::getPassRegistry());
41   }
42 
43   bool runOnModule(Module &M) override;
44 
45   StringRef getPassName() const override { return "SPIRV prepare functions"; }
46 
47   void getAnalysisUsage(AnalysisUsage &AU) const override {
48     ModulePass::getAnalysisUsage(AU);
49   }
50 };
51 
52 } // namespace
53 
54 char SPIRVPrepareFunctions::ID = 0;
55 
56 INITIALIZE_PASS(SPIRVPrepareFunctions, "prepare-functions",
57                 "SPIRV prepare functions", false, false)
58 
59 Function *SPIRVPrepareFunctions::processFunctionSignature(Function *F) {
60   IRBuilder<> B(F->getContext());
61 
62   bool IsRetAggr = F->getReturnType()->isAggregateType();
63   bool HasAggrArg =
64       std::any_of(F->arg_begin(), F->arg_end(), [](Argument &Arg) {
65         return Arg.getType()->isAggregateType();
66       });
67   bool DoClone = IsRetAggr || HasAggrArg;
68   if (!DoClone)
69     return F;
70   SmallVector<std::pair<int, Type *>, 4> ChangedTypes;
71   Type *RetType = IsRetAggr ? B.getInt32Ty() : F->getReturnType();
72   if (IsRetAggr)
73     ChangedTypes.push_back(std::pair<int, Type *>(-1, F->getReturnType()));
74   SmallVector<Type *, 4> ArgTypes;
75   for (const auto &Arg : F->args()) {
76     if (Arg.getType()->isAggregateType()) {
77       ArgTypes.push_back(B.getInt32Ty());
78       ChangedTypes.push_back(
79           std::pair<int, Type *>(Arg.getArgNo(), Arg.getType()));
80     } else
81       ArgTypes.push_back(Arg.getType());
82   }
83   FunctionType *NewFTy =
84       FunctionType::get(RetType, ArgTypes, F->getFunctionType()->isVarArg());
85   Function *NewF =
86       Function::Create(NewFTy, F->getLinkage(), F->getName(), *F->getParent());
87 
88   ValueToValueMapTy VMap;
89   auto NewFArgIt = NewF->arg_begin();
90   for (auto &Arg : F->args()) {
91     StringRef ArgName = Arg.getName();
92     NewFArgIt->setName(ArgName);
93     VMap[&Arg] = &(*NewFArgIt++);
94   }
95   SmallVector<ReturnInst *, 8> Returns;
96 
97   CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly,
98                     Returns);
99   NewF->takeName(F);
100 
101   NamedMDNode *FuncMD =
102       F->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs");
103   SmallVector<Metadata *, 2> MDArgs;
104   MDArgs.push_back(MDString::get(B.getContext(), NewF->getName()));
105   for (auto &ChangedTyP : ChangedTypes)
106     MDArgs.push_back(MDNode::get(
107         B.getContext(),
108         {ConstantAsMetadata::get(B.getInt32(ChangedTyP.first)),
109          ValueAsMetadata::get(Constant::getNullValue(ChangedTyP.second))}));
110   MDNode *ThisFuncMD = MDNode::get(B.getContext(), MDArgs);
111   FuncMD->addOperand(ThisFuncMD);
112 
113   for (auto *U : make_early_inc_range(F->users())) {
114     if (auto *CI = dyn_cast<CallInst>(U))
115       CI->mutateFunctionType(NewF->getFunctionType());
116     U->replaceUsesOfWith(F, NewF);
117   }
118   return NewF;
119 }
120 
121 std::string lowerLLVMIntrinsicName(IntrinsicInst *II) {
122   Function *IntrinsicFunc = II->getCalledFunction();
123   assert(IntrinsicFunc && "Missing function");
124   std::string FuncName = IntrinsicFunc->getName().str();
125   std::replace(FuncName.begin(), FuncName.end(), '.', '_');
126   FuncName = "spirv." + FuncName;
127   return FuncName;
128 }
129 
130 static Function *getOrCreateFunction(Module *M, Type *RetTy,
131                                      ArrayRef<Type *> ArgTypes,
132                                      StringRef Name) {
133   FunctionType *FT = FunctionType::get(RetTy, ArgTypes, false);
134   Function *F = M->getFunction(Name);
135   if (F && F->getFunctionType() == FT)
136     return F;
137   Function *NewF = Function::Create(FT, GlobalValue::ExternalLinkage, Name, M);
138   if (F)
139     NewF->setDSOLocal(F->isDSOLocal());
140   NewF->setCallingConv(CallingConv::SPIR_FUNC);
141   return NewF;
142 }
143 
144 static void lowerFunnelShifts(Module *M, IntrinsicInst *FSHIntrinsic) {
145   // Get a separate function - otherwise, we'd have to rework the CFG of the
146   // current one. Then simply replace the intrinsic uses with a call to the new
147   // function.
148   // Generate LLVM IR for  i* @spirv.llvm_fsh?_i* (i* %a, i* %b, i* %c)
149   FunctionType *FSHFuncTy = FSHIntrinsic->getFunctionType();
150   Type *FSHRetTy = FSHFuncTy->getReturnType();
151   const std::string FuncName = lowerLLVMIntrinsicName(FSHIntrinsic);
152   Function *FSHFunc =
153       getOrCreateFunction(M, FSHRetTy, FSHFuncTy->params(), FuncName);
154 
155   if (!FSHFunc->empty()) {
156     FSHIntrinsic->setCalledFunction(FSHFunc);
157     return;
158   }
159   BasicBlock *RotateBB = BasicBlock::Create(M->getContext(), "rotate", FSHFunc);
160   IRBuilder<> IRB(RotateBB);
161   Type *Ty = FSHFunc->getReturnType();
162   // Build the actual funnel shift rotate logic.
163   // In the comments, "int" is used interchangeably with "vector of int
164   // elements".
165   FixedVectorType *VectorTy = dyn_cast<FixedVectorType>(Ty);
166   Type *IntTy = VectorTy ? VectorTy->getElementType() : Ty;
167   unsigned BitWidth = IntTy->getIntegerBitWidth();
168   ConstantInt *BitWidthConstant = IRB.getInt({BitWidth, BitWidth});
169   Value *BitWidthForInsts =
170       VectorTy
171           ? IRB.CreateVectorSplat(VectorTy->getNumElements(), BitWidthConstant)
172           : BitWidthConstant;
173   Value *RotateModVal =
174       IRB.CreateURem(/*Rotate*/ FSHFunc->getArg(2), BitWidthForInsts);
175   Value *FirstShift = nullptr, *SecShift = nullptr;
176   if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {
177     // Shift the less significant number right, the "rotate" number of bits
178     // will be 0-filled on the left as a result of this regular shift.
179     FirstShift = IRB.CreateLShr(FSHFunc->getArg(1), RotateModVal);
180   } else {
181     // Shift the more significant number left, the "rotate" number of bits
182     // will be 0-filled on the right as a result of this regular shift.
183     FirstShift = IRB.CreateShl(FSHFunc->getArg(0), RotateModVal);
184   }
185   // We want the "rotate" number of the more significant int's LSBs (MSBs) to
186   // occupy the leftmost (rightmost) "0 space" left by the previous operation.
187   // Therefore, subtract the "rotate" number from the integer bitsize...
188   Value *SubRotateVal = IRB.CreateSub(BitWidthForInsts, RotateModVal);
189   if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {
190     // ...and left-shift the more significant int by this number, zero-filling
191     // the LSBs.
192     SecShift = IRB.CreateShl(FSHFunc->getArg(0), SubRotateVal);
193   } else {
194     // ...and right-shift the less significant int by this number, zero-filling
195     // the MSBs.
196     SecShift = IRB.CreateLShr(FSHFunc->getArg(1), SubRotateVal);
197   }
198   // A simple binary addition of the shifted ints yields the final result.
199   IRB.CreateRet(IRB.CreateOr(FirstShift, SecShift));
200 
201   FSHIntrinsic->setCalledFunction(FSHFunc);
202 }
203 
204 static void buildUMulWithOverflowFunc(Module *M, Function *UMulFunc) {
205   // The function body is already created.
206   if (!UMulFunc->empty())
207     return;
208 
209   BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", UMulFunc);
210   IRBuilder<> IRB(EntryBB);
211   // Build the actual unsigned multiplication logic with the overflow
212   // indication. Do unsigned multiplication Mul = A * B. Then check
213   // if unsigned division Div = Mul / A is not equal to B. If so,
214   // then overflow has happened.
215   Value *Mul = IRB.CreateNUWMul(UMulFunc->getArg(0), UMulFunc->getArg(1));
216   Value *Div = IRB.CreateUDiv(Mul, UMulFunc->getArg(0));
217   Value *Overflow = IRB.CreateICmpNE(UMulFunc->getArg(0), Div);
218 
219   // umul.with.overflow intrinsic return a structure, where the first element
220   // is the multiplication result, and the second is an overflow bit.
221   Type *StructTy = UMulFunc->getReturnType();
222   Value *Agg = IRB.CreateInsertValue(UndefValue::get(StructTy), Mul, {0});
223   Value *Res = IRB.CreateInsertValue(Agg, Overflow, {1});
224   IRB.CreateRet(Res);
225 }
226 
227 static void lowerUMulWithOverflow(Module *M, IntrinsicInst *UMulIntrinsic) {
228   // Get a separate function - otherwise, we'd have to rework the CFG of the
229   // current one. Then simply replace the intrinsic uses with a call to the new
230   // function.
231   FunctionType *UMulFuncTy = UMulIntrinsic->getFunctionType();
232   Type *FSHLRetTy = UMulFuncTy->getReturnType();
233   const std::string FuncName = lowerLLVMIntrinsicName(UMulIntrinsic);
234   Function *UMulFunc =
235       getOrCreateFunction(M, FSHLRetTy, UMulFuncTy->params(), FuncName);
236   buildUMulWithOverflowFunc(M, UMulFunc);
237   UMulIntrinsic->setCalledFunction(UMulFunc);
238 }
239 
240 static void substituteIntrinsicCalls(Module *M, Function *F) {
241   for (BasicBlock &BB : *F) {
242     for (Instruction &I : BB) {
243       auto Call = dyn_cast<CallInst>(&I);
244       if (!Call)
245         continue;
246       Call->setTailCall(false);
247       Function *CF = Call->getCalledFunction();
248       if (!CF || !CF->isIntrinsic())
249         continue;
250       auto *II = cast<IntrinsicInst>(Call);
251       if (II->getIntrinsicID() == Intrinsic::fshl ||
252           II->getIntrinsicID() == Intrinsic::fshr)
253         lowerFunnelShifts(M, II);
254       else if (II->getIntrinsicID() == Intrinsic::umul_with_overflow)
255         lowerUMulWithOverflow(M, II);
256     }
257   }
258 }
259 
260 bool SPIRVPrepareFunctions::runOnModule(Module &M) {
261   for (Function &F : M)
262     substituteIntrinsicCalls(&M, &F);
263 
264   std::vector<Function *> FuncsWorklist;
265   bool Changed = false;
266   for (auto &F : M)
267     FuncsWorklist.push_back(&F);
268 
269   for (auto *Func : FuncsWorklist) {
270     Function *F = processFunctionSignature(Func);
271 
272     bool CreatedNewF = F != Func;
273 
274     if (Func->isDeclaration()) {
275       Changed |= CreatedNewF;
276       continue;
277     }
278 
279     if (CreatedNewF)
280       Func->eraseFromParent();
281   }
282 
283   return Changed;
284 }
285 
286 ModulePass *llvm::createSPIRVPrepareFunctionsPass() {
287   return new SPIRVPrepareFunctions();
288 }
289