1 /*******************************************************************************
2 * Copyright 2017-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #ifndef CPU_X64_JIT_AVX512_COMMON_1X1_CONVOLUTION_HPP
18 #define CPU_X64_JIT_AVX512_COMMON_1X1_CONVOLUTION_HPP
19 
20 #include "common/c_types_map.hpp"
21 #include "common/dnnl_thread.hpp"
22 #include "common/memory_tracking.hpp"
23 #include "common/primitive.hpp"
24 #include "common/primitive_hashing.hpp"
25 #include "common/utils.hpp"
26 
27 #include "cpu/cpu_convolution_pd.hpp"
28 #include "cpu/dw_convolution_utils.hpp"
29 #include "cpu/platform.hpp"
30 
31 #include "cpu/x64/cpu_reducer.hpp"
32 #include "cpu/x64/jit_avx512_common_1x1_conv_kernel.hpp"
33 #include "cpu/x64/jit_transpose_utils.hpp"
34 #include "cpu/x64/jit_uni_1x1_conv_utils.hpp"
35 #include "cpu/x64/jit_uni_dw_convolution.hpp"
36 
37 namespace dnnl {
38 namespace impl {
39 namespace cpu {
40 namespace x64 {
41 
42 template <impl::data_type_t src_type, impl::data_type_t wei_type = src_type,
43         impl::data_type_t dst_type = src_type>
44 struct jit_avx512_common_1x1_convolution_fwd_t : public primitive_t {
45     struct pd_t : public cpu_convolution_fwd_pd_t {
pd_tdnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_fwd_t::pd_t46         pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
47                 const typename pd_t::base_class *hint_fwd_pd)
48             : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd)
49             , jcp_()
50             , rtus_() {}
51 
pd_tdnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_fwd_t::pd_t52         pd_t(const pd_t &other) : cpu_convolution_fwd_pd_t(other) {
53             if (copy(other) != status::success) is_initialized_ = false;
54         }
55 
56         DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_1x1:", avx512_common, ""),
57                 jit_avx512_common_1x1_convolution_fwd_t);
58 
initdnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_fwd_t::pd_t59         status_t init(engine_t *engine) {
60             using namespace utils;
61             bool ok = true && is_fwd()
62                     && set_default_alg_kind(alg_kind::convolution_direct)
63                     && expect_data_types(src_type, wei_type, dst_type, dst_type,
64                             data_type::undef)
65                     && attr()->has_default_values(
66                             primitive_attr_t::skip_mask_t::post_ops, dst_type)
67                     && !has_zero_dim_memory() && set_default_formats();
68             if (!ok) return status::unimplemented;
69 
70             const convolution_desc_t *conv_d = desc();
71             const memory_desc_t *src_d = src_md();
72             rtus_prepare(this, conv_d, src_d, dst_md(), weights_md());
73 
74             status_t status = jit_avx512_common_1x1_conv_kernel::init_conf(jcp_,
75                     *conv_d, *src_d, *weights_md(), *dst_md(), *attr(),
76                     dnnl_get_max_threads(), rtus_.reduce_src_);
77             if (status != status::success) return status;
78 
79             if (jcp_.with_dw_conv) {
80                 status = depthwise_po_init(engine);
81                 if (status != status::success) return status;
82             }
83 
84             auto scratchpad = scratchpad_registry().registrar();
85             jit_avx512_common_1x1_conv_kernel::init_scratchpad(
86                     scratchpad, jcp_);
87 
88             rtus_prepare_space_info(this, scratchpad, jcp_.nthr);
89 
90             return status::success;
91         }
92 
dst_mddnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_fwd_t::pd_t93         const memory_desc_t *dst_md(int index = 0) const override {
94             return jcp_.with_dw_conv ? dw_conv_pd_->dst_md(index) : &dst_md_;
95         }
96 
arg_mddnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_fwd_t::pd_t97         const memory_desc_t *arg_md(int index = 0) const override {
98             if (jcp_.with_dw_conv) {
99                 switch (index) {
100                     case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS:
101                         return dw_conv_pd_->weights_md(0);
102                     case DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS:
103                         return dw_conv_pd_->weights_md(1);
104                     default: break;
105                 }
106             }
107             return convolution_fwd_pd_t::arg_md(index);
108         }
109 
arg_usagednnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_fwd_t::pd_t110         arg_usage_t arg_usage(int arg) const override {
111             if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS))
112                 return arg_usage_t::input;
113 
114             if (arg == (DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS)
115                     && attr_post_op_dw_inputs() > 1)
116                 return arg_usage_t::input;
117 
118             return convolution_fwd_pd_t::arg_usage(arg);
119         }
120 
121         jit_1x1_conv_conf_t jcp_;
122         reduce_to_unit_stride_t rtus_;
123         using dw_pd_t = jit_avx512_common_dw_convolution_fwd_t::pd_t;
124         std::unique_ptr<dw_pd_t> dw_conv_pd_;
125 
126     protected:
set_default_formatsdnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_fwd_t::pd_t127         bool set_default_formats() {
128             using namespace format_tag;
129 
130             auto dat_tag = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c);
131             auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(),
132                     OIw16i16o, gOIw16i16o, OIhw16i16o, gOIhw16i16o, OIdhw16i16o,
133                     gOIdhw16i16o);
134 
135             return set_default_formats_common(dat_tag, wei_tag, dat_tag);
136         }
137 
copydnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_fwd_t::pd_t138         status_t copy(const pd_t &other) {
139             jcp_ = other.jcp_;
140             rtus_ = other.rtus_;
141             if (other.dw_conv_pd_) {
142                 dw_conv_pd_.reset(other.dw_conv_pd_->clone());
143                 if (!dw_conv_pd_) return status::out_of_memory;
144             }
145             return status::success;
146         }
147 
depthwise_po_initdnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_fwd_t::pd_t148         status_t depthwise_po_init(engine_t *engine) {
149 
150             using namespace memory_tracking;
151             auto &jcp_1x1 = jcp_;
152             primitive_attr_t attr_1x1(*attr());
153             if (!attr_1x1.is_initialized()) return status::out_of_memory;
154             const auto &src_md = dst_md_;
155             const memory_desc_wrapper src_d(src_md);
156             const auto nthr = dnnl_get_max_threads();
157             auto l2_cache = platform::get_per_core_cache_size(2) * nthr;
158 
159             // Note: A robust fusion implementation would be to check if both
160             // 1x1 conv and dw conv that are considered here for fusion are
161             // optimal independently. This would require creating a new
162             // primitive_desc through primitive_iterator & check if they match.
163             // Due to concern that these creations and/or checks could be heavy,
164             // for 1x1: Check that no better ISA is available.
165             // for dw: Always fuse with same ISA.
166             // Caveat: May be a better dw conv exists.
167 
168             // TODO: Add a check if better ISA exists following above note.
169             bool ok = true
170                     && (attr_1x1.post_ops_.find(primitive_kind::sum) == -1)
171                     // TODO: Below may be further tuned.
172                     && (l2_cache * 2 < src_d.size())
173                     // load_grp_count check can be redundant due to l2 check
174                     // above. Adding it explicitly as the current driver doesn't
175                     // work if this condition fails.
176                     && (jcp_1x1.load_grp_count < 2);
177             if (!ok) return status::unimplemented;
178 
179             int dw_po_index
180                     = attr_1x1.post_ops_.find(primitive_kind::convolution);
181             convolution_desc_t cd_dw;
182             primitive_attr_t attr_dw;
183             CHECK(get_depthwise_conv_desc(
184                     cd_dw, src_md, attr_1x1, attr_dw, dw_po_index));
185 
186             CHECK(safe_ptr_assign(
187                     dw_conv_pd_, new dw_pd_t(&cd_dw, &attr_dw, nullptr)));
188             CHECK(dw_conv_pd_->init(engine));
189             auto &jcp_dw = dw_conv_pd_->jcp_;
190 
191             ok = true
192                     && (dnnl_memory_desc_equal(&src_md, dw_conv_pd_->src_md(0)))
193                     && (jcp_1x1.oc_without_padding % jcp_1x1.oc_block == 0)
194                     && IMPLICATION(
195                             jcp_dw.ow_block, jcp_dw.ow_block == jcp_dw.ow);
196             if (!ok) return status::unimplemented;
197 
198             assert(dw_conv_pd_->dst_md(0)->format_kind != format_kind::any);
199             assert(dw_conv_pd_->weights_md(0)->format_kind != format_kind::any);
200             assert(IMPLICATION(
201                     dw_conv_pd_->weights_md(1)->data_type != data_type::undef,
202                     dw_conv_pd_->weights_md(1)->format_kind
203                             != format_kind::any));
204 
205             jcp_dw.is_fused_conv = true;
206             // TODO: Support/experiment arbitary oc_work in dw conv.
207             // Until then we keep oc_work perfectly divisible.
208             while (jcp_1x1.nb_load % jcp_1x1.nb_load_blocking != 0)
209                 --jcp_1x1.nb_load_blocking;
210             jcp_1x1.nb_load_blocking_max = jcp_1x1.nb_load_blocking;
211 
212             while (jcp_1x1.nb_load_blocking % jcp_dw.nb_ch_blocking != 0)
213                 --jcp_dw.nb_ch_blocking;
214 
215             jcp_dw.dw_conv_buffer_oc
216                     = jcp_1x1.nb_load_blocking * jcp_1x1.oc_block;
217             jcp_1x1.bcast_loop_output_step
218                     = jcp_1x1.ur * jcp_1x1.load_block * jcp_1x1.typesize_out;
219 
220             registrar_t scratchpad(scratchpad_registry_);
221             registrar_t dw_scratchpad(scratchpad, names::prefix_fusion);
222 
223             size_t dw_conv_buffer_size_ = (size_t)nthr * jcp_dw.kh * jcp_dw.iw
224                     * jcp_dw.dw_conv_buffer_oc;
225             assert(dw_conv_buffer_size_);
226             dw_scratchpad.book(memory_tracking::names::key_fusion_inout_buffer,
227                     dw_conv_buffer_size_,
228                     types::data_type_size(dw_conv_pd_->src_md()->data_type));
229 
230             jit_uni_dw_conv_fwd_kernel<avx512_common,
231                     data_type::f32>::init_scratchpad(dw_scratchpad, jcp_dw);
232 
233             return status::success;
234         }
235     };
236 
237     template <cpu_isa_t isa, typename conv_t>
238     friend status_t init_rtus_driver(conv_t *self);
239 
jit_avx512_common_1x1_convolution_fwd_tdnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_fwd_t240     jit_avx512_common_1x1_convolution_fwd_t(const pd_t *apd)
241         : primitive_t(apd) {}
242 
243     typedef typename prec_traits<src_type>::type src_data_t;
244     typedef typename prec_traits<wei_type>::type wei_data_t;
245     typedef typename prec_traits<dst_type>::type dst_data_t;
246 
initdnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_fwd_t247     status_t init(engine_t *engine) override {
248         CHECK(safe_ptr_assign(kernel_,
249                 new jit_avx512_common_1x1_conv_kernel(
250                         pd()->jcp_, *pd()->attr(), *pd()->dst_md(0))));
251         CHECK(kernel_->create_kernel());
252 
253         if (pd()->jcp_.with_dw_conv) {
254             CHECK(safe_ptr_assign(kernel_dw_,
255                     new dw_conv_kernel_t(
256                             pd()->dw_conv_pd_->jcp_, *pd()->dst_md(0))));
257             CHECK(kernel_dw_->create_kernel());
258         }
259 
260         CHECK(init_rtus_driver<avx512_common>(this));
261         return status::success;
262     }
263 
executednnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_fwd_t264     status_t execute(const exec_ctx_t &ctx) const override {
265         execute_forward(ctx);
266         return status::success;
267     }
268 
269 private:
270     void execute_forward(const exec_ctx_t &ctx) const;
271     void execute_forward_thr(const int ithr, const int nthr,
272             const src_data_t *src, const wei_data_t *weights,
273             const dst_data_t *bias, const wei_data_t *weights_dw,
274             const dst_data_t *bias_dw, dst_data_t *dst,
275             const memory_tracking::grantor_t &scratchpad,
276             const void *post_ops_binary_rhs_arg_vec,
277             const void *post_ops_binary_rhs_arg_vec_dw) const;
pddnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_fwd_t278     const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
279 
280     std::unique_ptr<jit_avx512_common_1x1_conv_kernel> kernel_;
281     std::unique_ptr<rtus_driver_t<avx512_common>> rtus_driver_;
282     using dw_conv_kernel_t = jit_uni_dw_conv_fwd_kernel_f32<avx512_common>;
283     std::unique_ptr<dw_conv_kernel_t> kernel_dw_;
284 };
285 
286 using jit_avx512_common_1x1_convolution_fwd_f32_t
287         = jit_avx512_common_1x1_convolution_fwd_t<data_type::f32>;
288 
289 template <impl::data_type_t diff_dst_type,
290         impl::data_type_t wei_type = diff_dst_type,
291         impl::data_type_t diff_src_type = diff_dst_type>
292 struct jit_avx512_common_1x1_convolution_bwd_data_t : public primitive_t {
293     struct pd_t : public cpu_convolution_bwd_data_pd_t {
pd_tdnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_bwd_data_t::pd_t294         pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
295                 const convolution_fwd_pd_t *hint_fwd_pd)
296             : cpu_convolution_bwd_data_pd_t(adesc, attr, hint_fwd_pd)
297             , jcp_()
298             , rtus_() {}
299 
300         DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_1x1:", avx512_common, ""),
301                 jit_avx512_common_1x1_convolution_bwd_data_t);
302 
initdnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_bwd_data_t::pd_t303         status_t init(engine_t *engine) {
304             bool ok = true && desc()->prop_kind == prop_kind::backward_data
305                     && set_default_alg_kind(alg_kind::convolution_direct)
306                     && expect_data_types(diff_src_type, wei_type,
307                             data_type::undef, diff_dst_type, data_type::undef)
308                     && attr()->has_default_values() && !has_zero_dim_memory()
309                     && set_default_formats();
310             if (!ok) return status::unimplemented;
311 
312             const convolution_desc_t *conv_d = desc();
313             const memory_desc_t *diff_src_d = diff_src_md();
314             rtus_prepare(this, conv_d, diff_src_d, diff_dst_md(), weights_md());
315 
316             status_t status = jit_avx512_common_1x1_conv_kernel::init_conf(jcp_,
317                     *conv_d, *diff_src_d, *weights_md(), *diff_dst_md(),
318                     *attr(), dnnl_get_max_threads(), rtus_.reduce_src_);
319             if (status != status::success) return status;
320 
321             auto scratchpad = scratchpad_registry().registrar();
322             jit_avx512_common_1x1_conv_kernel::init_scratchpad(
323                     scratchpad, jcp_);
324 
325             rtus_prepare_space_info(this, scratchpad, jcp_.nthr);
326 
327             return status::success;
328         }
329 
330         // TODO (Roma): structs conf header cleanup
331         jit_1x1_conv_conf_t jcp_;
332         reduce_to_unit_stride_t rtus_;
333 
334     protected:
set_default_formatsdnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_bwd_data_t::pd_t335         bool set_default_formats() {
336             using namespace format_tag;
337 
338             auto dat_tag = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c);
339             auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(),
340                     IOw16o16i, gIOw16o16i, IOhw16o16i, gIOhw16o16i, IOdhw16o16i,
341                     gIOdhw16o16i);
342 
343             return set_default_formats_common(dat_tag, wei_tag, dat_tag);
344         }
345     };
346 
347     template <cpu_isa_t isa, typename conv_t>
348     friend status_t init_rtus_driver(conv_t *self);
349 
jit_avx512_common_1x1_convolution_bwd_data_tdnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_bwd_data_t350     jit_avx512_common_1x1_convolution_bwd_data_t(const pd_t *apd)
351         : primitive_t(apd) {}
352 
353     typedef typename prec_traits<diff_dst_type>::type diff_dst_data_t;
354     typedef typename prec_traits<wei_type>::type wei_data_t;
355     typedef typename prec_traits<diff_src_type>::type diff_src_data_t;
356 
initdnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_bwd_data_t357     status_t init(engine_t *engine) override {
358         CHECK(safe_ptr_assign(kernel_,
359                 new jit_avx512_common_1x1_conv_kernel(
360                         pd()->jcp_, *pd()->attr(), *pd()->dst_md(0))));
361         CHECK(kernel_->create_kernel());
362         CHECK(init_rtus_driver<avx512_common>(this));
363         return status::success;
364     }
365 
executednnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_bwd_data_t366     status_t execute(const exec_ctx_t &ctx) const override {
367         execute_backward_data(ctx);
368         return status::success;
369     }
370 
371 private:
372     void execute_backward_data(const exec_ctx_t &ctx) const;
pddnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_bwd_data_t373     const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
374 
375     std::unique_ptr<jit_avx512_common_1x1_conv_kernel> kernel_;
376     std::unique_ptr<rtus_driver_t<avx512_common>> rtus_driver_;
377 };
378 
379 using jit_avx512_common_1x1_convolution_bwd_data_f32_t
380         = jit_avx512_common_1x1_convolution_bwd_data_t<data_type::f32>;
381 
382 struct jit_avx512_common_1x1_convolution_bwd_weights_t : public primitive_t {
383     struct pd_t : public cpu_convolution_bwd_weights_pd_t {
pd_tdnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_bwd_weights_t::pd_t384         pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
385                 const convolution_fwd_pd_t *hint_fwd_pd)
386             : cpu_convolution_bwd_weights_pd_t(adesc, attr, hint_fwd_pd)
387             , jcp_()
388             , rtus_() {}
389 
390         DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_1x1:", avx512_common, ""),
391                 jit_avx512_common_1x1_convolution_bwd_weights_t);
392 
initdnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_bwd_weights_t::pd_t393         status_t init(engine_t *engine) {
394             bool ok = true && desc()->prop_kind == prop_kind::backward_weights
395                     && set_default_alg_kind(alg_kind::convolution_direct)
396                     && expect_data_types(data_type::f32, data_type::f32,
397                             data_type::f32, data_type::f32, data_type::f32)
398                     && attr()->has_default_values() && !has_zero_dim_memory()
399                     && set_default_formats();
400             if (!ok) return status::unimplemented;
401 
402             const convolution_desc_t *conv_d = desc();
403             const memory_desc_t *src_d = src_md();
404             rtus_prepare(this, conv_d, src_d, diff_dst_md(), diff_weights_md());
405 
406             status_t status = jit_avx512_common_1x1_conv_kernel::init_conf(jcp_,
407                     *conv_d, *src_d, *diff_weights_md(), *diff_dst_md(),
408                     *attr(), dnnl_get_max_threads(), rtus_.reduce_src_);
409             if (status != status::success) return status;
410 
411             init_balancers();
412 
413             auto scratchpad = scratchpad_registry().registrar();
414             jit_avx512_common_1x1_conv_kernel::init_scratchpad(
415                     scratchpad, jcp_);
416 
417             auto reducer_bia_scratchpad = memory_tracking::registrar_t(
418                     scratchpad, memory_tracking::names::prefix_reducer_bia);
419             reducer_bia_conf_.init_scratchpad(reducer_bia_scratchpad);
420 
421             rtus_prepare_space_info(this, scratchpad, jcp_.nthr);
422 
423             return status::success;
424         }
425 
426         // TODO (Roma): structs conf header cleanup
427         jit_1x1_conv_conf_t jcp_;
428         cpu_reducer_t<data_type::f32>::conf_t reducer_bia_conf_;
429         reduce_to_unit_stride_t rtus_;
430 
431     protected:
set_default_formatsdnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_bwd_weights_t::pd_t432         bool set_default_formats() {
433             using namespace format_tag;
434 
435             auto dat_tag = utils::pick(ndims() - 3, nCw16c, nChw16c, nCdhw16c);
436             auto wei_tag = utils::pick(2 * ndims() - 6 + with_groups(),
437                     OIw16i16o, gOIw16i16o, OIhw16i16o, gOIhw16i16o, OIdhw16i16o,
438                     gOIdhw16i16o);
439 
440             return set_default_formats_common(dat_tag, wei_tag, dat_tag);
441         }
442 
443     private:
init_balancersdnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_bwd_weights_t::pd_t444         void init_balancers() {
445             const size_t max_buffer_size = jcp_.nthr * 3 * 5 * 5 * 16 * 16;
446             if (with_bias()) {
447                 reducer_bia_conf_.init(reduce_balancer_t(jcp_.nthr,
448                         jcp_.oc_block, jcp_.ngroups * jcp_.nb_load, jcp_.mb,
449                         max_buffer_size, true));
450             }
451         }
452     };
453 
454     template <cpu_isa_t isa, typename conv_t>
455     friend status_t init_rtus_driver(conv_t *self);
456 
jit_avx512_common_1x1_convolution_bwd_weights_tdnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_bwd_weights_t457     jit_avx512_common_1x1_convolution_bwd_weights_t(const pd_t *apd)
458         : primitive_t(apd) {}
459 
460     typedef typename prec_traits<data_type::f32>::type data_t;
461 
462     status_t init(engine_t *engine) override;
463 
executednnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_bwd_weights_t464     status_t execute(const exec_ctx_t &ctx) const override {
465         execute_backward_weights(ctx);
466         return status::success;
467     }
468 
469 private:
470     void execute_backward_weights(const exec_ctx_t &ctx) const;
pddnnl::impl::cpu::x64::jit_avx512_common_1x1_convolution_bwd_weights_t471     const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
472 
473     std::unique_ptr<jit_avx512_common_1x1_conv_kernel> kernel_;
474     std::unique_ptr<cpu_accumulator_1d_t<data_type::f32>> acc_ker_;
475     std::unique_ptr<cpu_reducer_t<data_type::f32>> reducer_bias_;
476     std::unique_ptr<jit_transpose4x16_src> trans_kernel_;
477     std::unique_ptr<rtus_driver_t<avx512_common>> rtus_driver_;
478 };
479 
480 } // namespace x64
481 } // namespace cpu
482 } // namespace impl
483 } // namespace dnnl
484 
485 #endif
486