1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file concat.cc
22  * \brief
23  * \author Bing Xu
24 */
25 
26 #include "./concat-inl.h"
27 #include "./mkldnn/mkldnn_ops-inl.h"
28 #include "./mkldnn/mkldnn_base-inl.h"
29 #include "../../common/utils.h"
30 
31 namespace mxnet {
32 namespace op {
33 
ConcatShape(const nnvm::NodeAttrs & attrs,mxnet::ShapeVector * in_shape,mxnet::ShapeVector * out_shape)34 bool ConcatShape(const nnvm::NodeAttrs& attrs,
35                  mxnet::ShapeVector *in_shape,
36                  mxnet::ShapeVector *out_shape) {
37   using namespace mshadow;
38   const ConcatParam& param_ = nnvm::get<ConcatParam>(attrs.parsed);
39   CHECK_EQ(in_shape->size(), static_cast<size_t>(param_.num_args));
40   mxnet::TShape dshape;
41   dim_t size = 0;
42   bool has_unknown_dim_size = false;
43   int axis = -1;
44   for (int i = 0; i < param_.num_args; ++i) {
45     mxnet::TShape tmp = (*in_shape)[i];
46     if (tmp.ndim() > 0) {
47       axis = CheckAxis(param_.dim, tmp.ndim());
48       has_unknown_dim_size = !mxnet::dim_size_is_known(tmp, axis) || has_unknown_dim_size;
49       size += tmp[axis];
50       tmp[axis] = -1;
51       shape_assign(&dshape, tmp);
52     }
53   }
54 
55   mxnet::TShape tmp = (*out_shape)[0];
56   if (tmp.ndim() > 0) {
57     axis = CheckAxis(param_.dim, tmp.ndim());
58     tmp[axis] = -1;
59     shape_assign(&dshape, tmp);
60   }
61 
62   if (dshape.ndim() == -1) return false;
63   CHECK_NE(dshape.ndim(), 0) << "zero-dimensional arrays cannot be concatenated";
64 
65   for (int i = 0; i < param_.num_args; ++i) {
66     CHECK(shape_assign(&(*in_shape)[i], dshape))
67         << "Incompatible input shape: expected " << dshape << ", got " << (*in_shape)[i];
68   }
69 
70   if (!has_unknown_dim_size) dshape[axis] = size;
71   CHECK(shape_assign(&(*out_shape)[0], dshape))
72       << "Incompatible output shape: expected " << dshape << ", got " << (*out_shape)[0];
73 
74   return shape_is_known(dshape);
75 }
76 
77 // Concat for RNN param deals with the reverse shape inference from output
78 // for the special case of concatenating RNN parameters.
79 // The first (and sometimes the second) input may be unknown on the target axis.
80 // If the two inputs are unknown, they always have the same shape.
RNNParamConcatShape(const nnvm::NodeAttrs & attrs,mxnet::ShapeVector * in_shape,mxnet::ShapeVector * out_shape)81 static bool RNNParamConcatShape(const nnvm::NodeAttrs& attrs,
82                                 mxnet::ShapeVector *in_shape,
83                                 mxnet::ShapeVector *out_shape) {
84   using namespace mshadow;
85   const ConcatParam& param_ = nnvm::get<ConcatParam>(attrs.parsed);
86   CHECK_EQ(in_shape->size(), static_cast<size_t>(param_.num_args));
87   mxnet::TShape dshape;
88   index_t size = 0;
89   std::vector<int> zero_indices;
90   int axis = -1;
91   for (int i = 0; i < param_.num_args; ++i) {
92     mxnet::TShape tmp = (*in_shape)[i];
93     if (tmp.ndim() > 0) {
94       axis = CheckAxis(param_.dim, tmp.ndim());
95       if (!mxnet::dim_size_is_known(tmp, axis)) {
96         zero_indices.emplace_back(i);
97       } else {
98         CHECK_GE(tmp[axis], 0);
99         size += tmp[axis];
100       }
101       tmp[axis] = -1;
102       shape_assign(&dshape, tmp);
103     }
104   }
105 
106   mxnet::TShape tmp = (*out_shape)[0];
107   if (tmp.ndim() > 0) {
108     axis = CheckAxis(param_.dim, tmp.ndim());
109     tmp[axis] = -1;
110     shape_assign(&dshape, tmp);
111   }
112 
113   if (!mxnet::ndim_is_known(dshape)) return false;
114 
115   for (int i = 0; i < param_.num_args; ++i) {
116     CHECK(shape_assign(&(*in_shape)[i], dshape))
117         << "Incompatible input shape: expected " << dshape << ", got " << (*in_shape)[i];
118   }
119 
120   if (zero_indices.empty()) dshape[axis] = size;
121   CHECK(shape_assign(&(*out_shape)[0], dshape))
122       << "Incompatible output shape: expected " << dshape << ", got " << (*out_shape)[0];
123   if ((*out_shape)[0][axis] != -1 && !zero_indices.empty()) {
124     int residual = (*out_shape)[0][axis] - size;
125     CHECK_GE(residual, 0)
126         << "Input size already exceeds output size. Residual: " << residual;
127     CHECK(zero_indices.size() <= 2 && zero_indices.size() > 0)
128         << "Expecting 1 or 2 inputs that need shape inference. Got: " << zero_indices.size();
129     bool need_infer = !shape_is_known((*out_shape)[0]);
130     for (int i : zero_indices) {
131       (*in_shape)[i][axis] = residual / zero_indices.size();
132       need_infer = need_infer || !shape_is_known((*in_shape)[i]);
133     }
134     return !need_infer;
135   }
136 
137   return shape_is_known(dshape);
138 }
139 
ConcatType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_type,std::vector<int> * out_type)140 bool ConcatType(const nnvm::NodeAttrs& attrs,
141                 std::vector<int> *in_type,
142                 std::vector<int> *out_type) {
143   const ConcatParam& param_ = nnvm::get<ConcatParam>(attrs.parsed);
144   int dtype = -1;
145 
146   // checks uniformity of input
147   for (size_t i =0; i < in_type->size(); ++i) {
148     if (dtype == -1) {
149       dtype = in_type->at(i);
150     } else {
151       CHECK(in_type->at(i) == dtype || in_type->at(i) == -1)
152           << "Non-uniform data type in "  << attrs.op->name
153           << ", expected data type " << mxnet::op::type_string(dtype)
154           << ", got data type " << mxnet::op::type_string(in_type->at(i))
155           << " for input " << i;
156     }
157   }
158 
159   size_t nin = param_.num_args;
160 
161   // if in types are known out types are unknown
162   if (dtype != -1 && (*out_type)[0] == -1) {
163     (*out_type)[0] = dtype;
164     in_type->clear();
165     for (size_t i = 0; i < nin; ++i) {
166       in_type->push_back(dtype);
167     }
168   // if out types are known in types are unknown
169   } else if ((*out_type)[0] != -1 && dtype == -1) {
170     in_type->clear();
171     for (size_t i = 0; i < nin; ++i) {
172       in_type->push_back((*out_type)[0]);
173     }
174   // if both out_types and in_types are known, and different
175   } else if ((*out_type)[0] != -1 && dtype != -1 && ((*out_type)[0] != dtype)) {
176     std::ostringstream os;
177     os << "Type inconsistent, Provided output type = "
178        << mxnet::op::type_string((*out_type)[0]) << ','
179        << " inferred type = " << mxnet::op::type_string(dtype);
180     throw mxnet::op::InferTypeError(os.str(), 0);
181   }
182   return true;
183 }
184 
ConcatForwardInferStorageType(const nnvm::NodeAttrs & attrs,const int dev_mask,DispatchMode * dispatch_mode,std::vector<int> * in_attrs,std::vector<int> * out_attrs)185 inline static bool ConcatForwardInferStorageType(const nnvm::NodeAttrs& attrs,
186                                                  const int dev_mask,
187                                                  DispatchMode* dispatch_mode,
188                                                  std::vector<int> *in_attrs,
189                                                  std::vector<int> *out_attrs) {
190   CHECK(!in_attrs->empty());
191   CHECK_EQ(out_attrs->size(), 1U);
192   auto& out_stype = out_attrs->at(0);
193   bool dispatched = false;
194   const ConcatParam& param = nnvm::get<ConcatParam>(attrs.parsed);
195   if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kCSRStorage)
196       && param.dim == 0) {
197     dispatched = storage_type_assign(&out_stype, kCSRStorage,
198                                      dispatch_mode, DispatchMode::kFComputeEx);
199   }
200 #if MXNET_USE_MKLDNN == 1
201   if (!dispatched && dev_mask == mshadow::cpu::kDevMask
202       && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)
203       && param.dim > 0) {
204     dispatched = storage_type_assign(&out_stype, kDefaultStorage,
205                                      dispatch_mode, DispatchMode::kFComputeEx);
206   }
207 #endif  // MXNET_USE_MKLDNN == 1
208   if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
209     dispatched = storage_type_assign(&out_stype, kDefaultStorage,
210                                      dispatch_mode, DispatchMode::kFCompute);
211   }
212   if (!dispatched) {
213     dispatched = dispatch_fallback(out_attrs, dispatch_mode);
214   }
215 #if MXNET_USE_MKLDNN == 1
216   if (!MKLDNNEnvSet())
217     *dispatch_mode = DispatchMode::kFComputeFallback;
218 #endif  // MXNET_USE_MKLDNN == 1
219   return dispatched;
220 }
221 
BackwardConcatStorageType(const nnvm::NodeAttrs & attrs,const int dev_mask,DispatchMode * dispatch_mode,std::vector<int> * in_attrs,std::vector<int> * out_attrs)222 inline static bool BackwardConcatStorageType(const nnvm::NodeAttrs& attrs,
223                                              const int dev_mask,
224                                              DispatchMode* dispatch_mode,
225                                              std::vector<int> *in_attrs,
226                                              std::vector<int> *out_attrs) {
227   DispatchMode wanted_mode;
228 #if MXNET_USE_MKLDNN == 1
229   const ConcatParam& param = nnvm::get<ConcatParam>(attrs.parsed);
230   CHECK_EQ(out_attrs->size(), in_attrs->size() - 1);
231   if (dev_mask == mshadow::cpu::kDevMask
232       && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)
233       && param.dim > 0)
234     wanted_mode = DispatchMode::kFComputeEx;
235   else
236 #endif  // MXNET_USE_MKLDNN == 1
237     wanted_mode = DispatchMode::kFCompute;
238 #if MXNET_USE_MKLDNN == 1
239   if (!MKLDNNEnvSet())
240     wanted_mode = DispatchMode::kFComputeFallback;
241 #endif  // MXNET_USE_MKLDNN == 1
242   return storage_type_assign(out_attrs, mxnet::kDefaultStorage,
243                              dispatch_mode, wanted_mode);
244 }
245 #if MXNET_USE_MKLDNN == 1
SupportMKLDNNConcat(const std::vector<NDArray> & arrs)246 bool SupportMKLDNNConcat(const std::vector<NDArray> &arrs) {
247   for (auto &arr : arrs) {
248     if (arr.IsView()) return false;
249     if (!(arr.dtype() == mshadow::kFloat32 || arr.dtype() == mshadow::kBfloat16)) return false;
250     // DO not support zero-size tensors.
251     if (arr.shape().Size() == 0) return false;
252     int ndim = arr.shape().ndim();
253     const int mkldnn_ndims = arr.GetMKLDNNData()->get_desc().data.ndims;
254     if (!(ndim == 2 || ndim == 4) || ndim != mkldnn_ndims) return false;
255   }
256   return true;
257 }
258 #endif  // MXNET_USE_MKLDNN == 1
ConcatComputeExCPU(const nnvm::NodeAttrs & attrs,const OpContext & op_ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)259 static void ConcatComputeExCPU(const nnvm::NodeAttrs& attrs,
260                                const OpContext& op_ctx,
261                                const std::vector<NDArray>& inputs,
262                                const std::vector<OpReqType>& req,
263                                const std::vector<NDArray>& outputs) {
264   CHECK(!inputs.empty());
265   CHECK_EQ(outputs.size(), 1U);
266   CHECK_EQ(req.size(), 1U);
267   if (req[0] == kNullOp) return;
268   if (common::ContainsOnlyStorage(inputs, kCSRStorage) &&
269       outputs[0].storage_type() == kCSRStorage) {
270     ConcatCSRImpl<cpu>(attrs, op_ctx, inputs, req, outputs);
271 #if MXNET_USE_MKLDNN == 1
272   } else if (SupportMKLDNNConcat(inputs)) {
273     MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
274     MKLDNNRun(MKLDNNConcatForward, attrs, op_ctx, inputs, req, outputs);
275     MKLDNN_OPCHECK_RUN(ConcatCompute<cpu>, attrs, op_ctx, inputs, req, outputs);
276   } else if (common::ContainsOnlyStorage(inputs, kDefaultStorage)) {
277     FallBackCompute(ConcatCompute<cpu>, attrs, op_ctx, inputs, req, outputs);
278 #endif  // MXNET_USE_MKLDNN == 1
279   } else {
280     LogUnimplementedOp(attrs, op_ctx, inputs, req, outputs);
281   }
282 }
283 
284 #if MXNET_USE_MKLDNN == 1
ConcatGradComputeExCPU(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)285 static void ConcatGradComputeExCPU(const nnvm::NodeAttrs& attrs,
286                                    const OpContext& ctx,
287                                    const std::vector<NDArray>& inputs,
288                                    const std::vector<OpReqType>& req,
289                                    const std::vector<NDArray>& outputs) {
290   if (SupportMKLDNNConcat(inputs)) {
291     MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
292     MKLDNNRun(MKLDNNConcatBackward, attrs, ctx, inputs, req, outputs);
293     MKLDNN_OPCHECK_RUN(ConcatGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
294     return;
295   }
296   FallBackCompute(ConcatGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
297 }
298 #endif  // MXNET_USE_MKLDNN == 1
299 
300 struct ConcatGrad {
301   const char *op_name;
operator ()mxnet::op::ConcatGrad302   std::vector<nnvm::NodeEntry> operator()(const nnvm::ObjectPtr& n,
303                                           const std::vector<nnvm::NodeEntry>& ograds) const {
304     CHECK_EQ(ograds.size(), 1);
305     std::vector<nnvm::NodeEntry> heads(ograds.begin(), ograds.end());
306 #if MXNET_USE_MKLDNN == 1
307     for (size_t i = 0; i < n->inputs.size(); i++) {
308       heads.push_back(n->inputs[i]);
309     }
310 #endif  // MXNET_USE_MKLDNN == 1
311     return MakeGradNode(op_name, n, heads, n->attrs.dict);
312   }
313 };
314 
315 DMLC_REGISTER_PARAMETER(ConcatParam);
316 
317 #define CONCAT_FORWARD_ATTRS \
318 .set_num_inputs([](const NodeAttrs& attrs) { \
319   const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed); \
320   return params.num_args; \
321 }) \
322 .set_num_outputs(1) \
323 .set_attr_parser(ParamParser<ConcatParam>) \
324 .set_attr<nnvm::FListInputNames>("FListInputNames", \
325     [](const NodeAttrs& attrs) { \
326   const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed); \
327   std::vector<std::string> ret; \
328   for (int i = 0; i < params.num_args; ++i) { \
329     ret.push_back(std::string("arg") + std::to_string(i)); \
330   } \
331   return ret; \
332 }) \
333 .set_attr<nnvm::FListOutputNames>("FListOutputNames", \
334     [](const NodeAttrs& attrs) { \
335     return std::vector<std::string>{"output"}; \
336 }) \
337 .set_attr<nnvm::FInferType>("FInferType", ConcatType) \
338 .set_attr<FInferStorageType>("FInferStorageType", ConcatForwardInferStorageType) \
339 .set_attr<FCompute>("FCompute<cpu>", ConcatCompute<cpu>) \
340 .set_attr<FComputeEx>("FComputeEx<cpu>", ConcatComputeExCPU) \
341 .set_attr<nnvm::FGradient>("FGradient", ConcatGrad{"_backward_Concat"}) \
342 .set_attr<std::string>("key_var_num_args", "num_args")
343 
344 
345 NNVM_REGISTER_OP(Concat)
346 MXNET_ADD_SPARSE_OP_ALIAS(concat)
347 .add_alias("concat")
348 .describe(R"code(Joins input arrays along a given axis.
349 
350 .. note:: `Concat` is deprecated. Use `concat` instead.
351 
352 The dimensions of the input arrays should be the same except the axis along
353 which they will be concatenated.
354 The dimension of the output array along the concatenated axis will be equal
355 to the sum of the corresponding dimensions of the input arrays.
356 
357 The storage type of ``concat`` output depends on storage types of inputs
358 
359 - concat(csr, csr, ..., csr, dim=0) = csr
360 - otherwise, ``concat`` generates output with default storage
361 
362 Example::
363 
364    x = [[1,1],[2,2]]
365    y = [[3,3],[4,4],[5,5]]
366    z = [[6,6], [7,7],[8,8]]
367 
368    concat(x,y,z,dim=0) = [[ 1.,  1.],
369                           [ 2.,  2.],
370                           [ 3.,  3.],
371                           [ 4.,  4.],
372                           [ 5.,  5.],
373                           [ 6.,  6.],
374                           [ 7.,  7.],
375                           [ 8.,  8.]]
376 
377    Note that you cannot concat x,y,z along dimension 1 since dimension
378    0 is not the same for all the input arrays.
379 
380    concat(y,z,dim=1) = [[ 3.,  3.,  6.,  6.],
381                          [ 4.,  4.,  7.,  7.],
382                          [ 5.,  5.,  8.,  8.]]
383 
384 )code" ADD_FILELINE)
385 #if MXNET_USE_MKLDNN == 1
__anonf271f7310102(const NodeAttrs& n) 386 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
387   return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
388 })
389 .set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
390 .set_attr<bool>("TIsMKLDNN", true)
391 #endif  // MXNET_USE_MKLDNN == 1
392 CONCAT_FORWARD_ATTRS
393 .set_attr<mxnet::FInferShape>("FInferShape", ConcatShape)
394 .add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
395 .add_arguments(ConcatParam::__FIELDS__());
396 
397 NNVM_REGISTER_OP(_backward_Concat)
__anonf271f7310202(const NodeAttrs& attrs) 398 .set_num_inputs([](const NodeAttrs& attrs) {
399 #if MXNET_USE_MKLDNN == 1
400   const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
401   return 1 + params.num_args;
402 #else
403   return 1;
404 #endif
405 })
__anonf271f7310302(const NodeAttrs& attrs) 406 .set_num_outputs([](const NodeAttrs& attrs) {
407   const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
408   return params.num_args;
409 })
410 .set_attr_parser(ParamParser<ConcatParam>)
411 #if MXNET_USE_MKLDNN == 1
__anonf271f7310402(const NodeAttrs& n) 412 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
413   return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
414 })
415 #endif  // MXNET_USE_MKLDNN == 1
416 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
417 .set_attr<FInferStorageType>("FInferStorageType", BackwardConcatStorageType)
418 #if MXNET_USE_MKLDNN == 1
419 .set_attr<bool>("TIsMKLDNN", true)
420 .set_attr<FComputeEx>("FComputeEx<cpu>", ConcatGradComputeExCPU)
421 #endif  // MXNET_USE_MKLDNN == 1
422 .set_attr<FCompute>("FCompute<cpu>", ConcatGradCompute<cpu>);
423 
424 // _rnn_param_concat is a custom concat op with specialized infer_shape,
425 // which handles the case where the first one or two inputs may have
426 // unknown shape that can be inferred from output shape.
427 NNVM_REGISTER_OP(_rnn_param_concat)
428 .add_alias("_npi_rnn_param_concat")
429 #if MXNET_USE_MKLDNN == 1
__anonf271f7310502(const NodeAttrs& n) 430 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
431   return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
432 })
433 #endif  // MXNET_USE_MKLDNN == 1
434 CONCAT_FORWARD_ATTRS
435 .set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
436 .set_attr<mxnet::FInferShape>("FInferShape", RNNParamConcatShape)
437 .add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
438 .add_arguments(ConcatParam::__FIELDS__());
439 
440 }  // namespace op
441 }  // namespace mxnet
442