1 //===- MLInlineAdvisor.h - ML - based InlineAdvisor factories ---*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef LLVM_ANALYSIS_MLINLINEADVISOR_H 10 #define LLVM_ANALYSIS_MLINLINEADVISOR_H 11 12 #include "llvm/Analysis/FunctionPropertiesAnalysis.h" 13 #include "llvm/Analysis/InlineAdvisor.h" 14 #include "llvm/Analysis/LazyCallGraph.h" 15 #include "llvm/Analysis/MLModelRunner.h" 16 #include "llvm/IR/PassManager.h" 17 18 #include <deque> 19 #include <map> 20 #include <memory> 21 #include <optional> 22 23 namespace llvm { 24 class DiagnosticInfoOptimizationBase; 25 class Module; 26 class MLInlineAdvice; 27 28 class MLInlineAdvisor : public InlineAdvisor { 29 public: 30 MLInlineAdvisor(Module &M, ModuleAnalysisManager &MAM, 31 std::unique_ptr<MLModelRunner> ModelRunner, 32 std::function<bool(CallBase &)> GetDefaultAdvice); 33 34 virtual ~MLInlineAdvisor() = default; 35 36 void onPassEntry(LazyCallGraph::SCC *SCC) override; 37 void onPassExit(LazyCallGraph::SCC *SCC) override; 38 39 int64_t getIRSize(Function &F) const { 40 return getCachedFPI(F).TotalInstructionCount; 41 } 42 void onSuccessfulInlining(const MLInlineAdvice &Advice, 43 bool CalleeWasDeleted); 44 45 bool isForcedToStop() const { return ForceStop; } 46 int64_t getLocalCalls(Function &F); 47 const MLModelRunner &getModelRunner() const { return *ModelRunner.get(); } 48 FunctionPropertiesInfo &getCachedFPI(Function &) const; 49 50 protected: 51 std::unique_ptr<InlineAdvice> getAdviceImpl(CallBase &CB) override; 52 53 std::unique_ptr<InlineAdvice> getMandatoryAdvice(CallBase &CB, 54 bool Advice) override; 55 56 virtual std::unique_ptr<MLInlineAdvice> getMandatoryAdviceImpl(CallBase &CB); 57 58 virtual std::unique_ptr<MLInlineAdvice> 59 getAdviceFromModel(CallBase &CB, OptimizationRemarkEmitter &ORE); 60 61 // Get the initial 'level' of the function, or 0 if the function has been 62 // introduced afterwards. 63 // TODO: should we keep this updated? 64 unsigned getInitialFunctionLevel(const Function &F) const; 65 66 std::unique_ptr<MLModelRunner> ModelRunner; 67 std::function<bool(CallBase &)> GetDefaultAdvice; 68 69 private: 70 int64_t getModuleIRSize() const; 71 std::unique_ptr<InlineAdvice> 72 getSkipAdviceIfUnreachableCallsite(CallBase &CB); 73 void print(raw_ostream &OS) const override; 74 75 // Using std::map to benefit from its iterator / reference non-invalidating 76 // semantics, which make it easy to use `getCachedFPI` results from multiple 77 // calls without needing to copy to avoid invalidation effects. 78 mutable std::map<const Function *, FunctionPropertiesInfo> FPICache; 79 80 LazyCallGraph &CG; 81 82 int64_t NodeCount = 0; 83 int64_t EdgeCount = 0; 84 int64_t EdgesOfLastSeenNodes = 0; 85 86 std::map<const LazyCallGraph::Node *, unsigned> FunctionLevels; 87 const int32_t InitialIRSize = 0; 88 int32_t CurrentIRSize = 0; 89 llvm::SmallPtrSet<const LazyCallGraph::Node *, 1> NodesInLastSCC; 90 DenseSet<const LazyCallGraph::Node *> AllNodes; 91 bool ForceStop = false; 92 }; 93 94 /// InlineAdvice that tracks changes post inlining. For that reason, it only 95 /// overrides the "successful inlining" extension points. 96 class MLInlineAdvice : public InlineAdvice { 97 public: 98 MLInlineAdvice(MLInlineAdvisor *Advisor, CallBase &CB, 99 OptimizationRemarkEmitter &ORE, bool Recommendation); 100 virtual ~MLInlineAdvice() = default; 101 102 void recordInliningImpl() override; 103 void recordInliningWithCalleeDeletedImpl() override; 104 void recordUnsuccessfulInliningImpl(const InlineResult &Result) override; 105 void recordUnattemptedInliningImpl() override; 106 107 Function *getCaller() const { return Caller; } 108 Function *getCallee() const { return Callee; } 109 110 const int64_t CallerIRSize; 111 const int64_t CalleeIRSize; 112 const int64_t CallerAndCalleeEdges; 113 void updateCachedCallerFPI(FunctionAnalysisManager &FAM) const; 114 115 private: 116 void reportContextForRemark(DiagnosticInfoOptimizationBase &OR); 117 MLInlineAdvisor *getAdvisor() const { 118 return static_cast<MLInlineAdvisor *>(Advisor); 119 }; 120 // Make a copy of the FPI of the caller right before inlining. If inlining 121 // fails, we can just update the cache with that value. 122 const FunctionPropertiesInfo PreInlineCallerFPI; 123 std::optional<FunctionPropertiesUpdater> FPU; 124 }; 125 126 } // namespace llvm 127 128 #endif // LLVM_ANALYSIS_MLINLINEADVISOR_H 129