1 /*
2 
3    BLIS
4    An object-based framework for developing high-performance BLAS-like
5    libraries.
6 
7    Copyright (C) 2018, The University of Texas at Austin
8    Copyright (C) 2016 - 2018, Advanced Micro Devices, Inc.
9 
10    Redistribution and use in source and binary forms, with or without
11    modification, are permitted provided that the following conditions are
12    met:
13     - Redistributions of source code must retain the above copyright
14       notice, this list of conditions and the following disclaimer.
15     - Redistributions in binary form must reproduce the above copyright
16       notice, this list of conditions and the following disclaimer in the
17       documentation and/or other materials provided with the distribution.
18     - Neither the name(s) of the copyright holder(s) nor the names of its
19       contributors may be used to endorse or promote products derived
20       from this software without specific prior written permission.
21 
22    THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
23    "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
24    LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
25    A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
26    HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
27    SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
28    LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
29    DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
30    THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
31    (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
32    OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
33 
34 */
35 
36 #include "immintrin.h"
37 #include "blis.h"
38 
39 /* Union data structure to access AVX registers
40    One 256-bit AVX register holds 8 SP elements. */
41 typedef union
42 {
43 	__m256  v;
44 	float   f[8] __attribute__((aligned(64)));
45 } v8sf_t;
46 
47 /* Union data structure to access AVX registers
48 *  One 256-bit AVX register holds 4 DP elements. */
49 typedef union
50 {
51 	__m256d v;
52 	double  d[4] __attribute__((aligned(64)));
53 } v4df_t;
54 
55 // -----------------------------------------------------------------------------
56 
bli_sdotxf_zen_int_8(conj_t conjat,conj_t conjx,dim_t m,dim_t b_n,float * restrict alpha,float * restrict a,inc_t inca,inc_t lda,float * restrict x,inc_t incx,float * restrict beta,float * restrict y,inc_t incy,cntx_t * restrict cntx)57 void bli_sdotxf_zen_int_8
58      (
59        conj_t           conjat,
60        conj_t           conjx,
61        dim_t            m,
62        dim_t            b_n,
63        float*  restrict alpha,
64        float*  restrict a, inc_t inca, inc_t lda,
65        float*  restrict x, inc_t incx,
66        float*  restrict beta,
67        float*  restrict y, inc_t incy,
68        cntx_t* restrict cntx
69      )
70 {
71 	const dim_t fuse_fac       = 8;
72 	const dim_t n_elem_per_reg = 8;
73 
74 	// If the b_n dimension is zero, y is empty and there is no computation.
75 	if ( bli_zero_dim1( b_n ) ) return;
76 
77 	// If the m dimension is zero, or if alpha is zero, the computation
78 	// simplifies to updating y.
79 	if ( bli_zero_dim1( m ) || PASTEMAC(s,eq0)( *alpha ) )
80 	{
81 		sscalv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_FLOAT, BLIS_SCALV_KER, cntx );
82 
83 		f
84 		(
85 		  BLIS_NO_CONJUGATE,
86 		  b_n,
87 		  beta,
88 		  y, incy,
89 		  cntx
90 		);
91 		return;
92 	}
93 
94 	// If b_n is not equal to the fusing factor, then perform the entire
95 	// operation as a loop over dotxv.
96 	if ( b_n != fuse_fac )
97 	{
98 		sdotxv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_FLOAT, BLIS_DOTXV_KER, cntx );
99 
100 		for ( dim_t i = 0; i < b_n; ++i )
101 		{
102 			float* a1   = a + (0  )*inca + (i  )*lda;
103 			float* x1   = x + (0  )*incx;
104 			float* psi1 = y + (i  )*incy;
105 
106 			f
107 			(
108 			  conjat,
109 			  conjx,
110 			  m,
111 			  alpha,
112 			  a1, inca,
113 			  x1, incx,
114 			  beta,
115 			  psi1,
116 			  cntx
117 			);
118 		}
119 		return;
120 	}
121 
122 	// At this point, we know that b_n is exactly equal to the fusing factor.
123 	// However, m may not be a multiple of the number of elements per vector.
124 
125 	// Going forward, we handle two possible storage formats of A explicitly:
126 	// (1) A is stored by columns, or (2) A is stored by rows. Either case is
127 	// further split into two subproblems along the m dimension:
128 	// (a) a vectorized part, starting at m = 0 and ending at any 0 <= m' <= m.
129 	// (b) a scalar part, starting at m' and ending at m. If no vectorization
130 	//     is possible then m' == 0 and thus the scalar part is the entire
131 	//     problem. If 0 < m', then the a and x pointers and m variable will
132 	//     be adjusted accordingly for the second subproblem.
133 	// Note: since parts (b) for both (1) and (2) are so similar, they are
134 	// factored out into one code block after the following conditional, which
135 	// distinguishes between (1) and (2).
136 
137 	// Intermediate variables to hold the completed dot products
138     float rho0 = 0, rho1 = 0, rho2 = 0, rho3 = 0,
139 	      rho4 = 0, rho5 = 0, rho6 = 0, rho7 = 0;
140 
141 	if ( inca == 1 && incx == 1 )
142 	{
143 		const dim_t n_iter_unroll = 1;
144 
145 		// Use the unrolling factor and the number of elements per register
146 		// to compute the number of vectorized and leftover iterations.
147 		dim_t m_viter = ( m ) / ( n_elem_per_reg * n_iter_unroll );
148 
149 		// Set up pointers for x and the b_n columns of A (rows of A^T).
150 		float* restrict x0 = x;
151 		float* restrict a0 = a + 0*lda;
152 		float* restrict a1 = a + 1*lda;
153 		float* restrict a2 = a + 2*lda;
154 		float* restrict a3 = a + 3*lda;
155 		float* restrict a4 = a + 4*lda;
156 		float* restrict a5 = a + 5*lda;
157 		float* restrict a6 = a + 6*lda;
158 		float* restrict a7 = a + 7*lda;
159 
160 		// Initialize b_n rho vector accumulators to zero.
161 		v8sf_t rho0v; rho0v.v = _mm256_setzero_ps();
162 		v8sf_t rho1v; rho1v.v = _mm256_setzero_ps();
163 		v8sf_t rho2v; rho2v.v = _mm256_setzero_ps();
164 		v8sf_t rho3v; rho3v.v = _mm256_setzero_ps();
165 		v8sf_t rho4v; rho4v.v = _mm256_setzero_ps();
166 		v8sf_t rho5v; rho5v.v = _mm256_setzero_ps();
167 		v8sf_t rho6v; rho6v.v = _mm256_setzero_ps();
168 		v8sf_t rho7v; rho7v.v = _mm256_setzero_ps();
169 
170 		v8sf_t x0v;
171 		v8sf_t a0v, a1v, a2v, a3v, a4v, a5v, a6v, a7v;
172 
173 		// If there are vectorized iterations, perform them with vector
174 		// instructions.
175 		for ( dim_t i = 0; i < m_viter; ++i )
176 		{
177 			// Load the input values.
178 			x0v.v = _mm256_loadu_ps( x0 + 0*n_elem_per_reg );
179 
180 			a0v.v = _mm256_loadu_ps( a0 + 0*n_elem_per_reg );
181 			a1v.v = _mm256_loadu_ps( a1 + 0*n_elem_per_reg );
182 			a2v.v = _mm256_loadu_ps( a2 + 0*n_elem_per_reg );
183 			a3v.v = _mm256_loadu_ps( a3 + 0*n_elem_per_reg );
184 			a4v.v = _mm256_loadu_ps( a4 + 0*n_elem_per_reg );
185 			a5v.v = _mm256_loadu_ps( a5 + 0*n_elem_per_reg );
186 			a6v.v = _mm256_loadu_ps( a6 + 0*n_elem_per_reg );
187 			a7v.v = _mm256_loadu_ps( a7 + 0*n_elem_per_reg );
188 
189 			// perform: rho?v += a?v * x0v;
190 			rho0v.v = _mm256_fmadd_ps( a0v.v, x0v.v, rho0v.v );
191 			rho1v.v = _mm256_fmadd_ps( a1v.v, x0v.v, rho1v.v );
192 			rho2v.v = _mm256_fmadd_ps( a2v.v, x0v.v, rho2v.v );
193 			rho3v.v = _mm256_fmadd_ps( a3v.v, x0v.v, rho3v.v );
194 			rho4v.v = _mm256_fmadd_ps( a4v.v, x0v.v, rho4v.v );
195 			rho5v.v = _mm256_fmadd_ps( a5v.v, x0v.v, rho5v.v );
196 			rho6v.v = _mm256_fmadd_ps( a6v.v, x0v.v, rho6v.v );
197 			rho7v.v = _mm256_fmadd_ps( a7v.v, x0v.v, rho7v.v );
198 
199 			x0 += n_elem_per_reg * n_iter_unroll;
200 			a0 += n_elem_per_reg * n_iter_unroll;
201 			a1 += n_elem_per_reg * n_iter_unroll;
202 			a2 += n_elem_per_reg * n_iter_unroll;
203 			a3 += n_elem_per_reg * n_iter_unroll;
204 			a4 += n_elem_per_reg * n_iter_unroll;
205 			a5 += n_elem_per_reg * n_iter_unroll;
206 			a6 += n_elem_per_reg * n_iter_unroll;
207 			a7 += n_elem_per_reg * n_iter_unroll;
208 		}
209 
210 #if 0
211 		rho0 += rho0v.f[0] + rho0v.f[1] + rho0v.f[2] + rho0v.f[3] +
212 		        rho0v.f[4] + rho0v.f[5] + rho0v.f[6] + rho0v.f[7];
213 		rho1 += rho1v.f[0] + rho1v.f[1] + rho1v.f[2] + rho1v.f[3] +
214 		        rho1v.f[4] + rho1v.f[5] + rho1v.f[6] + rho1v.f[7];
215 		rho2 += rho2v.f[0] + rho2v.f[1] + rho2v.f[2] + rho2v.f[3] +
216 		        rho2v.f[4] + rho2v.f[5] + rho2v.f[6] + rho2v.f[7];
217 		rho3 += rho3v.f[0] + rho3v.f[1] + rho3v.f[2] + rho3v.f[3] +
218 		        rho3v.f[4] + rho3v.f[5] + rho3v.f[6] + rho3v.f[7];
219 		rho4 += rho4v.f[0] + rho4v.f[1] + rho4v.f[2] + rho4v.f[3] +
220 		        rho4v.f[4] + rho4v.f[5] + rho4v.f[6] + rho4v.f[7];
221 		rho5 += rho5v.f[0] + rho5v.f[1] + rho5v.f[2] + rho5v.f[3] +
222 		        rho5v.f[4] + rho5v.f[5] + rho5v.f[6] + rho5v.f[7];
223 		rho6 += rho6v.f[0] + rho6v.f[1] + rho6v.f[2] + rho6v.f[3] +
224 		        rho6v.f[4] + rho6v.f[5] + rho6v.f[6] + rho6v.f[7];
225 		rho7 += rho7v.f[0] + rho7v.f[1] + rho7v.f[2] + rho7v.f[3] +
226 		        rho7v.f[4] + rho7v.f[5] + rho7v.f[6] + rho7v.f[7];
227 #else
228 		// Now we need to sum the elements within each vector.
229 
230 		v8sf_t onev; onev.v = _mm256_set1_ps( 1.0f );
231 
232 		// Sum the elements of a given rho?v by dotting it with 1. The '1' in
233 		// '0xf1' stores the sum of the upper four and lower four values to
234 		// the low elements of each lane: elements 4 and 0, respectively. (The
235 		// 'f' in '0xf1' means include all four elements of each lane in the
236 		// summation.)
237 		rho0v.v = _mm256_dp_ps( rho0v.v, onev.v, 0xf1 );
238 		rho1v.v = _mm256_dp_ps( rho1v.v, onev.v, 0xf1 );
239 		rho2v.v = _mm256_dp_ps( rho2v.v, onev.v, 0xf1 );
240 		rho3v.v = _mm256_dp_ps( rho3v.v, onev.v, 0xf1 );
241 		rho4v.v = _mm256_dp_ps( rho4v.v, onev.v, 0xf1 );
242 		rho5v.v = _mm256_dp_ps( rho5v.v, onev.v, 0xf1 );
243 		rho6v.v = _mm256_dp_ps( rho6v.v, onev.v, 0xf1 );
244 		rho7v.v = _mm256_dp_ps( rho7v.v, onev.v, 0xf1 );
245 
246 		// Manually add the results from above to finish the sum.
247 		rho0    = rho0v.f[0] + rho0v.f[4];
248 		rho1    = rho1v.f[0] + rho1v.f[4];
249 		rho2    = rho2v.f[0] + rho2v.f[4];
250 		rho3    = rho3v.f[0] + rho3v.f[4];
251 		rho4    = rho4v.f[0] + rho4v.f[4];
252 		rho5    = rho5v.f[0] + rho5v.f[4];
253 		rho6    = rho6v.f[0] + rho6v.f[4];
254 		rho7    = rho7v.f[0] + rho7v.f[4];
255 #endif
256 
257 		// Adjust for scalar subproblem.
258 		m -= n_elem_per_reg * n_iter_unroll * m_viter;
259 		a += n_elem_per_reg * n_iter_unroll * m_viter /* * inca */;
260 		x += n_elem_per_reg * n_iter_unroll * m_viter /* * incx */;
261 	}
262 	else if ( lda == 1 )
263 	{
264 		const dim_t n_iter_unroll = 4;
265 
266 		// Use the unrolling factor and the number of elements per register
267 		// to compute the number of vectorized and leftover iterations.
268 		dim_t m_viter = ( m ) / ( n_iter_unroll );
269 
270 		// Initialize pointers for x and A.
271 		float* restrict x0 = x;
272 		float* restrict a0 = a;
273 
274 		// Initialize rho vector accumulators to zero.
275 		v8sf_t rho0v; rho0v.v = _mm256_setzero_ps();
276 		v8sf_t rho1v; rho1v.v = _mm256_setzero_ps();
277 		v8sf_t rho2v; rho2v.v = _mm256_setzero_ps();
278 		v8sf_t rho3v; rho3v.v = _mm256_setzero_ps();
279 
280 		v8sf_t x0v, x1v, x2v, x3v;
281 		v8sf_t a0v, a1v, a2v, a3v;
282 
283 		for ( dim_t i = 0; i < m_viter; ++i )
284 		{
285 			// Load the input values.
286 			a0v.v = _mm256_loadu_ps( a0 + 0*inca );
287 			a1v.v = _mm256_loadu_ps( a0 + 1*inca );
288 			a2v.v = _mm256_loadu_ps( a0 + 2*inca );
289 			a3v.v = _mm256_loadu_ps( a0 + 3*inca );
290 
291 			x0v.v = _mm256_broadcast_ss( x0 + 0*incx );
292 			x1v.v = _mm256_broadcast_ss( x0 + 1*incx );
293 			x2v.v = _mm256_broadcast_ss( x0 + 2*incx );
294 			x3v.v = _mm256_broadcast_ss( x0 + 3*incx );
295 
296 			// perform : rho?v += a?v * x?v;
297 			rho0v.v = _mm256_fmadd_ps( a0v.v, x0v.v, rho0v.v );
298 			rho1v.v = _mm256_fmadd_ps( a1v.v, x1v.v, rho1v.v );
299 			rho2v.v = _mm256_fmadd_ps( a2v.v, x2v.v, rho2v.v );
300 			rho3v.v = _mm256_fmadd_ps( a3v.v, x3v.v, rho3v.v );
301 
302 			x0 += incx * n_iter_unroll;
303 			a0 += inca * n_iter_unroll;
304 		}
305 
306 		// Combine the 8 accumulators into one vector register.
307 		rho0v.v = _mm256_add_ps( rho0v.v, rho1v.v );
308 		rho2v.v = _mm256_add_ps( rho2v.v, rho3v.v );
309 		rho0v.v = _mm256_add_ps( rho0v.v, rho2v.v );
310 
311 		// Write vector components to scalar values.
312 		rho0 = rho0v.f[0];
313 		rho1 = rho0v.f[1];
314 		rho2 = rho0v.f[2];
315 		rho3 = rho0v.f[3];
316 		rho4 = rho0v.f[4];
317 		rho5 = rho0v.f[5];
318 		rho6 = rho0v.f[6];
319 		rho7 = rho0v.f[7];
320 
321 		// Adjust for scalar subproblem.
322 		m -= n_iter_unroll * m_viter;
323 		a += n_iter_unroll * m_viter * inca;
324 		x += n_iter_unroll * m_viter * incx;
325 	}
326 	else
327 	{
328 		// No vectorization possible; use scalar iterations for the entire
329 		// problem.
330 	}
331 
332 	// Scalar edge case.
333 	{
334 		// Initialize pointers for x and the b_n columns of A (rows of A^T).
335 		float* restrict x0 = x;
336 		float* restrict a0 = a + 0*lda;
337 		float* restrict a1 = a + 1*lda;
338 		float* restrict a2 = a + 2*lda;
339 		float* restrict a3 = a + 3*lda;
340 		float* restrict a4 = a + 4*lda;
341 		float* restrict a5 = a + 5*lda;
342 		float* restrict a6 = a + 6*lda;
343 		float* restrict a7 = a + 7*lda;
344 
345 		// If there are leftover iterations, perform them with scalar code.
346 		for ( dim_t i = 0; i < m ; ++i )
347 		{
348 			const float x0c = *x0;
349 
350 			const float a0c = *a0;
351 			const float a1c = *a1;
352 			const float a2c = *a2;
353 			const float a3c = *a3;
354 			const float a4c = *a4;
355 			const float a5c = *a5;
356 			const float a6c = *a6;
357 			const float a7c = *a7;
358 
359 			rho0 += a0c * x0c;
360 			rho1 += a1c * x0c;
361 			rho2 += a2c * x0c;
362 			rho3 += a3c * x0c;
363 			rho4 += a4c * x0c;
364 			rho5 += a5c * x0c;
365 			rho6 += a6c * x0c;
366 			rho7 += a7c * x0c;
367 
368 			x0 += incx;
369 			a0 += inca;
370 			a1 += inca;
371 			a2 += inca;
372 			a3 += inca;
373 			a4 += inca;
374 			a5 += inca;
375 			a6 += inca;
376 			a7 += inca;
377 		}
378 	}
379 
380 	// Now prepare the final rho values to output/accumulate back into
381 	// the y vector.
382 
383 	v8sf_t rho0v, y0v;
384 
385 	// Insert the scalar rho values into a single vector.
386 	rho0v.f[0] = rho0;
387 	rho0v.f[1] = rho1;
388 	rho0v.f[2] = rho2;
389 	rho0v.f[3] = rho3;
390 	rho0v.f[4] = rho4;
391 	rho0v.f[5] = rho5;
392 	rho0v.f[6] = rho6;
393 	rho0v.f[7] = rho7;
394 
395 	// Broadcast the alpha scalar.
396 	v8sf_t alphav; alphav.v = _mm256_broadcast_ss( alpha );
397 
398 	// We know at this point that alpha is nonzero; however, beta may still
399 	// be zero. If beta is indeed zero, we must overwrite y rather than scale
400 	// by beta (in case y contains NaN or Inf).
401 	if ( PASTEMAC(s,eq0)( *beta ) )
402 	{
403 		// Apply alpha to the accumulated dot product in rho:
404 		//   y := alpha * rho
405 		y0v.v = _mm256_mul_ps( alphav.v, rho0v.v );
406 	}
407 	else
408 	{
409 		// Broadcast the beta scalar.
410 		v8sf_t betav; betav.v = _mm256_broadcast_ss( beta );
411 
412 		// Load y.
413 		if ( incy == 1 )
414 		{
415 			y0v.v = _mm256_loadu_ps( y + 0*n_elem_per_reg );
416 		}
417 		else
418 		{
419 			y0v.f[0] = *(y + 0*incy); y0v.f[1] = *(y + 1*incy);
420 			y0v.f[2] = *(y + 2*incy); y0v.f[3] = *(y + 3*incy);
421 			y0v.f[4] = *(y + 4*incy); y0v.f[5] = *(y + 5*incy);
422 			y0v.f[6] = *(y + 6*incy); y0v.f[7] = *(y + 7*incy);
423 		}
424 
425 		// Apply beta to y and alpha to the accumulated dot product in rho:
426 		//   y := beta * y + alpha * rho
427 		y0v.v = _mm256_mul_ps( betav.v, y0v.v );
428 		y0v.v = _mm256_fmadd_ps( alphav.v, rho0v.v, y0v.v );
429 	}
430 
431 	// Store the output.
432 	if ( incy == 1 )
433 	{
434 		_mm256_storeu_ps( (y + 0*n_elem_per_reg), y0v.v );
435 	}
436 	else
437 	{
438 		*(y + 0*incy) = y0v.f[0]; *(y + 1*incy) = y0v.f[1];
439 		*(y + 2*incy) = y0v.f[2]; *(y + 3*incy) = y0v.f[3];
440 		*(y + 4*incy) = y0v.f[4]; *(y + 5*incy) = y0v.f[5];
441 		*(y + 6*incy) = y0v.f[6]; *(y + 7*incy) = y0v.f[7];
442 	}
443 }
444 
445 // -----------------------------------------------------------------------------
446 
bli_ddotxf_zen_int_8(conj_t conjat,conj_t conjx,dim_t m,dim_t b_n,double * restrict alpha,double * restrict a,inc_t inca,inc_t lda,double * restrict x,inc_t incx,double * restrict beta,double * restrict y,inc_t incy,cntx_t * restrict cntx)447 void bli_ddotxf_zen_int_8
448      (
449        conj_t           conjat,
450        conj_t           conjx,
451        dim_t            m,
452        dim_t            b_n,
453        double* restrict alpha,
454        double* restrict a, inc_t inca, inc_t lda,
455        double* restrict x, inc_t incx,
456        double* restrict beta,
457        double* restrict y, inc_t incy,
458        cntx_t* restrict cntx
459      )
460 {
461 	const dim_t      fuse_fac       = 8;
462 	const dim_t      n_elem_per_reg = 4;
463 
464 	// If the b_n dimension is zero, y is empty and there is no computation.
465 	if ( bli_zero_dim1( b_n ) ) return;
466 
467 	// If the m dimension is zero, or if alpha is zero, the computation
468 	// simplifies to updating y.
469 	if ( bli_zero_dim1( m ) || PASTEMAC(d,eq0)( *alpha ) )
470 	{
471 		dscalv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DOUBLE, BLIS_SCALV_KER, cntx );
472 
473 		f
474 		(
475 		  BLIS_NO_CONJUGATE,
476 		  b_n,
477 		  beta,
478 		  y, incy,
479 		  cntx
480 		);
481 		return;
482 	}
483 
484 	// If b_n is not equal to the fusing factor, then perform the entire
485 	// operation as a loop over dotxv.
486 	if ( b_n != fuse_fac )
487 	{
488 		ddotxv_ker_ft f = bli_cntx_get_l1v_ker_dt( BLIS_DOUBLE, BLIS_DOTXV_KER, cntx );
489 
490 		for ( dim_t i = 0; i < b_n; ++i )
491 		{
492 			double* a1   = a + (0  )*inca + (i  )*lda;
493 			double* x1   = x + (0  )*incx;
494 			double* psi1 = y + (i  )*incy;
495 
496 			f
497 			(
498 			  conjat,
499 			  conjx,
500 			  m,
501 			  alpha,
502 			  a1, inca,
503 			  x1, incx,
504 			  beta,
505 			  psi1,
506 			  cntx
507 			);
508 		}
509 		return;
510 	}
511 
512 	// At this point, we know that b_n is exactly equal to the fusing factor.
513 	// However, m may not be a multiple of the number of elements per vector.
514 
515 	// Going forward, we handle two possible storage formats of A explicitly:
516 	// (1) A is stored by columns, or (2) A is stored by rows. Either case is
517 	// further split into two subproblems along the m dimension:
518 	// (a) a vectorized part, starting at m = 0 and ending at any 0 <= m' <= m.
519 	// (b) a scalar part, starting at m' and ending at m. If no vectorization
520 	//     is possible then m' == 0 and thus the scalar part is the entire
521 	//     problem. If 0 < m', then the a and x pointers and m variable will
522 	//     be adjusted accordingly for the second subproblem.
523 	// Note: since parts (b) for both (1) and (2) are so similar, they are
524 	// factored out into one code block after the following conditional, which
525 	// distinguishes between (1) and (2).
526 
527 	// Intermediate variables to hold the completed dot products
528 	double rho0 = 0, rho1 = 0, rho2 = 0, rho3 = 0,
529 	       rho4 = 0, rho5 = 0, rho6 = 0, rho7 = 0;
530 
531 	if ( inca == 1 && incx == 1 )
532 	{
533 		const dim_t n_iter_unroll = 1;
534 
535 		// Use the unrolling factor and the number of elements per register
536 		// to compute the number of vectorized and leftover iterations.
537 		dim_t m_viter = ( m ) / ( n_elem_per_reg * n_iter_unroll );
538 
539 		// Set up pointers for x and the b_n columns of A (rows of A^T).
540 		double* restrict x0 = x;
541 		double* restrict a0 = a + 0*lda;
542 		double* restrict a1 = a + 1*lda;
543 		double* restrict a2 = a + 2*lda;
544 		double* restrict a3 = a + 3*lda;
545 		double* restrict a4 = a + 4*lda;
546 		double* restrict a5 = a + 5*lda;
547 		double* restrict a6 = a + 6*lda;
548 		double* restrict a7 = a + 7*lda;
549 
550 		// Initialize b_n rho vector accumulators to zero.
551 		v4df_t rho0v; rho0v.v = _mm256_setzero_pd();
552 		v4df_t rho1v; rho1v.v = _mm256_setzero_pd();
553 		v4df_t rho2v; rho2v.v = _mm256_setzero_pd();
554 		v4df_t rho3v; rho3v.v = _mm256_setzero_pd();
555 		v4df_t rho4v; rho4v.v = _mm256_setzero_pd();
556 		v4df_t rho5v; rho5v.v = _mm256_setzero_pd();
557 		v4df_t rho6v; rho6v.v = _mm256_setzero_pd();
558 		v4df_t rho7v; rho7v.v = _mm256_setzero_pd();
559 
560 		v4df_t x0v;
561 		v4df_t a0v, a1v, a2v, a3v, a4v, a5v, a6v, a7v;
562 
563 		// If there are vectorized iterations, perform them with vector
564 		// instructions.
565 		for ( dim_t i = 0; i < m_viter; ++i )
566 		{
567 			// Load the input values.
568 			x0v.v = _mm256_loadu_pd( x0 + 0*n_elem_per_reg );
569 
570 			a0v.v = _mm256_loadu_pd( a0 + 0*n_elem_per_reg );
571 			a1v.v = _mm256_loadu_pd( a1 + 0*n_elem_per_reg );
572 			a2v.v = _mm256_loadu_pd( a2 + 0*n_elem_per_reg );
573 			a3v.v = _mm256_loadu_pd( a3 + 0*n_elem_per_reg );
574 			a4v.v = _mm256_loadu_pd( a4 + 0*n_elem_per_reg );
575 			a5v.v = _mm256_loadu_pd( a5 + 0*n_elem_per_reg );
576 			a6v.v = _mm256_loadu_pd( a6 + 0*n_elem_per_reg );
577 			a7v.v = _mm256_loadu_pd( a7 + 0*n_elem_per_reg );
578 
579 			// perform: rho?v += a?v * x0v;
580 			rho0v.v = _mm256_fmadd_pd( a0v.v, x0v.v, rho0v.v );
581 			rho1v.v = _mm256_fmadd_pd( a1v.v, x0v.v, rho1v.v );
582 			rho2v.v = _mm256_fmadd_pd( a2v.v, x0v.v, rho2v.v );
583 			rho3v.v = _mm256_fmadd_pd( a3v.v, x0v.v, rho3v.v );
584 			rho4v.v = _mm256_fmadd_pd( a4v.v, x0v.v, rho4v.v );
585 			rho5v.v = _mm256_fmadd_pd( a5v.v, x0v.v, rho5v.v );
586 			rho6v.v = _mm256_fmadd_pd( a6v.v, x0v.v, rho6v.v );
587 			rho7v.v = _mm256_fmadd_pd( a7v.v, x0v.v, rho7v.v );
588 
589 			x0 += n_elem_per_reg * n_iter_unroll;
590 			a0 += n_elem_per_reg * n_iter_unroll;
591 			a1 += n_elem_per_reg * n_iter_unroll;
592 			a2 += n_elem_per_reg * n_iter_unroll;
593 			a3 += n_elem_per_reg * n_iter_unroll;
594 			a4 += n_elem_per_reg * n_iter_unroll;
595 			a5 += n_elem_per_reg * n_iter_unroll;
596 			a6 += n_elem_per_reg * n_iter_unroll;
597 			a7 += n_elem_per_reg * n_iter_unroll;
598 		}
599 
600 #if 0
601 		rho0 += rho0v.d[0] + rho0v.d[1] + rho0v.d[2] + rho0v.d[3];
602 		rho1 += rho1v.d[0] + rho1v.d[1] + rho1v.d[2] + rho1v.d[3];
603 		rho2 += rho2v.d[0] + rho2v.d[1] + rho2v.d[2] + rho2v.d[3];
604 		rho3 += rho3v.d[0] + rho3v.d[1] + rho3v.d[2] + rho3v.d[3];
605 		rho4 += rho4v.d[0] + rho4v.d[1] + rho4v.d[2] + rho4v.d[3];
606 		rho5 += rho5v.d[0] + rho5v.d[1] + rho5v.d[2] + rho5v.d[3];
607 		rho6 += rho6v.d[0] + rho6v.d[1] + rho6v.d[2] + rho6v.d[3];
608 		rho7 += rho7v.d[0] + rho7v.d[1] + rho7v.d[2] + rho7v.d[3];
609 #else
610 		// Sum the elements of a given rho?v. This computes the sum of
611 		// elements within lanes and stores the sum to both elements.
612 		rho0v.v = _mm256_hadd_pd( rho0v.v, rho0v.v );
613 		rho1v.v = _mm256_hadd_pd( rho1v.v, rho1v.v );
614 		rho2v.v = _mm256_hadd_pd( rho2v.v, rho2v.v );
615 		rho3v.v = _mm256_hadd_pd( rho3v.v, rho3v.v );
616 		rho4v.v = _mm256_hadd_pd( rho4v.v, rho4v.v );
617 		rho5v.v = _mm256_hadd_pd( rho5v.v, rho5v.v );
618 		rho6v.v = _mm256_hadd_pd( rho6v.v, rho6v.v );
619 		rho7v.v = _mm256_hadd_pd( rho7v.v, rho7v.v );
620 
621 		// Manually add the results from above to finish the sum.
622 		rho0 = rho0v.d[0] + rho0v.d[2];
623 		rho1 = rho1v.d[0] + rho1v.d[2];
624 		rho2 = rho2v.d[0] + rho2v.d[2];
625 		rho3 = rho3v.d[0] + rho3v.d[2];
626 		rho4 = rho4v.d[0] + rho4v.d[2];
627 		rho5 = rho5v.d[0] + rho5v.d[2];
628 		rho6 = rho6v.d[0] + rho6v.d[2];
629 		rho7 = rho7v.d[0] + rho7v.d[2];
630 #endif
631 		// Adjust for scalar subproblem.
632 		m -= n_elem_per_reg * n_iter_unroll * m_viter;
633 		a += n_elem_per_reg * n_iter_unroll * m_viter /* * inca */;
634 		x += n_elem_per_reg * n_iter_unroll * m_viter /* * incx */;
635 	}
636 	else if ( lda == 1 )
637 	{
638 		const dim_t n_iter_unroll = 3;
639 		const dim_t n_reg_per_row = 2; // fuse_fac / n_elem_per_reg;
640 
641 		// Use the unrolling factor and the number of elements per register
642 		// to compute the number of vectorized and leftover iterations.
643 		dim_t m_viter = ( m ) / ( n_reg_per_row * n_iter_unroll );
644 
645 		// Initialize pointers for x and A.
646 		double* restrict x0 = x;
647 		double* restrict a0 = a;
648 
649 		// Initialize rho vector accumulators to zero.
650 		v4df_t rho0v; rho0v.v = _mm256_setzero_pd();
651 		v4df_t rho1v; rho1v.v = _mm256_setzero_pd();
652 		v4df_t rho2v; rho2v.v = _mm256_setzero_pd();
653 		v4df_t rho3v; rho3v.v = _mm256_setzero_pd();
654 		v4df_t rho4v; rho4v.v = _mm256_setzero_pd();
655 		v4df_t rho5v; rho5v.v = _mm256_setzero_pd();
656 
657 		v4df_t x0v, x1v, x2v;
658 		v4df_t a0v, a1v, a2v, a3v, a4v, a5v;
659 
660 		for ( dim_t i = 0; i < m_viter; ++i )
661 		{
662 			// Load the input values.
663 			a0v.v = _mm256_loadu_pd( a0 + 0*inca + 0*n_elem_per_reg );
664 			a1v.v = _mm256_loadu_pd( a0 + 0*inca + 1*n_elem_per_reg );
665 			a2v.v = _mm256_loadu_pd( a0 + 1*inca + 0*n_elem_per_reg );
666 			a3v.v = _mm256_loadu_pd( a0 + 1*inca + 1*n_elem_per_reg );
667 			a4v.v = _mm256_loadu_pd( a0 + 2*inca + 0*n_elem_per_reg );
668 			a5v.v = _mm256_loadu_pd( a0 + 2*inca + 1*n_elem_per_reg );
669 
670 			x0v.v = _mm256_broadcast_sd( x0 + 0*incx );
671 			x1v.v = _mm256_broadcast_sd( x0 + 1*incx );
672 			x2v.v = _mm256_broadcast_sd( x0 + 2*incx );
673 
674 			// perform : rho?v += a?v * x?v;
675 			rho0v.v = _mm256_fmadd_pd( a0v.v, x0v.v, rho0v.v );
676 			rho1v.v = _mm256_fmadd_pd( a1v.v, x0v.v, rho1v.v );
677 			rho2v.v = _mm256_fmadd_pd( a2v.v, x1v.v, rho2v.v );
678 			rho3v.v = _mm256_fmadd_pd( a3v.v, x1v.v, rho3v.v );
679 			rho4v.v = _mm256_fmadd_pd( a4v.v, x2v.v, rho4v.v );
680 			rho5v.v = _mm256_fmadd_pd( a5v.v, x2v.v, rho5v.v );
681 
682 			x0 += incx * n_iter_unroll;
683 			a0 += inca * n_iter_unroll;
684 		}
685 
686 		// Combine the 8 accumulators into one vector register.
687 		rho0v.v = _mm256_add_pd( rho0v.v, rho2v.v );
688 		rho0v.v = _mm256_add_pd( rho0v.v, rho4v.v );
689 		rho1v.v = _mm256_add_pd( rho1v.v, rho3v.v );
690 		rho1v.v = _mm256_add_pd( rho1v.v, rho5v.v );
691 
692 		// Write vector components to scalar values.
693 		rho0 = rho0v.d[0];
694 		rho1 = rho0v.d[1];
695 		rho2 = rho0v.d[2];
696 		rho3 = rho0v.d[3];
697 		rho4 = rho1v.d[0];
698 		rho5 = rho1v.d[1];
699 		rho6 = rho1v.d[2];
700 		rho7 = rho1v.d[3];
701 
702 		// Adjust for scalar subproblem.
703 		m -= n_iter_unroll * m_viter;
704 		a += n_iter_unroll * m_viter * inca;
705 		x += n_iter_unroll * m_viter * incx;
706 	}
707 	else
708 	{
709 		// No vectorization possible; use scalar iterations for the entire
710 		// problem.
711 	}
712 
713 	// Scalar edge case.
714 	{
715 		// Initialize pointers for x and the b_n columns of A (rows of A^T).
716 		double* restrict x0 = x;
717 		double* restrict a0 = a + 0*lda;
718 		double* restrict a1 = a + 1*lda;
719 		double* restrict a2 = a + 2*lda;
720 		double* restrict a3 = a + 3*lda;
721 		double* restrict a4 = a + 4*lda;
722 		double* restrict a5 = a + 5*lda;
723 		double* restrict a6 = a + 6*lda;
724 		double* restrict a7 = a + 7*lda;
725 
726 		// If there are leftover iterations, perform them with scalar code.
727 		for ( dim_t i = 0; i < m ; ++i )
728 		{
729 			const double x0c = *x0;
730 
731 			const double a0c = *a0;
732 			const double a1c = *a1;
733 			const double a2c = *a2;
734 			const double a3c = *a3;
735 			const double a4c = *a4;
736 			const double a5c = *a5;
737 			const double a6c = *a6;
738 			const double a7c = *a7;
739 
740 			rho0 += a0c * x0c;
741 			rho1 += a1c * x0c;
742 			rho2 += a2c * x0c;
743 			rho3 += a3c * x0c;
744 			rho4 += a4c * x0c;
745 			rho5 += a5c * x0c;
746 			rho6 += a6c * x0c;
747 			rho7 += a7c * x0c;
748 
749 			x0 += incx;
750 			a0 += inca;
751 			a1 += inca;
752 			a2 += inca;
753 			a3 += inca;
754 			a4 += inca;
755 			a5 += inca;
756 			a6 += inca;
757 			a7 += inca;
758 		}
759 	}
760 
761 	// Now prepare the final rho values to output/accumulate back into
762 	// the y vector.
763 
764 	v4df_t rho0v, rho1v, y0v, y1v;
765 
766 	// Insert the scalar rho values into a single vector.
767 	rho0v.d[0] = rho0;
768 	rho0v.d[1] = rho1;
769 	rho0v.d[2] = rho2;
770 	rho0v.d[3] = rho3;
771 	rho1v.d[0] = rho4;
772 	rho1v.d[1] = rho5;
773 	rho1v.d[2] = rho6;
774 	rho1v.d[3] = rho7;
775 
776 	// Broadcast the alpha scalar.
777 	v4df_t alphav; alphav.v = _mm256_broadcast_sd( alpha );
778 
779 	// We know at this point that alpha is nonzero; however, beta may still
780 	// be zero. If beta is indeed zero, we must overwrite y rather than scale
781 	// by beta (in case y contains NaN or Inf).
782 	if ( PASTEMAC(d,eq0)( *beta ) )
783 	{
784 		// Apply alpha to the accumulated dot product in rho:
785 		//   y := alpha * rho
786 		y0v.v = _mm256_mul_pd( alphav.v, rho0v.v );
787 		y1v.v = _mm256_mul_pd( alphav.v, rho1v.v );
788 	}
789 	else
790 	{
791 		// Broadcast the beta scalar.
792 		v4df_t betav; betav.v = _mm256_broadcast_sd( beta );
793 
794 		// Load y.
795 		if ( incy == 1 )
796 		{
797 			y0v.v = _mm256_loadu_pd( y + 0*n_elem_per_reg );
798 			y1v.v = _mm256_loadu_pd( y + 1*n_elem_per_reg );
799 		}
800 		else
801 		{
802 			y0v.d[0] = *(y + 0*incy); y0v.d[1] = *(y + 1*incy);
803 			y0v.d[2] = *(y + 2*incy); y0v.d[3] = *(y + 3*incy);
804 			y1v.d[0] = *(y + 4*incy); y1v.d[1] = *(y + 5*incy);
805 			y1v.d[2] = *(y + 6*incy); y1v.d[3] = *(y + 7*incy);
806 		}
807 
808 		// Apply beta to y and alpha to the accumulated dot product in rho:
809 		//   y := beta * y + alpha * rho
810 		y0v.v = _mm256_mul_pd( betav.v, y0v.v );
811 		y1v.v = _mm256_mul_pd( betav.v, y1v.v );
812 		y0v.v = _mm256_fmadd_pd( alphav.v, rho0v.v, y0v.v );
813 		y1v.v = _mm256_fmadd_pd( alphav.v, rho1v.v, y1v.v );
814 	}
815 
816 	if ( incy == 1 )
817 	{
818 		// Store the output.
819 		_mm256_storeu_pd( (y + 0*n_elem_per_reg), y0v.v );
820 		_mm256_storeu_pd( (y + 1*n_elem_per_reg), y1v.v );
821 	}
822 	else
823 	{
824 		*(y + 0*incy) = y0v.d[0]; *(y + 1*incy) = y0v.d[1];
825 		*(y + 2*incy) = y0v.d[2]; *(y + 3*incy) = y0v.d[3];
826 		*(y + 4*incy) = y1v.d[0]; *(y + 5*incy) = y1v.d[1];
827 		*(y + 6*incy) = y1v.d[2]; *(y + 7*incy) = y1v.d[3];
828 	}
829 }
830 
831