1 // ATTENTION: The code in this file is highly EXPERIMENTAL. 2 // Adventurous users should note that the APIs will probably change. 3 4 #pragma once 5 6 #include <string> 7 #include "onnx/common/ir.h" 8 #include "onnx/onnx_pb.h" 9 10 namespace ONNX_NAMESPACE { 11 namespace optimization { 12 13 // Base struct representing result of a pass. 14 struct PostPassAnalysis { 15 virtual ~PostPassAnalysis() = default; 16 }; 17 18 // Enum that represents the type of optimization it is. 19 enum PassType { 20 // Class of optimizations that fuses operations. 21 Fuse = 0, 22 // Class of optimizations that removes useless operations. 23 Nop = 1, 24 // Class of optimizations that includes some form of seperation. 25 Seperate = 2, 26 // Immutable pass, also sometimes referred to as an analysis pass. 27 Immutable = 3, 28 // Other type of pass. 29 Other = 4 30 }; 31 32 // Enum that represents the return type of the analysis. 33 enum PassAnalysisType { 34 // An empty analysis is returned. Most likely will return PostPassAnalysis. 35 Empty = 0, 36 // A count based analysis is returned. Most likely of type 37 // CountBasedPassAnalysis 38 CountBased = 1 39 }; 40 41 enum PassEfficiency { 42 // A partially efficient optimization pass cannot guarantee that running two 43 // consecutive passes 44 // will return the same result as running a single pass. 45 Partial = 0, 46 // A completely efficient optimization guarantees that running two consecutive 47 // passes is equivalent 48 // to running a single pass. 49 Complete = 1 50 }; 51 52 // Describes what the optimization pass is attempting to optimize. 53 enum PassOptimizationType { 54 // Is not optimizing anything. Most likely will be used in an immutable pass. 55 None = 0, 56 // Optimizes for compute. 57 Compute = 1, 58 // Optimizes for memory. 59 Memory = 2, 60 // Optimizes for both compute and memory. 61 ComputeMemory = 3, 62 // Optimizes for stability (e.g. log-sum-exp trick). 63 Stability = 4 64 }; 65 66 enum NodeDestroyType { 67 // Does not destroy node 68 DestroyZero = 0, 69 // Equivalent to calling it.destroyCurrent() once. 70 DestroyOne = 1, 71 // Equivalent to calling it.destroyCurrent() twice. 72 DestroyTwo = 2 73 }; 74 75 // Base class for all optimizations within ONNX. A pass must contain the 76 // annotations described above. Furthermore each pass is given the ability to 77 // initialize and finalize it's pass. Each pass must have a unique name that 78 // pass managers/registry will use as identification. Finally the pass 79 // implements runPass which completes the pass inplace. 80 class Pass { 81 PassType pass_type; 82 PassEfficiency pass_efficiency; 83 PassOptimizationType pass_optimization_type; 84 85 public: 86 Pass( 87 PassType pass_type, 88 PassEfficiency pass_efficiency, 89 PassOptimizationType pass_optimization_type); 90 virtual ~Pass(); 91 getPassType()92 PassType getPassType() const { 93 return this->pass_type; 94 } getPassEfficiency()95 PassEfficiency getPassEfficiency() const { 96 return this->pass_efficiency; 97 } getPassOptimizationType()98 PassOptimizationType getPassOptimizationType() const { 99 return this->pass_optimization_type; 100 } 101 virtual PassAnalysisType getPassAnalysisType() const = 0; 102 virtual std::string getPassName() const = 0; 103 initializePass(Graph &)104 virtual bool initializePass(Graph&) { 105 return false; 106 } finalizePass(Graph &)107 virtual bool finalizePass(Graph&) { 108 return false; 109 } 110 virtual std::shared_ptr<PostPassAnalysis> runPass(Graph& graph) = 0; 111 112 protected: 113 // Iterates through the elements in the graph and counts the number of times 114 // the transform is succesfully run. 115 unsigned int DescendOnGraphAttributesAndCount( 116 Node* n, 117 std::function<unsigned int(Graph&)> fn); 118 // A more general version of the function above that doesn't constrain the 119 // return type of fn. 120 void DescendOnGraphAttributesUnconstrained( 121 Node* n, 122 std::function<void(Graph&)> fn); 123 }; 124 125 class ImmutablePass : Pass { 126 public: ImmutablePass()127 explicit ImmutablePass() 128 : Pass( 129 PassType::Immutable, 130 PassEfficiency::Complete, 131 PassOptimizationType::None) {} 132 ~ImmutablePass() override; 133 }; 134 135 // Pass Analysis done after a predicate based pass. 136 struct CountBasedPassAnalysis : PostPassAnalysis { 137 // Have to use raw pointer here. The idea is that the pass will pass <this> as 138 // a parameter to the constructor. We could use std::enable_shared_from_this 139 // but this complicates the memory model. Also since all passes come from 140 // GlobalPassRegistry which already utilizes smart pointers we don't have to 141 // worry about memory leaks from passes. 142 Pass* pass; 143 unsigned int num_positive_transforms; 144 bool initialization_done; 145 bool finalization_done; 146 147 public: 148 explicit CountBasedPassAnalysis( 149 Pass* pass, 150 unsigned int num_positive_transforms, 151 bool initialization_done, 152 bool finalization_done); 153 graphChangedCountBasedPassAnalysis154 bool graphChanged() { 155 return this->num_positive_transforms > 0; 156 } numSucceededTransformsCountBasedPassAnalysis157 bool numSucceededTransforms() { 158 return this->num_positive_transforms; 159 } 160 161 // Whether or not a repeated application of the pass might be useful. fixedPointOptimizationNeededCountBasedPassAnalysis162 bool fixedPointOptimizationNeeded() { 163 return this->graphChanged() && 164 pass->getPassEfficiency() == PassEfficiency::Partial; 165 } 166 }; 167 168 // A pass that is based on pattern matching. The majority of passes will 169 // implement this pass. In order for the pass to work the patternMatchPredicate 170 // function must be implemented witch matches a subgraph to the respective 171 // optimization pass. Lastly the runTransform method must also be implemented 172 // which simply implements the pass on any node which passes 173 // patternMatchPredicate. 174 class PredicateBasedPass : public Pass { 175 public: PredicateBasedPass(PassType pass_type,PassEfficiency pass_efficiency,PassOptimizationType pass_optimization_type)176 explicit PredicateBasedPass( 177 PassType pass_type, 178 PassEfficiency pass_efficiency, 179 PassOptimizationType pass_optimization_type) 180 : Pass(pass_type, pass_efficiency, pass_optimization_type) {} 181 ~PredicateBasedPass() override; 182 183 virtual bool patternMatchPredicate(Node* node) = 0; 184 // Run transform is given the current node in the iterator, a reference to the 185 // current graph as well as a reference describing how to treat the current 186 // node in the iterator post transform. Run transform is then responsible for 187 // running the actual transform as well as describing how to treat the 188 // iterator node. By default the current node will not call destroy. Do not 189 // internally delete node instead set the correct destroy_current type. 190 virtual bool 191 runTransform(Node* node, Graph& graph, NodeDestroyType& destroy_current) = 0; 192 193 std::shared_ptr<PostPassAnalysis> runPass(Graph& graph) override; 194 PassAnalysisType getPassAnalysisType() const override; 195 196 private: 197 unsigned int _runPassInternal(Graph& graph); 198 }; 199 200 // The most general pass which allows the user to run a pass given only a graph. 201 class FullGraphBasedPass : public Pass { 202 public: FullGraphBasedPass(PassType pass_type,PassEfficiency pass_efficiency,PassOptimizationType pass_optimization_type)203 explicit FullGraphBasedPass( 204 PassType pass_type, 205 PassEfficiency pass_efficiency, 206 PassOptimizationType pass_optimization_type) 207 : Pass(pass_type, pass_efficiency, pass_optimization_type) {} 208 ~FullGraphBasedPass() override; 209 }; 210 211 } // namespace optimization 212 } // namespace ONNX_NAMESPACE 213