1 /*******************************************************************************
2 * Copyright 2018-2020 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include "common/c_types_map.hpp"
18 #include "common/dnnl_thread.hpp"
19 #include "common/type_helpers.hpp"
20 #include "common/utils.hpp"
21 
22 #include "cpu/x64/jit_generator.hpp"
23 
24 #include "cpu/x64/jit_avx512_core_x8s8s32x_1x1_convolution.hpp"
25 
26 #include "cpu/cpu_primitive.hpp"
27 
28 namespace dnnl {
29 namespace impl {
30 namespace cpu {
31 namespace x64 {
32 
33 using namespace dnnl::impl::status;
34 using namespace dnnl::impl::memory_tracking::names;
35 using namespace dnnl::impl::utils;
36 
37 /* convolution forward */
38 template <data_type_t src_type, data_type_t dst_type>
39 status_t jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<src_type,
execute_forward(const exec_ctx_t & ctx) const40         dst_type>::execute_forward(const exec_ctx_t &ctx) const {
41     auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
42     auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS);
43     auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
44     auto dst = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST);
45     auto weights_dw = CTX_IN_MEM(
46             const wei_data_t *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS);
47     auto bias_dw = CTX_IN_MEM(
48             const char *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS);
49     const auto post_ops_binary_rhs_arg_vec
50             = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx);
51     const auto post_ops_binary_rhs_arg_vec_dw = pd()->jcp_dw_
52             ? binary_injector::prepare_binary_args(pd()->jcp_dw_->post_ops, ctx,
53                     pd()->jcp_.post_ops.entry_.size() + 1)
54             : std::vector<const void *> {};
55 
56     DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC);
57     DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST);
58 
59     auto scratchpad = ctx.get_scratchpad_grantor();
60 
61     if (pd()->jcp_.signed_input && pd()->jcp_.ver != ver_vnni) {
62         auto local_scales
63                 = scratchpad.template get<float>(key_conv_adjusted_scales);
64         auto scales = pd()->attr()->output_scales_.scales_;
65         size_t count = pd()->attr()->output_scales_.count_;
66         float factor = 1.f / pd()->jcp_.wei_adj_scale;
67         if (count == 1) {
68             utils::array_set(
69                     local_scales, scales[0] * factor, pd()->jcp_.ic_block);
70         } else {
71             for (size_t c = 0; c < count; c++)
72                 local_scales[c] = scales[c] * factor;
73         }
74     }
75 
76     if (pd()->jcp_.with_dw_conv) {
77         auto jcp_dw = pd()->jcp_dw_;
78         if (jcp_dw->signed_input && jcp_dw->ver != ver_vnni) {
79             memory_tracking::grantor_t dw_scratchpad(
80                     scratchpad, memory_tracking::names::prefix_fusion);
81             auto attr_dw = pd()->dw_conv_pd_->attr();
82 
83             auto local_scales = dw_scratchpad.template get<float>(
84                     key_conv_adjusted_scales);
85             auto scales = attr_dw->output_scales_.scales_;
86             size_t count = attr_dw->output_scales_.count_;
87             float factor = 1.f / jcp_dw->wei_adj_scale;
88             if (count == 1) {
89                 utils::array_set(
90                         local_scales, scales[0] * factor, pd()->jcp_.ic_block);
91             } else {
92                 for (size_t c = 0; c < count; c++)
93                     local_scales[c] = scales[c] * factor;
94             }
95         }
96     }
97     parallel(pd()->jcp_.nthr, [&](const int ithr, const int nthr) {
98         execute_forward_thr(ithr, nthr, src, weights, bias, weights_dw, bias_dw,
99                 dst, src_zero_point, dst_zero_point, scratchpad,
100                 post_ops_binary_rhs_arg_vec.data(),
101                 post_ops_binary_rhs_arg_vec_dw.data());
102     });
103     return status::success;
104 }
105 
106 template <data_type_t src_type, data_type_t dst_type>
107 void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<src_type,
execute_forward_thr(const int ithr,const int nthr,const src_data_t * src,const wei_data_t * weights,const char * bias,const wei_data_t * weights_dw,const char * bias_dw,dst_data_t * dst,const int32_t * src_zero_point,const int32_t * dst_zero_point,const memory_tracking::grantor_t & scratchpad,const void * post_ops_binary_rhs_arg_vec,const void * post_ops_binary_rhs_arg_vec_dw) const108         dst_type>::execute_forward_thr(const int ithr, const int nthr,
109         const src_data_t *src, const wei_data_t *weights, const char *bias,
110         const wei_data_t *weights_dw, const char *bias_dw, dst_data_t *dst,
111         const int32_t *src_zero_point, const int32_t *dst_zero_point,
112         const memory_tracking::grantor_t &scratchpad,
113         const void *post_ops_binary_rhs_arg_vec,
114         const void *post_ops_binary_rhs_arg_vec_dw) const {
115     const memory_desc_wrapper src_d(pd()->src_md());
116     const memory_desc_wrapper dst_d(pd()->dst_md());
117     const memory_desc_wrapper weights_d(pd()->weights_md(0));
118     const memory_desc_wrapper dw_weights_d(
119             pd()->arg_md(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS));
120 
121     const auto &jcp = pd()->jcp_;
122 
123     const size_t bia_dt_size = pd()->with_bias()
124             ? types::data_type_size(pd()->desc()->bias_desc.data_type)
125             : 0;
126 
127     auto rtus_space = pd()->rtus_.reduce_src_
128             ? scratchpad.get<src_data_t>(key_conv_rtus_space)
129             : nullptr;
130 
131     auto local_scales = scratchpad.get<float>(key_conv_adjusted_scales);
132 
133     const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast;
134 
135     const bool is_2d = pd()->ndims() == 4;
136     const bool is_3d = pd()->ndims() == 5;
137 
138     const int stride_d = pd()->KSD();
139     const int stride_h = pd()->KSH();
140     const int stride_w = pd()->KSW();
141 
142     float *oscales {nullptr};
143     if (jcp.signed_input && jcp.ver != ver_vnni)
144         oscales = scratchpad.get<float>(key_conv_adjusted_scales);
145     else
146         oscales = pd()->attr()->output_scales_.scales_;
147 
148     auto offset = weights_d.size() - weights_d.additional_buffer_size();
149     wei_data_t *w = const_cast<wei_data_t *>(weights);
150     const int32_t *compensation = (jcp.signed_input)
151             ? reinterpret_cast<int32_t *>(w + offset)
152             : nullptr;
153     const int32_t *zp_compensation = jcp.src_zero_point
154             ? reinterpret_cast<int32_t *>(&w[offset])
155                     + (jcp.signed_input ? jcp.ngroups * jcp.oc : 0)
156             : nullptr;
157 
158     auto p = jit_1x1_conv_call_s();
159 
160     auto rp = rtus_driver_t<avx512_common>::call_params_t();
161     const int nb_oc = jcp.nb_load;
162     const int nb_ic = jcp.nb_reduce;
163     // override some constants for fused dw_conv
164     const int os_block = jcp.with_dw_conv ? jcp.ow : jcp.bcast_block;
165     const int nb_bcast = jcp.with_dw_conv ? jcp.oh : jcp.nb_bcast;
166     const int nb_bcast_blocking = jcp.with_dw_conv ? 1 : jcp.nb_bcast_blocking;
167     const int nb_bcast_blocking_max
168             = jcp.with_dw_conv ? 1 : jcp.nb_bcast_blocking_max;
169     const int nb_load_blocking = jcp.nb_load_blocking;
170     const int nb_load_blocking_max = jcp.with_dw_conv
171             ? jcp.nb_load_blocking
172             : jcp.nb_load_blocking_max;
173 
174     // Begin: declare Variables needed for dw conv.
175     const auto jcp_dw = pd()->jcp_dw_;
176     const auto &dw_pd = pd()->dw_conv_pd_;
177     memory_tracking::grantor_t dw_scratchpad(
178             scratchpad, memory_tracking::names::prefix_fusion);
179 
180     const size_t dw_bia_dt_size = jcp_dw && jcp_dw->with_bias
181             ? types::data_type_size(dw_pd->desc()->bias_desc.data_type)
182             : 0;
183 
184     float *dw_oscales {nullptr};
185     int32_t *compensation_dw {nullptr};
186     if (jcp.with_dw_conv) {
187         offset = dw_weights_d.size() - dw_weights_d.additional_buffer_size();
188         w = const_cast<wei_data_t *>(weights_dw);
189         compensation_dw = (jcp_dw->signed_input)
190                 ? reinterpret_cast<int32_t *>(w + offset)
191                 : nullptr;
192         if (jcp_dw->signed_input && jcp_dw->ver != ver_vnni)
193             dw_oscales = dw_scratchpad.get<float>(key_conv_adjusted_scales);
194         else
195             dw_oscales = dw_pd->attr()->output_scales_.scales_;
196     }
197 
198     dst_data_t *pbuf {nullptr};
199     size_t row_offset {};
200     const int nb_buffer = jcp.nb_load_blocking;
201     std::vector<dst_data_t *> addrs;
202     // End
203 
204     auto step = [](int default_step, int remaining, int tail_step) {
205         assert(default_step <= tail_step);
206         return remaining < tail_step ? remaining : default_step;
207     };
208 
209     auto init_bcast = [&](int iwork, int bcast_end, int &n, int &g,
210                               int &bcast_step, int &od, int &oh, int &ow,
211                               int &id, int &ih, int &iw) {
212         int osb {0};
213         nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, nb_bcast);
214         bcast_step = step(
215                 nb_bcast_blocking, nb_bcast - osb, nb_bcast_blocking_max);
216         bcast_step = nstl::min(bcast_step, bcast_end - iwork);
217 
218         const int os = osb * os_block;
219         const int depth_orthogonal_area = jcp.ow * jcp.oh;
220         od = os / depth_orthogonal_area;
221         oh = (os % depth_orthogonal_area) / jcp.ow;
222         ow = (os % depth_orthogonal_area) % jcp.ow;
223 
224         id = od * stride_d;
225         ih = oh * stride_h;
226         iw = ow * stride_w;
227         rp.iw_start = iw;
228 
229         p.bcast_dim = this_block_size(os, jcp.os, bcast_step * os_block);
230         rp.os = p.bcast_dim;
231     };
232 
233     auto init_load = [&](int ocb, int ocb_end, int &load_step) {
234         load_step = step(nb_load_blocking, ocb_end - ocb, nb_load_blocking_max);
235         p.load_dim = this_block_size(ocb * jcp.oc_block, ocb_end * jcp.oc_block,
236                 load_step * jcp.oc_block);
237 
238         if (ocb + load_step >= nb_oc)
239             p.first_last_flag |= FLAG_OC_LAST;
240         else
241             p.first_last_flag &= ~FLAG_OC_LAST;
242     };
243 
244     auto init_reduce = [&]() {
245         p.reduce_dim = this_block_size(
246                 0, jcp.ic_without_padding, jcp.ic_without_padding);
247         rp.icb = p.reduce_dim;
248     };
249 
250     auto ker_1x1 = [&](int ocb, int ocb_start, int n, int g, int od, int oh,
251                            int ow, int id, int ih, int iw) {
252         const int icb = 0; // Start from the first IC block
253         const int _ocb = g * nb_oc + ocb;
254         const int _icb = g * nb_ic + icb;
255 
256         const size_t dst_off = is_3d
257                 ? dst_d.blk_off(n, _ocb * jcp.oc_block, od, oh, ow)
258                 : is_2d ? dst_d.blk_off(n, _ocb * jcp.oc_block, oh, ow)
259                         : dst_d.blk_off(n, _ocb * jcp.oc_block, ow);
260 
261         p.output_data = jcp.with_dw_conv ? pbuf + (oh % jcp_dw->kh) * row_offset
262                                          : &dst[dst_off];
263         p.load_data
264                 = &weights[pd()->with_groups() ? weights_d.blk_off(g, ocb, icb)
265                                                : weights_d.blk_off(ocb, icb)];
266         p.bias_data = &bias[_ocb * jcp.oc_block * bia_dt_size];
267         p.compensation = (jcp.signed_input) ? &compensation[_ocb * jcp.oc_block]
268                                             : nullptr;
269         p.zp_compensation = jcp.src_zero_point
270                 ? zp_compensation + _ocb * jcp.oc_block
271                 : nullptr;
272         p.src_zero_point = jcp.src_zero_point ? src_zero_point : nullptr;
273         p.dst_zero_point = jcp.dst_zero_point ? dst_zero_point : nullptr;
274         p.scales = (jcp.signed_input && jcp.ver != ver_vnni)
275                 ? &local_scales[jcp.is_oc_scale * _ocb * jcp.oc_block]
276                 : &oscales[jcp.is_oc_scale * _ocb * jcp.oc_block];
277         const size_t src_off = is_3d
278                 ? src_d.blk_off(n, _icb * jcp.ic_block, id, ih, iw)
279                 : is_2d ? src_d.blk_off(n, _icb * jcp.ic_block, ih, iw)
280                         : src_d.blk_off(n, _icb * jcp.ic_block, iw);
281         if (pd()->rtus_.reduce_src_) {
282             rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_
283                     + _icb * jcp.is * jcp.ic_block;
284             if (ocb == ocb_start) {
285                 rp.src = src + src_off;
286                 (*rtus_driver_)(&rp);
287             }
288             p.bcast_data = rp.ws;
289         } else
290             p.bcast_data = src + src_off;
291 
292         p.dst_l_off = dst_off;
293         p.oc_l_off = _ocb * jcp.oc_block;
294         p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec;
295         p.dst_orig = dst;
296 
297         (*kernel_)(&p);
298     };
299 
300     auto conv_1x1 = [&](int bcast_start, int bcast_end, int ocb_start,
301                             int ocb_end) {
302         if (bcast_start >= bcast_end || ocb_start >= ocb_end) return;
303         if (jcp.loop_order == loop_rlb) {
304             init_reduce();
305             int ocb = ocb_start;
306             while (ocb < ocb_end) {
307                 int load_step;
308                 init_load(ocb, ocb_end, load_step);
309                 int iwork = bcast_start;
310                 while (iwork < bcast_end) {
311                     int n, g, bcast_step, od, oh, ow, id, ih, iw;
312                     init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow,
313                             id, ih, iw);
314                     ker_1x1(ocb, ocb_start, n, g, od, oh, ow, id, ih, iw);
315                     iwork += bcast_step;
316                 }
317                 ocb += load_step;
318             }
319         } else if (jcp.loop_order == loop_lbr) {
320             int ocb = ocb_start;
321             while (ocb < ocb_end) {
322                 int load_step;
323                 init_load(ocb, ocb_end, load_step);
324                 int iwork = bcast_start;
325                 while (iwork < bcast_end) {
326                     int n, g, bcast_step, od, oh, ow, id, ih, iw;
327                     init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow,
328                             id, ih, iw);
329                     init_reduce();
330                     ker_1x1(ocb, ocb_start, n, g, od, oh, ow, id, ih, iw);
331                     iwork += bcast_step;
332                 }
333                 ocb += load_step;
334             }
335         } else if (jcp.loop_order == loop_rbl) {
336             init_reduce();
337             int iwork = bcast_start;
338             while (iwork < bcast_end) {
339                 int n, g, bcast_step, od, oh, ow, id, ih, iw;
340                 init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow, id,
341                         ih, iw);
342                 int ocb = ocb_start;
343                 while (ocb < ocb_end) {
344                     int load_step;
345                     init_load(ocb, ocb_end, load_step);
346                     ker_1x1(ocb, ocb_start, n, g, od, oh, ow, id, ih, iw);
347                     ocb += load_step;
348                 }
349                 iwork += bcast_step;
350             }
351         } else if (jcp.loop_order == loop_blr) {
352             int iwork = bcast_start;
353             while (iwork < bcast_end) {
354                 int n, g, bcast_step, od, oh, ow, id, ih, iw;
355                 init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow, id,
356                         ih, iw);
357                 int ocb = ocb_start;
358                 while (ocb < ocb_end) {
359                     int load_step;
360                     init_load(ocb, ocb_end, load_step);
361                     init_reduce();
362                     ker_1x1(ocb, ocb_start, n, g, od, oh, ow, id, ih, iw);
363                     ocb += load_step;
364                 }
365                 iwork += bcast_step;
366             }
367         } else {
368             assert(!"unsupported loop order");
369         }
370     };
371 
372     auto ker_dw = [&](int n, int ocb_start, int load_step, int &dw_oh) {
373         int oh_1x1 = dw_oh * jcp_dw->stride_h - jcp_dw->t_pad;
374         int oh_1x1_begin = nstl::max(oh_1x1, 0);
375 
376         for (int i = 0; i < jcp_dw->kh; ++i)
377             addrs[i] = pbuf + ((oh_1x1_begin++) % jcp_dw->kh) * row_offset;
378 
379         const auto ocb_end = ocb_start + load_step;
380         const size_t src_ch_stride = jcp_dw->nb_ch_blocking * jcp_dw->ch_block;
381         auto par_conv_dw = jit_conv_call_s();
382 
383         par_conv_dw.t_overflow = nstl::min(jcp_dw->kh, nstl::max(0, -oh_1x1));
384         par_conv_dw.b_overflow = nstl::min(
385                 jcp_dw->kh, nstl::max(0, oh_1x1 - jcp.oh + jcp_dw->kh));
386         par_conv_dw.kh_padding = nstl::max<int>(0,
387                 jcp_dw->kh - par_conv_dw.t_overflow - par_conv_dw.b_overflow);
388 
389         const size_t dst_offset = n * jcp_dw->ngroups * jcp_dw->oh * jcp_dw->ow
390                 + dw_oh * jcp_dw->ow * jcp_dw->ngroups;
391 
392         const auto wht_h_stride = dw_weights_d.blk_off(0, 0, 0, 1);
393         const auto wei_stride = (!jcp_dw->signed_input) * par_conv_dw.t_overflow
394                 * wht_h_stride;
395         for (int ocb = ocb_start; ocb < ocb_end;
396                 ocb += jcp_dw->nb_ch_blocking) {
397 
398             par_conv_dw.src = addrs.data();
399             par_conv_dw.dst = &dst[(dst_offset + jcp_dw->ch_block * ocb)
400                     * jcp_dw->typesize_out];
401 
402             par_conv_dw.filt
403                     = weights_dw + dw_weights_d.blk_off(ocb, 0) + wei_stride;
404             par_conv_dw.bias
405                     = &bias_dw[ocb * jcp_dw->ch_block * dw_bia_dt_size];
406             par_conv_dw.ur_w = (size_t)(jcp_dw->ow);
407             par_conv_dw.owb = jcp_dw->ow;
408             par_conv_dw.oc_blocks = ocb;
409             par_conv_dw.compensation = compensation_dw
410                     ? &compensation_dw[ocb * jcp_dw->ch_block]
411                     : nullptr;
412             par_conv_dw.scales = dw_oscales
413                     ? &dw_oscales[jcp_dw->is_oc_scale * ocb * jcp_dw->ch_block]
414                     : nullptr;
415 
416             par_conv_dw.oc_l_off = ocb * jcp_dw->ch_block;
417             par_conv_dw.post_ops_binary_rhs_arg_vec
418                     = post_ops_binary_rhs_arg_vec_dw;
419             par_conv_dw.dst_orig = dst;
420 
421             (*kernel_dw_)(&par_conv_dw);
422 
423             for (int i = 0; i < jcp_dw->kh; ++i)
424                 addrs[i] += src_ch_stride;
425         }
426     };
427 
428     auto conv_dw = [&]() {
429         auto jcp_dw = pd()->jcp_dw_;
430         auto dw_conv_buffer
431                 = dw_scratchpad.get<dst_data_t>(key_fusion_inout_buffer);
432 
433         const auto dw_conv_buffer_size_
434                 = (size_t)jcp_dw->kh * jcp.ow * nb_buffer * jcp.oc_block;
435         pbuf = dw_conv_buffer + ithr * dw_conv_buffer_size_;
436         row_offset = dw_conv_buffer_size_ / jcp_dw->kh;
437         addrs.resize(jcp_dw->kh);
438 
439         int bcast_start {0}, bcast_end {0}, ocb_start, ocb_end;
440         balance2D(nthr, ithr, jcp.mb * jcp.ngroups * jcp_dw->oh, bcast_start,
441                 bcast_end, nb_oc, ocb_start, ocb_end, jcp.load_grp_count);
442 
443         while (ocb_start < ocb_end) {
444             int load_step;
445             init_load(ocb_start, ocb_end, load_step);
446 
447             int oh_1x1 = 0;
448             auto bcast_iter = bcast_start;
449             while (bcast_iter < bcast_end) {
450                 int n, g, oh_dw;
451                 nd_iterator_init(bcast_iter, n, jcp.mb, g, jcp.ngroups, oh_dw,
452                         jcp_dw->oh);
453                 if (oh_dw == 0) oh_1x1 = 0; // Reset over mb boundary
454                 const int oh_1x1_range
455                         = oh_dw * jcp_dw->stride_h - jcp_dw->t_pad;
456                 const int oh_1x1_begin = nstl::max(oh_1x1_range, 0);
457                 const int oh_1x1_end
458                         = nstl::min(oh_1x1_range + jcp_dw->kh, jcp.oh);
459                 oh_1x1 = nstl::max(
460                         oh_1x1_begin, oh_1x1); // Skip rows computed previously
461 
462                 // dw_spatial to 1x1 spatial conversion. if jcp.oh != jcp_dw.oh
463                 const int bcast_start_1x1
464                         = n * jcp.ngroups * jcp.oh + g * jcp.oh + oh_1x1;
465                 const int bcast_end_1x1 = bcast_start_1x1 - oh_1x1 + oh_1x1_end;
466 
467                 conv_1x1(bcast_start_1x1, bcast_end_1x1, ocb_start,
468                         ocb_start + load_step);
469                 oh_1x1 = oh_1x1_end;
470                 ker_dw(n, g * nb_oc + ocb_start, load_step, oh_dw);
471 
472                 bcast_iter += nb_bcast_blocking;
473             }
474             ocb_start += load_step;
475         }
476     };
477 
478     if (jcp.with_dw_conv) {
479         conv_dw();
480     } else {
481         int bcast_start {0}, bcast_end {0}, ocb_start {0}, ocb_end {0};
482         balance2D(nthr, ithr, work_amount, bcast_start, bcast_end,
483                 jcp.nb_load / jcp.nb_load_chunk, ocb_start, ocb_end,
484                 jcp.load_grp_count);
485         if (jcp.nb_load_chunk > 1) {
486             ocb_start *= jcp.nb_load_chunk;
487             ocb_end *= jcp.nb_load_chunk;
488         }
489         conv_1x1(bcast_start, bcast_end, ocb_start, ocb_end);
490     }
491 }
492 
493 using namespace data_type;
494 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, u8>;
495 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, u8>;
496 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, s8>;
497 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, s8>;
498 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, s32>;
499 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, s32>;
500 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, f32>;
501 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, f32>;
502 
503 } // namespace x64
504 } // namespace cpu
505 } // namespace impl
506 } // namespace dnnl
507