1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file build_subgraph.cc
22 * \brief
23 */
24 #include <nnvm/graph.h>
25 #include <nnvm/pass.h>
26 #include <unordered_set>
27 #include <stack>
28 #include <queue>
29
30 #include "./subgraph_property.h"
31
32 #define DEBUG_SUBGRAPH 0
33
34 namespace nnvm {
35 ObjectPtr CreateVariableNode(const std::string& name);
36 }
37
38 namespace mxnet {
39 namespace op {
40 namespace sg { // sg stands for subgraph
41
42 #if DEBUG_SUBGRAPH
PrintSubgraph(const std::vector<BiDirectedNode * > & simple_nodes)43 void PrintSubgraph(const std::vector<BiDirectedNode*>& simple_nodes) {
44 std::string op_names = "";
45 for (size_t i = 0; i < simple_nodes.size(); ++i) {
46 op_names += simple_nodes[i]->node->attrs.name + ' ';
47 }
48 LOG(INFO) << "Subgraph node names: " << op_names;
49 }
50
PrintNodeEntry(const nnvm::NodeEntry & entry)51 void PrintNodeEntry(const nnvm::NodeEntry& entry) {
52 std::string ret = "NodeEntry: node_name=" + entry.node->attrs.name
53 + ", index=" + std::to_string(entry.index) + ", version=" + std::to_string(entry.version);
54 LOG(INFO) << ret;
55 }
56
PrintNodeEntries(const std::vector<nnvm::NodeEntry * > & entries)57 void PrintNodeEntries(const std::vector<nnvm::NodeEntry*>& entries) {
58 for (size_t i = 0; i < entries.size(); ++i) {
59 PrintNodeEntry(*entries[i]);
60 }
61 }
62 #endif
63
64 /*!
65 * \brief Given a MXNet computational graph, create an undirected graph from it.
66 * \param g the MXNet computational graph
67 * \param simple_nodes the nodes of undirected graph in top sorted order
68 */
CreateSimpleGraph(const nnvm::Graph & g,std::vector<BiDirectedNodePtr> * simple_nodes)69 void CreateSimpleGraph(const nnvm::Graph& g,
70 std::vector<BiDirectedNodePtr>* simple_nodes) {
71 const auto& indexed_graph = g.indexed_graph();
72 simple_nodes->reserve(indexed_graph.num_nodes());
73 DFSVisit(g.outputs, [&](const nnvm::ObjectPtr& node) {
74 BiDirectedNodePtr sn = BiDirectedNode::Create();
75 sn->node = node.get();
76 for (size_t i = 0; i < sn->node->inputs.size(); ++i) {
77 const auto& e = sn->node->inputs[i];
78 const auto input_nid = indexed_graph.node_id(e.node.get());
79 CHECK_LT(input_nid, simple_nodes->size());
80 auto& input_node_outputs = (*simple_nodes)[input_nid]->outputs;
81 auto it = input_node_outputs.find(sn->node);
82 if (it == input_node_outputs.end()) {
83 input_node_outputs.emplace(sn->node, std::vector<size_t>{i});
84 } else {
85 it->second.push_back(i);
86 }
87 }
88 simple_nodes->emplace_back(std::move(sn));
89 });
90 }
91
92 /*!
93 * \brief Reset labels of the subgraph nodes to the original state
94 * and clear the vector of subgraph nodes.
95 */
ResetNodeLabels(const nnvm::Graph & g,const std::vector<BiDirectedNodePtr> & simple_nodes,std::vector<BiDirectedNode * > * subgraph_nodes)96 void ResetNodeLabels(const nnvm::Graph& g,
97 const std::vector<BiDirectedNodePtr>& simple_nodes,
98 std::vector<BiDirectedNode*>* subgraph_nodes) {
99 for (auto n : *subgraph_nodes) {
100 const auto nid = g.indexed_graph().node_id(n->node);
101 simple_nodes[nid]->label = -1;
102 }
103 subgraph_nodes->clear();
104 }
105
106 /*
107 * \brief Prepare NodeAttr for node. NodeAttr will be used in SubgraphSelectorV2.
108 */
PrepareNodeAttr(const nnvm::Graph & g,const BiDirectedNode & node)109 static const std::shared_ptr<NodeAttr> PrepareNodeAttr(const nnvm::Graph& g,
110 const BiDirectedNode& node) {
111 const auto& indexed_graph = g.indexed_graph();
112 if (g.HasAttr("dtype") && g.HasAttr("shape") && g.HasAttr("dispatch_mode")) {
113 const auto& vdtype = g.GetAttr<nnvm::DTypeVector>("dtype");
114 const auto& vshape = g.GetAttr<mxnet::ShapeVector>("shape");
115 const auto& dispatch_modes = g.GetAttr<mxnet::DispatchModeVector>("dispatch_mode");
116 auto ret = std::make_shared<NodeAttr>();
117 ret->dispatch_mode = dispatch_modes[indexed_graph.node_id(node.node)];
118 for (const auto& e : node.node->inputs) {
119 ret->ishape.emplace_back(vshape[indexed_graph.entry_id(e)]);
120 ret->itype.emplace_back(vdtype[indexed_graph.entry_id(e)]);
121 }
122 return ret;
123 } else {
124 return nullptr;
125 }
126 }
127
128 /*!
129 * \brief This function traverses the nodes in a computation graph from a starting
130 * node following the input edges and output edges, and marks all nodes that
131 * can be accessed from the starting node. Before the function returns,
132 * it will conduct checking whether there is a loop between the potential subgraph
133 * and the outside nodes. If so, add the node that should break the loop
134 * in excluded_nodes and return false. Otherwise, return true.
135 * \param g the whole graph
136 * \subgraph_selector determines whether the visited node should be choosen or not
137 * \label the label of the current subgraph
138 * \snid node id of the seed simple node
139 * \simple_nodes all simple nodes in the top sorted order
140 * \subgraph_nodes all the nodes belonging to the same subgraph of seed node
141 * \excluded_nodes set of nodes that should be excluded from the current subgraph
142 */
LabelSubgraph(const nnvm::Graph & g,SubgraphSelectorV2Ptr subgraph_selector,const int label,const size_t snid,const std::vector<BiDirectedNodePtr> & simple_nodes,std::vector<BiDirectedNode * > * subgraph_nodes,std::unordered_set<const BiDirectedNode * > * excluded_nodes)143 bool LabelSubgraph(const nnvm::Graph& g, SubgraphSelectorV2Ptr subgraph_selector, const int label,
144 const size_t snid, const std::vector<BiDirectedNodePtr>& simple_nodes,
145 std::vector<BiDirectedNode*>* subgraph_nodes,
146 std::unordered_set<const BiDirectedNode*>* excluded_nodes) {
147 const auto& indexed_graph = g.indexed_graph();
148 std::queue<BiDirectedNode*> node_queue;
149 CHECK_EQ(simple_nodes[snid]->label, -1);
150 simple_nodes[snid]->label = label;
151 node_queue.push(simple_nodes[snid].get());
152 // key: nodes that serve as input/output nodes to the subgraph
153 // value: pair of vectors of nodes in the subgraph. The first vector contains the
154 // output nodes of the key in the subgraph, and the second vector contains the
155 // input nodes of the key in the subgraph.
156 // If a non-subgraph node has inputs from the subgraph and the other non-subgraph node
157 // has outputs to the subgraph, and the first non-subgraph node is an ancestor
158 // of the second non-subgraph node, there exits a cycle.
159 // When breaking the cycle, we want to start from removing the node with the largest node id
160 // in the subgraph.
161 std::unordered_map<const nnvm::Node*,
162 std::pair<std::vector<const nnvm::Node*>,
163 std::vector<const nnvm::Node*>>> non_subgraph_node_map;
164 while (!node_queue.empty()) {
165 BiDirectedNode* cur_node = node_queue.front();
166 node_queue.pop();
167 subgraph_nodes->push_back(cur_node);
168 // get qualified adjacent input nodes
169 for (auto& e : cur_node->node->inputs) {
170 const auto node = e.node.get();
171 const auto nid = indexed_graph.node_id(node);
172 auto snode = simple_nodes[nid].get();
173 CHECK_LT(nid, simple_nodes.size());
174 const bool select_input =
175 (snode->label == -1) && (!excluded_nodes || !excluded_nodes->count(snode)) &&
176 subgraph_selector->SelectInput(*cur_node, *snode, PrepareNodeAttr(g, *snode));
177 if (select_input) {
178 // e.node is a subgraph node
179 snode->label = label;
180 node_queue.push(snode);
181 } else if (snode->label == -1) {
182 // e.node is an input node of the subgraph
183 non_subgraph_node_map[e.node.get()].first.push_back(cur_node->node);
184 }
185 }
186 // get qualified output nodes
187 for (auto it = cur_node->outputs.begin(); it != cur_node->outputs.end(); ++it) {
188 const auto nid = indexed_graph.node_id(it->first);
189 auto snode = simple_nodes[nid].get();
190 CHECK_LT(nid, simple_nodes.size());
191 const bool select_output =
192 (snode->label == -1) && (!excluded_nodes || !excluded_nodes->count(snode)) &&
193 subgraph_selector->SelectOutput(*cur_node, *snode, PrepareNodeAttr(g, *snode));
194 if (select_output) {
195 // it->first is a subgraph node
196 snode->label = label;
197 node_queue.push(snode);
198 } else if (snode->label == -1) {
199 // it->first is an output node of the subgraph
200 non_subgraph_node_map[it->first].second.push_back(cur_node->node);
201 }
202 }
203 }
204 // prepare to check if there is a cycle
205 auto node_cmp = [&] (const nnvm::Node* node1, const nnvm::Node* node2) {
206 return indexed_graph.node_id(node1) < indexed_graph.node_id(node2);
207 };
208 std::vector<const nnvm::Node*> non_subgraph_nodes;
209 non_subgraph_nodes.reserve(non_subgraph_node_map.size());
210 for (auto& kv : non_subgraph_node_map) {
211 auto& output_nodes = kv.second.first;
212 std::sort(output_nodes.begin(), output_nodes.end(), node_cmp);
213 auto& input_nodes = kv.second.second;
214 std::sort(input_nodes.begin(), input_nodes.end(), node_cmp);
215 non_subgraph_nodes.push_back(kv.first);
216 }
217 // check whether there is a cycle between the subgraph and its input/output nodes
218 auto is_ancestor = [&](const nnvm::Node* ancestor, const nnvm::Node* descendant,
219 const std::vector<BiDirectedNode*>& snodes) {
220 if (ancestor == descendant) return true;
221 std::unordered_set<nnvm::Node*> snode_set;
222 for (const auto& sn : snodes) {
223 snode_set.insert(sn->node);
224 }
225 std::stack<const nnvm::Node*> s;
226 s.push(descendant);
227 size_t count = 0;
228 while (!s.empty() && count < indexed_graph.num_nodes()) {
229 ++count;
230 const nnvm::Node* top = s.top();
231 s.pop();
232 if (top == ancestor) {
233 return true;
234 }
235 for (const auto& entry : top->inputs) {
236 // when searching for the ancestor, the path cannot cross any subgraph node
237 if (!snode_set.count(entry.node.get())) {
238 s.push(entry.node.get());
239 }
240 }
241 }
242 return false;
243 };
244 std::sort(non_subgraph_nodes.begin(), non_subgraph_nodes.end(), node_cmp);
245 int excluded_node_id = -1;
246 for (size_t i = 0; i < non_subgraph_nodes.size(); ++i) {
247 auto it1 = non_subgraph_node_map.find(non_subgraph_nodes[i]);
248 CHECK(it1 != non_subgraph_node_map.end());
249 auto& output_nodes = it1->second.first; // has been top sorted
250 auto& input_nodes = it1->second.second; // has been top sorted
251 if (!output_nodes.empty() && !input_nodes.empty()) {
252 // there is a loop between node i and the subgraph
253 const auto node_id = std::max(indexed_graph.node_id(output_nodes.back()),
254 indexed_graph.node_id(input_nodes.back()));
255 excluded_node_id = std::max(excluded_node_id, static_cast<int>(node_id));
256 } else if (!input_nodes.empty()) {
257 // node i is an input to the subgraph, find out if there is a node j
258 // which is an output of the subgraph and also a child of node i.
259 for (size_t j = i + 1; j < non_subgraph_nodes.size(); ++j) {
260 auto it2 = non_subgraph_node_map.find(non_subgraph_nodes[j]);
261 CHECK(it2 != non_subgraph_node_map.end());
262 // i is topologically before j, j might be a direct/indirect output node of i
263 CHECK_LT(indexed_graph.node_id(it1->first), indexed_graph.node_id(it2->first));
264 if (!it2->second.first.empty() && is_ancestor(it1->first, it2->first, *subgraph_nodes)) {
265 // found a loop
266 const auto node_id = std::max(indexed_graph.node_id(input_nodes.back()),
267 indexed_graph.node_id(it2->second.first.back()));
268 excluded_node_id = std::max(excluded_node_id, static_cast<int>(node_id));
269 }
270 }
271 }
272 }
273
274 if (excluded_node_id != -1) {
275 CHECK_LT(excluded_node_id, static_cast<int>(simple_nodes.size()));
276 excluded_nodes->insert(simple_nodes[excluded_node_id].get());
277 ResetNodeLabels(g, simple_nodes, subgraph_nodes);
278 return false;
279 }
280 auto sim_node_cmp = [&] (const BiDirectedNode* node1, const BiDirectedNode* node2) {
281 return indexed_graph.node_id(node1->node) < indexed_graph.node_id(node2->node);
282 };
283 std::sort(subgraph_nodes->begin(), subgraph_nodes->end(), sim_node_cmp);
284 return true;
285 }
286
287 /*!
288 * \brief Finds all the nodes belonging to the same subgraph given a seed node.
289 * \param g the whole graph
290 * \subgraph_selector determines whether the visited node should be choosen or not
291 * \label the label of the current subgraph
292 * \snid node id of the seed simple node
293 * \simple_nodes all simple nodes in the top sorted order
294 * \subgraph_nodes all the nodes belonging to the same subgraph of seed node
295 * \return Subgraph node candidates sorted in the topological order
296 */
PreSelectSubgraphNodes(const nnvm::Graph & g,SubgraphSelectorV2Ptr subgraph_selector,const int label,const size_t snid,const std::vector<BiDirectedNodePtr> & simple_nodes,std::vector<BiDirectedNode * > * subgraph_nodes)297 void PreSelectSubgraphNodes(const nnvm::Graph& g, SubgraphSelectorV2Ptr subgraph_selector,
298 const int label, const size_t snid,
299 const std::vector<BiDirectedNodePtr>& simple_nodes,
300 std::vector<BiDirectedNode*>* subgraph_nodes) {
301 std::unordered_set<const BiDirectedNode*> excluded_nodes;
302 size_t n_excluded_nodes = 0;
303 const size_t max_num_retry = simple_nodes.size() * simple_nodes.size();
304 size_t count = 0;
305 bool success = false;
306 while (!success && count < max_num_retry) {
307 success = LabelSubgraph(g, subgraph_selector, label, snid, simple_nodes, subgraph_nodes,
308 &excluded_nodes);
309 if (!success) {
310 // Failed to label subgraph due to a cycle
311 // If the number of excluded_nodes didn't change since the last iteration,
312 // this means that there is no possible subgraph for the current node snid, we break
313 // Otherwise, we keep trying (with the excluded nodes tagged)
314 if (excluded_nodes.size() == n_excluded_nodes) {
315 break;
316 }
317 n_excluded_nodes = excluded_nodes.size();
318 std::string excluded_node_names;
319 for (auto node : excluded_nodes) {
320 excluded_node_names += node->node->attrs.name + ", ";
321 }
322 static int verbose = dmlc::GetEnv("MXNET_SUBGRAPH_VERBOSE", 1);
323 if (verbose > 1) {
324 LOG(INFO) << "Found a cycle when BFS from node " << simple_nodes[snid]->node->attrs.name
325 << ". Excluding nodes " << excluded_node_names << "and retrying";
326 }
327 subgraph_selector->Reset();
328 }
329 ++count;
330 }
331 if (!success) {
332 LOG(INFO) << "Tried " << count << " times of finding subgraphs starting from node "
333 << simple_nodes[snid]->node->attrs.name
334 << " without success because a loop "
335 "is always found between the subgraph and some other nodes. Will treat "
336 "seed node "
337 << simple_nodes[snid]->node->attrs.name << "as a subgraph with one node";
338 CHECK(subgraph_nodes->empty());
339 simple_nodes[snid]->label = label;
340 subgraph_nodes->push_back(simple_nodes[snid].get());
341 }
342 }
343
SelectSubgraphNodes(nnvm::Graph * g,SubgraphSelectorV2Ptr subgraph_selector,const std::vector<BiDirectedNodePtr> & simple_nodes,std::vector<std::vector<BiDirectedNode * >> * subgraph_nodes,std::vector<SubgraphSelectorV2Ptr> * subgraph_selectors,const BiDirectedNode * node,const size_t snid,size_t * subgraph_id)344 void SelectSubgraphNodes(nnvm::Graph* g, SubgraphSelectorV2Ptr subgraph_selector,
345 const std::vector<BiDirectedNodePtr>& simple_nodes,
346 std::vector<std::vector<BiDirectedNode*>>* subgraph_nodes,
347 std::vector<SubgraphSelectorV2Ptr>* subgraph_selectors,
348 const BiDirectedNode* node, const size_t snid, size_t* subgraph_id) {
349 const auto& indexed_graph = g->indexed_graph();
350
351 auto node_cmp = [&] (const BiDirectedNode* node1, const BiDirectedNode* node2) {
352 return indexed_graph.node_id(node1->node) < indexed_graph.node_id(node2->node);
353 };
354 if ((simple_nodes[snid]->label == -1) &&
355 subgraph_selector->Select(*node, PrepareNodeAttr(*g, *node))) {
356 // pre-select nodes that can be grouped in a subgraph
357 std::vector<BiDirectedNode*> preselected_nodes;
358 PreSelectSubgraphNodes(*g, subgraph_selector, *subgraph_id, snid, simple_nodes,
359 &preselected_nodes);
360
361 // filter out unqualified pre-selected nodes
362 std::vector<BiDirectedNode*> filtered_nodes = subgraph_selector->Filter(preselected_nodes);
363
364 // reset node labels that are not in filtered nodes
365 for (const auto n : preselected_nodes) {
366 const auto nit = std::find(filtered_nodes.begin(), filtered_nodes.end(), n);
367 if (nit == filtered_nodes.end()) {
368 n->label = -1;
369 }
370 }
371
372 if (filtered_nodes.size()) {
373 // make sure filtered_nodes is a subset of preselected_nodes
374 for (const auto n : filtered_nodes) {
375 const auto nit = std::find(preselected_nodes.begin(), preselected_nodes.end(), n);
376 CHECK(nit != preselected_nodes.end())
377 << "Node " << n->node->attrs.name
378 << " is not found in the pre-selected subgraph nodes."
379 " Please make sure that no new nodes were added in your subgraph"
380 " selector's Filter function";
381 }
382
383 // make sure nodes are sorted
384 std::sort(filtered_nodes.begin(), filtered_nodes.end(), node_cmp);
385 subgraph_nodes->push_back(filtered_nodes);
386 subgraph_selectors->push_back(subgraph_selector);
387 (*subgraph_id)++;
388 }
389 }
390 }
391
392 /*!
393 * \brief Finds subgraphs with all nodes that meet certain criteria.
394 * All nodes in a subgraph are marked with the same label.
395 */
FindSubgraphs(nnvm::Graph * g,const SubgraphProperty & subg_prop,const std::vector<BiDirectedNodePtr> & simple_nodes,std::vector<std::vector<BiDirectedNode * >> * subgraph_nodes,std::vector<SubgraphSelectorV2Ptr> * subgraph_selectors)396 void FindSubgraphs(nnvm::Graph* g,
397 const SubgraphProperty &subg_prop,
398 const std::vector<BiDirectedNodePtr>& simple_nodes,
399 std::vector<std::vector<BiDirectedNode*>>* subgraph_nodes,
400 std::vector<SubgraphSelectorV2Ptr>* subgraph_selectors) {
401 const auto& indexed_graph = g->indexed_graph();
402 CHECK_EQ(indexed_graph.num_nodes(), simple_nodes.size());
403
404 size_t subgraph_id = 0;
405 for (size_t i = 0; i < simple_nodes.size(); ++i) {
406 const auto snode = simple_nodes[i];
407 SubgraphSelectorV2Ptr subgraph_selector = subg_prop.CreateSubgraphSelectorV2();
408 SelectSubgraphNodes(g, subgraph_selector, simple_nodes, subgraph_nodes, subgraph_selectors,
409 snode.get(), i, &subgraph_id);
410 }
411 }
412
413 /*!
414 * \brief Sorts entries according to their topological order.
415 * Note that entry ids cannot be used to sort entries.
416 * \param entry_top_order_map mapping from entry pointer to its topological position in the graph
417 * \param entries Node entries to be sorted
418 */
SortEntries(const std::unordered_map<const nnvm::NodeEntry *,size_t> & entry_top_order_map,std::vector<nnvm::NodeEntry * > * entries)419 void SortEntries(const std::unordered_map<const nnvm::NodeEntry*, size_t>& entry_top_order_map,
420 std::vector<nnvm::NodeEntry*>* entries) {
421 auto entry_cmp = [&](const nnvm::NodeEntry* e1, const nnvm::NodeEntry* e2) {
422 const auto it1 = entry_top_order_map.find(e1);
423 CHECK(it1 != entry_top_order_map.end());
424 const auto it2 = entry_top_order_map.find(e2);
425 CHECK(it2 != entry_top_order_map.end());
426 return it1->second < it2->second;
427 };
428 std::sort(entries->begin(), entries->end(), entry_cmp);
429 }
430
431 /*!
432 * \brief Given a subgraph, find the input entries of a subgraph.
433 * \param g pointer to the whole graph
434 * \param simple_nods vector of simple nodes in top sorted order
435 * \param subgraph_nodes vector of pointers of simples of a subgraph.
436 * \param entry_top_order_map mapping entry pointer to its top sorted position
437 * \param input_entries input entries of the subgraph
438 */
FindInputEntries(const nnvm::Graph & g,const std::vector<BiDirectedNodePtr> & simple_nodes,const std::vector<BiDirectedNode * > & subgraph_nodes,const std::unordered_map<const nnvm::NodeEntry *,size_t> & entry_top_order_map,std::vector<nnvm::NodeEntry * > * input_entries)439 void FindInputEntries(const nnvm::Graph& g,
440 const std::vector<BiDirectedNodePtr>& simple_nodes,
441 const std::vector<BiDirectedNode*>& subgraph_nodes,
442 const std::unordered_map<const nnvm::NodeEntry*, size_t>& entry_top_order_map,
443 std::vector<nnvm::NodeEntry*>* input_entries) {
444 const auto& indexed_graph = g.indexed_graph();
445 int label = -1;
446 for (auto subgraph_node : subgraph_nodes) {
447 if (label == -1) {
448 label = subgraph_node->label;
449 } else {
450 CHECK_EQ(subgraph_node->label, label);
451 }
452 auto& inputs = subgraph_node->node->inputs;
453 for (auto &e : inputs) {
454 if (indexed_graph.exist(e.node.get())) {
455 // e's source node is not a subgraph node
456 const auto nid = indexed_graph.node_id(e.node.get());
457 // this is a node not belonging to the subgraph
458 if (simple_nodes[nid]->label != label) {
459 input_entries->push_back(&e);
460 }
461 } else {
462 // e's source node is a subgraph node.
463 // In this case, two subgraphs are adjacent.
464 input_entries->push_back(&e);
465 }
466 }
467 }
468 SortEntries(entry_top_order_map, input_entries);
469 }
470
471 /*!
472 * \brief Given a subgraph, find the output entries of a subgraph.
473 * \param g pointer to the whole graph
474 * \param simple_nods vector of simple nodes in top sorted order
475 * \param subgraph_nodes vector of pointers of simples of a subgraph.
476 * \param entry_top_order_map mapping entry pointer to its top sorted position
477 * \param output_entries output entries of the subgraph
478 */
FindOutputEntries(nnvm::Graph * g,const std::vector<BiDirectedNodePtr> & simple_nodes,const std::vector<BiDirectedNode * > & subgraph_nodes,const std::unordered_map<const nnvm::NodeEntry *,size_t> & entry_top_order_map,std::vector<nnvm::NodeEntry * > * output_entries)479 void FindOutputEntries(nnvm::Graph* g,
480 const std::vector<BiDirectedNodePtr>& simple_nodes,
481 const std::vector<BiDirectedNode*>& subgraph_nodes,
482 const std::unordered_map<const nnvm::NodeEntry*, size_t>&
483 entry_top_order_map,
484 std::vector<nnvm::NodeEntry*>* output_entries) {
485 if (subgraph_nodes.empty()) return;
486 const auto& indexed_graph = g->indexed_graph();
487 int label = -1;
488 for (auto subgraph_node : subgraph_nodes) {
489 if (label == -1) {
490 label = subgraph_node->label;
491 } else {
492 CHECK_EQ(subgraph_node->label, label);
493 }
494 for (auto &output_node : subgraph_node->outputs) {
495 if (indexed_graph.exist(output_node.first)) {
496 // if the output node is a normal graph node (not a subgraph node)
497 const auto nid = indexed_graph.node_id(output_node.first);
498 // this is a node not belonging to the current subgraph
499 if (simple_nodes[nid]->label != label) {
500 for (auto idx : output_node.second) {
501 auto& e = simple_nodes[nid]->node->inputs[idx];
502 output_entries->push_back(&e);
503 }
504 }
505 } else {
506 // if the output node is a subgraph node
507 // two graphs are adjacent
508 for (auto idx : output_node.second) {
509 output_entries->push_back(&(output_node.first->inputs[idx]));
510 }
511 }
512 }
513 }
514 // Check if current subgraph contains a node which is the last node
515 // of the whole graph. If so, save its corresponding entry as well.
516 for (auto &entry : g->outputs) {
517 // The entry might has been updated as an output of
518 // a subgraph node. In this case, no need
519 // to check its source for the current subgraph. Otherwise,
520 // do the following.
521 if (indexed_graph.exist(entry.node.get())) {
522 const auto nid = indexed_graph.node_id(entry.node.get());
523 if (simple_nodes[nid]->label == label) {
524 output_entries->push_back(&entry);
525 }
526 }
527 }
528 SortEntries(entry_top_order_map, output_entries);
529 }
530
531 /*!
532 * \brief Given a computation graph and a set of input node entries, this function cuts
533 * the node entries and creates new variable nodes as the input nodes of the
534 * subgraph. It returns the nodes that connect to the subgraph directly and
535 * the names of the new variable nodes.
536 */
CutGraphInputs(const std::vector<nnvm::NodeEntry * > & input_entries,std::vector<nnvm::NodeEntry> * orig_entries,std::vector<nnvm::NodeEntry> * unique_orig_entries,std::vector<nnvm::NodeEntry * > * unique_input_entries,const bool skip_var=false,const bool dedup=false)537 void CutGraphInputs(const std::vector<nnvm::NodeEntry*> &input_entries,
538 std::vector<nnvm::NodeEntry> *orig_entries,
539 std::vector<nnvm::NodeEntry> *unique_orig_entries,
540 std::vector<nnvm::NodeEntry*> *unique_input_entries,
541 const bool skip_var = false,
542 const bool dedup = false) {
543 orig_entries->resize(input_entries.size());
544 // map for creating unique var nodes for deduplicating entries from the same node
545 std::unordered_map<std::string, nnvm::NodeEntry> name_map;
546 std::unordered_map<std::string, int> name_count_map;
547
548 for (size_t i = 0; i < input_entries.size(); ++i) {
549 nnvm::NodeEntry *e = input_entries[i];
550 // If the node is a variable itself, we may want to skip the node.
551 if (e->node->is_variable() && skip_var) {
552 continue;
553 }
554 // save all original entries
555 orig_entries->at(i) = *e;
556 // get unique name for this entry
557 nnvm::Symbol sym;
558 sym.outputs.push_back(*e);
559 const auto output_names = sym.ListOutputNames();
560 CHECK_EQ(output_names.size(), 1U);
561 const std::string& var_name = output_names[0];
562 // check if this entry is a duplicate
563 if (name_count_map.count(var_name) == 0) {
564 // first use of this node as input to subgraph
565 name_count_map.emplace(var_name, 0);
566 unique_orig_entries->push_back(*e);
567 unique_input_entries->push_back(e);
568 nnvm::ObjectPtr n = nnvm::CreateVariableNode(var_name + std::to_string(0));
569 name_map.emplace(var_name, nnvm::NodeEntry{n, 0, 0});
570 } else {
571 // other use of same node as input to subgraph
572 name_count_map[var_name]++;
573 }
574
575 if (dedup) {
576 *e = name_map[var_name];
577 } else {
578 nnvm::ObjectPtr n = nnvm::CreateVariableNode(
579 var_name + std::to_string(name_count_map[var_name]));
580 *e = nnvm::NodeEntry{n, 0, 0};
581 }
582 }
583 }
584
585 /*!
586 * \brief This function reattaches the original input nodes that were cut
587 * by CutGraphInputs. This function is used when subgraphs are rejected, it
588 * reattaches the subgraph back to the main graph where it was cut earlier.
589 */
ReattachGraphInputs(const std::vector<nnvm::NodeEntry * > & input_entries,std::vector<nnvm::NodeEntry> * orig_entries)590 void ReattachGraphInputs(const std::vector<nnvm::NodeEntry*> &input_entries,
591 std::vector<nnvm::NodeEntry> *orig_entries) {
592 for (size_t i = 0; i < input_entries.size(); ++i) {
593 nnvm::NodeEntry *e = input_entries[i];
594 *e = orig_entries->at(i);
595 }
596 }
597
598 /*!
599 * \brief Replace a set of nodes belonging to the same subgraph with a subgraph node
600 * and keep the subgraph in the subgraph node.
601 */
CreateSubgraphNode(nnvm::Graph * g,const std::vector<BiDirectedNodePtr> & simple_nodes,const std::vector<BiDirectedNode * > & subgraph_nodes,const SubgraphSelectorV2Ptr & subgraph_selector,const size_t subgraph_id,std::unordered_map<const nnvm::NodeEntry *,size_t> * entry_top_order_map)602 void CreateSubgraphNode(nnvm::Graph* g,
603 const std::vector<BiDirectedNodePtr>& simple_nodes,
604 const std::vector<BiDirectedNode*>& subgraph_nodes,
605 const SubgraphSelectorV2Ptr& subgraph_selector,
606 const size_t subgraph_id,
607 std::unordered_map<const nnvm::NodeEntry*, size_t>* entry_top_order_map) {
608 #if DEBUG_SUBGRAPH
609 LOG(INFO) << "Searching for input entries...";
610 #endif
611 bool dedup_subgraph = g->HasAttr("dedup_subgraph");
612 std::vector<nnvm::NodeEntry*> input_entries; // nodes that produce inputs to subgraph nodes
613 FindInputEntries(*g, simple_nodes, subgraph_nodes, *entry_top_order_map, &input_entries);
614 std::vector<nnvm::NodeEntry> orig_input_entries; // original input entries (dupes)
615 std::vector<nnvm::NodeEntry> unique_orig_entries; // unique original input entries
616 std::vector<nnvm::NodeEntry*> unique_input_entries; // unique modified subgraph inputs
617 CutGraphInputs(input_entries, &orig_input_entries, &unique_orig_entries,
618 &unique_input_entries, false, dedup_subgraph);
619 #if DEBUG_SUBGRAPH
620 PrintNodeEntries(input_entries);
621 LOG(INFO) << "Searching for output entries...";
622 #endif
623 std::vector<nnvm::NodeEntry*> output_entries;
624 FindOutputEntries(g, simple_nodes, subgraph_nodes, *entry_top_order_map, &output_entries);
625
626 // Create a subgraph for the subgraph node
627 // entries are in topological order, with duplicates being neighbors
628 nnvm::Symbol sym;
629 size_t idx = 0;
630 nnvm::NodeEntryEqual node_equal;
631 sym.outputs.resize(output_entries.size());
632 for (size_t i = 0; i < output_entries.size(); ++i) {
633 if (dedup_subgraph) {
634 if (i == 0) { // add first entry
635 sym.outputs[idx] = *output_entries[i];
636 } else if (!node_equal(sym.outputs[idx], *output_entries[i])) { // compare to see if diff
637 // add new entries
638 idx++;
639 sym.outputs[idx] = *output_entries[i];
640 } // else skip over dupe entries
641 } else {
642 sym.outputs[i] = *output_entries[i];
643 }
644 }
645 if (dedup_subgraph)
646 sym.outputs.resize(idx+1);
647
648 const SubgraphPropertyPtr& subg_prop = g->GetAttr<SubgraphPropertyPtr>("subgraph_property");
649 if (dedup_subgraph)
650 subg_prop->InitSubgraphInputs(&unique_input_entries, &unique_orig_entries);
651 else
652 subg_prop->InitSubgraphInputs(&input_entries, &orig_input_entries);
653 nnvm::ObjectPtr n = subg_prop->CreateSubgraphNode(sym, subgraph_selector, subgraph_id);
654 // CreateSubgraphNode returns NULL if subgraph property determines that subgraph is sub-optimal
655 // In that case, subgraph node is not created and graph is not modified
656 if (n) {
657 // Connect the external nodes to the subgraph node.
658 subg_prop->ConnectSubgraphOutputs(n, &output_entries);
659 if (dedup_subgraph)
660 subg_prop->ConnectSubgraphInputs(n, &unique_input_entries, &unique_orig_entries);
661 else
662 subg_prop->ConnectSubgraphInputs(n, &input_entries, &orig_input_entries);
663
664 const auto& indexed_graph = g->indexed_graph();
665 for (size_t i = 0; i < n->inputs.size(); ++i) {
666 auto& e = n->inputs[i];
667 // update entry_top_order_map with newly created orig_input_entries
668 auto it = entry_top_order_map->find(input_entries[i]);
669 CHECK(it != entry_top_order_map->end());
670 entry_top_order_map->emplace(&e, it->second);
671 // update input entries' source simple nodes' outputs map
672 nnvm::Node* node = e.node.get();
673 if (indexed_graph.exist(node)) {
674 const auto nid = indexed_graph.node_id(node);
675 BiDirectedNode* sn = simple_nodes[nid].get();
676 for (BiDirectedNode* dest_node : subgraph_nodes) {
677 sn->outputs.erase(dest_node->node);
678 }
679 sn->outputs[n.get()].push_back(i);
680 }
681 }
682 } else {
683 ReattachGraphInputs(input_entries, &orig_input_entries);
684 }
685 #if DEBUG_SUBGRAPH
686 if (n)
687 LOG(INFO) << "Subgraph node created and output_entries updated.";
688 else
689 LOG(INFO) << "Subgraph node not created, output_entries not updated.";
690 PrintNodeEntries(output_entries);
691 #endif
692 }
693
694 /*!
695 * \brief Adjust a set of nodes belonging to the same subgraph. No new node is created, but
696 * adjust selected nodes' attributes.
697 * This can be used to implement peephole optimization. For example, adjust calibration information
698 * of quantized nodes.
699 */
AdjustSubgraphNode(nnvm::Graph * g,const std::vector<BiDirectedNode * > & subgraph_nodes,const SubgraphSelectorV2Ptr & subgraph_selector,const size_t subgraph_id)700 void AdjustSubgraphNode(nnvm::Graph* g,
701 const std::vector<BiDirectedNode*>& subgraph_nodes,
702 const SubgraphSelectorV2Ptr& subgraph_selector,
703 const size_t subgraph_id) {
704 std::vector<nnvm::Node*> node_list;
705 for (auto node : subgraph_nodes) {
706 node_list.push_back(node->node);
707 }
708
709 const SubgraphPropertyPtr& subg_prop = g->GetAttr<SubgraphPropertyPtr>("subgraph_property");
710 subg_prop->AdjustSubgraphNode(node_list, subgraph_selector, subgraph_id);
711 }
712
713 } // namespace sg
714
715 /*!
716 * \brief Sort entries of all the nodes' inputs vectors in the topological order.
717 * This is going to be used to sort input/output entries of subgraphs to keep
718 * the topological order unchanged.
719 */
TopSortEntries(const nnvm::Graph & g,std::unordered_map<const nnvm::NodeEntry *,size_t> * entry_top_order_map)720 void TopSortEntries(const nnvm::Graph& g,
721 std::unordered_map<const nnvm::NodeEntry*, size_t>* entry_top_order_map) {
722 CHECK(entry_top_order_map != nullptr);
723 std::unordered_set<const nnvm::Node*> visited;
724 // tuple: (graph node, index of node's inputs, node entry as the output of the graph node)
725 std::stack<std::tuple<nnvm::Node*, size_t, const nnvm::NodeEntry*>> s;
726 auto in_degree = [] (const nnvm::Node* node)->size_t {
727 if (!node) {
728 return 0;
729 }
730 CHECK_EQ(node->control_deps.size(), 0U);
731 return node->inputs.size();
732 };
733 for (auto& e : g.outputs) {
734 nnvm::Node* node = e.node.get();
735 if (visited.count(node) == 0U) {
736 s.emplace(node, 0U, &e);
737 visited.insert(node);
738 } else {
739 // The entry's source node has been visited before.
740 // Marking the order for it.
741 entry_top_order_map->emplace(&e, entry_top_order_map->size());
742 }
743 while (!s.empty()) {
744 auto& top = s.top();
745 if (std::get<1>(top) == in_degree(std::get<0>(top))) {
746 // The node's inputs has been exhausted.
747 entry_top_order_map->emplace(std::get<2>(top), entry_top_order_map->size());
748 s.pop();
749 } else {
750 // The node still has input entries not visited.
751 CHECK_LT(std::get<1>(top), std::get<0>(top)->inputs.size());
752 auto& entry = std::get<0>(top)->inputs[std::get<1>(top)++];
753 nnvm::Node* input_node = entry.node.get();
754 if (visited.count(input_node) == 0U) {
755 // The entry's source node has not been visited.
756 // Push the entry to the stack for marking order later.
757 s.emplace(input_node, 0U, &entry);
758 visited.insert(input_node);
759 } else {
760 // The entry's source node has been visited before.
761 // Marking the order for it.
762 entry_top_order_map->emplace(&entry, entry_top_order_map->size());
763 }
764 }
765 }
766 }
767 }
768
BuildSubgraph(nnvm::Graph && g)769 nnvm::Graph BuildSubgraph(nnvm::Graph&& g) {
770 static int verbose = dmlc::GetEnv("MXNET_SUBGRAPH_VERBOSE", 1);
771 if (!g.HasAttr("subgraph_property")) { // treat the whole graph as a subgraph
772 if (verbose > 1) {
773 LOG(INFO) << "The graph has no attribute of subgraph_property attached. "
774 "The original graph is returned.";
775 }
776 return g;
777 }
778 using namespace sg;
779
780 const SubgraphPropertyPtr& subg_prop = g.GetAttr<SubgraphPropertyPtr>("subgraph_property");
781 if (verbose > 1) {
782 const std::string& prop_name = subg_prop->HasAttr("property_name")
783 ? subg_prop->GetAttr<std::string>("property_name")
784 : "partition graph";
785 LOG(INFO) << "start to execute " << prop_name << ".";
786 }
787 // top sort NodeEntry of all the nodes' inputs
788 std::unordered_map<const nnvm::NodeEntry*, size_t> entry_top_order_map;
789 TopSortEntries(g, &entry_top_order_map);
790
791 // Create double directional graph for ease of finding subgraphs
792 std::vector<BiDirectedNodePtr> simple_nodes;
793 CreateSimpleGraph(g, &simple_nodes);
794 std::vector<std::vector<BiDirectedNode*>> subgraph_nodes;
795 std::vector<SubgraphSelectorV2Ptr> subgraph_selectors;
796 FindSubgraphs(&g, *subg_prop, simple_nodes, &subgraph_nodes, &subgraph_selectors);
797 CHECK_EQ(subgraph_nodes.size(), subgraph_selectors.size());
798 for (size_t i = 0; i < subgraph_nodes.size(); ++i) {
799 #if DEBUG_SUBGRAPH
800 std::set<BiDirectedNode*> simple_node_set(subgraph_nodes[i].begin(), subgraph_nodes[i].end());
801 CHECK_EQ(simple_node_set.size(), subgraph_nodes[i].size());
802 PrintSubgraph(subgraph_nodes[i]);
803 #endif
804 auto ptype = subg_prop->GetPropertyType();
805 if (ptype == SubgraphProperty::SgPropertyType::kCreate) {
806 CreateSubgraphNode(&g, simple_nodes, subgraph_nodes[i], subgraph_selectors[i], i,
807 &entry_top_order_map);
808 } else {
809 CHECK_EQ(ptype, SubgraphProperty::SgPropertyType::kAdjust);
810 AdjustSubgraphNode(&g, subgraph_nodes[i], subgraph_selectors[i], i);
811 }
812 }
813 return g;
814 }
815
816 NNVM_REGISTER_PASS(BuildSubgraph)
817 .describe("Apply a subgraph pass according to the user defined rules "
818 "in a derived class of SubgraphProperty")
819 .set_body(BuildSubgraph)
820 .set_change_graph(true);
821
822
823 } // namespace op
824 } // namespace mxnet
825