1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file concat.cc
22 * \brief
23 * \author Bing Xu
24 */
25
26 #include "./concat-inl.h"
27 #include "./mkldnn/mkldnn_ops-inl.h"
28 #include "./mkldnn/mkldnn_base-inl.h"
29 #include "../../common/utils.h"
30
31 namespace mxnet {
32 namespace op {
33
ConcatShape(const nnvm::NodeAttrs & attrs,mxnet::ShapeVector * in_shape,mxnet::ShapeVector * out_shape)34 bool ConcatShape(const nnvm::NodeAttrs& attrs,
35 mxnet::ShapeVector *in_shape,
36 mxnet::ShapeVector *out_shape) {
37 using namespace mshadow;
38 const ConcatParam& param_ = nnvm::get<ConcatParam>(attrs.parsed);
39 CHECK_EQ(in_shape->size(), static_cast<size_t>(param_.num_args));
40 mxnet::TShape dshape;
41 dim_t size = 0;
42 bool has_unknown_dim_size = false;
43 int axis = -1;
44 for (int i = 0; i < param_.num_args; ++i) {
45 mxnet::TShape tmp = (*in_shape)[i];
46 if (tmp.ndim() > 0) {
47 axis = CheckAxis(param_.dim, tmp.ndim());
48 has_unknown_dim_size = !mxnet::dim_size_is_known(tmp, axis) || has_unknown_dim_size;
49 size += tmp[axis];
50 tmp[axis] = -1;
51 shape_assign(&dshape, tmp);
52 }
53 }
54
55 mxnet::TShape tmp = (*out_shape)[0];
56 if (tmp.ndim() > 0) {
57 axis = CheckAxis(param_.dim, tmp.ndim());
58 tmp[axis] = -1;
59 shape_assign(&dshape, tmp);
60 }
61
62 if (dshape.ndim() == -1) return false;
63 CHECK_NE(dshape.ndim(), 0) << "zero-dimensional arrays cannot be concatenated";
64
65 for (int i = 0; i < param_.num_args; ++i) {
66 CHECK(shape_assign(&(*in_shape)[i], dshape))
67 << "Incompatible input shape: expected " << dshape << ", got " << (*in_shape)[i];
68 }
69
70 if (!has_unknown_dim_size) dshape[axis] = size;
71 CHECK(shape_assign(&(*out_shape)[0], dshape))
72 << "Incompatible output shape: expected " << dshape << ", got " << (*out_shape)[0];
73
74 return shape_is_known(dshape);
75 }
76
77 // Concat for RNN param deals with the reverse shape inference from output
78 // for the special case of concatenating RNN parameters.
79 // The first (and sometimes the second) input may be unknown on the target axis.
80 // If the two inputs are unknown, they always have the same shape.
RNNParamConcatShape(const nnvm::NodeAttrs & attrs,mxnet::ShapeVector * in_shape,mxnet::ShapeVector * out_shape)81 static bool RNNParamConcatShape(const nnvm::NodeAttrs& attrs,
82 mxnet::ShapeVector *in_shape,
83 mxnet::ShapeVector *out_shape) {
84 using namespace mshadow;
85 const ConcatParam& param_ = nnvm::get<ConcatParam>(attrs.parsed);
86 CHECK_EQ(in_shape->size(), static_cast<size_t>(param_.num_args));
87 mxnet::TShape dshape;
88 index_t size = 0;
89 std::vector<int> zero_indices;
90 int axis = -1;
91 for (int i = 0; i < param_.num_args; ++i) {
92 mxnet::TShape tmp = (*in_shape)[i];
93 if (tmp.ndim() > 0) {
94 axis = CheckAxis(param_.dim, tmp.ndim());
95 if (!mxnet::dim_size_is_known(tmp, axis)) {
96 zero_indices.emplace_back(i);
97 } else {
98 CHECK_GE(tmp[axis], 0);
99 size += tmp[axis];
100 }
101 tmp[axis] = -1;
102 shape_assign(&dshape, tmp);
103 }
104 }
105
106 mxnet::TShape tmp = (*out_shape)[0];
107 if (tmp.ndim() > 0) {
108 axis = CheckAxis(param_.dim, tmp.ndim());
109 tmp[axis] = -1;
110 shape_assign(&dshape, tmp);
111 }
112
113 if (!mxnet::ndim_is_known(dshape)) return false;
114
115 for (int i = 0; i < param_.num_args; ++i) {
116 CHECK(shape_assign(&(*in_shape)[i], dshape))
117 << "Incompatible input shape: expected " << dshape << ", got " << (*in_shape)[i];
118 }
119
120 if (zero_indices.empty()) dshape[axis] = size;
121 CHECK(shape_assign(&(*out_shape)[0], dshape))
122 << "Incompatible output shape: expected " << dshape << ", got " << (*out_shape)[0];
123 if ((*out_shape)[0][axis] != -1 && !zero_indices.empty()) {
124 int residual = (*out_shape)[0][axis] - size;
125 CHECK_GE(residual, 0)
126 << "Input size already exceeds output size. Residual: " << residual;
127 CHECK(zero_indices.size() <= 2 && zero_indices.size() > 0)
128 << "Expecting 1 or 2 inputs that need shape inference. Got: " << zero_indices.size();
129 bool need_infer = !shape_is_known((*out_shape)[0]);
130 for (int i : zero_indices) {
131 (*in_shape)[i][axis] = residual / zero_indices.size();
132 need_infer = need_infer || !shape_is_known((*in_shape)[i]);
133 }
134 return !need_infer;
135 }
136
137 return shape_is_known(dshape);
138 }
139
ConcatType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_type,std::vector<int> * out_type)140 bool ConcatType(const nnvm::NodeAttrs& attrs,
141 std::vector<int> *in_type,
142 std::vector<int> *out_type) {
143 const ConcatParam& param_ = nnvm::get<ConcatParam>(attrs.parsed);
144 int dtype = -1;
145
146 // checks uniformity of input
147 for (size_t i =0; i < in_type->size(); ++i) {
148 if (dtype == -1) {
149 dtype = in_type->at(i);
150 } else {
151 CHECK(in_type->at(i) == dtype || in_type->at(i) == -1)
152 << "Non-uniform data type in " << attrs.op->name
153 << ", expected data type " << mxnet::op::type_string(dtype)
154 << ", got data type " << mxnet::op::type_string(in_type->at(i))
155 << " for input " << i;
156 }
157 }
158
159 size_t nin = param_.num_args;
160
161 // if in types are known out types are unknown
162 if (dtype != -1 && (*out_type)[0] == -1) {
163 (*out_type)[0] = dtype;
164 in_type->clear();
165 for (size_t i = 0; i < nin; ++i) {
166 in_type->push_back(dtype);
167 }
168 // if out types are known in types are unknown
169 } else if ((*out_type)[0] != -1 && dtype == -1) {
170 in_type->clear();
171 for (size_t i = 0; i < nin; ++i) {
172 in_type->push_back((*out_type)[0]);
173 }
174 // if both out_types and in_types are known, and different
175 } else if ((*out_type)[0] != -1 && dtype != -1 && ((*out_type)[0] != dtype)) {
176 std::ostringstream os;
177 os << "Type inconsistent, Provided output type = "
178 << mxnet::op::type_string((*out_type)[0]) << ','
179 << " inferred type = " << mxnet::op::type_string(dtype);
180 throw mxnet::op::InferTypeError(os.str(), 0);
181 }
182 return true;
183 }
184
ConcatForwardInferStorageType(const nnvm::NodeAttrs & attrs,const int dev_mask,DispatchMode * dispatch_mode,std::vector<int> * in_attrs,std::vector<int> * out_attrs)185 inline static bool ConcatForwardInferStorageType(const nnvm::NodeAttrs& attrs,
186 const int dev_mask,
187 DispatchMode* dispatch_mode,
188 std::vector<int> *in_attrs,
189 std::vector<int> *out_attrs) {
190 CHECK(!in_attrs->empty());
191 CHECK_EQ(out_attrs->size(), 1U);
192 auto& out_stype = out_attrs->at(0);
193 bool dispatched = false;
194 const ConcatParam& param = nnvm::get<ConcatParam>(attrs.parsed);
195 if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kCSRStorage)
196 && param.dim == 0) {
197 dispatched = storage_type_assign(&out_stype, kCSRStorage,
198 dispatch_mode, DispatchMode::kFComputeEx);
199 }
200 #if MXNET_USE_MKLDNN == 1
201 if (!dispatched && dev_mask == mshadow::cpu::kDevMask
202 && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)
203 && param.dim > 0) {
204 dispatched = storage_type_assign(&out_stype, kDefaultStorage,
205 dispatch_mode, DispatchMode::kFComputeEx);
206 }
207 #endif // MXNET_USE_MKLDNN == 1
208 if (!dispatched && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)) {
209 dispatched = storage_type_assign(&out_stype, kDefaultStorage,
210 dispatch_mode, DispatchMode::kFCompute);
211 }
212 if (!dispatched) {
213 dispatched = dispatch_fallback(out_attrs, dispatch_mode);
214 }
215 #if MXNET_USE_MKLDNN == 1
216 if (!MKLDNNEnvSet())
217 *dispatch_mode = DispatchMode::kFComputeFallback;
218 #endif // MXNET_USE_MKLDNN == 1
219 return dispatched;
220 }
221
BackwardConcatStorageType(const nnvm::NodeAttrs & attrs,const int dev_mask,DispatchMode * dispatch_mode,std::vector<int> * in_attrs,std::vector<int> * out_attrs)222 inline static bool BackwardConcatStorageType(const nnvm::NodeAttrs& attrs,
223 const int dev_mask,
224 DispatchMode* dispatch_mode,
225 std::vector<int> *in_attrs,
226 std::vector<int> *out_attrs) {
227 DispatchMode wanted_mode;
228 #if MXNET_USE_MKLDNN == 1
229 const ConcatParam& param = nnvm::get<ConcatParam>(attrs.parsed);
230 CHECK_EQ(out_attrs->size(), in_attrs->size() - 1);
231 if (dev_mask == mshadow::cpu::kDevMask
232 && common::ContainsOnlyStorage(*in_attrs, kDefaultStorage)
233 && param.dim > 0)
234 wanted_mode = DispatchMode::kFComputeEx;
235 else
236 #endif // MXNET_USE_MKLDNN == 1
237 wanted_mode = DispatchMode::kFCompute;
238 #if MXNET_USE_MKLDNN == 1
239 if (!MKLDNNEnvSet())
240 wanted_mode = DispatchMode::kFComputeFallback;
241 #endif // MXNET_USE_MKLDNN == 1
242 return storage_type_assign(out_attrs, mxnet::kDefaultStorage,
243 dispatch_mode, wanted_mode);
244 }
245 #if MXNET_USE_MKLDNN == 1
SupportMKLDNNConcat(const std::vector<NDArray> & arrs)246 bool SupportMKLDNNConcat(const std::vector<NDArray> &arrs) {
247 for (auto &arr : arrs) {
248 if (arr.IsView()) return false;
249 if (!(arr.dtype() == mshadow::kFloat32 || arr.dtype() == mshadow::kBfloat16)) return false;
250 // DO not support zero-size tensors.
251 if (arr.shape().Size() == 0) return false;
252 int ndim = arr.shape().ndim();
253 const int mkldnn_ndims = arr.GetMKLDNNData()->get_desc().data.ndims;
254 if (!(ndim == 2 || ndim == 4) || ndim != mkldnn_ndims) return false;
255 }
256 return true;
257 }
258 #endif // MXNET_USE_MKLDNN == 1
ConcatComputeExCPU(const nnvm::NodeAttrs & attrs,const OpContext & op_ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)259 static void ConcatComputeExCPU(const nnvm::NodeAttrs& attrs,
260 const OpContext& op_ctx,
261 const std::vector<NDArray>& inputs,
262 const std::vector<OpReqType>& req,
263 const std::vector<NDArray>& outputs) {
264 CHECK(!inputs.empty());
265 CHECK_EQ(outputs.size(), 1U);
266 CHECK_EQ(req.size(), 1U);
267 if (req[0] == kNullOp) return;
268 if (common::ContainsOnlyStorage(inputs, kCSRStorage) &&
269 outputs[0].storage_type() == kCSRStorage) {
270 ConcatCSRImpl<cpu>(attrs, op_ctx, inputs, req, outputs);
271 #if MXNET_USE_MKLDNN == 1
272 } else if (SupportMKLDNNConcat(inputs)) {
273 MKLDNN_OPCHECK_INIT(false, outputs.size(), inputs, outputs);
274 MKLDNNRun(MKLDNNConcatForward, attrs, op_ctx, inputs, req, outputs);
275 MKLDNN_OPCHECK_RUN(ConcatCompute<cpu>, attrs, op_ctx, inputs, req, outputs);
276 } else if (common::ContainsOnlyStorage(inputs, kDefaultStorage)) {
277 FallBackCompute(ConcatCompute<cpu>, attrs, op_ctx, inputs, req, outputs);
278 #endif // MXNET_USE_MKLDNN == 1
279 } else {
280 LogUnimplementedOp(attrs, op_ctx, inputs, req, outputs);
281 }
282 }
283
284 #if MXNET_USE_MKLDNN == 1
ConcatGradComputeExCPU(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<NDArray> & inputs,const std::vector<OpReqType> & req,const std::vector<NDArray> & outputs)285 static void ConcatGradComputeExCPU(const nnvm::NodeAttrs& attrs,
286 const OpContext& ctx,
287 const std::vector<NDArray>& inputs,
288 const std::vector<OpReqType>& req,
289 const std::vector<NDArray>& outputs) {
290 if (SupportMKLDNNConcat(inputs)) {
291 MKLDNN_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
292 MKLDNNRun(MKLDNNConcatBackward, attrs, ctx, inputs, req, outputs);
293 MKLDNN_OPCHECK_RUN(ConcatGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
294 return;
295 }
296 FallBackCompute(ConcatGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
297 }
298 #endif // MXNET_USE_MKLDNN == 1
299
300 struct ConcatGrad {
301 const char *op_name;
operator ()mxnet::op::ConcatGrad302 std::vector<nnvm::NodeEntry> operator()(const nnvm::ObjectPtr& n,
303 const std::vector<nnvm::NodeEntry>& ograds) const {
304 CHECK_EQ(ograds.size(), 1);
305 std::vector<nnvm::NodeEntry> heads(ograds.begin(), ograds.end());
306 #if MXNET_USE_MKLDNN == 1
307 for (size_t i = 0; i < n->inputs.size(); i++) {
308 heads.push_back(n->inputs[i]);
309 }
310 #endif // MXNET_USE_MKLDNN == 1
311 return MakeGradNode(op_name, n, heads, n->attrs.dict);
312 }
313 };
314
315 DMLC_REGISTER_PARAMETER(ConcatParam);
316
317 #define CONCAT_FORWARD_ATTRS \
318 .set_num_inputs([](const NodeAttrs& attrs) { \
319 const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed); \
320 return params.num_args; \
321 }) \
322 .set_num_outputs(1) \
323 .set_attr_parser(ParamParser<ConcatParam>) \
324 .set_attr<nnvm::FListInputNames>("FListInputNames", \
325 [](const NodeAttrs& attrs) { \
326 const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed); \
327 std::vector<std::string> ret; \
328 for (int i = 0; i < params.num_args; ++i) { \
329 ret.push_back(std::string("arg") + std::to_string(i)); \
330 } \
331 return ret; \
332 }) \
333 .set_attr<nnvm::FListOutputNames>("FListOutputNames", \
334 [](const NodeAttrs& attrs) { \
335 return std::vector<std::string>{"output"}; \
336 }) \
337 .set_attr<nnvm::FInferType>("FInferType", ConcatType) \
338 .set_attr<FInferStorageType>("FInferStorageType", ConcatForwardInferStorageType) \
339 .set_attr<FCompute>("FCompute<cpu>", ConcatCompute<cpu>) \
340 .set_attr<FComputeEx>("FComputeEx<cpu>", ConcatComputeExCPU) \
341 .set_attr<nnvm::FGradient>("FGradient", ConcatGrad{"_backward_Concat"}) \
342 .set_attr<std::string>("key_var_num_args", "num_args")
343
344
345 NNVM_REGISTER_OP(Concat)
346 MXNET_ADD_SPARSE_OP_ALIAS(concat)
347 .add_alias("concat")
348 .describe(R"code(Joins input arrays along a given axis.
349
350 .. note:: `Concat` is deprecated. Use `concat` instead.
351
352 The dimensions of the input arrays should be the same except the axis along
353 which they will be concatenated.
354 The dimension of the output array along the concatenated axis will be equal
355 to the sum of the corresponding dimensions of the input arrays.
356
357 The storage type of ``concat`` output depends on storage types of inputs
358
359 - concat(csr, csr, ..., csr, dim=0) = csr
360 - otherwise, ``concat`` generates output with default storage
361
362 Example::
363
364 x = [[1,1],[2,2]]
365 y = [[3,3],[4,4],[5,5]]
366 z = [[6,6], [7,7],[8,8]]
367
368 concat(x,y,z,dim=0) = [[ 1., 1.],
369 [ 2., 2.],
370 [ 3., 3.],
371 [ 4., 4.],
372 [ 5., 5.],
373 [ 6., 6.],
374 [ 7., 7.],
375 [ 8., 8.]]
376
377 Note that you cannot concat x,y,z along dimension 1 since dimension
378 0 is not the same for all the input arrays.
379
380 concat(y,z,dim=1) = [[ 3., 3., 6., 6.],
381 [ 4., 4., 7., 7.],
382 [ 5., 5., 8., 8.]]
383
384 )code" ADD_FILELINE)
385 #if MXNET_USE_MKLDNN == 1
__anonf271f7310102(const NodeAttrs& n) 386 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
387 return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
388 })
389 .set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
390 .set_attr<bool>("TIsMKLDNN", true)
391 #endif // MXNET_USE_MKLDNN == 1
392 CONCAT_FORWARD_ATTRS
393 .set_attr<mxnet::FInferShape>("FInferShape", ConcatShape)
394 .add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
395 .add_arguments(ConcatParam::__FIELDS__());
396
397 NNVM_REGISTER_OP(_backward_Concat)
__anonf271f7310202(const NodeAttrs& attrs) 398 .set_num_inputs([](const NodeAttrs& attrs) {
399 #if MXNET_USE_MKLDNN == 1
400 const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
401 return 1 + params.num_args;
402 #else
403 return 1;
404 #endif
405 })
__anonf271f7310302(const NodeAttrs& attrs) 406 .set_num_outputs([](const NodeAttrs& attrs) {
407 const ConcatParam& params = nnvm::get<ConcatParam>(attrs.parsed);
408 return params.num_args;
409 })
410 .set_attr_parser(ParamParser<ConcatParam>)
411 #if MXNET_USE_MKLDNN == 1
__anonf271f7310402(const NodeAttrs& n) 412 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
413 return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
414 })
415 #endif // MXNET_USE_MKLDNN == 1
416 .set_attr<nnvm::TIsBackward>("TIsBackward", true)
417 .set_attr<FInferStorageType>("FInferStorageType", BackwardConcatStorageType)
418 #if MXNET_USE_MKLDNN == 1
419 .set_attr<bool>("TIsMKLDNN", true)
420 .set_attr<FComputeEx>("FComputeEx<cpu>", ConcatGradComputeExCPU)
421 #endif // MXNET_USE_MKLDNN == 1
422 .set_attr<FCompute>("FCompute<cpu>", ConcatGradCompute<cpu>);
423
424 // _rnn_param_concat is a custom concat op with specialized infer_shape,
425 // which handles the case where the first one or two inputs may have
426 // unknown shape that can be inferred from output shape.
427 NNVM_REGISTER_OP(_rnn_param_concat)
428 .add_alias("_npi_rnn_param_concat")
429 #if MXNET_USE_MKLDNN == 1
__anonf271f7310502(const NodeAttrs& n) 430 .set_attr<FResourceRequest>("FResourceRequest", [](const NodeAttrs& n) {
431 return std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
432 })
433 #endif // MXNET_USE_MKLDNN == 1
434 CONCAT_FORWARD_ATTRS
435 .set_attr<THasDeterministicOutput>("THasDeterministicOutput", true)
436 .set_attr<mxnet::FInferShape>("FInferShape", RNNParamConcatShape)
437 .add_argument("data", "NDArray-or-Symbol[]", "List of arrays to concatenate")
438 .add_arguments(ConcatParam::__FIELDS__());
439
440 } // namespace op
441 } // namespace mxnet
442