1 /*******************************************************************************
2 * Copyright 2018-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <assert.h>
18 
19 #include "oneapi/dnnl/dnnl.h"
20 
21 #include "c_types_map.hpp"
22 #include "concat_pd.hpp"
23 #include "engine.hpp"
24 #include "impl_list_item.hpp"
25 #include "primitive_cache.hpp"
26 #include "primitive_hashing.hpp"
27 #include "type_helpers.hpp"
28 #include "utils.hpp"
29 
30 using namespace dnnl::impl;
31 using namespace dnnl::impl::utils;
32 using namespace dnnl::impl::status;
33 
34 namespace dnnl {
35 namespace impl {
36 
concat_primitive_desc_create(primitive_desc_iface_t ** concat_pd_iface,const memory_desc_t * dst_md,int n,int concat_dim,const memory_desc_t * src_mds,const primitive_attr_t * attr,engine_t * engine)37 status_t concat_primitive_desc_create(primitive_desc_iface_t **concat_pd_iface,
38         const memory_desc_t *dst_md, int n, int concat_dim,
39         const memory_desc_t *src_mds, const primitive_attr_t *attr,
40         engine_t *engine) {
41 
42     bool args_ok = !any_null(concat_pd_iface, src_mds) && n > 0;
43     if (!args_ok) return invalid_arguments;
44 
45     if (attr == nullptr) attr = &default_attr();
46 
47     const int ndims = src_mds[0].ndims;
48     const dims_t &dims = src_mds[0].dims;
49     const data_type_t dt = src_mds[0].data_type;
50     if (memory_desc_wrapper(src_mds[0]).has_runtime_dims_or_strides())
51         return unimplemented;
52 
53     int concat_dim_sz = dims[concat_dim];
54     for (int i = 1; i < n; ++i) {
55         if (src_mds[i].ndims != ndims) return invalid_arguments;
56         if (memory_desc_wrapper(src_mds[i]).has_runtime_dims_or_strides())
57             return unimplemented;
58 
59         for (int d = 0; d < ndims; ++d) {
60             if (d == concat_dim) continue;
61             if (src_mds[i].dims[d] != dims[d]) return invalid_arguments;
62         }
63         if (src_mds[i].data_type != dt) return invalid_arguments;
64         concat_dim_sz += src_mds[i].dims[concat_dim];
65     }
66 
67     memory_desc_t dummy_dst_md;
68     if (dst_md) {
69         if (dst_md->ndims != ndims) return invalid_arguments;
70         if (memory_desc_wrapper(dst_md).has_runtime_dims_or_strides())
71             return unimplemented;
72         for (int d = 0; d < ndims; ++d) {
73             if (dst_md->dims[d] != (d == concat_dim ? concat_dim_sz : dims[d]))
74                 return invalid_arguments;
75         }
76     } else {
77         dummy_dst_md = src_mds[0];
78         dummy_dst_md.dims[concat_dim] = concat_dim_sz;
79         dummy_dst_md.format_kind = format_kind::any;
80         dst_md = &dummy_dst_md;
81     }
82 
83     dnnl_concat_desc_t desc
84             = {primitive_kind::concat, dst_md, n, concat_dim, src_mds};
85     primitive_hashing::key_t key(
86             engine, reinterpret_cast<op_desc_t *>(&desc), attr, 0, {});
87     auto pd = primitive_cache().get_pd(key);
88 
89     if (pd) {
90         return safe_ptr_assign(
91                 *concat_pd_iface, new primitive_desc_iface_t(pd, engine));
92     }
93 
94     concat_pd_t *concat_pd = nullptr;
95     for (auto c = engine->get_concat_implementation_list(); *c; ++c) {
96         if ((*c)(&concat_pd, engine, attr, dst_md, n, concat_dim, src_mds)
97                 == success) {
98             pd.reset(concat_pd);
99             CHECK(safe_ptr_assign(
100                     *concat_pd_iface, new primitive_desc_iface_t(pd, engine)));
101             return status::success;
102         }
103     }
104     return unimplemented;
105 }
106 
107 } // namespace impl
108 } // namespace dnnl
109 
dnnl_concat_primitive_desc_create(primitive_desc_iface_t ** concat_pd_iface,const memory_desc_t * dst_md,int n,int concat_dim,const memory_desc_t * src_mds,const primitive_attr_t * attr,engine_t * engine)110 status_t dnnl_concat_primitive_desc_create(
111         primitive_desc_iface_t **concat_pd_iface, const memory_desc_t *dst_md,
112         int n, int concat_dim, const memory_desc_t *src_mds,
113         const primitive_attr_t *attr, engine_t *engine) {
114     return concat_primitive_desc_create(
115             concat_pd_iface, dst_md, n, concat_dim, src_mds, attr, engine);
116 }
117