1 /******************************************************************************* 2 * Copyright 2017-2021 Intel Corporation 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 *******************************************************************************/ 16 17 #ifndef CPU_X64_JIT_AVX512_COMMON_1X1_CONVOLUTION_HPP 18 #define CPU_X64_JIT_AVX512_COMMON_1X1_CONVOLUTION_HPP 19 20 #include "common/c_types_map.hpp" 21 #include "common/dnnl_thread.hpp" 22 #include "common/memory_tracking.hpp" 23 #include "common/primitive.hpp" 24 #include "common/primitive_hashing.hpp" 25 #include "common/utils.hpp" 26 27 #include "cpu/cpu_convolution_pd.hpp" 28 #include "cpu/dw_convolution_utils.hpp" 29 #include "cpu/platform.hpp" 30 31 #include "cpu/x64/cpu_reducer.hpp" 32 #include "cpu/x64/jit_avx512_common_1x1_conv_kernel.hpp" 33 #include "cpu/x64/jit_transpose_utils.hpp" 34 #include "cpu/x64/jit_uni_1x1_conv_utils.hpp" 35 #include "cpu/x64/jit_uni_dw_convolution.hpp" 36 37 namespace dnnl { 38 namespace impl { 39 namespace cpu { 40 namespace x64 { 41 42 template <impl::data_type_t src_type, impl::data_type_t wei_type = src_type, 43 impl::data_type_t dst_type = src_type> 44 struct jit_avx512_common_1x1_convolution_fwd_t : public primitive_t { 45 struct pd_t : public cpu_convolution_fwd_pd_t { pd_tdnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_fwd_t::pd_t46 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, 47 const typename pd_t::base_class *hint_fwd_pd) 48 : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd) 49 , jcp_() 50 , rtus_() {} 51 pd_tdnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_fwd_t::pd_t52 pd_t(const pd_t &other) : cpu_convolution_fwd_pd_t(other) { 53 if (copy(other) != status::success) is_initialized_ = false; 54 } 55 56 DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_1x1:", avx512_common, ""), 57 jit_avx512_common_1x1_convolution_fwd_t); 58 initdnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_fwd_t::pd_t59 status_t init(engine_t *engine) { 60 using namespace utils; 61 bool ok = true && is_fwd() 62 && set_default_alg_kind(alg_kind::convolution_direct) 63 && expect_data_types(src_type, wei_type, dst_type, dst_type, 64 data_type::undef) 65 && attr()->has_default_values( 66 primitive_attr_t::skip_mask_t::post_ops, dst_type) 67 && !has_zero_dim_memory() && set_default_formats(); 68 if (!ok) return status::unimplemented; 69 70 const convolution_desc_t *conv_d = desc(); 71 const memory_desc_t *src_d = src_md(); 72 rtus_prepare(this, conv_d, src_d, dst_md(), weights_md()); 73 74 status_t status = jit_avx512_common_1x1_conv_kernel::init_conf(jcp_, 75 *conv_d, *src_d, *weights_md(), *dst_md(), *attr(), 76 dnnl_get_max_threads(), rtus_.reduce_src_); 77 if (status != status::success) return status; 78 79 if (jcp_.with_dw_conv) { 80 status = depthwise_po_init(engine); 81 if (status != status::success) return status; 82 } 83 84 auto scratchpad = scratchpad_registry().registrar(); 85 jit_avx512_common_1x1_conv_kernel::init_scratchpad( 86 scratchpad, jcp_); 87 88 rtus_prepare_space_info(this, scratchpad, jcp_.nthr); 89 90 return status::success; 91 } 92 dst_mddnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_fwd_t::pd_t93 const memory_desc_t *dst_md(int index = 0) const override { 94 return jcp_.with_dw_conv ? dw_conv_pd_->dst_md(index) : &dst_md_; 95 } 96 arg_mddnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_fwd_t::pd_t97 const memory_desc_t *arg_md(int index = 0) const override { 98 if (jcp_.with_dw_conv) { 99 switch (index) { 100 case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS: 101 return dw_conv_pd_->weights_md(0); 102 case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS: 103 return dw_conv_pd_->weights_md(1); 104 default: break; 105 } 106 } 107 return convolution_fwd_pd_t::arg_md(index); 108 } 109 arg_usagednnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_fwd_t::pd_t110 arg_usage_t arg_usage(int arg) const override { 111 if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS)) 112 return arg_usage_t::input; 113 114 if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS) 115 && attr_post_op_dw_inputs() > 1) 116 return arg_usage_t::input; 117 118 return convolution_fwd_pd_t::arg_usage(arg); 119 } 120 121 jit_1x1_conv_conf_t jcp_; 122 reduce_to_unit_stride_t rtus_; 123 using dw_pd_t = jit_avx512_common_dw_convolution_fwd_t::pd_t; 124 std::unique_ptr<dw_pd_t> dw_conv_pd_; 125 126 protected: set_default_formatsdnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_fwd_t::pd_t127 bool set_default_formats() { 128 using namespace format_tag; 129 130 auto dat_tag = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c); 131 auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(), 132 OIw16i16o, gOIw16i16o, OIhw16i16o, gOIhw16i16o, OIdhw16i16o, 133 gOIdhw16i16o); 134 135 return set_default_formats_common(dat_tag, wei_tag, dat_tag); 136 } 137 copydnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_fwd_t::pd_t138 status_t copy(const pd_t &other) { 139 jcp_ = other.jcp_; 140 rtus_ = other.rtus_; 141 if (other.dw_conv_pd_) { 142 dw_conv_pd_.reset(other.dw_conv_pd_->clone()); 143 if (!dw_conv_pd_) return status::out_of_memory; 144 } 145 return status::success; 146 } 147 depthwise_po_initdnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_fwd_t::pd_t148 status_t depthwise_po_init(engine_t *engine) { 149 150 using namespace memory_tracking; 151 auto &jcp_1x1 = jcp_; 152 primitive_attr_t attr_1x1(*attr()); 153 if (!attr_1x1.is_initialized()) return status::out_of_memory; 154 const auto &src_md = dst_md_; 155 const memory_desc_wrapper src_d(src_md); 156 const auto nthr = dnnl_get_max_threads(); 157 auto l2_cache = platform::get_per_core_cache_size(2) * nthr; 158 159 // Note: A robust fusion implementation would be to check if both 160 // 1x1 conv and dw conv that are considered here for fusion are 161 // optimal independently. This would require creating a new 162 // primitive_desc through primitive_iterator & check if they match. 163 // Due to concern that these creations and/or checks could be heavy, 164 // for 1x1: Check that no better ISA is available. 165 // for dw: Always fuse with same ISA. 166 // Caveat: May be a better dw conv exists. 167 168 // TODO: Add a check if better ISA exists following above note. 169 bool ok = true 170 && (attr_1x1.post_ops_.find(primitive_kind::sum) == -1) 171 // TODO: Below may be further tuned. 172 && (l2_cache * 2 < src_d.size()) 173 // load_grp_count check can be redundant due to l2 check 174 // above. Adding it explicitly as the current driver doesn't 175 // work if this condition fails. 176 && (jcp_1x1.load_grp_count < 2); 177 if (!ok) return status::unimplemented; 178 179 int dw_po_index 180 = attr_1x1.post_ops_.find(primitive_kind::convolution); 181 convolution_desc_t cd_dw; 182 primitive_attr_t attr_dw; 183 CHECK(get_depthwise_conv_desc( 184 cd_dw, src_md, attr_1x1, attr_dw, dw_po_index)); 185 186 CHECK(safe_ptr_assign( 187 dw_conv_pd_, new dw_pd_t(&cd_dw, &attr_dw, nullptr))); 188 CHECK(dw_conv_pd_->init(engine)); 189 auto &jcp_dw = dw_conv_pd_->jcp_; 190 191 ok = true 192 && (dnnl_memory_desc_equal(&src_md, dw_conv_pd_->src_md(0))) 193 && (jcp_1x1.oc_without_padding % jcp_1x1.oc_block == 0) 194 && IMPLICATION( 195 jcp_dw.ow_block, jcp_dw.ow_block == jcp_dw.ow); 196 if (!ok) return status::unimplemented; 197 198 assert(dw_conv_pd_->dst_md(0)->format_kind != format_kind::any); 199 assert(dw_conv_pd_->weights_md(0)->format_kind != format_kind::any); 200 assert(IMPLICATION( 201 dw_conv_pd_->weights_md(1)->data_type != data_type::undef, 202 dw_conv_pd_->weights_md(1)->format_kind 203 != format_kind::any)); 204 205 jcp_dw.is_fused_conv = true; 206 // TODO: Support/experiment arbitary oc_work in dw conv. 207 // Until then we keep oc_work perfectly divisible. 208 while (jcp_1x1.nb_load % jcp_1x1.nb_load_blocking != 0) 209 --jcp_1x1.nb_load_blocking; 210 jcp_1x1.nb_load_blocking_max = jcp_1x1.nb_load_blocking; 211 212 while (jcp_1x1.nb_load_blocking % jcp_dw.nb_ch_blocking != 0) 213 --jcp_dw.nb_ch_blocking; 214 215 jcp_dw.dw_conv_buffer_oc 216 = jcp_1x1.nb_load_blocking * jcp_1x1.oc_block; 217 jcp_1x1.bcast_loop_output_step 218 = jcp_1x1.ur * jcp_1x1.load_block * jcp_1x1.typesize_out; 219 220 registrar_t scratchpad(scratchpad_registry_); 221 registrar_t dw_scratchpad(scratchpad, names::prefix_fusion); 222 223 size_t dw_conv_buffer_size_ = (size_t)nthr * jcp_dw.kh * jcp_dw.iw 224 * jcp_dw.dw_conv_buffer_oc; 225 assert(dw_conv_buffer_size_); 226 dw_scratchpad.book(memory_tracking::names::key_fusion_inout_buffer, 227 dw_conv_buffer_size_, 228 types::data_type_size(dw_conv_pd_->src_md()->data_type)); 229 230 jit_uni_dw_conv_fwd_kernel<avx512_common, 231 data_type::f32>::init_scratchpad(dw_scratchpad, jcp_dw); 232 233 return status::success; 234 } 235 }; 236 237 template <cpu_isa_t isa, typename conv_t> 238 friend status_t init_rtus_driver(conv_t *self); 239 jit_avx512_common_1x1_convolution_fwd_tdnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_fwd_t240 jit_avx512_common_1x1_convolution_fwd_t(const pd_t *apd) 241 : primitive_t(apd) {} 242 243 typedef typename prec_traits<src_type>::type src_data_t; 244 typedef typename prec_traits<wei_type>::type wei_data_t; 245 typedef typename prec_traits<dst_type>::type dst_data_t; 246 initdnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_fwd_t247 status_t init(engine_t *engine) override { 248 CHECK(safe_ptr_assign(kernel_, 249 new jit_avx512_common_1x1_conv_kernel( 250 pd()->jcp_, *pd()->attr(), *pd()->dst_md(0)))); 251 CHECK(kernel_->create_kernel()); 252 253 if (pd()->jcp_.with_dw_conv) { 254 CHECK(safe_ptr_assign(kernel_dw_, 255 new dw_conv_kernel_t( 256 pd()->dw_conv_pd_->jcp_, *pd()->dst_md(0)))); 257 CHECK(kernel_dw_->create_kernel()); 258 } 259 260 CHECK(init_rtus_driver<avx512_common>(this)); 261 return status::success; 262 } 263 executednnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_fwd_t264 status_t execute(const exec_ctx_t &ctx) const override { 265 execute_forward(ctx); 266 return status::success; 267 } 268 269 private: 270 void execute_forward(const exec_ctx_t &ctx) const; 271 void execute_forward_thr(const int ithr, const int nthr, 272 const src_data_t *src, const wei_data_t *weights, 273 const dst_data_t *bias, const wei_data_t *weights_dw, 274 const dst_data_t *bias_dw, dst_data_t *dst, 275 const memory_tracking::grantor_t &scratchpad, 276 const void *post_ops_binary_rhs_arg_vec, 277 const void *post_ops_binary_rhs_arg_vec_dw) const; pddnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_fwd_t278 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } 279 280 std::unique_ptr<jit_avx512_common_1x1_conv_kernel> kernel_; 281 std::unique_ptr<rtus_driver_t<avx512_common>> rtus_driver_; 282 using dw_conv_kernel_t = jit_uni_dw_conv_fwd_kernel_f32<avx512_common>; 283 std::unique_ptr<dw_conv_kernel_t> kernel_dw_; 284 }; 285 286 using jit_avx512_common_1x1_convolution_fwd_f32_t 287 = jit_avx512_common_1x1_convolution_fwd_t<data_type::f32>; 288 289 template <impl::data_type_t diff_dst_type, 290 impl::data_type_t wei_type = diff_dst_type, 291 impl::data_type_t diff_src_type = diff_dst_type> 292 struct jit_avx512_common_1x1_convolution_bwd_data_t : public primitive_t { 293 struct pd_t : public cpu_convolution_bwd_data_pd_t { pd_tdnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_bwd_data_t::pd_t294 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, 295 const convolution_fwd_pd_t *hint_fwd_pd) 296 : cpu_convolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd) 297 , jcp_() 298 , rtus_() {} 299 300 DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_1x1:", avx512_common, ""), 301 jit_avx512_common_1x1_convolution_bwd_data_t); 302 initdnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_bwd_data_t::pd_t303 status_t init(engine_t *engine) { 304 bool ok = true && desc()->prop_kind == prop_kind::backward_data 305 && set_default_alg_kind(alg_kind::convolution_direct) 306 && expect_data_types(diff_src_type, wei_type, 307 data_type::undef, diff_dst_type, data_type::undef) 308 && attr()->has_default_values() && !has_zero_dim_memory() 309 && set_default_formats(); 310 if (!ok) return status::unimplemented; 311 312 const convolution_desc_t *conv_d = desc(); 313 const memory_desc_t *diff_src_d = diff_src_md(); 314 rtus_prepare(this, conv_d, diff_src_d, diff_dst_md(), weights_md()); 315 316 status_t status = jit_avx512_common_1x1_conv_kernel::init_conf(jcp_, 317 *conv_d, *diff_src_d, *weights_md(), *diff_dst_md(), 318 *attr(), dnnl_get_max_threads(), rtus_.reduce_src_); 319 if (status != status::success) return status; 320 321 auto scratchpad = scratchpad_registry().registrar(); 322 jit_avx512_common_1x1_conv_kernel::init_scratchpad( 323 scratchpad, jcp_); 324 325 rtus_prepare_space_info(this, scratchpad, jcp_.nthr); 326 327 return status::success; 328 } 329 330 // TODO (Roma): structs conf header cleanup 331 jit_1x1_conv_conf_t jcp_; 332 reduce_to_unit_stride_t rtus_; 333 334 protected: set_default_formatsdnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_bwd_data_t::pd_t335 bool set_default_formats() { 336 using namespace format_tag; 337 338 auto dat_tag = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c); 339 auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(), 340 IOw16o16i, gIOw16o16i, IOhw16o16i, gIOhw16o16i, IOdhw16o16i, 341 gIOdhw16o16i); 342 343 return set_default_formats_common(dat_tag, wei_tag, dat_tag); 344 } 345 }; 346 347 template <cpu_isa_t isa, typename conv_t> 348 friend status_t init_rtus_driver(conv_t *self); 349 jit_avx512_common_1x1_convolution_bwd_data_tdnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_bwd_data_t350 jit_avx512_common_1x1_convolution_bwd_data_t(const pd_t *apd) 351 : primitive_t(apd) {} 352 353 typedef typename prec_traits<diff_dst_type>::type diff_dst_data_t; 354 typedef typename prec_traits<wei_type>::type wei_data_t; 355 typedef typename prec_traits<diff_src_type>::type diff_src_data_t; 356 initdnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_bwd_data_t357 status_t init(engine_t *engine) override { 358 CHECK(safe_ptr_assign(kernel_, 359 new jit_avx512_common_1x1_conv_kernel( 360 pd()->jcp_, *pd()->attr(), *pd()->dst_md(0)))); 361 CHECK(kernel_->create_kernel()); 362 CHECK(init_rtus_driver<avx512_common>(this)); 363 return status::success; 364 } 365 executednnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_bwd_data_t366 status_t execute(const exec_ctx_t &ctx) const override { 367 execute_backward_data(ctx); 368 return status::success; 369 } 370 371 private: 372 void execute_backward_data(const exec_ctx_t &ctx) const; pddnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_bwd_data_t373 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } 374 375 std::unique_ptr<jit_avx512_common_1x1_conv_kernel> kernel_; 376 std::unique_ptr<rtus_driver_t<avx512_common>> rtus_driver_; 377 }; 378 379 using jit_avx512_common_1x1_convolution_bwd_data_f32_t 380 = jit_avx512_common_1x1_convolution_bwd_data_t<data_type::f32>; 381 382 struct jit_avx512_common_1x1_convolution_bwd_weights_t : public primitive_t { 383 struct pd_t : public cpu_convolution_bwd_weights_pd_t { pd_tdnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_bwd_weights_t::pd_t384 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, 385 const convolution_fwd_pd_t *hint_fwd_pd) 386 : cpu_convolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd) 387 , jcp_() 388 , rtus_() {} 389 390 DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_1x1:", avx512_common, ""), 391 jit_avx512_common_1x1_convolution_bwd_weights_t); 392 initdnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_bwd_weights_t::pd_t393 status_t init(engine_t *engine) { 394 bool ok = true && desc()->prop_kind == prop_kind::backward_weights 395 && set_default_alg_kind(alg_kind::convolution_direct) 396 && expect_data_types(data_type::f32, data_type::f32, 397 data_type::f32, data_type::f32, data_type::f32) 398 && attr()->has_default_values() && !has_zero_dim_memory() 399 && set_default_formats(); 400 if (!ok) return status::unimplemented; 401 402 const convolution_desc_t *conv_d = desc(); 403 const memory_desc_t *src_d = src_md(); 404 rtus_prepare(this, conv_d, src_d, diff_dst_md(), diff_weights_md()); 405 406 status_t status = jit_avx512_common_1x1_conv_kernel::init_conf(jcp_, 407 *conv_d, *src_d, *diff_weights_md(), *diff_dst_md(), 408 *attr(), dnnl_get_max_threads(), rtus_.reduce_src_); 409 if (status != status::success) return status; 410 411 init_balancers(); 412 413 auto scratchpad = scratchpad_registry().registrar(); 414 jit_avx512_common_1x1_conv_kernel::init_scratchpad( 415 scratchpad, jcp_); 416 417 auto reducer_bia_scratchpad = memory_tracking::registrar_t( 418 scratchpad, memory_tracking::names::prefix_reducer_bia); 419 reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad); 420 421 rtus_prepare_space_info(this, scratchpad, jcp_.nthr); 422 423 return status::success; 424 } 425 426 // TODO (Roma): structs conf header cleanup 427 jit_1x1_conv_conf_t jcp_; 428 cpu_reducer_t<data_type::f32>::conf_t reducer_bia_conf_; 429 reduce_to_unit_stride_t rtus_; 430 431 protected: set_default_formatsdnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_bwd_weights_t::pd_t432 bool set_default_formats() { 433 using namespace format_tag; 434 435 auto dat_tag = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c); 436 auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(), 437 OIw16i16o, gOIw16i16o, OIhw16i16o, gOIhw16i16o, OIdhw16i16o, 438 gOIdhw16i16o); 439 440 return set_default_formats_common(dat_tag, wei_tag, dat_tag); 441 } 442 443 private: init_balancersdnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_bwd_weights_t::pd_t444 void init_balancers() { 445 const size_t max_buffer_size = jcp_.nthr * 3 * 5 * 5 * 16 * 16; 446 if (with_bias()) { 447 reducer_bia_conf_.init(reduce_balancer_t(jcp_.nthr, 448 jcp_.oc_block, jcp_.ngroups * jcp_.nb_load, jcp_.mb, 449 max_buffer_size, true)); 450 } 451 } 452 }; 453 454 template <cpu_isa_t isa, typename conv_t> 455 friend status_t init_rtus_driver(conv_t *self); 456 jit_avx512_common_1x1_convolution_bwd_weights_tdnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_bwd_weights_t457 jit_avx512_common_1x1_convolution_bwd_weights_t(const pd_t *apd) 458 : primitive_t(apd) {} 459 460 typedef typename prec_traits<data_type::f32>::type data_t; 461 462 status_t init(engine_t *engine) override; 463 executednnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_bwd_weights_t464 status_t execute(const exec_ctx_t &ctx) const override { 465 execute_backward_weights(ctx); 466 return status::success; 467 } 468 469 private: 470 void execute_backward_weights(const exec_ctx_t &ctx) const; pddnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_bwd_weights_t471 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } 472 473 std::unique_ptr<jit_avx512_common_1x1_conv_kernel> kernel_; 474 std::unique_ptr<cpu_accumulator_1d_t<data_type::f32>> acc_ker_; 475 std::unique_ptr<cpu_reducer_t<data_type::f32>> reducer_bias_; 476 std::unique_ptr<jit_transpose4x16_src> trans_kernel_; 477 std::unique_ptr<rtus_driver_t<avx512_common>> rtus_driver_; 478 }; 479 480 } // namespace x64 481 } // namespace cpu 482 } // namespace impl 483 } // namespace dnnl 484 485 #endif 486