1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 #include "./imperative_utils.h"
21 #include "./cached_op.h"
22 #include "../operator/operator_common.h"
23
24 namespace {
25
NodeInputs(const nnvm::IndexedGraph & idx,const int node_idx,const std::vector<NDArray * > & arrays)26 std::vector<NDArray*> NodeInputs(const nnvm::IndexedGraph& idx,
27 const int node_idx,
28 const std::vector<NDArray*>& arrays) {
29 const nnvm::IndexedGraph::Node& node = idx[node_idx];
30 const size_t num_inputs = node.inputs.size();
31 std::vector<NDArray*> ndinputs;
32 ndinputs.reserve(num_inputs);
33 for (const auto& j : node.inputs) {
34 const size_t eid = idx.entry_id(j);
35 ndinputs.emplace_back(arrays[eid]);
36 }
37 return ndinputs;
38 }
39
NodeOutputs(const nnvm::IndexedGraph & idx,const int node_idx,const std::vector<NDArray * > & arrays)40 std::vector<NDArray*> NodeOutputs(const nnvm::IndexedGraph& idx,
41 const int node_idx,
42 const std::vector<NDArray*>& arrays) {
43 const nnvm::IndexedGraph::Node& node = idx[node_idx];
44 const size_t num_outputs = node.source->num_outputs();
45 std::vector<NDArray*> ndoutputs;
46 ndoutputs.reserve(num_outputs);
47 for (size_t j = 0; j < num_outputs; ++j) {
48 const size_t eid = idx.entry_id(node_idx, j);
49 ndoutputs.emplace_back(arrays[eid]);
50 }
51 return ndoutputs;
52 }
53
NodeReq(const nnvm::IndexedGraph & idx,const int node_idx,const std::vector<OpReqType> & array_reqs)54 std::vector<OpReqType> NodeReq(const nnvm::IndexedGraph& idx,
55 const int node_idx,
56 const std::vector<OpReqType>& array_reqs) {
57 const nnvm::IndexedGraph::Node& node = idx[node_idx];
58 const size_t num_outputs = node.source->num_outputs();
59 std::vector<OpReqType> req;
60 req.reserve(num_outputs);
61 for (size_t j = 0; j < num_outputs; ++j) {
62 const size_t eid = idx.entry_id(node_idx, j);
63 req.push_back(array_reqs[eid]);
64 }
65 return req;
66 }
67
InvokeOperator(const nnvm::IndexedGraph & idx,const int node_idx,const bool retain_graph,const std::vector<NDArray * > & arrays,Context ctx,std::vector<OpStatePtr> * p_states,const std::vector<NDArray * > & ndinputs,const std::vector<NDArray * > & ndoutputs,std::vector<OpReqType> * p_req,std::vector<uint32_t> * p_ref_count,std::function<void (const OpStatePtr & state)> invoke)68 void InvokeOperator(const nnvm::IndexedGraph& idx,
69 const int node_idx,
70 const bool retain_graph,
71 const std::vector<NDArray*>& arrays,
72 Context ctx,
73 std::vector<OpStatePtr>* p_states,
74 const std::vector<NDArray*>& ndinputs,
75 const std::vector<NDArray*>& ndoutputs,
76 std::vector<OpReqType> *p_req,
77 std::vector<uint32_t> *p_ref_count,
78 std::function<void(const OpStatePtr &state)> invoke) {
79 static const auto bwd_cached_op = Op::Get("_backward_CachedOp");
80 static auto& createop = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
81 static auto& is_layer_backward = Op::GetAttr<bool>("TIsLayerOpBackward");
82 std::vector<OpStatePtr>& states = *p_states;
83 std::vector<OpReqType> &req = *p_req;
84 std::vector<uint32_t> &ref_count = *p_ref_count;
85
86 const nnvm::IndexedGraph::Node& node = idx[node_idx];
87 if (node.source->op() == bwd_cached_op) {
88 const auto& cached_op = dmlc::get<CachedOpPtr>(node.source->attrs.parsed);
89 nnvm::Node* fwd_node = node.source->control_deps[0].get();
90 auto fwd_node_id = idx.node_id(fwd_node);
91 cached_op->Backward(retain_graph, states[fwd_node_id], ndinputs, req, ndoutputs);
92 } else if (createop.count(node.source->op())) {
93 mxnet::ShapeVector arg_shapes;
94 nnvm::DTypeVector arg_dtypes;
95 arg_shapes.reserve(ndinputs.size());
96 arg_dtypes.reserve(ndinputs.size());
97 for (auto& ndinput : ndinputs) {
98 arg_shapes.emplace_back(ndinput->shape());
99 arg_dtypes.emplace_back(ndinput->dtype());
100 }
101 states[node_idx] = createop[node.source->op()](node.source->attrs, ctx, arg_shapes, arg_dtypes);
102 invoke(states[node_idx]);
103 } else if (is_layer_backward.get(node.source->op(), false)) {
104 nnvm::Node* fwd_node = node.source->control_deps[0].get();
105 auto fwd_node_id = idx.node_id(fwd_node);
106 invoke(states[fwd_node_id]);
107 } else {
108 invoke(OpStatePtr());
109 }
110 for (const auto& j : node.inputs) {
111 size_t eid = idx.entry_id(j);
112 --ref_count[eid];
113 if (ref_count[eid] == 0) {
114 *arrays[eid] = NDArray();
115 }
116 }
117 for (size_t j = 0; j < ndoutputs.size(); ++j) {
118 size_t eid = idx.entry_id(node_idx, j);
119 if (ref_count[eid] == 0) {
120 *arrays[eid] = NDArray();
121 }
122 }
123 }
124
125 } // namespace
126
127 namespace mxnet {
128 namespace imperative {
129
RunGraph(const bool retain_graph,const nnvm::IndexedGraph & idx,const std::vector<NDArray * > & arrays,size_t node_start,size_t node_end,std::vector<OpReqType> && array_reqs,std::vector<uint32_t> && ref_count,std::vector<OpStatePtr> * p_states,const DispatchModeVector & dispatch_modes,bool recording,mxnet::ShapeVector * shapes,const imperative::CachedOpMonCallback & callback,const bool monitor_all)130 void RunGraph(
131 const bool retain_graph,
132 const nnvm::IndexedGraph& idx,
133 const std::vector<NDArray*>& arrays,
134 size_t node_start, size_t node_end,
135 std::vector<OpReqType>&& array_reqs,
136 std::vector<uint32_t>&& ref_count,
137 std::vector<OpStatePtr> *p_states,
138 const DispatchModeVector &dispatch_modes,
139 bool recording,
140 mxnet::ShapeVector *shapes,
141 const imperative::CachedOpMonCallback& callback,
142 const bool monitor_all) {
143 CHECK(shapes == nullptr);
144 for (size_t i = node_start; i < node_end; ++i) {
145 const nnvm::IndexedGraph::Node& node = idx[i];
146 if (node.source->op() == nullptr) {
147 continue;
148 }
149 std::vector<NDArray*> ndinputs = NodeInputs(idx, i, arrays);
150 std::vector<NDArray*> ndoutputs = NodeOutputs(idx, i, arrays);
151 std::vector<OpReqType> req = NodeReq(idx, i, array_reqs);
152 Context ctx = ndoutputs[0]->ctx();
153 if (callback && monitor_all) {
154 mxnet::common::ExecuteMonInputCallback(idx, arrays, i, callback);
155 }
156 auto invoke = [&](const OpStatePtr &state) {
157 const nnvm::IndexedGraph::Node& node = idx[i];
158 DispatchMode dispatch_mode = dispatch_modes[i];
159 Imperative::Get()->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs,
160 req, dispatch_mode, state);
161 if (recording) {
162 Imperative::Get()->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, state);
163 }
164 };
165 InvokeOperator(idx, i, retain_graph, arrays, ctx, p_states, ndinputs, ndoutputs,
166 &req, &ref_count, invoke);
167 if (callback) {
168 mxnet::common::ExecuteMonOutputCallback(idx, arrays, i, callback);
169 }
170 }
171 }
172
NaiveRunGraph(const bool retain_graph,const Context & default_ctx,const nnvm::IndexedGraph & idx,const std::vector<NDArray * > & arrays,size_t node_start,size_t node_end,std::vector<OpReqType> && array_reqs,std::vector<uint32_t> && ref_count,std::vector<OpStatePtr> * p_states,const DispatchModeVector & dispatch_modes,bool recording,mxnet::ShapeVector * shapes,const imperative::CachedOpMonCallback & callback,const bool monitor_all)173 void NaiveRunGraph(
174 const bool retain_graph,
175 const Context& default_ctx,
176 const nnvm::IndexedGraph& idx,
177 const std::vector<NDArray*>& arrays,
178 size_t node_start, size_t node_end,
179 std::vector<OpReqType>&& array_reqs,
180 std::vector<uint32_t>&& ref_count,
181 std::vector<OpStatePtr> *p_states,
182 const DispatchModeVector &dispatch_modes,
183 bool recording,
184 mxnet::ShapeVector *shapes,
185 const imperative::CachedOpMonCallback& callback,
186 const bool monitor_all) {
187 for (size_t i = node_start; i < node_end; ++i) {
188 const nnvm::IndexedGraph::Node& node = idx[i];
189 if (node.source->op() == nullptr) {
190 continue;
191 }
192 std::vector<NDArray*> ndinputs = NodeInputs(idx, i, arrays);
193 std::vector<NDArray*> ndoutputs = NodeOutputs(idx, i, arrays);
194 std::vector<OpReqType> req;
195 Context ctx = GetContext(node.source->attrs, ndinputs, ndoutputs, default_ctx);
196 if (callback && monitor_all) {
197 mxnet::common::ExecuteMonInputCallback(idx, arrays, i, callback);
198 }
199 auto invoke = [&](const OpStatePtr &state) {
200 const nnvm::IndexedGraph::Node& node = idx[i];
201 DispatchMode dispatch_mode = DispatchMode::kUndefined;
202 SetShapeType(ctx, node.source->attrs, ndinputs, ndoutputs, &dispatch_mode);
203 SetWriteInplaceReq(ndinputs, ndoutputs, &req);
204 Imperative::Get()->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs,
205 req, dispatch_mode, state);
206 for (size_t j = 0; j < ndoutputs.size(); ++j) {
207 if (mxnet::op::shape_is_none(ndoutputs[j]->shape())) {
208 ndoutputs[j]->WaitToRead();
209 ndoutputs[j]->SetShapeFromChunk();
210 }
211 size_t eid = idx.entry_id(i, j);
212 auto shape = ndoutputs[j]->shape();
213 (*shapes)[eid] = shape;
214 }
215 if (recording) {
216 Imperative::Get()->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, state);
217 }
218 };
219 InvokeOperator(idx, i, retain_graph, arrays, ctx, p_states, ndinputs, ndoutputs,
220 &req, &ref_count, invoke);
221 if (callback) {
222 mxnet::common::ExecuteMonOutputCallback(idx, arrays, i, callback);
223 }
224 }
225 }
226
227 } // namespace imperative
228 } // namespace mxnet
229