1 /***************************************************************************
2 Copyright (c) 2021, The OpenBLAS Project
3 All rights reserved.
4 Redistribution and use in source and binary forms, with or without
5 modification, are permitted provided that the following conditions are
6 met:
7 1. Redistributions of source code must retain the above copyright
8 notice, this list of conditions and the following disclaimer.
9 2. Redistributions in binary form must reproduce the above copyright
10 notice, this list of conditions and the following disclaimer in
11 the documentation and/or other materials provided with the
12 distribution.
13 3. Neither the name of the OpenBLAS project nor the names of
14 its contributors may be used to endorse or promote products
15 derived from this software without specific prior written permission.
16 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
17 AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18 IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE
19 ARE DISCLAIMED. IN NO EVENT SHALL THE OPENBLAS PROJECT OR CONTRIBUTORS BE
20 LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
21 DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
22 SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
23 CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
24 OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
25 USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
26 *****************************************************************************/
27
28 #include <stdio.h>
29 #include <immintrin.h>
30 #include "common.h"
31
32 #define _MM512_SHUFFLE_i32(result, in1, in2, imm8) \
33 asm("vshufps %3, %2, %1, %0": "=v"(result): "v"(in1), "v"(in2), "N"(imm8))
34
35 #define REORDER_8x32(t0, t1, t2, t3, t4, t5, t6, t7) { \
36 __m512i v; \
37 t0 = _mm512_unpacklo_epi32(r0, r1); \
38 t1 = _mm512_unpackhi_epi32(r0, r1); \
39 t2 = _mm512_unpacklo_epi32(r2, r3); \
40 t3 = _mm512_unpackhi_epi32(r2, r3); \
41 t4 = _mm512_unpacklo_epi32(r4, r5); \
42 t5 = _mm512_unpackhi_epi32(r4, r5); \
43 t6 = _mm512_unpacklo_epi32(r6, r7); \
44 t7 = _mm512_unpackhi_epi32(r6, r7); \
45 _MM512_SHUFFLE_i32(v, t0, t2, 0x4E); \
46 r0 = _mm512_mask_blend_epi32(kc, t0, v); \
47 r1 = _mm512_mask_blend_epi32(k3, t2, v); \
48 _MM512_SHUFFLE_i32(v, t1, t3, 0x4E); \
49 r2 = _mm512_mask_blend_epi32(kc, t1, v); \
50 r3 = _mm512_mask_blend_epi32(k3, t3, v); \
51 _MM512_SHUFFLE_i32(v, t4, t6, 0x4E); \
52 r4 = _mm512_mask_blend_epi32(kc, t4, v); \
53 r5 = _mm512_mask_blend_epi32(k3, t6, v); \
54 _MM512_SHUFFLE_i32(v, t5, t7, 0x4E); \
55 r6 = _mm512_mask_blend_epi32(kc, t5, v); \
56 r7 = _mm512_mask_blend_epi32(k3, t7, v); \
57 t0 = _mm512_permutex2var_epi32(r0, idx_lo, r4); \
58 t1 = _mm512_permutex2var_epi32(r1, idx_lo, r5); \
59 t2 = _mm512_permutex2var_epi32(r2, idx_lo, r6); \
60 t3 = _mm512_permutex2var_epi32(r3, idx_lo, r7); \
61 t4 = _mm512_permutex2var_epi32(r0, idx_hi, r4); \
62 t5 = _mm512_permutex2var_epi32(r1, idx_hi, r5); \
63 t6 = _mm512_permutex2var_epi32(r2, idx_hi, r6); \
64 t7 = _mm512_permutex2var_epi32(r3, idx_hi, r7); \
65 }
66
67 #define STORE_512_LO(x) \
68 v = _mm512_permutex2var_epi64(t0##x, idx_lo2, t1##x); \
69 _mm512_storeu_si512(boffset0 + x*32, v);
70
71 #define STORE_512_HI(x) \
72 v = _mm512_permutex2var_epi64(t0##x, idx_hi2, t1##x); \
73 _mm512_storeu_si512(boffset0 + (x + 8)*32, v);
74
75 #define MASK_STORE_512_LO(x) \
76 v = _mm512_permutex2var_epi64(t0##x, idx_lo2, t1##x); \
77 _mm512_mask_storeu_epi32(boffset0 + 2*x*remain_n, nmask, v);
78
79 #define MASK_STORE_512_HI(x) \
80 v = _mm512_permutex2var_epi64(t0##x, idx_hi2, t1##x); \
81 _mm512_mask_storeu_epi32(boffset0 + 2*(x + 8)*remain_n, nmask, v);
82
83 #define STORE_512(x, y) {\
84 __m512i v; \
85 if (x == 0) { STORE_512_LO(y); } \
86 else { STORE_512_HI(y); } \
87 }
88
89 #define MASK_STORE_512(x, y) {\
90 __m512i v; \
91 if (x == 0) { MASK_STORE_512_LO(y); } \
92 else { MASK_STORE_512_HI(y); } \
93 }
94
95 #define SET_TAIL(y, x) {\
96 if (y == 0) tail = _mm512_permutex2var_epi64(t0##x, idx_lo2, t1##x); \
97 else tail = _mm512_permutex2var_epi64(t0##x, idx_hi2, t1##x); \
98 }
99
100 #define GET_TAIL() \
101 switch (n_store + 1) { \
102 case 16: SET_TAIL(1, 7); break; \
103 case 15: SET_TAIL(1, 6); break; \
104 case 14: SET_TAIL(1, 5); break; \
105 case 13: SET_TAIL(1, 4); break; \
106 case 12: SET_TAIL(1, 3); break; \
107 case 11: SET_TAIL(1, 2); break; \
108 case 10: SET_TAIL(1, 1); break; \
109 case 9: SET_TAIL(1, 0); break; \
110 case 8: SET_TAIL(0, 7); break; \
111 case 7: SET_TAIL(0, 6); break; \
112 case 6: SET_TAIL(0, 5); break; \
113 case 5: SET_TAIL(0, 4); break; \
114 case 4: SET_TAIL(0, 3); break; \
115 case 3: SET_TAIL(0, 2); break; \
116 case 2: SET_TAIL(0, 1); break; \
117 case 1: SET_TAIL(0, 0); break; \
118 }
119
120
CNAME(BLASLONG m,BLASLONG n,IFLOAT * a,BLASLONG lda,IFLOAT * b)121 int CNAME(BLASLONG m, BLASLONG n, IFLOAT *a, BLASLONG lda, IFLOAT *b){
122 BLASLONG i, j;
123
124 IFLOAT *boffset0;
125 IFLOAT *aoffset;
126 IFLOAT *aoffset00, *aoffset01, *aoffset02, *aoffset03, *aoffset04, *aoffset05, *aoffset06, *aoffset07;
127 IFLOAT *aoffset10, *aoffset11, *aoffset12, *aoffset13, *aoffset14, *aoffset15, *aoffset16, *aoffset17;
128 aoffset = a;
129 boffset0 = b;
130
131 BLASLONG n16 = n & ~15;
132 BLASLONG m32 = m & ~31;
133
134 int permute_table[] = {
135 0x0, 0x1, 0x2, 0x3, 0x10, 0x11, 0x12, 0x13, 0x8, 0x9, 0xa, 0xb, 0x18, 0x19, 0x1a, 0x1b,
136 0x4, 0x5, 0x6, 0x7, 0x14, 0x15, 0x16, 0x17, 0xc, 0xd, 0xe, 0xf, 0x1c, 0x1d, 0x1e, 0x1f,
137 };
138 u_int64_t permute_table2[] = {
139 0x00, 0x01, 0x02, 0x03, 8|0x0, 8|0x1, 8|0x2, 8|0x3,
140 0x04, 0x05, 0x06, 0x07, 8|0x4, 8|0x5, 8|0x6, 8|0x7,
141 };
142 __m512i idx_lo = _mm512_loadu_si512(permute_table);
143 __m512i idx_hi = _mm512_loadu_si512(permute_table + 16);
144 __m512i idx_lo2 = _mm512_loadu_si512(permute_table2);
145 __m512i idx_hi2 = _mm512_loadu_si512(permute_table2 + 8);
146 __mmask16 kc = 0xcccc;
147 __mmask16 k3 = 0x3333;
148 __m512i r0, r1, r2, r3, r4, r5, r6, r7;
149 __m512i t00, t01, t02, t03, t04, t05, t06, t07;
150 __m512i t10, t11, t12, t13, t14, t15, t16, t17;
151
152 for (j = 0; j < n16; j += 16) {
153 aoffset00 = aoffset;
154 aoffset01 = aoffset00 + lda;
155 aoffset02 = aoffset01 + lda;
156 aoffset03 = aoffset02 + lda;
157 aoffset04 = aoffset03 + lda;
158 aoffset05 = aoffset04 + lda;
159 aoffset06 = aoffset05 + lda;
160 aoffset07 = aoffset06 + lda;
161 aoffset10 = aoffset07 + lda;
162 aoffset11 = aoffset10 + lda;
163 aoffset12 = aoffset11 + lda;
164 aoffset13 = aoffset12 + lda;
165 aoffset14 = aoffset13 + lda;
166 aoffset15 = aoffset14 + lda;
167 aoffset16 = aoffset15 + lda;
168 aoffset17 = aoffset16 + lda;
169 aoffset += 16 * lda;
170 for (i = 0; i < m32; i += 32) {
171 r0 = _mm512_loadu_si512(aoffset00 + i);
172 r1 = _mm512_loadu_si512(aoffset01 + i);
173 r2 = _mm512_loadu_si512(aoffset02 + i);
174 r3 = _mm512_loadu_si512(aoffset03 + i);
175 r4 = _mm512_loadu_si512(aoffset04 + i);
176 r5 = _mm512_loadu_si512(aoffset05 + i);
177 r6 = _mm512_loadu_si512(aoffset06 + i);
178 r7 = _mm512_loadu_si512(aoffset07 + i);
179 REORDER_8x32(t00, t01, t02, t03, t04, t05, t06, t07);
180 r0 = _mm512_loadu_si512(aoffset10 + i);
181 r1 = _mm512_loadu_si512(aoffset11 + i);
182 r2 = _mm512_loadu_si512(aoffset12 + i);
183 r3 = _mm512_loadu_si512(aoffset13 + i);
184 r4 = _mm512_loadu_si512(aoffset14 + i);
185 r5 = _mm512_loadu_si512(aoffset15 + i);
186 r6 = _mm512_loadu_si512(aoffset16 + i);
187 r7 = _mm512_loadu_si512(aoffset17 + i);
188 REORDER_8x32(t10, t11, t12, t13, t14, t15, t16, t17);
189 STORE_512(0, 0); STORE_512(0, 1); STORE_512(0, 2); STORE_512(0, 3);
190 STORE_512(0, 4); STORE_512(0, 5); STORE_512(0, 6); STORE_512(0, 7);
191 STORE_512(1, 0); STORE_512(1, 1); STORE_512(1, 2); STORE_512(1, 3);
192 STORE_512(1, 4); STORE_512(1, 5); STORE_512(1, 6); STORE_512(1, 7);
193 boffset0 += 16 * 32;
194 }
195 if (i < m) {
196 int remain_m = m - i;
197 __mmask32 mmask = (1UL << remain_m) - 1;
198 r0 = _mm512_maskz_loadu_epi16(mmask, aoffset00 + i);
199 r1 = _mm512_maskz_loadu_epi16(mmask, aoffset01 + i);
200 r2 = _mm512_maskz_loadu_epi16(mmask, aoffset02 + i);
201 r3 = _mm512_maskz_loadu_epi16(mmask, aoffset03 + i);
202 r4 = _mm512_maskz_loadu_epi16(mmask, aoffset04 + i);
203 r5 = _mm512_maskz_loadu_epi16(mmask, aoffset05 + i);
204 r6 = _mm512_maskz_loadu_epi16(mmask, aoffset06 + i);
205 r7 = _mm512_maskz_loadu_epi16(mmask, aoffset07 + i);
206 REORDER_8x32(t00, t01, t02, t03, t04, t05, t06, t07);
207 r0 = _mm512_maskz_loadu_epi16(mmask, aoffset10 + i);
208 r1 = _mm512_maskz_loadu_epi16(mmask, aoffset11 + i);
209 r2 = _mm512_maskz_loadu_epi16(mmask, aoffset12 + i);
210 r3 = _mm512_maskz_loadu_epi16(mmask, aoffset13 + i);
211 r4 = _mm512_maskz_loadu_epi16(mmask, aoffset14 + i);
212 r5 = _mm512_maskz_loadu_epi16(mmask, aoffset15 + i);
213 r6 = _mm512_maskz_loadu_epi16(mmask, aoffset16 + i);
214 r7 = _mm512_maskz_loadu_epi16(mmask, aoffset17 + i);
215 REORDER_8x32(t10, t11, t12, t13, t14, t15, t16, t17);
216 int n_store = remain_m/2;
217 switch (n_store) {
218 case 15: STORE_512(1, 6);
219 case 14: STORE_512(1, 5);
220 case 13: STORE_512(1, 4);
221 case 12: STORE_512(1, 3);
222 case 11: STORE_512(1, 2);
223 case 10: STORE_512(1, 1);
224 case 9: STORE_512(1, 0);
225 case 8: STORE_512(0, 7);
226 case 7: STORE_512(0, 6);
227 case 6: STORE_512(0, 5);
228 case 5: STORE_512(0, 4);
229 case 4: STORE_512(0, 3);
230 case 3: STORE_512(0, 2);
231 case 2: STORE_512(0, 1);
232 case 1: STORE_512(0, 0);
233 }
234 boffset0 += n_store * 32;
235 if (m & 0x1) {
236 __m512i tail;
237 GET_TAIL();
238 _mm256_storeu_si256((void *)boffset0, _mm512_cvtepi32_epi16(tail));
239 boffset0 += 16;
240 }
241 }
242
243 }
244 if (j < n) {
245 int remain_n = n - j;
246 __mmask16 nmask = (1UL << remain_n) - 1;
247 int load0, load1;
248 if (remain_n > 8) {
249 load0 = 8;
250 load1 = remain_n - 8;
251 } else {
252 load0 = remain_n;
253 load1 = 0;
254 }
255 aoffset00 = aoffset;
256 aoffset01 = aoffset00 + lda;
257 aoffset02 = aoffset01 + lda;
258 aoffset03 = aoffset02 + lda;
259 aoffset04 = aoffset03 + lda;
260 aoffset05 = aoffset04 + lda;
261 aoffset06 = aoffset05 + lda;
262 aoffset07 = aoffset06 + lda;
263 aoffset10 = aoffset07 + lda;
264 aoffset11 = aoffset10 + lda;
265 aoffset12 = aoffset11 + lda;
266 aoffset13 = aoffset12 + lda;
267 aoffset14 = aoffset13 + lda;
268 aoffset15 = aoffset14 + lda;
269 aoffset16 = aoffset15 + lda;
270 aoffset17 = aoffset16 + lda;
271 aoffset += 16 * lda;
272 for (i = 0; i < m32; i += 32) {
273 switch (load0) {
274 case 8: r7 = _mm512_loadu_si512(aoffset07 + i);
275 case 7: r6 = _mm512_loadu_si512(aoffset06 + i);
276 case 6: r5 = _mm512_loadu_si512(aoffset05 + i);
277 case 5: r4 = _mm512_loadu_si512(aoffset04 + i);
278 case 4: r3 = _mm512_loadu_si512(aoffset03 + i);
279 case 3: r2 = _mm512_loadu_si512(aoffset02 + i);
280 case 2: r1 = _mm512_loadu_si512(aoffset01 + i);
281 case 1: r0 = _mm512_loadu_si512(aoffset00 + i);
282 }
283 REORDER_8x32(t00, t01, t02, t03, t04, t05, t06, t07);
284 switch (load1) {
285 case 8: r7 = _mm512_loadu_si512(aoffset17 + i);
286 case 7: r6 = _mm512_loadu_si512(aoffset16 + i);
287 case 6: r5 = _mm512_loadu_si512(aoffset15 + i);
288 case 5: r4 = _mm512_loadu_si512(aoffset14 + i);
289 case 4: r3 = _mm512_loadu_si512(aoffset13 + i);
290 case 3: r2 = _mm512_loadu_si512(aoffset12 + i);
291 case 2: r1 = _mm512_loadu_si512(aoffset11 + i);
292 case 1: r0 = _mm512_loadu_si512(aoffset10 + i);
293 }
294 REORDER_8x32(t10, t11, t12, t13, t14, t15, t16, t17);
295 MASK_STORE_512(0, 0); MASK_STORE_512(0, 1); MASK_STORE_512(0, 2); MASK_STORE_512(0, 3);
296 MASK_STORE_512(0, 4); MASK_STORE_512(0, 5); MASK_STORE_512(0, 6); MASK_STORE_512(0, 7);
297 MASK_STORE_512(1, 0); MASK_STORE_512(1, 1); MASK_STORE_512(1, 2); MASK_STORE_512(1, 3);
298 MASK_STORE_512(1, 4); MASK_STORE_512(1, 5); MASK_STORE_512(1, 6); MASK_STORE_512(1, 7);
299 boffset0 += remain_n * 32;
300 }
301 if (i < m) {
302 int remain_m = m - i;
303 __mmask32 mmask = (1UL << remain_m) - 1;
304 switch (load0) {
305 case 8: r7 = _mm512_maskz_loadu_epi16(mmask, aoffset07 + i);
306 case 7: r6 = _mm512_maskz_loadu_epi16(mmask, aoffset06 + i);
307 case 6: r5 = _mm512_maskz_loadu_epi16(mmask, aoffset05 + i);
308 case 5: r4 = _mm512_maskz_loadu_epi16(mmask, aoffset04 + i);
309 case 4: r3 = _mm512_maskz_loadu_epi16(mmask, aoffset03 + i);
310 case 3: r2 = _mm512_maskz_loadu_epi16(mmask, aoffset02 + i);
311 case 2: r1 = _mm512_maskz_loadu_epi16(mmask, aoffset01 + i);
312 case 1: r0 = _mm512_maskz_loadu_epi16(mmask, aoffset00 + i);
313 }
314 REORDER_8x32(t00, t01, t02, t03, t04, t05, t06, t07);
315 switch (load1) {
316 case 8: r7 = _mm512_maskz_loadu_epi16(mmask, aoffset17 + i);
317 case 7: r6 = _mm512_maskz_loadu_epi16(mmask, aoffset16 + i);
318 case 6: r5 = _mm512_maskz_loadu_epi16(mmask, aoffset15 + i);
319 case 5: r4 = _mm512_maskz_loadu_epi16(mmask, aoffset14 + i);
320 case 4: r3 = _mm512_maskz_loadu_epi16(mmask, aoffset13 + i);
321 case 3: r2 = _mm512_maskz_loadu_epi16(mmask, aoffset12 + i);
322 case 2: r1 = _mm512_maskz_loadu_epi16(mmask, aoffset11 + i);
323 case 1: r0 = _mm512_maskz_loadu_epi16(mmask, aoffset10 + i);
324 }
325 REORDER_8x32(t10, t11, t12, t13, t14, t15, t16, t17);
326 int n_store = remain_m/2;
327 switch (n_store) {
328 case 15: MASK_STORE_512(1, 6);
329 case 14: MASK_STORE_512(1, 5);
330 case 13: MASK_STORE_512(1, 4);
331 case 12: MASK_STORE_512(1, 3);
332 case 11: MASK_STORE_512(1, 2);
333 case 10: MASK_STORE_512(1, 1);
334 case 9: MASK_STORE_512(1, 0);
335 case 8: MASK_STORE_512(0, 7);
336 case 7: MASK_STORE_512(0, 6);
337 case 6: MASK_STORE_512(0, 5);
338 case 5: MASK_STORE_512(0, 4);
339 case 4: MASK_STORE_512(0, 3);
340 case 3: MASK_STORE_512(0, 2);
341 case 2: MASK_STORE_512(0, 1);
342 case 1: MASK_STORE_512(0, 0);
343 }
344 boffset0 += n_store * remain_n * 2;
345 if (m & 0x1) {
346 __m512i tail;
347 GET_TAIL();
348 _mm256_mask_storeu_epi16((void *)boffset0, nmask, _mm512_cvtepi32_epi16(tail));
349 }
350 }
351 }
352 return 0;
353 }
354