1 #include "onnx/optimizer/pass.h"
2 #include "onnx/common/assertions.h"
3 
4 namespace ONNX_NAMESPACE {
5 namespace optimization {
6 
Pass(PassType pass_type,PassEfficiency pass_efficiency,PassOptimizationType pass_optimization_type)7 Pass::Pass(
8     PassType pass_type,
9     PassEfficiency pass_efficiency,
10     PassOptimizationType pass_optimization_type) {
11   this->pass_type = pass_type;
12   this->pass_efficiency = pass_efficiency;
13   this->pass_optimization_type = pass_optimization_type;
14 }
15 
~Pass()16 Pass::~Pass() {}
17 
DescendOnGraphAttributesAndCount(Node * n,std::function<unsigned int (Graph &)> fn)18 unsigned int Pass::DescendOnGraphAttributesAndCount(
19     Node* n,
20     std::function<unsigned int(Graph&)> fn) {
21   unsigned int num_changes = 0;
22   for (auto name : n->attributeNames()) {
23     auto kind = n->kindOf(name);
24     if (kind == AttributeKind::g) {
25       num_changes += fn(*n->g(name));
26     }
27     if (kind == AttributeKind::gs) {
28       for (auto& g : n->gs(name)) {
29         num_changes += fn(*g);
30       }
31     }
32   }
33   return num_changes;
34 }
35 
DescendOnGraphAttributesUnconstrained(Node * n,std::function<void (Graph &)> fn)36 void Pass::DescendOnGraphAttributesUnconstrained(
37     Node* n,
38     std::function<void(Graph&)> fn) {
39   for (auto name : n->attributeNames()) {
40     auto kind = n->kindOf(name);
41     if (kind == AttributeKind::g) {
42       fn(*n->g(name));
43     }
44     if (kind == AttributeKind::gs) {
45       for (auto& g : n->gs(name)) {
46         fn(*g);
47       }
48     }
49   }
50 }
51 
~PredicateBasedPass()52 PredicateBasedPass::~PredicateBasedPass() {}
53 
_runPassInternal(Graph & graph)54 unsigned int PredicateBasedPass::_runPassInternal(Graph& graph) {
55   unsigned int num_changes = false;
56   for (auto it = graph.begin(); it != graph.end(); ++it) {
57     auto* n = *it;
58     num_changes += this->DescendOnGraphAttributesAndCount(
59         n, [this](Graph& g) { return _runPassInternal(g); });
60     if (this->patternMatchPredicate(n)) {
61       NodeDestroyType destroy_type = NodeDestroyType::DestroyZero;
62       num_changes += this->runTransform(n, graph, destroy_type);
63 
64       if (destroy_type == NodeDestroyType::DestroyOne) {
65         it.destroyCurrent();
66       }
67       if (destroy_type == NodeDestroyType::DestroyTwo) {
68         it.destroyCurrent();
69         it.destroyCurrent();
70       }
71     }
72   }
73   return num_changes;
74 }
75 
getPassAnalysisType() const76 PassAnalysisType PredicateBasedPass::getPassAnalysisType() const {
77   return PassAnalysisType::CountBased;
78 }
79 
runPass(Graph & graph)80 std::shared_ptr<PostPassAnalysis> PredicateBasedPass::runPass(Graph& graph) {
81   bool initialized_pass = this->initializePass(graph);
82   unsigned int touched_optimizations = this->_runPassInternal(graph);
83   bool finalized_pass = this->finalizePass(graph);
84 
85   return std::shared_ptr<PostPassAnalysis>(new CountBasedPassAnalysis(
86       this, touched_optimizations, initialized_pass, finalized_pass));
87 }
88 
CountBasedPassAnalysis(Pass * pass,unsigned int num_positive_transforms,bool initialization_done,bool finalization_done)89 CountBasedPassAnalysis::CountBasedPassAnalysis(
90     Pass* pass,
91     unsigned int num_positive_transforms,
92     bool initialization_done,
93     bool finalization_done) {
94   this->pass = pass;
95   this->num_positive_transforms = num_positive_transforms;
96   this->initialization_done = initialization_done;
97   this->finalization_done = finalization_done;
98 }
99 
~FullGraphBasedPass()100 FullGraphBasedPass::~FullGraphBasedPass() {}
101 
102 } // namespace optimization
103 } // namespace ONNX_NAMESPACE
104