1 /******************************************************************************* 2 * Copyright 2017-2020 Intel Corporation 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 *******************************************************************************/ 16 17 #ifndef CPU_SIMPLE_CONCAT_HPP 18 #define CPU_SIMPLE_CONCAT_HPP 19 20 #include "common/memory_tracking.hpp" 21 #include "common/primitive.hpp" 22 23 #include "cpu/platform.hpp" 24 25 #include "cpu/cpu_concat_pd.hpp" 26 27 namespace dnnl { 28 namespace impl { 29 namespace cpu { 30 31 template <data_type_t data_type> 32 struct simple_concat_t : public primitive_t { 33 struct pd_t : public cpu_concat_pd_t { 34 using cpu_concat_pd_t::cpu_concat_pd_t; 35 pd_tdnnl::impl::cpu::simple_concat_t::pd_t36 pd_t(const pd_t &rhs) : cpu_concat_pd_t(rhs) { copy_from(rhs); } 37 38 DECLARE_CONCAT_PD_T("simple:any", simple_concat_t); 39 initdnnl::impl::cpu::simple_concat_t::pd_t40 status_t init(engine_t *engine) { 41 const memory_desc_wrapper dst_d(dst_md()); 42 bool ok = platform::has_data_type_support(data_type) 43 && cpu_concat_pd_t::init() == status::success 44 && dst_d.ndims() <= 6; 45 if (!ok) return status::unimplemented; 46 47 for (size_t i = 0; i < src_mds_.size(); ++i) { 48 const memory_desc_wrapper i_d(&src_mds_[i]); 49 const memory_desc_wrapper o_d(&src_image_mds_[i]); 50 51 const bool ignore_strides = true; 52 53 ok = ok 54 && utils::everyone_is( 55 data_type, i_d.data_type(), o_d.data_type()) 56 && utils::everyone_is(format_kind::blocked, 57 i_d.format_kind(), o_d.format_kind()) 58 && types::blocking_desc_is_equal( 59 *i_d.md_, *o_d.md_, ignore_strides) 60 && types::blocking_desc_is_equal( 61 *i_d.md_, *dst_d.md_, ignore_strides) 62 && !i_d.is_additional_buffer(); 63 if (!ok) return status::unimplemented; 64 } 65 66 dst_d.compute_blocks(blocks_); 67 format_perm(); 68 69 // start dim is the first dimension after which the concatenation 70 // would happen contiguously 71 const int start_dim = perm_[concat_dim()]; 72 73 // check that contiguous part is indeed contiguous (i.e. dense) 74 if (nelems_to_concat(dst_d) 75 != dst_d.padded_dims()[concat_dim()] / blocks_[concat_dim()] 76 * dst_d.blocking_desc().strides[concat_dim()]) 77 return status::unimplemented; 78 79 // check that all inputs have the same strides for the 80 // contiguous part [concat_dim .. ndims] for the *major* dims. 81 // the block part is already checked above 82 for (size_t i = 0; i < src_mds_.size(); ++i) { 83 const memory_desc_wrapper i_d(&src_mds_[i]); 84 for (int d = start_dim; d < dst_d.ndims(); ++d) { 85 if (dst_d.blocking_desc().strides[iperm_[d]] 86 != i_d.blocking_desc().strides[iperm_[d]]) 87 return status::unimplemented; 88 } 89 } 90 91 init_scratchpad(); 92 93 return status::success; 94 } 95 96 int perm_[DNNL_MAX_NDIMS] {}; 97 int iperm_[DNNL_MAX_NDIMS] {}; 98 dims_t blocks_ {}; 99 nelems_to_concatdnnl::impl::cpu::simple_concat_t::pd_t100 dim_t nelems_to_concat(const memory_desc_wrapper &data_d) const { 101 const int ndims = data_d.ndims(); 102 103 dim_t nelems = 1; 104 for (int i = perm_[concat_dim()]; i < ndims; i++) 105 nelems *= data_d.padded_dims()[iperm_[i]] / blocks_[iperm_[i]]; 106 for (int i = 0; i < ndims; i++) 107 nelems *= blocks_[i]; 108 109 return nelems; 110 } 111 112 private: format_permdnnl::impl::cpu::simple_concat_t::pd_t113 void format_perm() { 114 const memory_desc_wrapper dst_d(dst_md()); 115 const int ndims = dst_d.ndims(); 116 117 dims_t blocks = {0}; 118 dst_d.compute_blocks(blocks); 119 120 strides_t strides = {0}; 121 utils::array_copy(strides, dst_d.blocking_desc().strides, ndims); 122 123 dims_t ou_blocks = {0}; 124 utils::array_copy(ou_blocks, dst_d.padded_dims(), ndims); 125 126 for (int d = 0; d < ndims; d++) { 127 iperm_[d] = d; 128 ou_blocks[d] /= blocks[d]; 129 } 130 131 utils::simultaneous_sort(strides, ou_blocks, iperm_, ndims, 132 [](stride_t a, stride_t b) { return b - a; }); 133 134 for (int i = 0; i < ndims; i++) 135 perm_[iperm_[i]] = i; 136 } 137 init_scratchpaddnnl::impl::cpu::simple_concat_t::pd_t138 void init_scratchpad() { 139 using namespace memory_tracking::names; 140 auto scratchpad = scratchpad_registry().registrar(); 141 scratchpad.template book<data_t *>(key_concat_iptrs, n_inputs()); 142 scratchpad.template book<data_t *>(key_concat_optrs, n_inputs()); 143 scratchpad.template book<dim_t>(key_concat_nelems, n_inputs()); 144 scratchpad.template book<strides_t>( 145 key_concat_istrides, n_inputs()); 146 } 147 copy_fromdnnl::impl::cpu::simple_concat_t::pd_t148 void copy_from(const pd_t &rhs) { 149 int ndims = rhs.dst_md_.ndims; 150 utils::array_copy(perm_, rhs.perm_, ndims); 151 utils::array_copy(iperm_, rhs.iperm_, ndims); 152 utils::array_copy(blocks_, rhs.blocks_, ndims); 153 } 154 }; 155 simple_concat_tdnnl::impl::cpu::simple_concat_t156 simple_concat_t(const pd_t *apd) : primitive_t(apd) {} 157 158 status_t execute(const exec_ctx_t &ctx) const override; 159 160 typedef typename prec_traits<data_type>::type data_t; 161 162 private: pddnnl::impl::cpu::simple_concat_t163 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } 164 }; 165 166 } // namespace cpu 167 } // namespace impl 168 } // namespace dnnl 169 170 #endif 171