1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file legacy_op_util.cc
22  * \brief Utility to adapt OpProperty to the new NNVM registery
23  */
24 #include <dmlc/base.h>
25 #include <dmlc/strtonum.h>
26 #include <mxnet/base.h>
27 #include <mxnet/operator.h>
28 #include <mxnet/op_attr_types.h>
29 #include <mxnet/ndarray.h>
30 #include <nnvm/node.h>
31 #include <nnvm/graph.h>
32 #include <memory>
33 
34 namespace mxnet {
35 namespace op {
36 
37 using nnvm::Op;
38 using nnvm::Node;
39 using nnvm::ObjectPtr;
40 using nnvm::NodeAttrs;
41 using nnvm::NodeEntry;
42 
43 class ParsedOpProp {
44  public:
45   std::shared_ptr<OperatorProperty> ptr;
46   std::vector<std::string> arguments;
47   std::vector<std::string> aux_states;
48   std::vector<std::string> inputs;
49   std::vector<std::string> outputs;
50   // initializer
Init(const NodeAttrs & attrs)51   void Init(const NodeAttrs& attrs) {
52     // For performance, do a reserve first and then copy attrs.dict
53     std::vector<std::pair<std::string, std::string> > kwargs;
54     kwargs.reserve(attrs.dict.size());
55     kwargs.insert(kwargs.end(), attrs.dict.begin(), attrs.dict.end());
56     try {
57       ptr->Init(kwargs);
58     } catch (const dmlc::ParamError& e) {
59       std::ostringstream os;
60       os << e.what();
61       os << ", in operator " << attrs.op->name << "("
62          << "name=\"" << attrs.name << "\"";
63       for (const auto& k : attrs.dict) {
64         os << ", " << k.first << "=\"" << k.second << "\"";
65       }
66       os << ")";
67       throw dmlc::ParamError(os.str());
68     }
69     arguments = ptr->ListArguments();
70     aux_states = ptr->ListAuxiliaryStates();
71     outputs = ptr->ListOutputs();
72     inputs = arguments;
73     inputs.insert(
74         inputs.end(), aux_states.begin(), aux_states.end());
75   }
76 };
77 
78 class OperatorState {
79  public:
OperatorState(Operator * opr,const OperatorProperty * prop)80   OperatorState(Operator *opr, const OperatorProperty *prop) {
81     opr_ = opr;
82 
83     in_data_fwd_.resize(prop->ListArguments().size());
84     in_data_bwd_.resize(prop->ListArguments().size());
85     out_data_.resize(prop->NumOutputs());
86     aux_data_.resize(prop->ListAuxiliaryStates().size());
87     in_grad_.resize(in_data_fwd_.size());
88     out_grad_.resize(prop->NumVisibleOutputs());
89 
90     std::vector<TBlob*> out_grad_ptr(out_grad_.size());
91     for (size_t i = 0; i < out_grad_.size(); ++i) {
92       out_grad_ptr[i] = &out_grad_[i];
93     }
94     std::vector<TBlob*> in_data_ptr(in_data_fwd_.size());
95     for (size_t i = 0; i < in_data_fwd_.size(); ++i) {
96       in_data_ptr[i] = &in_data_bwd_[i];
97     }
98     std::vector<TBlob*> out_data_ptr(out_data_.size());
99     for (size_t i = 0; i < out_data_.size(); ++i) {
100       out_data_ptr[i] = &out_data_[i];
101     }
102     arg_data_ptr_ = prop->BackwardInputs(
103         out_grad_ptr, in_data_ptr, out_data_ptr);
104   }
105 
~OperatorState()106   ~OperatorState() { delete opr_; }
107 
Forward(const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)108   void Forward(const OpContext &ctx,
109                const std::vector<TBlob>& inputs,
110                const std::vector<OpReqType>& req,
111                const std::vector<TBlob>& outputs) {
112     CHECK_EQ(inputs.size(), in_data_fwd_.size() + aux_data_.size());
113     CHECK_EQ(outputs.size(), out_data_.size());
114     // in_data_bwd_ has the same tblobs as the ones in in_data_fwd_, except that the ones
115     // referred by arg_data_ptr_ will be overriden
116     for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_fwd_[i] = inputs[i];
117     for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_bwd_[i] = inputs[i];
118     for (size_t i = 0; i < aux_data_.size(); ++i) {
119       aux_data_[i] = inputs[i + in_data_fwd_.size()];
120     }
121     for (size_t i = 0; i < out_data_.size(); ++i) out_data_[i] = outputs[i];
122     opr_->Forward(ctx, in_data_fwd_, req, out_data_, aux_data_);
123   }
124 
Backward(const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)125   void Backward(const OpContext &ctx,
126                 const std::vector<TBlob>& inputs,
127                 const std::vector<OpReqType>& req,
128                 const std::vector<TBlob>& outputs) {
129     CHECK_EQ(arg_data_ptr_.size() + aux_data_.size(), inputs.size());
130     // override tblobs pointed by arg_data_ptr_ since they might not contain
131     // initialized data during forward pass.
132     for (size_t i = 0; i < arg_data_ptr_.size(); ++i) {
133       *arg_data_ptr_[i] = inputs[i];
134     }
135     for (size_t i = 0; i < aux_data_.size(); ++i) {
136       aux_data_[i] = inputs[inputs.size() - aux_data_.size() + i];
137     }
138     CHECK_EQ(outputs.size(), in_grad_.size());
139     for (size_t i = 0; i < outputs.size(); ++i) in_grad_[i] = outputs[i];
140     opr_->Backward(ctx, out_grad_, in_data_bwd_, out_data_, req, in_grad_, aux_data_);
141   }
142 
143  private:
144   Operator *opr_;
145   // input data blobs for forward and backward
146   // in_data_fwd_ and in_data_bwd_ will hold different tblobs when StorageFallbackOpExecutor
147   // performs storage fallback on a non-default input NDArray. The one in in_data_fwd_ is
148   // generated when setting up forward executor, while the one in in_data_bwd_ is generated
149   // when setting up backward executor.
150   std::vector<TBlob> in_data_fwd_, in_data_bwd_;
151   std::vector<TBlob> aux_data_, out_data_, in_grad_, out_grad_;
152   std::vector<TBlob*> arg_data_ptr_;
153 };
154 
LegacyOpForward(const OpStatePtr & state,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)155 void LegacyOpForward(const OpStatePtr& state,
156                      const OpContext& ctx,
157                      const std::vector<TBlob>& inputs,
158                      const std::vector<OpReqType>& req,
159                      const std::vector<TBlob>& outputs) {
160   auto& op = state.get_state<OperatorState>();
161   op.Forward(ctx, inputs, req, outputs);
162 }
163 
LegacyOpBackward(const OpStatePtr & state,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)164 void LegacyOpBackward(const OpStatePtr& state,
165                       const OpContext& ctx,
166                       const std::vector<TBlob>& inputs,
167                       const std::vector<OpReqType>& req,
168                       const std::vector<TBlob>& outputs) {
169   auto& op = state.get_state<OperatorState>();
170   op.Backward(ctx, inputs, req, outputs);
171 }
172 
173 // function to use operator property to infer attr
174 // get op property from the attribute
OpPropGetOpProperty(const NodeAttrs & attrs)175 const OperatorProperty* OpPropGetOpProperty(const NodeAttrs& attrs) {
176   return nnvm::get<ParsedOpProp>(attrs.parsed).ptr.get();
177 }
178 
179 template<typename AttrType, typename FInfer>
OpPropInferAttr(const NodeAttrs & attrs,std::vector<AttrType> * iattr,std::vector<AttrType> * oattr,FInfer finfer)180 bool OpPropInferAttr(const NodeAttrs& attrs,
181                      std::vector<AttrType> *iattr,
182                      std::vector<AttrType> *oattr,
183                      FInfer finfer) {
184   auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
185   CHECK_EQ(prop.inputs.size(), iattr->size())
186       << "op=" << attrs.op->name
187       << ", inputs.size=" << prop.inputs.size()
188       << ", iattr.size=" << iattr->size()
189       << ", arg.size=" << prop.arguments.size();
190   std::vector<AttrType> in_attr(prop.arguments.size());
191   std::vector<AttrType> aux_attr(prop.aux_states.size());
192 
193   for (size_t i = 0; i < prop.arguments.size(); ++i) {
194     in_attr[i] = (*iattr)[i];
195   }
196   for (size_t i = 0; i < prop.aux_states.size(); ++i) {
197     aux_attr[i] = (*iattr)[i + prop.arguments.size()];
198   }
199   if (!finfer(prop.ptr.get(), &in_attr, oattr, &aux_attr)) return false;
200 
201   for (size_t i = 0; i < prop.arguments.size(); ++i) {
202     (*iattr)[i] = in_attr[i];
203   }
204   for (size_t i = 0; i < prop.aux_states.size(); ++i) {
205     (*iattr)[i + prop.arguments.size()] = aux_attr[i];
206   }
207   return true;
208 }
209 
OpPropInferShape(const NodeAttrs & attrs,mxnet::ShapeVector * iattr,mxnet::ShapeVector * oattr)210 bool OpPropInferShape(const NodeAttrs& attrs,
211                       mxnet::ShapeVector *iattr,
212                       mxnet::ShapeVector *oattr) {
213   auto finfer = [](const OperatorProperty* op,
214                    mxnet::ShapeVector *in,
215                    mxnet::ShapeVector *out,
216                    mxnet::ShapeVector *aux) {
217     return op->InferShape(in, out, aux);
218   };
219   return OpPropInferAttr(attrs, iattr, oattr, finfer);
220 }
221 
OpPropInferType(const NodeAttrs & attrs,std::vector<int> * iattr,std::vector<int> * oattr)222 bool OpPropInferType(const NodeAttrs& attrs,
223                       std::vector<int> *iattr,
224                       std::vector<int> *oattr) {
225   auto finfer = [](const OperatorProperty* op,
226                    std::vector<int> *in,
227                    std::vector<int> *out,
228                    std::vector<int> *aux) {
229     return op->InferType(in, out, aux);
230   };
231   return OpPropInferAttr(attrs, iattr, oattr, finfer);
232 }
233 
OpPropNumInputs(const NodeAttrs & attrs)234 inline uint32_t OpPropNumInputs(const NodeAttrs& attrs) {
235   auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
236   return static_cast<uint32_t>(prop.inputs.size());
237 }
238 
OpPropNumOutputs(const NodeAttrs & attrs)239 inline uint32_t OpPropNumOutputs(const NodeAttrs& attrs) {
240   auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
241   return static_cast<uint32_t>(prop.outputs.size());
242 }
243 
OpPropNumVisibleOutputs(const NodeAttrs & attrs)244 inline uint32_t OpPropNumVisibleOutputs(const NodeAttrs& attrs) {
245   auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
246   return static_cast<uint32_t>(prop.ptr->NumVisibleOutputs());
247 }
248 
OpPropListInputNames(const NodeAttrs & attrs)249 std::vector<std::string> OpPropListInputNames(const NodeAttrs& attrs) {
250   auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
251   return prop.inputs;
252 }
253 
OpPropListOutputNames(const NodeAttrs & attrs)254 std::vector<std::string> OpPropListOutputNames(const NodeAttrs& attrs) {
255   auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
256   return prop.outputs;
257 }
258 
OpPropMutateInputs(const NodeAttrs & attrs)259 std::vector<uint32_t> OpPropMutateInputs(const NodeAttrs& attrs) {
260   auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
261   std::vector<uint32_t> ret;
262   for (uint32_t i = 0; i < prop.aux_states.size(); ++i) {
263     ret.push_back(static_cast<uint32_t>(i + prop.arguments.size()));
264   }
265   return ret;
266 }
267 
OpPropInplaceOption(const NodeAttrs & attrs)268 std::vector<std::pair<int, int> > OpPropInplaceOption(const NodeAttrs& attrs) {
269   auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
270   std::vector<int> in_data(prop.arguments.size());
271   std::vector<int> out_data(prop.outputs.size());
272   std::vector<void*> out_addr(prop.outputs.size());
273   for (size_t i = 0; i < in_data.size(); ++i) {
274     in_data[i] = static_cast<int>(i);
275   }
276   for (size_t i = 0; i < out_data.size(); ++i) {
277     out_data[i] = static_cast<int>(i);
278     out_addr[i] = &out_data[i];
279   }
280   std::vector<std::pair<int, int> > forward_inplace;
281   for (auto& kv : prop.ptr->ForwardInplaceOption(in_data, out_addr)) {
282     forward_inplace.emplace_back(kv.first, *static_cast<int*>(kv.second));
283   }
284   return forward_inplace;
285 }
286 
OpPropResourceRequest(const NodeAttrs & attrs)287 std::vector<ResourceRequest> OpPropResourceRequest(const NodeAttrs& attrs) {
288   mxnet::ShapeVector ishape;
289   auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
290   return prop.ptr->ForwardResource(ishape);
291 }
292 
OpBackResourceRequest(const NodeAttrs & attrs)293 std::vector<ResourceRequest> OpBackResourceRequest(const NodeAttrs& attrs) {
294   mxnet::ShapeVector ishape;
295   auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
296   return prop.ptr->BackwardResource(ishape);
297 }
298 
OpPropCreateLayerOp(const NodeAttrs & attrs,Context ctx,const mxnet::ShapeVector & ishape,const std::vector<int> & itype)299 OpStatePtr OpPropCreateLayerOp(const NodeAttrs& attrs,
300                                Context ctx,
301                                const mxnet::ShapeVector& ishape,
302                                const std::vector<int>& itype) {
303   auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
304   mxnet::ShapeVector is(ishape.begin(), ishape.begin() + prop.arguments.size());
305   std::vector<int> it(itype.begin(), itype.begin() + prop.arguments.size());
306   return OpStatePtr::Create<OperatorState>(prop.ptr->CreateOperatorEx(ctx, &is, &it),
307                                            prop.ptr.get());
308 }
309 
OpPropGradient(const Op * back_op,const ObjectPtr & ptr,const std::vector<NodeEntry> & out_grads)310 inline std::vector<NodeEntry> OpPropGradient(
311     const Op* back_op,
312     const ObjectPtr& ptr,
313     const std::vector<NodeEntry>& out_grads) {
314   auto& prop = nnvm::get<ParsedOpProp>(ptr->attrs.parsed);
315   std::vector<NodeEntry> out_data;
316   out_data.reserve(prop.outputs.size());
317   for (size_t i = 0; i < prop.outputs.size(); ++i)
318     out_data.emplace_back(ptr, i, 0);
319 
320   std::vector<NodeEntry> in_data(
321       ptr->inputs.begin(), ptr->inputs.begin() + prop.arguments.size());
322   std::vector<NodeEntry> ograd(
323       out_grads.begin(), out_grads.begin() + prop.ptr->NumVisibleOutputs());
324   auto inputs = prop.ptr->BackwardInputs(ograd, in_data, out_data);
325   // add all the auxiliary data
326   for (size_t i = 0; i < prop.aux_states.size(); ++i) {
327     inputs.emplace_back(ptr->inputs[i + prop.arguments.size()]);
328   }
329   ObjectPtr gnode = Node::Create();
330   gnode->inputs = std::move(inputs);
331   gnode->control_deps.emplace_back(ptr);
332   gnode->attrs = ptr->attrs;
333   gnode->attrs.op = back_op;
334   gnode->attrs.name = ptr->attrs.name + "_backward";
335   std::vector<NodeEntry> in_grad;
336   in_grad.reserve(prop.arguments.size() + prop.aux_states.size());
337   for (size_t i = 0; i < prop.arguments.size(); ++i) {
338     in_grad.emplace_back(gnode, i, 0);
339   }
340   // attach no gradient node to forbid gradient on aux_state
341   if (prop.aux_states.size() != 0) {
342     for (size_t i = 0; i < prop.aux_states.size(); ++i) {
343       in_grad.emplace_back(Node::Create(Op::Get("_NoGradient"), "NoGradient"), 0, 0);
344     }
345   }
346   return in_grad;
347 }
348 
OpBackNumOutputs(const NodeAttrs & attrs)349 inline uint32_t OpBackNumOutputs(const NodeAttrs& attrs) {
350   auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
351   return static_cast<uint32_t>(prop.arguments.size());
352 }
353 
OpBackListOutputNames(const NodeAttrs & attrs)354 std::vector<std::string> OpBackListOutputNames(const NodeAttrs& attrs) {
355   auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
356   return prop.arguments;
357 }
358 
OpBackMutateInputs(const NodeAttrs & attrs)359 std::vector<uint32_t> OpBackMutateInputs(const NodeAttrs& attrs) {
360   auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
361   if (prop.aux_states.size() == 0) return std::vector<uint32_t>{};
362   std::vector<int> out_grad_index(prop.ptr->NumVisibleOutputs());
363   std::vector<int> in_data_index(prop.arguments.size());
364   std::vector<int> out_data_index(prop.outputs.size());
365   size_t arg_size = prop.ptr->DeclareBackwardDependency(
366       out_grad_index, in_data_index, out_data_index).size();
367   std::vector<uint32_t> ret;
368   for (uint32_t i = 0; i < prop.aux_states.size(); ++i) {
369     ret.push_back(static_cast<uint32_t>(i + arg_size));
370   }
371   return ret;
372 }
373 
OpBackInplaceOption(const NodeAttrs & attrs)374 std::vector<std::pair<int, int> > OpBackInplaceOption(const NodeAttrs& attrs) {
375   auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
376   std::vector<int> out_grad_index(prop.ptr->NumVisibleOutputs());
377   std::vector<int> in_data_index(prop.arguments.size());
378   std::vector<int> out_data_index(prop.outputs.size());
379 
380   int counter = 0;
381   for (const int& i : in_data_index) {
382     in_data_index[i] = counter++;
383   }
384   for (const int& i : out_grad_index) {
385     out_grad_index[i] = counter++;
386   }
387   for (const int& i : out_data_index) {
388     out_data_index[i] = counter++;
389   }
390 
391   auto args_index = prop.ptr->DeclareBackwardDependency(
392       out_grad_index, in_data_index, out_data_index);
393   std::vector<int> args_array(counter, -1);
394   for (size_t i = 0; i < args_index.size(); ++i) {
395     args_array[args_index[i]] = static_cast<int>(i);
396   }
397 
398   std::vector<void*> in_grad_ptr(in_data_index.size());
399   for (size_t i = 0; i < in_grad_ptr.size(); ++i) {
400     // in data index starts from 0 to num_inputs
401     in_grad_ptr[i] = (void*)&in_data_index[i];  // NOLINT(*)
402   }
403 
404   auto remap_index = prop.ptr->BackwardInplaceOption(
405       out_grad_index, in_data_index, out_data_index, in_grad_ptr);
406   std::vector<std::pair<int, int> > remap(remap_index.size());
407   for (size_t i = 0; i < remap_index.size(); ++i) {
408     if (args_array[remap_index[i].first] == -1) {
409       LOG(FATAL) << "BackwardInplaceOption not consistent with DeclareBackwardDependency";
410     }
411     remap[i].first = args_array[remap_index[i].first];
412     remap[i].second = *static_cast<int*>(remap_index[i].second);
413   }
414   return remap;
415 }
416 
OpExecType(const NodeAttrs & attrs)417 inline ExecType OpExecType(const NodeAttrs& attrs) {
418   auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
419   return prop.ptr->exec_type();
420 }
421 
422 // register the legacy operator properties under NNVM registry.
RegisterLegacyOpProp()423 void RegisterLegacyOpProp() {
424   for (auto reg : dmlc::Registry<OperatorPropertyReg>::List()) {
425     Op& op = ::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(reg->name);
426     if (op.attr_parser != nullptr) continue;
427     auto creator = reg->body;
428     auto attr_parser = [creator](NodeAttrs* attrs) {
429       if (attrs->parsed.empty()) {
430         ParsedOpProp op;
431         op.ptr.reset(creator());
432         op.Init(*attrs);
433         attrs->parsed = std::move(op);
434       }
435     };
436     op.add_arguments(reg->arguments);
437     op.describe(reg->description);
438     // attribute parser
439     op.set_attr_parser(attr_parser);
440     op.set_num_inputs(OpPropNumInputs);
441     op.set_num_outputs(OpPropNumOutputs);
442     op.set_attr<nnvm::FListInputNames>("FListInputNames", OpPropListInputNames);
443     op.set_attr<nnvm::FListOutputNames>("FListOutputNames", OpPropListOutputNames);
444     op.set_attr<nnvm::FNumVisibleOutputs>("FNumVisibleOutputs", OpPropNumVisibleOutputs);
445     op.set_attr<mxnet::FInferShape>("FInferShape", OpPropInferShape);
446     op.set_attr<nnvm::FInferType>("FInferType", OpPropInferType);
447     op.set_attr<nnvm::FMutateInputs>("FMutateInputs", OpPropMutateInputs);
448     op.set_attr<nnvm::FInplaceOption>("FInplaceOption", OpPropInplaceOption);
449     op.set_attr<FResourceRequest>("FResourceRequest", OpPropResourceRequest);
450     op.set_attr<FExecType>("FExecType", OpExecType);
451     op.set_attr<FCreateOpState>("FCreateOpState", OpPropCreateLayerOp);
452     op.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", LegacyOpForward);
453     op.set_attr<FStatefulCompute>("FStatefulCompute<gpu>", LegacyOpForward);
454     if (reg->key_var_num_args.length() != 0) {
455       op.set_attr<std::string>("key_var_num_args", reg->key_var_num_args);
456     }
457 
458     // register BackwardOps
459     std::string back_op_name = "_backward_" + reg->name;
460     Op& back_op = ::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER__(back_op_name);
461     op.set_attr<nnvm::FGradient>("FGradient", std::bind(
462         OpPropGradient, &back_op,
463         std::placeholders::_1, std::placeholders::_2));
464     back_op.set_attr_parser(attr_parser);
465     back_op.set_num_inputs(nnvm::kVarg);
466     back_op.set_num_outputs(OpBackNumOutputs);
467     back_op.set_attr<nnvm::FListOutputNames>("FListOutputNames", OpBackListOutputNames);
468     back_op.set_attr<nnvm::FMutateInputs>("FMutateInputs", OpBackMutateInputs);
469     back_op.set_attr<nnvm::FInplaceOption>("FInplaceOption", OpBackInplaceOption);
470     back_op.set_attr<FResourceRequest>(
471         "FResourceRequest", OpBackResourceRequest);
472     back_op.set_attr<bool>("TIsLayerOpBackward", true);
473     back_op.set_attr<bool>("TIsBackward", true);
474     back_op.set_attr<FExecType>("FExecType", OpExecType);
475     back_op.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", LegacyOpBackward);
476     back_op.set_attr<FStatefulCompute>("FStatefulCompute<gpu>", LegacyOpBackward);
477   }
478 }
479 
480 // no gradient operator
481 NNVM_REGISTER_OP(_NoGradient)
482 .set_num_inputs(0)
483 .set_num_outputs(1)
484 .describe("Place holder for variable who cannot perform gradient");
485 
RegisterLegacyNDFunc()486 void RegisterLegacyNDFunc() {
487   for (auto reg : dmlc::Registry<NDArrayFunctionReg>::List()) {
488     if (reg->type_mask & kScalarArgBeforeNDArray) continue;
489     Op& op = ::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(reg->name);
490     if (op.attr_parser != nullptr) continue;
491 
492     CHECK_LE(reg->num_scalars + reg->num_use_vars, reg->arguments.size())
493         << reg->name;
494     auto func = reg->body;
495     op.describe(reg->description);
496     op.add_arguments(reg->arguments);
497     op.set_num_inputs(reg->num_use_vars);
498     op.set_num_outputs(reg->num_mutate_vars);
499     op.set_attr_parser([](NodeAttrs* attrs){});
500     op.set_attr<FNDArrayFunction>("FNDArrayFunction", [reg](const nnvm::NodeAttrs& attrs,
501                                                             const std::vector<NDArray>& inputs,
502                                                             std::vector<NDArray>* outputs) {
503         CHECK_EQ(inputs.size(), reg->num_use_vars);
504         CHECK_EQ(outputs->size(), reg->num_mutate_vars);
505 
506         int n_scalars = reg->num_scalars;
507         std::vector<float> scalars;
508         scalars.reserve(n_scalars);
509         auto dict = attrs.dict;
510         for (int i = 0; i < n_scalars; ++i) {
511           const std::string& name = reg->arguments[i+reg->num_use_vars].name;
512           auto s = dict.find(name);
513           CHECK(s != dict.end()) << "Missing scalar param " << name;
514           scalars.push_back(dmlc::stof(s->second));
515           dict.erase(s);
516         }
517 
518         int n_params = dict.size();
519         std::vector<const char*> keys, vals;
520         keys.reserve(n_params);
521         vals.reserve(n_params);
522         for (auto& i : dict) {
523           keys.push_back(dmlc::BeginPtr(i.first));
524           vals.push_back(dmlc::BeginPtr(i.second));
525         }
526         std::vector<NDArray*> input_ptrs, output_ptrs;
527         for (auto& i : inputs) {
528           input_ptrs.push_back(const_cast<NDArray*>(&i));
529         }
530         for (auto& i : *outputs) {
531           output_ptrs.push_back(&i);
532         }
533         reg->body(input_ptrs.data(),
534                   scalars.data(),
535                   output_ptrs.data(),
536                   n_params,
537                   const_cast<char**>(keys.data()),
538                   const_cast<char**>(vals.data()));
539       });
540   }
541 }
542 
543 }  // namespace op
544 }  // namespace mxnet
545