1 /***************************************************************************
2 Copyright (c) 2021, The OpenBLAS Project
3 All rights reserved.
4 Redistribution and use in source and binary forms, with or without
5 modification, are permitted provided that the following conditions are
6 met:
7 1. Redistributions of source code must retain the above copyright
8 notice, this list of conditions and the following disclaimer.
9 2. Redistributions in binary form must reproduce the above copyright
10 notice, this list of conditions and the following disclaimer in
11 the documentation and/or other materials provided with the
12 distribution.
13 3. Neither the name of the OpenBLAS project nor the names of
14 its contributors may be used to endorse or promote products
15 derived from this software without specific prior written permission.
16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
20 LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21 DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22 SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23 CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24 OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
25 USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 *****************************************************************************/
27
28 #include <immintrin.h>
29 #include "common.h"
30 #include <stdio.h>
31 #include <memory.h>
32
33 #define DECLARE_RESULT_512(M, N) __m512d result##M##N = _mm512_setzero_pd()
34 #define MATMUL_512(M, N) result##M##N = _mm512_fmadd_pd(Aval##M, Bval##N, result##M##N)
35
36 #define LOAD_KA_512(M, N) __m512d Aval##M = _mm512_loadu_pd(&A[(i + M)*lda + k]);
37 #define LOAD_KB_512(M, N) __m512d Bval##N = _mm512_loadu_pd(&B[(j + N)*ldb + k])
38 #define MASK_LOAD_KA_512(M, N) __m512d Aval##M = _mm512_maskz_loadu_pd(mask, &A[(i + M)*lda + k])
39 #define MASK_LOAD_KB_512(M, N) __m512d Bval##N = _mm512_maskz_loadu_pd(mask, &B[(j + N)*ldb + k])
40
41 #define REDUCE_4(rr0, rr1, rr2, rr3) \
42 __m512d r0, r1, r2, r3, t0, t1, t2, t3;\
43 r0 = _mm512_unpacklo_pd(rr0, rr1); r1 = _mm512_unpackhi_pd(rr0, rr1); \
44 r2 = _mm512_unpacklo_pd(rr2, rr3); r3 = _mm512_unpackhi_pd(rr2, rr3); \
45 t0 = _mm512_permutex2var_pd(r0, idx_lo, r2); t1 = _mm512_permutex2var_pd(r1, idx_lo, r3); \
46 t2 = _mm512_permutex2var_pd(r0, idx_hi, r2); t3 = _mm512_permutex2var_pd(r1, idx_hi, r3); \
47 r0 = _mm512_add_pd(t0, t1); r1 = _mm512_add_pd(t2, t3); t0 = _mm512_add_pd(r0, r1); \
48 __m256d s0, s1; \
49 s0 = _mm512_extractf64x4_pd(t0, 0); s1 = _mm512_extractf64x4_pd(t0, 1); \
50 s0 = _mm256_add_pd(s0, s1); s0 = _mm256_mul_pd(alpha_256, s0);
51
52 #define REDUCE_M4(N) REDUCE_4(result0##N, result1##N, result2##N, result3##N)
53 #define REDUCE_N4(M) REDUCE_4(result##M##0, result##M##1, result##M##2, result##M##3)
54
55 #if defined(B0)
56 #define STORE_REDUCE(M, N) C[(j+N)*ldc + i + M] = alpha * _mm512_reduce_add_pd(result##M##N)
57 #define STORE_M4(N, s0) _mm256_storeu_pd(&C[(j + N)*ldc + i], s0);
58 #define STORE_N4(M, s0) _mm256_i64scatter_pd(&C[j*ldc + i + M], vindex_n, s0, 8);
59 #else
60 #define STORE_REDUCE(M, N) C[(j+N)*ldc + i + M] = alpha * _mm512_reduce_add_pd(result##M##N) + beta * C[(j+N)*ldc + i + M]
61 #define STORE_M4(N, s0) \
62 asm("vfmadd231pd (%1), %2, %0": "+v"(s0):"r"(&C[(j + N)*ldc + i]), "v"(beta_256)); \
63 _mm256_storeu_pd(&C[(j + N)*ldc + i], s0);
64
65 #define STORE_N4(M, s0) \
66 s0 = _mm256_fmadd_pd(_mm256_i64gather_pd(&C[j*ldc + i + M], vindex_n, 8), beta_256, s0); \
67 _mm256_i64scatter_pd(&C[j*ldc + i + M], vindex_n, s0, 8);
68 #endif
69 #define STORE_REDUCE_M4(N) {\
70 REDUCE_M4(N) \
71 STORE_M4(N, s0) \
72 }
73 #define STORE_REDUCE_N4(M) {\
74 REDUCE_N4(M) \
75 STORE_N4(M, s0) \
76 }
77
78
79 #if defined(B0)
CNAME(BLASLONG M,BLASLONG N,BLASLONG K,FLOAT * A,BLASLONG lda,FLOAT alpha,FLOAT * B,BLASLONG ldb,FLOAT * C,BLASLONG ldc)80 int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc)
81 #else
82 int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc)
83 #endif
84 {
85 // column major
86 BLASLONG i, j, k;
87
88 BLASLONG m4 = M & ~3;
89 BLASLONG m2 = M & ~1;
90
91 BLASLONG n4 = N & ~3;
92 BLASLONG n2 = N & ~1;
93
94 BLASLONG k8 = K & ~7;
95
96 __mmask8 mask;
97
98 __m256i vindex_n = _mm256_set_epi64x(ldc*3, ldc*2, ldc, 0);
99 __m256d alpha_256 = _mm256_broadcast_sd(&alpha);
100 #if !defined(B0)
101 __m256d beta_256 = _mm256_broadcast_sd(&beta);
102 #endif
103
104 long long permute_table[] = {
105 0, 1, 0|8, 1|8, 4, 5, 4|8, 5|8,
106 2, 3, 2|8, 3|8, 6, 7, 6|8, 7|8,
107 };
108 __m512i idx_lo = _mm512_loadu_si512(permute_table);
109 __m512i idx_hi = _mm512_loadu_si512(permute_table + 8);
110
111 for (i = 0; i < m4; i += 4) {
112 for (j = 0; j < n4; j += 4) {
113 DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0);
114 DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1);
115 DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2);
116 DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3);
117 for (k = 0; k < k8; k += 8) {
118 LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x);
119 LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3);
120
121 MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0);
122 MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1);
123 MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2);
124 MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3);
125 }
126 int remains = K - k;
127 if (remains) {
128 mask = (1UL << remains) - 1;
129 MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x);
130 MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3);
131
132 MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0);
133 MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1);
134 MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2);
135 MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3);
136 }
137 STORE_REDUCE_M4(0); STORE_REDUCE_M4(1); STORE_REDUCE_M4(2); STORE_REDUCE_M4(3);
138 }
139 for (; j < n2; j += 2) {
140 DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0);
141 DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1);
142 for (k = 0; k < k8; k += 8) {
143 LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x);
144 LOAD_KB_512(x, 0); LOAD_KB_512(x, 1);
145
146 MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0);
147 MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1);
148 }
149 int remains = K - k;
150 if (remains) {
151 mask = (1UL << remains) - 1;
152 MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x);
153 MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1);
154
155 MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0);
156 MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1);
157 }
158 STORE_REDUCE_M4(0); STORE_REDUCE_M4(1);
159 }
160 for (; j < N; j += 1) {
161 DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0);
162 for (k = 0; k < k8; k += 8) {
163 LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x);
164 LOAD_KB_512(x, 0);
165
166 MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0);
167 }
168 int remains = K - k;
169 if (remains) {
170 mask = (1UL << remains) - 1;
171 MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x);
172 MASK_LOAD_KB_512(x, 0);
173
174 MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0);
175 }
176 STORE_REDUCE_M4(0);
177 }
178
179 }
180 for (; i < m2; i += 2) {
181 for (j = 0; j < n4; j += 4) {
182 DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0);
183 DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1);
184 DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2);
185 DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3);
186 for (k = 0; k < k8; k += 8) {
187 LOAD_KA_512(0, x); LOAD_KA_512(1, x);
188 LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3);
189
190 MATMUL_512(0, 0); MATMUL_512(1, 0);
191 MATMUL_512(0, 1); MATMUL_512(1, 1);
192 MATMUL_512(0, 2); MATMUL_512(1, 2);
193 MATMUL_512(0, 3); MATMUL_512(1, 3);
194 }
195 int remains = K - k;
196 if (remains) {
197 mask = (1UL << remains) - 1;
198 MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x);
199 MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3);
200
201 MATMUL_512(0, 0); MATMUL_512(1, 0);
202 MATMUL_512(0, 1); MATMUL_512(1, 1);
203 MATMUL_512(0, 2); MATMUL_512(1, 2);
204 MATMUL_512(0, 3); MATMUL_512(1, 3);
205 }
206 STORE_REDUCE_N4(0); STORE_REDUCE_N4(1);
207 }
208 for (; j < n2; j += 2) {
209 DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0);
210 DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1);
211 for (k = 0; k < k8; k += 8) {
212 LOAD_KA_512(0, x); LOAD_KA_512(1, x);
213 LOAD_KB_512(x, 0); LOAD_KB_512(x, 1);
214
215 MATMUL_512(0, 0); MATMUL_512(1, 0);
216 MATMUL_512(0, 1); MATMUL_512(1, 1);
217 }
218 int remains = K - k;
219 if (remains) {
220 mask = (1UL << remains) - 1;
221 MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x);
222 MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1);
223
224 MATMUL_512(0, 0); MATMUL_512(1, 0);
225 MATMUL_512(0, 1); MATMUL_512(1, 1);
226 }
227 STORE_REDUCE(0, 0); STORE_REDUCE(1, 0);
228 STORE_REDUCE(0, 1); STORE_REDUCE(1, 1);
229
230 }
231 for (; j < N; j += 1) {
232 DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0);
233 for (k = 0; k < k8; k += 8) {
234 LOAD_KA_512(0, x); LOAD_KA_512(1, x);
235 LOAD_KB_512(x, 0);
236
237 MATMUL_512(0, 0); MATMUL_512(1, 0);
238 }
239 int remains = K - k;
240 if (remains) {
241 mask = (1UL << remains) - 1;
242 MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x);
243 MASK_LOAD_KB_512(x, 0);
244
245 MATMUL_512(0, 0); MATMUL_512(1, 0);
246 }
247 STORE_REDUCE(0, 0); STORE_REDUCE(1, 0);
248 }
249 }
250 for (; i < M; i += 1) {
251 for (j = 0; j < n4; j += 4) {
252 DECLARE_RESULT_512(0, 0);
253 DECLARE_RESULT_512(0, 1);
254 DECLARE_RESULT_512(0, 2);
255 DECLARE_RESULT_512(0, 3);
256 for (k = 0; k < k8; k += 8) {
257 LOAD_KA_512(0, x);
258 LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3);
259
260 MATMUL_512(0, 0);
261 MATMUL_512(0, 1);
262 MATMUL_512(0, 2);
263 MATMUL_512(0, 3);
264 }
265 int remains = K - k;
266 if (remains) {
267 mask = (1UL << remains) - 1;
268 MASK_LOAD_KA_512(0, x);
269 MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3);
270
271
272 MATMUL_512(0, 0);
273 MATMUL_512(0, 1);
274 MATMUL_512(0, 2);
275 MATMUL_512(0, 3);
276 }
277 STORE_REDUCE_N4(0);
278 }
279 for (; j < n2; j += 2) {
280 DECLARE_RESULT_512(0, 0);
281 DECLARE_RESULT_512(0, 1);
282 for (k = 0; k < k8; k += 8) {
283 LOAD_KA_512(0, x);
284 LOAD_KB_512(x, 0); LOAD_KB_512(x, 1);
285
286 MATMUL_512(0, 0);
287 MATMUL_512(0, 1);
288 }
289 int remains = K - k;
290 if (remains) {
291 mask = (1UL << remains) - 1;
292 MASK_LOAD_KA_512(0, x);
293 MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1);
294
295 MATMUL_512(0, 0);
296 MATMUL_512(0, 1);
297 }
298 STORE_REDUCE(0, 0);
299 STORE_REDUCE(0, 1);
300
301 }
302 for (; j < N; j += 1) {
303 DECLARE_RESULT_512(0, 0);
304 for (k = 0; k < k8; k += 8) {
305 LOAD_KA_512(0, x);
306 LOAD_KB_512(x, 0);
307
308 MATMUL_512(0, 0);
309 }
310 int remains = K - k;
311 if (remains) {
312 mask = (1UL << remains) - 1;
313 MASK_LOAD_KA_512(0, x);
314 MASK_LOAD_KB_512(x, 0);
315
316 MATMUL_512(0, 0);
317 }
318 STORE_REDUCE(0, 0);
319 }
320 }
321 return 0;
322 }
323