1 #pragma once
2 
3 #include <set>
4 #include "onnx/optimizer/pass.h"
5 
6 namespace ONNX_NAMESPACE {
7 namespace optimization {
8 
9 // Lift lexically-scoped references within control operators to be inputs of the
10 // ops themselves. This transformation yields a graph that does not conform to
11 // the ONNX spec.
12 //
13 // The purpose of this pass is to expose the data dependencies within control
14 // blocks for frameworks that use those dependencies to schedule parallel
15 // execution. e.g. caffe2 graph execution.
16 //
17 // Example:
18 // ******************************** Before *************************************
19 // graph test (%X[FLOAT, 5]) {
20 //   %Y = Identity(%X)
21 //   %trip_count = Constant[value = <Scalar Tensor [10]>]()
22 //   %condition = Constant[value = <Scalar Tensor [1]>]()
23 //   %Y2, %Y3 = Loop[body = <graph body_graph>](%trip_count, %condition, %)
24 //   return %Y, %Y2
25 // }
26 //
27 // graph body_graph (%i[INT32, scalar], %cond[BOOL, scalar]) {
28 //   %_Y2 = Identity(%X)
29 //   %_Y3 = Identity(%Y)
30 //   return %cond, %_Y2, %_Y3
31 // }
32 //
33 // ******************************** After **************************************
34 // graph test (%X[FLOAT, 5]) {
35 //   %Y = Identity(%X)
36 //   %trip_count = Constant[value = <Scalar Tensor [10]>]()
37 //   %condition = Constant[value = <Scalar Tensor [1]>]()
38 //   %Y2, %Y3 = Loop[__control_inputs = ['X', 'Y'], body = <graph
39 //   body_graph>](%trip_count, %condition, %)
40 //                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
41 //   return %Y, %Y2
42 // }
43 //
44 // graph body_graph (%i[INT32, scalar], %cond[BOOL, scalar]) {
45 //   %_Y2 = Identity(%X)
46 //   %_Y3 = Identity(%Y)
47 //   return %cond, %_Y2, %_Y3
48 // }
49 //
50 // ******************************** Continue Docs*******************************
51 //
52 // The algorithm is roughly:
53 //  symbol_table_stack = empty stack of symbol tables
54 //
55 //  liftreferences(graph)
56 //      -> a set of unresolved reference strings:
57 //    unresolved_references = {}
58 //
59 //    symbol_table_stack.push(new symbol table containing inputs for this
60 //    sub-graph) for each node in the graph:
61 //      for input in node.inputs:
62 //        if input is not in this frame:
63 //          unresolved_references.insert(input)
64 //      if node is a control flow operator:
65 //        for each sub-graph g:
66 //          for each output in g's body:
67 //            if output is defined in current scope:
68 //              control_inputs.insert(output)
69 //          refs = liftreferences(g)
70 //          for each ref in refs:
71 //            if ref is in this frame or any parent frame (control_inputs):
72 //              control_inputs.insert(ref)
73 //            else:
74 //              unresolved_references.insert(ref)
75 //          set the control inputs attribute to the node
76 //        for output in node.outputs:
77 //          symbol_table_stack.top()[output] = Value*
78 //    return unresolved_references
79 struct LiftLexicalReferences : public FullGraphBasedPass {
LiftLexicalReferencesLiftLexicalReferences80   explicit LiftLexicalReferences()
81       : FullGraphBasedPass(
82             PassType::Seperate,
83             PassEfficiency::Complete,
84             PassOptimizationType::Memory) {}
85 
getPassNameLiftLexicalReferences86   std::string getPassName() const override {
87     return "lift_lexical_references";
88   }
getPassAnalysisTypeLiftLexicalReferences89   PassAnalysisType getPassAnalysisType() const override {
90     return PassAnalysisType::Empty;
91   }
92 
93   using ValueTable = std::unordered_map<std::string, Value*>;
94 
95   // Environment stack, please to store value table and
96   // controlled inputs
97   struct Environment {
nextLiftLexicalReferences::Environment98     Environment(std::shared_ptr<Environment> next = nullptr) : next(next) {}
99 
100     std::shared_ptr<Environment> next;
101 
findInThisFrameLiftLexicalReferences::Environment102     Value* findInThisFrame(const std::string& name) {
103       auto it = value_table.find(name);
104       if (it != value_table.end()) {
105         return it->second;
106       }
107       return nullptr;
108     }
109 
findInParentFrameLiftLexicalReferences::Environment110     Value* findInParentFrame(const std::string& name) {
111       return next ? next->findInAnyFrame(name) : nullptr;
112     }
113 
findInAnyFrameLiftLexicalReferences::Environment114     Value* findInAnyFrame(const std::string& name) {
115       for (auto runner = this; runner; runner = runner->next.get()) {
116         if (auto r = runner->findInThisFrame(name)) {
117           return r;
118         }
119       }
120       return nullptr;
121     }
122 
setVarLiftLexicalReferences::Environment123     void setVar(const std::string& name, Value* value) {
124       value_table[name] = value;
125     }
126 
127    private:
128     ValueTable value_table;
129   };
130 
131   std::shared_ptr<Environment> environment_stack;
132 
133   // environment stack helper
pushFrameLiftLexicalReferences134   void pushFrame() {
135     environment_stack = std::make_shared<Environment>(environment_stack);
136   }
137 
popFrameLiftLexicalReferences138   std::shared_ptr<Environment> popFrame() {
139     auto old_frame = environment_stack;
140     environment_stack = environment_stack->next;
141     return old_frame;
142   }
143 
liftReferencesLiftLexicalReferences144   std::set<std::string> liftReferences(Graph* g) {
145     std::set<std::string> unresolved_references;
146     pushFrame();
147     for (auto& inp : g->inputs()) {
148       environment_stack->setVar(inp->uniqueName(), inp);
149     }
150 
151     for (auto* n : g->nodes()) {
152       // Skip optional input/captured value node.
153       if (n->kind() == ONNX_NAMESPACE::kUndefined ||
154           n->kind() == ONNX_NAMESPACE::kCaptured) {
155         continue;
156       }
157       for (auto* inp : n->inputs()) {
158         // Empty string is 0-input variadic argument. Skip that one.
159         if (!inp->uniqueName().empty() &&
160             !environment_stack->findInThisFrame(inp->uniqueName())) {
161           unresolved_references.insert(inp->uniqueName());
162         }
163       }
164 
165       std::set<std::string> local_unresolved;
166 
167       // if a graph body output has already already been emitted outside of the
168       // subgraph scope, then it must be added as an input to the subgraph
169       auto add_subgraph_outputs = [&](Graph* body_graph) {
170         for (auto* out : body_graph->outputs()) {
171           if (environment_stack->findInAnyFrame(out->uniqueName())) {
172             local_unresolved.insert(out->uniqueName());
173           }
174         }
175       };
176 
177       if (n->kind() == ONNX_NAMESPACE::kLoop) {
178         auto* body_graph = n->g(ONNX_NAMESPACE::kbody).get();
179         local_unresolved = liftReferences(body_graph);
180         add_subgraph_outputs(body_graph);
181       } else if (n->kind() == ONNX_NAMESPACE::kIf) {
182         auto* then_graph = n->g(ONNX_NAMESPACE::kthen_branch).get();
183         add_subgraph_outputs(then_graph);
184         auto then_unresolved = liftReferences(then_graph);
185         local_unresolved.insert(then_unresolved.begin(), then_unresolved.end());
186         auto* else_graph = n->g(ONNX_NAMESPACE::kelse_branch).get();
187         add_subgraph_outputs(else_graph);
188         auto else_unresolved = liftReferences(else_graph);
189         local_unresolved.insert(else_unresolved.begin(), else_unresolved.end());
190       }
191 
192       std::vector<std::string> control_inputs;
193       for (auto& unresolved : local_unresolved) {
194         if (environment_stack->findInAnyFrame(unresolved)) {
195           control_inputs.push_back(unresolved);
196         } else {
197           unresolved_references.insert(unresolved);
198         }
199       }
200 
201       // Create this attribute so the backend knows how many of these inputs
202       // are simply there for control dependencies
203       if (!control_inputs.empty()) {
204         n->ss_(ONNX_NAMESPACE::k__control_inputs, std::move(control_inputs));
205       }
206 
207       for (auto* out : n->outputs()) {
208         environment_stack->setVar(out->uniqueName(), out);
209       }
210     }
211 
212     popFrame();
213     return unresolved_references;
214   }
215 
runPassLiftLexicalReferences216   std::shared_ptr<PostPassAnalysis> runPass(Graph& graph) override {
217     auto unresolved = liftReferences(&graph);
218 
219     if (unresolved.size()) {
220       std::string errmsg = "Unresolved value references: ";
221       for (auto& ref : unresolved) {
222         errmsg += ref + ",";
223       }
224       throw std::runtime_error(errmsg);
225     }
226     return std::shared_ptr<PostPassAnalysis>(new PostPassAnalysis());
227   }
228 };
229 
230 } // namespace optimization
231 } // namespace ONNX_NAMESPACE
232