1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file fully_connected.cc
22 * \brief fully connect operator
23 */
24 #include "./fully_connected-inl.h"
25 #include "./mkldnn/mkldnn_ops-inl.h"
26 #include "./mkldnn/mkldnn_base-inl.h"
27 #if MXNET_USE_NNPACK == 1
28 #include "../nnpack/nnpack_fully_connected-inl.h"
29 #endif // MXNET_USE_NNPACK
30
31 namespace mxnet {
32 namespace op {
33
SupportMKLDNNFC(const NDArray & input)34 bool SupportMKLDNNFC(const NDArray& input) {
35 int ndim = input.shape().ndim();
36 return (input.dtype() == mshadow::kFloat32 || input.dtype() == mshadow::kBfloat16) &&
37 (ndim >= 1 && ndim <= 4) &&
38 input.storage_type() == kDefaultStorage;
39 }
40
FullyConnectedShape(const nnvm::NodeAttrs & attrs,mxnet::ShapeVector * in_shape,mxnet::ShapeVector * out_shape)41 static bool FullyConnectedShape(const nnvm::NodeAttrs& attrs,
42 mxnet::ShapeVector *in_shape,
43 mxnet::ShapeVector *out_shape) {
44 const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
45 using namespace mshadow;
46 if (!param.no_bias) {
47 CHECK_EQ(in_shape->size(), 3U) << "Input:[data, weight, bias]";
48 } else {
49 CHECK_EQ(in_shape->size(), 2U) << "Input:[data, weight]";
50 }
51 CHECK_EQ(out_shape->size(), 1U);
52 mxnet::TShape dshape = (*in_shape)[fullc::kData];
53 mxnet::TShape oshape = (*out_shape)[0];
54 // require data to be known
55 if (!mxnet::ndim_is_known(dshape)) return false;
56
57 index_t num_input;
58 if (!param.flatten) {
59 num_input = dshape[dshape.ndim()-1];
60 } else {
61 num_input = dshape.ProdShape(1, dshape.ndim());
62 }
63 SHAPE_ASSIGN_CHECK(*in_shape, fullc::kWeight, Shape2(param.num_hidden, num_input));
64 if (!param.no_bias) {
65 if (!shape_assign(&(*in_shape)[fullc::kBias], Shape1(param.num_hidden)) &&
66 !shape_assign(&(*in_shape)[fullc::kBias], Shape2(param.num_hidden, 1))) {
67 LOG(FATAL) << "Unexpected shape for bias " << (*in_shape)[fullc::kBias];
68 }
69 }
70
71 if (!param.flatten) {
72 mxnet::TShape result_shape(dshape);
73 result_shape[dshape.ndim()-1] = param.num_hidden;
74 SHAPE_ASSIGN_CHECK(*out_shape, 0, result_shape);
75 } else {
76 SHAPE_ASSIGN_CHECK(*out_shape, 0, Shape2(dshape[0], param.num_hidden));
77 }
78 if (oshape.ndim() > 0) {
79 dshape[0] = oshape[0];
80 SHAPE_ASSIGN_CHECK(*in_shape, fullc::kData, dshape);
81 }
82 return true;
83 }
84
FullyConnectedComputeExCPU(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)85 void FullyConnectedComputeExCPU(const nnvm::NodeAttrs& attrs,
86 const OpContext &ctx,
87 const std::vector<NDArray> &inputs,
88 const std::vector<OpReqType> &req,
89 const std::vector<NDArray> &outputs) {
90 const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
91 const bool valid_data = inputs[0].storage_type() == kDefaultStorage;
92 const bool valid_weight = inputs[1].storage_type() == kDefaultStorage ||
93 inputs[1].storage_type() == kRowSparseStorage;
94 const bool valid_out = outputs[0].storage_type() == kDefaultStorage;
95 bool valid_bias = true;
96 if (!param.no_bias) {
97 valid_bias = inputs[2].storage_type() == kDefaultStorage ||
98 inputs[2].storage_type() == kRowSparseStorage;
99 }
100 #if MXNET_USE_MKLDNN == 1
101 if (common::ContainsOnlyStorage(inputs, kDefaultStorage) &&
102 common::ContainsOnlyStorage(outputs, kDefaultStorage)) {
103 if (SupportMKLDNNFC(inputs[0])) {
104 MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
105 MKLDNNRun(MKLDNNFCForward, attrs, ctx, inputs, req, outputs);
106 MKLDNN_OPCHECK_RUN(FullyConnectedCompute<cpu>, attrs, ctx, inputs, req,
107 outputs);
108 } else {
109 FallBackCompute(FullyConnectedCompute<cpu>, attrs, ctx, inputs, req, outputs);
110 }
111 return;
112 } else if (valid_data && valid_weight && valid_bias && valid_out) {
113 // inputs
114 std::vector<NDArray> temp_ndarrays;
115 std::vector<TBlob> in_blobs;
116 for (const NDArray& in : inputs) {
117 // if ndarray is in default storage and MKLDNN is available,
118 // need to make sure cpu layout data is used, instead of MKL layout
119 if (in.storage_type() == kDefaultStorage) {
120 temp_ndarrays.push_back(in.Reorder2Default());
121 in_blobs.emplace_back(temp_ndarrays.back().data());
122 } else {
123 in_blobs.emplace_back(in.data());
124 }
125 }
126 // output
127 FullyConnectedCompute<cpu>(attrs, ctx, in_blobs, req, {outputs[0].data()});
128 } else {
129 LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
130 }
131 #else
132 if (valid_data && valid_weight && valid_bias && valid_out) {
133 std::vector<TBlob> in_blobs(inputs.size());
134 for (size_t i = 0; i < in_blobs.size(); i++) in_blobs[i] = inputs[i].data();
135 std::vector<TBlob> out_blobs(outputs.size());
136 for (size_t i = 0; i < out_blobs.size(); i++) out_blobs[i] = outputs[i].data();
137 FullyConnectedCompute<cpu>(attrs, ctx, in_blobs, req, out_blobs);
138 } else {
139 LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
140 }
141 #endif
142 }
143
144 #if MXNET_USE_MKLDNN == 1
FullyConnectedGradComputeExCPU(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)145 void FullyConnectedGradComputeExCPU(const nnvm::NodeAttrs& attrs,
146 const OpContext &ctx,
147 const std::vector<NDArray> &inputs,
148 const std::vector<OpReqType> &req,
149 const std::vector<NDArray> &outputs) {
150 if (SupportMKLDNNFC(inputs[0])) {
151 MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
152 MKLDNNRun(MKLDNNFCBackward, attrs, ctx, inputs, req, outputs);
153 MKLDNN_OPCHECK_RUN(FullyConnectedGradCompute<cpu>, attrs, ctx, inputs, req,
154 outputs);
155 return;
156 }
157 FallBackCompute(FullyConnectedGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
158 }
159 #endif
160
FullyConnectedType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_type,std::vector<int> * out_type)161 static bool FullyConnectedType(const nnvm::NodeAttrs& attrs,
162 std::vector<int> *in_type, std::vector<int> *out_type) {
163 CHECK_GE(in_type->size(), 1U);
164 return ElemwiseAttr<int, type_is_none, type_assign, true, type_string>(
165 attrs, in_type, out_type, -1);
166 }
167
168 struct FullyConnectedGrad {
169 const char *op_name;
operator ()mxnet::op::FullyConnectedGrad170 std::vector<nnvm::NodeEntry> operator()(const nnvm::ObjectPtr& n,
171 const std::vector<nnvm::NodeEntry>& ograds) const {
172 std::vector<nnvm::NodeEntry> heads(ograds.begin(), ograds.end());
173 heads.push_back(n->inputs[fullc::kData]);
174 heads.push_back(n->inputs[fullc::kWeight]);
175 return MakeGradNode(op_name, n, heads, n->attrs.dict);
176 }
177 };
178
179 struct FullyConnectedGradGrad {
180 const char *op_name;
operator ()mxnet::op::FullyConnectedGradGrad181 std::vector<nnvm::NodeEntry> operator()(const nnvm::ObjectPtr& n,
182 const std::vector<nnvm::NodeEntry>& ograds) const {
183 std::vector<nnvm::NodeEntry> heads(ograds.begin(), ograds.end());
184 heads.push_back(n->inputs[0]); // o_y : head gradient of the output y
185 return MakeGradNode(op_name, n, heads, n->attrs.dict);
186 }
187 };
188
FCStorageType(const nnvm::NodeAttrs & attrs,const int dev_mask,DispatchMode * dispatch_mode,std::vector<int> * in_attrs,std::vector<int> * out_attrs)189 static bool FCStorageType(const nnvm::NodeAttrs& attrs,
190 const int dev_mask,
191 DispatchMode* dispatch_mode,
192 std::vector<int> *in_attrs,
193 std::vector<int> *out_attrs) {
194 const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
195 const bool valid_data = in_attrs->at(0) == kDefaultStorage;
196 const bool valid_weight = in_attrs->at(1) == kDefaultStorage ||
197 in_attrs->at(1) == kRowSparseStorage;
198 bool valid_bias = true;
199 uint32_t in_expected = 2;
200 if (!param.no_bias) {
201 in_expected = 3;
202 valid_bias = in_attrs->at(2) == kDefaultStorage || in_attrs->at(2) == kRowSparseStorage;
203 }
204 CHECK_EQ(in_attrs->size(), in_expected);
205 CHECK_EQ(out_attrs->size(), 1);
206 // dispatch to kFComputeEx is fine even if all inputs are dense and no MKL is present
207 bool dispatched = false;
208 if (!dispatched && valid_data && valid_weight && valid_bias) {
209 dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage,
210 dispatch_mode, DispatchMode::kFComputeEx);
211 }
212 #if MXNET_USE_MKLDNN == 1
213 if (!MKLDNNEnvSet())
214 *dispatch_mode = DispatchMode::kFComputeFallback;
215 #endif
216
217 if (!dispatched) {
218 dispatched = dispatch_fallback(out_attrs, dispatch_mode);
219 }
220 return dispatched;
221 }
222
BackwardFCStorageType(const nnvm::NodeAttrs & attrs,const int dev_mask,DispatchMode * dispatch_mode,std::vector<int> * in_attrs,std::vector<int> * out_attrs)223 static bool BackwardFCStorageType(const nnvm::NodeAttrs& attrs,
224 const int dev_mask,
225 DispatchMode* dispatch_mode,
226 std::vector<int> *in_attrs,
227 std::vector<int> *out_attrs) {
228 const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
229 uint32_t out_expected = param.no_bias ? 2 : 3;
230 CHECK_EQ(in_attrs->size(), 3U);
231 CHECK_EQ(out_attrs->size(), out_expected);
232 bool dispatched = false;
233 if (!dispatched && common::ContainsOnlyStorage(*in_attrs, mxnet::kDefaultStorage)) {
234 dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage,
235 dispatch_mode, DispatchMode::kFComputeEx);
236 }
237 if (!dispatched && common::ContainsStorageType(*in_attrs, mxnet::kRowSparseStorage)) {
238 dispatched = dispatch_fallback(out_attrs, dispatch_mode);
239 }
240 if (!dispatched) {
241 dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage,
242 dispatch_mode, DispatchMode::kFCompute);
243 }
244 #if MXNET_USE_MKLDNN == 1
245 if (!MKLDNNEnvSet())
246 *dispatch_mode = DispatchMode::kFComputeFallback;
247 #endif
248 return dispatched;
249 }
250
251 DMLC_REGISTER_PARAMETER(FullyConnectedParam);
252
253 NNVM_REGISTER_OP(FullyConnected)
254 MXNET_ADD_SPARSE_OP_ALIAS(FullyConnected)
255 .add_alias("_npx_fully_connected")
256 .describe(R"code(Applies a linear transformation: :math:`Y = XW^T + b`.
257
258 If ``flatten`` is set to be true, then the shapes are:
259
260 - **data**: `(batch_size, x1, x2, ..., xn)`
261 - **weight**: `(num_hidden, x1 * x2 * ... * xn)`
262 - **bias**: `(num_hidden,)`
263 - **out**: `(batch_size, num_hidden)`
264
265 If ``flatten`` is set to be false, then the shapes are:
266
267 - **data**: `(x1, x2, ..., xn, input_dim)`
268 - **weight**: `(num_hidden, input_dim)`
269 - **bias**: `(num_hidden,)`
270 - **out**: `(x1, x2, ..., xn, num_hidden)`
271
272 The learnable parameters include both ``weight`` and ``bias``.
273
274 If ``no_bias`` is set to be true, then the ``bias`` term is ignored.
275
276 .. Note::
277
278 The sparse support for FullyConnected is limited to forward evaluation with `row_sparse`
279 weight and bias, where the length of `weight.indices` and `bias.indices` must be equal
280 to `num_hidden`. This could be useful for model inference with `row_sparse` weights
281 trained with importance sampling or noise contrastive estimation.
282
283 To compute linear transformation with 'csr' sparse data, sparse.dot is recommended instead
284 of sparse.FullyConnected.
285
286 )code" ADD_FILELINE)
__anonfa5744a10102(const NodeAttrs& attrs) 287 .set_num_inputs([](const NodeAttrs& attrs) {
288 const FullyConnectedParam& params = nnvm::get<FullyConnectedParam>(attrs.parsed);
289 return params.no_bias ? 2 : 3;
290 })
291 .set_num_outputs(1)
292 .set_attr_parser(ParamParser<FullyConnectedParam>)
293 .set_attr<FInferStorageType>("FInferStorageType", FCStorageType)
__anonfa5744a10202(const NodeAttrs& attrs) 294 .set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
295 const FullyConnectedParam& params = nnvm::get<FullyConnectedParam>(attrs.parsed);
296 if (!params.no_bias) {
297 return std::vector<std::string>{"data", "weight", "bias"};
298 } else {
299 return std::vector<std::string>{"data", "weight"};
300 }
301 })
302 .set_attr<nnvm::FListOutputNames>("FListOutputNames",
__anonfa5744a10302(const NodeAttrs& attrs) 303 [](const NodeAttrs& attrs) {
304 return std::vector<std::string>{"output"};
305 })
306 #if MXNET_USE_MKLDNN == 1
307 .set_attr<bool>("TIsMKLDNN", true)
__anonfa5744a10402(const NodeAttrs& n) 308 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
309 return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
310 })
311 #endif
312 .set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
313 .set_attr<mxnet::FInferShape>("FInferShape", FullyConnectedShape)
314 .set_attr<nnvm::FInferType>("FInferType", FullyConnectedType)
315 .set_attr<FCompute>("FCompute<cpu>", FullyConnectedCompute<cpu>)
316 .set_attr<FComputeEx>("FComputeEx<cpu>", FullyConnectedComputeExCPU)
317 .set_attr<nnvm::FGradient>("FGradient", FullyConnectedGrad{"_backward_FullyConnected"})
318 .add_argument("data", "NDArray-or-Symbol", "Input data.")
319 .add_argument("weight", "NDArray-or-Symbol", "Weight matrix.")
320 .add_argument("bias", "NDArray-or-Symbol", "Bias parameter.")
321 .add_arguments(FullyConnectedParam::__FIELDS__());
322
323 NNVM_REGISTER_OP(_backward_FullyConnected)
324 .set_num_inputs(3)
__anonfa5744a10502(const NodeAttrs& attrs) 325 .set_num_outputs([](const NodeAttrs& attrs) {
326 const FullyConnectedParam& params = nnvm::get<FullyConnectedParam>(attrs.parsed);
327 return params.no_bias ? 2 : 3;
328 })
__anonfa5744a10602(const NodeAttrs& n) 329 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
330 return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
331 })
332 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
__anonfa5744a10702(const NodeAttrs& attrs)333 .set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const NodeAttrs& attrs){
334 return std::vector<std::pair<int, int> >{{1, 0}};
335 })
336 .set_attr<nnvm::FGradient>("FGradient", FullyConnectedGradGrad{"_backward_backward_FullyConnected"})
337 .set_attr<FInferStorageType>("FInferStorageType", BackwardFCStorageType)
338 .set_attr_parser(ParamParser<FullyConnectedParam>)
339 #if MXNET_USE_MKLDNN == 1
340 .set_attr<bool>("TIsMKLDNN", true)
341 .set_attr<FComputeEx>("FComputeEx<cpu>", FullyConnectedGradComputeExCPU)
342 #endif
343 .set_attr<FCompute>("FCompute<cpu>", FullyConnectedGradCompute<cpu>);
344
345 // 2nd gradient for fully connected
346 // Inputs are:
347 // o_x_grad : head gradient for x_grad
348 // o_w_grad : head gradient for w_grad
349 // o_b_grad : if param.no_bias is false
350 // o_y : head gradient of y
351 //
352 // outputs are:
353 // o_y_grad : not used
354 // x_grad_grad : o_w_grad * o_y^T
355 // w_grad_grad : o_x_grad * o_y
356 //
357 // For a detailed development of the second gradient see here: TODO(larroy)
358 NNVM_REGISTER_OP(_backward_backward_FullyConnected)
__anonfa5744a10802(const NodeAttrs& attrs) 359 .set_num_inputs([](const NodeAttrs& attrs) {
360 const FullyConnectedParam& params = nnvm::get<FullyConnectedParam>(attrs.parsed);
361 return params.no_bias ? 3 : 4;
362 })
__anonfa5744a10902(const NodeAttrs& attrs) 363 .set_num_outputs([](const NodeAttrs& attrs) {
364 return 3;
365 })
366 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
367 .set_attr_parser(ParamParser<FullyConnectedParam>)
368 .set_attr<FCompute>("FCompute<cpu>", FullyConnectedGradGradDTypeDispatch<cpu>);
369
370 } // namespace op
371 } // namespace mxnet
372