1 /*******************************************************************************
2 * Copyright 2017-2020 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef CPU_SIMPLE_CONCAT_HPP
18 #define CPU_SIMPLE_CONCAT_HPP
19 
20 #include "common/memory_tracking.hpp"
21 #include "common/primitive.hpp"
22 
23 #include "cpu/platform.hpp"
24 
25 #include "cpu/cpu_concat_pd.hpp"
26 
27 namespace dnnl {
28 namespace impl {
29 namespace cpu {
30 
31 template <data_type_t data_type>
32 struct simple_concat_t : public primitive_t {
33     struct pd_t : public cpu_concat_pd_t {
34         using cpu_concat_pd_t::cpu_concat_pd_t;
35 
pd_tdnnl::impl::cpu::simple_concat_t::pd_t36         pd_t(const pd_t &rhs) : cpu_concat_pd_t(rhs) { copy_from(rhs); }
37 
38         DECLARE_CONCAT_PD_T("simple:any", simple_concat_t);
39 
initdnnl::impl::cpu::simple_concat_t::pd_t40         status_t init(engine_t *engine) {
41             const memory_desc_wrapper dst_d(dst_md());
42             bool ok = platform::has_data_type_support(data_type)
43                     && cpu_concat_pd_t::init() == status::success
44                     && dst_d.ndims() <= 6;
45             if (!ok) return status::unimplemented;
46 
47             for (size_t i = 0; i < src_mds_.size(); ++i) {
48                 const memory_desc_wrapper i_d(&src_mds_[i]);
49                 const memory_desc_wrapper o_d(&src_image_mds_[i]);
50 
51                 const bool ignore_strides = true;
52 
53                 ok = ok
54                         && utils::everyone_is(
55                                 data_type, i_d.data_type(), o_d.data_type())
56                         && utils::everyone_is(format_kind::blocked,
57                                 i_d.format_kind(), o_d.format_kind())
58                         && types::blocking_desc_is_equal(
59                                 *i_d.md_, *o_d.md_, ignore_strides)
60                         && types::blocking_desc_is_equal(
61                                 *i_d.md_, *dst_d.md_, ignore_strides)
62                         && !i_d.is_additional_buffer();
63                 if (!ok) return status::unimplemented;
64             }
65 
66             dst_d.compute_blocks(blocks_);
67             format_perm();
68 
69             // start dim is the first dimension after which the concatenation
70             // would happen contiguously
71             const int start_dim = perm_[concat_dim()];
72 
73             // check that contiguous part is indeed contiguous (i.e. dense)
74             if (nelems_to_concat(dst_d)
75                     != dst_d.padded_dims()[concat_dim()] / blocks_[concat_dim()]
76                             * dst_d.blocking_desc().strides[concat_dim()])
77                 return status::unimplemented;
78 
79             // check that all inputs have the same strides for the
80             // contiguous part [concat_dim .. ndims] for the *major* dims.
81             // the block part is already checked above
82             for (size_t i = 0; i < src_mds_.size(); ++i) {
83                 const memory_desc_wrapper i_d(&src_mds_[i]);
84                 for (int d = start_dim; d < dst_d.ndims(); ++d) {
85                     if (dst_d.blocking_desc().strides[iperm_[d]]
86                             != i_d.blocking_desc().strides[iperm_[d]])
87                         return status::unimplemented;
88                 }
89             }
90 
91             init_scratchpad();
92 
93             return status::success;
94         }
95 
96         int perm_[DNNL_MAX_NDIMS] {};
97         int iperm_[DNNL_MAX_NDIMS] {};
98         dims_t blocks_ {};
99 
nelems_to_concatdnnl::impl::cpu::simple_concat_t::pd_t100         dim_t nelems_to_concat(const memory_desc_wrapper &data_d) const {
101             const int ndims = data_d.ndims();
102 
103             dim_t nelems = 1;
104             for (int i = perm_[concat_dim()]; i < ndims; i++)
105                 nelems *= data_d.padded_dims()[iperm_[i]] / blocks_[iperm_[i]];
106             for (int i = 0; i < ndims; i++)
107                 nelems *= blocks_[i];
108 
109             return nelems;
110         }
111 
112     private:
format_permdnnl::impl::cpu::simple_concat_t::pd_t113         void format_perm() {
114             const memory_desc_wrapper dst_d(dst_md());
115             const int ndims = dst_d.ndims();
116 
117             dims_t blocks = {0};
118             dst_d.compute_blocks(blocks);
119 
120             strides_t strides = {0};
121             utils::array_copy(strides, dst_d.blocking_desc().strides, ndims);
122 
123             dims_t ou_blocks = {0};
124             utils::array_copy(ou_blocks, dst_d.padded_dims(), ndims);
125 
126             for (int d = 0; d < ndims; d++) {
127                 iperm_[d] = d;
128                 ou_blocks[d] /= blocks[d];
129             }
130 
131             utils::simultaneous_sort(strides, ou_blocks, iperm_, ndims,
132                     [](stride_t a, stride_t b) { return b - a; });
133 
134             for (int i = 0; i < ndims; i++)
135                 perm_[iperm_[i]] = i;
136         }
137 
init_scratchpaddnnl::impl::cpu::simple_concat_t::pd_t138         void init_scratchpad() {
139             using namespace memory_tracking::names;
140             auto scratchpad = scratchpad_registry().registrar();
141             scratchpad.template book<data_t *>(key_concat_iptrs, n_inputs());
142             scratchpad.template book<data_t *>(key_concat_optrs, n_inputs());
143             scratchpad.template book<dim_t>(key_concat_nelems, n_inputs());
144             scratchpad.template book<strides_t>(
145                     key_concat_istrides, n_inputs());
146         }
147 
copy_fromdnnl::impl::cpu::simple_concat_t::pd_t148         void copy_from(const pd_t &rhs) {
149             int ndims = rhs.dst_md_.ndims;
150             utils::array_copy(perm_, rhs.perm_, ndims);
151             utils::array_copy(iperm_, rhs.iperm_, ndims);
152             utils::array_copy(blocks_, rhs.blocks_, ndims);
153         }
154     };
155 
simple_concat_tdnnl::impl::cpu::simple_concat_t156     simple_concat_t(const pd_t *apd) : primitive_t(apd) {}
157 
158     status_t execute(const exec_ctx_t &ctx) const override;
159 
160     typedef typename prec_traits<data_type>::type data_t;
161 
162 private:
pddnnl::impl::cpu::simple_concat_t163     const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
164 };
165 
166 } // namespace cpu
167 } // namespace impl
168 } // namespace dnnl
169 
170 #endif
171