1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file build_subgraph.cc
22  * \brief
23  */
24 #include <nnvm/graph.h>
25 #include <nnvm/pass.h>
26 #include <unordered_set>
27 #include <stack>
28 #include <queue>
29 
30 #include "./subgraph_property.h"
31 
32 #define DEBUG_SUBGRAPH 0
33 
34 namespace nnvm {
35 ObjectPtr CreateVariableNode(const std::string& name);
36 }
37 
38 namespace mxnet {
39 namespace op {
40 namespace sg {  // sg stands for subgraph
41 
42 #if DEBUG_SUBGRAPH
PrintSubgraph(const std::vector<BiDirectedNode * > & simple_nodes)43 void PrintSubgraph(const std::vector<BiDirectedNode*>& simple_nodes) {
44   std::string op_names = "";
45   for (size_t i = 0; i < simple_nodes.size(); ++i) {
46     op_names += simple_nodes[i]->node->attrs.name + ' ';
47   }
48   LOG(INFO) << "Subgraph node names: " << op_names;
49 }
50 
PrintNodeEntry(const nnvm::NodeEntry & entry)51 void PrintNodeEntry(const nnvm::NodeEntry& entry) {
52   std::string ret = "NodeEntry: node_name=" + entry.node->attrs.name
53     + ", index=" + std::to_string(entry.index) + ", version=" + std::to_string(entry.version);
54   LOG(INFO) << ret;
55 }
56 
PrintNodeEntries(const std::vector<nnvm::NodeEntry * > & entries)57 void PrintNodeEntries(const std::vector<nnvm::NodeEntry*>& entries) {
58   for (size_t i = 0; i < entries.size(); ++i) {
59     PrintNodeEntry(*entries[i]);
60   }
61 }
62 #endif
63 
64 /*!
65  * \brief Given a MXNet computational graph, create an undirected graph from it.
66  * \param g the MXNet computational graph
67  * \param simple_nodes the nodes of undirected graph in top sorted order
68  */
CreateSimpleGraph(const nnvm::Graph & g,std::vector<BiDirectedNodePtr> * simple_nodes)69 void CreateSimpleGraph(const nnvm::Graph& g,
70                        std::vector<BiDirectedNodePtr>* simple_nodes) {
71   const auto& indexed_graph = g.indexed_graph();
72   simple_nodes->reserve(indexed_graph.num_nodes());
73   DFSVisit(g.outputs, [&](const nnvm::ObjectPtr& node) {
74     BiDirectedNodePtr sn = BiDirectedNode::Create();
75     sn->node = node.get();
76     for (size_t i = 0; i < sn->node->inputs.size(); ++i) {
77       const auto& e = sn->node->inputs[i];
78       const auto input_nid = indexed_graph.node_id(e.node.get());
79       CHECK_LT(input_nid, simple_nodes->size());
80       auto& input_node_outputs = (*simple_nodes)[input_nid]->outputs;
81       auto it = input_node_outputs.find(sn->node);
82       if (it == input_node_outputs.end()) {
83         input_node_outputs.emplace(sn->node, std::vector<size_t>{i});
84       } else {
85         it->second.push_back(i);
86       }
87     }
88     simple_nodes->emplace_back(std::move(sn));
89   });
90 }
91 
92 /*!
93  * \brief Reset labels of the subgraph nodes to the original state
94  * and clear the vector of subgraph nodes.
95  */
ResetNodeLabels(const nnvm::Graph & g,const std::vector<BiDirectedNodePtr> & simple_nodes,std::vector<BiDirectedNode * > * subgraph_nodes)96 void ResetNodeLabels(const nnvm::Graph& g,
97                      const std::vector<BiDirectedNodePtr>& simple_nodes,
98                      std::vector<BiDirectedNode*>* subgraph_nodes) {
99   for (auto n : *subgraph_nodes) {
100     const auto nid = g.indexed_graph().node_id(n->node);
101     simple_nodes[nid]->label = -1;
102   }
103   subgraph_nodes->clear();
104 }
105 
106 /*
107  * \brief Prepare NodeAttr for node. NodeAttr will be used in SubgraphSelectorV2.
108  */
PrepareNodeAttr(const nnvm::Graph & g,const BiDirectedNode & node)109 static const std::shared_ptr<NodeAttr> PrepareNodeAttr(const nnvm::Graph& g,
110                                                        const BiDirectedNode& node) {
111   const auto& indexed_graph = g.indexed_graph();
112   if (g.HasAttr("dtype") && g.HasAttr("shape") && g.HasAttr("dispatch_mode")) {
113     const auto& vdtype = g.GetAttr<nnvm::DTypeVector>("dtype");
114     const auto& vshape = g.GetAttr<mxnet::ShapeVector>("shape");
115     const auto& dispatch_modes = g.GetAttr<mxnet::DispatchModeVector>("dispatch_mode");
116     auto ret = std::make_shared<NodeAttr>();
117     ret->dispatch_mode = dispatch_modes[indexed_graph.node_id(node.node)];
118     for (const auto& e : node.node->inputs) {
119       ret->ishape.emplace_back(vshape[indexed_graph.entry_id(e)]);
120       ret->itype.emplace_back(vdtype[indexed_graph.entry_id(e)]);
121     }
122     return ret;
123   } else {
124     return nullptr;
125   }
126 }
127 
128 /*!
129  * \brief This function traverses the nodes in a computation graph from a starting
130  * node following the input edges and output edges, and marks all nodes that
131  * can be accessed from the starting node. Before the function returns,
132  * it will conduct checking whether there is a loop between the potential subgraph
133  * and the outside nodes. If so, add the node that should break the loop
134  * in excluded_nodes and return false. Otherwise, return true.
135  * \param g the whole graph
136  * \subgraph_selector determines whether the visited node should be choosen or not
137  * \label the label of the current subgraph
138  * \snid node id of the seed simple node
139  * \simple_nodes all simple nodes in the top sorted order
140  * \subgraph_nodes all the nodes belonging to the same subgraph of seed node
141  * \excluded_nodes set of nodes that should be excluded from the current subgraph
142  */
LabelSubgraph(const nnvm::Graph & g,SubgraphSelectorV2Ptr subgraph_selector,const int label,const size_t snid,const std::vector<BiDirectedNodePtr> & simple_nodes,std::vector<BiDirectedNode * > * subgraph_nodes,std::unordered_set<const BiDirectedNode * > * excluded_nodes)143 bool LabelSubgraph(const nnvm::Graph& g, SubgraphSelectorV2Ptr subgraph_selector, const int label,
144                    const size_t snid, const std::vector<BiDirectedNodePtr>& simple_nodes,
145                    std::vector<BiDirectedNode*>* subgraph_nodes,
146                    std::unordered_set<const BiDirectedNode*>* excluded_nodes) {
147   const auto& indexed_graph = g.indexed_graph();
148   std::queue<BiDirectedNode*> node_queue;
149   CHECK_EQ(simple_nodes[snid]->label, -1);
150   simple_nodes[snid]->label = label;
151   node_queue.push(simple_nodes[snid].get());
152   // key: nodes that serve as input/output nodes to the subgraph
153   // value: pair of vectors of nodes in the subgraph. The first vector contains the
154   // output nodes of the key in the subgraph, and the second vector contains the
155   // input nodes of the key in the subgraph.
156   // If a non-subgraph node has inputs from the subgraph and the other non-subgraph node
157   // has outputs to the subgraph, and the first non-subgraph node is an ancestor
158   // of the second non-subgraph node, there exits a cycle.
159   // When breaking the cycle, we want to start from removing the node with the largest node id
160   // in the subgraph.
161   std::unordered_map<const nnvm::Node*,
162     std::pair<std::vector<const nnvm::Node*>,
163               std::vector<const nnvm::Node*>>> non_subgraph_node_map;
164   while (!node_queue.empty()) {
165     BiDirectedNode* cur_node = node_queue.front();
166     node_queue.pop();
167     subgraph_nodes->push_back(cur_node);
168     // get qualified adjacent input nodes
169     for (auto& e : cur_node->node->inputs) {
170       const auto node = e.node.get();
171       const auto nid = indexed_graph.node_id(node);
172       auto snode = simple_nodes[nid].get();
173       CHECK_LT(nid, simple_nodes.size());
174       const bool select_input =
175           (snode->label == -1) && (!excluded_nodes || !excluded_nodes->count(snode)) &&
176           subgraph_selector->SelectInput(*cur_node, *snode, PrepareNodeAttr(g, *snode));
177       if (select_input) {
178         // e.node is a subgraph node
179         snode->label = label;
180         node_queue.push(snode);
181       } else if (snode->label == -1) {
182         // e.node is an input node of the subgraph
183         non_subgraph_node_map[e.node.get()].first.push_back(cur_node->node);
184       }
185     }
186     // get qualified output nodes
187     for (auto it = cur_node->outputs.begin(); it != cur_node->outputs.end(); ++it) {
188       const auto nid = indexed_graph.node_id(it->first);
189       auto snode = simple_nodes[nid].get();
190       CHECK_LT(nid, simple_nodes.size());
191       const bool select_output =
192           (snode->label == -1) && (!excluded_nodes || !excluded_nodes->count(snode)) &&
193           subgraph_selector->SelectOutput(*cur_node, *snode, PrepareNodeAttr(g, *snode));
194       if (select_output) {
195         // it->first is a subgraph node
196         snode->label = label;
197         node_queue.push(snode);
198       } else if (snode->label == -1) {
199         // it->first is an output node of the subgraph
200         non_subgraph_node_map[it->first].second.push_back(cur_node->node);
201       }
202     }
203   }
204   // prepare to check if there is a cycle
205   auto node_cmp = [&] (const nnvm::Node* node1, const nnvm::Node* node2) {
206     return indexed_graph.node_id(node1) < indexed_graph.node_id(node2);
207   };
208   std::vector<const nnvm::Node*> non_subgraph_nodes;
209   non_subgraph_nodes.reserve(non_subgraph_node_map.size());
210   for (auto& kv : non_subgraph_node_map) {
211     auto& output_nodes = kv.second.first;
212     std::sort(output_nodes.begin(), output_nodes.end(), node_cmp);
213     auto& input_nodes = kv.second.second;
214     std::sort(input_nodes.begin(), input_nodes.end(), node_cmp);
215     non_subgraph_nodes.push_back(kv.first);
216   }
217   // check whether there is a cycle between the subgraph and its input/output nodes
218   auto is_ancestor = [&](const nnvm::Node* ancestor, const nnvm::Node* descendant,
219                          const std::vector<BiDirectedNode*>& snodes) {
220     if (ancestor == descendant) return true;
221     std::unordered_set<nnvm::Node*> snode_set;
222     for (const auto& sn : snodes) {
223       snode_set.insert(sn->node);
224     }
225     std::stack<const nnvm::Node*> s;
226     s.push(descendant);
227     size_t count = 0;
228     while (!s.empty() && count < indexed_graph.num_nodes()) {
229       ++count;
230       const nnvm::Node* top = s.top();
231       s.pop();
232       if (top == ancestor) {
233         return true;
234       }
235       for (const auto& entry : top->inputs) {
236         // when searching for the ancestor, the path cannot cross any subgraph node
237         if (!snode_set.count(entry.node.get())) {
238           s.push(entry.node.get());
239         }
240       }
241     }
242     return false;
243   };
244   std::sort(non_subgraph_nodes.begin(), non_subgraph_nodes.end(), node_cmp);
245   int excluded_node_id = -1;
246   for (size_t i = 0; i < non_subgraph_nodes.size(); ++i) {
247     auto it1 = non_subgraph_node_map.find(non_subgraph_nodes[i]);
248     CHECK(it1 != non_subgraph_node_map.end());
249     auto& output_nodes = it1->second.first;  // has been top sorted
250     auto& input_nodes = it1->second.second;  // has been top sorted
251     if (!output_nodes.empty() && !input_nodes.empty()) {
252       // there is a loop between node i and the subgraph
253       const auto node_id = std::max(indexed_graph.node_id(output_nodes.back()),
254                                     indexed_graph.node_id(input_nodes.back()));
255       excluded_node_id = std::max(excluded_node_id, static_cast<int>(node_id));
256     } else if (!input_nodes.empty()) {
257       // node i is an input to the subgraph, find out if there is a node j
258       // which is an output of the subgraph and also a child of node i.
259       for (size_t j = i + 1; j < non_subgraph_nodes.size(); ++j) {
260         auto it2 = non_subgraph_node_map.find(non_subgraph_nodes[j]);
261         CHECK(it2 != non_subgraph_node_map.end());
262         // i is topologically before j, j might be a direct/indirect output node of i
263         CHECK_LT(indexed_graph.node_id(it1->first), indexed_graph.node_id(it2->first));
264         if (!it2->second.first.empty() && is_ancestor(it1->first, it2->first, *subgraph_nodes)) {
265           // found a loop
266           const auto node_id = std::max(indexed_graph.node_id(input_nodes.back()),
267                                         indexed_graph.node_id(it2->second.first.back()));
268           excluded_node_id = std::max(excluded_node_id, static_cast<int>(node_id));
269         }
270       }
271     }
272   }
273 
274   if (excluded_node_id != -1) {
275     CHECK_LT(excluded_node_id, static_cast<int>(simple_nodes.size()));
276     excluded_nodes->insert(simple_nodes[excluded_node_id].get());
277     ResetNodeLabels(g, simple_nodes, subgraph_nodes);
278     return false;
279   }
280   auto sim_node_cmp = [&] (const BiDirectedNode* node1, const BiDirectedNode* node2) {
281     return indexed_graph.node_id(node1->node) < indexed_graph.node_id(node2->node);
282   };
283   std::sort(subgraph_nodes->begin(), subgraph_nodes->end(), sim_node_cmp);
284   return true;
285 }
286 
287 /*!
288  * \brief Finds all the nodes belonging to the same subgraph given a seed node.
289  * \param g the whole graph
290  * \subgraph_selector determines whether the visited node should be choosen or not
291  * \label the label of the current subgraph
292  * \snid node id of the seed simple node
293  * \simple_nodes all simple nodes in the top sorted order
294  * \subgraph_nodes all the nodes belonging to the same subgraph of seed node
295  * \return Subgraph node candidates sorted in the topological order
296  */
PreSelectSubgraphNodes(const nnvm::Graph & g,SubgraphSelectorV2Ptr subgraph_selector,const int label,const size_t snid,const std::vector<BiDirectedNodePtr> & simple_nodes,std::vector<BiDirectedNode * > * subgraph_nodes)297 void PreSelectSubgraphNodes(const nnvm::Graph& g, SubgraphSelectorV2Ptr subgraph_selector,
298                             const int label, const size_t snid,
299                             const std::vector<BiDirectedNodePtr>& simple_nodes,
300                             std::vector<BiDirectedNode*>* subgraph_nodes) {
301   std::unordered_set<const BiDirectedNode*> excluded_nodes;
302   size_t n_excluded_nodes = 0;
303   const size_t max_num_retry = simple_nodes.size() * simple_nodes.size();
304   size_t count = 0;
305   bool success = false;
306   while (!success && count < max_num_retry) {
307     success = LabelSubgraph(g, subgraph_selector, label, snid, simple_nodes, subgraph_nodes,
308                             &excluded_nodes);
309     if (!success) {
310       // Failed to label subgraph due to a cycle
311       // If the number of excluded_nodes didn't change since the last iteration,
312       // this means that there is no possible subgraph for the current node snid, we break
313       // Otherwise, we keep trying (with the excluded nodes tagged)
314       if (excluded_nodes.size() == n_excluded_nodes) {
315         break;
316       }
317       n_excluded_nodes = excluded_nodes.size();
318       std::string excluded_node_names;
319       for (auto node : excluded_nodes) {
320         excluded_node_names += node->node->attrs.name + ", ";
321       }
322       static int verbose = dmlc::GetEnv("MXNET_SUBGRAPH_VERBOSE", 1);
323       if (verbose > 1) {
324         LOG(INFO) << "Found a cycle when BFS from node " << simple_nodes[snid]->node->attrs.name
325                   << ". Excluding nodes " << excluded_node_names << "and retrying";
326       }
327       subgraph_selector->Reset();
328     }
329     ++count;
330   }
331   if (!success) {
332     LOG(INFO) << "Tried " << count << " times of finding subgraphs starting from node "
333               << simple_nodes[snid]->node->attrs.name
334               << " without success because a loop "
335                  "is always found between the subgraph and some other nodes. Will treat "
336                  "seed node "
337               << simple_nodes[snid]->node->attrs.name << "as a subgraph with one node";
338     CHECK(subgraph_nodes->empty());
339     simple_nodes[snid]->label = label;
340     subgraph_nodes->push_back(simple_nodes[snid].get());
341   }
342 }
343 
SelectSubgraphNodes(nnvm::Graph * g,SubgraphSelectorV2Ptr subgraph_selector,const std::vector<BiDirectedNodePtr> & simple_nodes,std::vector<std::vector<BiDirectedNode * >> * subgraph_nodes,std::vector<SubgraphSelectorV2Ptr> * subgraph_selectors,const BiDirectedNode * node,const size_t snid,size_t * subgraph_id)344 void SelectSubgraphNodes(nnvm::Graph* g, SubgraphSelectorV2Ptr subgraph_selector,
345                          const std::vector<BiDirectedNodePtr>& simple_nodes,
346                          std::vector<std::vector<BiDirectedNode*>>* subgraph_nodes,
347                          std::vector<SubgraphSelectorV2Ptr>* subgraph_selectors,
348                          const BiDirectedNode* node, const size_t snid, size_t* subgraph_id) {
349   const auto& indexed_graph = g->indexed_graph();
350 
351   auto node_cmp = [&] (const BiDirectedNode* node1, const BiDirectedNode* node2) {
352     return indexed_graph.node_id(node1->node) < indexed_graph.node_id(node2->node);
353   };
354   if ((simple_nodes[snid]->label == -1) &&
355       subgraph_selector->Select(*node, PrepareNodeAttr(*g, *node))) {
356     // pre-select nodes that can be grouped in a subgraph
357     std::vector<BiDirectedNode*> preselected_nodes;
358     PreSelectSubgraphNodes(*g, subgraph_selector, *subgraph_id, snid, simple_nodes,
359                             &preselected_nodes);
360 
361     // filter out unqualified pre-selected nodes
362     std::vector<BiDirectedNode*> filtered_nodes = subgraph_selector->Filter(preselected_nodes);
363 
364     // reset node labels that are not in filtered nodes
365     for (const auto n : preselected_nodes) {
366       const auto nit = std::find(filtered_nodes.begin(), filtered_nodes.end(), n);
367       if (nit == filtered_nodes.end()) {
368         n->label = -1;
369       }
370     }
371 
372     if (filtered_nodes.size()) {
373       // make sure filtered_nodes is a subset of preselected_nodes
374       for (const auto n : filtered_nodes) {
375         const auto nit = std::find(preselected_nodes.begin(), preselected_nodes.end(), n);
376         CHECK(nit != preselected_nodes.end())
377             << "Node " << n->node->attrs.name
378             << " is not found in the pre-selected subgraph nodes."
379                " Please make sure that no new nodes were added in your subgraph"
380                " selector's Filter function";
381       }
382 
383       // make sure nodes are sorted
384       std::sort(filtered_nodes.begin(), filtered_nodes.end(), node_cmp);
385       subgraph_nodes->push_back(filtered_nodes);
386       subgraph_selectors->push_back(subgraph_selector);
387       (*subgraph_id)++;
388     }
389   }
390 }
391 
392 /*!
393  * \brief Finds subgraphs with all nodes that meet certain criteria.
394  * All nodes in a subgraph are marked with the same label.
395  */
FindSubgraphs(nnvm::Graph * g,const SubgraphProperty & subg_prop,const std::vector<BiDirectedNodePtr> & simple_nodes,std::vector<std::vector<BiDirectedNode * >> * subgraph_nodes,std::vector<SubgraphSelectorV2Ptr> * subgraph_selectors)396 void FindSubgraphs(nnvm::Graph* g,
397                    const SubgraphProperty &subg_prop,
398                    const std::vector<BiDirectedNodePtr>& simple_nodes,
399                    std::vector<std::vector<BiDirectedNode*>>* subgraph_nodes,
400                    std::vector<SubgraphSelectorV2Ptr>* subgraph_selectors) {
401   const auto& indexed_graph = g->indexed_graph();
402   CHECK_EQ(indexed_graph.num_nodes(), simple_nodes.size());
403 
404   size_t subgraph_id = 0;
405   for (size_t i = 0; i < simple_nodes.size(); ++i) {
406     const auto snode = simple_nodes[i];
407     SubgraphSelectorV2Ptr subgraph_selector = subg_prop.CreateSubgraphSelectorV2();
408     SelectSubgraphNodes(g, subgraph_selector, simple_nodes, subgraph_nodes, subgraph_selectors,
409                         snode.get(), i, &subgraph_id);
410   }
411 }
412 
413 /*!
414  * \brief Sorts entries according to their topological order.
415  * Note that entry ids cannot be used to sort entries.
416  * \param entry_top_order_map mapping from entry pointer to its topological position in the graph
417  * \param entries Node entries to be sorted
418  */
SortEntries(const std::unordered_map<const nnvm::NodeEntry *,size_t> & entry_top_order_map,std::vector<nnvm::NodeEntry * > * entries)419 void SortEntries(const std::unordered_map<const nnvm::NodeEntry*, size_t>& entry_top_order_map,
420                  std::vector<nnvm::NodeEntry*>* entries) {
421   auto entry_cmp = [&](const nnvm::NodeEntry* e1, const nnvm::NodeEntry* e2) {
422     const auto it1 = entry_top_order_map.find(e1);
423     CHECK(it1 != entry_top_order_map.end());
424     const auto it2 = entry_top_order_map.find(e2);
425     CHECK(it2 != entry_top_order_map.end());
426     return it1->second < it2->second;
427   };
428   std::sort(entries->begin(), entries->end(), entry_cmp);
429 }
430 
431 /*!
432  * \brief Given a subgraph, find the input entries of a subgraph.
433  * \param g pointer to the whole graph
434  * \param simple_nods vector of simple nodes in top sorted order
435  * \param subgraph_nodes vector of pointers of simples of a subgraph.
436  * \param entry_top_order_map mapping entry pointer to its top sorted position
437  * \param input_entries input entries of the subgraph
438  */
FindInputEntries(const nnvm::Graph & g,const std::vector<BiDirectedNodePtr> & simple_nodes,const std::vector<BiDirectedNode * > & subgraph_nodes,const std::unordered_map<const nnvm::NodeEntry *,size_t> & entry_top_order_map,std::vector<nnvm::NodeEntry * > * input_entries)439 void FindInputEntries(const nnvm::Graph& g,
440                       const std::vector<BiDirectedNodePtr>& simple_nodes,
441                       const std::vector<BiDirectedNode*>& subgraph_nodes,
442                       const std::unordered_map<const nnvm::NodeEntry*, size_t>& entry_top_order_map,
443                       std::vector<nnvm::NodeEntry*>* input_entries) {
444   const auto& indexed_graph = g.indexed_graph();
445   int label = -1;
446   for (auto subgraph_node : subgraph_nodes) {
447     if (label == -1) {
448       label = subgraph_node->label;
449     } else {
450       CHECK_EQ(subgraph_node->label, label);
451     }
452     auto& inputs = subgraph_node->node->inputs;
453     for (auto &e : inputs) {
454       if (indexed_graph.exist(e.node.get())) {
455         // e's source node is not a subgraph node
456         const auto nid = indexed_graph.node_id(e.node.get());
457         // this is a node not belonging to the subgraph
458         if (simple_nodes[nid]->label != label) {
459           input_entries->push_back(&e);
460         }
461       } else {
462         // e's source node is a subgraph node.
463         // In this case, two subgraphs are adjacent.
464         input_entries->push_back(&e);
465       }
466     }
467   }
468   SortEntries(entry_top_order_map, input_entries);
469 }
470 
471 /*!
472  * \brief Given a subgraph, find the output entries of a subgraph.
473  * \param g pointer to the whole graph
474  * \param simple_nods vector of simple nodes in top sorted order
475  * \param subgraph_nodes vector of pointers of simples of a subgraph.
476  * \param entry_top_order_map mapping entry pointer to its top sorted position
477  * \param output_entries output entries of the subgraph
478  */
FindOutputEntries(nnvm::Graph * g,const std::vector<BiDirectedNodePtr> & simple_nodes,const std::vector<BiDirectedNode * > & subgraph_nodes,const std::unordered_map<const nnvm::NodeEntry *,size_t> & entry_top_order_map,std::vector<nnvm::NodeEntry * > * output_entries)479 void FindOutputEntries(nnvm::Graph* g,
480                        const std::vector<BiDirectedNodePtr>& simple_nodes,
481                        const std::vector<BiDirectedNode*>& subgraph_nodes,
482                        const std::unordered_map<const nnvm::NodeEntry*, size_t>&
483                          entry_top_order_map,
484                        std::vector<nnvm::NodeEntry*>* output_entries) {
485   if (subgraph_nodes.empty()) return;
486   const auto& indexed_graph = g->indexed_graph();
487   int label = -1;
488   for (auto subgraph_node : subgraph_nodes) {
489     if (label == -1) {
490       label = subgraph_node->label;
491     } else {
492       CHECK_EQ(subgraph_node->label, label);
493     }
494     for (auto &output_node : subgraph_node->outputs) {
495       if (indexed_graph.exist(output_node.first)) {
496         // if the output node is a normal graph node (not a subgraph node)
497         const auto nid = indexed_graph.node_id(output_node.first);
498         // this is a node not belonging to the current subgraph
499         if (simple_nodes[nid]->label != label) {
500           for (auto idx : output_node.second) {
501             auto& e = simple_nodes[nid]->node->inputs[idx];
502             output_entries->push_back(&e);
503           }
504         }
505       } else {
506         // if the output node is a subgraph node
507         // two graphs are adjacent
508         for (auto idx : output_node.second) {
509           output_entries->push_back(&(output_node.first->inputs[idx]));
510         }
511       }
512     }
513   }
514   // Check if current subgraph contains a node which is the last node
515   // of the whole graph. If so, save its corresponding entry as well.
516   for (auto &entry : g->outputs) {
517     // The entry might has been updated as an output of
518     // a subgraph node. In this case, no need
519     // to check its source for the current subgraph. Otherwise,
520     // do the following.
521     if (indexed_graph.exist(entry.node.get())) {
522       const auto nid = indexed_graph.node_id(entry.node.get());
523       if (simple_nodes[nid]->label == label) {
524         output_entries->push_back(&entry);
525       }
526     }
527   }
528   SortEntries(entry_top_order_map, output_entries);
529 }
530 
531 /*!
532  * \brief Given a computation graph and a set of input node entries, this function cuts
533  * the node entries and creates new variable nodes as the input nodes of the
534  * subgraph. It returns the nodes that connect to the subgraph directly and
535  * the names of the new variable nodes.
536  */
CutGraphInputs(const std::vector<nnvm::NodeEntry * > & input_entries,std::vector<nnvm::NodeEntry> * orig_entries,std::vector<nnvm::NodeEntry> * unique_orig_entries,std::vector<nnvm::NodeEntry * > * unique_input_entries,const bool skip_var=false,const bool dedup=false)537 void CutGraphInputs(const std::vector<nnvm::NodeEntry*> &input_entries,
538                     std::vector<nnvm::NodeEntry> *orig_entries,
539                     std::vector<nnvm::NodeEntry> *unique_orig_entries,
540                     std::vector<nnvm::NodeEntry*> *unique_input_entries,
541                     const bool skip_var = false,
542                     const bool dedup = false) {
543   orig_entries->resize(input_entries.size());
544   // map for creating unique var nodes for deduplicating entries from the same node
545   std::unordered_map<std::string, nnvm::NodeEntry> name_map;
546   std::unordered_map<std::string, int> name_count_map;
547 
548   for (size_t i = 0; i < input_entries.size(); ++i) {
549     nnvm::NodeEntry *e = input_entries[i];
550     // If the node is a variable itself, we may want to skip the node.
551     if (e->node->is_variable() && skip_var) {
552       continue;
553     }
554     // save all original entries
555     orig_entries->at(i) = *e;
556     // get unique name for this entry
557     nnvm::Symbol sym;
558     sym.outputs.push_back(*e);
559     const auto output_names = sym.ListOutputNames();
560     CHECK_EQ(output_names.size(), 1U);
561     const std::string& var_name = output_names[0];
562     // check if this entry is a duplicate
563     if (name_count_map.count(var_name) == 0) {
564       // first use of this node as input to subgraph
565       name_count_map.emplace(var_name, 0);
566       unique_orig_entries->push_back(*e);
567       unique_input_entries->push_back(e);
568       nnvm::ObjectPtr n = nnvm::CreateVariableNode(var_name + std::to_string(0));
569       name_map.emplace(var_name, nnvm::NodeEntry{n, 0, 0});
570     } else {
571       // other use of same node as input to subgraph
572       name_count_map[var_name]++;
573     }
574 
575     if (dedup) {
576       *e = name_map[var_name];
577     } else {
578       nnvm::ObjectPtr n = nnvm::CreateVariableNode(
579         var_name + std::to_string(name_count_map[var_name]));
580       *e = nnvm::NodeEntry{n, 0, 0};
581     }
582   }
583 }
584 
585 /*!
586  * \brief This function reattaches the original input nodes that were cut
587  * by CutGraphInputs. This function is used when subgraphs are rejected, it
588  * reattaches the subgraph back to the main graph where it was cut earlier.
589  */
ReattachGraphInputs(const std::vector<nnvm::NodeEntry * > & input_entries,std::vector<nnvm::NodeEntry> * orig_entries)590 void ReattachGraphInputs(const std::vector<nnvm::NodeEntry*> &input_entries,
591                          std::vector<nnvm::NodeEntry> *orig_entries) {
592   for (size_t i = 0; i < input_entries.size(); ++i) {
593     nnvm::NodeEntry *e = input_entries[i];
594     *e = orig_entries->at(i);
595   }
596 }
597 
598 /*!
599  * \brief Replace a set of nodes belonging to the same subgraph with a subgraph node
600  * and keep the subgraph in the subgraph node.
601  */
CreateSubgraphNode(nnvm::Graph * g,const std::vector<BiDirectedNodePtr> & simple_nodes,const std::vector<BiDirectedNode * > & subgraph_nodes,const SubgraphSelectorV2Ptr & subgraph_selector,const size_t subgraph_id,std::unordered_map<const nnvm::NodeEntry *,size_t> * entry_top_order_map)602 void CreateSubgraphNode(nnvm::Graph* g,
603                         const std::vector<BiDirectedNodePtr>& simple_nodes,
604                         const std::vector<BiDirectedNode*>& subgraph_nodes,
605                         const SubgraphSelectorV2Ptr& subgraph_selector,
606                         const size_t subgraph_id,
607                         std::unordered_map<const nnvm::NodeEntry*, size_t>* entry_top_order_map) {
608 #if DEBUG_SUBGRAPH
609   LOG(INFO) << "Searching for input entries...";
610 #endif
611   bool dedup_subgraph = g->HasAttr("dedup_subgraph");
612   std::vector<nnvm::NodeEntry*> input_entries;  // nodes that produce inputs to subgraph nodes
613   FindInputEntries(*g, simple_nodes, subgraph_nodes, *entry_top_order_map, &input_entries);
614   std::vector<nnvm::NodeEntry> orig_input_entries;  // original input entries (dupes)
615   std::vector<nnvm::NodeEntry> unique_orig_entries;  // unique original input entries
616   std::vector<nnvm::NodeEntry*> unique_input_entries;  // unique modified subgraph inputs
617   CutGraphInputs(input_entries, &orig_input_entries, &unique_orig_entries,
618                  &unique_input_entries, false, dedup_subgraph);
619 #if DEBUG_SUBGRAPH
620   PrintNodeEntries(input_entries);
621   LOG(INFO) << "Searching for output entries...";
622 #endif
623   std::vector<nnvm::NodeEntry*> output_entries;
624   FindOutputEntries(g, simple_nodes, subgraph_nodes, *entry_top_order_map, &output_entries);
625 
626   // Create a subgraph for the subgraph node
627   // entries are in topological order, with duplicates being neighbors
628   nnvm::Symbol sym;
629   size_t idx = 0;
630   nnvm::NodeEntryEqual node_equal;
631   sym.outputs.resize(output_entries.size());
632   for (size_t i = 0; i < output_entries.size(); ++i) {
633     if (dedup_subgraph) {
634       if (i == 0) {  // add first entry
635         sym.outputs[idx] = *output_entries[i];
636       } else if (!node_equal(sym.outputs[idx], *output_entries[i])) {  // compare to see if diff
637         // add new entries
638         idx++;
639         sym.outputs[idx] = *output_entries[i];
640       }  // else skip over dupe entries
641     } else {
642       sym.outputs[i] = *output_entries[i];
643     }
644   }
645   if (dedup_subgraph)
646     sym.outputs.resize(idx+1);
647 
648   const SubgraphPropertyPtr& subg_prop = g->GetAttr<SubgraphPropertyPtr>("subgraph_property");
649   if (dedup_subgraph)
650     subg_prop->InitSubgraphInputs(&unique_input_entries, &unique_orig_entries);
651   else
652     subg_prop->InitSubgraphInputs(&input_entries, &orig_input_entries);
653   nnvm::ObjectPtr n = subg_prop->CreateSubgraphNode(sym, subgraph_selector, subgraph_id);
654   // CreateSubgraphNode returns NULL if subgraph property determines that subgraph is sub-optimal
655   // In that case, subgraph node is not created and graph is not modified
656   if (n) {
657     // Connect the external nodes to the subgraph node.
658     subg_prop->ConnectSubgraphOutputs(n, &output_entries);
659     if (dedup_subgraph)
660       subg_prop->ConnectSubgraphInputs(n, &unique_input_entries, &unique_orig_entries);
661     else
662       subg_prop->ConnectSubgraphInputs(n, &input_entries, &orig_input_entries);
663 
664     const auto& indexed_graph = g->indexed_graph();
665     for (size_t i = 0; i < n->inputs.size(); ++i) {
666       auto& e = n->inputs[i];
667       // update entry_top_order_map with newly created orig_input_entries
668       auto it = entry_top_order_map->find(input_entries[i]);
669       CHECK(it != entry_top_order_map->end());
670       entry_top_order_map->emplace(&e, it->second);
671       // update input entries' source simple nodes' outputs map
672       nnvm::Node* node = e.node.get();
673       if (indexed_graph.exist(node)) {
674         const auto nid = indexed_graph.node_id(node);
675         BiDirectedNode* sn = simple_nodes[nid].get();
676         for (BiDirectedNode* dest_node : subgraph_nodes) {
677           sn->outputs.erase(dest_node->node);
678         }
679         sn->outputs[n.get()].push_back(i);
680       }
681     }
682   } else {
683     ReattachGraphInputs(input_entries, &orig_input_entries);
684   }
685 #if DEBUG_SUBGRAPH
686   if (n)
687     LOG(INFO) << "Subgraph node created and output_entries updated.";
688   else
689     LOG(INFO) << "Subgraph node not created, output_entries not updated.";
690   PrintNodeEntries(output_entries);
691 #endif
692 }
693 
694 /*!
695  * \brief Adjust a set of nodes belonging to the same subgraph. No new node is created, but
696  * adjust selected nodes' attributes.
697  * This can be used to implement peephole optimization. For example, adjust calibration information
698  * of quantized nodes.
699  */
AdjustSubgraphNode(nnvm::Graph * g,const std::vector<BiDirectedNode * > & subgraph_nodes,const SubgraphSelectorV2Ptr & subgraph_selector,const size_t subgraph_id)700 void AdjustSubgraphNode(nnvm::Graph* g,
701                         const std::vector<BiDirectedNode*>& subgraph_nodes,
702                         const SubgraphSelectorV2Ptr& subgraph_selector,
703                         const size_t subgraph_id) {
704   std::vector<nnvm::Node*> node_list;
705   for (auto node : subgraph_nodes) {
706     node_list.push_back(node->node);
707   }
708 
709   const SubgraphPropertyPtr& subg_prop = g->GetAttr<SubgraphPropertyPtr>("subgraph_property");
710   subg_prop->AdjustSubgraphNode(node_list, subgraph_selector, subgraph_id);
711 }
712 
713 }  // namespace sg
714 
715 /*!
716  * \brief Sort entries of all the nodes' inputs vectors in the topological order.
717  * This is going to be used to sort input/output entries of subgraphs to keep
718  * the topological order unchanged.
719  */
TopSortEntries(const nnvm::Graph & g,std::unordered_map<const nnvm::NodeEntry *,size_t> * entry_top_order_map)720 void TopSortEntries(const nnvm::Graph& g,
721                     std::unordered_map<const nnvm::NodeEntry*, size_t>* entry_top_order_map) {
722   CHECK(entry_top_order_map != nullptr);
723   std::unordered_set<const nnvm::Node*> visited;
724   // tuple: (graph node, index of node's inputs, node entry as the output of the graph node)
725   std::stack<std::tuple<nnvm::Node*, size_t, const nnvm::NodeEntry*>> s;
726   auto in_degree = [] (const nnvm::Node* node)->size_t {
727     if (!node) {
728       return 0;
729     }
730     CHECK_EQ(node->control_deps.size(), 0U);
731     return node->inputs.size();
732   };
733   for (auto& e : g.outputs) {
734     nnvm::Node* node = e.node.get();
735     if (visited.count(node) == 0U) {
736       s.emplace(node, 0U, &e);
737       visited.insert(node);
738     } else {
739       // The entry's source node has been visited before.
740       // Marking the order for it.
741       entry_top_order_map->emplace(&e, entry_top_order_map->size());
742     }
743     while (!s.empty()) {
744       auto& top = s.top();
745       if (std::get<1>(top) == in_degree(std::get<0>(top))) {
746         // The node's inputs has been exhausted.
747         entry_top_order_map->emplace(std::get<2>(top), entry_top_order_map->size());
748         s.pop();
749       } else {
750         // The node still has input entries not visited.
751         CHECK_LT(std::get<1>(top), std::get<0>(top)->inputs.size());
752         auto& entry = std::get<0>(top)->inputs[std::get<1>(top)++];
753         nnvm::Node* input_node = entry.node.get();
754         if (visited.count(input_node) == 0U) {
755           // The entry's source node has not been visited.
756           // Push the entry to the stack for marking order later.
757           s.emplace(input_node, 0U, &entry);
758           visited.insert(input_node);
759         } else {
760           // The entry's source node has been visited before.
761           // Marking the order for it.
762           entry_top_order_map->emplace(&entry, entry_top_order_map->size());
763         }
764       }
765     }
766   }
767 }
768 
BuildSubgraph(nnvm::Graph && g)769 nnvm::Graph BuildSubgraph(nnvm::Graph&& g) {
770     static int verbose = dmlc::GetEnv("MXNET_SUBGRAPH_VERBOSE", 1);
771   if (!g.HasAttr("subgraph_property")) {  // treat the whole graph as a subgraph
772     if (verbose > 1) {
773       LOG(INFO) << "The graph has no attribute of subgraph_property attached. "
774                    "The original graph is returned.";
775     }
776     return g;
777   }
778   using namespace sg;
779 
780   const SubgraphPropertyPtr& subg_prop = g.GetAttr<SubgraphPropertyPtr>("subgraph_property");
781   if (verbose > 1) {
782     const std::string& prop_name = subg_prop->HasAttr("property_name")
783                                        ? subg_prop->GetAttr<std::string>("property_name")
784                                        : "partition graph";
785     LOG(INFO) << "start to execute " << prop_name << ".";
786   }
787   // top sort NodeEntry of all the nodes' inputs
788   std::unordered_map<const nnvm::NodeEntry*, size_t> entry_top_order_map;
789   TopSortEntries(g, &entry_top_order_map);
790 
791   // Create double directional graph for ease of finding subgraphs
792   std::vector<BiDirectedNodePtr> simple_nodes;
793   CreateSimpleGraph(g, &simple_nodes);
794   std::vector<std::vector<BiDirectedNode*>> subgraph_nodes;
795   std::vector<SubgraphSelectorV2Ptr> subgraph_selectors;
796   FindSubgraphs(&g, *subg_prop, simple_nodes, &subgraph_nodes, &subgraph_selectors);
797   CHECK_EQ(subgraph_nodes.size(), subgraph_selectors.size());
798   for (size_t i = 0; i < subgraph_nodes.size(); ++i) {
799 #if DEBUG_SUBGRAPH
800     std::set<BiDirectedNode*> simple_node_set(subgraph_nodes[i].begin(), subgraph_nodes[i].end());
801     CHECK_EQ(simple_node_set.size(), subgraph_nodes[i].size());
802     PrintSubgraph(subgraph_nodes[i]);
803 #endif
804     auto ptype = subg_prop->GetPropertyType();
805     if (ptype == SubgraphProperty::SgPropertyType::kCreate) {
806       CreateSubgraphNode(&g, simple_nodes, subgraph_nodes[i], subgraph_selectors[i], i,
807                          &entry_top_order_map);
808     } else {
809       CHECK_EQ(ptype, SubgraphProperty::SgPropertyType::kAdjust);
810       AdjustSubgraphNode(&g, subgraph_nodes[i], subgraph_selectors[i], i);
811     }
812   }
813   return g;
814 }
815 
816 NNVM_REGISTER_PASS(BuildSubgraph)
817 .describe("Apply a subgraph pass according to the user defined rules "
818           "in a derived class of SubgraphProperty")
819 .set_body(BuildSubgraph)
820 .set_change_graph(true);
821 
822 
823 }  // namespace op
824 }  // namespace mxnet
825