1 /*******************************************************************************
2 * Copyright 2018-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18 #include <string.h>
19
20 #include "common/c_types_map.hpp"
21 #include "common/dnnl_thread.hpp"
22 #include "common/memory_tracking.hpp"
23 #include "common/type_helpers.hpp"
24 #include "common/utils.hpp"
25
26 #include "cpu/platform.hpp"
27
28 #include "cpu/x64/jit_avx512_core_u8s8s32x_wino_convolution.hpp"
29 #include "cpu/x64/jit_generator.hpp"
30 #include "cpu/x64/jit_primitive_conf.hpp"
31
32 namespace dnnl {
33 namespace impl {
34 namespace cpu {
35 namespace x64 {
36
37 using namespace dnnl::impl::memory_tracking::names;
38 using namespace dnnl::impl::utils;
39 using namespace dnnl::impl::data_type;
40 using namespace Xbyak;
41
42 namespace {
43 // Below scales are applied to source and weights data accordingly
44 // because this winograd implementation
45 // transforms source which may increase values up to 4x
46 // and transforms weights which may increase values up to 9/4x
47 const float adj_src_scale = 1.f / 4.f;
48 const float adj_wei_scale = 4.f / 9.f;
49 // Winograd transforms need ic and oc to be multiples of 16
50 const int load_block = 16;
51 } // namespace
52
53 /// SRC TRANSFORMS /////////////////////////////////////////////////////////////
54 struct jit_avx512_core_u8s8s32x_wino_conv_src_trans_t : public jit_generator {
55 DECLARE_CPU_JIT_AUX_FUNCTIONS(
56 jit_avx512_core_u8s8s32x_wino_conv_src_trans_t)
57
58 jit_conv_conf_2x3_wino_t jcp;
59 const primitive_attr_t &attr_;
60
61 struct call_params_t {
62 const void *src;
63 const void *wino_src;
64 const void *v_y_masks;
65 const void *v_x_masks;
66 };
67
jit_avx512_core_u8s8s32x_wino_conv_src_trans_tdnnl::impl::cpu::x64::jit_avx512_core_u8s8s32x_wino_conv_src_trans_t68 jit_avx512_core_u8s8s32x_wino_conv_src_trans_t(
69 jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr)
70 : jcp(ajcp), attr_(attr), unsign_val_in_wino_domain(5) {}
71 void generate() override;
72
reg_inp_inddnnl::impl::cpu::x64::jit_avx512_core_u8s8s32x_wino_conv_src_trans_t73 int reg_inp_ind(int i) const {
74 assert(i < jcp.alpha * jcp.alpha);
75 return (31 - i);
76 }
77
vreg_inpdnnl::impl::cpu::x64::jit_avx512_core_u8s8s32x_wino_conv_src_trans_t78 Xmm vreg_inp(int i) const { return Xmm(reg_inp_ind(i)); }
79
zmm_inpdnnl::impl::cpu::x64::jit_avx512_core_u8s8s32x_wino_conv_src_trans_t80 Zmm zmm_inp(int i) const { return Zmm(reg_inp_ind(i)); }
81
vreg_tmpdnnl::impl::cpu::x64::jit_avx512_core_u8s8s32x_wino_conv_src_trans_t82 Xmm vreg_tmp(int i) const {
83 assert(i < jcp.alpha * jcp.alpha);
84 return Xmm(15 - i);
85 }
vreg_outdnnl::impl::cpu::x64::jit_avx512_core_u8s8s32x_wino_conv_src_trans_t86 Xmm vreg_out(int i) const {
87 assert(i < jcp.alpha * jcp.alpha);
88 return Xmm(31 - i);
89 }
90
91 Opmask y_mask = Opmask(1);
92 Opmask r_mask = Opmask(2);
x_maskdnnl::impl::cpu::x64::jit_avx512_core_u8s8s32x_wino_conv_src_trans_t93 Opmask x_mask(int id) {
94 assert(id < 4);
95 return Opmask(3 + id);
96 }
97
98 Reg64 reg_ptr_src = r14;
99 Reg64 reg_ptr_dst = r13;
100
101 Reg64 reg_ptr_v_y_masks = r12;
102 Reg64 reg_ptr_v_x_masks = r11;
103
104 Reg64 reg_aux_ptr_src = r10;
105 Reg64 reg_aux_ptr_dst = r9;
106
107 Reg64 reg_ic_block = r8;
108
109 int unsign_val_in_wino_domain;
110
111 Reg64 reg_scratch_src_alpha = rdx;
112 Xmm xmm_src_alpha = Xmm(0);
113 Zmm zmm_src_alpha = Zmm(0);
114
115 Reg64 reg_shift = rax;
116 Xmm xmm_shift = Xmm(1);
117 Xmm xmm_zero = Xmm(0);
118
119 Reg64 reg_maskx = rbx;
120 Reg64 reg_masky = rsi;
121 Reg64 reg_nomask = reg_maskx;
122 };
123
generate()124 void jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::generate() {
125 Label ic_block_label;
126 Label end_label;
127 Label mask_label;
128 Label nomask_label;
129
130 auto load_src = [=](bool mask) {
131 for (int y = 0; y < jcp.alpha; y++) {
132 if (mask)
133 kmovw(y_mask, ptr[reg_ptr_v_y_masks + sizeof(uint16_t) * y]);
134 for (int x = 0; x < jcp.alpha; x++) {
135 Zmm zmm_i = zmm_inp(y * jcp.alpha + x);
136 Xmm vreg_i = vreg_inp(y * jcp.alpha + x);
137 int inp_offset = sizeof(uint8_t)
138 * ((-jcp.t_pad + y) * jcp.iw * jcp.ic
139 + (-jcp.l_pad + x) * jcp.ic);
140 if (mask) {
141 kandw(r_mask, y_mask, x_mask(x));
142 vmovdqu8(vreg_i | r_mask | T_z,
143 EVEX_compress_addr(reg_aux_ptr_src, inp_offset));
144 } else {
145 vmovdqu8(vreg_i,
146 EVEX_compress_addr(reg_aux_ptr_src, inp_offset));
147 }
148 vpmovzxbd(zmm_i, vreg_i); // to int32
149 vcvtdq2ps(zmm_i, zmm_i); // to fp32
150 vmulps(zmm_i, zmm_i, zmm_src_alpha); // *alpha
151 vcvtps2dq(zmm_i, zmm_i); // to int32
152 vpmovusdb(vreg_i, zmm_i); // to u8
153 }
154 }
155 };
156
157 preamble();
158
159 #define READ_PARAM(reg, field) \
160 mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
161 READ_PARAM(reg_ptr_src, src);
162 READ_PARAM(reg_ptr_dst, wino_src);
163 READ_PARAM(reg_ptr_v_y_masks, v_y_masks);
164 READ_PARAM(reg_ptr_v_x_masks, v_x_masks);
165 #undef READ_PARAM
166
167 mov(reg_maskx, ptr[reg_ptr_v_x_masks]);
168 mov(reg_masky, ptr[reg_ptr_v_y_masks]);
169 test(reg_maskx, reg_maskx);
170 jz(end_label, T_NEAR); // skip kernel if x mask is all 0's
171 test(reg_masky, reg_masky);
172 jz(end_label, T_NEAR); // skip kernel if y mask is all 0's
173 and_(reg_maskx, reg_masky);
174 mov(reg_nomask, reg_maskx);
175 not_(reg_nomask); // zero if x and y masks are all 1's
176
177 xor_(reg_shift, reg_shift);
178 mov(reg_shift.cvt8(), (int8_t)-128);
179
180 mov(reg_aux_ptr_src, reg_ptr_src);
181 mov(reg_aux_ptr_dst, reg_ptr_dst);
182
183 for (int i = 0; i < jcp.alpha; i++) {
184 kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(uint16_t) * i]);
185 }
186
187 mov(reg_scratch_src_alpha, float2int(adj_src_scale));
188
189 mov(reg_ic_block, jcp.ic / load_block);
190 L(ic_block_label);
191 {
192 vmovq(xmm_src_alpha, reg_scratch_src_alpha);
193 vbroadcastss(zmm_src_alpha, xmm_src_alpha);
194
195 test(reg_nomask, reg_nomask);
196 jz(nomask_label, T_NEAR);
197 load_src(true);
198 jmp(mask_label, T_NEAR);
199 L(nomask_label);
200 load_src(false);
201 L(mask_label);
202
203 for (int y = 0; y < 4; y++) {
204 vpsubb(vreg_tmp(y * 4 + 0), vreg_inp(y * 4 + 0),
205 vreg_inp(y * 4 + 2));
206 vpaddb(vreg_tmp(y * 4 + 1), vreg_inp(y * 4 + 1),
207 vreg_inp(y * 4 + 2));
208 vpsubb(vreg_tmp(y * 4 + 2), vreg_inp(y * 4 + 2),
209 vreg_inp(y * 4 + 1));
210 vpsubb(vreg_tmp(y * 4 + 3), vreg_inp(y * 4 + 1),
211 vreg_inp(y * 4 + 3));
212 }
213 for (int x = 0; x < 4; x++) {
214 vpsubb(vreg_out(x + 0 * 4), vreg_tmp(x + 4 * 0),
215 vreg_tmp(x + 4 * 2));
216 vpaddb(vreg_out(x + 1 * 4), vreg_tmp(x + 4 * 1),
217 vreg_tmp(x + 4 * 2));
218 vpsubb(vreg_out(x + 2 * 4), vreg_tmp(x + 4 * 2),
219 vreg_tmp(x + 4 * 1));
220 vpsubb(vreg_out(x + 3 * 4), vreg_tmp(x + 4 * 1),
221 vreg_tmp(x + 4 * 3));
222 }
223
224 vmovd(xmm_shift, reg_shift.cvt32());
225 vpxor(xmm_zero, xmm_zero, xmm_zero);
226 vpshufb(xmm_shift, xmm_shift, xmm_zero);
227
228 for (int i = 0; i < 16; i++) {
229 int out_offset = sizeof(uint8_t) * (jcp.inp_stride * i);
230 if (i != unsign_val_in_wino_domain)
231 vpsubb(vreg_out(i), vreg_out(i), Xmm(1));
232 vmovups(EVEX_compress_addr(reg_aux_ptr_dst, out_offset),
233 vreg_out(i));
234 }
235
236 add(reg_aux_ptr_src, sizeof(uint8_t) * load_block);
237 add(reg_aux_ptr_dst, sizeof(uint8_t) * load_block);
238 }
239 dec(reg_ic_block);
240 jnz(ic_block_label, T_NEAR);
241
242 L(end_label);
243 postamble();
244 }
245
246 /// DST TRANSFORMS /////////////////////////////////////////////////////////////
247 struct jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t : public jit_generator {
248 DECLARE_CPU_JIT_AUX_FUNCTIONS(
249 jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t)
250
251 jit_conv_conf_2x3_wino_t jcp;
252 const primitive_attr_t &attr_;
253
254 struct call_params_t {
255 const void *wino_dst;
256 const void *dst;
257 const void *v_y_masks;
258 const void *v_x_masks;
259
260 const void *bias;
261 const void *scales;
262 };
263
jit_avx512_core_u8s8s32x_wino_conv_dst_trans_tdnnl::impl::cpu::x64::jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t264 jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t(
265 jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr)
266 : jcp(ajcp), attr_(attr) {}
267
268 void generate() override;
269 bool maybe_relu(int position);
270
vreg_inpdnnl::impl::cpu::x64::jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t271 Zmm vreg_inp(int i) const { // 16
272 assert(i < jcp.alpha * jcp.alpha);
273 return Zmm(31 - i);
274 }
vreg_stgdnnl::impl::cpu::x64::jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t275 Zmm vreg_stg(int id) const { // 8
276 const int id_reg_stg = jcp.alpha * jcp.alpha + id;
277 assert(id < 8);
278 return Zmm(31 - id_reg_stg);
279 }
vreg_outdnnl::impl::cpu::x64::jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t280 Zmm vreg_out(int id) const { // 4
281 const int id_reg_out = jcp.alpha * jcp.alpha + 8 + id;
282 assert(id < 4);
283 return Zmm(31 - id_reg_out);
284 }
xmm_outdnnl::impl::cpu::x64::jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t285 Xmm xmm_out(int id) const { // 4
286 const int id_reg_out = jcp.alpha * jcp.alpha + 8 + id;
287 assert(id < 4);
288 return Xmm(31 - id_reg_out);
289 }
vreg_tmpdnnl::impl::cpu::x64::jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t290 Zmm vreg_tmp(int id) const { // 2
291 const int id_reg_tmp = jcp.alpha * jcp.alpha + 12 + id;
292 assert(id < 2);
293 return Zmm(31 - id_reg_tmp);
294 }
295
296 Zmm vreg_zero = Zmm(0);
297 Zmm vreg_bias = Zmm(1);
298 Zmm vreg_prev_dst = Zmm(2);
299 Zmm zmm_bias_alpha = Zmm(2);
300 Xmm xmm_bias_alpha = Xmm(2);
301 Zmm vreg_saturation_ubound = Zmm(3);
302 Zmm vreg_sum_zp = Zmm(4);
303
304 Opmask y_mask = Opmask(1);
305 Opmask r_mask = Opmask(2);
x_maskdnnl::impl::cpu::x64::jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t306 Opmask x_mask(int id) {
307 assert(id < 4);
308 return Opmask(3 + id);
309 }
310
311 Reg64 reg_scratch_bias_alpha = r15;
312
313 Reg64 reg_ptr_src = r14;
314 Reg64 reg_ptr_dst = r13;
315
316 Reg64 reg_ptr_v_y_masks = r12;
317 Reg64 reg_ptr_v_x_masks = r11;
318
319 Reg64 reg_aux_ptr_src = r10;
320 Reg64 reg_aux_ptr_dst = r9;
321
322 Reg64 reg_oc_block = r8;
323
324 Reg64 reg_ptr_bias = rbx;
325 Reg64 reg_ptr_scales = abi_not_param1;
326 Reg64 reg_ptr_sum_scale = rdx;
327 Reg64 reg_ptr_sum_zp = rdx;
328 Reg64 reg_ptr_saturation_ubound = rax;
329 };
330
maybe_relu(int position)331 bool jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::maybe_relu(int position) {
332 using namespace primitive_kind;
333 const auto &p = attr_.post_ops_;
334
335 if (position == 0) {
336 /* relu before sum */
337 return false || p.contain(eltwise, 0)
338 || (jcp.dst_dt == data_type::u8 && !p.contain(sum, 0));
339 } else if (position == 1) {
340 /* relu after sum */
341 const int sum_idx
342 = p.contain(sum, 0) ? 0 : (p.contain(sum, 1) ? 1 : -1);
343 if (sum_idx == -1) return false;
344
345 return false || p.contain(eltwise, sum_idx + 1)
346 || jcp.dst_dt == data_type::u8;
347 }
348
349 return false;
350 }
351
generate()352 void jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::generate() {
353 Label oc_block_label;
354
355 auto loop_body = [=]() {
356 const auto &p = attr_.post_ops_;
357 const int sum_idx = p.find(primitive_kind::sum);
358 const float *p_sum_scale
359 = (sum_idx != -1) ? &p.entry_[sum_idx].sum.scale : nullptr;
360 const int32_t *p_sum_zp
361 = (sum_idx != -1) ? &p.entry_[sum_idx].sum.zero_point : nullptr;
362 if (p_sum_scale) {
363 if (*p_sum_zp != 0) {
364 mov(reg_ptr_sum_zp, reinterpret_cast<size_t>(p_sum_zp));
365 vcvtdq2ps(vreg_sum_zp, ptr_b[reg_ptr_sum_zp]);
366 }
367 if (*p_sum_scale != 1.f)
368 mov(reg_ptr_sum_scale, reinterpret_cast<size_t>(p_sum_scale));
369 }
370 for (int i = 0; i < 16; i++) {
371 int internal_offset = sizeof(int32_t) * jcp.out_stride * i;
372 vmovups(vreg_inp(i),
373 EVEX_compress_addr(reg_aux_ptr_src, internal_offset));
374 }
375 for (int y = 0; y < jcp.alpha; y++) {
376 vpaddd(vreg_tmp(0), vreg_inp(y * 4 + 0), vreg_inp(y * 4 + 1));
377 vpaddd(vreg_stg(y * 2), vreg_tmp(0), vreg_inp(y * 4 + 2));
378
379 vpsubd(vreg_tmp(1), vreg_inp(y * 4 + 1), vreg_inp(y * 4 + 2));
380 vpsubd(vreg_stg(y * 2 + 1), vreg_tmp(1), vreg_inp(y * 4 + 3));
381 }
382 for (int x = 0; x < jcp.m; x++) {
383 vpaddd(vreg_tmp(0), vreg_stg(x), vreg_stg(x + 2 * 1));
384 vpaddd(vreg_out(x), vreg_tmp(0), vreg_stg(x + 2 * 2));
385
386 vpsubd(vreg_tmp(1), vreg_stg(x + 2 * 1), vreg_stg(x + 2 * 2));
387 vpsubd(vreg_out(x + 2), vreg_tmp(1), vreg_stg(x + 2 * 3));
388 }
389
390 if (jcp.with_bias) {
391 vmovq(xmm_bias_alpha, reg_scratch_bias_alpha);
392 vbroadcastss(zmm_bias_alpha, xmm_bias_alpha);
393
394 auto bias_addr = ptr[reg_ptr_bias];
395 switch (jcp.bia_dt) {
396 case data_type::f32:
397 case data_type::s32: vmovups(vreg_bias, bias_addr); break;
398 case data_type::s8: vpmovsxbd(vreg_bias, bias_addr); break;
399 case data_type::u8: vpmovzxbd(vreg_bias, bias_addr); break;
400 default: assert(!"unsupported dst data type");
401 }
402 if (jcp.bia_dt != data_type::f32) vcvtdq2ps(vreg_bias, vreg_bias);
403 vmulps(vreg_bias, vreg_bias, zmm_bias_alpha); // *alpha
404 }
405
406 auto sum_dt = p.get_sum_dt(jcp.dst_dt);
407
408 init_saturate_f32(vreg_zero, vreg_saturation_ubound,
409 reg_ptr_saturation_ubound, f32, jcp.dst_dt);
410 for (int y = 0; y < jcp.m; y++) {
411 kmovw(y_mask, ptr[reg_ptr_v_y_masks + sizeof(uint16_t) * y]);
412 for (int x = 0; x < jcp.m; x++) {
413 kandw(r_mask, y_mask, x_mask(x));
414
415 int i = y * jcp.m + x;
416 int offset
417 = jcp.typesize_out * (y * jcp.ow * jcp.oc + x * jcp.oc);
418 Address addr = EVEX_compress_addr(reg_aux_ptr_dst, offset);
419
420 Zmm zmm = vreg_out(i);
421 Xmm xmm = xmm_out(i);
422 vcvtdq2ps(zmm, zmm);
423 if (jcp.with_bias) vaddps(zmm, zmm, vreg_bias);
424 vmulps(zmm, zmm, ptr[reg_ptr_scales]);
425 if (maybe_relu(0)) vmaxps(zmm, vreg_zero, zmm);
426 if (p_sum_scale) { // post_op: sum
427 vpxord(vreg_prev_dst, vreg_prev_dst, vreg_prev_dst);
428 switch (sum_dt) {
429 case data_type::f32:
430 case data_type::s32:
431 vmovups(vreg_prev_dst | r_mask, addr);
432 break;
433 case data_type::s8:
434 vpmovsxbd(vreg_prev_dst | r_mask, addr);
435 break;
436 case data_type::u8:
437 vpmovzxbd(vreg_prev_dst | r_mask, addr);
438 break;
439 default: assert(!"unknown sum_dt");
440 }
441 if (sum_dt != data_type::f32)
442 vcvtdq2ps(vreg_prev_dst, vreg_prev_dst);
443 if (*p_sum_zp != 0) vsubps(vreg_prev_dst, vreg_sum_zp);
444 if (*p_sum_scale == 1.f)
445 vaddps(zmm, vreg_prev_dst);
446 else
447 vfmadd231ps(
448 zmm, vreg_prev_dst, zword_b[reg_ptr_sum_scale]);
449 }
450 // we skip max if dst_dt == u8 as it will happen in saturation
451 if (maybe_relu(1) && (jcp.dst_dt != u8))
452 vmaxps(zmm, vreg_zero, zmm);
453
454 if (utils::one_of(jcp.dst_dt, u8, s8, s32)) {
455 saturate_f32(
456 zmm, vreg_zero, vreg_saturation_ubound, jcp.dst_dt);
457 vcvtps2dq(zmm, zmm);
458 }
459 switch (jcp.dst_dt) {
460 case data_type::f32:
461 case data_type::s32: vmovups(addr, zmm | r_mask); break;
462 case data_type::s8:
463 vpmovsdb(xmm, zmm);
464 vmovups(addr, xmm | r_mask);
465 break;
466 case data_type::u8:
467 vpmovusdb(xmm, zmm);
468 vmovups(addr, xmm | r_mask);
469 break;
470 default: assert(!"unknown dst_dt");
471 }
472 }
473 }
474 };
475
476 preamble();
477
478 #define READ_PARAM(reg, field) \
479 mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
480 READ_PARAM(reg_ptr_src, wino_dst);
481 READ_PARAM(reg_ptr_dst, dst);
482 READ_PARAM(reg_ptr_v_y_masks, v_y_masks);
483 READ_PARAM(reg_ptr_v_x_masks, v_x_masks);
484 READ_PARAM(reg_ptr_bias, bias);
485 READ_PARAM(reg_ptr_scales, scales);
486 #undef READ_PARAM
487
488 if (jcp.with_bias)
489 mov(reg_scratch_bias_alpha, float2int(adj_src_scale * adj_wei_scale));
490
491 mov(reg_aux_ptr_src, reg_ptr_src);
492 mov(reg_aux_ptr_dst, reg_ptr_dst);
493
494 vpxord(vreg_zero, vreg_zero, vreg_zero);
495
496 for (int i = 0; i < jcp.m; i++)
497 kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(uint16_t) * i]);
498
499 int oc_blocks = jcp.oc / load_block;
500 mov(reg_oc_block, oc_blocks);
501 L(oc_block_label);
502 {
503 loop_body();
504 add(reg_aux_ptr_src, sizeof(int32_t) * load_block);
505 add(reg_aux_ptr_dst, jcp.typesize_out * load_block);
506
507 add(reg_ptr_scales, jcp.is_oc_scale * sizeof(float) * load_block);
508 add(reg_ptr_bias, sizeof(jcp.typesize_bia) * load_block);
509 }
510 dec(reg_oc_block);
511 jnz(oc_block_label, T_NEAR);
512
513 postamble();
514 }
515
516 /// GEMM kernel ////////////////////////////////////////////////////////////////
517 struct jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t : public jit_generator {
518 DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t)
519 jit_conv_conf_2x3_wino_t jcp;
520 const primitive_attr_t &attr_;
521
522 struct call_params_t {
523 const void *src;
524 const void *dst;
525 const void *wei;
526 const void *dst_b;
527 };
528
529 void generate() override;
530 static bool post_ops_ok(
531 jit_conv_conf_2x3_wino_t &jcp, const primitive_attr_t &attr);
532
jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_tdnnl::impl::cpu::x64::jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t533 jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t(
534 jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr)
535 : jcp(ajcp), attr_(attr) {}
536
537 static status_t init_conf(jit_conv_conf_2x3_wino_t &jcp,
538 const convolution_desc_t &cd, memory_desc_t &src_md,
539 memory_desc_t &weights_md, memory_desc_t &dst_md,
540 memory_desc_t &bias_md, const primitive_attr_t &attr);
541
vreg_outdnnl::impl::cpu::x64::jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t542 Zmm vreg_out(int n, int m) const {
543 const int id_reg_out = n * jcp.m_block + m;
544 assert(id_reg_out < jcp.n2_block * jcp.m_block);
545 return Zmm(31 - id_reg_out);
546 }
vreg_weidnnl::impl::cpu::x64::jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t547 Zmm vreg_wei(int i) const {
548 assert(31 - jcp.n2_block * jcp.m_block - i
549 > (jcp.ver == ver_vnni ? 0 : 2));
550 return Zmm(31 - jcp.n2_block * jcp.m_block - i);
551 }
552
553 Zmm vreg_src = Zmm(0);
554 Zmm vreg_one = Zmm(1);
555 Zmm vreg_tmp = Zmm(2);
556
557 Reg64 reg_ptr_src = r15;
558
559 Reg64 reg_aux_dst_b = r13;
560 Reg64 reg_aux_dst = r12;
561 Reg64 reg_aux_dst2 = r11;
562 Reg64 reg_aux_wei = r10;
563 Reg64 reg_aux_wei2 = r9;
564 Reg64 reg_aux_src = r8;
565 Reg64 reg_aux_src2 = rax;
566 Reg64 reg_mb = rbx;
567 Reg64 reg_nnb = abi_not_param1;
568 Reg64 reg_scratch = rdx;
569 Reg64 reg_K = rsi;
570 };
571
post_ops_ok(jit_conv_conf_2x3_wino_t & jcp,const primitive_attr_t & attr)572 bool jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::post_ops_ok(
573 jit_conv_conf_2x3_wino_t &jcp, const primitive_attr_t &attr) {
574 using namespace primitive_kind;
575 const auto &p = attr.post_ops_;
576
577 auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); };
578
579 switch (p.len()) {
580 case 0: return true;
581 case 1: return is_relu(0) || p.contain(sum, 0);
582 case 2:
583 return (p.contain(sum, 0) && is_relu(1))
584 || (p.contain(sum, 1) && is_relu(0));
585 case 3: return is_relu(0) && p.contain(sum, 1) && is_relu(2);
586 default: return false;
587 }
588
589 return false;
590 }
591
generate()592 void jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::generate() {
593 Label nnb_loop_label, K_loop_label, mb_loop_label;
594
595 auto compute = [=](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) {
596 if (jcp.ver == ver_vnni) {
597 vpdpbusd(vreg_acc, vreg_src, vreg_wei);
598 } else {
599 vpmaddubsw(vreg_tmp, vreg_src, vreg_wei);
600 vpmaddwd(vreg_tmp, vreg_tmp, vreg_one);
601 vpaddd(vreg_acc, vreg_acc, vreg_tmp);
602 }
603 };
604
605 preamble();
606 #define READ_PARAM(reg, field) \
607 mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
608 READ_PARAM(reg_ptr_src, src);
609 READ_PARAM(reg_aux_dst, dst);
610 READ_PARAM(reg_aux_wei, wei);
611 READ_PARAM(reg_aux_dst_b, dst_b);
612 #undef READ_PARAM
613
614 if (jcp.ver != ver_vnni) {
615 xor_(reg_scratch, reg_scratch);
616 Reg16 _t = reg_scratch.cvt16();
617 mov(_t, 0x1);
618 vpbroadcastw(vreg_one, _t);
619 }
620
621 if (!jcp.small_mb) {
622 mov(reg_nnb, jcp.n_chunks);
623 L(nnb_loop_label);
624 }
625 mov(reg_aux_dst2, reg_aux_dst);
626 mov(reg_aux_src, reg_ptr_src);
627 mov(reg_mb, jcp.M / jcp.m_block);
628 L(mb_loop_label);
629 {
630 for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) {
631 for (int m = 0; m < jcp.m_block; m++) {
632 int offset = jcp.typesize_acc * nb2 * jcp.n_block;
633 vmovups(vreg_out(nb2, m),
634 EVEX_compress_addr(reg_aux_dst_b, offset));
635 }
636 }
637 mov(reg_aux_src2, reg_aux_src);
638 mov(reg_aux_wei2, reg_aux_wei);
639 mov(reg_K, jcp.k_chunks);
640 L(K_loop_label);
641 {
642 for (int k = 0; k < jcp.k2_block; k += 4) {
643 for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) {
644 int wei_offset
645 = jcp.typesize_in * (nb2 * jcp.n_block * jcp.K);
646 vmovups(vreg_wei(nb2),
647 EVEX_compress_addr(reg_aux_wei2, wei_offset));
648 }
649 for (int m = 0; m < jcp.m_block; m++) {
650 int inp_offset = jcp.typesize_in * m * jcp.K;
651 vpbroadcastd(vreg_src,
652 EVEX_compress_addr(reg_aux_src2, inp_offset));
653 for (int nb2 = 0; nb2 < jcp.n2_block; nb2++)
654 compute(vreg_out(nb2, m), vreg_wei(nb2), vreg_src);
655 }
656 add(reg_aux_src2, jcp.typesize_in * 4);
657 add(reg_aux_wei2, jcp.typesize_in * 4 * jcp.n_block);
658 }
659 }
660 dec(reg_K);
661 jnz(K_loop_label, T_NEAR);
662
663 for (int m = 0; m < jcp.m_block; m++) {
664 for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) {
665 int offset = jcp.typesize_acc * (m * jcp.N + nb2 * jcp.n_block);
666 vmovups(EVEX_compress_addr(reg_aux_dst2, offset),
667 vreg_out(nb2, m));
668 }
669 }
670 add(reg_aux_src, jcp.typesize_in * jcp.m_block * jcp.K);
671 add(reg_aux_dst2, jcp.typesize_acc * jcp.m_block * jcp.N);
672 }
673 dec(reg_mb);
674 jnz(mb_loop_label, T_NEAR);
675
676 if (!jcp.small_mb) {
677 add(reg_aux_dst, jcp.typesize_acc * jcp.n2_block * jcp.n_block);
678 add(reg_aux_dst_b, jcp.typesize_acc * jcp.n2_block * jcp.n_block);
679 add(reg_aux_wei, jcp.typesize_in * jcp.n2_block * jcp.n_block * jcp.K);
680
681 dec(reg_nnb);
682 jnz(nnb_loop_label, T_NEAR);
683 }
684
685 postamble();
686 }
687 namespace {
is_winograd_faster_than_direct(const jit_conv_conf_2x3_wino_t & jcp)688 bool is_winograd_faster_than_direct(const jit_conv_conf_2x3_wino_t &jcp) {
689 if (jcp.ver == ver_vnni) {
690 return (jcp.mb <= jcp.nthr
691 && (jcp.mb > 4 && jcp.ic > 64
692 && !(jcp.oc > 128 && jcp.ih < 14)))
693 || jcp.mb > jcp.nthr;
694 }
695 return true;
696 }
697 } // namespace
698
init_conf(jit_conv_conf_2x3_wino_t & jcp,const convolution_desc_t & cd,memory_desc_t & src_md,memory_desc_t & wei_md,memory_desc_t & dst_md,memory_desc_t & bias_md,const primitive_attr_t & attr)699 status_t jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::init_conf(
700 jit_conv_conf_2x3_wino_t &jcp, const convolution_desc_t &cd,
701 memory_desc_t &src_md, memory_desc_t &wei_md, memory_desc_t &dst_md,
702 memory_desc_t &bias_md, const primitive_attr_t &attr) {
703 const memory_desc_wrapper src_d(&src_md);
704 const memory_desc_wrapper wei_d(&wei_md);
705 const memory_desc_wrapper dst_d(&dst_md);
706 const memory_desc_wrapper bias_d(&bias_md);
707
708 // This kernel only supports 2D convolutions.
709 if (src_d.ndims() != 4) return status::unimplemented;
710
711 const bool with_groups = wei_d.ndims() == src_d.ndims() + 1;
712
713 jcp.nthr = dnnl_get_max_threads();
714
715 jcp.ngroups = with_groups ? wei_d.dims()[0] : 1;
716 jcp.mb = src_d.dims()[0];
717 jcp.oc = dst_d.dims()[1] / jcp.ngroups;
718 jcp.ic = src_d.dims()[1] / jcp.ngroups;
719 jcp.ih = src_d.dims()[2];
720 jcp.iw = src_d.dims()[3];
721 jcp.oh = dst_d.dims()[2];
722 jcp.ow = dst_d.dims()[3];
723 jcp.kh = wei_d.dims()[with_groups + 2];
724 jcp.kw = wei_d.dims()[with_groups + 3];
725 jcp.t_pad = cd.padding[0][0];
726 jcp.l_pad = cd.padding[0][1];
727 jcp.stride_h = cd.strides[0];
728 jcp.stride_w = cd.strides[1];
729 jcp.dilate_h = cd.dilates[0];
730 jcp.dilate_w = cd.dilates[1];
731
732 const int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
733 const int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
734 jcp.r_pad = calculate_end_padding(
735 jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw);
736 jcp.b_pad = calculate_end_padding(
737 jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh);
738
739 jcp.ver = ver_avx512_core;
740 if (!(mayiuse(avx512_core) && src_d.data_type() == data_type::u8
741 && wei_d.data_type() == data_type::s8
742 && one_of(dst_d.data_type(), data_type::f32, data_type::s32,
743 data_type::s8, data_type::u8)))
744 return status::unimplemented;
745 if (mayiuse(avx512_core_vnni)) jcp.ver = ver_vnni;
746
747 if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
748 is_winograd_faster_than_direct(jcp)))
749 return status::unimplemented;
750
751 // block sizes needed for GEMM kernel
752 jcp.ic_block = 4;
753 jcp.oc_block = 16;
754
755 bool ok = true && jcp.ngroups == 1 && jcp.oc % load_block == 0
756 && jcp.ic % load_block == 0 && jcp.oc % jcp.oc_block == 0
757 && jcp.ic % jcp.ic_block == 0 && everyone_is(3, jcp.kh, jcp.kw)
758 && everyone_is(1, jcp.stride_h, jcp.stride_w)
759 && everyone_is(0, jcp.dilate_h, jcp.dilate_w)
760 && jcp.t_pad == jcp.b_pad && jcp.l_pad == jcp.r_pad
761 && one_of(jcp.t_pad, 0, 1) && one_of(jcp.l_pad, 0, 1);
762 if (!ok) return status::unimplemented;
763
764 format_tag_t dat_tag = format_tag::nhwc;
765 if (!src_d.matches_tag(dat_tag)) return status::unimplemented;
766 if (!dst_d.matches_tag(dat_tag)) return status::unimplemented;
767
768 jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
769
770 if (!post_ops_ok(jcp, attr)) return status::unimplemented;
771
772 jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
773 jcp.dst_dt = cd.dst_desc.data_type;
774
775 jcp.typesize_in = types::data_type_size(src_d.data_type());
776 jcp.typesize_out = types::data_type_size(dst_d.data_type());
777 jcp.typesize_acc = sizeof(int32_t);
778 jcp.typesize_bia
779 = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0;
780
781 jcp.nb_oc = jcp.oc / jcp.oc_block;
782 jcp.nb_ic = jcp.ic / jcp.ic_block;
783
784 jcp.m = 2;
785 jcp.r = 3;
786 jcp.alpha = jcp.m + jcp.r - 1;
787
788 int aa = jcp.alpha * jcp.alpha;
789 int L1_cap = platform::get_per_core_cache_size(1);
790 int L2_cap = platform::get_per_core_cache_size(2);
791 // need 1 extra reg for bcast, and 2 tmp regs for non-vnni
792 int free_regs = jcp.ver == ver_vnni ? 31 : 29;
793
794 auto get_thr_eff = [&](int small_mb, int ix, int iy, int n2_b) {
795 float thr_eff;
796 float Z = (float)jcp.ic + jcp.oc;
797 float Y = (float)jcp.ic * jcp.oc;
798 if (small_mb == 0) { // outer par
799 int nblocks = jcp.mb * div_up(jcp.oh, iy) * div_up(jcp.ow, ix);
800 thr_eff = (float)nblocks / rnd_up(nblocks, jcp.nthr);
801 } else { // inner par
802 int tranw = iy * ix / jcp.alpha;
803 int gemmw = aa * (jcp.nb_oc / n2_b);
804 int tranw_r = rnd_up(tranw, jcp.nthr);
805 int gemmw_r = rnd_up(gemmw, jcp.nthr);
806 thr_eff = (Z * tranw / tranw_r + Y * gemmw / gemmw_r) / (Z + Y);
807 }
808 return thr_eff;
809 };
810
811 auto get_mem_eff = [&](int small_mb, int ix, int iy, int n2_b) {
812 float mem_eff, req_mem;
813 int M = ix * iy / jcp.alpha;
814 if (small_mb == 0) { // outer parallelization strategy
815 // memory for wino transforms (other memory has poor reuse)
816 req_mem = (float)aa * M * (jcp.ic + jcp.typesize_acc * jcp.oc);
817 mem_eff = req_mem < L1_cap ? 1.f : req_mem < L2_cap ? 0.5f : 0.f;
818 } else { // inner parallelization strategy
819 // memory used during gemm
820 int N = jcp.oc_block * n2_b;
821 req_mem = (float)jcp.ic * (M + N) + jcp.typesize_acc * M * N;
822 mem_eff = nstl::min(1.f, L2_cap / req_mem);
823 // memory used during wino transforms
824 int M_per_thr = div_up(M, jcp.nthr);
825 req_mem = (float)aa * M_per_thr
826 * (jcp.ic + jcp.typesize_acc * jcp.oc);
827 if (req_mem > L2_cap) mem_eff = 0.1f;
828 }
829 return mem_eff;
830 };
831
832 auto get_tot_eff = [&](int small_mb, float thr_eff, float work_eff,
833 float mem_eff, float reg_eff) {
834 // these coefficients are chosen empirically
835 float mem_fac = 0.1f, reg_fac = 0.2f;
836 // normalized overhead relative to memory and register components
837 float tot_eff = 1.f + mem_fac * mem_eff + reg_fac * reg_eff;
838 // thread and work components affect all others
839 tot_eff *= thr_eff * work_eff;
840 return tot_eff;
841 };
842
843 auto find_m_n2_blocks
844 = [&](bool small_mb, int ix, int iy, float work_eff, int &m_block,
845 int &n2_block, float &tot_eff) {
846 int M = (ix * iy) / jcp.alpha;
847 int max_m_block = nstl::min(M, free_regs);
848 int max_n2_block = nstl::min(jcp.nb_oc, free_regs);
849 tot_eff = 0.f;
850 for (int im = max_m_block; im > 0; im--) {
851 if (M % im) continue;
852 for (int in2 = max_n2_block; in2 > 0; in2--) {
853 int used_regs = (im + 1) * in2;
854 float mem_eff = get_mem_eff(small_mb, ix, iy, in2);
855 float reg_eff = (float)(im * in2) / (im + in2);
856 float thr_eff = get_thr_eff(small_mb, ix, iy, in2);
857 float cur_tot_eff = get_tot_eff(small_mb, thr_eff,
858 work_eff, mem_eff, reg_eff);
859 if (jcp.nb_oc % in2 || used_regs > free_regs
860 || cur_tot_eff <= tot_eff)
861 continue;
862 tot_eff = cur_tot_eff;
863 m_block = im;
864 n2_block = in2;
865 }
866 }
867 };
868
869 /* Selecting xb and yb blocking */
870 int min_yb = jcp.m;
871 int min_xb = jcp.m;
872 int max_yb = nstl::max(min_yb, rnd_up(jcp.oh, 2));
873 int max_xb = nstl::max(min_xb, rnd_up(jcp.ow, 2));
874 float best_eff = 0.f;
875 for (int ix = min_xb; ix <= max_xb; ix += 2) {
876 assert(rnd_up(jcp.ow, ix) >= jcp.iw - 2);
877 for (int iy = max_yb; iy >= min_yb; iy -= 2) {
878 assert(rnd_up(jcp.oh, iy) >= jcp.ih - 2);
879
880 int m_b[2];
881 int n2_b[2];
882 bool small_mb;
883 float inner_eff, outer_eff, work_eff;
884
885 int tiled_area = rnd_up(jcp.oh, iy) * rnd_up(jcp.ow, ix);
886 work_eff = (float)jcp.oh * jcp.ow / tiled_area;
887 if (best_eff > 0.f && work_eff < 4.f / 9.f)
888 continue; // no gain from Winograd transformation
889
890 /* outer parallelization */
891 find_m_n2_blocks(
892 false, ix, iy, work_eff, m_b[0], n2_b[0], outer_eff);
893
894 /* inner parallelization */
895 find_m_n2_blocks(
896 true, ix, iy, work_eff, m_b[1], n2_b[1], inner_eff);
897
898 small_mb = inner_eff > outer_eff;
899 float eff = small_mb ? inner_eff : outer_eff;
900 if (eff > best_eff) {
901 best_eff = eff;
902 jcp.yb = iy;
903 jcp.xb = ix;
904 jcp.m_block = m_b[small_mb];
905 jcp.n2_block = n2_b[small_mb];
906 jcp.small_mb = small_mb;
907 }
908 }
909 }
910
911 assert((jcp.m_block + 1) * jcp.n2_block <= free_regs);
912 assert(jcp.xb % 2 == 0 && jcp.yb % 2 == 0);
913
914 jcp.mb_block = 1;
915 if (jcp.small_mb) {
916 // For small mb harness, set mb_block as large as possible subject to
917 // the constraint that winograd activations fit into available L3 cache
918 int L3_cap = platform::get_per_core_cache_size(3);
919 int M = jcp.xb * jcp.yb / 4;
920 int wino_src_size = 16 * M * jcp.ic * jcp.typesize_in;
921 int wino_dst_size = 16 * M * jcp.oc * jcp.typesize_acc;
922 int max_mb_block = nstl::min(
923 jcp.mb, jcp.nthr * L3_cap / (wino_src_size + wino_dst_size));
924 for (int i = max_mb_block; i > 1; i--) {
925 if (jcp.mb % i == 0) {
926 jcp.mb_block = i;
927 break;
928 }
929 }
930 }
931 jcp.nb_mb = jcp.mb / jcp.mb_block;
932
933 jcp.M = jcp.mb_block * jcp.xb * jcp.yb / 4;
934 jcp.N = jcp.oc;
935 jcp.K = jcp.ic;
936
937 jcp.inp_stride = jcp.M * jcp.ic;
938 jcp.out_stride = jcp.M * jcp.oc;
939 jcp.wei_stride = jcp.ic * jcp.oc;
940 jcp.bia_stride = jcp.oc;
941
942 jcp.n_block = jcp.oc_block;
943 jcp.k_block = jcp.ic_block;
944
945 jcp.n_chunks = (jcp.N / jcp.n_block) / jcp.n2_block;
946
947 // We need jcp.k2_block to be a multiple of jcp.k_block = jcp.ic_block = 4
948 // and jcp.K = jcp.ic to be a multiple of jcp.k2_block. Since jcp.ic is
949 // a multiple of load_block = 16, we just use that for now.
950 jcp.k2_block = load_block;
951 jcp.k_chunks = jcp.K / jcp.k2_block;
952
953 const auto &oscales = attr.output_scales_;
954 jcp.is_oc_scale = oscales.mask_ == 1 << 1;
955
956 // only common and per-oc-channel scales are supported
957 const bool oscales_ok = one_of(oscales.mask_, 0, 1 << 1);
958 if (!oscales_ok) return status::unimplemented;
959
960 /* re-create weights primitive descriptor
961 and set weights wino_blocking */
962 memory_desc_t expect_wei_md = wei_md;
963
964 expect_wei_md.format_kind = format_kind::wino;
965 expect_wei_md.data_type = data_type::s8;
966 dnnl_wino_desc_t &wd = expect_wei_md.format_desc.wino_desc;
967 wd.wino_format = dnnl_wino_wei_aaOIoi;
968 wd.r = jcp.r;
969 wd.alpha = jcp.alpha;
970 wd.ic = jcp.ic;
971 wd.oc = jcp.oc;
972 wd.ic_block = jcp.ic_block;
973 wd.oc_block = jcp.oc_block;
974 wd.oc2_block = jcp.n2_block;
975 wd.ic2_block = 1;
976 wd.adj_scale = adj_wei_scale;
977
978 size_t max_size = types::data_type_size(data_type::s8) * jcp.alpha
979 * jcp.alpha * jcp.ic * jcp.oc;
980 max_size += types::data_type_size(data_type::s32) * jcp.alpha * jcp.alpha
981 * jcp.oc;
982 wd.size = max_size;
983
984 if (wei_md.format_kind == format_kind::any) wei_md = expect_wei_md;
985 if (wei_md != expect_wei_md) return status::unimplemented;
986
987 const int tilesize = jcp.alpha * jcp.alpha;
988 const int numtiles = jcp.M;
989 const int alltiles = numtiles * tilesize;
990
991 jcp.size_wino_src
992 = utils::rnd_up(jcp.typesize_in * alltiles * jcp.ic, PAGE_4K)
993 / jcp.typesize_in;
994 jcp.size_wino_wei = tilesize * jcp.oc * jcp.ic;
995 jcp.size_wino_dst = alltiles * jcp.oc;
996
997 return status::success;
998 }
999 ////////////////////////////////////////////////////////////////////////////////
1000
jit_conf()1001 status_t jit_avx512_core_u8s8s32x_wino_convolution_fwd_t::pd_t::jit_conf() {
1002 return jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::init_conf(jcp_,
1003 *this->desc(), this->src_md_, this->weights_md_, this->dst_md_,
1004 this->bias_md_, *this->attr());
1005 }
1006
init_scratchpad()1007 void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t::pd_t::init_scratchpad() {
1008 auto scratchpad = this->scratchpad_registry().registrar();
1009
1010 int nthr_multiplier = jcp_.small_mb ? 1 : jcp_.nthr;
1011 scratchpad.template book<src_data_t>(
1012 key_wino_V, jcp_.size_wino_src * nthr_multiplier, PAGE_4K);
1013 scratchpad.template book<acc_data_t>(
1014 key_wino_M, jcp_.size_wino_dst * nthr_multiplier, PAGE_4K);
1015
1016 dim_t scale_count = attr()->output_scales_.count_;
1017 scratchpad.template book<float>(
1018 key_conv_adjusted_scales, nstl::max<dim_t>(scale_count, 16));
1019 }
1020
1021 jit_avx512_core_u8s8s32x_wino_convolution_fwd_t::
jit_avx512_core_u8s8s32x_wino_convolution_fwd_t(const pd_t * apd)1022 jit_avx512_core_u8s8s32x_wino_convolution_fwd_t(const pd_t *apd)
1023 : primitive_t(apd) {}
1024
init(engine_t * engine)1025 status_t jit_avx512_core_u8s8s32x_wino_convolution_fwd_t::init(
1026 engine_t *engine) {
1027 CHECK(safe_ptr_assign(kernel_,
1028 new jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t(
1029 pd()->jcp_, *pd()->attr())));
1030 CHECK(safe_ptr_assign(src_trans_,
1031 new jit_avx512_core_u8s8s32x_wino_conv_src_trans_t(
1032 pd()->jcp_, *pd()->attr())));
1033 CHECK(safe_ptr_assign(dst_trans_,
1034 new jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t(
1035 pd()->jcp_, *pd()->attr())));
1036 CHECK(kernel_->create_kernel());
1037 CHECK(src_trans_->create_kernel());
1038 CHECK(dst_trans_->create_kernel());
1039 return status::success;
1040 }
1041
1042 jit_avx512_core_u8s8s32x_wino_convolution_fwd_t::
1043 ~jit_avx512_core_u8s8s32x_wino_convolution_fwd_t()
1044 = default;
1045
adjust_oscales(const memory_tracking::grantor_t & scratchpad) const1046 const float *jit_avx512_core_u8s8s32x_wino_convolution_fwd_t::adjust_oscales(
1047 const memory_tracking::grantor_t &scratchpad) const {
1048 const float *oscales = pd()->attr()->output_scales_.scales_;
1049 auto loc_scales = scratchpad.template get<float>(key_conv_adjusted_scales);
1050 size_t count = pd()->attr()->output_scales_.count_;
1051 float factor = 1.f / (adj_src_scale * adj_wei_scale);
1052 if (count == 1)
1053 utils::array_set(loc_scales, oscales[0] * factor, 16);
1054 else
1055 for (size_t c = 0; c < count; c++)
1056 loc_scales[c] = oscales[c] * factor;
1057 return loc_scales;
1058 }
1059
execute_forward(const exec_ctx_t & ctx) const1060 void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t ::execute_forward(
1061 const exec_ctx_t &ctx) const {
1062 auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
1063 auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS);
1064 auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
1065 auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST);
1066
1067 const auto &jcp = kernel_->jcp;
1068 if (jcp.small_mb)
1069 execute_forward_small_mb(
1070 src, weights, bias, dst, ctx.get_scratchpad_grantor());
1071 else
1072 execute_forward_mbN(
1073 src, weights, bias, dst, ctx.get_scratchpad_grantor());
1074 }
1075
execute_forward_mbN(const src_data_t * src,const wei_data_t * wei,const char * bia,char * dst,const memory_tracking::grantor_t & scratchpad) const1076 void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t::execute_forward_mbN(
1077 const src_data_t *src, const wei_data_t *wei, const char *bia,
1078 char *dst, const memory_tracking::grantor_t &scratchpad) const {
1079 const auto &jcp = kernel_->jcp;
1080 const memory_desc_wrapper dst_d(pd()->dst_md());
1081 const size_t dst_dt_size = types::data_type_size(dst_d.data_type());
1082 const float *oscales = adjust_oscales(scratchpad);
1083
1084 auto dst_bias = (const acc_data_t *)(wei + jcp.size_wino_wei);
1085 auto wino_src_base = scratchpad.template get<src_data_t>(key_wino_V);
1086 auto wino_dst_base = scratchpad.template get<acc_data_t>(key_wino_M);
1087
1088 parallel_nd_ext(jcp.nthr, jcp.mb, div_up(jcp.oh, jcp.yb),
1089 div_up(jcp.ow, jcp.xb),
1090 [&](dim_t ithr, dim_t nthr, dim_t mb, dim_t tile_y_b,
1091 dim_t tile_x_b) {
1092 assert(nthr <= jcp.nthr);
1093 MAYBE_UNUSED(nthr);
1094
1095 int tile_y = tile_y_b * jcp.yb;
1096 int tile_x = tile_x_b * jcp.xb;
1097
1098 auto wino_src = wino_src_base + jcp.size_wino_src * ithr;
1099 auto wino_dst = wino_dst_base + jcp.size_wino_dst * ithr;
1100
1101 auto src_trans_p
1102 = jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::
1103 call_params_t();
1104 auto dst_trans_p
1105 = jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::
1106 call_params_t();
1107 auto gemm_p = jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::
1108 call_params_t();
1109
1110 /* transformation of input tensor to winograd domain */
1111 for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) {
1112 for (int x_in_block = 0; x_in_block < jcp.xb;
1113 x_in_block += 2) {
1114 uint16_t v_y_masks[4], v_x_masks[4];
1115
1116 int y = y_in_block + tile_y;
1117 int x = x_in_block + tile_x;
1118 int m = (y_in_block / 2) * (jcp.xb / 2)
1119 + (x_in_block / 2);
1120
1121 int v_ys = nstl::max(0, jcp.t_pad - y);
1122 int v_ye = nstl::min(jcp.alpha,
1123 nstl::max(0, jcp.ih + jcp.t_pad - y));
1124
1125 int v_xs = nstl::max(0, jcp.l_pad - x);
1126 int v_xe = nstl::min(jcp.alpha,
1127 nstl::max(0, jcp.iw + jcp.l_pad - x));
1128
1129 #pragma unroll(4)
1130 for (int i = 0; i < jcp.alpha; i++) {
1131 v_y_masks[i] = uint16_t(
1132 i < v_ys || i >= v_ye ? 0 : 0xffff);
1133 v_x_masks[i] = uint16_t(
1134 i < v_xs || i >= v_xe ? 0 : 0xffff);
1135 }
1136 auto local_s = src
1137 + (dim_t)mb * jcp.ih * jcp.iw * jcp.ic
1138 + y * jcp.iw * jcp.ic + x * jcp.ic;
1139 auto local_w = wino_src + m * jcp.ic;
1140
1141 src_trans_p.src = local_s;
1142 src_trans_p.wino_src = local_w;
1143 src_trans_p.v_y_masks = v_y_masks;
1144 src_trans_p.v_x_masks = v_x_masks;
1145
1146 (*src_trans_)(&src_trans_p);
1147 }
1148 }
1149 /* gemms */
1150 for (int tile_ij = 0; tile_ij < 16; tile_ij++) {
1151 // start threads at different GEMMs to help bring weights into LLC
1152 int offset = (tile_ij + ithr) % 16;
1153 gemm_p.src = wino_src + jcp.inp_stride * offset;
1154 gemm_p.dst = wino_dst + jcp.out_stride * offset;
1155 gemm_p.wei = wei + jcp.wei_stride * offset;
1156 gemm_p.dst_b = dst_bias + jcp.bia_stride * offset;
1157
1158 (*kernel_)(&gemm_p);
1159 }
1160
1161 auto dst_loc
1162 = dst + dst_dt_size * mb * jcp.oh * jcp.ow * jcp.oc;
1163 /* transformation from winograd domain to output tensor */
1164 for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) {
1165 for (int x_in_block = 0; x_in_block < jcp.xb;
1166 x_in_block += 2) {
1167 uint16_t v_y_masks[2], v_x_masks[2];
1168
1169 int y = y_in_block + tile_y;
1170 int x = x_in_block + tile_x;
1171 int m = (y_in_block / 2) * (jcp.xb / 2)
1172 + (x_in_block / 2);
1173
1174 #pragma unroll(2)
1175 for (int i = 0; i < jcp.m; i++) {
1176 v_x_masks[i]
1177 = uint16_t(x + i < jcp.ow ? 0xffff : 0);
1178 v_y_masks[i]
1179 = uint16_t(y + i < jcp.oh ? 0xffff : 0);
1180 }
1181 auto local_d = dst_loc
1182 + dst_dt_size
1183 * (y * jcp.ow * jcp.oc + x * jcp.oc);
1184 auto local_w = wino_dst + m * jcp.oc;
1185
1186 auto scales = oscales;
1187 dst_trans_p.dst = local_d;
1188 dst_trans_p.wino_dst = local_w;
1189 dst_trans_p.v_y_masks = v_y_masks;
1190 dst_trans_p.v_x_masks = v_x_masks;
1191
1192 dst_trans_p.scales = scales;
1193 dst_trans_p.bias = bia;
1194
1195 (*dst_trans_)(&dst_trans_p);
1196 }
1197 }
1198 });
1199 }
1200
execute_forward_small_mb(const src_data_t * src,const wei_data_t * wei,const char * bia,char * dst,const memory_tracking::grantor_t & scratchpad) const1201 void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t::execute_forward_small_mb(
1202 const src_data_t *src, const wei_data_t *wei, const char *bia,
1203 char *dst, const memory_tracking::grantor_t &scratchpad) const {
1204 const auto &jcp = kernel_->jcp;
1205 const memory_desc_wrapper dst_d(pd()->dst_md());
1206 const size_t dst_dt_size = types::data_type_size(dst_d.data_type());
1207 const float *oscales = adjust_oscales(scratchpad);
1208
1209 auto dst_bias = (const acc_data_t *)(wei + jcp.size_wino_wei);
1210 auto wino_src = scratchpad.template get<src_data_t>(key_wino_V);
1211 auto wino_dst = scratchpad.template get<acc_data_t>(key_wino_M);
1212
1213 for_(int mbb = 0; mbb < jcp.nb_mb; mbb++)
1214 for_(int tile_y = 0; tile_y < jcp.oh; tile_y += jcp.yb)
1215 for (int tile_x = 0; tile_x < jcp.ow; tile_x += jcp.xb) {
1216 /* transformation of input tensor to winograd domain */
1217 parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), jcp.mb_block,
1218 [&](dim_t y_in_block_b, dim_t x_in_block_b, dim_t mb) {
1219 int y_in_block = y_in_block_b * 2;
1220 int x_in_block = x_in_block_b * 2;
1221
1222 auto src_trans_p
1223 = jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::
1224 call_params_t();
1225
1226 uint16_t v_y_masks[4], v_x_masks[4];
1227
1228 int y = y_in_block + tile_y;
1229 int x = x_in_block + tile_x;
1230 int m = (mb * (jcp.yb / 2) + (y_in_block / 2))
1231 * (jcp.xb / 2)
1232 + (x_in_block / 2);
1233
1234 int v_ys = nstl::max(0, jcp.t_pad - y);
1235 int v_ye = nstl::min(
1236 jcp.alpha, nstl::max(0, jcp.ih + jcp.t_pad - y));
1237
1238 int v_xs = nstl::max(0, jcp.l_pad - x);
1239 int v_xe = nstl::min(
1240 jcp.alpha, nstl::max(0, jcp.iw + jcp.l_pad - x));
1241
1242 #pragma unroll(4)
1243 for (int i = 0; i < jcp.alpha; i++) {
1244 v_y_masks[i]
1245 = uint16_t(i < v_ys || i >= v_ye ? 0 : 0xffff);
1246 v_x_masks[i]
1247 = uint16_t(i < v_xs || i >= v_xe ? 0 : 0xffff);
1248 }
1249 auto local_s = src
1250 + ((dim_t)mbb * jcp.mb_block + mb) * jcp.ih * jcp.iw
1251 * jcp.ic
1252 + y * jcp.iw * jcp.ic + x * jcp.ic;
1253 auto local_w = wino_src + m * jcp.ic;
1254
1255 src_trans_p.src = local_s;
1256 src_trans_p.wino_src = local_w;
1257 src_trans_p.v_y_masks = v_y_masks;
1258 src_trans_p.v_x_masks = v_x_masks;
1259
1260 (*src_trans_)(&src_trans_p);
1261 });
1262
1263 /* gemms */
1264 parallel_nd(16, jcp.n_chunks, [&](dim_t tile_ij, dim_t nnb) {
1265 auto gemm_p = jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::
1266 call_params_t();
1267
1268 gemm_p.src = wino_src + jcp.inp_stride * tile_ij;
1269 gemm_p.dst = wino_dst + jcp.out_stride * tile_ij
1270 + nnb * jcp.n2_block * jcp.n_block;
1271 gemm_p.wei = wei + jcp.wei_stride * tile_ij
1272 + nnb * jcp.n2_block * jcp.n_block * jcp.K;
1273 gemm_p.dst_b = dst_bias + jcp.bia_stride * tile_ij
1274 + nnb * jcp.n2_block * jcp.n_block;
1275
1276 (*kernel_)(&gemm_p);
1277 });
1278
1279 /* transformation from winograd domain to output tensor */
1280 parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), jcp.mb_block,
1281 [&](dim_t y_in_block_b, dim_t x_in_block_b, dim_t mb) {
1282 int y_in_block = y_in_block_b * 2;
1283 int x_in_block = x_in_block_b * 2;
1284
1285 auto dst_trans_p
1286 = jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::
1287 call_params_t();
1288
1289 uint16_t v_y_masks[2], v_x_masks[2];
1290
1291 int y = y_in_block + tile_y;
1292 int x = x_in_block + tile_x;
1293 int m = (mb * (jcp.yb / 2) + (y_in_block / 2))
1294 * (jcp.xb / 2)
1295 + (x_in_block / 2);
1296
1297 #pragma unroll(2)
1298 for (int i = 0; i < jcp.m; i++) {
1299 v_x_masks[i] = uint16_t(x + i < jcp.ow ? 0xffff : 0);
1300 v_y_masks[i] = uint16_t(y + i < jcp.oh ? 0xffff : 0);
1301 }
1302 auto local_d = dst
1303 + dst_dt_size
1304 * ((mbb * jcp.mb_block + mb) * jcp.oh
1305 * jcp.ow * jcp.oc
1306 + y * jcp.ow * jcp.oc + x * jcp.oc);
1307 auto local_w = wino_dst + m * jcp.oc;
1308
1309 auto scales = oscales;
1310 dst_trans_p.dst = local_d;
1311 dst_trans_p.wino_dst = local_w;
1312 dst_trans_p.v_y_masks = v_y_masks;
1313 dst_trans_p.v_x_masks = v_x_masks;
1314
1315 dst_trans_p.scales = scales;
1316 dst_trans_p.bias = bia;
1317
1318 (*dst_trans_)(&dst_trans_p);
1319 });
1320 }
1321 }
1322
1323 } // namespace x64
1324 } // namespace cpu
1325 } // namespace impl
1326 } // namespace dnnl
1327