1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file fully_connected.cc
22  * \brief fully connect operator
23 */
24 #include "./fully_connected-inl.h"
25 #include "./mkldnn/mkldnn_ops-inl.h"
26 #include "./mkldnn/mkldnn_base-inl.h"
27 #if MXNET_USE_NNPACK == 1
28 #include "../nnpack/nnpack_fully_connected-inl.h"
29 #endif  // MXNET_USE_NNPACK
30 
31 namespace mxnet {
32 namespace op {
33 
SupportMKLDNNFC(const NDArray & input)34 bool SupportMKLDNNFC(const NDArray& input) {
35   int ndim = input.shape().ndim();
36   return (input.dtype() == mshadow::kFloat32 || input.dtype() == mshadow::kBfloat16) &&
37          (ndim >= 1 && ndim <= 4) &&
38          input.storage_type() == kDefaultStorage;
39 }
40 
FullyConnectedShape(const nnvm::NodeAttrs & attrs,mxnet::ShapeVector * in_shape,mxnet::ShapeVector * out_shape)41 static bool FullyConnectedShape(const nnvm::NodeAttrs& attrs,
42                                 mxnet::ShapeVector *in_shape,
43                                 mxnet::ShapeVector *out_shape) {
44   const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
45   using namespace mshadow;
46   if (!param.no_bias) {
47     CHECK_EQ(in_shape->size(), 3U) << "Input:[data, weight, bias]";
48   } else {
49     CHECK_EQ(in_shape->size(), 2U) << "Input:[data, weight]";
50   }
51   CHECK_EQ(out_shape->size(), 1U);
52   mxnet::TShape dshape = (*in_shape)[fullc::kData];
53   mxnet::TShape oshape = (*out_shape)[0];
54   // require data to be known
55   if (!mxnet::ndim_is_known(dshape)) return false;
56 
57   index_t num_input;
58   if (!param.flatten) {
59     num_input = dshape[dshape.ndim()-1];
60   } else {
61     num_input = dshape.ProdShape(1, dshape.ndim());
62   }
63   SHAPE_ASSIGN_CHECK(*in_shape, fullc::kWeight, Shape2(param.num_hidden, num_input));
64   if (!param.no_bias) {
65     if (!shape_assign(&(*in_shape)[fullc::kBias], Shape1(param.num_hidden)) &&
66         !shape_assign(&(*in_shape)[fullc::kBias], Shape2(param.num_hidden, 1))) {
67       LOG(FATAL) << "Unexpected shape for bias " << (*in_shape)[fullc::kBias];
68     }
69   }
70 
71   if (!param.flatten) {
72     mxnet::TShape result_shape(dshape);
73     result_shape[dshape.ndim()-1] = param.num_hidden;
74     SHAPE_ASSIGN_CHECK(*out_shape, 0, result_shape);
75   } else {
76     SHAPE_ASSIGN_CHECK(*out_shape, 0, Shape2(dshape[0], param.num_hidden));
77   }
78   if (oshape.ndim() > 0) {
79     dshape[0] = oshape[0];
80     SHAPE_ASSIGN_CHECK(*in_shape, fullc::kData, dshape);
81   }
82   return true;
83 }
84 
FullyConnectedComputeExCPU(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)85 void FullyConnectedComputeExCPU(const nnvm::NodeAttrs& attrs,
86                                 const OpContext &ctx,
87                                 const std::vector<NDArray> &inputs,
88                                 const std::vector<OpReqType> &req,
89                                 const std::vector<NDArray> &outputs) {
90   const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
91   const bool valid_data = inputs[0].storage_type() == kDefaultStorage;
92   const bool valid_weight = inputs[1].storage_type() == kDefaultStorage ||
93                             inputs[1].storage_type() == kRowSparseStorage;
94   const bool valid_out = outputs[0].storage_type() == kDefaultStorage;
95   bool valid_bias = true;
96   if (!param.no_bias) {
97     valid_bias = inputs[2].storage_type() == kDefaultStorage ||
98                  inputs[2].storage_type() == kRowSparseStorage;
99   }
100 #if MXNET_USE_MKLDNN == 1
101   if (common::ContainsOnlyStorage(inputs, kDefaultStorage) &&
102       common::ContainsOnlyStorage(outputs, kDefaultStorage)) {
103     if (SupportMKLDNNFC(inputs[0])) {
104       MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
105       MKLDNNRun(MKLDNNFCForward, attrs, ctx, inputs, req, outputs);
106       MKLDNN_OPCHECK_RUN(FullyConnectedCompute<cpu>, attrs, ctx, inputs, req,
107                          outputs);
108     } else {
109       FallBackCompute(FullyConnectedCompute<cpu>, attrs, ctx, inputs, req, outputs);
110     }
111     return;
112   } else if (valid_data && valid_weight && valid_bias && valid_out) {
113     // inputs
114     std::vector<NDArray> temp_ndarrays;
115     std::vector<TBlob> in_blobs;
116     for (const NDArray& in : inputs) {
117       // if ndarray is in default storage and MKLDNN is available,
118       // need to make sure cpu layout data is used, instead of MKL layout
119       if (in.storage_type() == kDefaultStorage) {
120         temp_ndarrays.push_back(in.Reorder2Default());
121         in_blobs.emplace_back(temp_ndarrays.back().data());
122       } else {
123         in_blobs.emplace_back(in.data());
124       }
125     }
126     // output
127     FullyConnectedCompute<cpu>(attrs, ctx, in_blobs, req, {outputs[0].data()});
128   } else {
129     LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
130   }
131 #else
132   if (valid_data && valid_weight && valid_bias && valid_out) {
133     std::vector<TBlob> in_blobs(inputs.size());
134     for (size_t i = 0; i < in_blobs.size(); i++) in_blobs[i] = inputs[i].data();
135     std::vector<TBlob> out_blobs(outputs.size());
136     for (size_t i = 0; i < out_blobs.size(); i++) out_blobs[i] = outputs[i].data();
137     FullyConnectedCompute<cpu>(attrs, ctx, in_blobs, req, out_blobs);
138   } else {
139     LogUnimplementedOp(attrs, ctx, inputs, req, outputs);
140   }
141 #endif
142 }
143 
144 #if MXNET_USE_MKLDNN == 1
FullyConnectedGradComputeExCPU(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)145 void FullyConnectedGradComputeExCPU(const nnvm::NodeAttrs& attrs,
146                                     const OpContext &ctx,
147                                     const std::vector<NDArray> &inputs,
148                                     const std::vector<OpReqType> &req,
149                                     const std::vector<NDArray> &outputs) {
150   if (SupportMKLDNNFC(inputs[0])) {
151     MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
152     MKLDNNRun(MKLDNNFCBackward, attrs, ctx, inputs, req, outputs);
153     MKLDNN_OPCHECK_RUN(FullyConnectedGradCompute<cpu>, attrs, ctx, inputs, req,
154                        outputs);
155     return;
156   }
157   FallBackCompute(FullyConnectedGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
158 }
159 #endif
160 
FullyConnectedType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_type,std::vector<int> * out_type)161 static bool FullyConnectedType(const nnvm::NodeAttrs& attrs,
162                                std::vector<int> *in_type, std::vector<int> *out_type) {
163   CHECK_GE(in_type->size(), 1U);
164   return ElemwiseAttr<int, type_is_none, type_assign, true, type_string>(
165       attrs, in_type, out_type, -1);
166 }
167 
168 struct FullyConnectedGrad {
169   const char *op_name;
operator ()mxnet::op::FullyConnectedGrad170   std::vector<nnvm::NodeEntry> operator()(const nnvm::ObjectPtr& n,
171                                           const std::vector<nnvm::NodeEntry>& ograds) const {
172     std::vector<nnvm::NodeEntry> heads(ograds.begin(), ograds.end());
173     heads.push_back(n->inputs[fullc::kData]);
174     heads.push_back(n->inputs[fullc::kWeight]);
175     return MakeGradNode(op_name, n, heads, n->attrs.dict);
176   }
177 };
178 
179 struct FullyConnectedGradGrad {
180   const char *op_name;
operator ()mxnet::op::FullyConnectedGradGrad181   std::vector<nnvm::NodeEntry> operator()(const nnvm::ObjectPtr& n,
182                                           const std::vector<nnvm::NodeEntry>& ograds) const {
183     std::vector<nnvm::NodeEntry> heads(ograds.begin(), ograds.end());
184     heads.push_back(n->inputs[0]);  // o_y : head gradient of the output y
185     return MakeGradNode(op_name, n, heads, n->attrs.dict);
186   }
187 };
188 
FCStorageType(const nnvm::NodeAttrs & attrs,const int dev_mask,DispatchMode * dispatch_mode,std::vector<int> * in_attrs,std::vector<int> * out_attrs)189 static bool FCStorageType(const nnvm::NodeAttrs& attrs,
190                           const int dev_mask,
191                           DispatchMode* dispatch_mode,
192                           std::vector<int> *in_attrs,
193                           std::vector<int> *out_attrs) {
194   const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
195   const bool valid_data = in_attrs->at(0) == kDefaultStorage;
196   const bool valid_weight = in_attrs->at(1) == kDefaultStorage ||
197                             in_attrs->at(1) == kRowSparseStorage;
198   bool valid_bias = true;
199   uint32_t in_expected = 2;
200   if (!param.no_bias) {
201     in_expected = 3;
202     valid_bias = in_attrs->at(2) == kDefaultStorage || in_attrs->at(2) == kRowSparseStorage;
203   }
204   CHECK_EQ(in_attrs->size(), in_expected);
205   CHECK_EQ(out_attrs->size(), 1);
206   // dispatch to kFComputeEx is fine even if all inputs are dense and no MKL is present
207   bool dispatched = false;
208   if (!dispatched && valid_data && valid_weight && valid_bias) {
209     dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage,
210                                      dispatch_mode, DispatchMode::kFComputeEx);
211   }
212 #if MXNET_USE_MKLDNN == 1
213   if (!MKLDNNEnvSet())
214     *dispatch_mode = DispatchMode::kFComputeFallback;
215 #endif
216 
217   if (!dispatched) {
218     dispatched = dispatch_fallback(out_attrs, dispatch_mode);
219   }
220   return dispatched;
221 }
222 
BackwardFCStorageType(const nnvm::NodeAttrs & attrs,const int dev_mask,DispatchMode * dispatch_mode,std::vector<int> * in_attrs,std::vector<int> * out_attrs)223 static bool BackwardFCStorageType(const nnvm::NodeAttrs& attrs,
224                                   const int dev_mask,
225                                   DispatchMode* dispatch_mode,
226                                   std::vector<int> *in_attrs,
227                                   std::vector<int> *out_attrs) {
228   const FullyConnectedParam& param = nnvm::get<FullyConnectedParam>(attrs.parsed);
229   uint32_t out_expected = param.no_bias ? 2 : 3;
230   CHECK_EQ(in_attrs->size(), 3U);
231   CHECK_EQ(out_attrs->size(), out_expected);
232   bool dispatched = false;
233   if (!dispatched && common::ContainsOnlyStorage(*in_attrs, mxnet::kDefaultStorage)) {
234     dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage,
235                                      dispatch_mode, DispatchMode::kFComputeEx);
236   }
237   if (!dispatched && common::ContainsStorageType(*in_attrs, mxnet::kRowSparseStorage)) {
238     dispatched = dispatch_fallback(out_attrs, dispatch_mode);
239   }
240   if (!dispatched) {
241     dispatched = storage_type_assign(out_attrs, mxnet::kDefaultStorage,
242                                      dispatch_mode, DispatchMode::kFCompute);
243   }
244 #if MXNET_USE_MKLDNN == 1
245   if (!MKLDNNEnvSet())
246     *dispatch_mode = DispatchMode::kFComputeFallback;
247 #endif
248   return dispatched;
249 }
250 
251 DMLC_REGISTER_PARAMETER(FullyConnectedParam);
252 
253 NNVM_REGISTER_OP(FullyConnected)
254 MXNET_ADD_SPARSE_OP_ALIAS(FullyConnected)
255 .add_alias("_npx_fully_connected")
256 .describe(R"code(Applies a linear transformation: :math:`Y = XW^T + b`.
257 
258 If ``flatten`` is set to be true, then the shapes are:
259 
260 - **data**: `(batch_size, x1, x2, ..., xn)`
261 - **weight**: `(num_hidden, x1 * x2 * ... * xn)`
262 - **bias**: `(num_hidden,)`
263 - **out**: `(batch_size, num_hidden)`
264 
265 If ``flatten`` is set to be false, then the shapes are:
266 
267 - **data**: `(x1, x2, ..., xn, input_dim)`
268 - **weight**: `(num_hidden, input_dim)`
269 - **bias**: `(num_hidden,)`
270 - **out**: `(x1, x2, ..., xn, num_hidden)`
271 
272 The learnable parameters include both ``weight`` and ``bias``.
273 
274 If ``no_bias`` is set to be true, then the ``bias`` term is ignored.
275 
276 .. Note::
277 
278     The sparse support for FullyConnected is limited to forward evaluation with `row_sparse`
279     weight and bias, where the length of `weight.indices` and `bias.indices` must be equal
280     to `num_hidden`. This could be useful for model inference with `row_sparse` weights
281     trained with importance sampling or noise contrastive estimation.
282 
283     To compute linear transformation with 'csr' sparse data, sparse.dot is recommended instead
284     of sparse.FullyConnected.
285 
286 )code" ADD_FILELINE)
__anonfa5744a10102(const NodeAttrs& attrs) 287 .set_num_inputs([](const NodeAttrs& attrs) {
288   const FullyConnectedParam& params = nnvm::get<FullyConnectedParam>(attrs.parsed);
289   return params.no_bias ? 2 : 3;
290 })
291 .set_num_outputs(1)
292 .set_attr_parser(ParamParser<FullyConnectedParam>)
293 .set_attr<FInferStorageType>("FInferStorageType", FCStorageType)
__anonfa5744a10202(const NodeAttrs& attrs) 294 .set_attr<nnvm::FListInputNames>("FListInputNames", [](const NodeAttrs& attrs) {
295   const FullyConnectedParam& params = nnvm::get<FullyConnectedParam>(attrs.parsed);
296   if (!params.no_bias) {
297     return std::vector<std::string>{"data", "weight", "bias"};
298   } else {
299     return std::vector<std::string>{"data", "weight"};
300   }
301 })
302 .set_attr<nnvm::FListOutputNames>("FListOutputNames",
__anonfa5744a10302(const NodeAttrs& attrs) 303     [](const NodeAttrs& attrs) {
304     return std::vector<std::string>{"output"};
305 })
306 #if MXNET_USE_MKLDNN == 1
307 .set_attr<bool>("TIsMKLDNN", true)
__anonfa5744a10402(const NodeAttrs& n) 308 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
309   return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
310 })
311 #endif
312 .set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
313 .set_attr<mxnet::FInferShape>("FInferShape", FullyConnectedShape)
314 .set_attr<nnvm::FInferType>("FInferType", FullyConnectedType)
315 .set_attr<FCompute>("FCompute<cpu>", FullyConnectedCompute<cpu>)
316 .set_attr<FComputeEx>("FComputeEx<cpu>", FullyConnectedComputeExCPU)
317 .set_attr<nnvm::FGradient>("FGradient", FullyConnectedGrad{"_backward_FullyConnected"})
318 .add_argument("data", "NDArray-or-Symbol", "Input data.")
319 .add_argument("weight", "NDArray-or-Symbol", "Weight matrix.")
320 .add_argument("bias", "NDArray-or-Symbol", "Bias parameter.")
321 .add_arguments(FullyConnectedParam::__FIELDS__());
322 
323 NNVM_REGISTER_OP(_backward_FullyConnected)
324 .set_num_inputs(3)
__anonfa5744a10502(const NodeAttrs& attrs) 325 .set_num_outputs([](const NodeAttrs& attrs) {
326   const FullyConnectedParam& params = nnvm::get<FullyConnectedParam>(attrs.parsed);
327   return params.no_bias ? 2 : 3;
328 })
__anonfa5744a10602(const NodeAttrs& n) 329 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
330   return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
331 })
332 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
__anonfa5744a10702(const NodeAttrs& attrs)333 .set_attr<nnvm::FInplaceOption>("FInplaceOption", [](const NodeAttrs& attrs){
334   return std::vector<std::pair<int, int> >{{1, 0}};
335 })
336 .set_attr<nnvm::FGradient>("FGradient", FullyConnectedGradGrad{"_backward_backward_FullyConnected"})
337 .set_attr<FInferStorageType>("FInferStorageType", BackwardFCStorageType)
338 .set_attr_parser(ParamParser<FullyConnectedParam>)
339 #if MXNET_USE_MKLDNN == 1
340 .set_attr<bool>("TIsMKLDNN", true)
341 .set_attr<FComputeEx>("FComputeEx<cpu>", FullyConnectedGradComputeExCPU)
342 #endif
343 .set_attr<FCompute>("FCompute<cpu>", FullyConnectedGradCompute<cpu>);
344 
345 // 2nd gradient for fully connected
346 // Inputs are:
347 // o_x_grad : head gradient for x_grad
348 // o_w_grad : head gradient for w_grad
349 // o_b_grad : if param.no_bias is false
350 // o_y : head gradient of y
351 //
352 // outputs are:
353 // o_y_grad : not used
354 // x_grad_grad : o_w_grad * o_y^T
355 // w_grad_grad : o_x_grad * o_y
356 //
357 // For a detailed development of the second gradient see here: TODO(larroy)
358 NNVM_REGISTER_OP(_backward_backward_FullyConnected)
__anonfa5744a10802(const NodeAttrs& attrs) 359 .set_num_inputs([](const NodeAttrs& attrs) {
360   const FullyConnectedParam& params = nnvm::get<FullyConnectedParam>(attrs.parsed);
361   return params.no_bias ? 3 : 4;
362 })
__anonfa5744a10902(const NodeAttrs& attrs) 363 .set_num_outputs([](const NodeAttrs& attrs) {
364   return 3;
365 })
366 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
367 .set_attr_parser(ParamParser<FullyConnectedParam>)
368 .set_attr<FCompute>("FCompute<cpu>", FullyConnectedGradGradDTypeDispatch<cpu>);
369 
370 }  // namespace op
371 }  // namespace mxnet
372