1 //===- FunctionSpecialization.cpp - Function Specialization ---------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This specialises functions with constant parameters. Constant parameters
10 // like function pointers and constant globals are propagated to the callee by
11 // specializing the function. The main benefit of this pass at the moment is
12 // that indirect calls are transformed into direct calls, which provides inline
13 // opportunities that the inliner would not have been able to achieve. That's
14 // why function specialisation is run before the inliner in the optimisation
15 // pipeline; that is by design. Otherwise, we would only benefit from constant
16 // passing, which is a valid use-case too, but hasn't been explored much in
17 // terms of performance uplifts, cost-model and compile-time impact.
18 //
19 // Current limitations:
20 // - It does not yet handle integer ranges. We do support "literal constants",
21 //   but that's off by default under an option.
22 // - The cost-model could be further looked into (it mainly focuses on inlining
23 //   benefits),
24 //
25 // Ideas:
26 // - With a function specialization attribute for arguments, we could have
27 //   a direct way to steer function specialization, avoiding the cost-model,
28 //   and thus control compile-times / code-size.
29 //
30 // Todos:
31 // - Specializing recursive functions relies on running the transformation a
32 //   number of times, which is controlled by option
33 //   `func-specialization-max-iters`. Thus, increasing this value and the
34 //   number of iterations, will linearly increase the number of times recursive
35 //   functions get specialized, see also the discussion in
36 //   https://reviews.llvm.org/D106426 for details. Perhaps there is a
37 //   compile-time friendlier way to control/limit the number of specialisations
38 //   for recursive functions.
39 // - Don't transform the function if function specialization does not trigger;
40 //   the SCCPSolver may make IR changes.
41 //
42 // References:
43 // - 2021 LLVM Dev Mtg “Introducing function specialisation, and can we enable
44 //   it by default?”, https://www.youtube.com/watch?v=zJiCjeXgV5Q
45 //
46 //===----------------------------------------------------------------------===//
47 
48 #include "llvm/ADT/Statistic.h"
49 #include "llvm/Analysis/CodeMetrics.h"
50 #include "llvm/Analysis/InlineCost.h"
51 #include "llvm/Analysis/LoopInfo.h"
52 #include "llvm/Analysis/TargetTransformInfo.h"
53 #include "llvm/Analysis/ValueLattice.h"
54 #include "llvm/Analysis/ValueLatticeUtils.h"
55 #include "llvm/IR/IntrinsicInst.h"
56 #include "llvm/Transforms/Scalar/SCCP.h"
57 #include "llvm/Transforms/Utils/Cloning.h"
58 #include "llvm/Transforms/Utils/SCCPSolver.h"
59 #include "llvm/Transforms/Utils/SizeOpts.h"
60 #include <cmath>
61 
62 using namespace llvm;
63 
64 #define DEBUG_TYPE "function-specialization"
65 
66 STATISTIC(NumFuncSpecialized, "Number of functions specialized");
67 
68 static cl::opt<bool> ForceFunctionSpecialization(
69     "force-function-specialization", cl::init(false), cl::Hidden,
70     cl::desc("Force function specialization for every call site with a "
71              "constant argument"));
72 
73 static cl::opt<unsigned> FuncSpecializationMaxIters(
74     "func-specialization-max-iters", cl::Hidden,
75     cl::desc("The maximum number of iterations function specialization is run"),
76     cl::init(1));
77 
78 static cl::opt<unsigned> MaxClonesThreshold(
79     "func-specialization-max-clones", cl::Hidden,
80     cl::desc("The maximum number of clones allowed for a single function "
81              "specialization"),
82     cl::init(3));
83 
84 static cl::opt<unsigned> SmallFunctionThreshold(
85     "func-specialization-size-threshold", cl::Hidden,
86     cl::desc("Don't specialize functions that have less than this theshold "
87              "number of instructions"),
88     cl::init(100));
89 
90 static cl::opt<unsigned>
91     AvgLoopIterationCount("func-specialization-avg-iters-cost", cl::Hidden,
92                           cl::desc("Average loop iteration count cost"),
93                           cl::init(10));
94 
95 static cl::opt<bool> SpecializeOnAddresses(
96     "func-specialization-on-address", cl::init(false), cl::Hidden,
97     cl::desc("Enable function specialization on the address of global values"));
98 
99 // Disabled by default as it can significantly increase compilation times.
100 // Running nikic's compile time tracker on x86 with instruction count as the
101 // metric shows 3-4% regression for SPASS while being neutral for all other
102 // benchmarks of the llvm test suite.
103 //
104 // https://llvm-compile-time-tracker.com
105 // https://github.com/nikic/llvm-compile-time-tracker
106 static cl::opt<bool> EnableSpecializationForLiteralConstant(
107     "function-specialization-for-literal-constant", cl::init(false), cl::Hidden,
108     cl::desc("Enable specialization of functions that take a literal constant "
109              "as an argument."));
110 
111 namespace {
112 // Bookkeeping struct to pass data from the analysis and profitability phase
113 // to the actual transform helper functions.
114 struct SpecializationInfo {
115   SmallVector<ArgInfo, 8> Args; // Stores the {formal,actual} argument pairs.
116   InstructionCost Gain;         // Profitability: Gain = Bonus - Cost.
117 };
118 } // Anonymous namespace
119 
120 using FuncList = SmallVectorImpl<Function *>;
121 using CallArgBinding = std::pair<CallBase *, Constant *>;
122 using CallSpecBinding = std::pair<CallBase *, SpecializationInfo>;
123 // We are using MapVector because it guarantees deterministic iteration
124 // order across executions.
125 using SpecializationMap = SmallMapVector<CallBase *, SpecializationInfo, 8>;
126 
127 // Helper to check if \p LV is either a constant or a constant
128 // range with a single element. This should cover exactly the same cases as the
129 // old ValueLatticeElement::isConstant() and is intended to be used in the
130 // transition to ValueLatticeElement.
131 static bool isConstant(const ValueLatticeElement &LV) {
132   return LV.isConstant() ||
133          (LV.isConstantRange() && LV.getConstantRange().isSingleElement());
134 }
135 
136 // Helper to check if \p LV is either overdefined or a constant int.
137 static bool isOverdefined(const ValueLatticeElement &LV) {
138   return !LV.isUnknownOrUndef() && !isConstant(LV);
139 }
140 
141 static Constant *getPromotableAlloca(AllocaInst *Alloca, CallInst *Call) {
142   Value *StoreValue = nullptr;
143   for (auto *User : Alloca->users()) {
144     // We can't use llvm::isAllocaPromotable() as that would fail because of
145     // the usage in the CallInst, which is what we check here.
146     if (User == Call)
147       continue;
148     if (auto *Bitcast = dyn_cast<BitCastInst>(User)) {
149       if (!Bitcast->hasOneUse() || *Bitcast->user_begin() != Call)
150         return nullptr;
151       continue;
152     }
153 
154     if (auto *Store = dyn_cast<StoreInst>(User)) {
155       // This is a duplicate store, bail out.
156       if (StoreValue || Store->isVolatile())
157         return nullptr;
158       StoreValue = Store->getValueOperand();
159       continue;
160     }
161     // Bail if there is any other unknown usage.
162     return nullptr;
163   }
164   return dyn_cast_or_null<Constant>(StoreValue);
165 }
166 
167 // A constant stack value is an AllocaInst that has a single constant
168 // value stored to it. Return this constant if such an alloca stack value
169 // is a function argument.
170 static Constant *getConstantStackValue(CallInst *Call, Value *Val,
171                                        SCCPSolver &Solver) {
172   if (!Val)
173     return nullptr;
174   Val = Val->stripPointerCasts();
175   if (auto *ConstVal = dyn_cast<ConstantInt>(Val))
176     return ConstVal;
177   auto *Alloca = dyn_cast<AllocaInst>(Val);
178   if (!Alloca || !Alloca->getAllocatedType()->isIntegerTy())
179     return nullptr;
180   return getPromotableAlloca(Alloca, Call);
181 }
182 
183 // To support specializing recursive functions, it is important to propagate
184 // constant arguments because after a first iteration of specialisation, a
185 // reduced example may look like this:
186 //
187 //     define internal void @RecursiveFn(i32* arg1) {
188 //       %temp = alloca i32, align 4
189 //       store i32 2 i32* %temp, align 4
190 //       call void @RecursiveFn.1(i32* nonnull %temp)
191 //       ret void
192 //     }
193 //
194 // Before a next iteration, we need to propagate the constant like so
195 // which allows further specialization in next iterations.
196 //
197 //     @funcspec.arg = internal constant i32 2
198 //
199 //     define internal void @someFunc(i32* arg1) {
200 //       call void @otherFunc(i32* nonnull @funcspec.arg)
201 //       ret void
202 //     }
203 //
204 static void constantArgPropagation(FuncList &WorkList, Module &M,
205                                    SCCPSolver &Solver) {
206   // Iterate over the argument tracked functions see if there
207   // are any new constant values for the call instruction via
208   // stack variables.
209   for (auto *F : WorkList) {
210 
211     for (auto *User : F->users()) {
212 
213       auto *Call = dyn_cast<CallInst>(User);
214       if (!Call)
215         continue;
216 
217       bool Changed = false;
218       for (const Use &U : Call->args()) {
219         unsigned Idx = Call->getArgOperandNo(&U);
220         Value *ArgOp = Call->getArgOperand(Idx);
221         Type *ArgOpType = ArgOp->getType();
222 
223         if (!Call->onlyReadsMemory(Idx) || !ArgOpType->isPointerTy())
224           continue;
225 
226         auto *ConstVal = getConstantStackValue(Call, ArgOp, Solver);
227         if (!ConstVal)
228           continue;
229 
230         Value *GV = new GlobalVariable(M, ConstVal->getType(), true,
231                                        GlobalValue::InternalLinkage, ConstVal,
232                                        "funcspec.arg");
233         if (ArgOpType != ConstVal->getType())
234           GV = ConstantExpr::getBitCast(cast<Constant>(GV), ArgOpType);
235 
236         Call->setArgOperand(Idx, GV);
237         Changed = true;
238       }
239 
240       // Add the changed CallInst to Solver Worklist
241       if (Changed)
242         Solver.visitCall(*Call);
243     }
244   }
245 }
246 
247 // ssa_copy intrinsics are introduced by the SCCP solver. These intrinsics
248 // interfere with the constantArgPropagation optimization.
249 static void removeSSACopy(Function &F) {
250   for (BasicBlock &BB : F) {
251     for (Instruction &Inst : llvm::make_early_inc_range(BB)) {
252       auto *II = dyn_cast<IntrinsicInst>(&Inst);
253       if (!II)
254         continue;
255       if (II->getIntrinsicID() != Intrinsic::ssa_copy)
256         continue;
257       Inst.replaceAllUsesWith(II->getOperand(0));
258       Inst.eraseFromParent();
259     }
260   }
261 }
262 
263 static void removeSSACopy(Module &M) {
264   for (Function &F : M)
265     removeSSACopy(F);
266 }
267 
268 namespace {
269 class FunctionSpecializer {
270 
271   /// The IPSCCP Solver.
272   SCCPSolver &Solver;
273 
274   /// Analyses used to help determine if a function should be specialized.
275   std::function<AssumptionCache &(Function &)> GetAC;
276   std::function<TargetTransformInfo &(Function &)> GetTTI;
277   std::function<TargetLibraryInfo &(Function &)> GetTLI;
278 
279   SmallPtrSet<Function *, 4> SpecializedFuncs;
280   SmallPtrSet<Function *, 4> FullySpecialized;
281   SmallVector<Instruction *> ReplacedWithConstant;
282   DenseMap<Function *, CodeMetrics> FunctionMetrics;
283 
284 public:
285   FunctionSpecializer(SCCPSolver &Solver,
286                       std::function<AssumptionCache &(Function &)> GetAC,
287                       std::function<TargetTransformInfo &(Function &)> GetTTI,
288                       std::function<TargetLibraryInfo &(Function &)> GetTLI)
289       : Solver(Solver), GetAC(GetAC), GetTTI(GetTTI), GetTLI(GetTLI) {}
290 
291   ~FunctionSpecializer() {
292     // Eliminate dead code.
293     removeDeadInstructions();
294     removeDeadFunctions();
295   }
296 
297   /// Attempt to specialize functions in the module to enable constant
298   /// propagation across function boundaries.
299   ///
300   /// \returns true if at least one function is specialized.
301   bool specializeFunctions(FuncList &Candidates, FuncList &WorkList) {
302     bool Changed = false;
303     for (auto *F : Candidates) {
304       if (!isCandidateFunction(F))
305         continue;
306 
307       auto Cost = getSpecializationCost(F);
308       if (!Cost.isValid()) {
309         LLVM_DEBUG(
310             dbgs() << "FnSpecialization: Invalid specialization cost.\n");
311         continue;
312       }
313 
314       LLVM_DEBUG(dbgs() << "FnSpecialization: Specialization cost for "
315                         << F->getName() << " is " << Cost << "\n");
316 
317       SmallVector<CallSpecBinding, 8> Specializations;
318       if (!calculateGains(F, Cost, Specializations)) {
319         LLVM_DEBUG(dbgs() << "FnSpecialization: No possible constants found\n");
320         continue;
321       }
322 
323       Changed = true;
324       for (auto &Entry : Specializations)
325         specializeFunction(F, Entry.second, WorkList);
326     }
327 
328     updateSpecializedFuncs(Candidates, WorkList);
329     NumFuncSpecialized += NbFunctionsSpecialized;
330     return Changed;
331   }
332 
333   void removeDeadInstructions() {
334     for (auto *I : ReplacedWithConstant) {
335       LLVM_DEBUG(dbgs() << "FnSpecialization: Removing dead instruction " << *I
336                         << "\n");
337       I->eraseFromParent();
338     }
339     ReplacedWithConstant.clear();
340   }
341 
342   void removeDeadFunctions() {
343     for (auto *F : FullySpecialized) {
344       LLVM_DEBUG(dbgs() << "FnSpecialization: Removing dead function "
345                         << F->getName() << "\n");
346       F->eraseFromParent();
347     }
348     FullySpecialized.clear();
349   }
350 
351   bool tryToReplaceWithConstant(Value *V) {
352     if (!V->getType()->isSingleValueType() || isa<CallBase>(V) ||
353         V->user_empty())
354       return false;
355 
356     const ValueLatticeElement &IV = Solver.getLatticeValueFor(V);
357     if (isOverdefined(IV))
358       return false;
359     auto *Const =
360         isConstant(IV) ? Solver.getConstant(IV) : UndefValue::get(V->getType());
361 
362     LLVM_DEBUG(dbgs() << "FnSpecialization: Replacing " << *V
363                       << "\nFnSpecialization: with " << *Const << "\n");
364 
365     // Record uses of V to avoid visiting irrelevant uses of const later.
366     SmallVector<Instruction *> UseInsts;
367     for (auto *U : V->users())
368       if (auto *I = dyn_cast<Instruction>(U))
369         if (Solver.isBlockExecutable(I->getParent()))
370           UseInsts.push_back(I);
371 
372     V->replaceAllUsesWith(Const);
373 
374     for (auto *I : UseInsts)
375       Solver.visit(I);
376 
377     // Remove the instruction from Block and Solver.
378     if (auto *I = dyn_cast<Instruction>(V)) {
379       if (I->isSafeToRemove()) {
380         ReplacedWithConstant.push_back(I);
381         Solver.removeLatticeValueFor(I);
382       }
383     }
384     return true;
385   }
386 
387 private:
388   // The number of functions specialised, used for collecting statistics and
389   // also in the cost model.
390   unsigned NbFunctionsSpecialized = 0;
391 
392   // Compute the code metrics for function \p F.
393   CodeMetrics &analyzeFunction(Function *F) {
394     auto I = FunctionMetrics.insert({F, CodeMetrics()});
395     CodeMetrics &Metrics = I.first->second;
396     if (I.second) {
397       // The code metrics were not cached.
398       SmallPtrSet<const Value *, 32> EphValues;
399       CodeMetrics::collectEphemeralValues(F, &(GetAC)(*F), EphValues);
400       for (BasicBlock &BB : *F)
401         Metrics.analyzeBasicBlock(&BB, (GetTTI)(*F), EphValues);
402 
403       LLVM_DEBUG(dbgs() << "FnSpecialization: Code size of function "
404                         << F->getName() << " is " << Metrics.NumInsts
405                         << " instructions\n");
406     }
407     return Metrics;
408   }
409 
410   /// Clone the function \p F and remove the ssa_copy intrinsics added by
411   /// the SCCPSolver in the cloned version.
412   Function *cloneCandidateFunction(Function *F, ValueToValueMapTy &Mappings) {
413     Function *Clone = CloneFunction(F, Mappings);
414     removeSSACopy(*Clone);
415     return Clone;
416   }
417 
418   /// This function decides whether it's worthwhile to specialize function
419   /// \p F based on the known constant values its arguments can take on. It
420   /// only discovers potential specialization opportunities without actually
421   /// applying them.
422   ///
423   /// \returns true if any specializations have been found.
424   bool calculateGains(Function *F, InstructionCost Cost,
425                       SmallVectorImpl<CallSpecBinding> &WorkList) {
426     SpecializationMap Specializations;
427     // Determine if we should specialize the function based on the values the
428     // argument can take on. If specialization is not profitable, we continue
429     // on to the next argument.
430     for (Argument &FormalArg : F->args()) {
431       // Determine if this argument is interesting. If we know the argument can
432       // take on any constant values, they are collected in Constants.
433       SmallVector<CallArgBinding, 8> ActualArgs;
434       if (!isArgumentInteresting(&FormalArg, ActualArgs)) {
435         LLVM_DEBUG(dbgs() << "FnSpecialization: Argument "
436                           << FormalArg.getNameOrAsOperand()
437                           << " is not interesting\n");
438         continue;
439       }
440 
441       for (const auto &Entry : ActualArgs) {
442         CallBase *Call = Entry.first;
443         Constant *ActualArg = Entry.second;
444 
445         auto I = Specializations.insert({Call, SpecializationInfo()});
446         SpecializationInfo &S = I.first->second;
447 
448         if (I.second)
449           S.Gain = ForceFunctionSpecialization ? 1 : 0 - Cost;
450         if (!ForceFunctionSpecialization)
451           S.Gain += getSpecializationBonus(&FormalArg, ActualArg);
452         S.Args.push_back({&FormalArg, ActualArg});
453       }
454     }
455 
456     // Remove unprofitable specializations.
457     Specializations.remove_if(
458         [](const auto &Entry) { return Entry.second.Gain <= 0; });
459 
460     // Clear the MapVector and return the underlying vector.
461     WorkList = Specializations.takeVector();
462 
463     // Sort the candidates in descending order.
464     llvm::stable_sort(WorkList, [](const auto &L, const auto &R) {
465       return L.second.Gain > R.second.Gain;
466     });
467 
468     // Truncate the worklist to 'MaxClonesThreshold' candidates if necessary.
469     if (WorkList.size() > MaxClonesThreshold) {
470       LLVM_DEBUG(dbgs() << "FnSpecialization: Number of candidates exceed "
471                         << "the maximum number of clones threshold.\n"
472                         << "FnSpecialization: Truncating worklist to "
473                         << MaxClonesThreshold << " candidates.\n");
474       WorkList.erase(WorkList.begin() + MaxClonesThreshold, WorkList.end());
475     }
476 
477     LLVM_DEBUG(dbgs() << "FnSpecialization: Specializations for function "
478                       << F->getName() << "\n";
479                for (const auto &Entry
480                     : WorkList) {
481                  dbgs() << "FnSpecialization:   Gain = " << Entry.second.Gain
482                         << "\n";
483                  for (const ArgInfo &Arg : Entry.second.Args)
484                    dbgs() << "FnSpecialization:   FormalArg = "
485                           << Arg.Formal->getNameOrAsOperand()
486                           << ", ActualArg = "
487                           << Arg.Actual->getNameOrAsOperand() << "\n";
488                });
489 
490     return !WorkList.empty();
491   }
492 
493   bool isCandidateFunction(Function *F) {
494     // Do not specialize the cloned function again.
495     if (SpecializedFuncs.contains(F))
496       return false;
497 
498     // If we're optimizing the function for size, we shouldn't specialize it.
499     if (F->hasOptSize() ||
500         shouldOptimizeForSize(F, nullptr, nullptr, PGSOQueryType::IRPass))
501       return false;
502 
503     // Exit if the function is not executable. There's no point in specializing
504     // a dead function.
505     if (!Solver.isBlockExecutable(&F->getEntryBlock()))
506       return false;
507 
508     // It wastes time to specialize a function which would get inlined finally.
509     if (F->hasFnAttribute(Attribute::AlwaysInline))
510       return false;
511 
512     LLVM_DEBUG(dbgs() << "FnSpecialization: Try function: " << F->getName()
513                       << "\n");
514     return true;
515   }
516 
517   void specializeFunction(Function *F, SpecializationInfo &S,
518                           FuncList &WorkList) {
519     ValueToValueMapTy Mappings;
520     Function *Clone = cloneCandidateFunction(F, Mappings);
521 
522     // Rewrite calls to the function so that they call the clone instead.
523     rewriteCallSites(Clone, S.Args, Mappings);
524 
525     // Initialize the lattice state of the arguments of the function clone,
526     // marking the argument on which we specialized the function constant
527     // with the given value.
528     Solver.markArgInFuncSpecialization(Clone, S.Args);
529 
530     // Mark all the specialized functions
531     WorkList.push_back(Clone);
532     NbFunctionsSpecialized++;
533 
534     // If the function has been completely specialized, the original function
535     // is no longer needed. Mark it unreachable.
536     if (F->getNumUses() == 0 || all_of(F->users(), [F](User *U) {
537           if (auto *CS = dyn_cast<CallBase>(U))
538             return CS->getFunction() == F;
539           return false;
540         })) {
541       Solver.markFunctionUnreachable(F);
542       FullySpecialized.insert(F);
543     }
544   }
545 
546   /// Compute and return the cost of specializing function \p F.
547   InstructionCost getSpecializationCost(Function *F) {
548     CodeMetrics &Metrics = analyzeFunction(F);
549     // If the code metrics reveal that we shouldn't duplicate the function, we
550     // shouldn't specialize it. Set the specialization cost to Invalid.
551     // Or if the lines of codes implies that this function is easy to get
552     // inlined so that we shouldn't specialize it.
553     if (Metrics.notDuplicatable || !Metrics.NumInsts.isValid() ||
554         (!ForceFunctionSpecialization &&
555          *Metrics.NumInsts.getValue() < SmallFunctionThreshold)) {
556       InstructionCost C{};
557       C.setInvalid();
558       return C;
559     }
560 
561     // Otherwise, set the specialization cost to be the cost of all the
562     // instructions in the function and penalty for specializing more functions.
563     unsigned Penalty = NbFunctionsSpecialized + 1;
564     return Metrics.NumInsts * InlineConstants::InstrCost * Penalty;
565   }
566 
567   InstructionCost getUserBonus(User *U, llvm::TargetTransformInfo &TTI,
568                                LoopInfo &LI) {
569     auto *I = dyn_cast_or_null<Instruction>(U);
570     // If not an instruction we do not know how to evaluate.
571     // Keep minimum possible cost for now so that it doesnt affect
572     // specialization.
573     if (!I)
574       return std::numeric_limits<unsigned>::min();
575 
576     auto Cost = TTI.getUserCost(U, TargetTransformInfo::TCK_SizeAndLatency);
577 
578     // Traverse recursively if there are more uses.
579     // TODO: Any other instructions to be added here?
580     if (I->mayReadFromMemory() || I->isCast())
581       for (auto *User : I->users())
582         Cost += getUserBonus(User, TTI, LI);
583 
584     // Increase the cost if it is inside the loop.
585     auto LoopDepth = LI.getLoopDepth(I->getParent());
586     Cost *= std::pow((double)AvgLoopIterationCount, LoopDepth);
587     return Cost;
588   }
589 
590   /// Compute a bonus for replacing argument \p A with constant \p C.
591   InstructionCost getSpecializationBonus(Argument *A, Constant *C) {
592     Function *F = A->getParent();
593     DominatorTree DT(*F);
594     LoopInfo LI(DT);
595     auto &TTI = (GetTTI)(*F);
596     LLVM_DEBUG(dbgs() << "FnSpecialization: Analysing bonus for constant: "
597                       << C->getNameOrAsOperand() << "\n");
598 
599     InstructionCost TotalCost = 0;
600     for (auto *U : A->users()) {
601       TotalCost += getUserBonus(U, TTI, LI);
602       LLVM_DEBUG(dbgs() << "FnSpecialization:   User cost ";
603                  TotalCost.print(dbgs()); dbgs() << " for: " << *U << "\n");
604     }
605 
606     // The below heuristic is only concerned with exposing inlining
607     // opportunities via indirect call promotion. If the argument is not a
608     // (potentially casted) function pointer, give up.
609     Function *CalledFunction = dyn_cast<Function>(C->stripPointerCasts());
610     if (!CalledFunction)
611       return TotalCost;
612 
613     // Get TTI for the called function (used for the inline cost).
614     auto &CalleeTTI = (GetTTI)(*CalledFunction);
615 
616     // Look at all the call sites whose called value is the argument.
617     // Specializing the function on the argument would allow these indirect
618     // calls to be promoted to direct calls. If the indirect call promotion
619     // would likely enable the called function to be inlined, specializing is a
620     // good idea.
621     int Bonus = 0;
622     for (User *U : A->users()) {
623       if (!isa<CallInst>(U) && !isa<InvokeInst>(U))
624         continue;
625       auto *CS = cast<CallBase>(U);
626       if (CS->getCalledOperand() != A)
627         continue;
628 
629       // Get the cost of inlining the called function at this call site. Note
630       // that this is only an estimate. The called function may eventually
631       // change in a way that leads to it not being inlined here, even though
632       // inlining looks profitable now. For example, one of its called
633       // functions may be inlined into it, making the called function too large
634       // to be inlined into this call site.
635       //
636       // We apply a boost for performing indirect call promotion by increasing
637       // the default threshold by the threshold for indirect calls.
638       auto Params = getInlineParams();
639       Params.DefaultThreshold += InlineConstants::IndirectCallThreshold;
640       InlineCost IC =
641           getInlineCost(*CS, CalledFunction, Params, CalleeTTI, GetAC, GetTLI);
642 
643       // We clamp the bonus for this call to be between zero and the default
644       // threshold.
645       if (IC.isAlways())
646         Bonus += Params.DefaultThreshold;
647       else if (IC.isVariable() && IC.getCostDelta() > 0)
648         Bonus += IC.getCostDelta();
649 
650       LLVM_DEBUG(dbgs() << "FnSpecialization:   Inlining bonus " << Bonus
651                         << " for user " << *U << "\n");
652     }
653 
654     return TotalCost + Bonus;
655   }
656 
657   /// Determine if we should specialize a function based on the incoming values
658   /// of the given argument.
659   ///
660   /// This function implements the goal-directed heuristic. It determines if
661   /// specializing the function based on the incoming values of argument \p A
662   /// would result in any significant optimization opportunities. If
663   /// optimization opportunities exist, the constant values of \p A on which to
664   /// specialize the function are collected in \p Constants.
665   ///
666   /// \returns true if the function should be specialized on the given
667   /// argument.
668   bool isArgumentInteresting(Argument *A,
669                              SmallVectorImpl<CallArgBinding> &Constants) {
670     // For now, don't attempt to specialize functions based on the values of
671     // composite types.
672     if (!A->getType()->isSingleValueType() || A->user_empty())
673       return false;
674 
675     // If the argument isn't overdefined, there's nothing to do. It should
676     // already be constant.
677     if (!Solver.getLatticeValueFor(A).isOverdefined()) {
678       LLVM_DEBUG(dbgs() << "FnSpecialization: Nothing to do, argument "
679                         << A->getNameOrAsOperand()
680                         << " is already constant?\n");
681       return false;
682     }
683 
684     // Collect the constant values that the argument can take on. If the
685     // argument can't take on any constant values, we aren't going to
686     // specialize the function. While it's possible to specialize the function
687     // based on non-constant arguments, there's likely not much benefit to
688     // constant propagation in doing so.
689     //
690     // TODO 1: currently it won't specialize if there are over the threshold of
691     // calls using the same argument, e.g foo(a) x 4 and foo(b) x 1, but it
692     // might be beneficial to take the occurrences into account in the cost
693     // model, so we would need to find the unique constants.
694     //
695     // TODO 2: this currently does not support constants, i.e. integer ranges.
696     //
697     getPossibleConstants(A, Constants);
698 
699     if (Constants.empty())
700       return false;
701 
702     LLVM_DEBUG(dbgs() << "FnSpecialization: Found interesting argument "
703                       << A->getNameOrAsOperand() << "\n");
704     return true;
705   }
706 
707   /// Collect in \p Constants all the constant values that argument \p A can
708   /// take on.
709   void getPossibleConstants(Argument *A,
710                             SmallVectorImpl<CallArgBinding> &Constants) {
711     Function *F = A->getParent();
712 
713     // SCCP solver does not record an argument that will be constructed on
714     // stack.
715     if (A->hasByValAttr() && !F->onlyReadsMemory())
716       return;
717 
718     // Iterate over all the call sites of the argument's parent function.
719     for (User *U : F->users()) {
720       if (!isa<CallInst>(U) && !isa<InvokeInst>(U))
721         continue;
722       auto &CS = *cast<CallBase>(U);
723       // If the call site has attribute minsize set, that callsite won't be
724       // specialized.
725       if (CS.hasFnAttr(Attribute::MinSize))
726         continue;
727 
728       // If the parent of the call site will never be executed, we don't need
729       // to worry about the passed value.
730       if (!Solver.isBlockExecutable(CS.getParent()))
731         continue;
732 
733       auto *V = CS.getArgOperand(A->getArgNo());
734       if (isa<PoisonValue>(V))
735         return;
736 
737       // TrackValueOfGlobalVariable only tracks scalar global variables.
738       if (auto *GV = dyn_cast<GlobalVariable>(V)) {
739         // Check if we want to specialize on the address of non-constant
740         // global values.
741         if (!GV->isConstant())
742           if (!SpecializeOnAddresses)
743             return;
744 
745         if (!GV->getValueType()->isSingleValueType())
746           return;
747       }
748 
749       if (isa<Constant>(V) && (Solver.getLatticeValueFor(V).isConstant() ||
750                                EnableSpecializationForLiteralConstant))
751         Constants.push_back({&CS, cast<Constant>(V)});
752     }
753   }
754 
755   /// Rewrite calls to function \p F to call function \p Clone instead.
756   ///
757   /// This function modifies calls to function \p F as long as the actual
758   /// arguments match those in \p Args. Note that for recursive calls we
759   /// need to compare against the cloned formal arguments.
760   ///
761   /// Callsites that have been marked with the MinSize function attribute won't
762   /// be specialized and rewritten.
763   void rewriteCallSites(Function *Clone, const SmallVectorImpl<ArgInfo> &Args,
764                         ValueToValueMapTy &Mappings) {
765     assert(!Args.empty() && "Specialization without arguments");
766     Function *F = Args[0].Formal->getParent();
767 
768     SmallVector<CallBase *, 8> CallSitesToRewrite;
769     for (auto *U : F->users()) {
770       if (!isa<CallInst>(U) && !isa<InvokeInst>(U))
771         continue;
772       auto &CS = *cast<CallBase>(U);
773       if (!CS.getCalledFunction() || CS.getCalledFunction() != F)
774         continue;
775       CallSitesToRewrite.push_back(&CS);
776     }
777 
778     LLVM_DEBUG(dbgs() << "FnSpecialization: Replacing call sites of "
779                       << F->getName() << " with " << Clone->getName() << "\n");
780 
781     for (auto *CS : CallSitesToRewrite) {
782       LLVM_DEBUG(dbgs() << "FnSpecialization:   "
783                         << CS->getFunction()->getName() << " ->" << *CS
784                         << "\n");
785       if (/* recursive call */
786           (CS->getFunction() == Clone &&
787            all_of(Args,
788                   [CS, &Mappings](const ArgInfo &Arg) {
789                     unsigned ArgNo = Arg.Formal->getArgNo();
790                     return CS->getArgOperand(ArgNo) == Mappings[Arg.Formal];
791                   })) ||
792           /* normal call */
793           all_of(Args, [CS](const ArgInfo &Arg) {
794             unsigned ArgNo = Arg.Formal->getArgNo();
795             return CS->getArgOperand(ArgNo) == Arg.Actual;
796           })) {
797         CS->setCalledFunction(Clone);
798         Solver.markOverdefined(CS);
799       }
800     }
801   }
802 
803   void updateSpecializedFuncs(FuncList &Candidates, FuncList &WorkList) {
804     for (auto *F : WorkList) {
805       SpecializedFuncs.insert(F);
806 
807       // Initialize the state of the newly created functions, marking them
808       // argument-tracked and executable.
809       if (F->hasExactDefinition() && !F->hasFnAttribute(Attribute::Naked))
810         Solver.addTrackedFunction(F);
811 
812       Solver.addArgumentTrackedFunction(F);
813       Candidates.push_back(F);
814       Solver.markBlockExecutable(&F->front());
815 
816       // Replace the function arguments for the specialized functions.
817       for (Argument &Arg : F->args())
818         if (!Arg.use_empty() && tryToReplaceWithConstant(&Arg))
819           LLVM_DEBUG(dbgs() << "FnSpecialization: Replaced constant argument: "
820                             << Arg.getNameOrAsOperand() << "\n");
821     }
822   }
823 };
824 } // namespace
825 
826 bool llvm::runFunctionSpecialization(
827     Module &M, const DataLayout &DL,
828     std::function<TargetLibraryInfo &(Function &)> GetTLI,
829     std::function<TargetTransformInfo &(Function &)> GetTTI,
830     std::function<AssumptionCache &(Function &)> GetAC,
831     function_ref<AnalysisResultsForFn(Function &)> GetAnalysis) {
832   SCCPSolver Solver(DL, GetTLI, M.getContext());
833   FunctionSpecializer FS(Solver, GetAC, GetTTI, GetTLI);
834   bool Changed = false;
835 
836   // Loop over all functions, marking arguments to those with their addresses
837   // taken or that are external as overdefined.
838   for (Function &F : M) {
839     if (F.isDeclaration())
840       continue;
841     if (F.hasFnAttribute(Attribute::NoDuplicate))
842       continue;
843 
844     LLVM_DEBUG(dbgs() << "\nFnSpecialization: Analysing decl: " << F.getName()
845                       << "\n");
846     Solver.addAnalysis(F, GetAnalysis(F));
847 
848     // Determine if we can track the function's arguments. If so, add the
849     // function to the solver's set of argument-tracked functions.
850     if (canTrackArgumentsInterprocedurally(&F)) {
851       LLVM_DEBUG(dbgs() << "FnSpecialization: Can track arguments\n");
852       Solver.addArgumentTrackedFunction(&F);
853       continue;
854     } else {
855       LLVM_DEBUG(dbgs() << "FnSpecialization: Can't track arguments!\n"
856                         << "FnSpecialization: Doesn't have local linkage, or "
857                         << "has its address taken\n");
858     }
859 
860     // Assume the function is called.
861     Solver.markBlockExecutable(&F.front());
862 
863     // Assume nothing about the incoming arguments.
864     for (Argument &AI : F.args())
865       Solver.markOverdefined(&AI);
866   }
867 
868   // Determine if we can track any of the module's global variables. If so, add
869   // the global variables we can track to the solver's set of tracked global
870   // variables.
871   for (GlobalVariable &G : M.globals()) {
872     G.removeDeadConstantUsers();
873     if (canTrackGlobalVariableInterprocedurally(&G))
874       Solver.trackValueOfGlobalVariable(&G);
875   }
876 
877   auto &TrackedFuncs = Solver.getArgumentTrackedFunctions();
878   SmallVector<Function *, 16> FuncDecls(TrackedFuncs.begin(),
879                                         TrackedFuncs.end());
880 
881   // No tracked functions, so nothing to do: don't run the solver and remove
882   // the ssa_copy intrinsics that may have been introduced.
883   if (TrackedFuncs.empty()) {
884     removeSSACopy(M);
885     return false;
886   }
887 
888   // Solve for constants.
889   auto RunSCCPSolver = [&](auto &WorkList) {
890     bool ResolvedUndefs = true;
891 
892     while (ResolvedUndefs) {
893       // Not running the solver unnecessary is checked in regression test
894       // nothing-to-do.ll, so if this debug message is changed, this regression
895       // test needs updating too.
896       LLVM_DEBUG(dbgs() << "FnSpecialization: Running solver\n");
897 
898       Solver.solve();
899       LLVM_DEBUG(dbgs() << "FnSpecialization: Resolving undefs\n");
900       ResolvedUndefs = false;
901       for (Function *F : WorkList)
902         if (Solver.resolvedUndefsIn(*F))
903           ResolvedUndefs = true;
904     }
905 
906     for (auto *F : WorkList) {
907       for (BasicBlock &BB : *F) {
908         if (!Solver.isBlockExecutable(&BB))
909           continue;
910         // FIXME: The solver may make changes to the function here, so set
911         // Changed, even if later function specialization does not trigger.
912         for (auto &I : make_early_inc_range(BB))
913           Changed |= FS.tryToReplaceWithConstant(&I);
914       }
915     }
916   };
917 
918 #ifndef NDEBUG
919   LLVM_DEBUG(dbgs() << "FnSpecialization: Worklist fn decls:\n");
920   for (auto *F : FuncDecls)
921     LLVM_DEBUG(dbgs() << "FnSpecialization: *) " << F->getName() << "\n");
922 #endif
923 
924   // Initially resolve the constants in all the argument tracked functions.
925   RunSCCPSolver(FuncDecls);
926 
927   SmallVector<Function *, 8> WorkList;
928   unsigned I = 0;
929   while (FuncSpecializationMaxIters != I++ &&
930          FS.specializeFunctions(FuncDecls, WorkList)) {
931     LLVM_DEBUG(dbgs() << "FnSpecialization: Finished iteration " << I << "\n");
932 
933     // Run the solver for the specialized functions.
934     RunSCCPSolver(WorkList);
935 
936     // Replace some unresolved constant arguments.
937     constantArgPropagation(FuncDecls, M, Solver);
938 
939     WorkList.clear();
940     Changed = true;
941   }
942 
943   LLVM_DEBUG(dbgs() << "FnSpecialization: Number of specializations = "
944                     << NumFuncSpecialized << "\n");
945 
946   // Remove any ssa_copy intrinsics that may have been introduced.
947   removeSSACopy(M);
948   return Changed;
949 }
950