1 #include "onnx/optimizer/pass.h"
2 #include "onnx/common/assertions.h"
3
4 namespace ONNX_NAMESPACE {
5 namespace optimization {
6
Pass(PassType pass_type,PassEfficiency pass_efficiency,PassOptimizationType pass_optimization_type)7 Pass::Pass(
8 PassType pass_type,
9 PassEfficiency pass_efficiency,
10 PassOptimizationType pass_optimization_type) {
11 this->pass_type = pass_type;
12 this->pass_efficiency = pass_efficiency;
13 this->pass_optimization_type = pass_optimization_type;
14 }
15
~Pass()16 Pass::~Pass() {}
17
DescendOnGraphAttributesAndCount(Node * n,std::function<unsigned int (Graph &)> fn)18 unsigned int Pass::DescendOnGraphAttributesAndCount(
19 Node* n,
20 std::function<unsigned int(Graph&)> fn) {
21 unsigned int num_changes = 0;
22 for (auto name : n->attributeNames()) {
23 auto kind = n->kindOf(name);
24 if (kind == AttributeKind::g) {
25 num_changes += fn(*n->g(name));
26 }
27 if (kind == AttributeKind::gs) {
28 for (auto& g : n->gs(name)) {
29 num_changes += fn(*g);
30 }
31 }
32 }
33 return num_changes;
34 }
35
DescendOnGraphAttributesUnconstrained(Node * n,std::function<void (Graph &)> fn)36 void Pass::DescendOnGraphAttributesUnconstrained(
37 Node* n,
38 std::function<void(Graph&)> fn) {
39 for (auto name : n->attributeNames()) {
40 auto kind = n->kindOf(name);
41 if (kind == AttributeKind::g) {
42 fn(*n->g(name));
43 }
44 if (kind == AttributeKind::gs) {
45 for (auto& g : n->gs(name)) {
46 fn(*g);
47 }
48 }
49 }
50 }
51
~PredicateBasedPass()52 PredicateBasedPass::~PredicateBasedPass() {}
53
_runPassInternal(Graph & graph)54 unsigned int PredicateBasedPass::_runPassInternal(Graph& graph) {
55 unsigned int num_changes = false;
56 for (auto it = graph.begin(); it != graph.end(); ++it) {
57 auto* n = *it;
58 num_changes += this->DescendOnGraphAttributesAndCount(
59 n, [this](Graph& g) { return _runPassInternal(g); });
60 if (this->patternMatchPredicate(n)) {
61 NodeDestroyType destroy_type = NodeDestroyType::DestroyZero;
62 num_changes += this->runTransform(n, graph, destroy_type);
63
64 if (destroy_type == NodeDestroyType::DestroyOne) {
65 it.destroyCurrent();
66 }
67 if (destroy_type == NodeDestroyType::DestroyTwo) {
68 it.destroyCurrent();
69 it.destroyCurrent();
70 }
71 }
72 }
73 return num_changes;
74 }
75
getPassAnalysisType() const76 PassAnalysisType PredicateBasedPass::getPassAnalysisType() const {
77 return PassAnalysisType::CountBased;
78 }
79
runPass(Graph & graph)80 std::shared_ptr<PostPassAnalysis> PredicateBasedPass::runPass(Graph& graph) {
81 bool initialized_pass = this->initializePass(graph);
82 unsigned int touched_optimizations = this->_runPassInternal(graph);
83 bool finalized_pass = this->finalizePass(graph);
84
85 return std::shared_ptr<PostPassAnalysis>(new CountBasedPassAnalysis(
86 this, touched_optimizations, initialized_pass, finalized_pass));
87 }
88
CountBasedPassAnalysis(Pass * pass,unsigned int num_positive_transforms,bool initialization_done,bool finalization_done)89 CountBasedPassAnalysis::CountBasedPassAnalysis(
90 Pass* pass,
91 unsigned int num_positive_transforms,
92 bool initialization_done,
93 bool finalization_done) {
94 this->pass = pass;
95 this->num_positive_transforms = num_positive_transforms;
96 this->initialization_done = initialization_done;
97 this->finalization_done = finalization_done;
98 }
99
~FullGraphBasedPass()100 FullGraphBasedPass::~FullGraphBasedPass() {}
101
102 } // namespace optimization
103 } // namespace ONNX_NAMESPACE
104