1 /*******************************************************************************
2 * Copyright 2016-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef CPU_X64_JIT_AVX2_1X1_CONVOLUTION_HPP
18 #define CPU_X64_JIT_AVX2_1X1_CONVOLUTION_HPP
19 
20 #include "common/c_types_map.hpp"
21 #include "common/dnnl_thread.hpp"
22 #include "common/memory_tracking.hpp"
23 #include "common/primitive.hpp"
24 #include "common/primitive_hashing.hpp"
25 #include "common/utils.hpp"
26 
27 #include "cpu/cpu_convolution_pd.hpp"
28 #include "cpu/dw_convolution_utils.hpp"
29 #include "cpu/platform.hpp"
30 
31 #include "cpu/x64/cpu_reducer.hpp"
32 #include "cpu/x64/jit_avx2_1x1_conv_kernel_f32.hpp"
33 #include "cpu/x64/jit_uni_1x1_conv_utils.hpp"
34 #include "cpu/x64/jit_uni_dw_convolution.hpp"
35 
36 namespace dnnl {
37 namespace impl {
38 namespace cpu {
39 namespace x64 {
40 
41 struct jit_avx2_1x1_convolution_fwd_t : public primitive_t {
42     // TODO: (Roma) Code duplication duplication! Remove with templates
43     //              (maybe...)!
44     struct pd_t : public cpu_convolution_fwd_pd_t {
pd_tdnnl::impl::cpu::x64::jit_avx2_1x1_convolution_fwd_t::pd_t45         pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
46                 const typename pd_t::base_class *hint_fwd_pd)
47             : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd)
48             , jcp_()
49             , rtus_()
50             , jcp_dw_(nullptr) {}
51 
pd_tdnnl::impl::cpu::x64::jit_avx2_1x1_convolution_fwd_t::pd_t52         pd_t(const pd_t &other) : cpu_convolution_fwd_pd_t(other) {
53             if (copy(other) != status::success) is_initialized_ = false;
54         }
55 
56         DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_1x1:", jcp_.isa, ""),
57                 jit_avx2_1x1_convolution_fwd_t);
58 
initdnnl::impl::cpu::x64::jit_avx2_1x1_convolution_fwd_t::pd_t59         status_t init(engine_t *engine) {
60             bool ok = true && is_fwd()
61                     && set_default_alg_kind(alg_kind::convolution_direct)
62                     && expect_data_types(data_type::f32, data_type::f32,
63                             data_type::f32, data_type::f32, data_type::f32)
64                     && attr()->has_default_values(
65                             primitive_attr_t::skip_mask_t::post_ops,
66                             data_type::f32)
67                     && !has_zero_dim_memory() && set_default_formats();
68             if (!ok) return status::unimplemented;
69 
70             const convolution_desc_t *conv_d = desc();
71             const memory_desc_t *src_d = src_md();
72             rtus_prepare(this, conv_d, src_d, dst_md(), weights_md());
73 
74             status_t status = jit_avx2_1x1_conv_kernel_f32::init_conf(
75                     jcp_, *conv_d, *src_d, *weights_md(), *dst_md(), *attr());
76             if (status != status::success) return status;
77 
78             if (jcp_.with_dw_conv) {
79                 status = depthwise_po_init(engine);
80                 if (status != status::success) return status;
81             }
82 
83             auto scratchpad = scratchpad_registry().registrar();
84             jit_avx2_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_);
85 
86             rtus_prepare_space_info(this, scratchpad, jcp_.nthr);
87 
88             return status::success;
89         }
90 
dst_mddnnl::impl::cpu::x64::jit_avx2_1x1_convolution_fwd_t::pd_t91         const memory_desc_t *dst_md(int index = 0) const override {
92             return jcp_.with_dw_conv ? dw_conv_pd_->dst_md(index) : &dst_md_;
93         }
94 
arg_mddnnl::impl::cpu::x64::jit_avx2_1x1_convolution_fwd_t::pd_t95         const memory_desc_t *arg_md(int index = 0) const override {
96             if (jcp_.with_dw_conv) {
97                 switch (index) {
98                     case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS:
99                         return dw_conv_pd_->weights_md(0);
100                     case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS:
101                         return dw_conv_pd_->weights_md(1);
102                     default: break;
103                 }
104             }
105             return convolution_fwd_pd_t::arg_md(index);
106         }
107 
arg_usagednnl::impl::cpu::x64::jit_avx2_1x1_convolution_fwd_t::pd_t108         arg_usage_t arg_usage(int arg) const override {
109             if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS))
110                 return arg_usage_t::input;
111 
112             if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS)
113                     && attr_post_op_dw_inputs() > 1)
114                 return arg_usage_t::input;
115 
116             return convolution_fwd_pd_t::arg_usage(arg);
117         }
118 
119         jit_1x1_conv_conf_t jcp_;
120         reduce_to_unit_stride_t rtus_;
121         jit_conv_conf_t *jcp_dw_;
122         std::unique_ptr<cpu_convolution_fwd_pd_t> dw_conv_pd_;
123 
124     protected:
125         template <cpu_isa_t isa>
126         using dw_pd_t = typename jit_uni_dw_convolution_fwd_t<isa,
127                 data_type::f32>::pd_t;
128 
set_default_formatsdnnl::impl::cpu::x64::jit_avx2_1x1_convolution_fwd_t::pd_t129         bool set_default_formats() {
130             using namespace format_tag;
131 
132             auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
133             auto wei_tag = with_groups()
134                     ? utils::pick(ndims() - 3, gOIw8i8o, gOIhw8i8o, gOIdhw8i8o)
135                     : utils::pick(ndims() - 3, OIw8i8o, OIhw8i8o, OIdhw8i8o);
136 
137             return set_default_formats_common(dat_tag, wei_tag, dat_tag);
138         }
139 
copydnnl::impl::cpu::x64::jit_avx2_1x1_convolution_fwd_t::pd_t140         status_t copy(const pd_t &other) {
141             jcp_ = other.jcp_;
142             rtus_ = other.rtus_;
143             jcp_dw_ = nullptr;
144             if (other.dw_conv_pd_) {
145                 dw_conv_pd_.reset(static_cast<cpu_convolution_fwd_pd_t *>(
146                         other.dw_conv_pd_->clone()));
147                 if (!dw_conv_pd_) return status::out_of_memory;
148                 if (jcp_.isa == avx2) {
149                     jcp_dw_ = &(static_cast<dw_pd_t<avx2> *>(dw_conv_pd_.get())
150                                         ->jcp_);
151                 } else { // sse41
152                     jcp_dw_ = &(static_cast<dw_pd_t<sse41> *>(dw_conv_pd_.get())
153                                         ->jcp_);
154                 }
155             }
156 
157             return status::success;
158         }
159 
depthwise_po_initdnnl::impl::cpu::x64::jit_avx2_1x1_convolution_fwd_t::pd_t160         status_t depthwise_po_init(engine_t *engine) {
161 
162             using namespace memory_tracking;
163             auto &jcp_1x1 = jcp_;
164             primitive_attr_t attr_1x1(*attr());
165             if (!attr_1x1.is_initialized()) return status::out_of_memory;
166             jit_conv_conf_t *jcp_dw = nullptr;
167 
168             const auto &src_md = dst_md_;
169             const memory_desc_wrapper src_d(src_md);
170             const auto nthr = dnnl_get_max_threads();
171             auto l2_cache = platform::get_per_core_cache_size(2) * nthr;
172 
173             // Note: A robust fusion implementation would be to check if both
174             // 1x1 conv and dw conv that are considered here for fusion are
175             // optimal independently. This would require creating a new
176             // primitive_desc through primitive_iterator & check if they match.
177             // Due to concern that these creations and/or checks could be heavy,
178             // for 1x1: Check that no better ISA is available.
179             // for dw: Always fuse with same ISA.
180             // Caveat: May be a better dw conv exists.
181 
182             bool ok = true && (!mayiuse(avx512_common))
183                     && (attr_1x1.post_ops_.find(primitive_kind::sum) == -1)
184                     // TODO: Below may be further tuned.
185                     && (l2_cache * 2 < src_d.size())
186                     // load_grp_count check can be redundant due to l2 check
187                     // above. Adding it explicitly as the current driver doesn't
188                     // work if this condition fails.
189                     && (jcp_1x1.load_grp_count < 2);
190             if (!ok) return status::unimplemented;
191 
192             int dw_po_index
193                     = attr_1x1.post_ops_.find(primitive_kind::convolution);
194 
195             convolution_desc_t cd_dw;
196             primitive_attr_t attr_dw;
197 
198             CHECK(get_depthwise_conv_desc(
199                     cd_dw, src_md, attr_1x1, attr_dw, dw_po_index));
200 
201             if (jcp_1x1.isa == avx2) {
202                 std::unique_ptr<dw_pd_t<avx2>> fusable_pd(
203                         new dw_pd_t<avx2>(&cd_dw, &attr_dw, nullptr));
204                 CHECK(fusable_pd->init(engine));
205                 jcp_dw = &(fusable_pd->jcp_);
206                 dw_conv_pd_ = std::move(fusable_pd);
207             } else {
208                 // Special case for this primitive, as we dont have dw<avx>.
209                 // In this case fuse with sse41 depthwise conv
210                 // NOTE: Currently dw f32 kernel is similar for all ISA and can
211                 // be fused regardless of ISA if inter-connecting md_ matches.
212                 std::unique_ptr<dw_pd_t<sse41>> fusable_pd(
213                         new dw_pd_t<sse41>(&cd_dw, &attr_dw, nullptr));
214                 CHECK(fusable_pd->init(engine));
215                 jcp_dw = &(fusable_pd->jcp_);
216                 dw_conv_pd_ = std::move(fusable_pd);
217             }
218 
219             ok = true
220                     && (dnnl_memory_desc_equal(&src_md, dw_conv_pd_->src_md(0)))
221                     && (jcp_1x1.oc_without_padding % jcp_1x1.oc_block == 0)
222                     && IMPLICATION(
223                             jcp_dw->ow_block, jcp_dw->ow_block == jcp_dw->ow);
224             if (!ok) return status::unimplemented;
225 
226             assert(dw_conv_pd_->dst_md(0)->format_kind != format_kind::any);
227             assert(dw_conv_pd_->weights_md(0)->format_kind != format_kind::any);
228             assert(IMPLICATION(
229                     dw_conv_pd_->weights_md(1)->data_type != data_type::undef,
230                     dw_conv_pd_->weights_md(1)->format_kind
231                             != format_kind::any));
232 
233             jcp_dw->is_fused_conv = true;
234             // TODO: Support/experiment arbitary oc_work in dw conv.
235             // Until then we keep oc_work perfectly divisible.
236             while (jcp_1x1.nb_load % jcp_1x1.nb_load_blocking != 0)
237                 --jcp_1x1.nb_load_blocking;
238             jcp_1x1.nb_load_blocking_max = jcp_1x1.nb_load_blocking;
239 
240             while (jcp_1x1.nb_load_blocking % jcp_dw->nb_ch_blocking != 0)
241                 --jcp_dw->nb_ch_blocking;
242 
243             jcp_dw->dw_conv_buffer_oc
244                     = jcp_1x1.nb_load_blocking * jcp_1x1.oc_block;
245             jcp_1x1.bcast_loop_output_step
246                     = jcp_1x1.ur * jcp_1x1.load_block * jcp_1x1.typesize_out;
247 
248             registrar_t scratchpad(scratchpad_registry_);
249             registrar_t dw_scratchpad(scratchpad, names::prefix_fusion);
250 
251             size_t dw_conv_buffer_size_ = (size_t)nthr * jcp_dw->kh * jcp_dw->iw
252                     * jcp_dw->dw_conv_buffer_oc;
253             assert(dw_conv_buffer_size_);
254             dw_scratchpad.book(memory_tracking::names::key_fusion_inout_buffer,
255                     dw_conv_buffer_size_,
256                     types::data_type_size(dw_conv_pd_->src_md()->data_type));
257 
258             if (jcp_1x1.isa == avx2)
259                 dw_conv_kernel_t<avx2>::init_scratchpad(dw_scratchpad, *jcp_dw);
260             else
261                 dw_conv_kernel_t<sse41>::init_scratchpad(
262                         dw_scratchpad, *jcp_dw);
263 
264             return status::success;
265         }
266     };
267 
268     template <cpu_isa_t isa, typename conv_t>
269     friend status_t init_rtus_driver(conv_t *self);
270 
jit_avx2_1x1_convolution_fwd_tdnnl::impl::cpu::x64::jit_avx2_1x1_convolution_fwd_t271     jit_avx2_1x1_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {}
272 
initdnnl::impl::cpu::x64::jit_avx2_1x1_convolution_fwd_t273     status_t init(engine_t *engine) override {
274         CHECK(safe_ptr_assign(kernel_,
275                 new jit_avx2_1x1_conv_kernel_f32(
276                         pd()->jcp_, *pd()->attr(), *pd()->dst_md(0))));
277         CHECK(kernel_->create_kernel());
278         CHECK(init_rtus_driver<avx2>(this));
279         if (pd()->jcp_.with_dw_conv) {
280             auto &isa = pd()->jcp_.isa;
281 
282             if (isa == avx2) {
283                 CHECK(safe_ptr_assign(kernel_dw_avx2,
284                         new dw_conv_kernel_t<avx2>(
285                                 *(pd()->jcp_dw_), *pd()->dst_md(0))));
286                 CHECK(kernel_dw_avx2->create_kernel());
287             } else {
288                 CHECK(safe_ptr_assign(kernel_dw_sse41,
289                         new dw_conv_kernel_t<sse41>(
290                                 *(pd()->jcp_dw_), *pd()->dst_md(0))));
291                 CHECK(kernel_dw_sse41->create_kernel());
292             }
293         }
294 
295         return status::success;
296     }
297 
298     typedef typename prec_traits<data_type::f32>::type data_t;
299 
executednnl::impl::cpu::x64::jit_avx2_1x1_convolution_fwd_t300     status_t execute(const exec_ctx_t &ctx) const override {
301         execute_forward(ctx);
302         return status::success;
303     }
304 
305 private:
306     void execute_forward(const exec_ctx_t &ctx) const;
307     void execute_forward_thr(const int ithr, const int nthr, const data_t *src,
308             const data_t *weights, const data_t *bias, const data_t *weights_dw,
309             const data_t *bias_dw, data_t *dst,
310             const memory_tracking::grantor_t &scratchpad,
311             const void *post_ops_binary_rhs_arg_vec,
312             const void *post_ops_binary_rhs_arg_vec_dw) const;
pddnnl::impl::cpu::x64::jit_avx2_1x1_convolution_fwd_t313     const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
314 
315     std::unique_ptr<jit_avx2_1x1_conv_kernel_f32> kernel_;
316     std::unique_ptr<rtus_driver_t<avx2>> rtus_driver_;
317 
318     template <cpu_isa_t isa>
319     using dw_conv_kernel_t = jit_uni_dw_conv_fwd_kernel<isa, data_type::f32>;
320 
321     std::unique_ptr<dw_conv_kernel_t<avx2>> kernel_dw_avx2;
322     std::unique_ptr<dw_conv_kernel_t<sse41>> kernel_dw_sse41;
323 };
324 
325 struct jit_avx2_1x1_convolution_bwd_data_t : public primitive_t {
326     struct pd_t : public cpu_convolution_bwd_data_pd_t {
pd_tdnnl::impl::cpu::x64::jit_avx2_1x1_convolution_bwd_data_t::pd_t327         pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
328                 const convolution_fwd_pd_t *hint_fwd_pd)
329             : cpu_convolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd)
330             , jcp_()
331             , rtus_() {}
332 
333         DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_1x1:", avx2, ""),
334                 jit_avx2_1x1_convolution_bwd_data_t);
335 
initdnnl::impl::cpu::x64::jit_avx2_1x1_convolution_bwd_data_t::pd_t336         status_t init(engine_t *engine) {
337             bool ok = true && desc()->prop_kind == prop_kind::backward_data
338                     && set_default_alg_kind(alg_kind::convolution_direct)
339                     && expect_data_types(data_type::f32, data_type::f32,
340                             data_type::undef, data_type::f32, data_type::f32)
341                     && attr()->has_default_values() && !has_zero_dim_memory()
342                     && set_default_formats();
343             if (!ok) return status::unimplemented;
344 
345             const convolution_desc_t *conv_d = desc();
346             const memory_desc_t *diff_src_d = diff_src_md();
347             rtus_prepare(this, conv_d, diff_src_d, diff_dst_md(), weights_md());
348 
349             status_t status = jit_avx2_1x1_conv_kernel_f32::init_conf(jcp_,
350                     *conv_d, *diff_src_d, *weights_md(), *diff_dst_md(),
351                     *attr());
352             if (status != status::success) return status;
353 
354             auto scratchpad = scratchpad_registry().registrar();
355             jit_avx2_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_);
356 
357             rtus_prepare_space_info(this, scratchpad, jcp_.nthr);
358 
359             return status::success;
360         }
361 
362         jit_1x1_conv_conf_t jcp_;
363         reduce_to_unit_stride_t rtus_;
364 
365     protected:
set_default_formatsdnnl::impl::cpu::x64::jit_avx2_1x1_convolution_bwd_data_t::pd_t366         bool set_default_formats() {
367             using namespace format_tag;
368 
369             auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
370             auto wei_tag = with_groups()
371                     ? utils::pick(ndims() - 3, gOIw8o8i, gOIhw8o8i, gOIdhw8o8i)
372                     : utils::pick(ndims() - 3, OIw8o8i, OIhw8o8i, OIdhw8o8i);
373 
374             return set_default_formats_common(dat_tag, wei_tag, dat_tag);
375         }
376     };
377 
378     template <cpu_isa_t isa, typename conv_t>
379     friend status_t init_rtus_driver(conv_t *self);
380 
jit_avx2_1x1_convolution_bwd_data_tdnnl::impl::cpu::x64::jit_avx2_1x1_convolution_bwd_data_t381     jit_avx2_1x1_convolution_bwd_data_t(const pd_t *apd) : primitive_t(apd) {}
382 
383     typedef typename prec_traits<data_type::f32>::type data_t;
384 
initdnnl::impl::cpu::x64::jit_avx2_1x1_convolution_bwd_data_t385     status_t init(engine_t *engine) override {
386         CHECK(safe_ptr_assign(kernel_,
387                 new jit_avx2_1x1_conv_kernel_f32(
388                         pd()->jcp_, *pd()->attr(), *pd()->dst_md(0))));
389         CHECK(kernel_->create_kernel());
390         CHECK(init_rtus_driver<avx2>(this));
391         return status::success;
392     }
393 
executednnl::impl::cpu::x64::jit_avx2_1x1_convolution_bwd_data_t394     status_t execute(const exec_ctx_t &ctx) const override {
395         execute_backward_data(ctx);
396         return status::success;
397     }
398 
399 private:
400     void execute_backward_data(const exec_ctx_t &ctx) const;
pddnnl::impl::cpu::x64::jit_avx2_1x1_convolution_bwd_data_t401     const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
402 
403     std::unique_ptr<jit_avx2_1x1_conv_kernel_f32> kernel_;
404     std::unique_ptr<rtus_driver_t<avx2>> rtus_driver_;
405 };
406 
407 struct jit_avx2_1x1_convolution_bwd_weights_t : public primitive_t {
408     struct pd_t : public cpu_convolution_bwd_weights_pd_t {
pd_tdnnl::impl::cpu::x64::jit_avx2_1x1_convolution_bwd_weights_t::pd_t409         pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
410                 const convolution_fwd_pd_t *hint_fwd_pd)
411             : cpu_convolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd)
412             , jcp_()
413             , rtus_() {}
414 
415         DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_1x1:", avx2, ""),
416                 jit_avx2_1x1_convolution_bwd_weights_t);
417 
initdnnl::impl::cpu::x64::jit_avx2_1x1_convolution_bwd_weights_t::pd_t418         status_t init(engine_t *engine) {
419             bool ok = true && desc()->prop_kind == prop_kind::backward_weights
420                     && set_default_alg_kind(alg_kind::convolution_direct)
421                     && expect_data_types(data_type::f32, data_type::f32,
422                             data_type::f32, data_type::f32, data_type::f32)
423                     && attr()->has_default_values() && !has_zero_dim_memory()
424                     && set_default_formats();
425             if (!ok) return status::unimplemented;
426 
427             const convolution_desc_t *conv_d = desc();
428             const memory_desc_t *src_d = src_md();
429             rtus_prepare(this, conv_d, src_d, diff_dst_md(), diff_weights_md());
430 
431             status_t status = jit_avx2_1x1_conv_kernel_f32::init_conf(jcp_,
432                     *conv_d, *src_d, *diff_weights_md(), *diff_dst_md(),
433                     *attr());
434             if (status != status::success) return status;
435 
436             init_balancers();
437 
438             auto scratchpad = scratchpad_registry().registrar();
439             jit_avx2_1x1_conv_kernel_f32::init_scratchpad(scratchpad, jcp_);
440 
441             rtus_prepare_space_info(this, scratchpad, jcp_.nthr);
442 
443             auto reducer_bia_scratchpad = memory_tracking::registrar_t(
444                     scratchpad, memory_tracking::names::prefix_reducer_bia);
445             reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad);
446 
447             auto reducer_wei_scratchpad = memory_tracking::registrar_t(
448                     scratchpad, memory_tracking::names::prefix_reducer_wei);
449             reducer_wei_conf_.init_scratchpad(reducer_wei_scratchpad);
450 
451             return status::success;
452         }
453 
454         jit_1x1_conv_conf_t jcp_;
455         cpu_reducer_t<data_type::f32>::conf_t reducer_bia_conf_;
456         cpu_reducer_2d_t<data_type::f32>::conf_t reducer_wei_conf_;
457         reduce_to_unit_stride_t rtus_;
458 
459     protected:
set_default_formatsdnnl::impl::cpu::x64::jit_avx2_1x1_convolution_bwd_weights_t::pd_t460         bool set_default_formats() {
461             using namespace format_tag;
462 
463             auto dat_tag = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
464             auto wei_tag = with_groups()
465                     ? utils::pick(ndims() - 3, gOIw8i8o, gOIhw8i8o, gOIdhw8i8o)
466                     : utils::pick(ndims() - 3, OIw8i8o, OIhw8i8o, OIdhw8i8o);
467 
468             return set_default_formats_common(dat_tag, wei_tag, dat_tag);
469         }
470 
471     private:
init_balancersdnnl::impl::cpu::x64::jit_avx2_1x1_convolution_bwd_weights_t::pd_t472         void init_balancers() {
473             const int ic_block = jcp_.bcast_block;
474             const int nb_ic = jcp_.nb_bcast;
475             const int nb_ic_blocking = jcp_.nb_bcast_blocking;
476             const int bcast_work = utils::div_up(nb_ic, nb_ic_blocking);
477 
478             const int oc_block = jcp_.load_block;
479             const int nb_oc = jcp_.nb_load;
480             const int nb_oc_blocking = jcp_.nb_load_blocking;
481             const int load_work = utils::div_up(nb_oc, nb_oc_blocking);
482 
483             const int job_size
484                     = nb_oc_blocking * nb_ic_blocking * ic_block * oc_block;
485             const int njobs_x = bcast_work;
486             const int njobs_y = jcp_.ngroups * load_work;
487 
488             const int max_threads = dnnl_get_max_threads();
489             const size_t max_buffer_size = (size_t)max_threads * job_size * 8;
490 
491             if (with_bias()) {
492                 reducer_bia_conf_.init(reduce_balancer_t(max_threads, oc_block,
493                         jcp_.ngroups * nb_oc, jcp_.mb, max_buffer_size, true));
494             }
495 
496             reducer_wei_conf_.init(
497                     reduce_balancer_t(max_threads, job_size, njobs_y * njobs_x,
498                             jcp_.mb * jcp_.nb_reduce, max_buffer_size, true),
499                     job_size / nb_oc_blocking, nb_oc_blocking, ic_block,
500                     nb_ic * ic_block * oc_block, nb_oc);
501         }
502     };
503 
504     template <cpu_isa_t isa, typename conv_t>
505     friend status_t init_rtus_driver(conv_t *self);
506 
jit_avx2_1x1_convolution_bwd_weights_tdnnl::impl::cpu::x64::jit_avx2_1x1_convolution_bwd_weights_t507     jit_avx2_1x1_convolution_bwd_weights_t(const pd_t *apd)
508         : primitive_t(apd) {}
509 
510     typedef typename prec_traits<data_type::f32>::type data_t;
511 
512     status_t init(engine_t *engine) override;
513 
executednnl::impl::cpu::x64::jit_avx2_1x1_convolution_bwd_weights_t514     status_t execute(const exec_ctx_t &ctx) const override {
515         execute_backward_weights(ctx);
516         return status::success;
517     }
518 
519 private:
520     void execute_backward_weights(const exec_ctx_t &ctx) const;
pddnnl::impl::cpu::x64::jit_avx2_1x1_convolution_bwd_weights_t521     const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
522 
523     std::unique_ptr<jit_avx2_1x1_conv_kernel_f32> kernel_;
524     std::unique_ptr<cpu_reducer_2d_t<data_type::f32>> reducer_weights_;
525     std::unique_ptr<cpu_reducer_t<data_type::f32>> reducer_bias_;
526     std::unique_ptr<rtus_driver_t<avx2>> rtus_driver_;
527 };
528 
529 } // namespace x64
530 } // namespace cpu
531 } // namespace impl
532 } // namespace dnnl
533 
534 #endif
535