1 // ATTENTION: The code in this file is highly EXPERIMENTAL.
2 // Adventurous users should note that the APIs will probably change.
3 
4 #pragma once
5 
6 #include <string>
7 #include "onnx/common/ir.h"
8 #include "onnx/onnx_pb.h"
9 
10 namespace ONNX_NAMESPACE {
11 namespace optimization {
12 
13 // Base struct representing result of a pass.
14 struct PostPassAnalysis {
15   virtual ~PostPassAnalysis() = default;
16 };
17 
18 // Enum that represents the type of optimization it is.
19 enum PassType {
20   // Class of optimizations that fuses operations.
21   Fuse = 0,
22   // Class of optimizations that removes useless operations.
23   Nop = 1,
24   // Class of optimizations that includes some form of seperation.
25   Seperate = 2,
26   // Immutable pass, also sometimes referred to as an analysis pass.
27   Immutable = 3,
28   // Other type of pass.
29   Other = 4
30 };
31 
32 // Enum that represents the return type of the analysis.
33 enum PassAnalysisType {
34   // An empty analysis is returned. Most likely will return PostPassAnalysis.
35   Empty = 0,
36   // A count based analysis is returned. Most likely of type
37   // CountBasedPassAnalysis
38   CountBased = 1
39 };
40 
41 enum PassEfficiency {
42   // A partially efficient optimization pass cannot guarantee that running two
43   // consecutive passes
44   // will return the same result as running a single pass.
45   Partial = 0,
46   // A completely efficient optimization guarantees that running two consecutive
47   // passes is equivalent
48   // to running a single pass.
49   Complete = 1
50 };
51 
52 // Describes what the optimization pass is attempting to optimize.
53 enum PassOptimizationType {
54   // Is not optimizing anything. Most likely will be used in an immutable pass.
55   None = 0,
56   // Optimizes for compute.
57   Compute = 1,
58   // Optimizes for memory.
59   Memory = 2,
60   // Optimizes for both compute and memory.
61   ComputeMemory = 3,
62   // Optimizes for stability (e.g. log-sum-exp trick).
63   Stability = 4
64 };
65 
66 enum NodeDestroyType {
67   // Does not destroy node
68   DestroyZero = 0,
69   // Equivalent to calling it.destroyCurrent() once.
70   DestroyOne = 1,
71   // Equivalent to calling it.destroyCurrent() twice.
72   DestroyTwo = 2
73 };
74 
75 // Base class for all optimizations within ONNX. A pass must contain the
76 // annotations described above. Furthermore each pass is given the ability to
77 // initialize and finalize it's pass. Each pass must have a unique name that
78 // pass managers/registry will use as identification. Finally the pass
79 // implements runPass which completes the pass inplace.
80 class Pass {
81   PassType pass_type;
82   PassEfficiency pass_efficiency;
83   PassOptimizationType pass_optimization_type;
84 
85  public:
86   Pass(
87       PassType pass_type,
88       PassEfficiency pass_efficiency,
89       PassOptimizationType pass_optimization_type);
90   virtual ~Pass();
91 
getPassType()92   PassType getPassType() const {
93     return this->pass_type;
94   }
getPassEfficiency()95   PassEfficiency getPassEfficiency() const {
96     return this->pass_efficiency;
97   }
getPassOptimizationType()98   PassOptimizationType getPassOptimizationType() const {
99     return this->pass_optimization_type;
100   }
101   virtual PassAnalysisType getPassAnalysisType() const = 0;
102   virtual std::string getPassName() const = 0;
103 
initializePass(Graph &)104   virtual bool initializePass(Graph&) {
105     return false;
106   }
finalizePass(Graph &)107   virtual bool finalizePass(Graph&) {
108     return false;
109   }
110   virtual std::shared_ptr<PostPassAnalysis> runPass(Graph& graph) = 0;
111 
112  protected:
113   // Iterates through the elements in the graph and counts the number of times
114   // the transform is succesfully run.
115   unsigned int DescendOnGraphAttributesAndCount(
116       Node* n,
117       std::function<unsigned int(Graph&)> fn);
118   // A more general version of the function above that doesn't constrain the
119   // return type of fn.
120   void DescendOnGraphAttributesUnconstrained(
121       Node* n,
122       std::function<void(Graph&)> fn);
123 };
124 
125 class ImmutablePass : Pass {
126  public:
ImmutablePass()127   explicit ImmutablePass()
128       : Pass(
129             PassType::Immutable,
130             PassEfficiency::Complete,
131             PassOptimizationType::None) {}
132   ~ImmutablePass() override;
133 };
134 
135 // Pass Analysis done after a predicate based pass.
136 struct CountBasedPassAnalysis : PostPassAnalysis {
137   // Have to use raw pointer here. The idea is that the pass will pass <this> as
138   // a parameter to the constructor. We could use std::enable_shared_from_this
139   // but this complicates the memory model. Also since all passes come from
140   // GlobalPassRegistry which already utilizes smart pointers we don't have to
141   // worry about memory leaks from passes.
142   Pass* pass;
143   unsigned int num_positive_transforms;
144   bool initialization_done;
145   bool finalization_done;
146 
147  public:
148   explicit CountBasedPassAnalysis(
149       Pass* pass,
150       unsigned int num_positive_transforms,
151       bool initialization_done,
152       bool finalization_done);
153 
graphChangedCountBasedPassAnalysis154   bool graphChanged() {
155     return this->num_positive_transforms > 0;
156   }
numSucceededTransformsCountBasedPassAnalysis157   bool numSucceededTransforms() {
158     return this->num_positive_transforms;
159   }
160 
161   // Whether or not a repeated application of the pass might be useful.
fixedPointOptimizationNeededCountBasedPassAnalysis162   bool fixedPointOptimizationNeeded() {
163     return this->graphChanged() &&
164         pass->getPassEfficiency() == PassEfficiency::Partial;
165   }
166 };
167 
168 // A pass that is based on pattern matching. The majority of passes will
169 // implement this pass. In order for the pass to work the patternMatchPredicate
170 // function must be implemented witch matches a subgraph to the respective
171 // optimization pass. Lastly the runTransform method must also be implemented
172 // which simply implements the pass on any node which passes
173 // patternMatchPredicate.
174 class PredicateBasedPass : public Pass {
175  public:
PredicateBasedPass(PassType pass_type,PassEfficiency pass_efficiency,PassOptimizationType pass_optimization_type)176   explicit PredicateBasedPass(
177       PassType pass_type,
178       PassEfficiency pass_efficiency,
179       PassOptimizationType pass_optimization_type)
180       : Pass(pass_type, pass_efficiency, pass_optimization_type) {}
181   ~PredicateBasedPass() override;
182 
183   virtual bool patternMatchPredicate(Node* node) = 0;
184   // Run transform is given the current node in the iterator, a reference to the
185   // current graph as well as a reference describing how to treat the current
186   // node in the iterator post transform. Run transform is then responsible for
187   // running the actual transform as well as describing how to treat the
188   // iterator node. By default the current node will not call destroy. Do not
189   // internally delete node instead set the correct destroy_current type.
190   virtual bool
191   runTransform(Node* node, Graph& graph, NodeDestroyType& destroy_current) = 0;
192 
193   std::shared_ptr<PostPassAnalysis> runPass(Graph& graph) override;
194   PassAnalysisType getPassAnalysisType() const override;
195 
196  private:
197   unsigned int _runPassInternal(Graph& graph);
198 };
199 
200 // The most general pass which allows the user to run a pass given only a graph.
201 class FullGraphBasedPass : public Pass {
202  public:
FullGraphBasedPass(PassType pass_type,PassEfficiency pass_efficiency,PassOptimizationType pass_optimization_type)203   explicit FullGraphBasedPass(
204       PassType pass_type,
205       PassEfficiency pass_efficiency,
206       PassOptimizationType pass_optimization_type)
207       : Pass(pass_type, pass_efficiency, pass_optimization_type) {}
208   ~FullGraphBasedPass() override;
209 };
210 
211 } // namespace optimization
212 } // namespace ONNX_NAMESPACE
213