1 /*
2
3 BLIS
4 An object-based framework for developing high-performance BLAS-like
5 libraries.
6
7 Copyright (C) 2018, The University of Texas at Austin
8 Copyright (C) 2016 - 2018, Advanced Micro Devices, Inc.
9
10 Redistribution and use in source and binary forms, with or without
11 modification, are permitted provided that the following conditions are
12 met:
13 - Redistributions of source code must retain the above copyright
14 notice, this list of conditions and the following disclaimer.
15 - Redistributions in binary form must reproduce the above copyright
16 notice, this list of conditions and the following disclaimer in the
17 documentation and/or other materials provided with the distribution.
18 - Neither the name(s) of the copyright holder(s) nor the names of its
19 contributors may be used to endorse or promote products derived
20 from this software without specific prior written permission.
21
22 THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
23 "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24 LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
25 A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
26 HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
27 SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
28 LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
29 DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
30 THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
31 (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
32 OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33
34 */
35
36 #include "immintrin.h"
37 #include "blis.h"
38
39 /* Union data structure to access AVX registers
40 One 256-bit AVX register holds 8 SP elements. */
41 typedef union
42 {
43 __m256 v;
44 float f[8] __attribute__((aligned(64)));
45 } v8sf_t;
46
47 /* Union data structure to access AVX registers
48 * One 256-bit AVX register holds 4 DP elements. */
49 typedef union
50 {
51 __m256d v;
52 double d[4] __attribute__((aligned(64)));
53 } v4df_t;
54
55 // -----------------------------------------------------------------------------
56
bli_saxpyf_zen_int_8(conj_t conja,conj_t conjx,dim_t m,dim_t b_n,float * restrict alpha,float * restrict a,inc_t inca,inc_t lda,float * restrict x,inc_t incx,float * restrict y,inc_t incy,cntx_t * restrict cntx)57 void bli_saxpyf_zen_int_8
58 (
59 conj_t conja,
60 conj_t conjx,
61 dim_t m,
62 dim_t b_n,
63 float* restrict alpha,
64 float* restrict a, inc_t inca, inc_t lda,
65 float* restrict x, inc_t incx,
66 float* restrict y, inc_t incy,
67 cntx_t* restrict cntx
68 )
69 {
70 const dim_t fuse_fac = 8;
71
72 const dim_t n_elem_per_reg = 8;
73 const dim_t n_iter_unroll = 1;
74
75 dim_t i;
76 dim_t m_viter;
77 dim_t m_left;
78
79 float* restrict a0;
80 float* restrict a1;
81 float* restrict a2;
82 float* restrict a3;
83 float* restrict a4;
84 float* restrict a5;
85 float* restrict a6;
86 float* restrict a7;
87
88 float* restrict y0;
89
90 v8sf_t chi0v, chi1v, chi2v, chi3v;
91 v8sf_t chi4v, chi5v, chi6v, chi7v;
92
93 v8sf_t a0v, a1v, a2v, a3v;
94 v8sf_t a4v, a5v, a6v, a7v;
95 v8sf_t y0v;
96
97 float chi0, chi1, chi2, chi3;
98 float chi4, chi5, chi6, chi7;
99
100 // If either dimension is zero, or if alpha is zero, return early.
101 if ( bli_zero_dim2( m, b_n ) || PASTEMAC(s,eq0)( *alpha ) ) return;
102
103 // If b_n is not equal to the fusing factor, then perform the entire
104 // operation as a loop over axpyv.
105 if ( b_n != fuse_fac )
106 {
107 saxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_FLOAT, BLIS_AXPYV_KER, cntx );
108
109 for ( i = 0; i < b_n; ++i )
110 {
111 float* a1 = a + (0 )*inca + (i )*lda;
112 float* chi1 = x + (i )*incx;
113 float* y1 = y + (0 )*incy;
114 float alpha_chi1;
115
116 PASTEMAC(s,copycjs)( conjx, *chi1, alpha_chi1 );
117 PASTEMAC(s,scals)( *alpha, alpha_chi1 );
118
119 f
120 (
121 conja,
122 m,
123 &alpha_chi1,
124 a1, inca,
125 y1, incy,
126 cntx
127 );
128 }
129
130 return;
131 }
132
133 // At this point, we know that b_n is exactly equal to the fusing factor.
134
135 // Use the unrolling factor and the number of elements per register
136 // to compute the number of vectorized and leftover iterations.
137 m_viter = ( m ) / ( n_elem_per_reg * n_iter_unroll );
138 m_left = ( m ) % ( n_elem_per_reg * n_iter_unroll );
139
140 // If there is anything that would interfere with our use of contiguous
141 // vector loads/stores, override m_viter and m_left to use scalar code
142 // for all iterations.
143 if ( inca != 1 || incy != 1 )
144 {
145 m_viter = 0;
146 m_left = m;
147 }
148
149 a0 = a + 0*lda;
150 a1 = a + 1*lda;
151 a2 = a + 2*lda;
152 a3 = a + 3*lda;
153 a4 = a + 4*lda;
154 a5 = a + 5*lda;
155 a6 = a + 6*lda;
156 a7 = a + 7*lda;
157 y0 = y;
158
159 chi0 = *( x + 0*incx );
160 chi1 = *( x + 1*incx );
161 chi2 = *( x + 2*incx );
162 chi3 = *( x + 3*incx );
163 chi4 = *( x + 4*incx );
164 chi5 = *( x + 5*incx );
165 chi6 = *( x + 6*incx );
166 chi7 = *( x + 7*incx );
167
168 // Scale each chi scalar by alpha.
169 PASTEMAC(s,scals)( *alpha, chi0 );
170 PASTEMAC(s,scals)( *alpha, chi1 );
171 PASTEMAC(s,scals)( *alpha, chi2 );
172 PASTEMAC(s,scals)( *alpha, chi3 );
173 PASTEMAC(s,scals)( *alpha, chi4 );
174 PASTEMAC(s,scals)( *alpha, chi5 );
175 PASTEMAC(s,scals)( *alpha, chi6 );
176 PASTEMAC(s,scals)( *alpha, chi7 );
177
178 // Broadcast the (alpha*chi?) scalars to all elements of vector registers.
179 chi0v.v = _mm256_broadcast_ss( &chi0 );
180 chi1v.v = _mm256_broadcast_ss( &chi1 );
181 chi2v.v = _mm256_broadcast_ss( &chi2 );
182 chi3v.v = _mm256_broadcast_ss( &chi3 );
183 chi4v.v = _mm256_broadcast_ss( &chi4 );
184 chi5v.v = _mm256_broadcast_ss( &chi5 );
185 chi6v.v = _mm256_broadcast_ss( &chi6 );
186 chi7v.v = _mm256_broadcast_ss( &chi7 );
187
188 // If there are vectorized iterations, perform them with vector
189 // instructions.
190 for ( i = 0; i < m_viter; ++i )
191 {
192 // Load the input values.
193 y0v.v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg );
194 a0v.v = _mm256_loadu_ps( a0 + 0*n_elem_per_reg );
195 a1v.v = _mm256_loadu_ps( a1 + 0*n_elem_per_reg );
196 a2v.v = _mm256_loadu_ps( a2 + 0*n_elem_per_reg );
197 a3v.v = _mm256_loadu_ps( a3 + 0*n_elem_per_reg );
198 a4v.v = _mm256_loadu_ps( a4 + 0*n_elem_per_reg );
199 a5v.v = _mm256_loadu_ps( a5 + 0*n_elem_per_reg );
200 a6v.v = _mm256_loadu_ps( a6 + 0*n_elem_per_reg );
201 a7v.v = _mm256_loadu_ps( a7 + 0*n_elem_per_reg );
202
203 // perform : y += alpha * x;
204 y0v.v = _mm256_fmadd_ps( a0v.v, chi0v.v, y0v.v );
205 y0v.v = _mm256_fmadd_ps( a1v.v, chi1v.v, y0v.v );
206 y0v.v = _mm256_fmadd_ps( a2v.v, chi2v.v, y0v.v );
207 y0v.v = _mm256_fmadd_ps( a3v.v, chi3v.v, y0v.v );
208 y0v.v = _mm256_fmadd_ps( a4v.v, chi4v.v, y0v.v );
209 y0v.v = _mm256_fmadd_ps( a5v.v, chi5v.v, y0v.v );
210 y0v.v = _mm256_fmadd_ps( a6v.v, chi6v.v, y0v.v );
211 y0v.v = _mm256_fmadd_ps( a7v.v, chi7v.v, y0v.v );
212
213 // Store the output.
214 _mm256_storeu_ps( (y0 + 0*n_elem_per_reg), y0v.v );
215
216 y0 += n_elem_per_reg;
217 a0 += n_elem_per_reg;
218 a1 += n_elem_per_reg;
219 a2 += n_elem_per_reg;
220 a3 += n_elem_per_reg;
221 a4 += n_elem_per_reg;
222 a5 += n_elem_per_reg;
223 a6 += n_elem_per_reg;
224 a7 += n_elem_per_reg;
225 }
226
227 // If there are leftover iterations, perform them with scalar code.
228 for ( i = 0; i < m_left ; ++i )
229 {
230 float y0c = *y0;
231
232 const float a0c = *a0;
233 const float a1c = *a1;
234 const float a2c = *a2;
235 const float a3c = *a3;
236 const float a4c = *a4;
237 const float a5c = *a5;
238 const float a6c = *a6;
239 const float a7c = *a7;
240
241 y0c += chi0 * a0c;
242 y0c += chi1 * a1c;
243 y0c += chi2 * a2c;
244 y0c += chi3 * a3c;
245 y0c += chi4 * a4c;
246 y0c += chi5 * a5c;
247 y0c += chi6 * a6c;
248 y0c += chi7 * a7c;
249
250 *y0 = y0c;
251
252 a0 += inca;
253 a1 += inca;
254 a2 += inca;
255 a3 += inca;
256 a4 += inca;
257 a5 += inca;
258 a6 += inca;
259 a7 += inca;
260 y0 += incy;
261 }
262 }
263
264 // -----------------------------------------------------------------------------
265
bli_daxpyf_zen_int_8(conj_t conja,conj_t conjx,dim_t m,dim_t b_n,double * restrict alpha,double * restrict a,inc_t inca,inc_t lda,double * restrict x,inc_t incx,double * restrict y,inc_t incy,cntx_t * restrict cntx)266 void bli_daxpyf_zen_int_8
267 (
268 conj_t conja,
269 conj_t conjx,
270 dim_t m,
271 dim_t b_n,
272 double* restrict alpha,
273 double* restrict a, inc_t inca, inc_t lda,
274 double* restrict x, inc_t incx,
275 double* restrict y, inc_t incy,
276 cntx_t* restrict cntx
277 )
278 {
279 const dim_t fuse_fac = 8;
280
281 const dim_t n_elem_per_reg = 4;
282 const dim_t n_iter_unroll = 1;
283
284 dim_t i;
285 dim_t m_viter;
286 dim_t m_left;
287
288 double* restrict a0;
289 double* restrict a1;
290 double* restrict a2;
291 double* restrict a3;
292 double* restrict a4;
293 double* restrict a5;
294 double* restrict a6;
295 double* restrict a7;
296
297 double* restrict y0;
298
299 v4df_t chi0v, chi1v, chi2v, chi3v;
300 v4df_t chi4v, chi5v, chi6v, chi7v;
301
302 v4df_t a0v, a1v, a2v, a3v;
303 v4df_t a4v, a5v, a6v, a7v;
304 v4df_t y0v;
305
306 double chi0, chi1, chi2, chi3;
307 double chi4, chi5, chi6, chi7;
308
309 // If either dimension is zero, or if alpha is zero, return early.
310 if ( bli_zero_dim2( m, b_n ) || PASTEMAC(d,eq0)( *alpha ) ) return;
311
312 // If b_n is not equal to the fusing factor, then perform the entire
313 // operation as a loop over axpyv.
314 if ( b_n != fuse_fac )
315 {
316 daxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DOUBLE, BLIS_AXPYV_KER, cntx );
317
318 for ( i = 0; i < b_n; ++i )
319 {
320 double* a1 = a + (0 )*inca + (i )*lda;
321 double* chi1 = x + (i )*incx;
322 double* y1 = y + (0 )*incy;
323 double alpha_chi1;
324
325 PASTEMAC(d,copycjs)( conjx, *chi1, alpha_chi1 );
326 PASTEMAC(d,scals)( *alpha, alpha_chi1 );
327
328 f
329 (
330 conja,
331 m,
332 &alpha_chi1,
333 a1, inca,
334 y1, incy,
335 cntx
336 );
337 }
338
339 return;
340 }
341
342 // At this point, we know that b_n is exactly equal to the fusing factor.
343
344 // Use the unrolling factor and the number of elements per register
345 // to compute the number of vectorized and leftover iterations.
346 m_viter = ( m ) / ( n_elem_per_reg * n_iter_unroll );
347 m_left = ( m ) % ( n_elem_per_reg * n_iter_unroll );
348
349 // If there is anything that would interfere with our use of contiguous
350 // vector loads/stores, override m_viter and m_left to use scalar code
351 // for all iterations.
352 if ( inca != 1 || incy != 1 )
353 {
354 m_viter = 0;
355 m_left = m;
356 }
357
358 a0 = a + 0*lda;
359 a1 = a + 1*lda;
360 a2 = a + 2*lda;
361 a3 = a + 3*lda;
362 a4 = a + 4*lda;
363 a5 = a + 5*lda;
364 a6 = a + 6*lda;
365 a7 = a + 7*lda;
366 y0 = y;
367
368 chi0 = *( x + 0*incx );
369 chi1 = *( x + 1*incx );
370 chi2 = *( x + 2*incx );
371 chi3 = *( x + 3*incx );
372 chi4 = *( x + 4*incx );
373 chi5 = *( x + 5*incx );
374 chi6 = *( x + 6*incx );
375 chi7 = *( x + 7*incx );
376
377 // Scale each chi scalar by alpha.
378 PASTEMAC(d,scals)( *alpha, chi0 );
379 PASTEMAC(d,scals)( *alpha, chi1 );
380 PASTEMAC(d,scals)( *alpha, chi2 );
381 PASTEMAC(d,scals)( *alpha, chi3 );
382 PASTEMAC(d,scals)( *alpha, chi4 );
383 PASTEMAC(d,scals)( *alpha, chi5 );
384 PASTEMAC(d,scals)( *alpha, chi6 );
385 PASTEMAC(d,scals)( *alpha, chi7 );
386
387 // Broadcast the (alpha*chi?) scalars to all elements of vector registers.
388 chi0v.v = _mm256_broadcast_sd( &chi0 );
389 chi1v.v = _mm256_broadcast_sd( &chi1 );
390 chi2v.v = _mm256_broadcast_sd( &chi2 );
391 chi3v.v = _mm256_broadcast_sd( &chi3 );
392 chi4v.v = _mm256_broadcast_sd( &chi4 );
393 chi5v.v = _mm256_broadcast_sd( &chi5 );
394 chi6v.v = _mm256_broadcast_sd( &chi6 );
395 chi7v.v = _mm256_broadcast_sd( &chi7 );
396
397 // If there are vectorized iterations, perform them with vector
398 // instructions.
399 for ( i = 0; i < m_viter; ++i )
400 {
401 // Load the input values.
402 y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg );
403 a0v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg );
404 a1v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg );
405 a2v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg );
406 a3v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg );
407 a4v.v = _mm256_loadu_pd( a4 + 0*n_elem_per_reg );
408 a5v.v = _mm256_loadu_pd( a5 + 0*n_elem_per_reg );
409 a6v.v = _mm256_loadu_pd( a6 + 0*n_elem_per_reg );
410 a7v.v = _mm256_loadu_pd( a7 + 0*n_elem_per_reg );
411
412 // perform : y += alpha * x;
413 y0v.v = _mm256_fmadd_pd( a0v.v, chi0v.v, y0v.v );
414 y0v.v = _mm256_fmadd_pd( a1v.v, chi1v.v, y0v.v );
415 y0v.v = _mm256_fmadd_pd( a2v.v, chi2v.v, y0v.v );
416 y0v.v = _mm256_fmadd_pd( a3v.v, chi3v.v, y0v.v );
417 y0v.v = _mm256_fmadd_pd( a4v.v, chi4v.v, y0v.v );
418 y0v.v = _mm256_fmadd_pd( a5v.v, chi5v.v, y0v.v );
419 y0v.v = _mm256_fmadd_pd( a6v.v, chi6v.v, y0v.v );
420 y0v.v = _mm256_fmadd_pd( a7v.v, chi7v.v, y0v.v );
421
422 // Store the output.
423 _mm256_storeu_pd( (y0 + 0*n_elem_per_reg), y0v.v );
424
425 y0 += n_elem_per_reg;
426 a0 += n_elem_per_reg;
427 a1 += n_elem_per_reg;
428 a2 += n_elem_per_reg;
429 a3 += n_elem_per_reg;
430 a4 += n_elem_per_reg;
431 a5 += n_elem_per_reg;
432 a6 += n_elem_per_reg;
433 a7 += n_elem_per_reg;
434 }
435
436 // If there are leftover iterations, perform them with scalar code.
437 for ( i = 0; i < m_left ; ++i )
438 {
439 double y0c = *y0;
440
441 const double a0c = *a0;
442 const double a1c = *a1;
443 const double a2c = *a2;
444 const double a3c = *a3;
445 const double a4c = *a4;
446 const double a5c = *a5;
447 const double a6c = *a6;
448 const double a7c = *a7;
449
450 y0c += chi0 * a0c;
451 y0c += chi1 * a1c;
452 y0c += chi2 * a2c;
453 y0c += chi3 * a3c;
454 y0c += chi4 * a4c;
455 y0c += chi5 * a5c;
456 y0c += chi6 * a6c;
457 y0c += chi7 * a7c;
458
459 *y0 = y0c;
460
461 a0 += inca;
462 a1 += inca;
463 a2 += inca;
464 a3 += inca;
465 a4 += inca;
466 a5 += inca;
467 a6 += inca;
468 a7 += inca;
469 y0 += incy;
470 }
471 }
472
473