1 /*******************************************************************************
2 * Copyright 2018-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <assert.h>
18 #include <string.h>
19 
20 #include "common/c_types_map.hpp"
21 #include "common/dnnl_thread.hpp"
22 #include "common/memory_tracking.hpp"
23 #include "common/type_helpers.hpp"
24 #include "common/utils.hpp"
25 
26 #include "cpu/platform.hpp"
27 
28 #include "cpu/x64/jit_avx512_core_u8s8s32x_wino_convolution.hpp"
29 #include "cpu/x64/jit_generator.hpp"
30 #include "cpu/x64/jit_primitive_conf.hpp"
31 
32 namespace dnnl {
33 namespace impl {
34 namespace cpu {
35 namespace x64 {
36 
37 using namespace dnnl::impl::memory_tracking::names;
38 using namespace dnnl::impl::utils;
39 using namespace dnnl::impl::data_type;
40 using namespace Xbyak;
41 
42 namespace {
43 // Below scales are applied to source and weights data accordingly
44 // because this winograd implementation
45 // transforms source which may increase values up to 4x
46 // and transforms weights which may increase values up to 9/4x
47 const float adj_src_scale = 1.f / 4.f;
48 const float adj_wei_scale = 4.f / 9.f;
49 // Winograd transforms need ic and oc to be multiples of 16
50 const int load_block = 16;
51 } // namespace
52 
53 /// SRC TRANSFORMS /////////////////////////////////////////////////////////////
54 struct jit_avx512_core_u8s8s32x_wino_conv_src_trans_t : public jit_generator {
55     DECLARE_CPU_JIT_AUX_FUNCTIONS(
56             jit_avx512_core_u8s8s32x_wino_conv_src_trans_t)
57 
58     jit_conv_conf_2x3_wino_t jcp;
59     const primitive_attr_t &attr_;
60 
61     struct call_params_t {
62         const void *src;
63         const void *wino_src;
64         const void *v_y_masks;
65         const void *v_x_masks;
66     };
67 
jit_avx512_core_u8s8s32x_wino_conv_src_trans_tdnnl::impl::cpu::x64::jit_avx512_core_u8s8s32x_wino_conv_src_trans_t68     jit_avx512_core_u8s8s32x_wino_conv_src_trans_t(
69             jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr)
70         : jcp(ajcp), attr_(attr), unsign_val_in_wino_domain(5) {}
71     void generate() override;
72 
reg_inp_inddnnl::impl::cpu::x64::jit_avx512_core_u8s8s32x_wino_conv_src_trans_t73     int reg_inp_ind(int i) const {
74         assert(i < jcp.alpha * jcp.alpha);
75         return (31 - i);
76     }
77 
vreg_inpdnnl::impl::cpu::x64::jit_avx512_core_u8s8s32x_wino_conv_src_trans_t78     Xmm vreg_inp(int i) const { return Xmm(reg_inp_ind(i)); }
79 
zmm_inpdnnl::impl::cpu::x64::jit_avx512_core_u8s8s32x_wino_conv_src_trans_t80     Zmm zmm_inp(int i) const { return Zmm(reg_inp_ind(i)); }
81 
vreg_tmpdnnl::impl::cpu::x64::jit_avx512_core_u8s8s32x_wino_conv_src_trans_t82     Xmm vreg_tmp(int i) const {
83         assert(i < jcp.alpha * jcp.alpha);
84         return Xmm(15 - i);
85     }
vreg_outdnnl::impl::cpu::x64::jit_avx512_core_u8s8s32x_wino_conv_src_trans_t86     Xmm vreg_out(int i) const {
87         assert(i < jcp.alpha * jcp.alpha);
88         return Xmm(31 - i);
89     }
90 
91     Opmask y_mask = Opmask(1);
92     Opmask r_mask = Opmask(2);
x_maskdnnl::impl::cpu::x64::jit_avx512_core_u8s8s32x_wino_conv_src_trans_t93     Opmask x_mask(int id) {
94         assert(id < 4);
95         return Opmask(3 + id);
96     }
97 
98     Reg64 reg_ptr_src = r14;
99     Reg64 reg_ptr_dst = r13;
100 
101     Reg64 reg_ptr_v_y_masks = r12;
102     Reg64 reg_ptr_v_x_masks = r11;
103 
104     Reg64 reg_aux_ptr_src = r10;
105     Reg64 reg_aux_ptr_dst = r9;
106 
107     Reg64 reg_ic_block = r8;
108 
109     int unsign_val_in_wino_domain;
110 
111     Reg64 reg_scratch_src_alpha = rdx;
112     Xmm xmm_src_alpha = Xmm(0);
113     Zmm zmm_src_alpha = Zmm(0);
114 
115     Reg64 reg_shift = rax;
116     Xmm xmm_shift = Xmm(1);
117     Xmm xmm_zero = Xmm(0);
118 
119     Reg64 reg_maskx = rbx;
120     Reg64 reg_masky = rsi;
121     Reg64 reg_nomask = reg_maskx;
122 };
123 
generate()124 void jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::generate() {
125     Label ic_block_label;
126     Label end_label;
127     Label mask_label;
128     Label nomask_label;
129 
130     auto load_src = [=](bool mask) {
131         for (int y = 0; y < jcp.alpha; y++) {
132             if (mask)
133                 kmovw(y_mask, ptr[reg_ptr_v_y_masks + sizeof(uint16_t) * y]);
134             for (int x = 0; x < jcp.alpha; x++) {
135                 Zmm zmm_i = zmm_inp(y * jcp.alpha + x);
136                 Xmm vreg_i = vreg_inp(y * jcp.alpha + x);
137                 int inp_offset = sizeof(uint8_t)
138                         * ((-jcp.t_pad + y) * jcp.iw * jcp.ic
139                                 + (-jcp.l_pad + x) * jcp.ic);
140                 if (mask) {
141                     kandw(r_mask, y_mask, x_mask(x));
142                     vmovdqu8(vreg_i | r_mask | T_z,
143                             EVEX_compress_addr(reg_aux_ptr_src, inp_offset));
144                 } else {
145                     vmovdqu8(vreg_i,
146                             EVEX_compress_addr(reg_aux_ptr_src, inp_offset));
147                 }
148                 vpmovzxbd(zmm_i, vreg_i); // to int32
149                 vcvtdq2ps(zmm_i, zmm_i); // to fp32
150                 vmulps(zmm_i, zmm_i, zmm_src_alpha); // *alpha
151                 vcvtps2dq(zmm_i, zmm_i); // to int32
152                 vpmovusdb(vreg_i, zmm_i); // to u8
153             }
154         }
155     };
156 
157     preamble();
158 
159 #define READ_PARAM(reg, field) \
160     mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
161     READ_PARAM(reg_ptr_src, src);
162     READ_PARAM(reg_ptr_dst, wino_src);
163     READ_PARAM(reg_ptr_v_y_masks, v_y_masks);
164     READ_PARAM(reg_ptr_v_x_masks, v_x_masks);
165 #undef READ_PARAM
166 
167     mov(reg_maskx, ptr[reg_ptr_v_x_masks]);
168     mov(reg_masky, ptr[reg_ptr_v_y_masks]);
169     test(reg_maskx, reg_maskx);
170     jz(end_label, T_NEAR); // skip kernel if x mask is all 0's
171     test(reg_masky, reg_masky);
172     jz(end_label, T_NEAR); // skip kernel if y mask is all 0's
173     and_(reg_maskx, reg_masky);
174     mov(reg_nomask, reg_maskx);
175     not_(reg_nomask); // zero if x and y masks are all 1's
176 
177     xor_(reg_shift, reg_shift);
178     mov(reg_shift.cvt8(), (int8_t)-128);
179 
180     mov(reg_aux_ptr_src, reg_ptr_src);
181     mov(reg_aux_ptr_dst, reg_ptr_dst);
182 
183     for (int i = 0; i < jcp.alpha; i++) {
184         kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(uint16_t) * i]);
185     }
186 
187     mov(reg_scratch_src_alpha, float2int(adj_src_scale));
188 
189     mov(reg_ic_block, jcp.ic / load_block);
190     L(ic_block_label);
191     {
192         vmovq(xmm_src_alpha, reg_scratch_src_alpha);
193         vbroadcastss(zmm_src_alpha, xmm_src_alpha);
194 
195         test(reg_nomask, reg_nomask);
196         jz(nomask_label, T_NEAR);
197         load_src(true);
198         jmp(mask_label, T_NEAR);
199         L(nomask_label);
200         load_src(false);
201         L(mask_label);
202 
203         for (int y = 0; y < 4; y++) {
204             vpsubb(vreg_tmp(y * 4 + 0), vreg_inp(y * 4 + 0),
205                     vreg_inp(y * 4 + 2));
206             vpaddb(vreg_tmp(y * 4 + 1), vreg_inp(y * 4 + 1),
207                     vreg_inp(y * 4 + 2));
208             vpsubb(vreg_tmp(y * 4 + 2), vreg_inp(y * 4 + 2),
209                     vreg_inp(y * 4 + 1));
210             vpsubb(vreg_tmp(y * 4 + 3), vreg_inp(y * 4 + 1),
211                     vreg_inp(y * 4 + 3));
212         }
213         for (int x = 0; x < 4; x++) {
214             vpsubb(vreg_out(x + 0 * 4), vreg_tmp(x + 4 * 0),
215                     vreg_tmp(x + 4 * 2));
216             vpaddb(vreg_out(x + 1 * 4), vreg_tmp(x + 4 * 1),
217                     vreg_tmp(x + 4 * 2));
218             vpsubb(vreg_out(x + 2 * 4), vreg_tmp(x + 4 * 2),
219                     vreg_tmp(x + 4 * 1));
220             vpsubb(vreg_out(x + 3 * 4), vreg_tmp(x + 4 * 1),
221                     vreg_tmp(x + 4 * 3));
222         }
223 
224         vmovd(xmm_shift, reg_shift.cvt32());
225         vpxor(xmm_zero, xmm_zero, xmm_zero);
226         vpshufb(xmm_shift, xmm_shift, xmm_zero);
227 
228         for (int i = 0; i < 16; i++) {
229             int out_offset = sizeof(uint8_t) * (jcp.inp_stride * i);
230             if (i != unsign_val_in_wino_domain)
231                 vpsubb(vreg_out(i), vreg_out(i), Xmm(1));
232             vmovups(EVEX_compress_addr(reg_aux_ptr_dst, out_offset),
233                     vreg_out(i));
234         }
235 
236         add(reg_aux_ptr_src, sizeof(uint8_t) * load_block);
237         add(reg_aux_ptr_dst, sizeof(uint8_t) * load_block);
238     }
239     dec(reg_ic_block);
240     jnz(ic_block_label, T_NEAR);
241 
242     L(end_label);
243     postamble();
244 }
245 
246 /// DST TRANSFORMS /////////////////////////////////////////////////////////////
247 struct jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t : public jit_generator {
248     DECLARE_CPU_JIT_AUX_FUNCTIONS(
249             jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t)
250 
251     jit_conv_conf_2x3_wino_t jcp;
252     const primitive_attr_t &attr_;
253 
254     struct call_params_t {
255         const void *wino_dst;
256         const void *dst;
257         const void *v_y_masks;
258         const void *v_x_masks;
259 
260         const void *bias;
261         const void *scales;
262     };
263 
jit_avx512_core_u8s8s32x_wino_conv_dst_trans_tdnnl::impl::cpu::x64::jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t264     jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t(
265             jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr)
266         : jcp(ajcp), attr_(attr) {}
267 
268     void generate() override;
269     bool maybe_relu(int position);
270 
vreg_inpdnnl::impl::cpu::x64::jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t271     Zmm vreg_inp(int i) const { // 16
272         assert(i < jcp.alpha * jcp.alpha);
273         return Zmm(31 - i);
274     }
vreg_stgdnnl::impl::cpu::x64::jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t275     Zmm vreg_stg(int id) const { // 8
276         const int id_reg_stg = jcp.alpha * jcp.alpha + id;
277         assert(id < 8);
278         return Zmm(31 - id_reg_stg);
279     }
vreg_outdnnl::impl::cpu::x64::jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t280     Zmm vreg_out(int id) const { // 4
281         const int id_reg_out = jcp.alpha * jcp.alpha + 8 + id;
282         assert(id < 4);
283         return Zmm(31 - id_reg_out);
284     }
xmm_outdnnl::impl::cpu::x64::jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t285     Xmm xmm_out(int id) const { // 4
286         const int id_reg_out = jcp.alpha * jcp.alpha + 8 + id;
287         assert(id < 4);
288         return Xmm(31 - id_reg_out);
289     }
vreg_tmpdnnl::impl::cpu::x64::jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t290     Zmm vreg_tmp(int id) const { // 2
291         const int id_reg_tmp = jcp.alpha * jcp.alpha + 12 + id;
292         assert(id < 2);
293         return Zmm(31 - id_reg_tmp);
294     }
295 
296     Zmm vreg_zero = Zmm(0);
297     Zmm vreg_bias = Zmm(1);
298     Zmm vreg_prev_dst = Zmm(2);
299     Zmm zmm_bias_alpha = Zmm(2);
300     Xmm xmm_bias_alpha = Xmm(2);
301     Zmm vreg_saturation_ubound = Zmm(3);
302     Zmm vreg_sum_zp = Zmm(4);
303 
304     Opmask y_mask = Opmask(1);
305     Opmask r_mask = Opmask(2);
x_maskdnnl::impl::cpu::x64::jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t306     Opmask x_mask(int id) {
307         assert(id < 4);
308         return Opmask(3 + id);
309     }
310 
311     Reg64 reg_scratch_bias_alpha = r15;
312 
313     Reg64 reg_ptr_src = r14;
314     Reg64 reg_ptr_dst = r13;
315 
316     Reg64 reg_ptr_v_y_masks = r12;
317     Reg64 reg_ptr_v_x_masks = r11;
318 
319     Reg64 reg_aux_ptr_src = r10;
320     Reg64 reg_aux_ptr_dst = r9;
321 
322     Reg64 reg_oc_block = r8;
323 
324     Reg64 reg_ptr_bias = rbx;
325     Reg64 reg_ptr_scales = abi_not_param1;
326     Reg64 reg_ptr_sum_scale = rdx;
327     Reg64 reg_ptr_sum_zp = rdx;
328     Reg64 reg_ptr_saturation_ubound = rax;
329 };
330 
maybe_relu(int position)331 bool jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::maybe_relu(int position) {
332     using namespace primitive_kind;
333     const auto &p = attr_.post_ops_;
334 
335     if (position == 0) {
336         /* relu before sum */
337         return false || p.contain(eltwise, 0)
338                 || (jcp.dst_dt == data_type::u8 && !p.contain(sum, 0));
339     } else if (position == 1) {
340         /* relu after sum */
341         const int sum_idx
342                 = p.contain(sum, 0) ? 0 : (p.contain(sum, 1) ? 1 : -1);
343         if (sum_idx == -1) return false;
344 
345         return false || p.contain(eltwise, sum_idx + 1)
346                 || jcp.dst_dt == data_type::u8;
347     }
348 
349     return false;
350 }
351 
generate()352 void jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::generate() {
353     Label oc_block_label;
354 
355     auto loop_body = [=]() {
356         const auto &p = attr_.post_ops_;
357         const int sum_idx = p.find(primitive_kind::sum);
358         const float *p_sum_scale
359                 = (sum_idx != -1) ? &p.entry_[sum_idx].sum.scale : nullptr;
360         const int32_t *p_sum_zp
361                 = (sum_idx != -1) ? &p.entry_[sum_idx].sum.zero_point : nullptr;
362         if (p_sum_scale) {
363             if (*p_sum_zp != 0) {
364                 mov(reg_ptr_sum_zp, reinterpret_cast<size_t>(p_sum_zp));
365                 vcvtdq2ps(vreg_sum_zp, ptr_b[reg_ptr_sum_zp]);
366             }
367             if (*p_sum_scale != 1.f)
368                 mov(reg_ptr_sum_scale, reinterpret_cast<size_t>(p_sum_scale));
369         }
370         for (int i = 0; i < 16; i++) {
371             int internal_offset = sizeof(int32_t) * jcp.out_stride * i;
372             vmovups(vreg_inp(i),
373                     EVEX_compress_addr(reg_aux_ptr_src, internal_offset));
374         }
375         for (int y = 0; y < jcp.alpha; y++) {
376             vpaddd(vreg_tmp(0), vreg_inp(y * 4 + 0), vreg_inp(y * 4 + 1));
377             vpaddd(vreg_stg(y * 2), vreg_tmp(0), vreg_inp(y * 4 + 2));
378 
379             vpsubd(vreg_tmp(1), vreg_inp(y * 4 + 1), vreg_inp(y * 4 + 2));
380             vpsubd(vreg_stg(y * 2 + 1), vreg_tmp(1), vreg_inp(y * 4 + 3));
381         }
382         for (int x = 0; x < jcp.m; x++) {
383             vpaddd(vreg_tmp(0), vreg_stg(x), vreg_stg(x + 2 * 1));
384             vpaddd(vreg_out(x), vreg_tmp(0), vreg_stg(x + 2 * 2));
385 
386             vpsubd(vreg_tmp(1), vreg_stg(x + 2 * 1), vreg_stg(x + 2 * 2));
387             vpsubd(vreg_out(x + 2), vreg_tmp(1), vreg_stg(x + 2 * 3));
388         }
389 
390         if (jcp.with_bias) {
391             vmovq(xmm_bias_alpha, reg_scratch_bias_alpha);
392             vbroadcastss(zmm_bias_alpha, xmm_bias_alpha);
393 
394             auto bias_addr = ptr[reg_ptr_bias];
395             switch (jcp.bia_dt) {
396                 case data_type::f32:
397                 case data_type::s32: vmovups(vreg_bias, bias_addr); break;
398                 case data_type::s8: vpmovsxbd(vreg_bias, bias_addr); break;
399                 case data_type::u8: vpmovzxbd(vreg_bias, bias_addr); break;
400                 default: assert(!"unsupported dst data type");
401             }
402             if (jcp.bia_dt != data_type::f32) vcvtdq2ps(vreg_bias, vreg_bias);
403             vmulps(vreg_bias, vreg_bias, zmm_bias_alpha); // *alpha
404         }
405 
406         auto sum_dt = p.get_sum_dt(jcp.dst_dt);
407 
408         init_saturate_f32(vreg_zero, vreg_saturation_ubound,
409                 reg_ptr_saturation_ubound, f32, jcp.dst_dt);
410         for (int y = 0; y < jcp.m; y++) {
411             kmovw(y_mask, ptr[reg_ptr_v_y_masks + sizeof(uint16_t) * y]);
412             for (int x = 0; x < jcp.m; x++) {
413                 kandw(r_mask, y_mask, x_mask(x));
414 
415                 int i = y * jcp.m + x;
416                 int offset
417                         = jcp.typesize_out * (y * jcp.ow * jcp.oc + x * jcp.oc);
418                 Address addr = EVEX_compress_addr(reg_aux_ptr_dst, offset);
419 
420                 Zmm zmm = vreg_out(i);
421                 Xmm xmm = xmm_out(i);
422                 vcvtdq2ps(zmm, zmm);
423                 if (jcp.with_bias) vaddps(zmm, zmm, vreg_bias);
424                 vmulps(zmm, zmm, ptr[reg_ptr_scales]);
425                 if (maybe_relu(0)) vmaxps(zmm, vreg_zero, zmm);
426                 if (p_sum_scale) { // post_op: sum
427                     vpxord(vreg_prev_dst, vreg_prev_dst, vreg_prev_dst);
428                     switch (sum_dt) {
429                         case data_type::f32:
430                         case data_type::s32:
431                             vmovups(vreg_prev_dst | r_mask, addr);
432                             break;
433                         case data_type::s8:
434                             vpmovsxbd(vreg_prev_dst | r_mask, addr);
435                             break;
436                         case data_type::u8:
437                             vpmovzxbd(vreg_prev_dst | r_mask, addr);
438                             break;
439                         default: assert(!"unknown sum_dt");
440                     }
441                     if (sum_dt != data_type::f32)
442                         vcvtdq2ps(vreg_prev_dst, vreg_prev_dst);
443                     if (*p_sum_zp != 0) vsubps(vreg_prev_dst, vreg_sum_zp);
444                     if (*p_sum_scale == 1.f)
445                         vaddps(zmm, vreg_prev_dst);
446                     else
447                         vfmadd231ps(
448                                 zmm, vreg_prev_dst, zword_b[reg_ptr_sum_scale]);
449                 }
450                 // we skip max if dst_dt == u8 as it will happen in saturation
451                 if (maybe_relu(1) && (jcp.dst_dt != u8))
452                     vmaxps(zmm, vreg_zero, zmm);
453 
454                 if (utils::one_of(jcp.dst_dt, u8, s8, s32)) {
455                     saturate_f32(
456                             zmm, vreg_zero, vreg_saturation_ubound, jcp.dst_dt);
457                     vcvtps2dq(zmm, zmm);
458                 }
459                 switch (jcp.dst_dt) {
460                     case data_type::f32:
461                     case data_type::s32: vmovups(addr, zmm | r_mask); break;
462                     case data_type::s8:
463                         vpmovsdb(xmm, zmm);
464                         vmovups(addr, xmm | r_mask);
465                         break;
466                     case data_type::u8:
467                         vpmovusdb(xmm, zmm);
468                         vmovups(addr, xmm | r_mask);
469                         break;
470                     default: assert(!"unknown dst_dt");
471                 }
472             }
473         }
474     };
475 
476     preamble();
477 
478 #define READ_PARAM(reg, field) \
479     mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
480     READ_PARAM(reg_ptr_src, wino_dst);
481     READ_PARAM(reg_ptr_dst, dst);
482     READ_PARAM(reg_ptr_v_y_masks, v_y_masks);
483     READ_PARAM(reg_ptr_v_x_masks, v_x_masks);
484     READ_PARAM(reg_ptr_bias, bias);
485     READ_PARAM(reg_ptr_scales, scales);
486 #undef READ_PARAM
487 
488     if (jcp.with_bias)
489         mov(reg_scratch_bias_alpha, float2int(adj_src_scale * adj_wei_scale));
490 
491     mov(reg_aux_ptr_src, reg_ptr_src);
492     mov(reg_aux_ptr_dst, reg_ptr_dst);
493 
494     vpxord(vreg_zero, vreg_zero, vreg_zero);
495 
496     for (int i = 0; i < jcp.m; i++)
497         kmovw(x_mask(i), ptr[reg_ptr_v_x_masks + sizeof(uint16_t) * i]);
498 
499     int oc_blocks = jcp.oc / load_block;
500     mov(reg_oc_block, oc_blocks);
501     L(oc_block_label);
502     {
503         loop_body();
504         add(reg_aux_ptr_src, sizeof(int32_t) * load_block);
505         add(reg_aux_ptr_dst, jcp.typesize_out * load_block);
506 
507         add(reg_ptr_scales, jcp.is_oc_scale * sizeof(float) * load_block);
508         add(reg_ptr_bias, sizeof(jcp.typesize_bia) * load_block);
509     }
510     dec(reg_oc_block);
511     jnz(oc_block_label, T_NEAR);
512 
513     postamble();
514 }
515 
516 /// GEMM kernel ////////////////////////////////////////////////////////////////
517 struct jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t : public jit_generator {
518     DECLARE_CPU_JIT_AUX_FUNCTIONS(jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t)
519     jit_conv_conf_2x3_wino_t jcp;
520     const primitive_attr_t &attr_;
521 
522     struct call_params_t {
523         const void *src;
524         const void *dst;
525         const void *wei;
526         const void *dst_b;
527     };
528 
529     void generate() override;
530     static bool post_ops_ok(
531             jit_conv_conf_2x3_wino_t &jcp, const primitive_attr_t &attr);
532 
jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_tdnnl::impl::cpu::x64::jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t533     jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t(
534             jit_conv_conf_2x3_wino_t ajcp, const primitive_attr_t &attr)
535         : jcp(ajcp), attr_(attr) {}
536 
537     static status_t init_conf(jit_conv_conf_2x3_wino_t &jcp,
538             const convolution_desc_t &cd, memory_desc_t &src_md,
539             memory_desc_t &weights_md, memory_desc_t &dst_md,
540             memory_desc_t &bias_md, const primitive_attr_t &attr);
541 
vreg_outdnnl::impl::cpu::x64::jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t542     Zmm vreg_out(int n, int m) const {
543         const int id_reg_out = n * jcp.m_block + m;
544         assert(id_reg_out < jcp.n2_block * jcp.m_block);
545         return Zmm(31 - id_reg_out);
546     }
vreg_weidnnl::impl::cpu::x64::jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t547     Zmm vreg_wei(int i) const {
548         assert(31 - jcp.n2_block * jcp.m_block - i
549                 > (jcp.ver == ver_vnni ? 0 : 2));
550         return Zmm(31 - jcp.n2_block * jcp.m_block - i);
551     }
552 
553     Zmm vreg_src = Zmm(0);
554     Zmm vreg_one = Zmm(1);
555     Zmm vreg_tmp = Zmm(2);
556 
557     Reg64 reg_ptr_src = r15;
558 
559     Reg64 reg_aux_dst_b = r13;
560     Reg64 reg_aux_dst = r12;
561     Reg64 reg_aux_dst2 = r11;
562     Reg64 reg_aux_wei = r10;
563     Reg64 reg_aux_wei2 = r9;
564     Reg64 reg_aux_src = r8;
565     Reg64 reg_aux_src2 = rax;
566     Reg64 reg_mb = rbx;
567     Reg64 reg_nnb = abi_not_param1;
568     Reg64 reg_scratch = rdx;
569     Reg64 reg_K = rsi;
570 };
571 
post_ops_ok(jit_conv_conf_2x3_wino_t & jcp,const primitive_attr_t & attr)572 bool jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::post_ops_ok(
573         jit_conv_conf_2x3_wino_t &jcp, const primitive_attr_t &attr) {
574     using namespace primitive_kind;
575     const auto &p = attr.post_ops_;
576 
577     auto is_relu = [&](int idx) { return p.entry_[idx].is_relu(); };
578 
579     switch (p.len()) {
580         case 0: return true;
581         case 1: return is_relu(0) || p.contain(sum, 0);
582         case 2:
583             return (p.contain(sum, 0) && is_relu(1))
584                     || (p.contain(sum, 1) && is_relu(0));
585         case 3: return is_relu(0) && p.contain(sum, 1) && is_relu(2);
586         default: return false;
587     }
588 
589     return false;
590 }
591 
generate()592 void jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::generate() {
593     Label nnb_loop_label, K_loop_label, mb_loop_label;
594 
595     auto compute = [=](Zmm vreg_acc, Zmm vreg_wei, Zmm vreg_src) {
596         if (jcp.ver == ver_vnni) {
597             vpdpbusd(vreg_acc, vreg_src, vreg_wei);
598         } else {
599             vpmaddubsw(vreg_tmp, vreg_src, vreg_wei);
600             vpmaddwd(vreg_tmp, vreg_tmp, vreg_one);
601             vpaddd(vreg_acc, vreg_acc, vreg_tmp);
602         }
603     };
604 
605     preamble();
606 #define READ_PARAM(reg, field) \
607     mov(reg, ptr[abi_param1 + offsetof(call_params_t, field)])
608     READ_PARAM(reg_ptr_src, src);
609     READ_PARAM(reg_aux_dst, dst);
610     READ_PARAM(reg_aux_wei, wei);
611     READ_PARAM(reg_aux_dst_b, dst_b);
612 #undef READ_PARAM
613 
614     if (jcp.ver != ver_vnni) {
615         xor_(reg_scratch, reg_scratch);
616         Reg16 _t = reg_scratch.cvt16();
617         mov(_t, 0x1);
618         vpbroadcastw(vreg_one, _t);
619     }
620 
621     if (!jcp.small_mb) {
622         mov(reg_nnb, jcp.n_chunks);
623         L(nnb_loop_label);
624     }
625     mov(reg_aux_dst2, reg_aux_dst);
626     mov(reg_aux_src, reg_ptr_src);
627     mov(reg_mb, jcp.M / jcp.m_block);
628     L(mb_loop_label);
629     {
630         for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) {
631             for (int m = 0; m < jcp.m_block; m++) {
632                 int offset = jcp.typesize_acc * nb2 * jcp.n_block;
633                 vmovups(vreg_out(nb2, m),
634                         EVEX_compress_addr(reg_aux_dst_b, offset));
635             }
636         }
637         mov(reg_aux_src2, reg_aux_src);
638         mov(reg_aux_wei2, reg_aux_wei);
639         mov(reg_K, jcp.k_chunks);
640         L(K_loop_label);
641         {
642             for (int k = 0; k < jcp.k2_block; k += 4) {
643                 for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) {
644                     int wei_offset
645                             = jcp.typesize_in * (nb2 * jcp.n_block * jcp.K);
646                     vmovups(vreg_wei(nb2),
647                             EVEX_compress_addr(reg_aux_wei2, wei_offset));
648                 }
649                 for (int m = 0; m < jcp.m_block; m++) {
650                     int inp_offset = jcp.typesize_in * m * jcp.K;
651                     vpbroadcastd(vreg_src,
652                             EVEX_compress_addr(reg_aux_src2, inp_offset));
653                     for (int nb2 = 0; nb2 < jcp.n2_block; nb2++)
654                         compute(vreg_out(nb2, m), vreg_wei(nb2), vreg_src);
655                 }
656                 add(reg_aux_src2, jcp.typesize_in * 4);
657                 add(reg_aux_wei2, jcp.typesize_in * 4 * jcp.n_block);
658             }
659         }
660         dec(reg_K);
661         jnz(K_loop_label, T_NEAR);
662 
663         for (int m = 0; m < jcp.m_block; m++) {
664             for (int nb2 = 0; nb2 < jcp.n2_block; nb2++) {
665                 int offset = jcp.typesize_acc * (m * jcp.N + nb2 * jcp.n_block);
666                 vmovups(EVEX_compress_addr(reg_aux_dst2, offset),
667                         vreg_out(nb2, m));
668             }
669         }
670         add(reg_aux_src, jcp.typesize_in * jcp.m_block * jcp.K);
671         add(reg_aux_dst2, jcp.typesize_acc * jcp.m_block * jcp.N);
672     }
673     dec(reg_mb);
674     jnz(mb_loop_label, T_NEAR);
675 
676     if (!jcp.small_mb) {
677         add(reg_aux_dst, jcp.typesize_acc * jcp.n2_block * jcp.n_block);
678         add(reg_aux_dst_b, jcp.typesize_acc * jcp.n2_block * jcp.n_block);
679         add(reg_aux_wei, jcp.typesize_in * jcp.n2_block * jcp.n_block * jcp.K);
680 
681         dec(reg_nnb);
682         jnz(nnb_loop_label, T_NEAR);
683     }
684 
685     postamble();
686 }
687 namespace {
is_winograd_faster_than_direct(const jit_conv_conf_2x3_wino_t & jcp)688 bool is_winograd_faster_than_direct(const jit_conv_conf_2x3_wino_t &jcp) {
689     if (jcp.ver == ver_vnni) {
690         return (jcp.mb <= jcp.nthr
691                        && (jcp.mb > 4 && jcp.ic > 64
692                                && !(jcp.oc > 128 && jcp.ih < 14)))
693                 || jcp.mb > jcp.nthr;
694     }
695     return true;
696 }
697 } // namespace
698 
init_conf(jit_conv_conf_2x3_wino_t & jcp,const convolution_desc_t & cd,memory_desc_t & src_md,memory_desc_t & wei_md,memory_desc_t & dst_md,memory_desc_t & bias_md,const primitive_attr_t & attr)699 status_t jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::init_conf(
700         jit_conv_conf_2x3_wino_t &jcp, const convolution_desc_t &cd,
701         memory_desc_t &src_md, memory_desc_t &wei_md, memory_desc_t &dst_md,
702         memory_desc_t &bias_md, const primitive_attr_t &attr) {
703     const memory_desc_wrapper src_d(&src_md);
704     const memory_desc_wrapper wei_d(&wei_md);
705     const memory_desc_wrapper dst_d(&dst_md);
706     const memory_desc_wrapper bias_d(&bias_md);
707 
708     // This kernel only supports 2D convolutions.
709     if (src_d.ndims() != 4) return status::unimplemented;
710 
711     const bool with_groups = wei_d.ndims() == src_d.ndims() + 1;
712 
713     jcp.nthr = dnnl_get_max_threads();
714 
715     jcp.ngroups = with_groups ? wei_d.dims()[0] : 1;
716     jcp.mb = src_d.dims()[0];
717     jcp.oc = dst_d.dims()[1] / jcp.ngroups;
718     jcp.ic = src_d.dims()[1] / jcp.ngroups;
719     jcp.ih = src_d.dims()[2];
720     jcp.iw = src_d.dims()[3];
721     jcp.oh = dst_d.dims()[2];
722     jcp.ow = dst_d.dims()[3];
723     jcp.kh = wei_d.dims()[with_groups + 2];
724     jcp.kw = wei_d.dims()[with_groups + 3];
725     jcp.t_pad = cd.padding[0][0];
726     jcp.l_pad = cd.padding[0][1];
727     jcp.stride_h = cd.strides[0];
728     jcp.stride_w = cd.strides[1];
729     jcp.dilate_h = cd.dilates[0];
730     jcp.dilate_w = cd.dilates[1];
731 
732     const int ext_kw = calculate_extended_filter_size(jcp.kw, jcp.dilate_w);
733     const int ext_kh = calculate_extended_filter_size(jcp.kh, jcp.dilate_h);
734     jcp.r_pad = calculate_end_padding(
735             jcp.l_pad, jcp.ow, jcp.iw, jcp.stride_w, ext_kw);
736     jcp.b_pad = calculate_end_padding(
737             jcp.t_pad, jcp.oh, jcp.ih, jcp.stride_h, ext_kh);
738 
739     jcp.ver = ver_avx512_core;
740     if (!(mayiuse(avx512_core) && src_d.data_type() == data_type::u8
741                 && wei_d.data_type() == data_type::s8
742                 && one_of(dst_d.data_type(), data_type::f32, data_type::s32,
743                         data_type::s8, data_type::u8)))
744         return status::unimplemented;
745     if (mayiuse(avx512_core_vnni)) jcp.ver = ver_vnni;
746 
747     if (!IMPLICATION(cd.alg_kind == alg_kind::convolution_auto,
748                 is_winograd_faster_than_direct(jcp)))
749         return status::unimplemented;
750 
751     // block sizes needed for GEMM kernel
752     jcp.ic_block = 4;
753     jcp.oc_block = 16;
754 
755     bool ok = true && jcp.ngroups == 1 && jcp.oc % load_block == 0
756             && jcp.ic % load_block == 0 && jcp.oc % jcp.oc_block == 0
757             && jcp.ic % jcp.ic_block == 0 && everyone_is(3, jcp.kh, jcp.kw)
758             && everyone_is(1, jcp.stride_h, jcp.stride_w)
759             && everyone_is(0, jcp.dilate_h, jcp.dilate_w)
760             && jcp.t_pad == jcp.b_pad && jcp.l_pad == jcp.r_pad
761             && one_of(jcp.t_pad, 0, 1) && one_of(jcp.l_pad, 0, 1);
762     if (!ok) return status::unimplemented;
763 
764     format_tag_t dat_tag = format_tag::nhwc;
765     if (!src_d.matches_tag(dat_tag)) return status::unimplemented;
766     if (!dst_d.matches_tag(dat_tag)) return status::unimplemented;
767 
768     jcp.with_bias = cd.bias_desc.format_kind != format_kind::undef;
769 
770     if (!post_ops_ok(jcp, attr)) return status::unimplemented;
771 
772     jcp.bia_dt = jcp.with_bias ? cd.bias_desc.data_type : data_type::undef;
773     jcp.dst_dt = cd.dst_desc.data_type;
774 
775     jcp.typesize_in = types::data_type_size(src_d.data_type());
776     jcp.typesize_out = types::data_type_size(dst_d.data_type());
777     jcp.typesize_acc = sizeof(int32_t);
778     jcp.typesize_bia
779             = jcp.with_bias ? types::data_type_size(bias_d.data_type()) : 0;
780 
781     jcp.nb_oc = jcp.oc / jcp.oc_block;
782     jcp.nb_ic = jcp.ic / jcp.ic_block;
783 
784     jcp.m = 2;
785     jcp.r = 3;
786     jcp.alpha = jcp.m + jcp.r - 1;
787 
788     int aa = jcp.alpha * jcp.alpha;
789     int L1_cap = platform::get_per_core_cache_size(1);
790     int L2_cap = platform::get_per_core_cache_size(2);
791     // need 1 extra reg for bcast, and 2 tmp regs for non-vnni
792     int free_regs = jcp.ver == ver_vnni ? 31 : 29;
793 
794     auto get_thr_eff = [&](int small_mb, int ix, int iy, int n2_b) {
795         float thr_eff;
796         float Z = (float)jcp.ic + jcp.oc;
797         float Y = (float)jcp.ic * jcp.oc;
798         if (small_mb == 0) { // outer par
799             int nblocks = jcp.mb * div_up(jcp.oh, iy) * div_up(jcp.ow, ix);
800             thr_eff = (float)nblocks / rnd_up(nblocks, jcp.nthr);
801         } else { // inner par
802             int tranw = iy * ix / jcp.alpha;
803             int gemmw = aa * (jcp.nb_oc / n2_b);
804             int tranw_r = rnd_up(tranw, jcp.nthr);
805             int gemmw_r = rnd_up(gemmw, jcp.nthr);
806             thr_eff = (Z * tranw / tranw_r + Y * gemmw / gemmw_r) / (Z + Y);
807         }
808         return thr_eff;
809     };
810 
811     auto get_mem_eff = [&](int small_mb, int ix, int iy, int n2_b) {
812         float mem_eff, req_mem;
813         int M = ix * iy / jcp.alpha;
814         if (small_mb == 0) { // outer parallelization strategy
815             // memory for wino transforms (other memory has poor reuse)
816             req_mem = (float)aa * M * (jcp.ic + jcp.typesize_acc * jcp.oc);
817             mem_eff = req_mem < L1_cap ? 1.f : req_mem < L2_cap ? 0.5f : 0.f;
818         } else { // inner parallelization strategy
819             // memory used during gemm
820             int N = jcp.oc_block * n2_b;
821             req_mem = (float)jcp.ic * (M + N) + jcp.typesize_acc * M * N;
822             mem_eff = nstl::min(1.f, L2_cap / req_mem);
823             // memory used during wino transforms
824             int M_per_thr = div_up(M, jcp.nthr);
825             req_mem = (float)aa * M_per_thr
826                     * (jcp.ic + jcp.typesize_acc * jcp.oc);
827             if (req_mem > L2_cap) mem_eff = 0.1f;
828         }
829         return mem_eff;
830     };
831 
832     auto get_tot_eff = [&](int small_mb, float thr_eff, float work_eff,
833                                float mem_eff, float reg_eff) {
834         // these coefficients are chosen empirically
835         float mem_fac = 0.1f, reg_fac = 0.2f;
836         // normalized overhead relative to memory and register components
837         float tot_eff = 1.f + mem_fac * mem_eff + reg_fac * reg_eff;
838         // thread and work components affect all others
839         tot_eff *= thr_eff * work_eff;
840         return tot_eff;
841     };
842 
843     auto find_m_n2_blocks
844             = [&](bool small_mb, int ix, int iy, float work_eff, int &m_block,
845                       int &n2_block, float &tot_eff) {
846                   int M = (ix * iy) / jcp.alpha;
847                   int max_m_block = nstl::min(M, free_regs);
848                   int max_n2_block = nstl::min(jcp.nb_oc, free_regs);
849                   tot_eff = 0.f;
850                   for (int im = max_m_block; im > 0; im--) {
851                       if (M % im) continue;
852                       for (int in2 = max_n2_block; in2 > 0; in2--) {
853                           int used_regs = (im + 1) * in2;
854                           float mem_eff = get_mem_eff(small_mb, ix, iy, in2);
855                           float reg_eff = (float)(im * in2) / (im + in2);
856                           float thr_eff = get_thr_eff(small_mb, ix, iy, in2);
857                           float cur_tot_eff = get_tot_eff(small_mb, thr_eff,
858                                   work_eff, mem_eff, reg_eff);
859                           if (jcp.nb_oc % in2 || used_regs > free_regs
860                                   || cur_tot_eff <= tot_eff)
861                               continue;
862                           tot_eff = cur_tot_eff;
863                           m_block = im;
864                           n2_block = in2;
865                       }
866                   }
867               };
868 
869     /* Selecting xb and yb blocking */
870     int min_yb = jcp.m;
871     int min_xb = jcp.m;
872     int max_yb = nstl::max(min_yb, rnd_up(jcp.oh, 2));
873     int max_xb = nstl::max(min_xb, rnd_up(jcp.ow, 2));
874     float best_eff = 0.f;
875     for (int ix = min_xb; ix <= max_xb; ix += 2) {
876         assert(rnd_up(jcp.ow, ix) >= jcp.iw - 2);
877         for (int iy = max_yb; iy >= min_yb; iy -= 2) {
878             assert(rnd_up(jcp.oh, iy) >= jcp.ih - 2);
879 
880             int m_b[2];
881             int n2_b[2];
882             bool small_mb;
883             float inner_eff, outer_eff, work_eff;
884 
885             int tiled_area = rnd_up(jcp.oh, iy) * rnd_up(jcp.ow, ix);
886             work_eff = (float)jcp.oh * jcp.ow / tiled_area;
887             if (best_eff > 0.f && work_eff < 4.f / 9.f)
888                 continue; // no gain from Winograd transformation
889 
890             /* outer parallelization */
891             find_m_n2_blocks(
892                     false, ix, iy, work_eff, m_b[0], n2_b[0], outer_eff);
893 
894             /* inner parallelization */
895             find_m_n2_blocks(
896                     true, ix, iy, work_eff, m_b[1], n2_b[1], inner_eff);
897 
898             small_mb = inner_eff > outer_eff;
899             float eff = small_mb ? inner_eff : outer_eff;
900             if (eff > best_eff) {
901                 best_eff = eff;
902                 jcp.yb = iy;
903                 jcp.xb = ix;
904                 jcp.m_block = m_b[small_mb];
905                 jcp.n2_block = n2_b[small_mb];
906                 jcp.small_mb = small_mb;
907             }
908         }
909     }
910 
911     assert((jcp.m_block + 1) * jcp.n2_block <= free_regs);
912     assert(jcp.xb % 2 == 0 && jcp.yb % 2 == 0);
913 
914     jcp.mb_block = 1;
915     if (jcp.small_mb) {
916         // For small mb harness, set mb_block as large as possible subject to
917         // the constraint that winograd activations fit into available L3 cache
918         int L3_cap = platform::get_per_core_cache_size(3);
919         int M = jcp.xb * jcp.yb / 4;
920         int wino_src_size = 16 * M * jcp.ic * jcp.typesize_in;
921         int wino_dst_size = 16 * M * jcp.oc * jcp.typesize_acc;
922         int max_mb_block = nstl::min(
923                 jcp.mb, jcp.nthr * L3_cap / (wino_src_size + wino_dst_size));
924         for (int i = max_mb_block; i > 1; i--) {
925             if (jcp.mb % i == 0) {
926                 jcp.mb_block = i;
927                 break;
928             }
929         }
930     }
931     jcp.nb_mb = jcp.mb / jcp.mb_block;
932 
933     jcp.M = jcp.mb_block * jcp.xb * jcp.yb / 4;
934     jcp.N = jcp.oc;
935     jcp.K = jcp.ic;
936 
937     jcp.inp_stride = jcp.M * jcp.ic;
938     jcp.out_stride = jcp.M * jcp.oc;
939     jcp.wei_stride = jcp.ic * jcp.oc;
940     jcp.bia_stride = jcp.oc;
941 
942     jcp.n_block = jcp.oc_block;
943     jcp.k_block = jcp.ic_block;
944 
945     jcp.n_chunks = (jcp.N / jcp.n_block) / jcp.n2_block;
946 
947     // We need jcp.k2_block to be a multiple of jcp.k_block = jcp.ic_block = 4
948     // and jcp.K = jcp.ic to be a multiple of jcp.k2_block. Since jcp.ic is
949     // a multiple of load_block = 16, we just use that for now.
950     jcp.k2_block = load_block;
951     jcp.k_chunks = jcp.K / jcp.k2_block;
952 
953     const auto &oscales = attr.output_scales_;
954     jcp.is_oc_scale = oscales.mask_ == 1 << 1;
955 
956     // only common and per-oc-channel scales are supported
957     const bool oscales_ok = one_of(oscales.mask_, 0, 1 << 1);
958     if (!oscales_ok) return status::unimplemented;
959 
960     /* re-create weights primitive descriptor
961                                     and set weights wino_blocking */
962     memory_desc_t expect_wei_md = wei_md;
963 
964     expect_wei_md.format_kind = format_kind::wino;
965     expect_wei_md.data_type = data_type::s8;
966     dnnl_wino_desc_t &wd = expect_wei_md.format_desc.wino_desc;
967     wd.wino_format = dnnl_wino_wei_aaOIoi;
968     wd.r = jcp.r;
969     wd.alpha = jcp.alpha;
970     wd.ic = jcp.ic;
971     wd.oc = jcp.oc;
972     wd.ic_block = jcp.ic_block;
973     wd.oc_block = jcp.oc_block;
974     wd.oc2_block = jcp.n2_block;
975     wd.ic2_block = 1;
976     wd.adj_scale = adj_wei_scale;
977 
978     size_t max_size = types::data_type_size(data_type::s8) * jcp.alpha
979             * jcp.alpha * jcp.ic * jcp.oc;
980     max_size += types::data_type_size(data_type::s32) * jcp.alpha * jcp.alpha
981             * jcp.oc;
982     wd.size = max_size;
983 
984     if (wei_md.format_kind == format_kind::any) wei_md = expect_wei_md;
985     if (wei_md != expect_wei_md) return status::unimplemented;
986 
987     const int tilesize = jcp.alpha * jcp.alpha;
988     const int numtiles = jcp.M;
989     const int alltiles = numtiles * tilesize;
990 
991     jcp.size_wino_src
992             = utils::rnd_up(jcp.typesize_in * alltiles * jcp.ic, PAGE_4K)
993             / jcp.typesize_in;
994     jcp.size_wino_wei = tilesize * jcp.oc * jcp.ic;
995     jcp.size_wino_dst = alltiles * jcp.oc;
996 
997     return status::success;
998 }
999 ////////////////////////////////////////////////////////////////////////////////
1000 
jit_conf()1001 status_t jit_avx512_core_u8s8s32x_wino_convolution_fwd_t::pd_t::jit_conf() {
1002     return jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::init_conf(jcp_,
1003             *this->desc(), this->src_md_, this->weights_md_, this->dst_md_,
1004             this->bias_md_, *this->attr());
1005 }
1006 
init_scratchpad()1007 void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t::pd_t::init_scratchpad() {
1008     auto scratchpad = this->scratchpad_registry().registrar();
1009 
1010     int nthr_multiplier = jcp_.small_mb ? 1 : jcp_.nthr;
1011     scratchpad.template book<src_data_t>(
1012             key_wino_V, jcp_.size_wino_src * nthr_multiplier, PAGE_4K);
1013     scratchpad.template book<acc_data_t>(
1014             key_wino_M, jcp_.size_wino_dst * nthr_multiplier, PAGE_4K);
1015 
1016     dim_t scale_count = attr()->output_scales_.count_;
1017     scratchpad.template book<float>(
1018             key_conv_adjusted_scales, nstl::max<dim_t>(scale_count, 16));
1019 }
1020 
1021 jit_avx512_core_u8s8s32x_wino_convolution_fwd_t::
jit_avx512_core_u8s8s32x_wino_convolution_fwd_t(const pd_t * apd)1022         jit_avx512_core_u8s8s32x_wino_convolution_fwd_t(const pd_t *apd)
1023     : primitive_t(apd) {}
1024 
init(engine_t * engine)1025 status_t jit_avx512_core_u8s8s32x_wino_convolution_fwd_t::init(
1026         engine_t *engine) {
1027     CHECK(safe_ptr_assign(kernel_,
1028             new jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t(
1029                     pd()->jcp_, *pd()->attr())));
1030     CHECK(safe_ptr_assign(src_trans_,
1031             new jit_avx512_core_u8s8s32x_wino_conv_src_trans_t(
1032                     pd()->jcp_, *pd()->attr())));
1033     CHECK(safe_ptr_assign(dst_trans_,
1034             new jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t(
1035                     pd()->jcp_, *pd()->attr())));
1036     CHECK(kernel_->create_kernel());
1037     CHECK(src_trans_->create_kernel());
1038     CHECK(dst_trans_->create_kernel());
1039     return status::success;
1040 }
1041 
1042 jit_avx512_core_u8s8s32x_wino_convolution_fwd_t::
1043         ~jit_avx512_core_u8s8s32x_wino_convolution_fwd_t()
1044         = default;
1045 
adjust_oscales(const memory_tracking::grantor_t & scratchpad) const1046 const float *jit_avx512_core_u8s8s32x_wino_convolution_fwd_t::adjust_oscales(
1047         const memory_tracking::grantor_t &scratchpad) const {
1048     const float *oscales = pd()->attr()->output_scales_.scales_;
1049     auto loc_scales = scratchpad.template get<float>(key_conv_adjusted_scales);
1050     size_t count = pd()->attr()->output_scales_.count_;
1051     float factor = 1.f / (adj_src_scale * adj_wei_scale);
1052     if (count == 1)
1053         utils::array_set(loc_scales, oscales[0] * factor, 16);
1054     else
1055         for (size_t c = 0; c < count; c++)
1056             loc_scales[c] = oscales[c] * factor;
1057     return loc_scales;
1058 }
1059 
execute_forward(const exec_ctx_t & ctx) const1060 void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t ::execute_forward(
1061         const exec_ctx_t &ctx) const {
1062     auto src = CTX_IN_MEM(const src_data_t *, DNNL_ARG_SRC);
1063     auto weights = CTX_IN_MEM(const wei_data_t *, DNNL_ARG_WEIGHTS);
1064     auto bias = CTX_IN_MEM(const char *, DNNL_ARG_BIAS);
1065     auto dst = CTX_OUT_MEM(char *, DNNL_ARG_DST);
1066 
1067     const auto &jcp = kernel_->jcp;
1068     if (jcp.small_mb)
1069         execute_forward_small_mb(
1070                 src, weights, bias, dst, ctx.get_scratchpad_grantor());
1071     else
1072         execute_forward_mbN(
1073                 src, weights, bias, dst, ctx.get_scratchpad_grantor());
1074 }
1075 
execute_forward_mbN(const src_data_t * src,const wei_data_t * wei,const char * bia,char * dst,const memory_tracking::grantor_t & scratchpad) const1076 void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t::execute_forward_mbN(
1077         const src_data_t *src, const wei_data_t *wei, const char *bia,
1078         char *dst, const memory_tracking::grantor_t &scratchpad) const {
1079     const auto &jcp = kernel_->jcp;
1080     const memory_desc_wrapper dst_d(pd()->dst_md());
1081     const size_t dst_dt_size = types::data_type_size(dst_d.data_type());
1082     const float *oscales = adjust_oscales(scratchpad);
1083 
1084     auto dst_bias = (const acc_data_t *)(wei + jcp.size_wino_wei);
1085     auto wino_src_base = scratchpad.template get<src_data_t>(key_wino_V);
1086     auto wino_dst_base = scratchpad.template get<acc_data_t>(key_wino_M);
1087 
1088     parallel_nd_ext(jcp.nthr, jcp.mb, div_up(jcp.oh, jcp.yb),
1089             div_up(jcp.ow, jcp.xb),
1090             [&](dim_t ithr, dim_t nthr, dim_t mb, dim_t tile_y_b,
1091                     dim_t tile_x_b) {
1092                 assert(nthr <= jcp.nthr);
1093                 MAYBE_UNUSED(nthr);
1094 
1095                 int tile_y = tile_y_b * jcp.yb;
1096                 int tile_x = tile_x_b * jcp.xb;
1097 
1098                 auto wino_src = wino_src_base + jcp.size_wino_src * ithr;
1099                 auto wino_dst = wino_dst_base + jcp.size_wino_dst * ithr;
1100 
1101                 auto src_trans_p
1102                         = jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::
1103                                 call_params_t();
1104                 auto dst_trans_p
1105                         = jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::
1106                                 call_params_t();
1107                 auto gemm_p = jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::
1108                         call_params_t();
1109 
1110                 /* transformation of input tensor to winograd domain */
1111                 for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) {
1112                     for (int x_in_block = 0; x_in_block < jcp.xb;
1113                             x_in_block += 2) {
1114                         uint16_t v_y_masks[4], v_x_masks[4];
1115 
1116                         int y = y_in_block + tile_y;
1117                         int x = x_in_block + tile_x;
1118                         int m = (y_in_block / 2) * (jcp.xb / 2)
1119                                 + (x_in_block / 2);
1120 
1121                         int v_ys = nstl::max(0, jcp.t_pad - y);
1122                         int v_ye = nstl::min(jcp.alpha,
1123                                 nstl::max(0, jcp.ih + jcp.t_pad - y));
1124 
1125                         int v_xs = nstl::max(0, jcp.l_pad - x);
1126                         int v_xe = nstl::min(jcp.alpha,
1127                                 nstl::max(0, jcp.iw + jcp.l_pad - x));
1128 
1129 #pragma unroll(4)
1130                         for (int i = 0; i < jcp.alpha; i++) {
1131                             v_y_masks[i] = uint16_t(
1132                                     i < v_ys || i >= v_ye ? 0 : 0xffff);
1133                             v_x_masks[i] = uint16_t(
1134                                     i < v_xs || i >= v_xe ? 0 : 0xffff);
1135                         }
1136                         auto local_s = src
1137                                 + (dim_t)mb * jcp.ih * jcp.iw * jcp.ic
1138                                 + y * jcp.iw * jcp.ic + x * jcp.ic;
1139                         auto local_w = wino_src + m * jcp.ic;
1140 
1141                         src_trans_p.src = local_s;
1142                         src_trans_p.wino_src = local_w;
1143                         src_trans_p.v_y_masks = v_y_masks;
1144                         src_trans_p.v_x_masks = v_x_masks;
1145 
1146                         (*src_trans_)(&src_trans_p);
1147                     }
1148                 }
1149                 /* gemms */
1150                 for (int tile_ij = 0; tile_ij < 16; tile_ij++) {
1151                     // start threads at different GEMMs to help bring weights into LLC
1152                     int offset = (tile_ij + ithr) % 16;
1153                     gemm_p.src = wino_src + jcp.inp_stride * offset;
1154                     gemm_p.dst = wino_dst + jcp.out_stride * offset;
1155                     gemm_p.wei = wei + jcp.wei_stride * offset;
1156                     gemm_p.dst_b = dst_bias + jcp.bia_stride * offset;
1157 
1158                     (*kernel_)(&gemm_p);
1159                 }
1160 
1161                 auto dst_loc
1162                         = dst + dst_dt_size * mb * jcp.oh * jcp.ow * jcp.oc;
1163                 /* transformation from winograd domain to output tensor */
1164                 for (int y_in_block = 0; y_in_block < jcp.yb; y_in_block += 2) {
1165                     for (int x_in_block = 0; x_in_block < jcp.xb;
1166                             x_in_block += 2) {
1167                         uint16_t v_y_masks[2], v_x_masks[2];
1168 
1169                         int y = y_in_block + tile_y;
1170                         int x = x_in_block + tile_x;
1171                         int m = (y_in_block / 2) * (jcp.xb / 2)
1172                                 + (x_in_block / 2);
1173 
1174 #pragma unroll(2)
1175                         for (int i = 0; i < jcp.m; i++) {
1176                             v_x_masks[i]
1177                                     = uint16_t(x + i < jcp.ow ? 0xffff : 0);
1178                             v_y_masks[i]
1179                                     = uint16_t(y + i < jcp.oh ? 0xffff : 0);
1180                         }
1181                         auto local_d = dst_loc
1182                                 + dst_dt_size
1183                                         * (y * jcp.ow * jcp.oc + x * jcp.oc);
1184                         auto local_w = wino_dst + m * jcp.oc;
1185 
1186                         auto scales = oscales;
1187                         dst_trans_p.dst = local_d;
1188                         dst_trans_p.wino_dst = local_w;
1189                         dst_trans_p.v_y_masks = v_y_masks;
1190                         dst_trans_p.v_x_masks = v_x_masks;
1191 
1192                         dst_trans_p.scales = scales;
1193                         dst_trans_p.bias = bia;
1194 
1195                         (*dst_trans_)(&dst_trans_p);
1196                     }
1197                 }
1198             });
1199 }
1200 
execute_forward_small_mb(const src_data_t * src,const wei_data_t * wei,const char * bia,char * dst,const memory_tracking::grantor_t & scratchpad) const1201 void jit_avx512_core_u8s8s32x_wino_convolution_fwd_t::execute_forward_small_mb(
1202         const src_data_t *src, const wei_data_t *wei, const char *bia,
1203         char *dst, const memory_tracking::grantor_t &scratchpad) const {
1204     const auto &jcp = kernel_->jcp;
1205     const memory_desc_wrapper dst_d(pd()->dst_md());
1206     const size_t dst_dt_size = types::data_type_size(dst_d.data_type());
1207     const float *oscales = adjust_oscales(scratchpad);
1208 
1209     auto dst_bias = (const acc_data_t *)(wei + jcp.size_wino_wei);
1210     auto wino_src = scratchpad.template get<src_data_t>(key_wino_V);
1211     auto wino_dst = scratchpad.template get<acc_data_t>(key_wino_M);
1212 
1213     for_(int mbb = 0; mbb < jcp.nb_mb; mbb++)
1214     for_(int tile_y = 0; tile_y < jcp.oh; tile_y += jcp.yb)
1215     for (int tile_x = 0; tile_x < jcp.ow; tile_x += jcp.xb) {
1216         /* transformation of input tensor to winograd domain */
1217         parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), jcp.mb_block,
1218                 [&](dim_t y_in_block_b, dim_t x_in_block_b, dim_t mb) {
1219                     int y_in_block = y_in_block_b * 2;
1220                     int x_in_block = x_in_block_b * 2;
1221 
1222                     auto src_trans_p
1223                             = jit_avx512_core_u8s8s32x_wino_conv_src_trans_t::
1224                                     call_params_t();
1225 
1226                     uint16_t v_y_masks[4], v_x_masks[4];
1227 
1228                     int y = y_in_block + tile_y;
1229                     int x = x_in_block + tile_x;
1230                     int m = (mb * (jcp.yb / 2) + (y_in_block / 2))
1231                                     * (jcp.xb / 2)
1232                             + (x_in_block / 2);
1233 
1234                     int v_ys = nstl::max(0, jcp.t_pad - y);
1235                     int v_ye = nstl::min(
1236                             jcp.alpha, nstl::max(0, jcp.ih + jcp.t_pad - y));
1237 
1238                     int v_xs = nstl::max(0, jcp.l_pad - x);
1239                     int v_xe = nstl::min(
1240                             jcp.alpha, nstl::max(0, jcp.iw + jcp.l_pad - x));
1241 
1242 #pragma unroll(4)
1243                     for (int i = 0; i < jcp.alpha; i++) {
1244                         v_y_masks[i]
1245                                 = uint16_t(i < v_ys || i >= v_ye ? 0 : 0xffff);
1246                         v_x_masks[i]
1247                                 = uint16_t(i < v_xs || i >= v_xe ? 0 : 0xffff);
1248                     }
1249                     auto local_s = src
1250                             + ((dim_t)mbb * jcp.mb_block + mb) * jcp.ih * jcp.iw
1251                                     * jcp.ic
1252                             + y * jcp.iw * jcp.ic + x * jcp.ic;
1253                     auto local_w = wino_src + m * jcp.ic;
1254 
1255                     src_trans_p.src = local_s;
1256                     src_trans_p.wino_src = local_w;
1257                     src_trans_p.v_y_masks = v_y_masks;
1258                     src_trans_p.v_x_masks = v_x_masks;
1259 
1260                     (*src_trans_)(&src_trans_p);
1261                 });
1262 
1263         /* gemms */
1264         parallel_nd(16, jcp.n_chunks, [&](dim_t tile_ij, dim_t nnb) {
1265             auto gemm_p = jit_avx512_core_u8s8s32x_wino_conv_fwd_ker_t::
1266                     call_params_t();
1267 
1268             gemm_p.src = wino_src + jcp.inp_stride * tile_ij;
1269             gemm_p.dst = wino_dst + jcp.out_stride * tile_ij
1270                     + nnb * jcp.n2_block * jcp.n_block;
1271             gemm_p.wei = wei + jcp.wei_stride * tile_ij
1272                     + nnb * jcp.n2_block * jcp.n_block * jcp.K;
1273             gemm_p.dst_b = dst_bias + jcp.bia_stride * tile_ij
1274                     + nnb * jcp.n2_block * jcp.n_block;
1275 
1276             (*kernel_)(&gemm_p);
1277         });
1278 
1279         /* transformation from winograd domain to output tensor */
1280         parallel_nd(div_up(jcp.yb, 2), div_up(jcp.xb, 2), jcp.mb_block,
1281                 [&](dim_t y_in_block_b, dim_t x_in_block_b, dim_t mb) {
1282                     int y_in_block = y_in_block_b * 2;
1283                     int x_in_block = x_in_block_b * 2;
1284 
1285                     auto dst_trans_p
1286                             = jit_avx512_core_u8s8s32x_wino_conv_dst_trans_t::
1287                                     call_params_t();
1288 
1289                     uint16_t v_y_masks[2], v_x_masks[2];
1290 
1291                     int y = y_in_block + tile_y;
1292                     int x = x_in_block + tile_x;
1293                     int m = (mb * (jcp.yb / 2) + (y_in_block / 2))
1294                                     * (jcp.xb / 2)
1295                             + (x_in_block / 2);
1296 
1297 #pragma unroll(2)
1298                     for (int i = 0; i < jcp.m; i++) {
1299                         v_x_masks[i] = uint16_t(x + i < jcp.ow ? 0xffff : 0);
1300                         v_y_masks[i] = uint16_t(y + i < jcp.oh ? 0xffff : 0);
1301                     }
1302                     auto local_d = dst
1303                             + dst_dt_size
1304                                     * ((mbb * jcp.mb_block + mb) * jcp.oh
1305                                                     * jcp.ow * jcp.oc
1306                                             + y * jcp.ow * jcp.oc + x * jcp.oc);
1307                     auto local_w = wino_dst + m * jcp.oc;
1308 
1309                     auto scales = oscales;
1310                     dst_trans_p.dst = local_d;
1311                     dst_trans_p.wino_dst = local_w;
1312                     dst_trans_p.v_y_masks = v_y_masks;
1313                     dst_trans_p.v_x_masks = v_x_masks;
1314 
1315                     dst_trans_p.scales = scales;
1316                     dst_trans_p.bias = bia;
1317 
1318                     (*dst_trans_)(&dst_trans_p);
1319                 });
1320     }
1321 }
1322 
1323 } // namespace x64
1324 } // namespace cpu
1325 } // namespace impl
1326 } // namespace dnnl
1327