1 /******************************************************************************* 2 * Copyright 2018-2020 Intel Corporation 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 *******************************************************************************/ 16 17 #include "common/c_types_map.hpp" 18 #include "common/dnnl_thread.hpp" 19 #include "common/type_helpers.hpp" 20 #include "common/utils.hpp" 21 22 #include "cpu/x64/jit_generator.hpp" 23 24 #include "cpu/x64/jit_avx512_core_x8s8s32x_1x1_convolution.hpp" 25 26 #include "cpu/cpu_primitive.hpp" 27 28 namespace dnnl { 29 namespace impl { 30 namespace cpu { 31 namespace x64 { 32 33 using namespace dnnl::impl::status; 34 using namespace dnnl::impl::memory_tracking::names; 35 using namespace dnnl::impl::utils; 36 37 /* convolution forward */ 38 template <data_type_t src_type, data_type_t dst_type> 39 status_t jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<src_type, execute_forward(const exec_ctx_t & ctx) const40 dst_type>::execute_forward(const exec_ctx_t &ctx) const { 41 auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC); 42 auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS); 43 auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS); 44 auto dst = CTX_OUT_MEM(dst_data_t *, DNNL_ARG_DST); 45 auto weights_dw = CTX_IN_MEM( 46 const wei_data_t *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS); 47 auto bias_dw = CTX_IN_MEM( 48 const char *, DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_BIAS); 49 const auto post_ops_binary_rhs_arg_vec 50 = binary_injector::prepare_binary_args(pd()->jcp_.post_ops, ctx); 51 const auto post_ops_binary_rhs_arg_vec_dw = pd()->jcp_dw_ 52 ? binary_injector::prepare_binary_args(pd()->jcp_dw_->post_ops, ctx, 53 pd()->jcp_.post_ops.entry_.size() + 1) 54 : std::vector<const void *> {}; 55 56 DEFINE_ZERO_POINTS_BUFFER(src_zero_point, DNNL_ARG_SRC); 57 DEFINE_ZERO_POINTS_BUFFER(dst_zero_point, DNNL_ARG_DST); 58 59 auto scratchpad = ctx.get_scratchpad_grantor(); 60 61 if (pd()->jcp_.signed_input && pd()->jcp_.ver != ver_vnni) { 62 auto local_scales 63 = scratchpad.template get<float>(key_conv_adjusted_scales); 64 auto scales = pd()->attr()->output_scales_.scales_; 65 size_t count = pd()->attr()->output_scales_.count_; 66 float factor = 1.f / pd()->jcp_.wei_adj_scale; 67 if (count == 1) { 68 utils::array_set( 69 local_scales, scales[0] * factor, pd()->jcp_.ic_block); 70 } else { 71 for (size_t c = 0; c < count; c++) 72 local_scales[c] = scales[c] * factor; 73 } 74 } 75 76 if (pd()->jcp_.with_dw_conv) { 77 auto jcp_dw = pd()->jcp_dw_; 78 if (jcp_dw->signed_input && jcp_dw->ver != ver_vnni) { 79 memory_tracking::grantor_t dw_scratchpad( 80 scratchpad, memory_tracking::names::prefix_fusion); 81 auto attr_dw = pd()->dw_conv_pd_->attr(); 82 83 auto local_scales = dw_scratchpad.template get<float>( 84 key_conv_adjusted_scales); 85 auto scales = attr_dw->output_scales_.scales_; 86 size_t count = attr_dw->output_scales_.count_; 87 float factor = 1.f / jcp_dw->wei_adj_scale; 88 if (count == 1) { 89 utils::array_set( 90 local_scales, scales[0] * factor, pd()->jcp_.ic_block); 91 } else { 92 for (size_t c = 0; c < count; c++) 93 local_scales[c] = scales[c] * factor; 94 } 95 } 96 } 97 parallel(pd()->jcp_.nthr, [&](const int ithr, const int nthr) { 98 execute_forward_thr(ithr, nthr, src, weights, bias, weights_dw, bias_dw, 99 dst, src_zero_point, dst_zero_point, scratchpad, 100 post_ops_binary_rhs_arg_vec.data(), 101 post_ops_binary_rhs_arg_vec_dw.data()); 102 }); 103 return status::success; 104 } 105 106 template <data_type_t src_type, data_type_t dst_type> 107 void jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<src_type, execute_forward_thr(const int ithr,const int nthr,const src_data_t * src,const wei_data_t * weights,const char * bias,const wei_data_t * weights_dw,const char * bias_dw,dst_data_t * dst,const int32_t * src_zero_point,const int32_t * dst_zero_point,const memory_tracking::grantor_t & scratchpad,const void * post_ops_binary_rhs_arg_vec,const void * post_ops_binary_rhs_arg_vec_dw) const108 dst_type>::execute_forward_thr(const int ithr, const int nthr, 109 const src_data_t *src, const wei_data_t *weights, const char *bias, 110 const wei_data_t *weights_dw, const char *bias_dw, dst_data_t *dst, 111 const int32_t *src_zero_point, const int32_t *dst_zero_point, 112 const memory_tracking::grantor_t &scratchpad, 113 const void *post_ops_binary_rhs_arg_vec, 114 const void *post_ops_binary_rhs_arg_vec_dw) const { 115 const memory_desc_wrapper src_d(pd()->src_md()); 116 const memory_desc_wrapper dst_d(pd()->dst_md()); 117 const memory_desc_wrapper weights_d(pd()->weights_md(0)); 118 const memory_desc_wrapper dw_weights_d( 119 pd()->arg_md(DNNL_ARG_ATTR_POST_OP_DW | DNNL_ARG_WEIGHTS)); 120 121 const auto &jcp = pd()->jcp_; 122 123 const size_t bia_dt_size = pd()->with_bias() 124 ? types::data_type_size(pd()->desc()->bias_desc.data_type) 125 : 0; 126 127 auto rtus_space = pd()->rtus_.reduce_src_ 128 ? scratchpad.get<src_data_t>(key_conv_rtus_space) 129 : nullptr; 130 131 auto local_scales = scratchpad.get<float>(key_conv_adjusted_scales); 132 133 const int work_amount = jcp.mb * jcp.ngroups * jcp.nb_bcast; 134 135 const bool is_2d = pd()->ndims() == 4; 136 const bool is_3d = pd()->ndims() == 5; 137 138 const int stride_d = pd()->KSD(); 139 const int stride_h = pd()->KSH(); 140 const int stride_w = pd()->KSW(); 141 142 float *oscales {nullptr}; 143 if (jcp.signed_input && jcp.ver != ver_vnni) 144 oscales = scratchpad.get<float>(key_conv_adjusted_scales); 145 else 146 oscales = pd()->attr()->output_scales_.scales_; 147 148 auto offset = weights_d.size() - weights_d.additional_buffer_size(); 149 wei_data_t *w = const_cast<wei_data_t *>(weights); 150 const int32_t *compensation = (jcp.signed_input) 151 ? reinterpret_cast<int32_t *>(w + offset) 152 : nullptr; 153 const int32_t *zp_compensation = jcp.src_zero_point 154 ? reinterpret_cast<int32_t *>(&w[offset]) 155 + (jcp.signed_input ? jcp.ngroups * jcp.oc : 0) 156 : nullptr; 157 158 auto p = jit_1x1_conv_call_s(); 159 160 auto rp = rtus_driver_t<avx512_common>::call_params_t(); 161 const int nb_oc = jcp.nb_load; 162 const int nb_ic = jcp.nb_reduce; 163 // override some constants for fused dw_conv 164 const int os_block = jcp.with_dw_conv ? jcp.ow : jcp.bcast_block; 165 const int nb_bcast = jcp.with_dw_conv ? jcp.oh : jcp.nb_bcast; 166 const int nb_bcast_blocking = jcp.with_dw_conv ? 1 : jcp.nb_bcast_blocking; 167 const int nb_bcast_blocking_max 168 = jcp.with_dw_conv ? 1 : jcp.nb_bcast_blocking_max; 169 const int nb_load_blocking = jcp.nb_load_blocking; 170 const int nb_load_blocking_max = jcp.with_dw_conv 171 ? jcp.nb_load_blocking 172 : jcp.nb_load_blocking_max; 173 174 // Begin: declare Variables needed for dw conv. 175 const auto jcp_dw = pd()->jcp_dw_; 176 const auto &dw_pd = pd()->dw_conv_pd_; 177 memory_tracking::grantor_t dw_scratchpad( 178 scratchpad, memory_tracking::names::prefix_fusion); 179 180 const size_t dw_bia_dt_size = jcp_dw && jcp_dw->with_bias 181 ? types::data_type_size(dw_pd->desc()->bias_desc.data_type) 182 : 0; 183 184 float *dw_oscales {nullptr}; 185 int32_t *compensation_dw {nullptr}; 186 if (jcp.with_dw_conv) { 187 offset = dw_weights_d.size() - dw_weights_d.additional_buffer_size(); 188 w = const_cast<wei_data_t *>(weights_dw); 189 compensation_dw = (jcp_dw->signed_input) 190 ? reinterpret_cast<int32_t *>(w + offset) 191 : nullptr; 192 if (jcp_dw->signed_input && jcp_dw->ver != ver_vnni) 193 dw_oscales = dw_scratchpad.get<float>(key_conv_adjusted_scales); 194 else 195 dw_oscales = dw_pd->attr()->output_scales_.scales_; 196 } 197 198 dst_data_t *pbuf {nullptr}; 199 size_t row_offset {}; 200 const int nb_buffer = jcp.nb_load_blocking; 201 std::vector<dst_data_t *> addrs; 202 // End 203 204 auto step = [](int default_step, int remaining, int tail_step) { 205 assert(default_step <= tail_step); 206 return remaining < tail_step ? remaining : default_step; 207 }; 208 209 auto init_bcast = [&](int iwork, int bcast_end, int &n, int &g, 210 int &bcast_step, int &od, int &oh, int &ow, 211 int &id, int &ih, int &iw) { 212 int osb {0}; 213 nd_iterator_init(iwork, n, jcp.mb, g, jcp.ngroups, osb, nb_bcast); 214 bcast_step = step( 215 nb_bcast_blocking, nb_bcast - osb, nb_bcast_blocking_max); 216 bcast_step = nstl::min(bcast_step, bcast_end - iwork); 217 218 const int os = osb * os_block; 219 const int depth_orthogonal_area = jcp.ow * jcp.oh; 220 od = os / depth_orthogonal_area; 221 oh = (os % depth_orthogonal_area) / jcp.ow; 222 ow = (os % depth_orthogonal_area) % jcp.ow; 223 224 id = od * stride_d; 225 ih = oh * stride_h; 226 iw = ow * stride_w; 227 rp.iw_start = iw; 228 229 p.bcast_dim = this_block_size(os, jcp.os, bcast_step * os_block); 230 rp.os = p.bcast_dim; 231 }; 232 233 auto init_load = [&](int ocb, int ocb_end, int &load_step) { 234 load_step = step(nb_load_blocking, ocb_end - ocb, nb_load_blocking_max); 235 p.load_dim = this_block_size(ocb * jcp.oc_block, ocb_end * jcp.oc_block, 236 load_step * jcp.oc_block); 237 238 if (ocb + load_step >= nb_oc) 239 p.first_last_flag |= FLAG_OC_LAST; 240 else 241 p.first_last_flag &= ~FLAG_OC_LAST; 242 }; 243 244 auto init_reduce = [&]() { 245 p.reduce_dim = this_block_size( 246 0, jcp.ic_without_padding, jcp.ic_without_padding); 247 rp.icb = p.reduce_dim; 248 }; 249 250 auto ker_1x1 = [&](int ocb, int ocb_start, int n, int g, int od, int oh, 251 int ow, int id, int ih, int iw) { 252 const int icb = 0; // Start from the first IC block 253 const int _ocb = g * nb_oc + ocb; 254 const int _icb = g * nb_ic + icb; 255 256 const size_t dst_off = is_3d 257 ? dst_d.blk_off(n, _ocb * jcp.oc_block, od, oh, ow) 258 : is_2d ? dst_d.blk_off(n, _ocb * jcp.oc_block, oh, ow) 259 : dst_d.blk_off(n, _ocb * jcp.oc_block, ow); 260 261 p.output_data = jcp.with_dw_conv ? pbuf + (oh % jcp_dw->kh) * row_offset 262 : &dst[dst_off]; 263 p.load_data 264 = &weights[pd()->with_groups() ? weights_d.blk_off(g, ocb, icb) 265 : weights_d.blk_off(ocb, icb)]; 266 p.bias_data = &bias[_ocb * jcp.oc_block * bia_dt_size]; 267 p.compensation = (jcp.signed_input) ? &compensation[_ocb * jcp.oc_block] 268 : nullptr; 269 p.zp_compensation = jcp.src_zero_point 270 ? zp_compensation + _ocb * jcp.oc_block 271 : nullptr; 272 p.src_zero_point = jcp.src_zero_point ? src_zero_point : nullptr; 273 p.dst_zero_point = jcp.dst_zero_point ? dst_zero_point : nullptr; 274 p.scales = (jcp.signed_input && jcp.ver != ver_vnni) 275 ? &local_scales[jcp.is_oc_scale * _ocb * jcp.oc_block] 276 : &oscales[jcp.is_oc_scale * _ocb * jcp.oc_block]; 277 const size_t src_off = is_3d 278 ? src_d.blk_off(n, _icb * jcp.ic_block, id, ih, iw) 279 : is_2d ? src_d.blk_off(n, _icb * jcp.ic_block, ih, iw) 280 : src_d.blk_off(n, _icb * jcp.ic_block, iw); 281 if (pd()->rtus_.reduce_src_) { 282 rp.ws = rtus_space + ithr * pd()->rtus_.space_per_thread_ 283 + _icb * jcp.is * jcp.ic_block; 284 if (ocb == ocb_start) { 285 rp.src = src + src_off; 286 (*rtus_driver_)(&rp); 287 } 288 p.bcast_data = rp.ws; 289 } else 290 p.bcast_data = src + src_off; 291 292 p.dst_l_off = dst_off; 293 p.oc_l_off = _ocb * jcp.oc_block; 294 p.post_ops_binary_rhs_arg_vec = post_ops_binary_rhs_arg_vec; 295 p.dst_orig = dst; 296 297 (*kernel_)(&p); 298 }; 299 300 auto conv_1x1 = [&](int bcast_start, int bcast_end, int ocb_start, 301 int ocb_end) { 302 if (bcast_start >= bcast_end || ocb_start >= ocb_end) return; 303 if (jcp.loop_order == loop_rlb) { 304 init_reduce(); 305 int ocb = ocb_start; 306 while (ocb < ocb_end) { 307 int load_step; 308 init_load(ocb, ocb_end, load_step); 309 int iwork = bcast_start; 310 while (iwork < bcast_end) { 311 int n, g, bcast_step, od, oh, ow, id, ih, iw; 312 init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow, 313 id, ih, iw); 314 ker_1x1(ocb, ocb_start, n, g, od, oh, ow, id, ih, iw); 315 iwork += bcast_step; 316 } 317 ocb += load_step; 318 } 319 } else if (jcp.loop_order == loop_lbr) { 320 int ocb = ocb_start; 321 while (ocb < ocb_end) { 322 int load_step; 323 init_load(ocb, ocb_end, load_step); 324 int iwork = bcast_start; 325 while (iwork < bcast_end) { 326 int n, g, bcast_step, od, oh, ow, id, ih, iw; 327 init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow, 328 id, ih, iw); 329 init_reduce(); 330 ker_1x1(ocb, ocb_start, n, g, od, oh, ow, id, ih, iw); 331 iwork += bcast_step; 332 } 333 ocb += load_step; 334 } 335 } else if (jcp.loop_order == loop_rbl) { 336 init_reduce(); 337 int iwork = bcast_start; 338 while (iwork < bcast_end) { 339 int n, g, bcast_step, od, oh, ow, id, ih, iw; 340 init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow, id, 341 ih, iw); 342 int ocb = ocb_start; 343 while (ocb < ocb_end) { 344 int load_step; 345 init_load(ocb, ocb_end, load_step); 346 ker_1x1(ocb, ocb_start, n, g, od, oh, ow, id, ih, iw); 347 ocb += load_step; 348 } 349 iwork += bcast_step; 350 } 351 } else if (jcp.loop_order == loop_blr) { 352 int iwork = bcast_start; 353 while (iwork < bcast_end) { 354 int n, g, bcast_step, od, oh, ow, id, ih, iw; 355 init_bcast(iwork, bcast_end, n, g, bcast_step, od, oh, ow, id, 356 ih, iw); 357 int ocb = ocb_start; 358 while (ocb < ocb_end) { 359 int load_step; 360 init_load(ocb, ocb_end, load_step); 361 init_reduce(); 362 ker_1x1(ocb, ocb_start, n, g, od, oh, ow, id, ih, iw); 363 ocb += load_step; 364 } 365 iwork += bcast_step; 366 } 367 } else { 368 assert(!"unsupported loop order"); 369 } 370 }; 371 372 auto ker_dw = [&](int n, int ocb_start, int load_step, int &dw_oh) { 373 int oh_1x1 = dw_oh * jcp_dw->stride_h - jcp_dw->t_pad; 374 int oh_1x1_begin = nstl::max(oh_1x1, 0); 375 376 for (int i = 0; i < jcp_dw->kh; ++i) 377 addrs[i] = pbuf + ((oh_1x1_begin++) % jcp_dw->kh) * row_offset; 378 379 const auto ocb_end = ocb_start + load_step; 380 const size_t src_ch_stride = jcp_dw->nb_ch_blocking * jcp_dw->ch_block; 381 auto par_conv_dw = jit_conv_call_s(); 382 383 par_conv_dw.t_overflow = nstl::min(jcp_dw->kh, nstl::max(0, -oh_1x1)); 384 par_conv_dw.b_overflow = nstl::min( 385 jcp_dw->kh, nstl::max(0, oh_1x1 - jcp.oh + jcp_dw->kh)); 386 par_conv_dw.kh_padding = nstl::max<int>(0, 387 jcp_dw->kh - par_conv_dw.t_overflow - par_conv_dw.b_overflow); 388 389 const size_t dst_offset = n * jcp_dw->ngroups * jcp_dw->oh * jcp_dw->ow 390 + dw_oh * jcp_dw->ow * jcp_dw->ngroups; 391 392 const auto wht_h_stride = dw_weights_d.blk_off(0, 0, 0, 1); 393 const auto wei_stride = (!jcp_dw->signed_input) * par_conv_dw.t_overflow 394 * wht_h_stride; 395 for (int ocb = ocb_start; ocb < ocb_end; 396 ocb += jcp_dw->nb_ch_blocking) { 397 398 par_conv_dw.src = addrs.data(); 399 par_conv_dw.dst = &dst[(dst_offset + jcp_dw->ch_block * ocb) 400 * jcp_dw->typesize_out]; 401 402 par_conv_dw.filt 403 = weights_dw + dw_weights_d.blk_off(ocb, 0) + wei_stride; 404 par_conv_dw.bias 405 = &bias_dw[ocb * jcp_dw->ch_block * dw_bia_dt_size]; 406 par_conv_dw.ur_w = (size_t)(jcp_dw->ow); 407 par_conv_dw.owb = jcp_dw->ow; 408 par_conv_dw.oc_blocks = ocb; 409 par_conv_dw.compensation = compensation_dw 410 ? &compensation_dw[ocb * jcp_dw->ch_block] 411 : nullptr; 412 par_conv_dw.scales = dw_oscales 413 ? &dw_oscales[jcp_dw->is_oc_scale * ocb * jcp_dw->ch_block] 414 : nullptr; 415 416 par_conv_dw.oc_l_off = ocb * jcp_dw->ch_block; 417 par_conv_dw.post_ops_binary_rhs_arg_vec 418 = post_ops_binary_rhs_arg_vec_dw; 419 par_conv_dw.dst_orig = dst; 420 421 (*kernel_dw_)(&par_conv_dw); 422 423 for (int i = 0; i < jcp_dw->kh; ++i) 424 addrs[i] += src_ch_stride; 425 } 426 }; 427 428 auto conv_dw = [&]() { 429 auto jcp_dw = pd()->jcp_dw_; 430 auto dw_conv_buffer 431 = dw_scratchpad.get<dst_data_t>(key_fusion_inout_buffer); 432 433 const auto dw_conv_buffer_size_ 434 = (size_t)jcp_dw->kh * jcp.ow * nb_buffer * jcp.oc_block; 435 pbuf = dw_conv_buffer + ithr * dw_conv_buffer_size_; 436 row_offset = dw_conv_buffer_size_ / jcp_dw->kh; 437 addrs.resize(jcp_dw->kh); 438 439 int bcast_start {0}, bcast_end {0}, ocb_start, ocb_end; 440 balance2D(nthr, ithr, jcp.mb * jcp.ngroups * jcp_dw->oh, bcast_start, 441 bcast_end, nb_oc, ocb_start, ocb_end, jcp.load_grp_count); 442 443 while (ocb_start < ocb_end) { 444 int load_step; 445 init_load(ocb_start, ocb_end, load_step); 446 447 int oh_1x1 = 0; 448 auto bcast_iter = bcast_start; 449 while (bcast_iter < bcast_end) { 450 int n, g, oh_dw; 451 nd_iterator_init(bcast_iter, n, jcp.mb, g, jcp.ngroups, oh_dw, 452 jcp_dw->oh); 453 if (oh_dw == 0) oh_1x1 = 0; // Reset over mb boundary 454 const int oh_1x1_range 455 = oh_dw * jcp_dw->stride_h - jcp_dw->t_pad; 456 const int oh_1x1_begin = nstl::max(oh_1x1_range, 0); 457 const int oh_1x1_end 458 = nstl::min(oh_1x1_range + jcp_dw->kh, jcp.oh); 459 oh_1x1 = nstl::max( 460 oh_1x1_begin, oh_1x1); // Skip rows computed previously 461 462 // dw_spatial to 1x1 spatial conversion. if jcp.oh != jcp_dw.oh 463 const int bcast_start_1x1 464 = n * jcp.ngroups * jcp.oh + g * jcp.oh + oh_1x1; 465 const int bcast_end_1x1 = bcast_start_1x1 - oh_1x1 + oh_1x1_end; 466 467 conv_1x1(bcast_start_1x1, bcast_end_1x1, ocb_start, 468 ocb_start + load_step); 469 oh_1x1 = oh_1x1_end; 470 ker_dw(n, g * nb_oc + ocb_start, load_step, oh_dw); 471 472 bcast_iter += nb_bcast_blocking; 473 } 474 ocb_start += load_step; 475 } 476 }; 477 478 if (jcp.with_dw_conv) { 479 conv_dw(); 480 } else { 481 int bcast_start {0}, bcast_end {0}, ocb_start {0}, ocb_end {0}; 482 balance2D(nthr, ithr, work_amount, bcast_start, bcast_end, 483 jcp.nb_load / jcp.nb_load_chunk, ocb_start, ocb_end, 484 jcp.load_grp_count); 485 if (jcp.nb_load_chunk > 1) { 486 ocb_start *= jcp.nb_load_chunk; 487 ocb_end *= jcp.nb_load_chunk; 488 } 489 conv_1x1(bcast_start, bcast_end, ocb_start, ocb_end); 490 } 491 } 492 493 using namespace data_type; 494 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, u8>; 495 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, u8>; 496 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, s8>; 497 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, s8>; 498 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, s32>; 499 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, s32>; 500 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<u8, f32>; 501 template struct jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t<s8, f32>; 502 503 } // namespace x64 504 } // namespace cpu 505 } // namespace impl 506 } // namespace dnnl 507