1 //===- MLInlineAdvisor.h - ML - based InlineAdvisor factories ---*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 
9 #ifndef LLVM_ANALYSIS_MLINLINEADVISOR_H
10 #define LLVM_ANALYSIS_MLINLINEADVISOR_H
11 
12 #include "llvm/Analysis/FunctionPropertiesAnalysis.h"
13 #include "llvm/Analysis/InlineAdvisor.h"
14 #include "llvm/Analysis/LazyCallGraph.h"
15 #include "llvm/Analysis/MLModelRunner.h"
16 #include "llvm/IR/PassManager.h"
17 
18 #include <deque>
19 #include <map>
20 #include <memory>
21 #include <optional>
22 
23 namespace llvm {
24 class DiagnosticInfoOptimizationBase;
25 class Module;
26 class MLInlineAdvice;
27 
28 class MLInlineAdvisor : public InlineAdvisor {
29 public:
30   MLInlineAdvisor(Module &M, ModuleAnalysisManager &MAM,
31                   std::unique_ptr<MLModelRunner> ModelRunner);
32 
33   virtual ~MLInlineAdvisor() = default;
34 
35   void onPassEntry(LazyCallGraph::SCC *SCC) override;
36   void onPassExit(LazyCallGraph::SCC *SCC) override;
37 
38   int64_t getIRSize(Function &F) const {
39     return getCachedFPI(F).TotalInstructionCount;
40   }
41   void onSuccessfulInlining(const MLInlineAdvice &Advice,
42                             bool CalleeWasDeleted);
43 
44   bool isForcedToStop() const { return ForceStop; }
45   int64_t getLocalCalls(Function &F);
46   const MLModelRunner &getModelRunner() const { return *ModelRunner.get(); }
47   FunctionPropertiesInfo &getCachedFPI(Function &) const;
48 
49 protected:
50   std::unique_ptr<InlineAdvice> getAdviceImpl(CallBase &CB) override;
51 
52   std::unique_ptr<InlineAdvice> getMandatoryAdvice(CallBase &CB,
53                                                    bool Advice) override;
54 
55   virtual std::unique_ptr<MLInlineAdvice> getMandatoryAdviceImpl(CallBase &CB);
56 
57   virtual std::unique_ptr<MLInlineAdvice>
58   getAdviceFromModel(CallBase &CB, OptimizationRemarkEmitter &ORE);
59 
60   // Get the initial 'level' of the function, or 0 if the function has been
61   // introduced afterwards.
62   // TODO: should we keep this updated?
63   unsigned getInitialFunctionLevel(const Function &F) const;
64 
65   std::unique_ptr<MLModelRunner> ModelRunner;
66 
67 private:
68   int64_t getModuleIRSize() const;
69   std::unique_ptr<InlineAdvice>
70   getSkipAdviceIfUnreachableCallsite(CallBase &CB);
71   void print(raw_ostream &OS) const override;
72 
73   // Using std::map to benefit from its iterator / reference non-invalidating
74   // semantics, which make it easy to use `getCachedFPI` results from multiple
75   // calls without needing to copy to avoid invalidation effects.
76   mutable std::map<const Function *, FunctionPropertiesInfo> FPICache;
77 
78   LazyCallGraph &CG;
79 
80   int64_t NodeCount = 0;
81   int64_t EdgeCount = 0;
82   int64_t EdgesOfLastSeenNodes = 0;
83 
84   std::map<const LazyCallGraph::Node *, unsigned> FunctionLevels;
85   const int32_t InitialIRSize = 0;
86   int32_t CurrentIRSize = 0;
87   llvm::SmallPtrSet<const LazyCallGraph::Node *, 1> NodesInLastSCC;
88   DenseSet<const LazyCallGraph::Node *> AllNodes;
89   bool ForceStop = false;
90 };
91 
92 /// InlineAdvice that tracks changes post inlining. For that reason, it only
93 /// overrides the "successful inlining" extension points.
94 class MLInlineAdvice : public InlineAdvice {
95 public:
96   MLInlineAdvice(MLInlineAdvisor *Advisor, CallBase &CB,
97                  OptimizationRemarkEmitter &ORE, bool Recommendation);
98   virtual ~MLInlineAdvice() = default;
99 
100   void recordInliningImpl() override;
101   void recordInliningWithCalleeDeletedImpl() override;
102   void recordUnsuccessfulInliningImpl(const InlineResult &Result) override;
103   void recordUnattemptedInliningImpl() override;
104 
105   Function *getCaller() const { return Caller; }
106   Function *getCallee() const { return Callee; }
107 
108   const int64_t CallerIRSize;
109   const int64_t CalleeIRSize;
110   const int64_t CallerAndCalleeEdges;
111   void updateCachedCallerFPI(FunctionAnalysisManager &FAM) const;
112 
113 private:
114   void reportContextForRemark(DiagnosticInfoOptimizationBase &OR);
115   MLInlineAdvisor *getAdvisor() const {
116     return static_cast<MLInlineAdvisor *>(Advisor);
117   };
118   // Make a copy of the FPI of the caller right before inlining. If inlining
119   // fails, we can just update the cache with that value.
120   const FunctionPropertiesInfo PreInlineCallerFPI;
121   std::optional<FunctionPropertiesUpdater> FPU;
122 };
123 
124 } // namespace llvm
125 
126 #endif // LLVM_ANALYSIS_MLINLINEADVISOR_H
127