1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 #include <mxnet/io.h>
21 #include <mxnet/base.h>
22 #include <mxnet/ndarray.h>
23 #include <mxnet/operator.h>
24 #include <mxnet/operator_util.h>
25 #include <dmlc/logging.h>
26 #include <dmlc/optional.h>
27 #include "./operator_common.h"
28 #include "./elemwise_op_common.h"
29 #include "../imperative/imperative_utils.h"
30 #include "./subgraph_op_common.h"
31
32 namespace mxnet {
33 namespace op {
34
35 struct ForeachParam : public dmlc::Parameter<ForeachParam> {
36 int num_args;
37 int num_outputs;
38 int num_out_data;
39 // The location of states in the subgraph inputs.
40 mxnet::Tuple<dim_t> in_state_locs;
41 // The location of data arrays in the subgraph inputs.
42 mxnet::Tuple<dim_t> in_data_locs;
43 // The location of remaining arrays in the subgraph inputs.
44 mxnet::Tuple<dim_t> remain_locs;
DMLC_DECLARE_PARAMETERmxnet::op::ForeachParam45 DMLC_DECLARE_PARAMETER(ForeachParam) {
46 DMLC_DECLARE_FIELD(num_args).set_lower_bound(1)
47 .describe("Number of inputs.");
48 DMLC_DECLARE_FIELD(num_outputs)
49 .describe("The number of outputs of the subgraph.");
50 DMLC_DECLARE_FIELD(num_out_data)
51 .describe("The number of output data of the subgraph.");
52 DMLC_DECLARE_FIELD(in_state_locs)
53 .describe("The locations of loop states among the inputs.");
54 DMLC_DECLARE_FIELD(in_data_locs)
55 .describe("The locations of input data among the inputs.");
56 DMLC_DECLARE_FIELD(remain_locs)
57 .describe("The locations of remaining data among the inputs.");
58 }
59 }; // struct ForeachParam
60
61 DMLC_REGISTER_PARAMETER(ForeachParam);
62
63 class ForeachState: public LoopState {
64 public:
65 ForeachParam params;
66 int num_iterations;
67
ForeachState(const Symbol & g,const ForeachParam & params)68 ForeachState(const Symbol &g, const ForeachParam ¶ms) : LoopState(g) {
69 this->params = params;
70 }
71 };
72
ForeachComputeExCPU(const OpStatePtr & state_ptr,const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)73 static void ForeachComputeExCPU(const OpStatePtr& state_ptr,
74 const OpContext& ctx,
75 const std::vector<NDArray>& inputs,
76 const std::vector<OpReqType>& req,
77 const std::vector<NDArray>& outputs) {
78 ForeachState &state = state_ptr.get_state<ForeachState>();
79 const ForeachParam& params = state.params;
80 const size_t iter_dim = 0;
81 CHECK_EQ(outputs.size(), (size_t) params.num_outputs);
82 CHECK_GT(params.in_data_locs.ndim(), 0);
83 size_t len = inputs[0].shape()[iter_dim];
84 state.num_iterations = len;
85 for (int i = 1; i < params.in_data_locs.ndim(); i++)
86 CHECK_EQ(inputs[i].shape()[iter_dim], len);
87 for (size_t i = 0; i < (size_t) params.num_out_data; i++)
88 CHECK_EQ(len, outputs[i].shape()[iter_dim]);
89 for (const auto &arr : outputs)
90 CHECK_EQ(arr.storage_type(), kDefaultStorage)
91 << "The for operator doesn't support the sparse format";
92
93 // Initialize the outputs of the subgraph is a little trickier.
94 // The states from the previous iteration are used as the inputs of the next
95 // iteration, so I have to maintain two arrays, so the inputs and outputs
96 // of the subgraph share the same memory.
97 std::vector<NDArray> subg_outputs1(outputs.size());
98 std::vector<NDArray> subg_outputs2(outputs.size());
99 std::vector<NDArray> *subg_outputs[2]{&subg_outputs1, &subg_outputs2};
100 // If the length is an odd number, the last iteration will use the first set
101 // of outputs. In this way, we don't need to copy the results from the
102 // subgraph to the final outputs of the loop.
103 if (len % 2 == 1) {
104 for (size_t i = params.num_out_data; i < subg_outputs1.size(); i++) {
105 subg_outputs1[i] = outputs[i];
106 subg_outputs2[i] = NDArray(outputs[i].shape(), outputs[i].ctx(), true,
107 outputs[i].dtype());
108 }
109 } else {
110 // Otherwise, we'll use the second set of outputs.
111 for (size_t i = params.num_out_data; i < subg_outputs1.size(); i++) {
112 subg_outputs1[i] = NDArray(outputs[i].shape(), outputs[i].ctx(), true,
113 outputs[i].dtype());
114 subg_outputs2[i] = outputs[i];
115 }
116 }
117
118 // Initialize the inputs for the subgraph.
119 // In each iteration, we need to update the subgraph inputs for input data
120 // and the loop states.
121 std::vector<NDArray> subg_inputs(inputs.size());
122 // The remaining arrays (other than input data and states) only need to be set once.
123 for (int j = 0; j < params.remain_locs.ndim(); j++) {
124 CHECK_LT(params.remain_locs[j], subg_inputs.size());
125 subg_inputs[params.remain_locs[j]] = inputs[j + params.in_data_locs.ndim()
126 + params.in_state_locs.ndim()];
127 }
128
129 // Here we iterate over the first dimension of the first input array.
130 for (size_t i = 0; i < len; i++) {
131 // Initialize outputs for the subgraph.
132 std::vector<NDArray> *subg_out_curr = subg_outputs[i % 2];
133 std::vector<NDArray> *subg_out_prev = subg_outputs[(i + 1) % 2];
134 for (int j = 0; j < params.num_out_data; j++)
135 (*subg_out_curr)[j] = outputs[j].At(i);
136 // When recording for backward computation, we should make sure
137 // that output arrays are actually different in each iteration.
138 if (ctx.need_grad && i < len - 1) {
139 for (size_t j = params.num_out_data; j < subg_out_curr->size(); j++)
140 (*subg_out_curr)[j] = NDArray(outputs[j].shape(), outputs[j].ctx(),
141 true, outputs[j].dtype());
142 } else if (ctx.need_grad && i == len - 1) {
143 // For the last iteration, we need to write data to the output array
144 // directly.
145 for (size_t j = params.num_out_data; j < subg_out_curr->size(); j++)
146 (*subg_out_curr)[j] = outputs[j];
147 }
148
149 // Initialize inputs for the subgraph.
150 // Get a slice from the input data arrays.
151 for (int j = 0; j < params.in_data_locs.ndim(); j++) {
152 size_t loc = params.in_data_locs[j];
153 subg_inputs[loc] = inputs[j].At(i);
154 }
155 // For the rest of the iterations, the states are the outputs
156 // from the previous iteration.
157 if (i > 0) {
158 for (size_t j = params.num_out_data; j < subg_out_prev->size(); j++) {
159 size_t idx = j - params.num_out_data;
160 CHECK_LT(params.in_state_locs[idx], subg_inputs.size());
161 subg_inputs[params.in_state_locs[idx]] = (*subg_out_prev)[j];
162 }
163 } else {
164 for (int j = 0; j < params.in_state_locs.ndim(); j++) {
165 CHECK_LT(params.in_state_locs[j], subg_inputs.size());
166 subg_inputs[params.in_state_locs[j]] = inputs[j + params.in_data_locs.ndim()];
167 }
168 }
169
170 state.Forward(i, subg_inputs, req, *subg_out_curr, ctx.need_grad);
171 }
172 }
173
ForeachGradComputeExCPU(const OpStatePtr & state_ptr,const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)174 static void ForeachGradComputeExCPU(const OpStatePtr& state_ptr,
175 const OpContext& ctx,
176 const std::vector<NDArray>& inputs,
177 const std::vector<OpReqType>& req,
178 const std::vector<NDArray>& outputs) {
179 ForeachState &state = state_ptr.get_state<ForeachState>();
180 const ForeachParam& params = state.params;
181 CHECK_EQ(outputs.size(), (size_t) params.num_args - 1);
182 CHECK_GT(params.in_data_locs.ndim(), 0);
183 for (const auto &arr : outputs)
184 CHECK_EQ(arr.storage_type(), kDefaultStorage)
185 << "The for operator doesn't support the sparse format";
186 int len = state.num_iterations;
187 size_t num_output_data = params.num_out_data;
188
189 // In backward computation, we need to run iterations from backwards.
190 std::vector<NDArray> subg_ograds(params.num_outputs);
191 std::vector<NDArray> subg_igrads(outputs.size());
192 for (size_t i = num_output_data; i < subg_ograds.size(); i++)
193 subg_ograds[i] = inputs[i];
194 std::vector<OpReqType> subg_req(req.size());
195 for (auto r : req)
196 CHECK_NE(r, kWriteInplace);
197
198 // There are three types of arrays in igrads.
199 // * data gradients.
200 // * loop variable gradients.
201 // * remaining variable gradients.
202 // They are in the following order:
203 // [data vars], [loop vars], [remaining vars]
204
205 // [remaining vars]
206 for (int i = 0; i < params.remain_locs.ndim(); i++) {
207 size_t loc = params.remain_locs[i];
208 size_t orig_loc = i + params.in_data_locs.ndim() + params.in_state_locs.ndim();
209 subg_igrads[loc] = outputs[orig_loc];
210 subg_req[loc] = req[orig_loc];
211 }
212
213 for (int iter_num = len - 1; iter_num >= 0; iter_num--) {
214 for (int i = 0; i < params.num_out_data; i++)
215 subg_ograds[i] = inputs[i].At(iter_num);
216 if (iter_num < len - 1) {
217 // For the rest of the iterations, we should add graidents to the
218 // remaining vars.
219 for (int i = 0; i < params.remain_locs.ndim(); i++) {
220 size_t loc = params.remain_locs[i];
221 subg_req[loc] = kAddTo;
222 }
223 }
224
225 // [data vars]
226 for (int i = 0; i < params.in_data_locs.ndim(); i++) {
227 size_t loc = params.in_data_locs[i];
228 subg_igrads[loc] = outputs[i].At(iter_num);
229 subg_req[loc] = req[i];
230 }
231 // [loop vars]
232 for (int i = 0; i < params.in_state_locs.ndim(); i++) {
233 size_t loc = params.in_state_locs[i];
234 const NDArray &output = outputs[i + params.in_data_locs.ndim()];
235 if (iter_num != 0) {
236 // For state gradients, we need to allocate new NDArrays
237 // because intermediate state gradients won't be returned to the users.
238 subg_igrads[loc] = NDArray(output.shape(), output.ctx(), true, output.dtype());
239 } else {
240 subg_igrads[loc] = output;
241 }
242 // For the first iteration, we need to use the request provided by
243 // the user to write state gradients to the outputs.
244 subg_req[loc] = iter_num != 0 ? kWriteTo : req[i + params.in_data_locs.ndim()];
245 }
246 state.Backward(iter_num, subg_ograds, subg_req, subg_igrads);
247
248 size_t num_states = subg_ograds.size() - num_output_data;
249 for (size_t i = 0; i < num_states; i++) {
250 size_t loc = params.in_state_locs[i];
251 CHECK_LT(loc, subg_igrads.size());
252 subg_ograds[i + num_output_data] = subg_igrads[loc];
253 }
254 }
255 state.Cleanup();
256 }
257
258 template<typename T>
remap(const std::vector<T> & op_in,size_t start,const mxnet::Tuple<dim_t> & locs,std::vector<T> * subg_in)259 static void remap(const std::vector<T> &op_in, size_t start,
260 const mxnet::Tuple<dim_t> &locs, std::vector<T> *subg_in) {
261 auto op_in_it = op_in.begin() + start;
262 for (int i = 0; i < locs.ndim(); i++) {
263 dim_t loc = locs[i];
264 subg_in->at(loc) = *(op_in_it + i);
265 }
266 }
267
SliceFirstDim(const mxnet::TShape & s)268 static inline mxnet::TShape SliceFirstDim(const mxnet::TShape &s) {
269 if (s.ndim() > 1) {
270 return mxnet::TShape(s.begin() + 1, s.end());
271 } else {
272 return mxnet::TShape(mshadow::Shape1(1));
273 }
274 }
275
ForeachShape(const nnvm::NodeAttrs & attrs,mxnet::ShapeVector * in_shape,mxnet::ShapeVector * out_shape)276 static bool ForeachShape(const nnvm::NodeAttrs& attrs,
277 mxnet::ShapeVector *in_shape,
278 mxnet::ShapeVector *out_shape) {
279 const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
280 CHECK_EQ(out_shape->size(), (size_t) params.num_outputs);
281 CHECK_EQ(attrs.subgraphs.size(), 1U);
282
283 mxnet::ShapeVector subg_in_shape(in_shape->size());
284 // data shape
285 std::vector<bool> data_1d(params.in_data_locs.ndim(), false);
286 for (int i = 0; i < params.in_data_locs.ndim(); i++) {
287 size_t loc = params.in_data_locs[i];
288 if (in_shape->at(i).ndim() == 1)
289 data_1d[i] = true;
290 subg_in_shape[loc] = SliceFirstDim(in_shape->at(i));
291 }
292 // state shape
293 remap(*in_shape, params.in_data_locs.ndim(), params.in_state_locs,
294 &subg_in_shape);
295 // remaining shape
296 remap(*in_shape, params.in_data_locs.ndim() + params.in_state_locs.ndim(),
297 params.remain_locs, &subg_in_shape);
298
299 mxnet::ShapeVector subg_out_shape = *out_shape;
300 for (int i = 0; i < params.num_out_data; i++) {
301 mxnet::TShape shape = subg_out_shape[i];
302 // If we don't have shape info, we don't need to do anything.
303 if (!mxnet::ndim_is_known(shape))
304 continue;
305 subg_out_shape[i] = SliceFirstDim(shape);
306 }
307
308 bool infer_success = InferSubgraphShape(*attrs.subgraphs[0],
309 &subg_in_shape, &subg_out_shape);
310
311 // After inference, we need to move inferred information back to in_shape and
312 // out_shape.
313
314 // For the shape of output data.
315 size_t len = in_shape->at(0)[0];
316 for (int i = 0; i < params.num_out_data; i++) {
317 // If the output shape isn't inferred, we don't need to propogate the info.
318 const auto& g_out_shape = subg_out_shape[i];
319 if (!mxnet::ndim_is_known(g_out_shape))
320 continue;
321
322 auto out = mxnet::TShape(g_out_shape.ndim() + 1, -1);
323 out[0] = len;
324 for (int i = 1; i < out.ndim(); i++)
325 out[i] = g_out_shape[i - 1];
326 SHAPE_ASSIGN_CHECK(*out_shape, i, out);
327 }
328 // For the shape of output states.
329 for (size_t i = params.num_out_data; i < subg_out_shape.size(); i++)
330 SHAPE_ASSIGN_CHECK(*out_shape, i, subg_out_shape[i]);
331
332 // For the shape of input data.
333 for (int i = 0; i < params.in_data_locs.ndim(); i++) {
334 size_t loc = params.in_data_locs[i];
335 const auto &shape = subg_in_shape[loc];
336 // If the input data shape isn't inferred, we don't need to propogate the
337 // info.
338 if (!mxnet::ndim_is_known(shape))
339 continue;
340
341 if (data_1d[i]) {
342 mxnet::TShape s(1, -1);
343 s[0] = len;
344 SHAPE_ASSIGN_CHECK(*in_shape, i, s);
345 } else {
346 auto in = mxnet::TShape(shape.ndim() + 1, -1);
347 in[0] = len;
348 for (int i = 1; i < in.ndim(); i++)
349 in[i] = shape[i - 1];
350 SHAPE_ASSIGN_CHECK(*in_shape, i, in);
351 }
352 }
353 // For the shape of state.
354 for (int i = 0; i < params.in_state_locs.ndim(); i++) {
355 size_t loc = params.in_state_locs[i];
356 SHAPE_ASSIGN_CHECK(*in_shape, i + params.in_data_locs.ndim(),
357 subg_in_shape[loc]);
358 }
359 // For the shape of remaining data.
360 for (int i = 0; i < params.remain_locs.ndim(); i++) {
361 size_t loc = params.remain_locs[i];
362 SHAPE_ASSIGN_CHECK(*in_shape,
363 i + params.in_data_locs.ndim() + params.in_state_locs.ndim(),
364 subg_in_shape[loc]);
365 }
366
367 if (infer_success) {
368 size_t num_states = out_shape->size() - params.num_out_data;
369 for (size_t i = 0; i < num_states; i++) {
370 CHECK_EQ((*out_shape)[i + params.num_out_data],
371 (*in_shape)[i + params.in_data_locs.ndim()]);
372 }
373 }
374 // Check if we have inferred the shapes correctly.
375 return infer_success;
376 }
377
ForeachType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_type,std::vector<int> * out_type)378 static bool ForeachType(const nnvm::NodeAttrs& attrs,
379 std::vector<int> *in_type, std::vector<int> *out_type) {
380 const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
381 CHECK_EQ(out_type->size(), (size_t) params.num_outputs);
382 CHECK_EQ(attrs.subgraphs.size(), 1U);
383 std::vector<int> subg_in_type(in_type->size(), 0);
384 remap(*in_type, 0, params.in_data_locs, &subg_in_type);
385 remap(*in_type, params.in_data_locs.ndim(), params.in_state_locs, &subg_in_type);
386 remap(*in_type, params.in_data_locs.ndim() + params.in_state_locs.ndim(),
387 params.remain_locs, &subg_in_type);
388 bool success = InferSubgraphDataType(*attrs.subgraphs[0], &subg_in_type, out_type);
389 for (int i = 0; i < params.in_data_locs.ndim(); i++) {
390 size_t loc = params.in_data_locs[i];
391 TYPE_ASSIGN_CHECK(*in_type, i, subg_in_type[loc]);
392 }
393 for (int i = 0; i < params.in_state_locs.ndim(); i++) {
394 size_t loc = params.in_state_locs[i];
395 TYPE_ASSIGN_CHECK(*in_type, i + params.in_data_locs.ndim(), subg_in_type[loc]);
396 }
397 for (int i = 0; i < params.remain_locs.ndim(); i++) {
398 size_t loc = params.remain_locs[i];
399 TYPE_ASSIGN_CHECK(*in_type, i + params.in_data_locs.ndim() + params.in_state_locs.ndim(),
400 subg_in_type[loc]);
401 }
402 return success;
403 }
404
ForeachStorageType(const nnvm::NodeAttrs & attrs,const int dev_mask,DispatchMode * dispatch_mode,std::vector<int> * in_attrs,std::vector<int> * out_attrs)405 static bool ForeachStorageType(const nnvm::NodeAttrs& attrs,
406 const int dev_mask,
407 DispatchMode* dispatch_mode,
408 std::vector<int> *in_attrs,
409 std::vector<int> *out_attrs) {
410 const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
411 CHECK_EQ(out_attrs->size(), (size_t) params.num_outputs);
412 CHECK_EQ(attrs.subgraphs.size(), 1U);
413 std::vector<int> subg_in_attrs(in_attrs->size(), kUndefinedStorage);
414 remap(*in_attrs, 0, params.in_data_locs, &subg_in_attrs);
415 remap(*in_attrs, params.in_data_locs.ndim(), params.in_state_locs, &subg_in_attrs);
416 remap(*in_attrs, params.in_data_locs.ndim() + params.in_state_locs.ndim(),
417 params.remain_locs, &subg_in_attrs);
418 bool success = InferSubgraphStorage(*attrs.subgraphs[0], dev_mask,
419 dispatch_mode, &subg_in_attrs, out_attrs);
420 for (int i = 0; i < params.in_data_locs.ndim(); i++) {
421 size_t loc = params.in_data_locs[i];
422 STORAGE_TYPE_ASSIGN_CHECK(*in_attrs, i, subg_in_attrs[loc]);
423 }
424 for (int i = 0; i < params.in_state_locs.ndim(); i++) {
425 size_t loc = params.in_state_locs[i];
426 STORAGE_TYPE_ASSIGN_CHECK(*in_attrs, i + params.in_data_locs.ndim(),
427 subg_in_attrs[loc]);
428 }
429 for (int i = 0; i < params.remain_locs.ndim(); i++) {
430 size_t loc = params.remain_locs[i];
431 STORAGE_TYPE_ASSIGN_CHECK(*in_attrs,
432 i + params.in_data_locs.ndim() + params.in_state_locs.ndim(),
433 subg_in_attrs[loc]);
434 }
435 return success;
436 }
437
BackwardForeachStorageType(const nnvm::NodeAttrs & attrs,const int dev_mask,DispatchMode * dispatch_mode,std::vector<int> * in_attrs,std::vector<int> * out_attrs)438 static bool BackwardForeachStorageType(const nnvm::NodeAttrs& attrs,
439 const int dev_mask,
440 DispatchMode* dispatch_mode,
441 std::vector<int> *in_attrs,
442 std::vector<int> *out_attrs) {
443 const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
444 CHECK_EQ(out_attrs->size(), (size_t) params.num_args - 1);
445 CHECK_EQ(in_attrs->size(), (size_t) params.num_args - 1 + params.num_outputs * 2);
446 CHECK_EQ(attrs.subgraphs.size(), 1U);
447 CachedOp op(*attrs.subgraphs[0],
448 std::vector<std::pair<std::string, std::string> >());
449 // map the operator inputs to the subgraph inputs.
450 std::vector<int> subg_forward_ins(params.num_args - 1, kUndefinedStorage);
451 remap(*in_attrs, params.num_outputs, params.in_data_locs, &subg_forward_ins);
452 remap(*in_attrs, params.num_outputs + params.in_data_locs.ndim(),
453 params.in_state_locs, &subg_forward_ins);
454 remap(*in_attrs, params.num_outputs + params.in_data_locs.ndim() + params.in_state_locs.ndim(),
455 params.remain_locs, &subg_forward_ins);
456
457 // Copy backward input storage to backward subgraph input storage.
458 std::vector<int> subg_in_attrs = *in_attrs;
459 for (size_t i = 0; i < subg_forward_ins.size(); i++)
460 subg_in_attrs[i + params.num_outputs] = subg_forward_ins[i];
461 return op.BackwardStorageType(attrs, dev_mask, dispatch_mode,
462 &subg_in_attrs, out_attrs);
463 }
464
CreateForeachState(const NodeAttrs & attrs,Context ctx,const mxnet::ShapeVector & ishape,const std::vector<int> & itype)465 static OpStatePtr CreateForeachState(const NodeAttrs& attrs,
466 Context ctx,
467 const mxnet::ShapeVector& ishape,
468 const std::vector<int>& itype) {
469 const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
470 return OpStatePtr::Create<ForeachState>(*attrs.subgraphs[0], params);
471 }
472
473 static std::vector<nnvm::NodeEntry>
ForeachGradient(const nnvm::ObjectPtr & n,const std::vector<nnvm::NodeEntry> & ograds)474 ForeachGradient(const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
475 ElemwiseGradUseInOut fgrad{"_backward_foreach"};
476 std::vector<nnvm::NodeEntry> entries = fgrad(n, ograds);
477 entries[0].node->attrs.subgraphs = n->attrs.subgraphs;
478 return entries;
479 }
480
481 struct WhileLoopParam : public dmlc::Parameter<WhileLoopParam> {
482 int num_args;
483 int num_outputs;
484 int num_out_data;
485 int max_iterations;
486 // `cond' and `func' each takes a subset of while_loop's inputs as that to their subgraphs
487 // `cond_input_locs' contains indices of inputs fed to `cond', and
488 // `func_input_locs' contains indices of inputs fed to `func'.
489 // `func_var_locs' are indices in which input "variables" are stored in func's inputs.
490 mxnet::Tuple<dim_t> cond_input_locs;
491 mxnet::Tuple<dim_t> func_input_locs;
492 mxnet::Tuple<dim_t> func_var_locs;
DMLC_DECLARE_PARAMETERmxnet::op::WhileLoopParam493 DMLC_DECLARE_PARAMETER(WhileLoopParam) {
494 DMLC_DECLARE_FIELD(num_args).set_lower_bound(2)
495 .describe("Number of input arguments, including cond and func as two symbol inputs.");
496 DMLC_DECLARE_FIELD(num_outputs).set_lower_bound(1)
497 .describe("The number of outputs of the subgraph.");
498 DMLC_DECLARE_FIELD(num_out_data).set_lower_bound(0)
499 .describe("The number of outputs from the function body.");
500 DMLC_DECLARE_FIELD(max_iterations).set_lower_bound(1)
501 .describe("Maximum number of iterations.");
502 DMLC_DECLARE_FIELD(cond_input_locs)
503 .describe("The locations of cond's inputs in the given inputs.");
504 DMLC_DECLARE_FIELD(func_input_locs)
505 .describe("The locations of func's inputs in the given inputs.");
506 DMLC_DECLARE_FIELD(func_var_locs)
507 .describe("The locations of loop_vars among func's inputs.");
508 }
509 template <typename T>
sync_in_outmxnet::op::WhileLoopParam510 bool sync_in_out(std::vector<T> *in,
511 std::vector<T> *out,
512 std::function<bool(const T &)> is_empty) const {
513 for (int i = this->num_out_data; i < this->num_outputs; ++i) {
514 // each out->at(i) is a params, loop_var
515 T &x = in->at(this->func_input_locs[this->func_var_locs[i - this->num_out_data]]);
516 T &y = out->at(i);
517 fill_value(&x, &y, is_empty(x), is_empty(y));
518 }
519 return true;
520 }
521 }; // struct WhileLoopParam
522
523 DMLC_REGISTER_PARAMETER(WhileLoopParam);
524
525 class WhileLoopState: public LoopState {
526 public:
527 WhileLoopParam params;
528 size_t n_iterations; // the actual number of steps taken in this while loop, <= max_iterations
529 CachedOpPtr cond_op;
530 // abbrev for output_input_mapping
531 // indicates to which index the output of `func' will be copied to the input of `cond'
532 std::vector<int> oi_map;
533
WhileLoopState(const WhileLoopParam & params,const Symbol & cond,const Symbol & func)534 WhileLoopState(const WhileLoopParam ¶ms, const Symbol &cond, const Symbol &func) :
535 LoopState(func),
536 params(params),
537 n_iterations(0U),
538 cond_op(LoopState::MakeSharedOp(cond)),
539 oi_map(params.func_var_locs.ndim(), -1) {
540 const mxnet::Tuple<dim_t> &func_input_locs = params.func_input_locs;
541 const mxnet::Tuple<dim_t> &func_var_locs = params.func_var_locs;
542 const mxnet::Tuple<dim_t> &cond_input_locs = params.cond_input_locs;
543 for (int i = 0; i < func_var_locs.ndim(); ++i) {
544 dim_t pos_i = func_input_locs[func_var_locs[i]];
545 for (int j = 0; j < cond_input_locs.ndim(); ++j) {
546 dim_t pos_j = cond_input_locs[j];
547 if (pos_i == pos_j) {
548 this->oi_map[i] = j;
549 }
550 }
551 }
552 }
553 };
554
WhileLoopComputeExCPU(const OpStatePtr & state_ptr,const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)555 static void WhileLoopComputeExCPU(const OpStatePtr& state_ptr,
556 const OpContext& ctx,
557 const std::vector<NDArray>& inputs,
558 const std::vector<OpReqType>& req,
559 const std::vector<NDArray>& outputs) {
560 // The argument `inputs' are loop_vars and other inputs
561 // loop_vars are stored in stored in `loop_vars_locs'
562 // The argument `outputs' are output and new_loop_vars
563 // [0: num_out_data) are outputs at each step.
564 // [num_out_data: ) are new_loop_vars
565 // TODO(Junru): avoid dynamic NDArray allocation
566 WhileLoopState &state = state_ptr.get_state<WhileLoopState>();
567 const WhileLoopParam& params = state.params;
568 // a helper function, converting std::vector<NDArray> to std::vector<NDArray*>
569 const auto to_ptr_vec = [](std::vector<NDArray> &in, std::vector<NDArray*> *out) {
570 out->clear();
571 out->reserve(in.size());
572 std::transform(std::begin(in),
573 std::end(in),
574 std::back_inserter(*out),
575 [](NDArray &a) {return &a;});
576 };
577 // sanity checks
578 CHECK_EQ(inputs.size() + 2U, (size_t) params.num_args);
579 CHECK_EQ(outputs.size(), (size_t) params.num_outputs);
580 CHECK_EQ(outputs.size(), req.size());
581 // construct inputs and outputs for cond
582 std::vector<NDArray> cond_inputs, cond_outputs = {NDArray()};
583 extract_by_loc(inputs, params.cond_input_locs, &cond_inputs);
584 std::vector<NDArray*> cond_input_ptr, cond_output_ptr;
585 to_ptr_vec(cond_inputs, &cond_input_ptr);
586 to_ptr_vec(cond_outputs, &cond_output_ptr);
587 // construct inputs and outputs for func
588 std::vector<NDArray> func_inputs, func_outputs(outputs.size());
589 extract_by_loc(inputs, params.func_input_locs, &func_inputs);
590 for (size_t &step = state.n_iterations = 0; step < (size_t) params.max_iterations; ++step) {
591 state.cond_op->Forward(nullptr, cond_input_ptr, cond_output_ptr);
592 if (!as_bool_scalar(*cond_output_ptr[0])) {
593 break;
594 }
595 // we create func_outputs for the current step:
596 for (size_t i = 0; i < outputs.size(); ++i) {
597 func_outputs[i] = NDArray(outputs[i].ctx(), outputs[i].dtype());
598 }
599 state.Forward(step, func_inputs, req, func_outputs, ctx.need_grad);
600 if (step == 0) {
601 for (int i = 0; i < params.num_out_data; ++i) {
602 func_outputs[i].WaitToRead();
603 if (!shape_is_known(func_outputs[i].shape())) {
604 func_outputs[i].SetShapeFromChunk();
605 }
606 mxnet::TShape step_shape = func_outputs[i].shape();
607 mxnet::TShape shape(step_shape.ndim() + 1, 0);
608 shape[0] = params.max_iterations;
609 for (int j = 0; j < step_shape.ndim(); ++j) {
610 shape[j + 1] = step_shape[j];
611 }
612 const_cast<NDArray &>(outputs[i]).Init(shape);
613 }
614 }
615 for (int i = 0; i < params.num_out_data; ++i) {
616 NDArray first_slot = outputs[i].At(step);
617 mxnet::CopyFromTo(func_outputs[i], &first_slot);
618 }
619 // func_inputs on the next step:
620 // the output (new_loop_vars) will become the new inputs (loop_vars)
621 for (size_t i = params.num_out_data; i < outputs.size(); ++i) {
622 int j = params.func_var_locs[i - params.num_out_data];
623 func_inputs[j] = func_outputs[i];
624 int k = state.oi_map[i - params.num_out_data];
625 if (k != -1) {
626 // I actually don't need to update cond_inputs
627 cond_inputs[k] = func_outputs[i];
628 cond_input_ptr[k] = &func_outputs[i];
629 }
630 }
631 }
632 // copy output data to `outputs'
633 // case 1: at least one step is executed,
634 // the final_loop_vars must be stored in func_inputs
635 // case 2: no step is executed
636 // the final_loop_vars is the same as loop_vars, which are also stored in func_inputs
637 // therefore, we copy func_inputs[:] to outputs[num_out_data: ]
638 for (size_t i = params.num_out_data; i < outputs.size(); ++i) {
639 size_t j = params.func_var_locs[i - params.num_out_data];
640 if (!shape_is_known(outputs[i].shape())) {
641 const_cast<NDArray &>(outputs[i]).Init(func_inputs[j].shape());
642 }
643 mxnet::CopyFromTo(func_inputs[j], &outputs[i]);
644 }
645 for (int i = 0; i < params.num_out_data; ++i) {
646 const_cast<NDArray &>(outputs[i]).SetShapeFromChunk();
647 }
648 if (state.n_iterations == 0) {
649 for (size_t i = 0; i < outputs.size(); ++i) {
650 if (!shape_is_known(outputs[i].shape())) {
651 const_cast<NDArray &>(outputs[i]).ReshapeAndAlloc({1});
652 }
653 }
654 }
655 }
656
WhileLoopGradComputeExCPU(const OpStatePtr & state_ptr,const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & _req,const std::vector<NDArray> & _outputs)657 static void WhileLoopGradComputeExCPU(const OpStatePtr& state_ptr,
658 const OpContext& ctx,
659 const std::vector<NDArray>& inputs,
660 const std::vector<OpReqType>& _req,
661 const std::vector<NDArray>& _outputs) {
662 // inputs are dl / df(x)
663 // outputs are dl / dx
664 // where f is the current function,
665 // x is the input to the current function,
666 // TODO(Junru): avoid dynamic NDArray allocation
667 WhileLoopState &state = state_ptr.get_state<WhileLoopState>();
668 const WhileLoopParam& params = state.params;
669 // sanity checks
670 CHECK_EQ(_outputs.size() + 2U, (size_t) params.num_args);
671 CHECK_EQ(_outputs.size(), _req.size());
672 for (auto x : _req) {
673 CHECK_NE(x, kWriteInplace);
674 }
675 std::vector<NDArray> outputs;
676 std::vector<OpReqType> req;
677 extract_by_loc(_outputs, params.func_input_locs, &outputs);
678 extract_by_loc(_req, params.func_input_locs, &req);
679 if (state.n_iterations == 0) {
680 for (int i = params.num_out_data; i < params.num_outputs; ++i) {
681 int j = params.func_var_locs[i - params.num_out_data];
682 mxnet::CopyFromTo(inputs[i], &outputs[j]);
683 }
684 state.Cleanup();
685 return;
686 }
687 // collect var_locs and out_locs, positions other than var_locs are out_locs, i.e.
688 // [0, var_locs[0])
689 // (var_locs[1], var_locs[2]),
690 // (var_locs[2], var_locs[3]),
691 // ...
692 // (var_locs[-2], var_locs[-1] = params.num_args - 2)
693 std::vector<dim_t> var_locs(params.func_var_locs.begin(), params.func_var_locs.end());
694 var_locs.push_back((dim_t) params.num_args - 2U);
695 sort(var_locs.begin(), var_locs.end());
696 // vectors for the backward loop
697 std::vector<NDArray> ograds(params.num_outputs);
698 std::vector<NDArray> igrads(outputs.size());
699 std::vector<OpReqType> iter_req(req.size());
700 for (int i = params.num_out_data; i < params.num_outputs; ++i)
701 ograds[i] = inputs[i];
702 const int n_iter = state.n_iterations;
703 for (int step = n_iter - 1; step >= 0; --step) {
704 // ograds[ : num_out_data] = inputs[ : num_out_data][step]
705 // ograds[num_out_data: ] is maintained in the end of each loop
706 std::transform(std::begin(inputs),
707 std::begin(inputs) + params.num_out_data,
708 std::begin(ograds),
709 [step] (const NDArray &a) { return a.At(step); } );
710 // igrads[i] =
711 // outputs[i] (step == 0)
712 // outputs[i] (step != 0 && i not in loop_var_locs)
713 // ArrayLike(outputs[i]) (step != 0 && i in loop_var_locs)
714 // iter_req =
715 // kWriteTo (step != 0 && i in loop_var_locs)
716 // req[i] (step == 0 && i in loop_var_locs)
717 // kAddTo (step != n_iters - 1 && i not in loop_var_locs)
718 // req[i] (step == n_iters - 1 && i not in loop_var_locs)
719 {
720 size_t i = 0;
721 for (size_t loc : var_locs) {
722 for ( ; i < loc; ++i) {
723 // locs other that var_locs
724 igrads[i] = outputs[i];
725 iter_req[i] = (step + 1 == n_iter || req[i] == kNullOp)
726 ? req[i]
727 : kAddTo;
728 }
729 if (i < (size_t) params.num_args - 2U) {
730 // a var
731 igrads[i] = (step == 0)
732 ? outputs[i]
733 : NDArray(outputs[i].shape(), outputs[i].ctx(), true, outputs[i].dtype());
734 iter_req[i] = (step == 0 || req[i] == kNullOp)
735 ? req[i]
736 : kWriteTo;
737 ++i;
738 } else {
739 break;
740 }
741 }
742 }
743 state.Backward(step, ograds, iter_req, igrads);
744 for (int i = params.num_out_data; i < params.num_outputs; ++i) {
745 size_t j = params.func_var_locs[i - params.num_out_data];
746 ograds[i] = igrads[j];
747 }
748 }
749 state.Cleanup();
750 }
751
WhileLoopType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_type,std::vector<int> * out_type)752 static bool WhileLoopType(const nnvm::NodeAttrs& attrs,
753 std::vector<int> *in_type, std::vector<int> *out_type) {
754 const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
755 static const std::function<bool(const int &)> is_udf = is_type_udf;
756 CHECK_EQ(in_type->size() + 2U, (size_t) params.num_args);
757 CHECK_EQ(out_type->size(), (size_t) params.num_outputs);
758 CHECK_EQ(attrs.subgraphs.size(), 2U);
759 CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U);
760 std::vector<int> cond_in_type;
761 std::vector<int> func_in_type;
762 extract_by_loc(*in_type, params.cond_input_locs, &cond_in_type);
763 extract_by_loc(*in_type, params.func_input_locs, &func_in_type);
764 std::vector<int> cond_out_type = {-1};
765 CHECK(params.sync_in_out(in_type, out_type, is_udf));
766 bool succ_0 = InferSubgraphDataType(*attrs.subgraphs[0], &cond_in_type, &cond_out_type);
767 CHECK(params.sync_in_out(in_type, out_type, is_udf));
768 CHECK(sync_in_in(params.cond_input_locs, in_type, &cond_in_type, is_udf));
769 bool succ_1 = InferSubgraphDataType(*attrs.subgraphs[1], &func_in_type, out_type);
770 CHECK(params.sync_in_out(in_type, out_type, is_udf));
771 CHECK(sync_in_in(params.func_input_locs, in_type, &func_in_type, is_udf));
772 return succ_0 && succ_1;
773 }
774
WhileLoopStorageType(const nnvm::NodeAttrs & attrs,const int dev_mask,DispatchMode * dispatch_mode,std::vector<int> * in_attrs,std::vector<int> * out_attrs)775 static bool WhileLoopStorageType(const nnvm::NodeAttrs& attrs,
776 const int dev_mask,
777 DispatchMode* dispatch_mode,
778 std::vector<int> *in_attrs,
779 std::vector<int> *out_attrs) {
780 const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
781 static const std::function<bool(const int &)> is_udf = is_stype_udf;
782 CHECK_EQ(in_attrs->size() + 2U, (size_t) params.num_args);
783 CHECK_EQ(out_attrs->size(), (size_t) params.num_outputs);
784 CHECK_EQ(attrs.subgraphs.size(), 2U);
785 CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U);
786 std::vector<int> cond_in_attrs;
787 std::vector<int> func_in_attrs;
788 extract_by_loc(*in_attrs, params.cond_input_locs, &cond_in_attrs);
789 extract_by_loc(*in_attrs, params.func_input_locs, &func_in_attrs);
790 std::vector<int> cond_out_attrs = {kDefaultStorage};
791 DispatchMode cond_mode = DispatchMode::kUndefined;
792 DispatchMode func_mode = DispatchMode::kUndefined;
793 *dispatch_mode = DispatchMode::kFComputeEx;
794 CHECK(params.sync_in_out(in_attrs, out_attrs, is_udf));
795 bool succ_0 = InferSubgraphStorage(*attrs.subgraphs[0], dev_mask, \
796 &cond_mode, &cond_in_attrs, &cond_out_attrs);
797 CHECK(params.sync_in_out(in_attrs, out_attrs, is_udf));
798 CHECK(sync_in_in(params.cond_input_locs, in_attrs, &cond_in_attrs, is_udf));
799 bool succ_1 = InferSubgraphStorage(*attrs.subgraphs[1], dev_mask, \
800 &func_mode, &func_in_attrs, out_attrs);
801 CHECK(params.sync_in_out(in_attrs, out_attrs, is_udf));
802 CHECK(sync_in_in(params.func_input_locs, in_attrs, &func_in_attrs, is_udf));
803 return succ_0 && succ_1;
804 }
805
BackwardWhileLoopStorageType(const nnvm::NodeAttrs & attrs,const int dev_mask,DispatchMode * dispatch_mode,std::vector<int> * in_attrs,std::vector<int> * out_attrs)806 static bool BackwardWhileLoopStorageType(const nnvm::NodeAttrs& attrs,
807 const int dev_mask,
808 DispatchMode* dispatch_mode,
809 std::vector<int> *in_attrs,
810 std::vector<int> *out_attrs) {
811 // `cond' is not backwarded, don't check
812 const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
813 CHECK_EQ(out_attrs->size() + 2U, (size_t) params.num_args);
814 CHECK_EQ(attrs.subgraphs.size(), 2U);
815 CachedOp op(*attrs.subgraphs[1], {});
816 return op.BackwardStorageType(attrs, dev_mask, dispatch_mode,
817 in_attrs, out_attrs);
818 }
819
CreateWhileLoopState(const NodeAttrs & attrs,Context ctx,const mxnet::ShapeVector & ishape,const std::vector<int> & itype)820 static OpStatePtr CreateWhileLoopState(const NodeAttrs& attrs,
821 Context ctx,
822 const mxnet::ShapeVector& ishape,
823 const std::vector<int>& itype) {
824 const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
825 return OpStatePtr::Create<WhileLoopState>(params, *attrs.subgraphs[0], *attrs.subgraphs[1]);
826 }
827
828 static std::vector<nnvm::NodeEntry>
WhileLoopGradient(const nnvm::ObjectPtr & n,const std::vector<nnvm::NodeEntry> & ograds)829 WhileLoopGradient(const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
830 ElemwiseGradUseInOut fgrad{"_backward_while_loop"};
831 std::vector<nnvm::NodeEntry> entries = fgrad(n, ograds);
832 entries[0].node->attrs.subgraphs = n->attrs.subgraphs;
833 return entries;
834 }
835
836 struct CondParam : public dmlc::Parameter<CondParam> {
837 int num_args;
838 int num_outputs;
839 mxnet::Tuple<dim_t> cond_input_locs;
840 mxnet::Tuple<dim_t> then_input_locs;
841 mxnet::Tuple<dim_t> else_input_locs;
DMLC_DECLARE_PARAMETERmxnet::op::CondParam842 DMLC_DECLARE_PARAMETER(CondParam) {
843 DMLC_DECLARE_FIELD(num_args).set_lower_bound(3)
844 .describe("Number of input arguments, including cond, then and else as three symbol inputs.");
845 DMLC_DECLARE_FIELD(num_outputs).set_lower_bound(1)
846 .describe("The number of outputs of the subgraph.");
847 DMLC_DECLARE_FIELD(cond_input_locs)
848 .describe("The locations of cond's inputs in the given inputs.");
849 DMLC_DECLARE_FIELD(then_input_locs)
850 .describe("The locations of then's inputs in the given inputs.");
851 DMLC_DECLARE_FIELD(else_input_locs)
852 .describe("The locations of else's inputs in the given inputs.");
853 }
854 }; // struct CondParam
855
856 DMLC_REGISTER_PARAMETER(CondParam);
857
858 class CondState {
859 public:
860 CondParam params;
861 CachedOpPtr cond_op;
862 LoopState then_branch;
863 LoopState else_branch;
864 int branch_selection; // 1 if then branch; 0 if else branch; -1 if undefined
865
CondState(const CondParam & params,const Symbol & cond,const Symbol & then_sym,const Symbol & else_sym)866 CondState(const CondParam ¶ms,
867 const Symbol &cond,
868 const Symbol &then_sym,
869 const Symbol &else_sym):
870 params(params),
871 cond_op(LoopState::MakeSharedOp(cond)),
872 then_branch(then_sym),
873 else_branch(else_sym),
874 branch_selection(-1) {
875 }
876 };
877
CondComputeExCPU(const OpStatePtr & state_ptr,const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)878 static void CondComputeExCPU(const OpStatePtr& state_ptr,
879 const OpContext& ctx,
880 const std::vector<NDArray>& inputs,
881 const std::vector<OpReqType>& req,
882 const std::vector<NDArray>& outputs) {
883 // The argument `inputs' are loop_vars and other inputs
884 // loop_vars are stored in stored in `loop_vars_locs'
885 // The argument `outputs' are output and new_loop_vars
886 // [0: num_out_data) are outputs at each step.
887 // [num_out_data: ) are new_loop_vars
888 CondState &state = state_ptr.get_state<CondState>();
889 const CondParam& params = state.params;
890 // a helper function, converting std::vector<NDArray> to std::vector<NDArray*>
891 const auto to_ptr_vec = [](std::vector<NDArray> &in, std::vector<NDArray*> *out) {
892 out->clear();
893 out->reserve(in.size());
894 std::transform(std::begin(in),
895 std::end(in),
896 std::back_inserter(*out),
897 [](NDArray &a) {return &a;});
898 };
899 // sanity checks
900 CHECK_EQ(inputs.size() + 3U, (size_t) params.num_args);
901 CHECK_EQ(outputs.size(), (size_t) params.num_outputs);
902 CHECK_EQ(outputs.size(), req.size());
903 // construct inputs and outputs for cond
904 std::vector<NDArray> cond_inputs;
905 std::vector<NDArray> cond_outputs = {NDArray()};
906 std::vector<NDArray*> cond_input_ptr;
907 std::vector<NDArray*> cond_output_ptr;
908 extract_by_loc(inputs, params.cond_input_locs, &cond_inputs);
909 to_ptr_vec(cond_inputs, &cond_input_ptr);
910 to_ptr_vec(cond_outputs, &cond_output_ptr);
911 int &branch_selection = state.branch_selection;
912 // run cond
913 state.cond_op->Forward(nullptr, cond_input_ptr, cond_output_ptr);
914 branch_selection = as_bool_scalar(*cond_output_ptr[0]);
915 // select the right branch
916 const mxnet::Tuple<dim_t> &func_input_locs = branch_selection
917 ? params.then_input_locs
918 : params.else_input_locs;
919 LoopState &loop_state = branch_selection
920 ? state.then_branch
921 : state.else_branch;
922 // extract inputs for the branch
923 std::vector<NDArray> func_inputs;
924 extract_by_loc(inputs, func_input_locs, &func_inputs);
925 loop_state.Forward(0, func_inputs, req, outputs, ctx.need_grad);
926 }
927
CondGradComputeExCPU(const OpStatePtr & state_ptr,const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & _req,const std::vector<NDArray> & outputs)928 static void CondGradComputeExCPU(const OpStatePtr& state_ptr,
929 const OpContext& ctx,
930 const std::vector<NDArray>& inputs,
931 const std::vector<OpReqType>& _req,
932 const std::vector<NDArray>& outputs) {
933 CondState &state = state_ptr.get_state<CondState>();
934 const CondParam& params = state.params;
935 // sanity checks
936 CHECK_EQ(outputs.size() + 3U, (size_t) params.num_args);
937 CHECK_EQ(outputs.size(), _req.size());
938 // select the right branch
939 int branch_selection = state.branch_selection;
940 CHECK_NE(branch_selection, -1);
941 const mxnet::Tuple<dim_t> &func_input_locs = branch_selection
942 ? params.then_input_locs
943 : params.else_input_locs;
944 LoopState &loop_state = branch_selection
945 ? state.then_branch
946 : state.else_branch;
947 // construct parameters
948 std::vector<NDArray> ograds(inputs.begin(), inputs.begin() + params.num_outputs);
949 std::vector<OpReqType> req;
950 extract_by_loc(_req, func_input_locs, &req);
951 std::vector<NDArray> igrads;
952 extract_by_loc(outputs, func_input_locs, &igrads);
953 loop_state.Backward(0, ograds, req, igrads);
954 loop_state.Cleanup();
955 }
956
CondType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_type,std::vector<int> * out_type)957 static bool CondType(const nnvm::NodeAttrs& attrs,
958 std::vector<int> *in_type,
959 std::vector<int> *out_type) {
960 const CondParam& params = nnvm::get<CondParam>(attrs.parsed);
961 static const std::function<bool(const int &)> is_udf = is_type_udf;
962 CHECK_EQ(in_type->size() + 3U, (size_t) params.num_args);
963 CHECK_EQ(out_type->size(), (size_t) params.num_outputs);
964 CHECK_EQ(attrs.subgraphs.size(), 3U);
965 CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U);
966 CHECK_EQ(attrs.subgraphs[1]->outputs.size(), attrs.subgraphs[2]->outputs.size());
967 std::vector<int> cond_in_type;
968 std::vector<int> then_in_type;
969 std::vector<int> else_in_type;
970 extract_by_loc(*in_type, params.cond_input_locs, &cond_in_type);
971 extract_by_loc(*in_type, params.then_input_locs, &then_in_type);
972 extract_by_loc(*in_type, params.else_input_locs, &else_in_type);
973 std::vector<int> cond_out_type = {0};
974 bool succ_0 = InferSubgraphDataType(*attrs.subgraphs[0], &cond_in_type, &cond_out_type);
975 CHECK(sync_in_in(params.cond_input_locs, in_type, &cond_in_type, is_udf));
976 bool succ_1 = InferSubgraphDataType(*attrs.subgraphs[1], &then_in_type, out_type);
977 CHECK(sync_in_in(params.then_input_locs, in_type, &then_in_type, is_udf));
978 bool succ_2 = InferSubgraphDataType(*attrs.subgraphs[2], &else_in_type, out_type);
979 CHECK(sync_in_in(params.else_input_locs, in_type, &else_in_type, is_udf));
980 return succ_0 && succ_1 && succ_2;
981 }
982
CondStorageType(const nnvm::NodeAttrs & attrs,const int dev_mask,DispatchMode * dispatch_mode,std::vector<int> * in_attrs,std::vector<int> * out_attrs)983 static bool CondStorageType(const nnvm::NodeAttrs& attrs,
984 const int dev_mask,
985 DispatchMode* dispatch_mode,
986 std::vector<int> *in_attrs,
987 std::vector<int> *out_attrs) {
988 const CondParam& params = nnvm::get<CondParam>(attrs.parsed);
989 static const std::function<bool(const int &)> is_udf = is_stype_udf;
990 CHECK_EQ(in_attrs->size() + 3U, (size_t) params.num_args);
991 CHECK_EQ(out_attrs->size(), (size_t) params.num_outputs);
992 CHECK_EQ(attrs.subgraphs.size(), 3U);
993 CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U);
994 CHECK_EQ(attrs.subgraphs[1]->outputs.size(), attrs.subgraphs[2]->outputs.size());
995 std::vector<int> cond_in_attrs;
996 std::vector<int> then_in_attrs;
997 std::vector<int> else_in_attrs;
998 extract_by_loc(*in_attrs, params.cond_input_locs, &cond_in_attrs);
999 extract_by_loc(*in_attrs, params.then_input_locs, &then_in_attrs);
1000 extract_by_loc(*in_attrs, params.else_input_locs, &else_in_attrs);
1001 std::vector<int> cond_out_attrs = {kDefaultStorage};
1002 DispatchMode cond_mode = DispatchMode::kUndefined;
1003 DispatchMode then_mode = DispatchMode::kUndefined;
1004 DispatchMode else_mode = DispatchMode::kUndefined;
1005 *dispatch_mode = DispatchMode::kFComputeEx;
1006 bool succ_0 = InferSubgraphStorage(*attrs.subgraphs[0], dev_mask, \
1007 &cond_mode, &cond_in_attrs, &cond_out_attrs);
1008 CHECK(sync_in_in(params.cond_input_locs, in_attrs, &cond_in_attrs, is_udf));
1009 bool succ_1 = InferSubgraphStorage(*attrs.subgraphs[1], dev_mask, \
1010 &then_mode, &then_in_attrs, out_attrs);
1011 CHECK(sync_in_in(params.then_input_locs, in_attrs, &then_in_attrs, is_udf));
1012 bool succ_2 = InferSubgraphStorage(*attrs.subgraphs[2], dev_mask, \
1013 &else_mode, &else_in_attrs, out_attrs);
1014 CHECK(sync_in_in(params.else_input_locs, in_attrs, &else_in_attrs, is_udf));
1015 return succ_0 && succ_1 && succ_2;
1016 }
1017
BackwardCondStorageType(const nnvm::NodeAttrs & attrs,const int dev_mask,DispatchMode * dispatch_mode,std::vector<int> * in_attrs,std::vector<int> * out_attrs)1018 static bool BackwardCondStorageType(const nnvm::NodeAttrs& attrs,
1019 const int dev_mask,
1020 DispatchMode* dispatch_mode,
1021 std::vector<int> *in_attrs,
1022 std::vector<int> *out_attrs) {
1023 const CondParam& params = nnvm::get<CondParam>(attrs.parsed);
1024 CHECK_EQ(out_attrs->size() + 3U, (size_t) params.num_args);
1025 CHECK_EQ(attrs.subgraphs.size(), 3U);
1026 static const std::function<bool(const int &)> is_udf = is_stype_udf;
1027 auto sub_pass = [&](const std::shared_ptr<Symbol> &subg, const mxnet::Tuple<dim_t> &input_locs) {
1028 // A. first construct subg_in_attrs
1029 // need subg_in_attrs as subg_bwd_out (copy), subg_fwd_in (extract), subg_fwd_out (copy)
1030 std::vector<int> subg_in_attrs;
1031 size_t num_elts = params.num_outputs * 2 + input_locs.ndim();
1032 subg_in_attrs.reserve(num_elts);
1033 // part 1. subg_bwd_out (copy)
1034 subg_in_attrs.insert(subg_in_attrs.end(),
1035 in_attrs->begin(),
1036 in_attrs->begin() + params.num_outputs);
1037 // part 2. subg_fwd_in (extract)
1038 std::vector<int> fwd_in(in_attrs->begin() + params.num_outputs,
1039 in_attrs->begin() + params.num_outputs + params.num_args - 3);
1040 std::vector<int> subg_fwd_in;
1041 extract_by_loc(fwd_in, input_locs, &subg_fwd_in);
1042 subg_in_attrs.insert(subg_in_attrs.end(),
1043 subg_fwd_in.begin(),
1044 subg_fwd_in.end());
1045 // part 3. subg_fwd_out (copy)
1046 subg_in_attrs.insert(subg_in_attrs.end(),
1047 in_attrs->begin() + params.num_outputs + params.num_args - 3,
1048 in_attrs->end());
1049 // check correctness of the number of elements
1050 CHECK_EQ(subg_in_attrs.size(), num_elts);
1051 // B. then we construct subg_out_attrs by extracting from out_attrs
1052 std::vector<int> subg_out_attrs;
1053 extract_by_loc(*out_attrs, input_locs, &subg_out_attrs);
1054 // then we construct the subgraph and do inference
1055 CachedOp op(*subg, {});
1056 bool ret = op.BackwardStorageType(attrs, dev_mask, dispatch_mode, \
1057 &subg_in_attrs, &subg_out_attrs);
1058 CHECK(sync_in_in(input_locs, out_attrs, &subg_out_attrs, is_udf));
1059 return ret;
1060 };
1061 for (const dim_t &cond_in : params.cond_input_locs) {
1062 (*out_attrs)[cond_in] = kDefaultStorage;
1063 }
1064 bool succ_0 = sub_pass(attrs.subgraphs[1], params.then_input_locs);
1065 bool succ_1 = sub_pass(attrs.subgraphs[2], params.else_input_locs);
1066 return succ_0 && succ_1;
1067 }
1068
CreateCondState(const NodeAttrs & attrs,Context ctx,const mxnet::ShapeVector & ishape,const std::vector<int> & itype)1069 static OpStatePtr CreateCondState(const NodeAttrs& attrs,
1070 Context ctx,
1071 const mxnet::ShapeVector& ishape,
1072 const std::vector<int>& itype) {
1073 const CondParam& params = nnvm::get<CondParam>(attrs.parsed);
1074 return OpStatePtr::Create<CondState>(
1075 params,
1076 *attrs.subgraphs[0],
1077 *attrs.subgraphs[1],
1078 *attrs.subgraphs[2]);
1079 }
1080
1081 static std::vector<nnvm::NodeEntry>
CondGradient(const nnvm::ObjectPtr & n,const std::vector<nnvm::NodeEntry> & ograds)1082 CondGradient(const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
1083 ElemwiseGradUseInOut fgrad{"_backward_cond"};
1084 std::vector<nnvm::NodeEntry> entries = fgrad(n, ograds);
1085 entries[0].node->attrs.subgraphs = n->attrs.subgraphs;
1086 return entries;
1087 }
1088
1089 NNVM_REGISTER_OP(_foreach)
1090 .MXNET_DESCRIBE("Run a for loop over an NDArray with user-defined computation")
1091 .set_attr_parser(ParamParser<ForeachParam>)
1092 .set_attr<FInferStorageType>("FInferStorageType", ForeachStorageType)
__anona7e01d100702(const NodeAttrs& attrs) 1093 .set_num_inputs([](const NodeAttrs& attrs) {
1094 const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
1095 return params.num_args;
1096 })
__anona7e01d100802(const NodeAttrs& attrs) 1097 .set_num_outputs([](const NodeAttrs& attrs) {
1098 const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
1099 return params.num_outputs;
1100 })
1101 .set_attr<nnvm::FListInputNames>("FListInputNames",
__anona7e01d100902(const NodeAttrs& attrs) 1102 [](const NodeAttrs& attrs) {
1103 const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
1104 std::vector<std::string> names;
1105 names.emplace_back("fn");
1106 for (int i = 0; i < params.num_args - 1; i++)
1107 names.push_back("data" + std::to_string(i));
1108 return names;
1109 })
1110 .set_attr<nnvm::FInputGraph>("FInputGraph",
__anona7e01d100a02(const NodeAttrs& attrs) 1111 [](const NodeAttrs& attrs) {
1112 return std::vector<uint32_t>{0};
1113 })
1114 .set_attr<nnvm::FGradient>("FGradient", ForeachGradient)
1115 .set_attr<FCreateOpState>("FCreateOpState", CreateForeachState)
1116 .set_attr<mxnet::FInferShape>("FInferShape", ForeachShape)
1117 .set_attr<nnvm::FInferType>("FInferType", ForeachType)
1118 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", ForeachComputeExCPU)
1119 // Foreach operator works like an executor. Its code will always run on CPU.
1120 // So the same code can be registered for both CPU and GPU.
1121 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", ForeachComputeExCPU)
__anona7e01d100b02(const NodeAttrs& attrs) 1122 .set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
1123 return ExecType::kSubgraphExec;
1124 })
1125 .set_attr<std::string>("key_var_num_args", "num_args")
1126 .add_argument("fn", "Symbol", "Input graph.")
1127 .add_argument("data", "NDArray-or-Symbol[]",
1128 "The input arrays that include data arrays and states.")
1129 .add_arguments(ForeachParam::__FIELDS__());
1130
1131 NNVM_REGISTER_OP(_backward_foreach)
__anona7e01d100c02(const NodeAttrs& attrs)1132 .set_num_inputs([](const NodeAttrs& attrs){
1133 const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
1134 return params.num_outputs * 2 + params.num_args - 1;
1135 })
__anona7e01d100d02(const NodeAttrs& attrs)1136 .set_num_outputs([](const NodeAttrs& attrs){
1137 const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
1138 return params.num_args - 1;
1139 })
__anona7e01d100e02(const NodeAttrs& attrs) 1140 .set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
1141 return ExecType::kSubgraphExec;
1142 })
1143 .set_attr<FInferStorageType>("FInferStorageType", BackwardForeachStorageType)
1144 .set_attr_parser(ParamParser<ForeachParam>)
1145 .set_attr<bool>("TIsLayerOpBackward", true)
1146 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
1147 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", ForeachGradComputeExCPU)
1148 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", ForeachGradComputeExCPU);
1149
1150 NNVM_REGISTER_OP(_while_loop)
1151 .MXNET_DESCRIBE("Run a while loop over with user-defined condition and computation")
1152 .set_attr_parser(ParamParser<WhileLoopParam>)
1153 .set_attr<FInferStorageType>("FInferStorageType", WhileLoopStorageType)
__anona7e01d100f02(const NodeAttrs& attrs) 1154 .set_num_inputs([](const NodeAttrs& attrs) {
1155 const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
1156 return params.num_args;
1157 })
__anona7e01d101002(const NodeAttrs& attrs) 1158 .set_num_outputs([](const NodeAttrs& attrs) {
1159 const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
1160 return params.num_outputs;
1161 })
1162 .set_attr<nnvm::FListInputNames>("FListInputNames",
__anona7e01d101102(const NodeAttrs& attrs) 1163 [](const NodeAttrs& attrs) {
1164 const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
1165 std::vector<std::string> names;
1166 names.reserve(params.num_args);
1167 names.emplace_back("cond");
1168 names.emplace_back("func");
1169 for (int i = 2; i < params.num_args; i++)
1170 names.push_back("data" + std::to_string(i - 2));
1171 return names;
1172 })
1173 .set_attr<nnvm::FInputGraph>("FInputGraph",
__anona7e01d101202(const NodeAttrs& attrs) 1174 [](const NodeAttrs& attrs) {
1175 return std::vector<uint32_t>{0, 1};
1176 })
1177 .set_attr<nnvm::FGradient>("FGradient", WhileLoopGradient)
1178 .set_attr<FCreateOpState>("FCreateOpState", CreateWhileLoopState)
1179 .set_attr<nnvm::FInferType>("FInferType", WhileLoopType)
1180 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", WhileLoopComputeExCPU)
__anona7e01d101302(const NodeAttrs& attrs) 1181 .set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
1182 return ExecType::kSubgraphExec;
1183 })
1184 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", WhileLoopComputeExCPU)
1185 .set_attr<std::string>("key_var_num_args", "num_args")
1186 .add_argument("cond", "Symbol", "Input graph for the loop condition.")
1187 .add_argument("func", "Symbol", "Input graph for the loop body.")
1188 .add_argument("data", "NDArray-or-Symbol[]",
1189 "The input arrays that include data arrays and states.")
1190 .add_arguments(WhileLoopParam::__FIELDS__());
1191
1192 NNVM_REGISTER_OP(_backward_while_loop)
__anona7e01d101402(const NodeAttrs& attrs)1193 .set_num_inputs([](const NodeAttrs& attrs){
1194 const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
1195 return params.num_outputs * 2 + params.num_args - 2;
1196 })
__anona7e01d101502(const NodeAttrs& attrs)1197 .set_num_outputs([](const NodeAttrs& attrs){
1198 const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
1199 return params.num_args - 2;
1200 })
__anona7e01d101602(const NodeAttrs& attrs) 1201 .set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
1202 return ExecType::kSubgraphExec;
1203 })
1204 .set_attr<FInferStorageType>("FInferStorageType", BackwardWhileLoopStorageType)
1205 .set_attr_parser(ParamParser<WhileLoopParam>)
1206 .set_attr<bool>("TIsLayerOpBackward", true)
1207 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
1208 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", WhileLoopGradComputeExCPU)
1209 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", WhileLoopGradComputeExCPU);
1210
1211 NNVM_REGISTER_OP(_cond)
1212 .MXNET_DESCRIBE("Run a if-then-else using user-defined condition and computation")
1213 .set_attr_parser(ParamParser<CondParam>)
1214 .set_attr<FInferStorageType>("FInferStorageType", CondStorageType)
__anona7e01d101702(const NodeAttrs& attrs) 1215 .set_num_inputs([](const NodeAttrs& attrs) {
1216 const CondParam& params = nnvm::get<CondParam>(attrs.parsed);
1217 return params.num_args;
1218 })
__anona7e01d101802(const NodeAttrs& attrs) 1219 .set_num_outputs([](const NodeAttrs& attrs) {
1220 const CondParam& params = nnvm::get<CondParam>(attrs.parsed);
1221 return params.num_outputs;
1222 })
1223 .set_attr<nnvm::FListInputNames>("FListInputNames",
__anona7e01d101902(const NodeAttrs& attrs) 1224 [](const NodeAttrs& attrs) {
1225 const CondParam& params = nnvm::get<CondParam>(attrs.parsed);
1226 std::vector<std::string> names;
1227 names.reserve(params.num_args);
1228 names.emplace_back("cond");
1229 names.emplace_back("then_branch");
1230 names.emplace_back("else_branch");
1231 for (int i = 3; i < params.num_args; ++i)
1232 names.push_back("data" + std::to_string(i - 3));
1233 return names;
1234 })
1235 .set_attr<nnvm::FInputGraph>("FInputGraph",
__anona7e01d101a02(const NodeAttrs& attrs) 1236 [](const NodeAttrs& attrs) {
1237 return std::vector<uint32_t>{0, 1, 2};
1238 })
1239 .set_attr<nnvm::FGradient>("FGradient", CondGradient)
1240 .set_attr<FCreateOpState>("FCreateOpState", CreateCondState)
1241 .set_attr<nnvm::FInferType>("FInferType", CondType)
1242 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", CondComputeExCPU)
__anona7e01d101b02(const NodeAttrs& attrs) 1243 .set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
1244 return ExecType::kSubgraphExec;
1245 })
1246 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", CondComputeExCPU)
1247 .set_attr<std::string>("key_var_num_args", "num_args")
1248 .add_argument("cond", "Symbol", "Input graph for the condition.")
1249 .add_argument("then_branch", "Symbol", "Input graph for the then branch.")
1250 .add_argument("else_branch", "Symbol", "Input graph for the else branch.")
1251 .add_argument("data", "NDArray-or-Symbol[]",
1252 "The input arrays that include data arrays and states.")
1253 .add_arguments(CondParam::__FIELDS__());
1254
1255 NNVM_REGISTER_OP(_backward_cond)
__anona7e01d101c02(const NodeAttrs& attrs)1256 .set_num_inputs([](const NodeAttrs& attrs){
1257 const CondParam& params = nnvm::get<CondParam>(attrs.parsed);
1258 return params.num_outputs * 2 + params.num_args - 3;
1259 })
__anona7e01d101d02(const NodeAttrs& attrs)1260 .set_num_outputs([](const NodeAttrs& attrs){
1261 const CondParam& params = nnvm::get<CondParam>(attrs.parsed);
1262 return params.num_args - 3;
1263 })
__anona7e01d101e02(const NodeAttrs& attrs) 1264 .set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
1265 return ExecType::kSubgraphExec;
1266 })
1267 .set_attr<FInferStorageType>("FInferStorageType", BackwardCondStorageType)
1268 .set_attr_parser(ParamParser<CondParam>)
1269 .set_attr<bool>("TIsLayerOpBackward", true)
1270 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
1271 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", CondGradComputeExCPU)
1272 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", CondGradComputeExCPU);
1273 } // namespace op
1274 } // namespace mxnet
1275