1 /*******************************************************************************
2 * Copyright 2018-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18
19 #include "oneapi/dnnl/dnnl.h"
20
21 #include "c_types_map.hpp"
22 #include "concat_pd.hpp"
23 #include "engine.hpp"
24 #include "impl_list_item.hpp"
25 #include "primitive_cache.hpp"
26 #include "primitive_hashing.hpp"
27 #include "type_helpers.hpp"
28 #include "utils.hpp"
29
30 using namespace dnnl::impl;
31 using namespace dnnl::impl::utils;
32 using namespace dnnl::impl::status;
33
34 namespace dnnl {
35 namespace impl {
36
concat_primitive_desc_create(primitive_desc_iface_t ** concat_pd_iface,const memory_desc_t * dst_md,int n,int concat_dim,const memory_desc_t * src_mds,const primitive_attr_t * attr,engine_t * engine)37 status_t concat_primitive_desc_create(primitive_desc_iface_t **concat_pd_iface,
38 const memory_desc_t *dst_md, int n, int concat_dim,
39 const memory_desc_t *src_mds, const primitive_attr_t *attr,
40 engine_t *engine) {
41
42 bool args_ok = !any_null(concat_pd_iface, src_mds) && n > 0;
43 if (!args_ok) return invalid_arguments;
44
45 if (attr == nullptr) attr = &default_attr();
46
47 const int ndims = src_mds[0].ndims;
48 const dims_t &dims = src_mds[0].dims;
49 const data_type_t dt = src_mds[0].data_type;
50 if (memory_desc_wrapper(src_mds[0]).has_runtime_dims_or_strides())
51 return unimplemented;
52
53 int concat_dim_sz = dims[concat_dim];
54 for (int i = 1; i < n; ++i) {
55 if (src_mds[i].ndims != ndims) return invalid_arguments;
56 if (memory_desc_wrapper(src_mds[i]).has_runtime_dims_or_strides())
57 return unimplemented;
58
59 for (int d = 0; d < ndims; ++d) {
60 if (d == concat_dim) continue;
61 if (src_mds[i].dims[d] != dims[d]) return invalid_arguments;
62 }
63 if (src_mds[i].data_type != dt) return invalid_arguments;
64 concat_dim_sz += src_mds[i].dims[concat_dim];
65 }
66
67 memory_desc_t dummy_dst_md;
68 if (dst_md) {
69 if (dst_md->ndims != ndims) return invalid_arguments;
70 if (memory_desc_wrapper(dst_md).has_runtime_dims_or_strides())
71 return unimplemented;
72 for (int d = 0; d < ndims; ++d) {
73 if (dst_md->dims[d] != (d == concat_dim ? concat_dim_sz : dims[d]))
74 return invalid_arguments;
75 }
76 } else {
77 dummy_dst_md = src_mds[0];
78 dummy_dst_md.dims[concat_dim] = concat_dim_sz;
79 dummy_dst_md.format_kind = format_kind::any;
80 dst_md = &dummy_dst_md;
81 }
82
83 dnnl_concat_desc_t desc
84 = {primitive_kind::concat, dst_md, n, concat_dim, src_mds};
85 primitive_hashing::key_t key(
86 engine, reinterpret_cast<op_desc_t *>(&desc), attr, 0, {});
87 auto pd = primitive_cache().get_pd(key);
88
89 if (pd) {
90 return safe_ptr_assign(
91 *concat_pd_iface, new primitive_desc_iface_t(pd, engine));
92 }
93
94 concat_pd_t *concat_pd = nullptr;
95 for (auto c = engine->get_concat_implementation_list(); *c; ++c) {
96 if ((*c)(&concat_pd, engine, attr, dst_md, n, concat_dim, src_mds)
97 == success) {
98 pd.reset(concat_pd);
99 CHECK(safe_ptr_assign(
100 *concat_pd_iface, new primitive_desc_iface_t(pd, engine)));
101 return status::success;
102 }
103 }
104 return unimplemented;
105 }
106
107 } // namespace impl
108 } // namespace dnnl
109
dnnl_concat_primitive_desc_create(primitive_desc_iface_t ** concat_pd_iface,const memory_desc_t * dst_md,int n,int concat_dim,const memory_desc_t * src_mds,const primitive_attr_t * attr,engine_t * engine)110 status_t dnnl_concat_primitive_desc_create(
111 primitive_desc_iface_t **concat_pd_iface, const memory_desc_t *dst_md,
112 int n, int concat_dim, const memory_desc_t *src_mds,
113 const primitive_attr_t *attr, engine_t *engine) {
114 return concat_primitive_desc_create(
115 concat_pd_iface, dst_md, n, concat_dim, src_mds, attr, engine);
116 }
117