1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file graph_executor.cc
22  * \brief graph executor
23  */
24 #include <mxnet/base.h>
25 #include <nnvm/graph.h>
26 #include <nnvm/pass_functions.h>
27 #include <vector>
28 #include <set>
29 #include <algorithm>
30 
31 #include "./exec_pass.h"
32 #include "./graph_executor.h"
33 #include "./cuda_graphs.h"
34 #include "../profiler/profiler.h"
35 #include "../common/utils.h"
36 #include "../common/exec_utils.h"
37 #include "../operator/subgraph/subgraph_property.h"
38 #include "../operator/operator_common.h"
39 
40 namespace mxnet {
41 namespace exec {
42 
43 using namespace mxnet::common;
44 
GetDefaultSubgraphBackend()45 static const std::string GetDefaultSubgraphBackend() {
46 #if MXNET_USE_MKLDNN == 1
47   return std::string("MKLDNN");
48 #else
49   return std::string();
50 #endif
51 }
52 
GraphExecutor(const nnvm::Symbol & symbol)53 GraphExecutor::GraphExecutor(const nnvm::Symbol& symbol) {
54   log_verbose_ = dmlc::GetEnv("MXNET_EXEC_VERBOSE_LOGGING", false);
55   need_grad_ = false;
56   is_dynamic_ = false;
57   subgraph_property_ = dmlc::GetEnv("MXNET_SUBGRAPH_BACKEND", GetDefaultSubgraphBackend());
58   if (subgraph_property_ == "NONE") {
59     subgraph_property_ = std::string();
60     LOG(INFO) << "MXNET_SUBGRAPH_BACKEND=NONE is detected, subgraph backend is not in use";
61   }
62   engine_ref_ = Engine::_GetSharedRef();
63   symbol_ = symbol.Copy();
64 }
65 
~GraphExecutor()66 GraphExecutor::~GraphExecutor() {
67   for (auto& n : op_nodes_) {
68     if (n.cached_opr != nullptr) {
69       Engine::Get()->DeleteOperator(n.cached_opr);
70     }
71   }
72   // clean up seg ops
73   for (auto& seg : cached_seg_opr_) {
74     if (seg.opr != nullptr) {
75       Engine::Get()->DeleteOperator(seg.opr);
76     }
77   }
78 }
79 
Forward(bool is_train)80 void GraphExecutor::Forward(bool is_train) {
81   RunOps(is_train, 0, num_forward_nodes_);
82 }
83 
PartialForward(bool is_train,int step,int * step_left)84 void GraphExecutor::PartialForward(bool is_train, int step, int *step_left) {
85   size_t sstep = static_cast<size_t>(step);
86   if (sstep >= num_forward_nodes_) {
87     *step_left = 0;
88     return;
89   }
90   RunOps(is_train, sstep, sstep + 1);
91   *step_left = static_cast<int>(num_forward_nodes_ - sstep - 1);
92 }
93 
Backward(const std::vector<NDArray> & head_grads,bool is_train)94 void GraphExecutor::Backward(const std::vector<NDArray>& head_grads, bool is_train) {
95   {
96     const auto& idx = graph_.indexed_graph();
97     if (num_forward_inputs_ != idx.input_nodes().size()) {
98       for (size_t i = 0; i < head_grad_array_.size(); ++i) {
99         if (!head_grad_array_[i].is_none()) {
100           CHECK(i < head_grads.size() && !head_grads[i].is_none())
101               << "Because the last operator is not Loss function, "
102               << "head_gradient is required when calling backward. "
103               << "If you are attempting to minimize the output as "
104               << "an objective, please modify your network and "
105               << "pass it through the make_loss symbol.";
106           const NDArray &from = head_grads[i];
107           NDArray &to = head_grad_array_[i];
108           if (this->is_dynamic_) {
109             to.WaitToRead();
110             if (!shape_is_known(to.shape())) {
111               to.Init(from.shape());
112             }
113           }
114           CopyFromTo(from, &to);
115         }
116       }
117     }
118   }
119   if (this->is_dynamic_) {
120     graph_ = InferShape(std::move(graph_), {}, "");
121     mxnet::ShapeVector rshape = graph_.MoveCopyAttr<mxnet::ShapeVector>("shape");
122     const auto& idx = graph_.indexed_graph();
123     for (size_t nid = 0; nid < idx.num_nodes(); ++nid) {
124       const auto& inode = idx[nid];
125       if (inode.source->is_variable()) continue;
126       OpNode& opnode = op_nodes_[nid];
127       if (opnode.skip_exec_node) continue;
128       for (NDArray &array : opnode.exec->in_array) {
129         array.WaitToRead();
130         if (!shape_is_known(array.shape())) {
131           array.SetShapeFromChunk();
132         }
133       }
134       int i = 0;
135       for (NDArray &array : opnode.exec->in_array) {
136         array.WaitToRead();
137         if (!shape_is_known(array.shape())) {
138           array.SetShapeFromChunk();
139         }
140         if (!shape_is_known(array.shape())) {
141           mxnet::TShape shape = rshape[idx.entry_id(inode.inputs[i])];
142           if (shape_is_known(shape)) {
143             array.ReshapeAndAlloc(shape);
144           }
145         }
146         ++i;
147       }
148       i = 0;
149       for (NDArray &array : opnode.exec->out_array) {
150         array.WaitToRead();
151         if (!shape_is_known(array.shape())) {
152           array.SetShapeFromChunk();
153         }
154         if (!shape_is_known(array.shape())) {
155           mxnet::TShape shape = rshape[idx.entry_id(nid, i)];
156           if (shape_is_known(shape)) {
157             array.ReshapeAndAlloc(shape);
158           }
159         }
160         ++i;
161       }
162     }
163     graph_.attrs["shape"] = std::make_shared<dmlc::any>(rshape);
164   }
165   const auto& idx = graph_.indexed_graph();
166   RunOps(is_train, num_forward_nodes_, idx.num_nodes());
167 }
168 
Print(std::ostream & os) const169 void GraphExecutor::Print(std::ostream &os) const {  // NOLINT(*)
170   nnvm::Symbol s;
171   s.outputs = graph_.outputs;
172   s.Print(os);
173   // message to be backward compatible with the memonger
174   size_t total_bytes = graph_.GetAttr<size_t>("storage_allocated_bytes");
175   os << "Total " << (total_bytes >> 20UL) << " MB allocated\n";
176   os << "Total " << 11 << " TempSpace resource requested\n";
177 }
178 
179 /*!
180  * \brief Return the "optimized" symbol contained in the executor graph.
181  */
GetOptimizedSymbol()182 nnvm::Symbol GraphExecutor::GetOptimizedSymbol() {
183   Symbol ret;
184   ret.outputs = std::vector<nnvm::NodeEntry>(graph_.outputs.begin(),
185       graph_.outputs.begin() + num_forward_outputs_);
186   return ret.Copy();
187 }
188 
SetMonitorCallback(const MonitorCallback & callback,bool monitor_all)189 void GraphExecutor::SetMonitorCallback(const MonitorCallback& callback, bool monitor_all) {
190   CHECK(callback) << "invalid callback";
191   monitor_callback_ = callback;
192   monitor_all_ = monitor_all;
193 }
194 
outputs() const195 const std::vector<NDArray>& GraphExecutor::outputs() const {
196   if (this->is_dynamic_) {
197     for (const NDArray &array : output_arrays_) {
198       array.WaitToRead();
199       if (!shape_is_known(array.shape())) {
200         const_cast<NDArray &>(array).SetShapeFromChunk();
201       }
202     }
203   }
204   return output_arrays_;
205 }
206 
in_arg_map() const207 const std::unordered_map<std::string, NDArray>& GraphExecutor::in_arg_map() const {
208   return in_arg_map_;
209 }
210 
arg_grad_map() const211 const std::unordered_map<std::string, NDArray>& GraphExecutor::arg_grad_map() const {
212   return arg_grad_map_;
213 }
214 
aux_state_map() const215 const std::unordered_map<std::string, NDArray>& GraphExecutor::aux_state_map() const {
216   return aux_state_map_;
217 }
218 
AttrHint(nnvm::NodeEntry src,nnvm::NodeEntry like)219 static nnvm::NodeEntry AttrHint(nnvm::NodeEntry src, nnvm::NodeEntry like) {
220   static const Op* id_like = Op::Get("_identity_with_attr_like_rhs");
221   nnvm::ObjectPtr n = nnvm::Node::Create();
222   n->attrs.op = id_like;
223   n->attrs.name = src.node->attrs.name + "_id";
224   n->inputs = {src, like};
225   return nnvm::NodeEntry{n, 0, 0};
226 }
227 
AggregateGradient(std::vector<nnvm::NodeEntry> && v)228 nnvm::NodeEntry AggregateGradient(std::vector<nnvm::NodeEntry>&& v) {
229   using nnvm::Op;
230   static size_t inplace_sum_cap = dmlc::GetEnv("MXNET_EXEC_INPLACE_GRAD_SUM_CAP", 8);
231   static const Op* ewise_plus_op = Op::Get("_grad_add");
232   static const Op* ewise_sum_op = Op::Get("ElementWiseSum");
233   static const Op* identity_op = Op::Get("identity");
234   static const Op* zeros_op = Op::Get("_zeros");
235   static const Op* zeros_like_op = Op::Get("zeros_like");
236 
237   if (v.empty()) {
238     nnvm::ObjectPtr ng = nnvm::Node::Create();
239     ng->attrs.op = Op::Get("_zeros_without_dtype");
240     ng->attrs.name = "zeros_without_dtype";
241     ng->attrs.op->attr_parser(&(ng->attrs));
242     return nnvm::NodeEntry(std::move(ng), 0, 0);
243   }
244 
245   // remove zero in the sum. at least keep 1.
246   auto begin = std::remove_if(v.begin(), v.end(), [](const nnvm::NodeEntry& nodeEntry) {
247      CHECK(nodeEntry.node);
248      return nodeEntry.node->op() == zeros_op || nodeEntry.node->op() == zeros_like_op;
249   });
250   if (begin == v.begin()) ++begin;
251   v.erase(begin, v.end());
252   CHECK(!v.empty());
253 
254   if (v.size() == 1) {
255     return std::move(v[0]);
256   } else {
257     if (v.size() < inplace_sum_cap) {
258       nnvm::ObjectPtr sum_node = nnvm::Node::Create();
259       sum_node->attrs.op = ewise_sum_op;
260       sum_node->attrs.name = "sum_grad";
261       sum_node->attrs.dict["num_args"] = std::to_string(v.size());
262       sum_node->attrs.op->attr_parser(&(sum_node->attrs));
263       sum_node->inputs = std::move(v);
264       return nnvm::NodeEntry(std::move(sum_node), 0, 0);
265     } else {
266       // use a stream line of plus instead
267       nnvm::NodeEntry ret = v[0];
268       for (size_t i = 1; i < v.size(); ++i) {
269         // Add control flow dependency from to previous node
270         // This enforces the gradient sum order will be in the inverse
271         // order of forward traversal
272         // NOTE: adding control dependency can be dangerous and cause cycle in the dep.
273         // The curent usage is correct, because of the following invariant:
274         // assert: v[i-1] do not depend on v[i]
275         // To put in plain text: v is gradient vector that get pushed in the order
276         // that can generate them, which means if v[i] is not yet pushed,
277         // all previous gradient cannot depend on it.
278         // Note: For a symbol like the following:
279         // data = mx.sym.Variable('data')
280         // sym = data + data + data + data + data + data + data
281         // the node entries v passed in here are of the same node of
282         // op _identity_with_attr_like_rhs. We should skip adding a node
283         // to its own control_deps.
284         if (v[i-1].node != v[i].node) {
285           v[i].node->control_deps.push_back(ret.node);
286         }
287 
288         std::ostringstream os;
289         os << "sum_grad_" << i;
290         nnvm::ObjectPtr x = nnvm::Node::Create();
291         x->attrs.op = ewise_plus_op;
292         x->attrs.name = os.str();
293         x->inputs = {ret, v[i]};
294         ret = nnvm::NodeEntry(std::move(x), 0, 0);
295       }
296       // identity node is used to avoid exposure of dummy plus node
297       // when its output get assigned to another space.
298       nnvm::ObjectPtr id_node = nnvm::Node::Create();
299       id_node->attrs.op = identity_op;
300       id_node->attrs.name = "sum_grad_final";
301       id_node->inputs = {ret};
302       return nnvm::NodeEntry{id_node, 0, 0};
303     }
304   }
305 }
306 
307 template<typename ValueType>
get_node_attr(const nnvm::Node & node,const std::string & key,ValueType default_value)308 inline ValueType get_node_attr(
309     const nnvm::Node& node,
310     const std::string& key, ValueType default_value) {
311   auto it = node.attrs.dict.find(key);
312   if (it == node.attrs.dict.end()) {
313     return default_value;
314   } else {
315     ValueType ret;
316     dmlc::parameter::FieldEntry<ValueType> e;
317     e.Init(key, &ret, ret);
318     e.Set(&ret, it->second);
319     return ret;
320   }
321 }
322 
323 /*!
324  * \brief Create the graph for backward pass.
325  * This is triggered by both simple_bind and bind flows.
326  */
InitFullGraph(nnvm::Symbol symbol,const std::vector<OpReqType> & grad_req_types)327 nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol,
328                                          const std::vector<OpReqType>& grad_req_types) {
329   using nnvm::ObjectPtr;
330   using nnvm::NodeEntry;
331   // initial information
332   num_forward_outputs_ = symbol.outputs.size();
333   num_forward_inputs_ = symbol.ListInputs(nnvm::Symbol::kAll).size();
334 
335   nnvm::Graph g;
336   g.outputs = symbol.outputs;
337   bool do_elim_common_expr = dmlc::GetEnv("MXNET_ELIMINATE_COMMON_EXPR", true);
338   if (do_elim_common_expr)
339     g = exec::EliminateCommonExpr(std::move(g));
340   need_grad_ = false;
341   for (OpReqType req : grad_req_types) {
342     if (req != kNullOp)
343       need_grad_ = true;
344   }
345   if (!need_grad_) return g;
346   for (size_t i = 0; i < g.outputs.size(); ++i) {
347     NodeEntry ngrad(nnvm::Node::Create(), 0, 0);
348     ngrad.node->attrs.name = "_head_grad_" + std::to_string(i);
349     head_grad_entry_.emplace_back(AttrHint(ngrad, g.outputs[i]));
350     head_grad_map_[ngrad.node.get()] = i;
351   }
352   std::vector<ObjectPtr> args = symbol.ListInputs(nnvm::Symbol::kReadOnlyArgs);
353   std::vector<NodeEntry> xs;
354   for (size_t i = 0; i < grad_req_types.size(); ++i) {
355     if (grad_req_types[i] != kNullOp) {
356       xs.emplace_back(args[i]);
357     }
358   }
359 
360   int do_mirror = dmlc::GetEnv("MXNET_BACKWARD_DO_MIRROR", 0);
361   auto need_mirror = [do_mirror](const nnvm::Node& node) -> int {
362     if (node.is_variable()) return 0;
363     const std::string& type = node.attrs.op->name;
364     if (type == "Dropout") return false;
365     if (get_node_attr(node, "__force_mirroring__", false)) return true;
366     if (do_mirror == 0) return false;
367     if (type == "Convolution") return false;
368     if (type == "FullyConnected") return false;
369     if (type == "Concat") return false;
370     if (type == "SoftmaxOutput") return false;
371     return true;
372   };
373 
374   std::vector<const nnvm::Op*> zero_ops;
375   zero_ops.push_back(nnvm::Op::Get("zeros_like"));
376   zero_ops.push_back(nnvm::Op::Get("_zeros"));
377 
378   // take gradient
379   nnvm::Graph g_grad = nnvm::pass::MXGradient(
380       g, symbol.outputs, xs, head_grad_entry_,
381       AggregateGradient, need_mirror, nullptr,
382       zero_ops, "_copy");
383   CHECK_EQ(g_grad.outputs.size(), xs.size());
384   for (const auto &e : g_grad.outputs) {
385     g.outputs.push_back(e);
386   }
387 
388   return g;
389 }
390 
391 /*!
392  * \brief GraphExecutor initializer for regular bind flow in which
393  * input arguments and gradients are provided by users. This initializer
394  * uses the user provided NDArrays to populate data entries of the graph.
395  */
Init(nnvm::Symbol symbol,const Context & default_ctx,const std::map<std::string,Context> & ctx_map,const std::vector<NDArray> & in_args,const std::vector<NDArray> & arg_grad_store,const std::vector<OpReqType> & grad_req_types,const std::vector<NDArray> & aux_states,Executor * shared_exec,const nnvm::NodeEntryMap<NDArray> & feed_dict)396 void GraphExecutor::Init(nnvm::Symbol symbol,
397                          const Context& default_ctx,
398                          const std::map<std::string, Context>& ctx_map,
399                          const std::vector<NDArray>& in_args,
400                          const std::vector<NDArray>& arg_grad_store,
401                          const std::vector<OpReqType>& grad_req_types,
402                          const std::vector<NDArray>& aux_states,
403                          Executor* shared_exec,
404                          const nnvm::NodeEntryMap<NDArray>& feed_dict) {
405   // create in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes
406   auto get_ctx1 = [](const NDArray& nd) { return nd.ctx(); };
407   auto get_ctx2 = [default_ctx](const NDArray& nd) -> Context {
408     if (nd.is_none()) return default_ctx;
409     return nd.ctx();
410   };
411   std::vector<Context> in_arg_ctxes(in_args.size());
412   std::transform(in_args.begin(), in_args.end(), in_arg_ctxes.begin(), get_ctx1);
413   std::vector<Context> arg_grad_ctxes(arg_grad_store.size());
414   std::transform(arg_grad_store.begin(), arg_grad_store.end(), arg_grad_ctxes.begin(), get_ctx2);
415   std::vector<Context> aux_state_ctxes(aux_states.size());
416   std::transform(aux_states.begin(), aux_states.end(), aux_state_ctxes.begin(), get_ctx1);
417 
418   nnvm::Graph g = InitGraph(symbol, default_ctx, ctx_map, in_arg_ctxes,
419                             arg_grad_ctxes, aux_state_ctxes, grad_req_types);
420 
421   // create arg_shapes and arg_dtypes for shape and type inferences
422   const auto& idx = g.indexed_graph();
423   const auto& mutable_nodes = idx.mutable_input_nodes();
424   size_t arg_top = 0, aux_top = 0;
425   data_entry_.resize(idx.num_node_entries());
426   mxnet::ShapeVector arg_shapes;
427   nnvm::DTypeVector arg_dtypes;
428   StorageTypeVector arg_stypes(idx.num_node_entries(), -1);
429   for (size_t i = 0; i < num_forward_inputs_; ++i) {
430     const uint32_t nid = idx.input_nodes().at(i);
431     const std::string& arg_name = idx[nid].source->attrs.name;
432     size_t eid = idx.entry_id(nid, 0);
433     if (mutable_nodes.count(nid)) {
434       CHECK_LT(aux_top, aux_states.size());
435       data_entry_[eid] = aux_states[aux_top];
436       arg_shapes.push_back(aux_states[aux_top].shape());
437       arg_dtypes.push_back(aux_states[aux_top].dtype());
438       arg_stypes[eid] = aux_states[aux_top].storage_type();
439       aux_state_map_.emplace(arg_name, aux_states[aux_top]);
440       ++aux_top;
441     } else {
442       CHECK_LT(arg_top, in_args.size());
443       data_entry_[eid] = in_args[arg_top];
444       arg_shapes.push_back(in_args[arg_top].shape());
445       arg_dtypes.push_back(in_args[arg_top].dtype());
446       arg_stypes[eid] = in_args[arg_top].storage_type();
447       in_arg_map_.emplace(arg_name, in_args[arg_top]);
448       if (kNullOp != grad_req_types[arg_top]) {
449         auto grad_oid = grad_store_.size() + num_forward_outputs_;
450         auto grad_eid = idx.entry_id(idx.outputs()[grad_oid]);
451         arg_stypes[grad_eid] = arg_grad_store[arg_top].storage_type();
452         grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_store[arg_top]);
453         arg_grad_map_.emplace(arg_name, arg_grad_store[arg_top]);
454         if (log_verbose_) {
455           LOG(INFO) << "\tassign data entry\t" << grad_eid << " as "
456                     << common::stype_string(arg_stypes[grad_eid]) << " (grad)";
457         }
458       }
459       ++arg_top;
460     }
461     if (log_verbose_) {
462       LOG(INFO) << "\tassign data entry\t" << eid << " as "
463                 << common::stype_string(data_entry_[eid].storage_type()) << " (input)";
464     }
465   }
466 
467   // expand arg_shapes and arg_dtypes to contain backward inputs
468   arg_shapes.resize(idx.input_nodes().size(), mxnet::TShape());
469   g = InferShape(std::move(g), std::move(arg_shapes), "__shape__");
470   if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
471     this->is_dynamic_ = true;
472   }
473 
474   arg_dtypes.resize(idx.input_nodes().size(), -1);
475   g = InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
476   if (g.GetAttr<size_t>("dtype_num_unknown_nodes") != 0U) {
477     HandleInferTypeError(num_forward_inputs_, g.indexed_graph(),
478                          g.GetAttr<nnvm::DTypeVector>("dtype"));
479   }
480 
481   g.attrs["storage_type"] = std::make_shared<dmlc::any>(std::move(arg_stypes));
482   g = InferStorageType(std::move(g), StorageTypeVector(), "");
483   if (g.GetAttr<size_t>("storage_type_num_unknown_nodes") != 0U) {
484     HandleInferStorageTypeError(num_forward_inputs_, g.indexed_graph(),
485                                 g.GetAttr<StorageTypeVector>("storage_type"));
486   }
487 
488   // Initialize the rest attributes of the graph.
489   // This function can be called by regular bind
490   // operation flow as well.
491   FinishInitGraph(symbol, g, shared_exec, feed_dict);
492 }
493 
494 /*!
495  * \brief Initialize in_args, arg_grads, and aux_states
496  * and their data_entry_ of the executor. This function
497  * is called for regular simple_bind flow, i.e. no
498  * shared data arrays are provided.
499  */
InitArguments(const nnvm::IndexedGraph & idx,const mxnet::ShapeVector & inferred_shapes,const nnvm::DTypeVector & inferred_dtypes,const StorageTypeVector & inferred_stypes,const std::vector<Context> & in_arg_ctxes,const std::vector<Context> & arg_grad_ctxes,const std::vector<Context> & aux_state_ctxes,const std::vector<OpReqType> & grad_req_types,std::vector<NDArray> * in_arg_vec,std::vector<NDArray> * arg_grad_vec,std::vector<NDArray> * aux_state_vec)500 void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx,
501                                   const mxnet::ShapeVector& inferred_shapes,
502                                   const nnvm::DTypeVector& inferred_dtypes,
503                                   const StorageTypeVector& inferred_stypes,
504                                   const std::vector<Context>& in_arg_ctxes,
505                                   const std::vector<Context>& arg_grad_ctxes,
506                                   const std::vector<Context>& aux_state_ctxes,
507                                   const std::vector<OpReqType>& grad_req_types,
508                                   std::vector<NDArray>* in_arg_vec,
509                                   std::vector<NDArray>* arg_grad_vec,
510                                   std::vector<NDArray>* aux_state_vec) {
511   // initialize in_args, arg_grads, and aux_states
512   // populate grad_store_
513   data_entry_.resize(idx.num_node_entries());
514   size_t arg_top = 0, aux_top = 0;
515   const auto& mutable_nodes = idx.mutable_input_nodes();
516   for (size_t i = 0; i < num_forward_inputs_; ++i) {
517     const uint32_t nid = idx.input_nodes().at(i);
518     const uint32_t eid = idx.entry_id(nid, 0);
519     const mxnet::TShape& inferred_shape = inferred_shapes[eid];
520     const int inferred_dtype = inferred_dtypes[eid];
521     const NDArrayStorageType inferred_stype = (NDArrayStorageType) inferred_stypes[eid];
522     const std::string& arg_name = idx[nid].source->attrs.name;
523     if (mutable_nodes.count(nid)) {  // aux_states
524       EmplaceBackZeros(inferred_stype, inferred_shape, aux_state_ctxes[aux_top],
525                        inferred_dtype, aux_state_vec);
526       data_entry_[eid] = aux_state_vec->back();
527       aux_state_map_.emplace(arg_name, aux_state_vec->back());
528       ++aux_top;
529       if (log_verbose_) {
530         LOG(INFO) << "\tassign aux entry\t" << eid << "\t as "
531                   << common::stype_string(inferred_stype);
532       }
533     } else {  // in_args
534       EmplaceBackZeros(inferred_stype, inferred_shape, in_arg_ctxes[arg_top],
535                        inferred_dtype, in_arg_vec);
536       data_entry_[eid] = in_arg_vec->back();
537       if (log_verbose_) {
538         LOG(INFO) << "\tassign data entry\t" << eid << "\tas "
539                   << common::stype_string(inferred_stype);
540       }
541       // Get the storage type for grad
542       if (kNullOp == grad_req_types[arg_top]) {
543         arg_grad_vec->emplace_back();
544       } else {
545         // Init based on storage type
546         auto grad_oid = grad_store_.size() + num_forward_outputs_;
547         auto grad_eid = idx.entry_id(idx.outputs()[grad_oid]);
548         auto grad_stype = (NDArrayStorageType) inferred_stypes[grad_eid];
549         EmplaceBackZeros(grad_stype, inferred_shape, arg_grad_ctxes[arg_top],
550                          inferred_dtype, arg_grad_vec);
551         if (log_verbose_) {
552           LOG(INFO) << "\tassign grad entry\t" << grad_eid << "\tas "
553                     << common::stype_string(grad_stype);
554         }
555         grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back());
556         arg_grad_map_.emplace(arg_name, arg_grad_vec->back());
557       }
558       in_arg_map_.emplace(arg_name, in_arg_vec->back());
559       ++arg_top;
560     }
561   }
562 }
563 
564 /*!
565  * \brief Initialize in_args, arg_grads, and aux_states
566  * and their data_entry_ of the executor using
567  * shared_buffer from DataParallelExecutorGroup
568  * and shared_exec if available.
569  */
InitArguments(const nnvm::IndexedGraph & idx,const mxnet::ShapeVector & inferred_shapes,const nnvm::DTypeVector & inferred_dtypes,const StorageTypeVector & inferred_stypes,const std::vector<Context> & in_arg_ctxes,const std::vector<Context> & arg_grad_ctxes,const std::vector<Context> & aux_state_ctxes,const std::vector<OpReqType> & grad_req_types,const std::unordered_set<std::string> & shared_arg_names,const Executor * shared_exec,std::unordered_map<std::string,NDArray> * shared_buffer,std::vector<NDArray> * in_arg_vec,std::vector<NDArray> * arg_grad_vec,std::vector<NDArray> * aux_state_vec)570 void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx,
571                                   const mxnet::ShapeVector& inferred_shapes,
572                                   const nnvm::DTypeVector& inferred_dtypes,
573                                   const StorageTypeVector& inferred_stypes,
574                                   const std::vector<Context>& in_arg_ctxes,
575                                   const std::vector<Context>& arg_grad_ctxes,
576                                   const std::vector<Context>& aux_state_ctxes,
577                                   const std::vector<OpReqType>& grad_req_types,
578                                   const std::unordered_set<std::string>& shared_arg_names,
579                                   const Executor* shared_exec,
580                                   std::unordered_map<std::string, NDArray>* shared_buffer,
581                                   std::vector<NDArray>* in_arg_vec,
582                                   std::vector<NDArray>* arg_grad_vec,
583                                   std::vector<NDArray>* aux_state_vec) {
584   // initialize in_args, arg_grads, and aux_states and populate grad_store_
585   data_entry_.resize(idx.num_node_entries());
586   size_t arg_top = 0, aux_top = 0;
587   const auto& mutable_nodes = idx.mutable_input_nodes();
588   for (size_t i = 0; i < num_forward_inputs_; ++i) {
589     const uint32_t nid = idx.input_nodes().at(i);
590     const uint32_t eid = idx.entry_id(nid, 0);
591     const mxnet::TShape& inferred_shape = inferred_shapes[eid];
592     const int inferred_dtype = inferred_dtypes[eid];
593     const NDArrayStorageType inferred_stype = (NDArrayStorageType) inferred_stypes[eid];
594     const std::string& arg_name = idx[nid].source->attrs.name;
595     // aux_states
596     if (mutable_nodes.count(nid)) {
597       if (nullptr != shared_exec) {
598         const NDArray& aux_nd = shared_exec->aux_state_map().at(arg_name);
599         CHECK(inferred_stype == kDefaultStorage && aux_nd.storage_type() == kDefaultStorage)
600           << "Non-default storage type detected when creating auxilliary NDArray. The allocated "
601           << "memory of shared_exec.aux_array cannot be resued for argument: "
602           << arg_name << " for the current executor";
603         CHECK_EQ(inferred_shape, aux_nd.shape())
604           << "Inferred shape does not match shared_exec.aux_array's shape."
605              " Therefore, the allocated memory for shared_exec.aux_array cannot"
606              " be resued for creating auxilliary NDArray of the argument: "
607           << arg_name << " for the current executor";
608         CHECK_EQ(inferred_dtype, aux_nd.dtype())
609           << "Inferred dtype does not match shared_exec.aux_array's dtype."
610              " Therefore, the allocated memory for shared_exec.aux_array cannot"
611              " be resued for creating auxilliary NDArray of the argument: "
612           << arg_name << " for the current executor";
613         aux_state_vec->emplace_back(aux_nd);
614       } else {
615         EmplaceBackZeros(inferred_stype, inferred_shape, aux_state_ctxes[aux_top],
616                          inferred_dtype, aux_state_vec);
617       }  // if (has_shared_exec)
618       data_entry_[eid] = aux_state_vec->back();
619       aux_state_map_.emplace(arg_name, aux_state_vec->back());
620       ++aux_top;
621     } else {  // in_args and grad for in_args
622       if (shared_arg_names.count(arg_name)) {  // model parameter
623         // model parameter
624         if (nullptr != shared_exec) {
625           const NDArray& in_arg_nd = shared_exec->in_arg_map().at(arg_name);
626           auto arg_nd_stype = in_arg_nd.storage_type();
627           // for model parameter, both default storage and row_sparse storage can be shared
628           bool shareable_arg_stype = inferred_stype == kDefaultStorage ||
629                                      inferred_stype == kRowSparseStorage;
630           // try to reuse memory from shared_exec
631           CHECK(shareable_arg_stype) << "Inferred storage type "
632             << common::stype_string(inferred_stype)
633             << " does not support memory sharing with shared_exec.arg_array";
634           CHECK_EQ(inferred_stype, arg_nd_stype)
635             << "Inferred stype does not match shared_exec.arg_array's stype"
636                " Therefore, the allocated memory for shared_exec.arg_array cannot"
637                " be resued for creating NDArray of the argument "
638             << arg_name << " for the current executor";
639           CHECK_EQ(inferred_shape, in_arg_nd.shape())
640             << "Inferred shape does not match shared_exec.arg_array's shape"
641                " Therefore, the allocated memory for shared_exec.arg_array cannot"
642                " be resued for creating NDArray of the argument "
643             << arg_name << " for the current executor";
644           CHECK_EQ(inferred_dtype, in_arg_nd.dtype())
645             << "Inferred dtype does not match shared_exec.arg_array's dtype"
646                " Therefore, the allocated memory for shared_exec.arg_array cannot"
647                " be resued for creating NDArray of the argument "
648             << arg_name << " for the current executor";
649           in_arg_vec->emplace_back(in_arg_nd);
650         } else {
651           // doesn't have shared_exec, or non-default storage
652           EmplaceBackZeros(inferred_stype, inferred_shape, in_arg_ctxes[arg_top],
653                            inferred_dtype, in_arg_vec);
654         }
655         // gradient for model parameter
656         if (kNullOp == grad_req_types[arg_top]) {
657           arg_grad_vec->emplace_back();
658         } else {
659           auto grad_oid = grad_store_.size() + num_forward_outputs_;
660           auto grad_eid = idx.entry_id(idx.outputs()[grad_oid]);
661           auto grad_stype = (NDArrayStorageType) inferred_stypes[grad_eid];
662           if (nullptr != shared_exec && grad_stype == kDefaultStorage &&
663               shared_exec->arg_grad_map().at(arg_name).storage_type() == kDefaultStorage) {
664             // try to reuse memory from shared_exec
665             arg_grad_vec->emplace_back(shared_exec->arg_grad_map().at(arg_name));
666           } else {
667             // no need to reuse memory from shared_exec for gradient of non-default storage
668             EmplaceBackZeros(grad_stype, inferred_shape, arg_grad_ctxes[arg_top],
669                              inferred_dtype, arg_grad_vec);
670           }
671           grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back());
672         }
673       } else {  // !shared_arg_names.count(arg_name)
674         // model parameter, row_sparse ndarray sharing enabled
675         bool enable_row_sparse_sharing = true;
676         in_arg_vec->emplace_back(ReshapeOrCreate(arg_name, inferred_shape, inferred_dtype,
677                                                  inferred_stype, in_arg_ctxes[arg_top],
678                                                  shared_buffer, enable_row_sparse_sharing));
679         // gradient for model parameter, row_sparse ndarray sharing disabled
680         if (kNullOp == grad_req_types[arg_top]) {
681           arg_grad_vec->emplace_back();
682         } else {
683           auto grad_oid = grad_store_.size() + num_forward_outputs_;
684           auto grad_eid = idx.entry_id(idx.outputs()[grad_oid]);
685           auto grad_stype = (NDArrayStorageType) inferred_stypes[grad_eid];
686           bool enable_row_sparse_sharing = false;
687           arg_grad_vec->emplace_back(ReshapeOrCreate("grad of " + arg_name, inferred_shape,
688                                                      inferred_dtype, grad_stype,
689                                                      arg_grad_ctxes[arg_top], shared_buffer,
690                                                      enable_row_sparse_sharing));
691           grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back());
692         }  // if (kNullOp == grad_req_types[arg_top])
693       }  // if (shared_arg_names.count(arg_name))
694       in_arg_map_.emplace(arg_name, in_arg_vec->back());
695       if (!arg_grad_vec->back().is_none()) {
696         arg_grad_map_.emplace(arg_name, arg_grad_vec->back());
697       }
698       data_entry_[eid] = in_arg_vec->back();
699       ++arg_top;
700     }
701   }
702 }
703 
704 /*!
705  * \brief Finish graph initialization after shape and dtype inferences.
706  * This function is used by both simple_bind and bind flows.
707  */
FinishInitGraph(nnvm::Symbol symbol,nnvm::Graph g,Executor * shared_exec,const nnvm::NodeEntryMap<NDArray> & feed_dict)708 void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol,
709                                     nnvm::Graph g,
710                                     Executor* shared_exec,
711                                     const nnvm::NodeEntryMap<NDArray>& feed_dict) {
712   const auto& idx = g.indexed_graph();
713   const auto& vstorage_type = g.GetAttr<StorageTypeVector>("storage_type");
714 
715   // data entries for output gradients
716   for (size_t j = num_forward_outputs_; j < idx.outputs().size(); ++j) {
717     data_entry_[idx.entry_id(idx.outputs()[j])] = grad_store_[j - num_forward_outputs_].second;
718   }
719 
720   {
721     // memory allocator
722     nnvm::StorageVector arg_storage_id(idx.num_node_entries(), kBadStorageID);
723     for (size_t j = num_forward_outputs_; j < idx.outputs().size(); ++j) {
724       arg_storage_id[idx.entry_id(idx.outputs()[j])] = kExternalStorageID;
725     }
726     for (const auto& kv : feed_dict) {
727       uint32_t eid = idx.entry_id(kv.first);
728       data_entry_[eid] = kv.second;
729       arg_storage_id[eid] = kExternalStorageID;
730     }
731     for (size_t i = 0; i < idx.num_node_entries(); i++) {
732       if (vstorage_type[i] != kDefaultStorage) arg_storage_id[i] = kDynamicStorageID;
733     }
734     g.attrs["storage"] = std::make_shared<dmlc::any>(std::move(arg_storage_id));
735     g = nnvm::ApplyPass(g, "MXPlanMemory");
736   }
737   g = DetectInplaceAddTo(g);
738 
739   // log the static memory plan of the graph
740   static bool mem_log_verbose = dmlc::GetEnv("MXNET_MEM_PLAN_VERBOSE_LOGGING", false);
741   if (mem_log_verbose) {
742     common::LogMemoryPlan(g);
743   }
744 
745   g = AttachOpExecs(g);
746   AttachOpResources(g);
747   graph_ = std::move(g);
748 
749   if (shared_exec != nullptr) {
750     this->InitDataEntryMemory(&(dynamic_cast<GraphExecutor*>(shared_exec)->data_pool_));
751   } else {
752     this->InitDataEntryMemory(nullptr);
753   }
754 
755   {
756     // initialize output arrays
757     auto& idx = graph_.indexed_graph();
758     for (size_t i = 0; i < num_forward_outputs_; ++i) {
759       auto& e = idx.outputs()[i];
760       output_arrays_.push_back(data_entry_[idx.entry_id(e)]);
761     }
762     // initialize head gradient array
763     head_grad_array_.resize(symbol.outputs.size());
764     for (size_t i = num_forward_inputs_; i < idx.input_nodes().size(); ++i) {
765       uint32_t nid = idx.input_nodes().at(i);
766       uint32_t oid = head_grad_map_.at(idx[nid].source);
767       head_grad_array_[oid] = data_entry_[idx.entry_id(nid, 0)];
768     }
769   }
770   this->InitCachedOps();
771   this->InitOpSegs();
772 }
773 
774 /*!
775  * \brief GraphExecutor initializer for simple bind flow in
776  * which only certain input shapes and dtypes are provided by users.
777  * The initializer uses these shapes and dtypes to perform
778  * shape and dtype inferences, and then create NDArrays
779  * to populate data entries of the graph. The created NDArrays
780  * for in_args, arg_grads and aux_states are passed to the
781  * front end to attach the created executor.
782  * In front end, if the simple_bind flow is trigger by
783  * _bind_ith_exec, the shared data arrays of DataParallelExecutorGroup
784  * and shared executor will be taken into account in creating
785  * NDArrays for in_args, arg_grads, and aux_states for resuing
786  * already allocated memory.
787  */
Init(nnvm::Symbol symbol,const Context & default_ctx,const std::map<std::string,Context> & ctx_map,const std::vector<Context> & in_arg_ctxes,const std::vector<Context> & arg_grad_ctxes,const std::vector<Context> & aux_state_ctxes,const std::unordered_map<std::string,mxnet::TShape> & arg_shape_map,const std::unordered_map<std::string,int> & arg_dtype_map,const std::unordered_map<std::string,int> & arg_stype_map,const std::vector<OpReqType> & grad_req_types,const std::unordered_set<std::string> & shared_arg_names,std::vector<NDArray> * in_arg_vec,std::vector<NDArray> * arg_grad_vec,std::vector<NDArray> * aux_state_vec,std::unordered_map<std::string,NDArray> * shared_buffer,Executor * shared_exec,const nnvm::NodeEntryMap<NDArray> & feed_dict)788 void GraphExecutor::Init(nnvm::Symbol symbol,
789                          const Context& default_ctx,
790                          const std::map<std::string, Context>& ctx_map,
791                          const std::vector<Context>& in_arg_ctxes,
792                          const std::vector<Context>& arg_grad_ctxes,
793                          const std::vector<Context>& aux_state_ctxes,
794                          const std::unordered_map<std::string, mxnet::TShape>& arg_shape_map,
795                          const std::unordered_map<std::string, int>& arg_dtype_map,
796                          const std::unordered_map<std::string, int>& arg_stype_map,
797                          const std::vector<OpReqType>& grad_req_types,
798                          const std::unordered_set<std::string>& shared_arg_names,
799                          std::vector<NDArray>* in_arg_vec,
800                          std::vector<NDArray>* arg_grad_vec,
801                          std::vector<NDArray>* aux_state_vec,
802                          std::unordered_map<std::string, NDArray>* shared_buffer,
803                          Executor* shared_exec,
804                          const nnvm::NodeEntryMap<NDArray>& feed_dict) {
805   nnvm::Graph g = InitGraph(symbol, default_ctx, ctx_map, in_arg_ctxes, arg_grad_ctxes,
806                             aux_state_ctxes, grad_req_types);
807 
808   // The following code of shape and dtype inferences and argument
809   // initialization is for simple_bind only. Regular bind operation
810   // should do this differently.
811 
812   // Initialize arg_shapes and arg_dtypes for shape and type inferences.
813   // It contains all in_args and aux_states' shapes and types in a certain order.
814   const nnvm::IndexedGraph& idx = g.indexed_graph();
815   mxnet::ShapeVector arg_shapes(idx.input_nodes().size(), mxnet::TShape());
816   nnvm::DTypeVector arg_dtypes(idx.input_nodes().size(), -1);
817   StorageTypeVector arg_stypes(idx.input_nodes().size(), kUndefinedStorage);
818   for (size_t i = 0; i < num_forward_inputs_; ++i) {
819     const uint32_t nid = idx.input_nodes().at(i);
820     const std::string& name = idx[nid].source->attrs.name;
821     auto it1 = arg_shape_map.find(name);
822     if (arg_shape_map.end() != it1) {
823       arg_shapes[i] = it1->second;
824     }
825     auto it2 = arg_dtype_map.find(name);
826     if (arg_dtype_map.end() != it2) {
827       arg_dtypes[i] = it2->second;
828     }
829     auto it3 = arg_stype_map.find(name);
830     if (arg_stype_map.end() != it3) {
831       arg_stypes[i] = it3->second;
832     }
833   }
834   g = InferShape(std::move(g), std::move(arg_shapes), "__shape__");
835   if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
836     HandleInferShapeError(num_forward_inputs_, g.indexed_graph(),
837                           g.GetAttr<mxnet::ShapeVector>("shape"));
838   }
839 
840   g = InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
841   if (g.GetAttr<size_t>("dtype_num_unknown_nodes") != 0U) {
842     HandleInferTypeError(num_forward_inputs_, g.indexed_graph(),
843                          g.GetAttr<nnvm::DTypeVector>("dtype"));
844   }
845 
846   g = InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__");
847   if (g.GetAttr<size_t>("storage_type_num_unknown_nodes") != 0U) {
848     HandleInferStorageTypeError(num_forward_inputs_, g.indexed_graph(),
849                                 g.GetAttr<StorageTypeVector>("storage_type"));
850   }
851 
852   // Create in_args, arg_grads, and aux_states using
853   // the inferred shapes and dtypes.
854   if (nullptr == shared_buffer) {  // regular simple bind
855     InitArguments(idx, g.GetAttr<mxnet::ShapeVector>("shape"),
856                   g.GetAttr<nnvm::DTypeVector>("dtype"),
857                   g.GetAttr<StorageTypeVector>("storage_type"),
858                   in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes,
859                   grad_req_types, in_arg_vec, arg_grad_vec, aux_state_vec);
860   } else {  // simple bind using shared data arrays and shared_exec
861     InitArguments(idx, g.GetAttr<mxnet::ShapeVector>("shape"),
862                   g.GetAttr<nnvm::DTypeVector>("dtype"),
863                   g.GetAttr<StorageTypeVector>("storage_type"),
864                   in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes,
865                   grad_req_types, shared_arg_names, shared_exec,
866                   shared_buffer, in_arg_vec, arg_grad_vec, aux_state_vec);
867   }
868   // The above code of shape and dtype inferences and argument
869   // initialization is for simple_bind only. Regular bind operation
870   // should do this differently.
871 
872   // Initialize the rest attributes of the graph.
873   // This function can be called by regular bind
874   // operation flow as well.
875   FinishInitGraph(symbol, g, shared_exec, feed_dict);
876 }
877 
878 /*!
879  * \brief Return a new executor with the same symbol and shared memory,
880  * but different input/output shapes.
881  * For runtime reshaping, variable length sequences, etc.
882  * The returned executor shares state with the current one,
883  * and cannot be used in parallel with it.
884  */
Reshape(const bool partial_shaping,const bool allow_up_sizing,const Context & default_ctx,const std::map<std::string,Context> & ctx_map,const std::unordered_map<std::string,mxnet::TShape> & provided_arg_shapes,std::vector<NDArray> * in_args,std::vector<NDArray> * arg_grads,std::vector<NDArray> * aux_states)885 Executor* GraphExecutor::Reshape(const bool partial_shaping,
886                                  const bool allow_up_sizing,
887                                  const Context& default_ctx,
888                                  const std::map<std::string, Context>& ctx_map,
889                                  const std::unordered_map<std::string, mxnet::TShape>&
890                                    provided_arg_shapes,
891                                  std::vector<NDArray>* in_args,
892                                  std::vector<NDArray>* arg_grads,
893                                  std::vector<NDArray>* aux_states) {
894   nnvm::Graph g;
895   nnvm::Symbol symbol;
896   symbol.outputs = symbol_.outputs;
897   g.outputs = symbol_.outputs;
898   const nnvm::IndexedGraph& idx = g.indexed_graph();
899   mxnet::ShapeVector arg_shapes(idx.input_nodes().size(), mxnet::TShape());
900   for (size_t i = 0; i < num_forward_inputs_; ++i) {
901     const uint32_t nid = idx.input_nodes().at(i);
902     const std::string& name = idx[nid].source->attrs.name;
903     auto it = provided_arg_shapes.find(name);
904     if (provided_arg_shapes.end() != it) {
905       arg_shapes[i] = it->second;
906     }
907   }
908   g = InferShape(std::move(g), std::move(arg_shapes), "__shape__");
909   if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
910     this->is_dynamic_ = true;
911   }
912   const mxnet::ShapeVector& shape_vec = g.GetAttr<mxnet::ShapeVector>("shape");
913   std::vector<OpReqType> grad_req_types;
914   size_t grad_top = 0;
915   const size_t num_args = in_arg_map_.size();
916   const size_t num_aux = aux_state_map_.size();
917   in_args->reserve(num_args);
918   grad_req_types.reserve(num_args);
919   arg_grads->reserve(num_args);
920   aux_states->reserve(num_aux);
921   for (uint32_t nid : idx.input_nodes()) {
922     std::string name = idx[nid].source->attrs.name;
923     const mxnet::TShape& new_shape = shape_vec[idx.entry_id(nid, 0)];
924     if (idx.mutable_input_nodes().count(nid) == 0) {
925       NDArray& arr = in_arg_map_.at(name);
926       auto it = arg_grad_map_.find(name);
927       if (partial_shaping || provided_arg_shapes.count(name) || new_shape == arr.shape()) {
928         if (new_shape.Size() > arr.shape().Size()) {
929           CHECK(allow_up_sizing) << "New shape of arg: " << name << " is larger than original."
930             << "First making a big executor and then down sizing it "
931             << "is more efficient than the reverse."
932             << "If you really want to up size, set allow_up_sizing=True "
933             << "to enable allocation of new arrays.";
934           in_args->emplace_back(new_shape, arr.ctx(), false, arr.dtype());
935           if (it != arg_grad_map_.end()) {
936             NDArray& darr = it->second;
937             arg_grads->emplace_back(new_shape, darr.ctx(), false, darr.dtype());
938             grad_req_types.push_back(grad_store_.at(grad_top++).first);
939           } else {
940             arg_grads->emplace_back();
941             grad_req_types.push_back(kNullOp);
942           }
943         } else {
944           in_args->push_back(arr.Reshape(new_shape));
945           if (it != arg_grad_map_.end()) {
946             NDArray& darr = it->second;
947             arg_grads->push_back(darr.Reshape(new_shape));
948             grad_req_types.push_back(grad_store_.at(grad_top++).first);
949           } else {
950             arg_grads->emplace_back();
951             grad_req_types.push_back(kNullOp);
952           }
953         }
954       } else {
955         LOG(FATAL) << "Shape of unspecifie arg: " << name << " changed. "
956           << "This can cause the new executor to not share parameters "
957           << "with the old one. Please check for error in network."
958           << "If this is intended, set partial_shaping=True to suppress this warning.";
959       }
960     } else {
961       NDArray& arr = aux_state_map_.at(name);
962       if (partial_shaping || new_shape == arr.shape()) {
963         if (new_shape.Size() > arr.shape().Size()) {
964           CHECK(allow_up_sizing) << "New shape of arg: " << name << " is larger than original."
965             << "First making a big executor and then down sizing it "
966             << "is more efficient than the reverse."
967             << "If you really want to up size, set allow_up_sizing=True "
968             << "to enable allocation of new arrays.";
969           aux_states->emplace_back(new_shape, arr.ctx(), false, arr.dtype());
970         } else {
971           aux_states->push_back(arr.Reshape(new_shape));
972         }
973       } else {
974         LOG(FATAL) << "Shape of unspecifie arg: " << name << " changed. "
975           << "This can cause the new executor to not share parameters "
976           << "with the old one. Please check for error in network."
977           << "If this is intended, set partial_shaping=True to suppress this warning.";
978       }
979     }
980   }
981   auto exec = new GraphExecutor(symbol);
982   exec->Init(symbol.Copy(), default_ctx, ctx_map,
983              *in_args, *arg_grads, grad_req_types, *aux_states,
984              this);
985   return exec;
986 }
987 
988 /*!
989  * \brief This function is triggered by both simple_bind
990  * and bind flows.
991  * Setup backward graph, create device and context
992  * attributes in the graph, and calculate the number
993  * of forward nodes.
994  */
InitGraph(nnvm::Symbol symbol,const Context & default_ctx,const std::map<std::string,Context> & ctx_map,const std::vector<Context> & in_arg_ctxes,const std::vector<Context> & arg_grad_ctxes,const std::vector<Context> & aux_state_ctxes,const std::vector<OpReqType> & grad_req_types)995 Graph GraphExecutor::InitGraph(nnvm::Symbol symbol,
996                                const Context& default_ctx,
997                                const std::map<std::string, Context>& ctx_map,
998                                const std::vector<Context>& in_arg_ctxes,
999                                const std::vector<Context>& arg_grad_ctxes,
1000                                const std::vector<Context>& aux_state_ctxes,
1001                                const std::vector<OpReqType>& grad_req_types) {
1002   // setup gradient
1003   nnvm::Graph g = InitFullGraph(symbol, grad_req_types);
1004 
1005 #if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC && !defined(_WIN32)
1006   if (default_ctx.dev_mask() == Context::kGPU && dmlc::GetEnv("MXNET_USE_FUSION", true)) {
1007     nnvm::Graph unoptimized_graph;
1008     common::CopyGraph(&unoptimized_graph, g, false);
1009 
1010     if (common::CheckForInputNameDuplicates(unoptimized_graph.indexed_graph())) {
1011       g = exec::FusePointwise(std::move(g), num_forward_outputs_);
1012       // Check the topological order of inputs
1013       const auto &original_inputs = unoptimized_graph.indexed_graph().input_nodes();
1014       const auto &new_inputs = g.indexed_graph().input_nodes();
1015       if (original_inputs.size() != new_inputs.size()) {
1016         LOG(WARNING)
1017           << "Number of inputs after fusion does not match original number of inputs. "
1018           << "This is most probably a bug. Disabling fusion for this run.";
1019         g = unoptimized_graph;
1020       } else {
1021         for (size_t i = 0; i < new_inputs.size(); ++i) {
1022           if (unoptimized_graph.indexed_graph()[original_inputs[i]].source->attrs.name !=
1023               g.indexed_graph()[new_inputs[i]].source->attrs.name) {
1024             LOG(WARNING) << "Disabling fusion due to altered topological order of inputs.";
1025             g = unoptimized_graph;
1026             break;
1027           }
1028         }
1029       }
1030     } else {
1031       LOG(WARNING)
1032         << "Graph contains duplicate names for some of its inputs - fusion is NOT enabled!";
1033      }
1034   }
1035 #else
1036   // Only warn user if MXNET_USE_FUSION env var is explicitly set
1037   if (default_ctx.dev_mask() == Context::kGPU && dmlc::GetEnv("MXNET_USE_FUSION", false)) {
1038     WarnFusionNotSupported();
1039   }
1040 #endif  // MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC && !defined(_WIN32)
1041 
1042   // create "device" and "context" attrs for the graph
1043   g = AssignContext(g, default_ctx, ctx_map,
1044                     in_arg_ctxes,
1045                     arg_grad_ctxes,
1046                     aux_state_ctxes,
1047                     grad_req_types,
1048                     num_forward_inputs_,
1049                     num_forward_outputs_);
1050 
1051   const auto& idx = g.indexed_graph();
1052   // get number of nodes used in forward pass
1053   num_forward_nodes_ = 0;
1054   for (size_t i = 0; i < num_forward_outputs_; ++i) {
1055     num_forward_nodes_ = std::max(
1056         num_forward_nodes_, static_cast<size_t>(idx.outputs()[i].node_id + 1));
1057   }
1058   return g;
1059 }
1060 
1061 // initialize the memory of each entries
InitDataEntryMemory(std::vector<NDArray> * shared_pool)1062 void GraphExecutor::InitDataEntryMemory(std::vector<NDArray>* shared_pool) {
1063   using nnvm::DTypeVector;
1064   using mxnet::ShapeVector;
1065   using nnvm::StorageVector;
1066   // get the graph
1067   const auto& idx = graph_.indexed_graph();
1068   // get the storage
1069   const auto& vdtype = graph_.GetAttr<DTypeVector>("dtype");
1070   const auto& vshape = graph_.GetAttr<mxnet::ShapeVector>("shape");
1071   const auto& vstorage = graph_.GetAttr<StorageVector>("storage_id");
1072   const auto& vstorage_type = graph_.GetAttr<StorageTypeVector>("storage_type");
1073   const auto& vctx = graph_.GetAttr<ContextVector>("context");
1074   CHECK_EQ(idx.num_node_entries(), vshape.size());
1075   CHECK_EQ(idx.num_node_entries(), vdtype.size());
1076   CHECK_EQ(idx.num_node_entries(), vstorage.size());
1077   CHECK_EQ(data_entry_.size(), vshape.size());
1078   std::vector<Context> data_context(idx.num_node_entries());
1079   std::vector<NDArrayStorageType> data_storage_type(idx.num_node_entries(), kUndefinedStorage);
1080   for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
1081     for (uint32_t i = 0; i < idx[nid].source->num_outputs(); ++i) {
1082       auto eid = idx.entry_id(nid, i);
1083       data_context[eid] = vctx[nid];
1084       CHECK_NE(vstorage_type[eid], kUndefinedStorage);
1085       data_storage_type[eid] = (NDArrayStorageType) vstorage_type[eid];
1086     }
1087   }
1088 
1089   // information about the pool
1090   struct PoolEntry {
1091     Context ctx;
1092     size_t bytes;
1093     NDArrayStorageType stype;
1094   };
1095   std::vector<PoolEntry> pool_info;
1096 
1097   // assign array to head gradient
1098   for (size_t i = num_forward_inputs_; i < idx.input_nodes().size(); ++i) {
1099     uint32_t nid = idx.input_nodes().at(i);
1100     uint32_t oid = head_grad_map_.at(idx[nid].source);
1101     uint32_t eid = idx.entry_id(idx.outputs()[oid]);
1102     NDArrayStorageType stype = (NDArrayStorageType) vstorage_type[eid];
1103     bool unknown_shape = !shape_is_known(vshape[eid]);
1104     CHECK_NE(vdtype[eid], -1);
1105     auto data_eid = idx.entry_id(nid, 0);
1106     // initialize based on storage_type
1107     if (stype != kDefaultStorage) {
1108       data_entry_[data_eid] = NDArray(stype, vshape[eid], data_context[eid], true, vdtype[eid]);
1109     } else if (!unknown_shape) {
1110       data_entry_[data_eid] = NDArray(vshape[eid], data_context[eid], false, vdtype[eid]);
1111     } else {
1112       data_entry_[data_eid] = NDArray(data_context[eid], vdtype[eid]);
1113     }
1114     if (log_verbose_) {
1115       LOG(INFO) << "\tinit head_grad entry\t" << data_eid << "\tas "
1116                 << common::stype_string(stype);
1117     }
1118   }
1119   // get maximum bytes in each pool
1120   for (size_t i = 0; i < vshape.size(); ++i) {
1121     if (!data_entry_[i].is_none()) continue;
1122     size_t shape_size = 0;
1123     if (shape_is_known(vshape[i])) {
1124       shape_size = vshape[i].Size();
1125     }
1126     size_t bytes = shape_size * mshadow::mshadow_sizeof(vdtype[i]);
1127     int storage_id = vstorage[i];
1128     // skip pool allocation for kBadStorageID, kExternalStorageID and kDynamicStorageID
1129     if (storage_id < 0) continue;
1130     size_t sid = static_cast<size_t>(storage_id);
1131     if (sid >= pool_info.size()) {
1132       pool_info.resize(sid + 1, PoolEntry{Context::CPU(), size_t(0), kUndefinedStorage});
1133     }
1134     PoolEntry& info = pool_info[sid];
1135     if (info.bytes == 0) {
1136       info = PoolEntry{data_context[i], bytes, data_storage_type[i]};
1137     } else {
1138       info.bytes = std::max(info.bytes, bytes);
1139     }
1140   }
1141   // construct the re-use pool, if needed
1142   std::multimap<size_t, NDArray> free_pool;
1143   if (shared_pool != nullptr) {
1144     for (const NDArray& nd : *shared_pool) {
1145       size_t bytes = 0;
1146       if (shape_is_known(nd.shape())) {
1147         bytes = nd.shape().Size() * mshadow::mshadow_sizeof(nd.dtype());
1148       }
1149       free_pool.insert(std::make_pair(bytes, nd));
1150     }
1151   }
1152   // remake the data pool
1153   data_pool_.clear();
1154   data_pool_.resize(pool_info.size());
1155 
1156   // sort the pool info the descending order before allocating memory
1157   std::vector<size_t> sorted_pool_index;
1158   for (size_t i = 0; i < pool_info.size(); i++) {
1159     sorted_pool_index.push_back(i);
1160   }
1161   auto pool_comparator = [&pool_info](size_t lhs, size_t rhs){
1162     return pool_info[lhs].bytes > pool_info[rhs].bytes;
1163   };
1164   std::sort(sorted_pool_index.begin(), sorted_pool_index.end(), pool_comparator);
1165 
1166   for (size_t i : sorted_pool_index) {
1167     const Context& ctx = pool_info[i].ctx;
1168     size_t bytes = pool_info[i].bytes;
1169     bool allocated = false;
1170     for (auto it = free_pool.lower_bound(bytes); it != free_pool.end(); ++it) {
1171       if (it->second.ctx() == ctx && it->first >= bytes) {
1172         data_pool_[i] = it->second;
1173         free_pool.erase(it);
1174         allocated = true;
1175         break;
1176       }
1177     }
1178     if (!allocated) {
1179       size_t nword = (bytes + 3) / 4;
1180       CHECK_LE(nword, std::numeric_limits<nnvm::dim_t>::max());
1181       // allocate float arrays
1182       mxnet::TShape shape{static_cast<nnvm::dim_t>(nword)};
1183       // TODO(junwu): adding delay_alloc=true to create nd
1184       // is a temporary solution.
1185       NDArray nd(shape, ctx, true);
1186       data_pool_[i] = nd;
1187       // put the new allocated arrays to shared pool
1188       if (shared_pool != nullptr)  {
1189         shared_pool->push_back(nd);
1190       }
1191     }
1192   }
1193   CHECK_EQ(data_pool_.size(), pool_info.size());
1194   // assign the data entries
1195   for (size_t i = 0; i < data_entry_.size(); ++i) {
1196     // avoid pre-allocated arrays
1197     if (!data_entry_[i].is_none()) continue;
1198     // assign allocated array by storage id
1199     int storage_id = vstorage[i];
1200     auto storage_type = (NDArrayStorageType) vstorage_type[i];
1201     if (storage_type == kDefaultStorage) {
1202       if (!shape_is_known(vshape[i])) {
1203         data_entry_[i] = NDArray(data_context[i], vdtype[i]);
1204       } else {
1205         CHECK_GE(storage_id, 0) << "Do not support runtime shape op yet";
1206         const NDArray& src = data_pool_.at(storage_id);
1207         data_entry_[i] = src.AsArray(vshape[i], vdtype[i]);
1208       }
1209     } else {
1210       data_entry_[i] = NDArray(storage_type, vshape[i], data_context[i],
1211                                true, vdtype[i]);
1212     }
1213     if (log_verbose_) {
1214       LOG(INFO) << "\tinit data entry\t" << i << "\tas " << common::stype_string(storage_type);
1215     }
1216   }
1217 }
1218 
1219 
InitCachedOps()1220 void GraphExecutor::InitCachedOps() {
1221   // get the graph
1222   const auto& idx = graph_.indexed_graph();
1223   const auto& vstorage_inplace =
1224       graph_.GetAttr<std::vector<int> >("storage_inplace_index");
1225   const auto& op_execs =
1226       graph_.GetAttr<OpExecVector>("op_execs");
1227   const auto& vctx = graph_.GetAttr<ContextVector>("context");
1228   const auto& addto_entry = graph_.GetAttr<std::vector<int> >("addto_entry");
1229   const auto& skip_plus_node = graph_.GetAttr<std::vector<int> >("skip_plus_node");
1230 
1231   op_nodes_.resize(idx.num_nodes());
1232   // setup the array and requirements.
1233   for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
1234     const auto& inode = idx[nid];
1235     if (inode.source->is_variable()) continue;
1236     op_nodes_[nid].opr_name = inode.source->op()->name.c_str();
1237     if (skip_plus_node.at(nid)) {
1238       op_nodes_[nid].skip_exec_node = true; continue;
1239     }
1240 
1241     op_nodes_[nid].exec = op_execs[nid];
1242     op_nodes_[nid].ctx = vctx[nid];
1243     auto& exec = op_nodes_[nid].exec;
1244     CHECK_EQ(exec->in_array.size(), 0U);
1245     CHECK_EQ(exec->out_array.size(), 0U);
1246     for (const auto& e : inode.inputs) {
1247       exec->in_array.push_back(data_entry_[idx.entry_id(e)]);
1248     }
1249     // detect inplace requirement
1250     for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
1251       uint32_t eid = idx.entry_id(nid, index);
1252       exec->out_array.push_back(data_entry_[eid]);
1253       if (addto_entry.at(eid) != 0) {
1254         exec->req.push_back(kAddTo);
1255       } else if (vstorage_inplace[eid] >= 0) {
1256         exec->req.push_back(kWriteInplace);
1257       } else if (vstorage_inplace[eid] == -2) {
1258         // -2 indicate that the entry is never referenced.
1259         exec->req.push_back(kNullOp);
1260       } else {
1261         exec->req.push_back(kWriteTo);
1262       }
1263     }
1264   }
1265   // Note that this modifies the requirement of kWriteInplace
1266   for (size_t j = num_forward_outputs_; j < idx.outputs().size(); ++j) {
1267     auto& e = idx.outputs()[j];
1268     op_nodes_[e.node_id].exec->req[e.index] =
1269         grad_store_[j - num_forward_outputs_].first;
1270   }
1271   for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
1272     const auto& inode = idx[nid];
1273     if (inode.source->is_variable()) continue;
1274     if (op_nodes_[nid].skip_exec_node) continue;
1275     auto& exec = op_nodes_[nid].exec;
1276     bool is_async = op_nodes_[nid].exec->exec_type() == ExecType::kAsync;
1277     bool is_gpu = op_nodes_[nid].ctx.dev_mask() == gpu::kDevMask;
1278 
1279     // the variables
1280     std::vector<Engine::VarHandle> use_vars, mutate_vars;
1281     for (const auto& nd : exec->in_array) {
1282       use_vars.push_back(nd.var());
1283     }
1284     for (const auto& r : exec->op_ctx.requested) {
1285       mutate_vars.push_back(r.var);
1286     }
1287     for (const auto& nd : exec->out_array) {
1288       mutate_vars.push_back(nd.var());
1289     }
1290     if (exec->var() != nullptr) {
1291       mutate_vars.push_back(exec->var());
1292     }
1293     // dedup vars
1294     Engine::Get()->DeduplicateVarHandle(&use_vars, &mutate_vars);
1295     // all vars include both mutate vars and use vars
1296     std::vector<Engine::VarHandle> all_vars(use_vars);
1297     std::copy(mutate_vars.begin(), mutate_vars.end(),
1298               std::inserter(all_vars, all_vars.end()));
1299     // setup exec vars
1300     Engine::Get()->PushAsync(
1301       [exec](RunContext rctx, Engine::CallbackOnComplete on_complete) {
1302         exec->Setup();
1303         on_complete();
1304       }, Context::CPU(), {}, all_vars, FnProperty::kNormal, 0,
1305       "SetupExec");
1306     auto exec_fun = [exec, is_async, is_gpu] (
1307         RunContext ctx, Engine::CallbackOnComplete on_complete) {
1308       if (is_async) {
1309         exec->op_ctx.async_on_complete = on_complete;
1310       }
1311       exec->Run(ctx, is_gpu);
1312       // call on complete only if it is async op
1313       if (!is_async) {
1314         if (is_gpu) {
1315 #if MXNET_USE_CUDA
1316           // Wait GPU kernel to finish.
1317           ctx.get_stream<gpu>()->Wait();
1318 #else
1319           LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
1320 #endif
1321         }
1322         on_complete();
1323       }
1324     };
1325     // setup the vars
1326     op_nodes_[nid].cached_opr = Engine::Get()->NewOperator(
1327         exec_fun, use_vars, mutate_vars, FnProperty::kNormal,
1328         op_nodes_[nid].opr_name);
1329     op_nodes_[nid].mutate_vars = mutate_vars;
1330     op_nodes_[nid].use_vars = use_vars;
1331   }
1332 }
1333 
InitOpSegs()1334 void GraphExecutor::InitOpSegs() {
1335   size_t total_num_nodes = graph_.indexed_graph().num_nodes();
1336   cached_seg_opr_.clear();
1337   CachedSegOpr p;
1338   cached_seg_opr_.resize(total_num_nodes, p);
1339   if (monitor_callback_) return;
1340 
1341   // Symbolic bulking is set by the same environment variables as Imperative bulking.
1342   // Generate segments based on the graph structure
1343   bool prefer_bulk_exec_inference = Imperative::PreferBulkExecInference();
1344   // Whether to perform bulk exec for training
1345   const profiler::Profiler *prof = profiler::Profiler::Get();
1346   bool prefer_bulk_exec_train = Imperative::PreferBulkExecTrain()
1347                                 && (!prof || !prof->AggregateEnabled());
1348   if (this->is_dynamic_) {
1349     prefer_bulk_exec_inference = false;
1350     prefer_bulk_exec_train = false;
1351   }
1352   bool is_training = num_forward_nodes_ != total_num_nodes;
1353 
1354   if (prefer_bulk_exec_train && is_training) {
1355     // Bulk the forward portion of the graph per the bulk segment max size for forward training
1356     this->BulkOpSegs(0, num_forward_nodes_, Imperative::BulkExecMaxNodeTrainFwd());
1357     // Bulk the backward portion of the graph per the bulk segment max size for backward training
1358     this->BulkOpSegs(num_forward_nodes_, total_num_nodes, Imperative::BulkExecMaxNodeTrainBwd());
1359   }
1360 
1361   if (prefer_bulk_exec_inference && !is_training) {
1362     // Bulk the entire graph as one bulk segment if possible
1363     this->BulkOpSegs(0, total_num_nodes, total_num_nodes);
1364   }
1365 }
1366 
1367 
BulkOpSegs(size_t from_node,size_t up_to_node,size_t segment_num_nodes_max)1368 void GraphExecutor::BulkOpSegs(size_t from_node, size_t up_to_node, size_t segment_num_nodes_max) {
1369   size_t topo_start = from_node;
1370   size_t segment_node_count = 0;
1371   for (size_t nid = from_node; nid < up_to_node; nid++) {
1372     auto &node = graph_.indexed_graph()[nid].source;
1373     auto &op_node = op_nodes_[nid];
1374     // Variables, such as learned weights, are ignored in the segment_node_count
1375     bool ignore_node = node->is_variable() || op_node.skip_exec_node || op_node.exec == nullptr;
1376     if (!ignore_node)
1377       segment_node_count++;
1378     bool can_bulk = ignore_node || op_node.exec->exec_type() == ExecType::kSync;
1379     // check if we need to create the segment based on properties of this node
1380     if (!can_bulk || nid == up_to_node - 1 || segment_node_count >= segment_num_nodes_max) {
1381       // Create a new segment for the previous nodes- include also this node if it's bulkable
1382       cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, can_bulk ? nid + 1 : nid);
1383       topo_start = nid + 1;
1384       segment_node_count = 0;
1385     }
1386   }
1387 }
1388 
ExecuteMonInputCallback(size_t nid)1389 void GraphExecutor::ExecuteMonInputCallback(size_t nid) {
1390   static const auto& flist_inputs =
1391       nnvm::Op::GetAttr<nnvm::FListInputNames>("FListInputNames");
1392   const auto& idx = graph_.indexed_graph();
1393   std::vector<std::string> input_names;
1394   OpNode& opnode = op_nodes_[nid];
1395   const auto& inode = idx[nid];
1396   const auto& node = idx[nid].source;
1397   if (flist_inputs.count(node->op())) {
1398     input_names = flist_inputs[node->op()](node->attrs);
1399   } else {
1400     for (size_t i = 0; i < node->num_inputs(); ++i) {
1401       input_names.emplace_back("input" + std::to_string(i));
1402     }
1403   }
1404   CHECK_EQ(opnode.exec->in_array.size(), input_names.size());
1405   for (size_t i = 0; i < opnode.exec->in_array.size(); ++i) {
1406     if (node->inputs[i].node->is_variable()) {
1407     // Monitor variable
1408     NDArray *cpy = new NDArray(opnode.exec->in_array[i]);
1409     std::string name = node->inputs[i].node->attrs.name;
1410     this->monitor_callback_(name.c_str(), reinterpret_cast<void*>(cpy));
1411     }
1412     NDArray *cpy = new NDArray(opnode.exec->in_array[i]);
1413     std::string name = inode.source->attrs.name + "_" + input_names[i];
1414     this->monitor_callback_(name.c_str(), reinterpret_cast<void*>(cpy));
1415   }
1416 }
1417 
ExecuteMonOutputCallback(size_t nid)1418 void GraphExecutor::ExecuteMonOutputCallback(size_t nid) {
1419   const auto& idx = graph_.indexed_graph();
1420   OpNode& opnode = op_nodes_[nid];
1421   const auto& node = idx[nid].source;
1422   for (size_t i = 0; i < opnode.exec->out_array.size(); ++i) {
1423     NDArray *cpy = new NDArray(opnode.exec->out_array[i]);
1424     nnvm::ObjectPtr node_ptr = std::make_shared<nnvm::Node>(*node);
1425     std::string name = GetOutputName({node_ptr, static_cast<uint32_t >(i), 0});
1426     this->monitor_callback_(name.c_str(), reinterpret_cast<void*>(cpy));
1427   }
1428 }
1429 
RunOps(bool is_train,size_t topo_start,size_t topo_end)1430 void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) {
1431   static auto& finfer_shape = nnvm::Op::GetAttr<mxnet::FInferShape>("FInferShape");
1432   static auto& is_backward = Op::GetAttr<nnvm::TIsBackward>("TIsBackward");
1433   // Update context
1434   const auto& idx = graph_.indexed_graph();
1435   for (size_t nid = topo_start; nid < topo_end; ++nid) {
1436     OpNode& opnode = op_nodes_[nid];
1437     if (opnode.skip_exec_node) continue;
1438     const auto& inode = idx[nid];
1439     if (inode.source->is_variable()) continue;
1440     opnode.exec->op_ctx.is_train = is_train;
1441     opnode.exec->op_ctx.need_grad = need_grad_;
1442   }
1443 
1444   mxnet::ShapeVector rshape = graph_.MoveCopyAttr<mxnet::ShapeVector>("shape");
1445   // Push Ops
1446   for (size_t nid = topo_start; nid < topo_end; ++nid) {
1447     auto seg_op = cached_seg_opr_[nid];
1448     // Check segments first
1449     if (monitor_callback_ == nullptr && seg_op.opr != nullptr && seg_op.topo_end <= topo_end) {
1450       bool profiling = profiler::Profiler::Get()->GetState() == profiler::Profiler::kRunning;
1451       Engine::Get()->Push(seg_op.opr, seg_op.ctx, 0, profiling);
1452       nid = seg_op.topo_end - 1;
1453       continue;
1454     }
1455     // Normal mode
1456     const auto& inode = idx[nid];
1457     const uint32_t num_inputs = inode.inputs.size();
1458     const uint32_t num_outputs = inode.source->num_outputs();
1459     if (inode.source->is_variable()) continue;
1460     OpNode& opnode = op_nodes_[nid];
1461     if (op_nodes_[nid].skip_exec_node) continue;
1462     // Monitor callbacks
1463     if (monitor_callback_ && monitor_all_) {
1464       ExecuteMonInputCallback(nid);
1465     }
1466     if (this->is_dynamic_) {
1467       const auto &op = inode.source->op();
1468       {
1469         for (NDArray &array : opnode.exec->in_array) {
1470           array.WaitToRead();
1471           if (!shape_is_known(array.shape())) {
1472             array.SetShapeFromChunk();
1473           }
1474         }
1475         int i = 0;
1476         for (NDArray &array : opnode.exec->out_array) {
1477           array.WaitToRead();
1478           if (!shape_is_known(array.shape())) {
1479             array.SetShapeFromChunk();
1480           }
1481           if (!shape_is_known(array.shape())) {
1482             mxnet::TShape shape = rshape[idx.entry_id(nid, i)];
1483             if (shape_is_known(shape)) {
1484               array.ReshapeAndAlloc(shape);
1485             }
1486           }
1487           ++i;
1488         }
1489       }
1490       if (finfer_shape.count(op)) {
1491         mxnet::ShapeVector in_shapes;
1492         mxnet::ShapeVector out_shapes;
1493         for (NDArray &array : opnode.exec->in_array) {
1494           in_shapes.push_back(array.shape());
1495         }
1496         for (NDArray &array : opnode.exec->out_array) {
1497           out_shapes.push_back(array.shape());
1498         }
1499         auto finfer = finfer_shape[op];
1500         try {
1501           bool success = finfer(inode.source->attrs, &in_shapes, &out_shapes);
1502           CHECK(success) << "InferShape failed in operator " << inode.source->attrs.name;
1503         } catch (const std::exception& e) {
1504           throw dmlc::Error("Error in operator " + inode.source->attrs.name + ": " + e.what());
1505         }
1506         int n_out = out_shapes.size();
1507         for (int i = 0; i < n_out; ++i) {
1508           NDArray &array = opnode.exec->out_array[i];
1509           if (!shape_is_known(array.shape())) {
1510             array.Init(out_shapes[i]);
1511           }
1512         }
1513       } else if (is_backward.get(inode.source->op(), false) && inode.control_deps.size()) {
1514         CHECK_GE(inode.control_deps.size(), 1U) <<
1515           "BackwardOp need to have control_deps to its forward op";
1516         uint32_t fid = inode.control_deps[0];
1517         const OpNode& fopnode = op_nodes_[fid];
1518         CHECK_EQ(fopnode.exec->in_array.size(), opnode.exec->out_array.size());
1519         int nelem = fopnode.exec->in_array.size();
1520         std::vector<NDArray> &from = fopnode.exec->in_array;
1521         std::vector<NDArray> &to = opnode.exec->out_array;
1522         for (int i = 0; i < nelem; ++i) {
1523           if (!shape_is_known(to[i].shape())) {
1524             to[i].Init(from[i].shape());
1525           }
1526         }
1527       }
1528     }
1529     opnode.exec->op_ctx.is_train = is_train;
1530     opnode.exec->op_ctx.need_grad = need_grad_;
1531     if (opnode.exec->exec_type() == ExecType::kCrossDeviceCopy) {
1532       CHECK_EQ(inode.inputs.size(), 1U);
1533       CHECK_EQ(opnode.exec->in_array.size(), 1U);
1534       CHECK_EQ(opnode.exec->out_array.size(), 1U);
1535       CopyFromTo(opnode.exec->in_array[0], &(opnode.exec->out_array[0]));
1536     } else if (opnode.exec->exec_type() == ExecType::kSubgraphExec) {
1537       // If the node contains a subgraph, we can't execute it in the engine.
1538       opnode.exec->Run(opnode.exec->op_ctx.run_ctx, false);
1539     } else if (opnode.cached_opr != nullptr) {
1540       bool profiling = profiler::Profiler::Get()->GetState() == profiler::Profiler::kRunning;
1541       Engine::Get()->Push(opnode.cached_opr, opnode.ctx, 0, profiling);
1542       if (this->is_dynamic_) {
1543         for (NDArray &array : opnode.exec->out_array) {
1544           array.WaitToRead();
1545           if (!shape_is_known(array.shape())) {
1546             array.SetShapeFromChunk();
1547           }
1548         }
1549       }
1550     } else {
1551       LOG(FATAL) << "Not accessed";
1552     }
1553     for (uint32_t i = 0; i < num_inputs; ++i) {
1554       int eid = idx.entry_id(inode.inputs[i]);
1555       if (!shape_is_known(rshape[eid])) {
1556         rshape[eid] = opnode.exec->in_array[i].shape();
1557       }
1558     }
1559     for (uint32_t i = 0; i < num_outputs; ++i) {
1560       int eid = idx.entry_id(nid, i);
1561       if (!shape_is_known(rshape[eid])) {
1562         rshape[eid] = opnode.exec->out_array[i].shape();
1563       }
1564     }
1565     // Monitor callbacks
1566     if (monitor_callback_) {
1567       ExecuteMonOutputCallback(nid);
1568     }
1569   }
1570   graph_.attrs["shape"] = std::make_shared<dmlc::any>(rshape);
1571 }
1572 
CreateCachedSegOpr(size_t topo_start,size_t topo_end)1573 GraphExecutor::CachedSegOpr GraphExecutor::CreateCachedSegOpr(size_t topo_start, size_t topo_end) {
1574   std::vector<Engine::VarHandle> use_vars;
1575   std::vector<Engine::VarHandle> mutate_vars;
1576   Context *pctx = nullptr;
1577   GraphExecutor::CachedSegOpr ret;
1578   ret.topo_start = topo_start;
1579   ret.topo_end = topo_end;
1580   auto& exec_list = ret.exec_list;
1581   // invalid segment
1582   if (topo_end <= topo_start) {
1583     return ret;
1584   }
1585   std::string opr_names = "[";
1586 
1587   const auto& idx = graph_.indexed_graph();
1588   for (size_t nid = topo_start; nid < topo_end; ++nid) {
1589     std::vector<Engine::VarHandle> all_vars;
1590     const auto& inode = idx[nid];
1591     OpNode& op_node = op_nodes_[nid];
1592     if (op_node.skip_exec_node) continue;
1593     if (inode.source->is_variable()) continue;
1594     if (op_node.exec->exec_type() != ExecType::kSync) {
1595       return ret;
1596     }
1597     if (pctx == nullptr) pctx = &(op_node.ctx);
1598     if (*pctx != op_node.ctx) {
1599       return ret;
1600     }
1601     auto& exec = op_nodes_[nid].exec;
1602     std::copy(op_node.mutate_vars.begin(), op_node.mutate_vars.end(),
1603               std::inserter(mutate_vars, mutate_vars.end()));
1604     std::copy(op_node.use_vars.begin(), op_node.use_vars.end(),
1605               std::inserter(use_vars, use_vars.end()));
1606     ret.exec_list.push_back(exec);
1607     opr_names += inode.source->op()->name + ",";
1608   }
1609 
1610   if (pctx == nullptr)
1611     return ret;
1612   ret.ctx = *pctx;
1613   Engine::Get()->DeduplicateVarHandle(&use_vars, &mutate_vars);
1614 
1615   bool is_gpu = pctx->dev_mask() == gpu::kDevMask;
1616 
1617 #if CUDA_GRAPHS_AVAILABLE
1618   // Provide initialized `cuda_graphs_exec`, which when captured
1619   // by exec_fun, acts like a static variable inside the mutable closure.
1620   cuda_graphs::CudaGraphsExec cuda_graphs_exec(exec_list, is_gpu, opr_names.c_str());
1621   auto exec_fun = [cuda_graphs_exec, exec_list, is_gpu] (
1622       RunContext rctx, Engine::CallbackOnComplete on_complete) mutable {
1623     // Run all opr in the sub-graph with CUDA graphs executor if possible
1624     cuda_graphs_exec.RunAll(exec_list, rctx, is_gpu);
1625 #else
1626   auto exec_fun = [exec_list, is_gpu] (
1627       RunContext rctx, Engine::CallbackOnComplete on_complete) {
1628     // Run all opr in the sub-graph
1629     OpExecutor::RunAll(exec_list, rctx, is_gpu);
1630 #endif
1631     if (is_gpu) {
1632 #if MXNET_USE_CUDA
1633       // Wait GPU kernel to finish.
1634       rctx.get_stream<gpu>()->Wait();
1635 #else
1636       LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
1637 #endif
1638     }
1639     on_complete();
1640   };
1641   opr_names.pop_back();
1642   opr_names += "]";
1643   ret.opr = Engine::Get()->NewOperator(
1644     exec_fun, use_vars, mutate_vars, FnProperty::kNormal,
1645     opr_names.c_str());
1646   return ret;
1647 }
1648 
1649 // Infer shapes, dtypes, stypes, contexts for the forward graph
1650 static nnvm::Graph InferForwardAttrs(nnvm::Graph g,
1651                                      mxnet::ShapeVector arg_shapes,
1652                                      nnvm::DTypeVector arg_dtypes,
1653                                      StorageTypeVector arg_stypes,
1654                                      const Context& default_ctx,
1655                                      const std::map<std::string, Context>& ctx_map,
1656                                      const std::vector<Context>& in_arg_ctxes,
1657                                      const std::vector<Context>& aux_state_ctxes,
1658                                      bool partial_shape = false) {
1659   const auto& indexed_graph = g.indexed_graph();
1660   const auto num_forward_inputs = indexed_graph.input_nodes().size();
1661   g = AssignContext(g, default_ctx, ctx_map, in_arg_ctxes, {},
1662                    aux_state_ctxes, {}, num_forward_inputs, g.outputs.size());
1663   g = InferShape(std::move(g), std::move(arg_shapes), "__shape__");
1664   if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
1665     if (!partial_shape) {
1666       HandleInferShapeError(num_forward_inputs, indexed_graph,
1667                             g.GetAttr<mxnet::ShapeVector>("shape"));
1668     }
1669   }
1670   g = InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
1671   if (g.GetAttr<size_t>("dtype_num_unknown_nodes") != 0U) {
1672     HandleInferTypeError(num_forward_inputs, indexed_graph,
1673                          g.GetAttr<nnvm::DTypeVector>("dtype"));
1674   }
1675   g = InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__");
1676   if (g.GetAttr<size_t>("storage_type_num_unknown_nodes") != 0U) {
1677     HandleInferStorageTypeError(num_forward_inputs, indexed_graph,
1678                                 g.GetAttr<StorageTypeVector>("storage_type"));
1679   }
1680   return g;
1681 }
1682 
1683 static bool SubgraphBackendCheck(const op::SubgraphBackendPtr& backend,
1684                                  const Context& default_ctx,
1685                                  int verbose = 1) {
1686   if (backend->HasAttr("enable") && (backend->GetAttr<bool>("enable") != true)) {
1687     if (verbose > 1) {
1688       LOG(INFO) << "Subgraph backend " << backend->GetName()
1689                 << " isn't activated.";
1690     }
1691     return false;
1692   }
1693   if (backend->HasAttr("context") && backend->GetAttr<Context>("context") != default_ctx) {
1694     if (verbose > 1) {
1695       LOG(INFO) << "Subgraph backend " << backend->GetName()
1696                 << " isn't activated as context mismatch.";
1697     }
1698     return false;
1699   }
1700   return true;
1701 }
1702 
1703 static bool SubgraphPropertyCheck(const std::string& backend_name,
1704                                   const op::SubgraphPropertyPtr& prop, bool need_grad,
1705                                   int verbose = 1) {
1706   auto full_name =
1707       prop->HasAttr("property_name") ? prop->GetAttr<std::string>("property_name") : std::string();
1708   if (prop->HasAttr("disable") && prop->GetAttr<bool>("disable") == true) {
1709     LOG(INFO) << "subgraph property " << full_name << " from backend " << backend_name
1710               << " is disabled.";
1711     return false;
1712   }
1713   if (prop->HasAttr("inference_only") && prop->GetAttr<bool>("inference_only") == true) {
1714     if (need_grad) {
1715       if (verbose > 1) {
1716         LOG(INFO) << "skip partitioning graph with subgraph property " << full_name
1717                   << " from backend " << backend_name << " as it requires `grad_req=null`.";
1718       }
1719       return false;
1720     }
1721   }
1722   return true;
1723 }
1724 
1725 // Given input attr arrays, partition the graph using the backend name equal to prop_name.
1726 // This is a common function for bind and simple_bind flows.
1727 static nnvm::Symbol BuildSubgraph(const nnvm::Symbol& src, op::SubgraphPropertyPtr subgraph_prop,
1728                                   const mxnet::ShapeVector& arg_shapes,
1729                                   const nnvm::DTypeVector& arg_dtypes,
1730                                   const StorageTypeVector& arg_stypes, const Context& default_ctx,
1731                                   const std::map<std::string, Context>& ctx_map,
1732                                   const std::vector<Context>& in_arg_ctxes,
1733                                   const std::vector<Context>& aux_state_ctxes) {
1734   nnvm::Symbol ret = src.Copy();
1735   nnvm::Graph g;
1736   g.outputs = ret.outputs;
1737   g = InferForwardAttrs(g, arg_shapes, arg_dtypes, arg_stypes, default_ctx, ctx_map, in_arg_ctxes,
1738                         aux_state_ctxes, true);
1739   subgraph_prop->SetAttr("graph", g);
1740   g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(subgraph_prop);
1741   g = ApplyPass(std::move(g), "BuildSubgraph");
1742   subgraph_prop->RemoveAttr("graph");
1743   g.attrs.erase("subgraph_property");
1744   ret.outputs = g.outputs;
1745   return ret;
1746 }
1747 
1748 // Given input attr dicts, partition the graph using the backend.
1749 // This is for simple_bind flow.
1750 static nnvm::Symbol BuildSubgraph(
1751     const nnvm::Symbol& src, const op::SubgraphBackendPtr backend,
1752     const std::unordered_map<std::string, mxnet::TShape>& arg_shape_map,
1753     const std::unordered_map<std::string, int>& arg_dtype_map,
1754     const std::unordered_map<std::string, int>& arg_stype_map, const Context& default_ctx,
1755     const std::map<std::string, Context>& ctx_map, std::vector<Context>* in_arg_ctxes,
1756     std::vector<Context>* arg_grad_ctxes, std::vector<OpReqType>* grad_req_types,
1757     std::vector<Context>* aux_state_ctxes, int verbose = 1) {
1758   // setup map for in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes and grad_req_types
1759   std::unordered_map<std::string, Context> in_arg_ctx_map;
1760   std::unordered_map<std::string, Context> arg_grad_ctx_map;
1761   std::unordered_map<std::string, Context> aux_state_ctx_map;
1762   std::unordered_map<std::string, OpReqType> grad_req_type_map;
1763 
1764   auto arg_names = src.ListInputNames(nnvm::Symbol::kReadOnlyArgs);
1765   auto aux_names = src.ListInputNames(nnvm::Symbol::kAuxiliaryStates);
1766   for (size_t i = 0; i < arg_names.size(); ++i) {
1767     const auto& name = arg_names[i];
1768     in_arg_ctx_map[name] = in_arg_ctxes->at(i);
1769     arg_grad_ctx_map[name] = arg_grad_ctxes->at(i);
1770     grad_req_type_map[name] = grad_req_types->at(i);
1771   }
1772 
1773   for (size_t i = 0; i < aux_names.size(); ++i) {
1774     aux_state_ctx_map[aux_names[i]] = aux_state_ctxes->at(i);
1775   }
1776 
1777   bool need_grad = false;
1778   for (OpReqType req : *grad_req_types) {
1779     if (req != kNullOp) {
1780       need_grad = true;
1781       break;
1782     }
1783   }
1784   nnvm::Symbol ret = src.Copy();
1785   std::unordered_set<std::string> op_names_set;
1786   const auto& backend_name = backend->GetName();
1787   const auto it = op::SubgraphPropertyOpNameSet::Get()->find(backend_name);
1788   // assign a op name set to the subgraph property if it has been provided by users
1789   if (it != op::SubgraphPropertyOpNameSet::Get()->end()) {
1790     LOG(INFO) << "SubgraphPropertyOpNameSet for subgraph property " << backend_name
1791               << " has been assigned a value. Please make sure it is initialized"
1792                  " only for the testing purpose.";
1793     op_names_set = it->second;
1794   }
1795 
1796   const auto& subgraph_prop_list = backend->GetSubgraphProperties();
1797   for (auto& subgraph_prop : subgraph_prop_list) {
1798     if (SubgraphPropertyCheck(backend_name, subgraph_prop, need_grad, verbose)) {
1799       subgraph_prop->SetAttr("op_names", op_names_set);
1800       const std::vector<std::string> input_names = ret.ListInputNames(Symbol::kAll);
1801       mxnet::ShapeVector arg_shapes(input_names.size(), mxnet::TShape());
1802       nnvm::DTypeVector arg_dtypes(input_names.size(), -1);
1803       StorageTypeVector arg_stypes(input_names.size(), kUndefinedStorage);
1804       for (size_t i = 0; i < input_names.size(); ++i) {
1805         const auto& input_name = input_names[i];
1806         const auto it1 = arg_shape_map.find(input_name);
1807         if (arg_shape_map.end() != it1) {
1808           arg_shapes[i] = it1->second;
1809         }
1810         const auto it2 = arg_dtype_map.find(input_name);
1811         if (arg_dtype_map.end() != it2) {
1812           arg_dtypes[i] = it2->second;
1813         }
1814         const auto it3 = arg_stype_map.find(input_name);
1815         if (arg_stype_map.end() != it3) {
1816           arg_stypes[i] = it3->second;
1817         }
1818       }
1819       ret = BuildSubgraph(ret, subgraph_prop, arg_shapes, arg_dtypes, arg_stypes, default_ctx,
1820                           ctx_map, *in_arg_ctxes, *aux_state_ctxes);
1821       // Reorder in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes and grad_req_types according to
1822       // partitioned symbol input sequence
1823       in_arg_ctxes->clear();
1824       arg_grad_ctxes->clear();
1825       aux_state_ctxes->clear();
1826       grad_req_types->clear();
1827       auto new_arg_names = ret.ListInputNames(nnvm::Symbol::kReadOnlyArgs);
1828       auto new_aux_names = ret.ListInputNames(nnvm::Symbol::kAuxiliaryStates);
1829       for (const auto& arg_name : new_arg_names) {
1830         CHECK(in_arg_ctx_map.count(arg_name));
1831         in_arg_ctxes->push_back(in_arg_ctx_map[arg_name]);
1832         arg_grad_ctxes->push_back(arg_grad_ctx_map[arg_name]);
1833         grad_req_types->push_back(grad_req_type_map[arg_name]);
1834       }
1835       for (const auto& arg_name : new_aux_names) {
1836         CHECK(aux_state_ctx_map.count(arg_name));
1837         aux_state_ctxes->push_back(aux_state_ctx_map[arg_name]);
1838       }
1839     }
1840   }
1841   return ret;
1842 }
1843 
1844 // Given input ndarrays, partition the graph using backend.
1845 // This is for bind flow.
1846 static nnvm::Symbol BuildSubgraph(const nnvm::Symbol& src, const op::SubgraphBackendPtr backend,
1847                                   const Context& default_ctx,
1848                                   const std::map<std::string, Context>& ctx_map,
1849                                   std::vector<NDArray>* in_args,
1850                                   std::vector<NDArray>* arg_grad_store,
1851                                   std::vector<OpReqType>* grad_req_type,
1852                                   std::vector<NDArray>* aux_states, int verbose = 1) {
1853   // setup map for in_args, arg_grad_store, grad_req_type and aux_states
1854   std::unordered_map<std::string, NDArray> in_args_map;
1855   std::unordered_map<std::string, NDArray> arg_grad_store_map;
1856   std::unordered_map<std::string, OpReqType> grad_req_type_map;
1857   std::unordered_map<std::string, NDArray> aux_states_map;
1858   const std::vector<std::string> arg_names = src.ListInputNames(nnvm::Symbol::kReadOnlyArgs);
1859   const std::vector<std::string> aux_names = src.ListInputNames(nnvm::Symbol::kAuxiliaryStates);
1860   for (size_t i = 0; i < arg_names.size(); ++i) {
1861     in_args_map[arg_names[i]] = in_args->at(i);
1862   }
1863 
1864   for (size_t i = 0; i < aux_names.size(); ++i) {
1865     aux_states_map[aux_names[i]] = aux_states->at(i);
1866   }
1867 
1868   if (arg_grad_store->size()) {
1869     for (size_t i = 0; i < arg_names.size(); ++i) {
1870       const auto& name = arg_names[i];
1871       arg_grad_store_map[name] = arg_grad_store->at(i);
1872       grad_req_type_map[name] = grad_req_type->at(i);
1873     }
1874   }
1875 
1876   bool need_grad = false;
1877   for (OpReqType req : *grad_req_type) {
1878     if (req != kNullOp) {
1879       need_grad = true;
1880       break;
1881     }
1882   }
1883   nnvm::Symbol ret = src.Copy();
1884   std::unordered_set<std::string> op_names_set;
1885   const auto& backend_name = backend->GetName();
1886   auto it = op::SubgraphPropertyOpNameSet::Get()->find(backend_name);
1887   // assign a op name set to the subgraph property if it has been provided by users
1888   if (it != op::SubgraphPropertyOpNameSet::Get()->end()) {
1889     LOG(INFO) << "SubgraphPropertyOpNameSet for subgraph property " << backend_name
1890               << " has been assigned a value. Please make sure it is initialized"
1891                  " only for the testing purpose.";
1892     op_names_set = it->second;
1893   }
1894   const auto& subgraph_prop_list = backend->GetSubgraphProperties();
1895 
1896   for (auto subgraph_prop : subgraph_prop_list) {
1897     if (SubgraphPropertyCheck(backend_name, subgraph_prop, need_grad, verbose)) {
1898       subgraph_prop->SetAttr("op_names", op_names_set);
1899       const std::vector<std::string> input_names = ret.ListInputNames(Symbol::kAll);
1900       const std::vector<std::string> arg_names = ret.ListInputNames(nnvm::Symbol::kReadOnlyArgs);
1901       const std::vector<std::string> aux_names = ret.ListInputNames(nnvm::Symbol::kAuxiliaryStates);
1902       CHECK_EQ(arg_names.size(), in_args_map.size());
1903       CHECK_EQ(aux_names.size(), aux_states_map.size());
1904       mxnet::ShapeVector arg_shapes;  // all input shapes
1905       arg_shapes.reserve(input_names.size());
1906       nnvm::DTypeVector arg_dtypes;  // all input dtypes
1907       arg_dtypes.reserve(input_names.size());
1908       StorageTypeVector arg_stypes;  // all input stypes
1909       arg_stypes.reserve(input_names.size());
1910       std::vector<Context> in_arg_ctxes(in_args_map.size());
1911       std::vector<Context> aux_state_ctxes(aux_states_map.size());
1912 
1913       size_t i1 = 0, i2 = 0;
1914       for (const auto& input_name : input_names) {
1915         if (i2 < aux_names.size() && aux_names[i2] == input_name) {
1916           const auto &aux_st = aux_states_map[input_name];
1917           arg_shapes.push_back(aux_st.shape());
1918           arg_dtypes.push_back(aux_st.dtype());
1919           arg_stypes.push_back(aux_st.storage_type());
1920           aux_state_ctxes[i2] = aux_st.ctx();
1921           ++i2;
1922         } else {
1923           CHECK(i1 < arg_names.size());
1924           CHECK_EQ(arg_names[i1], input_name);
1925           const auto &in_arg = in_args_map[input_name];
1926           arg_shapes.push_back(in_arg.shape());
1927           arg_dtypes.push_back(in_arg.dtype());
1928           arg_stypes.push_back(in_arg.storage_type());
1929           in_arg_ctxes[i1] = in_arg.ctx();
1930           ++i1;
1931         }
1932       }
1933 
1934       ret = BuildSubgraph(ret, subgraph_prop, arg_shapes, arg_dtypes, arg_stypes, default_ctx,
1935                           ctx_map, in_arg_ctxes, aux_state_ctxes);
1936     }
1937   }
1938   // Reorder in_args, arg_grad_store, grad_req_type and aux_states according to partitioned symbol
1939   // input sequence
1940   const auto new_arg_names = ret.ListInputNames(nnvm::Symbol::kReadOnlyArgs);
1941   const auto new_aux_names = ret.ListInputNames(nnvm::Symbol::kAuxiliaryStates);
1942   CHECK_EQ(arg_names.size(), new_arg_names.size());
1943   CHECK_EQ(arg_names.size(), new_arg_names.size());
1944   in_args->clear();
1945   aux_states->clear();
1946   for (const auto& arg_name : new_arg_names) {
1947     CHECK(in_args_map.count(arg_name));
1948     in_args->push_back(in_args_map[arg_name]);
1949   }
1950 
1951   for (const auto& arg_name : new_aux_names) {
1952     CHECK(aux_states_map.count(arg_name));
1953     aux_states->push_back(aux_states_map[arg_name]);
1954   }
1955 
1956   if (arg_grad_store->size()) {
1957     arg_grad_store->clear();
1958     grad_req_type->clear();
1959     for (const auto& arg_name : new_arg_names) {
1960       arg_grad_store->push_back(arg_grad_store_map[arg_name]);
1961       grad_req_type->push_back(grad_req_type_map[arg_name]);
1962     }
1963   }
1964   return ret;
1965 }
1966 }  // namespace exec
1967 
SimpleBind(nnvm::Symbol symbol,const Context & default_ctx,const std::map<std::string,Context> & group2ctx,const std::vector<Context> & in_arg_ctxes,const std::vector<Context> & arg_grad_ctxes,const std::vector<Context> & aux_state_ctxes,const std::unordered_map<std::string,mxnet::TShape> & arg_shape_map,const std::unordered_map<std::string,int> & arg_dtype_map,const std::unordered_map<std::string,int> & arg_stype_map,const std::vector<OpReqType> & grad_req_types,const std::unordered_set<std::string> & shared_arg_names,std::vector<NDArray> * in_args,std::vector<NDArray> * arg_grads,std::vector<NDArray> * aux_states,std::unordered_map<std::string,NDArray> * shared_buffer,Executor * shared_exec)1968 Executor *Executor::SimpleBind(nnvm::Symbol symbol,
1969                                const Context& default_ctx,
1970                                const std::map<std::string, Context>& group2ctx,
1971                                const std::vector<Context>& in_arg_ctxes,
1972                                const std::vector<Context>& arg_grad_ctxes,
1973                                const std::vector<Context>& aux_state_ctxes,
1974                                const std::unordered_map<std::string, mxnet::TShape>& arg_shape_map,
1975                                const std::unordered_map<std::string, int>& arg_dtype_map,
1976                                const std::unordered_map<std::string, int>& arg_stype_map,
1977                                const std::vector<OpReqType>& grad_req_types,
1978                                const std::unordered_set<std::string>& shared_arg_names,
1979                                std::vector<NDArray>* in_args,
1980                                std::vector<NDArray>* arg_grads,
1981                                std::vector<NDArray>* aux_states,
1982                                std::unordered_map<std::string, NDArray>* shared_buffer,
1983                                Executor* shared_exec) {
1984   auto exec = new exec::GraphExecutor(symbol);
1985   bool init = false;
1986   if (!exec->subgraph_property().empty()) {
1987     static int verbose = dmlc::GetEnv("MXNET_SUBGRAPH_VERBOSE", 1);
1988     const auto& backend_name = exec->subgraph_property();
1989     const auto& backend = op::SubgraphBackendRegistry::Get()->GetSubgraphBackend(backend_name);
1990     if (exec::SubgraphBackendCheck(backend, default_ctx, verbose)) {
1991       if (verbose) LOG(INFO) << "Subgraph backend " << backend_name << " is activated.";
1992       std::vector<Context> tmp_in_arg_ctxes = in_arg_ctxes;
1993       std::vector<Context> tmp_arg_grad_ctxes = arg_grad_ctxes;
1994       std::vector<Context> tmp_aux_state_ctxes = aux_state_ctxes;
1995       std::vector<OpReqType> tmp_grad_req_types = grad_req_types;
1996       std::vector<NDArray> tmp_in_args;
1997       std::vector<NDArray> tmp_arg_grads;
1998       std::vector<NDArray> tmp_aux_states;
1999       const auto arg_names = symbol.ListInputNames(nnvm::Symbol::kReadOnlyArgs);
2000       const auto aux_names = symbol.ListInputNames(nnvm::Symbol::kAuxiliaryStates);
2001       symbol = exec::BuildSubgraph(symbol, backend, arg_shape_map, arg_dtype_map, arg_stype_map,
2002                                    default_ctx, group2ctx, &tmp_in_arg_ctxes, &tmp_arg_grad_ctxes,
2003                                    &tmp_grad_req_types, &tmp_aux_state_ctxes, verbose);
2004       // Subgraph cannot be recreated from unoptimized symbol
2005       delete exec;
2006       exec = new exec::GraphExecutor(symbol);
2007       exec->Init(symbol.Copy(), default_ctx, group2ctx, tmp_in_arg_ctxes, tmp_arg_grad_ctxes,
2008                  tmp_aux_state_ctxes, arg_shape_map, arg_dtype_map, arg_stype_map,
2009                  tmp_grad_req_types, shared_arg_names, &tmp_in_args, &tmp_arg_grads,
2010                  &tmp_aux_states, shared_buffer, shared_exec);
2011       init = true;
2012       const auto new_arg_names = symbol.ListInputNames(nnvm::Symbol::kReadOnlyArgs);
2013       const auto new_aux_names = symbol.ListInputNames(nnvm::Symbol::kAuxiliaryStates);
2014       std::unordered_map<std::string, size_t> new_arg_names_idx_map;
2015       std::unordered_map<std::string, size_t> new_aux_names_idx_map;
2016       for (size_t i = 0; i != new_arg_names.size(); ++i) {
2017         new_arg_names_idx_map[new_arg_names[i]] = i;
2018       }
2019       for (size_t i = 0; i != new_aux_names.size(); ++i) {
2020         new_aux_names_idx_map[new_aux_names[i]] = i;
2021       }
2022 
2023       in_args->reserve(arg_names.size());
2024       arg_grads->reserve(arg_names.size());
2025       for (size_t i = 0; i != arg_names.size(); ++i) {
2026         const auto& arg_name = arg_names[i];
2027         const auto& it = new_arg_names_idx_map.find(arg_name);
2028         CHECK(it != new_arg_names_idx_map.end())
2029             << "Subgraph doesn't support remove any input node for now.";
2030         in_args->emplace_back(std::move(tmp_in_args[it->second]));
2031         arg_grads->emplace_back(std::move(tmp_arg_grads[it->second]));
2032       }
2033 
2034       aux_states->reserve(aux_names.size());
2035       for (size_t i = 0; i != aux_names.size(); ++i) {
2036         const auto& aux_name = aux_names[i];
2037         const auto& it = new_aux_names_idx_map.find(aux_name);
2038         CHECK(it != new_aux_names_idx_map.end())
2039             << "Subgraph doesn't support remove any input node for now.";
2040         aux_states->emplace_back(std::move(tmp_aux_states[it->second]));
2041       }
2042     }
2043   }
2044   if (!init) {
2045     // init without subgraph
2046     exec->Init(symbol.Copy(), default_ctx, group2ctx, in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes,
2047                arg_shape_map, arg_dtype_map, arg_stype_map, grad_req_types, shared_arg_names,
2048                in_args, arg_grads, aux_states, shared_buffer, shared_exec);
2049   }
2050   return exec;
2051 }
2052 
Bind(nnvm::Symbol symbol,const Context & default_ctx,const std::map<std::string,Context> & group2ctx,const std::vector<NDArray> & in_args,const std::vector<NDArray> & arg_grad_store,const std::vector<OpReqType> & grad_req_type,const std::vector<NDArray> & aux_states,Executor * shared_exec)2053 Executor *Executor::Bind(nnvm::Symbol symbol,
2054                          const Context& default_ctx,
2055                          const std::map<std::string, Context>& group2ctx,
2056                          const std::vector<NDArray> &in_args,
2057                          const std::vector<NDArray> &arg_grad_store,
2058                          const std::vector<OpReqType> &grad_req_type,
2059                          const std::vector<NDArray> &aux_states,
2060                          Executor* shared_exec) {
2061   auto exec = new exec::GraphExecutor(symbol);
2062   static int verbose = dmlc::GetEnv("MXNET_SUBGRAPH_VERBOSE", 1);
2063   std::vector<NDArray> tmp_in_args = in_args;
2064   std::vector<NDArray> tmp_arg_grad_store = arg_grad_store;
2065   std::vector<OpReqType> tmp_grad_req_type = grad_req_type;
2066   std::vector<NDArray> tmp_aux_states = aux_states;
2067 
2068   if (!exec->subgraph_property().empty()) {
2069     const auto& backend_name = exec->subgraph_property();
2070     const auto& backend = op::SubgraphBackendRegistry::Get()->GetSubgraphBackend(backend_name);
2071     if (exec::SubgraphBackendCheck(backend, default_ctx, verbose)) {
2072       if (verbose) LOG(INFO) << "Subgraph backend " << backend_name << " is activated.";
2073       symbol = exec::BuildSubgraph(symbol, backend, default_ctx, group2ctx, &tmp_in_args,
2074                                    &tmp_arg_grad_store, &tmp_grad_req_type, &tmp_aux_states,
2075                                    verbose);
2076       // Subgraph cannot be recreated from unoptimized symbol
2077       delete exec;
2078       exec = new exec::GraphExecutor(symbol);
2079     }
2080   }
2081   exec->Init(symbol.Copy(), default_ctx, group2ctx, tmp_in_args, tmp_arg_grad_store,
2082              tmp_grad_req_type, tmp_aux_states, reinterpret_cast<Executor*>(shared_exec));
2083   return exec;
2084 }
2085 }  // namespace mxnet
2086