1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file custom.cc
22  * \brief
23  * \author Junyuan Xie
24 */
25 #include <mxnet/c_api.h>
26 #include <mxnet/base.h>
27 #include <mxnet/ndarray.h>
28 #include <mxnet/imperative.h>
29 
30 #include "./c_api_common.h"
31 #include "../operator/operator_common.h"
32 #include "../operator/custom/custom-inl.h"
33 
34 namespace mxnet {
35 namespace custom_function {
36 
37 struct CustomFunctionParam {
38   size_t num_args, num_outs;
39   std::shared_ptr<MXCallbackList> info;
40   std::vector<mxnet::TShape> out_shapes;
41   std::vector<int> out_dtypes;
42 };
43 
Gradient(const nnvm::ObjectPtr & n,const std::vector<nnvm::NodeEntry> & out_grads)44 std::vector<nnvm::NodeEntry> Gradient(
45     const nnvm::ObjectPtr& n,
46     const std::vector<nnvm::NodeEntry>& out_grads) {
47   const CustomFunctionParam& params = nnvm::get<CustomFunctionParam>(n->attrs.parsed);
48 
49   nnvm::ObjectPtr g = nnvm::Node::Create();
50   g->attrs.op = nnvm::Op::Get("_backward_CustomFunction");
51   g->attrs.name = n->attrs.name + "_backward";
52   g->attrs.parsed = params;
53   g->control_deps.emplace_back(n);
54 
55   g->inputs = out_grads;
56 
57   std::vector<nnvm::NodeEntry> ret;
58   for (uint32_t i = 0; i < g->num_outputs(); ++i) {
59     ret.emplace_back(g, i, 0);
60   }
61 
62   return ret;
63 }
64 
CreateState(const nnvm::NodeAttrs & attrs,Context ctx,const mxnet::ShapeVector & ishape,const std::vector<int> & itype)65 OpStatePtr CreateState(const nnvm::NodeAttrs& attrs,
66                        Context ctx,
67                        const mxnet::ShapeVector& ishape,
68                        const std::vector<int>& itype) {
69   LOG(FATAL) << "Not reached";
70   return OpStatePtr::Create<void*>(nullptr);
71 }
72 
Forward(const OpStatePtr & state,const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)73 void Forward(const OpStatePtr& state,
74              const OpContext& ctx,
75              const std::vector<NDArray>& inputs,
76              const std::vector<OpReqType>& req,
77              const std::vector<NDArray>& outputs) {
78   LOG(FATAL) << "Not reached";
79 }
80 
Backward(const OpStatePtr & state,const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)81 void Backward(const OpStatePtr& state,
82               const OpContext& ctx,
83               const std::vector<NDArray>& inputs,
84               const std::vector<OpReqType>& req,
85               const std::vector<NDArray>& outputs) {
86   const CustomFunctionParam& params = state.get_state<CustomFunctionParam>();
87 
88   std::vector<NDArrayHandle> ptrs;
89   std::vector<NDArray> cpys;
90   std::vector<int> tags;
91   std::unordered_set<int> input_tags({0});
92   std::unordered_set<int> output_tags({1});
93 
94   auto dev_id = ctx.run_ctx.ctx.dev_id;
95 
96   for (const auto& i : inputs) {
97     NDArray* nd = new NDArray(i.data(), dev_id);
98     ptrs.push_back(reinterpret_cast<NDArrayHandle>(nd));
99     cpys.push_back(*nd);
100     tags.push_back(0);
101   }
102   for (const auto& i : outputs) {
103     NDArray* nd = new NDArray(i.data(), dev_id);
104     ptrs.push_back(reinterpret_cast<NDArrayHandle>(nd));
105     cpys.push_back(*nd);
106     tags.push_back(1);
107   }
108 
109   op::custom::CustomOperator::Get()->Push(
110     [=]() {
111       CHECK(reinterpret_cast<CustomFunctionBwdFunc>(
112           params.info->callbacks[kCustomFunctionBackward])(
113               inputs.size(), outputs.size(),
114               const_cast<NDArrayHandle*>(ptrs.data()),
115               reinterpret_cast<const int*>(req.data()), ctx.is_train,
116               params.info->contexts[kCustomFunctionBackward]));
117     }, ctx, false, ctx.is_train, cpys, tags, output_tags, outputs);
118 }
119 
InferStorageType(const nnvm::NodeAttrs & attrs,const int dev_mask,DispatchMode * dispatch_mode,std::vector<int> * iattr,std::vector<int> * oattr)120 inline bool InferStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask,
121                              DispatchMode* dispatch_mode,
122                              std::vector<int>* iattr, std::vector<int>* oattr) {
123   using namespace op;
124 
125   for (size_t i = 0; i < iattr->size(); ++i) {
126     STORAGE_TYPE_ASSIGN_CHECK(*iattr, i, kDefaultStorage);
127   }
128   for (size_t i = 0; i < oattr->size(); ++i) {
129     STORAGE_TYPE_ASSIGN_CHECK(*oattr, i, kDefaultStorage);
130   }
131   DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
132   return true;
133 }
134 
135 NNVM_REGISTER_OP(_CustomFunction)
__anon04bb31290202(const NodeAttrs& attrs) 136 .set_num_inputs([](const NodeAttrs& attrs) {
137     const CustomFunctionParam& params = nnvm::get<CustomFunctionParam>(attrs.parsed);
138     return params.num_args;
139   })
__anon04bb31290302(const NodeAttrs& attrs) 140 .set_num_outputs([](const NodeAttrs& attrs) {
141     const CustomFunctionParam& params = nnvm::get<CustomFunctionParam>(attrs.parsed);
142     return params.num_outs;
143   })
144 .set_attr<mxnet::FInferShape>("FInferShape",
145   [](const NodeAttrs& attrs, mxnet::ShapeVector *in_shape,
__anon04bb31290402(const NodeAttrs& attrs, mxnet::ShapeVector *in_shape, mxnet::ShapeVector *out_shape) 146      mxnet::ShapeVector *out_shape) {
147     const CustomFunctionParam& params = nnvm::get<CustomFunctionParam>(attrs.parsed);
148     *out_shape = params.out_shapes;
149     return true;
150   })
151 .set_attr<nnvm::FInferType>("FInferType",
152   [](const NodeAttrs& attrs, std::vector<int> *in_type,
__anon04bb31290502(const NodeAttrs& attrs, std::vector<int> *in_type, std::vector<int> *out_type) 153      std::vector<int> *out_type) {
154     const CustomFunctionParam& params = nnvm::get<CustomFunctionParam>(attrs.parsed);
155     *out_type = params.out_dtypes;
156     return true;
157   })
158 .set_attr<FCreateOpState>("FCreateOpState", CreateState)
159 .set_attr<nnvm::FGradient>("FGradient", Gradient)
160 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", Forward)
161 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", Forward)
162 .set_attr<FInferStorageType>("FInferStorageType", InferStorageType);
163 
164 
165 NNVM_REGISTER_OP(_backward_CustomFunction)
__anon04bb31290602(const NodeAttrs& attrs) 166 .set_num_inputs([](const NodeAttrs& attrs) {
167     const CustomFunctionParam& params = nnvm::get<CustomFunctionParam>(attrs.parsed);
168     return params.num_outs;
169   })
__anon04bb31290702(const NodeAttrs& attrs) 170 .set_num_outputs([](const NodeAttrs& attrs) {
171     const CustomFunctionParam& params = nnvm::get<CustomFunctionParam>(attrs.parsed);
172     return params.num_args;
173   })
174 .set_attr<bool>("TIsBackward", true)
175 .set_attr<bool>("TIsLayerOpBackward", true)
__anon04bb31290802(const NodeAttrs& attrs) 176 .set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
177     return ExecType::kAsync;
178   })
179 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", Backward)
180 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", Backward)
181 .set_attr<FInferStorageType>("FInferStorageType", InferStorageType);
182 
183 }  // namespace custom_function
184 }  // namespace mxnet
185 
MXCustomFunctionRecord(int num_inputs,NDArrayHandle * inputs,int num_outputs,NDArrayHandle * outputs,MXCallbackList * callbacks)186 int MXCustomFunctionRecord(int num_inputs, NDArrayHandle *inputs,
187                            int num_outputs, NDArrayHandle *outputs,
188                            MXCallbackList *callbacks) {
189   using namespace mxnet;
190   using namespace mxnet::custom_function;
191   API_BEGIN();
192   CHECK(Imperative::Get()->is_recording());
193   auto state = OpStatePtr::Create<CustomFunctionParam>();
194   CustomFunctionParam& params = state.get_state<CustomFunctionParam>();
195   params.num_args = num_inputs;
196   params.num_outs = num_outputs;
197   params.info.reset(callbacks, [](MXCallbackList* ptr){
198       reinterpret_cast<CustomFunctionDelFunc>(ptr->callbacks[kCustomFunctionDelete])(
199         ptr->contexts[kCustomFunctionDelete]);
200     });
201   std::vector<NDArray*> ndinputs, ndoutputs;
202   ndinputs.reserve(num_inputs);
203   ndoutputs.reserve(num_outputs);
204   params.out_shapes.reserve(num_outputs);
205   params.out_dtypes.reserve(num_outputs);
206   for (int i = 0; i < num_inputs; ++i) {
207     ndinputs.emplace_back(reinterpret_cast<NDArray*>(inputs[i]));
208   }
209   for (int i = 0; i < num_outputs; ++i) {
210     NDArray* arr = reinterpret_cast<NDArray*>(outputs[i]);
211     ndoutputs.emplace_back(arr);
212     params.out_shapes.emplace_back(arr->shape());
213     params.out_dtypes.emplace_back(arr->dtype());
214   }
215   nnvm::NodeAttrs attrs;
216   attrs.op = nnvm::Op::Get("_CustomFunction");
217   attrs.parsed = params;
218   Imperative::Get()->RecordOp(
219       std::move(attrs), ndinputs, ndoutputs, state);
220 
221   API_END();
222 }
223