1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #include "./imperative_utils.h"
21 #include "./cached_op.h"
22 #include "../operator/operator_common.h"
23 
24 namespace {
25 
NodeInputs(const nnvm::IndexedGraph & idx,const int node_idx,const std::vector<NDArray * > & arrays)26 std::vector<NDArray*> NodeInputs(const nnvm::IndexedGraph& idx,
27                                  const int node_idx,
28                                  const std::vector<NDArray*>& arrays) {
29   const nnvm::IndexedGraph::Node& node = idx[node_idx];
30   const size_t num_inputs = node.inputs.size();
31   std::vector<NDArray*> ndinputs;
32   ndinputs.reserve(num_inputs);
33   for (const auto& j : node.inputs) {
34     const size_t eid = idx.entry_id(j);
35     ndinputs.emplace_back(arrays[eid]);
36   }
37   return ndinputs;
38 }
39 
NodeOutputs(const nnvm::IndexedGraph & idx,const int node_idx,const std::vector<NDArray * > & arrays)40 std::vector<NDArray*> NodeOutputs(const nnvm::IndexedGraph& idx,
41                                   const int node_idx,
42                                   const std::vector<NDArray*>& arrays) {
43   const nnvm::IndexedGraph::Node& node = idx[node_idx];
44   const size_t num_outputs = node.source->num_outputs();
45   std::vector<NDArray*> ndoutputs;
46   ndoutputs.reserve(num_outputs);
47   for (size_t j = 0; j < num_outputs; ++j) {
48     const size_t eid = idx.entry_id(node_idx, j);
49     ndoutputs.emplace_back(arrays[eid]);
50   }
51   return ndoutputs;
52 }
53 
NodeReq(const nnvm::IndexedGraph & idx,const int node_idx,const std::vector<OpReqType> & array_reqs)54 std::vector<OpReqType> NodeReq(const nnvm::IndexedGraph& idx,
55                                const int node_idx,
56                                const std::vector<OpReqType>& array_reqs) {
57   const nnvm::IndexedGraph::Node& node = idx[node_idx];
58   const size_t num_outputs = node.source->num_outputs();
59   std::vector<OpReqType> req;
60   req.reserve(num_outputs);
61   for (size_t j = 0; j < num_outputs; ++j) {
62     const size_t eid = idx.entry_id(node_idx, j);
63     req.push_back(array_reqs[eid]);
64   }
65   return req;
66 }
67 
InvokeOperator(const nnvm::IndexedGraph & idx,const int node_idx,const bool retain_graph,const std::vector<NDArray * > & arrays,Context ctx,std::vector<OpStatePtr> * p_states,const std::vector<NDArray * > & ndinputs,const std::vector<NDArray * > & ndoutputs,std::vector<OpReqType> * p_req,std::vector<uint32_t> * p_ref_count,std::function<void (const OpStatePtr & state)> invoke)68 void InvokeOperator(const nnvm::IndexedGraph& idx,
69                     const int node_idx,
70                     const bool retain_graph,
71                     const std::vector<NDArray*>& arrays,
72                     Context ctx,
73                     std::vector<OpStatePtr>* p_states,
74                     const std::vector<NDArray*>& ndinputs,
75                     const std::vector<NDArray*>& ndoutputs,
76                     std::vector<OpReqType> *p_req,
77                     std::vector<uint32_t> *p_ref_count,
78                     std::function<void(const OpStatePtr &state)> invoke) {
79   static const auto bwd_cached_op = Op::Get("_backward_CachedOp");
80   static auto& createop = nnvm::Op::GetAttr<FCreateOpState>("FCreateOpState");
81   static auto& is_layer_backward = Op::GetAttr<bool>("TIsLayerOpBackward");
82   std::vector<OpStatePtr>& states = *p_states;
83   std::vector<OpReqType> &req = *p_req;
84   std::vector<uint32_t> &ref_count = *p_ref_count;
85 
86   const nnvm::IndexedGraph::Node& node = idx[node_idx];
87   if (node.source->op() == bwd_cached_op) {
88     const auto& cached_op = dmlc::get<CachedOpPtr>(node.source->attrs.parsed);
89     nnvm::Node* fwd_node = node.source->control_deps[0].get();
90     auto fwd_node_id = idx.node_id(fwd_node);
91     cached_op->Backward(retain_graph, states[fwd_node_id], ndinputs, req, ndoutputs);
92   } else if (createop.count(node.source->op())) {
93     mxnet::ShapeVector arg_shapes;
94     nnvm::DTypeVector arg_dtypes;
95     arg_shapes.reserve(ndinputs.size());
96     arg_dtypes.reserve(ndinputs.size());
97     for (auto& ndinput : ndinputs) {
98       arg_shapes.emplace_back(ndinput->shape());
99       arg_dtypes.emplace_back(ndinput->dtype());
100     }
101     states[node_idx] = createop[node.source->op()](node.source->attrs, ctx, arg_shapes, arg_dtypes);
102     invoke(states[node_idx]);
103   } else if (is_layer_backward.get(node.source->op(), false)) {
104     nnvm::Node* fwd_node = node.source->control_deps[0].get();
105     auto fwd_node_id = idx.node_id(fwd_node);
106     invoke(states[fwd_node_id]);
107   } else {
108     invoke(OpStatePtr());
109   }
110   for (const auto& j : node.inputs) {
111     size_t eid = idx.entry_id(j);
112     --ref_count[eid];
113     if (ref_count[eid] == 0) {
114       *arrays[eid] = NDArray();
115     }
116   }
117   for (size_t j = 0; j < ndoutputs.size(); ++j) {
118     size_t eid = idx.entry_id(node_idx, j);
119     if (ref_count[eid] == 0) {
120       *arrays[eid] = NDArray();
121     }
122   }
123 }
124 
125 }  // namespace
126 
127 namespace mxnet {
128 namespace imperative {
129 
RunGraph(const bool retain_graph,const nnvm::IndexedGraph & idx,const std::vector<NDArray * > & arrays,size_t node_start,size_t node_end,std::vector<OpReqType> && array_reqs,std::vector<uint32_t> && ref_count,std::vector<OpStatePtr> * p_states,const DispatchModeVector & dispatch_modes,bool recording,mxnet::ShapeVector * shapes,const imperative::CachedOpMonCallback & callback,const bool monitor_all)130 void RunGraph(
131     const bool retain_graph,
132     const nnvm::IndexedGraph& idx,
133     const std::vector<NDArray*>& arrays,
134     size_t node_start, size_t node_end,
135     std::vector<OpReqType>&& array_reqs,
136     std::vector<uint32_t>&& ref_count,
137     std::vector<OpStatePtr> *p_states,
138     const DispatchModeVector &dispatch_modes,
139     bool recording,
140     mxnet::ShapeVector *shapes,
141     const imperative::CachedOpMonCallback& callback,
142     const bool monitor_all) {
143   CHECK(shapes == nullptr);
144   for (size_t i = node_start; i < node_end; ++i) {
145     const nnvm::IndexedGraph::Node& node = idx[i];
146     if (node.source->op() == nullptr) {
147       continue;
148     }
149     std::vector<NDArray*> ndinputs = NodeInputs(idx, i, arrays);
150     std::vector<NDArray*> ndoutputs = NodeOutputs(idx, i, arrays);
151     std::vector<OpReqType> req = NodeReq(idx, i, array_reqs);
152     Context ctx = ndoutputs[0]->ctx();
153     if (callback && monitor_all) {
154         mxnet::common::ExecuteMonInputCallback(idx, arrays, i, callback);
155     }
156     auto invoke = [&](const OpStatePtr &state) {
157       const nnvm::IndexedGraph::Node& node = idx[i];
158       DispatchMode dispatch_mode = dispatch_modes[i];
159       Imperative::Get()->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs,
160                                   req, dispatch_mode, state);
161       if (recording) {
162         Imperative::Get()->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, state);
163       }
164     };
165     InvokeOperator(idx, i, retain_graph, arrays, ctx, p_states, ndinputs, ndoutputs,
166                    &req, &ref_count, invoke);
167     if (callback) {
168         mxnet::common::ExecuteMonOutputCallback(idx, arrays, i, callback);
169     }
170   }
171 }
172 
NaiveRunGraph(const bool retain_graph,const Context & default_ctx,const nnvm::IndexedGraph & idx,const std::vector<NDArray * > & arrays,size_t node_start,size_t node_end,std::vector<OpReqType> && array_reqs,std::vector<uint32_t> && ref_count,std::vector<OpStatePtr> * p_states,const DispatchModeVector & dispatch_modes,bool recording,mxnet::ShapeVector * shapes,const imperative::CachedOpMonCallback & callback,const bool monitor_all)173 void NaiveRunGraph(
174     const bool retain_graph,
175     const Context& default_ctx,
176     const nnvm::IndexedGraph& idx,
177     const std::vector<NDArray*>& arrays,
178     size_t node_start, size_t node_end,
179     std::vector<OpReqType>&& array_reqs,
180     std::vector<uint32_t>&& ref_count,
181     std::vector<OpStatePtr> *p_states,
182     const DispatchModeVector &dispatch_modes,
183     bool recording,
184     mxnet::ShapeVector *shapes,
185     const imperative::CachedOpMonCallback& callback,
186     const bool monitor_all) {
187   for (size_t i = node_start; i < node_end; ++i) {
188     const nnvm::IndexedGraph::Node& node = idx[i];
189     if (node.source->op() == nullptr) {
190       continue;
191     }
192     std::vector<NDArray*> ndinputs = NodeInputs(idx, i, arrays);
193     std::vector<NDArray*> ndoutputs = NodeOutputs(idx, i, arrays);
194     std::vector<OpReqType> req;
195     Context ctx = GetContext(node.source->attrs, ndinputs, ndoutputs, default_ctx);
196     if (callback && monitor_all) {
197         mxnet::common::ExecuteMonInputCallback(idx, arrays, i, callback);
198     }
199     auto invoke = [&](const OpStatePtr &state) {
200       const nnvm::IndexedGraph::Node& node = idx[i];
201       DispatchMode dispatch_mode = DispatchMode::kUndefined;
202       SetShapeType(ctx, node.source->attrs, ndinputs, ndoutputs, &dispatch_mode);
203       SetWriteInplaceReq(ndinputs, ndoutputs, &req);
204       Imperative::Get()->InvokeOp(ctx, node.source->attrs, ndinputs, ndoutputs,
205                                   req, dispatch_mode, state);
206       for (size_t j = 0; j < ndoutputs.size(); ++j) {
207         if (mxnet::op::shape_is_none(ndoutputs[j]->shape())) {
208           ndoutputs[j]->WaitToRead();
209           ndoutputs[j]->SetShapeFromChunk();
210         }
211         size_t eid = idx.entry_id(i, j);
212         auto shape = ndoutputs[j]->shape();
213         (*shapes)[eid] = shape;
214       }
215       if (recording) {
216         Imperative::Get()->RecordOp(NodeAttrs(node.source->attrs), ndinputs, ndoutputs, state);
217       }
218     };
219     InvokeOperator(idx, i, retain_graph, arrays, ctx, p_states, ndinputs, ndoutputs,
220                    &req, &ref_count, invoke);
221     if (callback) {
222         mxnet::common::ExecuteMonOutputCallback(idx, arrays, i, callback);
223     }
224   }
225 }
226 
227 }  // namespace imperative
228 }  // namespace mxnet
229