1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file custom.cc
22 * \brief
23 * \author Junyuan Xie
24 */
25 #include <mxnet/c_api.h>
26 #include <mxnet/base.h>
27 #include <mxnet/ndarray.h>
28 #include <mxnet/imperative.h>
29
30 #include "./c_api_common.h"
31 #include "../operator/operator_common.h"
32 #include "../operator/custom/custom-inl.h"
33
34 namespace mxnet {
35 namespace custom_function {
36
37 struct CustomFunctionParam {
38 size_t num_args, num_outs;
39 std::shared_ptr<MXCallbackList> info;
40 std::vector<mxnet::TShape> out_shapes;
41 std::vector<int> out_dtypes;
42 };
43
Gradient(const nnvm::ObjectPtr & n,const std::vector<nnvm::NodeEntry> & out_grads)44 std::vector<nnvm::NodeEntry> Gradient(
45 const nnvm::ObjectPtr& n,
46 const std::vector<nnvm::NodeEntry>& out_grads) {
47 const CustomFunctionParam& params = nnvm::get<CustomFunctionParam>(n->attrs.parsed);
48
49 nnvm::ObjectPtr g = nnvm::Node::Create();
50 g->attrs.op = nnvm::Op::Get("_backward_CustomFunction");
51 g->attrs.name = n->attrs.name + "_backward";
52 g->attrs.parsed = params;
53 g->control_deps.emplace_back(n);
54
55 g->inputs = out_grads;
56
57 std::vector<nnvm::NodeEntry> ret;
58 for (uint32_t i = 0; i < g->num_outputs(); ++i) {
59 ret.emplace_back(g, i, 0);
60 }
61
62 return ret;
63 }
64
CreateState(const nnvm::NodeAttrs & attrs,Context ctx,const mxnet::ShapeVector & ishape,const std::vector<int> & itype)65 OpStatePtr CreateState(const nnvm::NodeAttrs& attrs,
66 Context ctx,
67 const mxnet::ShapeVector& ishape,
68 const std::vector<int>& itype) {
69 LOG(FATAL) << "Not reached";
70 return OpStatePtr::Create<void*>(nullptr);
71 }
72
Forward(const OpStatePtr & state,const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)73 void Forward(const OpStatePtr& state,
74 const OpContext& ctx,
75 const std::vector<NDArray>& inputs,
76 const std::vector<OpReqType>& req,
77 const std::vector<NDArray>& outputs) {
78 LOG(FATAL) << "Not reached";
79 }
80
Backward(const OpStatePtr & state,const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)81 void Backward(const OpStatePtr& state,
82 const OpContext& ctx,
83 const std::vector<NDArray>& inputs,
84 const std::vector<OpReqType>& req,
85 const std::vector<NDArray>& outputs) {
86 const CustomFunctionParam& params = state.get_state<CustomFunctionParam>();
87
88 std::vector<NDArrayHandle> ptrs;
89 std::vector<NDArray> cpys;
90 std::vector<int> tags;
91 std::unordered_set<int> input_tags({0});
92 std::unordered_set<int> output_tags({1});
93
94 auto dev_id = ctx.run_ctx.ctx.dev_id;
95
96 for (const auto& i : inputs) {
97 NDArray* nd = new NDArray(i.data(), dev_id);
98 ptrs.push_back(reinterpret_cast<NDArrayHandle>(nd));
99 cpys.push_back(*nd);
100 tags.push_back(0);
101 }
102 for (const auto& i : outputs) {
103 NDArray* nd = new NDArray(i.data(), dev_id);
104 ptrs.push_back(reinterpret_cast<NDArrayHandle>(nd));
105 cpys.push_back(*nd);
106 tags.push_back(1);
107 }
108
109 op::custom::CustomOperator::Get()->Push(
110 [=]() {
111 CHECK(reinterpret_cast<CustomFunctionBwdFunc>(
112 params.info->callbacks[kCustomFunctionBackward])(
113 inputs.size(), outputs.size(),
114 const_cast<NDArrayHandle*>(ptrs.data()),
115 reinterpret_cast<const int*>(req.data()), ctx.is_train,
116 params.info->contexts[kCustomFunctionBackward]));
117 }, ctx, false, ctx.is_train, cpys, tags, output_tags, outputs);
118 }
119
InferStorageType(const nnvm::NodeAttrs & attrs,const int dev_mask,DispatchMode * dispatch_mode,std::vector<int> * iattr,std::vector<int> * oattr)120 inline bool InferStorageType(const nnvm::NodeAttrs& attrs, const int dev_mask,
121 DispatchMode* dispatch_mode,
122 std::vector<int>* iattr, std::vector<int>* oattr) {
123 using namespace op;
124
125 for (size_t i = 0; i < iattr->size(); ++i) {
126 STORAGE_TYPE_ASSIGN_CHECK(*iattr, i, kDefaultStorage);
127 }
128 for (size_t i = 0; i < oattr->size(); ++i) {
129 STORAGE_TYPE_ASSIGN_CHECK(*oattr, i, kDefaultStorage);
130 }
131 DISPATCH_MODE_ASSIGN_CHECK(dispatch_mode, 0, DispatchMode::kFComputeEx);
132 return true;
133 }
134
135 NNVM_REGISTER_OP(_CustomFunction)
__anon04bb31290202(const NodeAttrs& attrs) 136 .set_num_inputs([](const NodeAttrs& attrs) {
137 const CustomFunctionParam& params = nnvm::get<CustomFunctionParam>(attrs.parsed);
138 return params.num_args;
139 })
__anon04bb31290302(const NodeAttrs& attrs) 140 .set_num_outputs([](const NodeAttrs& attrs) {
141 const CustomFunctionParam& params = nnvm::get<CustomFunctionParam>(attrs.parsed);
142 return params.num_outs;
143 })
144 .set_attr<mxnet::FInferShape>("FInferShape",
145 [](const NodeAttrs& attrs, mxnet::ShapeVector *in_shape,
__anon04bb31290402(const NodeAttrs& attrs, mxnet::ShapeVector *in_shape, mxnet::ShapeVector *out_shape) 146 mxnet::ShapeVector *out_shape) {
147 const CustomFunctionParam& params = nnvm::get<CustomFunctionParam>(attrs.parsed);
148 *out_shape = params.out_shapes;
149 return true;
150 })
151 .set_attr<nnvm::FInferType>("FInferType",
152 [](const NodeAttrs& attrs, std::vector<int> *in_type,
__anon04bb31290502(const NodeAttrs& attrs, std::vector<int> *in_type, std::vector<int> *out_type) 153 std::vector<int> *out_type) {
154 const CustomFunctionParam& params = nnvm::get<CustomFunctionParam>(attrs.parsed);
155 *out_type = params.out_dtypes;
156 return true;
157 })
158 .set_attr<FCreateOpState>("FCreateOpState", CreateState)
159 .set_attr<nnvm::FGradient>("FGradient", Gradient)
160 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", Forward)
161 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", Forward)
162 .set_attr<FInferStorageType>("FInferStorageType", InferStorageType);
163
164
165 NNVM_REGISTER_OP(_backward_CustomFunction)
__anon04bb31290602(const NodeAttrs& attrs) 166 .set_num_inputs([](const NodeAttrs& attrs) {
167 const CustomFunctionParam& params = nnvm::get<CustomFunctionParam>(attrs.parsed);
168 return params.num_outs;
169 })
__anon04bb31290702(const NodeAttrs& attrs) 170 .set_num_outputs([](const NodeAttrs& attrs) {
171 const CustomFunctionParam& params = nnvm::get<CustomFunctionParam>(attrs.parsed);
172 return params.num_args;
173 })
174 .set_attr<bool>("TIsBackward", true)
175 .set_attr<bool>("TIsLayerOpBackward", true)
__anon04bb31290802(const NodeAttrs& attrs) 176 .set_attr<FExecType>("FExecType", [](const NodeAttrs& attrs) {
177 return ExecType::kAsync;
178 })
179 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<cpu>", Backward)
180 .set_attr<FStatefulComputeEx>("FStatefulComputeEx<gpu>", Backward)
181 .set_attr<FInferStorageType>("FInferStorageType", InferStorageType);
182
183 } // namespace custom_function
184 } // namespace mxnet
185
MXCustomFunctionRecord(int num_inputs,NDArrayHandle * inputs,int num_outputs,NDArrayHandle * outputs,MXCallbackList * callbacks)186 int MXCustomFunctionRecord(int num_inputs, NDArrayHandle *inputs,
187 int num_outputs, NDArrayHandle *outputs,
188 MXCallbackList *callbacks) {
189 using namespace mxnet;
190 using namespace mxnet::custom_function;
191 API_BEGIN();
192 CHECK(Imperative::Get()->is_recording());
193 auto state = OpStatePtr::Create<CustomFunctionParam>();
194 CustomFunctionParam& params = state.get_state<CustomFunctionParam>();
195 params.num_args = num_inputs;
196 params.num_outs = num_outputs;
197 params.info.reset(callbacks, [](MXCallbackList* ptr){
198 reinterpret_cast<CustomFunctionDelFunc>(ptr->callbacks[kCustomFunctionDelete])(
199 ptr->contexts[kCustomFunctionDelete]);
200 });
201 std::vector<NDArray*> ndinputs, ndoutputs;
202 ndinputs.reserve(num_inputs);
203 ndoutputs.reserve(num_outputs);
204 params.out_shapes.reserve(num_outputs);
205 params.out_dtypes.reserve(num_outputs);
206 for (int i = 0; i < num_inputs; ++i) {
207 ndinputs.emplace_back(reinterpret_cast<NDArray*>(inputs[i]));
208 }
209 for (int i = 0; i < num_outputs; ++i) {
210 NDArray* arr = reinterpret_cast<NDArray*>(outputs[i]);
211 ndoutputs.emplace_back(arr);
212 params.out_shapes.emplace_back(arr->shape());
213 params.out_dtypes.emplace_back(arr->dtype());
214 }
215 nnvm::NodeAttrs attrs;
216 attrs.op = nnvm::Op::Get("_CustomFunction");
217 attrs.parsed = params;
218 Imperative::Get()->RecordOp(
219 std::move(attrs), ndinputs, ndoutputs, state);
220
221 API_END();
222 }
223