1fcaf7f86SDimitry Andric //===-- SPIRVPrepareFunctions.cpp - modify function signatures --*- C++ -*-===//
2fcaf7f86SDimitry Andric //
3fcaf7f86SDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4fcaf7f86SDimitry Andric // See https://llvm.org/LICENSE.txt for license information.
5fcaf7f86SDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6fcaf7f86SDimitry Andric //
7fcaf7f86SDimitry Andric //===----------------------------------------------------------------------===//
8fcaf7f86SDimitry Andric //
9fcaf7f86SDimitry Andric // This pass modifies function signatures containing aggregate arguments
1006c3fb27SDimitry Andric // and/or return value before IRTranslator. Information about the original
1106c3fb27SDimitry Andric // signatures is stored in metadata. It is used during call lowering to
1206c3fb27SDimitry Andric // restore correct SPIR-V types of function arguments and return values.
1306c3fb27SDimitry Andric // This pass also substitutes some llvm intrinsic calls with calls to newly
1406c3fb27SDimitry Andric // generated functions (as the Khronos LLVM/SPIR-V Translator does).
15fcaf7f86SDimitry Andric //
16fcaf7f86SDimitry Andric // NOTE: this pass is a module-level one due to the necessity to modify
17fcaf7f86SDimitry Andric // GVs/functions.
18fcaf7f86SDimitry Andric //
19fcaf7f86SDimitry Andric //===----------------------------------------------------------------------===//
20fcaf7f86SDimitry Andric 
21fcaf7f86SDimitry Andric #include "SPIRV.h"
22*5f757f3fSDimitry Andric #include "SPIRVSubtarget.h"
23fcaf7f86SDimitry Andric #include "SPIRVTargetMachine.h"
24fcaf7f86SDimitry Andric #include "SPIRVUtils.h"
25bdd1243dSDimitry Andric #include "llvm/CodeGen/IntrinsicLowering.h"
26fcaf7f86SDimitry Andric #include "llvm/IR/IRBuilder.h"
27fcaf7f86SDimitry Andric #include "llvm/IR/IntrinsicInst.h"
28*5f757f3fSDimitry Andric #include "llvm/IR/Intrinsics.h"
29*5f757f3fSDimitry Andric #include "llvm/IR/IntrinsicsSPIRV.h"
30fcaf7f86SDimitry Andric #include "llvm/Transforms/Utils/Cloning.h"
31fcaf7f86SDimitry Andric #include "llvm/Transforms/Utils/LowerMemIntrinsics.h"
32fcaf7f86SDimitry Andric 
33fcaf7f86SDimitry Andric using namespace llvm;
34fcaf7f86SDimitry Andric 
35fcaf7f86SDimitry Andric namespace llvm {
36fcaf7f86SDimitry Andric void initializeSPIRVPrepareFunctionsPass(PassRegistry &);
37fcaf7f86SDimitry Andric }
38fcaf7f86SDimitry Andric 
39fcaf7f86SDimitry Andric namespace {
40fcaf7f86SDimitry Andric 
41fcaf7f86SDimitry Andric class SPIRVPrepareFunctions : public ModulePass {
42*5f757f3fSDimitry Andric   const SPIRVTargetMachine &TM;
4306c3fb27SDimitry Andric   bool substituteIntrinsicCalls(Function *F);
4406c3fb27SDimitry Andric   Function *removeAggregateTypesFromSignature(Function *F);
45fcaf7f86SDimitry Andric 
46fcaf7f86SDimitry Andric public:
47fcaf7f86SDimitry Andric   static char ID;
SPIRVPrepareFunctions(const SPIRVTargetMachine & TM)48*5f757f3fSDimitry Andric   SPIRVPrepareFunctions(const SPIRVTargetMachine &TM) : ModulePass(ID), TM(TM) {
49fcaf7f86SDimitry Andric     initializeSPIRVPrepareFunctionsPass(*PassRegistry::getPassRegistry());
50fcaf7f86SDimitry Andric   }
51fcaf7f86SDimitry Andric 
52fcaf7f86SDimitry Andric   bool runOnModule(Module &M) override;
53fcaf7f86SDimitry Andric 
getPassName() const54fcaf7f86SDimitry Andric   StringRef getPassName() const override { return "SPIRV prepare functions"; }
55fcaf7f86SDimitry Andric 
getAnalysisUsage(AnalysisUsage & AU) const56fcaf7f86SDimitry Andric   void getAnalysisUsage(AnalysisUsage &AU) const override {
57fcaf7f86SDimitry Andric     ModulePass::getAnalysisUsage(AU);
58fcaf7f86SDimitry Andric   }
59fcaf7f86SDimitry Andric };
60fcaf7f86SDimitry Andric 
61fcaf7f86SDimitry Andric } // namespace
62fcaf7f86SDimitry Andric 
63fcaf7f86SDimitry Andric char SPIRVPrepareFunctions::ID = 0;
64fcaf7f86SDimitry Andric 
65fcaf7f86SDimitry Andric INITIALIZE_PASS(SPIRVPrepareFunctions, "prepare-functions",
66fcaf7f86SDimitry Andric                 "SPIRV prepare functions", false, false)
67fcaf7f86SDimitry Andric 
lowerLLVMIntrinsicName(IntrinsicInst * II)6806c3fb27SDimitry Andric std::string lowerLLVMIntrinsicName(IntrinsicInst *II) {
6906c3fb27SDimitry Andric   Function *IntrinsicFunc = II->getCalledFunction();
7006c3fb27SDimitry Andric   assert(IntrinsicFunc && "Missing function");
7106c3fb27SDimitry Andric   std::string FuncName = IntrinsicFunc->getName().str();
7206c3fb27SDimitry Andric   std::replace(FuncName.begin(), FuncName.end(), '.', '_');
7306c3fb27SDimitry Andric   FuncName = "spirv." + FuncName;
7406c3fb27SDimitry Andric   return FuncName;
7506c3fb27SDimitry Andric }
7606c3fb27SDimitry Andric 
getOrCreateFunction(Module * M,Type * RetTy,ArrayRef<Type * > ArgTypes,StringRef Name)7706c3fb27SDimitry Andric static Function *getOrCreateFunction(Module *M, Type *RetTy,
7806c3fb27SDimitry Andric                                      ArrayRef<Type *> ArgTypes,
7906c3fb27SDimitry Andric                                      StringRef Name) {
8006c3fb27SDimitry Andric   FunctionType *FT = FunctionType::get(RetTy, ArgTypes, false);
8106c3fb27SDimitry Andric   Function *F = M->getFunction(Name);
8206c3fb27SDimitry Andric   if (F && F->getFunctionType() == FT)
8306c3fb27SDimitry Andric     return F;
8406c3fb27SDimitry Andric   Function *NewF = Function::Create(FT, GlobalValue::ExternalLinkage, Name, M);
8506c3fb27SDimitry Andric   if (F)
8606c3fb27SDimitry Andric     NewF->setDSOLocal(F->isDSOLocal());
8706c3fb27SDimitry Andric   NewF->setCallingConv(CallingConv::SPIR_FUNC);
8806c3fb27SDimitry Andric   return NewF;
8906c3fb27SDimitry Andric }
9006c3fb27SDimitry Andric 
lowerIntrinsicToFunction(IntrinsicInst * Intrinsic)9106c3fb27SDimitry Andric static bool lowerIntrinsicToFunction(IntrinsicInst *Intrinsic) {
9206c3fb27SDimitry Andric   // For @llvm.memset.* intrinsic cases with constant value and length arguments
9306c3fb27SDimitry Andric   // are emulated via "storing" a constant array to the destination. For other
9406c3fb27SDimitry Andric   // cases we wrap the intrinsic in @spirv.llvm_memset_* function and expand the
9506c3fb27SDimitry Andric   // intrinsic to a loop via expandMemSetAsLoop().
9606c3fb27SDimitry Andric   if (auto *MSI = dyn_cast<MemSetInst>(Intrinsic))
9706c3fb27SDimitry Andric     if (isa<Constant>(MSI->getValue()) && isa<ConstantInt>(MSI->getLength()))
9806c3fb27SDimitry Andric       return false; // It is handled later using OpCopyMemorySized.
9906c3fb27SDimitry Andric 
10006c3fb27SDimitry Andric   Module *M = Intrinsic->getModule();
10106c3fb27SDimitry Andric   std::string FuncName = lowerLLVMIntrinsicName(Intrinsic);
10206c3fb27SDimitry Andric   if (Intrinsic->isVolatile())
10306c3fb27SDimitry Andric     FuncName += ".volatile";
10406c3fb27SDimitry Andric   // Redirect @llvm.intrinsic.* call to @spirv.llvm_intrinsic_*
10506c3fb27SDimitry Andric   Function *F = M->getFunction(FuncName);
10606c3fb27SDimitry Andric   if (F) {
10706c3fb27SDimitry Andric     Intrinsic->setCalledFunction(F);
10806c3fb27SDimitry Andric     return true;
10906c3fb27SDimitry Andric   }
11006c3fb27SDimitry Andric   // TODO copy arguments attributes: nocapture writeonly.
11106c3fb27SDimitry Andric   FunctionCallee FC =
11206c3fb27SDimitry Andric       M->getOrInsertFunction(FuncName, Intrinsic->getFunctionType());
11306c3fb27SDimitry Andric   auto IntrinsicID = Intrinsic->getIntrinsicID();
11406c3fb27SDimitry Andric   Intrinsic->setCalledFunction(FC);
11506c3fb27SDimitry Andric 
11606c3fb27SDimitry Andric   F = dyn_cast<Function>(FC.getCallee());
11706c3fb27SDimitry Andric   assert(F && "Callee must be a function");
11806c3fb27SDimitry Andric 
11906c3fb27SDimitry Andric   switch (IntrinsicID) {
12006c3fb27SDimitry Andric   case Intrinsic::memset: {
12106c3fb27SDimitry Andric     auto *MSI = static_cast<MemSetInst *>(Intrinsic);
12206c3fb27SDimitry Andric     Argument *Dest = F->getArg(0);
12306c3fb27SDimitry Andric     Argument *Val = F->getArg(1);
12406c3fb27SDimitry Andric     Argument *Len = F->getArg(2);
12506c3fb27SDimitry Andric     Argument *IsVolatile = F->getArg(3);
12606c3fb27SDimitry Andric     Dest->setName("dest");
12706c3fb27SDimitry Andric     Val->setName("val");
12806c3fb27SDimitry Andric     Len->setName("len");
12906c3fb27SDimitry Andric     IsVolatile->setName("isvolatile");
13006c3fb27SDimitry Andric     BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", F);
13106c3fb27SDimitry Andric     IRBuilder<> IRB(EntryBB);
13206c3fb27SDimitry Andric     auto *MemSet = IRB.CreateMemSet(Dest, Val, Len, MSI->getDestAlign(),
13306c3fb27SDimitry Andric                                     MSI->isVolatile());
13406c3fb27SDimitry Andric     IRB.CreateRetVoid();
13506c3fb27SDimitry Andric     expandMemSetAsLoop(cast<MemSetInst>(MemSet));
13606c3fb27SDimitry Andric     MemSet->eraseFromParent();
13706c3fb27SDimitry Andric     break;
13806c3fb27SDimitry Andric   }
13906c3fb27SDimitry Andric   case Intrinsic::bswap: {
14006c3fb27SDimitry Andric     BasicBlock *EntryBB = BasicBlock::Create(M->getContext(), "entry", F);
14106c3fb27SDimitry Andric     IRBuilder<> IRB(EntryBB);
14206c3fb27SDimitry Andric     auto *BSwap = IRB.CreateIntrinsic(Intrinsic::bswap, Intrinsic->getType(),
14306c3fb27SDimitry Andric                                       F->getArg(0));
14406c3fb27SDimitry Andric     IRB.CreateRet(BSwap);
14506c3fb27SDimitry Andric     IntrinsicLowering IL(M->getDataLayout());
14606c3fb27SDimitry Andric     IL.LowerIntrinsicCall(BSwap);
14706c3fb27SDimitry Andric     break;
14806c3fb27SDimitry Andric   }
14906c3fb27SDimitry Andric   default:
15006c3fb27SDimitry Andric     break;
15106c3fb27SDimitry Andric   }
15206c3fb27SDimitry Andric   return true;
15306c3fb27SDimitry Andric }
15406c3fb27SDimitry Andric 
lowerFunnelShifts(IntrinsicInst * FSHIntrinsic)15506c3fb27SDimitry Andric static void lowerFunnelShifts(IntrinsicInst *FSHIntrinsic) {
15606c3fb27SDimitry Andric   // Get a separate function - otherwise, we'd have to rework the CFG of the
15706c3fb27SDimitry Andric   // current one. Then simply replace the intrinsic uses with a call to the new
15806c3fb27SDimitry Andric   // function.
15906c3fb27SDimitry Andric   // Generate LLVM IR for  i* @spirv.llvm_fsh?_i* (i* %a, i* %b, i* %c)
16006c3fb27SDimitry Andric   Module *M = FSHIntrinsic->getModule();
16106c3fb27SDimitry Andric   FunctionType *FSHFuncTy = FSHIntrinsic->getFunctionType();
16206c3fb27SDimitry Andric   Type *FSHRetTy = FSHFuncTy->getReturnType();
16306c3fb27SDimitry Andric   const std::string FuncName = lowerLLVMIntrinsicName(FSHIntrinsic);
16406c3fb27SDimitry Andric   Function *FSHFunc =
16506c3fb27SDimitry Andric       getOrCreateFunction(M, FSHRetTy, FSHFuncTy->params(), FuncName);
16606c3fb27SDimitry Andric 
16706c3fb27SDimitry Andric   if (!FSHFunc->empty()) {
16806c3fb27SDimitry Andric     FSHIntrinsic->setCalledFunction(FSHFunc);
16906c3fb27SDimitry Andric     return;
17006c3fb27SDimitry Andric   }
17106c3fb27SDimitry Andric   BasicBlock *RotateBB = BasicBlock::Create(M->getContext(), "rotate", FSHFunc);
17206c3fb27SDimitry Andric   IRBuilder<> IRB(RotateBB);
17306c3fb27SDimitry Andric   Type *Ty = FSHFunc->getReturnType();
17406c3fb27SDimitry Andric   // Build the actual funnel shift rotate logic.
17506c3fb27SDimitry Andric   // In the comments, "int" is used interchangeably with "vector of int
17606c3fb27SDimitry Andric   // elements".
17706c3fb27SDimitry Andric   FixedVectorType *VectorTy = dyn_cast<FixedVectorType>(Ty);
17806c3fb27SDimitry Andric   Type *IntTy = VectorTy ? VectorTy->getElementType() : Ty;
17906c3fb27SDimitry Andric   unsigned BitWidth = IntTy->getIntegerBitWidth();
18006c3fb27SDimitry Andric   ConstantInt *BitWidthConstant = IRB.getInt({BitWidth, BitWidth});
18106c3fb27SDimitry Andric   Value *BitWidthForInsts =
18206c3fb27SDimitry Andric       VectorTy
18306c3fb27SDimitry Andric           ? IRB.CreateVectorSplat(VectorTy->getNumElements(), BitWidthConstant)
18406c3fb27SDimitry Andric           : BitWidthConstant;
18506c3fb27SDimitry Andric   Value *RotateModVal =
18606c3fb27SDimitry Andric       IRB.CreateURem(/*Rotate*/ FSHFunc->getArg(2), BitWidthForInsts);
18706c3fb27SDimitry Andric   Value *FirstShift = nullptr, *SecShift = nullptr;
18806c3fb27SDimitry Andric   if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {
18906c3fb27SDimitry Andric     // Shift the less significant number right, the "rotate" number of bits
19006c3fb27SDimitry Andric     // will be 0-filled on the left as a result of this regular shift.
19106c3fb27SDimitry Andric     FirstShift = IRB.CreateLShr(FSHFunc->getArg(1), RotateModVal);
19206c3fb27SDimitry Andric   } else {
19306c3fb27SDimitry Andric     // Shift the more significant number left, the "rotate" number of bits
19406c3fb27SDimitry Andric     // will be 0-filled on the right as a result of this regular shift.
19506c3fb27SDimitry Andric     FirstShift = IRB.CreateShl(FSHFunc->getArg(0), RotateModVal);
19606c3fb27SDimitry Andric   }
19706c3fb27SDimitry Andric   // We want the "rotate" number of the more significant int's LSBs (MSBs) to
19806c3fb27SDimitry Andric   // occupy the leftmost (rightmost) "0 space" left by the previous operation.
19906c3fb27SDimitry Andric   // Therefore, subtract the "rotate" number from the integer bitsize...
20006c3fb27SDimitry Andric   Value *SubRotateVal = IRB.CreateSub(BitWidthForInsts, RotateModVal);
20106c3fb27SDimitry Andric   if (FSHIntrinsic->getIntrinsicID() == Intrinsic::fshr) {
20206c3fb27SDimitry Andric     // ...and left-shift the more significant int by this number, zero-filling
20306c3fb27SDimitry Andric     // the LSBs.
20406c3fb27SDimitry Andric     SecShift = IRB.CreateShl(FSHFunc->getArg(0), SubRotateVal);
20506c3fb27SDimitry Andric   } else {
20606c3fb27SDimitry Andric     // ...and right-shift the less significant int by this number, zero-filling
20706c3fb27SDimitry Andric     // the MSBs.
20806c3fb27SDimitry Andric     SecShift = IRB.CreateLShr(FSHFunc->getArg(1), SubRotateVal);
20906c3fb27SDimitry Andric   }
21006c3fb27SDimitry Andric   // A simple binary addition of the shifted ints yields the final result.
21106c3fb27SDimitry Andric   IRB.CreateRet(IRB.CreateOr(FirstShift, SecShift));
21206c3fb27SDimitry Andric 
21306c3fb27SDimitry Andric   FSHIntrinsic->setCalledFunction(FSHFunc);
21406c3fb27SDimitry Andric }
21506c3fb27SDimitry Andric 
buildUMulWithOverflowFunc(Function * UMulFunc)21606c3fb27SDimitry Andric static void buildUMulWithOverflowFunc(Function *UMulFunc) {
21706c3fb27SDimitry Andric   // The function body is already created.
21806c3fb27SDimitry Andric   if (!UMulFunc->empty())
21906c3fb27SDimitry Andric     return;
22006c3fb27SDimitry Andric 
22106c3fb27SDimitry Andric   BasicBlock *EntryBB = BasicBlock::Create(UMulFunc->getParent()->getContext(),
22206c3fb27SDimitry Andric                                            "entry", UMulFunc);
22306c3fb27SDimitry Andric   IRBuilder<> IRB(EntryBB);
22406c3fb27SDimitry Andric   // Build the actual unsigned multiplication logic with the overflow
22506c3fb27SDimitry Andric   // indication. Do unsigned multiplication Mul = A * B. Then check
22606c3fb27SDimitry Andric   // if unsigned division Div = Mul / A is not equal to B. If so,
22706c3fb27SDimitry Andric   // then overflow has happened.
22806c3fb27SDimitry Andric   Value *Mul = IRB.CreateNUWMul(UMulFunc->getArg(0), UMulFunc->getArg(1));
22906c3fb27SDimitry Andric   Value *Div = IRB.CreateUDiv(Mul, UMulFunc->getArg(0));
23006c3fb27SDimitry Andric   Value *Overflow = IRB.CreateICmpNE(UMulFunc->getArg(0), Div);
23106c3fb27SDimitry Andric 
23206c3fb27SDimitry Andric   // umul.with.overflow intrinsic return a structure, where the first element
23306c3fb27SDimitry Andric   // is the multiplication result, and the second is an overflow bit.
23406c3fb27SDimitry Andric   Type *StructTy = UMulFunc->getReturnType();
23506c3fb27SDimitry Andric   Value *Agg = IRB.CreateInsertValue(PoisonValue::get(StructTy), Mul, {0});
23606c3fb27SDimitry Andric   Value *Res = IRB.CreateInsertValue(Agg, Overflow, {1});
23706c3fb27SDimitry Andric   IRB.CreateRet(Res);
23806c3fb27SDimitry Andric }
23906c3fb27SDimitry Andric 
lowerExpectAssume(IntrinsicInst * II)240*5f757f3fSDimitry Andric static void lowerExpectAssume(IntrinsicInst *II) {
241*5f757f3fSDimitry Andric   // If we cannot use the SPV_KHR_expect_assume extension, then we need to
242*5f757f3fSDimitry Andric   // ignore the intrinsic and move on. It should be removed later on by LLVM.
243*5f757f3fSDimitry Andric   // Otherwise we should lower the intrinsic to the corresponding SPIR-V
244*5f757f3fSDimitry Andric   // instruction.
245*5f757f3fSDimitry Andric   // For @llvm.assume we have OpAssumeTrueKHR.
246*5f757f3fSDimitry Andric   // For @llvm.expect we have OpExpectKHR.
247*5f757f3fSDimitry Andric   //
248*5f757f3fSDimitry Andric   // We need to lower this into a builtin and then the builtin into a SPIR-V
249*5f757f3fSDimitry Andric   // instruction.
250*5f757f3fSDimitry Andric   if (II->getIntrinsicID() == Intrinsic::assume) {
251*5f757f3fSDimitry Andric     Function *F = Intrinsic::getDeclaration(
252*5f757f3fSDimitry Andric         II->getModule(), Intrinsic::SPVIntrinsics::spv_assume);
253*5f757f3fSDimitry Andric     II->setCalledFunction(F);
254*5f757f3fSDimitry Andric   } else if (II->getIntrinsicID() == Intrinsic::expect) {
255*5f757f3fSDimitry Andric     Function *F = Intrinsic::getDeclaration(
256*5f757f3fSDimitry Andric         II->getModule(), Intrinsic::SPVIntrinsics::spv_expect,
257*5f757f3fSDimitry Andric         {II->getOperand(0)->getType()});
258*5f757f3fSDimitry Andric     II->setCalledFunction(F);
259*5f757f3fSDimitry Andric   } else {
260*5f757f3fSDimitry Andric     llvm_unreachable("Unknown intrinsic");
261*5f757f3fSDimitry Andric   }
262*5f757f3fSDimitry Andric 
263*5f757f3fSDimitry Andric   return;
264*5f757f3fSDimitry Andric }
265*5f757f3fSDimitry Andric 
lowerUMulWithOverflow(IntrinsicInst * UMulIntrinsic)26606c3fb27SDimitry Andric static void lowerUMulWithOverflow(IntrinsicInst *UMulIntrinsic) {
26706c3fb27SDimitry Andric   // Get a separate function - otherwise, we'd have to rework the CFG of the
26806c3fb27SDimitry Andric   // current one. Then simply replace the intrinsic uses with a call to the new
26906c3fb27SDimitry Andric   // function.
27006c3fb27SDimitry Andric   Module *M = UMulIntrinsic->getModule();
27106c3fb27SDimitry Andric   FunctionType *UMulFuncTy = UMulIntrinsic->getFunctionType();
27206c3fb27SDimitry Andric   Type *FSHLRetTy = UMulFuncTy->getReturnType();
27306c3fb27SDimitry Andric   const std::string FuncName = lowerLLVMIntrinsicName(UMulIntrinsic);
27406c3fb27SDimitry Andric   Function *UMulFunc =
27506c3fb27SDimitry Andric       getOrCreateFunction(M, FSHLRetTy, UMulFuncTy->params(), FuncName);
27606c3fb27SDimitry Andric   buildUMulWithOverflowFunc(UMulFunc);
27706c3fb27SDimitry Andric   UMulIntrinsic->setCalledFunction(UMulFunc);
27806c3fb27SDimitry Andric }
27906c3fb27SDimitry Andric 
28006c3fb27SDimitry Andric // Substitutes calls to LLVM intrinsics with either calls to SPIR-V intrinsics
28106c3fb27SDimitry Andric // or calls to proper generated functions. Returns True if F was modified.
substituteIntrinsicCalls(Function * F)28206c3fb27SDimitry Andric bool SPIRVPrepareFunctions::substituteIntrinsicCalls(Function *F) {
28306c3fb27SDimitry Andric   bool Changed = false;
28406c3fb27SDimitry Andric   for (BasicBlock &BB : *F) {
28506c3fb27SDimitry Andric     for (Instruction &I : BB) {
28606c3fb27SDimitry Andric       auto Call = dyn_cast<CallInst>(&I);
28706c3fb27SDimitry Andric       if (!Call)
28806c3fb27SDimitry Andric         continue;
28906c3fb27SDimitry Andric       Function *CF = Call->getCalledFunction();
29006c3fb27SDimitry Andric       if (!CF || !CF->isIntrinsic())
29106c3fb27SDimitry Andric         continue;
29206c3fb27SDimitry Andric       auto *II = cast<IntrinsicInst>(Call);
29306c3fb27SDimitry Andric       if (II->getIntrinsicID() == Intrinsic::memset ||
29406c3fb27SDimitry Andric           II->getIntrinsicID() == Intrinsic::bswap)
29506c3fb27SDimitry Andric         Changed |= lowerIntrinsicToFunction(II);
29606c3fb27SDimitry Andric       else if (II->getIntrinsicID() == Intrinsic::fshl ||
29706c3fb27SDimitry Andric                II->getIntrinsicID() == Intrinsic::fshr) {
29806c3fb27SDimitry Andric         lowerFunnelShifts(II);
29906c3fb27SDimitry Andric         Changed = true;
30006c3fb27SDimitry Andric       } else if (II->getIntrinsicID() == Intrinsic::umul_with_overflow) {
30106c3fb27SDimitry Andric         lowerUMulWithOverflow(II);
30206c3fb27SDimitry Andric         Changed = true;
303*5f757f3fSDimitry Andric       } else if (II->getIntrinsicID() == Intrinsic::assume ||
304*5f757f3fSDimitry Andric                  II->getIntrinsicID() == Intrinsic::expect) {
305*5f757f3fSDimitry Andric         const SPIRVSubtarget &STI = TM.getSubtarget<SPIRVSubtarget>(*F);
306*5f757f3fSDimitry Andric         if (STI.canUseExtension(SPIRV::Extension::SPV_KHR_expect_assume))
307*5f757f3fSDimitry Andric           lowerExpectAssume(II);
308*5f757f3fSDimitry Andric         Changed = true;
30906c3fb27SDimitry Andric       }
31006c3fb27SDimitry Andric     }
31106c3fb27SDimitry Andric   }
31206c3fb27SDimitry Andric   return Changed;
31306c3fb27SDimitry Andric }
31406c3fb27SDimitry Andric 
31506c3fb27SDimitry Andric // Returns F if aggregate argument/return types are not present or cloned F
31606c3fb27SDimitry Andric // function with the types replaced by i32 types. The change in types is
31706c3fb27SDimitry Andric // noted in 'spv.cloned_funcs' metadata for later restoration.
31806c3fb27SDimitry Andric Function *
removeAggregateTypesFromSignature(Function * F)31906c3fb27SDimitry Andric SPIRVPrepareFunctions::removeAggregateTypesFromSignature(Function *F) {
320fcaf7f86SDimitry Andric   IRBuilder<> B(F->getContext());
321fcaf7f86SDimitry Andric 
322fcaf7f86SDimitry Andric   bool IsRetAggr = F->getReturnType()->isAggregateType();
323fcaf7f86SDimitry Andric   bool HasAggrArg =
324fcaf7f86SDimitry Andric       std::any_of(F->arg_begin(), F->arg_end(), [](Argument &Arg) {
325fcaf7f86SDimitry Andric         return Arg.getType()->isAggregateType();
326fcaf7f86SDimitry Andric       });
327fcaf7f86SDimitry Andric   bool DoClone = IsRetAggr || HasAggrArg;
328fcaf7f86SDimitry Andric   if (!DoClone)
329fcaf7f86SDimitry Andric     return F;
330fcaf7f86SDimitry Andric   SmallVector<std::pair<int, Type *>, 4> ChangedTypes;
331fcaf7f86SDimitry Andric   Type *RetType = IsRetAggr ? B.getInt32Ty() : F->getReturnType();
332fcaf7f86SDimitry Andric   if (IsRetAggr)
333fcaf7f86SDimitry Andric     ChangedTypes.push_back(std::pair<int, Type *>(-1, F->getReturnType()));
334fcaf7f86SDimitry Andric   SmallVector<Type *, 4> ArgTypes;
335fcaf7f86SDimitry Andric   for (const auto &Arg : F->args()) {
336fcaf7f86SDimitry Andric     if (Arg.getType()->isAggregateType()) {
337fcaf7f86SDimitry Andric       ArgTypes.push_back(B.getInt32Ty());
338fcaf7f86SDimitry Andric       ChangedTypes.push_back(
339fcaf7f86SDimitry Andric           std::pair<int, Type *>(Arg.getArgNo(), Arg.getType()));
340fcaf7f86SDimitry Andric     } else
341fcaf7f86SDimitry Andric       ArgTypes.push_back(Arg.getType());
342fcaf7f86SDimitry Andric   }
343fcaf7f86SDimitry Andric   FunctionType *NewFTy =
344fcaf7f86SDimitry Andric       FunctionType::get(RetType, ArgTypes, F->getFunctionType()->isVarArg());
345fcaf7f86SDimitry Andric   Function *NewF =
346fcaf7f86SDimitry Andric       Function::Create(NewFTy, F->getLinkage(), F->getName(), *F->getParent());
347fcaf7f86SDimitry Andric 
348fcaf7f86SDimitry Andric   ValueToValueMapTy VMap;
349fcaf7f86SDimitry Andric   auto NewFArgIt = NewF->arg_begin();
350fcaf7f86SDimitry Andric   for (auto &Arg : F->args()) {
351fcaf7f86SDimitry Andric     StringRef ArgName = Arg.getName();
352fcaf7f86SDimitry Andric     NewFArgIt->setName(ArgName);
353fcaf7f86SDimitry Andric     VMap[&Arg] = &(*NewFArgIt++);
354fcaf7f86SDimitry Andric   }
355fcaf7f86SDimitry Andric   SmallVector<ReturnInst *, 8> Returns;
356fcaf7f86SDimitry Andric 
357fcaf7f86SDimitry Andric   CloneFunctionInto(NewF, F, VMap, CloneFunctionChangeType::LocalChangesOnly,
358fcaf7f86SDimitry Andric                     Returns);
359fcaf7f86SDimitry Andric   NewF->takeName(F);
360fcaf7f86SDimitry Andric 
361fcaf7f86SDimitry Andric   NamedMDNode *FuncMD =
362fcaf7f86SDimitry Andric       F->getParent()->getOrInsertNamedMetadata("spv.cloned_funcs");
363fcaf7f86SDimitry Andric   SmallVector<Metadata *, 2> MDArgs;
364fcaf7f86SDimitry Andric   MDArgs.push_back(MDString::get(B.getContext(), NewF->getName()));
365fcaf7f86SDimitry Andric   for (auto &ChangedTyP : ChangedTypes)
366fcaf7f86SDimitry Andric     MDArgs.push_back(MDNode::get(
367fcaf7f86SDimitry Andric         B.getContext(),
368fcaf7f86SDimitry Andric         {ConstantAsMetadata::get(B.getInt32(ChangedTyP.first)),
369fcaf7f86SDimitry Andric          ValueAsMetadata::get(Constant::getNullValue(ChangedTyP.second))}));
370fcaf7f86SDimitry Andric   MDNode *ThisFuncMD = MDNode::get(B.getContext(), MDArgs);
371fcaf7f86SDimitry Andric   FuncMD->addOperand(ThisFuncMD);
372fcaf7f86SDimitry Andric 
373fcaf7f86SDimitry Andric   for (auto *U : make_early_inc_range(F->users())) {
374fcaf7f86SDimitry Andric     if (auto *CI = dyn_cast<CallInst>(U))
375fcaf7f86SDimitry Andric       CI->mutateFunctionType(NewF->getFunctionType());
376fcaf7f86SDimitry Andric     U->replaceUsesOfWith(F, NewF);
377fcaf7f86SDimitry Andric   }
378fcaf7f86SDimitry Andric   return NewF;
379fcaf7f86SDimitry Andric }
380fcaf7f86SDimitry Andric 
runOnModule(Module & M)381fcaf7f86SDimitry Andric bool SPIRVPrepareFunctions::runOnModule(Module &M) {
38206c3fb27SDimitry Andric   bool Changed = false;
383fcaf7f86SDimitry Andric   for (Function &F : M)
38406c3fb27SDimitry Andric     Changed |= substituteIntrinsicCalls(&F);
385fcaf7f86SDimitry Andric 
386fcaf7f86SDimitry Andric   std::vector<Function *> FuncsWorklist;
387fcaf7f86SDimitry Andric   for (auto &F : M)
388fcaf7f86SDimitry Andric     FuncsWorklist.push_back(&F);
389fcaf7f86SDimitry Andric 
39006c3fb27SDimitry Andric   for (auto *F : FuncsWorklist) {
39106c3fb27SDimitry Andric     Function *NewF = removeAggregateTypesFromSignature(F);
392fcaf7f86SDimitry Andric 
39306c3fb27SDimitry Andric     if (NewF != F) {
39406c3fb27SDimitry Andric       F->eraseFromParent();
39506c3fb27SDimitry Andric       Changed = true;
396fcaf7f86SDimitry Andric     }
397fcaf7f86SDimitry Andric   }
398fcaf7f86SDimitry Andric   return Changed;
399fcaf7f86SDimitry Andric }
400fcaf7f86SDimitry Andric 
401*5f757f3fSDimitry Andric ModulePass *
createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine & TM)402*5f757f3fSDimitry Andric llvm::createSPIRVPrepareFunctionsPass(const SPIRVTargetMachine &TM) {
403*5f757f3fSDimitry Andric   return new SPIRVPrepareFunctions(TM);
404fcaf7f86SDimitry Andric }
405