1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 #include <mxnet/io.h>
21 #include <mxnet/base.h>
22 #include <mxnet/ndarray.h>
23 #include <mxnet/operator.h>
24 #include <mxnet/operator_util.h>
25 #include <dmlc/logging.h>
26 #include <dmlc/optional.h>
27 #include "./operator_common.h"
28 #include "./elemwise_op_common.h"
29 #include "../imperative/imperative_utils.h"
30 #include "./subgraph_op_common.h"
31 
32 namespace mxnet {
33 namespace op {
34 
35 struct ForeachParam : public dmlc::Parameter<ForeachParam> {
36   int num_args;
37   int num_outputs;
38   int num_out_data;
39   // The location of states in the subgraph inputs.
40   mxnet::Tuple<dim_t> in_state_locs;
41   // The location of data arrays in the subgraph inputs.
42   mxnet::Tuple<dim_t> in_data_locs;
43   // The location of remaining arrays in the subgraph inputs.
44   mxnet::Tuple<dim_t> remain_locs;
DMLC_DECLARE_PARAMETERmxnet::op::ForeachParam45   DMLC_DECLARE_PARAMETER(ForeachParam) {
46     DMLC_DECLARE_FIELD(num_args).set_lower_bound(1)
47     .describe("Number of inputs.");
48     DMLC_DECLARE_FIELD(num_outputs)
49     .describe("The number of outputs of the subgraph.");
50     DMLC_DECLARE_FIELD(num_out_data)
51     .describe("The number of output data of the subgraph.");
52     DMLC_DECLARE_FIELD(in_state_locs)
53     .describe("The locations of loop states among the inputs.");
54     DMLC_DECLARE_FIELD(in_data_locs)
55     .describe("The locations of input data among the inputs.");
56     DMLC_DECLARE_FIELD(remain_locs)
57     .describe("The locations of remaining data among the inputs.");
58   }
59 };  // struct ForeachParam
60 
61 DMLC_REGISTER_PARAMETER(ForeachParam);
62 
63 class ForeachState: public LoopState {
64  public:
65   ForeachParam params;
66   int num_iterations;
67 
ForeachState(const Symbol & g,const ForeachParam & params)68   ForeachState(const Symbol &g, const ForeachParam &params) : LoopState(g) {
69     this->params = params;
70   }
71 };
72 
ForeachComputeExCPU(const OpStatePtr & state_ptr,const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)73 static void ForeachComputeExCPU(const OpStatePtr& state_ptr,
74                                 const OpContext& ctx,
75                                 const std::vector<NDArray>& inputs,
76                                 const std::vector<OpReqType>& req,
77                                 const std::vector<NDArray>& outputs) {
78   ForeachState &state = state_ptr.get_state<ForeachState>();
79   const ForeachParam& params = state.params;
80   const size_t iter_dim = 0;
81   CHECK_EQ(outputs.size(), (size_t) params.num_outputs);
82   CHECK_GT(params.in_data_locs.ndim(), 0);
83   size_t len = inputs[0].shape()[iter_dim];
84   state.num_iterations = len;
85   for (int i = 1; i < params.in_data_locs.ndim(); i++)
86     CHECK_EQ(inputs[i].shape()[iter_dim], len);
87   for (size_t i = 0; i < (size_t) params.num_out_data; i++)
88     CHECK_EQ(len, outputs[i].shape()[iter_dim]);
89   for (const auto &arr : outputs)
90     CHECK_EQ(arr.storage_type(), kDefaultStorage)
91         << "The for operator doesn't support the sparse format";
92 
93   // Initialize the outputs of the subgraph is a little trickier.
94   // The states from the previous iteration are used as the inputs of the next
95   // iteration, so I have to maintain two arrays, so the inputs and outputs
96   // of the subgraph share the same memory.
97   std::vector<NDArray> subg_outputs1(outputs.size());
98   std::vector<NDArray> subg_outputs2(outputs.size());
99   std::vector<NDArray> *subg_outputs[2]{&subg_outputs1, &subg_outputs2};
100   // If the length is an odd number, the last iteration will use the first set
101   // of outputs. In this way, we don't need to copy the results from the
102   // subgraph to the final outputs of the loop.
103   if (len % 2 == 1) {
104     for (size_t i = params.num_out_data; i < subg_outputs1.size(); i++) {
105       subg_outputs1[i] = outputs[i];
106       subg_outputs2[i] = NDArray(outputs[i].shape(), outputs[i].ctx(), true,
107                                  outputs[i].dtype());
108     }
109   } else {
110     // Otherwise, we'll use the second set of outputs.
111     for (size_t i = params.num_out_data; i < subg_outputs1.size(); i++) {
112       subg_outputs1[i] = NDArray(outputs[i].shape(), outputs[i].ctx(), true,
113                                  outputs[i].dtype());
114       subg_outputs2[i] = outputs[i];
115     }
116   }
117 
118   // Initialize the inputs for the subgraph.
119   // In each iteration, we need to update the subgraph inputs for input data
120   // and the loop states.
121   std::vector<NDArray> subg_inputs(inputs.size());
122   // The remaining arrays (other than input data and states) only need to be set once.
123   for (int j = 0; j < params.remain_locs.ndim(); j++) {
124     CHECK_LT(params.remain_locs[j], subg_inputs.size());
125     subg_inputs[params.remain_locs[j]] = inputs[j + params.in_data_locs.ndim()
126         + params.in_state_locs.ndim()];
127   }
128 
129   // Here we iterate over the first dimension of the first input array.
130   for (size_t i = 0; i < len; i++) {
131     // Initialize outputs for the subgraph.
132     std::vector<NDArray> *subg_out_curr = subg_outputs[i % 2];
133     std::vector<NDArray> *subg_out_prev = subg_outputs[(i + 1) % 2];
134     for (int j = 0; j < params.num_out_data; j++)
135       (*subg_out_curr)[j] = outputs[j].At(i);
136     // When recording for backward computation, we should make sure
137     // that output arrays are actually different in each iteration.
138     if (ctx.need_grad && i < len - 1) {
139       for (size_t j = params.num_out_data; j < subg_out_curr->size(); j++)
140         (*subg_out_curr)[j] = NDArray(outputs[j].shape(), outputs[j].ctx(),
141                                       true, outputs[j].dtype());
142     } else if (ctx.need_grad && i == len - 1) {
143       // For the last iteration, we need to write data to the output array
144       // directly.
145       for (size_t j = params.num_out_data; j < subg_out_curr->size(); j++)
146         (*subg_out_curr)[j] = outputs[j];
147     }
148 
149     // Initialize inputs for the subgraph.
150     // Get a slice from the input data arrays.
151     for (int j = 0; j < params.in_data_locs.ndim(); j++) {
152       size_t loc = params.in_data_locs[j];
153       subg_inputs[loc] = inputs[j].At(i);
154     }
155     // For the rest of the iterations, the states are the outputs
156     // from the previous iteration.
157     if (i > 0) {
158       for (size_t j = params.num_out_data; j < subg_out_prev->size(); j++) {
159         size_t idx = j - params.num_out_data;
160         CHECK_LT(params.in_state_locs[idx], subg_inputs.size());
161         subg_inputs[params.in_state_locs[idx]] = (*subg_out_prev)[j];
162       }
163     } else {
164       for (int j = 0; j < params.in_state_locs.ndim(); j++) {
165         CHECK_LT(params.in_state_locs[j], subg_inputs.size());
166         subg_inputs[params.in_state_locs[j]] = inputs[j + params.in_data_locs.ndim()];
167       }
168     }
169 
170     state.Forward(i, subg_inputs, req, *subg_out_curr, ctx.need_grad);
171   }
172 }
173 
ForeachGradComputeExCPU(const OpStatePtr & state_ptr,const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)174 static void ForeachGradComputeExCPU(const OpStatePtr& state_ptr,
175                                     const OpContext& ctx,
176                                     const std::vector<NDArray>& inputs,
177                                     const std::vector<OpReqType>& req,
178                                     const std::vector<NDArray>& outputs) {
179   ForeachState &state = state_ptr.get_state<ForeachState>();
180   const ForeachParam& params = state.params;
181   CHECK_EQ(outputs.size(), (size_t) params.num_args - 1);
182   CHECK_GT(params.in_data_locs.ndim(), 0);
183   for (const auto &arr : outputs)
184     CHECK_EQ(arr.storage_type(), kDefaultStorage)
185         << "The for operator doesn't support the sparse format";
186   int len = state.num_iterations;
187   size_t num_output_data = params.num_out_data;
188 
189   // In backward computation, we need to run iterations from backwards.
190   std::vector<NDArray> subg_ograds(params.num_outputs);
191   std::vector<NDArray> subg_igrads(outputs.size());
192   for (size_t i = num_output_data; i < subg_ograds.size(); i++)
193     subg_ograds[i] = inputs[i];
194   std::vector<OpReqType> subg_req(req.size());
195   for (auto r : req)
196     CHECK_NE(r, kWriteInplace);
197 
198   // There are three types of arrays in igrads.
199   // * data gradients.
200   // * loop variable gradients.
201   // * remaining variable gradients.
202   // They are in the following order:
203   // [data vars], [loop vars], [remaining vars]
204 
205   // [remaining vars]
206   for (int i = 0; i < params.remain_locs.ndim(); i++) {
207     size_t loc = params.remain_locs[i];
208     size_t orig_loc = i + params.in_data_locs.ndim() + params.in_state_locs.ndim();
209     subg_igrads[loc] = outputs[orig_loc];
210     subg_req[loc] = req[orig_loc];
211   }
212 
213   for (int iter_num = len - 1; iter_num >= 0; iter_num--) {
214     for (int i = 0; i < params.num_out_data; i++)
215       subg_ograds[i] = inputs[i].At(iter_num);
216     if (iter_num < len - 1) {
217       // For the rest of the iterations, we should add graidents to the
218       // remaining vars.
219       for (int i = 0; i < params.remain_locs.ndim(); i++) {
220         size_t loc = params.remain_locs[i];
221         subg_req[loc] = kAddTo;
222       }
223     }
224 
225     // [data vars]
226     for (int i = 0; i < params.in_data_locs.ndim(); i++) {
227       size_t loc = params.in_data_locs[i];
228       subg_igrads[loc] = outputs[i].At(iter_num);
229       subg_req[loc] = req[i];
230     }
231     // [loop vars]
232     for (int i = 0; i < params.in_state_locs.ndim(); i++) {
233       size_t loc = params.in_state_locs[i];
234       const NDArray &output = outputs[i + params.in_data_locs.ndim()];
235       if (iter_num != 0) {
236         // For state gradients, we need to allocate new NDArrays
237         // because intermediate state gradients won't be returned to the users.
238         subg_igrads[loc] = NDArray(output.shape(), output.ctx(), true, output.dtype());
239       } else {
240         subg_igrads[loc] = output;
241       }
242       // For the first iteration, we need to use the request provided by
243       // the user to write state gradients to the outputs.
244       subg_req[loc] = iter_num != 0 ? kWriteTo : req[i + params.in_data_locs.ndim()];
245     }
246     state.Backward(iter_num, subg_ograds, subg_req, subg_igrads);
247 
248     size_t num_states = subg_ograds.size() - num_output_data;
249     for (size_t i = 0; i < num_states; i++) {
250       size_t loc = params.in_state_locs[i];
251       CHECK_LT(loc, subg_igrads.size());
252       subg_ograds[i + num_output_data] = subg_igrads[loc];
253     }
254   }
255   state.Cleanup();
256 }
257 
258 template<typename T>
remap(const std::vector<T> & op_in,size_t start,const mxnet::Tuple<dim_t> & locs,std::vector<T> * subg_in)259 static void remap(const std::vector<T> &op_in, size_t start,
260                   const mxnet::Tuple<dim_t> &locs, std::vector<T> *subg_in) {
261   auto op_in_it = op_in.begin() + start;
262   for (int i = 0; i < locs.ndim(); i++) {
263     dim_t loc = locs[i];
264     subg_in->at(loc) = *(op_in_it + i);
265   }
266 }
267 
SliceFirstDim(const mxnet::TShape & s)268 static inline mxnet::TShape SliceFirstDim(const mxnet::TShape &s) {
269   if (s.ndim() > 1) {
270     return mxnet::TShape(s.begin() + 1, s.end());
271   } else {
272     return mxnet::TShape(mshadow::Shape1(1));
273   }
274 }
275 
ForeachShape(const nnvm::NodeAttrs & attrs,mxnet::ShapeVector * in_shape,mxnet::ShapeVector * out_shape)276 static bool ForeachShape(const nnvm::NodeAttrs& attrs,
277                          mxnet::ShapeVector *in_shape,
278                          mxnet::ShapeVector *out_shape) {
279   const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
280   CHECK_EQ(out_shape->size(), (size_t) params.num_outputs);
281   CHECK_EQ(attrs.subgraphs.size(), 1U);
282 
283   mxnet::ShapeVector subg_in_shape(in_shape->size());
284   // data shape
285   std::vector<bool> data_1d(params.in_data_locs.ndim(), false);
286   for (int i = 0; i < params.in_data_locs.ndim(); i++) {
287     size_t loc = params.in_data_locs[i];
288     if (in_shape->at(i).ndim() == 1)
289       data_1d[i] = true;
290     subg_in_shape[loc] = SliceFirstDim(in_shape->at(i));
291   }
292   // state shape
293   remap(*in_shape, params.in_data_locs.ndim(), params.in_state_locs,
294         &subg_in_shape);
295   // remaining shape
296   remap(*in_shape, params.in_data_locs.ndim() + params.in_state_locs.ndim(),
297         params.remain_locs, &subg_in_shape);
298 
299   mxnet::ShapeVector subg_out_shape = *out_shape;
300   for (int i = 0; i < params.num_out_data; i++) {
301     mxnet::TShape shape = subg_out_shape[i];
302     // If we don't have shape info, we don't need to do anything.
303     if (!mxnet::ndim_is_known(shape))
304       continue;
305     subg_out_shape[i] = SliceFirstDim(shape);
306   }
307 
308   bool infer_success = InferSubgraphShape(*attrs.subgraphs[0],
309                                           &subg_in_shape, &subg_out_shape);
310 
311   // After inference, we need to move inferred information back to in_shape and
312   // out_shape.
313 
314   // For the shape of output data.
315   size_t len = in_shape->at(0)[0];
316   for (int i = 0; i < params.num_out_data; i++) {
317     // If the output shape isn't inferred, we don't need to propogate the info.
318     const auto& g_out_shape = subg_out_shape[i];
319     if (!mxnet::ndim_is_known(g_out_shape))
320       continue;
321 
322     auto out = mxnet::TShape(g_out_shape.ndim() + 1, -1);
323     out[0] = len;
324     for (int i = 1; i < out.ndim(); i++)
325       out[i] = g_out_shape[i - 1];
326     SHAPE_ASSIGN_CHECK(*out_shape, i, out);
327   }
328   // For the shape of output states.
329   for (size_t i = params.num_out_data; i < subg_out_shape.size(); i++)
330     SHAPE_ASSIGN_CHECK(*out_shape, i, subg_out_shape[i]);
331 
332   // For the shape of input data.
333   for (int i = 0; i < params.in_data_locs.ndim(); i++) {
334     size_t loc = params.in_data_locs[i];
335     const auto &shape = subg_in_shape[loc];
336     // If the input data shape isn't inferred, we don't need to propogate the
337     // info.
338     if (!mxnet::ndim_is_known(shape))
339       continue;
340 
341     if (data_1d[i]) {
342       mxnet::TShape s(1, -1);
343       s[0] = len;
344       SHAPE_ASSIGN_CHECK(*in_shape, i, s);
345     } else {
346       auto in = mxnet::TShape(shape.ndim() + 1, -1);
347       in[0] = len;
348       for (int i = 1; i < in.ndim(); i++)
349         in[i] = shape[i - 1];
350       SHAPE_ASSIGN_CHECK(*in_shape, i, in);
351     }
352   }
353   // For the shape of state.
354   for (int i = 0; i < params.in_state_locs.ndim(); i++) {
355     size_t loc = params.in_state_locs[i];
356     SHAPE_ASSIGN_CHECK(*in_shape, i + params.in_data_locs.ndim(),
357                        subg_in_shape[loc]);
358   }
359   // For the shape of remaining data.
360   for (int i = 0; i < params.remain_locs.ndim(); i++) {
361     size_t loc = params.remain_locs[i];
362     SHAPE_ASSIGN_CHECK(*in_shape,
363                        i + params.in_data_locs.ndim() + params.in_state_locs.ndim(),
364                        subg_in_shape[loc]);
365   }
366 
367   if (infer_success) {
368     size_t num_states = out_shape->size() - params.num_out_data;
369     for (size_t i = 0; i < num_states; i++) {
370       CHECK_EQ((*out_shape)[i + params.num_out_data],
371                (*in_shape)[i + params.in_data_locs.ndim()]);
372     }
373   }
374   // Check if we have inferred the shapes correctly.
375   return infer_success;
376 }
377 
ForeachType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_type,std::vector<int> * out_type)378 static bool ForeachType(const nnvm::NodeAttrs& attrs,
379                         std::vector<int> *in_type, std::vector<int> *out_type) {
380   const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
381   CHECK_EQ(out_type->size(), (size_t) params.num_outputs);
382   CHECK_EQ(attrs.subgraphs.size(), 1U);
383   std::vector<int> subg_in_type(in_type->size(), 0);
384   remap(*in_type, 0, params.in_data_locs, &subg_in_type);
385   remap(*in_type, params.in_data_locs.ndim(), params.in_state_locs, &subg_in_type);
386   remap(*in_type, params.in_data_locs.ndim() + params.in_state_locs.ndim(),
387         params.remain_locs, &subg_in_type);
388   bool success = InferSubgraphDataType(*attrs.subgraphs[0], &subg_in_type, out_type);
389   for (int i = 0; i < params.in_data_locs.ndim(); i++) {
390     size_t loc = params.in_data_locs[i];
391     TYPE_ASSIGN_CHECK(*in_type, i, subg_in_type[loc]);
392   }
393   for (int i = 0; i < params.in_state_locs.ndim(); i++) {
394     size_t loc = params.in_state_locs[i];
395     TYPE_ASSIGN_CHECK(*in_type, i + params.in_data_locs.ndim(), subg_in_type[loc]);
396   }
397   for (int i = 0; i < params.remain_locs.ndim(); i++) {
398     size_t loc = params.remain_locs[i];
399     TYPE_ASSIGN_CHECK(*in_type, i + params.in_data_locs.ndim() + params.in_state_locs.ndim(),
400                       subg_in_type[loc]);
401   }
402   return success;
403 }
404 
ForeachStorageType(const nnvm::NodeAttrs & attrs,const int dev_mask,DispatchMode * dispatch_mode,std::vector<int> * in_attrs,std::vector<int> * out_attrs)405 static bool ForeachStorageType(const nnvm::NodeAttrs& attrs,
406                                const int dev_mask,
407                                DispatchMode* dispatch_mode,
408                                std::vector<int> *in_attrs,
409                                std::vector<int> *out_attrs) {
410   const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
411   CHECK_EQ(out_attrs->size(), (size_t) params.num_outputs);
412   CHECK_EQ(attrs.subgraphs.size(), 1U);
413   std::vector<int> subg_in_attrs(in_attrs->size(), kUndefinedStorage);
414   remap(*in_attrs, 0, params.in_data_locs, &subg_in_attrs);
415   remap(*in_attrs, params.in_data_locs.ndim(), params.in_state_locs, &subg_in_attrs);
416   remap(*in_attrs, params.in_data_locs.ndim() + params.in_state_locs.ndim(),
417         params.remain_locs, &subg_in_attrs);
418   bool success = InferSubgraphStorage(*attrs.subgraphs[0], dev_mask,
419                                       dispatch_mode, &subg_in_attrs, out_attrs);
420   for (int i = 0; i < params.in_data_locs.ndim(); i++) {
421     size_t loc = params.in_data_locs[i];
422     STORAGE_TYPE_ASSIGN_CHECK(*in_attrs, i, subg_in_attrs[loc]);
423   }
424   for (int i = 0; i < params.in_state_locs.ndim(); i++) {
425     size_t loc = params.in_state_locs[i];
426     STORAGE_TYPE_ASSIGN_CHECK(*in_attrs, i + params.in_data_locs.ndim(),
427                               subg_in_attrs[loc]);
428   }
429   for (int i = 0; i < params.remain_locs.ndim(); i++) {
430     size_t loc = params.remain_locs[i];
431     STORAGE_TYPE_ASSIGN_CHECK(*in_attrs,
432                               i + params.in_data_locs.ndim() + params.in_state_locs.ndim(),
433                               subg_in_attrs[loc]);
434   }
435   return success;
436 }
437 
BackwardForeachStorageType(const nnvm::NodeAttrs & attrs,const int dev_mask,DispatchMode * dispatch_mode,std::vector<int> * in_attrs,std::vector<int> * out_attrs)438 static bool BackwardForeachStorageType(const nnvm::NodeAttrs& attrs,
439                                        const int dev_mask,
440                                        DispatchMode* dispatch_mode,
441                                        std::vector<int> *in_attrs,
442                                        std::vector<int> *out_attrs) {
443   const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
444   CHECK_EQ(out_attrs->size(), (size_t) params.num_args - 1);
445   CHECK_EQ(in_attrs->size(), (size_t) params.num_args - 1 + params.num_outputs * 2);
446   CHECK_EQ(attrs.subgraphs.size(), 1U);
447   CachedOp op(*attrs.subgraphs[0],
448               std::vector<std::pair<std::string, std::string> >());
449   // map the operator inputs to the subgraph inputs.
450   std::vector<int> subg_forward_ins(params.num_args - 1, kUndefinedStorage);
451   remap(*in_attrs, params.num_outputs, params.in_data_locs, &subg_forward_ins);
452   remap(*in_attrs, params.num_outputs + params.in_data_locs.ndim(),
453         params.in_state_locs, &subg_forward_ins);
454   remap(*in_attrs, params.num_outputs + params.in_data_locs.ndim() + params.in_state_locs.ndim(),
455         params.remain_locs, &subg_forward_ins);
456 
457   // Copy backward input storage to backward subgraph input storage.
458   std::vector<int> subg_in_attrs = *in_attrs;
459   for (size_t i = 0; i < subg_forward_ins.size(); i++)
460     subg_in_attrs[i + params.num_outputs] = subg_forward_ins[i];
461   return op.BackwardStorageType(attrs, dev_mask, dispatch_mode,
462                                 &subg_in_attrs, out_attrs);
463 }
464 
CreateForeachState(const NodeAttrs & attrs,Context ctx,const mxnet::ShapeVector & ishape,const std::vector<int> & itype)465 static OpStatePtr CreateForeachState(const NodeAttrs& attrs,
466                                      Context ctx,
467                                      const mxnet::ShapeVector& ishape,
468                                      const std::vector<int>& itype) {
469   const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
470   return OpStatePtr::Create<ForeachState>(*attrs.subgraphs[0], params);
471 }
472 
473 static std::vector<nnvm::NodeEntry>
ForeachGradient(const nnvm::ObjectPtr & n,const std::vector<nnvm::NodeEntry> & ograds)474 ForeachGradient(const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
475   ElemwiseGradUseInOut fgrad{"_backward_foreach"};
476   std::vector<nnvm::NodeEntry> entries = fgrad(n, ograds);
477   entries[0].node->attrs.subgraphs = n->attrs.subgraphs;
478   return entries;
479 }
480 
481 struct WhileLoopParam : public dmlc::Parameter<WhileLoopParam> {
482   int num_args;
483   int num_outputs;
484   int num_out_data;
485   int max_iterations;
486   // `cond' and `func' each takes a subset of while_loop's inputs as that to their subgraphs
487   // `cond_input_locs' contains indices of inputs fed to `cond', and
488   // `func_input_locs' contains indices of inputs fed to `func'.
489   // `func_var_locs' are indices in which input "variables" are stored in func's inputs.
490   mxnet::Tuple<dim_t> cond_input_locs;
491   mxnet::Tuple<dim_t> func_input_locs;
492   mxnet::Tuple<dim_t> func_var_locs;
DMLC_DECLARE_PARAMETERmxnet::op::WhileLoopParam493   DMLC_DECLARE_PARAMETER(WhileLoopParam) {
494     DMLC_DECLARE_FIELD(num_args).set_lower_bound(2)
495     .describe("Number of input arguments, including cond and func as two symbol inputs.");
496     DMLC_DECLARE_FIELD(num_outputs).set_lower_bound(1)
497     .describe("The number of outputs of the subgraph.");
498     DMLC_DECLARE_FIELD(num_out_data).set_lower_bound(0)
499     .describe("The number of outputs from the function body.");
500     DMLC_DECLARE_FIELD(max_iterations).set_lower_bound(1)
501     .describe("Maximum number of iterations.");
502     DMLC_DECLARE_FIELD(cond_input_locs)
503     .describe("The locations of cond's inputs in the given inputs.");
504     DMLC_DECLARE_FIELD(func_input_locs)
505     .describe("The locations of func's inputs in the given inputs.");
506     DMLC_DECLARE_FIELD(func_var_locs)
507     .describe("The locations of loop_vars among func's inputs.");
508   }
509   template <typename T>
sync_in_outmxnet::op::WhileLoopParam510   bool sync_in_out(std::vector<T> *in,
511                    std::vector<T> *out,
512                    std::function<bool(const T &)> is_empty) const {
513     for (int i = this->num_out_data; i < this->num_outputs; ++i) {
514       // each out->at(i) is a params, loop_var
515       T &x = in->at(this->func_input_locs[this->func_var_locs[i - this->num_out_data]]);
516       T &y = out->at(i);
517       fill_value(&x, &y, is_empty(x), is_empty(y));
518     }
519     return true;
520   }
521 };  // struct WhileLoopParam
522 
523 DMLC_REGISTER_PARAMETER(WhileLoopParam);
524 
525 class WhileLoopState: public LoopState {
526  public:
527   WhileLoopParam params;
528   size_t n_iterations;  // the actual number of steps taken in this while loop, <= max_iterations
529   CachedOpPtr cond_op;
530   // abbrev for output_input_mapping
531   // indicates to which index the output of `func' will be copied to the input of `cond'
532   std::vector<int> oi_map;
533 
WhileLoopState(const WhileLoopParam & params,const Symbol & cond,const Symbol & func)534   WhileLoopState(const WhileLoopParam &params, const Symbol &cond, const Symbol &func) :
535                  LoopState(func),
536                  params(params),
537                  n_iterations(0U),
538                  cond_op(LoopState::MakeSharedOp(cond)),
539                  oi_map(params.func_var_locs.ndim(), -1) {
540     const mxnet::Tuple<dim_t> &func_input_locs = params.func_input_locs;
541     const mxnet::Tuple<dim_t> &func_var_locs = params.func_var_locs;
542     const mxnet::Tuple<dim_t> &cond_input_locs = params.cond_input_locs;
543     for (int i = 0; i < func_var_locs.ndim(); ++i) {
544       dim_t pos_i = func_input_locs[func_var_locs[i]];
545       for (int j = 0; j < cond_input_locs.ndim(); ++j) {
546         dim_t pos_j = cond_input_locs[j];
547         if (pos_i == pos_j) {
548           this->oi_map[i] = j;
549         }
550       }
551     }
552   }
553 };
554 
WhileLoopComputeExCPU(const OpStatePtr & state_ptr,const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)555 static void WhileLoopComputeExCPU(const OpStatePtr& state_ptr,
556                                   const OpContext& ctx,
557                                   const std::vector<NDArray>& inputs,
558                                   const std::vector<OpReqType>& req,
559                                   const std::vector<NDArray>& outputs) {
560   // The argument `inputs' are loop_vars and other inputs
561   // loop_vars are stored in stored in `loop_vars_locs'
562   // The argument `outputs' are output and new_loop_vars
563   // [0: num_out_data) are outputs at each step.
564   // [num_out_data: ) are new_loop_vars
565   // TODO(Junru): avoid dynamic NDArray allocation
566   WhileLoopState &state = state_ptr.get_state<WhileLoopState>();
567   const WhileLoopParam& params = state.params;
568   // a helper function, converting std::vector<NDArray> to std::vector<NDArray*>
569   const auto to_ptr_vec = [](std::vector<NDArray> &in, std::vector<NDArray*> *out) {
570     out->clear();
571     out->reserve(in.size());
572     std::transform(std::begin(in),
573                    std::end(in),
574                    std::back_inserter(*out),
575                    [](NDArray &a) {return &a;});
576   };
577   // sanity checks
578   CHECK_EQ(inputs.size() + 2U, (size_t) params.num_args);
579   CHECK_EQ(outputs.size(), (size_t) params.num_outputs);
580   CHECK_EQ(outputs.size(), req.size());
581   // construct inputs and outputs for cond
582   std::vector<NDArray> cond_inputs, cond_outputs = {NDArray()};
583   extract_by_loc(inputs, params.cond_input_locs, &cond_inputs);
584   std::vector<NDArray*> cond_input_ptr, cond_output_ptr;
585   to_ptr_vec(cond_inputs, &cond_input_ptr);
586   to_ptr_vec(cond_outputs, &cond_output_ptr);
587   // construct inputs and outputs for func
588   std::vector<NDArray> func_inputs, func_outputs(outputs.size());
589   extract_by_loc(inputs, params.func_input_locs, &func_inputs);
590   for (size_t &step = state.n_iterations = 0; step < (size_t) params.max_iterations; ++step) {
591     state.cond_op->Forward(nullptr, cond_input_ptr, cond_output_ptr);
592     if (!as_bool_scalar(*cond_output_ptr[0])) {
593       break;
594     }
595     // we create func_outputs for the current step:
596     for (size_t i = 0; i < outputs.size(); ++i) {
597       func_outputs[i] = NDArray(outputs[i].ctx(), outputs[i].dtype());
598     }
599     state.Forward(step, func_inputs, req, func_outputs, ctx.need_grad);
600     if (step == 0) {
601       for (int i = 0; i < params.num_out_data; ++i) {
602         func_outputs[i].WaitToRead();
603         if (!shape_is_known(func_outputs[i].shape())) {
604           func_outputs[i].SetShapeFromChunk();
605         }
606         mxnet::TShape step_shape = func_outputs[i].shape();
607         mxnet::TShape shape(step_shape.ndim() + 1, 0);
608         shape[0] = params.max_iterations;
609         for (int j = 0; j < step_shape.ndim(); ++j) {
610           shape[j + 1] = step_shape[j];
611         }
612         const_cast<NDArray &>(outputs[i]).Init(shape);
613       }
614     }
615     for (int i = 0; i < params.num_out_data; ++i) {
616       NDArray first_slot = outputs[i].At(step);
617       mxnet::CopyFromTo(func_outputs[i], &first_slot);
618     }
619     // func_inputs on the next step:
620     // the output (new_loop_vars) will become the new inputs (loop_vars)
621     for (size_t i = params.num_out_data; i < outputs.size(); ++i) {
622       int j = params.func_var_locs[i - params.num_out_data];
623       func_inputs[j] = func_outputs[i];
624       int k = state.oi_map[i - params.num_out_data];
625       if (k != -1) {
626         // I actually don't need to update cond_inputs
627         cond_inputs[k] = func_outputs[i];
628         cond_input_ptr[k] = &func_outputs[i];
629       }
630     }
631   }
632   // copy output data to `outputs'
633   // case 1: at least one step is executed,
634   // the final_loop_vars must be stored in func_inputs
635   // case 2: no step is executed
636   // the final_loop_vars is the same as loop_vars, which are also stored in func_inputs
637   // therefore, we copy func_inputs[:] to outputs[num_out_data: ]
638   for (size_t i = params.num_out_data; i < outputs.size(); ++i) {
639     size_t j = params.func_var_locs[i - params.num_out_data];
640     if (!shape_is_known(outputs[i].shape())) {
641       const_cast<NDArray &>(outputs[i]).Init(func_inputs[j].shape());
642     }
643     mxnet::CopyFromTo(func_inputs[j], &outputs[i]);
644   }
645   for (int i = 0; i < params.num_out_data; ++i) {
646     const_cast<NDArray &>(outputs[i]).SetShapeFromChunk();
647   }
648   if (state.n_iterations == 0) {
649     for (size_t i = 0; i < outputs.size(); ++i) {
650       if (!shape_is_known(outputs[i].shape())) {
651         const_cast<NDArray &>(outputs[i]).ReshapeAndAlloc({1});
652       }
653     }
654   }
655 }
656 
WhileLoopGradComputeExCPU(const OpStatePtr & state_ptr,const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & _req,const std::vector<NDArray> & _outputs)657 static void WhileLoopGradComputeExCPU(const OpStatePtr& state_ptr,
658                                       const OpContext& ctx,
659                                       const std::vector<NDArray>& inputs,
660                                       const std::vector<OpReqType>& _req,
661                                       const std::vector<NDArray>& _outputs) {
662   // inputs are dl / df(x)
663   // outputs are dl / dx
664   // where f is the current function,
665   // x is the input to the current function,
666   // TODO(Junru): avoid dynamic NDArray allocation
667   WhileLoopState &state = state_ptr.get_state<WhileLoopState>();
668   const WhileLoopParam& params = state.params;
669   // sanity checks
670   CHECK_EQ(_outputs.size() + 2U, (size_t) params.num_args);
671   CHECK_EQ(_outputs.size(), _req.size());
672   for (auto x : _req) {
673     CHECK_NE(x, kWriteInplace);
674   }
675   std::vector<NDArray> outputs;
676   std::vector<OpReqType> req;
677   extract_by_loc(_outputs, params.func_input_locs, &outputs);
678   extract_by_loc(_req, params.func_input_locs, &req);
679   if (state.n_iterations == 0) {
680     for (int i = params.num_out_data; i < params.num_outputs; ++i) {
681       int j = params.func_var_locs[i - params.num_out_data];
682       mxnet::CopyFromTo(inputs[i], &outputs[j]);
683     }
684     state.Cleanup();
685     return;
686   }
687   // collect var_locs and out_locs, positions other than var_locs are out_locs, i.e.
688   // [0, var_locs[0])
689   // (var_locs[1], var_locs[2]),
690   // (var_locs[2], var_locs[3]),
691   // ...
692   // (var_locs[-2], var_locs[-1] = params.num_args - 2)
693   std::vector<dim_t> var_locs(params.func_var_locs.begin(), params.func_var_locs.end());
694   var_locs.push_back((dim_t) params.num_args - 2U);
695   sort(var_locs.begin(), var_locs.end());
696   // vectors for the backward loop
697   std::vector<NDArray> ograds(params.num_outputs);
698   std::vector<NDArray> igrads(outputs.size());
699   std::vector<OpReqType> iter_req(req.size());
700   for (int i = params.num_out_data; i < params.num_outputs; ++i)
701     ograds[i] = inputs[i];
702   const int n_iter = state.n_iterations;
703   for (int step = n_iter - 1; step >= 0; --step) {
704     // ograds[ : num_out_data] = inputs[ : num_out_data][step]
705     // ograds[num_out_data: ] is maintained in the end of each loop
706     std::transform(std::begin(inputs),
707                    std::begin(inputs) + params.num_out_data,
708                    std::begin(ograds),
709                    [step] (const NDArray &a) { return a.At(step); } );
710     // igrads[i] =
711     //    outputs[i]            (step == 0)
712     //    outputs[i]            (step != 0 && i not in loop_var_locs)
713     //    ArrayLike(outputs[i]) (step != 0 && i in loop_var_locs)
714     // iter_req =
715     //    kWriteTo              (step != 0           && i in loop_var_locs)
716     //    req[i]                (step == 0           && i in loop_var_locs)
717     //    kAddTo                (step != n_iters - 1 && i not in loop_var_locs)
718     //    req[i]                (step == n_iters - 1 && i not in loop_var_locs)
719     {
720       size_t i = 0;
721       for (size_t loc : var_locs) {
722         for ( ; i < loc; ++i) {
723           // locs other that var_locs
724           igrads[i] = outputs[i];
725           iter_req[i] = (step + 1 == n_iter || req[i] == kNullOp)
726                       ? req[i]
727                       : kAddTo;
728         }
729         if (i < (size_t) params.num_args - 2U) {
730           // a var
731           igrads[i] = (step == 0)
732                     ? outputs[i]
733                     : NDArray(outputs[i].shape(), outputs[i].ctx(), true, outputs[i].dtype());
734           iter_req[i] = (step == 0 || req[i] == kNullOp)
735                       ? req[i]
736                       : kWriteTo;
737           ++i;
738         } else {
739           break;
740         }
741       }
742     }
743     state.Backward(step, ograds, iter_req, igrads);
744     for (int i = params.num_out_data; i < params.num_outputs; ++i) {
745       size_t j = params.func_var_locs[i - params.num_out_data];
746       ograds[i] = igrads[j];
747     }
748   }
749   state.Cleanup();
750 }
751 
WhileLoopType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_type,std::vector<int> * out_type)752 static bool WhileLoopType(const nnvm::NodeAttrs& attrs,
753                           std::vector<int> *in_type, std::vector<int> *out_type) {
754   const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
755   static const std::function<bool(const int &)> is_udf = is_type_udf;
756   CHECK_EQ(in_type->size() + 2U, (size_t) params.num_args);
757   CHECK_EQ(out_type->size(), (size_t) params.num_outputs);
758   CHECK_EQ(attrs.subgraphs.size(), 2U);
759   CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U);
760   std::vector<int> cond_in_type;
761   std::vector<int> func_in_type;
762   extract_by_loc(*in_type, params.cond_input_locs, &cond_in_type);
763   extract_by_loc(*in_type, params.func_input_locs, &func_in_type);
764   std::vector<int> cond_out_type = {-1};
765   CHECK(params.sync_in_out(in_type, out_type, is_udf));
766   bool succ_0 = InferSubgraphDataType(*attrs.subgraphs[0], &cond_in_type, &cond_out_type);
767   CHECK(params.sync_in_out(in_type, out_type, is_udf));
768   CHECK(sync_in_in(params.cond_input_locs, in_type, &cond_in_type, is_udf));
769   bool succ_1 = InferSubgraphDataType(*attrs.subgraphs[1], &func_in_type, out_type);
770   CHECK(params.sync_in_out(in_type, out_type, is_udf));
771   CHECK(sync_in_in(params.func_input_locs, in_type, &func_in_type, is_udf));
772   return succ_0 && succ_1;
773 }
774 
WhileLoopStorageType(const nnvm::NodeAttrs & attrs,const int dev_mask,DispatchMode * dispatch_mode,std::vector<int> * in_attrs,std::vector<int> * out_attrs)775 static bool WhileLoopStorageType(const nnvm::NodeAttrs& attrs,
776                                  const int dev_mask,
777                                  DispatchMode* dispatch_mode,
778                                  std::vector<int> *in_attrs,
779                                  std::vector<int> *out_attrs) {
780   const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
781   static const std::function<bool(const int &)> is_udf = is_stype_udf;
782   CHECK_EQ(in_attrs->size() + 2U, (size_t) params.num_args);
783   CHECK_EQ(out_attrs->size(), (size_t) params.num_outputs);
784   CHECK_EQ(attrs.subgraphs.size(), 2U);
785   CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U);
786   std::vector<int> cond_in_attrs;
787   std::vector<int> func_in_attrs;
788   extract_by_loc(*in_attrs, params.cond_input_locs, &cond_in_attrs);
789   extract_by_loc(*in_attrs, params.func_input_locs, &func_in_attrs);
790   std::vector<int> cond_out_attrs = {kDefaultStorage};
791   DispatchMode cond_mode = DispatchMode::kUndefined;
792   DispatchMode func_mode = DispatchMode::kUndefined;
793   *dispatch_mode = DispatchMode::kFComputeEx;
794   CHECK(params.sync_in_out(in_attrs, out_attrs, is_udf));
795   bool succ_0 = InferSubgraphStorage(*attrs.subgraphs[0], dev_mask, \
796                                      &cond_mode, &cond_in_attrs, &cond_out_attrs);
797   CHECK(params.sync_in_out(in_attrs, out_attrs, is_udf));
798   CHECK(sync_in_in(params.cond_input_locs, in_attrs, &cond_in_attrs, is_udf));
799   bool succ_1 = InferSubgraphStorage(*attrs.subgraphs[1], dev_mask, \
800                                      &func_mode, &func_in_attrs, out_attrs);
801   CHECK(params.sync_in_out(in_attrs, out_attrs, is_udf));
802   CHECK(sync_in_in(params.func_input_locs, in_attrs, &func_in_attrs, is_udf));
803   return succ_0 && succ_1;
804 }
805 
BackwardWhileLoopStorageType(const nnvm::NodeAttrs & attrs,const int dev_mask,DispatchMode * dispatch_mode,std::vector<int> * in_attrs,std::vector<int> * out_attrs)806 static bool BackwardWhileLoopStorageType(const nnvm::NodeAttrs& attrs,
807                                          const int dev_mask,
808                                          DispatchMode* dispatch_mode,
809                                          std::vector<int> *in_attrs,
810                                          std::vector<int> *out_attrs) {
811   // `cond' is not backwarded, don't check
812   const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
813   CHECK_EQ(out_attrs->size() + 2U, (size_t) params.num_args);
814   CHECK_EQ(attrs.subgraphs.size(), 2U);
815   CachedOp op(*attrs.subgraphs[1], {});
816   return op.BackwardStorageType(attrs, dev_mask, dispatch_mode,
817                                 in_attrs, out_attrs);
818 }
819 
CreateWhileLoopState(const NodeAttrs & attrs,Context ctx,const mxnet::ShapeVector & ishape,const std::vector<int> & itype)820 static OpStatePtr CreateWhileLoopState(const NodeAttrs& attrs,
821                                        Context ctx,
822                                        const mxnet::ShapeVector& ishape,
823                                        const std::vector<int>& itype) {
824   const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
825   return OpStatePtr::Create<WhileLoopState>(params, *attrs.subgraphs[0], *attrs.subgraphs[1]);
826 }
827 
828 static std::vector<nnvm::NodeEntry>
WhileLoopGradient(const nnvm::ObjectPtr & n,const std::vector<nnvm::NodeEntry> & ograds)829 WhileLoopGradient(const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
830   ElemwiseGradUseInOut fgrad{"_backward_while_loop"};
831   std::vector<nnvm::NodeEntry> entries = fgrad(n, ograds);
832   entries[0].node->attrs.subgraphs = n->attrs.subgraphs;
833   return entries;
834 }
835 
836 struct CondParam : public dmlc::Parameter<CondParam> {
837   int num_args;
838   int num_outputs;
839   mxnet::Tuple<dim_t> cond_input_locs;
840   mxnet::Tuple<dim_t> then_input_locs;
841   mxnet::Tuple<dim_t> else_input_locs;
DMLC_DECLARE_PARAMETERmxnet::op::CondParam842   DMLC_DECLARE_PARAMETER(CondParam) {
843     DMLC_DECLARE_FIELD(num_args).set_lower_bound(3)
844     .describe("Number of input arguments, including cond, then and else as three symbol inputs.");
845     DMLC_DECLARE_FIELD(num_outputs).set_lower_bound(1)
846     .describe("The number of outputs of the subgraph.");
847     DMLC_DECLARE_FIELD(cond_input_locs)
848     .describe("The locations of cond's inputs in the given inputs.");
849     DMLC_DECLARE_FIELD(then_input_locs)
850     .describe("The locations of then's inputs in the given inputs.");
851     DMLC_DECLARE_FIELD(else_input_locs)
852     .describe("The locations of else's inputs in the given inputs.");
853   }
854 };  // struct CondParam
855 
856 DMLC_REGISTER_PARAMETER(CondParam);
857 
858 class CondState {
859  public:
860   CondParam params;
861   CachedOpPtr cond_op;
862   LoopState then_branch;
863   LoopState else_branch;
864   int branch_selection;  // 1 if then branch; 0 if else branch; -1 if undefined
865 
CondState(const CondParam & params,const Symbol & cond,const Symbol & then_sym,const Symbol & else_sym)866   CondState(const CondParam &params,
867             const Symbol &cond,
868             const Symbol &then_sym,
869             const Symbol &else_sym):
870             params(params),
871             cond_op(LoopState::MakeSharedOp(cond)),
872             then_branch(then_sym),
873             else_branch(else_sym),
874             branch_selection(-1) {
875   }
876 };
877 
CondComputeExCPU(const OpStatePtr & state_ptr,const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)878 static void CondComputeExCPU(const OpStatePtr& state_ptr,
879                              const OpContext& ctx,
880                              const std::vector<NDArray>& inputs,
881                              const std::vector<OpReqType>& req,
882                              const std::vector<NDArray>& outputs) {
883   // The argument `inputs' are loop_vars and other inputs
884   // loop_vars are stored in stored in `loop_vars_locs'
885   // The argument `outputs' are output and new_loop_vars
886   // [0: num_out_data) are outputs at each step.
887   // [num_out_data: ) are new_loop_vars
888   CondState &state = state_ptr.get_state<CondState>();
889   const CondParam& params = state.params;
890   // a helper function, converting std::vector<NDArray> to std::vector<NDArray*>
891   const auto to_ptr_vec = [](std::vector<NDArray> &in, std::vector<NDArray*> *out) {
892     out->clear();
893     out->reserve(in.size());
894     std::transform(std::begin(in),
895                    std::end(in),
896                    std::back_inserter(*out),
897                    [](NDArray &a) {return &a;});
898   };
899   // sanity checks
900   CHECK_EQ(inputs.size() + 3U, (size_t) params.num_args);
901   CHECK_EQ(outputs.size(), (size_t) params.num_outputs);
902   CHECK_EQ(outputs.size(), req.size());
903   // construct inputs and outputs for cond
904   std::vector<NDArray> cond_inputs;
905   std::vector<NDArray> cond_outputs = {NDArray()};
906   std::vector<NDArray*> cond_input_ptr;
907   std::vector<NDArray*> cond_output_ptr;
908   extract_by_loc(inputs, params.cond_input_locs, &cond_inputs);
909   to_ptr_vec(cond_inputs, &cond_input_ptr);
910   to_ptr_vec(cond_outputs, &cond_output_ptr);
911   int &branch_selection = state.branch_selection;
912   // run cond
913   state.cond_op->Forward(nullptr, cond_input_ptr, cond_output_ptr);
914   branch_selection = as_bool_scalar(*cond_output_ptr[0]);
915   // select the right branch
916   const mxnet::Tuple<dim_t> &func_input_locs = branch_selection
917                                             ? params.then_input_locs
918                                             : params.else_input_locs;
919   LoopState &loop_state = branch_selection
920                         ? state.then_branch
921                         : state.else_branch;
922   // extract inputs for the branch
923   std::vector<NDArray> func_inputs;
924   extract_by_loc(inputs, func_input_locs, &func_inputs);
925   loop_state.Forward(0, func_inputs, req, outputs, ctx.need_grad);
926 }
927 
CondGradComputeExCPU(const OpStatePtr & state_ptr,const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & _req,const std::vector<NDArray> & outputs)928 static void CondGradComputeExCPU(const OpStatePtr& state_ptr,
929                                  const OpContext& ctx,
930                                  const std::vector<NDArray>& inputs,
931                                  const std::vector<OpReqType>& _req,
932                                  const std::vector<NDArray>& outputs) {
933   CondState &state = state_ptr.get_state<CondState>();
934   const CondParam& params = state.params;
935   // sanity checks
936   CHECK_EQ(outputs.size() + 3U, (size_t) params.num_args);
937   CHECK_EQ(outputs.size(), _req.size());
938   // select the right branch
939   int branch_selection = state.branch_selection;
940   CHECK_NE(branch_selection, -1);
941   const mxnet::Tuple<dim_t> &func_input_locs = branch_selection
942                                             ? params.then_input_locs
943                                             : params.else_input_locs;
944   LoopState &loop_state = branch_selection
945                         ? state.then_branch
946                         : state.else_branch;
947   // construct parameters
948   std::vector<NDArray> ograds(inputs.begin(), inputs.begin() + params.num_outputs);
949   std::vector<OpReqType> req;
950   extract_by_loc(_req, func_input_locs, &req);
951   std::vector<NDArray> igrads;
952   extract_by_loc(outputs, func_input_locs, &igrads);
953   loop_state.Backward(0, ograds, req, igrads);
954   loop_state.Cleanup();
955 }
956 
CondType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_type,std::vector<int> * out_type)957 static bool CondType(const nnvm::NodeAttrs& attrs,
958                      std::vector<int> *in_type,
959                      std::vector<int> *out_type) {
960   const CondParam& params = nnvm::get<CondParam>(attrs.parsed);
961   static const std::function<bool(const int &)> is_udf = is_type_udf;
962   CHECK_EQ(in_type->size() + 3U, (size_t) params.num_args);
963   CHECK_EQ(out_type->size(), (size_t) params.num_outputs);
964   CHECK_EQ(attrs.subgraphs.size(), 3U);
965   CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U);
966   CHECK_EQ(attrs.subgraphs[1]->outputs.size(), attrs.subgraphs[2]->outputs.size());
967   std::vector<int> cond_in_type;
968   std::vector<int> then_in_type;
969   std::vector<int> else_in_type;
970   extract_by_loc(*in_type, params.cond_input_locs, &cond_in_type);
971   extract_by_loc(*in_type, params.then_input_locs, &then_in_type);
972   extract_by_loc(*in_type, params.else_input_locs, &else_in_type);
973   std::vector<int> cond_out_type = {0};
974   bool succ_0 = InferSubgraphDataType(*attrs.subgraphs[0], &cond_in_type, &cond_out_type);
975   CHECK(sync_in_in(params.cond_input_locs, in_type, &cond_in_type, is_udf));
976   bool succ_1 = InferSubgraphDataType(*attrs.subgraphs[1], &then_in_type, out_type);
977   CHECK(sync_in_in(params.then_input_locs, in_type, &then_in_type, is_udf));
978   bool succ_2 = InferSubgraphDataType(*attrs.subgraphs[2], &else_in_type, out_type);
979   CHECK(sync_in_in(params.else_input_locs, in_type, &else_in_type, is_udf));
980   return succ_0 && succ_1 && succ_2;
981 }
982 
CondStorageType(const nnvm::NodeAttrs & attrs,const int dev_mask,DispatchMode * dispatch_mode,std::vector<int> * in_attrs,std::vector<int> * out_attrs)983 static bool CondStorageType(const nnvm::NodeAttrs& attrs,
984                             const int dev_mask,
985                             DispatchMode* dispatch_mode,
986                             std::vector<int> *in_attrs,
987                             std::vector<int> *out_attrs) {
988   const CondParam& params = nnvm::get<CondParam>(attrs.parsed);
989   static const std::function<bool(const int &)> is_udf = is_stype_udf;
990   CHECK_EQ(in_attrs->size() + 3U, (size_t) params.num_args);
991   CHECK_EQ(out_attrs->size(), (size_t) params.num_outputs);
992   CHECK_EQ(attrs.subgraphs.size(), 3U);
993   CHECK_EQ(attrs.subgraphs[0]->outputs.size(), 1U);
994   CHECK_EQ(attrs.subgraphs[1]->outputs.size(), attrs.subgraphs[2]->outputs.size());
995   std::vector<int> cond_in_attrs;
996   std::vector<int> then_in_attrs;
997   std::vector<int> else_in_attrs;
998   extract_by_loc(*in_attrs, params.cond_input_locs, &cond_in_attrs);
999   extract_by_loc(*in_attrs, params.then_input_locs, &then_in_attrs);
1000   extract_by_loc(*in_attrs, params.else_input_locs, &else_in_attrs);
1001   std::vector<int> cond_out_attrs = {kDefaultStorage};
1002   DispatchMode cond_mode = DispatchMode::kUndefined;
1003   DispatchMode then_mode = DispatchMode::kUndefined;
1004   DispatchMode else_mode = DispatchMode::kUndefined;
1005   *dispatch_mode = DispatchMode::kFComputeEx;
1006   bool succ_0 = InferSubgraphStorage(*attrs.subgraphs[0], dev_mask, \
1007                                      &cond_mode, &cond_in_attrs, &cond_out_attrs);
1008   CHECK(sync_in_in(params.cond_input_locs, in_attrs, &cond_in_attrs, is_udf));
1009   bool succ_1 = InferSubgraphStorage(*attrs.subgraphs[1], dev_mask, \
1010                                      &then_mode, &then_in_attrs, out_attrs);
1011   CHECK(sync_in_in(params.then_input_locs, in_attrs, &then_in_attrs, is_udf));
1012   bool succ_2 = InferSubgraphStorage(*attrs.subgraphs[2], dev_mask, \
1013                                      &else_mode, &else_in_attrs, out_attrs);
1014   CHECK(sync_in_in(params.else_input_locs, in_attrs, &else_in_attrs, is_udf));
1015   return succ_0 && succ_1 && succ_2;
1016 }
1017 
BackwardCondStorageType(const nnvm::NodeAttrs & attrs,const int dev_mask,DispatchMode * dispatch_mode,std::vector<int> * in_attrs,std::vector<int> * out_attrs)1018 static bool BackwardCondStorageType(const nnvm::NodeAttrs& attrs,
1019                                     const int dev_mask,
1020                                     DispatchMode* dispatch_mode,
1021                                     std::vector<int> *in_attrs,
1022                                     std::vector<int> *out_attrs) {
1023   const CondParam& params = nnvm::get<CondParam>(attrs.parsed);
1024   CHECK_EQ(out_attrs->size() + 3U, (size_t) params.num_args);
1025   CHECK_EQ(attrs.subgraphs.size(), 3U);
1026   static const std::function<bool(const int &)> is_udf = is_stype_udf;
1027   auto sub_pass = [&](const std::shared_ptr<Symbol> &subg, const mxnet::Tuple<dim_t> &input_locs) {
1028     // A. first construct subg_in_attrs
1029     // need subg_in_attrs as subg_bwd_out (copy), subg_fwd_in (extract), subg_fwd_out (copy)
1030     std::vector<int> subg_in_attrs;
1031     size_t num_elts = params.num_outputs * 2 + input_locs.ndim();
1032     subg_in_attrs.reserve(num_elts);
1033     // part 1. subg_bwd_out (copy)
1034     subg_in_attrs.insert(subg_in_attrs.end(),
1035                          in_attrs->begin(),
1036                          in_attrs->begin() + params.num_outputs);
1037     // part 2. subg_fwd_in (extract)
1038     std::vector<int> fwd_in(in_attrs->begin() + params.num_outputs,
1039                             in_attrs->begin() + params.num_outputs + params.num_args - 3);
1040     std::vector<int> subg_fwd_in;
1041     extract_by_loc(fwd_in, input_locs, &subg_fwd_in);
1042     subg_in_attrs.insert(subg_in_attrs.end(),
1043                          subg_fwd_in.begin(),
1044                          subg_fwd_in.end());
1045     // part 3. subg_fwd_out (copy)
1046     subg_in_attrs.insert(subg_in_attrs.end(),
1047                          in_attrs->begin() + params.num_outputs + params.num_args - 3,
1048                          in_attrs->end());
1049     // check correctness of the number of elements
1050     CHECK_EQ(subg_in_attrs.size(), num_elts);
1051     // B. then we construct subg_out_attrs by extracting from out_attrs
1052     std::vector<int> subg_out_attrs;
1053     extract_by_loc(*out_attrs, input_locs, &subg_out_attrs);
1054     // then we construct the subgraph and do inference
1055     CachedOp op(*subg, {});
1056     bool ret = op.BackwardStorageType(attrs, dev_mask, dispatch_mode, \
1057                                       &subg_in_attrs, &subg_out_attrs);
1058     CHECK(sync_in_in(input_locs, out_attrs, &subg_out_attrs, is_udf));
1059     return ret;
1060   };
1061   for (const dim_t &cond_in : params.cond_input_locs) {
1062     (*out_attrs)[cond_in] = kDefaultStorage;
1063   }
1064   bool succ_0 = sub_pass(attrs.subgraphs[1], params.then_input_locs);
1065   bool succ_1 = sub_pass(attrs.subgraphs[2], params.else_input_locs);
1066   return succ_0 && succ_1;
1067 }
1068 
CreateCondState(const NodeAttrs & attrs,Context ctx,const mxnet::ShapeVector & ishape,const std::vector<int> & itype)1069 static OpStatePtr CreateCondState(const NodeAttrs& attrs,
1070                                   Context ctx,
1071                                   const mxnet::ShapeVector& ishape,
1072                                   const std::vector<int>& itype) {
1073   const CondParam& params = nnvm::get<CondParam>(attrs.parsed);
1074   return OpStatePtr::Create<CondState>(
1075     params,
1076     *attrs.subgraphs[0],
1077     *attrs.subgraphs[1],
1078     *attrs.subgraphs[2]);
1079 }
1080 
1081 static std::vector<nnvm::NodeEntry>
CondGradient(const nnvm::ObjectPtr & n,const std::vector<nnvm::NodeEntry> & ograds)1082 CondGradient(const nnvm::ObjectPtr& n, const std::vector<nnvm::NodeEntry>& ograds) {
1083   ElemwiseGradUseInOut fgrad{"_backward_cond"};
1084   std::vector<nnvm::NodeEntry> entries = fgrad(n, ograds);
1085   entries[0].node->attrs.subgraphs = n->attrs.subgraphs;
1086   return entries;
1087 }
1088 
1089 NNVM_REGISTER_OP(_foreach)
1090 .MXNET_DESCRIBE("Run a for loop over an NDArray with user-defined computation")
1091 .set_attr_parser(ParamParser<ForeachParam>)
1092 .set_attr<FInferStorageType>("FInferStorageType", ForeachStorageType)
__anona7e01d100702(const NodeAttrs& attrs) 1093 .set_num_inputs([](const NodeAttrs& attrs) {
1094   const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
1095   return params.num_args;
1096 })
__anona7e01d100802(const NodeAttrs& attrs) 1097 .set_num_outputs([](const NodeAttrs& attrs) {
1098   const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
1099   return params.num_outputs;
1100 })
1101 .set_attr<nnvm::FListInputNames>("FListInputNames",
__anona7e01d100902(const NodeAttrs& attrs) 1102     [](const NodeAttrs& attrs) {
1103   const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
1104   std::vector<std::string> names;
1105   names.emplace_back("fn");
1106   for (int i = 0; i < params.num_args - 1; i++)
1107     names.push_back("data" + std::to_string(i));
1108   return names;
1109 })
1110 .set_attr<nnvm::FInputGraph>("FInputGraph",
__anona7e01d100a02(const NodeAttrs& attrs) 1111     [](const NodeAttrs& attrs) {
1112   return std::vector<uint32_t>{0};
1113 })
1114 .set_attr<nnvm::FGradient>("FGradient", ForeachGradient)
1115 .set_attr<FCreateOpState>("FCreateOpState", CreateForeachState)
1116 .set_attr<mxnet::FInferShape>("FInferShape", ForeachShape)
1117 .set_attr<nnvm::FInferType>("FInferType", ForeachType)
1118 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", ForeachComputeExCPU)
1119 // Foreach operator works like an executor. Its code will always run on CPU.
1120 // So the same code can be registered for both CPU and GPU.
1121 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", ForeachComputeExCPU)
__anona7e01d100b02(const NodeAttrs& attrs) 1122 .set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
1123   return ExecType::kSubgraphExec;
1124 })
1125 .set_attr<std::string>("key_var_num_args", "num_args")
1126 .add_argument("fn", "Symbol", "Input graph.")
1127 .add_argument("data", "NDArray-or-Symbol[]",
1128               "The input arrays that include data arrays and states.")
1129 .add_arguments(ForeachParam::__FIELDS__());
1130 
1131 NNVM_REGISTER_OP(_backward_foreach)
__anona7e01d100c02(const NodeAttrs& attrs)1132 .set_num_inputs([](const NodeAttrs& attrs){
1133   const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
1134   return params.num_outputs * 2 + params.num_args - 1;
1135 })
__anona7e01d100d02(const NodeAttrs& attrs)1136 .set_num_outputs([](const NodeAttrs& attrs){
1137   const ForeachParam& params = nnvm::get<ForeachParam>(attrs.parsed);
1138   return params.num_args - 1;
1139 })
__anona7e01d100e02(const NodeAttrs& attrs) 1140 .set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
1141   return ExecType::kSubgraphExec;
1142 })
1143 .set_attr<FInferStorageType>("FInferStorageType", BackwardForeachStorageType)
1144 .set_attr_parser(ParamParser<ForeachParam>)
1145 .set_attr<bool>("TIsLayerOpBackward", true)
1146 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
1147 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", ForeachGradComputeExCPU)
1148 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", ForeachGradComputeExCPU);
1149 
1150 NNVM_REGISTER_OP(_while_loop)
1151 .MXNET_DESCRIBE("Run a while loop over with user-defined condition and computation")
1152 .set_attr_parser(ParamParser<WhileLoopParam>)
1153 .set_attr<FInferStorageType>("FInferStorageType", WhileLoopStorageType)
__anona7e01d100f02(const NodeAttrs& attrs) 1154 .set_num_inputs([](const NodeAttrs& attrs) {
1155   const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
1156   return params.num_args;
1157 })
__anona7e01d101002(const NodeAttrs& attrs) 1158 .set_num_outputs([](const NodeAttrs& attrs) {
1159   const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
1160   return params.num_outputs;
1161 })
1162 .set_attr<nnvm::FListInputNames>("FListInputNames",
__anona7e01d101102(const NodeAttrs& attrs) 1163     [](const NodeAttrs& attrs) {
1164   const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
1165   std::vector<std::string> names;
1166   names.reserve(params.num_args);
1167   names.emplace_back("cond");
1168   names.emplace_back("func");
1169   for (int i = 2; i < params.num_args; i++)
1170     names.push_back("data" + std::to_string(i - 2));
1171   return names;
1172 })
1173 .set_attr<nnvm::FInputGraph>("FInputGraph",
__anona7e01d101202(const NodeAttrs& attrs) 1174     [](const NodeAttrs& attrs) {
1175   return std::vector<uint32_t>{0, 1};
1176 })
1177 .set_attr<nnvm::FGradient>("FGradient", WhileLoopGradient)
1178 .set_attr<FCreateOpState>("FCreateOpState", CreateWhileLoopState)
1179 .set_attr<nnvm::FInferType>("FInferType", WhileLoopType)
1180 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", WhileLoopComputeExCPU)
__anona7e01d101302(const NodeAttrs& attrs) 1181 .set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
1182   return ExecType::kSubgraphExec;
1183 })
1184 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", WhileLoopComputeExCPU)
1185 .set_attr<std::string>("key_var_num_args", "num_args")
1186 .add_argument("cond", "Symbol", "Input graph for the loop condition.")
1187 .add_argument("func", "Symbol", "Input graph for the loop body.")
1188 .add_argument("data", "NDArray-or-Symbol[]",
1189               "The input arrays that include data arrays and states.")
1190 .add_arguments(WhileLoopParam::__FIELDS__());
1191 
1192 NNVM_REGISTER_OP(_backward_while_loop)
__anona7e01d101402(const NodeAttrs& attrs)1193 .set_num_inputs([](const NodeAttrs& attrs){
1194   const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
1195   return params.num_outputs * 2 + params.num_args - 2;
1196 })
__anona7e01d101502(const NodeAttrs& attrs)1197 .set_num_outputs([](const NodeAttrs& attrs){
1198   const WhileLoopParam& params = nnvm::get<WhileLoopParam>(attrs.parsed);
1199   return params.num_args - 2;
1200 })
__anona7e01d101602(const NodeAttrs& attrs) 1201 .set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
1202   return ExecType::kSubgraphExec;
1203 })
1204 .set_attr<FInferStorageType>("FInferStorageType", BackwardWhileLoopStorageType)
1205 .set_attr_parser(ParamParser<WhileLoopParam>)
1206 .set_attr<bool>("TIsLayerOpBackward", true)
1207 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
1208 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", WhileLoopGradComputeExCPU)
1209 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", WhileLoopGradComputeExCPU);
1210 
1211 NNVM_REGISTER_OP(_cond)
1212 .MXNET_DESCRIBE("Run a if-then-else using user-defined condition and computation")
1213 .set_attr_parser(ParamParser<CondParam>)
1214 .set_attr<FInferStorageType>("FInferStorageType", CondStorageType)
__anona7e01d101702(const NodeAttrs& attrs) 1215 .set_num_inputs([](const NodeAttrs& attrs) {
1216   const CondParam& params = nnvm::get<CondParam>(attrs.parsed);
1217   return params.num_args;
1218 })
__anona7e01d101802(const NodeAttrs& attrs) 1219 .set_num_outputs([](const NodeAttrs& attrs) {
1220   const CondParam& params = nnvm::get<CondParam>(attrs.parsed);
1221   return params.num_outputs;
1222 })
1223 .set_attr<nnvm::FListInputNames>("FListInputNames",
__anona7e01d101902(const NodeAttrs& attrs) 1224     [](const NodeAttrs& attrs) {
1225   const CondParam& params = nnvm::get<CondParam>(attrs.parsed);
1226   std::vector<std::string> names;
1227   names.reserve(params.num_args);
1228   names.emplace_back("cond");
1229   names.emplace_back("then_branch");
1230   names.emplace_back("else_branch");
1231   for (int i = 3; i < params.num_args; ++i)
1232     names.push_back("data" + std::to_string(i - 3));
1233   return names;
1234 })
1235 .set_attr<nnvm::FInputGraph>("FInputGraph",
__anona7e01d101a02(const NodeAttrs& attrs) 1236     [](const NodeAttrs& attrs) {
1237   return std::vector<uint32_t>{0, 1, 2};
1238 })
1239 .set_attr<nnvm::FGradient>("FGradient", CondGradient)
1240 .set_attr<FCreateOpState>("FCreateOpState", CreateCondState)
1241 .set_attr<nnvm::FInferType>("FInferType", CondType)
1242 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", CondComputeExCPU)
__anona7e01d101b02(const NodeAttrs& attrs) 1243 .set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
1244   return ExecType::kSubgraphExec;
1245 })
1246 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", CondComputeExCPU)
1247 .set_attr<std::string>("key_var_num_args", "num_args")
1248 .add_argument("cond", "Symbol", "Input graph for the condition.")
1249 .add_argument("then_branch", "Symbol", "Input graph for the then branch.")
1250 .add_argument("else_branch", "Symbol", "Input graph for the else branch.")
1251 .add_argument("data", "NDArray-or-Symbol[]",
1252               "The input arrays that include data arrays and states.")
1253 .add_arguments(CondParam::__FIELDS__());
1254 
1255 NNVM_REGISTER_OP(_backward_cond)
__anona7e01d101c02(const NodeAttrs& attrs)1256 .set_num_inputs([](const NodeAttrs& attrs){
1257   const CondParam& params = nnvm::get<CondParam>(attrs.parsed);
1258   return params.num_outputs * 2 + params.num_args - 3;
1259 })
__anona7e01d101d02(const NodeAttrs& attrs)1260 .set_num_outputs([](const NodeAttrs& attrs){
1261   const CondParam& params = nnvm::get<CondParam>(attrs.parsed);
1262   return params.num_args - 3;
1263 })
__anona7e01d101e02(const NodeAttrs& attrs) 1264 .set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
1265   return ExecType::kSubgraphExec;
1266 })
1267 .set_attr<FInferStorageType>("FInferStorageType", BackwardCondStorageType)
1268 .set_attr_parser(ParamParser<CondParam>)
1269 .set_attr<bool>("TIsLayerOpBackward", true)
1270 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
1271 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", CondGradComputeExCPU)
1272 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", CondGradComputeExCPU);
1273 }  // namespace op
1274 }  // namespace mxnet
1275