1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved.                      *
3 * This file is part of the LIBXSMM library.                                   *
4 *                                                                             *
5 * For information on the license, see the LICENSE file.                       *
6 * Further information: https://github.com/hfp/libxsmm/                        *
7 * SPDX-License-Identifier: BSD-3-Clause                                       *
8 ******************************************************************************/
9 /* Alexander Heinecke, Evangelos Georganas (Intel Corp.)
10 ******************************************************************************/
11 #include "libxsmm_dnn_fullyconnected_backward_weight_update.h"
12 #include "libxsmm_main.h"
13 
14 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_custom_f32_f32(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid);
15 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_f32_f32(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid);
16 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_custom_bf16_f32(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid);
17 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid);
18 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_emu(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid);
19 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)20 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
21 void bf16_vnni_transpose_16x16(void* source_void, void* dest_void, int source_stride, int dest_stride)
22 {
23 #if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
24   libxsmm_bfloat16 *source = (libxsmm_bfloat16*)source_void;
25   libxsmm_bfloat16 *dest = (libxsmm_bfloat16*)dest_void;
26   __m512i zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6, zmm7;
27   __m512i tmp0, tmp1, tmp2, tmp3;
28   const __m512i abcdefgh_to_abefcdgh = _mm512_set4_epi32(0x0f0e0b0a, 0x0d0c0908, 0x07060302, 0x05040100);
29 
30   zmm0 = _mm512_loadu_si512(source);
31   zmm1 = _mm512_loadu_si512(source + source_stride);
32   zmm2 = _mm512_loadu_si512(source + source_stride*2);
33   zmm3 = _mm512_loadu_si512(source + source_stride*3);
34   zmm4 = _mm512_loadu_si512(source + source_stride*4);
35   zmm5 = _mm512_loadu_si512(source + source_stride*5);
36   zmm6 = _mm512_loadu_si512(source + source_stride*6);
37   zmm7 = _mm512_loadu_si512(source + source_stride*7);
38 
39   zmm0 = _mm512_shuffle_epi8(zmm0, abcdefgh_to_abefcdgh);
40   zmm1 = _mm512_shuffle_epi8(zmm1, abcdefgh_to_abefcdgh);
41   zmm2 = _mm512_shuffle_epi8(zmm2, abcdefgh_to_abefcdgh);
42   zmm3 = _mm512_shuffle_epi8(zmm3, abcdefgh_to_abefcdgh);
43   zmm4 = _mm512_shuffle_epi8(zmm4, abcdefgh_to_abefcdgh);
44   zmm5 = _mm512_shuffle_epi8(zmm5, abcdefgh_to_abefcdgh);
45   zmm6 = _mm512_shuffle_epi8(zmm6, abcdefgh_to_abefcdgh);
46   zmm7 = _mm512_shuffle_epi8(zmm7, abcdefgh_to_abefcdgh);
47 
48   tmp0 = _mm512_unpacklo_epi64(zmm0, zmm1);
49   tmp1 = _mm512_unpackhi_epi64(zmm0, zmm1);
50   tmp2 = _mm512_unpacklo_epi64(zmm2, zmm3);
51   tmp3 = _mm512_unpackhi_epi64(zmm2, zmm3);
52   zmm0 = _mm512_unpacklo_epi64(zmm4, zmm5);
53   zmm1 = _mm512_unpackhi_epi64(zmm4, zmm5);
54   zmm2 = _mm512_unpacklo_epi64(zmm6, zmm7);
55   zmm3 = _mm512_unpackhi_epi64(zmm6, zmm7);
56 
57   zmm4 = _mm512_shuffle_i32x4(tmp0, tmp2, 0x88);
58   zmm6 = _mm512_shuffle_i32x4(tmp0, tmp2, 0xdd);
59   zmm5 = _mm512_shuffle_i32x4(tmp1, tmp3, 0x88);
60   zmm7 = _mm512_shuffle_i32x4(tmp1, tmp3, 0xdd);
61   tmp0 = _mm512_shuffle_i32x4(zmm0, zmm2, 0x88);
62   tmp1 = _mm512_shuffle_i32x4(zmm0, zmm2, 0xdd);
63   tmp2 = _mm512_shuffle_i32x4(zmm1, zmm3, 0x88);
64   tmp3 = _mm512_shuffle_i32x4(zmm1, zmm3, 0xdd);
65 
66   zmm0 = _mm512_shuffle_i32x4(zmm4, tmp0, 0x88);
67   zmm1 = _mm512_shuffle_i32x4(zmm5, tmp2, 0x88);
68   zmm2 = _mm512_shuffle_i32x4(zmm6, tmp1, 0x88);
69   zmm3 = _mm512_shuffle_i32x4(zmm7, tmp3, 0x88);
70   zmm4 = _mm512_shuffle_i32x4(zmm4, tmp0, 0xdd);
71   zmm5 = _mm512_shuffle_i32x4(zmm5, tmp2, 0xdd);
72   zmm6 = _mm512_shuffle_i32x4(zmm6, tmp1, 0xdd);
73   zmm7 = _mm512_shuffle_i32x4(zmm7, tmp3, 0xdd);
74 
75   _mm512_storeu_si512(dest, zmm0);
76   _mm512_storeu_si512(dest + dest_stride, zmm1);
77   _mm512_storeu_si512(dest + dest_stride * 2, zmm2);
78   _mm512_storeu_si512(dest + dest_stride * 3, zmm3);
79   _mm512_storeu_si512(dest + dest_stride * 4, zmm4);
80   _mm512_storeu_si512(dest + dest_stride * 5, zmm5);
81   _mm512_storeu_si512(dest + dest_stride * 6, zmm6);
82   _mm512_storeu_si512(dest + dest_stride * 7, zmm7);
83 #else
84   LIBXSMM_UNUSED(source_void); LIBXSMM_UNUSED(dest_void); LIBXSMM_UNUSED(source_stride); LIBXSMM_UNUSED(dest_stride);
85 #endif
86 }
87 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)88 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
89 void bf16_vnni_transpose(libxsmm_bfloat16* src, libxsmm_bfloat16* dst, int M, int N, int ld_in, int ld_out)
90 {
91 #if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
92   const int _M = M/16, _N = N/16;
93   int i = 0, j = 0;
94   for (i = 0; i < _N; i++) {
95     for (j = 0; j < _M; j++) {
96       bf16_vnni_transpose_16x16((libxsmm_bfloat16*) src+i*16*ld_in+j*32, (libxsmm_bfloat16*) dst+j*16*ld_out+i*32, ld_in*2, ld_out*2);
97     }
98   }
99 #else
100   LIBXSMM_UNUSED(src); LIBXSMM_UNUSED(dst); LIBXSMM_UNUSED(M); LIBXSMM_UNUSED(N); LIBXSMM_UNUSED(ld_in); LIBXSMM_UNUSED(ld_out);
101 #endif
102 }
103 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)104 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
105 void bf16_transpose_32x16(libxsmm_bfloat16 *in, libxsmm_bfloat16 *out, int ld_in, int ld_out)
106 {
107 #if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
108   __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf;
109   __m512i t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, ta, tb, tc, td, te, tf;
110   const int in_width=ld_in, out_width=ld_out;
111   const __m512i idx_lo         = _mm512_set_epi64(13, 12, 5, 4, 9, 8, 1, 0);
112   const __m512i idx_hi         = _mm512_set_epi64(7, 6, 15, 14, 3, 2, 11, 10);
113 
114   r0 = _mm512_loadu_si512(in + 0*in_width);
115   r1 = _mm512_loadu_si512(in + 1*in_width);
116   r2 = _mm512_loadu_si512(in + 2*in_width);
117   r3 = _mm512_loadu_si512(in + 3*in_width);
118   r4 = _mm512_loadu_si512(in + 4*in_width);
119   r5 = _mm512_loadu_si512(in + 5*in_width);
120   r6 = _mm512_loadu_si512(in + 6*in_width);
121   r7 = _mm512_loadu_si512(in + 7*in_width);
122   r8 = _mm512_loadu_si512(in + 8*in_width);
123   r9 = _mm512_loadu_si512(in + 9*in_width);
124   ra = _mm512_loadu_si512(in + 10*in_width);
125   rb = _mm512_loadu_si512(in + 11*in_width);
126   rc = _mm512_loadu_si512(in + 12*in_width);
127   rd = _mm512_loadu_si512(in + 13*in_width);
128   re = _mm512_loadu_si512(in + 14*in_width);
129   rf = _mm512_loadu_si512(in + 15*in_width);
130 
131   t0 = _mm512_unpacklo_epi16(r0,r1);
132   t1 = _mm512_unpackhi_epi16(r0,r1);
133   t2 = _mm512_unpacklo_epi16(r2,r3);
134   t3 = _mm512_unpackhi_epi16(r2,r3);
135   t4 = _mm512_unpacklo_epi16(r4,r5);
136   t5 = _mm512_unpackhi_epi16(r4,r5);
137   t6 = _mm512_unpacklo_epi16(r6,r7);
138   t7 = _mm512_unpackhi_epi16(r6,r7);
139   t8 = _mm512_unpacklo_epi16(r8,r9);
140   t9 = _mm512_unpackhi_epi16(r8,r9);
141   ta = _mm512_unpacklo_epi16(ra,rb);
142   tb = _mm512_unpackhi_epi16(ra,rb);
143   tc = _mm512_unpacklo_epi16(rc,rd);
144   td = _mm512_unpackhi_epi16(rc,rd);
145   te = _mm512_unpacklo_epi16(re,rf);
146   tf = _mm512_unpackhi_epi16(re,rf);
147 
148   r0 = _mm512_unpacklo_epi32(t0,t2);
149   r1 = _mm512_unpackhi_epi32(t0,t2);
150   r2 = _mm512_unpacklo_epi32(t1,t3);
151   r3 = _mm512_unpackhi_epi32(t1,t3);
152   r4 = _mm512_unpacklo_epi32(t4,t6);
153   r5 = _mm512_unpackhi_epi32(t4,t6);
154   r6 = _mm512_unpacklo_epi32(t5,t7);
155   r7 = _mm512_unpackhi_epi32(t5,t7);
156   r8 = _mm512_unpacklo_epi32(t8,ta);
157   r9 = _mm512_unpackhi_epi32(t8,ta);
158   ra = _mm512_unpacklo_epi32(t9,tb);
159   rb = _mm512_unpackhi_epi32(t9,tb);
160   rc = _mm512_unpacklo_epi32(tc,te);
161   rd = _mm512_unpackhi_epi32(tc,te);
162   re = _mm512_unpacklo_epi32(td,tf);
163   rf = _mm512_unpackhi_epi32(td,tf);
164 
165   t0 = _mm512_unpacklo_epi64(r0,r4);
166   t1 = _mm512_unpackhi_epi64(r0,r4);
167   t2 = _mm512_unpacklo_epi64(r1,r5);
168   t3 = _mm512_unpackhi_epi64(r1,r5);
169   t4 = _mm512_unpacklo_epi64(r2,r6);
170   t5 = _mm512_unpackhi_epi64(r2,r6);
171   t6 = _mm512_unpacklo_epi64(r3,r7);
172   t7 = _mm512_unpackhi_epi64(r3,r7);
173   t8 = _mm512_unpacklo_epi64(r8,rc);
174   t9 = _mm512_unpackhi_epi64(r8,rc);
175   ta = _mm512_unpacklo_epi64(r9,rd);
176   tb = _mm512_unpackhi_epi64(r9,rd);
177   tc = _mm512_unpacklo_epi64(ra,re);
178   td = _mm512_unpackhi_epi64(ra,re);
179   te = _mm512_unpacklo_epi64(rb,rf);
180   tf = _mm512_unpackhi_epi64(rb,rf);
181 
182   r0 = _mm512_shuffle_i32x4(t0, t1, 0x88);
183   r1 = _mm512_shuffle_i32x4(t2, t3, 0x88);
184   r2 = _mm512_shuffle_i32x4(t4, t5, 0x88);
185   r3 = _mm512_shuffle_i32x4(t6, t7, 0x88);
186   r4 = _mm512_shuffle_i32x4(t0, t1, 0xdd);
187   r5 = _mm512_shuffle_i32x4(t2, t3, 0xdd);
188   r6 = _mm512_shuffle_i32x4(t4, t5, 0xdd);
189   r7 = _mm512_shuffle_i32x4(t6, t7, 0xdd);
190   r8 = _mm512_shuffle_i32x4(t8, t9, 0x88);
191   r9 = _mm512_shuffle_i32x4(ta, tb, 0x88);
192   ra = _mm512_shuffle_i32x4(tc, td, 0x88);
193   rb = _mm512_shuffle_i32x4(te, tf, 0x88);
194   rc = _mm512_shuffle_i32x4(t8, t9, 0xdd);
195   rd = _mm512_shuffle_i32x4(ta, tb, 0xdd);
196   re = _mm512_shuffle_i32x4(tc, td, 0xdd);
197   rf = _mm512_shuffle_i32x4(te, tf, 0xdd);
198 
199   t0 = _mm512_permutex2var_epi64(r0, idx_lo, r8);
200   t1 = _mm512_permutex2var_epi64(r1, idx_lo, r9);
201   t2 = _mm512_permutex2var_epi64(r2, idx_lo, ra);
202   t3 = _mm512_permutex2var_epi64(r3, idx_lo, rb);
203   t4 = _mm512_permutex2var_epi64(r4, idx_lo, rc);
204   t5 = _mm512_permutex2var_epi64(r5, idx_lo, rd);
205   t6 = _mm512_permutex2var_epi64(r6, idx_lo, re);
206   t7 = _mm512_permutex2var_epi64(r7, idx_lo, rf);
207   t8 = _mm512_permutex2var_epi64(r8, idx_hi, r0);
208   t9 = _mm512_permutex2var_epi64(r9, idx_hi, r1);
209   ta = _mm512_permutex2var_epi64(ra, idx_hi, r2);
210   tb = _mm512_permutex2var_epi64(rb, idx_hi, r3);
211   tc = _mm512_permutex2var_epi64(rc, idx_hi, r4);
212   td = _mm512_permutex2var_epi64(rd, idx_hi, r5);
213   te = _mm512_permutex2var_epi64(re, idx_hi, r6);
214   tf = _mm512_permutex2var_epi64(rf, idx_hi, r7);
215 
216   LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 0*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t0, 0));
217   LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 1*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t0, 1));
218   LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 2*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t1, 0));
219   LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 3*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t1, 1));
220   LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 4*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t2, 0));
221   LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 5*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t2, 1));
222   LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 6*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t3, 0));
223   LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 7*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t3, 1));
224   LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 8*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t4, 0));
225   LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 9*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t4, 1));
226   LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 10*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t5, 0));
227   LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 11*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t5, 1));
228   LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 12*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t6, 0));
229   LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 13*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t6, 1));
230   LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 14*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t7, 0));
231   LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 15*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t7, 1));
232   LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 16*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t8, 0));
233   LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 17*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t8, 1));
234   LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 18*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t9, 0));
235   LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 19*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t9, 1));
236   LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 20*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(ta, 0));
237   LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 21*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(ta, 1));
238   LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 22*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tb, 0));
239   LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 23*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tb, 1));
240   LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 24*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tc, 0));
241   LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 25*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tc, 1));
242   LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 26*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(td, 0));
243   LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 27*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(td, 1));
244   LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 28*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(te, 0));
245   LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 29*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(te, 1));
246   LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 30*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tf, 0));
247   LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 31*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tf, 1));
248 #else
249  LIBXSMM_UNUSED(in); LIBXSMM_UNUSED(out); LIBXSMM_UNUSED(ld_in); LIBXSMM_UNUSED(ld_out);
250 #endif
251 }
252 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)253 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
254 void bf16_transpose_32xcols(libxsmm_bfloat16 *in, libxsmm_bfloat16 *out, int col, int ld_in, int ld_out)
255 {
256 #if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
257   __m512i r0 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), r1 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), r2 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), r3 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), r4 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), r5 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), r6 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), r7 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), r8 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), r9 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), ra = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), rb = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), rc = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), rd = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), re = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), rf = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
258   __m512i t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, ta, tb, tc, td, te, tf;
259   const int in_width=ld_in, out_width=ld_out;
260   const __m512i idx_lo         = _mm512_set_epi64(13, 12, 5, 4, 9, 8, 1, 0);
261   const __m512i idx_hi         = _mm512_set_epi64(7, 6, 15, 14, 3, 2, 11, 10);
262   __mmask16 store_mask         = LIBXSMM_INTRINSICS_MM512_CVTU32_MASK16(((unsigned int)1 << col) - 1);
263 
264   if (col == 15) {
265     r0 = _mm512_loadu_si512(in + 0*in_width);
266     r1 = _mm512_loadu_si512(in + 1*in_width);
267     r2 = _mm512_loadu_si512(in + 2*in_width);
268     r3 = _mm512_loadu_si512(in + 3*in_width);
269     r4 = _mm512_loadu_si512(in + 4*in_width);
270     r5 = _mm512_loadu_si512(in + 5*in_width);
271     r6 = _mm512_loadu_si512(in + 6*in_width);
272     r7 = _mm512_loadu_si512(in + 7*in_width);
273     r8 = _mm512_loadu_si512(in + 8*in_width);
274     r9 = _mm512_loadu_si512(in + 9*in_width);
275     ra = _mm512_loadu_si512(in + 10*in_width);
276     rb = _mm512_loadu_si512(in + 11*in_width);
277     rc = _mm512_loadu_si512(in + 12*in_width);
278     rd = _mm512_loadu_si512(in + 13*in_width);
279     re = _mm512_loadu_si512(in + 14*in_width);
280   } else if (col == 14) {
281     r0 = _mm512_loadu_si512(in + 0*in_width);
282     r1 = _mm512_loadu_si512(in + 1*in_width);
283     r2 = _mm512_loadu_si512(in + 2*in_width);
284     r3 = _mm512_loadu_si512(in + 3*in_width);
285     r4 = _mm512_loadu_si512(in + 4*in_width);
286     r5 = _mm512_loadu_si512(in + 5*in_width);
287     r6 = _mm512_loadu_si512(in + 6*in_width);
288     r7 = _mm512_loadu_si512(in + 7*in_width);
289     r8 = _mm512_loadu_si512(in + 8*in_width);
290     r9 = _mm512_loadu_si512(in + 9*in_width);
291     ra = _mm512_loadu_si512(in + 10*in_width);
292     rb = _mm512_loadu_si512(in + 11*in_width);
293     rc = _mm512_loadu_si512(in + 12*in_width);
294     rd = _mm512_loadu_si512(in + 13*in_width);
295   } else if (col == 13) {
296     r0 = _mm512_loadu_si512(in + 0*in_width);
297     r1 = _mm512_loadu_si512(in + 1*in_width);
298     r2 = _mm512_loadu_si512(in + 2*in_width);
299     r3 = _mm512_loadu_si512(in + 3*in_width);
300     r4 = _mm512_loadu_si512(in + 4*in_width);
301     r5 = _mm512_loadu_si512(in + 5*in_width);
302     r6 = _mm512_loadu_si512(in + 6*in_width);
303     r7 = _mm512_loadu_si512(in + 7*in_width);
304     r8 = _mm512_loadu_si512(in + 8*in_width);
305     r9 = _mm512_loadu_si512(in + 9*in_width);
306     ra = _mm512_loadu_si512(in + 10*in_width);
307     rb = _mm512_loadu_si512(in + 11*in_width);
308     rc = _mm512_loadu_si512(in + 12*in_width);
309   } else if (col == 12) {
310     r0 = _mm512_loadu_si512(in + 0*in_width);
311     r1 = _mm512_loadu_si512(in + 1*in_width);
312     r2 = _mm512_loadu_si512(in + 2*in_width);
313     r3 = _mm512_loadu_si512(in + 3*in_width);
314     r4 = _mm512_loadu_si512(in + 4*in_width);
315     r5 = _mm512_loadu_si512(in + 5*in_width);
316     r6 = _mm512_loadu_si512(in + 6*in_width);
317     r7 = _mm512_loadu_si512(in + 7*in_width);
318     r8 = _mm512_loadu_si512(in + 8*in_width);
319     r9 = _mm512_loadu_si512(in + 9*in_width);
320     ra = _mm512_loadu_si512(in + 10*in_width);
321     rb = _mm512_loadu_si512(in + 11*in_width);
322   } else if (col == 11) {
323     r0 = _mm512_loadu_si512(in + 0*in_width);
324     r1 = _mm512_loadu_si512(in + 1*in_width);
325     r2 = _mm512_loadu_si512(in + 2*in_width);
326     r3 = _mm512_loadu_si512(in + 3*in_width);
327     r4 = _mm512_loadu_si512(in + 4*in_width);
328     r5 = _mm512_loadu_si512(in + 5*in_width);
329     r6 = _mm512_loadu_si512(in + 6*in_width);
330     r7 = _mm512_loadu_si512(in + 7*in_width);
331     r8 = _mm512_loadu_si512(in + 8*in_width);
332     r9 = _mm512_loadu_si512(in + 9*in_width);
333     ra = _mm512_loadu_si512(in + 10*in_width);
334   } else if (col == 10) {
335     r0 = _mm512_loadu_si512(in + 0*in_width);
336     r1 = _mm512_loadu_si512(in + 1*in_width);
337     r2 = _mm512_loadu_si512(in + 2*in_width);
338     r3 = _mm512_loadu_si512(in + 3*in_width);
339     r4 = _mm512_loadu_si512(in + 4*in_width);
340     r5 = _mm512_loadu_si512(in + 5*in_width);
341     r6 = _mm512_loadu_si512(in + 6*in_width);
342     r7 = _mm512_loadu_si512(in + 7*in_width);
343     r8 = _mm512_loadu_si512(in + 8*in_width);
344     r9 = _mm512_loadu_si512(in + 9*in_width);
345   } else if (col == 9) {
346     r0 = _mm512_loadu_si512(in + 0*in_width);
347     r1 = _mm512_loadu_si512(in + 1*in_width);
348     r2 = _mm512_loadu_si512(in + 2*in_width);
349     r3 = _mm512_loadu_si512(in + 3*in_width);
350     r4 = _mm512_loadu_si512(in + 4*in_width);
351     r5 = _mm512_loadu_si512(in + 5*in_width);
352     r6 = _mm512_loadu_si512(in + 6*in_width);
353     r7 = _mm512_loadu_si512(in + 7*in_width);
354     r8 = _mm512_loadu_si512(in + 8*in_width);
355   } else if (col == 8) {
356     r0 = _mm512_loadu_si512(in + 0*in_width);
357     r1 = _mm512_loadu_si512(in + 1*in_width);
358     r2 = _mm512_loadu_si512(in + 2*in_width);
359     r3 = _mm512_loadu_si512(in + 3*in_width);
360     r4 = _mm512_loadu_si512(in + 4*in_width);
361     r5 = _mm512_loadu_si512(in + 5*in_width);
362     r6 = _mm512_loadu_si512(in + 6*in_width);
363     r7 = _mm512_loadu_si512(in + 7*in_width);
364   } else if (col == 7) {
365     r0 = _mm512_loadu_si512(in + 0*in_width);
366     r1 = _mm512_loadu_si512(in + 1*in_width);
367     r2 = _mm512_loadu_si512(in + 2*in_width);
368     r3 = _mm512_loadu_si512(in + 3*in_width);
369     r4 = _mm512_loadu_si512(in + 4*in_width);
370     r5 = _mm512_loadu_si512(in + 5*in_width);
371     r6 = _mm512_loadu_si512(in + 6*in_width);
372   } else if (col == 6) {
373     r0 = _mm512_loadu_si512(in + 0*in_width);
374     r1 = _mm512_loadu_si512(in + 1*in_width);
375     r2 = _mm512_loadu_si512(in + 2*in_width);
376     r3 = _mm512_loadu_si512(in + 3*in_width);
377     r4 = _mm512_loadu_si512(in + 4*in_width);
378     r5 = _mm512_loadu_si512(in + 5*in_width);
379   } else if (col == 5) {
380     r0 = _mm512_loadu_si512(in + 0*in_width);
381     r1 = _mm512_loadu_si512(in + 1*in_width);
382     r2 = _mm512_loadu_si512(in + 2*in_width);
383     r3 = _mm512_loadu_si512(in + 3*in_width);
384     r4 = _mm512_loadu_si512(in + 4*in_width);
385   } else if (col == 4) {
386     r0 = _mm512_loadu_si512(in + 0*in_width);
387     r1 = _mm512_loadu_si512(in + 1*in_width);
388     r2 = _mm512_loadu_si512(in + 2*in_width);
389     r3 = _mm512_loadu_si512(in + 3*in_width);
390   } else if (col == 3) {
391     r0 = _mm512_loadu_si512(in + 0*in_width);
392     r1 = _mm512_loadu_si512(in + 1*in_width);
393     r2 = _mm512_loadu_si512(in + 2*in_width);
394   } else if (col == 2) {
395     r0 = _mm512_loadu_si512(in + 0*in_width);
396     r1 = _mm512_loadu_si512(in + 1*in_width);
397   } else if (col == 1) {
398     r0 = _mm512_loadu_si512(in + 0*in_width);
399   }
400 
401   t0 = _mm512_unpacklo_epi16(r0,r1);
402   t1 = _mm512_unpackhi_epi16(r0,r1);
403   t2 = _mm512_unpacklo_epi16(r2,r3);
404   t3 = _mm512_unpackhi_epi16(r2,r3);
405   t4 = _mm512_unpacklo_epi16(r4,r5);
406   t5 = _mm512_unpackhi_epi16(r4,r5);
407   t6 = _mm512_unpacklo_epi16(r6,r7);
408   t7 = _mm512_unpackhi_epi16(r6,r7);
409   t8 = _mm512_unpacklo_epi16(r8,r9);
410   t9 = _mm512_unpackhi_epi16(r8,r9);
411   ta = _mm512_unpacklo_epi16(ra,rb);
412   tb = _mm512_unpackhi_epi16(ra,rb);
413   tc = _mm512_unpacklo_epi16(rc,rd);
414   td = _mm512_unpackhi_epi16(rc,rd);
415   te = _mm512_unpacklo_epi16(re,rf);
416   tf = _mm512_unpackhi_epi16(re,rf);
417 
418   r0 = _mm512_unpacklo_epi32(t0,t2);
419   r1 = _mm512_unpackhi_epi32(t0,t2);
420   r2 = _mm512_unpacklo_epi32(t1,t3);
421   r3 = _mm512_unpackhi_epi32(t1,t3);
422   r4 = _mm512_unpacklo_epi32(t4,t6);
423   r5 = _mm512_unpackhi_epi32(t4,t6);
424   r6 = _mm512_unpacklo_epi32(t5,t7);
425   r7 = _mm512_unpackhi_epi32(t5,t7);
426   r8 = _mm512_unpacklo_epi32(t8,ta);
427   r9 = _mm512_unpackhi_epi32(t8,ta);
428   ra = _mm512_unpacklo_epi32(t9,tb);
429   rb = _mm512_unpackhi_epi32(t9,tb);
430   rc = _mm512_unpacklo_epi32(tc,te);
431   rd = _mm512_unpackhi_epi32(tc,te);
432   re = _mm512_unpacklo_epi32(td,tf);
433   rf = _mm512_unpackhi_epi32(td,tf);
434 
435   t0 = _mm512_unpacklo_epi64(r0,r4);
436   t1 = _mm512_unpackhi_epi64(r0,r4);
437   t2 = _mm512_unpacklo_epi64(r1,r5);
438   t3 = _mm512_unpackhi_epi64(r1,r5);
439   t4 = _mm512_unpacklo_epi64(r2,r6);
440   t5 = _mm512_unpackhi_epi64(r2,r6);
441   t6 = _mm512_unpacklo_epi64(r3,r7);
442   t7 = _mm512_unpackhi_epi64(r3,r7);
443   t8 = _mm512_unpacklo_epi64(r8,rc);
444   t9 = _mm512_unpackhi_epi64(r8,rc);
445   ta = _mm512_unpacklo_epi64(r9,rd);
446   tb = _mm512_unpackhi_epi64(r9,rd);
447   tc = _mm512_unpacklo_epi64(ra,re);
448   td = _mm512_unpackhi_epi64(ra,re);
449   te = _mm512_unpacklo_epi64(rb,rf);
450   tf = _mm512_unpackhi_epi64(rb,rf);
451 
452   r0 = _mm512_shuffle_i32x4(t0, t1, 0x88);
453   r1 = _mm512_shuffle_i32x4(t2, t3, 0x88);
454   r2 = _mm512_shuffle_i32x4(t4, t5, 0x88);
455   r3 = _mm512_shuffle_i32x4(t6, t7, 0x88);
456   r4 = _mm512_shuffle_i32x4(t0, t1, 0xdd);
457   r5 = _mm512_shuffle_i32x4(t2, t3, 0xdd);
458   r6 = _mm512_shuffle_i32x4(t4, t5, 0xdd);
459   r7 = _mm512_shuffle_i32x4(t6, t7, 0xdd);
460   r8 = _mm512_shuffle_i32x4(t8, t9, 0x88);
461   r9 = _mm512_shuffle_i32x4(ta, tb, 0x88);
462   ra = _mm512_shuffle_i32x4(tc, td, 0x88);
463   rb = _mm512_shuffle_i32x4(te, tf, 0x88);
464   rc = _mm512_shuffle_i32x4(t8, t9, 0xdd);
465   rd = _mm512_shuffle_i32x4(ta, tb, 0xdd);
466   re = _mm512_shuffle_i32x4(tc, td, 0xdd);
467   rf = _mm512_shuffle_i32x4(te, tf, 0xdd);
468 
469   t0 = _mm512_permutex2var_epi64(r0, idx_lo, r8);
470   t1 = _mm512_permutex2var_epi64(r1, idx_lo, r9);
471   t2 = _mm512_permutex2var_epi64(r2, idx_lo, ra);
472   t3 = _mm512_permutex2var_epi64(r3, idx_lo, rb);
473   t4 = _mm512_permutex2var_epi64(r4, idx_lo, rc);
474   t5 = _mm512_permutex2var_epi64(r5, idx_lo, rd);
475   t6 = _mm512_permutex2var_epi64(r6, idx_lo, re);
476   t7 = _mm512_permutex2var_epi64(r7, idx_lo, rf);
477   t8 = _mm512_permutex2var_epi64(r8, idx_hi, r0);
478   t9 = _mm512_permutex2var_epi64(r9, idx_hi, r1);
479   ta = _mm512_permutex2var_epi64(ra, idx_hi, r2);
480   tb = _mm512_permutex2var_epi64(rb, idx_hi, r3);
481   tc = _mm512_permutex2var_epi64(rc, idx_hi, r4);
482   td = _mm512_permutex2var_epi64(rd, idx_hi, r5);
483   te = _mm512_permutex2var_epi64(re, idx_hi, r6);
484   tf = _mm512_permutex2var_epi64(rf, idx_hi, r7);
485 
486   _mm256_mask_storeu_epi16(out + 0*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t0, 0));
487   _mm256_mask_storeu_epi16(out + 1*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t0, 1));
488   _mm256_mask_storeu_epi16(out + 2*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t1, 0));
489   _mm256_mask_storeu_epi16(out + 3*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t1, 1));
490   _mm256_mask_storeu_epi16(out + 4*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t2, 0));
491   _mm256_mask_storeu_epi16(out + 5*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t2, 1));
492   _mm256_mask_storeu_epi16(out + 6*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t3, 0));
493   _mm256_mask_storeu_epi16(out + 7*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t3, 1));
494   _mm256_mask_storeu_epi16(out + 8*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t4, 0));
495   _mm256_mask_storeu_epi16(out + 9*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t4, 1));
496   _mm256_mask_storeu_epi16(out + 10*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t5, 0));
497   _mm256_mask_storeu_epi16(out + 11*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t5, 1));
498   _mm256_mask_storeu_epi16(out + 12*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t6, 0));
499   _mm256_mask_storeu_epi16(out + 13*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t6, 1));
500   _mm256_mask_storeu_epi16(out + 14*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t7, 0));
501   _mm256_mask_storeu_epi16(out + 15*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t7, 1));
502   _mm256_mask_storeu_epi16(out + 16*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t8, 0));
503   _mm256_mask_storeu_epi16(out + 17*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t8, 1));
504   _mm256_mask_storeu_epi16(out + 18*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t9, 0));
505   _mm256_mask_storeu_epi16(out + 19*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t9, 1));
506   _mm256_mask_storeu_epi16(out + 20*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(ta, 0));
507   _mm256_mask_storeu_epi16(out + 21*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(ta, 1));
508   _mm256_mask_storeu_epi16(out + 22*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tb, 0));
509   _mm256_mask_storeu_epi16(out + 23*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tb, 1));
510   _mm256_mask_storeu_epi16(out + 24*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tc, 0));
511   _mm256_mask_storeu_epi16(out + 25*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tc, 1));
512   _mm256_mask_storeu_epi16(out + 26*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(td, 0));
513   _mm256_mask_storeu_epi16(out + 27*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(td, 1));
514   _mm256_mask_storeu_epi16(out + 28*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(te, 0));
515   _mm256_mask_storeu_epi16(out + 29*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(te, 1));
516   _mm256_mask_storeu_epi16(out + 30*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tf, 0));
517   _mm256_mask_storeu_epi16(out + 31*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tf, 1));
518 #else
519  LIBXSMM_UNUSED(in); LIBXSMM_UNUSED(out); LIBXSMM_UNUSED(ld_in); LIBXSMM_UNUSED(ld_out); LIBXSMM_UNUSED(col);
520 #endif
521 }
522 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)523 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
524 void bf16_transpose(libxsmm_bfloat16 *in, libxsmm_bfloat16 *out, int M, int N, int ld_in, int ld_out){
525 #if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
526   int i, j;
527   int full16_chunks = N/16;
528   int remainder_cols = N%16;
529   int _N = N - remainder_cols;
530 
531   if (full16_chunks) {
532     for (i=0; i<M; i+=32) {
533       for (j=0; j<_N; j+=16) {
534         bf16_transpose_32x16((libxsmm_bfloat16*)in + i + ld_in*j, (libxsmm_bfloat16*)out + j + i*ld_out, ld_in, ld_out);
535       }
536     }
537   }
538 
539   if (remainder_cols) {
540     for (i=0; i<M; i+=32) {
541       bf16_transpose_32xcols((libxsmm_bfloat16*)in + i + ld_in*full16_chunks*16, (libxsmm_bfloat16*)out + full16_chunks*16 + i*ld_out, remainder_cols, ld_in, ld_out);
542     }
543   }
544 #else
545  LIBXSMM_UNUSED(in); LIBXSMM_UNUSED(out); LIBXSMM_UNUSED(M); LIBXSMM_UNUSED(N); LIBXSMM_UNUSED(ld_in); LIBXSMM_UNUSED(ld_out);
546 #endif
547 }
548 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)549 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
550 void bf16_vnni_reformat(libxsmm_bfloat16 *_in, libxsmm_bfloat16 *_out, int M, int N, int ld_in, int ld_out) {
551 #if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
552   int n_full_pairs = N/2, n_pair, m;
553   int half_n_pair = N%2;
554   libxsmm_bfloat16 *in = _in, *out = _out;
555   const __m512i selector = LIBXSMM_INTRINSICS_MM512_SET_EPI16(32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0);
556   const __m512i offsets_lo = LIBXSMM_INTRINSICS_MM512_SET_EPI16(15, 15, 14, 14, 13, 13, 12, 12, 11, 11, 10, 10, 9, 9, 8, 8, 7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2, 1, 1, 0, 0);
557   const __m512i offsets_hi = LIBXSMM_INTRINSICS_MM512_SET_EPI16(31, 31, 30, 30, 29, 29, 28, 28, 27, 27, 26, 26, 25, 25, 24, 24, 23, 23, 22, 22, 21, 21, 20, 20, 19, 19, 18, 18, 17, 17, 16, 16);
558   const __m512i idx_lo =  _mm512_or_epi32(selector, offsets_lo);
559   const __m512i idx_hi =  _mm512_or_epi32(selector, offsets_hi);
560   const __m512i zero_reg = _mm512_setzero_si512();
561   __m512i n0, n1, out_lo, out_hi;
562   LIBXSMM_UNUSED(ld_out);
563   for (n_pair = 0; n_pair < n_full_pairs; n_pair++) {
564     for (m = 0; m < M; m+=32) {
565       n0 = _mm512_loadu_si512((const libxsmm_bfloat16*)in+m);
566       n1 = _mm512_loadu_si512((const libxsmm_bfloat16*)in+m+ld_in);
567       out_lo = _mm512_permutex2var_epi16(n0, idx_lo, n1);
568       out_hi = _mm512_permutex2var_epi16(n0, idx_hi, n1);
569       _mm512_storeu_si512((libxsmm_bfloat16*)out+m*2, out_lo);
570       _mm512_storeu_si512((libxsmm_bfloat16*)out+m*2+32, out_hi);
571     }
572     in += 2*ld_in;
573     out += 2*ld_in;
574   }
575   if (half_n_pair == 1) {
576     for (m = 0; m < M; m+=32) {
577       n0 = _mm512_loadu_si512((const libxsmm_bfloat16*)in+m);
578       n1 = zero_reg;
579       out_lo = _mm512_permutex2var_epi16(n0, idx_lo, n1);
580       out_hi = _mm512_permutex2var_epi16(n0, idx_lo, n1);
581       _mm512_storeu_si512((libxsmm_bfloat16*)out+m*2, out_lo);
582       _mm512_storeu_si512((libxsmm_bfloat16*)out+m*2+32, out_hi);
583     }
584   }
585 #else
586  LIBXSMM_UNUSED(_in); LIBXSMM_UNUSED(_out); LIBXSMM_UNUSED(M); LIBXSMM_UNUSED(N); LIBXSMM_UNUSED(ld_in); LIBXSMM_UNUSED(ld_out);
587 #endif
588 }
589 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)590 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
591 libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_custom_f32_f32(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
592 {
593   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
594 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
595   typedef float element_input_type;
596   typedef float element_output_type;
597   typedef float element_filter_type;
598   libxsmm_blasint lda_bwd = (libxsmm_blasint)handle->ifmblock;
599   libxsmm_blasint ldb_bwd = (libxsmm_blasint)handle->desc.K;
600   libxsmm_blasint ldc_bwd = (libxsmm_blasint)handle->desc.C;
601   libxsmm_blasint lda_upd = (libxsmm_blasint)handle->desc.K;
602   libxsmm_blasint ldb_upd = (libxsmm_blasint)handle->desc.N;
603   libxsmm_blasint ldc_upd = (libxsmm_blasint)handle->ofmblock;
604   element_input_type alpha = (element_input_type)1;
605   element_input_type beta = (element_input_type)0;
606 
607   if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) {
608     typedef libxsmm_smmfunction gemm_function;
609     gemm_function gemm_kernel_bwd = libxsmm_smmdispatch(handle->ifmblock, handle->desc.N, handle->desc.K, &lda_bwd, &ldb_bwd, &ldc_bwd, &alpha, &beta, NULL, NULL);
610     gemm_function gemm_kernel_upd = libxsmm_smmdispatch(handle->ofmblock, handle->ifmblock, handle->desc.N, &lda_upd, &ldb_upd, &ldc_upd, &alpha, &beta, NULL, NULL);
611 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_custom_generic.tpl.c"
612   } else {
613     status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
614   }
615 #else /* should not happen */
616   LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(kind); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
617   status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
618 #endif
619   return status;
620 }
621 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)622 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
623 libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_custom_bf16_f32(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
624 {
625   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
626 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
627   typedef libxsmm_bfloat16 element_input_type;
628   typedef float element_output_type;
629   typedef libxsmm_bfloat16 element_filter_type;
630   typedef libxsmm_smmfunction gemm_function;
631   libxsmm_blasint lda_bwd = (libxsmm_blasint)handle->ifmblock;
632   libxsmm_blasint ldb_bwd = (libxsmm_blasint)handle->desc.K;
633   libxsmm_blasint ldc_bwd = (libxsmm_blasint)handle->desc.C;
634   libxsmm_blasint lda_upd = (libxsmm_blasint)handle->desc.K;
635   libxsmm_blasint ldb_upd = (libxsmm_blasint)handle->desc.N;
636   libxsmm_blasint ldc_upd = (libxsmm_blasint)handle->ofmblock;
637   float alpha = (element_input_type)1;
638   float beta = (element_input_type)0;
639 
640   if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) {
641     gemm_function gemm_kernel_bwd = libxsmm_smmdispatch(handle->ifmblock, handle->desc.N, handle->desc.K, &lda_bwd, &ldb_bwd, &ldc_bwd, &alpha, &beta, NULL, NULL);
642     gemm_function gemm_kernel_upd = libxsmm_smmdispatch(handle->ofmblock, handle->ifmblock, handle->desc.N, &lda_upd, &ldb_upd, &ldc_upd, &alpha, &beta, NULL, NULL);
643 # define LIBXSMM_DNN_FULLYCONNECTED_BWD_BF16_F32
644 # define LIBXSMM_DNN_FULLYCONNECTED_UPD_BF16_F32
645 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_custom_generic.tpl.c"
646 # undef LIBXSMM_DNN_FULLYCONNECTED_UPD_BF16_F32
647 # undef LIBXSMM_DNN_FULLYCONNECTED_BWD_BF16_F32
648   } else {
649     status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
650   }
651 #else /* should not happen */
652   LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(kind); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
653   status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
654 #endif
655   return status;
656 }
657 
658 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)659 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
660 libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_f32_f32(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
661 {
662   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
663 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
664   typedef float element_input_type;
665   typedef float element_output_type;
666   typedef float element_filter_type;
667   libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_bwd = handle->gemm_bwd.xgemm.smrs;
668   libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_bwd_zerobeta = handle->gemm_bwd2.xgemm.smrs;
669   libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_upd = handle->gemm_upd.xgemm.smrs;
670   libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_upd_zerobeta = handle->gemm_upd2.xgemm.smrs;
671 
672 #define LIBXSMM_DNN_FC_BWD_USE_AVX512
673   if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) {
674 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c"
675   } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) {
676 #define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
677 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c"
678 #undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
679   } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) {
680 #define LIBXSMM_DNN_FC_BWD_FUSE_RELU
681 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c"
682 #undef LIBXSMM_DNN_FC_BWD_FUSE_RELU
683   } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_SIGMOID ) {
684 #define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
685 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c"
686 #undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
687   } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_RELU ) {
688 #define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
689 #define LIBXSMM_DNN_FC_BWD_FUSE_RELU
690 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c"
691 #undef LIBXSMM_DNN_FC_BWD_FUSE_RELU
692 #undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
693   } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_SIGMOID ) {
694 #define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
695 #define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
696 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c"
697 #undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
698 #undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
699   } else {
700     status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
701   }
702 #undef LIBXSMM_DNN_FC_BWD_USE_AVX512
703 #else /* should not happen */
704   LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(kind); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
705   status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
706 #endif
707   return status;
708 }
709 
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)710 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
711 libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_emu(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
712 {
713   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
714 #if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
715   typedef libxsmm_bfloat16 element_input_type;
716   typedef libxsmm_bfloat16 element_output_type;
717   typedef libxsmm_bfloat16 element_filter_type;
718   libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernel_bwd = handle->gemm_bwd.xgemm.bsmrs;
719   libxsmm_bmmfunction_reducebatch_strd batchreduce_kernel_bwd_zerobeta = handle->gemm_bwd2.xgemm.bmrs;
720   libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernel_upd = handle->gemm_upd.xgemm.bsmrs;
721   libxsmm_bmmfunction_reducebatch_strd batchreduce_kernel_upd_zerobeta = handle->gemm_upd2.xgemm.bmrs;
722 
723   /* some portable macrros fof BF16 <-> FP32 */
724 # include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
725 
726   if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) {
727 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c"
728   } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) {
729 #define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
730 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c"
731 #undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
732   } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) {
733 #define LIBXSMM_DNN_FC_BWD_FUSE_RELU
734 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c"
735 #undef LIBXSMM_DNN_FC_BWD_FUSE_RELU
736   } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_SIGMOID ) {
737 #define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
738 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c"
739 #undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
740   } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_RELU ) {
741 #define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
742 #define LIBXSMM_DNN_FC_BWD_FUSE_RELU
743 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c"
744 #undef LIBXSMM_DNN_FC_BWD_FUSE_RELU
745 #undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
746   } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_SIGMOID ) {
747 #define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
748 #define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
749 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c"
750 #undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
751 #undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
752   } else {
753     status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
754   }
755 
756 # include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
757 #else /* should not happen */
758   LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(kind); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
759   status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
760 #endif
761   return status;
762 }
763 
764 #if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX)765   LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX)
766 libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
767 {
768   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
769 #if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
770   typedef libxsmm_bfloat16 element_input_type;
771   typedef libxsmm_bfloat16 element_output_type;
772   typedef libxsmm_bfloat16 element_filter_type;
773   libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernel_bwd = handle->gemm_bwd.xgemm.bsmrs;
774   libxsmm_bmmfunction_reducebatch_strd batchreduce_kernel_bwd_zerobeta = handle->gemm_bwd2.xgemm.bmrs;
775   libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernel_upd = handle->gemm_upd.xgemm.bsmrs;
776   libxsmm_bmmfunction_reducebatch_strd batchreduce_kernel_upd_zerobeta = handle->gemm_upd2.xgemm.bmrs;
777 
778 #define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
779   /* some portable macrros fof BF16 <-> FP32 */
780 # include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
781 
782   if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) {
783 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c"
784   } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) {
785 #define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
786 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c"
787 #undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
788   } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) {
789 #define LIBXSMM_DNN_FC_BWD_FUSE_RELU
790 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c"
791 #undef LIBXSMM_DNN_FC_BWD_FUSE_RELU
792   } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_SIGMOID ) {
793 #define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
794 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c"
795 #undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
796   } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_RELU ) {
797 #define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
798 #define LIBXSMM_DNN_FC_BWD_FUSE_RELU
799 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c"
800 #undef LIBXSMM_DNN_FC_BWD_FUSE_RELU
801 #undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
802   } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_SIGMOID ) {
803 #define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
804 #define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
805 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c"
806 #undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
807 #undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
808   } else {
809     status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
810   }
811 
812 # include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
813 #undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
814 #else /* should not happen */
815   LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(kind); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
816   status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
817 #endif
818   return status;
819 }
820 #else
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)821   LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
822 libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
823 {
824   return libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_emu( handle, kind, start_thread, tid );
825 }
826 #endif
827 
libxsmm_dnn_fullyconnected_st_bwdupd_custom(libxsmm_dnn_fullyconnected * handle,libxsmm_dnn_compute_kind kind,int start_thread,int tid)828 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_custom(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
829 {
830   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
831 
832   /* check if all required tensors are bound */
833   if ( kind == LIBXSMM_DNN_COMPUTE_KIND_BWD ) {
834     if (handle->grad_input == 0 || handle->grad_output == 0 ||
835         handle->reg_filter == 0 || handle->scratch == 0         ) {
836       status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
837       return status;
838     }
839   } else if ( kind == LIBXSMM_DNN_COMPUTE_KIND_UPD ) {
840     if (handle->reg_input == 0   || handle->grad_output == 0 ||
841         handle->grad_filter == 0 || handle->scratch == 0         ) {
842       status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
843       return status;
844     }
845   } else {
846     if (handle->grad_input == 0 || handle->grad_output == 0 ||
847         handle->reg_input  == 0 || handle->grad_filter == 0 ||
848         handle->reg_filter == 0 || handle->scratch == 0         ) {
849       status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
850       return status;
851     }
852   }
853 
854   /* check if we are on an AVX512 platform */
855 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
856   if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) {
857     if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
858       status = libxsmm_dnn_fullyconnected_st_bwdupd_custom_f32_f32( handle, kind, start_thread, tid);
859     }
860 #if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__*/
861     else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
862       status = libxsmm_dnn_fullyconnected_st_bwdupd_custom_bf16_f32( handle, kind, start_thread, tid);
863     }
864 #endif
865     else {
866       status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
867       return status;
868     }
869   } else
870 #endif
871   {
872     if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
873       typedef float element_input_type;
874       typedef float element_output_type;
875       typedef float element_filter_type;
876       typedef libxsmm_smmfunction gemm_function;
877       libxsmm_blasint lda_bwd = (libxsmm_blasint)handle->ifmblock;
878       libxsmm_blasint ldb_bwd = (libxsmm_blasint)handle->desc.K;
879       libxsmm_blasint ldc_bwd = (libxsmm_blasint)handle->desc.C;
880       libxsmm_blasint lda_upd = (libxsmm_blasint)handle->desc.K;
881       libxsmm_blasint ldb_upd = (libxsmm_blasint)handle->desc.N;
882       libxsmm_blasint ldc_upd = (libxsmm_blasint)handle->ofmblock;
883       element_input_type alpha = (element_input_type)1;
884       element_input_type beta = (element_input_type)0;
885 
886       if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) {
887         gemm_function gemm_kernel_bwd = libxsmm_smmdispatch(handle->ifmblock, handle->desc.N, handle->desc.K, &lda_bwd, &ldb_bwd, &ldc_bwd, &alpha, &beta, NULL, NULL);
888         gemm_function gemm_kernel_upd = libxsmm_smmdispatch(handle->ofmblock, handle->ifmblock, handle->desc.N, &lda_upd, &ldb_upd, &ldc_upd, &alpha, &beta, NULL, NULL);
889 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_custom_generic.tpl.c"
890       } else {
891         status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
892       }
893     } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
894       typedef libxsmm_bfloat16 element_input_type;
895       typedef float element_output_type;
896       typedef libxsmm_bfloat16 element_filter_type;
897       typedef libxsmm_smmfunction gemm_function;
898       libxsmm_blasint lda_bwd = (libxsmm_blasint)handle->ifmblock;
899       libxsmm_blasint ldb_bwd = (libxsmm_blasint)handle->desc.K;
900       libxsmm_blasint ldc_bwd = (libxsmm_blasint)handle->desc.C;
901       libxsmm_blasint lda_upd = (libxsmm_blasint)handle->desc.K;
902       libxsmm_blasint ldb_upd = (libxsmm_blasint)handle->desc.N;
903       libxsmm_blasint ldc_upd = (libxsmm_blasint)handle->ofmblock;
904       float alpha = (element_input_type)1;
905       float beta = (element_input_type)0;
906 
907       if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) {
908         gemm_function gemm_kernel_bwd = libxsmm_smmdispatch(handle->ifmblock, handle->desc.N, handle->desc.K, &lda_bwd, &ldb_bwd, &ldc_bwd, &alpha, &beta, NULL, NULL);
909         gemm_function gemm_kernel_upd = libxsmm_smmdispatch(handle->ofmblock, handle->ifmblock, handle->desc.N, &lda_upd, &ldb_upd, &ldc_upd, &alpha, &beta, NULL, NULL);
910 # define LIBXSMM_DNN_FULLYCONNECTED_BWD_BF16_F32
911 # define LIBXSMM_DNN_FULLYCONNECTED_UPD_BF16_F32
912 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_custom_generic.tpl.c"
913 # undef LIBXSMM_DNN_FULLYCONNECTED_UPD_BF16_F32
914 # undef LIBXSMM_DNN_FULLYCONNECTED_BWD_BF16_F32
915       } else {
916         status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
917       }
918     } else {
919       status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
920       return status;
921     }
922   }
923 
924   return status;
925 }
926 
927 
libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck(libxsmm_dnn_fullyconnected * handle,libxsmm_dnn_compute_kind kind,int start_thread,int tid)928 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
929 {
930   libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
931 
932   /* check if all required tensors are bound */
933   if ( kind == LIBXSMM_DNN_COMPUTE_KIND_BWD ) {
934     if (handle->grad_input == 0 || handle->grad_output == 0 ||
935         handle->reg_filter == 0 || handle->scratch == 0         ) {
936       status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
937       return status;
938     }
939   } else if ( kind == LIBXSMM_DNN_COMPUTE_KIND_UPD ) {
940     if (handle->reg_input == 0   || handle->grad_output == 0 ||
941         handle->grad_filter == 0 || handle->scratch == 0         ) {
942       status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
943       return status;
944     }
945   } else {
946     if (handle->grad_input == 0 || handle->grad_output == 0 ||
947         handle->reg_input  == 0 || handle->grad_filter == 0 ||
948         handle->reg_filter == 0 || handle->scratch == 0         ) {
949       status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
950       return status;
951     }
952   }
953 
954   if ( ((handle->desc.fuse_ops & LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) != 0) && ( handle->grad_bias == 0 ) )  {
955     status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
956     return status;
957   }
958   if ( ((handle->desc.fuse_ops & LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) != 0) && ( handle->relumask == 0 ) )  {
959     status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
960     return status;
961   }
962 
963   /* check if we are on an AVX512 platform */
964 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
965   if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) {
966     if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
967       status = libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_f32_f32( handle, kind, start_thread, tid);
968     }
969 #if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
970     else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_CORE && libxsmm_target_archid < LIBXSMM_X86_AVX512_CPX) {
971       status = libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_emu( handle, kind, start_thread, tid);
972     } else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_CPX ) {
973       status = libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16( handle, kind, start_thread, tid);
974     }
975 #elif defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
976     else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_CORE ) {
977       status = libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_emu( handle, kind, start_thread, tid);
978     }
979 #endif
980     else {
981       status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
982       return status;
983     }
984   } else
985 #endif
986   {
987     if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
988       typedef float element_input_type;
989       typedef float element_output_type;
990       typedef float element_filter_type;
991       libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_bwd = handle->gemm_bwd.xgemm.smrs;
992       libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_bwd_zerobeta = handle->gemm_bwd2.xgemm.smrs;
993       libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_upd = handle->gemm_upd.xgemm.smrs;
994       libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_upd_zerobeta = handle->gemm_upd2.xgemm.smrs;
995 
996       if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) {
997 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c"
998       } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) {
999 #define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
1000 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c"
1001 #undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
1002       } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) {
1003 #define LIBXSMM_DNN_FC_BWD_FUSE_RELU
1004 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c"
1005 #undef LIBXSMM_DNN_FC_BWD_FUSE_RELU
1006       } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_SIGMOID ) {
1007 #define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
1008 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c"
1009 #undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
1010       } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_RELU ) {
1011 #define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
1012 #define LIBXSMM_DNN_FC_BWD_FUSE_RELU
1013 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c"
1014 #undef LIBXSMM_DNN_FC_BWD_FUSE_RELU
1015 #undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
1016       } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_SIGMOID ) {
1017 #define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
1018 #define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
1019 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c"
1020 #undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
1021 #undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
1022       } else {
1023         status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
1024       }
1025     } else {
1026       status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
1027       return status;
1028     }
1029   }
1030 
1031   return status;
1032 }
1033 
1034 
libxsmm_dnn_fullyconnected_st_bwdupd_nhwc(libxsmm_dnn_fullyconnected * handle,libxsmm_dnn_compute_kind kind,int start_thread,int tid)1035 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_nhwc(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
1036 {
1037   libxsmm_dnn_err_t status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
1038   LIBXSMM_UNUSED( handle );
1039   LIBXSMM_UNUSED( kind );
1040   LIBXSMM_UNUSED( start_thread );
1041   LIBXSMM_UNUSED( tid );
1042   return status;
1043 }
1044 
1045