1 /******************************************************************************
2 * Copyright (c) Intel Corporation - All rights reserved. *
3 * This file is part of the LIBXSMM library. *
4 * *
5 * For information on the license, see the LICENSE file. *
6 * Further information: https://github.com/hfp/libxsmm/ *
7 * SPDX-License-Identifier: BSD-3-Clause *
8 ******************************************************************************/
9 /* Alexander Heinecke, Evangelos Georganas (Intel Corp.)
10 ******************************************************************************/
11 #include "libxsmm_dnn_fullyconnected_backward_weight_update.h"
12 #include "libxsmm_main.h"
13
14 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_custom_f32_f32(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid);
15 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_f32_f32(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid);
16 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_custom_bf16_f32(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid);
17 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid);
18 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_emu(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid);
19
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)20 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
21 void bf16_vnni_transpose_16x16(void* source_void, void* dest_void, int source_stride, int dest_stride)
22 {
23 #if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
24 libxsmm_bfloat16 *source = (libxsmm_bfloat16*)source_void;
25 libxsmm_bfloat16 *dest = (libxsmm_bfloat16*)dest_void;
26 __m512i zmm0, zmm1, zmm2, zmm3, zmm4, zmm5, zmm6, zmm7;
27 __m512i tmp0, tmp1, tmp2, tmp3;
28 const __m512i abcdefgh_to_abefcdgh = _mm512_set4_epi32(0x0f0e0b0a, 0x0d0c0908, 0x07060302, 0x05040100);
29
30 zmm0 = _mm512_loadu_si512(source);
31 zmm1 = _mm512_loadu_si512(source + source_stride);
32 zmm2 = _mm512_loadu_si512(source + source_stride*2);
33 zmm3 = _mm512_loadu_si512(source + source_stride*3);
34 zmm4 = _mm512_loadu_si512(source + source_stride*4);
35 zmm5 = _mm512_loadu_si512(source + source_stride*5);
36 zmm6 = _mm512_loadu_si512(source + source_stride*6);
37 zmm7 = _mm512_loadu_si512(source + source_stride*7);
38
39 zmm0 = _mm512_shuffle_epi8(zmm0, abcdefgh_to_abefcdgh);
40 zmm1 = _mm512_shuffle_epi8(zmm1, abcdefgh_to_abefcdgh);
41 zmm2 = _mm512_shuffle_epi8(zmm2, abcdefgh_to_abefcdgh);
42 zmm3 = _mm512_shuffle_epi8(zmm3, abcdefgh_to_abefcdgh);
43 zmm4 = _mm512_shuffle_epi8(zmm4, abcdefgh_to_abefcdgh);
44 zmm5 = _mm512_shuffle_epi8(zmm5, abcdefgh_to_abefcdgh);
45 zmm6 = _mm512_shuffle_epi8(zmm6, abcdefgh_to_abefcdgh);
46 zmm7 = _mm512_shuffle_epi8(zmm7, abcdefgh_to_abefcdgh);
47
48 tmp0 = _mm512_unpacklo_epi64(zmm0, zmm1);
49 tmp1 = _mm512_unpackhi_epi64(zmm0, zmm1);
50 tmp2 = _mm512_unpacklo_epi64(zmm2, zmm3);
51 tmp3 = _mm512_unpackhi_epi64(zmm2, zmm3);
52 zmm0 = _mm512_unpacklo_epi64(zmm4, zmm5);
53 zmm1 = _mm512_unpackhi_epi64(zmm4, zmm5);
54 zmm2 = _mm512_unpacklo_epi64(zmm6, zmm7);
55 zmm3 = _mm512_unpackhi_epi64(zmm6, zmm7);
56
57 zmm4 = _mm512_shuffle_i32x4(tmp0, tmp2, 0x88);
58 zmm6 = _mm512_shuffle_i32x4(tmp0, tmp2, 0xdd);
59 zmm5 = _mm512_shuffle_i32x4(tmp1, tmp3, 0x88);
60 zmm7 = _mm512_shuffle_i32x4(tmp1, tmp3, 0xdd);
61 tmp0 = _mm512_shuffle_i32x4(zmm0, zmm2, 0x88);
62 tmp1 = _mm512_shuffle_i32x4(zmm0, zmm2, 0xdd);
63 tmp2 = _mm512_shuffle_i32x4(zmm1, zmm3, 0x88);
64 tmp3 = _mm512_shuffle_i32x4(zmm1, zmm3, 0xdd);
65
66 zmm0 = _mm512_shuffle_i32x4(zmm4, tmp0, 0x88);
67 zmm1 = _mm512_shuffle_i32x4(zmm5, tmp2, 0x88);
68 zmm2 = _mm512_shuffle_i32x4(zmm6, tmp1, 0x88);
69 zmm3 = _mm512_shuffle_i32x4(zmm7, tmp3, 0x88);
70 zmm4 = _mm512_shuffle_i32x4(zmm4, tmp0, 0xdd);
71 zmm5 = _mm512_shuffle_i32x4(zmm5, tmp2, 0xdd);
72 zmm6 = _mm512_shuffle_i32x4(zmm6, tmp1, 0xdd);
73 zmm7 = _mm512_shuffle_i32x4(zmm7, tmp3, 0xdd);
74
75 _mm512_storeu_si512(dest, zmm0);
76 _mm512_storeu_si512(dest + dest_stride, zmm1);
77 _mm512_storeu_si512(dest + dest_stride * 2, zmm2);
78 _mm512_storeu_si512(dest + dest_stride * 3, zmm3);
79 _mm512_storeu_si512(dest + dest_stride * 4, zmm4);
80 _mm512_storeu_si512(dest + dest_stride * 5, zmm5);
81 _mm512_storeu_si512(dest + dest_stride * 6, zmm6);
82 _mm512_storeu_si512(dest + dest_stride * 7, zmm7);
83 #else
84 LIBXSMM_UNUSED(source_void); LIBXSMM_UNUSED(dest_void); LIBXSMM_UNUSED(source_stride); LIBXSMM_UNUSED(dest_stride);
85 #endif
86 }
87
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)88 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
89 void bf16_vnni_transpose(libxsmm_bfloat16* src, libxsmm_bfloat16* dst, int M, int N, int ld_in, int ld_out)
90 {
91 #if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
92 const int _M = M/16, _N = N/16;
93 int i = 0, j = 0;
94 for (i = 0; i < _N; i++) {
95 for (j = 0; j < _M; j++) {
96 bf16_vnni_transpose_16x16((libxsmm_bfloat16*) src+i*16*ld_in+j*32, (libxsmm_bfloat16*) dst+j*16*ld_out+i*32, ld_in*2, ld_out*2);
97 }
98 }
99 #else
100 LIBXSMM_UNUSED(src); LIBXSMM_UNUSED(dst); LIBXSMM_UNUSED(M); LIBXSMM_UNUSED(N); LIBXSMM_UNUSED(ld_in); LIBXSMM_UNUSED(ld_out);
101 #endif
102 }
103
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)104 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
105 void bf16_transpose_32x16(libxsmm_bfloat16 *in, libxsmm_bfloat16 *out, int ld_in, int ld_out)
106 {
107 #if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
108 __m512i r0, r1, r2, r3, r4, r5, r6, r7, r8, r9, ra, rb, rc, rd, re, rf;
109 __m512i t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, ta, tb, tc, td, te, tf;
110 const int in_width=ld_in, out_width=ld_out;
111 const __m512i idx_lo = _mm512_set_epi64(13, 12, 5, 4, 9, 8, 1, 0);
112 const __m512i idx_hi = _mm512_set_epi64(7, 6, 15, 14, 3, 2, 11, 10);
113
114 r0 = _mm512_loadu_si512(in + 0*in_width);
115 r1 = _mm512_loadu_si512(in + 1*in_width);
116 r2 = _mm512_loadu_si512(in + 2*in_width);
117 r3 = _mm512_loadu_si512(in + 3*in_width);
118 r4 = _mm512_loadu_si512(in + 4*in_width);
119 r5 = _mm512_loadu_si512(in + 5*in_width);
120 r6 = _mm512_loadu_si512(in + 6*in_width);
121 r7 = _mm512_loadu_si512(in + 7*in_width);
122 r8 = _mm512_loadu_si512(in + 8*in_width);
123 r9 = _mm512_loadu_si512(in + 9*in_width);
124 ra = _mm512_loadu_si512(in + 10*in_width);
125 rb = _mm512_loadu_si512(in + 11*in_width);
126 rc = _mm512_loadu_si512(in + 12*in_width);
127 rd = _mm512_loadu_si512(in + 13*in_width);
128 re = _mm512_loadu_si512(in + 14*in_width);
129 rf = _mm512_loadu_si512(in + 15*in_width);
130
131 t0 = _mm512_unpacklo_epi16(r0,r1);
132 t1 = _mm512_unpackhi_epi16(r0,r1);
133 t2 = _mm512_unpacklo_epi16(r2,r3);
134 t3 = _mm512_unpackhi_epi16(r2,r3);
135 t4 = _mm512_unpacklo_epi16(r4,r5);
136 t5 = _mm512_unpackhi_epi16(r4,r5);
137 t6 = _mm512_unpacklo_epi16(r6,r7);
138 t7 = _mm512_unpackhi_epi16(r6,r7);
139 t8 = _mm512_unpacklo_epi16(r8,r9);
140 t9 = _mm512_unpackhi_epi16(r8,r9);
141 ta = _mm512_unpacklo_epi16(ra,rb);
142 tb = _mm512_unpackhi_epi16(ra,rb);
143 tc = _mm512_unpacklo_epi16(rc,rd);
144 td = _mm512_unpackhi_epi16(rc,rd);
145 te = _mm512_unpacklo_epi16(re,rf);
146 tf = _mm512_unpackhi_epi16(re,rf);
147
148 r0 = _mm512_unpacklo_epi32(t0,t2);
149 r1 = _mm512_unpackhi_epi32(t0,t2);
150 r2 = _mm512_unpacklo_epi32(t1,t3);
151 r3 = _mm512_unpackhi_epi32(t1,t3);
152 r4 = _mm512_unpacklo_epi32(t4,t6);
153 r5 = _mm512_unpackhi_epi32(t4,t6);
154 r6 = _mm512_unpacklo_epi32(t5,t7);
155 r7 = _mm512_unpackhi_epi32(t5,t7);
156 r8 = _mm512_unpacklo_epi32(t8,ta);
157 r9 = _mm512_unpackhi_epi32(t8,ta);
158 ra = _mm512_unpacklo_epi32(t9,tb);
159 rb = _mm512_unpackhi_epi32(t9,tb);
160 rc = _mm512_unpacklo_epi32(tc,te);
161 rd = _mm512_unpackhi_epi32(tc,te);
162 re = _mm512_unpacklo_epi32(td,tf);
163 rf = _mm512_unpackhi_epi32(td,tf);
164
165 t0 = _mm512_unpacklo_epi64(r0,r4);
166 t1 = _mm512_unpackhi_epi64(r0,r4);
167 t2 = _mm512_unpacklo_epi64(r1,r5);
168 t3 = _mm512_unpackhi_epi64(r1,r5);
169 t4 = _mm512_unpacklo_epi64(r2,r6);
170 t5 = _mm512_unpackhi_epi64(r2,r6);
171 t6 = _mm512_unpacklo_epi64(r3,r7);
172 t7 = _mm512_unpackhi_epi64(r3,r7);
173 t8 = _mm512_unpacklo_epi64(r8,rc);
174 t9 = _mm512_unpackhi_epi64(r8,rc);
175 ta = _mm512_unpacklo_epi64(r9,rd);
176 tb = _mm512_unpackhi_epi64(r9,rd);
177 tc = _mm512_unpacklo_epi64(ra,re);
178 td = _mm512_unpackhi_epi64(ra,re);
179 te = _mm512_unpacklo_epi64(rb,rf);
180 tf = _mm512_unpackhi_epi64(rb,rf);
181
182 r0 = _mm512_shuffle_i32x4(t0, t1, 0x88);
183 r1 = _mm512_shuffle_i32x4(t2, t3, 0x88);
184 r2 = _mm512_shuffle_i32x4(t4, t5, 0x88);
185 r3 = _mm512_shuffle_i32x4(t6, t7, 0x88);
186 r4 = _mm512_shuffle_i32x4(t0, t1, 0xdd);
187 r5 = _mm512_shuffle_i32x4(t2, t3, 0xdd);
188 r6 = _mm512_shuffle_i32x4(t4, t5, 0xdd);
189 r7 = _mm512_shuffle_i32x4(t6, t7, 0xdd);
190 r8 = _mm512_shuffle_i32x4(t8, t9, 0x88);
191 r9 = _mm512_shuffle_i32x4(ta, tb, 0x88);
192 ra = _mm512_shuffle_i32x4(tc, td, 0x88);
193 rb = _mm512_shuffle_i32x4(te, tf, 0x88);
194 rc = _mm512_shuffle_i32x4(t8, t9, 0xdd);
195 rd = _mm512_shuffle_i32x4(ta, tb, 0xdd);
196 re = _mm512_shuffle_i32x4(tc, td, 0xdd);
197 rf = _mm512_shuffle_i32x4(te, tf, 0xdd);
198
199 t0 = _mm512_permutex2var_epi64(r0, idx_lo, r8);
200 t1 = _mm512_permutex2var_epi64(r1, idx_lo, r9);
201 t2 = _mm512_permutex2var_epi64(r2, idx_lo, ra);
202 t3 = _mm512_permutex2var_epi64(r3, idx_lo, rb);
203 t4 = _mm512_permutex2var_epi64(r4, idx_lo, rc);
204 t5 = _mm512_permutex2var_epi64(r5, idx_lo, rd);
205 t6 = _mm512_permutex2var_epi64(r6, idx_lo, re);
206 t7 = _mm512_permutex2var_epi64(r7, idx_lo, rf);
207 t8 = _mm512_permutex2var_epi64(r8, idx_hi, r0);
208 t9 = _mm512_permutex2var_epi64(r9, idx_hi, r1);
209 ta = _mm512_permutex2var_epi64(ra, idx_hi, r2);
210 tb = _mm512_permutex2var_epi64(rb, idx_hi, r3);
211 tc = _mm512_permutex2var_epi64(rc, idx_hi, r4);
212 td = _mm512_permutex2var_epi64(rd, idx_hi, r5);
213 te = _mm512_permutex2var_epi64(re, idx_hi, r6);
214 tf = _mm512_permutex2var_epi64(rf, idx_hi, r7);
215
216 LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 0*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t0, 0));
217 LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 1*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t0, 1));
218 LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 2*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t1, 0));
219 LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 3*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t1, 1));
220 LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 4*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t2, 0));
221 LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 5*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t2, 1));
222 LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 6*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t3, 0));
223 LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 7*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t3, 1));
224 LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 8*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t4, 0));
225 LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 9*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t4, 1));
226 LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 10*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t5, 0));
227 LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 11*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t5, 1));
228 LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 12*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t6, 0));
229 LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 13*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t6, 1));
230 LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 14*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t7, 0));
231 LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 15*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t7, 1));
232 LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 16*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t8, 0));
233 LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 17*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t8, 1));
234 LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 18*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t9, 0));
235 LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 19*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t9, 1));
236 LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 20*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(ta, 0));
237 LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 21*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(ta, 1));
238 LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 22*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tb, 0));
239 LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 23*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tb, 1));
240 LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 24*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tc, 0));
241 LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 25*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tc, 1));
242 LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 26*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(td, 0));
243 LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 27*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(td, 1));
244 LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 28*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(te, 0));
245 LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 29*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(te, 1));
246 LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 30*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tf, 0));
247 LIBXSMM_INTRINSICS_MM256_STORE_EPI32(out + 31*out_width, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tf, 1));
248 #else
249 LIBXSMM_UNUSED(in); LIBXSMM_UNUSED(out); LIBXSMM_UNUSED(ld_in); LIBXSMM_UNUSED(ld_out);
250 #endif
251 }
252
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)253 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
254 void bf16_transpose_32xcols(libxsmm_bfloat16 *in, libxsmm_bfloat16 *out, int col, int ld_in, int ld_out)
255 {
256 #if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
257 __m512i r0 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), r1 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), r2 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), r3 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), r4 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), r5 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), r6 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), r7 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), r8 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), r9 = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), ra = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), rb = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), rc = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), rd = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), re = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32(), rf = LIBXSMM_INTRINSICS_MM512_UNDEFINED_EPI32();
258 __m512i t0, t1, t2, t3, t4, t5, t6, t7, t8, t9, ta, tb, tc, td, te, tf;
259 const int in_width=ld_in, out_width=ld_out;
260 const __m512i idx_lo = _mm512_set_epi64(13, 12, 5, 4, 9, 8, 1, 0);
261 const __m512i idx_hi = _mm512_set_epi64(7, 6, 15, 14, 3, 2, 11, 10);
262 __mmask16 store_mask = LIBXSMM_INTRINSICS_MM512_CVTU32_MASK16(((unsigned int)1 << col) - 1);
263
264 if (col == 15) {
265 r0 = _mm512_loadu_si512(in + 0*in_width);
266 r1 = _mm512_loadu_si512(in + 1*in_width);
267 r2 = _mm512_loadu_si512(in + 2*in_width);
268 r3 = _mm512_loadu_si512(in + 3*in_width);
269 r4 = _mm512_loadu_si512(in + 4*in_width);
270 r5 = _mm512_loadu_si512(in + 5*in_width);
271 r6 = _mm512_loadu_si512(in + 6*in_width);
272 r7 = _mm512_loadu_si512(in + 7*in_width);
273 r8 = _mm512_loadu_si512(in + 8*in_width);
274 r9 = _mm512_loadu_si512(in + 9*in_width);
275 ra = _mm512_loadu_si512(in + 10*in_width);
276 rb = _mm512_loadu_si512(in + 11*in_width);
277 rc = _mm512_loadu_si512(in + 12*in_width);
278 rd = _mm512_loadu_si512(in + 13*in_width);
279 re = _mm512_loadu_si512(in + 14*in_width);
280 } else if (col == 14) {
281 r0 = _mm512_loadu_si512(in + 0*in_width);
282 r1 = _mm512_loadu_si512(in + 1*in_width);
283 r2 = _mm512_loadu_si512(in + 2*in_width);
284 r3 = _mm512_loadu_si512(in + 3*in_width);
285 r4 = _mm512_loadu_si512(in + 4*in_width);
286 r5 = _mm512_loadu_si512(in + 5*in_width);
287 r6 = _mm512_loadu_si512(in + 6*in_width);
288 r7 = _mm512_loadu_si512(in + 7*in_width);
289 r8 = _mm512_loadu_si512(in + 8*in_width);
290 r9 = _mm512_loadu_si512(in + 9*in_width);
291 ra = _mm512_loadu_si512(in + 10*in_width);
292 rb = _mm512_loadu_si512(in + 11*in_width);
293 rc = _mm512_loadu_si512(in + 12*in_width);
294 rd = _mm512_loadu_si512(in + 13*in_width);
295 } else if (col == 13) {
296 r0 = _mm512_loadu_si512(in + 0*in_width);
297 r1 = _mm512_loadu_si512(in + 1*in_width);
298 r2 = _mm512_loadu_si512(in + 2*in_width);
299 r3 = _mm512_loadu_si512(in + 3*in_width);
300 r4 = _mm512_loadu_si512(in + 4*in_width);
301 r5 = _mm512_loadu_si512(in + 5*in_width);
302 r6 = _mm512_loadu_si512(in + 6*in_width);
303 r7 = _mm512_loadu_si512(in + 7*in_width);
304 r8 = _mm512_loadu_si512(in + 8*in_width);
305 r9 = _mm512_loadu_si512(in + 9*in_width);
306 ra = _mm512_loadu_si512(in + 10*in_width);
307 rb = _mm512_loadu_si512(in + 11*in_width);
308 rc = _mm512_loadu_si512(in + 12*in_width);
309 } else if (col == 12) {
310 r0 = _mm512_loadu_si512(in + 0*in_width);
311 r1 = _mm512_loadu_si512(in + 1*in_width);
312 r2 = _mm512_loadu_si512(in + 2*in_width);
313 r3 = _mm512_loadu_si512(in + 3*in_width);
314 r4 = _mm512_loadu_si512(in + 4*in_width);
315 r5 = _mm512_loadu_si512(in + 5*in_width);
316 r6 = _mm512_loadu_si512(in + 6*in_width);
317 r7 = _mm512_loadu_si512(in + 7*in_width);
318 r8 = _mm512_loadu_si512(in + 8*in_width);
319 r9 = _mm512_loadu_si512(in + 9*in_width);
320 ra = _mm512_loadu_si512(in + 10*in_width);
321 rb = _mm512_loadu_si512(in + 11*in_width);
322 } else if (col == 11) {
323 r0 = _mm512_loadu_si512(in + 0*in_width);
324 r1 = _mm512_loadu_si512(in + 1*in_width);
325 r2 = _mm512_loadu_si512(in + 2*in_width);
326 r3 = _mm512_loadu_si512(in + 3*in_width);
327 r4 = _mm512_loadu_si512(in + 4*in_width);
328 r5 = _mm512_loadu_si512(in + 5*in_width);
329 r6 = _mm512_loadu_si512(in + 6*in_width);
330 r7 = _mm512_loadu_si512(in + 7*in_width);
331 r8 = _mm512_loadu_si512(in + 8*in_width);
332 r9 = _mm512_loadu_si512(in + 9*in_width);
333 ra = _mm512_loadu_si512(in + 10*in_width);
334 } else if (col == 10) {
335 r0 = _mm512_loadu_si512(in + 0*in_width);
336 r1 = _mm512_loadu_si512(in + 1*in_width);
337 r2 = _mm512_loadu_si512(in + 2*in_width);
338 r3 = _mm512_loadu_si512(in + 3*in_width);
339 r4 = _mm512_loadu_si512(in + 4*in_width);
340 r5 = _mm512_loadu_si512(in + 5*in_width);
341 r6 = _mm512_loadu_si512(in + 6*in_width);
342 r7 = _mm512_loadu_si512(in + 7*in_width);
343 r8 = _mm512_loadu_si512(in + 8*in_width);
344 r9 = _mm512_loadu_si512(in + 9*in_width);
345 } else if (col == 9) {
346 r0 = _mm512_loadu_si512(in + 0*in_width);
347 r1 = _mm512_loadu_si512(in + 1*in_width);
348 r2 = _mm512_loadu_si512(in + 2*in_width);
349 r3 = _mm512_loadu_si512(in + 3*in_width);
350 r4 = _mm512_loadu_si512(in + 4*in_width);
351 r5 = _mm512_loadu_si512(in + 5*in_width);
352 r6 = _mm512_loadu_si512(in + 6*in_width);
353 r7 = _mm512_loadu_si512(in + 7*in_width);
354 r8 = _mm512_loadu_si512(in + 8*in_width);
355 } else if (col == 8) {
356 r0 = _mm512_loadu_si512(in + 0*in_width);
357 r1 = _mm512_loadu_si512(in + 1*in_width);
358 r2 = _mm512_loadu_si512(in + 2*in_width);
359 r3 = _mm512_loadu_si512(in + 3*in_width);
360 r4 = _mm512_loadu_si512(in + 4*in_width);
361 r5 = _mm512_loadu_si512(in + 5*in_width);
362 r6 = _mm512_loadu_si512(in + 6*in_width);
363 r7 = _mm512_loadu_si512(in + 7*in_width);
364 } else if (col == 7) {
365 r0 = _mm512_loadu_si512(in + 0*in_width);
366 r1 = _mm512_loadu_si512(in + 1*in_width);
367 r2 = _mm512_loadu_si512(in + 2*in_width);
368 r3 = _mm512_loadu_si512(in + 3*in_width);
369 r4 = _mm512_loadu_si512(in + 4*in_width);
370 r5 = _mm512_loadu_si512(in + 5*in_width);
371 r6 = _mm512_loadu_si512(in + 6*in_width);
372 } else if (col == 6) {
373 r0 = _mm512_loadu_si512(in + 0*in_width);
374 r1 = _mm512_loadu_si512(in + 1*in_width);
375 r2 = _mm512_loadu_si512(in + 2*in_width);
376 r3 = _mm512_loadu_si512(in + 3*in_width);
377 r4 = _mm512_loadu_si512(in + 4*in_width);
378 r5 = _mm512_loadu_si512(in + 5*in_width);
379 } else if (col == 5) {
380 r0 = _mm512_loadu_si512(in + 0*in_width);
381 r1 = _mm512_loadu_si512(in + 1*in_width);
382 r2 = _mm512_loadu_si512(in + 2*in_width);
383 r3 = _mm512_loadu_si512(in + 3*in_width);
384 r4 = _mm512_loadu_si512(in + 4*in_width);
385 } else if (col == 4) {
386 r0 = _mm512_loadu_si512(in + 0*in_width);
387 r1 = _mm512_loadu_si512(in + 1*in_width);
388 r2 = _mm512_loadu_si512(in + 2*in_width);
389 r3 = _mm512_loadu_si512(in + 3*in_width);
390 } else if (col == 3) {
391 r0 = _mm512_loadu_si512(in + 0*in_width);
392 r1 = _mm512_loadu_si512(in + 1*in_width);
393 r2 = _mm512_loadu_si512(in + 2*in_width);
394 } else if (col == 2) {
395 r0 = _mm512_loadu_si512(in + 0*in_width);
396 r1 = _mm512_loadu_si512(in + 1*in_width);
397 } else if (col == 1) {
398 r0 = _mm512_loadu_si512(in + 0*in_width);
399 }
400
401 t0 = _mm512_unpacklo_epi16(r0,r1);
402 t1 = _mm512_unpackhi_epi16(r0,r1);
403 t2 = _mm512_unpacklo_epi16(r2,r3);
404 t3 = _mm512_unpackhi_epi16(r2,r3);
405 t4 = _mm512_unpacklo_epi16(r4,r5);
406 t5 = _mm512_unpackhi_epi16(r4,r5);
407 t6 = _mm512_unpacklo_epi16(r6,r7);
408 t7 = _mm512_unpackhi_epi16(r6,r7);
409 t8 = _mm512_unpacklo_epi16(r8,r9);
410 t9 = _mm512_unpackhi_epi16(r8,r9);
411 ta = _mm512_unpacklo_epi16(ra,rb);
412 tb = _mm512_unpackhi_epi16(ra,rb);
413 tc = _mm512_unpacklo_epi16(rc,rd);
414 td = _mm512_unpackhi_epi16(rc,rd);
415 te = _mm512_unpacklo_epi16(re,rf);
416 tf = _mm512_unpackhi_epi16(re,rf);
417
418 r0 = _mm512_unpacklo_epi32(t0,t2);
419 r1 = _mm512_unpackhi_epi32(t0,t2);
420 r2 = _mm512_unpacklo_epi32(t1,t3);
421 r3 = _mm512_unpackhi_epi32(t1,t3);
422 r4 = _mm512_unpacklo_epi32(t4,t6);
423 r5 = _mm512_unpackhi_epi32(t4,t6);
424 r6 = _mm512_unpacklo_epi32(t5,t7);
425 r7 = _mm512_unpackhi_epi32(t5,t7);
426 r8 = _mm512_unpacklo_epi32(t8,ta);
427 r9 = _mm512_unpackhi_epi32(t8,ta);
428 ra = _mm512_unpacklo_epi32(t9,tb);
429 rb = _mm512_unpackhi_epi32(t9,tb);
430 rc = _mm512_unpacklo_epi32(tc,te);
431 rd = _mm512_unpackhi_epi32(tc,te);
432 re = _mm512_unpacklo_epi32(td,tf);
433 rf = _mm512_unpackhi_epi32(td,tf);
434
435 t0 = _mm512_unpacklo_epi64(r0,r4);
436 t1 = _mm512_unpackhi_epi64(r0,r4);
437 t2 = _mm512_unpacklo_epi64(r1,r5);
438 t3 = _mm512_unpackhi_epi64(r1,r5);
439 t4 = _mm512_unpacklo_epi64(r2,r6);
440 t5 = _mm512_unpackhi_epi64(r2,r6);
441 t6 = _mm512_unpacklo_epi64(r3,r7);
442 t7 = _mm512_unpackhi_epi64(r3,r7);
443 t8 = _mm512_unpacklo_epi64(r8,rc);
444 t9 = _mm512_unpackhi_epi64(r8,rc);
445 ta = _mm512_unpacklo_epi64(r9,rd);
446 tb = _mm512_unpackhi_epi64(r9,rd);
447 tc = _mm512_unpacklo_epi64(ra,re);
448 td = _mm512_unpackhi_epi64(ra,re);
449 te = _mm512_unpacklo_epi64(rb,rf);
450 tf = _mm512_unpackhi_epi64(rb,rf);
451
452 r0 = _mm512_shuffle_i32x4(t0, t1, 0x88);
453 r1 = _mm512_shuffle_i32x4(t2, t3, 0x88);
454 r2 = _mm512_shuffle_i32x4(t4, t5, 0x88);
455 r3 = _mm512_shuffle_i32x4(t6, t7, 0x88);
456 r4 = _mm512_shuffle_i32x4(t0, t1, 0xdd);
457 r5 = _mm512_shuffle_i32x4(t2, t3, 0xdd);
458 r6 = _mm512_shuffle_i32x4(t4, t5, 0xdd);
459 r7 = _mm512_shuffle_i32x4(t6, t7, 0xdd);
460 r8 = _mm512_shuffle_i32x4(t8, t9, 0x88);
461 r9 = _mm512_shuffle_i32x4(ta, tb, 0x88);
462 ra = _mm512_shuffle_i32x4(tc, td, 0x88);
463 rb = _mm512_shuffle_i32x4(te, tf, 0x88);
464 rc = _mm512_shuffle_i32x4(t8, t9, 0xdd);
465 rd = _mm512_shuffle_i32x4(ta, tb, 0xdd);
466 re = _mm512_shuffle_i32x4(tc, td, 0xdd);
467 rf = _mm512_shuffle_i32x4(te, tf, 0xdd);
468
469 t0 = _mm512_permutex2var_epi64(r0, idx_lo, r8);
470 t1 = _mm512_permutex2var_epi64(r1, idx_lo, r9);
471 t2 = _mm512_permutex2var_epi64(r2, idx_lo, ra);
472 t3 = _mm512_permutex2var_epi64(r3, idx_lo, rb);
473 t4 = _mm512_permutex2var_epi64(r4, idx_lo, rc);
474 t5 = _mm512_permutex2var_epi64(r5, idx_lo, rd);
475 t6 = _mm512_permutex2var_epi64(r6, idx_lo, re);
476 t7 = _mm512_permutex2var_epi64(r7, idx_lo, rf);
477 t8 = _mm512_permutex2var_epi64(r8, idx_hi, r0);
478 t9 = _mm512_permutex2var_epi64(r9, idx_hi, r1);
479 ta = _mm512_permutex2var_epi64(ra, idx_hi, r2);
480 tb = _mm512_permutex2var_epi64(rb, idx_hi, r3);
481 tc = _mm512_permutex2var_epi64(rc, idx_hi, r4);
482 td = _mm512_permutex2var_epi64(rd, idx_hi, r5);
483 te = _mm512_permutex2var_epi64(re, idx_hi, r6);
484 tf = _mm512_permutex2var_epi64(rf, idx_hi, r7);
485
486 _mm256_mask_storeu_epi16(out + 0*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t0, 0));
487 _mm256_mask_storeu_epi16(out + 1*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t0, 1));
488 _mm256_mask_storeu_epi16(out + 2*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t1, 0));
489 _mm256_mask_storeu_epi16(out + 3*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t1, 1));
490 _mm256_mask_storeu_epi16(out + 4*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t2, 0));
491 _mm256_mask_storeu_epi16(out + 5*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t2, 1));
492 _mm256_mask_storeu_epi16(out + 6*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t3, 0));
493 _mm256_mask_storeu_epi16(out + 7*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t3, 1));
494 _mm256_mask_storeu_epi16(out + 8*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t4, 0));
495 _mm256_mask_storeu_epi16(out + 9*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t4, 1));
496 _mm256_mask_storeu_epi16(out + 10*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t5, 0));
497 _mm256_mask_storeu_epi16(out + 11*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t5, 1));
498 _mm256_mask_storeu_epi16(out + 12*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t6, 0));
499 _mm256_mask_storeu_epi16(out + 13*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t6, 1));
500 _mm256_mask_storeu_epi16(out + 14*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t7, 0));
501 _mm256_mask_storeu_epi16(out + 15*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t7, 1));
502 _mm256_mask_storeu_epi16(out + 16*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t8, 0));
503 _mm256_mask_storeu_epi16(out + 17*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t8, 1));
504 _mm256_mask_storeu_epi16(out + 18*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t9, 0));
505 _mm256_mask_storeu_epi16(out + 19*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(t9, 1));
506 _mm256_mask_storeu_epi16(out + 20*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(ta, 0));
507 _mm256_mask_storeu_epi16(out + 21*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(ta, 1));
508 _mm256_mask_storeu_epi16(out + 22*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tb, 0));
509 _mm256_mask_storeu_epi16(out + 23*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tb, 1));
510 _mm256_mask_storeu_epi16(out + 24*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tc, 0));
511 _mm256_mask_storeu_epi16(out + 25*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tc, 1));
512 _mm256_mask_storeu_epi16(out + 26*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(td, 0));
513 _mm256_mask_storeu_epi16(out + 27*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(td, 1));
514 _mm256_mask_storeu_epi16(out + 28*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(te, 0));
515 _mm256_mask_storeu_epi16(out + 29*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(te, 1));
516 _mm256_mask_storeu_epi16(out + 30*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tf, 0));
517 _mm256_mask_storeu_epi16(out + 31*out_width, store_mask, LIBXSMM_INTRINSICS_MM512_EXTRACTI64X4_EPI64(tf, 1));
518 #else
519 LIBXSMM_UNUSED(in); LIBXSMM_UNUSED(out); LIBXSMM_UNUSED(ld_in); LIBXSMM_UNUSED(ld_out); LIBXSMM_UNUSED(col);
520 #endif
521 }
522
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)523 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
524 void bf16_transpose(libxsmm_bfloat16 *in, libxsmm_bfloat16 *out, int M, int N, int ld_in, int ld_out){
525 #if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
526 int i, j;
527 int full16_chunks = N/16;
528 int remainder_cols = N%16;
529 int _N = N - remainder_cols;
530
531 if (full16_chunks) {
532 for (i=0; i<M; i+=32) {
533 for (j=0; j<_N; j+=16) {
534 bf16_transpose_32x16((libxsmm_bfloat16*)in + i + ld_in*j, (libxsmm_bfloat16*)out + j + i*ld_out, ld_in, ld_out);
535 }
536 }
537 }
538
539 if (remainder_cols) {
540 for (i=0; i<M; i+=32) {
541 bf16_transpose_32xcols((libxsmm_bfloat16*)in + i + ld_in*full16_chunks*16, (libxsmm_bfloat16*)out + full16_chunks*16 + i*ld_out, remainder_cols, ld_in, ld_out);
542 }
543 }
544 #else
545 LIBXSMM_UNUSED(in); LIBXSMM_UNUSED(out); LIBXSMM_UNUSED(M); LIBXSMM_UNUSED(N); LIBXSMM_UNUSED(ld_in); LIBXSMM_UNUSED(ld_out);
546 #endif
547 }
548
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)549 LIBXSMM_API_INLINE LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
550 void bf16_vnni_reformat(libxsmm_bfloat16 *_in, libxsmm_bfloat16 *_out, int M, int N, int ld_in, int ld_out) {
551 #if defined(LIBXSMM_INTRINSICS_AVX512_CORE)
552 int n_full_pairs = N/2, n_pair, m;
553 int half_n_pair = N%2;
554 libxsmm_bfloat16 *in = _in, *out = _out;
555 const __m512i selector = LIBXSMM_INTRINSICS_MM512_SET_EPI16(32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0, 32, 0);
556 const __m512i offsets_lo = LIBXSMM_INTRINSICS_MM512_SET_EPI16(15, 15, 14, 14, 13, 13, 12, 12, 11, 11, 10, 10, 9, 9, 8, 8, 7, 7, 6, 6, 5, 5, 4, 4, 3, 3, 2, 2, 1, 1, 0, 0);
557 const __m512i offsets_hi = LIBXSMM_INTRINSICS_MM512_SET_EPI16(31, 31, 30, 30, 29, 29, 28, 28, 27, 27, 26, 26, 25, 25, 24, 24, 23, 23, 22, 22, 21, 21, 20, 20, 19, 19, 18, 18, 17, 17, 16, 16);
558 const __m512i idx_lo = _mm512_or_epi32(selector, offsets_lo);
559 const __m512i idx_hi = _mm512_or_epi32(selector, offsets_hi);
560 const __m512i zero_reg = _mm512_setzero_si512();
561 __m512i n0, n1, out_lo, out_hi;
562 LIBXSMM_UNUSED(ld_out);
563 for (n_pair = 0; n_pair < n_full_pairs; n_pair++) {
564 for (m = 0; m < M; m+=32) {
565 n0 = _mm512_loadu_si512((const libxsmm_bfloat16*)in+m);
566 n1 = _mm512_loadu_si512((const libxsmm_bfloat16*)in+m+ld_in);
567 out_lo = _mm512_permutex2var_epi16(n0, idx_lo, n1);
568 out_hi = _mm512_permutex2var_epi16(n0, idx_hi, n1);
569 _mm512_storeu_si512((libxsmm_bfloat16*)out+m*2, out_lo);
570 _mm512_storeu_si512((libxsmm_bfloat16*)out+m*2+32, out_hi);
571 }
572 in += 2*ld_in;
573 out += 2*ld_in;
574 }
575 if (half_n_pair == 1) {
576 for (m = 0; m < M; m+=32) {
577 n0 = _mm512_loadu_si512((const libxsmm_bfloat16*)in+m);
578 n1 = zero_reg;
579 out_lo = _mm512_permutex2var_epi16(n0, idx_lo, n1);
580 out_hi = _mm512_permutex2var_epi16(n0, idx_lo, n1);
581 _mm512_storeu_si512((libxsmm_bfloat16*)out+m*2, out_lo);
582 _mm512_storeu_si512((libxsmm_bfloat16*)out+m*2+32, out_hi);
583 }
584 }
585 #else
586 LIBXSMM_UNUSED(_in); LIBXSMM_UNUSED(_out); LIBXSMM_UNUSED(M); LIBXSMM_UNUSED(N); LIBXSMM_UNUSED(ld_in); LIBXSMM_UNUSED(ld_out);
587 #endif
588 }
589
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)590 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
591 libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_custom_f32_f32(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
592 {
593 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
594 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
595 typedef float element_input_type;
596 typedef float element_output_type;
597 typedef float element_filter_type;
598 libxsmm_blasint lda_bwd = (libxsmm_blasint)handle->ifmblock;
599 libxsmm_blasint ldb_bwd = (libxsmm_blasint)handle->desc.K;
600 libxsmm_blasint ldc_bwd = (libxsmm_blasint)handle->desc.C;
601 libxsmm_blasint lda_upd = (libxsmm_blasint)handle->desc.K;
602 libxsmm_blasint ldb_upd = (libxsmm_blasint)handle->desc.N;
603 libxsmm_blasint ldc_upd = (libxsmm_blasint)handle->ofmblock;
604 element_input_type alpha = (element_input_type)1;
605 element_input_type beta = (element_input_type)0;
606
607 if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) {
608 typedef libxsmm_smmfunction gemm_function;
609 gemm_function gemm_kernel_bwd = libxsmm_smmdispatch(handle->ifmblock, handle->desc.N, handle->desc.K, &lda_bwd, &ldb_bwd, &ldc_bwd, &alpha, &beta, NULL, NULL);
610 gemm_function gemm_kernel_upd = libxsmm_smmdispatch(handle->ofmblock, handle->ifmblock, handle->desc.N, &lda_upd, &ldb_upd, &ldc_upd, &alpha, &beta, NULL, NULL);
611 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_custom_generic.tpl.c"
612 } else {
613 status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
614 }
615 #else /* should not happen */
616 LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(kind); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
617 status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
618 #endif
619 return status;
620 }
621
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)622 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
623 libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_custom_bf16_f32(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
624 {
625 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
626 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
627 typedef libxsmm_bfloat16 element_input_type;
628 typedef float element_output_type;
629 typedef libxsmm_bfloat16 element_filter_type;
630 typedef libxsmm_smmfunction gemm_function;
631 libxsmm_blasint lda_bwd = (libxsmm_blasint)handle->ifmblock;
632 libxsmm_blasint ldb_bwd = (libxsmm_blasint)handle->desc.K;
633 libxsmm_blasint ldc_bwd = (libxsmm_blasint)handle->desc.C;
634 libxsmm_blasint lda_upd = (libxsmm_blasint)handle->desc.K;
635 libxsmm_blasint ldb_upd = (libxsmm_blasint)handle->desc.N;
636 libxsmm_blasint ldc_upd = (libxsmm_blasint)handle->ofmblock;
637 float alpha = (element_input_type)1;
638 float beta = (element_input_type)0;
639
640 if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) {
641 gemm_function gemm_kernel_bwd = libxsmm_smmdispatch(handle->ifmblock, handle->desc.N, handle->desc.K, &lda_bwd, &ldb_bwd, &ldc_bwd, &alpha, &beta, NULL, NULL);
642 gemm_function gemm_kernel_upd = libxsmm_smmdispatch(handle->ofmblock, handle->ifmblock, handle->desc.N, &lda_upd, &ldb_upd, &ldc_upd, &alpha, &beta, NULL, NULL);
643 # define LIBXSMM_DNN_FULLYCONNECTED_BWD_BF16_F32
644 # define LIBXSMM_DNN_FULLYCONNECTED_UPD_BF16_F32
645 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_custom_generic.tpl.c"
646 # undef LIBXSMM_DNN_FULLYCONNECTED_UPD_BF16_F32
647 # undef LIBXSMM_DNN_FULLYCONNECTED_BWD_BF16_F32
648 } else {
649 status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
650 }
651 #else /* should not happen */
652 LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(kind); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
653 status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
654 #endif
655 return status;
656 }
657
658
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)659 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512)
660 libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_f32_f32(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
661 {
662 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
663 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
664 typedef float element_input_type;
665 typedef float element_output_type;
666 typedef float element_filter_type;
667 libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_bwd = handle->gemm_bwd.xgemm.smrs;
668 libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_bwd_zerobeta = handle->gemm_bwd2.xgemm.smrs;
669 libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_upd = handle->gemm_upd.xgemm.smrs;
670 libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_upd_zerobeta = handle->gemm_upd2.xgemm.smrs;
671
672 #define LIBXSMM_DNN_FC_BWD_USE_AVX512
673 if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) {
674 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c"
675 } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) {
676 #define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
677 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c"
678 #undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
679 } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) {
680 #define LIBXSMM_DNN_FC_BWD_FUSE_RELU
681 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c"
682 #undef LIBXSMM_DNN_FC_BWD_FUSE_RELU
683 } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_SIGMOID ) {
684 #define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
685 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c"
686 #undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
687 } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_RELU ) {
688 #define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
689 #define LIBXSMM_DNN_FC_BWD_FUSE_RELU
690 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c"
691 #undef LIBXSMM_DNN_FC_BWD_FUSE_RELU
692 #undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
693 } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_SIGMOID ) {
694 #define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
695 #define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
696 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c"
697 #undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
698 #undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
699 } else {
700 status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
701 }
702 #undef LIBXSMM_DNN_FC_BWD_USE_AVX512
703 #else /* should not happen */
704 LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(kind); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
705 status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
706 #endif
707 return status;
708 }
709
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)710 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
711 libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_emu(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
712 {
713 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
714 #if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
715 typedef libxsmm_bfloat16 element_input_type;
716 typedef libxsmm_bfloat16 element_output_type;
717 typedef libxsmm_bfloat16 element_filter_type;
718 libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernel_bwd = handle->gemm_bwd.xgemm.bsmrs;
719 libxsmm_bmmfunction_reducebatch_strd batchreduce_kernel_bwd_zerobeta = handle->gemm_bwd2.xgemm.bmrs;
720 libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernel_upd = handle->gemm_upd.xgemm.bsmrs;
721 libxsmm_bmmfunction_reducebatch_strd batchreduce_kernel_upd_zerobeta = handle->gemm_upd2.xgemm.bmrs;
722
723 /* some portable macrros fof BF16 <-> FP32 */
724 # include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
725
726 if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) {
727 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c"
728 } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) {
729 #define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
730 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c"
731 #undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
732 } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) {
733 #define LIBXSMM_DNN_FC_BWD_FUSE_RELU
734 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c"
735 #undef LIBXSMM_DNN_FC_BWD_FUSE_RELU
736 } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_SIGMOID ) {
737 #define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
738 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c"
739 #undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
740 } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_RELU ) {
741 #define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
742 #define LIBXSMM_DNN_FC_BWD_FUSE_RELU
743 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c"
744 #undef LIBXSMM_DNN_FC_BWD_FUSE_RELU
745 #undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
746 } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_SIGMOID ) {
747 #define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
748 #define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
749 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c"
750 #undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
751 #undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
752 } else {
753 status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
754 }
755
756 # include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
757 #else /* should not happen */
758 LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(kind); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
759 status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
760 #endif
761 return status;
762 }
763
764 #if defined(LIBXSMM_INTRINSICS_AVX512_CPX)
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX)765 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CPX)
766 libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
767 {
768 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
769 #if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
770 typedef libxsmm_bfloat16 element_input_type;
771 typedef libxsmm_bfloat16 element_output_type;
772 typedef libxsmm_bfloat16 element_filter_type;
773 libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernel_bwd = handle->gemm_bwd.xgemm.bsmrs;
774 libxsmm_bmmfunction_reducebatch_strd batchreduce_kernel_bwd_zerobeta = handle->gemm_bwd2.xgemm.bmrs;
775 libxsmm_bsmmfunction_reducebatch_strd batchreduce_kernel_upd = handle->gemm_upd.xgemm.bsmrs;
776 libxsmm_bmmfunction_reducebatch_strd batchreduce_kernel_upd_zerobeta = handle->gemm_upd2.xgemm.bmrs;
777
778 #define LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
779 /* some portable macrros fof BF16 <-> FP32 */
780 # include "template/libxsmm_dnn_bf16_macros_define.tpl.c"
781
782 if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) {
783 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c"
784 } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) {
785 #define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
786 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c"
787 #undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
788 } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) {
789 #define LIBXSMM_DNN_FC_BWD_FUSE_RELU
790 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c"
791 #undef LIBXSMM_DNN_FC_BWD_FUSE_RELU
792 } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_SIGMOID ) {
793 #define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
794 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c"
795 #undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
796 } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_RELU ) {
797 #define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
798 #define LIBXSMM_DNN_FC_BWD_FUSE_RELU
799 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c"
800 #undef LIBXSMM_DNN_FC_BWD_FUSE_RELU
801 #undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
802 } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_SIGMOID ) {
803 #define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
804 #define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
805 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic_bf16.tpl.c"
806 #undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
807 #undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
808 } else {
809 status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
810 }
811
812 # include "template/libxsmm_dnn_bf16_macros_undefine.tpl.c"
813 #undef LIBXSMM_DNN_BF16_USE_CPX_AVX512_NI
814 #else /* should not happen */
815 LIBXSMM_UNUSED(handle); LIBXSMM_UNUSED(kind); LIBXSMM_UNUSED(start_thread); LIBXSMM_UNUSED(tid);
816 status = LIBXSMM_DNN_ERR_UNSUPPORTED_ARCH;
817 #endif
818 return status;
819 }
820 #else
LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)821 LIBXSMM_API_INTERN LIBXSMM_INTRINSICS(LIBXSMM_X86_AVX512_CORE)
822 libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
823 {
824 return libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_emu( handle, kind, start_thread, tid );
825 }
826 #endif
827
libxsmm_dnn_fullyconnected_st_bwdupd_custom(libxsmm_dnn_fullyconnected * handle,libxsmm_dnn_compute_kind kind,int start_thread,int tid)828 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_custom(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
829 {
830 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
831
832 /* check if all required tensors are bound */
833 if ( kind == LIBXSMM_DNN_COMPUTE_KIND_BWD ) {
834 if (handle->grad_input == 0 || handle->grad_output == 0 ||
835 handle->reg_filter == 0 || handle->scratch == 0 ) {
836 status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
837 return status;
838 }
839 } else if ( kind == LIBXSMM_DNN_COMPUTE_KIND_UPD ) {
840 if (handle->reg_input == 0 || handle->grad_output == 0 ||
841 handle->grad_filter == 0 || handle->scratch == 0 ) {
842 status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
843 return status;
844 }
845 } else {
846 if (handle->grad_input == 0 || handle->grad_output == 0 ||
847 handle->reg_input == 0 || handle->grad_filter == 0 ||
848 handle->reg_filter == 0 || handle->scratch == 0 ) {
849 status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
850 return status;
851 }
852 }
853
854 /* check if we are on an AVX512 platform */
855 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
856 if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) {
857 if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
858 status = libxsmm_dnn_fullyconnected_st_bwdupd_custom_f32_f32( handle, kind, start_thread, tid);
859 }
860 #if defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__*/
861 else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
862 status = libxsmm_dnn_fullyconnected_st_bwdupd_custom_bf16_f32( handle, kind, start_thread, tid);
863 }
864 #endif
865 else {
866 status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
867 return status;
868 }
869 } else
870 #endif
871 {
872 if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
873 typedef float element_input_type;
874 typedef float element_output_type;
875 typedef float element_filter_type;
876 typedef libxsmm_smmfunction gemm_function;
877 libxsmm_blasint lda_bwd = (libxsmm_blasint)handle->ifmblock;
878 libxsmm_blasint ldb_bwd = (libxsmm_blasint)handle->desc.K;
879 libxsmm_blasint ldc_bwd = (libxsmm_blasint)handle->desc.C;
880 libxsmm_blasint lda_upd = (libxsmm_blasint)handle->desc.K;
881 libxsmm_blasint ldb_upd = (libxsmm_blasint)handle->desc.N;
882 libxsmm_blasint ldc_upd = (libxsmm_blasint)handle->ofmblock;
883 element_input_type alpha = (element_input_type)1;
884 element_input_type beta = (element_input_type)0;
885
886 if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) {
887 gemm_function gemm_kernel_bwd = libxsmm_smmdispatch(handle->ifmblock, handle->desc.N, handle->desc.K, &lda_bwd, &ldb_bwd, &ldc_bwd, &alpha, &beta, NULL, NULL);
888 gemm_function gemm_kernel_upd = libxsmm_smmdispatch(handle->ofmblock, handle->ifmblock, handle->desc.N, &lda_upd, &ldb_upd, &ldc_upd, &alpha, &beta, NULL, NULL);
889 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_custom_generic.tpl.c"
890 } else {
891 status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
892 }
893 } else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
894 typedef libxsmm_bfloat16 element_input_type;
895 typedef float element_output_type;
896 typedef libxsmm_bfloat16 element_filter_type;
897 typedef libxsmm_smmfunction gemm_function;
898 libxsmm_blasint lda_bwd = (libxsmm_blasint)handle->ifmblock;
899 libxsmm_blasint ldb_bwd = (libxsmm_blasint)handle->desc.K;
900 libxsmm_blasint ldc_bwd = (libxsmm_blasint)handle->desc.C;
901 libxsmm_blasint lda_upd = (libxsmm_blasint)handle->desc.K;
902 libxsmm_blasint ldb_upd = (libxsmm_blasint)handle->desc.N;
903 libxsmm_blasint ldc_upd = (libxsmm_blasint)handle->ofmblock;
904 float alpha = (element_input_type)1;
905 float beta = (element_input_type)0;
906
907 if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) {
908 gemm_function gemm_kernel_bwd = libxsmm_smmdispatch(handle->ifmblock, handle->desc.N, handle->desc.K, &lda_bwd, &ldb_bwd, &ldc_bwd, &alpha, &beta, NULL, NULL);
909 gemm_function gemm_kernel_upd = libxsmm_smmdispatch(handle->ofmblock, handle->ifmblock, handle->desc.N, &lda_upd, &ldb_upd, &ldc_upd, &alpha, &beta, NULL, NULL);
910 # define LIBXSMM_DNN_FULLYCONNECTED_BWD_BF16_F32
911 # define LIBXSMM_DNN_FULLYCONNECTED_UPD_BF16_F32
912 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_custom_generic.tpl.c"
913 # undef LIBXSMM_DNN_FULLYCONNECTED_UPD_BF16_F32
914 # undef LIBXSMM_DNN_FULLYCONNECTED_BWD_BF16_F32
915 } else {
916 status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
917 }
918 } else {
919 status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
920 return status;
921 }
922 }
923
924 return status;
925 }
926
927
libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck(libxsmm_dnn_fullyconnected * handle,libxsmm_dnn_compute_kind kind,int start_thread,int tid)928 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
929 {
930 libxsmm_dnn_err_t status = LIBXSMM_DNN_SUCCESS;
931
932 /* check if all required tensors are bound */
933 if ( kind == LIBXSMM_DNN_COMPUTE_KIND_BWD ) {
934 if (handle->grad_input == 0 || handle->grad_output == 0 ||
935 handle->reg_filter == 0 || handle->scratch == 0 ) {
936 status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
937 return status;
938 }
939 } else if ( kind == LIBXSMM_DNN_COMPUTE_KIND_UPD ) {
940 if (handle->reg_input == 0 || handle->grad_output == 0 ||
941 handle->grad_filter == 0 || handle->scratch == 0 ) {
942 status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
943 return status;
944 }
945 } else {
946 if (handle->grad_input == 0 || handle->grad_output == 0 ||
947 handle->reg_input == 0 || handle->grad_filter == 0 ||
948 handle->reg_filter == 0 || handle->scratch == 0 ) {
949 status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
950 return status;
951 }
952 }
953
954 if ( ((handle->desc.fuse_ops & LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) != 0) && ( handle->grad_bias == 0 ) ) {
955 status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
956 return status;
957 }
958 if ( ((handle->desc.fuse_ops & LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) != 0) && ( handle->relumask == 0 ) ) {
959 status = LIBXSMM_DNN_ERR_DATA_NOT_BOUND;
960 return status;
961 }
962
963 /* check if we are on an AVX512 platform */
964 #if defined(LIBXSMM_INTRINSICS_AVX512) /*__AVX512F__*/
965 if ( libxsmm_target_archid >= LIBXSMM_X86_AVX512 ) {
966 if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
967 status = libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_f32_f32( handle, kind, start_thread, tid);
968 }
969 #if defined(LIBXSMM_INTRINSICS_AVX512_CPX) /*__AVX512F__,__AVX512BW__,__AVX512DQ__,__AVX512BF16__*/
970 else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_CORE && libxsmm_target_archid < LIBXSMM_X86_AVX512_CPX) {
971 status = libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_emu( handle, kind, start_thread, tid);
972 } else if ( handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_CPX ) {
973 status = libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16( handle, kind, start_thread, tid);
974 }
975 #elif defined(LIBXSMM_INTRINSICS_AVX512_CORE) /*__AVX512F__,__AVX512BW__,__AVX512DQ__*/
976 else if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_BF16 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_BF16 && libxsmm_target_archid >= LIBXSMM_X86_AVX512_CORE ) {
977 status = libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_bf16_bf16_emu( handle, kind, start_thread, tid);
978 }
979 #endif
980 else {
981 status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
982 return status;
983 }
984 } else
985 #endif
986 {
987 if (handle->desc.datatype_in == LIBXSMM_DNN_DATATYPE_F32 && handle->desc.datatype_out == LIBXSMM_DNN_DATATYPE_F32 ) {
988 typedef float element_input_type;
989 typedef float element_output_type;
990 typedef float element_filter_type;
991 libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_bwd = handle->gemm_bwd.xgemm.smrs;
992 libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_bwd_zerobeta = handle->gemm_bwd2.xgemm.smrs;
993 libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_upd = handle->gemm_upd.xgemm.smrs;
994 libxsmm_smmfunction_reducebatch_strd batchreduce_kernel_upd_zerobeta = handle->gemm_upd2.xgemm.smrs;
995
996 if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_NONE ) {
997 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c"
998 } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS ) {
999 #define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
1000 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c"
1001 #undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
1002 } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_RELU ) {
1003 #define LIBXSMM_DNN_FC_BWD_FUSE_RELU
1004 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c"
1005 #undef LIBXSMM_DNN_FC_BWD_FUSE_RELU
1006 } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_SIGMOID ) {
1007 #define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
1008 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c"
1009 #undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
1010 } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_RELU ) {
1011 #define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
1012 #define LIBXSMM_DNN_FC_BWD_FUSE_RELU
1013 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c"
1014 #undef LIBXSMM_DNN_FC_BWD_FUSE_RELU
1015 #undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
1016 } else if ( handle->desc.fuse_ops == LIBXSMM_DNN_FULLYCONNECTED_FUSE_BIAS_SIGMOID ) {
1017 #define LIBXSMM_DNN_FC_BWD_FUSE_BIAS
1018 #define LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
1019 # include "template/libxsmm_dnn_fullyconnected_st_bwdupd_ncnc_kcck_generic.tpl.c"
1020 #undef LIBXSMM_DNN_FC_BWD_FUSE_SIGMOID
1021 #undef LIBXSMM_DNN_FC_BWD_FUSE_BIAS
1022 } else {
1023 status = LIBXSMM_DNN_ERR_FC_UNSUPPORTED_FUSION;
1024 }
1025 } else {
1026 status = LIBXSMM_DNN_ERR_UNSUPPORTED_DATATYPE;
1027 return status;
1028 }
1029 }
1030
1031 return status;
1032 }
1033
1034
libxsmm_dnn_fullyconnected_st_bwdupd_nhwc(libxsmm_dnn_fullyconnected * handle,libxsmm_dnn_compute_kind kind,int start_thread,int tid)1035 LIBXSMM_API_INTERN libxsmm_dnn_err_t libxsmm_dnn_fullyconnected_st_bwdupd_nhwc(libxsmm_dnn_fullyconnected* handle, libxsmm_dnn_compute_kind kind, int start_thread, int tid)
1036 {
1037 libxsmm_dnn_err_t status = LIBXSMM_DNN_ERR_NOT_IMPLEMENTED;
1038 LIBXSMM_UNUSED( handle );
1039 LIBXSMM_UNUSED( kind );
1040 LIBXSMM_UNUSED( start_thread );
1041 LIBXSMM_UNUSED( tid );
1042 return status;
1043 }
1044
1045