1 //===- AMDGPULibCalls.cpp -------------------------------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 /// \file
10 /// This file does AMD library function optimizations.
11 //
12 //===----------------------------------------------------------------------===//
13
14 #include "AMDGPU.h"
15 #include "AMDGPULibFunc.h"
16 #include "GCNSubtarget.h"
17 #include "llvm/Analysis/AliasAnalysis.h"
18 #include "llvm/Analysis/Loads.h"
19 #include "llvm/IR/IntrinsicsAMDGPU.h"
20 #include "llvm/InitializePasses.h"
21 #include "llvm/Target/TargetMachine.h"
22
23 #define DEBUG_TYPE "amdgpu-simplifylib"
24
25 using namespace llvm;
26
27 static cl::opt<bool> EnablePreLink("amdgpu-prelink",
28 cl::desc("Enable pre-link mode optimizations"),
29 cl::init(false),
30 cl::Hidden);
31
32 static cl::list<std::string> UseNative("amdgpu-use-native",
33 cl::desc("Comma separated list of functions to replace with native, or all"),
34 cl::CommaSeparated, cl::ValueOptional,
35 cl::Hidden);
36
37 #define MATH_PI numbers::pi
38 #define MATH_E numbers::e
39 #define MATH_SQRT2 numbers::sqrt2
40 #define MATH_SQRT1_2 numbers::inv_sqrt2
41
42 namespace llvm {
43
44 class AMDGPULibCalls {
45 private:
46
47 typedef llvm::AMDGPULibFunc FuncInfo;
48
49 const TargetMachine *TM;
50
51 // -fuse-native.
52 bool AllNative = false;
53
54 bool useNativeFunc(const StringRef F) const;
55
56 // Return a pointer (pointer expr) to the function if function defintion with
57 // "FuncName" exists. It may create a new function prototype in pre-link mode.
58 FunctionCallee getFunction(Module *M, const FuncInfo &fInfo);
59
60 // Replace a normal function with its native version.
61 bool replaceWithNative(CallInst *CI, const FuncInfo &FInfo);
62
63 bool parseFunctionName(const StringRef& FMangledName,
64 FuncInfo *FInfo=nullptr /*out*/);
65
66 bool TDOFold(CallInst *CI, const FuncInfo &FInfo);
67
68 /* Specialized optimizations */
69
70 // recip (half or native)
71 bool fold_recip(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
72
73 // divide (half or native)
74 bool fold_divide(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
75
76 // pow/powr/pown
77 bool fold_pow(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
78
79 // rootn
80 bool fold_rootn(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
81
82 // fma/mad
83 bool fold_fma_mad(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
84
85 // -fuse-native for sincos
86 bool sincosUseNative(CallInst *aCI, const FuncInfo &FInfo);
87
88 // evaluate calls if calls' arguments are constants.
89 bool evaluateScalarMathFunc(FuncInfo &FInfo, double& Res0,
90 double& Res1, Constant *copr0, Constant *copr1, Constant *copr2);
91 bool evaluateCall(CallInst *aCI, FuncInfo &FInfo);
92
93 // exp
94 bool fold_exp(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
95
96 // exp2
97 bool fold_exp2(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
98
99 // exp10
100 bool fold_exp10(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
101
102 // log
103 bool fold_log(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
104
105 // log2
106 bool fold_log2(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
107
108 // log10
109 bool fold_log10(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
110
111 // sqrt
112 bool fold_sqrt(CallInst *CI, IRBuilder<> &B, const FuncInfo &FInfo);
113
114 // sin/cos
115 bool fold_sincos(CallInst * CI, IRBuilder<> &B, AliasAnalysis * AA);
116
117 // __read_pipe/__write_pipe
118 bool fold_read_write_pipe(CallInst *CI, IRBuilder<> &B, FuncInfo &FInfo);
119
120 // llvm.amdgcn.wavefrontsize
121 bool fold_wavefrontsize(CallInst *CI, IRBuilder<> &B);
122
123 // Get insertion point at entry.
124 BasicBlock::iterator getEntryIns(CallInst * UI);
125 // Insert an Alloc instruction.
126 AllocaInst* insertAlloca(CallInst * UI, IRBuilder<> &B, const char *prefix);
127 // Get a scalar native builtin signle argument FP function
128 FunctionCallee getNativeFunction(Module *M, const FuncInfo &FInfo);
129
130 protected:
131 CallInst *CI;
132
133 bool isUnsafeMath(const CallInst *CI) const;
134
replaceCall(Value * With)135 void replaceCall(Value *With) {
136 CI->replaceAllUsesWith(With);
137 CI->eraseFromParent();
138 }
139
140 public:
AMDGPULibCalls(const TargetMachine * TM_=nullptr)141 AMDGPULibCalls(const TargetMachine *TM_ = nullptr) : TM(TM_) {}
142
143 bool fold(CallInst *CI, AliasAnalysis *AA = nullptr);
144
145 void initNativeFuncs();
146
147 // Replace a normal math function call with that native version
148 bool useNative(CallInst *CI);
149 };
150
151 } // end llvm namespace
152
153 namespace {
154
155 class AMDGPUSimplifyLibCalls : public FunctionPass {
156
157 AMDGPULibCalls Simplifier;
158
159 public:
160 static char ID; // Pass identification
161
AMDGPUSimplifyLibCalls(const TargetMachine * TM=nullptr)162 AMDGPUSimplifyLibCalls(const TargetMachine *TM = nullptr)
163 : FunctionPass(ID), Simplifier(TM) {
164 initializeAMDGPUSimplifyLibCallsPass(*PassRegistry::getPassRegistry());
165 }
166
getAnalysisUsage(AnalysisUsage & AU) const167 void getAnalysisUsage(AnalysisUsage &AU) const override {
168 AU.addRequired<AAResultsWrapperPass>();
169 }
170
171 bool runOnFunction(Function &M) override;
172 };
173
174 class AMDGPUUseNativeCalls : public FunctionPass {
175
176 AMDGPULibCalls Simplifier;
177
178 public:
179 static char ID; // Pass identification
180
AMDGPUUseNativeCalls()181 AMDGPUUseNativeCalls() : FunctionPass(ID) {
182 initializeAMDGPUUseNativeCallsPass(*PassRegistry::getPassRegistry());
183 Simplifier.initNativeFuncs();
184 }
185
186 bool runOnFunction(Function &F) override;
187 };
188
189 } // end anonymous namespace.
190
191 char AMDGPUSimplifyLibCalls::ID = 0;
192 char AMDGPUUseNativeCalls::ID = 0;
193
194 INITIALIZE_PASS_BEGIN(AMDGPUSimplifyLibCalls, "amdgpu-simplifylib",
195 "Simplify well-known AMD library calls", false, false)
INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)196 INITIALIZE_PASS_DEPENDENCY(AAResultsWrapperPass)
197 INITIALIZE_PASS_END(AMDGPUSimplifyLibCalls, "amdgpu-simplifylib",
198 "Simplify well-known AMD library calls", false, false)
199
200 INITIALIZE_PASS(AMDGPUUseNativeCalls, "amdgpu-usenative",
201 "Replace builtin math calls with that native versions.",
202 false, false)
203
204 template <typename IRB>
205 static CallInst *CreateCallEx(IRB &B, FunctionCallee Callee, Value *Arg,
206 const Twine &Name = "") {
207 CallInst *R = B.CreateCall(Callee, Arg, Name);
208 if (Function *F = dyn_cast<Function>(Callee.getCallee()))
209 R->setCallingConv(F->getCallingConv());
210 return R;
211 }
212
213 template <typename IRB>
CreateCallEx2(IRB & B,FunctionCallee Callee,Value * Arg1,Value * Arg2,const Twine & Name="")214 static CallInst *CreateCallEx2(IRB &B, FunctionCallee Callee, Value *Arg1,
215 Value *Arg2, const Twine &Name = "") {
216 CallInst *R = B.CreateCall(Callee, {Arg1, Arg2}, Name);
217 if (Function *F = dyn_cast<Function>(Callee.getCallee()))
218 R->setCallingConv(F->getCallingConv());
219 return R;
220 }
221
222 // Data structures for table-driven optimizations.
223 // FuncTbl works for both f32 and f64 functions with 1 input argument
224
225 struct TableEntry {
226 double result;
227 double input;
228 };
229
230 /* a list of {result, input} */
231 static const TableEntry tbl_acos[] = {
232 {MATH_PI / 2.0, 0.0},
233 {MATH_PI / 2.0, -0.0},
234 {0.0, 1.0},
235 {MATH_PI, -1.0}
236 };
237 static const TableEntry tbl_acosh[] = {
238 {0.0, 1.0}
239 };
240 static const TableEntry tbl_acospi[] = {
241 {0.5, 0.0},
242 {0.5, -0.0},
243 {0.0, 1.0},
244 {1.0, -1.0}
245 };
246 static const TableEntry tbl_asin[] = {
247 {0.0, 0.0},
248 {-0.0, -0.0},
249 {MATH_PI / 2.0, 1.0},
250 {-MATH_PI / 2.0, -1.0}
251 };
252 static const TableEntry tbl_asinh[] = {
253 {0.0, 0.0},
254 {-0.0, -0.0}
255 };
256 static const TableEntry tbl_asinpi[] = {
257 {0.0, 0.0},
258 {-0.0, -0.0},
259 {0.5, 1.0},
260 {-0.5, -1.0}
261 };
262 static const TableEntry tbl_atan[] = {
263 {0.0, 0.0},
264 {-0.0, -0.0},
265 {MATH_PI / 4.0, 1.0},
266 {-MATH_PI / 4.0, -1.0}
267 };
268 static const TableEntry tbl_atanh[] = {
269 {0.0, 0.0},
270 {-0.0, -0.0}
271 };
272 static const TableEntry tbl_atanpi[] = {
273 {0.0, 0.0},
274 {-0.0, -0.0},
275 {0.25, 1.0},
276 {-0.25, -1.0}
277 };
278 static const TableEntry tbl_cbrt[] = {
279 {0.0, 0.0},
280 {-0.0, -0.0},
281 {1.0, 1.0},
282 {-1.0, -1.0},
283 };
284 static const TableEntry tbl_cos[] = {
285 {1.0, 0.0},
286 {1.0, -0.0}
287 };
288 static const TableEntry tbl_cosh[] = {
289 {1.0, 0.0},
290 {1.0, -0.0}
291 };
292 static const TableEntry tbl_cospi[] = {
293 {1.0, 0.0},
294 {1.0, -0.0}
295 };
296 static const TableEntry tbl_erfc[] = {
297 {1.0, 0.0},
298 {1.0, -0.0}
299 };
300 static const TableEntry tbl_erf[] = {
301 {0.0, 0.0},
302 {-0.0, -0.0}
303 };
304 static const TableEntry tbl_exp[] = {
305 {1.0, 0.0},
306 {1.0, -0.0},
307 {MATH_E, 1.0}
308 };
309 static const TableEntry tbl_exp2[] = {
310 {1.0, 0.0},
311 {1.0, -0.0},
312 {2.0, 1.0}
313 };
314 static const TableEntry tbl_exp10[] = {
315 {1.0, 0.0},
316 {1.0, -0.0},
317 {10.0, 1.0}
318 };
319 static const TableEntry tbl_expm1[] = {
320 {0.0, 0.0},
321 {-0.0, -0.0}
322 };
323 static const TableEntry tbl_log[] = {
324 {0.0, 1.0},
325 {1.0, MATH_E}
326 };
327 static const TableEntry tbl_log2[] = {
328 {0.0, 1.0},
329 {1.0, 2.0}
330 };
331 static const TableEntry tbl_log10[] = {
332 {0.0, 1.0},
333 {1.0, 10.0}
334 };
335 static const TableEntry tbl_rsqrt[] = {
336 {1.0, 1.0},
337 {MATH_SQRT1_2, 2.0}
338 };
339 static const TableEntry tbl_sin[] = {
340 {0.0, 0.0},
341 {-0.0, -0.0}
342 };
343 static const TableEntry tbl_sinh[] = {
344 {0.0, 0.0},
345 {-0.0, -0.0}
346 };
347 static const TableEntry tbl_sinpi[] = {
348 {0.0, 0.0},
349 {-0.0, -0.0}
350 };
351 static const TableEntry tbl_sqrt[] = {
352 {0.0, 0.0},
353 {1.0, 1.0},
354 {MATH_SQRT2, 2.0}
355 };
356 static const TableEntry tbl_tan[] = {
357 {0.0, 0.0},
358 {-0.0, -0.0}
359 };
360 static const TableEntry tbl_tanh[] = {
361 {0.0, 0.0},
362 {-0.0, -0.0}
363 };
364 static const TableEntry tbl_tanpi[] = {
365 {0.0, 0.0},
366 {-0.0, -0.0}
367 };
368 static const TableEntry tbl_tgamma[] = {
369 {1.0, 1.0},
370 {1.0, 2.0},
371 {2.0, 3.0},
372 {6.0, 4.0}
373 };
374
HasNative(AMDGPULibFunc::EFuncId id)375 static bool HasNative(AMDGPULibFunc::EFuncId id) {
376 switch(id) {
377 case AMDGPULibFunc::EI_DIVIDE:
378 case AMDGPULibFunc::EI_COS:
379 case AMDGPULibFunc::EI_EXP:
380 case AMDGPULibFunc::EI_EXP2:
381 case AMDGPULibFunc::EI_EXP10:
382 case AMDGPULibFunc::EI_LOG:
383 case AMDGPULibFunc::EI_LOG2:
384 case AMDGPULibFunc::EI_LOG10:
385 case AMDGPULibFunc::EI_POWR:
386 case AMDGPULibFunc::EI_RECIP:
387 case AMDGPULibFunc::EI_RSQRT:
388 case AMDGPULibFunc::EI_SIN:
389 case AMDGPULibFunc::EI_SINCOS:
390 case AMDGPULibFunc::EI_SQRT:
391 case AMDGPULibFunc::EI_TAN:
392 return true;
393 default:;
394 }
395 return false;
396 }
397
398 struct TableRef {
399 size_t size;
400 const TableEntry *table; // variable size: from 0 to (size - 1)
401
TableRefTableRef402 TableRef() : size(0), table(nullptr) {}
403
404 template <size_t N>
TableRefTableRef405 TableRef(const TableEntry (&tbl)[N]) : size(N), table(&tbl[0]) {}
406 };
407
getOptTable(AMDGPULibFunc::EFuncId id)408 static TableRef getOptTable(AMDGPULibFunc::EFuncId id) {
409 switch(id) {
410 case AMDGPULibFunc::EI_ACOS: return TableRef(tbl_acos);
411 case AMDGPULibFunc::EI_ACOSH: return TableRef(tbl_acosh);
412 case AMDGPULibFunc::EI_ACOSPI: return TableRef(tbl_acospi);
413 case AMDGPULibFunc::EI_ASIN: return TableRef(tbl_asin);
414 case AMDGPULibFunc::EI_ASINH: return TableRef(tbl_asinh);
415 case AMDGPULibFunc::EI_ASINPI: return TableRef(tbl_asinpi);
416 case AMDGPULibFunc::EI_ATAN: return TableRef(tbl_atan);
417 case AMDGPULibFunc::EI_ATANH: return TableRef(tbl_atanh);
418 case AMDGPULibFunc::EI_ATANPI: return TableRef(tbl_atanpi);
419 case AMDGPULibFunc::EI_CBRT: return TableRef(tbl_cbrt);
420 case AMDGPULibFunc::EI_NCOS:
421 case AMDGPULibFunc::EI_COS: return TableRef(tbl_cos);
422 case AMDGPULibFunc::EI_COSH: return TableRef(tbl_cosh);
423 case AMDGPULibFunc::EI_COSPI: return TableRef(tbl_cospi);
424 case AMDGPULibFunc::EI_ERFC: return TableRef(tbl_erfc);
425 case AMDGPULibFunc::EI_ERF: return TableRef(tbl_erf);
426 case AMDGPULibFunc::EI_EXP: return TableRef(tbl_exp);
427 case AMDGPULibFunc::EI_NEXP2:
428 case AMDGPULibFunc::EI_EXP2: return TableRef(tbl_exp2);
429 case AMDGPULibFunc::EI_EXP10: return TableRef(tbl_exp10);
430 case AMDGPULibFunc::EI_EXPM1: return TableRef(tbl_expm1);
431 case AMDGPULibFunc::EI_LOG: return TableRef(tbl_log);
432 case AMDGPULibFunc::EI_NLOG2:
433 case AMDGPULibFunc::EI_LOG2: return TableRef(tbl_log2);
434 case AMDGPULibFunc::EI_LOG10: return TableRef(tbl_log10);
435 case AMDGPULibFunc::EI_NRSQRT:
436 case AMDGPULibFunc::EI_RSQRT: return TableRef(tbl_rsqrt);
437 case AMDGPULibFunc::EI_NSIN:
438 case AMDGPULibFunc::EI_SIN: return TableRef(tbl_sin);
439 case AMDGPULibFunc::EI_SINH: return TableRef(tbl_sinh);
440 case AMDGPULibFunc::EI_SINPI: return TableRef(tbl_sinpi);
441 case AMDGPULibFunc::EI_NSQRT:
442 case AMDGPULibFunc::EI_SQRT: return TableRef(tbl_sqrt);
443 case AMDGPULibFunc::EI_TAN: return TableRef(tbl_tan);
444 case AMDGPULibFunc::EI_TANH: return TableRef(tbl_tanh);
445 case AMDGPULibFunc::EI_TANPI: return TableRef(tbl_tanpi);
446 case AMDGPULibFunc::EI_TGAMMA: return TableRef(tbl_tgamma);
447 default:;
448 }
449 return TableRef();
450 }
451
getVecSize(const AMDGPULibFunc & FInfo)452 static inline int getVecSize(const AMDGPULibFunc& FInfo) {
453 return FInfo.getLeads()[0].VectorSize;
454 }
455
getArgType(const AMDGPULibFunc & FInfo)456 static inline AMDGPULibFunc::EType getArgType(const AMDGPULibFunc& FInfo) {
457 return (AMDGPULibFunc::EType)FInfo.getLeads()[0].ArgType;
458 }
459
getFunction(Module * M,const FuncInfo & fInfo)460 FunctionCallee AMDGPULibCalls::getFunction(Module *M, const FuncInfo &fInfo) {
461 // If we are doing PreLinkOpt, the function is external. So it is safe to
462 // use getOrInsertFunction() at this stage.
463
464 return EnablePreLink ? AMDGPULibFunc::getOrInsertFunction(M, fInfo)
465 : AMDGPULibFunc::getFunction(M, fInfo);
466 }
467
parseFunctionName(const StringRef & FMangledName,FuncInfo * FInfo)468 bool AMDGPULibCalls::parseFunctionName(const StringRef& FMangledName,
469 FuncInfo *FInfo) {
470 return AMDGPULibFunc::parse(FMangledName, *FInfo);
471 }
472
isUnsafeMath(const CallInst * CI) const473 bool AMDGPULibCalls::isUnsafeMath(const CallInst *CI) const {
474 if (auto Op = dyn_cast<FPMathOperator>(CI))
475 if (Op->isFast())
476 return true;
477 const Function *F = CI->getParent()->getParent();
478 Attribute Attr = F->getFnAttribute("unsafe-fp-math");
479 return Attr.getValueAsBool();
480 }
481
useNativeFunc(const StringRef F) const482 bool AMDGPULibCalls::useNativeFunc(const StringRef F) const {
483 return AllNative || llvm::is_contained(UseNative, F);
484 }
485
initNativeFuncs()486 void AMDGPULibCalls::initNativeFuncs() {
487 AllNative = useNativeFunc("all") ||
488 (UseNative.getNumOccurrences() && UseNative.size() == 1 &&
489 UseNative.begin()->empty());
490 }
491
sincosUseNative(CallInst * aCI,const FuncInfo & FInfo)492 bool AMDGPULibCalls::sincosUseNative(CallInst *aCI, const FuncInfo &FInfo) {
493 bool native_sin = useNativeFunc("sin");
494 bool native_cos = useNativeFunc("cos");
495
496 if (native_sin && native_cos) {
497 Module *M = aCI->getModule();
498 Value *opr0 = aCI->getArgOperand(0);
499
500 AMDGPULibFunc nf;
501 nf.getLeads()[0].ArgType = FInfo.getLeads()[0].ArgType;
502 nf.getLeads()[0].VectorSize = FInfo.getLeads()[0].VectorSize;
503
504 nf.setPrefix(AMDGPULibFunc::NATIVE);
505 nf.setId(AMDGPULibFunc::EI_SIN);
506 FunctionCallee sinExpr = getFunction(M, nf);
507
508 nf.setPrefix(AMDGPULibFunc::NATIVE);
509 nf.setId(AMDGPULibFunc::EI_COS);
510 FunctionCallee cosExpr = getFunction(M, nf);
511 if (sinExpr && cosExpr) {
512 Value *sinval = CallInst::Create(sinExpr, opr0, "splitsin", aCI);
513 Value *cosval = CallInst::Create(cosExpr, opr0, "splitcos", aCI);
514 new StoreInst(cosval, aCI->getArgOperand(1), aCI);
515
516 DEBUG_WITH_TYPE("usenative", dbgs() << "<useNative> replace " << *aCI
517 << " with native version of sin/cos");
518
519 replaceCall(sinval);
520 return true;
521 }
522 }
523 return false;
524 }
525
useNative(CallInst * aCI)526 bool AMDGPULibCalls::useNative(CallInst *aCI) {
527 CI = aCI;
528 Function *Callee = aCI->getCalledFunction();
529
530 FuncInfo FInfo;
531 if (!parseFunctionName(Callee->getName(), &FInfo) || !FInfo.isMangled() ||
532 FInfo.getPrefix() != AMDGPULibFunc::NOPFX ||
533 getArgType(FInfo) == AMDGPULibFunc::F64 || !HasNative(FInfo.getId()) ||
534 !(AllNative || useNativeFunc(FInfo.getName()))) {
535 return false;
536 }
537
538 if (FInfo.getId() == AMDGPULibFunc::EI_SINCOS)
539 return sincosUseNative(aCI, FInfo);
540
541 FInfo.setPrefix(AMDGPULibFunc::NATIVE);
542 FunctionCallee F = getFunction(aCI->getModule(), FInfo);
543 if (!F)
544 return false;
545
546 aCI->setCalledFunction(F);
547 DEBUG_WITH_TYPE("usenative", dbgs() << "<useNative> replace " << *aCI
548 << " with native version");
549 return true;
550 }
551
552 // Clang emits call of __read_pipe_2 or __read_pipe_4 for OpenCL read_pipe
553 // builtin, with appended type size and alignment arguments, where 2 or 4
554 // indicates the original number of arguments. The library has optimized version
555 // of __read_pipe_2/__read_pipe_4 when the type size and alignment has the same
556 // power of 2 value. This function transforms __read_pipe_2 to __read_pipe_2_N
557 // for such cases where N is the size in bytes of the type (N = 1, 2, 4, 8, ...,
558 // 128). The same for __read_pipe_4, write_pipe_2, and write_pipe_4.
fold_read_write_pipe(CallInst * CI,IRBuilder<> & B,FuncInfo & FInfo)559 bool AMDGPULibCalls::fold_read_write_pipe(CallInst *CI, IRBuilder<> &B,
560 FuncInfo &FInfo) {
561 auto *Callee = CI->getCalledFunction();
562 if (!Callee->isDeclaration())
563 return false;
564
565 assert(Callee->hasName() && "Invalid read_pipe/write_pipe function");
566 auto *M = Callee->getParent();
567 auto &Ctx = M->getContext();
568 std::string Name = std::string(Callee->getName());
569 auto NumArg = CI->getNumArgOperands();
570 if (NumArg != 4 && NumArg != 6)
571 return false;
572 auto *PacketSize = CI->getArgOperand(NumArg - 2);
573 auto *PacketAlign = CI->getArgOperand(NumArg - 1);
574 if (!isa<ConstantInt>(PacketSize) || !isa<ConstantInt>(PacketAlign))
575 return false;
576 unsigned Size = cast<ConstantInt>(PacketSize)->getZExtValue();
577 Align Alignment = cast<ConstantInt>(PacketAlign)->getAlignValue();
578 if (Alignment != Size)
579 return false;
580
581 Type *PtrElemTy;
582 if (Size <= 8)
583 PtrElemTy = Type::getIntNTy(Ctx, Size * 8);
584 else
585 PtrElemTy = FixedVectorType::get(Type::getInt64Ty(Ctx), Size / 8);
586 unsigned PtrArgLoc = CI->getNumArgOperands() - 3;
587 auto PtrArg = CI->getArgOperand(PtrArgLoc);
588 unsigned PtrArgAS = PtrArg->getType()->getPointerAddressSpace();
589 auto *PtrTy = llvm::PointerType::get(PtrElemTy, PtrArgAS);
590
591 SmallVector<llvm::Type *, 6> ArgTys;
592 for (unsigned I = 0; I != PtrArgLoc; ++I)
593 ArgTys.push_back(CI->getArgOperand(I)->getType());
594 ArgTys.push_back(PtrTy);
595
596 Name = Name + "_" + std::to_string(Size);
597 auto *FTy = FunctionType::get(Callee->getReturnType(),
598 ArrayRef<Type *>(ArgTys), false);
599 AMDGPULibFunc NewLibFunc(Name, FTy);
600 FunctionCallee F = AMDGPULibFunc::getOrInsertFunction(M, NewLibFunc);
601 if (!F)
602 return false;
603
604 auto *BCast = B.CreatePointerCast(PtrArg, PtrTy);
605 SmallVector<Value *, 6> Args;
606 for (unsigned I = 0; I != PtrArgLoc; ++I)
607 Args.push_back(CI->getArgOperand(I));
608 Args.push_back(BCast);
609
610 auto *NCI = B.CreateCall(F, Args);
611 NCI->setAttributes(CI->getAttributes());
612 CI->replaceAllUsesWith(NCI);
613 CI->dropAllReferences();
614 CI->eraseFromParent();
615
616 return true;
617 }
618
619 // This function returns false if no change; return true otherwise.
fold(CallInst * CI,AliasAnalysis * AA)620 bool AMDGPULibCalls::fold(CallInst *CI, AliasAnalysis *AA) {
621 this->CI = CI;
622 Function *Callee = CI->getCalledFunction();
623
624 // Ignore indirect calls.
625 if (Callee == 0) return false;
626
627 BasicBlock *BB = CI->getParent();
628 LLVMContext &Context = CI->getParent()->getContext();
629 IRBuilder<> B(Context);
630
631 // Set the builder to the instruction after the call.
632 B.SetInsertPoint(BB, CI->getIterator());
633
634 // Copy fast flags from the original call.
635 if (const FPMathOperator *FPOp = dyn_cast<const FPMathOperator>(CI))
636 B.setFastMathFlags(FPOp->getFastMathFlags());
637
638 switch (Callee->getIntrinsicID()) {
639 default:
640 break;
641 case Intrinsic::amdgcn_wavefrontsize:
642 return !EnablePreLink && fold_wavefrontsize(CI, B);
643 }
644
645 FuncInfo FInfo;
646 if (!parseFunctionName(Callee->getName(), &FInfo))
647 return false;
648
649 // Further check the number of arguments to see if they match.
650 if (CI->getNumArgOperands() != FInfo.getNumArgs())
651 return false;
652
653 if (TDOFold(CI, FInfo))
654 return true;
655
656 // Under unsafe-math, evaluate calls if possible.
657 // According to Brian Sumner, we can do this for all f32 function calls
658 // using host's double function calls.
659 if (isUnsafeMath(CI) && evaluateCall(CI, FInfo))
660 return true;
661
662 // Specilized optimizations for each function call
663 switch (FInfo.getId()) {
664 case AMDGPULibFunc::EI_RECIP:
665 // skip vector function
666 assert ((FInfo.getPrefix() == AMDGPULibFunc::NATIVE ||
667 FInfo.getPrefix() == AMDGPULibFunc::HALF) &&
668 "recip must be an either native or half function");
669 return (getVecSize(FInfo) != 1) ? false : fold_recip(CI, B, FInfo);
670
671 case AMDGPULibFunc::EI_DIVIDE:
672 // skip vector function
673 assert ((FInfo.getPrefix() == AMDGPULibFunc::NATIVE ||
674 FInfo.getPrefix() == AMDGPULibFunc::HALF) &&
675 "divide must be an either native or half function");
676 return (getVecSize(FInfo) != 1) ? false : fold_divide(CI, B, FInfo);
677
678 case AMDGPULibFunc::EI_POW:
679 case AMDGPULibFunc::EI_POWR:
680 case AMDGPULibFunc::EI_POWN:
681 return fold_pow(CI, B, FInfo);
682
683 case AMDGPULibFunc::EI_ROOTN:
684 // skip vector function
685 return (getVecSize(FInfo) != 1) ? false : fold_rootn(CI, B, FInfo);
686
687 case AMDGPULibFunc::EI_FMA:
688 case AMDGPULibFunc::EI_MAD:
689 case AMDGPULibFunc::EI_NFMA:
690 // skip vector function
691 return (getVecSize(FInfo) != 1) ? false : fold_fma_mad(CI, B, FInfo);
692
693 case AMDGPULibFunc::EI_SQRT:
694 return isUnsafeMath(CI) && fold_sqrt(CI, B, FInfo);
695 case AMDGPULibFunc::EI_COS:
696 case AMDGPULibFunc::EI_SIN:
697 if ((getArgType(FInfo) == AMDGPULibFunc::F32 ||
698 getArgType(FInfo) == AMDGPULibFunc::F64)
699 && (FInfo.getPrefix() == AMDGPULibFunc::NOPFX))
700 return fold_sincos(CI, B, AA);
701
702 break;
703 case AMDGPULibFunc::EI_READ_PIPE_2:
704 case AMDGPULibFunc::EI_READ_PIPE_4:
705 case AMDGPULibFunc::EI_WRITE_PIPE_2:
706 case AMDGPULibFunc::EI_WRITE_PIPE_4:
707 return fold_read_write_pipe(CI, B, FInfo);
708
709 default:
710 break;
711 }
712
713 return false;
714 }
715
TDOFold(CallInst * CI,const FuncInfo & FInfo)716 bool AMDGPULibCalls::TDOFold(CallInst *CI, const FuncInfo &FInfo) {
717 // Table-Driven optimization
718 const TableRef tr = getOptTable(FInfo.getId());
719 if (tr.size==0)
720 return false;
721
722 int const sz = (int)tr.size;
723 const TableEntry * const ftbl = tr.table;
724 Value *opr0 = CI->getArgOperand(0);
725
726 if (getVecSize(FInfo) > 1) {
727 if (ConstantDataVector *CV = dyn_cast<ConstantDataVector>(opr0)) {
728 SmallVector<double, 0> DVal;
729 for (int eltNo = 0; eltNo < getVecSize(FInfo); ++eltNo) {
730 ConstantFP *eltval = dyn_cast<ConstantFP>(
731 CV->getElementAsConstant((unsigned)eltNo));
732 assert(eltval && "Non-FP arguments in math function!");
733 bool found = false;
734 for (int i=0; i < sz; ++i) {
735 if (eltval->isExactlyValue(ftbl[i].input)) {
736 DVal.push_back(ftbl[i].result);
737 found = true;
738 break;
739 }
740 }
741 if (!found) {
742 // This vector constants not handled yet.
743 return false;
744 }
745 }
746 LLVMContext &context = CI->getParent()->getParent()->getContext();
747 Constant *nval;
748 if (getArgType(FInfo) == AMDGPULibFunc::F32) {
749 SmallVector<float, 0> FVal;
750 for (unsigned i = 0; i < DVal.size(); ++i) {
751 FVal.push_back((float)DVal[i]);
752 }
753 ArrayRef<float> tmp(FVal);
754 nval = ConstantDataVector::get(context, tmp);
755 } else { // F64
756 ArrayRef<double> tmp(DVal);
757 nval = ConstantDataVector::get(context, tmp);
758 }
759 LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *nval << "\n");
760 replaceCall(nval);
761 return true;
762 }
763 } else {
764 // Scalar version
765 if (ConstantFP *CF = dyn_cast<ConstantFP>(opr0)) {
766 for (int i = 0; i < sz; ++i) {
767 if (CF->isExactlyValue(ftbl[i].input)) {
768 Value *nval = ConstantFP::get(CF->getType(), ftbl[i].result);
769 LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *nval << "\n");
770 replaceCall(nval);
771 return true;
772 }
773 }
774 }
775 }
776
777 return false;
778 }
779
replaceWithNative(CallInst * CI,const FuncInfo & FInfo)780 bool AMDGPULibCalls::replaceWithNative(CallInst *CI, const FuncInfo &FInfo) {
781 Module *M = CI->getModule();
782 if (getArgType(FInfo) != AMDGPULibFunc::F32 ||
783 FInfo.getPrefix() != AMDGPULibFunc::NOPFX ||
784 !HasNative(FInfo.getId()))
785 return false;
786
787 AMDGPULibFunc nf = FInfo;
788 nf.setPrefix(AMDGPULibFunc::NATIVE);
789 if (FunctionCallee FPExpr = getFunction(M, nf)) {
790 LLVM_DEBUG(dbgs() << "AMDIC: " << *CI << " ---> ");
791
792 CI->setCalledFunction(FPExpr);
793
794 LLVM_DEBUG(dbgs() << *CI << '\n');
795
796 return true;
797 }
798 return false;
799 }
800
801 // [native_]half_recip(c) ==> 1.0/c
fold_recip(CallInst * CI,IRBuilder<> & B,const FuncInfo & FInfo)802 bool AMDGPULibCalls::fold_recip(CallInst *CI, IRBuilder<> &B,
803 const FuncInfo &FInfo) {
804 Value *opr0 = CI->getArgOperand(0);
805 if (ConstantFP *CF = dyn_cast<ConstantFP>(opr0)) {
806 // Just create a normal div. Later, InstCombine will be able
807 // to compute the divide into a constant (avoid check float infinity
808 // or subnormal at this point).
809 Value *nval = B.CreateFDiv(ConstantFP::get(CF->getType(), 1.0),
810 opr0,
811 "recip2div");
812 LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *nval << "\n");
813 replaceCall(nval);
814 return true;
815 }
816 return false;
817 }
818
819 // [native_]half_divide(x, c) ==> x/c
fold_divide(CallInst * CI,IRBuilder<> & B,const FuncInfo & FInfo)820 bool AMDGPULibCalls::fold_divide(CallInst *CI, IRBuilder<> &B,
821 const FuncInfo &FInfo) {
822 Value *opr0 = CI->getArgOperand(0);
823 Value *opr1 = CI->getArgOperand(1);
824 ConstantFP *CF0 = dyn_cast<ConstantFP>(opr0);
825 ConstantFP *CF1 = dyn_cast<ConstantFP>(opr1);
826
827 if ((CF0 && CF1) || // both are constants
828 (CF1 && (getArgType(FInfo) == AMDGPULibFunc::F32)))
829 // CF1 is constant && f32 divide
830 {
831 Value *nval1 = B.CreateFDiv(ConstantFP::get(opr1->getType(), 1.0),
832 opr1, "__div2recip");
833 Value *nval = B.CreateFMul(opr0, nval1, "__div2mul");
834 replaceCall(nval);
835 return true;
836 }
837 return false;
838 }
839
840 namespace llvm {
log2(double V)841 static double log2(double V) {
842 #if _XOPEN_SOURCE >= 600 || defined(_ISOC99_SOURCE) || _POSIX_C_SOURCE >= 200112L
843 return ::log2(V);
844 #else
845 return log(V) / numbers::ln2;
846 #endif
847 }
848 }
849
fold_pow(CallInst * CI,IRBuilder<> & B,const FuncInfo & FInfo)850 bool AMDGPULibCalls::fold_pow(CallInst *CI, IRBuilder<> &B,
851 const FuncInfo &FInfo) {
852 assert((FInfo.getId() == AMDGPULibFunc::EI_POW ||
853 FInfo.getId() == AMDGPULibFunc::EI_POWR ||
854 FInfo.getId() == AMDGPULibFunc::EI_POWN) &&
855 "fold_pow: encounter a wrong function call");
856
857 Value *opr0, *opr1;
858 ConstantFP *CF;
859 ConstantInt *CINT;
860 ConstantAggregateZero *CZero;
861 Type *eltType;
862
863 opr0 = CI->getArgOperand(0);
864 opr1 = CI->getArgOperand(1);
865 CZero = dyn_cast<ConstantAggregateZero>(opr1);
866 if (getVecSize(FInfo) == 1) {
867 eltType = opr0->getType();
868 CF = dyn_cast<ConstantFP>(opr1);
869 CINT = dyn_cast<ConstantInt>(opr1);
870 } else {
871 VectorType *VTy = dyn_cast<VectorType>(opr0->getType());
872 assert(VTy && "Oprand of vector function should be of vectortype");
873 eltType = VTy->getElementType();
874 ConstantDataVector *CDV = dyn_cast<ConstantDataVector>(opr1);
875
876 // Now, only Handle vector const whose elements have the same value.
877 CF = CDV ? dyn_cast_or_null<ConstantFP>(CDV->getSplatValue()) : nullptr;
878 CINT = CDV ? dyn_cast_or_null<ConstantInt>(CDV->getSplatValue()) : nullptr;
879 }
880
881 // No unsafe math , no constant argument, do nothing
882 if (!isUnsafeMath(CI) && !CF && !CINT && !CZero)
883 return false;
884
885 // 0x1111111 means that we don't do anything for this call.
886 int ci_opr1 = (CINT ? (int)CINT->getSExtValue() : 0x1111111);
887
888 if ((CF && CF->isZero()) || (CINT && ci_opr1 == 0) || CZero) {
889 // pow/powr/pown(x, 0) == 1
890 LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> 1\n");
891 Constant *cnval = ConstantFP::get(eltType, 1.0);
892 if (getVecSize(FInfo) > 1) {
893 cnval = ConstantDataVector::getSplat(getVecSize(FInfo), cnval);
894 }
895 replaceCall(cnval);
896 return true;
897 }
898 if ((CF && CF->isExactlyValue(1.0)) || (CINT && ci_opr1 == 1)) {
899 // pow/powr/pown(x, 1.0) = x
900 LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *opr0 << "\n");
901 replaceCall(opr0);
902 return true;
903 }
904 if ((CF && CF->isExactlyValue(2.0)) || (CINT && ci_opr1 == 2)) {
905 // pow/powr/pown(x, 2.0) = x*x
906 LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *opr0 << " * " << *opr0
907 << "\n");
908 Value *nval = B.CreateFMul(opr0, opr0, "__pow2");
909 replaceCall(nval);
910 return true;
911 }
912 if ((CF && CF->isExactlyValue(-1.0)) || (CINT && ci_opr1 == -1)) {
913 // pow/powr/pown(x, -1.0) = 1.0/x
914 LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> 1 / " << *opr0 << "\n");
915 Constant *cnval = ConstantFP::get(eltType, 1.0);
916 if (getVecSize(FInfo) > 1) {
917 cnval = ConstantDataVector::getSplat(getVecSize(FInfo), cnval);
918 }
919 Value *nval = B.CreateFDiv(cnval, opr0, "__powrecip");
920 replaceCall(nval);
921 return true;
922 }
923
924 Module *M = CI->getModule();
925 if (CF && (CF->isExactlyValue(0.5) || CF->isExactlyValue(-0.5))) {
926 // pow[r](x, [-]0.5) = sqrt(x)
927 bool issqrt = CF->isExactlyValue(0.5);
928 if (FunctionCallee FPExpr =
929 getFunction(M, AMDGPULibFunc(issqrt ? AMDGPULibFunc::EI_SQRT
930 : AMDGPULibFunc::EI_RSQRT,
931 FInfo))) {
932 LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> "
933 << FInfo.getName().c_str() << "(" << *opr0 << ")\n");
934 Value *nval = CreateCallEx(B,FPExpr, opr0, issqrt ? "__pow2sqrt"
935 : "__pow2rsqrt");
936 replaceCall(nval);
937 return true;
938 }
939 }
940
941 if (!isUnsafeMath(CI))
942 return false;
943
944 // Unsafe Math optimization
945
946 // Remember that ci_opr1 is set if opr1 is integral
947 if (CF) {
948 double dval = (getArgType(FInfo) == AMDGPULibFunc::F32)
949 ? (double)CF->getValueAPF().convertToFloat()
950 : CF->getValueAPF().convertToDouble();
951 int ival = (int)dval;
952 if ((double)ival == dval) {
953 ci_opr1 = ival;
954 } else
955 ci_opr1 = 0x11111111;
956 }
957
958 // pow/powr/pown(x, c) = [1/](x*x*..x); where
959 // trunc(c) == c && the number of x == c && |c| <= 12
960 unsigned abs_opr1 = (ci_opr1 < 0) ? -ci_opr1 : ci_opr1;
961 if (abs_opr1 <= 12) {
962 Constant *cnval;
963 Value *nval;
964 if (abs_opr1 == 0) {
965 cnval = ConstantFP::get(eltType, 1.0);
966 if (getVecSize(FInfo) > 1) {
967 cnval = ConstantDataVector::getSplat(getVecSize(FInfo), cnval);
968 }
969 nval = cnval;
970 } else {
971 Value *valx2 = nullptr;
972 nval = nullptr;
973 while (abs_opr1 > 0) {
974 valx2 = valx2 ? B.CreateFMul(valx2, valx2, "__powx2") : opr0;
975 if (abs_opr1 & 1) {
976 nval = nval ? B.CreateFMul(nval, valx2, "__powprod") : valx2;
977 }
978 abs_opr1 >>= 1;
979 }
980 }
981
982 if (ci_opr1 < 0) {
983 cnval = ConstantFP::get(eltType, 1.0);
984 if (getVecSize(FInfo) > 1) {
985 cnval = ConstantDataVector::getSplat(getVecSize(FInfo), cnval);
986 }
987 nval = B.CreateFDiv(cnval, nval, "__1powprod");
988 }
989 LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> "
990 << ((ci_opr1 < 0) ? "1/prod(" : "prod(") << *opr0
991 << ")\n");
992 replaceCall(nval);
993 return true;
994 }
995
996 // powr ---> exp2(y * log2(x))
997 // pown/pow ---> powr(fabs(x), y) | (x & ((int)y << 31))
998 FunctionCallee ExpExpr =
999 getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_EXP2, FInfo));
1000 if (!ExpExpr)
1001 return false;
1002
1003 bool needlog = false;
1004 bool needabs = false;
1005 bool needcopysign = false;
1006 Constant *cnval = nullptr;
1007 if (getVecSize(FInfo) == 1) {
1008 CF = dyn_cast<ConstantFP>(opr0);
1009
1010 if (CF) {
1011 double V = (getArgType(FInfo) == AMDGPULibFunc::F32)
1012 ? (double)CF->getValueAPF().convertToFloat()
1013 : CF->getValueAPF().convertToDouble();
1014
1015 V = log2(std::abs(V));
1016 cnval = ConstantFP::get(eltType, V);
1017 needcopysign = (FInfo.getId() != AMDGPULibFunc::EI_POWR) &&
1018 CF->isNegative();
1019 } else {
1020 needlog = true;
1021 needcopysign = needabs = FInfo.getId() != AMDGPULibFunc::EI_POWR &&
1022 (!CF || CF->isNegative());
1023 }
1024 } else {
1025 ConstantDataVector *CDV = dyn_cast<ConstantDataVector>(opr0);
1026
1027 if (!CDV) {
1028 needlog = true;
1029 needcopysign = needabs = FInfo.getId() != AMDGPULibFunc::EI_POWR;
1030 } else {
1031 assert ((int)CDV->getNumElements() == getVecSize(FInfo) &&
1032 "Wrong vector size detected");
1033
1034 SmallVector<double, 0> DVal;
1035 for (int i=0; i < getVecSize(FInfo); ++i) {
1036 double V = (getArgType(FInfo) == AMDGPULibFunc::F32)
1037 ? (double)CDV->getElementAsFloat(i)
1038 : CDV->getElementAsDouble(i);
1039 if (V < 0.0) needcopysign = true;
1040 V = log2(std::abs(V));
1041 DVal.push_back(V);
1042 }
1043 if (getArgType(FInfo) == AMDGPULibFunc::F32) {
1044 SmallVector<float, 0> FVal;
1045 for (unsigned i=0; i < DVal.size(); ++i) {
1046 FVal.push_back((float)DVal[i]);
1047 }
1048 ArrayRef<float> tmp(FVal);
1049 cnval = ConstantDataVector::get(M->getContext(), tmp);
1050 } else {
1051 ArrayRef<double> tmp(DVal);
1052 cnval = ConstantDataVector::get(M->getContext(), tmp);
1053 }
1054 }
1055 }
1056
1057 if (needcopysign && (FInfo.getId() == AMDGPULibFunc::EI_POW)) {
1058 // We cannot handle corner cases for a general pow() function, give up
1059 // unless y is a constant integral value. Then proceed as if it were pown.
1060 if (getVecSize(FInfo) == 1) {
1061 if (const ConstantFP *CF = dyn_cast<ConstantFP>(opr1)) {
1062 double y = (getArgType(FInfo) == AMDGPULibFunc::F32)
1063 ? (double)CF->getValueAPF().convertToFloat()
1064 : CF->getValueAPF().convertToDouble();
1065 if (y != (double)(int64_t)y)
1066 return false;
1067 } else
1068 return false;
1069 } else {
1070 if (const ConstantDataVector *CDV = dyn_cast<ConstantDataVector>(opr1)) {
1071 for (int i=0; i < getVecSize(FInfo); ++i) {
1072 double y = (getArgType(FInfo) == AMDGPULibFunc::F32)
1073 ? (double)CDV->getElementAsFloat(i)
1074 : CDV->getElementAsDouble(i);
1075 if (y != (double)(int64_t)y)
1076 return false;
1077 }
1078 } else
1079 return false;
1080 }
1081 }
1082
1083 Value *nval;
1084 if (needabs) {
1085 FunctionCallee AbsExpr =
1086 getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_FABS, FInfo));
1087 if (!AbsExpr)
1088 return false;
1089 nval = CreateCallEx(B, AbsExpr, opr0, "__fabs");
1090 } else {
1091 nval = cnval ? cnval : opr0;
1092 }
1093 if (needlog) {
1094 FunctionCallee LogExpr =
1095 getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_LOG2, FInfo));
1096 if (!LogExpr)
1097 return false;
1098 nval = CreateCallEx(B,LogExpr, nval, "__log2");
1099 }
1100
1101 if (FInfo.getId() == AMDGPULibFunc::EI_POWN) {
1102 // convert int(32) to fp(f32 or f64)
1103 opr1 = B.CreateSIToFP(opr1, nval->getType(), "pownI2F");
1104 }
1105 nval = B.CreateFMul(opr1, nval, "__ylogx");
1106 nval = CreateCallEx(B,ExpExpr, nval, "__exp2");
1107
1108 if (needcopysign) {
1109 Value *opr_n;
1110 Type* rTy = opr0->getType();
1111 Type* nTyS = eltType->isDoubleTy() ? B.getInt64Ty() : B.getInt32Ty();
1112 Type *nTy = nTyS;
1113 if (const auto *vTy = dyn_cast<FixedVectorType>(rTy))
1114 nTy = FixedVectorType::get(nTyS, vTy);
1115 unsigned size = nTy->getScalarSizeInBits();
1116 opr_n = CI->getArgOperand(1);
1117 if (opr_n->getType()->isIntegerTy())
1118 opr_n = B.CreateZExtOrBitCast(opr_n, nTy, "__ytou");
1119 else
1120 opr_n = B.CreateFPToSI(opr1, nTy, "__ytou");
1121
1122 Value *sign = B.CreateShl(opr_n, size-1, "__yeven");
1123 sign = B.CreateAnd(B.CreateBitCast(opr0, nTy), sign, "__pow_sign");
1124 nval = B.CreateOr(B.CreateBitCast(nval, nTy), sign);
1125 nval = B.CreateBitCast(nval, opr0->getType());
1126 }
1127
1128 LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> "
1129 << "exp2(" << *opr1 << " * log2(" << *opr0 << "))\n");
1130 replaceCall(nval);
1131
1132 return true;
1133 }
1134
fold_rootn(CallInst * CI,IRBuilder<> & B,const FuncInfo & FInfo)1135 bool AMDGPULibCalls::fold_rootn(CallInst *CI, IRBuilder<> &B,
1136 const FuncInfo &FInfo) {
1137 Value *opr0 = CI->getArgOperand(0);
1138 Value *opr1 = CI->getArgOperand(1);
1139
1140 ConstantInt *CINT = dyn_cast<ConstantInt>(opr1);
1141 if (!CINT) {
1142 return false;
1143 }
1144 int ci_opr1 = (int)CINT->getSExtValue();
1145 if (ci_opr1 == 1) { // rootn(x, 1) = x
1146 LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *opr0 << "\n");
1147 replaceCall(opr0);
1148 return true;
1149 }
1150 if (ci_opr1 == 2) { // rootn(x, 2) = sqrt(x)
1151 Module *M = CI->getModule();
1152 if (FunctionCallee FPExpr =
1153 getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_SQRT, FInfo))) {
1154 LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> sqrt(" << *opr0 << ")\n");
1155 Value *nval = CreateCallEx(B,FPExpr, opr0, "__rootn2sqrt");
1156 replaceCall(nval);
1157 return true;
1158 }
1159 } else if (ci_opr1 == 3) { // rootn(x, 3) = cbrt(x)
1160 Module *M = CI->getModule();
1161 if (FunctionCallee FPExpr =
1162 getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_CBRT, FInfo))) {
1163 LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> cbrt(" << *opr0 << ")\n");
1164 Value *nval = CreateCallEx(B,FPExpr, opr0, "__rootn2cbrt");
1165 replaceCall(nval);
1166 return true;
1167 }
1168 } else if (ci_opr1 == -1) { // rootn(x, -1) = 1.0/x
1169 LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> 1.0 / " << *opr0 << "\n");
1170 Value *nval = B.CreateFDiv(ConstantFP::get(opr0->getType(), 1.0),
1171 opr0,
1172 "__rootn2div");
1173 replaceCall(nval);
1174 return true;
1175 } else if (ci_opr1 == -2) { // rootn(x, -2) = rsqrt(x)
1176 Module *M = CI->getModule();
1177 if (FunctionCallee FPExpr =
1178 getFunction(M, AMDGPULibFunc(AMDGPULibFunc::EI_RSQRT, FInfo))) {
1179 LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> rsqrt(" << *opr0
1180 << ")\n");
1181 Value *nval = CreateCallEx(B,FPExpr, opr0, "__rootn2rsqrt");
1182 replaceCall(nval);
1183 return true;
1184 }
1185 }
1186 return false;
1187 }
1188
fold_fma_mad(CallInst * CI,IRBuilder<> & B,const FuncInfo & FInfo)1189 bool AMDGPULibCalls::fold_fma_mad(CallInst *CI, IRBuilder<> &B,
1190 const FuncInfo &FInfo) {
1191 Value *opr0 = CI->getArgOperand(0);
1192 Value *opr1 = CI->getArgOperand(1);
1193 Value *opr2 = CI->getArgOperand(2);
1194
1195 ConstantFP *CF0 = dyn_cast<ConstantFP>(opr0);
1196 ConstantFP *CF1 = dyn_cast<ConstantFP>(opr1);
1197 if ((CF0 && CF0->isZero()) || (CF1 && CF1->isZero())) {
1198 // fma/mad(a, b, c) = c if a=0 || b=0
1199 LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *opr2 << "\n");
1200 replaceCall(opr2);
1201 return true;
1202 }
1203 if (CF0 && CF0->isExactlyValue(1.0f)) {
1204 // fma/mad(a, b, c) = b+c if a=1
1205 LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *opr1 << " + " << *opr2
1206 << "\n");
1207 Value *nval = B.CreateFAdd(opr1, opr2, "fmaadd");
1208 replaceCall(nval);
1209 return true;
1210 }
1211 if (CF1 && CF1->isExactlyValue(1.0f)) {
1212 // fma/mad(a, b, c) = a+c if b=1
1213 LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *opr0 << " + " << *opr2
1214 << "\n");
1215 Value *nval = B.CreateFAdd(opr0, opr2, "fmaadd");
1216 replaceCall(nval);
1217 return true;
1218 }
1219 if (ConstantFP *CF = dyn_cast<ConstantFP>(opr2)) {
1220 if (CF->isZero()) {
1221 // fma/mad(a, b, c) = a*b if c=0
1222 LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> " << *opr0 << " * "
1223 << *opr1 << "\n");
1224 Value *nval = B.CreateFMul(opr0, opr1, "fmamul");
1225 replaceCall(nval);
1226 return true;
1227 }
1228 }
1229
1230 return false;
1231 }
1232
1233 // Get a scalar native builtin signle argument FP function
getNativeFunction(Module * M,const FuncInfo & FInfo)1234 FunctionCallee AMDGPULibCalls::getNativeFunction(Module *M,
1235 const FuncInfo &FInfo) {
1236 if (getArgType(FInfo) == AMDGPULibFunc::F64 || !HasNative(FInfo.getId()))
1237 return nullptr;
1238 FuncInfo nf = FInfo;
1239 nf.setPrefix(AMDGPULibFunc::NATIVE);
1240 return getFunction(M, nf);
1241 }
1242
1243 // fold sqrt -> native_sqrt (x)
fold_sqrt(CallInst * CI,IRBuilder<> & B,const FuncInfo & FInfo)1244 bool AMDGPULibCalls::fold_sqrt(CallInst *CI, IRBuilder<> &B,
1245 const FuncInfo &FInfo) {
1246 if (getArgType(FInfo) == AMDGPULibFunc::F32 && (getVecSize(FInfo) == 1) &&
1247 (FInfo.getPrefix() != AMDGPULibFunc::NATIVE)) {
1248 if (FunctionCallee FPExpr = getNativeFunction(
1249 CI->getModule(), AMDGPULibFunc(AMDGPULibFunc::EI_SQRT, FInfo))) {
1250 Value *opr0 = CI->getArgOperand(0);
1251 LLVM_DEBUG(errs() << "AMDIC: " << *CI << " ---> "
1252 << "sqrt(" << *opr0 << ")\n");
1253 Value *nval = CreateCallEx(B,FPExpr, opr0, "__sqrt");
1254 replaceCall(nval);
1255 return true;
1256 }
1257 }
1258 return false;
1259 }
1260
1261 // fold sin, cos -> sincos.
fold_sincos(CallInst * CI,IRBuilder<> & B,AliasAnalysis * AA)1262 bool AMDGPULibCalls::fold_sincos(CallInst *CI, IRBuilder<> &B,
1263 AliasAnalysis *AA) {
1264 AMDGPULibFunc fInfo;
1265 if (!AMDGPULibFunc::parse(CI->getCalledFunction()->getName(), fInfo))
1266 return false;
1267
1268 assert(fInfo.getId() == AMDGPULibFunc::EI_SIN ||
1269 fInfo.getId() == AMDGPULibFunc::EI_COS);
1270 bool const isSin = fInfo.getId() == AMDGPULibFunc::EI_SIN;
1271
1272 Value *CArgVal = CI->getArgOperand(0);
1273 BasicBlock * const CBB = CI->getParent();
1274
1275 int const MaxScan = 30;
1276 bool Changed = false;
1277
1278 { // fold in load value.
1279 LoadInst *LI = dyn_cast<LoadInst>(CArgVal);
1280 if (LI && LI->getParent() == CBB) {
1281 BasicBlock::iterator BBI = LI->getIterator();
1282 Value *AvailableVal = FindAvailableLoadedValue(LI, CBB, BBI, MaxScan, AA);
1283 if (AvailableVal) {
1284 Changed = true;
1285 CArgVal->replaceAllUsesWith(AvailableVal);
1286 if (CArgVal->getNumUses() == 0)
1287 LI->eraseFromParent();
1288 CArgVal = CI->getArgOperand(0);
1289 }
1290 }
1291 }
1292
1293 Module *M = CI->getModule();
1294 fInfo.setId(isSin ? AMDGPULibFunc::EI_COS : AMDGPULibFunc::EI_SIN);
1295 std::string const PairName = fInfo.mangle();
1296
1297 CallInst *UI = nullptr;
1298 for (User* U : CArgVal->users()) {
1299 CallInst *XI = dyn_cast_or_null<CallInst>(U);
1300 if (!XI || XI == CI || XI->getParent() != CBB)
1301 continue;
1302
1303 Function *UCallee = XI->getCalledFunction();
1304 if (!UCallee || !UCallee->getName().equals(PairName))
1305 continue;
1306
1307 BasicBlock::iterator BBI = CI->getIterator();
1308 if (BBI == CI->getParent()->begin())
1309 break;
1310 --BBI;
1311 for (int I = MaxScan; I > 0 && BBI != CBB->begin(); --BBI, --I) {
1312 if (cast<Instruction>(BBI) == XI) {
1313 UI = XI;
1314 break;
1315 }
1316 }
1317 if (UI) break;
1318 }
1319
1320 if (!UI)
1321 return Changed;
1322
1323 // Merge the sin and cos.
1324
1325 // for OpenCL 2.0 we have only generic implementation of sincos
1326 // function.
1327 AMDGPULibFunc nf(AMDGPULibFunc::EI_SINCOS, fInfo);
1328 nf.getLeads()[0].PtrKind = AMDGPULibFunc::getEPtrKindFromAddrSpace(AMDGPUAS::FLAT_ADDRESS);
1329 FunctionCallee Fsincos = getFunction(M, nf);
1330 if (!Fsincos)
1331 return Changed;
1332
1333 BasicBlock::iterator ItOld = B.GetInsertPoint();
1334 AllocaInst *Alloc = insertAlloca(UI, B, "__sincos_");
1335 B.SetInsertPoint(UI);
1336
1337 Value *P = Alloc;
1338 Type *PTy = Fsincos.getFunctionType()->getParamType(1);
1339 // The allocaInst allocates the memory in private address space. This need
1340 // to be bitcasted to point to the address space of cos pointer type.
1341 // In OpenCL 2.0 this is generic, while in 1.2 that is private.
1342 if (PTy->getPointerAddressSpace() != AMDGPUAS::PRIVATE_ADDRESS)
1343 P = B.CreateAddrSpaceCast(Alloc, PTy);
1344 CallInst *Call = CreateCallEx2(B, Fsincos, UI->getArgOperand(0), P);
1345
1346 LLVM_DEBUG(errs() << "AMDIC: fold_sincos (" << *CI << ", " << *UI << ") with "
1347 << *Call << "\n");
1348
1349 if (!isSin) { // CI->cos, UI->sin
1350 B.SetInsertPoint(&*ItOld);
1351 UI->replaceAllUsesWith(&*Call);
1352 Instruction *Reload = B.CreateLoad(Alloc->getAllocatedType(), Alloc);
1353 CI->replaceAllUsesWith(Reload);
1354 UI->eraseFromParent();
1355 CI->eraseFromParent();
1356 } else { // CI->sin, UI->cos
1357 Instruction *Reload = B.CreateLoad(Alloc->getAllocatedType(), Alloc);
1358 UI->replaceAllUsesWith(Reload);
1359 CI->replaceAllUsesWith(Call);
1360 UI->eraseFromParent();
1361 CI->eraseFromParent();
1362 }
1363 return true;
1364 }
1365
fold_wavefrontsize(CallInst * CI,IRBuilder<> & B)1366 bool AMDGPULibCalls::fold_wavefrontsize(CallInst *CI, IRBuilder<> &B) {
1367 if (!TM)
1368 return false;
1369
1370 StringRef CPU = TM->getTargetCPU();
1371 StringRef Features = TM->getTargetFeatureString();
1372 if ((CPU.empty() || CPU.equals_lower("generic")) &&
1373 (Features.empty() ||
1374 Features.find_lower("wavefrontsize") == StringRef::npos))
1375 return false;
1376
1377 Function *F = CI->getParent()->getParent();
1378 const GCNSubtarget &ST = TM->getSubtarget<GCNSubtarget>(*F);
1379 unsigned N = ST.getWavefrontSize();
1380
1381 LLVM_DEBUG(errs() << "AMDIC: fold_wavefrontsize (" << *CI << ") with "
1382 << N << "\n");
1383
1384 CI->replaceAllUsesWith(ConstantInt::get(B.getInt32Ty(), N));
1385 CI->eraseFromParent();
1386 return true;
1387 }
1388
1389 // Get insertion point at entry.
getEntryIns(CallInst * UI)1390 BasicBlock::iterator AMDGPULibCalls::getEntryIns(CallInst * UI) {
1391 Function * Func = UI->getParent()->getParent();
1392 BasicBlock * BB = &Func->getEntryBlock();
1393 assert(BB && "Entry block not found!");
1394 BasicBlock::iterator ItNew = BB->begin();
1395 return ItNew;
1396 }
1397
1398 // Insert a AllocsInst at the beginning of function entry block.
insertAlloca(CallInst * UI,IRBuilder<> & B,const char * prefix)1399 AllocaInst* AMDGPULibCalls::insertAlloca(CallInst *UI, IRBuilder<> &B,
1400 const char *prefix) {
1401 BasicBlock::iterator ItNew = getEntryIns(UI);
1402 Function *UCallee = UI->getCalledFunction();
1403 Type *RetType = UCallee->getReturnType();
1404 B.SetInsertPoint(&*ItNew);
1405 AllocaInst *Alloc = B.CreateAlloca(RetType, 0,
1406 std::string(prefix) + UI->getName());
1407 Alloc->setAlignment(
1408 Align(UCallee->getParent()->getDataLayout().getTypeAllocSize(RetType)));
1409 return Alloc;
1410 }
1411
evaluateScalarMathFunc(FuncInfo & FInfo,double & Res0,double & Res1,Constant * copr0,Constant * copr1,Constant * copr2)1412 bool AMDGPULibCalls::evaluateScalarMathFunc(FuncInfo &FInfo,
1413 double& Res0, double& Res1,
1414 Constant *copr0, Constant *copr1,
1415 Constant *copr2) {
1416 // By default, opr0/opr1/opr3 holds values of float/double type.
1417 // If they are not float/double, each function has to its
1418 // operand separately.
1419 double opr0=0.0, opr1=0.0, opr2=0.0;
1420 ConstantFP *fpopr0 = dyn_cast_or_null<ConstantFP>(copr0);
1421 ConstantFP *fpopr1 = dyn_cast_or_null<ConstantFP>(copr1);
1422 ConstantFP *fpopr2 = dyn_cast_or_null<ConstantFP>(copr2);
1423 if (fpopr0) {
1424 opr0 = (getArgType(FInfo) == AMDGPULibFunc::F64)
1425 ? fpopr0->getValueAPF().convertToDouble()
1426 : (double)fpopr0->getValueAPF().convertToFloat();
1427 }
1428
1429 if (fpopr1) {
1430 opr1 = (getArgType(FInfo) == AMDGPULibFunc::F64)
1431 ? fpopr1->getValueAPF().convertToDouble()
1432 : (double)fpopr1->getValueAPF().convertToFloat();
1433 }
1434
1435 if (fpopr2) {
1436 opr2 = (getArgType(FInfo) == AMDGPULibFunc::F64)
1437 ? fpopr2->getValueAPF().convertToDouble()
1438 : (double)fpopr2->getValueAPF().convertToFloat();
1439 }
1440
1441 switch (FInfo.getId()) {
1442 default : return false;
1443
1444 case AMDGPULibFunc::EI_ACOS:
1445 Res0 = acos(opr0);
1446 return true;
1447
1448 case AMDGPULibFunc::EI_ACOSH:
1449 // acosh(x) == log(x + sqrt(x*x - 1))
1450 Res0 = log(opr0 + sqrt(opr0*opr0 - 1.0));
1451 return true;
1452
1453 case AMDGPULibFunc::EI_ACOSPI:
1454 Res0 = acos(opr0) / MATH_PI;
1455 return true;
1456
1457 case AMDGPULibFunc::EI_ASIN:
1458 Res0 = asin(opr0);
1459 return true;
1460
1461 case AMDGPULibFunc::EI_ASINH:
1462 // asinh(x) == log(x + sqrt(x*x + 1))
1463 Res0 = log(opr0 + sqrt(opr0*opr0 + 1.0));
1464 return true;
1465
1466 case AMDGPULibFunc::EI_ASINPI:
1467 Res0 = asin(opr0) / MATH_PI;
1468 return true;
1469
1470 case AMDGPULibFunc::EI_ATAN:
1471 Res0 = atan(opr0);
1472 return true;
1473
1474 case AMDGPULibFunc::EI_ATANH:
1475 // atanh(x) == (log(x+1) - log(x-1))/2;
1476 Res0 = (log(opr0 + 1.0) - log(opr0 - 1.0))/2.0;
1477 return true;
1478
1479 case AMDGPULibFunc::EI_ATANPI:
1480 Res0 = atan(opr0) / MATH_PI;
1481 return true;
1482
1483 case AMDGPULibFunc::EI_CBRT:
1484 Res0 = (opr0 < 0.0) ? -pow(-opr0, 1.0/3.0) : pow(opr0, 1.0/3.0);
1485 return true;
1486
1487 case AMDGPULibFunc::EI_COS:
1488 Res0 = cos(opr0);
1489 return true;
1490
1491 case AMDGPULibFunc::EI_COSH:
1492 Res0 = cosh(opr0);
1493 return true;
1494
1495 case AMDGPULibFunc::EI_COSPI:
1496 Res0 = cos(MATH_PI * opr0);
1497 return true;
1498
1499 case AMDGPULibFunc::EI_EXP:
1500 Res0 = exp(opr0);
1501 return true;
1502
1503 case AMDGPULibFunc::EI_EXP2:
1504 Res0 = pow(2.0, opr0);
1505 return true;
1506
1507 case AMDGPULibFunc::EI_EXP10:
1508 Res0 = pow(10.0, opr0);
1509 return true;
1510
1511 case AMDGPULibFunc::EI_EXPM1:
1512 Res0 = exp(opr0) - 1.0;
1513 return true;
1514
1515 case AMDGPULibFunc::EI_LOG:
1516 Res0 = log(opr0);
1517 return true;
1518
1519 case AMDGPULibFunc::EI_LOG2:
1520 Res0 = log(opr0) / log(2.0);
1521 return true;
1522
1523 case AMDGPULibFunc::EI_LOG10:
1524 Res0 = log(opr0) / log(10.0);
1525 return true;
1526
1527 case AMDGPULibFunc::EI_RSQRT:
1528 Res0 = 1.0 / sqrt(opr0);
1529 return true;
1530
1531 case AMDGPULibFunc::EI_SIN:
1532 Res0 = sin(opr0);
1533 return true;
1534
1535 case AMDGPULibFunc::EI_SINH:
1536 Res0 = sinh(opr0);
1537 return true;
1538
1539 case AMDGPULibFunc::EI_SINPI:
1540 Res0 = sin(MATH_PI * opr0);
1541 return true;
1542
1543 case AMDGPULibFunc::EI_SQRT:
1544 Res0 = sqrt(opr0);
1545 return true;
1546
1547 case AMDGPULibFunc::EI_TAN:
1548 Res0 = tan(opr0);
1549 return true;
1550
1551 case AMDGPULibFunc::EI_TANH:
1552 Res0 = tanh(opr0);
1553 return true;
1554
1555 case AMDGPULibFunc::EI_TANPI:
1556 Res0 = tan(MATH_PI * opr0);
1557 return true;
1558
1559 case AMDGPULibFunc::EI_RECIP:
1560 Res0 = 1.0 / opr0;
1561 return true;
1562
1563 // two-arg functions
1564 case AMDGPULibFunc::EI_DIVIDE:
1565 Res0 = opr0 / opr1;
1566 return true;
1567
1568 case AMDGPULibFunc::EI_POW:
1569 case AMDGPULibFunc::EI_POWR:
1570 Res0 = pow(opr0, opr1);
1571 return true;
1572
1573 case AMDGPULibFunc::EI_POWN: {
1574 if (ConstantInt *iopr1 = dyn_cast_or_null<ConstantInt>(copr1)) {
1575 double val = (double)iopr1->getSExtValue();
1576 Res0 = pow(opr0, val);
1577 return true;
1578 }
1579 return false;
1580 }
1581
1582 case AMDGPULibFunc::EI_ROOTN: {
1583 if (ConstantInt *iopr1 = dyn_cast_or_null<ConstantInt>(copr1)) {
1584 double val = (double)iopr1->getSExtValue();
1585 Res0 = pow(opr0, 1.0 / val);
1586 return true;
1587 }
1588 return false;
1589 }
1590
1591 // with ptr arg
1592 case AMDGPULibFunc::EI_SINCOS:
1593 Res0 = sin(opr0);
1594 Res1 = cos(opr0);
1595 return true;
1596
1597 // three-arg functions
1598 case AMDGPULibFunc::EI_FMA:
1599 case AMDGPULibFunc::EI_MAD:
1600 Res0 = opr0 * opr1 + opr2;
1601 return true;
1602 }
1603
1604 return false;
1605 }
1606
evaluateCall(CallInst * aCI,FuncInfo & FInfo)1607 bool AMDGPULibCalls::evaluateCall(CallInst *aCI, FuncInfo &FInfo) {
1608 int numArgs = (int)aCI->getNumArgOperands();
1609 if (numArgs > 3)
1610 return false;
1611
1612 Constant *copr0 = nullptr;
1613 Constant *copr1 = nullptr;
1614 Constant *copr2 = nullptr;
1615 if (numArgs > 0) {
1616 if ((copr0 = dyn_cast<Constant>(aCI->getArgOperand(0))) == nullptr)
1617 return false;
1618 }
1619
1620 if (numArgs > 1) {
1621 if ((copr1 = dyn_cast<Constant>(aCI->getArgOperand(1))) == nullptr) {
1622 if (FInfo.getId() != AMDGPULibFunc::EI_SINCOS)
1623 return false;
1624 }
1625 }
1626
1627 if (numArgs > 2) {
1628 if ((copr2 = dyn_cast<Constant>(aCI->getArgOperand(2))) == nullptr)
1629 return false;
1630 }
1631
1632 // At this point, all arguments to aCI are constants.
1633
1634 // max vector size is 16, and sincos will generate two results.
1635 double DVal0[16], DVal1[16];
1636 bool hasTwoResults = (FInfo.getId() == AMDGPULibFunc::EI_SINCOS);
1637 if (getVecSize(FInfo) == 1) {
1638 if (!evaluateScalarMathFunc(FInfo, DVal0[0],
1639 DVal1[0], copr0, copr1, copr2)) {
1640 return false;
1641 }
1642 } else {
1643 ConstantDataVector *CDV0 = dyn_cast_or_null<ConstantDataVector>(copr0);
1644 ConstantDataVector *CDV1 = dyn_cast_or_null<ConstantDataVector>(copr1);
1645 ConstantDataVector *CDV2 = dyn_cast_or_null<ConstantDataVector>(copr2);
1646 for (int i=0; i < getVecSize(FInfo); ++i) {
1647 Constant *celt0 = CDV0 ? CDV0->getElementAsConstant(i) : nullptr;
1648 Constant *celt1 = CDV1 ? CDV1->getElementAsConstant(i) : nullptr;
1649 Constant *celt2 = CDV2 ? CDV2->getElementAsConstant(i) : nullptr;
1650 if (!evaluateScalarMathFunc(FInfo, DVal0[i],
1651 DVal1[i], celt0, celt1, celt2)) {
1652 return false;
1653 }
1654 }
1655 }
1656
1657 LLVMContext &context = CI->getParent()->getParent()->getContext();
1658 Constant *nval0, *nval1;
1659 if (getVecSize(FInfo) == 1) {
1660 nval0 = ConstantFP::get(CI->getType(), DVal0[0]);
1661 if (hasTwoResults)
1662 nval1 = ConstantFP::get(CI->getType(), DVal1[0]);
1663 } else {
1664 if (getArgType(FInfo) == AMDGPULibFunc::F32) {
1665 SmallVector <float, 0> FVal0, FVal1;
1666 for (int i=0; i < getVecSize(FInfo); ++i)
1667 FVal0.push_back((float)DVal0[i]);
1668 ArrayRef<float> tmp0(FVal0);
1669 nval0 = ConstantDataVector::get(context, tmp0);
1670 if (hasTwoResults) {
1671 for (int i=0; i < getVecSize(FInfo); ++i)
1672 FVal1.push_back((float)DVal1[i]);
1673 ArrayRef<float> tmp1(FVal1);
1674 nval1 = ConstantDataVector::get(context, tmp1);
1675 }
1676 } else {
1677 ArrayRef<double> tmp0(DVal0);
1678 nval0 = ConstantDataVector::get(context, tmp0);
1679 if (hasTwoResults) {
1680 ArrayRef<double> tmp1(DVal1);
1681 nval1 = ConstantDataVector::get(context, tmp1);
1682 }
1683 }
1684 }
1685
1686 if (hasTwoResults) {
1687 // sincos
1688 assert(FInfo.getId() == AMDGPULibFunc::EI_SINCOS &&
1689 "math function with ptr arg not supported yet");
1690 new StoreInst(nval1, aCI->getArgOperand(1), aCI);
1691 }
1692
1693 replaceCall(nval0);
1694 return true;
1695 }
1696
1697 // Public interface to the Simplify LibCalls pass.
createAMDGPUSimplifyLibCallsPass(const TargetMachine * TM)1698 FunctionPass *llvm::createAMDGPUSimplifyLibCallsPass(const TargetMachine *TM) {
1699 return new AMDGPUSimplifyLibCalls(TM);
1700 }
1701
createAMDGPUUseNativeCallsPass()1702 FunctionPass *llvm::createAMDGPUUseNativeCallsPass() {
1703 return new AMDGPUUseNativeCalls();
1704 }
1705
runOnFunction(Function & F)1706 bool AMDGPUSimplifyLibCalls::runOnFunction(Function &F) {
1707 if (skipFunction(F))
1708 return false;
1709
1710 bool Changed = false;
1711 auto AA = &getAnalysis<AAResultsWrapperPass>().getAAResults();
1712
1713 LLVM_DEBUG(dbgs() << "AMDIC: process function ";
1714 F.printAsOperand(dbgs(), false, F.getParent()); dbgs() << '\n';);
1715
1716 for (auto &BB : F) {
1717 for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ) {
1718 // Ignore non-calls.
1719 CallInst *CI = dyn_cast<CallInst>(I);
1720 ++I;
1721 // Ignore intrinsics that do not become real instructions.
1722 if (!CI || isa<DbgInfoIntrinsic>(CI) || CI->isLifetimeStartOrEnd())
1723 continue;
1724
1725 // Ignore indirect calls.
1726 Function *Callee = CI->getCalledFunction();
1727 if (Callee == 0) continue;
1728
1729 LLVM_DEBUG(dbgs() << "AMDIC: try folding " << *CI << "\n";
1730 dbgs().flush());
1731 if(Simplifier.fold(CI, AA))
1732 Changed = true;
1733 }
1734 }
1735 return Changed;
1736 }
1737
run(Function & F,FunctionAnalysisManager & AM)1738 PreservedAnalyses AMDGPUSimplifyLibCallsPass::run(Function &F,
1739 FunctionAnalysisManager &AM) {
1740 AMDGPULibCalls Simplifier(&TM);
1741 Simplifier.initNativeFuncs();
1742
1743 bool Changed = false;
1744 auto AA = &AM.getResult<AAManager>(F);
1745
1746 LLVM_DEBUG(dbgs() << "AMDIC: process function ";
1747 F.printAsOperand(dbgs(), false, F.getParent()); dbgs() << '\n';);
1748
1749 for (auto &BB : F) {
1750 for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E;) {
1751 // Ignore non-calls.
1752 CallInst *CI = dyn_cast<CallInst>(I);
1753 ++I;
1754 // Ignore intrinsics that do not become real instructions.
1755 if (!CI || isa<DbgInfoIntrinsic>(CI) || CI->isLifetimeStartOrEnd())
1756 continue;
1757
1758 // Ignore indirect calls.
1759 Function *Callee = CI->getCalledFunction();
1760 if (Callee == 0)
1761 continue;
1762
1763 LLVM_DEBUG(dbgs() << "AMDIC: try folding " << *CI << "\n";
1764 dbgs().flush());
1765 if (Simplifier.fold(CI, AA))
1766 Changed = true;
1767 }
1768 }
1769 return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
1770 }
1771
runOnFunction(Function & F)1772 bool AMDGPUUseNativeCalls::runOnFunction(Function &F) {
1773 if (skipFunction(F) || UseNative.empty())
1774 return false;
1775
1776 bool Changed = false;
1777 for (auto &BB : F) {
1778 for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E; ) {
1779 // Ignore non-calls.
1780 CallInst *CI = dyn_cast<CallInst>(I);
1781 ++I;
1782 if (!CI) continue;
1783
1784 // Ignore indirect calls.
1785 Function *Callee = CI->getCalledFunction();
1786 if (Callee == 0) continue;
1787
1788 if(Simplifier.useNative(CI))
1789 Changed = true;
1790 }
1791 }
1792 return Changed;
1793 }
1794
run(Function & F,FunctionAnalysisManager & AM)1795 PreservedAnalyses AMDGPUUseNativeCallsPass::run(Function &F,
1796 FunctionAnalysisManager &AM) {
1797 if (UseNative.empty())
1798 return PreservedAnalyses::all();
1799
1800 AMDGPULibCalls Simplifier;
1801 Simplifier.initNativeFuncs();
1802
1803 bool Changed = false;
1804 for (auto &BB : F) {
1805 for (BasicBlock::iterator I = BB.begin(), E = BB.end(); I != E;) {
1806 // Ignore non-calls.
1807 CallInst *CI = dyn_cast<CallInst>(I);
1808 ++I;
1809 if (!CI)
1810 continue;
1811
1812 // Ignore indirect calls.
1813 Function *Callee = CI->getCalledFunction();
1814 if (Callee == 0)
1815 continue;
1816
1817 if (Simplifier.useNative(CI))
1818 Changed = true;
1819 }
1820 }
1821 return Changed ? PreservedAnalyses::none() : PreservedAnalyses::all();
1822 }
1823