1 /******************************************************************************* 2 * Copyright 2016-2021 Intel Corporation 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 *******************************************************************************/ 16 17 #ifndef CPU_X64_JIT_AVX2_1X1_CONVOLUTION_HPP 18 #define CPU_X64_JIT_AVX2_1X1_CONVOLUTION_HPP 19 20 #include "common/c_types_map.hpp" 21 #include "common/dnnl_thread.hpp" 22 #include "common/memory_tracking.hpp" 23 #include "common/primitive.hpp" 24 #include "common/primitive_hashing.hpp" 25 #include "common/utils.hpp" 26 27 #include "cpu/cpu_convolution_pd.hpp" 28 #include "cpu/dw_convolution_utils.hpp" 29 #include "cpu/platform.hpp" 30 31 #include "cpu/x64/cpu_reducer.hpp" 32 #include "cpu/x64/jit_avx2_1x1_conv_kernel_f32.hpp" 33 #include "cpu/x64/jit_uni_1x1_conv_utils.hpp" 34 #include "cpu/x64/jit_uni_dw_convolution.hpp" 35 36 namespace dnnl { 37 namespace impl { 38 namespace cpu { 39 namespace x64 { 40 41 struct jit_avx2_1x1_convolution_fwd_t : public primitive_t { 42 // TODO: (Roma) Code duplication duplication! Remove with templates 43 // (maybe...)! 44 struct pd_t : public cpu_convolution_fwd_pd_t { pd_tdnnl::impl::cpu::x64::jit_avx2_1x1_convolution_fwd_t::pd_t45 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, 46 const typename pd_t::base_class *hint_fwd_pd) 47 : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd) 48 , jcp_() 49 , rtus_() 50 , jcp_dw_(nullptr) {} 51 pd_tdnnl::impl::cpu::x64::jit_avx2_1x1_convolution_fwd_t::pd_t52 pd_t(const pd_t &other) : cpu_convolution_fwd_pd_t(other) { 53 if (copy(other) != status::success) is_initialized_ = false; 54 } 55 56 DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_1x1:", jcp_.isa, ""), 57 jit_avx2_1x1_convolution_fwd_t); 58 initdnnl::impl::cpu::x64::jit_avx2_1x1_convolution_fwd_t::pd_t59 status_t init(engine_t *engine) { 60 bool ok = true && is_fwd() 61 && set_default_alg_kind(alg_kind::convolution_direct) 62 && expect_data_types(data_type::f32, data_type::f32, 63 data_type::f32, data_type::f32, data_type::f32) 64 && attr()->has_default_values( 65 primitive_attr_t::skip_mask_t::post_ops, 66 data_type::f32) 67 && !has_zero_dim_memory() && set_default_formats(); 68 if (!ok) return status::unimplemented; 69 70 const convolution_desc_t *conv_d = desc(); 71 const memory_desc_t *src_d = src_md(); 72 rtus_prepare(this, conv_d, src_d, dst_md(), weights_md()); 73 74 status_t status = jit_avx2_1x1_conv_kernel_f32::init_conf( 75 jcp_, *conv_d, *src_d, *weights_md(), *dst_md(), *attr()); 76 if (status != status::success) return status; 77 78 if (jcp_.with_dw_conv) { 79 status = depthwise_po_init(engine); 80 if (status != status::success) return status; 81 } 82 83 auto scratchpad = scratchpad_registry().registrar(); 84 jit_avx2_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_); 85 86 rtus_prepare_space_info(this, scratchpad, jcp_.nthr); 87 88 return status::success; 89 } 90 dst_mddnnl::impl::cpu::x64::jit_avx2_1x1_convolution_fwd_t::pd_t91 const memory_desc_t *dst_md(int index = 0) const override { 92 return jcp_.with_dw_conv ? dw_conv_pd_->dst_md(index) : &dst_md_; 93 } 94 arg_mddnnl::impl::cpu::x64::jit_avx2_1x1_convolution_fwd_t::pd_t95 const memory_desc_t *arg_md(int index = 0) const override { 96 if (jcp_.with_dw_conv) { 97 switch (index) { 98 case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS: 99 return dw_conv_pd_->weights_md(0); 100 case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS: 101 return dw_conv_pd_->weights_md(1); 102 default: break; 103 } 104 } 105 return convolution_fwd_pd_t::arg_md(index); 106 } 107 arg_usagednnl::impl::cpu::x64::jit_avx2_1x1_convolution_fwd_t::pd_t108 arg_usage_t arg_usage(int arg) const override { 109 if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS)) 110 return arg_usage_t::input; 111 112 if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS) 113 && attr_post_op_dw_inputs() > 1) 114 return arg_usage_t::input; 115 116 return convolution_fwd_pd_t::arg_usage(arg); 117 } 118 119 jit_1x1_conv_conf_t jcp_; 120 reduce_to_unit_stride_t rtus_; 121 jit_conv_conf_t *jcp_dw_; 122 std::unique_ptr<cpu_convolution_fwd_pd_t> dw_conv_pd_; 123 124 protected: 125 template <cpu_isa_t isa> 126 using dw_pd_t = typename jit_uni_dw_convolution_fwd_t<isa, 127 data_type::f32>::pd_t; 128 set_default_formatsdnnl::impl::cpu::x64::jit_avx2_1x1_convolution_fwd_t::pd_t129 bool set_default_formats() { 130 using namespace format_tag; 131 132 auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); 133 auto wei_tag = with_groups() 134 ? utils::pick(ndims() - 3, gOIw8i8o, gOIhw8i8o, gOIdhw8i8o) 135 : utils::pick(ndims() - 3, OIw8i8o, OIhw8i8o, OIdhw8i8o); 136 137 return set_default_formats_common(dat_tag, wei_tag, dat_tag); 138 } 139 copydnnl::impl::cpu::x64::jit_avx2_1x1_convolution_fwd_t::pd_t140 status_t copy(const pd_t &other) { 141 jcp_ = other.jcp_; 142 rtus_ = other.rtus_; 143 jcp_dw_ = nullptr; 144 if (other.dw_conv_pd_) { 145 dw_conv_pd_.reset(static_cast<cpu_convolution_fwd_pd_t *>( 146 other.dw_conv_pd_->clone())); 147 if (!dw_conv_pd_) return status::out_of_memory; 148 if (jcp_.isa == avx2) { 149 jcp_dw_ = &(static_cast<dw_pd_t<avx2> *>(dw_conv_pd_.get()) 150 ->jcp_); 151 } else { // sse41 152 jcp_dw_ = &(static_cast<dw_pd_t<sse41> *>(dw_conv_pd_.get()) 153 ->jcp_); 154 } 155 } 156 157 return status::success; 158 } 159 depthwise_po_initdnnl::impl::cpu::x64::jit_avx2_1x1_convolution_fwd_t::pd_t160 status_t depthwise_po_init(engine_t *engine) { 161 162 using namespace memory_tracking; 163 auto &jcp_1x1 = jcp_; 164 primitive_attr_t attr_1x1(*attr()); 165 if (!attr_1x1.is_initialized()) return status::out_of_memory; 166 jit_conv_conf_t *jcp_dw = nullptr; 167 168 const auto &src_md = dst_md_; 169 const memory_desc_wrapper src_d(src_md); 170 const auto nthr = dnnl_get_max_threads(); 171 auto l2_cache = platform::get_per_core_cache_size(2) * nthr; 172 173 // Note: A robust fusion implementation would be to check if both 174 // 1x1 conv and dw conv that are considered here for fusion are 175 // optimal independently. This would require creating a new 176 // primitive_desc through primitive_iterator & check if they match. 177 // Due to concern that these creations and/or checks could be heavy, 178 // for 1x1: Check that no better ISA is available. 179 // for dw: Always fuse with same ISA. 180 // Caveat: May be a better dw conv exists. 181 182 bool ok = true && (!mayiuse(avx512_common)) 183 && (attr_1x1.post_ops_.find(primitive_kind::sum) == -1) 184 // TODO: Below may be further tuned. 185 && (l2_cache * 2 < src_d.size()) 186 // load_grp_count check can be redundant due to l2 check 187 // above. Adding it explicitly as the current driver doesn't 188 // work if this condition fails. 189 && (jcp_1x1.load_grp_count < 2); 190 if (!ok) return status::unimplemented; 191 192 int dw_po_index 193 = attr_1x1.post_ops_.find(primitive_kind::convolution); 194 195 convolution_desc_t cd_dw; 196 primitive_attr_t attr_dw; 197 198 CHECK(get_depthwise_conv_desc( 199 cd_dw, src_md, attr_1x1, attr_dw, dw_po_index)); 200 201 if (jcp_1x1.isa == avx2) { 202 std::unique_ptr<dw_pd_t<avx2>> fusable_pd( 203 new dw_pd_t<avx2>(&cd_dw, &attr_dw, nullptr)); 204 CHECK(fusable_pd->init(engine)); 205 jcp_dw = &(fusable_pd->jcp_); 206 dw_conv_pd_ = std::move(fusable_pd); 207 } else { 208 // Special case for this primitive, as we dont have dw<avx>. 209 // In this case fuse with sse41 depthwise conv 210 // NOTE: Currently dw f32 kernel is similar for all ISA and can 211 // be fused regardless of ISA if inter-connecting md_ matches. 212 std::unique_ptr<dw_pd_t<sse41>> fusable_pd( 213 new dw_pd_t<sse41>(&cd_dw, &attr_dw, nullptr)); 214 CHECK(fusable_pd->init(engine)); 215 jcp_dw = &(fusable_pd->jcp_); 216 dw_conv_pd_ = std::move(fusable_pd); 217 } 218 219 ok = true 220 && (dnnl_memory_desc_equal(&src_md, dw_conv_pd_->src_md(0))) 221 && (jcp_1x1.oc_without_padding % jcp_1x1.oc_block == 0) 222 && IMPLICATION( 223 jcp_dw->ow_block, jcp_dw->ow_block == jcp_dw->ow); 224 if (!ok) return status::unimplemented; 225 226 assert(dw_conv_pd_->dst_md(0)->format_kind != format_kind::any); 227 assert(dw_conv_pd_->weights_md(0)->format_kind != format_kind::any); 228 assert(IMPLICATION( 229 dw_conv_pd_->weights_md(1)->data_type != data_type::undef, 230 dw_conv_pd_->weights_md(1)->format_kind 231 != format_kind::any)); 232 233 jcp_dw->is_fused_conv = true; 234 // TODO: Support/experiment arbitary oc_work in dw conv. 235 // Until then we keep oc_work perfectly divisible. 236 while (jcp_1x1.nb_load % jcp_1x1.nb_load_blocking != 0) 237 --jcp_1x1.nb_load_blocking; 238 jcp_1x1.nb_load_blocking_max = jcp_1x1.nb_load_blocking; 239 240 while (jcp_1x1.nb_load_blocking % jcp_dw->nb_ch_blocking != 0) 241 --jcp_dw->nb_ch_blocking; 242 243 jcp_dw->dw_conv_buffer_oc 244 = jcp_1x1.nb_load_blocking * jcp_1x1.oc_block; 245 jcp_1x1.bcast_loop_output_step 246 = jcp_1x1.ur * jcp_1x1.load_block * jcp_1x1.typesize_out; 247 248 registrar_t scratchpad(scratchpad_registry_); 249 registrar_t dw_scratchpad(scratchpad, names::prefix_fusion); 250 251 size_t dw_conv_buffer_size_ = (size_t)nthr * jcp_dw->kh * jcp_dw->iw 252 * jcp_dw->dw_conv_buffer_oc; 253 assert(dw_conv_buffer_size_); 254 dw_scratchpad.book(memory_tracking::names::key_fusion_inout_buffer, 255 dw_conv_buffer_size_, 256 types::data_type_size(dw_conv_pd_->src_md()->data_type)); 257 258 if (jcp_1x1.isa == avx2) 259 dw_conv_kernel_t<avx2>::init_scratchpad(dw_scratchpad, *jcp_dw); 260 else 261 dw_conv_kernel_t<sse41>::init_scratchpad( 262 dw_scratchpad, *jcp_dw); 263 264 return status::success; 265 } 266 }; 267 268 template <cpu_isa_t isa, typename conv_t> 269 friend status_t init_rtus_driver(conv_t *self); 270 jit_avx2_1x1_convolution_fwd_tdnnl::impl::cpu::x64::jit_avx2_1x1_convolution_fwd_t271 jit_avx2_1x1_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {} 272 initdnnl::impl::cpu::x64::jit_avx2_1x1_convolution_fwd_t273 status_t init(engine_t *engine) override { 274 CHECK(safe_ptr_assign(kernel_, 275 new jit_avx2_1x1_conv_kernel_f32( 276 pd()->jcp_, *pd()->attr(), *pd()->dst_md(0)))); 277 CHECK(kernel_->create_kernel()); 278 CHECK(init_rtus_driver<avx2>(this)); 279 if (pd()->jcp_.with_dw_conv) { 280 auto &isa = pd()->jcp_.isa; 281 282 if (isa == avx2) { 283 CHECK(safe_ptr_assign(kernel_dw_avx2, 284 new dw_conv_kernel_t<avx2>( 285 *(pd()->jcp_dw_), *pd()->dst_md(0)))); 286 CHECK(kernel_dw_avx2->create_kernel()); 287 } else { 288 CHECK(safe_ptr_assign(kernel_dw_sse41, 289 new dw_conv_kernel_t<sse41>( 290 *(pd()->jcp_dw_), *pd()->dst_md(0)))); 291 CHECK(kernel_dw_sse41->create_kernel()); 292 } 293 } 294 295 return status::success; 296 } 297 298 typedef typename prec_traits<data_type::f32>::type data_t; 299 executednnl::impl::cpu::x64::jit_avx2_1x1_convolution_fwd_t300 status_t execute(const exec_ctx_t &ctx) const override { 301 execute_forward(ctx); 302 return status::success; 303 } 304 305 private: 306 void execute_forward(const exec_ctx_t &ctx) const; 307 void execute_forward_thr(const int ithr, const int nthr, const data_t *src, 308 const data_t *weights, const data_t *bias, const data_t *weights_dw, 309 const data_t *bias_dw, data_t *dst, 310 const memory_tracking::grantor_t &scratchpad, 311 const void *post_ops_binary_rhs_arg_vec, 312 const void *post_ops_binary_rhs_arg_vec_dw) const; pddnnl::impl::cpu::x64::jit_avx2_1x1_convolution_fwd_t313 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } 314 315 std::unique_ptr<jit_avx2_1x1_conv_kernel_f32> kernel_; 316 std::unique_ptr<rtus_driver_t<avx2>> rtus_driver_; 317 318 template <cpu_isa_t isa> 319 using dw_conv_kernel_t = jit_uni_dw_conv_fwd_kernel<isa, data_type::f32>; 320 321 std::unique_ptr<dw_conv_kernel_t<avx2>> kernel_dw_avx2; 322 std::unique_ptr<dw_conv_kernel_t<sse41>> kernel_dw_sse41; 323 }; 324 325 struct jit_avx2_1x1_convolution_bwd_data_t : public primitive_t { 326 struct pd_t : public cpu_convolution_bwd_data_pd_t { pd_tdnnl::impl::cpu::x64::jit_avx2_1x1_convolution_bwd_data_t::pd_t327 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, 328 const convolution_fwd_pd_t *hint_fwd_pd) 329 : cpu_convolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd) 330 , jcp_() 331 , rtus_() {} 332 333 DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_1x1:", avx2, ""), 334 jit_avx2_1x1_convolution_bwd_data_t); 335 initdnnl::impl::cpu::x64::jit_avx2_1x1_convolution_bwd_data_t::pd_t336 status_t init(engine_t *engine) { 337 bool ok = true && desc()->prop_kind == prop_kind::backward_data 338 && set_default_alg_kind(alg_kind::convolution_direct) 339 && expect_data_types(data_type::f32, data_type::f32, 340 data_type::undef, data_type::f32, data_type::f32) 341 && attr()->has_default_values() && !has_zero_dim_memory() 342 && set_default_formats(); 343 if (!ok) return status::unimplemented; 344 345 const convolution_desc_t *conv_d = desc(); 346 const memory_desc_t *diff_src_d = diff_src_md(); 347 rtus_prepare(this, conv_d, diff_src_d, diff_dst_md(), weights_md()); 348 349 status_t status = jit_avx2_1x1_conv_kernel_f32::init_conf(jcp_, 350 *conv_d, *diff_src_d, *weights_md(), *diff_dst_md(), 351 *attr()); 352 if (status != status::success) return status; 353 354 auto scratchpad = scratchpad_registry().registrar(); 355 jit_avx2_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_); 356 357 rtus_prepare_space_info(this, scratchpad, jcp_.nthr); 358 359 return status::success; 360 } 361 362 jit_1x1_conv_conf_t jcp_; 363 reduce_to_unit_stride_t rtus_; 364 365 protected: set_default_formatsdnnl::impl::cpu::x64::jit_avx2_1x1_convolution_bwd_data_t::pd_t366 bool set_default_formats() { 367 using namespace format_tag; 368 369 auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); 370 auto wei_tag = with_groups() 371 ? utils::pick(ndims() - 3, gOIw8o8i, gOIhw8o8i, gOIdhw8o8i) 372 : utils::pick(ndims() - 3, OIw8o8i, OIhw8o8i, OIdhw8o8i); 373 374 return set_default_formats_common(dat_tag, wei_tag, dat_tag); 375 } 376 }; 377 378 template <cpu_isa_t isa, typename conv_t> 379 friend status_t init_rtus_driver(conv_t *self); 380 jit_avx2_1x1_convolution_bwd_data_tdnnl::impl::cpu::x64::jit_avx2_1x1_convolution_bwd_data_t381 jit_avx2_1x1_convolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) {} 382 383 typedef typename prec_traits<data_type::f32>::type data_t; 384 initdnnl::impl::cpu::x64::jit_avx2_1x1_convolution_bwd_data_t385 status_t init(engine_t *engine) override { 386 CHECK(safe_ptr_assign(kernel_, 387 new jit_avx2_1x1_conv_kernel_f32( 388 pd()->jcp_, *pd()->attr(), *pd()->dst_md(0)))); 389 CHECK(kernel_->create_kernel()); 390 CHECK(init_rtus_driver<avx2>(this)); 391 return status::success; 392 } 393 executednnl::impl::cpu::x64::jit_avx2_1x1_convolution_bwd_data_t394 status_t execute(const exec_ctx_t &ctx) const override { 395 execute_backward_data(ctx); 396 return status::success; 397 } 398 399 private: 400 void execute_backward_data(const exec_ctx_t &ctx) const; pddnnl::impl::cpu::x64::jit_avx2_1x1_convolution_bwd_data_t401 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } 402 403 std::unique_ptr<jit_avx2_1x1_conv_kernel_f32> kernel_; 404 std::unique_ptr<rtus_driver_t<avx2>> rtus_driver_; 405 }; 406 407 struct jit_avx2_1x1_convolution_bwd_weights_t : public primitive_t { 408 struct pd_t : public cpu_convolution_bwd_weights_pd_t { pd_tdnnl::impl::cpu::x64::jit_avx2_1x1_convolution_bwd_weights_t::pd_t409 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr, 410 const convolution_fwd_pd_t *hint_fwd_pd) 411 : cpu_convolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd) 412 , jcp_() 413 , rtus_() {} 414 415 DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_1x1:", avx2, ""), 416 jit_avx2_1x1_convolution_bwd_weights_t); 417 initdnnl::impl::cpu::x64::jit_avx2_1x1_convolution_bwd_weights_t::pd_t418 status_t init(engine_t *engine) { 419 bool ok = true && desc()->prop_kind == prop_kind::backward_weights 420 && set_default_alg_kind(alg_kind::convolution_direct) 421 && expect_data_types(data_type::f32, data_type::f32, 422 data_type::f32, data_type::f32, data_type::f32) 423 && attr()->has_default_values() && !has_zero_dim_memory() 424 && set_default_formats(); 425 if (!ok) return status::unimplemented; 426 427 const convolution_desc_t *conv_d = desc(); 428 const memory_desc_t *src_d = src_md(); 429 rtus_prepare(this, conv_d, src_d, diff_dst_md(), diff_weights_md()); 430 431 status_t status = jit_avx2_1x1_conv_kernel_f32::init_conf(jcp_, 432 *conv_d, *src_d, *diff_weights_md(), *diff_dst_md(), 433 *attr()); 434 if (status != status::success) return status; 435 436 init_balancers(); 437 438 auto scratchpad = scratchpad_registry().registrar(); 439 jit_avx2_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_); 440 441 rtus_prepare_space_info(this, scratchpad, jcp_.nthr); 442 443 auto reducer_bia_scratchpad = memory_tracking::registrar_t( 444 scratchpad, memory_tracking::names::prefix_reducer_bia); 445 reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad); 446 447 auto reducer_wei_scratchpad = memory_tracking::registrar_t( 448 scratchpad, memory_tracking::names::prefix_reducer_wei); 449 reducer_wei_conf_.init_scratchpad(reducer_wei_scratchpad); 450 451 return status::success; 452 } 453 454 jit_1x1_conv_conf_t jcp_; 455 cpu_reducer_t<data_type::f32>::conf_t reducer_bia_conf_; 456 cpu_reducer_2d_t<data_type::f32>::conf_t reducer_wei_conf_; 457 reduce_to_unit_stride_t rtus_; 458 459 protected: set_default_formatsdnnl::impl::cpu::x64::jit_avx2_1x1_convolution_bwd_weights_t::pd_t460 bool set_default_formats() { 461 using namespace format_tag; 462 463 auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c); 464 auto wei_tag = with_groups() 465 ? utils::pick(ndims() - 3, gOIw8i8o, gOIhw8i8o, gOIdhw8i8o) 466 : utils::pick(ndims() - 3, OIw8i8o, OIhw8i8o, OIdhw8i8o); 467 468 return set_default_formats_common(dat_tag, wei_tag, dat_tag); 469 } 470 471 private: init_balancersdnnl::impl::cpu::x64::jit_avx2_1x1_convolution_bwd_weights_t::pd_t472 void init_balancers() { 473 const int ic_block = jcp_.bcast_block; 474 const int nb_ic = jcp_.nb_bcast; 475 const int nb_ic_blocking = jcp_.nb_bcast_blocking; 476 const int bcast_work = utils::div_up(nb_ic, nb_ic_blocking); 477 478 const int oc_block = jcp_.load_block; 479 const int nb_oc = jcp_.nb_load; 480 const int nb_oc_blocking = jcp_.nb_load_blocking; 481 const int load_work = utils::div_up(nb_oc, nb_oc_blocking); 482 483 const int job_size 484 = nb_oc_blocking * nb_ic_blocking * ic_block * oc_block; 485 const int njobs_x = bcast_work; 486 const int njobs_y = jcp_.ngroups * load_work; 487 488 const int max_threads = dnnl_get_max_threads(); 489 const size_t max_buffer_size = (size_t)max_threads * job_size * 8; 490 491 if (with_bias()) { 492 reducer_bia_conf_.init(reduce_balancer_t(max_threads, oc_block, 493 jcp_.ngroups * nb_oc, jcp_.mb, max_buffer_size, true)); 494 } 495 496 reducer_wei_conf_.init( 497 reduce_balancer_t(max_threads, job_size, njobs_y * njobs_x, 498 jcp_.mb * jcp_.nb_reduce, max_buffer_size, true), 499 job_size / nb_oc_blocking, nb_oc_blocking, ic_block, 500 nb_ic * ic_block * oc_block, nb_oc); 501 } 502 }; 503 504 template <cpu_isa_t isa, typename conv_t> 505 friend status_t init_rtus_driver(conv_t *self); 506 jit_avx2_1x1_convolution_bwd_weights_tdnnl::impl::cpu::x64::jit_avx2_1x1_convolution_bwd_weights_t507 jit_avx2_1x1_convolution_bwd_weights_t(const pd_t *apd) 508 : primitive_t(apd) {} 509 510 typedef typename prec_traits<data_type::f32>::type data_t; 511 512 status_t init(engine_t *engine) override; 513 executednnl::impl::cpu::x64::jit_avx2_1x1_convolution_bwd_weights_t514 status_t execute(const exec_ctx_t &ctx) const override { 515 execute_backward_weights(ctx); 516 return status::success; 517 } 518 519 private: 520 void execute_backward_weights(const exec_ctx_t &ctx) const; pddnnl::impl::cpu::x64::jit_avx2_1x1_convolution_bwd_weights_t521 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); } 522 523 std::unique_ptr<jit_avx2_1x1_conv_kernel_f32> kernel_; 524 std::unique_ptr<cpu_reducer_2d_t<data_type::f32>> reducer_weights_; 525 std::unique_ptr<cpu_reducer_t<data_type::f32>> reducer_bias_; 526 std::unique_ptr<rtus_driver_t<avx2>> rtus_driver_; 527 }; 528 529 } // namespace x64 530 } // namespace cpu 531 } // namespace impl 532 } // namespace dnnl 533 534 #endif 535