1 /*
2 
3    BLIS
4    An object-based framework for developing high-performance BLAS-like
5    libraries.
6 
7    Copyright (C) 2018, The University of Texas at Austin
8    Copyright (C) 2016 - 2018, Advanced Micro Devices, Inc.
9 
10    Redistribution and use in source and binary forms, with or without
11    modification, are permitted provided that the following conditions are
12    met:
13     - Redistributions of source code must retain the above copyright
14       notice, this list of conditions and the following disclaimer.
15     - Redistributions in binary form must reproduce the above copyright
16       notice, this list of conditions and the following disclaimer in the
17       documentation and/or other materials provided with the distribution.
18     - Neither the name(s) of the copyright holder(s) nor the names of its
19       contributors may be used to endorse or promote products derived
20       from this software without specific prior written permission.
21 
22    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
23    "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
25    A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
26    HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
27    SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
28    LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
29    DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
30    THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
31    (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
32    OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33 
34 */
35 
36 #include "immintrin.h"
37 #include "blis.h"
38 
39 /* Union data structure to access AVX registers
40    One 256-bit AVX register holds 8 SP elements. */
41 typedef union
42 {
43 	__m256  v;
44 	float   f[8] __attribute__((aligned(64)));
45 } v8sf_t;
46 
47 /* Union data structure to access AVX registers
48 *  One 256-bit AVX register holds 4 DP elements. */
49 typedef union
50 {
51 	__m256d v;
52 	double  d[4] __attribute__((aligned(64)));
53 } v4df_t;
54 
55 // -----------------------------------------------------------------------------
56 
bli_saxpyf_zen_int_8(conj_t conja,conj_t conjx,dim_t m,dim_t b_n,float * restrict alpha,float * restrict a,inc_t inca,inc_t lda,float * restrict x,inc_t incx,float * restrict y,inc_t incy,cntx_t * restrict cntx)57 void bli_saxpyf_zen_int_8
58      (
59        conj_t           conja,
60        conj_t           conjx,
61        dim_t            m,
62        dim_t            b_n,
63        float*  restrict alpha,
64        float*  restrict a, inc_t inca, inc_t lda,
65        float*  restrict x, inc_t incx,
66        float*  restrict y, inc_t incy,
67        cntx_t* restrict cntx
68      )
69 {
70 	const dim_t      fuse_fac       = 8;
71 
72 	const dim_t      n_elem_per_reg = 8;
73 	const dim_t      n_iter_unroll  = 1;
74 
75 	dim_t            i;
76 	dim_t            m_viter;
77 	dim_t            m_left;
78 
79 	float*  restrict a0;
80 	float*  restrict a1;
81 	float*  restrict a2;
82 	float*  restrict a3;
83 	float*  restrict a4;
84 	float*  restrict a5;
85 	float*  restrict a6;
86 	float*  restrict a7;
87 
88 	float*  restrict y0;
89 
90 	v8sf_t           chi0v, chi1v, chi2v, chi3v;
91 	v8sf_t           chi4v, chi5v, chi6v, chi7v;
92 
93 	v8sf_t           a0v, a1v, a2v, a3v;
94 	v8sf_t           a4v, a5v, a6v, a7v;
95 	v8sf_t           y0v;
96 
97 	float            chi0, chi1, chi2, chi3;
98 	float            chi4, chi5, chi6, chi7;
99 
100 	// If either dimension is zero, or if alpha is zero, return early.
101 	if ( bli_zero_dim2( m, b_n ) || PASTEMAC(s,eq0)( *alpha ) ) return;
102 
103 	// If b_n is not equal to the fusing factor, then perform the entire
104 	// operation as a loop over axpyv.
105 	if ( b_n != fuse_fac )
106 	{
107 		saxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_FLOAT, BLIS_AXPYV_KER, cntx );
108 
109 		for ( i = 0; i < b_n; ++i )
110 		{
111 			float* a1   = a + (0  )*inca + (i  )*lda;
112 			float* chi1 = x + (i  )*incx;
113 			float* y1   = y + (0  )*incy;
114 			float  alpha_chi1;
115 
116 			PASTEMAC(s,copycjs)( conjx, *chi1, alpha_chi1 );
117 			PASTEMAC(s,scals)( *alpha, alpha_chi1 );
118 
119 			f
120 			(
121 			  conja,
122 			  m,
123 			  &alpha_chi1,
124 			  a1, inca,
125 			  y1, incy,
126 			  cntx
127 			);
128 		}
129 
130 		return;
131 	}
132 
133 	// At this point, we know that b_n is exactly equal to the fusing factor.
134 
135 	// Use the unrolling factor and the number of elements per register
136 	// to compute the number of vectorized and leftover iterations.
137 	m_viter = ( m ) / ( n_elem_per_reg * n_iter_unroll );
138 	m_left  = ( m ) % ( n_elem_per_reg * n_iter_unroll );
139 
140 	// If there is anything that would interfere with our use of contiguous
141 	// vector loads/stores, override m_viter and m_left to use scalar code
142 	// for all iterations.
143 	if ( inca != 1 || incy != 1 )
144 	{
145 		m_viter = 0;
146 		m_left  = m;
147 	}
148 
149 	a0   = a + 0*lda;
150 	a1   = a + 1*lda;
151 	a2   = a + 2*lda;
152 	a3   = a + 3*lda;
153 	a4   = a + 4*lda;
154 	a5   = a + 5*lda;
155 	a6   = a + 6*lda;
156 	a7   = a + 7*lda;
157 	y0   = y;
158 
159 	chi0 = *( x + 0*incx );
160 	chi1 = *( x + 1*incx );
161 	chi2 = *( x + 2*incx );
162 	chi3 = *( x + 3*incx );
163 	chi4 = *( x + 4*incx );
164 	chi5 = *( x + 5*incx );
165 	chi6 = *( x + 6*incx );
166 	chi7 = *( x + 7*incx );
167 
168 	// Scale each chi scalar by alpha.
169 	PASTEMAC(s,scals)( *alpha, chi0 );
170 	PASTEMAC(s,scals)( *alpha, chi1 );
171 	PASTEMAC(s,scals)( *alpha, chi2 );
172 	PASTEMAC(s,scals)( *alpha, chi3 );
173 	PASTEMAC(s,scals)( *alpha, chi4 );
174 	PASTEMAC(s,scals)( *alpha, chi5 );
175 	PASTEMAC(s,scals)( *alpha, chi6 );
176 	PASTEMAC(s,scals)( *alpha, chi7 );
177 
178 	// Broadcast the (alpha*chi?) scalars to all elements of vector registers.
179 	chi0v.v = _mm256_broadcast_ss( &chi0 );
180 	chi1v.v = _mm256_broadcast_ss( &chi1 );
181 	chi2v.v = _mm256_broadcast_ss( &chi2 );
182 	chi3v.v = _mm256_broadcast_ss( &chi3 );
183 	chi4v.v = _mm256_broadcast_ss( &chi4 );
184 	chi5v.v = _mm256_broadcast_ss( &chi5 );
185 	chi6v.v = _mm256_broadcast_ss( &chi6 );
186 	chi7v.v = _mm256_broadcast_ss( &chi7 );
187 
188 	// If there are vectorized iterations, perform them with vector
189 	// instructions.
190 	for ( i = 0; i < m_viter; ++i )
191 	{
192 		// Load the input values.
193 		y0v.v = _mm256_loadu_ps( y0 + 0*n_elem_per_reg );
194 		a0v.v = _mm256_loadu_ps( a0 + 0*n_elem_per_reg );
195 		a1v.v = _mm256_loadu_ps( a1 + 0*n_elem_per_reg );
196 		a2v.v = _mm256_loadu_ps( a2 + 0*n_elem_per_reg );
197 		a3v.v = _mm256_loadu_ps( a3 + 0*n_elem_per_reg );
198 		a4v.v = _mm256_loadu_ps( a4 + 0*n_elem_per_reg );
199 		a5v.v = _mm256_loadu_ps( a5 + 0*n_elem_per_reg );
200 		a6v.v = _mm256_loadu_ps( a6 + 0*n_elem_per_reg );
201 		a7v.v = _mm256_loadu_ps( a7 + 0*n_elem_per_reg );
202 
203 		// perform : y += alpha * x;
204 		y0v.v = _mm256_fmadd_ps( a0v.v, chi0v.v, y0v.v );
205 		y0v.v = _mm256_fmadd_ps( a1v.v, chi1v.v, y0v.v );
206 		y0v.v = _mm256_fmadd_ps( a2v.v, chi2v.v, y0v.v );
207 		y0v.v = _mm256_fmadd_ps( a3v.v, chi3v.v, y0v.v );
208 		y0v.v = _mm256_fmadd_ps( a4v.v, chi4v.v, y0v.v );
209 		y0v.v = _mm256_fmadd_ps( a5v.v, chi5v.v, y0v.v );
210 		y0v.v = _mm256_fmadd_ps( a6v.v, chi6v.v, y0v.v );
211 		y0v.v = _mm256_fmadd_ps( a7v.v, chi7v.v, y0v.v );
212 
213 		// Store the output.
214 		_mm256_storeu_ps( (y0 + 0*n_elem_per_reg), y0v.v );
215 
216 		y0 += n_elem_per_reg;
217 		a0 += n_elem_per_reg;
218 		a1 += n_elem_per_reg;
219 		a2 += n_elem_per_reg;
220 		a3 += n_elem_per_reg;
221 		a4 += n_elem_per_reg;
222 		a5 += n_elem_per_reg;
223 		a6 += n_elem_per_reg;
224 		a7 += n_elem_per_reg;
225 	}
226 
227 	// If there are leftover iterations, perform them with scalar code.
228 	for ( i = 0; i < m_left ; ++i )
229 	{
230 		float       y0c = *y0;
231 
232 		const float a0c = *a0;
233 		const float a1c = *a1;
234 		const float a2c = *a2;
235 		const float a3c = *a3;
236 		const float a4c = *a4;
237 		const float a5c = *a5;
238 		const float a6c = *a6;
239 		const float a7c = *a7;
240 
241 		y0c += chi0 * a0c;
242 		y0c += chi1 * a1c;
243 		y0c += chi2 * a2c;
244 		y0c += chi3 * a3c;
245 		y0c += chi4 * a4c;
246 		y0c += chi5 * a5c;
247 		y0c += chi6 * a6c;
248 		y0c += chi7 * a7c;
249 
250 		*y0 = y0c;
251 
252 		a0 += inca;
253 		a1 += inca;
254 		a2 += inca;
255 		a3 += inca;
256 		a4 += inca;
257 		a5 += inca;
258 		a6 += inca;
259 		a7 += inca;
260 		y0 += incy;
261 	}
262 }
263 
264 // -----------------------------------------------------------------------------
265 
bli_daxpyf_zen_int_8(conj_t conja,conj_t conjx,dim_t m,dim_t b_n,double * restrict alpha,double * restrict a,inc_t inca,inc_t lda,double * restrict x,inc_t incx,double * restrict y,inc_t incy,cntx_t * restrict cntx)266 void bli_daxpyf_zen_int_8
267      (
268        conj_t           conja,
269        conj_t           conjx,
270        dim_t            m,
271        dim_t            b_n,
272        double* restrict alpha,
273        double* restrict a, inc_t inca, inc_t lda,
274        double* restrict x, inc_t incx,
275        double* restrict y, inc_t incy,
276        cntx_t* restrict cntx
277      )
278 {
279 	const dim_t      fuse_fac       = 8;
280 
281 	const dim_t      n_elem_per_reg = 4;
282 	const dim_t      n_iter_unroll  = 1;
283 
284 	dim_t            i;
285 	dim_t            m_viter;
286 	dim_t            m_left;
287 
288 	double* restrict a0;
289 	double* restrict a1;
290 	double* restrict a2;
291 	double* restrict a3;
292 	double* restrict a4;
293 	double* restrict a5;
294 	double* restrict a6;
295 	double* restrict a7;
296 
297 	double* restrict y0;
298 
299 	v4df_t           chi0v, chi1v, chi2v, chi3v;
300 	v4df_t           chi4v, chi5v, chi6v, chi7v;
301 
302 	v4df_t           a0v, a1v, a2v, a3v;
303 	v4df_t           a4v, a5v, a6v, a7v;
304 	v4df_t           y0v;
305 
306 	double           chi0, chi1, chi2, chi3;
307 	double           chi4, chi5, chi6, chi7;
308 
309 	// If either dimension is zero, or if alpha is zero, return early.
310 	if ( bli_zero_dim2( m, b_n ) || PASTEMAC(d,eq0)( *alpha ) ) return;
311 
312 	// If b_n is not equal to the fusing factor, then perform the entire
313 	// operation as a loop over axpyv.
314 	if ( b_n != fuse_fac )
315 	{
316 		daxpyv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DOUBLE, BLIS_AXPYV_KER, cntx );
317 
318 		for ( i = 0; i < b_n; ++i )
319 		{
320 			double* a1   = a + (0  )*inca + (i  )*lda;
321 			double* chi1 = x + (i  )*incx;
322 			double* y1   = y + (0  )*incy;
323 			double  alpha_chi1;
324 
325 			PASTEMAC(d,copycjs)( conjx, *chi1, alpha_chi1 );
326 			PASTEMAC(d,scals)( *alpha, alpha_chi1 );
327 
328 			f
329 			(
330 			  conja,
331 			  m,
332 			  &alpha_chi1,
333 			  a1, inca,
334 			  y1, incy,
335 			  cntx
336 			);
337 		}
338 
339 		return;
340 	}
341 
342 	// At this point, we know that b_n is exactly equal to the fusing factor.
343 
344 	// Use the unrolling factor and the number of elements per register
345 	// to compute the number of vectorized and leftover iterations.
346 	m_viter = ( m ) / ( n_elem_per_reg * n_iter_unroll );
347 	m_left  = ( m ) % ( n_elem_per_reg * n_iter_unroll );
348 
349 	// If there is anything that would interfere with our use of contiguous
350 	// vector loads/stores, override m_viter and m_left to use scalar code
351 	// for all iterations.
352 	if ( inca != 1 || incy != 1 )
353 	{
354 		m_viter = 0;
355 		m_left  = m;
356 	}
357 
358 	a0   = a + 0*lda;
359 	a1   = a + 1*lda;
360 	a2   = a + 2*lda;
361 	a3   = a + 3*lda;
362 	a4   = a + 4*lda;
363 	a5   = a + 5*lda;
364 	a6   = a + 6*lda;
365 	a7   = a + 7*lda;
366 	y0   = y;
367 
368 	chi0 = *( x + 0*incx );
369 	chi1 = *( x + 1*incx );
370 	chi2 = *( x + 2*incx );
371 	chi3 = *( x + 3*incx );
372 	chi4 = *( x + 4*incx );
373 	chi5 = *( x + 5*incx );
374 	chi6 = *( x + 6*incx );
375 	chi7 = *( x + 7*incx );
376 
377 	// Scale each chi scalar by alpha.
378 	PASTEMAC(d,scals)( *alpha, chi0 );
379 	PASTEMAC(d,scals)( *alpha, chi1 );
380 	PASTEMAC(d,scals)( *alpha, chi2 );
381 	PASTEMAC(d,scals)( *alpha, chi3 );
382 	PASTEMAC(d,scals)( *alpha, chi4 );
383 	PASTEMAC(d,scals)( *alpha, chi5 );
384 	PASTEMAC(d,scals)( *alpha, chi6 );
385 	PASTEMAC(d,scals)( *alpha, chi7 );
386 
387 	// Broadcast the (alpha*chi?) scalars to all elements of vector registers.
388 	chi0v.v = _mm256_broadcast_sd( &chi0 );
389 	chi1v.v = _mm256_broadcast_sd( &chi1 );
390 	chi2v.v = _mm256_broadcast_sd( &chi2 );
391 	chi3v.v = _mm256_broadcast_sd( &chi3 );
392 	chi4v.v = _mm256_broadcast_sd( &chi4 );
393 	chi5v.v = _mm256_broadcast_sd( &chi5 );
394 	chi6v.v = _mm256_broadcast_sd( &chi6 );
395 	chi7v.v = _mm256_broadcast_sd( &chi7 );
396 
397 	// If there are vectorized iterations, perform them with vector
398 	// instructions.
399 	for ( i = 0; i < m_viter; ++i )
400 	{
401 		// Load the input values.
402 		y0v.v = _mm256_loadu_pd( y0 + 0*n_elem_per_reg );
403 		a0v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg );
404 		a1v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg );
405 		a2v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg );
406 		a3v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg );
407 		a4v.v = _mm256_loadu_pd( a4 + 0*n_elem_per_reg );
408 		a5v.v = _mm256_loadu_pd( a5 + 0*n_elem_per_reg );
409 		a6v.v = _mm256_loadu_pd( a6 + 0*n_elem_per_reg );
410 		a7v.v = _mm256_loadu_pd( a7 + 0*n_elem_per_reg );
411 
412 		// perform : y += alpha * x;
413 		y0v.v = _mm256_fmadd_pd( a0v.v, chi0v.v, y0v.v );
414 		y0v.v = _mm256_fmadd_pd( a1v.v, chi1v.v, y0v.v );
415 		y0v.v = _mm256_fmadd_pd( a2v.v, chi2v.v, y0v.v );
416 		y0v.v = _mm256_fmadd_pd( a3v.v, chi3v.v, y0v.v );
417 		y0v.v = _mm256_fmadd_pd( a4v.v, chi4v.v, y0v.v );
418 		y0v.v = _mm256_fmadd_pd( a5v.v, chi5v.v, y0v.v );
419 		y0v.v = _mm256_fmadd_pd( a6v.v, chi6v.v, y0v.v );
420 		y0v.v = _mm256_fmadd_pd( a7v.v, chi7v.v, y0v.v );
421 
422 		// Store the output.
423 		_mm256_storeu_pd( (y0 + 0*n_elem_per_reg), y0v.v );
424 
425 		y0 += n_elem_per_reg;
426 		a0 += n_elem_per_reg;
427 		a1 += n_elem_per_reg;
428 		a2 += n_elem_per_reg;
429 		a3 += n_elem_per_reg;
430 		a4 += n_elem_per_reg;
431 		a5 += n_elem_per_reg;
432 		a6 += n_elem_per_reg;
433 		a7 += n_elem_per_reg;
434 	}
435 
436 	// If there are leftover iterations, perform them with scalar code.
437 	for ( i = 0; i < m_left ; ++i )
438 	{
439 		double       y0c = *y0;
440 
441 		const double a0c = *a0;
442 		const double a1c = *a1;
443 		const double a2c = *a2;
444 		const double a3c = *a3;
445 		const double a4c = *a4;
446 		const double a5c = *a5;
447 		const double a6c = *a6;
448 		const double a7c = *a7;
449 
450 		y0c += chi0 * a0c;
451 		y0c += chi1 * a1c;
452 		y0c += chi2 * a2c;
453 		y0c += chi3 * a3c;
454 		y0c += chi4 * a4c;
455 		y0c += chi5 * a5c;
456 		y0c += chi6 * a6c;
457 		y0c += chi7 * a7c;
458 
459 		*y0 = y0c;
460 
461 		a0 += inca;
462 		a1 += inca;
463 		a2 += inca;
464 		a3 += inca;
465 		a4 += inca;
466 		a5 += inca;
467 		a6 += inca;
468 		a7 += inca;
469 		y0 += incy;
470 	}
471 }
472 
473