1 /***************************************************************************
2 Copyright (c) 2021, The OpenBLAS Project
3 All rights reserved.
4 Redistribution and use in source and binary forms, with or without
5 modification, are permitted provided that the following conditions are
6 met:
7 1. Redistributions of source code must retain the above copyright
8 notice, this list of conditions and the following disclaimer.
9 2. Redistributions in binary form must reproduce the above copyright
10 notice, this list of conditions and the following disclaimer in
11 the documentation and/or other materials provided with the
12 distribution.
13 3. Neither the name of the OpenBLAS project nor the names of
14 its contributors may be used to endorse or promote products
15 derived from this software without specific prior written permission.
16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
20 LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21 DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22 SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23 CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24 OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
25 USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 *****************************************************************************/
27 
28 #include <immintrin.h>
29 #include "common.h"
30 #include <stdio.h>
31 
32 #define DECLARE_RESULT_512(M, N) __m512 result##M##N = _mm512_setzero_ps()
33 #define BROADCAST_LOAD_A_512(M, N) __m512 Aval##M = _mm512_broadcastss_ps(_mm_load_ss(&A[k  + lda * (i+M)]))
34 #define LOAD_B_512(M,N)  __m512 Bval##N = _mm512_loadu_ps(&B[ldb * k + j + (N*16)])
35 #define MASK_LOAD_B_512(M, N) __m512 Bval##N = _mm512_maskz_loadu_ps(mask, &B[ldb * k + j + (N*16)])
36 #define MATMUL_512(M, N) result##M##N = _mm512_fmadd_ps(Aval##M, Bval##N, result##M##N)
37 
38 #if defined(B0)
39 #define STORE_8xy(v, N, x, y) _mm256_storeu_ps(&C[(j + N*16 + x + y*8)*ldc + i], v)
40 #define STORE_4xy(v, N, x, y) _mm_mask_storeu_ps(&C[(j + N*16 + x + y*4)*ldc + i], mask8, v)
41 #define SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \
42 				_mm512_i32scatter_ps(&C[(j + N*16)*ldc + i + M], vindex_n, result##M##N, 4);
43 #define MASK_SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \
44 				    _mm512_mask_i32scatter_ps(&C[(j + N*16)*ldc + i + M], mask, vindex_n, result##M##N, 4);
45 #else
46 #define STORE_8xy(v, N, x, y) \
47 	asm("vfmadd231ps (%1), %2, %0": "+v"(v): "r"(&C[(j + N*16 + x + y*8)*ldc + i]), "v"(beta_256)); \
48 	_mm256_storeu_ps(&C[(j + N*16 + x + y*8)*ldc + i], v)
49 #define STORE_4xy(v, N, x, y) \
50 	asm("vfmadd231ps (%1), %2, %0": "+v"(v): "r"(&C[(j + N*16 + x + y*4)*ldc + i]), "v"(beta_128)); \
51 	_mm_mask_storeu_ps(&C[(j + N*16 + x + y*4)*ldc + i], mask8, v)
52 #define SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \
53 				__m512 tmp##M##N = _mm512_i32gather_ps(vindex_n, &C[(j + N*16)*ldc + i + M], 4); \
54 				result##M##N = _mm512_fmadd_ps(tmp##M##N, beta_512, result##M##N); \
55 				_mm512_i32scatter_ps(&C[(j + N*16)*ldc + i + M], vindex_n, result##M##N, 4);
56 #define MASK_SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \
57 				__m512 tmp##M##N = _mm512_mask_i32gather_ps(_mm512_setzero_ps(), mask, vindex_n, &C[(j + N*16)*ldc + i + M], 4); \
58 				result##M##N = _mm512_fmadd_ps(tmp##M##N, beta_512, result##M##N); \
59 				_mm512_mask_i32scatter_ps(&C[(j + N*16)*ldc + i + M], mask, vindex_n, result##M##N, 4);
60 #endif
61 
62 #define REORDER_8x16(r0, r1, r2, r3, r4, r5, r6, r7) \
63 	__m512 t0, t1, t2, t3, t4, t5, t6, t7, v; \
64 	t0 = _mm512_unpacklo_ps(r0, r1); \
65 	t1 = _mm512_unpackhi_ps(r0, r1); \
66 	t2 = _mm512_unpacklo_ps(r2, r3); \
67 	t3 = _mm512_unpackhi_ps(r2, r3); \
68 	t4 = _mm512_unpacklo_ps(r4, r5); \
69 	t5 = _mm512_unpackhi_ps(r4, r5); \
70 	t6 = _mm512_unpacklo_ps(r6, r7); \
71 	t7 = _mm512_unpackhi_ps(r6, r7); \
72 	v = _mm512_shuffle_ps(t0, t2, 0x4E);  \
73 	r0 = _mm512_mask_blend_ps(kc, t0, v); \
74 	r1 = _mm512_mask_blend_ps(k3, t2, v); \
75 	v = _mm512_shuffle_ps(t1, t3, 0x4E);  \
76 	r2 = _mm512_mask_blend_ps(kc, t1, v); \
77 	r3 = _mm512_mask_blend_ps(k3, t3, v); \
78 	v = _mm512_shuffle_ps(t4, t6, 0x4E);  \
79 	r4 = _mm512_mask_blend_ps(kc, t4, v); \
80 	r5 = _mm512_mask_blend_ps(k3, t6, v); \
81 	v = _mm512_shuffle_ps(t5, t7, 0x4E);  \
82 	r6 = _mm512_mask_blend_ps(kc, t5, v); \
83 	r7 = _mm512_mask_blend_ps(k3, t7, v); \
84 	t0 = _mm512_permutex2var_ps(r0, idx_lo, r4); \
85 	t1 = _mm512_permutex2var_ps(r1, idx_lo, r5); \
86 	t2 = _mm512_permutex2var_ps(r2, idx_lo, r6); \
87 	t3 = _mm512_permutex2var_ps(r3, idx_lo, r7); \
88 	t4 = _mm512_permutex2var_ps(r0, idx_hi, r4); \
89 	t5 = _mm512_permutex2var_ps(r1, idx_hi, r5); \
90 	t6 = _mm512_permutex2var_ps(r2, idx_hi, r6); \
91 	t7 = _mm512_permutex2var_ps(r3, idx_hi, r7); \
92 	t0 = _mm512_mul_ps(t0, alpha_512); \
93 	t1 = _mm512_mul_ps(t1, alpha_512); \
94 	t2 = _mm512_mul_ps(t2, alpha_512); \
95 	t3 = _mm512_mul_ps(t3, alpha_512); \
96 	t4 = _mm512_mul_ps(t4, alpha_512); \
97 	t5 = _mm512_mul_ps(t5, alpha_512); \
98 	t6 = _mm512_mul_ps(t6, alpha_512); \
99 	t7 = _mm512_mul_ps(t7, alpha_512);
100 
101 #define SAVE_8(N, x, y) {\
102 	__m256 v8 = _mm512_extractf32x8_ps(t##x, y); \
103 	STORE_8xy(v8, N, x, y); \
104 }
105 
106 #define REORDER_STORE_8x16(N) {\
107 	REORDER_8x16(result0##N, result1##N, result2##N, result3##N, result4##N, result5##N, result6##N, result7##N); \
108 	SAVE_8(N, 0, 0); SAVE_8(N, 1, 0); SAVE_8(N, 2, 0); SAVE_8(N, 3, 0); SAVE_8(N, 4, 0); SAVE_8(N, 5, 0); SAVE_8(N, 6, 0); SAVE_8(N, 7, 0); \
109 	SAVE_8(N, 0, 1); SAVE_8(N, 1, 1); SAVE_8(N, 2, 1); SAVE_8(N, 3, 1); SAVE_8(N, 4, 1); SAVE_8(N, 5, 1); SAVE_8(N, 6, 1); SAVE_8(N, 7, 1); \
110 }
111 
112 #define MASK_SAVE_8() \
113 	switch (nn) { \
114 		case 16: SAVE_8(0, 7, 1); \
115 		case 15: SAVE_8(0, 6, 1); \
116 		case 14: SAVE_8(0, 5, 1); \
117 		case 13: SAVE_8(0, 4, 1); \
118 		case 12: SAVE_8(0, 3, 1); \
119 		case 11: SAVE_8(0, 2, 1); \
120 		case 10: SAVE_8(0, 1, 1); \
121 		case 9: SAVE_8(0, 0, 1); \
122 		case 8: SAVE_8(0, 7, 0); \
123 		case 7: SAVE_8(0, 6, 0); \
124 		case 6: SAVE_8(0, 5, 0); \
125 		case 5: SAVE_8(0, 4, 0); \
126 		case 4: SAVE_8(0, 3, 0); \
127 		case 3: SAVE_8(0, 2, 0); \
128 		case 2: SAVE_8(0, 1, 0); \
129 		case 1: SAVE_8(0, 0, 0); \
130 	}
131 
132 #define MASK_REORDER_STORE_8x16(N) {\
133 	REORDER_8x16(result0##N, result1##N, result2##N, result3##N, result4##N, result5##N, result6##N, result7##N); \
134 	MASK_SAVE_8(); \
135 }
136 
137 #define REORDER_4x16(r0, r1, r2, r3) \
138 	__m512 t0, t1, t2, t3, v; \
139 	t0 = _mm512_unpacklo_ps(r0, r1); \
140 	t1 = _mm512_unpackhi_ps(r0, r1); \
141 	t2 = _mm512_unpacklo_ps(r2, r3); \
142 	t3 = _mm512_unpackhi_ps(r2, r3); \
143 	v = _mm512_shuffle_ps(t0, t2, 0x4E);  \
144 	r0 = _mm512_mask_blend_ps(kc, t0, v); \
145 	r1 = _mm512_mask_blend_ps(k3, t2, v); \
146 	v = _mm512_shuffle_ps(t1, t3, 0x4E);  \
147 	r2 = _mm512_mask_blend_ps(kc, t1, v); \
148 	r3 = _mm512_mask_blend_ps(k3, t3, v); \
149 	t0 = _mm512_mul_ps(r0, alpha_512); \
150 	t1 = _mm512_mul_ps(r1, alpha_512); \
151 	t2 = _mm512_mul_ps(r2, alpha_512); \
152 	t3 = _mm512_mul_ps(r3, alpha_512);
153 
154 #define SAVE_4(N, x, y) {\
155 	__m128 v4 = _mm512_extractf32x4_ps(t##x, y); \
156 	STORE_4xy(v4, N, x, y); \
157 }
158 
159 #define REORDER_STORE_4x16(N) {\
160 	REORDER_4x16(result0##N, result1##N, result2##N, result3##N); \
161 	SAVE_4(N, 0, 0); SAVE_4(N, 1, 0); SAVE_4(N, 2, 0); SAVE_4(N, 3, 0); \
162 	SAVE_4(N, 0, 1); SAVE_4(N, 1, 1); SAVE_4(N, 2, 1); SAVE_4(N, 3, 1); \
163 	SAVE_4(N, 0, 2); SAVE_4(N, 1, 2); SAVE_4(N, 2, 2); SAVE_4(N, 3, 2); \
164 	SAVE_4(N, 0, 3); SAVE_4(N, 1, 3); SAVE_4(N, 2, 3); SAVE_4(N, 3, 3); \
165 }
166 
167 #define MASK_SAVE_4() \
168 	switch (nn) { \
169 		case 16: SAVE_4(0, 3, 3); \
170 		case 15: SAVE_4(0, 2, 3); \
171 		case 14: SAVE_4(0, 1, 3); \
172 		case 13: SAVE_4(0, 0, 3); \
173 		case 12: SAVE_4(0, 3, 2); \
174 		case 11: SAVE_4(0, 2, 2); \
175 		case 10: SAVE_4(0, 1, 2); \
176 		case 9: SAVE_4(0, 0, 2); \
177 		case 8: SAVE_4(0, 3, 1); \
178 		case 7: SAVE_4(0, 2, 1); \
179 		case 6: SAVE_4(0, 1, 1); \
180 		case 5: SAVE_4(0, 0, 1); \
181 		case 4: SAVE_4(0, 3, 0); \
182 		case 3: SAVE_4(0, 2, 0); \
183 		case 2: SAVE_4(0, 1, 0); \
184 		case 1: SAVE_4(0, 0, 0); \
185 	}
186 
187 #define MASK_REORDER_STORE_4x16(N) {\
188 	REORDER_4x16(result0##N, result1##N, result2##N, result3##N); \
189 	MASK_SAVE_4(); \
190 }
191 
192 
193 #if defined(B0)
CNAME(BLASLONG M,BLASLONG N,BLASLONG K,FLOAT * A,BLASLONG lda,FLOAT alpha,FLOAT * B,BLASLONG ldb,FLOAT * C,BLASLONG ldc)194 int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc)
195 #else
196 int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc)
197 #endif
198 {
199 	// column major
200 	BLASLONG i, j, k;
201 
202 	BLASLONG m8 = M & ~7;
203 	BLASLONG m4 = M & ~3;
204 	BLASLONG m2 = M & ~1;
205 
206 	BLASLONG n64 = N & ~63;
207 	BLASLONG n32 = N & ~31;
208 
209 	__m512 alpha_512 = _mm512_broadcastss_ps(_mm_load_ss(&alpha));
210 #if !defined(B0)
211 	__m256 beta_256 = _mm256_broadcastss_ps(_mm_load_ss(&beta));
212 	__m128 beta_128 = _mm_broadcastss_ps(_mm_load_ss(&beta));
213 #endif
214 	int permute_table[] = {
215 		0x0, 0x1, 0x2, 0x3, 0x10, 0x11, 0x12, 0x13, 0x8, 0x9, 0xa, 0xb, 0x18, 0x19, 0x1a, 0x1b,
216 		0x4, 0x5, 0x6, 0x7, 0x14, 0x15, 0x16, 0x17, 0xc, 0xd, 0xe, 0xf, 0x1c, 0x1d, 0x1e, 0x1f,
217 	};
218 	__m512i idx_lo = _mm512_loadu_si512(permute_table);
219 	__m512i idx_hi = _mm512_loadu_si512(permute_table + 16);
220 	__mmask16 kc = 0xcccc;
221 	__mmask16 k3 = 0x3333;
222 	__mmask8 mask8 = 0xff;  // force use AVX128 instead of SSE
223 
224 	for (i = 0; i < m8; i += 8) {
225 		for (j = 0; j < n32; j += 32) {
226 			DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0);
227 			DECLARE_RESULT_512(4, 0); DECLARE_RESULT_512(5, 0); DECLARE_RESULT_512(6, 0); DECLARE_RESULT_512(7, 0);
228 
229 			DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1);
230 			DECLARE_RESULT_512(4, 1); DECLARE_RESULT_512(5, 1); DECLARE_RESULT_512(6, 1); DECLARE_RESULT_512(7, 1);
231 			for (k = 0; k < K; k++) {
232 				BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x);
233 				BROADCAST_LOAD_A_512(4, x); BROADCAST_LOAD_A_512(5, x); BROADCAST_LOAD_A_512(6, x); BROADCAST_LOAD_A_512(7, x);
234 				LOAD_B_512(x, 0); LOAD_B_512(x, 1);
235 				MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0);
236 				MATMUL_512(4, 0); MATMUL_512(5, 0); MATMUL_512(6, 0); MATMUL_512(7, 0);
237 				MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1);
238 				MATMUL_512(4, 1); MATMUL_512(5, 1); MATMUL_512(6, 1); MATMUL_512(7, 1);
239 			}
240 			REORDER_STORE_8x16(0);
241 			REORDER_STORE_8x16(1);
242 		}
243 		__mmask16 mask = 0xffff;
244 		int nn = 16;
245 		for (; j < N; j += 16) {
246 			if (N - j < 16) {
247 				nn = N - j;
248 				mask = (1UL << nn) - 1;
249 			}
250 			DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0);
251 			DECLARE_RESULT_512(4, 0); DECLARE_RESULT_512(5, 0); DECLARE_RESULT_512(6, 0); DECLARE_RESULT_512(7, 0);
252 			for (k = 0; k < K; k++) {
253 				BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x);
254 				BROADCAST_LOAD_A_512(4, x); BROADCAST_LOAD_A_512(5, x); BROADCAST_LOAD_A_512(6, x); BROADCAST_LOAD_A_512(7, x);
255 				MASK_LOAD_B_512(x, 0);
256 				MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0);
257 				MATMUL_512(4, 0); MATMUL_512(5, 0); MATMUL_512(6, 0); MATMUL_512(7, 0);
258 			}
259 			MASK_REORDER_STORE_8x16(0);
260 		}
261 	}
262 	for (; i < m4; i += 4) {
263 		for (j = 0; j < n64; j += 64) {
264 			DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0);
265 			DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1);
266 			DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2);
267 			DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3);
268 			for (k = 0; k < K; k++) {
269 				BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x);
270 				LOAD_B_512(x, 0); LOAD_B_512(x, 1); LOAD_B_512(x, 2); LOAD_B_512(x, 3);
271 				MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0);
272 				MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1);
273 				MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2);
274 				MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3);
275 			}
276 			REORDER_STORE_4x16(0);
277 			REORDER_STORE_4x16(1);
278 			REORDER_STORE_4x16(2);
279 			REORDER_STORE_4x16(3);
280 		}
281 		for (; j < n32; j += 32) {
282 			DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0);
283 			DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1);
284 			for (k = 0; k < K; k++) {
285 				BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x);
286 				LOAD_B_512(x, 0); LOAD_B_512(x, 1);
287 				MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0);
288 				MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1);
289 			}
290 			REORDER_STORE_4x16(0);
291 			REORDER_STORE_4x16(1);
292 		}
293 		__mmask16 mask = 0xffff;
294 		int nn = 16;
295 		for (; j < N; j += 16) {
296 			if (N - j < 16) {
297 				nn = N - j;
298 				mask = (1UL << nn) - 1;
299 			}
300 			DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0);
301 			for (k = 0; k < K; k++) {
302 				BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x);
303 				MASK_LOAD_B_512(x, 0);
304 				MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0);
305 			}
306 			MASK_REORDER_STORE_4x16(0);
307 		}
308 	}
309 	if (i < M) {
310 		int index_n[16];
311 		for (int ii = 0; ii < 16; ii++) {
312 			index_n[ii] = ii * ldc;
313 		}
314 		__m512i vindex_n = _mm512_loadu_si512(index_n);
315 #if !defined(B0)
316 		__m512 beta_512 = _mm512_broadcastss_ps(_mm_load_ss(&beta));
317 #endif
318 		for (; i < m2; i += 2) {
319 			for (j = 0; j < n64; j += 64) {
320 				DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0);
321 				DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1);
322 				DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2);
323 				DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3);
324 				for (k = 0; k < K; k++) {
325 					BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x);
326 					LOAD_B_512(x, 0); LOAD_B_512(x, 1); LOAD_B_512(x, 2); LOAD_B_512(x, 3);
327 					MATMUL_512(0, 0); MATMUL_512(1, 0);
328 					MATMUL_512(0, 1); MATMUL_512(1, 1);
329 					MATMUL_512(0, 2); MATMUL_512(1, 2);
330 					MATMUL_512(0, 3); MATMUL_512(1, 3);
331 				}
332 				SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0);
333 				SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1);
334 				SCATTER_STORE_512(0, 2); SCATTER_STORE_512(1, 2);
335 				SCATTER_STORE_512(0, 3); SCATTER_STORE_512(1, 3);
336 			}
337 			for (; j < n32; j += 32) {
338 				DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0);
339 				DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1);
340 				for (k = 0; k < K; k++) {
341 					BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x);
342 					LOAD_B_512(x, 0); LOAD_B_512(x, 1);
343 					MATMUL_512(0, 0); MATMUL_512(1, 0);
344 					MATMUL_512(0, 1); MATMUL_512(1, 1);
345 				}
346 				SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0);
347 				SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1);
348 			}
349 			__mmask16 mask = 0xffff;
350 			int nn = 16;
351 			for (; j < N; j += 16) {
352 				if (N - j < 16) {
353 					nn = N - j;
354 					mask = (1UL << nn) - 1;
355 				}
356 				DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0);
357 				for (k = 0; k < K; k++) {
358 					BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x);
359 					MASK_LOAD_B_512(x, 0);
360 					MATMUL_512(0, 0); MATMUL_512(1, 0);
361 				}
362 				MASK_SCATTER_STORE_512(0, 0); MASK_SCATTER_STORE_512(1, 0);
363 			}
364 		}
365 		for (; i < M; i += 1) {
366 			for (j = 0; j < n64; j += 64) {
367 				DECLARE_RESULT_512(0, 0);
368 				DECLARE_RESULT_512(0, 1);
369 				DECLARE_RESULT_512(0, 2);
370 				DECLARE_RESULT_512(0, 3);
371 				for (k = 0; k < K; k++) {
372 					BROADCAST_LOAD_A_512(0, x);
373 					LOAD_B_512(x, 0); LOAD_B_512(x, 1); LOAD_B_512(x, 2); LOAD_B_512(x, 3);
374 					MATMUL_512(0, 0);
375 					MATMUL_512(0, 1);
376 					MATMUL_512(0, 2);
377 					MATMUL_512(0, 3);
378 				}
379 				SCATTER_STORE_512(0, 0);
380 				SCATTER_STORE_512(0, 1);
381 				SCATTER_STORE_512(0, 2);
382 				SCATTER_STORE_512(0, 3);
383 			}
384 			for (; j < n32; j += 32) {
385 				DECLARE_RESULT_512(0, 0);
386 				DECLARE_RESULT_512(0, 1);
387 				for (k = 0; k < K; k++) {
388 					BROADCAST_LOAD_A_512(0, x);
389 					LOAD_B_512(x, 0); LOAD_B_512(x, 1);
390 					MATMUL_512(0, 0);
391 					MATMUL_512(0, 1);
392 				}
393 				SCATTER_STORE_512(0, 0);
394 				SCATTER_STORE_512(0, 1);
395 			}
396 			__mmask16 mask = 0xffff;
397 			int nn = 16;
398 			for (; j < N; j += 16) {
399 				if (N - j < 16) {
400 					nn = N - j;
401 					mask = (1UL << nn) - 1;
402 				}
403 				DECLARE_RESULT_512(0, 0);
404 				for (k = 0; k < K; k++) {
405 					BROADCAST_LOAD_A_512(0, x);
406 					MASK_LOAD_B_512(x, 0);
407 					MATMUL_512(0, 0);
408 				}
409 				MASK_SCATTER_STORE_512(0, 0);
410 			}
411 		}
412 	}
413 	return 0;
414 }
415