1 //=================================================================================================
2 /*!
3 // \file blaze/math/dense/MMM.h
4 // \brief Header file for the dense matrix multiplication kernels
5 //
6 // Copyright (C) 2012-2020 Klaus Iglberger - All Rights Reserved
7 //
8 // This file is part of the Blaze library. You can redistribute it and/or modify it under
9 // the terms of the New (Revised) BSD License. Redistribution and use in source and binary
10 // forms, with or without modification, are permitted provided that the following conditions
11 // are met:
12 //
13 // 1. Redistributions of source code must retain the above copyright notice, this list of
14 // conditions and the following disclaimer.
15 // 2. Redistributions in binary form must reproduce the above copyright notice, this list
16 // of conditions and the following disclaimer in the documentation and/or other materials
17 // provided with the distribution.
18 // 3. Neither the names of the Blaze development group nor the names of its contributors
19 // may be used to endorse or promote products derived from this software without specific
20 // prior written permission.
21 //
22 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY
23 // EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES
24 // OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT
25 // SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
26 // INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED
27 // TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
28 // BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
29 // CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
30 // ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH
31 // DAMAGE.
32 */
33 //=================================================================================================
34
35 #ifndef _BLAZE_MATH_DENSE_MMM_H_
36 #define _BLAZE_MATH_DENSE_MMM_H_
37
38
39 //*************************************************************************************************
40 // Includes
41 //*************************************************************************************************
42
43 #include <blaze/math/Aliases.h>
44 #include <blaze/math/constraints/Adaptor.h>
45 #include <blaze/math/constraints/ColumnMajorMatrix.h>
46 #include <blaze/math/constraints/Computation.h>
47 #include <blaze/math/constraints/DenseMatrix.h>
48 #include <blaze/math/constraints/Hermitian.h>
49 #include <blaze/math/constraints/RowMajorMatrix.h>
50 #include <blaze/math/constraints/SIMDCombinable.h>
51 #include <blaze/math/constraints/StrictlyLower.h>
52 #include <blaze/math/constraints/StrictlyUpper.h>
53 #include <blaze/math/constraints/Symmetric.h>
54 #include <blaze/math/constraints/UniLower.h>
55 #include <blaze/math/constraints/UniUpper.h>
56 #include <blaze/math/constraints/Lower.h>
57 #include <blaze/math/constraints/Upper.h>
58 #include <blaze/math/dense/DynamicMatrix.h>
59 #include <blaze/math/expressions/DenseMatrix.h>
60 #include <blaze/math/shims/IsDefault.h>
61 #include <blaze/math/shims/IsOne.h>
62 #include <blaze/math/shims/PrevMultiple.h>
63 #include <blaze/math/shims/Serial.h>
64 #include <blaze/math/SIMD.h>
65 #include <blaze/math/typetraits/IsLower.h>
66 #include <blaze/math/typetraits/IsPadded.h>
67 #include <blaze/math/typetraits/IsUpper.h>
68 #include <blaze/math/views/Check.h>
69 #include <blaze/math/views/Submatrix.h>
70 #include <blaze/system/Blocking.h>
71 #include <blaze/util/algorithms/Min.h>
72 #include <blaze/util/Assert.h>
73 #include <blaze/util/StaticAssert.h>
74 #include <blaze/util/Types.h>
75 #include <blaze/util/typetraits/IsFloatingPoint.h>
76
77
78 namespace blaze {
79
80 //=================================================================================================
81 //
82 // GENERAL DENSE MATRIX MULTIPLICATION KERNELS
83 //
84 //=================================================================================================
85
86 //*************************************************************************************************
87 /*! \cond BLAZE_INTERNAL */
88 /*!\brief Compute kernel for a general dense matrix/dense matrix multiplication
89 // (\f$ C=\alpha*A*B+\beta*C \f$).
90 // \ingroup dense_matrix
91 //
92 // \param C The target left-hand side row-major dense matrix.
93 // \param A The left-hand side multiplication operand.
94 // \param B The right-hand side multiplication operand.
95 // \param alpha The scaling factor for \f$ A*B \f$.
96 // \param beta The scaling factor for \f$ C \f$.
97 // \return void
98 //
99 // This function implements the compute kernel for a general dense matrix/dense matrix
100 // multiplication of the form \f$ C=\alpha*A*B+\beta*C \f$. Both \a A and \a B must
101 // be non-expression dense matrix types, \a C must be a non-expression, non-adaptor,
102 // row-major dense matrix type. The element types of all three matrices must be SIMD
103 // combinable, i.e. must provide a common SIMD interface.
104 */
105 template< typename MT1, typename MT2, typename MT3, typename ST >
mmm(DenseMatrix<MT1,false> & C,const MT2 & A,const MT3 & B,ST alpha,ST beta)106 void mmm( DenseMatrix<MT1,false>& C, const MT2& A, const MT3& B, ST alpha, ST beta )
107 {
108 using ET1 = ElementType_t<MT1>;
109 using ET2 = ElementType_t<MT2>;
110 using ET3 = ElementType_t<MT3>;
111 using SIMDType = SIMDTrait_t<ET1>;
112
113 BLAZE_CONSTRAINT_MUST_BE_DENSE_MATRIX_TYPE ( MT1 );
114 BLAZE_CONSTRAINT_MUST_BE_ROW_MAJOR_MATRIX_TYPE( MT1 );
115 BLAZE_CONSTRAINT_MUST_NOT_BE_ADAPTOR_TYPE ( MT1 );
116 BLAZE_CONSTRAINT_MUST_NOT_BE_COMPUTATION_TYPE ( MT1 );
117
118 BLAZE_CONSTRAINT_MUST_BE_DENSE_MATRIX_TYPE ( MT2 );
119 BLAZE_CONSTRAINT_MUST_NOT_BE_COMPUTATION_TYPE( MT2 );
120
121 BLAZE_CONSTRAINT_MUST_BE_DENSE_MATRIX_TYPE ( MT3 );
122 BLAZE_CONSTRAINT_MUST_NOT_BE_COMPUTATION_TYPE( MT3 );
123
124 BLAZE_CONSTRAINT_MUST_BE_SIMD_COMBINABLE_TYPES( ET1, ET2 );
125 BLAZE_CONSTRAINT_MUST_BE_SIMD_COMBINABLE_TYPES( ET1, ET3 );
126
127 constexpr size_t SIMDSIZE( SIMDTrait<ET1>::size );
128
129 constexpr bool remainder( !IsPadded_v<MT2> || !IsPadded_v<MT3> );
130
131 constexpr size_t KBLOCK( MMM_OUTER_BLOCK_SIZE * ( 16UL/sizeof(ET1) ) );
132 constexpr size_t JBLOCK( MMM_INNER_BLOCK_SIZE );
133
134 BLAZE_STATIC_ASSERT( KBLOCK >= SIMDSIZE && KBLOCK % SIMDSIZE == 0UL );
135 BLAZE_STATIC_ASSERT( JBLOCK >= SIMDSIZE && JBLOCK % SIMDSIZE == 0UL );
136
137 const size_t M( A.rows() );
138 const size_t N( B.columns() );
139 const size_t K( A.columns() );
140
141 BLAZE_INTERNAL_ASSERT( A.columns() == B.rows(), "Invalid matrix sizes detected" );
142
143 DynamicMatrix<ET2,false> A2( M, KBLOCK );
144 DynamicMatrix<ET3,true> B2( KBLOCK, JBLOCK );
145
146 if( isDefault( beta ) ) {
147 reset( *C );
148 }
149 else if( !isOne( beta ) ) {
150 (*C) *= beta;
151 }
152
153 size_t kk( 0UL );
154 size_t kblock( 0UL );
155
156 while( kk + ( remainder ? SIMDSIZE-1UL : 0UL ) < K )
157 {
158 if( remainder ) {
159 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( prevMultiple( K - kk, SIMDSIZE ) ) );
160 }
161 else {
162 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( K - kk ) );
163 }
164
165 const size_t ibegin( IsLower_v<MT2> ? kk : 0UL );
166 const size_t iend ( IsUpper_v<MT2> ? kk+kblock : M );
167 const size_t isize ( iend - ibegin );
168
169 A2 = serial( submatrix< remainder ? unaligned : aligned >( A, ibegin, kk, isize, kblock, unchecked ) );
170
171 size_t jj( 0UL );
172 size_t jblock( 0UL );
173
174 while( jj < N )
175 {
176 jblock = ( ( jj+JBLOCK <= N )?( JBLOCK ):( N - jj ) );
177
178 if( ( IsLower_v<MT3> && kk+kblock <= jj ) ||
179 ( IsUpper_v<MT3> && jj+jblock <= kk ) ) {
180 jj += jblock;
181 continue;
182 }
183
184 B2 = serial( submatrix< remainder ? unaligned : aligned >( B, kk, jj, kblock, jblock, unchecked ) );
185
186 size_t i( 0UL );
187
188 if( IsFloatingPoint_v<ET1> )
189 {
190 for( ; (i+5UL) <= isize; i+=5UL )
191 {
192 size_t j( 0UL );
193
194 for( ; (j+2UL) <= jblock; j+=2UL )
195 {
196 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8, xmm9, xmm10;
197
198 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
199 {
200 const SIMDType a1( A2.load(i ,k) );
201 const SIMDType a2( A2.load(i+1UL,k) );
202 const SIMDType a3( A2.load(i+2UL,k) );
203 const SIMDType a4( A2.load(i+3UL,k) );
204 const SIMDType a5( A2.load(i+4UL,k) );
205
206 const SIMDType b1( B2.load(k,j ) );
207 const SIMDType b2( B2.load(k,j+1UL) );
208
209 xmm1 += a1 * b1;
210 xmm2 += a1 * b2;
211 xmm3 += a2 * b1;
212 xmm4 += a2 * b2;
213 xmm5 += a3 * b1;
214 xmm6 += a3 * b2;
215 xmm7 += a4 * b1;
216 xmm8 += a4 * b2;
217 xmm9 += a5 * b1;
218 xmm10 += a5 * b2;
219 }
220
221 (*C)(ibegin+i ,jj+j ) += sum( xmm1 ) * alpha;
222 (*C)(ibegin+i ,jj+j+1UL) += sum( xmm2 ) * alpha;
223 (*C)(ibegin+i+1UL,jj+j ) += sum( xmm3 ) * alpha;
224 (*C)(ibegin+i+1UL,jj+j+1UL) += sum( xmm4 ) * alpha;
225 (*C)(ibegin+i+2UL,jj+j ) += sum( xmm5 ) * alpha;
226 (*C)(ibegin+i+2UL,jj+j+1UL) += sum( xmm6 ) * alpha;
227 (*C)(ibegin+i+3UL,jj+j ) += sum( xmm7 ) * alpha;
228 (*C)(ibegin+i+3UL,jj+j+1UL) += sum( xmm8 ) * alpha;
229 (*C)(ibegin+i+4UL,jj+j ) += sum( xmm9 ) * alpha;
230 (*C)(ibegin+i+4UL,jj+j+1UL) += sum( xmm10 ) * alpha;
231 }
232
233 if( j<jblock )
234 {
235 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5;
236
237 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
238 {
239 const SIMDType a1( A2.load(i ,k) );
240 const SIMDType a2( A2.load(i+1UL,k) );
241 const SIMDType a3( A2.load(i+2UL,k) );
242 const SIMDType a4( A2.load(i+3UL,k) );
243 const SIMDType a5( A2.load(i+4UL,k) );
244
245 const SIMDType b1( B2.load(k,j) );
246
247 xmm1 += a1 * b1;
248 xmm2 += a2 * b1;
249 xmm3 += a3 * b1;
250 xmm4 += a4 * b1;
251 xmm5 += a5 * b1;
252 }
253
254 (*C)(ibegin+i ,jj+j) += sum( xmm1 ) * alpha;
255 (*C)(ibegin+i+1UL,jj+j) += sum( xmm2 ) * alpha;
256 (*C)(ibegin+i+2UL,jj+j) += sum( xmm3 ) * alpha;
257 (*C)(ibegin+i+3UL,jj+j) += sum( xmm4 ) * alpha;
258 (*C)(ibegin+i+4UL,jj+j) += sum( xmm5 ) * alpha;
259 }
260 }
261 }
262 else
263 {
264 for( ; (i+4UL) <= isize; i+=4UL )
265 {
266 size_t j( 0UL );
267
268 for( ; (j+2UL) <= jblock; j+=2UL )
269 {
270 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
271
272 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
273 {
274 const SIMDType a1( A2.load(i ,k) );
275 const SIMDType a2( A2.load(i+1UL,k) );
276 const SIMDType a3( A2.load(i+2UL,k) );
277 const SIMDType a4( A2.load(i+3UL,k) );
278
279 const SIMDType b1( B2.load(k,j ) );
280 const SIMDType b2( B2.load(k,j+1UL) );
281
282 xmm1 += a1 * b1;
283 xmm2 += a1 * b2;
284 xmm3 += a2 * b1;
285 xmm4 += a2 * b2;
286 xmm5 += a3 * b1;
287 xmm6 += a3 * b2;
288 xmm7 += a4 * b1;
289 xmm8 += a4 * b2;
290 }
291
292 (*C)(ibegin+i ,jj+j ) += sum( xmm1 ) * alpha;
293 (*C)(ibegin+i ,jj+j+1UL) += sum( xmm2 ) * alpha;
294 (*C)(ibegin+i+1UL,jj+j ) += sum( xmm3 ) * alpha;
295 (*C)(ibegin+i+1UL,jj+j+1UL) += sum( xmm4 ) * alpha;
296 (*C)(ibegin+i+2UL,jj+j ) += sum( xmm5 ) * alpha;
297 (*C)(ibegin+i+2UL,jj+j+1UL) += sum( xmm6 ) * alpha;
298 (*C)(ibegin+i+3UL,jj+j ) += sum( xmm7 ) * alpha;
299 (*C)(ibegin+i+3UL,jj+j+1UL) += sum( xmm8 ) * alpha;
300 }
301
302 if( j<jblock )
303 {
304 SIMDType xmm1, xmm2, xmm3, xmm4;
305
306 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
307 {
308 const SIMDType a1( A2.load(i ,k) );
309 const SIMDType a2( A2.load(i+1UL,k) );
310 const SIMDType a3( A2.load(i+2UL,k) );
311 const SIMDType a4( A2.load(i+3UL,k) );
312
313 const SIMDType b1( B2.load(k,j) );
314
315 xmm1 += a1 * b1;
316 xmm2 += a2 * b1;
317 xmm3 += a3 * b1;
318 xmm4 += a4 * b1;
319 }
320
321 (*C)(ibegin+i ,jj+j) += sum( xmm1 ) * alpha;
322 (*C)(ibegin+i+1UL,jj+j) += sum( xmm2 ) * alpha;
323 (*C)(ibegin+i+2UL,jj+j) += sum( xmm3 ) * alpha;
324 (*C)(ibegin+i+3UL,jj+j) += sum( xmm4 ) * alpha;
325 }
326 }
327 }
328
329 for( ; (i+2UL) <= isize; i+=2UL )
330 {
331 size_t j( 0UL );
332
333 for( ; (j+4UL) <= jblock; j+=4UL )
334 {
335 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
336
337 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
338 {
339 const SIMDType a1( A2.load(i ,k) );
340 const SIMDType a2( A2.load(i+1UL,k) );
341
342 const SIMDType b1( B2.load(k,j ) );
343 const SIMDType b2( B2.load(k,j+1UL) );
344 const SIMDType b3( B2.load(k,j+2UL) );
345 const SIMDType b4( B2.load(k,j+3UL) );
346
347 xmm1 += a1 * b1;
348 xmm2 += a1 * b2;
349 xmm3 += a1 * b3;
350 xmm4 += a1 * b4;
351 xmm5 += a2 * b1;
352 xmm6 += a2 * b2;
353 xmm7 += a2 * b3;
354 xmm8 += a2 * b4;
355 }
356
357 (*C)(ibegin+i ,jj+j ) += sum( xmm1 ) * alpha;
358 (*C)(ibegin+i ,jj+j+1UL) += sum( xmm2 ) * alpha;
359 (*C)(ibegin+i ,jj+j+2UL) += sum( xmm3 ) * alpha;
360 (*C)(ibegin+i ,jj+j+3UL) += sum( xmm4 ) * alpha;
361 (*C)(ibegin+i+1UL,jj+j ) += sum( xmm5 ) * alpha;
362 (*C)(ibegin+i+1UL,jj+j+1UL) += sum( xmm6 ) * alpha;
363 (*C)(ibegin+i+1UL,jj+j+2UL) += sum( xmm7 ) * alpha;
364 (*C)(ibegin+i+1UL,jj+j+3UL) += sum( xmm8 ) * alpha;
365 }
366
367 for( ; (j+2UL) <= jblock; j+=2UL )
368 {
369 SIMDType xmm1, xmm2, xmm3, xmm4;
370
371 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
372 {
373 const SIMDType a1( A2.load(i ,k) );
374 const SIMDType a2( A2.load(i+1UL,k) );
375
376 const SIMDType b1( B2.load(k,j ) );
377 const SIMDType b2( B2.load(k,j+1UL) );
378
379 xmm1 += a1 * b1;
380 xmm2 += a1 * b2;
381 xmm3 += a2 * b1;
382 xmm4 += a2 * b2;
383 }
384
385 (*C)(ibegin+i ,jj+j ) += sum( xmm1 ) * alpha;
386 (*C)(ibegin+i ,jj+j+1UL) += sum( xmm2 ) * alpha;
387 (*C)(ibegin+i+1UL,jj+j ) += sum( xmm3 ) * alpha;
388 (*C)(ibegin+i+1UL,jj+j+1UL) += sum( xmm4 ) * alpha;
389 }
390
391 if( j<jblock )
392 {
393 SIMDType xmm1, xmm2;
394
395 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
396 {
397 const SIMDType a1( A2.load(i ,k) );
398 const SIMDType a2( A2.load(i+1UL,k) );
399
400 const SIMDType b1( B2.load(k,j) );
401
402 xmm1 += a1 * b1;
403 xmm2 += a2 * b1;
404 }
405
406 (*C)(ibegin+i ,jj+j) += sum( xmm1 ) * alpha;
407 (*C)(ibegin+i+1UL,jj+j) += sum( xmm2 ) * alpha;
408 }
409 }
410
411 if( i<isize )
412 {
413 size_t j( 0UL );
414
415 for( ; (j+2UL) <= jblock; j+=2UL )
416 {
417 SIMDType xmm1, xmm2;
418
419 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
420 {
421 const SIMDType a1( A2.load(i,k) );
422
423 xmm1 += a1 * B2.load(k,j );
424 xmm2 += a1 * B2.load(k,j+1UL);
425 }
426
427 (*C)(ibegin+i,jj+j ) += sum( xmm1 ) * alpha;
428 (*C)(ibegin+i,jj+j+1UL) += sum( xmm2 ) * alpha;
429 }
430
431 if( j<jblock )
432 {
433 SIMDType xmm1;
434
435 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
436 {
437 const SIMDType a1( A2.load(i,k) );
438
439 xmm1 += a1 * B2.load(k,j);
440 }
441
442 (*C)(ibegin+i,jj+j) += sum( xmm1 ) * alpha;
443 }
444 }
445
446 jj += jblock;
447 }
448
449 kk += kblock;
450 }
451
452 if( remainder && kk < K )
453 {
454 const size_t ksize( K - kk );
455
456 const size_t ibegin( IsLower_v<MT2> ? kk : 0UL );
457 const size_t isize ( M - ibegin );
458
459 A2 = serial( submatrix( A, ibegin, kk, isize, ksize, unchecked ) );
460
461 size_t jj( 0UL );
462 size_t jblock( 0UL );
463
464 while( jj < N )
465 {
466 jblock = ( ( jj+JBLOCK <= N )?( JBLOCK ):( N - jj ) );
467
468 if( IsUpper_v<MT3> && jj+jblock <= kk ) {
469 jj += jblock;
470 continue;
471 }
472
473 B2 = serial( submatrix( B, kk, jj, ksize, jblock, unchecked ) );
474
475 size_t i( 0UL );
476
477 if( IsFloatingPoint_v<ET1> )
478 {
479 for( ; (i+5UL) <= isize; i+=5UL )
480 {
481 size_t j( 0UL );
482
483 for( ; (j+2UL) <= jblock; j+=2UL ) {
484 for( size_t k=0UL; k<ksize; ++k ) {
485 (*C)(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
486 (*C)(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
487 (*C)(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
488 (*C)(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
489 (*C)(ibegin+i+2UL,jj+j ) += A2(i+2UL,k) * B2(k,j ) * alpha;
490 (*C)(ibegin+i+2UL,jj+j+1UL) += A2(i+2UL,k) * B2(k,j+1UL) * alpha;
491 (*C)(ibegin+i+3UL,jj+j ) += A2(i+3UL,k) * B2(k,j ) * alpha;
492 (*C)(ibegin+i+3UL,jj+j+1UL) += A2(i+3UL,k) * B2(k,j+1UL) * alpha;
493 (*C)(ibegin+i+4UL,jj+j ) += A2(i+4UL,k) * B2(k,j ) * alpha;
494 (*C)(ibegin+i+4UL,jj+j+1UL) += A2(i+4UL,k) * B2(k,j+1UL) * alpha;
495 }
496 }
497
498 if( j<jblock ) {
499 for( size_t k=0UL; k<ksize; ++k ) {
500 (*C)(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
501 (*C)(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
502 (*C)(ibegin+i+2UL,jj+j) += A2(i+2UL,k) * B2(k,j) * alpha;
503 (*C)(ibegin+i+3UL,jj+j) += A2(i+3UL,k) * B2(k,j) * alpha;
504 (*C)(ibegin+i+4UL,jj+j) += A2(i+4UL,k) * B2(k,j) * alpha;
505 }
506 }
507 }
508 }
509 else
510 {
511 for( ; (i+4UL) <= isize; i+=4UL )
512 {
513 size_t j( 0UL );
514
515 for( ; (j+2UL) <= jblock; j+=2UL ) {
516 for( size_t k=0UL; k<ksize; ++k ) {
517 (*C)(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
518 (*C)(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
519 (*C)(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
520 (*C)(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
521 (*C)(ibegin+i+2UL,jj+j ) += A2(i+2UL,k) * B2(k,j ) * alpha;
522 (*C)(ibegin+i+2UL,jj+j+1UL) += A2(i+2UL,k) * B2(k,j+1UL) * alpha;
523 (*C)(ibegin+i+3UL,jj+j ) += A2(i+3UL,k) * B2(k,j ) * alpha;
524 (*C)(ibegin+i+3UL,jj+j+1UL) += A2(i+3UL,k) * B2(k,j+1UL) * alpha;
525 }
526 }
527
528 if( j<jblock ) {
529 for( size_t k=0UL; k<ksize; ++k ) {
530 (*C)(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
531 (*C)(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
532 (*C)(ibegin+i+2UL,jj+j) += A2(i+2UL,k) * B2(k,j) * alpha;
533 (*C)(ibegin+i+3UL,jj+j) += A2(i+3UL,k) * B2(k,j) * alpha;
534 }
535 }
536 }
537 }
538
539 for( ; (i+2UL) <= isize; i+=2UL )
540 {
541 size_t j( 0UL );
542
543 for( ; (j+2UL) <= jblock; j+=2UL ) {
544 for( size_t k=0UL; k<ksize; ++k ) {
545 (*C)(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
546 (*C)(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
547 (*C)(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
548 (*C)(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
549 }
550 }
551
552 if( j<jblock ) {
553 for( size_t k=0UL; k<ksize; ++k ) {
554 (*C)(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
555 (*C)(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
556 }
557 }
558 }
559
560 if( i<isize )
561 {
562 size_t j( 0UL );
563
564 for( ; (j+2UL) <= jblock; j+=2UL ) {
565 for( size_t k=0UL; k<ksize; ++k ) {
566 (*C)(ibegin+i,jj+j ) += A2(i,k) * B2(k,j ) * alpha;
567 (*C)(ibegin+i,jj+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
568 }
569 }
570
571 if( j<jblock ) {
572 for( size_t k=0UL; k<ksize; ++k ) {
573 (*C)(ibegin+i,jj+j) += A2(i,k) * B2(k,j) * alpha;
574 }
575 }
576 }
577
578 jj += jblock;
579 }
580 }
581 }
582 /*! \endcond */
583 //*************************************************************************************************
584
585
586 //*************************************************************************************************
587 /*! \cond BLAZE_INTERNAL */
588 /*!\brief Compute kernel for a general dense matrix/dense matrix multiplication
589 // (\f$ C=\alpha*A*B+\beta*C \f$).
590 // \ingroup dense_matrix
591 //
592 // \param C The target left-hand side column-major dense matrix.
593 // \param A The left-hand side multiplication operand.
594 // \param B The right-hand side multiplication operand.
595 // \param alpha The scaling factor for \f$ A*B \f$.
596 // \param beta The scaling factor for \f$ C \f$.
597 // \return void
598 //
599 // This function implements the compute kernel for a general dense matrix/dense matrix
600 // multiplication of the form \f$ C=\alpha*A*B+\beta*C \f$. Both \a A and \a B must
601 // be non-expression dense matrix types, \a C must be a non-expression, non-adaptor,
602 // column-major dense matrix type. The element types of all three matrices must be SIMD
603 // combinable, i.e. must provide a common SIMD interface.
604 */
605 template< typename MT1, typename MT2, typename MT3, typename ST >
mmm(DenseMatrix<MT1,true> & C,const MT2 & A,const MT3 & B,ST alpha,ST beta)606 void mmm( DenseMatrix<MT1,true>& C, const MT2& A, const MT3& B, ST alpha, ST beta )
607 {
608 using ET1 = ElementType_t<MT1>;
609 using ET2 = ElementType_t<MT2>;
610 using ET3 = ElementType_t<MT3>;
611 using SIMDType = SIMDTrait_t<ET1>;
612
613 BLAZE_CONSTRAINT_MUST_BE_DENSE_MATRIX_TYPE ( MT1 );
614 BLAZE_CONSTRAINT_MUST_BE_COLUMN_MAJOR_MATRIX_TYPE( MT1 );
615 BLAZE_CONSTRAINT_MUST_NOT_BE_ADAPTOR_TYPE ( MT1 );
616 BLAZE_CONSTRAINT_MUST_NOT_BE_COMPUTATION_TYPE ( MT1 );
617
618 BLAZE_CONSTRAINT_MUST_BE_DENSE_MATRIX_TYPE ( MT2 );
619 BLAZE_CONSTRAINT_MUST_NOT_BE_COMPUTATION_TYPE( MT2 );
620
621 BLAZE_CONSTRAINT_MUST_BE_DENSE_MATRIX_TYPE ( MT3 );
622 BLAZE_CONSTRAINT_MUST_NOT_BE_COMPUTATION_TYPE( MT3 );
623
624 BLAZE_CONSTRAINT_MUST_BE_SIMD_COMBINABLE_TYPES( ET1, ET2 );
625 BLAZE_CONSTRAINT_MUST_BE_SIMD_COMBINABLE_TYPES( ET1, ET3 );
626
627 constexpr size_t SIMDSIZE( SIMDTrait<ET1>::size );
628
629 constexpr bool remainder( !IsPadded_v<MT2> || !IsPadded_v<MT3> );
630
631 constexpr size_t KBLOCK( MMM_OUTER_BLOCK_SIZE * ( 16UL/sizeof(ET1) ) );
632 constexpr size_t IBLOCK( MMM_INNER_BLOCK_SIZE );
633
634 BLAZE_STATIC_ASSERT( KBLOCK >= SIMDSIZE && KBLOCK % SIMDSIZE == 0UL );
635 BLAZE_STATIC_ASSERT( IBLOCK >= SIMDSIZE && IBLOCK % SIMDSIZE == 0UL );
636
637 const size_t M( A.rows() );
638 const size_t N( B.columns() );
639 const size_t K( A.columns() );
640
641 BLAZE_INTERNAL_ASSERT( A.columns() == B.rows(), "Invalid matrix sizes detected" );
642
643 DynamicMatrix<ET2,false> A2( IBLOCK, KBLOCK );
644 DynamicMatrix<ET3,true> B2( KBLOCK, N );
645
646 if( isDefault( beta ) ) {
647 reset( *C );
648 }
649 else if( !isOne( beta ) ) {
650 (*C) *= beta;
651 }
652
653 size_t kk( 0UL );
654 size_t kblock( 0UL );
655
656 while( kk + ( remainder ? SIMDSIZE-1UL : 0UL ) < K )
657 {
658 if( remainder ) {
659 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( prevMultiple( K - kk, SIMDSIZE ) ) );
660 }
661 else {
662 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( K - kk ) );
663 }
664
665 const size_t jbegin( IsUpper_v<MT3> ? kk : 0UL );
666 const size_t jend ( IsLower_v<MT3> ? kk+kblock : N );
667 const size_t jsize ( jend - jbegin );
668
669 B2 = serial( submatrix< remainder ? unaligned : aligned >( B, kk, jbegin, kblock, jsize, unchecked ) );
670
671 size_t ii( 0UL );
672 size_t iblock( 0UL );
673
674 while( ii < M )
675 {
676 iblock = ( ( ii+IBLOCK <= M )?( IBLOCK ):( M - ii ) );
677
678 if( ( IsLower_v<MT2> && ii+iblock <= kk ) ||
679 ( IsUpper_v<MT2> && kk+kblock <= ii ) ) {
680 ii += iblock;
681 continue;
682 }
683
684 A2 = serial( submatrix< remainder ? unaligned : aligned >( A, ii, kk, iblock, kblock, unchecked ) );
685
686 size_t j( 0UL );
687
688 if( IsFloatingPoint_v<ET3> )
689 {
690 for( ; (j+5UL) <= jsize; j+=5UL )
691 {
692 size_t i( 0UL );
693
694 for( ; (i+2UL) <= iblock; i+=2UL )
695 {
696 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8, xmm9, xmm10;
697
698 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
699 {
700 const SIMDType a1( A2.load(i ,k) );
701 const SIMDType a2( A2.load(i+1UL,k) );
702
703 const SIMDType b1( B2.load(k,j ) );
704 const SIMDType b2( B2.load(k,j+1UL) );
705 const SIMDType b3( B2.load(k,j+2UL) );
706 const SIMDType b4( B2.load(k,j+3UL) );
707 const SIMDType b5( B2.load(k,j+4UL) );
708
709 xmm1 += a1 * b1;
710 xmm2 += a1 * b2;
711 xmm3 += a1 * b3;
712 xmm4 += a1 * b4;
713 xmm5 += a1 * b5;
714 xmm6 += a2 * b1;
715 xmm7 += a2 * b2;
716 xmm8 += a2 * b3;
717 xmm9 += a2 * b4;
718 xmm10 += a2 * b5;
719 }
720
721 (*C)(ii+i ,jbegin+j ) += sum( xmm1 ) * alpha;
722 (*C)(ii+i ,jbegin+j+1UL) += sum( xmm2 ) * alpha;
723 (*C)(ii+i ,jbegin+j+2UL) += sum( xmm3 ) * alpha;
724 (*C)(ii+i ,jbegin+j+3UL) += sum( xmm4 ) * alpha;
725 (*C)(ii+i ,jbegin+j+4UL) += sum( xmm5 ) * alpha;
726 (*C)(ii+i+1UL,jbegin+j ) += sum( xmm6 ) * alpha;
727 (*C)(ii+i+1UL,jbegin+j+1UL) += sum( xmm7 ) * alpha;
728 (*C)(ii+i+1UL,jbegin+j+2UL) += sum( xmm8 ) * alpha;
729 (*C)(ii+i+1UL,jbegin+j+3UL) += sum( xmm9 ) * alpha;
730 (*C)(ii+i+1UL,jbegin+j+4UL) += sum( xmm10 ) * alpha;
731 }
732
733 if( i<iblock )
734 {
735 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5;
736
737 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
738 {
739 const SIMDType a1( A2.load(i,k) );
740
741 xmm1 += a1 * B2.load(k,j );
742 xmm2 += a1 * B2.load(k,j+1UL);
743 xmm3 += a1 * B2.load(k,j+2UL);
744 xmm4 += a1 * B2.load(k,j+3UL);
745 xmm5 += a1 * B2.load(k,j+4UL);
746 }
747
748 (*C)(ii+i,jbegin+j ) += sum( xmm1 ) * alpha;
749 (*C)(ii+i,jbegin+j+1UL) += sum( xmm2 ) * alpha;
750 (*C)(ii+i,jbegin+j+2UL) += sum( xmm3 ) * alpha;
751 (*C)(ii+i,jbegin+j+3UL) += sum( xmm4 ) * alpha;
752 (*C)(ii+i,jbegin+j+4UL) += sum( xmm5 ) * alpha;
753 }
754 }
755 }
756 else
757 {
758 for( ; (j+4UL) <= jsize; j+=4UL )
759 {
760 size_t i( 0UL );
761
762 for( ; (i+2UL) <= iblock; i+=2UL )
763 {
764 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
765
766 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
767 {
768 const SIMDType a1( A2.load(i ,k) );
769 const SIMDType a2( A2.load(i+1UL,k) );
770
771 const SIMDType b1( B2.load(k,j ) );
772 const SIMDType b2( B2.load(k,j+1UL) );
773 const SIMDType b3( B2.load(k,j+2UL) );
774 const SIMDType b4( B2.load(k,j+3UL) );
775
776 xmm1 += a1 * b1;
777 xmm2 += a1 * b2;
778 xmm3 += a1 * b3;
779 xmm4 += a1 * b4;
780 xmm5 += a2 * b1;
781 xmm6 += a2 * b2;
782 xmm7 += a2 * b3;
783 xmm8 += a2 * b4;
784 }
785
786 (*C)(ii+i ,jbegin+j ) += sum( xmm1 ) * alpha;
787 (*C)(ii+i ,jbegin+j+1UL) += sum( xmm2 ) * alpha;
788 (*C)(ii+i ,jbegin+j+2UL) += sum( xmm3 ) * alpha;
789 (*C)(ii+i ,jbegin+j+3UL) += sum( xmm4 ) * alpha;
790 (*C)(ii+i+1UL,jbegin+j ) += sum( xmm5 ) * alpha;
791 (*C)(ii+i+1UL,jbegin+j+1UL) += sum( xmm6 ) * alpha;
792 (*C)(ii+i+1UL,jbegin+j+2UL) += sum( xmm7 ) * alpha;
793 (*C)(ii+i+1UL,jbegin+j+3UL) += sum( xmm8 ) * alpha;
794 }
795
796 if( i<iblock )
797 {
798 SIMDType xmm1, xmm2, xmm3, xmm4;
799
800 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
801 {
802 const SIMDType a1( A2.load(i,k) );
803
804 xmm1 += a1 * B2.load(k,j );
805 xmm2 += a1 * B2.load(k,j+1UL);
806 xmm3 += a1 * B2.load(k,j+2UL);
807 xmm4 += a1 * B2.load(k,j+3UL);
808 }
809
810 (*C)(ii+i,jbegin+j ) += sum( xmm1 ) * alpha;
811 (*C)(ii+i,jbegin+j+1UL) += sum( xmm2 ) * alpha;
812 (*C)(ii+i,jbegin+j+2UL) += sum( xmm3 ) * alpha;
813 (*C)(ii+i,jbegin+j+3UL) += sum( xmm4 ) * alpha;
814 }
815 }
816 }
817
818 for( ; (j+2UL) <= jsize; j+=2UL )
819 {
820 size_t i( 0UL );
821
822 for( ; (i+4UL) <= iblock; i+=4UL )
823 {
824 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
825
826 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
827 {
828 const SIMDType a1( A2.load(i ,k) );
829 const SIMDType a2( A2.load(i+1UL,k) );
830 const SIMDType a3( A2.load(i+2UL,k) );
831 const SIMDType a4( A2.load(i+3UL,k) );
832
833 const SIMDType b1( B2.load(k,j ) );
834 const SIMDType b2( B2.load(k,j+1UL) );
835
836 xmm1 += a1 * b1;
837 xmm2 += a1 * b2;
838 xmm3 += a2 * b1;
839 xmm4 += a2 * b2;
840 xmm5 += a3 * b1;
841 xmm6 += a3 * b2;
842 xmm7 += a4 * b1;
843 xmm8 += a4 * b2;
844 }
845
846 (*C)(ii+i ,jbegin+j ) += sum( xmm1 ) * alpha;
847 (*C)(ii+i ,jbegin+j+1UL) += sum( xmm2 ) * alpha;
848 (*C)(ii+i+1UL,jbegin+j ) += sum( xmm3 ) * alpha;
849 (*C)(ii+i+1UL,jbegin+j+1UL) += sum( xmm4 ) * alpha;
850 (*C)(ii+i+2UL,jbegin+j ) += sum( xmm5 ) * alpha;
851 (*C)(ii+i+2UL,jbegin+j+1UL) += sum( xmm6 ) * alpha;
852 (*C)(ii+i+3UL,jbegin+j ) += sum( xmm7 ) * alpha;
853 (*C)(ii+i+3UL,jbegin+j+1UL) += sum( xmm8 ) * alpha;
854 }
855
856 for( ; (i+2UL) <= iblock; i+=2UL )
857 {
858 SIMDType xmm1, xmm2, xmm3, xmm4;
859
860 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
861 {
862 const SIMDType a1( A2.load(i ,k) );
863 const SIMDType a2( A2.load(i+1UL,k) );
864
865 const SIMDType b1( B2.load(k,j ) );
866 const SIMDType b2( B2.load(k,j+1UL) );
867
868 xmm1 += a1 * b1;
869 xmm2 += a1 * b2;
870 xmm3 += a2 * b1;
871 xmm4 += a2 * b2;
872 }
873
874 (*C)(ii+i ,jbegin+j ) += sum( xmm1 ) * alpha;
875 (*C)(ii+i ,jbegin+j+1UL) += sum( xmm2 ) * alpha;
876 (*C)(ii+i+1UL,jbegin+j ) += sum( xmm3 ) * alpha;
877 (*C)(ii+i+1UL,jbegin+j+1UL) += sum( xmm4 ) * alpha;
878 }
879
880 if( i<iblock )
881 {
882 SIMDType xmm1, xmm2;
883
884 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
885 {
886 const SIMDType a1( A2.load(i,k) );
887
888 xmm1 += a1 * B2.load(k,j );
889 xmm2 += a1 * B2.load(k,j+1UL);
890 }
891
892 (*C)(ii+i,jbegin+j ) += sum( xmm1 ) * alpha;
893 (*C)(ii+i,jbegin+j+1UL) += sum( xmm2 ) * alpha;
894 }
895 }
896
897 if( j<jsize )
898 {
899 size_t i( 0UL );
900
901 for( ; (i+2UL) <= iblock; i+=2UL )
902 {
903 SIMDType xmm1, xmm2;
904
905 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
906 {
907 const SIMDType b1( B2.load(k,j) );
908
909 xmm1 += A2.load(i ,k) * b1;
910 xmm2 += A2.load(i+1UL,k) * b1;
911 }
912
913 (*C)(ii+i ,jbegin+j) += sum( xmm1 ) * alpha;
914 (*C)(ii+i+1UL,jbegin+j) += sum( xmm2 ) * alpha;
915 }
916
917 if( i<iblock )
918 {
919 SIMDType xmm1;
920
921 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
922 {
923 xmm1 += A2.load(i,k) * B2.load(k,j);
924 }
925
926 (*C)(ii+i,jbegin+j) += sum( xmm1 ) * alpha;
927 }
928 }
929
930 ii += iblock;
931 }
932
933 kk += kblock;
934 }
935
936 if( remainder && kk < K )
937 {
938 const size_t ksize( K - kk );
939
940 const size_t jbegin( IsUpper_v<MT3> ? kk : 0UL );
941 const size_t jsize ( N - jbegin );
942
943 B2 = serial( submatrix( B, kk, jbegin, ksize, jsize, unchecked ) );
944
945 size_t ii( 0UL );
946 size_t iblock( 0UL );
947
948 while( ii < M )
949 {
950 iblock = ( ( ii+IBLOCK <= M )?( IBLOCK ):( M - ii ) );
951
952 if( IsLower_v<MT2> && ii+iblock <= kk ) {
953 ii += iblock;
954 continue;
955 }
956
957 A2 = serial( submatrix( A, ii, kk, iblock, ksize, unchecked ) );
958
959 size_t j( 0UL );
960
961 if( IsFloatingPoint_v<ET1> )
962 {
963 for( ; (j+5UL) <= jsize; j+=5UL )
964 {
965 size_t i( 0UL );
966
967 for( ; (i+2UL) <= iblock; i+=2UL ) {
968 for( size_t k=0UL; k<ksize; ++k ) {
969 (*C)(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
970 (*C)(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
971 (*C)(ii+i ,jbegin+j+2UL) += A2(i ,k) * B2(k,j+2UL) * alpha;
972 (*C)(ii+i ,jbegin+j+3UL) += A2(i ,k) * B2(k,j+3UL) * alpha;
973 (*C)(ii+i ,jbegin+j+4UL) += A2(i ,k) * B2(k,j+4UL) * alpha;
974 (*C)(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
975 (*C)(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
976 (*C)(ii+i+1UL,jbegin+j+2UL) += A2(i+1UL,k) * B2(k,j+2UL) * alpha;
977 (*C)(ii+i+1UL,jbegin+j+3UL) += A2(i+1UL,k) * B2(k,j+3UL) * alpha;
978 (*C)(ii+i+1UL,jbegin+j+4UL) += A2(i+1UL,k) * B2(k,j+4UL) * alpha;
979 }
980 }
981
982 if( i<iblock ) {
983 for( size_t k=0UL; k<ksize; ++k ) {
984 (*C)(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
985 (*C)(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
986 (*C)(ii+i,jbegin+j+2UL) += A2(i,k) * B2(k,j+2UL) * alpha;
987 (*C)(ii+i,jbegin+j+3UL) += A2(i,k) * B2(k,j+3UL) * alpha;
988 (*C)(ii+i,jbegin+j+4UL) += A2(i,k) * B2(k,j+4UL) * alpha;
989 }
990 }
991 }
992 }
993 else
994 {
995 for( ; (j+4UL) <= jsize; j+=4UL )
996 {
997 size_t i( 0UL );
998
999 for( ; (i+2UL) <= iblock; i+=2UL ) {
1000 for( size_t k=0UL; k<ksize; ++k ) {
1001 (*C)(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
1002 (*C)(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
1003 (*C)(ii+i ,jbegin+j+2UL) += A2(i ,k) * B2(k,j+2UL) * alpha;
1004 (*C)(ii+i ,jbegin+j+3UL) += A2(i ,k) * B2(k,j+3UL) * alpha;
1005 (*C)(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
1006 (*C)(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
1007 (*C)(ii+i+1UL,jbegin+j+2UL) += A2(i+1UL,k) * B2(k,j+2UL) * alpha;
1008 (*C)(ii+i+1UL,jbegin+j+3UL) += A2(i+1UL,k) * B2(k,j+3UL) * alpha;
1009 }
1010 }
1011
1012 if( i<iblock ) {
1013 for( size_t k=0UL; k<ksize; ++k ) {
1014 (*C)(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
1015 (*C)(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
1016 (*C)(ii+i,jbegin+j+2UL) += A2(i,k) * B2(k,j+2UL) * alpha;
1017 (*C)(ii+i,jbegin+j+3UL) += A2(i,k) * B2(k,j+3UL) * alpha;
1018 }
1019 }
1020 }
1021 }
1022
1023 for( ; (j+2UL) <= jsize; j+=2UL )
1024 {
1025 size_t i( 0UL );
1026
1027 for( ; (i+2UL) <= iblock; i+=2UL ) {
1028 for( size_t k=0UL; k<ksize; ++k ) {
1029 (*C)(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
1030 (*C)(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
1031 (*C)(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
1032 (*C)(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
1033 }
1034 }
1035
1036 if( i<iblock ) {
1037 for( size_t k=0UL; k<ksize; ++k ) {
1038 (*C)(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
1039 (*C)(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
1040 }
1041 }
1042 }
1043
1044 if( j<jsize )
1045 {
1046 size_t i( 0UL );
1047
1048 for( ; (i+2UL) <= iblock; i+=2UL ) {
1049 for( size_t k=0UL; k<ksize; ++k ) {
1050 (*C)(ii+i ,jbegin+j) += A2(i ,k) * B2(k,j) * alpha;
1051 (*C)(ii+i+1UL,jbegin+j) += A2(i+1UL,k) * B2(k,j) * alpha;
1052 }
1053 }
1054
1055 if( i<iblock ) {
1056 for( size_t k=0UL; k<ksize; ++k ) {
1057 (*C)(ii+i,jbegin+j) += A2(i,k) * B2(k,j) * alpha;
1058 }
1059 }
1060 }
1061
1062 ii += iblock;
1063 }
1064 }
1065 }
1066 /*! \endcond */
1067 //*************************************************************************************************
1068
1069
1070 //*************************************************************************************************
1071 /*! \cond BLAZE_INTERNAL */
1072 /*!\brief Compute kernel for a general dense matrix/dense matrix multiplication (\f$ C=A*B \f$).
1073 // \ingroup dense_matrix
1074 //
1075 // \param C The target left-hand side column-major dense matrix.
1076 // \param A The left-hand side multiplication operand.
1077 // \param B The right-hand side multiplication operand.
1078 // \return void
1079 //
1080 // This function implements the compute kernel for a general dense matrix/dense matrix
1081 // multiplication of the form \f$ C=A*B \f$. Both \a A and \a B must be non-expression
1082 // dense matrix types, \a C must be a non-expression, non-adaptor, row-major dense matrix
1083 // type. The element types of all three matrices must be SIMD combinable, i.e. must
1084 // provide a common SIMD interface.
1085 */
1086 template< typename MT1, typename MT2, typename MT3 >
mmm(MT1 & C,const MT2 & A,const MT3 & B)1087 inline void mmm( MT1& C, const MT2& A, const MT3& B )
1088 {
1089 using ET1 = ElementType_t<MT1>;
1090 using ET2 = ElementType_t<MT2>;
1091 using ET3 = ElementType_t<MT3>;
1092
1093 BLAZE_CONSTRAINT_MUST_BE_SIMD_COMBINABLE_TYPES( ET1, ET2 );
1094 BLAZE_CONSTRAINT_MUST_BE_SIMD_COMBINABLE_TYPES( ET1, ET3 );
1095
1096 mmm( C, A, B, ET1(1), ET1(0) );
1097 }
1098 /*! \endcond */
1099 //*************************************************************************************************
1100
1101
1102
1103
1104 //=================================================================================================
1105 //
1106 // LOWER DENSE MATRIX MULTIPLICATION KERNELS
1107 //
1108 //=================================================================================================
1109
1110 //*************************************************************************************************
1111 /*! \cond BLAZE_INTERNAL */
1112 /*!\brief Compute kernel for a lower dense matrix/dense matrix multiplication
1113 // (\f$ C=\alpha*A*B+\beta*C \f$).
1114 // \ingroup dense_matrix
1115 //
1116 // \param C The target left-hand side row-major dense matrix.
1117 // \param A The left-hand side multiplication operand.
1118 // \param B The right-hand side multiplication operand.
1119 // \param alpha The scaling factor for \f$ A*B \f$.
1120 // \param beta The scaling factor for \f$ C \f$.
1121 // \return void
1122 //
1123 // This function implements the compute kernel for a lower dense matrix/dense matrix
1124 // multiplication of the form \f$ C=\alpha*A*B+\beta*C \f$. Both \a A and \a B must
1125 // be non-expression dense matrix types, \a C must be a non-expression, non-adaptor,
1126 // row-major dense matrix type. The element types of all three matrices must be SIMD
1127 // combinable, i.e. must provide a common SIMD interface.
1128 */
1129 template< typename MT1, typename MT2, typename MT3, typename ST >
lmmm(DenseMatrix<MT1,false> & C,const MT2 & A,const MT3 & B,ST alpha,ST beta)1130 void lmmm( DenseMatrix<MT1,false>& C, const MT2& A, const MT3& B, ST alpha, ST beta )
1131 {
1132 using ET1 = ElementType_t<MT1>;
1133 using ET2 = ElementType_t<MT2>;
1134 using ET3 = ElementType_t<MT3>;
1135 using SIMDType = SIMDTrait_t<ET1>;
1136
1137 BLAZE_CONSTRAINT_MUST_BE_DENSE_MATRIX_TYPE ( MT1 );
1138 BLAZE_CONSTRAINT_MUST_BE_ROW_MAJOR_MATRIX_TYPE ( MT1 );
1139 BLAZE_CONSTRAINT_MUST_NOT_BE_SYMMETRIC_MATRIX_TYPE ( MT1 );
1140 BLAZE_CONSTRAINT_MUST_NOT_BE_HERMITIAN_MATRIX_TYPE ( MT1 );
1141 BLAZE_CONSTRAINT_MUST_NOT_BE_UNILOWER_MATRIX_TYPE ( MT1 );
1142 BLAZE_CONSTRAINT_MUST_NOT_BE_STRICTLY_LOWER_MATRIX_TYPE( MT1 );
1143 BLAZE_CONSTRAINT_MUST_NOT_BE_UPPER_MATRIX_TYPE ( MT1 );
1144 BLAZE_CONSTRAINT_MUST_NOT_BE_COMPUTATION_TYPE ( MT1 );
1145
1146 BLAZE_CONSTRAINT_MUST_BE_DENSE_MATRIX_TYPE ( MT2 );
1147 BLAZE_CONSTRAINT_MUST_NOT_BE_COMPUTATION_TYPE( MT2 );
1148
1149 BLAZE_CONSTRAINT_MUST_BE_DENSE_MATRIX_TYPE ( MT3 );
1150 BLAZE_CONSTRAINT_MUST_NOT_BE_COMPUTATION_TYPE( MT3 );
1151
1152 BLAZE_CONSTRAINT_MUST_BE_SIMD_COMBINABLE_TYPES( ET1, ET2 );
1153 BLAZE_CONSTRAINT_MUST_BE_SIMD_COMBINABLE_TYPES( ET1, ET3 );
1154
1155 constexpr size_t SIMDSIZE( SIMDTrait<ET1>::size );
1156
1157 constexpr bool remainder( !IsPadded_v<MT2> || !IsPadded_v<MT3> );
1158
1159 constexpr size_t KBLOCK( MMM_OUTER_BLOCK_SIZE * ( 16UL/sizeof(ET1) ) );
1160 constexpr size_t JBLOCK( MMM_INNER_BLOCK_SIZE );
1161
1162 BLAZE_STATIC_ASSERT( KBLOCK >= SIMDSIZE && KBLOCK % SIMDSIZE == 0UL );
1163 BLAZE_STATIC_ASSERT( JBLOCK >= SIMDSIZE && JBLOCK % SIMDSIZE == 0UL );
1164
1165 const size_t M( A.rows() );
1166 const size_t N( B.columns() );
1167 const size_t K( A.columns() );
1168
1169 BLAZE_INTERNAL_ASSERT( A.columns() == B.rows(), "Invalid matrix sizes detected" );
1170
1171 DynamicMatrix<ET2,false> A2( M, KBLOCK );
1172 DynamicMatrix<ET3,true> B2( KBLOCK, JBLOCK );
1173
1174 decltype(auto) c( derestrict( *C ) );
1175
1176 if( isDefault( beta ) ) {
1177 reset( c );
1178 }
1179 else if( !isOne( beta ) ) {
1180 c *= beta;
1181 }
1182
1183 size_t kk( 0UL );
1184 size_t kblock( 0UL );
1185
1186 while( kk + ( remainder ? SIMDSIZE-1UL : 0UL ) < K )
1187 {
1188 if( remainder ) {
1189 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( prevMultiple( K - kk, SIMDSIZE ) ) );
1190 }
1191 else {
1192 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( K - kk ) );
1193 }
1194
1195 const size_t ibegin( IsLower_v<MT2> ? kk : 0UL );
1196 const size_t iend ( IsUpper_v<MT2> ? kk+kblock : M );
1197 const size_t isize ( iend - ibegin );
1198
1199 A2 = serial( submatrix< remainder ? unaligned : aligned >( A, ibegin, kk, isize, kblock, unchecked ) );
1200
1201 size_t jj( 0UL );
1202 size_t jblock( 0UL );
1203
1204 while( jj < N )
1205 {
1206 jblock = ( ( jj+JBLOCK <= N )?( JBLOCK ):( N - jj ) );
1207
1208 if( ( IsLower_v<MT3> && kk+kblock <= jj ) ||
1209 ( IsUpper_v<MT3> && jj+jblock <= kk ) ) {
1210 jj += jblock;
1211 continue;
1212 }
1213
1214 B2 = serial( submatrix< remainder ? unaligned : aligned >( B, kk, jj, kblock, jblock, unchecked ) );
1215
1216 size_t i( 0UL );
1217
1218 if( IsFloatingPoint_v<ET1> )
1219 {
1220 for( ; (i+5UL) <= isize; i+=5UL )
1221 {
1222 if( jj > ibegin+i+4UL ) continue;
1223
1224 const size_t jend( min( ibegin+i-jj+5UL, jblock ) );
1225 size_t j( 0UL );
1226
1227 for( ; (j+2UL) <= jend; j+=2UL )
1228 {
1229 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8, xmm9, xmm10;
1230
1231 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
1232 {
1233 const SIMDType a1( A2.load(i ,k) );
1234 const SIMDType a2( A2.load(i+1UL,k) );
1235 const SIMDType a3( A2.load(i+2UL,k) );
1236 const SIMDType a4( A2.load(i+3UL,k) );
1237 const SIMDType a5( A2.load(i+4UL,k) );
1238
1239 const SIMDType b1( B2.load(k,j ) );
1240 const SIMDType b2( B2.load(k,j+1UL) );
1241
1242 xmm1 += a1 * b1;
1243 xmm2 += a1 * b2;
1244 xmm3 += a2 * b1;
1245 xmm4 += a2 * b2;
1246 xmm5 += a3 * b1;
1247 xmm6 += a3 * b2;
1248 xmm7 += a4 * b1;
1249 xmm8 += a4 * b2;
1250 xmm9 += a5 * b1;
1251 xmm10 += a5 * b2;
1252 }
1253
1254 c(ibegin+i ,jj+j ) += sum( xmm1 ) * alpha;
1255 c(ibegin+i ,jj+j+1UL) += sum( xmm2 ) * alpha;
1256 c(ibegin+i+1UL,jj+j ) += sum( xmm3 ) * alpha;
1257 c(ibegin+i+1UL,jj+j+1UL) += sum( xmm4 ) * alpha;
1258 c(ibegin+i+2UL,jj+j ) += sum( xmm5 ) * alpha;
1259 c(ibegin+i+2UL,jj+j+1UL) += sum( xmm6 ) * alpha;
1260 c(ibegin+i+3UL,jj+j ) += sum( xmm7 ) * alpha;
1261 c(ibegin+i+3UL,jj+j+1UL) += sum( xmm8 ) * alpha;
1262 c(ibegin+i+4UL,jj+j ) += sum( xmm9 ) * alpha;
1263 c(ibegin+i+4UL,jj+j+1UL) += sum( xmm10 ) * alpha;
1264 }
1265
1266 if( j<jend )
1267 {
1268 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5;
1269
1270 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
1271 {
1272 const SIMDType a1( A2.load(i ,k) );
1273 const SIMDType a2( A2.load(i+1UL,k) );
1274 const SIMDType a3( A2.load(i+2UL,k) );
1275 const SIMDType a4( A2.load(i+3UL,k) );
1276 const SIMDType a5( A2.load(i+4UL,k) );
1277
1278 const SIMDType b1( B2.load(k,j) );
1279
1280 xmm1 += a1 * b1;
1281 xmm2 += a2 * b1;
1282 xmm3 += a3 * b1;
1283 xmm4 += a4 * b1;
1284 xmm5 += a5 * b1;
1285 }
1286
1287 c(ibegin+i ,jj+j) += sum( xmm1 ) * alpha;
1288 c(ibegin+i+1UL,jj+j) += sum( xmm2 ) * alpha;
1289 c(ibegin+i+2UL,jj+j) += sum( xmm3 ) * alpha;
1290 c(ibegin+i+3UL,jj+j) += sum( xmm4 ) * alpha;
1291 c(ibegin+i+4UL,jj+j) += sum( xmm5 ) * alpha;
1292 }
1293 }
1294 }
1295 else
1296 {
1297 for( ; (i+4UL) <= isize; i+=4UL )
1298 {
1299 if( jj > ibegin+i+3UL ) continue;
1300
1301 const size_t jend( min( ibegin+i-jj+4UL, jblock ) );
1302 size_t j( 0UL );
1303
1304 for( ; (j+2UL) <= jend; j+=2UL )
1305 {
1306 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
1307
1308 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
1309 {
1310 const SIMDType a1( A2.load(i ,k) );
1311 const SIMDType a2( A2.load(i+1UL,k) );
1312 const SIMDType a3( A2.load(i+2UL,k) );
1313 const SIMDType a4( A2.load(i+3UL,k) );
1314
1315 const SIMDType b1( B2.load(k,j ) );
1316 const SIMDType b2( B2.load(k,j+1UL) );
1317
1318 xmm1 += a1 * b1;
1319 xmm2 += a1 * b2;
1320 xmm3 += a2 * b1;
1321 xmm4 += a2 * b2;
1322 xmm5 += a3 * b1;
1323 xmm6 += a3 * b2;
1324 xmm7 += a4 * b1;
1325 xmm8 += a4 * b2;
1326 }
1327
1328 c(ibegin+i ,jj+j ) += sum( xmm1 ) * alpha;
1329 c(ibegin+i ,jj+j+1UL) += sum( xmm2 ) * alpha;
1330 c(ibegin+i+1UL,jj+j ) += sum( xmm3 ) * alpha;
1331 c(ibegin+i+1UL,jj+j+1UL) += sum( xmm4 ) * alpha;
1332 c(ibegin+i+2UL,jj+j ) += sum( xmm5 ) * alpha;
1333 c(ibegin+i+2UL,jj+j+1UL) += sum( xmm6 ) * alpha;
1334 c(ibegin+i+3UL,jj+j ) += sum( xmm7 ) * alpha;
1335 c(ibegin+i+3UL,jj+j+1UL) += sum( xmm8 ) * alpha;
1336 }
1337
1338 if( j<jend )
1339 {
1340 SIMDType xmm1, xmm2, xmm3, xmm4;
1341
1342 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
1343 {
1344 const SIMDType a1( A2.load(i ,k) );
1345 const SIMDType a2( A2.load(i+1UL,k) );
1346 const SIMDType a3( A2.load(i+2UL,k) );
1347 const SIMDType a4( A2.load(i+3UL,k) );
1348
1349 const SIMDType b1( B2.load(k,j) );
1350
1351 xmm1 += a1 * b1;
1352 xmm2 += a2 * b1;
1353 xmm3 += a3 * b1;
1354 xmm4 += a4 * b1;
1355 }
1356
1357 c(ibegin+i ,jj+j) += sum( xmm1 ) * alpha;
1358 c(ibegin+i+1UL,jj+j) += sum( xmm2 ) * alpha;
1359 c(ibegin+i+2UL,jj+j) += sum( xmm3 ) * alpha;
1360 c(ibegin+i+3UL,jj+j) += sum( xmm4 ) * alpha;
1361 }
1362 }
1363 }
1364
1365 for( ; (i+2UL) <= isize; i+=2UL )
1366 {
1367 if( jj > ibegin+i+1UL ) continue;
1368
1369 const size_t jend( min( ibegin+i-jj+2UL, jblock ) );
1370 size_t j( 0UL );
1371
1372 for( ; (j+4UL) <= jend; j+=4UL )
1373 {
1374 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
1375
1376 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
1377 {
1378 const SIMDType a1( A2.load(i ,k) );
1379 const SIMDType a2( A2.load(i+1UL,k) );
1380
1381 const SIMDType b1( B2.load(k,j ) );
1382 const SIMDType b2( B2.load(k,j+1UL) );
1383 const SIMDType b3( B2.load(k,j+2UL) );
1384 const SIMDType b4( B2.load(k,j+3UL) );
1385
1386 xmm1 += a1 * b1;
1387 xmm2 += a1 * b2;
1388 xmm3 += a1 * b3;
1389 xmm4 += a1 * b4;
1390 xmm5 += a2 * b1;
1391 xmm6 += a2 * b2;
1392 xmm7 += a2 * b3;
1393 xmm8 += a2 * b4;
1394 }
1395
1396 c(ibegin+i ,jj+j ) += sum( xmm1 ) * alpha;
1397 c(ibegin+i ,jj+j+1UL) += sum( xmm2 ) * alpha;
1398 c(ibegin+i ,jj+j+2UL) += sum( xmm3 ) * alpha;
1399 c(ibegin+i ,jj+j+3UL) += sum( xmm4 ) * alpha;
1400 c(ibegin+i+1UL,jj+j ) += sum( xmm5 ) * alpha;
1401 c(ibegin+i+1UL,jj+j+1UL) += sum( xmm6 ) * alpha;
1402 c(ibegin+i+1UL,jj+j+2UL) += sum( xmm7 ) * alpha;
1403 c(ibegin+i+1UL,jj+j+3UL) += sum( xmm8 ) * alpha;
1404 }
1405
1406 for( ; (j+2UL) <= jend; j+=2UL )
1407 {
1408 SIMDType xmm1, xmm2, xmm3, xmm4;
1409
1410 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
1411 {
1412 const SIMDType a1( A2.load(i ,k) );
1413 const SIMDType a2( A2.load(i+1UL,k) );
1414
1415 const SIMDType b1( B2.load(k,j ) );
1416 const SIMDType b2( B2.load(k,j+1UL) );
1417
1418 xmm1 += a1 * b1;
1419 xmm2 += a1 * b2;
1420 xmm3 += a2 * b1;
1421 xmm4 += a2 * b2;
1422 }
1423
1424 c(ibegin+i ,jj+j ) += sum( xmm1 ) * alpha;
1425 c(ibegin+i ,jj+j+1UL) += sum( xmm2 ) * alpha;
1426 c(ibegin+i+1UL,jj+j ) += sum( xmm3 ) * alpha;
1427 c(ibegin+i+1UL,jj+j+1UL) += sum( xmm4 ) * alpha;
1428 }
1429
1430 if( j<jend )
1431 {
1432 SIMDType xmm1, xmm2;
1433
1434 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
1435 {
1436 const SIMDType a1( A2.load(i ,k) );
1437 const SIMDType a2( A2.load(i+1UL,k) );
1438
1439 const SIMDType b1( B2.load(k,j) );
1440
1441 xmm1 += a1 * b1;
1442 xmm2 += a2 * b1;
1443 }
1444
1445 c(ibegin+i ,jj+j) += sum( xmm1 ) * alpha;
1446 c(ibegin+i+1UL,jj+j) += sum( xmm2 ) * alpha;
1447 }
1448 }
1449
1450 if( i<isize && jj <= ibegin+i )
1451 {
1452 const size_t jend( min( ibegin+i-jj+2UL, jblock ) );
1453 size_t j( 0UL );
1454
1455 for( ; (j+2UL) <= jend; j+=2UL )
1456 {
1457 SIMDType xmm1, xmm2;
1458
1459 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
1460 {
1461 const SIMDType a1( A2.load(i,k) );
1462
1463 xmm1 += a1 * B2.load(k,j );
1464 xmm2 += a1 * B2.load(k,j+1UL);
1465 }
1466
1467 c(ibegin+i,jj+j ) += sum( xmm1 ) * alpha;
1468 c(ibegin+i,jj+j+1UL) += sum( xmm2 ) * alpha;
1469 }
1470
1471 if( j<jend )
1472 {
1473 SIMDType xmm1;
1474
1475 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
1476 {
1477 const SIMDType a1( A2.load(i,k) );
1478
1479 xmm1 += a1 * B2.load(k,j);
1480 }
1481
1482 c(ibegin+i,jj+j) += sum( xmm1 ) * alpha;
1483 }
1484 }
1485
1486 jj += jblock;
1487 }
1488
1489 kk += kblock;
1490 }
1491
1492 if( remainder && kk < K )
1493 {
1494 const size_t ksize( K - kk );
1495
1496 const size_t ibegin( IsLower_v<MT2> ? kk : 0UL );
1497 const size_t isize ( M - ibegin );
1498
1499 A2 = serial( submatrix( A, ibegin, kk, isize, ksize, unchecked ) );
1500
1501 size_t jj( 0UL );
1502 size_t jblock( 0UL );
1503
1504 while( jj < N )
1505 {
1506 jblock = ( ( jj+JBLOCK <= N )?( JBLOCK ):( N - jj ) );
1507
1508 if( IsUpper_v<MT3> && jj+jblock <= kk ) {
1509 jj += jblock;
1510 continue;
1511 }
1512
1513 B2 = serial( submatrix( B, kk, jj, ksize, jblock, unchecked ) );
1514
1515 size_t i( 0UL );
1516
1517 if( IsFloatingPoint_v<ET1> )
1518 {
1519 for( ; (i+5UL) <= isize; i+=5UL )
1520 {
1521 if( jj > ibegin+i+4UL ) continue;
1522
1523 const size_t jend( min( ibegin+i-jj+5UL, jblock ) );
1524 size_t j( 0UL );
1525
1526 for( ; (j+2UL) <= jend; j+=2UL ) {
1527 for( size_t k=0UL; k<ksize; ++k ) {
1528 c(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
1529 c(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
1530 c(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
1531 c(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
1532 c(ibegin+i+2UL,jj+j ) += A2(i+2UL,k) * B2(k,j ) * alpha;
1533 c(ibegin+i+2UL,jj+j+1UL) += A2(i+2UL,k) * B2(k,j+1UL) * alpha;
1534 c(ibegin+i+3UL,jj+j ) += A2(i+3UL,k) * B2(k,j ) * alpha;
1535 c(ibegin+i+3UL,jj+j+1UL) += A2(i+3UL,k) * B2(k,j+1UL) * alpha;
1536 c(ibegin+i+4UL,jj+j ) += A2(i+4UL,k) * B2(k,j ) * alpha;
1537 c(ibegin+i+4UL,jj+j+1UL) += A2(i+4UL,k) * B2(k,j+1UL) * alpha;
1538 }
1539 }
1540
1541 if( j<jend ) {
1542 for( size_t k=0UL; k<ksize; ++k ) {
1543 c(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
1544 c(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
1545 c(ibegin+i+2UL,jj+j) += A2(i+2UL,k) * B2(k,j) * alpha;
1546 c(ibegin+i+3UL,jj+j) += A2(i+3UL,k) * B2(k,j) * alpha;
1547 c(ibegin+i+4UL,jj+j) += A2(i+4UL,k) * B2(k,j) * alpha;
1548 }
1549 }
1550 }
1551 }
1552 else
1553 {
1554 for( ; (i+4UL) <= isize; i+=4UL )
1555 {
1556 if( jj > ibegin+i+3UL ) continue;
1557
1558 const size_t jend( min( ibegin+i-jj+4UL, jblock ) );
1559 size_t j( 0UL );
1560
1561 for( ; (j+2UL) <= jend; j+=2UL ) {
1562 for( size_t k=0UL; k<ksize; ++k ) {
1563 c(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
1564 c(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
1565 c(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
1566 c(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
1567 c(ibegin+i+2UL,jj+j ) += A2(i+2UL,k) * B2(k,j ) * alpha;
1568 c(ibegin+i+2UL,jj+j+1UL) += A2(i+2UL,k) * B2(k,j+1UL) * alpha;
1569 c(ibegin+i+3UL,jj+j ) += A2(i+3UL,k) * B2(k,j ) * alpha;
1570 c(ibegin+i+3UL,jj+j+1UL) += A2(i+3UL,k) * B2(k,j+1UL) * alpha;
1571 }
1572 }
1573
1574 if( j<jend ) {
1575 for( size_t k=0UL; k<ksize; ++k ) {
1576 c(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
1577 c(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
1578 c(ibegin+i+2UL,jj+j) += A2(i+2UL,k) * B2(k,j) * alpha;
1579 c(ibegin+i+3UL,jj+j) += A2(i+3UL,k) * B2(k,j) * alpha;
1580 }
1581 }
1582 }
1583 }
1584
1585 for( ; (i+2UL) <= isize; i+=2UL )
1586 {
1587 if( jj > ibegin+i+1UL ) continue;
1588
1589 const size_t jend( min( ibegin+i-jj+2UL, jblock ) );
1590 size_t j( 0UL );
1591
1592 for( ; (j+2UL) <= jend; j+=2UL ) {
1593 for( size_t k=0UL; k<ksize; ++k ) {
1594 c(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
1595 c(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
1596 c(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
1597 c(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
1598 }
1599 }
1600
1601 if( j<jend ) {
1602 for( size_t k=0UL; k<ksize; ++k ) {
1603 c(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
1604 c(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
1605 }
1606 }
1607 }
1608
1609 if( i<isize && jj <= ibegin+i )
1610 {
1611 const size_t jend( min( ibegin+i-jj+2UL, jblock ) );
1612 size_t j( 0UL );
1613
1614 for( ; (j+2UL) <= jend; j+=2UL ) {
1615 for( size_t k=0UL; k<ksize; ++k ) {
1616 c(ibegin+i,jj+j ) += A2(i,k) * B2(k,j ) * alpha;
1617 c(ibegin+i,jj+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
1618 }
1619 }
1620
1621 if( j<jend ) {
1622 for( size_t k=0UL; k<ksize; ++k ) {
1623 c(ibegin+i,jj+j) += A2(i,k) * B2(k,j) * alpha;
1624 }
1625 }
1626 }
1627
1628 jj += jblock;
1629 }
1630 }
1631 }
1632 /*! \endcond */
1633 //*************************************************************************************************
1634
1635
1636 //*************************************************************************************************
1637 /*! \cond BLAZE_INTERNAL */
1638 /*!\brief Compute kernel for a lower dense matrix/dense matrix multiplication
1639 // (\f$ C=\alpha*A*B+\beta*C \f$).
1640 // \ingroup dense_matrix
1641 //
1642 // \param C The target left-hand side column-major dense matrix.
1643 // \param A The left-hand side multiplication operand.
1644 // \param B The right-hand side multiplication operand.
1645 // \param alpha The scaling factor for \f$ A*B \f$.
1646 // \param beta The scaling factor for \f$ C \f$.
1647 // \return void
1648 //
1649 // This function implements the compute kernel for a lower dense matrix/dense matrix
1650 // multiplication of the form \f$ C=\alpha*A*B+\beta*C \f$. Both \a A and \a B must
1651 // be non-expression dense matrix types, \a C must be a non-expression, non-adaptor,
1652 // column-major dense matrix type. The element types of all three matrices must be SIMD
1653 // combinable, i.e. must provide a common SIMD interface.
1654 */
1655 template< typename MT1, typename MT2, typename MT3, typename ST >
lmmm(DenseMatrix<MT1,true> & C,const MT2 & A,const MT3 & B,ST alpha,ST beta)1656 void lmmm( DenseMatrix<MT1,true>& C, const MT2& A, const MT3& B, ST alpha, ST beta )
1657 {
1658 using ET1 = ElementType_t<MT1>;
1659 using ET2 = ElementType_t<MT2>;
1660 using ET3 = ElementType_t<MT3>;
1661 using SIMDType = SIMDTrait_t<ET1>;
1662
1663 BLAZE_CONSTRAINT_MUST_BE_DENSE_MATRIX_TYPE ( MT1 );
1664 BLAZE_CONSTRAINT_MUST_BE_COLUMN_MAJOR_MATRIX_TYPE ( MT1 );
1665 BLAZE_CONSTRAINT_MUST_NOT_BE_SYMMETRIC_MATRIX_TYPE ( MT1 );
1666 BLAZE_CONSTRAINT_MUST_NOT_BE_HERMITIAN_MATRIX_TYPE ( MT1 );
1667 BLAZE_CONSTRAINT_MUST_NOT_BE_UNILOWER_MATRIX_TYPE ( MT1 );
1668 BLAZE_CONSTRAINT_MUST_NOT_BE_STRICTLY_LOWER_MATRIX_TYPE( MT1 );
1669 BLAZE_CONSTRAINT_MUST_NOT_BE_UPPER_MATRIX_TYPE ( MT1 );
1670 BLAZE_CONSTRAINT_MUST_NOT_BE_COMPUTATION_TYPE ( MT1 );
1671
1672 BLAZE_CONSTRAINT_MUST_BE_DENSE_MATRIX_TYPE ( MT2 );
1673 BLAZE_CONSTRAINT_MUST_NOT_BE_COMPUTATION_TYPE( MT2 );
1674
1675 BLAZE_CONSTRAINT_MUST_BE_DENSE_MATRIX_TYPE ( MT3 );
1676 BLAZE_CONSTRAINT_MUST_NOT_BE_COMPUTATION_TYPE( MT3 );
1677
1678 BLAZE_CONSTRAINT_MUST_BE_SIMD_COMBINABLE_TYPES( ET1, ET2 );
1679 BLAZE_CONSTRAINT_MUST_BE_SIMD_COMBINABLE_TYPES( ET1, ET3 );
1680
1681 constexpr size_t SIMDSIZE( SIMDTrait<ET1>::size );
1682
1683 constexpr bool remainder( !IsPadded_v<MT2> || !IsPadded_v<MT3> );
1684
1685 constexpr size_t KBLOCK( MMM_OUTER_BLOCK_SIZE * ( 16UL/sizeof(ET1) ) );
1686 constexpr size_t IBLOCK( MMM_INNER_BLOCK_SIZE );
1687
1688 BLAZE_STATIC_ASSERT( KBLOCK >= SIMDSIZE && KBLOCK % SIMDSIZE == 0UL );
1689 BLAZE_STATIC_ASSERT( IBLOCK >= SIMDSIZE && IBLOCK % SIMDSIZE == 0UL );
1690
1691 const size_t M( A.rows() );
1692 const size_t N( B.columns() );
1693 const size_t K( A.columns() );
1694
1695 BLAZE_INTERNAL_ASSERT( A.columns() == B.rows(), "Invalid matrix sizes detected" );
1696
1697 DynamicMatrix<ET2,false> A2( IBLOCK, KBLOCK );
1698 DynamicMatrix<ET3,true> B2( KBLOCK, N );
1699
1700 decltype(auto) c( derestrict( *C ) );
1701
1702 if( isDefault( beta ) ) {
1703 reset( c );
1704 }
1705 else if( !isOne( beta ) ) {
1706 c *= beta;
1707 }
1708
1709 size_t kk( 0UL );
1710 size_t kblock( 0UL );
1711
1712 while( kk + ( remainder ? SIMDSIZE-1UL : 0UL ) < K )
1713 {
1714 if( remainder ) {
1715 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( prevMultiple( K - kk, SIMDSIZE ) ) );
1716 }
1717 else {
1718 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( K - kk ) );
1719 }
1720
1721 const size_t jbegin( IsUpper_v<MT3> ? kk : 0UL );
1722 const size_t jend ( IsLower_v<MT3> ? kk+kblock : N );
1723 const size_t jsize ( jend - jbegin );
1724
1725 B2 = serial( submatrix< remainder ? unaligned : aligned >( B, kk, jbegin, kblock, jsize, unchecked ) );
1726
1727 size_t ii( 0UL );
1728 size_t iblock( 0UL );
1729
1730 while( ii < M )
1731 {
1732 iblock = ( ( ii+IBLOCK <= M )?( IBLOCK ):( M - ii ) );
1733
1734 if( ( IsLower_v<MT2> && ii+iblock <= kk ) ||
1735 ( IsUpper_v<MT2> && kk+kblock <= ii ) ) {
1736 ii += iblock;
1737 continue;
1738 }
1739
1740 A2 = serial( submatrix< remainder ? unaligned : aligned >( A, ii, kk, iblock, kblock, unchecked ) );
1741
1742 size_t j( 0UL );
1743
1744 if( IsFloatingPoint_v<ET3> )
1745 {
1746 for( ; (j+5UL) <= jsize; j+=5UL )
1747 {
1748 if( ii+iblock < jbegin ) continue;
1749
1750 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
1751
1752 for( ; (i+2UL) <= iblock; i+=2UL )
1753 {
1754 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8, xmm9, xmm10;
1755
1756 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
1757 {
1758 const SIMDType a1( A2.load(i ,k) );
1759 const SIMDType a2( A2.load(i+1UL,k) );
1760
1761 const SIMDType b1( B2.load(k,j ) );
1762 const SIMDType b2( B2.load(k,j+1UL) );
1763 const SIMDType b3( B2.load(k,j+2UL) );
1764 const SIMDType b4( B2.load(k,j+3UL) );
1765 const SIMDType b5( B2.load(k,j+4UL) );
1766
1767 xmm1 += a1 * b1;
1768 xmm2 += a1 * b2;
1769 xmm3 += a1 * b3;
1770 xmm4 += a1 * b4;
1771 xmm5 += a1 * b5;
1772 xmm6 += a2 * b1;
1773 xmm7 += a2 * b2;
1774 xmm8 += a2 * b3;
1775 xmm9 += a2 * b4;
1776 xmm10 += a2 * b5;
1777 }
1778
1779 c(ii+i ,jbegin+j ) += sum( xmm1 ) * alpha;
1780 c(ii+i ,jbegin+j+1UL) += sum( xmm2 ) * alpha;
1781 c(ii+i ,jbegin+j+2UL) += sum( xmm3 ) * alpha;
1782 c(ii+i ,jbegin+j+3UL) += sum( xmm4 ) * alpha;
1783 c(ii+i ,jbegin+j+4UL) += sum( xmm5 ) * alpha;
1784 c(ii+i+1UL,jbegin+j ) += sum( xmm6 ) * alpha;
1785 c(ii+i+1UL,jbegin+j+1UL) += sum( xmm7 ) * alpha;
1786 c(ii+i+1UL,jbegin+j+2UL) += sum( xmm8 ) * alpha;
1787 c(ii+i+1UL,jbegin+j+3UL) += sum( xmm9 ) * alpha;
1788 c(ii+i+1UL,jbegin+j+4UL) += sum( xmm10 ) * alpha;
1789 }
1790
1791 if( i<iblock )
1792 {
1793 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5;
1794
1795 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
1796 {
1797 const SIMDType a1( A2.load(i,k) );
1798
1799 xmm1 += a1 * B2.load(k,j );
1800 xmm2 += a1 * B2.load(k,j+1UL);
1801 xmm3 += a1 * B2.load(k,j+2UL);
1802 xmm4 += a1 * B2.load(k,j+3UL);
1803 xmm5 += a1 * B2.load(k,j+4UL);
1804 }
1805
1806 c(ii+i,jbegin+j ) += sum( xmm1 ) * alpha;
1807 c(ii+i,jbegin+j+1UL) += sum( xmm2 ) * alpha;
1808 c(ii+i,jbegin+j+2UL) += sum( xmm3 ) * alpha;
1809 c(ii+i,jbegin+j+3UL) += sum( xmm4 ) * alpha;
1810 c(ii+i,jbegin+j+4UL) += sum( xmm5 ) * alpha;
1811 }
1812 }
1813 }
1814 else
1815 {
1816 for( ; (j+4UL) <= jsize; j+=4UL )
1817 {
1818 if( ii+iblock < jbegin ) continue;
1819
1820 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
1821
1822 for( ; (i+2UL) <= iblock; i+=2UL )
1823 {
1824 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
1825
1826 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
1827 {
1828 const SIMDType a1( A2.load(i ,k) );
1829 const SIMDType a2( A2.load(i+1UL,k) );
1830
1831 const SIMDType b1( B2.load(k,j ) );
1832 const SIMDType b2( B2.load(k,j+1UL) );
1833 const SIMDType b3( B2.load(k,j+2UL) );
1834 const SIMDType b4( B2.load(k,j+3UL) );
1835
1836 xmm1 += a1 * b1;
1837 xmm2 += a1 * b2;
1838 xmm3 += a1 * b3;
1839 xmm4 += a1 * b4;
1840 xmm5 += a2 * b1;
1841 xmm6 += a2 * b2;
1842 xmm7 += a2 * b3;
1843 xmm8 += a2 * b4;
1844 }
1845
1846 c(ii+i ,jbegin+j ) += sum( xmm1 ) * alpha;
1847 c(ii+i ,jbegin+j+1UL) += sum( xmm2 ) * alpha;
1848 c(ii+i ,jbegin+j+2UL) += sum( xmm3 ) * alpha;
1849 c(ii+i ,jbegin+j+3UL) += sum( xmm4 ) * alpha;
1850 c(ii+i+1UL,jbegin+j ) += sum( xmm5 ) * alpha;
1851 c(ii+i+1UL,jbegin+j+1UL) += sum( xmm6 ) * alpha;
1852 c(ii+i+1UL,jbegin+j+2UL) += sum( xmm7 ) * alpha;
1853 c(ii+i+1UL,jbegin+j+3UL) += sum( xmm8 ) * alpha;
1854 }
1855
1856 if( i<iblock )
1857 {
1858 SIMDType xmm1, xmm2, xmm3, xmm4;
1859
1860 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
1861 {
1862 const SIMDType a1( A2.load(i,k) );
1863
1864 xmm1 += a1 * B2.load(k,j );
1865 xmm2 += a1 * B2.load(k,j+1UL);
1866 xmm3 += a1 * B2.load(k,j+2UL);
1867 xmm4 += a1 * B2.load(k,j+3UL);
1868 }
1869
1870 c(ii+i,jbegin+j ) += sum( xmm1 ) * alpha;
1871 c(ii+i,jbegin+j+1UL) += sum( xmm2 ) * alpha;
1872 c(ii+i,jbegin+j+2UL) += sum( xmm3 ) * alpha;
1873 c(ii+i,jbegin+j+3UL) += sum( xmm4 ) * alpha;
1874 }
1875 }
1876 }
1877
1878 for( ; (j+2UL) <= jsize; j+=2UL )
1879 {
1880 if( ii+iblock < jbegin ) continue;
1881
1882 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
1883
1884 for( ; (i+4UL) <= iblock; i+=4UL )
1885 {
1886 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
1887
1888 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
1889 {
1890 const SIMDType a1( A2.load(i ,k) );
1891 const SIMDType a2( A2.load(i+1UL,k) );
1892 const SIMDType a3( A2.load(i+2UL,k) );
1893 const SIMDType a4( A2.load(i+3UL,k) );
1894
1895 const SIMDType b1( B2.load(k,j ) );
1896 const SIMDType b2( B2.load(k,j+1UL) );
1897
1898 xmm1 += a1 * b1;
1899 xmm2 += a1 * b2;
1900 xmm3 += a2 * b1;
1901 xmm4 += a2 * b2;
1902 xmm5 += a3 * b1;
1903 xmm6 += a3 * b2;
1904 xmm7 += a4 * b1;
1905 xmm8 += a4 * b2;
1906 }
1907
1908 c(ii+i ,jbegin+j ) += sum( xmm1 ) * alpha;
1909 c(ii+i ,jbegin+j+1UL) += sum( xmm2 ) * alpha;
1910 c(ii+i+1UL,jbegin+j ) += sum( xmm3 ) * alpha;
1911 c(ii+i+1UL,jbegin+j+1UL) += sum( xmm4 ) * alpha;
1912 c(ii+i+2UL,jbegin+j ) += sum( xmm5 ) * alpha;
1913 c(ii+i+2UL,jbegin+j+1UL) += sum( xmm6 ) * alpha;
1914 c(ii+i+3UL,jbegin+j ) += sum( xmm7 ) * alpha;
1915 c(ii+i+3UL,jbegin+j+1UL) += sum( xmm8 ) * alpha;
1916 }
1917
1918 for( ; (i+2UL) <= iblock; i+=2UL )
1919 {
1920 SIMDType xmm1, xmm2, xmm3, xmm4;
1921
1922 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
1923 {
1924 const SIMDType a1( A2.load(i ,k) );
1925 const SIMDType a2( A2.load(i+1UL,k) );
1926
1927 const SIMDType b1( B2.load(k,j ) );
1928 const SIMDType b2( B2.load(k,j+1UL) );
1929
1930 xmm1 += a1 * b1;
1931 xmm2 += a1 * b2;
1932 xmm3 += a2 * b1;
1933 xmm4 += a2 * b2;
1934 }
1935
1936 c(ii+i ,jbegin+j ) += sum( xmm1 ) * alpha;
1937 c(ii+i ,jbegin+j+1UL) += sum( xmm2 ) * alpha;
1938 c(ii+i+1UL,jbegin+j ) += sum( xmm3 ) * alpha;
1939 c(ii+i+1UL,jbegin+j+1UL) += sum( xmm4 ) * alpha;
1940 }
1941
1942 if( i<iblock )
1943 {
1944 SIMDType xmm1, xmm2;
1945
1946 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
1947 {
1948 const SIMDType a1( A2.load(i,k) );
1949
1950 xmm1 += a1 * B2.load(k,j );
1951 xmm2 += a1 * B2.load(k,j+1UL);
1952 }
1953
1954 c(ii+i,jbegin+j ) += sum( xmm1 ) * alpha;
1955 c(ii+i,jbegin+j+1UL) += sum( xmm2 ) * alpha;
1956 }
1957 }
1958
1959 if( j<jsize && ii+iblock >= jbegin )
1960 {
1961 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
1962
1963 for( ; (i+2UL) <= iblock; i+=2UL )
1964 {
1965 SIMDType xmm1, xmm2;
1966
1967 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
1968 {
1969 const SIMDType b1( B2.load(k,j) );
1970
1971 xmm1 += A2.load(i ,k) * b1;
1972 xmm2 += A2.load(i+1UL,k) * b1;
1973 }
1974
1975 c(ii+i ,jbegin+j) += sum( xmm1 ) * alpha;
1976 c(ii+i+1UL,jbegin+j) += sum( xmm2 ) * alpha;
1977 }
1978
1979 if( i<iblock )
1980 {
1981 SIMDType xmm1;
1982
1983 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
1984 {
1985 xmm1 += A2.load(i,k) * B2.load(k,j);
1986 }
1987
1988 c(ii+i,jbegin+j) += sum( xmm1 ) * alpha;
1989 }
1990 }
1991
1992 ii += iblock;
1993 }
1994
1995 kk += kblock;
1996 }
1997
1998 if( remainder && kk < K )
1999 {
2000 const size_t ksize( K - kk );
2001
2002 const size_t jbegin( IsUpper_v<MT3> ? kk : 0UL );
2003 const size_t jsize ( N - jbegin );
2004
2005 B2 = serial( submatrix( B, kk, jbegin, ksize, jsize, unchecked ) );
2006
2007 size_t ii( 0UL );
2008 size_t iblock( 0UL );
2009
2010 while( ii < M )
2011 {
2012 iblock = ( ( ii+IBLOCK <= M )?( IBLOCK ):( M - ii ) );
2013
2014 if( IsLower_v<MT2> && ii+iblock <= kk ) {
2015 ii += iblock;
2016 continue;
2017 }
2018
2019 A2 = serial( submatrix( A, ii, kk, iblock, ksize, unchecked ) );
2020
2021 size_t j( 0UL );
2022
2023 if( IsFloatingPoint_v<ET1> )
2024 {
2025 for( ; (j+5UL) <= jsize; j+=5UL )
2026 {
2027 if( ii+iblock < jbegin ) continue;
2028
2029 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
2030
2031 for( ; (i+2UL) <= iblock; i+=2UL ) {
2032 for( size_t k=0UL; k<ksize; ++k ) {
2033 c(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
2034 c(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
2035 c(ii+i ,jbegin+j+2UL) += A2(i ,k) * B2(k,j+2UL) * alpha;
2036 c(ii+i ,jbegin+j+3UL) += A2(i ,k) * B2(k,j+3UL) * alpha;
2037 c(ii+i ,jbegin+j+4UL) += A2(i ,k) * B2(k,j+4UL) * alpha;
2038 c(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
2039 c(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
2040 c(ii+i+1UL,jbegin+j+2UL) += A2(i+1UL,k) * B2(k,j+2UL) * alpha;
2041 c(ii+i+1UL,jbegin+j+3UL) += A2(i+1UL,k) * B2(k,j+3UL) * alpha;
2042 c(ii+i+1UL,jbegin+j+4UL) += A2(i+1UL,k) * B2(k,j+4UL) * alpha;
2043 }
2044 }
2045
2046 if( i<iblock ) {
2047 for( size_t k=0UL; k<ksize; ++k ) {
2048 c(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
2049 c(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
2050 c(ii+i,jbegin+j+2UL) += A2(i,k) * B2(k,j+2UL) * alpha;
2051 c(ii+i,jbegin+j+3UL) += A2(i,k) * B2(k,j+3UL) * alpha;
2052 c(ii+i,jbegin+j+4UL) += A2(i,k) * B2(k,j+4UL) * alpha;
2053 }
2054 }
2055 }
2056 }
2057 else
2058 {
2059 for( ; (j+4UL) <= jsize; j+=4UL )
2060 {
2061 if( ii+iblock < jbegin ) continue;
2062
2063 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
2064
2065 for( ; (i+2UL) <= iblock; i+=2UL ) {
2066 for( size_t k=0UL; k<ksize; ++k ) {
2067 c(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
2068 c(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
2069 c(ii+i ,jbegin+j+2UL) += A2(i ,k) * B2(k,j+2UL) * alpha;
2070 c(ii+i ,jbegin+j+3UL) += A2(i ,k) * B2(k,j+3UL) * alpha;
2071 c(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
2072 c(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
2073 c(ii+i+1UL,jbegin+j+2UL) += A2(i+1UL,k) * B2(k,j+2UL) * alpha;
2074 c(ii+i+1UL,jbegin+j+3UL) += A2(i+1UL,k) * B2(k,j+3UL) * alpha;
2075 }
2076 }
2077
2078 if( i<iblock ) {
2079 for( size_t k=0UL; k<ksize; ++k ) {
2080 c(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
2081 c(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
2082 c(ii+i,jbegin+j+2UL) += A2(i,k) * B2(k,j+2UL) * alpha;
2083 c(ii+i,jbegin+j+3UL) += A2(i,k) * B2(k,j+3UL) * alpha;
2084 }
2085 }
2086 }
2087 }
2088
2089 for( ; (j+2UL) <= jsize; j+=2UL )
2090 {
2091 if( ii+iblock < jbegin ) continue;
2092
2093 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
2094
2095 for( ; (i+2UL) <= iblock; i+=2UL ) {
2096 for( size_t k=0UL; k<ksize; ++k ) {
2097 c(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
2098 c(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
2099 c(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
2100 c(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
2101 }
2102 }
2103
2104 if( i<iblock ) {
2105 for( size_t k=0UL; k<ksize; ++k ) {
2106 c(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
2107 c(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
2108 }
2109 }
2110 }
2111
2112 if( j<jsize )
2113 {
2114 if( ii+iblock < jbegin ) continue;
2115
2116 size_t i( ( ii > jbegin+j )?( 0UL ):( jbegin+j-ii ) );
2117
2118 for( ; (i+2UL) <= iblock; i+=2UL ) {
2119 for( size_t k=0UL; k<ksize; ++k ) {
2120 c(ii+i ,jbegin+j) += A2(i ,k) * B2(k,j) * alpha;
2121 c(ii+i+1UL,jbegin+j) += A2(i+1UL,k) * B2(k,j) * alpha;
2122 }
2123 }
2124
2125 if( i<iblock ) {
2126 for( size_t k=0UL; k<ksize; ++k ) {
2127 c(ii+i,jbegin+j) += A2(i,k) * B2(k,j) * alpha;
2128 }
2129 }
2130 }
2131
2132 ii += iblock;
2133 }
2134 }
2135 }
2136 /*! \endcond */
2137 //*************************************************************************************************
2138
2139
2140 //*************************************************************************************************
2141 /*! \cond BLAZE_INTERNAL */
2142 /*!\brief Compute kernel for a lower dense matrix/dense matrix multiplication (\f$ C=A*B \f$).
2143 // \ingroup dense_matrix
2144 //
2145 // \param C The target left-hand side column-major dense matrix.
2146 // \param A The left-hand side multiplication operand.
2147 // \param B The right-hand side multiplication operand.
2148 // \return void
2149 //
2150 // This function implements the compute kernel for a lower dense matrix/dense matrix
2151 // multiplication of the form \f$ C=A*B \f$. Both \a A and \a B must be non-expression
2152 // dense matrix types, \a C must be a non-expression, non-adaptor, row-major dense matrix
2153 // type. The element types of all three matrices must be SIMD combinable, i.e. must
2154 // provide a common SIMD interface.
2155 */
2156 template< typename MT1, typename MT2, typename MT3 >
lmmm(MT1 & C,const MT2 & A,const MT3 & B)2157 inline void lmmm( MT1& C, const MT2& A, const MT3& B )
2158 {
2159 using ET1 = ElementType_t<MT1>;
2160 using ET2 = ElementType_t<MT2>;
2161 using ET3 = ElementType_t<MT3>;
2162
2163 BLAZE_CONSTRAINT_MUST_BE_SIMD_COMBINABLE_TYPES( ET1, ET2 );
2164 BLAZE_CONSTRAINT_MUST_BE_SIMD_COMBINABLE_TYPES( ET1, ET3 );
2165
2166 lmmm( C, A, B, ET1(1), ET1(0) );
2167 }
2168 /*! \endcond */
2169 //*************************************************************************************************
2170
2171
2172
2173
2174 //=================================================================================================
2175 //
2176 // UPPER DENSE MATRIX MULTIPLICATION KERNELS
2177 //
2178 //=================================================================================================
2179
2180 //*************************************************************************************************
2181 /*! \cond BLAZE_INTERNAL */
2182 /*!\brief Compute kernel for a upper dense matrix/dense matrix multiplication
2183 // (\f$ C=\alpha*A*B+\beta*C \f$).
2184 // \ingroup dense_matrix
2185 //
2186 // \param C The target left-hand side row-major dense matrix.
2187 // \param A The left-hand side multiplication operand.
2188 // \param B The right-hand side multiplication operand.
2189 // \param alpha The scaling factor for \f$ A*B \f$.
2190 // \param beta The scaling factor for \f$ C \f$.
2191 // \return void
2192 //
2193 // This function implements the compute kernel for a upper dense matrix/dense matrix
2194 // multiplication of the form \f$ C=\alpha*A*B+\beta*C \f$. Both \a A and \a B must
2195 // be non-expression dense matrix types, \a C must be a non-expression, non-adaptor,
2196 // row-major dense matrix type. The element types of all three matrices must be SIMD
2197 // combinable, i.e. must provide a common SIMD interface.
2198 */
2199 template< typename MT1, typename MT2, typename MT3, typename ST >
ummm(DenseMatrix<MT1,false> & C,const MT2 & A,const MT3 & B,ST alpha,ST beta)2200 void ummm( DenseMatrix<MT1,false>& C, const MT2& A, const MT3& B, ST alpha, ST beta )
2201 {
2202 using ET1 = ElementType_t<MT1>;
2203 using ET2 = ElementType_t<MT2>;
2204 using ET3 = ElementType_t<MT3>;
2205 using SIMDType = SIMDTrait_t<ET1>;
2206
2207 BLAZE_CONSTRAINT_MUST_BE_DENSE_MATRIX_TYPE ( MT1 );
2208 BLAZE_CONSTRAINT_MUST_BE_ROW_MAJOR_MATRIX_TYPE ( MT1 );
2209 BLAZE_CONSTRAINT_MUST_NOT_BE_SYMMETRIC_MATRIX_TYPE ( MT1 );
2210 BLAZE_CONSTRAINT_MUST_NOT_BE_HERMITIAN_MATRIX_TYPE ( MT1 );
2211 BLAZE_CONSTRAINT_MUST_NOT_BE_LOWER_MATRIX_TYPE ( MT1 );
2212 BLAZE_CONSTRAINT_MUST_NOT_BE_UNIUPPER_MATRIX_TYPE ( MT1 );
2213 BLAZE_CONSTRAINT_MUST_NOT_BE_STRICTLY_UPPER_MATRIX_TYPE( MT1 );
2214 BLAZE_CONSTRAINT_MUST_NOT_BE_COMPUTATION_TYPE ( MT1 );
2215
2216 BLAZE_CONSTRAINT_MUST_BE_DENSE_MATRIX_TYPE ( MT2 );
2217 BLAZE_CONSTRAINT_MUST_NOT_BE_COMPUTATION_TYPE( MT2 );
2218
2219 BLAZE_CONSTRAINT_MUST_BE_DENSE_MATRIX_TYPE ( MT3 );
2220 BLAZE_CONSTRAINT_MUST_NOT_BE_COMPUTATION_TYPE( MT3 );
2221
2222 BLAZE_CONSTRAINT_MUST_BE_SIMD_COMBINABLE_TYPES( ET1, ET2 );
2223 BLAZE_CONSTRAINT_MUST_BE_SIMD_COMBINABLE_TYPES( ET1, ET3 );
2224
2225 constexpr size_t SIMDSIZE( SIMDTrait<ET1>::size );
2226
2227 constexpr bool remainder( !IsPadded_v<MT2> || !IsPadded_v<MT3> );
2228
2229 constexpr size_t KBLOCK( MMM_OUTER_BLOCK_SIZE * ( 16UL/sizeof(ET1) ) );
2230 constexpr size_t JBLOCK( MMM_INNER_BLOCK_SIZE );
2231
2232 BLAZE_STATIC_ASSERT( KBLOCK >= SIMDSIZE && KBLOCK % SIMDSIZE == 0UL );
2233 BLAZE_STATIC_ASSERT( JBLOCK >= SIMDSIZE && JBLOCK % SIMDSIZE == 0UL );
2234
2235 const size_t M( A.rows() );
2236 const size_t N( B.columns() );
2237 const size_t K( A.columns() );
2238
2239 BLAZE_INTERNAL_ASSERT( A.columns() == B.rows(), "Invalid matrix sizes detected" );
2240
2241 DynamicMatrix<ET2,false> A2( M, KBLOCK );
2242 DynamicMatrix<ET3,true> B2( KBLOCK, JBLOCK );
2243
2244 decltype(auto) c( derestrict( *C ) );
2245
2246 if( isDefault( beta ) ) {
2247 reset( c );
2248 }
2249 else if( !isOne( beta ) ) {
2250 c *= beta;
2251 }
2252
2253 size_t kk( 0UL );
2254 size_t kblock( 0UL );
2255
2256 while( kk + ( remainder ? SIMDSIZE-1UL : 0UL ) < K )
2257 {
2258 if( remainder ) {
2259 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( prevMultiple( K - kk, SIMDSIZE ) ) );
2260 }
2261 else {
2262 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( K - kk ) );
2263 }
2264
2265 const size_t ibegin( IsLower_v<MT2> ? kk : 0UL );
2266 const size_t iend ( IsUpper_v<MT2> ? kk+kblock : M );
2267 const size_t isize ( iend - ibegin );
2268
2269 A2 = serial( submatrix< remainder ? unaligned : aligned >( A, ibegin, kk, isize, kblock, unchecked ) );
2270
2271 size_t jj( 0UL );
2272 size_t jblock( 0UL );
2273
2274 while( jj < N )
2275 {
2276 jblock = ( ( jj+JBLOCK <= N )?( JBLOCK ):( N - jj ) );
2277
2278 if( ( IsLower_v<MT3> && kk+kblock <= jj ) ||
2279 ( IsUpper_v<MT3> && jj+jblock <= kk ) ) {
2280 jj += jblock;
2281 continue;
2282 }
2283
2284 B2 = serial( submatrix< remainder ? unaligned : aligned >( B, kk, jj, kblock, jblock, unchecked ) );
2285
2286 size_t i( 0UL );
2287
2288 if( IsFloatingPoint_v<ET1> )
2289 {
2290 for( ; (i+5UL) <= isize; i+=5UL )
2291 {
2292 if( jj+jblock < ibegin ) continue;
2293
2294 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2295
2296 for( ; (j+2UL) <= jblock; j+=2UL )
2297 {
2298 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8, xmm9, xmm10;
2299
2300 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
2301 {
2302 const SIMDType a1( A2.load(i ,k) );
2303 const SIMDType a2( A2.load(i+1UL,k) );
2304 const SIMDType a3( A2.load(i+2UL,k) );
2305 const SIMDType a4( A2.load(i+3UL,k) );
2306 const SIMDType a5( A2.load(i+4UL,k) );
2307
2308 const SIMDType b1( B2.load(k,j ) );
2309 const SIMDType b2( B2.load(k,j+1UL) );
2310
2311 xmm1 += a1 * b1;
2312 xmm2 += a1 * b2;
2313 xmm3 += a2 * b1;
2314 xmm4 += a2 * b2;
2315 xmm5 += a3 * b1;
2316 xmm6 += a3 * b2;
2317 xmm7 += a4 * b1;
2318 xmm8 += a4 * b2;
2319 xmm9 += a5 * b1;
2320 xmm10 += a5 * b2;
2321 }
2322
2323 c(ibegin+i ,jj+j ) += sum( xmm1 ) * alpha;
2324 c(ibegin+i ,jj+j+1UL) += sum( xmm2 ) * alpha;
2325 c(ibegin+i+1UL,jj+j ) += sum( xmm3 ) * alpha;
2326 c(ibegin+i+1UL,jj+j+1UL) += sum( xmm4 ) * alpha;
2327 c(ibegin+i+2UL,jj+j ) += sum( xmm5 ) * alpha;
2328 c(ibegin+i+2UL,jj+j+1UL) += sum( xmm6 ) * alpha;
2329 c(ibegin+i+3UL,jj+j ) += sum( xmm7 ) * alpha;
2330 c(ibegin+i+3UL,jj+j+1UL) += sum( xmm8 ) * alpha;
2331 c(ibegin+i+4UL,jj+j ) += sum( xmm9 ) * alpha;
2332 c(ibegin+i+4UL,jj+j+1UL) += sum( xmm10 ) * alpha;
2333 }
2334
2335 if( j<jblock )
2336 {
2337 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5;
2338
2339 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
2340 {
2341 const SIMDType a1( A2.load(i ,k) );
2342 const SIMDType a2( A2.load(i+1UL,k) );
2343 const SIMDType a3( A2.load(i+2UL,k) );
2344 const SIMDType a4( A2.load(i+3UL,k) );
2345 const SIMDType a5( A2.load(i+4UL,k) );
2346
2347 const SIMDType b1( B2.load(k,j) );
2348
2349 xmm1 += a1 * b1;
2350 xmm2 += a2 * b1;
2351 xmm3 += a3 * b1;
2352 xmm4 += a4 * b1;
2353 xmm5 += a5 * b1;
2354 }
2355
2356 c(ibegin+i ,jj+j) += sum( xmm1 ) * alpha;
2357 c(ibegin+i+1UL,jj+j) += sum( xmm2 ) * alpha;
2358 c(ibegin+i+2UL,jj+j) += sum( xmm3 ) * alpha;
2359 c(ibegin+i+3UL,jj+j) += sum( xmm4 ) * alpha;
2360 c(ibegin+i+4UL,jj+j) += sum( xmm5 ) * alpha;
2361 }
2362 }
2363 }
2364 else
2365 {
2366 for( ; (i+4UL) <= isize; i+=4UL )
2367 {
2368 if( jj+jblock < ibegin ) continue;
2369
2370 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2371
2372 for( ; (j+2UL) <= jblock; j+=2UL )
2373 {
2374 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
2375
2376 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
2377 {
2378 const SIMDType a1( A2.load(i ,k) );
2379 const SIMDType a2( A2.load(i+1UL,k) );
2380 const SIMDType a3( A2.load(i+2UL,k) );
2381 const SIMDType a4( A2.load(i+3UL,k) );
2382
2383 const SIMDType b1( B2.load(k,j ) );
2384 const SIMDType b2( B2.load(k,j+1UL) );
2385
2386 xmm1 += a1 * b1;
2387 xmm2 += a1 * b2;
2388 xmm3 += a2 * b1;
2389 xmm4 += a2 * b2;
2390 xmm5 += a3 * b1;
2391 xmm6 += a3 * b2;
2392 xmm7 += a4 * b1;
2393 xmm8 += a4 * b2;
2394 }
2395
2396 c(ibegin+i ,jj+j ) += sum( xmm1 ) * alpha;
2397 c(ibegin+i ,jj+j+1UL) += sum( xmm2 ) * alpha;
2398 c(ibegin+i+1UL,jj+j ) += sum( xmm3 ) * alpha;
2399 c(ibegin+i+1UL,jj+j+1UL) += sum( xmm4 ) * alpha;
2400 c(ibegin+i+2UL,jj+j ) += sum( xmm5 ) * alpha;
2401 c(ibegin+i+2UL,jj+j+1UL) += sum( xmm6 ) * alpha;
2402 c(ibegin+i+3UL,jj+j ) += sum( xmm7 ) * alpha;
2403 c(ibegin+i+3UL,jj+j+1UL) += sum( xmm8 ) * alpha;
2404 }
2405
2406 if( j<jblock )
2407 {
2408 SIMDType xmm1, xmm2, xmm3, xmm4;
2409
2410 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
2411 {
2412 const SIMDType a1( A2.load(i ,k) );
2413 const SIMDType a2( A2.load(i+1UL,k) );
2414 const SIMDType a3( A2.load(i+2UL,k) );
2415 const SIMDType a4( A2.load(i+3UL,k) );
2416
2417 const SIMDType b1( B2.load(k,j) );
2418
2419 xmm1 += a1 * b1;
2420 xmm2 += a2 * b1;
2421 xmm3 += a3 * b1;
2422 xmm4 += a4 * b1;
2423 }
2424
2425 c(ibegin+i ,jj+j) += sum( xmm1 ) * alpha;
2426 c(ibegin+i+1UL,jj+j) += sum( xmm2 ) * alpha;
2427 c(ibegin+i+2UL,jj+j) += sum( xmm3 ) * alpha;
2428 c(ibegin+i+3UL,jj+j) += sum( xmm4 ) * alpha;
2429 }
2430 }
2431 }
2432
2433 for( ; (i+2UL) <= isize; i+=2UL )
2434 {
2435 if( jj+jblock < ibegin ) continue;
2436
2437 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2438
2439 for( ; (j+4UL) <= jblock; j+=4UL )
2440 {
2441 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
2442
2443 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
2444 {
2445 const SIMDType a1( A2.load(i ,k) );
2446 const SIMDType a2( A2.load(i+1UL,k) );
2447
2448 const SIMDType b1( B2.load(k,j ) );
2449 const SIMDType b2( B2.load(k,j+1UL) );
2450 const SIMDType b3( B2.load(k,j+2UL) );
2451 const SIMDType b4( B2.load(k,j+3UL) );
2452
2453 xmm1 += a1 * b1;
2454 xmm2 += a1 * b2;
2455 xmm3 += a1 * b3;
2456 xmm4 += a1 * b4;
2457 xmm5 += a2 * b1;
2458 xmm6 += a2 * b2;
2459 xmm7 += a2 * b3;
2460 xmm8 += a2 * b4;
2461 }
2462
2463 c(ibegin+i ,jj+j ) += sum( xmm1 ) * alpha;
2464 c(ibegin+i ,jj+j+1UL) += sum( xmm2 ) * alpha;
2465 c(ibegin+i ,jj+j+2UL) += sum( xmm3 ) * alpha;
2466 c(ibegin+i ,jj+j+3UL) += sum( xmm4 ) * alpha;
2467 c(ibegin+i+1UL,jj+j ) += sum( xmm5 ) * alpha;
2468 c(ibegin+i+1UL,jj+j+1UL) += sum( xmm6 ) * alpha;
2469 c(ibegin+i+1UL,jj+j+2UL) += sum( xmm7 ) * alpha;
2470 c(ibegin+i+1UL,jj+j+3UL) += sum( xmm8 ) * alpha;
2471 }
2472
2473 for( ; (j+2UL) <= jblock; j+=2UL )
2474 {
2475 SIMDType xmm1, xmm2, xmm3, xmm4;
2476
2477 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
2478 {
2479 const SIMDType a1( A2.load(i ,k) );
2480 const SIMDType a2( A2.load(i+1UL,k) );
2481
2482 const SIMDType b1( B2.load(k,j ) );
2483 const SIMDType b2( B2.load(k,j+1UL) );
2484
2485 xmm1 += a1 * b1;
2486 xmm2 += a1 * b2;
2487 xmm3 += a2 * b1;
2488 xmm4 += a2 * b2;
2489 }
2490
2491 c(ibegin+i ,jj+j ) += sum( xmm1 ) * alpha;
2492 c(ibegin+i ,jj+j+1UL) += sum( xmm2 ) * alpha;
2493 c(ibegin+i+1UL,jj+j ) += sum( xmm3 ) * alpha;
2494 c(ibegin+i+1UL,jj+j+1UL) += sum( xmm4 ) * alpha;
2495 }
2496
2497 if( j<jblock )
2498 {
2499 SIMDType xmm1, xmm2;
2500
2501 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
2502 {
2503 const SIMDType a1( A2.load(i ,k) );
2504 const SIMDType a2( A2.load(i+1UL,k) );
2505
2506 const SIMDType b1( B2.load(k,j) );
2507
2508 xmm1 += a1 * b1;
2509 xmm2 += a2 * b1;
2510 }
2511
2512 c(ibegin+i ,jj+j) += sum( xmm1 ) * alpha;
2513 c(ibegin+i+1UL,jj+j) += sum( xmm2 ) * alpha;
2514 }
2515 }
2516
2517 if( i<isize && jj+jblock >= ibegin )
2518 {
2519 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2520
2521 for( ; (j+2UL) <= jblock; j+=2UL )
2522 {
2523 SIMDType xmm1, xmm2;
2524
2525 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
2526 {
2527 const SIMDType a1( A2.load(i,k) );
2528
2529 xmm1 += a1 * B2.load(k,j );
2530 xmm2 += a1 * B2.load(k,j+1UL);
2531 }
2532
2533 c(ibegin+i,jj+j ) += sum( xmm1 ) * alpha;
2534 c(ibegin+i,jj+j+1UL) += sum( xmm2 ) * alpha;
2535 }
2536
2537 if( j<jblock )
2538 {
2539 SIMDType xmm1;
2540
2541 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
2542 {
2543 const SIMDType a1( A2.load(i,k) );
2544
2545 xmm1 += a1 * B2.load(k,j);
2546 }
2547
2548 c(ibegin+i,jj+j) += sum( xmm1 ) * alpha;
2549 }
2550 }
2551
2552 jj += jblock;
2553 }
2554
2555 kk += kblock;
2556 }
2557
2558 if( remainder && kk < K )
2559 {
2560 const size_t ksize( K - kk );
2561
2562 const size_t ibegin( IsLower_v<MT2> ? kk : 0UL );
2563 const size_t isize ( M - ibegin );
2564
2565 A2 = serial( submatrix( A, ibegin, kk, isize, ksize, unchecked ) );
2566
2567 size_t jj( 0UL );
2568 size_t jblock( 0UL );
2569
2570 while( jj < N )
2571 {
2572 jblock = ( ( jj+JBLOCK <= N )?( JBLOCK ):( N - jj ) );
2573
2574 if( IsUpper_v<MT3> && jj+jblock <= kk ) {
2575 jj += jblock;
2576 continue;
2577 }
2578
2579 B2 = serial( submatrix( B, kk, jj, ksize, jblock, unchecked ) );
2580
2581 size_t i( 0UL );
2582
2583 if( IsFloatingPoint_v<ET1> )
2584 {
2585 for( ; (i+5UL) <= isize; i+=5UL )
2586 {
2587 if( jj+jblock < ibegin ) continue;
2588
2589 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2590
2591 for( ; (j+2UL) <= jblock; j+=2UL ) {
2592 for( size_t k=0UL; k<ksize; ++k ) {
2593 c(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
2594 c(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
2595 c(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
2596 c(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
2597 c(ibegin+i+2UL,jj+j ) += A2(i+2UL,k) * B2(k,j ) * alpha;
2598 c(ibegin+i+2UL,jj+j+1UL) += A2(i+2UL,k) * B2(k,j+1UL) * alpha;
2599 c(ibegin+i+3UL,jj+j ) += A2(i+3UL,k) * B2(k,j ) * alpha;
2600 c(ibegin+i+3UL,jj+j+1UL) += A2(i+3UL,k) * B2(k,j+1UL) * alpha;
2601 c(ibegin+i+4UL,jj+j ) += A2(i+4UL,k) * B2(k,j ) * alpha;
2602 c(ibegin+i+4UL,jj+j+1UL) += A2(i+4UL,k) * B2(k,j+1UL) * alpha;
2603 }
2604 }
2605
2606 if( j<jblock ) {
2607 for( size_t k=0UL; k<ksize; ++k ) {
2608 c(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
2609 c(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
2610 c(ibegin+i+2UL,jj+j) += A2(i+2UL,k) * B2(k,j) * alpha;
2611 c(ibegin+i+3UL,jj+j) += A2(i+3UL,k) * B2(k,j) * alpha;
2612 c(ibegin+i+4UL,jj+j) += A2(i+4UL,k) * B2(k,j) * alpha;
2613 }
2614 }
2615 }
2616 }
2617 else
2618 {
2619 for( ; (i+4UL) <= isize; i+=4UL )
2620 {
2621 if( jj+jblock < ibegin ) continue;
2622
2623 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2624
2625 for( ; (j+2UL) <= jblock; j+=2UL ) {
2626 for( size_t k=0UL; k<ksize; ++k ) {
2627 c(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
2628 c(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
2629 c(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
2630 c(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
2631 c(ibegin+i+2UL,jj+j ) += A2(i+2UL,k) * B2(k,j ) * alpha;
2632 c(ibegin+i+2UL,jj+j+1UL) += A2(i+2UL,k) * B2(k,j+1UL) * alpha;
2633 c(ibegin+i+3UL,jj+j ) += A2(i+3UL,k) * B2(k,j ) * alpha;
2634 c(ibegin+i+3UL,jj+j+1UL) += A2(i+3UL,k) * B2(k,j+1UL) * alpha;
2635 }
2636 }
2637
2638 if( j<jblock ) {
2639 for( size_t k=0UL; k<ksize; ++k ) {
2640 c(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
2641 c(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
2642 c(ibegin+i+2UL,jj+j) += A2(i+2UL,k) * B2(k,j) * alpha;
2643 c(ibegin+i+3UL,jj+j) += A2(i+3UL,k) * B2(k,j) * alpha;
2644 }
2645 }
2646 }
2647 }
2648
2649 for( ; (i+2UL) <= isize; i+=2UL )
2650 {
2651 if( jj+jblock < ibegin ) continue;
2652
2653 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2654
2655 for( ; (j+2UL) <= jblock; j+=2UL ) {
2656 for( size_t k=0UL; k<ksize; ++k ) {
2657 c(ibegin+i ,jj+j ) += A2(i ,k) * B2(k,j ) * alpha;
2658 c(ibegin+i ,jj+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
2659 c(ibegin+i+1UL,jj+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
2660 c(ibegin+i+1UL,jj+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
2661 }
2662 }
2663
2664 if( j<jblock ) {
2665 for( size_t k=0UL; k<ksize; ++k ) {
2666 c(ibegin+i ,jj+j) += A2(i ,k) * B2(k,j) * alpha;
2667 c(ibegin+i+1UL,jj+j) += A2(i+1UL,k) * B2(k,j) * alpha;
2668 }
2669 }
2670 }
2671
2672 if( i<isize && jj+jblock >= ibegin )
2673 {
2674 size_t j( ( jj > ibegin+i )?( 0UL ):( ibegin+i-jj ) );
2675
2676 for( ; (j+2UL) <= jblock; j+=2UL ) {
2677 for( size_t k=0UL; k<ksize; ++k ) {
2678 c(ibegin+i,jj+j ) += A2(i,k) * B2(k,j ) * alpha;
2679 c(ibegin+i,jj+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
2680 }
2681 }
2682
2683 if( j<jblock ) {
2684 for( size_t k=0UL; k<ksize; ++k ) {
2685 c(ibegin+i,jj+j) += A2(i,k) * B2(k,j) * alpha;
2686 }
2687 }
2688 }
2689
2690 jj += jblock;
2691 }
2692 }
2693 }
2694 /*! \endcond */
2695 //*************************************************************************************************
2696
2697
2698 //*************************************************************************************************
2699 /*! \cond BLAZE_INTERNAL */
2700 /*!\brief Compute kernel for a upper dense matrix/dense matrix multiplication
2701 // (\f$ C=\alpha*A*B+\beta*C \f$).
2702 // \ingroup dense_matrix
2703 //
2704 // \param C The target left-hand side column-major dense matrix.
2705 // \param A The left-hand side multiplication operand.
2706 // \param B The right-hand side multiplication operand.
2707 // \param alpha The scaling factor for \f$ A*B \f$.
2708 // \param beta The scaling factor for \f$ C \f$.
2709 // \return void
2710 //
2711 // This function implements the compute kernel for a upper dense matrix/dense matrix
2712 // multiplication of the form \f$ C=\alpha*A*B+\beta*C \f$. Both \a A and \a B must
2713 // be non-expression dense matrix types, \a C must be a non-expression, non-adaptor,
2714 // column-major dense matrix type. The element types of all three matrices must be SIMD
2715 // combinable, i.e. must provide a common SIMD interface.
2716 */
2717 template< typename MT1, typename MT2, typename MT3, typename ST >
ummm(DenseMatrix<MT1,true> & C,const MT2 & A,const MT3 & B,ST alpha,ST beta)2718 void ummm( DenseMatrix<MT1,true>& C, const MT2& A, const MT3& B, ST alpha, ST beta )
2719 {
2720 using ET1 = ElementType_t<MT1>;
2721 using ET2 = ElementType_t<MT2>;
2722 using ET3 = ElementType_t<MT3>;
2723 using SIMDType = SIMDTrait_t<ET1>;
2724
2725 BLAZE_CONSTRAINT_MUST_BE_DENSE_MATRIX_TYPE ( MT1 );
2726 BLAZE_CONSTRAINT_MUST_BE_COLUMN_MAJOR_MATRIX_TYPE ( MT1 );
2727 BLAZE_CONSTRAINT_MUST_NOT_BE_SYMMETRIC_MATRIX_TYPE ( MT1 );
2728 BLAZE_CONSTRAINT_MUST_NOT_BE_HERMITIAN_MATRIX_TYPE ( MT1 );
2729 BLAZE_CONSTRAINT_MUST_NOT_BE_LOWER_MATRIX_TYPE ( MT1 );
2730 BLAZE_CONSTRAINT_MUST_NOT_BE_UNIUPPER_MATRIX_TYPE ( MT1 );
2731 BLAZE_CONSTRAINT_MUST_NOT_BE_STRICTLY_UPPER_MATRIX_TYPE( MT1 );
2732 BLAZE_CONSTRAINT_MUST_NOT_BE_COMPUTATION_TYPE ( MT1 );
2733
2734 BLAZE_CONSTRAINT_MUST_BE_DENSE_MATRIX_TYPE ( MT2 );
2735 BLAZE_CONSTRAINT_MUST_NOT_BE_COMPUTATION_TYPE( MT2 );
2736
2737 BLAZE_CONSTRAINT_MUST_BE_DENSE_MATRIX_TYPE ( MT3 );
2738 BLAZE_CONSTRAINT_MUST_NOT_BE_COMPUTATION_TYPE( MT3 );
2739
2740 BLAZE_CONSTRAINT_MUST_BE_SIMD_COMBINABLE_TYPES( ET1, ET2 );
2741 BLAZE_CONSTRAINT_MUST_BE_SIMD_COMBINABLE_TYPES( ET1, ET3 );
2742
2743 constexpr size_t SIMDSIZE( SIMDTrait<ET1>::size );
2744
2745 constexpr bool remainder( !IsPadded_v<MT2> || !IsPadded_v<MT3> );
2746
2747 constexpr size_t KBLOCK( MMM_OUTER_BLOCK_SIZE * ( 16UL/sizeof(ET1) ) );
2748 constexpr size_t IBLOCK( MMM_INNER_BLOCK_SIZE );
2749
2750 BLAZE_STATIC_ASSERT( KBLOCK >= SIMDSIZE && KBLOCK % SIMDSIZE == 0UL );
2751 BLAZE_STATIC_ASSERT( IBLOCK >= SIMDSIZE && IBLOCK % SIMDSIZE == 0UL );
2752
2753 const size_t M( A.rows() );
2754 const size_t N( B.columns() );
2755 const size_t K( A.columns() );
2756
2757 BLAZE_INTERNAL_ASSERT( A.columns() == B.rows(), "Invalid matrix sizes detected" );
2758
2759 DynamicMatrix<ET2,false> A2( IBLOCK, KBLOCK );
2760 DynamicMatrix<ET3,true> B2( KBLOCK, N );
2761
2762 decltype(auto) c( derestrict( *C ) );
2763
2764 if( isDefault( beta ) ) {
2765 reset( c );
2766 }
2767 else if( !isOne( beta ) ) {
2768 c *= beta;
2769 }
2770
2771 size_t kk( 0UL );
2772 size_t kblock( 0UL );
2773
2774 while( kk + ( remainder ? SIMDSIZE-1UL : 0UL ) < K )
2775 {
2776 if( remainder ) {
2777 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( prevMultiple( K - kk, SIMDSIZE ) ) );
2778 }
2779 else {
2780 kblock = ( ( kk+KBLOCK <= K )?( KBLOCK ):( K - kk ) );
2781 }
2782
2783 const size_t jbegin( IsUpper_v<MT3> ? kk : 0UL );
2784 const size_t jend ( IsLower_v<MT3> ? kk+kblock : N );
2785 const size_t jsize ( jend - jbegin );
2786
2787 B2 = serial( submatrix< remainder ? unaligned : aligned >( B, kk, jbegin, kblock, jsize, unchecked ) );
2788
2789 size_t ii( 0UL );
2790 size_t iblock( 0UL );
2791
2792 while( ii < M )
2793 {
2794 iblock = ( ( ii+IBLOCK <= M )?( IBLOCK ):( M - ii ) );
2795
2796 if( ( IsLower_v<MT2> && ii+iblock <= kk ) ||
2797 ( IsUpper_v<MT2> && kk+kblock <= ii ) ) {
2798 ii += iblock;
2799 continue;
2800 }
2801
2802 A2 = serial( submatrix< remainder ? unaligned : aligned >( A, ii, kk, iblock, kblock, unchecked ) );
2803
2804 size_t j( 0UL );
2805
2806 if( IsFloatingPoint_v<ET3> )
2807 {
2808 for( ; (j+5UL) <= jsize; j+=5UL )
2809 {
2810 if( ii > jbegin+j+4UL ) continue;
2811
2812 const size_t iend( min( iblock, jbegin+j-ii+5UL ) );
2813 size_t i( 0UL );
2814
2815 for( ; (i+2UL) <= iend; i+=2UL )
2816 {
2817 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8, xmm9, xmm10;
2818
2819 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
2820 {
2821 const SIMDType a1( A2.load(i ,k) );
2822 const SIMDType a2( A2.load(i+1UL,k) );
2823
2824 const SIMDType b1( B2.load(k,j ) );
2825 const SIMDType b2( B2.load(k,j+1UL) );
2826 const SIMDType b3( B2.load(k,j+2UL) );
2827 const SIMDType b4( B2.load(k,j+3UL) );
2828 const SIMDType b5( B2.load(k,j+4UL) );
2829
2830 xmm1 += a1 * b1;
2831 xmm2 += a1 * b2;
2832 xmm3 += a1 * b3;
2833 xmm4 += a1 * b4;
2834 xmm5 += a1 * b5;
2835 xmm6 += a2 * b1;
2836 xmm7 += a2 * b2;
2837 xmm8 += a2 * b3;
2838 xmm9 += a2 * b4;
2839 xmm10 += a2 * b5;
2840 }
2841
2842 c(ii+i ,jbegin+j ) += sum( xmm1 ) * alpha;
2843 c(ii+i ,jbegin+j+1UL) += sum( xmm2 ) * alpha;
2844 c(ii+i ,jbegin+j+2UL) += sum( xmm3 ) * alpha;
2845 c(ii+i ,jbegin+j+3UL) += sum( xmm4 ) * alpha;
2846 c(ii+i ,jbegin+j+4UL) += sum( xmm5 ) * alpha;
2847 c(ii+i+1UL,jbegin+j ) += sum( xmm6 ) * alpha;
2848 c(ii+i+1UL,jbegin+j+1UL) += sum( xmm7 ) * alpha;
2849 c(ii+i+1UL,jbegin+j+2UL) += sum( xmm8 ) * alpha;
2850 c(ii+i+1UL,jbegin+j+3UL) += sum( xmm9 ) * alpha;
2851 c(ii+i+1UL,jbegin+j+4UL) += sum( xmm10 ) * alpha;
2852 }
2853
2854 if( i<iend )
2855 {
2856 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5;
2857
2858 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
2859 {
2860 const SIMDType a1( A2.load(i,k) );
2861
2862 xmm1 += a1 * B2.load(k,j );
2863 xmm2 += a1 * B2.load(k,j+1UL);
2864 xmm3 += a1 * B2.load(k,j+2UL);
2865 xmm4 += a1 * B2.load(k,j+3UL);
2866 xmm5 += a1 * B2.load(k,j+4UL);
2867 }
2868
2869 c(ii+i,jbegin+j ) += sum( xmm1 ) * alpha;
2870 c(ii+i,jbegin+j+1UL) += sum( xmm2 ) * alpha;
2871 c(ii+i,jbegin+j+2UL) += sum( xmm3 ) * alpha;
2872 c(ii+i,jbegin+j+3UL) += sum( xmm4 ) * alpha;
2873 c(ii+i,jbegin+j+4UL) += sum( xmm5 ) * alpha;
2874 }
2875 }
2876 }
2877 else
2878 {
2879 for( ; (j+4UL) <= jsize; j+=4UL )
2880 {
2881 if( ii > jbegin+j+3UL ) continue;
2882
2883 const size_t iend( min( iblock, jbegin+j-ii+4UL ) );
2884 size_t i( 0UL );
2885
2886 for( ; (i+2UL) <= iend; i+=2UL )
2887 {
2888 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
2889
2890 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
2891 {
2892 const SIMDType a1( A2.load(i ,k) );
2893 const SIMDType a2( A2.load(i+1UL,k) );
2894
2895 const SIMDType b1( B2.load(k,j ) );
2896 const SIMDType b2( B2.load(k,j+1UL) );
2897 const SIMDType b3( B2.load(k,j+2UL) );
2898 const SIMDType b4( B2.load(k,j+3UL) );
2899
2900 xmm1 += a1 * b1;
2901 xmm2 += a1 * b2;
2902 xmm3 += a1 * b3;
2903 xmm4 += a1 * b4;
2904 xmm5 += a2 * b1;
2905 xmm6 += a2 * b2;
2906 xmm7 += a2 * b3;
2907 xmm8 += a2 * b4;
2908 }
2909
2910 c(ii+i ,jbegin+j ) += sum( xmm1 ) * alpha;
2911 c(ii+i ,jbegin+j+1UL) += sum( xmm2 ) * alpha;
2912 c(ii+i ,jbegin+j+2UL) += sum( xmm3 ) * alpha;
2913 c(ii+i ,jbegin+j+3UL) += sum( xmm4 ) * alpha;
2914 c(ii+i+1UL,jbegin+j ) += sum( xmm5 ) * alpha;
2915 c(ii+i+1UL,jbegin+j+1UL) += sum( xmm6 ) * alpha;
2916 c(ii+i+1UL,jbegin+j+2UL) += sum( xmm7 ) * alpha;
2917 c(ii+i+1UL,jbegin+j+3UL) += sum( xmm8 ) * alpha;
2918 }
2919
2920 if( i<iend )
2921 {
2922 SIMDType xmm1, xmm2, xmm3, xmm4;
2923
2924 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
2925 {
2926 const SIMDType a1( A2.load(i,k) );
2927
2928 xmm1 += a1 * B2.load(k,j );
2929 xmm2 += a1 * B2.load(k,j+1UL);
2930 xmm3 += a1 * B2.load(k,j+2UL);
2931 xmm4 += a1 * B2.load(k,j+3UL);
2932 }
2933
2934 c(ii+i,jbegin+j ) += sum( xmm1 ) * alpha;
2935 c(ii+i,jbegin+j+1UL) += sum( xmm2 ) * alpha;
2936 c(ii+i,jbegin+j+2UL) += sum( xmm3 ) * alpha;
2937 c(ii+i,jbegin+j+3UL) += sum( xmm4 ) * alpha;
2938 }
2939 }
2940 }
2941
2942 for( ; (j+2UL) <= jsize; j+=2UL )
2943 {
2944 if( ii > jbegin+j+1UL ) continue;
2945
2946 const size_t iend( min( iblock, jbegin+j-ii+2UL ) );
2947 size_t i( 0UL );
2948
2949 for( ; (i+4UL) <= iend; i+=4UL )
2950 {
2951 SIMDType xmm1, xmm2, xmm3, xmm4, xmm5, xmm6, xmm7, xmm8;
2952
2953 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
2954 {
2955 const SIMDType a1( A2.load(i ,k) );
2956 const SIMDType a2( A2.load(i+1UL,k) );
2957 const SIMDType a3( A2.load(i+2UL,k) );
2958 const SIMDType a4( A2.load(i+3UL,k) );
2959
2960 const SIMDType b1( B2.load(k,j ) );
2961 const SIMDType b2( B2.load(k,j+1UL) );
2962
2963 xmm1 += a1 * b1;
2964 xmm2 += a1 * b2;
2965 xmm3 += a2 * b1;
2966 xmm4 += a2 * b2;
2967 xmm5 += a3 * b1;
2968 xmm6 += a3 * b2;
2969 xmm7 += a4 * b1;
2970 xmm8 += a4 * b2;
2971 }
2972
2973 c(ii+i ,jbegin+j ) += sum( xmm1 ) * alpha;
2974 c(ii+i ,jbegin+j+1UL) += sum( xmm2 ) * alpha;
2975 c(ii+i+1UL,jbegin+j ) += sum( xmm3 ) * alpha;
2976 c(ii+i+1UL,jbegin+j+1UL) += sum( xmm4 ) * alpha;
2977 c(ii+i+2UL,jbegin+j ) += sum( xmm5 ) * alpha;
2978 c(ii+i+2UL,jbegin+j+1UL) += sum( xmm6 ) * alpha;
2979 c(ii+i+3UL,jbegin+j ) += sum( xmm7 ) * alpha;
2980 c(ii+i+3UL,jbegin+j+1UL) += sum( xmm8 ) * alpha;
2981 }
2982
2983 for( ; (i+2UL) <= iend; i+=2UL )
2984 {
2985 SIMDType xmm1, xmm2, xmm3, xmm4;
2986
2987 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
2988 {
2989 const SIMDType a1( A2.load(i ,k) );
2990 const SIMDType a2( A2.load(i+1UL,k) );
2991
2992 const SIMDType b1( B2.load(k,j ) );
2993 const SIMDType b2( B2.load(k,j+1UL) );
2994
2995 xmm1 += a1 * b1;
2996 xmm2 += a1 * b2;
2997 xmm3 += a2 * b1;
2998 xmm4 += a2 * b2;
2999 }
3000
3001 c(ii+i ,jbegin+j ) += sum( xmm1 ) * alpha;
3002 c(ii+i ,jbegin+j+1UL) += sum( xmm2 ) * alpha;
3003 c(ii+i+1UL,jbegin+j ) += sum( xmm3 ) * alpha;
3004 c(ii+i+1UL,jbegin+j+1UL) += sum( xmm4 ) * alpha;
3005 }
3006
3007 if( i<iend )
3008 {
3009 SIMDType xmm1, xmm2;
3010
3011 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
3012 {
3013 const SIMDType a1( A2.load(i,k) );
3014
3015 xmm1 += a1 * B2.load(k,j );
3016 xmm2 += a1 * B2.load(k,j+1UL);
3017 }
3018
3019 c(ii+i,jbegin+j ) += sum( xmm1 ) * alpha;
3020 c(ii+i,jbegin+j+1UL) += sum( xmm2 ) * alpha;
3021 }
3022 }
3023
3024 if( j<jsize && ii <= jbegin+j )
3025 {
3026 const size_t iend( min( iblock, jbegin+j-ii+2UL ) );
3027 size_t i( 0UL );
3028
3029 for( ; (i+2UL) <= iend; i+=2UL )
3030 {
3031 SIMDType xmm1, xmm2;
3032
3033 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
3034 {
3035 const SIMDType b1( B2.load(k,j) );
3036
3037 xmm1 += A2.load(i ,k) * b1;
3038 xmm2 += A2.load(i+1UL,k) * b1;
3039 }
3040
3041 c(ii+i ,jbegin+j) += sum( xmm1 ) * alpha;
3042 c(ii+i+1UL,jbegin+j) += sum( xmm2 ) * alpha;
3043 }
3044
3045 if( i<iend )
3046 {
3047 SIMDType xmm1;
3048
3049 for( size_t k=0UL; k<kblock; k+=SIMDSIZE )
3050 {
3051 xmm1 += A2.load(i,k) * B2.load(k,j);
3052 }
3053
3054 c(ii+i,jbegin+j) += sum( xmm1 ) * alpha;
3055 }
3056 }
3057
3058 ii += iblock;
3059 }
3060
3061 kk += kblock;
3062 }
3063
3064 if( remainder && kk < K )
3065 {
3066 const size_t ksize( K - kk );
3067
3068 const size_t jbegin( IsUpper_v<MT3> ? kk : 0UL );
3069 const size_t jsize ( N - jbegin );
3070
3071 B2 = serial( submatrix( B, kk, jbegin, ksize, jsize, unchecked ) );
3072
3073 size_t ii( 0UL );
3074 size_t iblock( 0UL );
3075
3076 while( ii < M )
3077 {
3078 iblock = ( ( ii+IBLOCK <= M )?( IBLOCK ):( M - ii ) );
3079
3080 if( IsLower_v<MT2> && ii+iblock <= kk ) {
3081 ii += iblock;
3082 continue;
3083 }
3084
3085 A2 = serial( submatrix( A, ii, kk, iblock, ksize, unchecked ) );
3086
3087 size_t j( 0UL );
3088
3089 if( IsFloatingPoint_v<ET1> )
3090 {
3091 for( ; (j+5UL) <= jsize; j+=5UL )
3092 {
3093 if( ii > jbegin+j+4UL ) continue;
3094
3095 const size_t iend( min( iblock, jbegin+j-ii+5UL ) );
3096 size_t i( 0UL );
3097
3098 for( ; (i+2UL) <= iend; i+=2UL ) {
3099 for( size_t k=0UL; k<ksize; ++k ) {
3100 c(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
3101 c(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
3102 c(ii+i ,jbegin+j+2UL) += A2(i ,k) * B2(k,j+2UL) * alpha;
3103 c(ii+i ,jbegin+j+3UL) += A2(i ,k) * B2(k,j+3UL) * alpha;
3104 c(ii+i ,jbegin+j+4UL) += A2(i ,k) * B2(k,j+4UL) * alpha;
3105 c(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
3106 c(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
3107 c(ii+i+1UL,jbegin+j+2UL) += A2(i+1UL,k) * B2(k,j+2UL) * alpha;
3108 c(ii+i+1UL,jbegin+j+3UL) += A2(i+1UL,k) * B2(k,j+3UL) * alpha;
3109 c(ii+i+1UL,jbegin+j+4UL) += A2(i+1UL,k) * B2(k,j+4UL) * alpha;
3110 }
3111 }
3112
3113 if( i<iend ) {
3114 for( size_t k=0UL; k<ksize; ++k ) {
3115 c(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
3116 c(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
3117 c(ii+i,jbegin+j+2UL) += A2(i,k) * B2(k,j+2UL) * alpha;
3118 c(ii+i,jbegin+j+3UL) += A2(i,k) * B2(k,j+3UL) * alpha;
3119 c(ii+i,jbegin+j+4UL) += A2(i,k) * B2(k,j+4UL) * alpha;
3120 }
3121 }
3122 }
3123 }
3124 else
3125 {
3126 for( ; (j+4UL) <= jsize; j+=4UL )
3127 {
3128 if( ii > jbegin+j+3UL ) continue;
3129
3130 const size_t iend( min( iblock, jbegin+j-ii+4UL ) );
3131 size_t i( 0UL );
3132
3133 for( ; (i+2UL) <= iend; i+=2UL ) {
3134 for( size_t k=0UL; k<ksize; ++k ) {
3135 c(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
3136 c(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
3137 c(ii+i ,jbegin+j+2UL) += A2(i ,k) * B2(k,j+2UL) * alpha;
3138 c(ii+i ,jbegin+j+3UL) += A2(i ,k) * B2(k,j+3UL) * alpha;
3139 c(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
3140 c(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
3141 c(ii+i+1UL,jbegin+j+2UL) += A2(i+1UL,k) * B2(k,j+2UL) * alpha;
3142 c(ii+i+1UL,jbegin+j+3UL) += A2(i+1UL,k) * B2(k,j+3UL) * alpha;
3143 }
3144 }
3145
3146 if( i<iend ) {
3147 for( size_t k=0UL; k<ksize; ++k ) {
3148 c(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
3149 c(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
3150 c(ii+i,jbegin+j+2UL) += A2(i,k) * B2(k,j+2UL) * alpha;
3151 c(ii+i,jbegin+j+3UL) += A2(i,k) * B2(k,j+3UL) * alpha;
3152 }
3153 }
3154 }
3155 }
3156
3157 for( ; (j+2UL) <= jsize; j+=2UL )
3158 {
3159 if( ii > jbegin+j+1UL ) continue;
3160
3161 const size_t iend( min( iblock, jbegin+j-ii+2UL ) );
3162 size_t i( 0UL );
3163
3164 for( ; (i+2UL) <= iend; i+=2UL ) {
3165 for( size_t k=0UL; k<ksize; ++k ) {
3166 c(ii+i ,jbegin+j ) += A2(i ,k) * B2(k,j ) * alpha;
3167 c(ii+i ,jbegin+j+1UL) += A2(i ,k) * B2(k,j+1UL) * alpha;
3168 c(ii+i+1UL,jbegin+j ) += A2(i+1UL,k) * B2(k,j ) * alpha;
3169 c(ii+i+1UL,jbegin+j+1UL) += A2(i+1UL,k) * B2(k,j+1UL) * alpha;
3170 }
3171 }
3172
3173 if( i<iend ) {
3174 for( size_t k=0UL; k<ksize; ++k ) {
3175 c(ii+i,jbegin+j ) += A2(i,k) * B2(k,j ) * alpha;
3176 c(ii+i,jbegin+j+1UL) += A2(i,k) * B2(k,j+1UL) * alpha;
3177 }
3178 }
3179 }
3180
3181 if( j<jsize && ii <= jbegin+j )
3182 {
3183 const size_t iend( min( iblock, jbegin+j-ii+2UL ) );
3184 size_t i( 0UL );
3185
3186 for( ; (i+2UL) <= iend; i+=2UL ) {
3187 for( size_t k=0UL; k<ksize; ++k ) {
3188 c(ii+i ,jbegin+j) += A2(i ,k) * B2(k,j) * alpha;
3189 c(ii+i+1UL,jbegin+j) += A2(i+1UL,k) * B2(k,j) * alpha;
3190 }
3191 }
3192
3193 if( i<iend ) {
3194 for( size_t k=0UL; k<ksize; ++k ) {
3195 c(ii+i,jbegin+j) += A2(i,k) * B2(k,j) * alpha;
3196 }
3197 }
3198 }
3199
3200 ii += iblock;
3201 }
3202 }
3203 }
3204 /*! \endcond */
3205 //*************************************************************************************************
3206
3207
3208 //*************************************************************************************************
3209 /*! \cond BLAZE_INTERNAL */
3210 /*!\brief Compute kernel for a upper dense matrix/dense matrix multiplication (\f$ C=A*B \f$).
3211 // \ingroup dense_matrix
3212 //
3213 // \param C The target left-hand side column-major dense matrix.
3214 // \param A The left-hand side multiplication operand.
3215 // \param B The right-hand side multiplication operand.
3216 // \return void
3217 //
3218 // This function implements the compute kernel for a upper dense matrix/dense matrix
3219 // multiplication of the form \f$ C=A*B \f$. Both \a A and \a B must be non-expression
3220 // dense matrix types, \a C must be a non-expression, non-adaptor, row-major dense matrix
3221 // type. The element types of all three matrices must be SIMD combinable, i.e. must
3222 // provide a common SIMD interface.
3223 */
3224 template< typename MT1, typename MT2, typename MT3 >
ummm(MT1 & C,const MT2 & A,const MT3 & B)3225 inline void ummm( MT1& C, const MT2& A, const MT3& B )
3226 {
3227 using ET1 = ElementType_t<MT1>;
3228 using ET2 = ElementType_t<MT2>;
3229 using ET3 = ElementType_t<MT3>;
3230
3231 BLAZE_CONSTRAINT_MUST_BE_SIMD_COMBINABLE_TYPES( ET1, ET2 );
3232 BLAZE_CONSTRAINT_MUST_BE_SIMD_COMBINABLE_TYPES( ET1, ET3 );
3233
3234 ummm( C, A, B, ET1(1), ET1(0) );
3235 }
3236 /*! \endcond */
3237 //*************************************************************************************************
3238
3239
3240
3241
3242 //=================================================================================================
3243 //
3244 // SYMMETRIC DENSE MATRIX MULTIPLICATION KERNELS
3245 //
3246 //=================================================================================================
3247
3248 //*************************************************************************************************
3249 /*! \cond BLAZE_INTERNAL */
3250 /*!\brief Compute kernel for a symmetric dense matrix/dense matrix multiplication
3251 // (\f$ C=\alpha*A*B \f$).
3252 // \ingroup dense_matrix
3253 //
3254 // \param C The target left-hand side row-major dense matrix.
3255 // \param A The left-hand side multiplication operand.
3256 // \param B The right-hand side multiplication operand.
3257 // \param alpha The scaling factor for \f$ A*B \f$.
3258 // \return void
3259 //
3260 // This function implements the compute kernel for a symmetric dense matrix/dense matrix
3261 // multiplication of the form \f$ C=\alpha*A*B \f$. Both \a A and \a B must be non-expression
3262 // dense matrix types, \a C must be a non-expression, non-adaptor, row-major dense matrix type.
3263 // The element types of all three matrices must be SIMD combinable, i.e. must provide a common
3264 // SIMD interface.
3265 */
3266 template< typename MT1, typename MT2, typename MT3, typename ST >
smmm(DenseMatrix<MT1,false> & C,const MT2 & A,const MT3 & B,ST alpha)3267 void smmm( DenseMatrix<MT1,false>& C, const MT2& A, const MT3& B, ST alpha )
3268 {
3269 using ET1 = ElementType_t<MT1>;
3270 using ET2 = ElementType_t<MT2>;
3271 using ET3 = ElementType_t<MT3>;
3272
3273 BLAZE_CONSTRAINT_MUST_BE_DENSE_MATRIX_TYPE ( MT1 );
3274 BLAZE_CONSTRAINT_MUST_BE_ROW_MAJOR_MATRIX_TYPE( MT1 );
3275 BLAZE_CONSTRAINT_MUST_NOT_BE_ADAPTOR_TYPE ( MT1 );
3276 BLAZE_CONSTRAINT_MUST_NOT_BE_COMPUTATION_TYPE ( MT1 );
3277
3278 BLAZE_CONSTRAINT_MUST_BE_DENSE_MATRIX_TYPE ( MT2 );
3279 BLAZE_CONSTRAINT_MUST_NOT_BE_COMPUTATION_TYPE( MT2 );
3280
3281 BLAZE_CONSTRAINT_MUST_BE_DENSE_MATRIX_TYPE ( MT3 );
3282 BLAZE_CONSTRAINT_MUST_NOT_BE_COMPUTATION_TYPE( MT3 );
3283
3284 BLAZE_CONSTRAINT_MUST_BE_SIMD_COMBINABLE_TYPES( ET1, ET2 );
3285 BLAZE_CONSTRAINT_MUST_BE_SIMD_COMBINABLE_TYPES( ET1, ET3 );
3286
3287 const size_t M( A.rows() );
3288 const size_t N( B.columns() );
3289
3290 BLAZE_INTERNAL_ASSERT( A.columns() == B.rows(), "Invalid matrix sizes detected" );
3291
3292 lmmm( C, A, B, alpha, ST(0) );
3293
3294 for( size_t ii=0UL; ii<M; ii+=BLOCK_SIZE )
3295 {
3296 const size_t iend( min( M, ii+BLOCK_SIZE ) );
3297
3298 for( size_t i=ii; i<iend; ++i ) {
3299 for( size_t j=i+1UL; j<iend; ++j ) {
3300 (*C)(i,j) = (*C)(j,i);
3301 }
3302 }
3303
3304 for( size_t jj=ii+BLOCK_SIZE; jj<N; jj+=BLOCK_SIZE ) {
3305 const size_t jend( min( N, jj+BLOCK_SIZE ) );
3306 for( size_t i=ii; i<iend; ++i ) {
3307 for( size_t j=jj; j<jend; ++j ) {
3308 (*C)(i,j) = (*C)(j,i);
3309 }
3310 }
3311 }
3312 }
3313 }
3314 /*! \endcond */
3315 //*************************************************************************************************
3316
3317
3318 //*************************************************************************************************
3319 /*! \cond BLAZE_INTERNAL */
3320 /*!\brief Compute kernel for a symmetric dense matrix/dense matrix multiplication
3321 // (\f$ C=\alpha*A*B \f$).
3322 // \ingroup dense_matrix
3323 //
3324 // \param C The target left-hand side column-major dense matrix.
3325 // \param A The left-hand side multiplication operand.
3326 // \param B The right-hand side multiplication operand.
3327 // \param alpha The scaling factor for \f$ A*B \f$.
3328 // \return void
3329 //
3330 // This function implements the compute kernel for a symmetric dense matrix/dense matrix
3331 // multiplication of the form \f$ C=\alpha*A*B \f$. Both \a A and \a B must be non-expression
3332 // dense matrix types, \a C must be a non-expression, non-adaptor, column-major dense matrix
3333 // type. The element types of all three matrices must be SIMD combinable, i.e. must provide
3334 // a common SIMD interface.
3335 */
3336 template< typename MT1, typename MT2, typename MT3, typename ST >
smmm(DenseMatrix<MT1,true> & C,const MT2 & A,const MT3 & B,ST alpha)3337 void smmm( DenseMatrix<MT1,true>& C, const MT2& A, const MT3& B, ST alpha )
3338 {
3339 using ET1 = ElementType_t<MT1>;
3340 using ET2 = ElementType_t<MT2>;
3341 using ET3 = ElementType_t<MT3>;
3342
3343 BLAZE_CONSTRAINT_MUST_BE_DENSE_MATRIX_TYPE ( MT1 );
3344 BLAZE_CONSTRAINT_MUST_BE_COLUMN_MAJOR_MATRIX_TYPE( MT1 );
3345 BLAZE_CONSTRAINT_MUST_NOT_BE_ADAPTOR_TYPE ( MT1 );
3346 BLAZE_CONSTRAINT_MUST_NOT_BE_COMPUTATION_TYPE ( MT1 );
3347
3348 BLAZE_CONSTRAINT_MUST_BE_DENSE_MATRIX_TYPE ( MT2 );
3349 BLAZE_CONSTRAINT_MUST_NOT_BE_COMPUTATION_TYPE( MT2 );
3350
3351 BLAZE_CONSTRAINT_MUST_BE_DENSE_MATRIX_TYPE ( MT3 );
3352 BLAZE_CONSTRAINT_MUST_NOT_BE_COMPUTATION_TYPE( MT3 );
3353
3354 BLAZE_CONSTRAINT_MUST_BE_SIMD_COMBINABLE_TYPES( ET1, ET2 );
3355 BLAZE_CONSTRAINT_MUST_BE_SIMD_COMBINABLE_TYPES( ET1, ET3 );
3356
3357 const size_t M( A.rows() );
3358 const size_t N( B.columns() );
3359
3360 BLAZE_INTERNAL_ASSERT( A.columns() == B.rows(), "Invalid matrix sizes detected" );
3361
3362 ummm( C, A, B, alpha, ST(0) );
3363
3364 for( size_t jj=0UL; jj<N; jj+=BLOCK_SIZE )
3365 {
3366 const size_t jend( min( N, jj+BLOCK_SIZE ) );
3367
3368 for( size_t j=jj; j<jend; ++j ) {
3369 for( size_t i=jj+1UL; i<jend; ++i ) {
3370 (*C)(i,j) = (*C)(j,i);
3371 }
3372 }
3373
3374 for( size_t ii=jj+BLOCK_SIZE; ii<M; ii+=BLOCK_SIZE ) {
3375 const size_t iend( min( M, ii+BLOCK_SIZE ) );
3376 for( size_t j=jj; j<jend; ++j ) {
3377 for( size_t i=ii; i<iend; ++i ) {
3378 (*C)(i,j) = (*C)(j,i);
3379 }
3380 }
3381 }
3382 }
3383 }
3384 /*! \endcond */
3385 //*************************************************************************************************
3386
3387
3388 //*************************************************************************************************
3389 /*! \cond BLAZE_INTERNAL */
3390 /*!\brief Compute kernel for a symmetric dense matrix/dense matrix multiplication (\f$ C=A*B \f$).
3391 // \ingroup dense_matrix
3392 //
3393 // \param C The target left-hand side column-major dense matrix.
3394 // \param A The left-hand side multiplication operand.
3395 // \param B The right-hand side multiplication operand.
3396 // \return void
3397 //
3398 // This function implements the compute kernel for a symmetric dense matrix/dense matrix
3399 // multiplication of the form \f$ C=A*B \f$. Both \a A and \a B must be non-expression
3400 // dense matrix types, \a C must be a non-expression, non-adaptor, row-major dense matrix
3401 // type. The element types of all three matrices must be SIMD combinable, i.e. must
3402 // provide a common SIMD interface.
3403 */
3404 template< typename MT1, typename MT2, typename MT3 >
smmm(MT1 & C,const MT2 & A,const MT3 & B)3405 inline void smmm( MT1& C, const MT2& A, const MT3& B )
3406 {
3407 using ET1 = ElementType_t<MT1>;
3408 using ET2 = ElementType_t<MT2>;
3409 using ET3 = ElementType_t<MT3>;
3410
3411 BLAZE_CONSTRAINT_MUST_BE_SIMD_COMBINABLE_TYPES( ET1, ET2 );
3412 BLAZE_CONSTRAINT_MUST_BE_SIMD_COMBINABLE_TYPES( ET1, ET3 );
3413
3414 smmm( C, A, B, ET1(1) );
3415 }
3416 /*! \endcond */
3417 //*************************************************************************************************
3418
3419
3420
3421
3422 //=================================================================================================
3423 //
3424 // HERMITIAN DENSE MATRIX MULTIPLICATION KERNELS
3425 //
3426 //=================================================================================================
3427
3428 //*************************************************************************************************
3429 /*! \cond BLAZE_INTERNAL */
3430 /*!\brief Compute kernel for a Hermitian dense matrix/dense matrix multiplication
3431 // (\f$ C=\alpha*A*B \f$).
3432 // \ingroup dense_matrix
3433 //
3434 // \param C The target left-hand side row-major dense matrix.
3435 // \param A The left-hand side multiplication operand.
3436 // \param B The right-hand side multiplication operand.
3437 // \param alpha The scaling factor for \f$ A*B \f$.
3438 // \return void
3439 //
3440 // This function implements the compute kernel for a Hermitian dense matrix/dense matrix
3441 // multiplication of the form \f$ C=\alpha*A*B \f$. Both \a A and \a B must be non-expression
3442 // dense matrix types, \a C must be a non-expression, non-adaptor, row-major dense matrix type.
3443 // The element types of all three matrices must be SIMD combinable, i.e. must provide a common
3444 // SIMD interface.
3445 */
3446 template< typename MT1, typename MT2, typename MT3, typename ST >
hmmm(DenseMatrix<MT1,false> & C,const MT2 & A,const MT3 & B,ST alpha)3447 void hmmm( DenseMatrix<MT1,false>& C, const MT2& A, const MT3& B, ST alpha )
3448 {
3449 using ET1 = ElementType_t<MT1>;
3450 using ET2 = ElementType_t<MT2>;
3451 using ET3 = ElementType_t<MT3>;
3452
3453 BLAZE_CONSTRAINT_MUST_BE_DENSE_MATRIX_TYPE ( MT1 );
3454 BLAZE_CONSTRAINT_MUST_BE_ROW_MAJOR_MATRIX_TYPE( MT1 );
3455 BLAZE_CONSTRAINT_MUST_NOT_BE_ADAPTOR_TYPE ( MT1 );
3456 BLAZE_CONSTRAINT_MUST_NOT_BE_COMPUTATION_TYPE ( MT1 );
3457
3458 BLAZE_CONSTRAINT_MUST_BE_DENSE_MATRIX_TYPE ( MT2 );
3459 BLAZE_CONSTRAINT_MUST_NOT_BE_COMPUTATION_TYPE( MT2 );
3460
3461 BLAZE_CONSTRAINT_MUST_BE_DENSE_MATRIX_TYPE ( MT3 );
3462 BLAZE_CONSTRAINT_MUST_NOT_BE_COMPUTATION_TYPE( MT3 );
3463
3464 BLAZE_CONSTRAINT_MUST_BE_SIMD_COMBINABLE_TYPES( ET1, ET2 );
3465 BLAZE_CONSTRAINT_MUST_BE_SIMD_COMBINABLE_TYPES( ET1, ET3 );
3466
3467 const size_t M( A.rows() );
3468 const size_t N( B.columns() );
3469
3470 BLAZE_INTERNAL_ASSERT( A.columns() == B.rows(), "Invalid matrix sizes detected" );
3471
3472 lmmm( C, A, B, alpha, ST(0) );
3473
3474 for( size_t ii=0UL; ii<M; ii+=BLOCK_SIZE )
3475 {
3476 const size_t iend( min( M, ii+BLOCK_SIZE ) );
3477
3478 for( size_t i=ii; i<iend; ++i ) {
3479 for( size_t j=i+1UL; j<iend; ++j ) {
3480 (*C)(i,j) = conj( (*C)(j,i) );
3481 }
3482 }
3483
3484 for( size_t jj=ii+BLOCK_SIZE; jj<N; jj+=BLOCK_SIZE ) {
3485 const size_t jend( min( N, jj+BLOCK_SIZE ) );
3486 for( size_t i=ii; i<iend; ++i ) {
3487 for( size_t j=jj; j<jend; ++j ) {
3488 (*C)(i,j) = conj( (*C)(j,i) );
3489 }
3490 }
3491 }
3492 }
3493 }
3494 /*! \endcond */
3495 //*************************************************************************************************
3496
3497
3498 //*************************************************************************************************
3499 /*! \cond BLAZE_INTERNAL */
3500 /*!\brief Compute kernel for a Hermitian dense matrix/dense matrix multiplication
3501 // (\f$ C=\alpha*A*B \f$).
3502 // \ingroup dense_matrix
3503 //
3504 // \param C The target left-hand side column-major dense matrix.
3505 // \param A The left-hand side multiplication operand.
3506 // \param B The right-hand side multiplication operand.
3507 // \param alpha The scaling factor for \f$ A*B \f$.
3508 // \return void
3509 //
3510 // This function implements the compute kernel for a Hermitian dense matrix/dense matrix
3511 // multiplication of the form \f$ C=\alpha*A*B \f$. Both \a A and \a B must be non-expression
3512 // dense matrix types, \a C must be a non-expression, non-adaptor, column-major dense matrix
3513 // type. The element types of all three matrices must be SIMD combinable, i.e. must provide
3514 // a common SIMD interface.
3515 */
3516 template< typename MT1, typename MT2, typename MT3, typename ST >
hmmm(DenseMatrix<MT1,true> & C,const MT2 & A,const MT3 & B,ST alpha)3517 void hmmm( DenseMatrix<MT1,true>& C, const MT2& A, const MT3& B, ST alpha )
3518 {
3519 using ET1 = ElementType_t<MT1>;
3520 using ET2 = ElementType_t<MT2>;
3521 using ET3 = ElementType_t<MT3>;
3522
3523 BLAZE_CONSTRAINT_MUST_BE_DENSE_MATRIX_TYPE ( MT1 );
3524 BLAZE_CONSTRAINT_MUST_BE_COLUMN_MAJOR_MATRIX_TYPE( MT1 );
3525 BLAZE_CONSTRAINT_MUST_NOT_BE_ADAPTOR_TYPE ( MT1 );
3526 BLAZE_CONSTRAINT_MUST_NOT_BE_COMPUTATION_TYPE ( MT1 );
3527
3528 BLAZE_CONSTRAINT_MUST_BE_DENSE_MATRIX_TYPE ( MT2 );
3529 BLAZE_CONSTRAINT_MUST_NOT_BE_COMPUTATION_TYPE( MT2 );
3530
3531 BLAZE_CONSTRAINT_MUST_BE_DENSE_MATRIX_TYPE ( MT3 );
3532 BLAZE_CONSTRAINT_MUST_NOT_BE_COMPUTATION_TYPE( MT3 );
3533
3534 BLAZE_CONSTRAINT_MUST_BE_SIMD_COMBINABLE_TYPES( ET1, ET2 );
3535 BLAZE_CONSTRAINT_MUST_BE_SIMD_COMBINABLE_TYPES( ET1, ET3 );
3536
3537 const size_t M( A.rows() );
3538 const size_t N( B.columns() );
3539
3540 BLAZE_INTERNAL_ASSERT( A.columns() == B.rows(), "Invalid matrix sizes detected" );
3541
3542 ummm( C, A, B, alpha, ST(0) );
3543
3544 for( size_t jj=0UL; jj<N; jj+=BLOCK_SIZE )
3545 {
3546 const size_t jend( min( N, jj+BLOCK_SIZE ) );
3547
3548 for( size_t j=jj; j<jend; ++j ) {
3549 for( size_t i=jj+1UL; i<jend; ++i ) {
3550 (*C)(i,j) = conj( (*C)(j,i) );
3551 }
3552 }
3553
3554 for( size_t ii=jj+BLOCK_SIZE; ii<M; ii+=BLOCK_SIZE ) {
3555 const size_t iend( min( M, ii+BLOCK_SIZE ) );
3556 for( size_t j=jj; j<jend; ++j ) {
3557 for( size_t i=ii; i<iend; ++i ) {
3558 (*C)(i,j) = conj( (*C)(j,i) );
3559 }
3560 }
3561 }
3562 }
3563 }
3564 /*! \endcond */
3565 //*************************************************************************************************
3566
3567
3568 //*************************************************************************************************
3569 /*! \cond BLAZE_INTERNAL */
3570 /*!\brief Compute kernel for a Hermitian dense matrix/dense matrix multiplication (\f$ C=A*B \f$).
3571 // \ingroup dense_matrix
3572 //
3573 // \param C The target left-hand side column-major dense matrix.
3574 // \param A The left-hand side multiplication operand.
3575 // \param B The right-hand side multiplication operand.
3576 // \return void
3577 //
3578 // This function implements the compute kernel for a Hermitian dense matrix/dense matrix
3579 // multiplication of the form \f$ C=A*B \f$. Both \a A and \a B must be non-expression
3580 // dense matrix types, \a C must be a non-expression, non-adaptor, row-major dense matrix
3581 // type. The element types of all three matrices must be SIMD combinable, i.e. must
3582 // provide a common SIMD interface.
3583 */
3584 template< typename MT1, typename MT2, typename MT3 >
hmmm(MT1 & C,const MT2 & A,const MT3 & B)3585 inline void hmmm( MT1& C, const MT2& A, const MT3& B )
3586 {
3587 using ET1 = ElementType_t<MT1>;
3588 using ET2 = ElementType_t<MT2>;
3589 using ET3 = ElementType_t<MT3>;
3590
3591 BLAZE_CONSTRAINT_MUST_BE_SIMD_COMBINABLE_TYPES( ET1, ET2 );
3592 BLAZE_CONSTRAINT_MUST_BE_SIMD_COMBINABLE_TYPES( ET1, ET3 );
3593
3594 hmmm( C, A, B, ET1(1) );
3595 }
3596 /*! \endcond */
3597 //*************************************************************************************************
3598
3599 } // namespace blaze
3600
3601 #endif
3602