1 //===- MLInlineAdvisor.h - ML - based InlineAdvisor factories ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_ANALYSIS_MLINLINEADVISOR_H
10 #define LLVM_ANALYSIS_MLINLINEADVISOR_H
11 
12 #include "llvm/Analysis/CallGraph.h"
13 #include "llvm/Analysis/InlineAdvisor.h"
14 #include "llvm/Analysis/MLModelRunner.h"
15 #include "llvm/IR/PassManager.h"
16 
17 #include <memory>
18 #include <unordered_map>
19 
20 namespace llvm {
21 class Module;
22 class MLInlineAdvice;
23 
24 class MLInlineAdvisor : public InlineAdvisor {
25 public:
26   MLInlineAdvisor(Module &M, ModuleAnalysisManager &MAM,
27                   std::unique_ptr<MLModelRunner> ModelRunner);
28 
callGraph()29   CallGraph *callGraph() const { return CG.get(); }
30   virtual ~MLInlineAdvisor() = default;
31 
32   void onPassEntry() override;
33 
getIRSize(const Function & F)34   int64_t getIRSize(const Function &F) const { return F.getInstructionCount(); }
35   void onSuccessfulInlining(const MLInlineAdvice &Advice,
36                             bool CalleeWasDeleted);
37 
isForcedToStop()38   bool isForcedToStop() const { return ForceStop; }
39   int64_t getLocalCalls(Function &F);
getModelRunner()40   const MLModelRunner &getModelRunner() const { return *ModelRunner.get(); }
41 
42 protected:
43   std::unique_ptr<InlineAdvice> getAdviceImpl(CallBase &CB) override;
44 
45   std::unique_ptr<InlineAdvice> getMandatoryAdvice(CallBase &CB,
46                                                    bool Advice) override;
47 
48   virtual std::unique_ptr<MLInlineAdvice> getMandatoryAdviceImpl(CallBase &CB);
49 
50   virtual std::unique_ptr<MLInlineAdvice>
51   getAdviceFromModel(CallBase &CB, OptimizationRemarkEmitter &ORE);
52 
53   std::unique_ptr<MLModelRunner> ModelRunner;
54 
55 private:
56   int64_t getModuleIRSize() const;
57 
58   std::unique_ptr<CallGraph> CG;
59 
60   int64_t NodeCount = 0;
61   int64_t EdgeCount = 0;
62   std::map<const Function *, unsigned> FunctionLevels;
63   const int32_t InitialIRSize = 0;
64   int32_t CurrentIRSize = 0;
65 
66   bool ForceStop = false;
67 };
68 
69 /// InlineAdvice that tracks changes post inlining. For that reason, it only
70 /// overrides the "successful inlining" extension points.
71 class MLInlineAdvice : public InlineAdvice {
72 public:
MLInlineAdvice(MLInlineAdvisor * Advisor,CallBase & CB,OptimizationRemarkEmitter & ORE,bool Recommendation)73   MLInlineAdvice(MLInlineAdvisor *Advisor, CallBase &CB,
74                  OptimizationRemarkEmitter &ORE, bool Recommendation)
75       : InlineAdvice(Advisor, CB, ORE, Recommendation),
76         CallerIRSize(Advisor->isForcedToStop() ? 0
77                                                : Advisor->getIRSize(*Caller)),
78         CalleeIRSize(Advisor->isForcedToStop() ? 0
79                                                : Advisor->getIRSize(*Callee)),
80         CallerAndCalleeEdges(Advisor->isForcedToStop()
81                                  ? 0
82                                  : (Advisor->getLocalCalls(*Caller) +
83                                     Advisor->getLocalCalls(*Callee))) {}
84   virtual ~MLInlineAdvice() = default;
85 
86   void recordInliningImpl() override;
87   void recordInliningWithCalleeDeletedImpl() override;
88   void recordUnsuccessfulInliningImpl(const InlineResult &Result) override;
89   void recordUnattemptedInliningImpl() override;
90 
getCaller()91   Function *getCaller() const { return Caller; }
getCallee()92   Function *getCallee() const { return Callee; }
93 
94   const int64_t CallerIRSize;
95   const int64_t CalleeIRSize;
96   const int64_t CallerAndCalleeEdges;
97 
98 private:
99   void reportContextForRemark(DiagnosticInfoOptimizationBase &OR);
100 
getAdvisor()101   MLInlineAdvisor *getAdvisor() const {
102     return static_cast<MLInlineAdvisor *>(Advisor);
103   };
104 };
105 
106 } // namespace llvm
107 
108 #endif // LLVM_ANALYSIS_MLINLINEADVISOR_H
109