1 #pragma once 2 3 #include <set> 4 #include "onnx/optimizer/pass.h" 5 6 namespace ONNX_NAMESPACE { 7 namespace optimization { 8 9 // Lift lexically-scoped references within control operators to be inputs of the 10 // ops themselves. This transformation yields a graph that does not conform to 11 // the ONNX spec. 12 // 13 // The purpose of this pass is to expose the data dependencies within control 14 // blocks for frameworks that use those dependencies to schedule parallel 15 // execution. e.g. caffe2 graph execution. 16 // 17 // Example: 18 // ******************************** Before ************************************* 19 // graph test (%X[FLOAT, 5]) { 20 // %Y = Identity(%X) 21 // %trip_count = Constant[value = <Scalar Tensor [10]>]() 22 // %condition = Constant[value = <Scalar Tensor [1]>]() 23 // %Y2, %Y3 = Loop[body = <graph body_graph>](%trip_count, %condition, %) 24 // return %Y, %Y2 25 // } 26 // 27 // graph body_graph (%i[INT32, scalar], %cond[BOOL, scalar]) { 28 // %_Y2 = Identity(%X) 29 // %_Y3 = Identity(%Y) 30 // return %cond, %_Y2, %_Y3 31 // } 32 // 33 // ******************************** After ************************************** 34 // graph test (%X[FLOAT, 5]) { 35 // %Y = Identity(%X) 36 // %trip_count = Constant[value = <Scalar Tensor [10]>]() 37 // %condition = Constant[value = <Scalar Tensor [1]>]() 38 // %Y2, %Y3 = Loop[__control_inputs = ['X', 'Y'], body = <graph 39 // body_graph>](%trip_count, %condition, %) 40 // ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 41 // return %Y, %Y2 42 // } 43 // 44 // graph body_graph (%i[INT32, scalar], %cond[BOOL, scalar]) { 45 // %_Y2 = Identity(%X) 46 // %_Y3 = Identity(%Y) 47 // return %cond, %_Y2, %_Y3 48 // } 49 // 50 // ******************************** Continue Docs******************************* 51 // 52 // The algorithm is roughly: 53 // symbol_table_stack = empty stack of symbol tables 54 // 55 // liftreferences(graph) 56 // -> a set of unresolved reference strings: 57 // unresolved_references = {} 58 // 59 // symbol_table_stack.push(new symbol table containing inputs for this 60 // sub-graph) for each node in the graph: 61 // for input in node.inputs: 62 // if input is not in this frame: 63 // unresolved_references.insert(input) 64 // if node is a control flow operator: 65 // for each sub-graph g: 66 // for each output in g's body: 67 // if output is defined in current scope: 68 // control_inputs.insert(output) 69 // refs = liftreferences(g) 70 // for each ref in refs: 71 // if ref is in this frame or any parent frame (control_inputs): 72 // control_inputs.insert(ref) 73 // else: 74 // unresolved_references.insert(ref) 75 // set the control inputs attribute to the node 76 // for output in node.outputs: 77 // symbol_table_stack.top()[output] = Value* 78 // return unresolved_references 79 struct LiftLexicalReferences : public FullGraphBasedPass { LiftLexicalReferencesLiftLexicalReferences80 explicit LiftLexicalReferences() 81 : FullGraphBasedPass( 82 PassType::Seperate, 83 PassEfficiency::Complete, 84 PassOptimizationType::Memory) {} 85 getPassNameLiftLexicalReferences86 std::string getPassName() const override { 87 return "lift_lexical_references"; 88 } getPassAnalysisTypeLiftLexicalReferences89 PassAnalysisType getPassAnalysisType() const override { 90 return PassAnalysisType::Empty; 91 } 92 93 using ValueTable = std::unordered_map<std::string, Value*>; 94 95 // Environment stack, please to store value table and 96 // controlled inputs 97 struct Environment { nextLiftLexicalReferences::Environment98 Environment(std::shared_ptr<Environment> next = nullptr) : next(next) {} 99 100 std::shared_ptr<Environment> next; 101 findInThisFrameLiftLexicalReferences::Environment102 Value* findInThisFrame(const std::string& name) { 103 auto it = value_table.find(name); 104 if (it != value_table.end()) { 105 return it->second; 106 } 107 return nullptr; 108 } 109 findInParentFrameLiftLexicalReferences::Environment110 Value* findInParentFrame(const std::string& name) { 111 return next ? next->findInAnyFrame(name) : nullptr; 112 } 113 findInAnyFrameLiftLexicalReferences::Environment114 Value* findInAnyFrame(const std::string& name) { 115 for (auto runner = this; runner; runner = runner->next.get()) { 116 if (auto r = runner->findInThisFrame(name)) { 117 return r; 118 } 119 } 120 return nullptr; 121 } 122 setVarLiftLexicalReferences::Environment123 void setVar(const std::string& name, Value* value) { 124 value_table[name] = value; 125 } 126 127 private: 128 ValueTable value_table; 129 }; 130 131 std::shared_ptr<Environment> environment_stack; 132 133 // environment stack helper pushFrameLiftLexicalReferences134 void pushFrame() { 135 environment_stack = std::make_shared<Environment>(environment_stack); 136 } 137 popFrameLiftLexicalReferences138 std::shared_ptr<Environment> popFrame() { 139 auto old_frame = environment_stack; 140 environment_stack = environment_stack->next; 141 return old_frame; 142 } 143 liftReferencesLiftLexicalReferences144 std::set<std::string> liftReferences(Graph* g) { 145 std::set<std::string> unresolved_references; 146 pushFrame(); 147 for (auto& inp : g->inputs()) { 148 environment_stack->setVar(inp->uniqueName(), inp); 149 } 150 151 for (auto* n : g->nodes()) { 152 // Skip optional input/captured value node. 153 if (n->kind() == ONNX_NAMESPACE::kUndefined || 154 n->kind() == ONNX_NAMESPACE::kCaptured) { 155 continue; 156 } 157 for (auto* inp : n->inputs()) { 158 // Empty string is 0-input variadic argument. Skip that one. 159 if (!inp->uniqueName().empty() && 160 !environment_stack->findInThisFrame(inp->uniqueName())) { 161 unresolved_references.insert(inp->uniqueName()); 162 } 163 } 164 165 std::set<std::string> local_unresolved; 166 167 // if a graph body output has already already been emitted outside of the 168 // subgraph scope, then it must be added as an input to the subgraph 169 auto add_subgraph_outputs = [&](Graph* body_graph) { 170 for (auto* out : body_graph->outputs()) { 171 if (environment_stack->findInAnyFrame(out->uniqueName())) { 172 local_unresolved.insert(out->uniqueName()); 173 } 174 } 175 }; 176 177 if (n->kind() == ONNX_NAMESPACE::kLoop) { 178 auto* body_graph = n->g(ONNX_NAMESPACE::kbody).get(); 179 local_unresolved = liftReferences(body_graph); 180 add_subgraph_outputs(body_graph); 181 } else if (n->kind() == ONNX_NAMESPACE::kIf) { 182 auto* then_graph = n->g(ONNX_NAMESPACE::kthen_branch).get(); 183 add_subgraph_outputs(then_graph); 184 auto then_unresolved = liftReferences(then_graph); 185 local_unresolved.insert(then_unresolved.begin(), then_unresolved.end()); 186 auto* else_graph = n->g(ONNX_NAMESPACE::kelse_branch).get(); 187 add_subgraph_outputs(else_graph); 188 auto else_unresolved = liftReferences(else_graph); 189 local_unresolved.insert(else_unresolved.begin(), else_unresolved.end()); 190 } 191 192 std::vector<std::string> control_inputs; 193 for (auto& unresolved : local_unresolved) { 194 if (environment_stack->findInAnyFrame(unresolved)) { 195 control_inputs.push_back(unresolved); 196 } else { 197 unresolved_references.insert(unresolved); 198 } 199 } 200 201 // Create this attribute so the backend knows how many of these inputs 202 // are simply there for control dependencies 203 if (!control_inputs.empty()) { 204 n->ss_(ONNX_NAMESPACE::k__control_inputs, std::move(control_inputs)); 205 } 206 207 for (auto* out : n->outputs()) { 208 environment_stack->setVar(out->uniqueName(), out); 209 } 210 } 211 212 popFrame(); 213 return unresolved_references; 214 } 215 runPassLiftLexicalReferences216 std::shared_ptr<PostPassAnalysis> runPass(Graph& graph) override { 217 auto unresolved = liftReferences(&graph); 218 219 if (unresolved.size()) { 220 std::string errmsg = "Unresolved value references: "; 221 for (auto& ref : unresolved) { 222 errmsg += ref + ","; 223 } 224 throw std::runtime_error(errmsg); 225 } 226 return std::shared_ptr<PostPassAnalysis>(new PostPassAnalysis()); 227 } 228 }; 229 230 } // namespace optimization 231 } // namespace ONNX_NAMESPACE 232