1 /***************************************************************************
2 Copyright (c) 2021, The OpenBLAS Project
3 All rights reserved.
4 Redistribution and use in source and binary forms, with or without
5 modification, are permitted provided that the following conditions are
6 met:
7 1. Redistributions of source code must retain the above copyright
8 notice, this list of conditions and the following disclaimer.
9 2. Redistributions in binary form must reproduce the above copyright
10 notice, this list of conditions and the following disclaimer in
11 the documentation and/or other materials provided with the
12 distribution.
13 3. Neither the name of the OpenBLAS project nor the names of
14 its contributors may be used to endorse or promote products
15 derived from this software without specific prior written permission.
16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
20 LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21 DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22 SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23 CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24 OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
25 USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 *****************************************************************************/
27 
28 #include <immintrin.h>
29 #include "common.h"
30 #include <stdio.h>
31 #include <memory.h>
32 
33 #define DECLARE_RESULT_512(M, N) __m512d result##M##N = _mm512_setzero_pd()
34 #define MATMUL_512(M, N) result##M##N = _mm512_fmadd_pd(Aval##M, Bval##N, result##M##N)
35 
36 #define LOAD_KA_512(M, N) __m512d Aval##M = _mm512_loadu_pd(&A[(i + M)*lda + k]);
37 #define LOAD_KB_512(M, N) __m512d Bval##N = _mm512_loadu_pd(&B[(j + N)*ldb + k])
38 #define MASK_LOAD_KA_512(M, N) __m512d Aval##M = _mm512_maskz_loadu_pd(mask, &A[(i + M)*lda + k])
39 #define MASK_LOAD_KB_512(M, N) __m512d Bval##N = _mm512_maskz_loadu_pd(mask, &B[(j + N)*ldb + k])
40 
41 #define REDUCE_4(rr0, rr1, rr2, rr3) \
42 	__m512d r0, r1, r2, r3, t0, t1, t2, t3;\
43 	r0 = _mm512_unpacklo_pd(rr0, rr1); r1 = _mm512_unpackhi_pd(rr0, rr1); \
44 	r2 = _mm512_unpacklo_pd(rr2, rr3); r3 = _mm512_unpackhi_pd(rr2, rr3); \
45 	t0 = _mm512_permutex2var_pd(r0, idx_lo, r2); t1 = _mm512_permutex2var_pd(r1, idx_lo, r3); \
46 	t2 = _mm512_permutex2var_pd(r0, idx_hi, r2); t3 = _mm512_permutex2var_pd(r1, idx_hi, r3); \
47 	r0 = _mm512_add_pd(t0, t1); r1 = _mm512_add_pd(t2, t3); t0 = _mm512_add_pd(r0, r1); \
48 	__m256d s0, s1; \
49 	s0 = _mm512_extractf64x4_pd(t0, 0); s1 = _mm512_extractf64x4_pd(t0, 1); \
50 	s0 = _mm256_add_pd(s0, s1); s0 = _mm256_mul_pd(alpha_256, s0);
51 
52 #define REDUCE_M4(N) REDUCE_4(result0##N, result1##N, result2##N, result3##N)
53 #define REDUCE_N4(M) REDUCE_4(result##M##0, result##M##1, result##M##2, result##M##3)
54 
55 #if defined(B0)
56 #define STORE_REDUCE(M, N) C[(j+N)*ldc + i + M] = alpha * _mm512_reduce_add_pd(result##M##N)
57 #define STORE_M4(N, s0) _mm256_storeu_pd(&C[(j + N)*ldc + i], s0);
58 #define STORE_N4(M, s0) _mm256_i64scatter_pd(&C[j*ldc + i + M], vindex_n, s0, 8);
59 #else
60 #define STORE_REDUCE(M, N) C[(j+N)*ldc + i + M] = alpha * _mm512_reduce_add_pd(result##M##N) + beta * C[(j+N)*ldc + i + M]
61 #define STORE_M4(N, s0) \
62 	asm("vfmadd231pd (%1), %2, %0": "+v"(s0):"r"(&C[(j + N)*ldc + i]), "v"(beta_256)); \
63 	_mm256_storeu_pd(&C[(j + N)*ldc + i], s0);
64 
65 #define STORE_N4(M, s0) \
66 	s0 = _mm256_fmadd_pd(_mm256_i64gather_pd(&C[j*ldc + i + M], vindex_n, 8), beta_256, s0); \
67 	_mm256_i64scatter_pd(&C[j*ldc + i + M], vindex_n, s0, 8);
68 #endif
69 #define STORE_REDUCE_M4(N) {\
70 	REDUCE_M4(N) \
71 	STORE_M4(N, s0) \
72 }
73 #define STORE_REDUCE_N4(M) {\
74 	REDUCE_N4(M) \
75 	STORE_N4(M, s0) \
76 }
77 
78 
79 #if defined(B0)
CNAME(BLASLONG M,BLASLONG N,BLASLONG K,FLOAT * A,BLASLONG lda,FLOAT alpha,FLOAT * B,BLASLONG ldb,FLOAT * C,BLASLONG ldc)80 int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT * C, BLASLONG ldc)
81 #else
82 int CNAME(BLASLONG M, BLASLONG N, BLASLONG K, FLOAT * A, BLASLONG lda, FLOAT alpha, FLOAT * B, BLASLONG ldb, FLOAT beta, FLOAT * C, BLASLONG ldc)
83 #endif
84 {
85 	// column major
86 	BLASLONG i, j, k;
87 
88 	BLASLONG m4 = M & ~3;
89 	BLASLONG m2 = M & ~1;
90 
91 	BLASLONG n4 = N & ~3;
92 	BLASLONG n2 = N & ~1;
93 
94 	BLASLONG k8 = K & ~7;
95 
96 	__mmask8 mask;
97 
98 	__m256i vindex_n = _mm256_set_epi64x(ldc*3, ldc*2, ldc, 0);
99 	__m256d alpha_256 = _mm256_broadcast_sd(&alpha);
100 #if !defined(B0)
101 	__m256d beta_256 = _mm256_broadcast_sd(&beta);
102 #endif
103 
104 	long long permute_table[] = {
105 		0, 1, 0|8, 1|8, 4, 5, 4|8, 5|8,
106 		2, 3, 2|8, 3|8, 6, 7, 6|8, 7|8,
107 	};
108 	__m512i idx_lo = _mm512_loadu_si512(permute_table);
109 	__m512i idx_hi = _mm512_loadu_si512(permute_table + 8);
110 
111 	for (i = 0; i < m4; i += 4) {
112 		for (j = 0; j < n4; j += 4) {
113 			DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0);
114 			DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1);
115 			DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2); DECLARE_RESULT_512(2, 2); DECLARE_RESULT_512(3, 2);
116 			DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3); DECLARE_RESULT_512(2, 3); DECLARE_RESULT_512(3, 3);
117 			for (k = 0; k < k8; k += 8) {
118 				LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x);
119 				LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3);
120 
121 				MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0);
122 				MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1);
123 				MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2);
124 				MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3);
125 			}
126 			int remains = K - k;
127 			if (remains) {
128 				mask = (1UL << remains) - 1;
129 				MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x);
130 				MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3);
131 
132 				MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0);
133 				MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1);
134 				MATMUL_512(0, 2); MATMUL_512(1, 2); MATMUL_512(2, 2); MATMUL_512(3, 2);
135 				MATMUL_512(0, 3); MATMUL_512(1, 3); MATMUL_512(2, 3); MATMUL_512(3, 3);
136 			}
137 			STORE_REDUCE_M4(0); STORE_REDUCE_M4(1); STORE_REDUCE_M4(2); STORE_REDUCE_M4(3);
138 		}
139 		for (; j < n2; j += 2) {
140 			DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0);
141 			DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1); DECLARE_RESULT_512(2, 1); DECLARE_RESULT_512(3, 1);
142 			for (k = 0; k < k8; k += 8) {
143 				LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x);
144 				LOAD_KB_512(x, 0); LOAD_KB_512(x, 1);
145 
146 				MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0);
147 				MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1);
148 			}
149 			int remains = K - k;
150 			if (remains) {
151 				mask = (1UL << remains) - 1;
152 				MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x);
153 				MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1);
154 
155 				MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0);
156 				MATMUL_512(0, 1); MATMUL_512(1, 1); MATMUL_512(2, 1); MATMUL_512(3, 1);
157 			}
158 			STORE_REDUCE_M4(0); STORE_REDUCE_M4(1);
159 		}
160 		for (; j < N; j += 1) {
161 			DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0); DECLARE_RESULT_512(2, 0); DECLARE_RESULT_512(3, 0);
162 			for (k = 0; k < k8; k += 8) {
163 				LOAD_KA_512(0, x); LOAD_KA_512(1, x); LOAD_KA_512(2, x); LOAD_KA_512(3, x);
164 				LOAD_KB_512(x, 0);
165 
166 				MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0);
167 			}
168 			int remains = K - k;
169 			if (remains) {
170 				mask = (1UL << remains) - 1;
171 				MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x); MASK_LOAD_KA_512(2, x); MASK_LOAD_KA_512(3, x);
172 				MASK_LOAD_KB_512(x, 0);
173 
174 				MATMUL_512(0, 0); MATMUL_512(1, 0); MATMUL_512(2, 0); MATMUL_512(3, 0);
175 			}
176 			STORE_REDUCE_M4(0);
177 		}
178 
179 	}
180 	for (; i < m2; i += 2) {
181 		for (j = 0; j < n4; j += 4) {
182 			DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0);
183 			DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1);
184 			DECLARE_RESULT_512(0, 2); DECLARE_RESULT_512(1, 2);
185 			DECLARE_RESULT_512(0, 3); DECLARE_RESULT_512(1, 3);
186 			for (k = 0; k < k8; k += 8) {
187 				LOAD_KA_512(0, x); LOAD_KA_512(1, x);
188 				LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3);
189 
190 				MATMUL_512(0, 0); MATMUL_512(1, 0);
191 				MATMUL_512(0, 1); MATMUL_512(1, 1);
192 				MATMUL_512(0, 2); MATMUL_512(1, 2);
193 				MATMUL_512(0, 3); MATMUL_512(1, 3);
194 			}
195 			int remains = K - k;
196 			if (remains) {
197 				mask = (1UL << remains) - 1;
198 				MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x);
199 				MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3);
200 
201 				MATMUL_512(0, 0); MATMUL_512(1, 0);
202 				MATMUL_512(0, 1); MATMUL_512(1, 1);
203 				MATMUL_512(0, 2); MATMUL_512(1, 2);
204 				MATMUL_512(0, 3); MATMUL_512(1, 3);
205 			}
206 			STORE_REDUCE_N4(0); STORE_REDUCE_N4(1);
207 		}
208 		for (; j < n2; j += 2) {
209 			DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0);
210 			DECLARE_RESULT_512(0, 1); DECLARE_RESULT_512(1, 1);
211 			for (k = 0; k < k8; k += 8) {
212 				LOAD_KA_512(0, x); LOAD_KA_512(1, x);
213 				LOAD_KB_512(x, 0); LOAD_KB_512(x, 1);
214 
215 				MATMUL_512(0, 0); MATMUL_512(1, 0);
216 				MATMUL_512(0, 1); MATMUL_512(1, 1);
217 			}
218 			int remains = K - k;
219 			if (remains) {
220 				mask = (1UL << remains) - 1;
221 				MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x);
222 				MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1);
223 
224 				MATMUL_512(0, 0); MATMUL_512(1, 0);
225 				MATMUL_512(0, 1); MATMUL_512(1, 1);
226 			}
227 			STORE_REDUCE(0, 0); STORE_REDUCE(1, 0);
228 			STORE_REDUCE(0, 1); STORE_REDUCE(1, 1);
229 
230 		}
231 		for (; j < N; j += 1) {
232 			DECLARE_RESULT_512(0, 0); DECLARE_RESULT_512(1, 0);
233 			for (k = 0; k < k8; k += 8) {
234 				LOAD_KA_512(0, x); LOAD_KA_512(1, x);
235 				LOAD_KB_512(x, 0);
236 
237 				MATMUL_512(0, 0); MATMUL_512(1, 0);
238 			}
239 			int remains = K - k;
240 			if (remains) {
241 				mask = (1UL << remains) - 1;
242 				MASK_LOAD_KA_512(0, x); MASK_LOAD_KA_512(1, x);
243 				MASK_LOAD_KB_512(x, 0);
244 
245 				MATMUL_512(0, 0); MATMUL_512(1, 0);
246 			}
247 			STORE_REDUCE(0, 0); STORE_REDUCE(1, 0);
248 		}
249 	}
250 	for (; i < M; i += 1) {
251 		for (j = 0; j < n4; j += 4) {
252 			DECLARE_RESULT_512(0, 0);
253 			DECLARE_RESULT_512(0, 1);
254 			DECLARE_RESULT_512(0, 2);
255 			DECLARE_RESULT_512(0, 3);
256 			for (k = 0; k < k8; k += 8) {
257 				LOAD_KA_512(0, x);
258 				LOAD_KB_512(x, 0); LOAD_KB_512(x, 1); LOAD_KB_512(x, 2); LOAD_KB_512(x, 3);
259 
260 				MATMUL_512(0, 0);
261 				MATMUL_512(0, 1);
262 				MATMUL_512(0, 2);
263 				MATMUL_512(0, 3);
264 			}
265 			int remains = K - k;
266 			if (remains) {
267 				mask = (1UL << remains) - 1;
268 				MASK_LOAD_KA_512(0, x);
269 				MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1); MASK_LOAD_KB_512(x, 2); MASK_LOAD_KB_512(x, 3);
270 
271 
272 				MATMUL_512(0, 0);
273 				MATMUL_512(0, 1);
274 				MATMUL_512(0, 2);
275 				MATMUL_512(0, 3);
276 			}
277 			STORE_REDUCE_N4(0);
278 		}
279 		for (; j < n2; j += 2) {
280 			DECLARE_RESULT_512(0, 0);
281 			DECLARE_RESULT_512(0, 1);
282 			for (k = 0; k < k8; k += 8) {
283 				LOAD_KA_512(0, x);
284 				LOAD_KB_512(x, 0); LOAD_KB_512(x, 1);
285 
286 				MATMUL_512(0, 0);
287 				MATMUL_512(0, 1);
288 			}
289 			int remains = K - k;
290 			if (remains) {
291 				mask = (1UL << remains) - 1;
292 				MASK_LOAD_KA_512(0, x);
293 				MASK_LOAD_KB_512(x, 0); MASK_LOAD_KB_512(x, 1);
294 
295 				MATMUL_512(0, 0);
296 				MATMUL_512(0, 1);
297 			}
298 			STORE_REDUCE(0, 0);
299 			STORE_REDUCE(0, 1);
300 
301 		}
302 		for (; j < N; j += 1) {
303 			DECLARE_RESULT_512(0, 0);
304 			for (k = 0; k < k8; k += 8) {
305 				LOAD_KA_512(0, x);
306 				LOAD_KB_512(x, 0);
307 
308 				MATMUL_512(0, 0);
309 			}
310 			int remains = K - k;
311 			if (remains) {
312 				mask = (1UL << remains) - 1;
313 				MASK_LOAD_KA_512(0, x);
314 				MASK_LOAD_KB_512(x, 0);
315 
316 				MATMUL_512(0, 0);
317 			}
318 			STORE_REDUCE(0, 0);
319 		}
320 	}
321 	return 0;
322 }
323