1 /***************************************************************************
2 Copyright (c) 2021, The OpenBLAS Project
3 All rights reserved.
4 Redistribution and use in source and binary forms, with or without
5 modification, are permitted provided that the following conditions are
6 met:
7 1. Redistributions of source code must retain the above copyright
8 notice, this list of conditions and the following disclaimer.
9 2. Redistributions in binary form must reproduce the above copyright
10 notice, this list of conditions and the following disclaimer in
11 the documentation and/or other materials provided with the
12 distribution.
13 3. Neither the name of the OpenBLAS project nor the names of
14 its contributors may be used to endorse or promote products
15 derived from this software without specific prior written permission.
16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
20 LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21 DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22 SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23 CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24 OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
25 USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 *****************************************************************************/
27 
28 #include <stdio.h>
29 #include <immintrin.h>
30 #include "common.h"
31 
32 #define _MM512_SHUFFLE_i32(result, in1, in2, imm8) \
33 	asm("vshufps %3, %2, %1, %0": "=v"(result): "v"(in1), "v"(in2), "N"(imm8))
34 
35 #define REORDER_8x32(t0, t1, t2, t3, t4, t5, t6, t7) { \
36 	__m512i v; \
37 	t0 = _mm512_unpacklo_epi32(r0, r1); \
38 	t1 = _mm512_unpackhi_epi32(r0, r1); \
39 	t2 = _mm512_unpacklo_epi32(r2, r3); \
40 	t3 = _mm512_unpackhi_epi32(r2, r3); \
41 	t4 = _mm512_unpacklo_epi32(r4, r5); \
42 	t5 = _mm512_unpackhi_epi32(r4, r5); \
43 	t6 = _mm512_unpacklo_epi32(r6, r7); \
44 	t7 = _mm512_unpackhi_epi32(r6, r7); \
45 	_MM512_SHUFFLE_i32(v, t0, t2, 0x4E); \
46 	r0 = _mm512_mask_blend_epi32(kc, t0, v); \
47 	r1 = _mm512_mask_blend_epi32(k3, t2, v); \
48 	_MM512_SHUFFLE_i32(v, t1, t3, 0x4E); \
49 	r2 = _mm512_mask_blend_epi32(kc, t1, v); \
50 	r3 = _mm512_mask_blend_epi32(k3, t3, v); \
51 	_MM512_SHUFFLE_i32(v, t4, t6, 0x4E); \
52 	r4 = _mm512_mask_blend_epi32(kc, t4, v); \
53 	r5 = _mm512_mask_blend_epi32(k3, t6, v); \
54 	_MM512_SHUFFLE_i32(v, t5, t7, 0x4E); \
55 	r6 = _mm512_mask_blend_epi32(kc, t5, v); \
56 	r7 = _mm512_mask_blend_epi32(k3, t7, v); \
57 	t0 = _mm512_permutex2var_epi32(r0, idx_lo, r4); \
58 	t1 = _mm512_permutex2var_epi32(r1, idx_lo, r5); \
59 	t2 = _mm512_permutex2var_epi32(r2, idx_lo, r6); \
60 	t3 = _mm512_permutex2var_epi32(r3, idx_lo, r7); \
61 	t4 = _mm512_permutex2var_epi32(r0, idx_hi, r4); \
62 	t5 = _mm512_permutex2var_epi32(r1, idx_hi, r5); \
63 	t6 = _mm512_permutex2var_epi32(r2, idx_hi, r6); \
64 	t7 = _mm512_permutex2var_epi32(r3, idx_hi, r7); \
65 }
66 
67 #define STORE_512_LO(x) \
68 	v = _mm512_permutex2var_epi64(t0##x, idx_lo2, t1##x); \
69 	_mm512_storeu_si512(boffset0 + x*32, v);
70 
71 #define STORE_512_HI(x) \
72 	v = _mm512_permutex2var_epi64(t0##x, idx_hi2, t1##x); \
73 	_mm512_storeu_si512(boffset0 + (x + 8)*32, v);
74 
75 #define MASK_STORE_512_LO(x) \
76 	v = _mm512_permutex2var_epi64(t0##x, idx_lo2, t1##x); \
77 	_mm512_mask_storeu_epi32(boffset0 + 2*x*remain_n, nmask, v);
78 
79 #define MASK_STORE_512_HI(x) \
80 	v = _mm512_permutex2var_epi64(t0##x, idx_hi2, t1##x); \
81 	_mm512_mask_storeu_epi32(boffset0 + 2*(x + 8)*remain_n, nmask, v);
82 
83 #define STORE_512(x, y) {\
84 	__m512i v; \
85 	if (x == 0) { STORE_512_LO(y); } \
86 	else { STORE_512_HI(y); } \
87 }
88 
89 #define MASK_STORE_512(x, y) {\
90 	__m512i v; \
91 	if (x == 0) { MASK_STORE_512_LO(y); } \
92 	else { MASK_STORE_512_HI(y); } \
93 }
94 
95 #define SET_TAIL(y, x) {\
96 	if (y == 0) tail = _mm512_permutex2var_epi64(t0##x, idx_lo2, t1##x); \
97 	else tail = _mm512_permutex2var_epi64(t0##x, idx_hi2, t1##x); \
98 }
99 
100 #define GET_TAIL() \
101 	switch (n_store + 1) { \
102 		case 16: SET_TAIL(1, 7); break; \
103 		case 15: SET_TAIL(1, 6); break; \
104 		case 14: SET_TAIL(1, 5); break; \
105 		case 13: SET_TAIL(1, 4); break; \
106 		case 12: SET_TAIL(1, 3); break; \
107 		case 11: SET_TAIL(1, 2); break; \
108 		case 10: SET_TAIL(1, 1); break; \
109 		case  9: SET_TAIL(1, 0); break; \
110 		case  8: SET_TAIL(0, 7); break; \
111 		case  7: SET_TAIL(0, 6); break; \
112 		case  6: SET_TAIL(0, 5); break; \
113 		case  5: SET_TAIL(0, 4); break; \
114 		case  4: SET_TAIL(0, 3); break; \
115 		case  3: SET_TAIL(0, 2); break; \
116 		case  2: SET_TAIL(0, 1); break; \
117 		case  1: SET_TAIL(0, 0); break; \
118 	}
119 
120 
CNAME(BLASLONG m,BLASLONG n,IFLOAT * a,BLASLONG lda,IFLOAT * b)121 int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b){
122 	BLASLONG i, j;
123 
124 	IFLOAT *boffset0;
125 	IFLOAT *aoffset;
126 	IFLOAT *aoffset00, *aoffset01, *aoffset02, *aoffset03, *aoffset04, *aoffset05, *aoffset06, *aoffset07;
127 	IFLOAT *aoffset10, *aoffset11, *aoffset12, *aoffset13, *aoffset14, *aoffset15, *aoffset16, *aoffset17;
128 	aoffset = a;
129 	boffset0   = b;
130 
131 	BLASLONG n16 = n & ~15;
132 	BLASLONG m32 = m & ~31;
133 
134 	int permute_table[] = {
135 		0x0, 0x1, 0x2, 0x3, 0x10, 0x11, 0x12, 0x13, 0x8, 0x9, 0xa, 0xb, 0x18, 0x19, 0x1a, 0x1b,
136 		0x4, 0x5, 0x6, 0x7, 0x14, 0x15, 0x16, 0x17, 0xc, 0xd, 0xe, 0xf, 0x1c, 0x1d, 0x1e, 0x1f,
137 	};
138 	u_int64_t permute_table2[] = {
139 		0x00, 0x01, 0x02, 0x03, 8|0x0, 8|0x1, 8|0x2, 8|0x3,
140 		0x04, 0x05, 0x06, 0x07, 8|0x4, 8|0x5, 8|0x6, 8|0x7,
141 	};
142 	__m512i idx_lo = _mm512_loadu_si512(permute_table);
143 	__m512i idx_hi = _mm512_loadu_si512(permute_table + 16);
144 	__m512i idx_lo2 = _mm512_loadu_si512(permute_table2);
145 	__m512i idx_hi2 = _mm512_loadu_si512(permute_table2 + 8);
146 	__mmask16 kc = 0xcccc;
147 	__mmask16 k3 = 0x3333;
148 	__m512i r0, r1, r2, r3, r4, r5, r6, r7;
149 	__m512i t00, t01, t02, t03, t04, t05, t06, t07;
150 	__m512i t10, t11, t12, t13, t14, t15, t16, t17;
151 
152 	for (j = 0; j < n16; j += 16) {
153 		aoffset00 = aoffset;
154 		aoffset01 = aoffset00 + lda;
155 		aoffset02 = aoffset01 + lda;
156 		aoffset03 = aoffset02 + lda;
157 		aoffset04 = aoffset03 + lda;
158 		aoffset05 = aoffset04 + lda;
159 		aoffset06 = aoffset05 + lda;
160 		aoffset07 = aoffset06 + lda;
161 		aoffset10 = aoffset07 + lda;
162 		aoffset11 = aoffset10 + lda;
163 		aoffset12 = aoffset11 + lda;
164 		aoffset13 = aoffset12 + lda;
165 		aoffset14 = aoffset13 + lda;
166 		aoffset15 = aoffset14 + lda;
167 		aoffset16 = aoffset15 + lda;
168 		aoffset17 = aoffset16 + lda;
169 		aoffset += 16 * lda;
170 		for (i = 0; i < m32; i += 32) {
171 			r0 = _mm512_loadu_si512(aoffset00 + i);
172 			r1 = _mm512_loadu_si512(aoffset01 + i);
173 			r2 = _mm512_loadu_si512(aoffset02 + i);
174 			r3 = _mm512_loadu_si512(aoffset03 + i);
175 			r4 = _mm512_loadu_si512(aoffset04 + i);
176 			r5 = _mm512_loadu_si512(aoffset05 + i);
177 			r6 = _mm512_loadu_si512(aoffset06 + i);
178 			r7 = _mm512_loadu_si512(aoffset07 + i);
179 			REORDER_8x32(t00, t01, t02, t03, t04, t05, t06, t07);
180 			r0 = _mm512_loadu_si512(aoffset10 + i);
181 			r1 = _mm512_loadu_si512(aoffset11 + i);
182 			r2 = _mm512_loadu_si512(aoffset12 + i);
183 			r3 = _mm512_loadu_si512(aoffset13 + i);
184 			r4 = _mm512_loadu_si512(aoffset14 + i);
185 			r5 = _mm512_loadu_si512(aoffset15 + i);
186 			r6 = _mm512_loadu_si512(aoffset16 + i);
187 			r7 = _mm512_loadu_si512(aoffset17 + i);
188 			REORDER_8x32(t10, t11, t12, t13, t14, t15, t16, t17);
189 			STORE_512(0, 0); STORE_512(0, 1); STORE_512(0, 2); STORE_512(0, 3);
190 			STORE_512(0, 4); STORE_512(0, 5); STORE_512(0, 6); STORE_512(0, 7);
191 			STORE_512(1, 0); STORE_512(1, 1); STORE_512(1, 2); STORE_512(1, 3);
192 			STORE_512(1, 4); STORE_512(1, 5); STORE_512(1, 6); STORE_512(1, 7);
193 			boffset0 += 16 * 32;
194 		}
195 		if (i < m) {
196 			int remain_m = m - i;
197 			__mmask32 mmask = (1UL << remain_m) - 1;
198 			r0 = _mm512_maskz_loadu_epi16(mmask, aoffset00 + i);
199 			r1 = _mm512_maskz_loadu_epi16(mmask, aoffset01 + i);
200 			r2 = _mm512_maskz_loadu_epi16(mmask, aoffset02 + i);
201 			r3 = _mm512_maskz_loadu_epi16(mmask, aoffset03 + i);
202 			r4 = _mm512_maskz_loadu_epi16(mmask, aoffset04 + i);
203 			r5 = _mm512_maskz_loadu_epi16(mmask, aoffset05 + i);
204 			r6 = _mm512_maskz_loadu_epi16(mmask, aoffset06 + i);
205 			r7 = _mm512_maskz_loadu_epi16(mmask, aoffset07 + i);
206 			REORDER_8x32(t00, t01, t02, t03, t04, t05, t06, t07);
207 			r0 = _mm512_maskz_loadu_epi16(mmask, aoffset10 + i);
208 			r1 = _mm512_maskz_loadu_epi16(mmask, aoffset11 + i);
209 			r2 = _mm512_maskz_loadu_epi16(mmask, aoffset12 + i);
210 			r3 = _mm512_maskz_loadu_epi16(mmask, aoffset13 + i);
211 			r4 = _mm512_maskz_loadu_epi16(mmask, aoffset14 + i);
212 			r5 = _mm512_maskz_loadu_epi16(mmask, aoffset15 + i);
213 			r6 = _mm512_maskz_loadu_epi16(mmask, aoffset16 + i);
214 			r7 = _mm512_maskz_loadu_epi16(mmask, aoffset17 + i);
215 			REORDER_8x32(t10, t11, t12, t13, t14, t15, t16, t17);
216 			int n_store = remain_m/2;
217 			switch (n_store) {
218 				case 15: STORE_512(1, 6);
219 				case 14: STORE_512(1, 5);
220 				case 13: STORE_512(1, 4);
221 				case 12: STORE_512(1, 3);
222 				case 11: STORE_512(1, 2);
223 				case 10: STORE_512(1, 1);
224 				case  9: STORE_512(1, 0);
225 				case  8: STORE_512(0, 7);
226 				case  7: STORE_512(0, 6);
227 				case  6: STORE_512(0, 5);
228 				case  5: STORE_512(0, 4);
229 				case  4: STORE_512(0, 3);
230 				case  3: STORE_512(0, 2);
231 				case  2: STORE_512(0, 1);
232 				case  1: STORE_512(0, 0);
233 			}
234 			boffset0 += n_store * 32;
235 			if (m & 0x1) {
236 				__m512i tail;
237 				GET_TAIL();
238 				_mm256_storeu_si256((void *)boffset0, _mm512_cvtepi32_epi16(tail));
239 				boffset0 += 16;
240 			}
241 		}
242 
243 	}
244 	if (j < n) {
245 		int remain_n = n - j;
246 		__mmask16 nmask = (1UL << remain_n) - 1;
247 		int load0, load1;
248 		if (remain_n > 8) {
249 			load0 = 8;
250 			load1 = remain_n - 8;
251 		} else {
252 			load0 = remain_n;
253 			load1 = 0;
254 		}
255 		aoffset00 = aoffset;
256 		aoffset01 = aoffset00 + lda;
257 		aoffset02 = aoffset01 + lda;
258 		aoffset03 = aoffset02 + lda;
259 		aoffset04 = aoffset03 + lda;
260 		aoffset05 = aoffset04 + lda;
261 		aoffset06 = aoffset05 + lda;
262 		aoffset07 = aoffset06 + lda;
263 		aoffset10 = aoffset07 + lda;
264 		aoffset11 = aoffset10 + lda;
265 		aoffset12 = aoffset11 + lda;
266 		aoffset13 = aoffset12 + lda;
267 		aoffset14 = aoffset13 + lda;
268 		aoffset15 = aoffset14 + lda;
269 		aoffset16 = aoffset15 + lda;
270 		aoffset17 = aoffset16 + lda;
271 		aoffset += 16 * lda;
272 		for (i = 0; i < m32; i += 32) {
273 			switch (load0) {
274 				case 8: r7 = _mm512_loadu_si512(aoffset07 + i);
275 				case 7: r6 = _mm512_loadu_si512(aoffset06 + i);
276 				case 6: r5 = _mm512_loadu_si512(aoffset05 + i);
277 				case 5: r4 = _mm512_loadu_si512(aoffset04 + i);
278 				case 4: r3 = _mm512_loadu_si512(aoffset03 + i);
279 				case 3: r2 = _mm512_loadu_si512(aoffset02 + i);
280 				case 2: r1 = _mm512_loadu_si512(aoffset01 + i);
281 				case 1: r0 = _mm512_loadu_si512(aoffset00 + i);
282 			}
283 			REORDER_8x32(t00, t01, t02, t03, t04, t05, t06, t07);
284 			switch (load1) {
285 				case 8: r7 = _mm512_loadu_si512(aoffset17 + i);
286 				case 7: r6 = _mm512_loadu_si512(aoffset16 + i);
287 				case 6: r5 = _mm512_loadu_si512(aoffset15 + i);
288 				case 5: r4 = _mm512_loadu_si512(aoffset14 + i);
289 				case 4: r3 = _mm512_loadu_si512(aoffset13 + i);
290 				case 3: r2 = _mm512_loadu_si512(aoffset12 + i);
291 				case 2: r1 = _mm512_loadu_si512(aoffset11 + i);
292 				case 1: r0 = _mm512_loadu_si512(aoffset10 + i);
293 			}
294 			REORDER_8x32(t10, t11, t12, t13, t14, t15, t16, t17);
295 			MASK_STORE_512(0, 0); MASK_STORE_512(0, 1); MASK_STORE_512(0, 2); MASK_STORE_512(0, 3);
296 			MASK_STORE_512(0, 4); MASK_STORE_512(0, 5); MASK_STORE_512(0, 6); MASK_STORE_512(0, 7);
297 			MASK_STORE_512(1, 0); MASK_STORE_512(1, 1); MASK_STORE_512(1, 2); MASK_STORE_512(1, 3);
298 			MASK_STORE_512(1, 4); MASK_STORE_512(1, 5); MASK_STORE_512(1, 6); MASK_STORE_512(1, 7);
299 			boffset0 += remain_n * 32;
300 		}
301 		if (i < m) {
302 			int remain_m = m - i;
303 			__mmask32 mmask = (1UL << remain_m) - 1;
304 			switch (load0) {
305 				case 8: r7 = _mm512_maskz_loadu_epi16(mmask, aoffset07 + i);
306 				case 7: r6 = _mm512_maskz_loadu_epi16(mmask, aoffset06 + i);
307 				case 6: r5 = _mm512_maskz_loadu_epi16(mmask, aoffset05 + i);
308 				case 5: r4 = _mm512_maskz_loadu_epi16(mmask, aoffset04 + i);
309 				case 4: r3 = _mm512_maskz_loadu_epi16(mmask, aoffset03 + i);
310 				case 3: r2 = _mm512_maskz_loadu_epi16(mmask, aoffset02 + i);
311 				case 2: r1 = _mm512_maskz_loadu_epi16(mmask, aoffset01 + i);
312 				case 1: r0 = _mm512_maskz_loadu_epi16(mmask, aoffset00 + i);
313 			}
314 			REORDER_8x32(t00, t01, t02, t03, t04, t05, t06, t07);
315 			switch (load1) {
316 				case 8: r7 = _mm512_maskz_loadu_epi16(mmask, aoffset17 + i);
317 				case 7: r6 = _mm512_maskz_loadu_epi16(mmask, aoffset16 + i);
318 				case 6: r5 = _mm512_maskz_loadu_epi16(mmask, aoffset15 + i);
319 				case 5: r4 = _mm512_maskz_loadu_epi16(mmask, aoffset14 + i);
320 				case 4: r3 = _mm512_maskz_loadu_epi16(mmask, aoffset13 + i);
321 				case 3: r2 = _mm512_maskz_loadu_epi16(mmask, aoffset12 + i);
322 				case 2: r1 = _mm512_maskz_loadu_epi16(mmask, aoffset11 + i);
323 				case 1: r0 = _mm512_maskz_loadu_epi16(mmask, aoffset10 + i);
324 			}
325 			REORDER_8x32(t10, t11, t12, t13, t14, t15, t16, t17);
326 			int n_store = remain_m/2;
327 			switch (n_store) {
328 				case 15: MASK_STORE_512(1, 6);
329 				case 14: MASK_STORE_512(1, 5);
330 				case 13: MASK_STORE_512(1, 4);
331 				case 12: MASK_STORE_512(1, 3);
332 				case 11: MASK_STORE_512(1, 2);
333 				case 10: MASK_STORE_512(1, 1);
334 				case  9: MASK_STORE_512(1, 0);
335 				case  8: MASK_STORE_512(0, 7);
336 				case  7: MASK_STORE_512(0, 6);
337 				case  6: MASK_STORE_512(0, 5);
338 				case  5: MASK_STORE_512(0, 4);
339 				case  4: MASK_STORE_512(0, 3);
340 				case  3: MASK_STORE_512(0, 2);
341 				case  2: MASK_STORE_512(0, 1);
342 				case  1: MASK_STORE_512(0, 0);
343 			}
344 			boffset0 += n_store * remain_n * 2;
345 			if (m & 0x1) {
346 				__m512i tail;
347 				GET_TAIL();
348 				_mm256_mask_storeu_epi16((void *)boffset0, nmask, _mm512_cvtepi32_epi16(tail));
349 			}
350 		}
351 	}
352 	return 0;
353 }
354