1 /***************************************************************************
2 Copyright (c) 2021, The OpenBLAS Project
3 All rights reserved.
4 Redistribution and use in source and binary forms, with or without
5 modification, are permitted provided that the following conditions are
6 met:
7 1. Redistributions of source code must retain the above copyright
8 notice, this list of conditions and the following disclaimer.
9 2. Redistributions in binary form must reproduce the above copyright
10 notice, this list of conditions and the following disclaimer in
11 the documentation and/or other materials provided with the
12 distribution.
13 3. Neither the name of the OpenBLAS project nor the names of
14 its contributors may be used to endorse or promote products
15 derived from this software without specific prior written permission.
16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
20 LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21 DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22 SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23 CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24 OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
25 USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 *****************************************************************************/
27
28 #include <immintrin.h>
29 #include "common.h"
30 #include <stdio.h>
31
32 #define DECLARE_RESULT_512(M, N) __m512 result##M##N = _mm512_setzero_ps()
33 #define BROADCAST_LOAD_A_512(M, N) __m512 Aval##M = _mm512_broadcastss_ps(_mm_load_ss(&A[k + lda * (i+M)]))
34 #define LOAD_B_512(M,N) __m512 Bval##N = _mm512_loadu_ps(&B[ldb * k + j + (N*16)])
35 #define MASK_LOAD_B_512(M, N) __m512 Bval##N = _mm512_maskz_loadu_ps(mask, &B[ldb * k + j + (N*16)])
36 #define MATMUL_512(M, N) result##M##N = _mm512_fmadd_ps(Aval##M, Bval##N, result##M##N)
37
38 #if defined(B0)
39 #define STORE_8xy(v, N, x, y) _mm256_storeu_ps(&C[(j + N*16 + x + y*8)*ldc + i], v)
40 #define STORE_4xy(v, N, x, y) _mm_mask_storeu_ps(&C[(j + N*16 + x + y*4)*ldc + i], mask8, v)
41 #define SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \
42 _mm512_i32scatter_ps(&C[(j + N*16)*ldc + i + M], vindex_n, result##M##N, 4);
43 #define MASK_SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \
44 _mm512_mask_i32scatter_ps(&C[(j + N*16)*ldc + i + M], mask, vindex_n, result##M##N, 4);
45 #else
46 #define STORE_8xy(v, N, x, y) \
47 asm("vfmadd231ps (%1), %2, %0": "+v"(v): "r"(&C[(j + N*16 + x + y*8)*ldc + i]), "v"(beta_256)); \
48 _mm256_storeu_ps(&C[(j + N*16 + x + y*8)*ldc + i], v)
49 #define STORE_4xy(v, N, x, y) \
50 asm("vfmadd231ps (%1), %2, %0": "+v"(v): "r"(&C[(j + N*16 + x + y*4)*ldc + i]), "v"(beta_128)); \
51 _mm_mask_storeu_ps(&C[(j + N*16 + x + y*4)*ldc + i], mask8, v)
52 #define SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \
53 __m512 tmp##M##N = _mm512_i32gather_ps(vindex_n, &C[(j + N*16)*ldc + i + M], 4); \
54 result##M##N = _mm512_fmadd_ps(tmp##M##N, beta_512, result##M##N); \
55 _mm512_i32scatter_ps(&C[(j + N*16)*ldc + i + M], vindex_n, result##M##N, 4);
56 #define MASK_SCATTER_STORE_512(M, N) result##M##N = _mm512_mul_ps(result##M##N, alpha_512); \
57 __m512 tmp##M##N = _mm512_mask_i32gather_ps(_mm512_setzero_ps(), mask, vindex_n, &C[(j + N*16)*ldc + i + M], 4); \
58 result##M##N = _mm512_fmadd_ps(tmp##M##N, beta_512, result##M##N); \
59 _mm512_mask_i32scatter_ps(&C[(j + N*16)*ldc + i + M], mask, vindex_n, result##M##N, 4);
60 #endif
61
62 #define REORDER_8x16(r0, r1, r2, r3, r4, r5, r6, r7) \
63 __m512 t0, t1, t2, t3, t4, t5, t6, t7, v; \
64 t0 = _mm512_unpacklo_ps(r0, r1); \
65 t1 = _mm512_unpackhi_ps(r0, r1); \
66 t2 = _mm512_unpacklo_ps(r2, r3); \
67 t3 = _mm512_unpackhi_ps(r2, r3); \
68 t4 = _mm512_unpacklo_ps(r4, r5); \
69 t5 = _mm512_unpackhi_ps(r4, r5); \
70 t6 = _mm512_unpacklo_ps(r6, r7); \
71 t7 = _mm512_unpackhi_ps(r6, r7); \
72 v = _mm512_shuffle_ps(t0, t2, 0x4E); \
73 r0 = _mm512_mask_blend_ps(kc, t0, v); \
74 r1 = _mm512_mask_blend_ps(k3, t2, v); \
75 v = _mm512_shuffle_ps(t1, t3, 0x4E); \
76 r2 = _mm512_mask_blend_ps(kc, t1, v); \
77 r3 = _mm512_mask_blend_ps(k3, t3, v); \
78 v = _mm512_shuffle_ps(t4, t6, 0x4E); \
79 r4 = _mm512_mask_blend_ps(kc, t4, v); \
80 r5 = _mm512_mask_blend_ps(k3, t6, v); \
81 v = _mm512_shuffle_ps(t5, t7, 0x4E); \
82 r6 = _mm512_mask_blend_ps(kc, t5, v); \
83 r7 = _mm512_mask_blend_ps(k3, t7, v); \
84 t0 = _mm512_permutex2var_ps(r0, idx_lo, r4); \
85 t1 = _mm512_permutex2var_ps(r1, idx_lo, r5); \
86 t2 = _mm512_permutex2var_ps(r2, idx_lo, r6); \
87 t3 = _mm512_permutex2var_ps(r3, idx_lo, r7); \
88 t4 = _mm512_permutex2var_ps(r0, idx_hi, r4); \
89 t5 = _mm512_permutex2var_ps(r1, idx_hi, r5); \
90 t6 = _mm512_permutex2var_ps(r2, idx_hi, r6); \
91 t7 = _mm512_permutex2var_ps(r3, idx_hi, r7); \
92 t0 = _mm512_mul_ps(t0, alpha_512); \
93 t1 = _mm512_mul_ps(t1, alpha_512); \
94 t2 = _mm512_mul_ps(t2, alpha_512); \
95 t3 = _mm512_mul_ps(t3, alpha_512); \
96 t4 = _mm512_mul_ps(t4, alpha_512); \
97 t5 = _mm512_mul_ps(t5, alpha_512); \
98 t6 = _mm512_mul_ps(t6, alpha_512); \
99 t7 = _mm512_mul_ps(t7, alpha_512);
100
101 #define SAVE_8(N, x, y) {\
102 __m256 v8 = _mm512_extractf32x8_ps(t##x, y); \
103 STORE_8xy(v8, N, x, y); \
104 }
105
106 #define REORDER_STORE_8x16(N) {\
107 REORDER_8x16(result0##N, result1##N, result2##N, result3##N, result4##N, result5##N, result6##N, result7##N); \
108 SAVE_8(N, 0, 0); SAVE_8(N, 1, 0); SAVE_8(N, 2, 0); SAVE_8(N, 3, 0); SAVE_8(N, 4, 0); SAVE_8(N, 5, 0); SAVE_8(N, 6, 0); SAVE_8(N, 7, 0); \
109 SAVE_8(N, 0, 1); SAVE_8(N, 1, 1); SAVE_8(N, 2, 1); SAVE_8(N, 3, 1); SAVE_8(N, 4, 1); SAVE_8(N, 5, 1); SAVE_8(N, 6, 1); SAVE_8(N, 7, 1); \
110 }
111
112 #define MASK_SAVE_8() \
113 switch (nn) { \
114 case 16: SAVE_8(0, 7, 1); \
115 case 15: SAVE_8(0, 6, 1); \
116 case 14: SAVE_8(0, 5, 1); \
117 case 13: SAVE_8(0, 4, 1); \
118 case 12: SAVE_8(0, 3, 1); \
119 case 11: SAVE_8(0, 2, 1); \
120 case 10: SAVE_8(0, 1, 1); \
121 case 9: SAVE_8(0, 0, 1); \
122 case 8: SAVE_8(0, 7, 0); \
123 case 7: SAVE_8(0, 6, 0); \
124 case 6: SAVE_8(0, 5, 0); \
125 case 5: SAVE_8(0, 4, 0); \
126 case 4: SAVE_8(0, 3, 0); \
127 case 3: SAVE_8(0, 2, 0); \
128 case 2: SAVE_8(0, 1, 0); \
129 case 1: SAVE_8(0, 0, 0); \
130 }
131
132 #define MASK_REORDER_STORE_8x16(N) {\
133 REORDER_8x16(result0##N, result1##N, result2##N, result3##N, result4##N, result5##N, result6##N, result7##N); \
134 MASK_SAVE_8(); \
135 }
136
137 #define REORDER_4x16(r0, r1, r2, r3) \
138 __m512 t0, t1, t2, t3, v; \
139 t0 = _mm512_unpacklo_ps(r0, r1); \
140 t1 = _mm512_unpackhi_ps(r0, r1); \
141 t2 = _mm512_unpacklo_ps(r2, r3); \
142 t3 = _mm512_unpackhi_ps(r2, r3); \
143 v = _mm512_shuffle_ps(t0, t2, 0x4E); \
144 r0 = _mm512_mask_blend_ps(kc, t0, v); \
145 r1 = _mm512_mask_blend_ps(k3, t2, v); \
146 v = _mm512_shuffle_ps(t1, t3, 0x4E); \
147 r2 = _mm512_mask_blend_ps(kc, t1, v); \
148 r3 = _mm512_mask_blend_ps(k3, t3, v); \
149 t0 = _mm512_mul_ps(r0, alpha_512); \
150 t1 = _mm512_mul_ps(r1, alpha_512); \
151 t2 = _mm512_mul_ps(r2, alpha_512); \
152 t3 = _mm512_mul_ps(r3, alpha_512);
153
154 #define SAVE_4(N, x, y) {\
155 __m128 v4 = _mm512_extractf32x4_ps(t##x, y); \
156 STORE_4xy(v4, N, x, y); \
157 }
158
159 #define REORDER_STORE_4x16(N) {\
160 REORDER_4x16(result0##N, result1##N, result2##N, result3##N); \
161 SAVE_4(N, 0, 0); SAVE_4(N, 1, 0); SAVE_4(N, 2, 0); SAVE_4(N, 3, 0); \
162 SAVE_4(N, 0, 1); SAVE_4(N, 1, 1); SAVE_4(N, 2, 1); SAVE_4(N, 3, 1); \
163 SAVE_4(N, 0, 2); SAVE_4(N, 1, 2); SAVE_4(N, 2, 2); SAVE_4(N, 3, 2); \
164 SAVE_4(N, 0, 3); SAVE_4(N, 1, 3); SAVE_4(N, 2, 3); SAVE_4(N, 3, 3); \
165 }
166
167 #define MASK_SAVE_4() \
168 switch (nn) { \
169 case 16: SAVE_4(0, 3, 3); \
170 case 15: SAVE_4(0, 2, 3); \
171 case 14: SAVE_4(0, 1, 3); \
172 case 13: SAVE_4(0, 0, 3); \
173 case 12: SAVE_4(0, 3, 2); \
174 case 11: SAVE_4(0, 2, 2); \
175 case 10: SAVE_4(0, 1, 2); \
176 case 9: SAVE_4(0, 0, 2); \
177 case 8: SAVE_4(0, 3, 1); \
178 case 7: SAVE_4(0, 2, 1); \
179 case 6: SAVE_4(0, 1, 1); \
180 case 5: SAVE_4(0, 0, 1); \
181 case 4: SAVE_4(0, 3, 0); \
182 case 3: SAVE_4(0, 2, 0); \
183 case 2: SAVE_4(0, 1, 0); \
184 case 1: SAVE_4(0, 0, 0); \
185 }
186
187 #define MASK_REORDER_STORE_4x16(N) {\
188 REORDER_4x16(result0##N, result1##N, result2##N, result3##N); \
189 MASK_SAVE_4(); \
190 }
191
192
193 #if defined(B0)
CNAME(BLASLONG M,BLASLONG N,BLASLONG K,FLOAT * A,BLASLONG lda,FLOAT alpha,FLOAT * B,BLASLONG ldb,FLOAT * C,BLASLONG ldc)194 int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc)
195 #else
196 int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc)
197 #endif
198 {
199 // column major
200 BLASLONG i, j, k;
201
202 BLASLONG m8 = M & ~7;
203 BLASLONG m4 = M & ~3;
204 BLASLONG m2 = M & ~1;
205
206 BLASLONG n64 = N & ~63;
207 BLASLONG n32 = N & ~31;
208
209 __m512 alpha_512 = _mm512_broadcastss_ps(_mm_load_ss(&alpha));
210 #if !defined(B0)
211 __m256 beta_256 = _mm256_broadcastss_ps(_mm_load_ss(&beta));
212 __m128 beta_128 = _mm_broadcastss_ps(_mm_load_ss(&beta));
213 #endif
214 int permute_table[] = {
215 0x0, 0x1, 0x2, 0x3, 0x10, 0x11, 0x12, 0x13, 0x8, 0x9, 0xa, 0xb, 0x18, 0x19, 0x1a, 0x1b,
216 0x4, 0x5, 0x6, 0x7, 0x14, 0x15, 0x16, 0x17, 0xc, 0xd, 0xe, 0xf, 0x1c, 0x1d, 0x1e, 0x1f,
217 };
218 __m512i idx_lo = _mm512_loadu_si512(permute_table);
219 __m512i idx_hi = _mm512_loadu_si512(permute_table + 16);
220 __mmask16 kc = 0xcccc;
221 __mmask16 k3 = 0x3333;
222 __mmask8 mask8 = 0xff; // force use AVX128 instead of SSE
223
224 for (i = 0; i < m8; i += 8) {
225 for (j = 0; j < n32; j += 32) {
226 DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0);
227 DECLARE_RESULT_512(4, 0); DECLARE_RESULT_512(5, 0); DECLARE_RESULT_512(6, 0); DECLARE_RESULT_512(7, 0);
228
229 DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1);
230 DECLARE_RESULT_512(4, 1); DECLARE_RESULT_512(5, 1); DECLARE_RESULT_512(6, 1); DECLARE_RESULT_512(7, 1);
231 for (k = 0; k < K; k++) {
232 BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x);
233 BROADCAST_LOAD_A_512(4, x); BROADCAST_LOAD_A_512(5, x); BROADCAST_LOAD_A_512(6, x); BROADCAST_LOAD_A_512(7, x);
234 LOAD_B_512(x, 0); LOAD_B_512(x, 1);
235 MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0);
236 MATMUL_512(4, 0); MATMUL_512(5, 0); MATMUL_512(6, 0); MATMUL_512(7, 0);
237 MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1);
238 MATMUL_512(4, 1); MATMUL_512(5, 1); MATMUL_512(6, 1); MATMUL_512(7, 1);
239 }
240 REORDER_STORE_8x16(0);
241 REORDER_STORE_8x16(1);
242 }
243 __mmask16 mask = 0xffff;
244 int nn = 16;
245 for (; j < N; j += 16) {
246 if (N - j < 16) {
247 nn = N - j;
248 mask = (1UL << nn) - 1;
249 }
250 DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0);
251 DECLARE_RESULT_512(4, 0); DECLARE_RESULT_512(5, 0); DECLARE_RESULT_512(6, 0); DECLARE_RESULT_512(7, 0);
252 for (k = 0; k < K; k++) {
253 BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x);
254 BROADCAST_LOAD_A_512(4, x); BROADCAST_LOAD_A_512(5, x); BROADCAST_LOAD_A_512(6, x); BROADCAST_LOAD_A_512(7, x);
255 MASK_LOAD_B_512(x, 0);
256 MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0);
257 MATMUL_512(4, 0); MATMUL_512(5, 0); MATMUL_512(6, 0); MATMUL_512(7, 0);
258 }
259 MASK_REORDER_STORE_8x16(0);
260 }
261 }
262 for (; i < m4; i += 4) {
263 for (j = 0; j < n64; j += 64) {
264 DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0);
265 DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1);
266 DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2);
267 DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3);
268 for (k = 0; k < K; k++) {
269 BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x);
270 LOAD_B_512(x, 0); LOAD_B_512(x, 1); LOAD_B_512(x, 2); LOAD_B_512(x, 3);
271 MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0);
272 MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1);
273 MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2);
274 MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3);
275 }
276 REORDER_STORE_4x16(0);
277 REORDER_STORE_4x16(1);
278 REORDER_STORE_4x16(2);
279 REORDER_STORE_4x16(3);
280 }
281 for (; j < n32; j += 32) {
282 DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0);
283 DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1);
284 for (k = 0; k < K; k++) {
285 BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x);
286 LOAD_B_512(x, 0); LOAD_B_512(x, 1);
287 MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0);
288 MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1);
289 }
290 REORDER_STORE_4x16(0);
291 REORDER_STORE_4x16(1);
292 }
293 __mmask16 mask = 0xffff;
294 int nn = 16;
295 for (; j < N; j += 16) {
296 if (N - j < 16) {
297 nn = N - j;
298 mask = (1UL << nn) - 1;
299 }
300 DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0);
301 for (k = 0; k < K; k++) {
302 BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x); BROADCAST_LOAD_A_512(2, x); BROADCAST_LOAD_A_512(3, x);
303 MASK_LOAD_B_512(x, 0);
304 MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0);
305 }
306 MASK_REORDER_STORE_4x16(0);
307 }
308 }
309 if (i < M) {
310 int index_n[16];
311 for (int ii = 0; ii < 16; ii++) {
312 index_n[ii] = ii * ldc;
313 }
314 __m512i vindex_n = _mm512_loadu_si512(index_n);
315 #if !defined(B0)
316 __m512 beta_512 = _mm512_broadcastss_ps(_mm_load_ss(&beta));
317 #endif
318 for (; i < m2; i += 2) {
319 for (j = 0; j < n64; j += 64) {
320 DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0);
321 DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1);
322 DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2);
323 DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3);
324 for (k = 0; k < K; k++) {
325 BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x);
326 LOAD_B_512(x, 0); LOAD_B_512(x, 1); LOAD_B_512(x, 2); LOAD_B_512(x, 3);
327 MATMUL_512(0, 0); MATMUL_512(1, 0);
328 MATMUL_512(0, 1); MATMUL_512(1, 1);
329 MATMUL_512(0, 2); MATMUL_512(1, 2);
330 MATMUL_512(0, 3); MATMUL_512(1, 3);
331 }
332 SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0);
333 SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1);
334 SCATTER_STORE_512(0, 2); SCATTER_STORE_512(1, 2);
335 SCATTER_STORE_512(0, 3); SCATTER_STORE_512(1, 3);
336 }
337 for (; j < n32; j += 32) {
338 DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0);
339 DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1);
340 for (k = 0; k < K; k++) {
341 BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x);
342 LOAD_B_512(x, 0); LOAD_B_512(x, 1);
343 MATMUL_512(0, 0); MATMUL_512(1, 0);
344 MATMUL_512(0, 1); MATMUL_512(1, 1);
345 }
346 SCATTER_STORE_512(0, 0); SCATTER_STORE_512(1, 0);
347 SCATTER_STORE_512(0, 1); SCATTER_STORE_512(1, 1);
348 }
349 __mmask16 mask = 0xffff;
350 int nn = 16;
351 for (; j < N; j += 16) {
352 if (N - j < 16) {
353 nn = N - j;
354 mask = (1UL << nn) - 1;
355 }
356 DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0);
357 for (k = 0; k < K; k++) {
358 BROADCAST_LOAD_A_512(0, x); BROADCAST_LOAD_A_512(1, x);
359 MASK_LOAD_B_512(x, 0);
360 MATMUL_512(0, 0); MATMUL_512(1, 0);
361 }
362 MASK_SCATTER_STORE_512(0, 0); MASK_SCATTER_STORE_512(1, 0);
363 }
364 }
365 for (; i < M; i += 1) {
366 for (j = 0; j < n64; j += 64) {
367 DECLARE_RESULT_512(0, 0);
368 DECLARE_RESULT_512(0, 1);
369 DECLARE_RESULT_512(0, 2);
370 DECLARE_RESULT_512(0, 3);
371 for (k = 0; k < K; k++) {
372 BROADCAST_LOAD_A_512(0, x);
373 LOAD_B_512(x, 0); LOAD_B_512(x, 1); LOAD_B_512(x, 2); LOAD_B_512(x, 3);
374 MATMUL_512(0, 0);
375 MATMUL_512(0, 1);
376 MATMUL_512(0, 2);
377 MATMUL_512(0, 3);
378 }
379 SCATTER_STORE_512(0, 0);
380 SCATTER_STORE_512(0, 1);
381 SCATTER_STORE_512(0, 2);
382 SCATTER_STORE_512(0, 3);
383 }
384 for (; j < n32; j += 32) {
385 DECLARE_RESULT_512(0, 0);
386 DECLARE_RESULT_512(0, 1);
387 for (k = 0; k < K; k++) {
388 BROADCAST_LOAD_A_512(0, x);
389 LOAD_B_512(x, 0); LOAD_B_512(x, 1);
390 MATMUL_512(0, 0);
391 MATMUL_512(0, 1);
392 }
393 SCATTER_STORE_512(0, 0);
394 SCATTER_STORE_512(0, 1);
395 }
396 __mmask16 mask = 0xffff;
397 int nn = 16;
398 for (; j < N; j += 16) {
399 if (N - j < 16) {
400 nn = N - j;
401 mask = (1UL << nn) - 1;
402 }
403 DECLARE_RESULT_512(0, 0);
404 for (k = 0; k < K; k++) {
405 BROADCAST_LOAD_A_512(0, x);
406 MASK_LOAD_B_512(x, 0);
407 MATMUL_512(0, 0);
408 }
409 MASK_SCATTER_STORE_512(0, 0);
410 }
411 }
412 }
413 return 0;
414 }
415