1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file graph_executor.cc
22 * \brief graph executor
23 */
24 #include <mxnet/base.h>
25 #include <nnvm/graph.h>
26 #include <nnvm/pass_functions.h>
27 #include <vector>
28 #include <set>
29 #include <algorithm>
30
31 #include "./exec_pass.h"
32 #include "./graph_executor.h"
33 #include "./cuda_graphs.h"
34 #include "../profiler/profiler.h"
35 #include "../common/utils.h"
36 #include "../common/exec_utils.h"
37 #include "../operator/subgraph/subgraph_property.h"
38 #include "../operator/operator_common.h"
39
40 namespace mxnet {
41 namespace exec {
42
43 using namespace mxnet::common;
44
GetDefaultSubgraphBackend()45 static const std::string GetDefaultSubgraphBackend() {
46 #if MXNET_USE_MKLDNN == 1
47 return std::string("MKLDNN");
48 #else
49 return std::string();
50 #endif
51 }
52
GraphExecutor(const nnvm::Symbol & symbol)53 GraphExecutor::GraphExecutor(const nnvm::Symbol& symbol) {
54 log_verbose_ = dmlc::GetEnv("MXNET_EXEC_VERBOSE_LOGGING", false);
55 need_grad_ = false;
56 is_dynamic_ = false;
57 subgraph_property_ = dmlc::GetEnv("MXNET_SUBGRAPH_BACKEND", GetDefaultSubgraphBackend());
58 if (subgraph_property_ == "NONE") {
59 subgraph_property_ = std::string();
60 LOG(INFO) << "MXNET_SUBGRAPH_BACKEND=NONE is detected, subgraph backend is not in use";
61 }
62 engine_ref_ = Engine::_GetSharedRef();
63 symbol_ = symbol.Copy();
64 }
65
~GraphExecutor()66 GraphExecutor::~GraphExecutor() {
67 for (auto& n : op_nodes_) {
68 if (n.cached_opr != nullptr) {
69 Engine::Get()->DeleteOperator(n.cached_opr);
70 }
71 }
72 // clean up seg ops
73 for (auto& seg : cached_seg_opr_) {
74 if (seg.opr != nullptr) {
75 Engine::Get()->DeleteOperator(seg.opr);
76 }
77 }
78 }
79
Forward(bool is_train)80 void GraphExecutor::Forward(bool is_train) {
81 RunOps(is_train, 0, num_forward_nodes_);
82 }
83
PartialForward(bool is_train,int step,int * step_left)84 void GraphExecutor::PartialForward(bool is_train, int step, int *step_left) {
85 size_t sstep = static_cast<size_t>(step);
86 if (sstep >= num_forward_nodes_) {
87 *step_left = 0;
88 return;
89 }
90 RunOps(is_train, sstep, sstep + 1);
91 *step_left = static_cast<int>(num_forward_nodes_ - sstep - 1);
92 }
93
Backward(const std::vector<NDArray> & head_grads,bool is_train)94 void GraphExecutor::Backward(const std::vector<NDArray>& head_grads, bool is_train) {
95 {
96 const auto& idx = graph_.indexed_graph();
97 if (num_forward_inputs_ != idx.input_nodes().size()) {
98 for (size_t i = 0; i < head_grad_array_.size(); ++i) {
99 if (!head_grad_array_[i].is_none()) {
100 CHECK(i < head_grads.size() && !head_grads[i].is_none())
101 << "Because the last operator is not Loss function, "
102 << "head_gradient is required when calling backward. "
103 << "If you are attempting to minimize the output as "
104 << "an objective, please modify your network and "
105 << "pass it through the make_loss symbol.";
106 const NDArray &from = head_grads[i];
107 NDArray &to = head_grad_array_[i];
108 if (this->is_dynamic_) {
109 to.WaitToRead();
110 if (!shape_is_known(to.shape())) {
111 to.Init(from.shape());
112 }
113 }
114 CopyFromTo(from, &to);
115 }
116 }
117 }
118 }
119 if (this->is_dynamic_) {
120 graph_ = InferShape(std::move(graph_), {}, "");
121 mxnet::ShapeVector rshape = graph_.MoveCopyAttr<mxnet::ShapeVector>("shape");
122 const auto& idx = graph_.indexed_graph();
123 for (size_t nid = 0; nid < idx.num_nodes(); ++nid) {
124 const auto& inode = idx[nid];
125 if (inode.source->is_variable()) continue;
126 OpNode& opnode = op_nodes_[nid];
127 if (opnode.skip_exec_node) continue;
128 for (NDArray &array : opnode.exec->in_array) {
129 array.WaitToRead();
130 if (!shape_is_known(array.shape())) {
131 array.SetShapeFromChunk();
132 }
133 }
134 int i = 0;
135 for (NDArray &array : opnode.exec->in_array) {
136 array.WaitToRead();
137 if (!shape_is_known(array.shape())) {
138 array.SetShapeFromChunk();
139 }
140 if (!shape_is_known(array.shape())) {
141 mxnet::TShape shape = rshape[idx.entry_id(inode.inputs[i])];
142 if (shape_is_known(shape)) {
143 array.ReshapeAndAlloc(shape);
144 }
145 }
146 ++i;
147 }
148 i = 0;
149 for (NDArray &array : opnode.exec->out_array) {
150 array.WaitToRead();
151 if (!shape_is_known(array.shape())) {
152 array.SetShapeFromChunk();
153 }
154 if (!shape_is_known(array.shape())) {
155 mxnet::TShape shape = rshape[idx.entry_id(nid, i)];
156 if (shape_is_known(shape)) {
157 array.ReshapeAndAlloc(shape);
158 }
159 }
160 ++i;
161 }
162 }
163 graph_.attrs["shape"] = std::make_shared<dmlc::any>(rshape);
164 }
165 const auto& idx = graph_.indexed_graph();
166 RunOps(is_train, num_forward_nodes_, idx.num_nodes());
167 }
168
Print(std::ostream & os) const169 void GraphExecutor::Print(std::ostream &os) const { // NOLINT(*)
170 nnvm::Symbol s;
171 s.outputs = graph_.outputs;
172 s.Print(os);
173 // message to be backward compatible with the memonger
174 size_t total_bytes = graph_.GetAttr<size_t>("storage_allocated_bytes");
175 os << "Total " << (total_bytes >> 20UL) << " MB allocated\n";
176 os << "Total " << 11 << " TempSpace resource requested\n";
177 }
178
179 /*!
180 * \brief Return the "optimized" symbol contained in the executor graph.
181 */
GetOptimizedSymbol()182 nnvm::Symbol GraphExecutor::GetOptimizedSymbol() {
183 Symbol ret;
184 ret.outputs = std::vector<nnvm::NodeEntry>(graph_.outputs.begin(),
185 graph_.outputs.begin() + num_forward_outputs_);
186 return ret.Copy();
187 }
188
SetMonitorCallback(const MonitorCallback & callback,bool monitor_all)189 void GraphExecutor::SetMonitorCallback(const MonitorCallback& callback, bool monitor_all) {
190 CHECK(callback) << "invalid callback";
191 monitor_callback_ = callback;
192 monitor_all_ = monitor_all;
193 }
194
outputs() const195 const std::vector<NDArray>& GraphExecutor::outputs() const {
196 if (this->is_dynamic_) {
197 for (const NDArray &array : output_arrays_) {
198 array.WaitToRead();
199 if (!shape_is_known(array.shape())) {
200 const_cast<NDArray &>(array).SetShapeFromChunk();
201 }
202 }
203 }
204 return output_arrays_;
205 }
206
in_arg_map() const207 const std::unordered_map<std::string, NDArray>& GraphExecutor::in_arg_map() const {
208 return in_arg_map_;
209 }
210
arg_grad_map() const211 const std::unordered_map<std::string, NDArray>& GraphExecutor::arg_grad_map() const {
212 return arg_grad_map_;
213 }
214
aux_state_map() const215 const std::unordered_map<std::string, NDArray>& GraphExecutor::aux_state_map() const {
216 return aux_state_map_;
217 }
218
AttrHint(nnvm::NodeEntry src,nnvm::NodeEntry like)219 static nnvm::NodeEntry AttrHint(nnvm::NodeEntry src, nnvm::NodeEntry like) {
220 static const Op* id_like = Op::Get("_identity_with_attr_like_rhs");
221 nnvm::ObjectPtr n = nnvm::Node::Create();
222 n->attrs.op = id_like;
223 n->attrs.name = src.node->attrs.name + "_id";
224 n->inputs = {src, like};
225 return nnvm::NodeEntry{n, 0, 0};
226 }
227
AggregateGradient(std::vector<nnvm::NodeEntry> && v)228 nnvm::NodeEntry AggregateGradient(std::vector<nnvm::NodeEntry>&& v) {
229 using nnvm::Op;
230 static size_t inplace_sum_cap = dmlc::GetEnv("MXNET_EXEC_INPLACE_GRAD_SUM_CAP", 8);
231 static const Op* ewise_plus_op = Op::Get("_grad_add");
232 static const Op* ewise_sum_op = Op::Get("ElementWiseSum");
233 static const Op* identity_op = Op::Get("identity");
234 static const Op* zeros_op = Op::Get("_zeros");
235 static const Op* zeros_like_op = Op::Get("zeros_like");
236
237 if (v.empty()) {
238 nnvm::ObjectPtr ng = nnvm::Node::Create();
239 ng->attrs.op = Op::Get("_zeros_without_dtype");
240 ng->attrs.name = "zeros_without_dtype";
241 ng->attrs.op->attr_parser(&(ng->attrs));
242 return nnvm::NodeEntry(std::move(ng), 0, 0);
243 }
244
245 // remove zero in the sum. at least keep 1.
246 auto begin = std::remove_if(v.begin(), v.end(), [](const nnvm::NodeEntry& nodeEntry) {
247 CHECK(nodeEntry.node);
248 return nodeEntry.node->op() == zeros_op || nodeEntry.node->op() == zeros_like_op;
249 });
250 if (begin == v.begin()) ++begin;
251 v.erase(begin, v.end());
252 CHECK(!v.empty());
253
254 if (v.size() == 1) {
255 return std::move(v[0]);
256 } else {
257 if (v.size() < inplace_sum_cap) {
258 nnvm::ObjectPtr sum_node = nnvm::Node::Create();
259 sum_node->attrs.op = ewise_sum_op;
260 sum_node->attrs.name = "sum_grad";
261 sum_node->attrs.dict["num_args"] = std::to_string(v.size());
262 sum_node->attrs.op->attr_parser(&(sum_node->attrs));
263 sum_node->inputs = std::move(v);
264 return nnvm::NodeEntry(std::move(sum_node), 0, 0);
265 } else {
266 // use a stream line of plus instead
267 nnvm::NodeEntry ret = v[0];
268 for (size_t i = 1; i < v.size(); ++i) {
269 // Add control flow dependency from to previous node
270 // This enforces the gradient sum order will be in the inverse
271 // order of forward traversal
272 // NOTE: adding control dependency can be dangerous and cause cycle in the dep.
273 // The curent usage is correct, because of the following invariant:
274 // assert: v[i-1] do not depend on v[i]
275 // To put in plain text: v is gradient vector that get pushed in the order
276 // that can generate them, which means if v[i] is not yet pushed,
277 // all previous gradient cannot depend on it.
278 // Note: For a symbol like the following:
279 // data = mx.sym.Variable('data')
280 // sym = data + data + data + data + data + data + data
281 // the node entries v passed in here are of the same node of
282 // op _identity_with_attr_like_rhs. We should skip adding a node
283 // to its own control_deps.
284 if (v[i-1].node != v[i].node) {
285 v[i].node->control_deps.push_back(ret.node);
286 }
287
288 std::ostringstream os;
289 os << "sum_grad_" << i;
290 nnvm::ObjectPtr x = nnvm::Node::Create();
291 x->attrs.op = ewise_plus_op;
292 x->attrs.name = os.str();
293 x->inputs = {ret, v[i]};
294 ret = nnvm::NodeEntry(std::move(x), 0, 0);
295 }
296 // identity node is used to avoid exposure of dummy plus node
297 // when its output get assigned to another space.
298 nnvm::ObjectPtr id_node = nnvm::Node::Create();
299 id_node->attrs.op = identity_op;
300 id_node->attrs.name = "sum_grad_final";
301 id_node->inputs = {ret};
302 return nnvm::NodeEntry{id_node, 0, 0};
303 }
304 }
305 }
306
307 template<typename ValueType>
get_node_attr(const nnvm::Node & node,const std::string & key,ValueType default_value)308 inline ValueType get_node_attr(
309 const nnvm::Node& node,
310 const std::string& key, ValueType default_value) {
311 auto it = node.attrs.dict.find(key);
312 if (it == node.attrs.dict.end()) {
313 return default_value;
314 } else {
315 ValueType ret;
316 dmlc::parameter::FieldEntry<ValueType> e;
317 e.Init(key, &ret, ret);
318 e.Set(&ret, it->second);
319 return ret;
320 }
321 }
322
323 /*!
324 * \brief Create the graph for backward pass.
325 * This is triggered by both simple_bind and bind flows.
326 */
InitFullGraph(nnvm::Symbol symbol,const std::vector<OpReqType> & grad_req_types)327 nnvm::Graph GraphExecutor::InitFullGraph(nnvm::Symbol symbol,
328 const std::vector<OpReqType>& grad_req_types) {
329 using nnvm::ObjectPtr;
330 using nnvm::NodeEntry;
331 // initial information
332 num_forward_outputs_ = symbol.outputs.size();
333 num_forward_inputs_ = symbol.ListInputs(nnvm::Symbol::kAll).size();
334
335 nnvm::Graph g;
336 g.outputs = symbol.outputs;
337 bool do_elim_common_expr = dmlc::GetEnv("MXNET_ELIMINATE_COMMON_EXPR", true);
338 if (do_elim_common_expr)
339 g = exec::EliminateCommonExpr(std::move(g));
340 need_grad_ = false;
341 for (OpReqType req : grad_req_types) {
342 if (req != kNullOp)
343 need_grad_ = true;
344 }
345 if (!need_grad_) return g;
346 for (size_t i = 0; i < g.outputs.size(); ++i) {
347 NodeEntry ngrad(nnvm::Node::Create(), 0, 0);
348 ngrad.node->attrs.name = "_head_grad_" + std::to_string(i);
349 head_grad_entry_.emplace_back(AttrHint(ngrad, g.outputs[i]));
350 head_grad_map_[ngrad.node.get()] = i;
351 }
352 std::vector<ObjectPtr> args = symbol.ListInputs(nnvm::Symbol::kReadOnlyArgs);
353 std::vector<NodeEntry> xs;
354 for (size_t i = 0; i < grad_req_types.size(); ++i) {
355 if (grad_req_types[i] != kNullOp) {
356 xs.emplace_back(args[i]);
357 }
358 }
359
360 int do_mirror = dmlc::GetEnv("MXNET_BACKWARD_DO_MIRROR", 0);
361 auto need_mirror = [do_mirror](const nnvm::Node& node) -> int {
362 if (node.is_variable()) return 0;
363 const std::string& type = node.attrs.op->name;
364 if (type == "Dropout") return false;
365 if (get_node_attr(node, "__force_mirroring__", false)) return true;
366 if (do_mirror == 0) return false;
367 if (type == "Convolution") return false;
368 if (type == "FullyConnected") return false;
369 if (type == "Concat") return false;
370 if (type == "SoftmaxOutput") return false;
371 return true;
372 };
373
374 std::vector<const nnvm::Op*> zero_ops;
375 zero_ops.push_back(nnvm::Op::Get("zeros_like"));
376 zero_ops.push_back(nnvm::Op::Get("_zeros"));
377
378 // take gradient
379 nnvm::Graph g_grad = nnvm::pass::MXGradient(
380 g, symbol.outputs, xs, head_grad_entry_,
381 AggregateGradient, need_mirror, nullptr,
382 zero_ops, "_copy");
383 CHECK_EQ(g_grad.outputs.size(), xs.size());
384 for (const auto &e : g_grad.outputs) {
385 g.outputs.push_back(e);
386 }
387
388 return g;
389 }
390
391 /*!
392 * \brief GraphExecutor initializer for regular bind flow in which
393 * input arguments and gradients are provided by users. This initializer
394 * uses the user provided NDArrays to populate data entries of the graph.
395 */
Init(nnvm::Symbol symbol,const Context & default_ctx,const std::map<std::string,Context> & ctx_map,const std::vector<NDArray> & in_args,const std::vector<NDArray> & arg_grad_store,const std::vector<OpReqType> & grad_req_types,const std::vector<NDArray> & aux_states,Executor * shared_exec,const nnvm::NodeEntryMap<NDArray> & feed_dict)396 void GraphExecutor::Init(nnvm::Symbol symbol,
397 const Context& default_ctx,
398 const std::map<std::string, Context>& ctx_map,
399 const std::vector<NDArray>& in_args,
400 const std::vector<NDArray>& arg_grad_store,
401 const std::vector<OpReqType>& grad_req_types,
402 const std::vector<NDArray>& aux_states,
403 Executor* shared_exec,
404 const nnvm::NodeEntryMap<NDArray>& feed_dict) {
405 // create in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes
406 auto get_ctx1 = [](const NDArray& nd) { return nd.ctx(); };
407 auto get_ctx2 = [default_ctx](const NDArray& nd) -> Context {
408 if (nd.is_none()) return default_ctx;
409 return nd.ctx();
410 };
411 std::vector<Context> in_arg_ctxes(in_args.size());
412 std::transform(in_args.begin(), in_args.end(), in_arg_ctxes.begin(), get_ctx1);
413 std::vector<Context> arg_grad_ctxes(arg_grad_store.size());
414 std::transform(arg_grad_store.begin(), arg_grad_store.end(), arg_grad_ctxes.begin(), get_ctx2);
415 std::vector<Context> aux_state_ctxes(aux_states.size());
416 std::transform(aux_states.begin(), aux_states.end(), aux_state_ctxes.begin(), get_ctx1);
417
418 nnvm::Graph g = InitGraph(symbol, default_ctx, ctx_map, in_arg_ctxes,
419 arg_grad_ctxes, aux_state_ctxes, grad_req_types);
420
421 // create arg_shapes and arg_dtypes for shape and type inferences
422 const auto& idx = g.indexed_graph();
423 const auto& mutable_nodes = idx.mutable_input_nodes();
424 size_t arg_top = 0, aux_top = 0;
425 data_entry_.resize(idx.num_node_entries());
426 mxnet::ShapeVector arg_shapes;
427 nnvm::DTypeVector arg_dtypes;
428 StorageTypeVector arg_stypes(idx.num_node_entries(), -1);
429 for (size_t i = 0; i < num_forward_inputs_; ++i) {
430 const uint32_t nid = idx.input_nodes().at(i);
431 const std::string& arg_name = idx[nid].source->attrs.name;
432 size_t eid = idx.entry_id(nid, 0);
433 if (mutable_nodes.count(nid)) {
434 CHECK_LT(aux_top, aux_states.size());
435 data_entry_[eid] = aux_states[aux_top];
436 arg_shapes.push_back(aux_states[aux_top].shape());
437 arg_dtypes.push_back(aux_states[aux_top].dtype());
438 arg_stypes[eid] = aux_states[aux_top].storage_type();
439 aux_state_map_.emplace(arg_name, aux_states[aux_top]);
440 ++aux_top;
441 } else {
442 CHECK_LT(arg_top, in_args.size());
443 data_entry_[eid] = in_args[arg_top];
444 arg_shapes.push_back(in_args[arg_top].shape());
445 arg_dtypes.push_back(in_args[arg_top].dtype());
446 arg_stypes[eid] = in_args[arg_top].storage_type();
447 in_arg_map_.emplace(arg_name, in_args[arg_top]);
448 if (kNullOp != grad_req_types[arg_top]) {
449 auto grad_oid = grad_store_.size() + num_forward_outputs_;
450 auto grad_eid = idx.entry_id(idx.outputs()[grad_oid]);
451 arg_stypes[grad_eid] = arg_grad_store[arg_top].storage_type();
452 grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_store[arg_top]);
453 arg_grad_map_.emplace(arg_name, arg_grad_store[arg_top]);
454 if (log_verbose_) {
455 LOG(INFO) << "\tassign data entry\t" << grad_eid << " as "
456 << common::stype_string(arg_stypes[grad_eid]) << " (grad)";
457 }
458 }
459 ++arg_top;
460 }
461 if (log_verbose_) {
462 LOG(INFO) << "\tassign data entry\t" << eid << " as "
463 << common::stype_string(data_entry_[eid].storage_type()) << " (input)";
464 }
465 }
466
467 // expand arg_shapes and arg_dtypes to contain backward inputs
468 arg_shapes.resize(idx.input_nodes().size(), mxnet::TShape());
469 g = InferShape(std::move(g), std::move(arg_shapes), "__shape__");
470 if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
471 this->is_dynamic_ = true;
472 }
473
474 arg_dtypes.resize(idx.input_nodes().size(), -1);
475 g = InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
476 if (g.GetAttr<size_t>("dtype_num_unknown_nodes") != 0U) {
477 HandleInferTypeError(num_forward_inputs_, g.indexed_graph(),
478 g.GetAttr<nnvm::DTypeVector>("dtype"));
479 }
480
481 g.attrs["storage_type"] = std::make_shared<dmlc::any>(std::move(arg_stypes));
482 g = InferStorageType(std::move(g), StorageTypeVector(), "");
483 if (g.GetAttr<size_t>("storage_type_num_unknown_nodes") != 0U) {
484 HandleInferStorageTypeError(num_forward_inputs_, g.indexed_graph(),
485 g.GetAttr<StorageTypeVector>("storage_type"));
486 }
487
488 // Initialize the rest attributes of the graph.
489 // This function can be called by regular bind
490 // operation flow as well.
491 FinishInitGraph(symbol, g, shared_exec, feed_dict);
492 }
493
494 /*!
495 * \brief Initialize in_args, arg_grads, and aux_states
496 * and their data_entry_ of the executor. This function
497 * is called for regular simple_bind flow, i.e. no
498 * shared data arrays are provided.
499 */
InitArguments(const nnvm::IndexedGraph & idx,const mxnet::ShapeVector & inferred_shapes,const nnvm::DTypeVector & inferred_dtypes,const StorageTypeVector & inferred_stypes,const std::vector<Context> & in_arg_ctxes,const std::vector<Context> & arg_grad_ctxes,const std::vector<Context> & aux_state_ctxes,const std::vector<OpReqType> & grad_req_types,std::vector<NDArray> * in_arg_vec,std::vector<NDArray> * arg_grad_vec,std::vector<NDArray> * aux_state_vec)500 void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx,
501 const mxnet::ShapeVector& inferred_shapes,
502 const nnvm::DTypeVector& inferred_dtypes,
503 const StorageTypeVector& inferred_stypes,
504 const std::vector<Context>& in_arg_ctxes,
505 const std::vector<Context>& arg_grad_ctxes,
506 const std::vector<Context>& aux_state_ctxes,
507 const std::vector<OpReqType>& grad_req_types,
508 std::vector<NDArray>* in_arg_vec,
509 std::vector<NDArray>* arg_grad_vec,
510 std::vector<NDArray>* aux_state_vec) {
511 // initialize in_args, arg_grads, and aux_states
512 // populate grad_store_
513 data_entry_.resize(idx.num_node_entries());
514 size_t arg_top = 0, aux_top = 0;
515 const auto& mutable_nodes = idx.mutable_input_nodes();
516 for (size_t i = 0; i < num_forward_inputs_; ++i) {
517 const uint32_t nid = idx.input_nodes().at(i);
518 const uint32_t eid = idx.entry_id(nid, 0);
519 const mxnet::TShape& inferred_shape = inferred_shapes[eid];
520 const int inferred_dtype = inferred_dtypes[eid];
521 const NDArrayStorageType inferred_stype = (NDArrayStorageType) inferred_stypes[eid];
522 const std::string& arg_name = idx[nid].source->attrs.name;
523 if (mutable_nodes.count(nid)) { // aux_states
524 EmplaceBackZeros(inferred_stype, inferred_shape, aux_state_ctxes[aux_top],
525 inferred_dtype, aux_state_vec);
526 data_entry_[eid] = aux_state_vec->back();
527 aux_state_map_.emplace(arg_name, aux_state_vec->back());
528 ++aux_top;
529 if (log_verbose_) {
530 LOG(INFO) << "\tassign aux entry\t" << eid << "\t as "
531 << common::stype_string(inferred_stype);
532 }
533 } else { // in_args
534 EmplaceBackZeros(inferred_stype, inferred_shape, in_arg_ctxes[arg_top],
535 inferred_dtype, in_arg_vec);
536 data_entry_[eid] = in_arg_vec->back();
537 if (log_verbose_) {
538 LOG(INFO) << "\tassign data entry\t" << eid << "\tas "
539 << common::stype_string(inferred_stype);
540 }
541 // Get the storage type for grad
542 if (kNullOp == grad_req_types[arg_top]) {
543 arg_grad_vec->emplace_back();
544 } else {
545 // Init based on storage type
546 auto grad_oid = grad_store_.size() + num_forward_outputs_;
547 auto grad_eid = idx.entry_id(idx.outputs()[grad_oid]);
548 auto grad_stype = (NDArrayStorageType) inferred_stypes[grad_eid];
549 EmplaceBackZeros(grad_stype, inferred_shape, arg_grad_ctxes[arg_top],
550 inferred_dtype, arg_grad_vec);
551 if (log_verbose_) {
552 LOG(INFO) << "\tassign grad entry\t" << grad_eid << "\tas "
553 << common::stype_string(grad_stype);
554 }
555 grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back());
556 arg_grad_map_.emplace(arg_name, arg_grad_vec->back());
557 }
558 in_arg_map_.emplace(arg_name, in_arg_vec->back());
559 ++arg_top;
560 }
561 }
562 }
563
564 /*!
565 * \brief Initialize in_args, arg_grads, and aux_states
566 * and their data_entry_ of the executor using
567 * shared_buffer from DataParallelExecutorGroup
568 * and shared_exec if available.
569 */
InitArguments(const nnvm::IndexedGraph & idx,const mxnet::ShapeVector & inferred_shapes,const nnvm::DTypeVector & inferred_dtypes,const StorageTypeVector & inferred_stypes,const std::vector<Context> & in_arg_ctxes,const std::vector<Context> & arg_grad_ctxes,const std::vector<Context> & aux_state_ctxes,const std::vector<OpReqType> & grad_req_types,const std::unordered_set<std::string> & shared_arg_names,const Executor * shared_exec,std::unordered_map<std::string,NDArray> * shared_buffer,std::vector<NDArray> * in_arg_vec,std::vector<NDArray> * arg_grad_vec,std::vector<NDArray> * aux_state_vec)570 void GraphExecutor::InitArguments(const nnvm::IndexedGraph& idx,
571 const mxnet::ShapeVector& inferred_shapes,
572 const nnvm::DTypeVector& inferred_dtypes,
573 const StorageTypeVector& inferred_stypes,
574 const std::vector<Context>& in_arg_ctxes,
575 const std::vector<Context>& arg_grad_ctxes,
576 const std::vector<Context>& aux_state_ctxes,
577 const std::vector<OpReqType>& grad_req_types,
578 const std::unordered_set<std::string>& shared_arg_names,
579 const Executor* shared_exec,
580 std::unordered_map<std::string, NDArray>* shared_buffer,
581 std::vector<NDArray>* in_arg_vec,
582 std::vector<NDArray>* arg_grad_vec,
583 std::vector<NDArray>* aux_state_vec) {
584 // initialize in_args, arg_grads, and aux_states and populate grad_store_
585 data_entry_.resize(idx.num_node_entries());
586 size_t arg_top = 0, aux_top = 0;
587 const auto& mutable_nodes = idx.mutable_input_nodes();
588 for (size_t i = 0; i < num_forward_inputs_; ++i) {
589 const uint32_t nid = idx.input_nodes().at(i);
590 const uint32_t eid = idx.entry_id(nid, 0);
591 const mxnet::TShape& inferred_shape = inferred_shapes[eid];
592 const int inferred_dtype = inferred_dtypes[eid];
593 const NDArrayStorageType inferred_stype = (NDArrayStorageType) inferred_stypes[eid];
594 const std::string& arg_name = idx[nid].source->attrs.name;
595 // aux_states
596 if (mutable_nodes.count(nid)) {
597 if (nullptr != shared_exec) {
598 const NDArray& aux_nd = shared_exec->aux_state_map().at(arg_name);
599 CHECK(inferred_stype == kDefaultStorage && aux_nd.storage_type() == kDefaultStorage)
600 << "Non-default storage type detected when creating auxilliary NDArray. The allocated "
601 << "memory of shared_exec.aux_array cannot be resued for argument: "
602 << arg_name << " for the current executor";
603 CHECK_EQ(inferred_shape, aux_nd.shape())
604 << "Inferred shape does not match shared_exec.aux_array's shape."
605 " Therefore, the allocated memory for shared_exec.aux_array cannot"
606 " be resued for creating auxilliary NDArray of the argument: "
607 << arg_name << " for the current executor";
608 CHECK_EQ(inferred_dtype, aux_nd.dtype())
609 << "Inferred dtype does not match shared_exec.aux_array's dtype."
610 " Therefore, the allocated memory for shared_exec.aux_array cannot"
611 " be resued for creating auxilliary NDArray of the argument: "
612 << arg_name << " for the current executor";
613 aux_state_vec->emplace_back(aux_nd);
614 } else {
615 EmplaceBackZeros(inferred_stype, inferred_shape, aux_state_ctxes[aux_top],
616 inferred_dtype, aux_state_vec);
617 } // if (has_shared_exec)
618 data_entry_[eid] = aux_state_vec->back();
619 aux_state_map_.emplace(arg_name, aux_state_vec->back());
620 ++aux_top;
621 } else { // in_args and grad for in_args
622 if (shared_arg_names.count(arg_name)) { // model parameter
623 // model parameter
624 if (nullptr != shared_exec) {
625 const NDArray& in_arg_nd = shared_exec->in_arg_map().at(arg_name);
626 auto arg_nd_stype = in_arg_nd.storage_type();
627 // for model parameter, both default storage and row_sparse storage can be shared
628 bool shareable_arg_stype = inferred_stype == kDefaultStorage ||
629 inferred_stype == kRowSparseStorage;
630 // try to reuse memory from shared_exec
631 CHECK(shareable_arg_stype) << "Inferred storage type "
632 << common::stype_string(inferred_stype)
633 << " does not support memory sharing with shared_exec.arg_array";
634 CHECK_EQ(inferred_stype, arg_nd_stype)
635 << "Inferred stype does not match shared_exec.arg_array's stype"
636 " Therefore, the allocated memory for shared_exec.arg_array cannot"
637 " be resued for creating NDArray of the argument "
638 << arg_name << " for the current executor";
639 CHECK_EQ(inferred_shape, in_arg_nd.shape())
640 << "Inferred shape does not match shared_exec.arg_array's shape"
641 " Therefore, the allocated memory for shared_exec.arg_array cannot"
642 " be resued for creating NDArray of the argument "
643 << arg_name << " for the current executor";
644 CHECK_EQ(inferred_dtype, in_arg_nd.dtype())
645 << "Inferred dtype does not match shared_exec.arg_array's dtype"
646 " Therefore, the allocated memory for shared_exec.arg_array cannot"
647 " be resued for creating NDArray of the argument "
648 << arg_name << " for the current executor";
649 in_arg_vec->emplace_back(in_arg_nd);
650 } else {
651 // doesn't have shared_exec, or non-default storage
652 EmplaceBackZeros(inferred_stype, inferred_shape, in_arg_ctxes[arg_top],
653 inferred_dtype, in_arg_vec);
654 }
655 // gradient for model parameter
656 if (kNullOp == grad_req_types[arg_top]) {
657 arg_grad_vec->emplace_back();
658 } else {
659 auto grad_oid = grad_store_.size() + num_forward_outputs_;
660 auto grad_eid = idx.entry_id(idx.outputs()[grad_oid]);
661 auto grad_stype = (NDArrayStorageType) inferred_stypes[grad_eid];
662 if (nullptr != shared_exec && grad_stype == kDefaultStorage &&
663 shared_exec->arg_grad_map().at(arg_name).storage_type() == kDefaultStorage) {
664 // try to reuse memory from shared_exec
665 arg_grad_vec->emplace_back(shared_exec->arg_grad_map().at(arg_name));
666 } else {
667 // no need to reuse memory from shared_exec for gradient of non-default storage
668 EmplaceBackZeros(grad_stype, inferred_shape, arg_grad_ctxes[arg_top],
669 inferred_dtype, arg_grad_vec);
670 }
671 grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back());
672 }
673 } else { // !shared_arg_names.count(arg_name)
674 // model parameter, row_sparse ndarray sharing enabled
675 bool enable_row_sparse_sharing = true;
676 in_arg_vec->emplace_back(ReshapeOrCreate(arg_name, inferred_shape, inferred_dtype,
677 inferred_stype, in_arg_ctxes[arg_top],
678 shared_buffer, enable_row_sparse_sharing));
679 // gradient for model parameter, row_sparse ndarray sharing disabled
680 if (kNullOp == grad_req_types[arg_top]) {
681 arg_grad_vec->emplace_back();
682 } else {
683 auto grad_oid = grad_store_.size() + num_forward_outputs_;
684 auto grad_eid = idx.entry_id(idx.outputs()[grad_oid]);
685 auto grad_stype = (NDArrayStorageType) inferred_stypes[grad_eid];
686 bool enable_row_sparse_sharing = false;
687 arg_grad_vec->emplace_back(ReshapeOrCreate("grad of " + arg_name, inferred_shape,
688 inferred_dtype, grad_stype,
689 arg_grad_ctxes[arg_top], shared_buffer,
690 enable_row_sparse_sharing));
691 grad_store_.emplace_back(grad_req_types[arg_top], arg_grad_vec->back());
692 } // if (kNullOp == grad_req_types[arg_top])
693 } // if (shared_arg_names.count(arg_name))
694 in_arg_map_.emplace(arg_name, in_arg_vec->back());
695 if (!arg_grad_vec->back().is_none()) {
696 arg_grad_map_.emplace(arg_name, arg_grad_vec->back());
697 }
698 data_entry_[eid] = in_arg_vec->back();
699 ++arg_top;
700 }
701 }
702 }
703
704 /*!
705 * \brief Finish graph initialization after shape and dtype inferences.
706 * This function is used by both simple_bind and bind flows.
707 */
FinishInitGraph(nnvm::Symbol symbol,nnvm::Graph g,Executor * shared_exec,const nnvm::NodeEntryMap<NDArray> & feed_dict)708 void GraphExecutor::FinishInitGraph(nnvm::Symbol symbol,
709 nnvm::Graph g,
710 Executor* shared_exec,
711 const nnvm::NodeEntryMap<NDArray>& feed_dict) {
712 const auto& idx = g.indexed_graph();
713 const auto& vstorage_type = g.GetAttr<StorageTypeVector>("storage_type");
714
715 // data entries for output gradients
716 for (size_t j = num_forward_outputs_; j < idx.outputs().size(); ++j) {
717 data_entry_[idx.entry_id(idx.outputs()[j])] = grad_store_[j - num_forward_outputs_].second;
718 }
719
720 {
721 // memory allocator
722 nnvm::StorageVector arg_storage_id(idx.num_node_entries(), kBadStorageID);
723 for (size_t j = num_forward_outputs_; j < idx.outputs().size(); ++j) {
724 arg_storage_id[idx.entry_id(idx.outputs()[j])] = kExternalStorageID;
725 }
726 for (const auto& kv : feed_dict) {
727 uint32_t eid = idx.entry_id(kv.first);
728 data_entry_[eid] = kv.second;
729 arg_storage_id[eid] = kExternalStorageID;
730 }
731 for (size_t i = 0; i < idx.num_node_entries(); i++) {
732 if (vstorage_type[i] != kDefaultStorage) arg_storage_id[i] = kDynamicStorageID;
733 }
734 g.attrs["storage"] = std::make_shared<dmlc::any>(std::move(arg_storage_id));
735 g = nnvm::ApplyPass(g, "MXPlanMemory");
736 }
737 g = DetectInplaceAddTo(g);
738
739 // log the static memory plan of the graph
740 static bool mem_log_verbose = dmlc::GetEnv("MXNET_MEM_PLAN_VERBOSE_LOGGING", false);
741 if (mem_log_verbose) {
742 common::LogMemoryPlan(g);
743 }
744
745 g = AttachOpExecs(g);
746 AttachOpResources(g);
747 graph_ = std::move(g);
748
749 if (shared_exec != nullptr) {
750 this->InitDataEntryMemory(&(dynamic_cast<GraphExecutor*>(shared_exec)->data_pool_));
751 } else {
752 this->InitDataEntryMemory(nullptr);
753 }
754
755 {
756 // initialize output arrays
757 auto& idx = graph_.indexed_graph();
758 for (size_t i = 0; i < num_forward_outputs_; ++i) {
759 auto& e = idx.outputs()[i];
760 output_arrays_.push_back(data_entry_[idx.entry_id(e)]);
761 }
762 // initialize head gradient array
763 head_grad_array_.resize(symbol.outputs.size());
764 for (size_t i = num_forward_inputs_; i < idx.input_nodes().size(); ++i) {
765 uint32_t nid = idx.input_nodes().at(i);
766 uint32_t oid = head_grad_map_.at(idx[nid].source);
767 head_grad_array_[oid] = data_entry_[idx.entry_id(nid, 0)];
768 }
769 }
770 this->InitCachedOps();
771 this->InitOpSegs();
772 }
773
774 /*!
775 * \brief GraphExecutor initializer for simple bind flow in
776 * which only certain input shapes and dtypes are provided by users.
777 * The initializer uses these shapes and dtypes to perform
778 * shape and dtype inferences, and then create NDArrays
779 * to populate data entries of the graph. The created NDArrays
780 * for in_args, arg_grads and aux_states are passed to the
781 * front end to attach the created executor.
782 * In front end, if the simple_bind flow is trigger by
783 * _bind_ith_exec, the shared data arrays of DataParallelExecutorGroup
784 * and shared executor will be taken into account in creating
785 * NDArrays for in_args, arg_grads, and aux_states for resuing
786 * already allocated memory.
787 */
Init(nnvm::Symbol symbol,const Context & default_ctx,const std::map<std::string,Context> & ctx_map,const std::vector<Context> & in_arg_ctxes,const std::vector<Context> & arg_grad_ctxes,const std::vector<Context> & aux_state_ctxes,const std::unordered_map<std::string,mxnet::TShape> & arg_shape_map,const std::unordered_map<std::string,int> & arg_dtype_map,const std::unordered_map<std::string,int> & arg_stype_map,const std::vector<OpReqType> & grad_req_types,const std::unordered_set<std::string> & shared_arg_names,std::vector<NDArray> * in_arg_vec,std::vector<NDArray> * arg_grad_vec,std::vector<NDArray> * aux_state_vec,std::unordered_map<std::string,NDArray> * shared_buffer,Executor * shared_exec,const nnvm::NodeEntryMap<NDArray> & feed_dict)788 void GraphExecutor::Init(nnvm::Symbol symbol,
789 const Context& default_ctx,
790 const std::map<std::string, Context>& ctx_map,
791 const std::vector<Context>& in_arg_ctxes,
792 const std::vector<Context>& arg_grad_ctxes,
793 const std::vector<Context>& aux_state_ctxes,
794 const std::unordered_map<std::string, mxnet::TShape>& arg_shape_map,
795 const std::unordered_map<std::string, int>& arg_dtype_map,
796 const std::unordered_map<std::string, int>& arg_stype_map,
797 const std::vector<OpReqType>& grad_req_types,
798 const std::unordered_set<std::string>& shared_arg_names,
799 std::vector<NDArray>* in_arg_vec,
800 std::vector<NDArray>* arg_grad_vec,
801 std::vector<NDArray>* aux_state_vec,
802 std::unordered_map<std::string, NDArray>* shared_buffer,
803 Executor* shared_exec,
804 const nnvm::NodeEntryMap<NDArray>& feed_dict) {
805 nnvm::Graph g = InitGraph(symbol, default_ctx, ctx_map, in_arg_ctxes, arg_grad_ctxes,
806 aux_state_ctxes, grad_req_types);
807
808 // The following code of shape and dtype inferences and argument
809 // initialization is for simple_bind only. Regular bind operation
810 // should do this differently.
811
812 // Initialize arg_shapes and arg_dtypes for shape and type inferences.
813 // It contains all in_args and aux_states' shapes and types in a certain order.
814 const nnvm::IndexedGraph& idx = g.indexed_graph();
815 mxnet::ShapeVector arg_shapes(idx.input_nodes().size(), mxnet::TShape());
816 nnvm::DTypeVector arg_dtypes(idx.input_nodes().size(), -1);
817 StorageTypeVector arg_stypes(idx.input_nodes().size(), kUndefinedStorage);
818 for (size_t i = 0; i < num_forward_inputs_; ++i) {
819 const uint32_t nid = idx.input_nodes().at(i);
820 const std::string& name = idx[nid].source->attrs.name;
821 auto it1 = arg_shape_map.find(name);
822 if (arg_shape_map.end() != it1) {
823 arg_shapes[i] = it1->second;
824 }
825 auto it2 = arg_dtype_map.find(name);
826 if (arg_dtype_map.end() != it2) {
827 arg_dtypes[i] = it2->second;
828 }
829 auto it3 = arg_stype_map.find(name);
830 if (arg_stype_map.end() != it3) {
831 arg_stypes[i] = it3->second;
832 }
833 }
834 g = InferShape(std::move(g), std::move(arg_shapes), "__shape__");
835 if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
836 HandleInferShapeError(num_forward_inputs_, g.indexed_graph(),
837 g.GetAttr<mxnet::ShapeVector>("shape"));
838 }
839
840 g = InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
841 if (g.GetAttr<size_t>("dtype_num_unknown_nodes") != 0U) {
842 HandleInferTypeError(num_forward_inputs_, g.indexed_graph(),
843 g.GetAttr<nnvm::DTypeVector>("dtype"));
844 }
845
846 g = InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__");
847 if (g.GetAttr<size_t>("storage_type_num_unknown_nodes") != 0U) {
848 HandleInferStorageTypeError(num_forward_inputs_, g.indexed_graph(),
849 g.GetAttr<StorageTypeVector>("storage_type"));
850 }
851
852 // Create in_args, arg_grads, and aux_states using
853 // the inferred shapes and dtypes.
854 if (nullptr == shared_buffer) { // regular simple bind
855 InitArguments(idx, g.GetAttr<mxnet::ShapeVector>("shape"),
856 g.GetAttr<nnvm::DTypeVector>("dtype"),
857 g.GetAttr<StorageTypeVector>("storage_type"),
858 in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes,
859 grad_req_types, in_arg_vec, arg_grad_vec, aux_state_vec);
860 } else { // simple bind using shared data arrays and shared_exec
861 InitArguments(idx, g.GetAttr<mxnet::ShapeVector>("shape"),
862 g.GetAttr<nnvm::DTypeVector>("dtype"),
863 g.GetAttr<StorageTypeVector>("storage_type"),
864 in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes,
865 grad_req_types, shared_arg_names, shared_exec,
866 shared_buffer, in_arg_vec, arg_grad_vec, aux_state_vec);
867 }
868 // The above code of shape and dtype inferences and argument
869 // initialization is for simple_bind only. Regular bind operation
870 // should do this differently.
871
872 // Initialize the rest attributes of the graph.
873 // This function can be called by regular bind
874 // operation flow as well.
875 FinishInitGraph(symbol, g, shared_exec, feed_dict);
876 }
877
878 /*!
879 * \brief Return a new executor with the same symbol and shared memory,
880 * but different input/output shapes.
881 * For runtime reshaping, variable length sequences, etc.
882 * The returned executor shares state with the current one,
883 * and cannot be used in parallel with it.
884 */
Reshape(const bool partial_shaping,const bool allow_up_sizing,const Context & default_ctx,const std::map<std::string,Context> & ctx_map,const std::unordered_map<std::string,mxnet::TShape> & provided_arg_shapes,std::vector<NDArray> * in_args,std::vector<NDArray> * arg_grads,std::vector<NDArray> * aux_states)885 Executor* GraphExecutor::Reshape(const bool partial_shaping,
886 const bool allow_up_sizing,
887 const Context& default_ctx,
888 const std::map<std::string, Context>& ctx_map,
889 const std::unordered_map<std::string, mxnet::TShape>&
890 provided_arg_shapes,
891 std::vector<NDArray>* in_args,
892 std::vector<NDArray>* arg_grads,
893 std::vector<NDArray>* aux_states) {
894 nnvm::Graph g;
895 nnvm::Symbol symbol;
896 symbol.outputs = symbol_.outputs;
897 g.outputs = symbol_.outputs;
898 const nnvm::IndexedGraph& idx = g.indexed_graph();
899 mxnet::ShapeVector arg_shapes(idx.input_nodes().size(), mxnet::TShape());
900 for (size_t i = 0; i < num_forward_inputs_; ++i) {
901 const uint32_t nid = idx.input_nodes().at(i);
902 const std::string& name = idx[nid].source->attrs.name;
903 auto it = provided_arg_shapes.find(name);
904 if (provided_arg_shapes.end() != it) {
905 arg_shapes[i] = it->second;
906 }
907 }
908 g = InferShape(std::move(g), std::move(arg_shapes), "__shape__");
909 if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
910 this->is_dynamic_ = true;
911 }
912 const mxnet::ShapeVector& shape_vec = g.GetAttr<mxnet::ShapeVector>("shape");
913 std::vector<OpReqType> grad_req_types;
914 size_t grad_top = 0;
915 const size_t num_args = in_arg_map_.size();
916 const size_t num_aux = aux_state_map_.size();
917 in_args->reserve(num_args);
918 grad_req_types.reserve(num_args);
919 arg_grads->reserve(num_args);
920 aux_states->reserve(num_aux);
921 for (uint32_t nid : idx.input_nodes()) {
922 std::string name = idx[nid].source->attrs.name;
923 const mxnet::TShape& new_shape = shape_vec[idx.entry_id(nid, 0)];
924 if (idx.mutable_input_nodes().count(nid) == 0) {
925 NDArray& arr = in_arg_map_.at(name);
926 auto it = arg_grad_map_.find(name);
927 if (partial_shaping || provided_arg_shapes.count(name) || new_shape == arr.shape()) {
928 if (new_shape.Size() > arr.shape().Size()) {
929 CHECK(allow_up_sizing) << "New shape of arg: " << name << " is larger than original."
930 << "First making a big executor and then down sizing it "
931 << "is more efficient than the reverse."
932 << "If you really want to up size, set allow_up_sizing=True "
933 << "to enable allocation of new arrays.";
934 in_args->emplace_back(new_shape, arr.ctx(), false, arr.dtype());
935 if (it != arg_grad_map_.end()) {
936 NDArray& darr = it->second;
937 arg_grads->emplace_back(new_shape, darr.ctx(), false, darr.dtype());
938 grad_req_types.push_back(grad_store_.at(grad_top++).first);
939 } else {
940 arg_grads->emplace_back();
941 grad_req_types.push_back(kNullOp);
942 }
943 } else {
944 in_args->push_back(arr.Reshape(new_shape));
945 if (it != arg_grad_map_.end()) {
946 NDArray& darr = it->second;
947 arg_grads->push_back(darr.Reshape(new_shape));
948 grad_req_types.push_back(grad_store_.at(grad_top++).first);
949 } else {
950 arg_grads->emplace_back();
951 grad_req_types.push_back(kNullOp);
952 }
953 }
954 } else {
955 LOG(FATAL) << "Shape of unspecifie arg: " << name << " changed. "
956 << "This can cause the new executor to not share parameters "
957 << "with the old one. Please check for error in network."
958 << "If this is intended, set partial_shaping=True to suppress this warning.";
959 }
960 } else {
961 NDArray& arr = aux_state_map_.at(name);
962 if (partial_shaping || new_shape == arr.shape()) {
963 if (new_shape.Size() > arr.shape().Size()) {
964 CHECK(allow_up_sizing) << "New shape of arg: " << name << " is larger than original."
965 << "First making a big executor and then down sizing it "
966 << "is more efficient than the reverse."
967 << "If you really want to up size, set allow_up_sizing=True "
968 << "to enable allocation of new arrays.";
969 aux_states->emplace_back(new_shape, arr.ctx(), false, arr.dtype());
970 } else {
971 aux_states->push_back(arr.Reshape(new_shape));
972 }
973 } else {
974 LOG(FATAL) << "Shape of unspecifie arg: " << name << " changed. "
975 << "This can cause the new executor to not share parameters "
976 << "with the old one. Please check for error in network."
977 << "If this is intended, set partial_shaping=True to suppress this warning.";
978 }
979 }
980 }
981 auto exec = new GraphExecutor(symbol);
982 exec->Init(symbol.Copy(), default_ctx, ctx_map,
983 *in_args, *arg_grads, grad_req_types, *aux_states,
984 this);
985 return exec;
986 }
987
988 /*!
989 * \brief This function is triggered by both simple_bind
990 * and bind flows.
991 * Setup backward graph, create device and context
992 * attributes in the graph, and calculate the number
993 * of forward nodes.
994 */
InitGraph(nnvm::Symbol symbol,const Context & default_ctx,const std::map<std::string,Context> & ctx_map,const std::vector<Context> & in_arg_ctxes,const std::vector<Context> & arg_grad_ctxes,const std::vector<Context> & aux_state_ctxes,const std::vector<OpReqType> & grad_req_types)995 Graph GraphExecutor::InitGraph(nnvm::Symbol symbol,
996 const Context& default_ctx,
997 const std::map<std::string, Context>& ctx_map,
998 const std::vector<Context>& in_arg_ctxes,
999 const std::vector<Context>& arg_grad_ctxes,
1000 const std::vector<Context>& aux_state_ctxes,
1001 const std::vector<OpReqType>& grad_req_types) {
1002 // setup gradient
1003 nnvm::Graph g = InitFullGraph(symbol, grad_req_types);
1004
1005 #if MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC && !defined(_WIN32)
1006 if (default_ctx.dev_mask() == Context::kGPU && dmlc::GetEnv("MXNET_USE_FUSION", true)) {
1007 nnvm::Graph unoptimized_graph;
1008 common::CopyGraph(&unoptimized_graph, g, false);
1009
1010 if (common::CheckForInputNameDuplicates(unoptimized_graph.indexed_graph())) {
1011 g = exec::FusePointwise(std::move(g), num_forward_outputs_);
1012 // Check the topological order of inputs
1013 const auto &original_inputs = unoptimized_graph.indexed_graph().input_nodes();
1014 const auto &new_inputs = g.indexed_graph().input_nodes();
1015 if (original_inputs.size() != new_inputs.size()) {
1016 LOG(WARNING)
1017 << "Number of inputs after fusion does not match original number of inputs. "
1018 << "This is most probably a bug. Disabling fusion for this run.";
1019 g = unoptimized_graph;
1020 } else {
1021 for (size_t i = 0; i < new_inputs.size(); ++i) {
1022 if (unoptimized_graph.indexed_graph()[original_inputs[i]].source->attrs.name !=
1023 g.indexed_graph()[new_inputs[i]].source->attrs.name) {
1024 LOG(WARNING) << "Disabling fusion due to altered topological order of inputs.";
1025 g = unoptimized_graph;
1026 break;
1027 }
1028 }
1029 }
1030 } else {
1031 LOG(WARNING)
1032 << "Graph contains duplicate names for some of its inputs - fusion is NOT enabled!";
1033 }
1034 }
1035 #else
1036 // Only warn user if MXNET_USE_FUSION env var is explicitly set
1037 if (default_ctx.dev_mask() == Context::kGPU && dmlc::GetEnv("MXNET_USE_FUSION", false)) {
1038 WarnFusionNotSupported();
1039 }
1040 #endif // MXNET_USE_CUDA && MXNET_ENABLE_CUDA_RTC && !defined(_WIN32)
1041
1042 // create "device" and "context" attrs for the graph
1043 g = AssignContext(g, default_ctx, ctx_map,
1044 in_arg_ctxes,
1045 arg_grad_ctxes,
1046 aux_state_ctxes,
1047 grad_req_types,
1048 num_forward_inputs_,
1049 num_forward_outputs_);
1050
1051 const auto& idx = g.indexed_graph();
1052 // get number of nodes used in forward pass
1053 num_forward_nodes_ = 0;
1054 for (size_t i = 0; i < num_forward_outputs_; ++i) {
1055 num_forward_nodes_ = std::max(
1056 num_forward_nodes_, static_cast<size_t>(idx.outputs()[i].node_id + 1));
1057 }
1058 return g;
1059 }
1060
1061 // initialize the memory of each entries
InitDataEntryMemory(std::vector<NDArray> * shared_pool)1062 void GraphExecutor::InitDataEntryMemory(std::vector<NDArray>* shared_pool) {
1063 using nnvm::DTypeVector;
1064 using mxnet::ShapeVector;
1065 using nnvm::StorageVector;
1066 // get the graph
1067 const auto& idx = graph_.indexed_graph();
1068 // get the storage
1069 const auto& vdtype = graph_.GetAttr<DTypeVector>("dtype");
1070 const auto& vshape = graph_.GetAttr<mxnet::ShapeVector>("shape");
1071 const auto& vstorage = graph_.GetAttr<StorageVector>("storage_id");
1072 const auto& vstorage_type = graph_.GetAttr<StorageTypeVector>("storage_type");
1073 const auto& vctx = graph_.GetAttr<ContextVector>("context");
1074 CHECK_EQ(idx.num_node_entries(), vshape.size());
1075 CHECK_EQ(idx.num_node_entries(), vdtype.size());
1076 CHECK_EQ(idx.num_node_entries(), vstorage.size());
1077 CHECK_EQ(data_entry_.size(), vshape.size());
1078 std::vector<Context> data_context(idx.num_node_entries());
1079 std::vector<NDArrayStorageType> data_storage_type(idx.num_node_entries(), kUndefinedStorage);
1080 for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
1081 for (uint32_t i = 0; i < idx[nid].source->num_outputs(); ++i) {
1082 auto eid = idx.entry_id(nid, i);
1083 data_context[eid] = vctx[nid];
1084 CHECK_NE(vstorage_type[eid], kUndefinedStorage);
1085 data_storage_type[eid] = (NDArrayStorageType) vstorage_type[eid];
1086 }
1087 }
1088
1089 // information about the pool
1090 struct PoolEntry {
1091 Context ctx;
1092 size_t bytes;
1093 NDArrayStorageType stype;
1094 };
1095 std::vector<PoolEntry> pool_info;
1096
1097 // assign array to head gradient
1098 for (size_t i = num_forward_inputs_; i < idx.input_nodes().size(); ++i) {
1099 uint32_t nid = idx.input_nodes().at(i);
1100 uint32_t oid = head_grad_map_.at(idx[nid].source);
1101 uint32_t eid = idx.entry_id(idx.outputs()[oid]);
1102 NDArrayStorageType stype = (NDArrayStorageType) vstorage_type[eid];
1103 bool unknown_shape = !shape_is_known(vshape[eid]);
1104 CHECK_NE(vdtype[eid], -1);
1105 auto data_eid = idx.entry_id(nid, 0);
1106 // initialize based on storage_type
1107 if (stype != kDefaultStorage) {
1108 data_entry_[data_eid] = NDArray(stype, vshape[eid], data_context[eid], true, vdtype[eid]);
1109 } else if (!unknown_shape) {
1110 data_entry_[data_eid] = NDArray(vshape[eid], data_context[eid], false, vdtype[eid]);
1111 } else {
1112 data_entry_[data_eid] = NDArray(data_context[eid], vdtype[eid]);
1113 }
1114 if (log_verbose_) {
1115 LOG(INFO) << "\tinit head_grad entry\t" << data_eid << "\tas "
1116 << common::stype_string(stype);
1117 }
1118 }
1119 // get maximum bytes in each pool
1120 for (size_t i = 0; i < vshape.size(); ++i) {
1121 if (!data_entry_[i].is_none()) continue;
1122 size_t shape_size = 0;
1123 if (shape_is_known(vshape[i])) {
1124 shape_size = vshape[i].Size();
1125 }
1126 size_t bytes = shape_size * mshadow::mshadow_sizeof(vdtype[i]);
1127 int storage_id = vstorage[i];
1128 // skip pool allocation for kBadStorageID, kExternalStorageID and kDynamicStorageID
1129 if (storage_id < 0) continue;
1130 size_t sid = static_cast<size_t>(storage_id);
1131 if (sid >= pool_info.size()) {
1132 pool_info.resize(sid + 1, PoolEntry{Context::CPU(), size_t(0), kUndefinedStorage});
1133 }
1134 PoolEntry& info = pool_info[sid];
1135 if (info.bytes == 0) {
1136 info = PoolEntry{data_context[i], bytes, data_storage_type[i]};
1137 } else {
1138 info.bytes = std::max(info.bytes, bytes);
1139 }
1140 }
1141 // construct the re-use pool, if needed
1142 std::multimap<size_t, NDArray> free_pool;
1143 if (shared_pool != nullptr) {
1144 for (const NDArray& nd : *shared_pool) {
1145 size_t bytes = 0;
1146 if (shape_is_known(nd.shape())) {
1147 bytes = nd.shape().Size() * mshadow::mshadow_sizeof(nd.dtype());
1148 }
1149 free_pool.insert(std::make_pair(bytes, nd));
1150 }
1151 }
1152 // remake the data pool
1153 data_pool_.clear();
1154 data_pool_.resize(pool_info.size());
1155
1156 // sort the pool info the descending order before allocating memory
1157 std::vector<size_t> sorted_pool_index;
1158 for (size_t i = 0; i < pool_info.size(); i++) {
1159 sorted_pool_index.push_back(i);
1160 }
1161 auto pool_comparator = [&pool_info](size_t lhs, size_t rhs){
1162 return pool_info[lhs].bytes > pool_info[rhs].bytes;
1163 };
1164 std::sort(sorted_pool_index.begin(), sorted_pool_index.end(), pool_comparator);
1165
1166 for (size_t i : sorted_pool_index) {
1167 const Context& ctx = pool_info[i].ctx;
1168 size_t bytes = pool_info[i].bytes;
1169 bool allocated = false;
1170 for (auto it = free_pool.lower_bound(bytes); it != free_pool.end(); ++it) {
1171 if (it->second.ctx() == ctx && it->first >= bytes) {
1172 data_pool_[i] = it->second;
1173 free_pool.erase(it);
1174 allocated = true;
1175 break;
1176 }
1177 }
1178 if (!allocated) {
1179 size_t nword = (bytes + 3) / 4;
1180 CHECK_LE(nword, std::numeric_limits<nnvm::dim_t>::max());
1181 // allocate float arrays
1182 mxnet::TShape shape{static_cast<nnvm::dim_t>(nword)};
1183 // TODO(junwu): adding delay_alloc=true to create nd
1184 // is a temporary solution.
1185 NDArray nd(shape, ctx, true);
1186 data_pool_[i] = nd;
1187 // put the new allocated arrays to shared pool
1188 if (shared_pool != nullptr) {
1189 shared_pool->push_back(nd);
1190 }
1191 }
1192 }
1193 CHECK_EQ(data_pool_.size(), pool_info.size());
1194 // assign the data entries
1195 for (size_t i = 0; i < data_entry_.size(); ++i) {
1196 // avoid pre-allocated arrays
1197 if (!data_entry_[i].is_none()) continue;
1198 // assign allocated array by storage id
1199 int storage_id = vstorage[i];
1200 auto storage_type = (NDArrayStorageType) vstorage_type[i];
1201 if (storage_type == kDefaultStorage) {
1202 if (!shape_is_known(vshape[i])) {
1203 data_entry_[i] = NDArray(data_context[i], vdtype[i]);
1204 } else {
1205 CHECK_GE(storage_id, 0) << "Do not support runtime shape op yet";
1206 const NDArray& src = data_pool_.at(storage_id);
1207 data_entry_[i] = src.AsArray(vshape[i], vdtype[i]);
1208 }
1209 } else {
1210 data_entry_[i] = NDArray(storage_type, vshape[i], data_context[i],
1211 true, vdtype[i]);
1212 }
1213 if (log_verbose_) {
1214 LOG(INFO) << "\tinit data entry\t" << i << "\tas " << common::stype_string(storage_type);
1215 }
1216 }
1217 }
1218
1219
InitCachedOps()1220 void GraphExecutor::InitCachedOps() {
1221 // get the graph
1222 const auto& idx = graph_.indexed_graph();
1223 const auto& vstorage_inplace =
1224 graph_.GetAttr<std::vector<int> >("storage_inplace_index");
1225 const auto& op_execs =
1226 graph_.GetAttr<OpExecVector>("op_execs");
1227 const auto& vctx = graph_.GetAttr<ContextVector>("context");
1228 const auto& addto_entry = graph_.GetAttr<std::vector<int> >("addto_entry");
1229 const auto& skip_plus_node = graph_.GetAttr<std::vector<int> >("skip_plus_node");
1230
1231 op_nodes_.resize(idx.num_nodes());
1232 // setup the array and requirements.
1233 for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
1234 const auto& inode = idx[nid];
1235 if (inode.source->is_variable()) continue;
1236 op_nodes_[nid].opr_name = inode.source->op()->name.c_str();
1237 if (skip_plus_node.at(nid)) {
1238 op_nodes_[nid].skip_exec_node = true; continue;
1239 }
1240
1241 op_nodes_[nid].exec = op_execs[nid];
1242 op_nodes_[nid].ctx = vctx[nid];
1243 auto& exec = op_nodes_[nid].exec;
1244 CHECK_EQ(exec->in_array.size(), 0U);
1245 CHECK_EQ(exec->out_array.size(), 0U);
1246 for (const auto& e : inode.inputs) {
1247 exec->in_array.push_back(data_entry_[idx.entry_id(e)]);
1248 }
1249 // detect inplace requirement
1250 for (uint32_t index = 0; index < inode.source->num_outputs(); ++index) {
1251 uint32_t eid = idx.entry_id(nid, index);
1252 exec->out_array.push_back(data_entry_[eid]);
1253 if (addto_entry.at(eid) != 0) {
1254 exec->req.push_back(kAddTo);
1255 } else if (vstorage_inplace[eid] >= 0) {
1256 exec->req.push_back(kWriteInplace);
1257 } else if (vstorage_inplace[eid] == -2) {
1258 // -2 indicate that the entry is never referenced.
1259 exec->req.push_back(kNullOp);
1260 } else {
1261 exec->req.push_back(kWriteTo);
1262 }
1263 }
1264 }
1265 // Note that this modifies the requirement of kWriteInplace
1266 for (size_t j = num_forward_outputs_; j < idx.outputs().size(); ++j) {
1267 auto& e = idx.outputs()[j];
1268 op_nodes_[e.node_id].exec->req[e.index] =
1269 grad_store_[j - num_forward_outputs_].first;
1270 }
1271 for (uint32_t nid = 0; nid < idx.num_nodes(); ++nid) {
1272 const auto& inode = idx[nid];
1273 if (inode.source->is_variable()) continue;
1274 if (op_nodes_[nid].skip_exec_node) continue;
1275 auto& exec = op_nodes_[nid].exec;
1276 bool is_async = op_nodes_[nid].exec->exec_type() == ExecType::kAsync;
1277 bool is_gpu = op_nodes_[nid].ctx.dev_mask() == gpu::kDevMask;
1278
1279 // the variables
1280 std::vector<Engine::VarHandle> use_vars, mutate_vars;
1281 for (const auto& nd : exec->in_array) {
1282 use_vars.push_back(nd.var());
1283 }
1284 for (const auto& r : exec->op_ctx.requested) {
1285 mutate_vars.push_back(r.var);
1286 }
1287 for (const auto& nd : exec->out_array) {
1288 mutate_vars.push_back(nd.var());
1289 }
1290 if (exec->var() != nullptr) {
1291 mutate_vars.push_back(exec->var());
1292 }
1293 // dedup vars
1294 Engine::Get()->DeduplicateVarHandle(&use_vars, &mutate_vars);
1295 // all vars include both mutate vars and use vars
1296 std::vector<Engine::VarHandle> all_vars(use_vars);
1297 std::copy(mutate_vars.begin(), mutate_vars.end(),
1298 std::inserter(all_vars, all_vars.end()));
1299 // setup exec vars
1300 Engine::Get()->PushAsync(
1301 [exec](RunContext rctx, Engine::CallbackOnComplete on_complete) {
1302 exec->Setup();
1303 on_complete();
1304 }, Context::CPU(), {}, all_vars, FnProperty::kNormal, 0,
1305 "SetupExec");
1306 auto exec_fun = [exec, is_async, is_gpu] (
1307 RunContext ctx, Engine::CallbackOnComplete on_complete) {
1308 if (is_async) {
1309 exec->op_ctx.async_on_complete = on_complete;
1310 }
1311 exec->Run(ctx, is_gpu);
1312 // call on complete only if it is async op
1313 if (!is_async) {
1314 if (is_gpu) {
1315 #if MXNET_USE_CUDA
1316 // Wait GPU kernel to finish.
1317 ctx.get_stream<gpu>()->Wait();
1318 #else
1319 LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
1320 #endif
1321 }
1322 on_complete();
1323 }
1324 };
1325 // setup the vars
1326 op_nodes_[nid].cached_opr = Engine::Get()->NewOperator(
1327 exec_fun, use_vars, mutate_vars, FnProperty::kNormal,
1328 op_nodes_[nid].opr_name);
1329 op_nodes_[nid].mutate_vars = mutate_vars;
1330 op_nodes_[nid].use_vars = use_vars;
1331 }
1332 }
1333
InitOpSegs()1334 void GraphExecutor::InitOpSegs() {
1335 size_t total_num_nodes = graph_.indexed_graph().num_nodes();
1336 cached_seg_opr_.clear();
1337 CachedSegOpr p;
1338 cached_seg_opr_.resize(total_num_nodes, p);
1339 if (monitor_callback_) return;
1340
1341 // Symbolic bulking is set by the same environment variables as Imperative bulking.
1342 // Generate segments based on the graph structure
1343 bool prefer_bulk_exec_inference = Imperative::PreferBulkExecInference();
1344 // Whether to perform bulk exec for training
1345 const profiler::Profiler *prof = profiler::Profiler::Get();
1346 bool prefer_bulk_exec_train = Imperative::PreferBulkExecTrain()
1347 && (!prof || !prof->AggregateEnabled());
1348 if (this->is_dynamic_) {
1349 prefer_bulk_exec_inference = false;
1350 prefer_bulk_exec_train = false;
1351 }
1352 bool is_training = num_forward_nodes_ != total_num_nodes;
1353
1354 if (prefer_bulk_exec_train && is_training) {
1355 // Bulk the forward portion of the graph per the bulk segment max size for forward training
1356 this->BulkOpSegs(0, num_forward_nodes_, Imperative::BulkExecMaxNodeTrainFwd());
1357 // Bulk the backward portion of the graph per the bulk segment max size for backward training
1358 this->BulkOpSegs(num_forward_nodes_, total_num_nodes, Imperative::BulkExecMaxNodeTrainBwd());
1359 }
1360
1361 if (prefer_bulk_exec_inference && !is_training) {
1362 // Bulk the entire graph as one bulk segment if possible
1363 this->BulkOpSegs(0, total_num_nodes, total_num_nodes);
1364 }
1365 }
1366
1367
BulkOpSegs(size_t from_node,size_t up_to_node,size_t segment_num_nodes_max)1368 void GraphExecutor::BulkOpSegs(size_t from_node, size_t up_to_node, size_t segment_num_nodes_max) {
1369 size_t topo_start = from_node;
1370 size_t segment_node_count = 0;
1371 for (size_t nid = from_node; nid < up_to_node; nid++) {
1372 auto &node = graph_.indexed_graph()[nid].source;
1373 auto &op_node = op_nodes_[nid];
1374 // Variables, such as learned weights, are ignored in the segment_node_count
1375 bool ignore_node = node->is_variable() || op_node.skip_exec_node || op_node.exec == nullptr;
1376 if (!ignore_node)
1377 segment_node_count++;
1378 bool can_bulk = ignore_node || op_node.exec->exec_type() == ExecType::kSync;
1379 // check if we need to create the segment based on properties of this node
1380 if (!can_bulk || nid == up_to_node - 1 || segment_node_count >= segment_num_nodes_max) {
1381 // Create a new segment for the previous nodes- include also this node if it's bulkable
1382 cached_seg_opr_[topo_start] = this->CreateCachedSegOpr(topo_start, can_bulk ? nid + 1 : nid);
1383 topo_start = nid + 1;
1384 segment_node_count = 0;
1385 }
1386 }
1387 }
1388
ExecuteMonInputCallback(size_t nid)1389 void GraphExecutor::ExecuteMonInputCallback(size_t nid) {
1390 static const auto& flist_inputs =
1391 nnvm::Op::GetAttr<nnvm::FListInputNames>("FListInputNames");
1392 const auto& idx = graph_.indexed_graph();
1393 std::vector<std::string> input_names;
1394 OpNode& opnode = op_nodes_[nid];
1395 const auto& inode = idx[nid];
1396 const auto& node = idx[nid].source;
1397 if (flist_inputs.count(node->op())) {
1398 input_names = flist_inputs[node->op()](node->attrs);
1399 } else {
1400 for (size_t i = 0; i < node->num_inputs(); ++i) {
1401 input_names.emplace_back("input" + std::to_string(i));
1402 }
1403 }
1404 CHECK_EQ(opnode.exec->in_array.size(), input_names.size());
1405 for (size_t i = 0; i < opnode.exec->in_array.size(); ++i) {
1406 if (node->inputs[i].node->is_variable()) {
1407 // Monitor variable
1408 NDArray *cpy = new NDArray(opnode.exec->in_array[i]);
1409 std::string name = node->inputs[i].node->attrs.name;
1410 this->monitor_callback_(name.c_str(), reinterpret_cast<void*>(cpy));
1411 }
1412 NDArray *cpy = new NDArray(opnode.exec->in_array[i]);
1413 std::string name = inode.source->attrs.name + "_" + input_names[i];
1414 this->monitor_callback_(name.c_str(), reinterpret_cast<void*>(cpy));
1415 }
1416 }
1417
ExecuteMonOutputCallback(size_t nid)1418 void GraphExecutor::ExecuteMonOutputCallback(size_t nid) {
1419 const auto& idx = graph_.indexed_graph();
1420 OpNode& opnode = op_nodes_[nid];
1421 const auto& node = idx[nid].source;
1422 for (size_t i = 0; i < opnode.exec->out_array.size(); ++i) {
1423 NDArray *cpy = new NDArray(opnode.exec->out_array[i]);
1424 nnvm::ObjectPtr node_ptr = std::make_shared<nnvm::Node>(*node);
1425 std::string name = GetOutputName({node_ptr, static_cast<uint32_t >(i), 0});
1426 this->monitor_callback_(name.c_str(), reinterpret_cast<void*>(cpy));
1427 }
1428 }
1429
RunOps(bool is_train,size_t topo_start,size_t topo_end)1430 void GraphExecutor::RunOps(bool is_train, size_t topo_start, size_t topo_end) {
1431 static auto& finfer_shape = nnvm::Op::GetAttr<mxnet::FInferShape>("FInferShape");
1432 static auto& is_backward = Op::GetAttr<nnvm::TIsBackward>("TIsBackward");
1433 // Update context
1434 const auto& idx = graph_.indexed_graph();
1435 for (size_t nid = topo_start; nid < topo_end; ++nid) {
1436 OpNode& opnode = op_nodes_[nid];
1437 if (opnode.skip_exec_node) continue;
1438 const auto& inode = idx[nid];
1439 if (inode.source->is_variable()) continue;
1440 opnode.exec->op_ctx.is_train = is_train;
1441 opnode.exec->op_ctx.need_grad = need_grad_;
1442 }
1443
1444 mxnet::ShapeVector rshape = graph_.MoveCopyAttr<mxnet::ShapeVector>("shape");
1445 // Push Ops
1446 for (size_t nid = topo_start; nid < topo_end; ++nid) {
1447 auto seg_op = cached_seg_opr_[nid];
1448 // Check segments first
1449 if (monitor_callback_ == nullptr && seg_op.opr != nullptr && seg_op.topo_end <= topo_end) {
1450 bool profiling = profiler::Profiler::Get()->GetState() == profiler::Profiler::kRunning;
1451 Engine::Get()->Push(seg_op.opr, seg_op.ctx, 0, profiling);
1452 nid = seg_op.topo_end - 1;
1453 continue;
1454 }
1455 // Normal mode
1456 const auto& inode = idx[nid];
1457 const uint32_t num_inputs = inode.inputs.size();
1458 const uint32_t num_outputs = inode.source->num_outputs();
1459 if (inode.source->is_variable()) continue;
1460 OpNode& opnode = op_nodes_[nid];
1461 if (op_nodes_[nid].skip_exec_node) continue;
1462 // Monitor callbacks
1463 if (monitor_callback_ && monitor_all_) {
1464 ExecuteMonInputCallback(nid);
1465 }
1466 if (this->is_dynamic_) {
1467 const auto &op = inode.source->op();
1468 {
1469 for (NDArray &array : opnode.exec->in_array) {
1470 array.WaitToRead();
1471 if (!shape_is_known(array.shape())) {
1472 array.SetShapeFromChunk();
1473 }
1474 }
1475 int i = 0;
1476 for (NDArray &array : opnode.exec->out_array) {
1477 array.WaitToRead();
1478 if (!shape_is_known(array.shape())) {
1479 array.SetShapeFromChunk();
1480 }
1481 if (!shape_is_known(array.shape())) {
1482 mxnet::TShape shape = rshape[idx.entry_id(nid, i)];
1483 if (shape_is_known(shape)) {
1484 array.ReshapeAndAlloc(shape);
1485 }
1486 }
1487 ++i;
1488 }
1489 }
1490 if (finfer_shape.count(op)) {
1491 mxnet::ShapeVector in_shapes;
1492 mxnet::ShapeVector out_shapes;
1493 for (NDArray &array : opnode.exec->in_array) {
1494 in_shapes.push_back(array.shape());
1495 }
1496 for (NDArray &array : opnode.exec->out_array) {
1497 out_shapes.push_back(array.shape());
1498 }
1499 auto finfer = finfer_shape[op];
1500 try {
1501 bool success = finfer(inode.source->attrs, &in_shapes, &out_shapes);
1502 CHECK(success) << "InferShape failed in operator " << inode.source->attrs.name;
1503 } catch (const std::exception& e) {
1504 throw dmlc::Error("Error in operator " + inode.source->attrs.name + ": " + e.what());
1505 }
1506 int n_out = out_shapes.size();
1507 for (int i = 0; i < n_out; ++i) {
1508 NDArray &array = opnode.exec->out_array[i];
1509 if (!shape_is_known(array.shape())) {
1510 array.Init(out_shapes[i]);
1511 }
1512 }
1513 } else if (is_backward.get(inode.source->op(), false) && inode.control_deps.size()) {
1514 CHECK_GE(inode.control_deps.size(), 1U) <<
1515 "BackwardOp need to have control_deps to its forward op";
1516 uint32_t fid = inode.control_deps[0];
1517 const OpNode& fopnode = op_nodes_[fid];
1518 CHECK_EQ(fopnode.exec->in_array.size(), opnode.exec->out_array.size());
1519 int nelem = fopnode.exec->in_array.size();
1520 std::vector<NDArray> &from = fopnode.exec->in_array;
1521 std::vector<NDArray> &to = opnode.exec->out_array;
1522 for (int i = 0; i < nelem; ++i) {
1523 if (!shape_is_known(to[i].shape())) {
1524 to[i].Init(from[i].shape());
1525 }
1526 }
1527 }
1528 }
1529 opnode.exec->op_ctx.is_train = is_train;
1530 opnode.exec->op_ctx.need_grad = need_grad_;
1531 if (opnode.exec->exec_type() == ExecType::kCrossDeviceCopy) {
1532 CHECK_EQ(inode.inputs.size(), 1U);
1533 CHECK_EQ(opnode.exec->in_array.size(), 1U);
1534 CHECK_EQ(opnode.exec->out_array.size(), 1U);
1535 CopyFromTo(opnode.exec->in_array[0], &(opnode.exec->out_array[0]));
1536 } else if (opnode.exec->exec_type() == ExecType::kSubgraphExec) {
1537 // If the node contains a subgraph, we can't execute it in the engine.
1538 opnode.exec->Run(opnode.exec->op_ctx.run_ctx, false);
1539 } else if (opnode.cached_opr != nullptr) {
1540 bool profiling = profiler::Profiler::Get()->GetState() == profiler::Profiler::kRunning;
1541 Engine::Get()->Push(opnode.cached_opr, opnode.ctx, 0, profiling);
1542 if (this->is_dynamic_) {
1543 for (NDArray &array : opnode.exec->out_array) {
1544 array.WaitToRead();
1545 if (!shape_is_known(array.shape())) {
1546 array.SetShapeFromChunk();
1547 }
1548 }
1549 }
1550 } else {
1551 LOG(FATAL) << "Not accessed";
1552 }
1553 for (uint32_t i = 0; i < num_inputs; ++i) {
1554 int eid = idx.entry_id(inode.inputs[i]);
1555 if (!shape_is_known(rshape[eid])) {
1556 rshape[eid] = opnode.exec->in_array[i].shape();
1557 }
1558 }
1559 for (uint32_t i = 0; i < num_outputs; ++i) {
1560 int eid = idx.entry_id(nid, i);
1561 if (!shape_is_known(rshape[eid])) {
1562 rshape[eid] = opnode.exec->out_array[i].shape();
1563 }
1564 }
1565 // Monitor callbacks
1566 if (monitor_callback_) {
1567 ExecuteMonOutputCallback(nid);
1568 }
1569 }
1570 graph_.attrs["shape"] = std::make_shared<dmlc::any>(rshape);
1571 }
1572
CreateCachedSegOpr(size_t topo_start,size_t topo_end)1573 GraphExecutor::CachedSegOpr GraphExecutor::CreateCachedSegOpr(size_t topo_start, size_t topo_end) {
1574 std::vector<Engine::VarHandle> use_vars;
1575 std::vector<Engine::VarHandle> mutate_vars;
1576 Context *pctx = nullptr;
1577 GraphExecutor::CachedSegOpr ret;
1578 ret.topo_start = topo_start;
1579 ret.topo_end = topo_end;
1580 auto& exec_list = ret.exec_list;
1581 // invalid segment
1582 if (topo_end <= topo_start) {
1583 return ret;
1584 }
1585 std::string opr_names = "[";
1586
1587 const auto& idx = graph_.indexed_graph();
1588 for (size_t nid = topo_start; nid < topo_end; ++nid) {
1589 std::vector<Engine::VarHandle> all_vars;
1590 const auto& inode = idx[nid];
1591 OpNode& op_node = op_nodes_[nid];
1592 if (op_node.skip_exec_node) continue;
1593 if (inode.source->is_variable()) continue;
1594 if (op_node.exec->exec_type() != ExecType::kSync) {
1595 return ret;
1596 }
1597 if (pctx == nullptr) pctx = &(op_node.ctx);
1598 if (*pctx != op_node.ctx) {
1599 return ret;
1600 }
1601 auto& exec = op_nodes_[nid].exec;
1602 std::copy(op_node.mutate_vars.begin(), op_node.mutate_vars.end(),
1603 std::inserter(mutate_vars, mutate_vars.end()));
1604 std::copy(op_node.use_vars.begin(), op_node.use_vars.end(),
1605 std::inserter(use_vars, use_vars.end()));
1606 ret.exec_list.push_back(exec);
1607 opr_names += inode.source->op()->name + ",";
1608 }
1609
1610 if (pctx == nullptr)
1611 return ret;
1612 ret.ctx = *pctx;
1613 Engine::Get()->DeduplicateVarHandle(&use_vars, &mutate_vars);
1614
1615 bool is_gpu = pctx->dev_mask() == gpu::kDevMask;
1616
1617 #if CUDA_GRAPHS_AVAILABLE
1618 // Provide initialized `cuda_graphs_exec`, which when captured
1619 // by exec_fun, acts like a static variable inside the mutable closure.
1620 cuda_graphs::CudaGraphsExec cuda_graphs_exec(exec_list, is_gpu, opr_names.c_str());
1621 auto exec_fun = [cuda_graphs_exec, exec_list, is_gpu] (
1622 RunContext rctx, Engine::CallbackOnComplete on_complete) mutable {
1623 // Run all opr in the sub-graph with CUDA graphs executor if possible
1624 cuda_graphs_exec.RunAll(exec_list, rctx, is_gpu);
1625 #else
1626 auto exec_fun = [exec_list, is_gpu] (
1627 RunContext rctx, Engine::CallbackOnComplete on_complete) {
1628 // Run all opr in the sub-graph
1629 OpExecutor::RunAll(exec_list, rctx, is_gpu);
1630 #endif
1631 if (is_gpu) {
1632 #if MXNET_USE_CUDA
1633 // Wait GPU kernel to finish.
1634 rctx.get_stream<gpu>()->Wait();
1635 #else
1636 LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR;
1637 #endif
1638 }
1639 on_complete();
1640 };
1641 opr_names.pop_back();
1642 opr_names += "]";
1643 ret.opr = Engine::Get()->NewOperator(
1644 exec_fun, use_vars, mutate_vars, FnProperty::kNormal,
1645 opr_names.c_str());
1646 return ret;
1647 }
1648
1649 // Infer shapes, dtypes, stypes, contexts for the forward graph
1650 static nnvm::Graph InferForwardAttrs(nnvm::Graph g,
1651 mxnet::ShapeVector arg_shapes,
1652 nnvm::DTypeVector arg_dtypes,
1653 StorageTypeVector arg_stypes,
1654 const Context& default_ctx,
1655 const std::map<std::string, Context>& ctx_map,
1656 const std::vector<Context>& in_arg_ctxes,
1657 const std::vector<Context>& aux_state_ctxes,
1658 bool partial_shape = false) {
1659 const auto& indexed_graph = g.indexed_graph();
1660 const auto num_forward_inputs = indexed_graph.input_nodes().size();
1661 g = AssignContext(g, default_ctx, ctx_map, in_arg_ctxes, {},
1662 aux_state_ctxes, {}, num_forward_inputs, g.outputs.size());
1663 g = InferShape(std::move(g), std::move(arg_shapes), "__shape__");
1664 if (g.GetAttr<size_t>("shape_num_unknown_nodes") != 0U) {
1665 if (!partial_shape) {
1666 HandleInferShapeError(num_forward_inputs, indexed_graph,
1667 g.GetAttr<mxnet::ShapeVector>("shape"));
1668 }
1669 }
1670 g = InferType(std::move(g), std::move(arg_dtypes), "__dtype__");
1671 if (g.GetAttr<size_t>("dtype_num_unknown_nodes") != 0U) {
1672 HandleInferTypeError(num_forward_inputs, indexed_graph,
1673 g.GetAttr<nnvm::DTypeVector>("dtype"));
1674 }
1675 g = InferStorageType(std::move(g), std::move(arg_stypes), "__storage_type__");
1676 if (g.GetAttr<size_t>("storage_type_num_unknown_nodes") != 0U) {
1677 HandleInferStorageTypeError(num_forward_inputs, indexed_graph,
1678 g.GetAttr<StorageTypeVector>("storage_type"));
1679 }
1680 return g;
1681 }
1682
1683 static bool SubgraphBackendCheck(const op::SubgraphBackendPtr& backend,
1684 const Context& default_ctx,
1685 int verbose = 1) {
1686 if (backend->HasAttr("enable") && (backend->GetAttr<bool>("enable") != true)) {
1687 if (verbose > 1) {
1688 LOG(INFO) << "Subgraph backend " << backend->GetName()
1689 << " isn't activated.";
1690 }
1691 return false;
1692 }
1693 if (backend->HasAttr("context") && backend->GetAttr<Context>("context") != default_ctx) {
1694 if (verbose > 1) {
1695 LOG(INFO) << "Subgraph backend " << backend->GetName()
1696 << " isn't activated as context mismatch.";
1697 }
1698 return false;
1699 }
1700 return true;
1701 }
1702
1703 static bool SubgraphPropertyCheck(const std::string& backend_name,
1704 const op::SubgraphPropertyPtr& prop, bool need_grad,
1705 int verbose = 1) {
1706 auto full_name =
1707 prop->HasAttr("property_name") ? prop->GetAttr<std::string>("property_name") : std::string();
1708 if (prop->HasAttr("disable") && prop->GetAttr<bool>("disable") == true) {
1709 LOG(INFO) << "subgraph property " << full_name << " from backend " << backend_name
1710 << " is disabled.";
1711 return false;
1712 }
1713 if (prop->HasAttr("inference_only") && prop->GetAttr<bool>("inference_only") == true) {
1714 if (need_grad) {
1715 if (verbose > 1) {
1716 LOG(INFO) << "skip partitioning graph with subgraph property " << full_name
1717 << " from backend " << backend_name << " as it requires `grad_req=null`.";
1718 }
1719 return false;
1720 }
1721 }
1722 return true;
1723 }
1724
1725 // Given input attr arrays, partition the graph using the backend name equal to prop_name.
1726 // This is a common function for bind and simple_bind flows.
1727 static nnvm::Symbol BuildSubgraph(const nnvm::Symbol& src, op::SubgraphPropertyPtr subgraph_prop,
1728 const mxnet::ShapeVector& arg_shapes,
1729 const nnvm::DTypeVector& arg_dtypes,
1730 const StorageTypeVector& arg_stypes, const Context& default_ctx,
1731 const std::map<std::string, Context>& ctx_map,
1732 const std::vector<Context>& in_arg_ctxes,
1733 const std::vector<Context>& aux_state_ctxes) {
1734 nnvm::Symbol ret = src.Copy();
1735 nnvm::Graph g;
1736 g.outputs = ret.outputs;
1737 g = InferForwardAttrs(g, arg_shapes, arg_dtypes, arg_stypes, default_ctx, ctx_map, in_arg_ctxes,
1738 aux_state_ctxes, true);
1739 subgraph_prop->SetAttr("graph", g);
1740 g.attrs["subgraph_property"] = std::make_shared<nnvm::any>(subgraph_prop);
1741 g = ApplyPass(std::move(g), "BuildSubgraph");
1742 subgraph_prop->RemoveAttr("graph");
1743 g.attrs.erase("subgraph_property");
1744 ret.outputs = g.outputs;
1745 return ret;
1746 }
1747
1748 // Given input attr dicts, partition the graph using the backend.
1749 // This is for simple_bind flow.
1750 static nnvm::Symbol BuildSubgraph(
1751 const nnvm::Symbol& src, const op::SubgraphBackendPtr backend,
1752 const std::unordered_map<std::string, mxnet::TShape>& arg_shape_map,
1753 const std::unordered_map<std::string, int>& arg_dtype_map,
1754 const std::unordered_map<std::string, int>& arg_stype_map, const Context& default_ctx,
1755 const std::map<std::string, Context>& ctx_map, std::vector<Context>* in_arg_ctxes,
1756 std::vector<Context>* arg_grad_ctxes, std::vector<OpReqType>* grad_req_types,
1757 std::vector<Context>* aux_state_ctxes, int verbose = 1) {
1758 // setup map for in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes and grad_req_types
1759 std::unordered_map<std::string, Context> in_arg_ctx_map;
1760 std::unordered_map<std::string, Context> arg_grad_ctx_map;
1761 std::unordered_map<std::string, Context> aux_state_ctx_map;
1762 std::unordered_map<std::string, OpReqType> grad_req_type_map;
1763
1764 auto arg_names = src.ListInputNames(nnvm::Symbol::kReadOnlyArgs);
1765 auto aux_names = src.ListInputNames(nnvm::Symbol::kAuxiliaryStates);
1766 for (size_t i = 0; i < arg_names.size(); ++i) {
1767 const auto& name = arg_names[i];
1768 in_arg_ctx_map[name] = in_arg_ctxes->at(i);
1769 arg_grad_ctx_map[name] = arg_grad_ctxes->at(i);
1770 grad_req_type_map[name] = grad_req_types->at(i);
1771 }
1772
1773 for (size_t i = 0; i < aux_names.size(); ++i) {
1774 aux_state_ctx_map[aux_names[i]] = aux_state_ctxes->at(i);
1775 }
1776
1777 bool need_grad = false;
1778 for (OpReqType req : *grad_req_types) {
1779 if (req != kNullOp) {
1780 need_grad = true;
1781 break;
1782 }
1783 }
1784 nnvm::Symbol ret = src.Copy();
1785 std::unordered_set<std::string> op_names_set;
1786 const auto& backend_name = backend->GetName();
1787 const auto it = op::SubgraphPropertyOpNameSet::Get()->find(backend_name);
1788 // assign a op name set to the subgraph property if it has been provided by users
1789 if (it != op::SubgraphPropertyOpNameSet::Get()->end()) {
1790 LOG(INFO) << "SubgraphPropertyOpNameSet for subgraph property " << backend_name
1791 << " has been assigned a value. Please make sure it is initialized"
1792 " only for the testing purpose.";
1793 op_names_set = it->second;
1794 }
1795
1796 const auto& subgraph_prop_list = backend->GetSubgraphProperties();
1797 for (auto& subgraph_prop : subgraph_prop_list) {
1798 if (SubgraphPropertyCheck(backend_name, subgraph_prop, need_grad, verbose)) {
1799 subgraph_prop->SetAttr("op_names", op_names_set);
1800 const std::vector<std::string> input_names = ret.ListInputNames(Symbol::kAll);
1801 mxnet::ShapeVector arg_shapes(input_names.size(), mxnet::TShape());
1802 nnvm::DTypeVector arg_dtypes(input_names.size(), -1);
1803 StorageTypeVector arg_stypes(input_names.size(), kUndefinedStorage);
1804 for (size_t i = 0; i < input_names.size(); ++i) {
1805 const auto& input_name = input_names[i];
1806 const auto it1 = arg_shape_map.find(input_name);
1807 if (arg_shape_map.end() != it1) {
1808 arg_shapes[i] = it1->second;
1809 }
1810 const auto it2 = arg_dtype_map.find(input_name);
1811 if (arg_dtype_map.end() != it2) {
1812 arg_dtypes[i] = it2->second;
1813 }
1814 const auto it3 = arg_stype_map.find(input_name);
1815 if (arg_stype_map.end() != it3) {
1816 arg_stypes[i] = it3->second;
1817 }
1818 }
1819 ret = BuildSubgraph(ret, subgraph_prop, arg_shapes, arg_dtypes, arg_stypes, default_ctx,
1820 ctx_map, *in_arg_ctxes, *aux_state_ctxes);
1821 // Reorder in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes and grad_req_types according to
1822 // partitioned symbol input sequence
1823 in_arg_ctxes->clear();
1824 arg_grad_ctxes->clear();
1825 aux_state_ctxes->clear();
1826 grad_req_types->clear();
1827 auto new_arg_names = ret.ListInputNames(nnvm::Symbol::kReadOnlyArgs);
1828 auto new_aux_names = ret.ListInputNames(nnvm::Symbol::kAuxiliaryStates);
1829 for (const auto& arg_name : new_arg_names) {
1830 CHECK(in_arg_ctx_map.count(arg_name));
1831 in_arg_ctxes->push_back(in_arg_ctx_map[arg_name]);
1832 arg_grad_ctxes->push_back(arg_grad_ctx_map[arg_name]);
1833 grad_req_types->push_back(grad_req_type_map[arg_name]);
1834 }
1835 for (const auto& arg_name : new_aux_names) {
1836 CHECK(aux_state_ctx_map.count(arg_name));
1837 aux_state_ctxes->push_back(aux_state_ctx_map[arg_name]);
1838 }
1839 }
1840 }
1841 return ret;
1842 }
1843
1844 // Given input ndarrays, partition the graph using backend.
1845 // This is for bind flow.
1846 static nnvm::Symbol BuildSubgraph(const nnvm::Symbol& src, const op::SubgraphBackendPtr backend,
1847 const Context& default_ctx,
1848 const std::map<std::string, Context>& ctx_map,
1849 std::vector<NDArray>* in_args,
1850 std::vector<NDArray>* arg_grad_store,
1851 std::vector<OpReqType>* grad_req_type,
1852 std::vector<NDArray>* aux_states, int verbose = 1) {
1853 // setup map for in_args, arg_grad_store, grad_req_type and aux_states
1854 std::unordered_map<std::string, NDArray> in_args_map;
1855 std::unordered_map<std::string, NDArray> arg_grad_store_map;
1856 std::unordered_map<std::string, OpReqType> grad_req_type_map;
1857 std::unordered_map<std::string, NDArray> aux_states_map;
1858 const std::vector<std::string> arg_names = src.ListInputNames(nnvm::Symbol::kReadOnlyArgs);
1859 const std::vector<std::string> aux_names = src.ListInputNames(nnvm::Symbol::kAuxiliaryStates);
1860 for (size_t i = 0; i < arg_names.size(); ++i) {
1861 in_args_map[arg_names[i]] = in_args->at(i);
1862 }
1863
1864 for (size_t i = 0; i < aux_names.size(); ++i) {
1865 aux_states_map[aux_names[i]] = aux_states->at(i);
1866 }
1867
1868 if (arg_grad_store->size()) {
1869 for (size_t i = 0; i < arg_names.size(); ++i) {
1870 const auto& name = arg_names[i];
1871 arg_grad_store_map[name] = arg_grad_store->at(i);
1872 grad_req_type_map[name] = grad_req_type->at(i);
1873 }
1874 }
1875
1876 bool need_grad = false;
1877 for (OpReqType req : *grad_req_type) {
1878 if (req != kNullOp) {
1879 need_grad = true;
1880 break;
1881 }
1882 }
1883 nnvm::Symbol ret = src.Copy();
1884 std::unordered_set<std::string> op_names_set;
1885 const auto& backend_name = backend->GetName();
1886 auto it = op::SubgraphPropertyOpNameSet::Get()->find(backend_name);
1887 // assign a op name set to the subgraph property if it has been provided by users
1888 if (it != op::SubgraphPropertyOpNameSet::Get()->end()) {
1889 LOG(INFO) << "SubgraphPropertyOpNameSet for subgraph property " << backend_name
1890 << " has been assigned a value. Please make sure it is initialized"
1891 " only for the testing purpose.";
1892 op_names_set = it->second;
1893 }
1894 const auto& subgraph_prop_list = backend->GetSubgraphProperties();
1895
1896 for (auto subgraph_prop : subgraph_prop_list) {
1897 if (SubgraphPropertyCheck(backend_name, subgraph_prop, need_grad, verbose)) {
1898 subgraph_prop->SetAttr("op_names", op_names_set);
1899 const std::vector<std::string> input_names = ret.ListInputNames(Symbol::kAll);
1900 const std::vector<std::string> arg_names = ret.ListInputNames(nnvm::Symbol::kReadOnlyArgs);
1901 const std::vector<std::string> aux_names = ret.ListInputNames(nnvm::Symbol::kAuxiliaryStates);
1902 CHECK_EQ(arg_names.size(), in_args_map.size());
1903 CHECK_EQ(aux_names.size(), aux_states_map.size());
1904 mxnet::ShapeVector arg_shapes; // all input shapes
1905 arg_shapes.reserve(input_names.size());
1906 nnvm::DTypeVector arg_dtypes; // all input dtypes
1907 arg_dtypes.reserve(input_names.size());
1908 StorageTypeVector arg_stypes; // all input stypes
1909 arg_stypes.reserve(input_names.size());
1910 std::vector<Context> in_arg_ctxes(in_args_map.size());
1911 std::vector<Context> aux_state_ctxes(aux_states_map.size());
1912
1913 size_t i1 = 0, i2 = 0;
1914 for (const auto& input_name : input_names) {
1915 if (i2 < aux_names.size() && aux_names[i2] == input_name) {
1916 const auto &aux_st = aux_states_map[input_name];
1917 arg_shapes.push_back(aux_st.shape());
1918 arg_dtypes.push_back(aux_st.dtype());
1919 arg_stypes.push_back(aux_st.storage_type());
1920 aux_state_ctxes[i2] = aux_st.ctx();
1921 ++i2;
1922 } else {
1923 CHECK(i1 < arg_names.size());
1924 CHECK_EQ(arg_names[i1], input_name);
1925 const auto &in_arg = in_args_map[input_name];
1926 arg_shapes.push_back(in_arg.shape());
1927 arg_dtypes.push_back(in_arg.dtype());
1928 arg_stypes.push_back(in_arg.storage_type());
1929 in_arg_ctxes[i1] = in_arg.ctx();
1930 ++i1;
1931 }
1932 }
1933
1934 ret = BuildSubgraph(ret, subgraph_prop, arg_shapes, arg_dtypes, arg_stypes, default_ctx,
1935 ctx_map, in_arg_ctxes, aux_state_ctxes);
1936 }
1937 }
1938 // Reorder in_args, arg_grad_store, grad_req_type and aux_states according to partitioned symbol
1939 // input sequence
1940 const auto new_arg_names = ret.ListInputNames(nnvm::Symbol::kReadOnlyArgs);
1941 const auto new_aux_names = ret.ListInputNames(nnvm::Symbol::kAuxiliaryStates);
1942 CHECK_EQ(arg_names.size(), new_arg_names.size());
1943 CHECK_EQ(arg_names.size(), new_arg_names.size());
1944 in_args->clear();
1945 aux_states->clear();
1946 for (const auto& arg_name : new_arg_names) {
1947 CHECK(in_args_map.count(arg_name));
1948 in_args->push_back(in_args_map[arg_name]);
1949 }
1950
1951 for (const auto& arg_name : new_aux_names) {
1952 CHECK(aux_states_map.count(arg_name));
1953 aux_states->push_back(aux_states_map[arg_name]);
1954 }
1955
1956 if (arg_grad_store->size()) {
1957 arg_grad_store->clear();
1958 grad_req_type->clear();
1959 for (const auto& arg_name : new_arg_names) {
1960 arg_grad_store->push_back(arg_grad_store_map[arg_name]);
1961 grad_req_type->push_back(grad_req_type_map[arg_name]);
1962 }
1963 }
1964 return ret;
1965 }
1966 } // namespace exec
1967
SimpleBind(nnvm::Symbol symbol,const Context & default_ctx,const std::map<std::string,Context> & group2ctx,const std::vector<Context> & in_arg_ctxes,const std::vector<Context> & arg_grad_ctxes,const std::vector<Context> & aux_state_ctxes,const std::unordered_map<std::string,mxnet::TShape> & arg_shape_map,const std::unordered_map<std::string,int> & arg_dtype_map,const std::unordered_map<std::string,int> & arg_stype_map,const std::vector<OpReqType> & grad_req_types,const std::unordered_set<std::string> & shared_arg_names,std::vector<NDArray> * in_args,std::vector<NDArray> * arg_grads,std::vector<NDArray> * aux_states,std::unordered_map<std::string,NDArray> * shared_buffer,Executor * shared_exec)1968 Executor *Executor::SimpleBind(nnvm::Symbol symbol,
1969 const Context& default_ctx,
1970 const std::map<std::string, Context>& group2ctx,
1971 const std::vector<Context>& in_arg_ctxes,
1972 const std::vector<Context>& arg_grad_ctxes,
1973 const std::vector<Context>& aux_state_ctxes,
1974 const std::unordered_map<std::string, mxnet::TShape>& arg_shape_map,
1975 const std::unordered_map<std::string, int>& arg_dtype_map,
1976 const std::unordered_map<std::string, int>& arg_stype_map,
1977 const std::vector<OpReqType>& grad_req_types,
1978 const std::unordered_set<std::string>& shared_arg_names,
1979 std::vector<NDArray>* in_args,
1980 std::vector<NDArray>* arg_grads,
1981 std::vector<NDArray>* aux_states,
1982 std::unordered_map<std::string, NDArray>* shared_buffer,
1983 Executor* shared_exec) {
1984 auto exec = new exec::GraphExecutor(symbol);
1985 bool init = false;
1986 if (!exec->subgraph_property().empty()) {
1987 static int verbose = dmlc::GetEnv("MXNET_SUBGRAPH_VERBOSE", 1);
1988 const auto& backend_name = exec->subgraph_property();
1989 const auto& backend = op::SubgraphBackendRegistry::Get()->GetSubgraphBackend(backend_name);
1990 if (exec::SubgraphBackendCheck(backend, default_ctx, verbose)) {
1991 if (verbose) LOG(INFO) << "Subgraph backend " << backend_name << " is activated.";
1992 std::vector<Context> tmp_in_arg_ctxes = in_arg_ctxes;
1993 std::vector<Context> tmp_arg_grad_ctxes = arg_grad_ctxes;
1994 std::vector<Context> tmp_aux_state_ctxes = aux_state_ctxes;
1995 std::vector<OpReqType> tmp_grad_req_types = grad_req_types;
1996 std::vector<NDArray> tmp_in_args;
1997 std::vector<NDArray> tmp_arg_grads;
1998 std::vector<NDArray> tmp_aux_states;
1999 const auto arg_names = symbol.ListInputNames(nnvm::Symbol::kReadOnlyArgs);
2000 const auto aux_names = symbol.ListInputNames(nnvm::Symbol::kAuxiliaryStates);
2001 symbol = exec::BuildSubgraph(symbol, backend, arg_shape_map, arg_dtype_map, arg_stype_map,
2002 default_ctx, group2ctx, &tmp_in_arg_ctxes, &tmp_arg_grad_ctxes,
2003 &tmp_grad_req_types, &tmp_aux_state_ctxes, verbose);
2004 // Subgraph cannot be recreated from unoptimized symbol
2005 delete exec;
2006 exec = new exec::GraphExecutor(symbol);
2007 exec->Init(symbol.Copy(), default_ctx, group2ctx, tmp_in_arg_ctxes, tmp_arg_grad_ctxes,
2008 tmp_aux_state_ctxes, arg_shape_map, arg_dtype_map, arg_stype_map,
2009 tmp_grad_req_types, shared_arg_names, &tmp_in_args, &tmp_arg_grads,
2010 &tmp_aux_states, shared_buffer, shared_exec);
2011 init = true;
2012 const auto new_arg_names = symbol.ListInputNames(nnvm::Symbol::kReadOnlyArgs);
2013 const auto new_aux_names = symbol.ListInputNames(nnvm::Symbol::kAuxiliaryStates);
2014 std::unordered_map<std::string, size_t> new_arg_names_idx_map;
2015 std::unordered_map<std::string, size_t> new_aux_names_idx_map;
2016 for (size_t i = 0; i != new_arg_names.size(); ++i) {
2017 new_arg_names_idx_map[new_arg_names[i]] = i;
2018 }
2019 for (size_t i = 0; i != new_aux_names.size(); ++i) {
2020 new_aux_names_idx_map[new_aux_names[i]] = i;
2021 }
2022
2023 in_args->reserve(arg_names.size());
2024 arg_grads->reserve(arg_names.size());
2025 for (size_t i = 0; i != arg_names.size(); ++i) {
2026 const auto& arg_name = arg_names[i];
2027 const auto& it = new_arg_names_idx_map.find(arg_name);
2028 CHECK(it != new_arg_names_idx_map.end())
2029 << "Subgraph doesn't support remove any input node for now.";
2030 in_args->emplace_back(std::move(tmp_in_args[it->second]));
2031 arg_grads->emplace_back(std::move(tmp_arg_grads[it->second]));
2032 }
2033
2034 aux_states->reserve(aux_names.size());
2035 for (size_t i = 0; i != aux_names.size(); ++i) {
2036 const auto& aux_name = aux_names[i];
2037 const auto& it = new_aux_names_idx_map.find(aux_name);
2038 CHECK(it != new_aux_names_idx_map.end())
2039 << "Subgraph doesn't support remove any input node for now.";
2040 aux_states->emplace_back(std::move(tmp_aux_states[it->second]));
2041 }
2042 }
2043 }
2044 if (!init) {
2045 // init without subgraph
2046 exec->Init(symbol.Copy(), default_ctx, group2ctx, in_arg_ctxes, arg_grad_ctxes, aux_state_ctxes,
2047 arg_shape_map, arg_dtype_map, arg_stype_map, grad_req_types, shared_arg_names,
2048 in_args, arg_grads, aux_states, shared_buffer, shared_exec);
2049 }
2050 return exec;
2051 }
2052
Bind(nnvm::Symbol symbol,const Context & default_ctx,const std::map<std::string,Context> & group2ctx,const std::vector<NDArray> & in_args,const std::vector<NDArray> & arg_grad_store,const std::vector<OpReqType> & grad_req_type,const std::vector<NDArray> & aux_states,Executor * shared_exec)2053 Executor *Executor::Bind(nnvm::Symbol symbol,
2054 const Context& default_ctx,
2055 const std::map<std::string, Context>& group2ctx,
2056 const std::vector<NDArray> &in_args,
2057 const std::vector<NDArray> &arg_grad_store,
2058 const std::vector<OpReqType> &grad_req_type,
2059 const std::vector<NDArray> &aux_states,
2060 Executor* shared_exec) {
2061 auto exec = new exec::GraphExecutor(symbol);
2062 static int verbose = dmlc::GetEnv("MXNET_SUBGRAPH_VERBOSE", 1);
2063 std::vector<NDArray> tmp_in_args = in_args;
2064 std::vector<NDArray> tmp_arg_grad_store = arg_grad_store;
2065 std::vector<OpReqType> tmp_grad_req_type = grad_req_type;
2066 std::vector<NDArray> tmp_aux_states = aux_states;
2067
2068 if (!exec->subgraph_property().empty()) {
2069 const auto& backend_name = exec->subgraph_property();
2070 const auto& backend = op::SubgraphBackendRegistry::Get()->GetSubgraphBackend(backend_name);
2071 if (exec::SubgraphBackendCheck(backend, default_ctx, verbose)) {
2072 if (verbose) LOG(INFO) << "Subgraph backend " << backend_name << " is activated.";
2073 symbol = exec::BuildSubgraph(symbol, backend, default_ctx, group2ctx, &tmp_in_args,
2074 &tmp_arg_grad_store, &tmp_grad_req_type, &tmp_aux_states,
2075 verbose);
2076 // Subgraph cannot be recreated from unoptimized symbol
2077 delete exec;
2078 exec = new exec::GraphExecutor(symbol);
2079 }
2080 }
2081 exec->Init(symbol.Copy(), default_ctx, group2ctx, tmp_in_args, tmp_arg_grad_store,
2082 tmp_grad_req_type, tmp_aux_states, reinterpret_cast<Executor*>(shared_exec));
2083 return exec;
2084 }
2085 } // namespace mxnet
2086