1 //=== ReplaceWithVeclib.cpp - Replace vector intrinsics with veclib calls -===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Replaces LLVM IR instructions with vector operands (i.e., the frem
10 // instruction or calls to LLVM intrinsics) with matching calls to functions
11 // from a vector library (e.g libmvec, SVML) using TargetLibraryInfo interface.
12 //
13 //===----------------------------------------------------------------------===//
14 
15 #include "llvm/CodeGen/ReplaceWithVeclib.h"
16 #include "llvm/ADT/STLExtras.h"
17 #include "llvm/ADT/Statistic.h"
18 #include "llvm/ADT/StringRef.h"
19 #include "llvm/Analysis/DemandedBits.h"
20 #include "llvm/Analysis/GlobalsModRef.h"
21 #include "llvm/Analysis/OptimizationRemarkEmitter.h"
22 #include "llvm/Analysis/TargetLibraryInfo.h"
23 #include "llvm/Analysis/VectorUtils.h"
24 #include "llvm/CodeGen/Passes.h"
25 #include "llvm/IR/DerivedTypes.h"
26 #include "llvm/IR/IRBuilder.h"
27 #include "llvm/IR/InstIterator.h"
28 #include "llvm/IR/VFABIDemangler.h"
29 #include "llvm/Support/TypeSize.h"
30 #include "llvm/Transforms/Utils/ModuleUtils.h"
31 
32 using namespace llvm;
33 
34 #define DEBUG_TYPE "replace-with-veclib"
35 
36 STATISTIC(NumCallsReplaced,
37           "Number of calls to intrinsics that have been replaced.");
38 
39 STATISTIC(NumTLIFuncDeclAdded,
40           "Number of vector library function declarations added.");
41 
42 STATISTIC(NumFuncUsedAdded,
43           "Number of functions added to `llvm.compiler.used`");
44 
45 /// Returns a vector Function that it adds to the Module \p M. When an \p
46 /// ScalarFunc is not null, it copies its attributes to the newly created
47 /// Function.
48 Function *getTLIFunction(Module *M, FunctionType *VectorFTy,
49                          const StringRef TLIName,
50                          Function *ScalarFunc = nullptr) {
51   Function *TLIFunc = M->getFunction(TLIName);
52   if (!TLIFunc) {
53     TLIFunc =
54         Function::Create(VectorFTy, Function::ExternalLinkage, TLIName, *M);
55     if (ScalarFunc)
56       TLIFunc->copyAttributesFrom(ScalarFunc);
57 
58     LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Added vector library function `"
59                       << TLIName << "` of type `" << *(TLIFunc->getType())
60                       << "` to module.\n");
61 
62     ++NumTLIFuncDeclAdded;
63     // Add the freshly created function to llvm.compiler.used, similar to as it
64     // is done in InjectTLIMappings.
65     appendToCompilerUsed(*M, {TLIFunc});
66     LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Adding `" << TLIName
67                       << "` to `@llvm.compiler.used`.\n");
68     ++NumFuncUsedAdded;
69   }
70   return TLIFunc;
71 }
72 
73 /// Replace the instruction \p I with a call to the corresponding function from
74 /// the vector library (\p TLIVecFunc).
75 static void replaceWithTLIFunction(Instruction &I, VFInfo &Info,
76                                    Function *TLIVecFunc) {
77   IRBuilder<> IRBuilder(&I);
78   auto *CI = dyn_cast<CallInst>(&I);
79   SmallVector<Value *> Args(CI ? CI->args() : I.operands());
80   if (auto OptMaskpos = Info.getParamIndexForOptionalMask()) {
81     auto *MaskTy =
82         VectorType::get(Type::getInt1Ty(I.getContext()), Info.Shape.VF);
83     Args.insert(Args.begin() + OptMaskpos.value(),
84                 Constant::getAllOnesValue(MaskTy));
85   }
86 
87   // If it is a call instruction, preserve the operand bundles.
88   SmallVector<OperandBundleDef, 1> OpBundles;
89   if (CI)
90     CI->getOperandBundlesAsDefs(OpBundles);
91 
92   auto *Replacement = IRBuilder.CreateCall(TLIVecFunc, Args, OpBundles);
93   I.replaceAllUsesWith(Replacement);
94   // Preserve fast math flags for FP math.
95   if (isa<FPMathOperator>(Replacement))
96     Replacement->copyFastMathFlags(&I);
97 }
98 
99 /// Returns true when successfully replaced \p I with a suitable function taking
100 /// vector arguments, based on available mappings in the \p TLI. Currently only
101 /// works when \p I is a call to vectorized intrinsic or the frem instruction.
102 static bool replaceWithCallToVeclib(const TargetLibraryInfo &TLI,
103                                     Instruction &I) {
104   // At the moment VFABI assumes the return type is always widened unless it is
105   // a void type.
106   auto *VTy = dyn_cast<VectorType>(I.getType());
107   ElementCount EC(VTy ? VTy->getElementCount() : ElementCount::getFixed(0));
108 
109   // Compute the argument types of the corresponding scalar call and the scalar
110   // function name. For calls, it additionally finds the function to replace
111   // and checks that all vector operands match the previously found EC.
112   SmallVector<Type *, 8> ScalarArgTypes;
113   std::string ScalarName;
114   Function *FuncToReplace = nullptr;
115   auto *CI = dyn_cast<CallInst>(&I);
116   if (CI) {
117     FuncToReplace = CI->getCalledFunction();
118     Intrinsic::ID IID = FuncToReplace->getIntrinsicID();
119     assert(IID != Intrinsic::not_intrinsic && "Not an intrinsic");
120     for (auto Arg : enumerate(CI->args())) {
121       auto *ArgTy = Arg.value()->getType();
122       if (isVectorIntrinsicWithScalarOpAtArg(IID, Arg.index())) {
123         ScalarArgTypes.push_back(ArgTy);
124       } else if (auto *VectorArgTy = dyn_cast<VectorType>(ArgTy)) {
125         ScalarArgTypes.push_back(VectorArgTy->getElementType());
126         // When return type is void, set EC to the first vector argument, and
127         // disallow vector arguments with different ECs.
128         if (EC.isZero())
129           EC = VectorArgTy->getElementCount();
130         else if (EC != VectorArgTy->getElementCount())
131           return false;
132       } else
133         // Exit when it is supposed to be a vector argument but it isn't.
134         return false;
135     }
136     // Try to reconstruct the name for the scalar version of the instruction,
137     // using scalar argument types.
138     ScalarName = Intrinsic::isOverloaded(IID)
139                      ? Intrinsic::getName(IID, ScalarArgTypes, I.getModule())
140                      : Intrinsic::getName(IID).str();
141   } else {
142     assert(VTy && "Return type must be a vector");
143     auto *ScalarTy = VTy->getScalarType();
144     LibFunc Func;
145     if (!TLI.getLibFunc(I.getOpcode(), ScalarTy, Func))
146       return false;
147     ScalarName = TLI.getName(Func);
148     ScalarArgTypes = {ScalarTy, ScalarTy};
149   }
150 
151   // Try to find the mapping for the scalar version of this intrinsic and the
152   // exact vector width of the call operands in the TargetLibraryInfo. First,
153   // check with a non-masked variant, and if that fails try with a masked one.
154   const VecDesc *VD =
155       TLI.getVectorMappingInfo(ScalarName, EC, /*Masked*/ false);
156   if (!VD && !(VD = TLI.getVectorMappingInfo(ScalarName, EC, /*Masked*/ true)))
157     return false;
158 
159   LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Found TLI mapping from: `" << ScalarName
160                     << "` and vector width " << EC << " to: `"
161                     << VD->getVectorFnName() << "`.\n");
162 
163   // Replace the call to the intrinsic with a call to the vector library
164   // function.
165   Type *ScalarRetTy = I.getType()->getScalarType();
166   FunctionType *ScalarFTy =
167       FunctionType::get(ScalarRetTy, ScalarArgTypes, /*isVarArg*/ false);
168   const std::string MangledName = VD->getVectorFunctionABIVariantString();
169   auto OptInfo = VFABI::tryDemangleForVFABI(MangledName, ScalarFTy);
170   if (!OptInfo)
171     return false;
172 
173   // There is no guarantee that the vectorized instructions followed the VFABI
174   // specification when being created, this is why we need to add extra check to
175   // make sure that the operands of the vector function obtained via VFABI match
176   // the operands of the original vector instruction.
177   if (CI) {
178     for (auto VFParam : OptInfo->Shape.Parameters) {
179       if (VFParam.ParamKind == VFParamKind::GlobalPredicate)
180         continue;
181 
182       // tryDemangleForVFABI must return valid ParamPos, otherwise it could be
183       // a bug in the VFABI parser.
184       assert(VFParam.ParamPos < CI->arg_size() &&
185              "ParamPos has invalid range.");
186       Type *OrigTy = CI->getArgOperand(VFParam.ParamPos)->getType();
187       if (OrigTy->isVectorTy() != (VFParam.ParamKind == VFParamKind::Vector)) {
188         LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Will not replace: " << ScalarName
189                           << ". Wrong type at index " << VFParam.ParamPos
190                           << ": " << *OrigTy << "\n");
191         return false;
192       }
193     }
194   }
195 
196   FunctionType *VectorFTy = VFABI::createFunctionType(*OptInfo, ScalarFTy);
197   if (!VectorFTy)
198     return false;
199 
200   Function *TLIFunc = getTLIFunction(I.getModule(), VectorFTy,
201                                      VD->getVectorFnName(), FuncToReplace);
202 
203   replaceWithTLIFunction(I, *OptInfo, TLIFunc);
204   LLVM_DEBUG(dbgs() << DEBUG_TYPE << ": Replaced call to `" << ScalarName
205                     << "` with call to `" << TLIFunc->getName() << "`.\n");
206   ++NumCallsReplaced;
207   return true;
208 }
209 
210 /// Supported instruction \p I must be a vectorized frem or a call to an
211 /// intrinsic that returns either void or a vector.
212 static bool isSupportedInstruction(Instruction *I) {
213   Type *Ty = I->getType();
214   if (auto *CI = dyn_cast<CallInst>(I))
215     return (Ty->isVectorTy() || Ty->isVoidTy()) && CI->getCalledFunction() &&
216            CI->getCalledFunction()->getIntrinsicID() !=
217                Intrinsic::not_intrinsic;
218   if (I->getOpcode() == Instruction::FRem && Ty->isVectorTy())
219     return true;
220   return false;
221 }
222 
223 static bool runImpl(const TargetLibraryInfo &TLI, Function &F) {
224   bool Changed = false;
225   SmallVector<Instruction *> ReplacedCalls;
226   for (auto &I : instructions(F)) {
227     if (!isSupportedInstruction(&I))
228       continue;
229     if (replaceWithCallToVeclib(TLI, I)) {
230       ReplacedCalls.push_back(&I);
231       Changed = true;
232     }
233   }
234   // Erase the calls to the intrinsics that have been replaced
235   // with calls to the vector library.
236   for (auto *CI : ReplacedCalls)
237     CI->eraseFromParent();
238   return Changed;
239 }
240 
241 ////////////////////////////////////////////////////////////////////////////////
242 // New pass manager implementation.
243 ////////////////////////////////////////////////////////////////////////////////
244 PreservedAnalyses ReplaceWithVeclib::run(Function &F,
245                                          FunctionAnalysisManager &AM) {
246   const TargetLibraryInfo &TLI = AM.getResult<TargetLibraryAnalysis>(F);
247   auto Changed = runImpl(TLI, F);
248   if (Changed) {
249     LLVM_DEBUG(dbgs() << "Instructions replaced with vector libraries: "
250                       << NumCallsReplaced << "\n");
251 
252     PreservedAnalyses PA;
253     PA.preserveSet<CFGAnalyses>();
254     PA.preserve<TargetLibraryAnalysis>();
255     PA.preserve<ScalarEvolutionAnalysis>();
256     PA.preserve<LoopAccessAnalysis>();
257     PA.preserve<DemandedBitsAnalysis>();
258     PA.preserve<OptimizationRemarkEmitterAnalysis>();
259     return PA;
260   }
261 
262   // The pass did not replace any calls, hence it preserves all analyses.
263   return PreservedAnalyses::all();
264 }
265 
266 ////////////////////////////////////////////////////////////////////////////////
267 // Legacy PM Implementation.
268 ////////////////////////////////////////////////////////////////////////////////
269 bool ReplaceWithVeclibLegacy::runOnFunction(Function &F) {
270   const TargetLibraryInfo &TLI =
271       getAnalysis<TargetLibraryInfoWrapperPass>().getTLI(F);
272   return runImpl(TLI, F);
273 }
274 
275 void ReplaceWithVeclibLegacy::getAnalysisUsage(AnalysisUsage &AU) const {
276   AU.setPreservesCFG();
277   AU.addRequired<TargetLibraryInfoWrapperPass>();
278   AU.addPreserved<TargetLibraryInfoWrapperPass>();
279   AU.addPreserved<ScalarEvolutionWrapperPass>();
280   AU.addPreserved<AAResultsWrapperPass>();
281   AU.addPreserved<OptimizationRemarkEmitterWrapperPass>();
282   AU.addPreserved<GlobalsAAWrapperPass>();
283 }
284 
285 ////////////////////////////////////////////////////////////////////////////////
286 // Legacy Pass manager initialization
287 ////////////////////////////////////////////////////////////////////////////////
288 char ReplaceWithVeclibLegacy::ID = 0;
289 
290 INITIALIZE_PASS_BEGIN(ReplaceWithVeclibLegacy, DEBUG_TYPE,
291                       "Replace intrinsics with calls to vector library", false,
292                       false)
293 INITIALIZE_PASS_DEPENDENCY(TargetLibraryInfoWrapperPass)
294 INITIALIZE_PASS_END(ReplaceWithVeclibLegacy, DEBUG_TYPE,
295                     "Replace intrinsics with calls to vector library", false,
296                     false)
297 
298 FunctionPass *llvm::createReplaceWithVeclibLegacyPass() {
299   return new ReplaceWithVeclibLegacy();
300 }
301