1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file legacy_op_util.cc
22 * \brief Utility to adapt OpProperty to the new NNVM registery
23 */
24 #include <dmlc/base.h>
25 #include <dmlc/strtonum.h>
26 #include <mxnet/base.h>
27 #include <mxnet/operator.h>
28 #include <mxnet/op_attr_types.h>
29 #include <mxnet/ndarray.h>
30 #include <nnvm/node.h>
31 #include <nnvm/graph.h>
32 #include <memory>
33
34 namespace mxnet {
35 namespace op {
36
37 using nnvm::Op;
38 using nnvm::Node;
39 using nnvm::ObjectPtr;
40 using nnvm::NodeAttrs;
41 using nnvm::NodeEntry;
42
43 class ParsedOpProp {
44 public:
45 std::shared_ptr<OperatorProperty> ptr;
46 std::vector<std::string> arguments;
47 std::vector<std::string> aux_states;
48 std::vector<std::string> inputs;
49 std::vector<std::string> outputs;
50 // initializer
Init(const NodeAttrs & attrs)51 void Init(const NodeAttrs& attrs) {
52 // For performance, do a reserve first and then copy attrs.dict
53 std::vector<std::pair<std::string, std::string> > kwargs;
54 kwargs.reserve(attrs.dict.size());
55 kwargs.insert(kwargs.end(), attrs.dict.begin(), attrs.dict.end());
56 try {
57 ptr->Init(kwargs);
58 } catch (const dmlc::ParamError& e) {
59 std::ostringstream os;
60 os << e.what();
61 os << ", in operator " << attrs.op->name << "("
62 << "name=\"" << attrs.name << "\"";
63 for (const auto& k : attrs.dict) {
64 os << ", " << k.first << "=\"" << k.second << "\"";
65 }
66 os << ")";
67 throw dmlc::ParamError(os.str());
68 }
69 arguments = ptr->ListArguments();
70 aux_states = ptr->ListAuxiliaryStates();
71 outputs = ptr->ListOutputs();
72 inputs = arguments;
73 inputs.insert(
74 inputs.end(), aux_states.begin(), aux_states.end());
75 }
76 };
77
78 class OperatorState {
79 public:
OperatorState(Operator * opr,const OperatorProperty * prop)80 OperatorState(Operator *opr, const OperatorProperty *prop) {
81 opr_ = opr;
82
83 in_data_fwd_.resize(prop->ListArguments().size());
84 in_data_bwd_.resize(prop->ListArguments().size());
85 out_data_.resize(prop->NumOutputs());
86 aux_data_.resize(prop->ListAuxiliaryStates().size());
87 in_grad_.resize(in_data_fwd_.size());
88 out_grad_.resize(prop->NumVisibleOutputs());
89
90 std::vector<TBlob*> out_grad_ptr(out_grad_.size());
91 for (size_t i = 0; i < out_grad_.size(); ++i) {
92 out_grad_ptr[i] = &out_grad_[i];
93 }
94 std::vector<TBlob*> in_data_ptr(in_data_fwd_.size());
95 for (size_t i = 0; i < in_data_fwd_.size(); ++i) {
96 in_data_ptr[i] = &in_data_bwd_[i];
97 }
98 std::vector<TBlob*> out_data_ptr(out_data_.size());
99 for (size_t i = 0; i < out_data_.size(); ++i) {
100 out_data_ptr[i] = &out_data_[i];
101 }
102 arg_data_ptr_ = prop->BackwardInputs(
103 out_grad_ptr, in_data_ptr, out_data_ptr);
104 }
105
~OperatorState()106 ~OperatorState() { delete opr_; }
107
Forward(const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)108 void Forward(const OpContext &ctx,
109 const std::vector<TBlob>& inputs,
110 const std::vector<OpReqType>& req,
111 const std::vector<TBlob>& outputs) {
112 CHECK_EQ(inputs.size(), in_data_fwd_.size() + aux_data_.size());
113 CHECK_EQ(outputs.size(), out_data_.size());
114 // in_data_bwd_ has the same tblobs as the ones in in_data_fwd_, except that the ones
115 // referred by arg_data_ptr_ will be overriden
116 for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_fwd_[i] = inputs[i];
117 for (size_t i = 0; i < in_data_fwd_.size(); ++i) in_data_bwd_[i] = inputs[i];
118 for (size_t i = 0; i < aux_data_.size(); ++i) {
119 aux_data_[i] = inputs[i + in_data_fwd_.size()];
120 }
121 for (size_t i = 0; i < out_data_.size(); ++i) out_data_[i] = outputs[i];
122 opr_->Forward(ctx, in_data_fwd_, req, out_data_, aux_data_);
123 }
124
Backward(const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)125 void Backward(const OpContext &ctx,
126 const std::vector<TBlob>& inputs,
127 const std::vector<OpReqType>& req,
128 const std::vector<TBlob>& outputs) {
129 CHECK_EQ(arg_data_ptr_.size() + aux_data_.size(), inputs.size());
130 // override tblobs pointed by arg_data_ptr_ since they might not contain
131 // initialized data during forward pass.
132 for (size_t i = 0; i < arg_data_ptr_.size(); ++i) {
133 *arg_data_ptr_[i] = inputs[i];
134 }
135 for (size_t i = 0; i < aux_data_.size(); ++i) {
136 aux_data_[i] = inputs[inputs.size() - aux_data_.size() + i];
137 }
138 CHECK_EQ(outputs.size(), in_grad_.size());
139 for (size_t i = 0; i < outputs.size(); ++i) in_grad_[i] = outputs[i];
140 opr_->Backward(ctx, out_grad_, in_data_bwd_, out_data_, req, in_grad_, aux_data_);
141 }
142
143 private:
144 Operator *opr_;
145 // input data blobs for forward and backward
146 // in_data_fwd_ and in_data_bwd_ will hold different tblobs when StorageFallbackOpExecutor
147 // performs storage fallback on a non-default input NDArray. The one in in_data_fwd_ is
148 // generated when setting up forward executor, while the one in in_data_bwd_ is generated
149 // when setting up backward executor.
150 std::vector<TBlob> in_data_fwd_, in_data_bwd_;
151 std::vector<TBlob> aux_data_, out_data_, in_grad_, out_grad_;
152 std::vector<TBlob*> arg_data_ptr_;
153 };
154
LegacyOpForward(const OpStatePtr & state,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)155 void LegacyOpForward(const OpStatePtr& state,
156 const OpContext& ctx,
157 const std::vector<TBlob>& inputs,
158 const std::vector<OpReqType>& req,
159 const std::vector<TBlob>& outputs) {
160 auto& op = state.get_state<OperatorState>();
161 op.Forward(ctx, inputs, req, outputs);
162 }
163
LegacyOpBackward(const OpStatePtr & state,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)164 void LegacyOpBackward(const OpStatePtr& state,
165 const OpContext& ctx,
166 const std::vector<TBlob>& inputs,
167 const std::vector<OpReqType>& req,
168 const std::vector<TBlob>& outputs) {
169 auto& op = state.get_state<OperatorState>();
170 op.Backward(ctx, inputs, req, outputs);
171 }
172
173 // function to use operator property to infer attr
174 // get op property from the attribute
OpPropGetOpProperty(const NodeAttrs & attrs)175 const OperatorProperty* OpPropGetOpProperty(const NodeAttrs& attrs) {
176 return nnvm::get<ParsedOpProp>(attrs.parsed).ptr.get();
177 }
178
179 template<typename AttrType, typename FInfer>
OpPropInferAttr(const NodeAttrs & attrs,std::vector<AttrType> * iattr,std::vector<AttrType> * oattr,FInfer finfer)180 bool OpPropInferAttr(const NodeAttrs& attrs,
181 std::vector<AttrType> *iattr,
182 std::vector<AttrType> *oattr,
183 FInfer finfer) {
184 auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
185 CHECK_EQ(prop.inputs.size(), iattr->size())
186 << "op=" << attrs.op->name
187 << ", inputs.size=" << prop.inputs.size()
188 << ", iattr.size=" << iattr->size()
189 << ", arg.size=" << prop.arguments.size();
190 std::vector<AttrType> in_attr(prop.arguments.size());
191 std::vector<AttrType> aux_attr(prop.aux_states.size());
192
193 for (size_t i = 0; i < prop.arguments.size(); ++i) {
194 in_attr[i] = (*iattr)[i];
195 }
196 for (size_t i = 0; i < prop.aux_states.size(); ++i) {
197 aux_attr[i] = (*iattr)[i + prop.arguments.size()];
198 }
199 if (!finfer(prop.ptr.get(), &in_attr, oattr, &aux_attr)) return false;
200
201 for (size_t i = 0; i < prop.arguments.size(); ++i) {
202 (*iattr)[i] = in_attr[i];
203 }
204 for (size_t i = 0; i < prop.aux_states.size(); ++i) {
205 (*iattr)[i + prop.arguments.size()] = aux_attr[i];
206 }
207 return true;
208 }
209
OpPropInferShape(const NodeAttrs & attrs,mxnet::ShapeVector * iattr,mxnet::ShapeVector * oattr)210 bool OpPropInferShape(const NodeAttrs& attrs,
211 mxnet::ShapeVector *iattr,
212 mxnet::ShapeVector *oattr) {
213 auto finfer = [](const OperatorProperty* op,
214 mxnet::ShapeVector *in,
215 mxnet::ShapeVector *out,
216 mxnet::ShapeVector *aux) {
217 return op->InferShape(in, out, aux);
218 };
219 return OpPropInferAttr(attrs, iattr, oattr, finfer);
220 }
221
OpPropInferType(const NodeAttrs & attrs,std::vector<int> * iattr,std::vector<int> * oattr)222 bool OpPropInferType(const NodeAttrs& attrs,
223 std::vector<int> *iattr,
224 std::vector<int> *oattr) {
225 auto finfer = [](const OperatorProperty* op,
226 std::vector<int> *in,
227 std::vector<int> *out,
228 std::vector<int> *aux) {
229 return op->InferType(in, out, aux);
230 };
231 return OpPropInferAttr(attrs, iattr, oattr, finfer);
232 }
233
OpPropNumInputs(const NodeAttrs & attrs)234 inline uint32_t OpPropNumInputs(const NodeAttrs& attrs) {
235 auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
236 return static_cast<uint32_t>(prop.inputs.size());
237 }
238
OpPropNumOutputs(const NodeAttrs & attrs)239 inline uint32_t OpPropNumOutputs(const NodeAttrs& attrs) {
240 auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
241 return static_cast<uint32_t>(prop.outputs.size());
242 }
243
OpPropNumVisibleOutputs(const NodeAttrs & attrs)244 inline uint32_t OpPropNumVisibleOutputs(const NodeAttrs& attrs) {
245 auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
246 return static_cast<uint32_t>(prop.ptr->NumVisibleOutputs());
247 }
248
OpPropListInputNames(const NodeAttrs & attrs)249 std::vector<std::string> OpPropListInputNames(const NodeAttrs& attrs) {
250 auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
251 return prop.inputs;
252 }
253
OpPropListOutputNames(const NodeAttrs & attrs)254 std::vector<std::string> OpPropListOutputNames(const NodeAttrs& attrs) {
255 auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
256 return prop.outputs;
257 }
258
OpPropMutateInputs(const NodeAttrs & attrs)259 std::vector<uint32_t> OpPropMutateInputs(const NodeAttrs& attrs) {
260 auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
261 std::vector<uint32_t> ret;
262 for (uint32_t i = 0; i < prop.aux_states.size(); ++i) {
263 ret.push_back(static_cast<uint32_t>(i + prop.arguments.size()));
264 }
265 return ret;
266 }
267
OpPropInplaceOption(const NodeAttrs & attrs)268 std::vector<std::pair<int, int> > OpPropInplaceOption(const NodeAttrs& attrs) {
269 auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
270 std::vector<int> in_data(prop.arguments.size());
271 std::vector<int> out_data(prop.outputs.size());
272 std::vector<void*> out_addr(prop.outputs.size());
273 for (size_t i = 0; i < in_data.size(); ++i) {
274 in_data[i] = static_cast<int>(i);
275 }
276 for (size_t i = 0; i < out_data.size(); ++i) {
277 out_data[i] = static_cast<int>(i);
278 out_addr[i] = &out_data[i];
279 }
280 std::vector<std::pair<int, int> > forward_inplace;
281 for (auto& kv : prop.ptr->ForwardInplaceOption(in_data, out_addr)) {
282 forward_inplace.emplace_back(kv.first, *static_cast<int*>(kv.second));
283 }
284 return forward_inplace;
285 }
286
OpPropResourceRequest(const NodeAttrs & attrs)287 std::vector<ResourceRequest> OpPropResourceRequest(const NodeAttrs& attrs) {
288 mxnet::ShapeVector ishape;
289 auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
290 return prop.ptr->ForwardResource(ishape);
291 }
292
OpBackResourceRequest(const NodeAttrs & attrs)293 std::vector<ResourceRequest> OpBackResourceRequest(const NodeAttrs& attrs) {
294 mxnet::ShapeVector ishape;
295 auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
296 return prop.ptr->BackwardResource(ishape);
297 }
298
OpPropCreateLayerOp(const NodeAttrs & attrs,Context ctx,const mxnet::ShapeVector & ishape,const std::vector<int> & itype)299 OpStatePtr OpPropCreateLayerOp(const NodeAttrs& attrs,
300 Context ctx,
301 const mxnet::ShapeVector& ishape,
302 const std::vector<int>& itype) {
303 auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
304 mxnet::ShapeVector is(ishape.begin(), ishape.begin() + prop.arguments.size());
305 std::vector<int> it(itype.begin(), itype.begin() + prop.arguments.size());
306 return OpStatePtr::Create<OperatorState>(prop.ptr->CreateOperatorEx(ctx, &is, &it),
307 prop.ptr.get());
308 }
309
OpPropGradient(const Op * back_op,const ObjectPtr & ptr,const std::vector<NodeEntry> & out_grads)310 inline std::vector<NodeEntry> OpPropGradient(
311 const Op* back_op,
312 const ObjectPtr& ptr,
313 const std::vector<NodeEntry>& out_grads) {
314 auto& prop = nnvm::get<ParsedOpProp>(ptr->attrs.parsed);
315 std::vector<NodeEntry> out_data;
316 out_data.reserve(prop.outputs.size());
317 for (size_t i = 0; i < prop.outputs.size(); ++i)
318 out_data.emplace_back(ptr, i, 0);
319
320 std::vector<NodeEntry> in_data(
321 ptr->inputs.begin(), ptr->inputs.begin() + prop.arguments.size());
322 std::vector<NodeEntry> ograd(
323 out_grads.begin(), out_grads.begin() + prop.ptr->NumVisibleOutputs());
324 auto inputs = prop.ptr->BackwardInputs(ograd, in_data, out_data);
325 // add all the auxiliary data
326 for (size_t i = 0; i < prop.aux_states.size(); ++i) {
327 inputs.emplace_back(ptr->inputs[i + prop.arguments.size()]);
328 }
329 ObjectPtr gnode = Node::Create();
330 gnode->inputs = std::move(inputs);
331 gnode->control_deps.emplace_back(ptr);
332 gnode->attrs = ptr->attrs;
333 gnode->attrs.op = back_op;
334 gnode->attrs.name = ptr->attrs.name + "_backward";
335 std::vector<NodeEntry> in_grad;
336 in_grad.reserve(prop.arguments.size() + prop.aux_states.size());
337 for (size_t i = 0; i < prop.arguments.size(); ++i) {
338 in_grad.emplace_back(gnode, i, 0);
339 }
340 // attach no gradient node to forbid gradient on aux_state
341 if (prop.aux_states.size() != 0) {
342 for (size_t i = 0; i < prop.aux_states.size(); ++i) {
343 in_grad.emplace_back(Node::Create(Op::Get("_NoGradient"), "NoGradient"), 0, 0);
344 }
345 }
346 return in_grad;
347 }
348
OpBackNumOutputs(const NodeAttrs & attrs)349 inline uint32_t OpBackNumOutputs(const NodeAttrs& attrs) {
350 auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
351 return static_cast<uint32_t>(prop.arguments.size());
352 }
353
OpBackListOutputNames(const NodeAttrs & attrs)354 std::vector<std::string> OpBackListOutputNames(const NodeAttrs& attrs) {
355 auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
356 return prop.arguments;
357 }
358
OpBackMutateInputs(const NodeAttrs & attrs)359 std::vector<uint32_t> OpBackMutateInputs(const NodeAttrs& attrs) {
360 auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
361 if (prop.aux_states.size() == 0) return std::vector<uint32_t>{};
362 std::vector<int> out_grad_index(prop.ptr->NumVisibleOutputs());
363 std::vector<int> in_data_index(prop.arguments.size());
364 std::vector<int> out_data_index(prop.outputs.size());
365 size_t arg_size = prop.ptr->DeclareBackwardDependency(
366 out_grad_index, in_data_index, out_data_index).size();
367 std::vector<uint32_t> ret;
368 for (uint32_t i = 0; i < prop.aux_states.size(); ++i) {
369 ret.push_back(static_cast<uint32_t>(i + arg_size));
370 }
371 return ret;
372 }
373
OpBackInplaceOption(const NodeAttrs & attrs)374 std::vector<std::pair<int, int> > OpBackInplaceOption(const NodeAttrs& attrs) {
375 auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
376 std::vector<int> out_grad_index(prop.ptr->NumVisibleOutputs());
377 std::vector<int> in_data_index(prop.arguments.size());
378 std::vector<int> out_data_index(prop.outputs.size());
379
380 int counter = 0;
381 for (const int& i : in_data_index) {
382 in_data_index[i] = counter++;
383 }
384 for (const int& i : out_grad_index) {
385 out_grad_index[i] = counter++;
386 }
387 for (const int& i : out_data_index) {
388 out_data_index[i] = counter++;
389 }
390
391 auto args_index = prop.ptr->DeclareBackwardDependency(
392 out_grad_index, in_data_index, out_data_index);
393 std::vector<int> args_array(counter, -1);
394 for (size_t i = 0; i < args_index.size(); ++i) {
395 args_array[args_index[i]] = static_cast<int>(i);
396 }
397
398 std::vector<void*> in_grad_ptr(in_data_index.size());
399 for (size_t i = 0; i < in_grad_ptr.size(); ++i) {
400 // in data index starts from 0 to num_inputs
401 in_grad_ptr[i] = (void*)&in_data_index[i]; // NOLINT(*)
402 }
403
404 auto remap_index = prop.ptr->BackwardInplaceOption(
405 out_grad_index, in_data_index, out_data_index, in_grad_ptr);
406 std::vector<std::pair<int, int> > remap(remap_index.size());
407 for (size_t i = 0; i < remap_index.size(); ++i) {
408 if (args_array[remap_index[i].first] == -1) {
409 LOG(FATAL) << "BackwardInplaceOption not consistent with DeclareBackwardDependency";
410 }
411 remap[i].first = args_array[remap_index[i].first];
412 remap[i].second = *static_cast<int*>(remap_index[i].second);
413 }
414 return remap;
415 }
416
OpExecType(const NodeAttrs & attrs)417 inline ExecType OpExecType(const NodeAttrs& attrs) {
418 auto& prop = nnvm::get<ParsedOpProp>(attrs.parsed);
419 return prop.ptr->exec_type();
420 }
421
422 // register the legacy operator properties under NNVM registry.
RegisterLegacyOpProp()423 void RegisterLegacyOpProp() {
424 for (auto reg : dmlc::Registry<OperatorPropertyReg>::List()) {
425 Op& op = ::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(reg->name);
426 if (op.attr_parser != nullptr) continue;
427 auto creator = reg->body;
428 auto attr_parser = [creator](NodeAttrs* attrs) {
429 if (attrs->parsed.empty()) {
430 ParsedOpProp op;
431 op.ptr.reset(creator());
432 op.Init(*attrs);
433 attrs->parsed = std::move(op);
434 }
435 };
436 op.add_arguments(reg->arguments);
437 op.describe(reg->description);
438 // attribute parser
439 op.set_attr_parser(attr_parser);
440 op.set_num_inputs(OpPropNumInputs);
441 op.set_num_outputs(OpPropNumOutputs);
442 op.set_attr<nnvm::FListInputNames>("FListInputNames", OpPropListInputNames);
443 op.set_attr<nnvm::FListOutputNames>("FListOutputNames", OpPropListOutputNames);
444 op.set_attr<nnvm::FNumVisibleOutputs>("FNumVisibleOutputs", OpPropNumVisibleOutputs);
445 op.set_attr<mxnet::FInferShape>("FInferShape", OpPropInferShape);
446 op.set_attr<nnvm::FInferType>("FInferType", OpPropInferType);
447 op.set_attr<nnvm::FMutateInputs>("FMutateInputs", OpPropMutateInputs);
448 op.set_attr<nnvm::FInplaceOption>("FInplaceOption", OpPropInplaceOption);
449 op.set_attr<FResourceRequest>("FResourceRequest", OpPropResourceRequest);
450 op.set_attr<FExecType>("FExecType", OpExecType);
451 op.set_attr<FCreateOpState>("FCreateOpState", OpPropCreateLayerOp);
452 op.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", LegacyOpForward);
453 op.set_attr<FStatefulCompute>("FStatefulCompute<gpu>", LegacyOpForward);
454 if (reg->key_var_num_args.length() != 0) {
455 op.set_attr<std::string>("key_var_num_args", reg->key_var_num_args);
456 }
457
458 // register BackwardOps
459 std::string back_op_name = "_backward_" + reg->name;
460 Op& back_op = ::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER__(back_op_name);
461 op.set_attr<nnvm::FGradient>("FGradient", std::bind(
462 OpPropGradient, &back_op,
463 std::placeholders::_1, std::placeholders::_2));
464 back_op.set_attr_parser(attr_parser);
465 back_op.set_num_inputs(nnvm::kVarg);
466 back_op.set_num_outputs(OpBackNumOutputs);
467 back_op.set_attr<nnvm::FListOutputNames>("FListOutputNames", OpBackListOutputNames);
468 back_op.set_attr<nnvm::FMutateInputs>("FMutateInputs", OpBackMutateInputs);
469 back_op.set_attr<nnvm::FInplaceOption>("FInplaceOption", OpBackInplaceOption);
470 back_op.set_attr<FResourceRequest>(
471 "FResourceRequest", OpBackResourceRequest);
472 back_op.set_attr<bool>("TIsLayerOpBackward", true);
473 back_op.set_attr<bool>("TIsBackward", true);
474 back_op.set_attr<FExecType>("FExecType", OpExecType);
475 back_op.set_attr<FStatefulCompute>("FStatefulCompute<cpu>", LegacyOpBackward);
476 back_op.set_attr<FStatefulCompute>("FStatefulCompute<gpu>", LegacyOpBackward);
477 }
478 }
479
480 // no gradient operator
481 NNVM_REGISTER_OP(_NoGradient)
482 .set_num_inputs(0)
483 .set_num_outputs(1)
484 .describe("Place holder for variable who cannot perform gradient");
485
RegisterLegacyNDFunc()486 void RegisterLegacyNDFunc() {
487 for (auto reg : dmlc::Registry<NDArrayFunctionReg>::List()) {
488 if (reg->type_mask & kScalarArgBeforeNDArray) continue;
489 Op& op = ::dmlc::Registry<::nnvm::Op>::Get()->__REGISTER_OR_GET__(reg->name);
490 if (op.attr_parser != nullptr) continue;
491
492 CHECK_LE(reg->num_scalars + reg->num_use_vars, reg->arguments.size())
493 << reg->name;
494 auto func = reg->body;
495 op.describe(reg->description);
496 op.add_arguments(reg->arguments);
497 op.set_num_inputs(reg->num_use_vars);
498 op.set_num_outputs(reg->num_mutate_vars);
499 op.set_attr_parser([](NodeAttrs* attrs){});
500 op.set_attr<FNDArrayFunction>("FNDArrayFunction", [reg](const nnvm::NodeAttrs& attrs,
501 const std::vector<NDArray>& inputs,
502 std::vector<NDArray>* outputs) {
503 CHECK_EQ(inputs.size(), reg->num_use_vars);
504 CHECK_EQ(outputs->size(), reg->num_mutate_vars);
505
506 int n_scalars = reg->num_scalars;
507 std::vector<float> scalars;
508 scalars.reserve(n_scalars);
509 auto dict = attrs.dict;
510 for (int i = 0; i < n_scalars; ++i) {
511 const std::string& name = reg->arguments[i+reg->num_use_vars].name;
512 auto s = dict.find(name);
513 CHECK(s != dict.end()) << "Missing scalar param " << name;
514 scalars.push_back(dmlc::stof(s->second));
515 dict.erase(s);
516 }
517
518 int n_params = dict.size();
519 std::vector<const char*> keys, vals;
520 keys.reserve(n_params);
521 vals.reserve(n_params);
522 for (auto& i : dict) {
523 keys.push_back(dmlc::BeginPtr(i.first));
524 vals.push_back(dmlc::BeginPtr(i.second));
525 }
526 std::vector<NDArray*> input_ptrs, output_ptrs;
527 for (auto& i : inputs) {
528 input_ptrs.push_back(const_cast<NDArray*>(&i));
529 }
530 for (auto& i : *outputs) {
531 output_ptrs.push_back(&i);
532 }
533 reg->body(input_ptrs.data(),
534 scalars.data(),
535 output_ptrs.data(),
536 n_params,
537 const_cast<char**>(keys.data()),
538 const_cast<char**>(vals.data()));
539 });
540 }
541 }
542
543 } // namespace op
544 } // namespace mxnet
545