1 //===- MLInlineAdvisor.h - ML - based InlineAdvisor factories ---*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 9 #ifndef LLVM_ANALYSIS_MLINLINEADVISOR_H 10 #define LLVM_ANALYSIS_MLINLINEADVISOR_H 11 12 #include "llvm/Analysis/CallGraph.h" 13 #include "llvm/Analysis/InlineAdvisor.h" 14 #include "llvm/Analysis/MLModelRunner.h" 15 #include "llvm/IR/PassManager.h" 16 17 #include <memory> 18 #include <unordered_map> 19 20 namespace llvm { 21 class Module; 22 class MLInlineAdvice; 23 24 class MLInlineAdvisor : public InlineAdvisor { 25 public: 26 MLInlineAdvisor(Module &M, ModuleAnalysisManager &MAM, 27 std::unique_ptr<MLModelRunner> ModelRunner); 28 callGraph()29 CallGraph *callGraph() const { return CG.get(); } 30 virtual ~MLInlineAdvisor() = default; 31 32 void onPassEntry() override; 33 getIRSize(const Function & F)34 int64_t getIRSize(const Function &F) const { return F.getInstructionCount(); } 35 void onSuccessfulInlining(const MLInlineAdvice &Advice, 36 bool CalleeWasDeleted); 37 isForcedToStop()38 bool isForcedToStop() const { return ForceStop; } 39 int64_t getLocalCalls(Function &F); getModelRunner()40 const MLModelRunner &getModelRunner() const { return *ModelRunner.get(); } 41 42 protected: 43 std::unique_ptr<InlineAdvice> getAdviceImpl(CallBase &CB) override; 44 45 std::unique_ptr<InlineAdvice> getMandatoryAdvice(CallBase &CB, 46 bool Advice) override; 47 48 virtual std::unique_ptr<MLInlineAdvice> getMandatoryAdviceImpl(CallBase &CB); 49 50 virtual std::unique_ptr<MLInlineAdvice> 51 getAdviceFromModel(CallBase &CB, OptimizationRemarkEmitter &ORE); 52 53 std::unique_ptr<MLModelRunner> ModelRunner; 54 55 private: 56 int64_t getModuleIRSize() const; 57 58 std::unique_ptr<CallGraph> CG; 59 60 int64_t NodeCount = 0; 61 int64_t EdgeCount = 0; 62 std::map<const Function *, unsigned> FunctionLevels; 63 const int32_t InitialIRSize = 0; 64 int32_t CurrentIRSize = 0; 65 66 bool ForceStop = false; 67 }; 68 69 /// InlineAdvice that tracks changes post inlining. For that reason, it only 70 /// overrides the "successful inlining" extension points. 71 class MLInlineAdvice : public InlineAdvice { 72 public: MLInlineAdvice(MLInlineAdvisor * Advisor,CallBase & CB,OptimizationRemarkEmitter & ORE,bool Recommendation)73 MLInlineAdvice(MLInlineAdvisor *Advisor, CallBase &CB, 74 OptimizationRemarkEmitter &ORE, bool Recommendation) 75 : InlineAdvice(Advisor, CB, ORE, Recommendation), 76 CallerIRSize(Advisor->isForcedToStop() ? 0 77 : Advisor->getIRSize(*Caller)), 78 CalleeIRSize(Advisor->isForcedToStop() ? 0 79 : Advisor->getIRSize(*Callee)), 80 CallerAndCalleeEdges(Advisor->isForcedToStop() 81 ? 0 82 : (Advisor->getLocalCalls(*Caller) + 83 Advisor->getLocalCalls(*Callee))) {} 84 virtual ~MLInlineAdvice() = default; 85 86 void recordInliningImpl() override; 87 void recordInliningWithCalleeDeletedImpl() override; 88 void recordUnsuccessfulInliningImpl(const InlineResult &Result) override; 89 void recordUnattemptedInliningImpl() override; 90 getCaller()91 Function *getCaller() const { return Caller; } getCallee()92 Function *getCallee() const { return Callee; } 93 94 const int64_t CallerIRSize; 95 const int64_t CalleeIRSize; 96 const int64_t CallerAndCalleeEdges; 97 98 private: 99 void reportContextForRemark(DiagnosticInfoOptimizationBase &OR); 100 getAdvisor()101 MLInlineAdvisor *getAdvisor() const { 102 return static_cast<MLInlineAdvisor *>(Advisor); 103 }; 104 }; 105 106 } // namespace llvm 107 108 #endif // LLVM_ANALYSIS_MLINLINEADVISOR_H 109