1 /*
2 Copyright (c) 2009-2014, Jack Poulson
3 All rights reserved.
4
5 This file is part of Elemental and is under the BSD 2-Clause License,
6 which can be found in the LICENSE file in the root directory, or at
7 http://opensource.org/licenses/BSD-2-Clause
8 */
9 #ifndef ELEM_TRR2K_LOCAL_HPP
10 #define ELEM_TRR2K_LOCAL_HPP
11
12 #include ELEM_AXPYTRIANGLE_INC
13 #include ELEM_SCALETRAPEZOID_INC
14 #include ELEM_GEMM_INC
15
16 namespace elem {
17
18 namespace trr2k {
19
20 #ifndef ELEM_RELEASE
21
EnsureSame(const Grid & gA,const Grid & gB,const Grid & gC,const Grid & gD,const Grid & gE)22 void EnsureSame
23 ( const Grid& gA, const Grid& gB, const Grid& gC,
24 const Grid& gD, const Grid& gE )
25 {
26 if( gA != gB || gB != gC || gC != gD || gD != gE )
27 LogicError("Grids must be the same");
28 }
29
30 template<typename T>
EnsureConformal(const DistMatrix<T,MC,STAR> & A,const DistMatrix<T> & E,std::string name)31 void EnsureConformal
32 ( const DistMatrix<T,MC,STAR>& A, const DistMatrix<T>& E, std::string name )
33 {
34 if( A.Height() != E.Height() || A.ColAlign() != E.ColAlign() )
35 LogicError(name," not conformal with E");
36 }
37
38 template<typename T>
EnsureConformal(const DistMatrix<T,STAR,MC> & A,const DistMatrix<T> & E,std::string name)39 void EnsureConformal
40 ( const DistMatrix<T,STAR,MC>& A, const DistMatrix<T>& E, std::string name )
41 {
42 if( A.Width() != E.Height() || A.RowAlign() != E.ColAlign() )
43 LogicError(name," not conformal with E");
44 }
45
46 template<typename T>
EnsureConformal(const DistMatrix<T,MR,STAR> & A,const DistMatrix<T> & E,std::string name)47 void EnsureConformal
48 ( const DistMatrix<T,MR,STAR>& A, const DistMatrix<T>& E, std::string name )
49 {
50 if( A.Height() != E.Width() || A.ColAlign() != E.RowAlign() )
51 LogicError(name," not conformal with E");
52 }
53
54 template<typename T>
EnsureConformal(const DistMatrix<T,STAR,MR> & A,const DistMatrix<T> & E,std::string name)55 void EnsureConformal
56 ( const DistMatrix<T,STAR,MR>& A, const DistMatrix<T>& E, std::string name )
57 {
58 if( A.Width() != E.Width() || A.RowAlign() != E.RowAlign() )
59 LogicError(name," not conformal with E");
60 }
61
62 template<typename T,Distribution UA,Distribution VA,
63 Distribution UB,Distribution VB,
64 Distribution UC,Distribution VC,
65 Distribution UD,Distribution VD>
CheckInput(const DistMatrix<T,UA,VA> & A,const DistMatrix<T,UB,VB> & B,const DistMatrix<T,UC,VC> & C,const DistMatrix<T,UD,VD> & D,const DistMatrix<T> & E)66 void CheckInput
67 ( const DistMatrix<T,UA,VA>& A, const DistMatrix<T,UB,VB>& B,
68 const DistMatrix<T,UC,VC>& C, const DistMatrix<T,UD,VD>& D,
69 const DistMatrix<T>& E )
70 {
71 EnsureSame( A.Grid(), B.Grid(), C.Grid(), D.Grid(), E.Grid() );
72 EnsureConformal( A, E, "A" );
73 EnsureConformal( B, E, "B" );
74 EnsureConformal( C, E, "C" );
75 EnsureConformal( D, E, "D" );
76 }
77
78 #endif // ifndef ELEM_RELEASE
79
80 // E := alpha (A B + C D) + beta E
81 template<typename T>
82 inline void
LocalTrr2kKernel(UpperOrLower uplo,T alpha,const DistMatrix<T,MC,STAR> & A,const DistMatrix<T,STAR,MR> & B,const DistMatrix<T,MC,STAR> & C,const DistMatrix<T,STAR,MR> & D,T beta,DistMatrix<T> & E)83 LocalTrr2kKernel
84 ( UpperOrLower uplo,
85 T alpha, const DistMatrix<T,MC,STAR>& A, const DistMatrix<T,STAR,MR>& B,
86 const DistMatrix<T,MC,STAR>& C, const DistMatrix<T,STAR,MR>& D,
87 T beta, DistMatrix<T>& E )
88 {
89 DEBUG_ONLY(
90 CallStackEntry cse("LocalTrr2kKernel");
91 CheckInput( A, B, C, D, E );
92 )
93 const Grid& g = E.Grid();
94
95 DistMatrix<T,MC,STAR> AT(g), CT(g),
96 AB(g), CB(g);
97 DistMatrix<T,STAR,MR> BL(g), BR(g),
98 DL(g), DR(g);
99 DistMatrix<T> ETL(g), ETR(g),
100 EBL(g), EBR(g);
101 DistMatrix<T> FTL(g), FBR(g);
102
103 const Int half = E.Height()/2;
104 ScaleTrapezoid( beta, uplo, E );
105 LockedPartitionDown( A, AT, AB, half );
106 LockedPartitionRight( B, BL, BR, half );
107 LockedPartitionDown( C, CT, CB, half );
108 LockedPartitionRight( D, DL, DR, half );
109 PartitionDownDiagonal
110 ( E, ETL, ETR,
111 EBL, EBR, half );
112
113 if( uplo == LOWER )
114 {
115 LocalGemm( NORMAL, NORMAL, alpha, AB, BL, T(1), EBL );
116 LocalGemm( NORMAL, NORMAL, alpha, CB, DL, T(1), EBL );
117 }
118 else
119 {
120 LocalGemm( NORMAL, NORMAL, alpha, AT, BR, T(1), ETR );
121 LocalGemm( NORMAL, NORMAL, alpha, CT, DR, T(1), ETR );
122 }
123
124 FTL.AlignWith( ETL );
125 LocalGemm( NORMAL, NORMAL, alpha, AT, BL, FTL );
126 LocalGemm( NORMAL, NORMAL, alpha, CT, DL, T(1), FTL );
127 AxpyTriangle( uplo, T(1), FTL, ETL );
128
129 FBR.AlignWith( EBR );
130 LocalGemm( NORMAL, NORMAL, alpha, AB, BR, FBR );
131 LocalGemm( NORMAL, NORMAL, alpha, CB, DR, T(1), FBR );
132 AxpyTriangle( uplo, T(1), FBR, EBR );
133 }
134
135 // E := alpha (A B + C D^{T/H}) + beta C
136 template<typename T>
137 inline void
LocalTrr2kKernel(UpperOrLower uplo,Orientation orientationOfD,T alpha,const DistMatrix<T,MC,STAR> & A,const DistMatrix<T,STAR,MR> & B,const DistMatrix<T,MC,STAR> & C,const DistMatrix<T,MR,STAR> & D,T beta,DistMatrix<T> & E)138 LocalTrr2kKernel
139 ( UpperOrLower uplo, Orientation orientationOfD,
140 T alpha, const DistMatrix<T,MC,STAR>& A, const DistMatrix<T,STAR,MR>& B,
141 const DistMatrix<T,MC,STAR>& C, const DistMatrix<T,MR,STAR>& D,
142 T beta, DistMatrix<T>& E )
143 {
144 DEBUG_ONLY(
145 CallStackEntry cse("LocalTrr2kKernel");
146 CheckInput( A, B, C, D, E );
147 )
148 const Grid& g = E.Grid();
149
150 DistMatrix<T,MC,STAR> AT(g), CT(g),
151 AB(g), CB(g);
152 DistMatrix<T,MR,STAR> DT(g),
153 DB(g);
154 DistMatrix<T,STAR,MR> BL(g), BR(g);
155 DistMatrix<T> ETL(g), ETR(g),
156 EBL(g), EBR(g);
157 DistMatrix<T> FTL(g), FBR(g);
158
159 const Int half = E.Height()/2;
160 ScaleTrapezoid( beta, uplo, E );
161 LockedPartitionDown( A, AT, AB, half );
162 LockedPartitionRight( B, BL, BR, half );
163 LockedPartitionDown( C, CT, CB, half );
164 LockedPartitionDown( D, DT, DB, half );
165 PartitionDownDiagonal
166 ( E, ETL, ETR,
167 EBL, EBR, half );
168
169 if( uplo == LOWER )
170 {
171 LocalGemm( NORMAL, NORMAL, alpha, AB, BL, T(1), EBL );
172 LocalGemm( NORMAL, orientationOfD, alpha, CB, DT, T(1), EBL );
173 }
174 else
175 {
176 LocalGemm( NORMAL, NORMAL, alpha, AT, BR, T(1), ETR );
177 LocalGemm( NORMAL, orientationOfD, alpha, CT, DB, T(1), ETR );
178 }
179
180 FTL.AlignWith( ETL );
181 LocalGemm( NORMAL, NORMAL, alpha, AT, BL, FTL );
182 LocalGemm( NORMAL, orientationOfD, alpha, CT, DT, T(1), FTL );
183 AxpyTriangle( uplo, T(1), FTL, ETL );
184
185 FBR.AlignWith( EBR );
186 LocalGemm( NORMAL, NORMAL, alpha, AB, BR, FBR );
187 LocalGemm( NORMAL, orientationOfD, alpha, CB, DB, T(1), FBR );
188 AxpyTriangle( uplo, T(1), FBR, EBR );
189 }
190
191 // E := alpha (A B + C^{T/H} D) + beta E
192 template<typename T>
193 inline void
LocalTrr2kKernel(UpperOrLower uplo,Orientation orientationOfC,T alpha,const DistMatrix<T,MC,STAR> & A,const DistMatrix<T,STAR,MR> & B,const DistMatrix<T,STAR,MC> & C,const DistMatrix<T,STAR,MR> & D,T beta,DistMatrix<T> & E)194 LocalTrr2kKernel
195 ( UpperOrLower uplo, Orientation orientationOfC,
196 T alpha, const DistMatrix<T,MC,STAR>& A, const DistMatrix<T,STAR,MR>& B,
197 const DistMatrix<T,STAR,MC>& C, const DistMatrix<T,STAR,MR>& D,
198 T beta, DistMatrix<T>& E )
199 {
200 DEBUG_ONLY(
201 CallStackEntry cse("LocalTrr2kKernel");
202 CheckInput( A, B, C, D, E );
203 )
204 const Grid& g = E.Grid();
205
206 DistMatrix<T,MC,STAR> AT(g), AB(g);
207 DistMatrix<T,STAR,MC> CL(g), CR(g);
208 DistMatrix<T,STAR,MR> BL(g), BR(g),
209 DL(g), DR(g);
210 DistMatrix<T> ETL(g), ETR(g),
211 EBL(g), EBR(g);
212 DistMatrix<T> FTL(g), FBR(g);
213
214 const Int half = E.Height()/2;
215 ScaleTrapezoid( beta, uplo, E );
216 LockedPartitionDown( A, AT, AB, half );
217 LockedPartitionRight( B, BL, BR, half );
218 LockedPartitionRight( C, CL, CR, half );
219 LockedPartitionRight( D, DL, DR, half );
220 PartitionDownDiagonal
221 ( E, ETL, ETR,
222 EBL, EBR, half );
223
224 if( uplo == LOWER )
225 {
226 LocalGemm( NORMAL, NORMAL, alpha, AB, BL, T(1), EBL );
227 LocalGemm( orientationOfC, NORMAL, alpha, CR, DL, T(1), EBL );
228 }
229 else
230 {
231 LocalGemm( NORMAL, NORMAL, alpha, AT, BR, T(1), ETR );
232 LocalGemm( orientationOfC, NORMAL, alpha, CL, DR, T(1), ETR );
233 }
234
235 FTL.AlignWith( ETL );
236 LocalGemm( NORMAL, NORMAL, alpha, AT, BL, FTL );
237 LocalGemm( orientationOfC, NORMAL, alpha, CL, DL, T(1), FTL );
238 AxpyTriangle( uplo, T(1), FTL, ETL );
239
240 FBR.AlignWith( EBR );
241 LocalGemm( NORMAL, NORMAL, alpha, AB, BR, FBR );
242 LocalGemm( orientationOfC, NORMAL, alpha, CR, DR, T(1), FBR );
243 AxpyTriangle( uplo, T(1), FBR, EBR );
244 }
245
246 // E := alpha (A B + C^{T/H} D^{T/H}) + beta E
247 template<typename T>
248 inline void
LocalTrr2kKernel(UpperOrLower uplo,Orientation orientationOfC,Orientation orientationOfD,T alpha,const DistMatrix<T,MC,STAR> & A,const DistMatrix<T,STAR,MR> & B,const DistMatrix<T,STAR,MC> & C,const DistMatrix<T,MR,STAR> & D,T beta,DistMatrix<T> & E)249 LocalTrr2kKernel
250 ( UpperOrLower uplo, Orientation orientationOfC, Orientation orientationOfD,
251 T alpha, const DistMatrix<T,MC,STAR>& A, const DistMatrix<T,STAR,MR>& B,
252 const DistMatrix<T,STAR,MC>& C, const DistMatrix<T,MR,STAR>& D,
253 T beta, DistMatrix<T>& E )
254 {
255 DEBUG_ONLY(
256 CallStackEntry cse("LocalTrr2kKernel");
257 CheckInput( A, B, C, D, E );
258 )
259 const Grid& g = E.Grid();
260
261 DistMatrix<T,MC,STAR> AT(g), AB(g);
262 DistMatrix<T,STAR,MR> BL(g), BR(g);
263 DistMatrix<T,STAR,MC> CL(g), CR(g);
264 DistMatrix<T,MR,STAR> DT(g), DB(g);
265 DistMatrix<T> ETL(g), ETR(g),
266 EBL(g), EBR(g);
267 DistMatrix<T> FTL(g), FBR(g);
268
269 const Int half = E.Height()/2;
270 ScaleTrapezoid( beta, uplo, E );
271 LockedPartitionDown( A, AT, AB, half );
272 LockedPartitionRight( B, BL, BR, half );
273 LockedPartitionRight( C, CL, CR, half );
274 LockedPartitionDown( D, DT, DB, half );
275 PartitionDownDiagonal
276 ( E, ETL, ETR,
277 EBL, EBR, half );
278
279 if( uplo == LOWER )
280 {
281 LocalGemm( NORMAL, NORMAL, alpha, AB, BL, T(1), EBL );
282 LocalGemm( orientationOfC, orientationOfD, alpha, CR, DT, T(1), EBL );
283 }
284 else
285 {
286 LocalGemm( NORMAL, NORMAL, alpha, AT, BR, T(1), ETR );
287 LocalGemm( orientationOfC, orientationOfD, alpha, CL, DB, T(1), ETR );
288 }
289
290 FTL.AlignWith( ETL );
291 LocalGemm( NORMAL, NORMAL, alpha, AT, BL, FTL );
292 LocalGemm( orientationOfC, orientationOfD, alpha, CL, DT, T(1), FTL );
293 AxpyTriangle( uplo, T(1), FTL, ETL );
294
295 FBR.AlignWith( EBR );
296 LocalGemm( NORMAL, NORMAL, alpha, AB, BR, FBR );
297 LocalGemm( orientationOfC, orientationOfD, alpha, CR, DB, T(1), FBR );
298 AxpyTriangle( uplo, T(1), FBR, EBR );
299 }
300
301 // E := alpha (A B^{T/H} + C D) + beta C
302 template<typename T>
303 inline void
LocalTrr2kKernel(UpperOrLower uplo,Orientation orientationOfB,T alpha,const DistMatrix<T,MC,STAR> & A,const DistMatrix<T,MR,STAR> & B,const DistMatrix<T,MC,STAR> & C,const DistMatrix<T,STAR,MR> & D,T beta,DistMatrix<T> & E)304 LocalTrr2kKernel
305 ( UpperOrLower uplo, Orientation orientationOfB,
306 T alpha, const DistMatrix<T,MC,STAR>& A, const DistMatrix<T,MR,STAR>& B,
307 const DistMatrix<T,MC,STAR>& C, const DistMatrix<T,STAR,MR>& D,
308 T beta, DistMatrix<T>& E )
309 {
310 DEBUG_ONLY(
311 CallStackEntry cse("LocalTrr2kKernel");
312 CheckInput( A, B, C, D, E );
313 )
314 const Grid& g = E.Grid();
315
316 DistMatrix<T,MC,STAR> AT(g), CT(g),
317 AB(g), CB(g);
318 DistMatrix<T,MR,STAR> BT(g), BB(g);
319 DistMatrix<T,STAR,MR> DL(g), DR(g);
320 DistMatrix<T> ETL(g), ETR(g),
321 EBL(g), EBR(g);
322 DistMatrix<T> FTL(g), FBR(g);
323
324 const Int half = E.Height()/2;
325 ScaleTrapezoid( beta, uplo, E );
326 LockedPartitionDown( A, AT, AB, half );
327 LockedPartitionDown( B, BT, BB, half );
328 LockedPartitionDown( C, CT, CB, half );
329 LockedPartitionRight( D, DL, DR, half );
330 PartitionDownDiagonal
331 ( E, ETL, ETR,
332 EBL, EBR, half );
333
334 if( uplo == LOWER )
335 {
336 LocalGemm( NORMAL, orientationOfB, alpha, AB, BT, T(1), EBL );
337 LocalGemm( NORMAL, NORMAL, alpha, CB, DL, T(1), EBL );
338 }
339 else
340 {
341 LocalGemm( NORMAL, orientationOfB, alpha, AT, BB, T(1), ETR );
342 LocalGemm( NORMAL, NORMAL, alpha, CT, DR, T(1), ETR );
343 }
344
345 FTL.AlignWith( ETL );
346 LocalGemm( NORMAL, orientationOfB, alpha, AT, BT, FTL );
347 LocalGemm( NORMAL, NORMAL, alpha, CT, DL, T(1), FTL );
348 AxpyTriangle( uplo, T(1), FTL, ETL );
349
350 FBR.AlignWith( EBR );
351 LocalGemm( NORMAL, orientationOfB, alpha, AB, BB, FBR );
352 LocalGemm( NORMAL, NORMAL, alpha, CB, DR, T(1), FBR );
353 AxpyTriangle( uplo, T(1), FBR, EBR );
354 }
355
356 // E := alpha (A B^{T/H} + C D^{T/H}) + beta C
357 template<typename T>
358 inline void
LocalTrr2kKernel(UpperOrLower uplo,Orientation orientationOfB,Orientation orientationOfD,T alpha,const DistMatrix<T,MC,STAR> & A,const DistMatrix<T,MR,STAR> & B,const DistMatrix<T,MC,STAR> & C,const DistMatrix<T,MR,STAR> & D,T beta,DistMatrix<T> & E)359 LocalTrr2kKernel
360 ( UpperOrLower uplo, Orientation orientationOfB, Orientation orientationOfD,
361 T alpha, const DistMatrix<T,MC,STAR>& A, const DistMatrix<T,MR,STAR>& B,
362 const DistMatrix<T,MC,STAR>& C, const DistMatrix<T,MR,STAR>& D,
363 T beta, DistMatrix<T>& E )
364 {
365 DEBUG_ONLY(
366 CallStackEntry cse("LocalTrr2kKernel");
367 CheckInput( A, B, C, D, E );
368 )
369 const Grid& g = E.Grid();
370
371 DistMatrix<T,MC,STAR> AT(g), CT(g),
372 AB(g), CB(g);
373 DistMatrix<T,MR,STAR> BT(g), DT(g),
374 BB(g), DB(g);
375 DistMatrix<T> ETL(g), ETR(g),
376 EBL(g), EBR(g);
377 DistMatrix<T> FTL(g), FBR(g);
378
379 const Int half = E.Height()/2;
380 ScaleTrapezoid( beta, uplo, E );
381 LockedPartitionDown( A, AT, AB, half );
382 LockedPartitionDown( B, BT, BB, half );
383 LockedPartitionDown( C, CT, CB, half );
384 LockedPartitionDown( D, DT, DB, half );
385 PartitionDownDiagonal
386 ( E, ETL, ETR,
387 EBL, EBR, half );
388
389 if( uplo == LOWER )
390 {
391 LocalGemm( NORMAL, orientationOfB, alpha, AB, BT, T(1), EBL );
392 LocalGemm( NORMAL, orientationOfD, alpha, CB, DT, T(1), EBL );
393 }
394 else
395 {
396 LocalGemm( NORMAL, orientationOfB, alpha, AT, BB, T(1), ETR );
397 LocalGemm( NORMAL, orientationOfD, alpha, CT, DB, T(1), ETR );
398 }
399
400 FTL.AlignWith( ETL );
401 LocalGemm( NORMAL, orientationOfB, alpha, AT, BT, FTL );
402 LocalGemm( NORMAL, orientationOfD, alpha, CT, DT, T(1), FTL );
403 AxpyTriangle( uplo, T(1), FTL, ETL );
404
405 FBR.AlignWith( EBR );
406 LocalGemm( NORMAL, orientationOfB, alpha, AB, BB, FBR );
407 LocalGemm( NORMAL, orientationOfD, alpha, CB, DB, T(1), FBR );
408 AxpyTriangle( uplo, T(1), FBR, EBR );
409 }
410
411 // E := alpha (A B^{T/H} + C^{T/H} D) + beta E
412 template<typename T>
413 inline void
LocalTrr2kKernel(UpperOrLower uplo,Orientation orientationOfB,Orientation orientationOfC,T alpha,const DistMatrix<T,MC,STAR> & A,const DistMatrix<T,MR,STAR> & B,const DistMatrix<T,STAR,MC> & C,const DistMatrix<T,STAR,MR> & D,T beta,DistMatrix<T> & E)414 LocalTrr2kKernel
415 ( UpperOrLower uplo, Orientation orientationOfB, Orientation orientationOfC,
416 T alpha, const DistMatrix<T,MC,STAR>& A, const DistMatrix<T,MR,STAR>& B,
417 const DistMatrix<T,STAR,MC>& C, const DistMatrix<T,STAR,MR>& D,
418 T beta, DistMatrix<T>& E )
419 {
420 DEBUG_ONLY(
421 CallStackEntry cse("LocalTrr2kKernel");
422 CheckInput( A, B, C, D, E );
423 )
424 const Grid& g = E.Grid();
425
426 DistMatrix<T,MC,STAR> AT(g), AB(g);
427 DistMatrix<T,MR,STAR> BT(g), BB(g);
428 DistMatrix<T,STAR,MC> CL(g), CR(g);
429 DistMatrix<T,STAR,MR> DL(g), DR(g);
430 DistMatrix<T> ETL(g), ETR(g),
431 EBL(g), EBR(g);
432 DistMatrix<T> FTL(g), FBR(g);
433
434 const Int half = E.Height()/2;
435 ScaleTrapezoid( beta, uplo, E );
436 LockedPartitionDown( A, AT, AB, half );
437 LockedPartitionDown( B, BT, BB, half );
438 LockedPartitionRight( C, CL, CR, half );
439 LockedPartitionRight( D, DL, DR, half );
440 PartitionDownDiagonal
441 ( E, ETL, ETR,
442 EBL, EBR, half );
443
444 if( uplo == LOWER )
445 {
446 LocalGemm( NORMAL, orientationOfB, alpha, AB, BT, T(1), EBL );
447 LocalGemm( orientationOfC, NORMAL, alpha, CR, DL, T(1), EBL );
448 }
449 else
450 {
451 LocalGemm( NORMAL, orientationOfB, alpha, AT, BB, T(1), ETR );
452 LocalGemm( orientationOfC, NORMAL, alpha, CL, DR, T(1), ETR );
453 }
454
455 FTL.AlignWith( ETL );
456 LocalGemm( NORMAL, orientationOfB, alpha, AT, BT, FTL );
457 LocalGemm( orientationOfC, NORMAL, alpha, CL, DL, T(1), FTL );
458 AxpyTriangle( uplo, T(1), FTL, ETL );
459
460 FBR.AlignWith( EBR );
461 LocalGemm( NORMAL, orientationOfB, alpha, AB, BB, FBR );
462 LocalGemm( orientationOfC, NORMAL, alpha, CR, DR, T(1), FBR );
463 AxpyTriangle( uplo, T(1), FBR, EBR );
464 }
465
466 // E := alpha (A B^{T/H} + C^{T/H} D^{T/H}) + beta C
467 template<typename T>
468 inline void
LocalTrr2kKernel(UpperOrLower uplo,Orientation orientationOfB,Orientation orientationOfC,Orientation orientationOfD,T alpha,const DistMatrix<T,MC,STAR> & A,const DistMatrix<T,MR,STAR> & B,const DistMatrix<T,STAR,MC> & C,const DistMatrix<T,MR,STAR> & D,T beta,DistMatrix<T> & E)469 LocalTrr2kKernel
470 ( UpperOrLower uplo,
471 Orientation orientationOfB,
472 Orientation orientationOfC,
473 Orientation orientationOfD,
474 T alpha, const DistMatrix<T,MC,STAR>& A, const DistMatrix<T,MR,STAR>& B,
475 const DistMatrix<T,STAR,MC>& C, const DistMatrix<T,MR,STAR>& D,
476 T beta, DistMatrix<T>& E )
477 {
478 DEBUG_ONLY(
479 CallStackEntry cse("LocalTrr2kKernel");
480 CheckInput( A, B, C, D, E );
481 )
482 const Grid& g = E.Grid();
483
484 DistMatrix<T,MC,STAR> AT(g), AB(g);
485 DistMatrix<T,MR,STAR> BT(g), BB(g);
486 DistMatrix<T,STAR,MC> CL(g), CR(g);
487 DistMatrix<T,MR,STAR> DT(g), DB(g);
488 DistMatrix<T> ETL(g), ETR(g),
489 EBL(g), EBR(g);
490 DistMatrix<T> FTL(g), FBR(g);
491
492 const Int half = E.Height()/2;
493 ScaleTrapezoid( beta, uplo, E );
494 LockedPartitionDown( A, AT, AB, half );
495 LockedPartitionDown( B, BT, BB, half );
496 LockedPartitionRight( C, CL, CR, half );
497 LockedPartitionDown( D, DT, DB, half );
498 PartitionDownDiagonal
499 ( E, ETL, ETR,
500 EBL, EBR, half );
501
502 if( uplo == LOWER )
503 {
504 LocalGemm( NORMAL, orientationOfB, alpha, AB, BT, T(1), EBL );
505 LocalGemm( orientationOfC, orientationOfD, alpha, CR, DT, T(1), EBL );
506 }
507 else
508 {
509 LocalGemm( NORMAL, orientationOfB, alpha, AT, BB, T(1), ETR );
510 LocalGemm( orientationOfC, orientationOfD, alpha, CL, DB, T(1), ETR );
511 }
512
513 FTL.AlignWith( ETL );
514 LocalGemm( NORMAL, orientationOfB, alpha, AT, BT, FTL );
515 LocalGemm( orientationOfC, orientationOfD, alpha, CL, DT, T(1), FTL );
516 AxpyTriangle( uplo, T(1), FTL, ETL );
517
518 FBR.AlignWith( EBR );
519 LocalGemm( NORMAL, orientationOfB, alpha, AB, BB, FBR );
520 LocalGemm( orientationOfC, orientationOfD, alpha, CR, DB, T(1), FBR );
521 AxpyTriangle( uplo, T(1), FBR, EBR );
522 }
523
524 // E := alpha (A^{T/H} B + C D) + beta E
525 template<typename T>
526 inline void
LocalTrr2kKernel(UpperOrLower uplo,Orientation orientationOfA,T alpha,const DistMatrix<T,STAR,MC> & A,const DistMatrix<T,STAR,MR> & B,const DistMatrix<T,MC,STAR> & C,const DistMatrix<T,STAR,MR> & D,T beta,DistMatrix<T> & E)527 LocalTrr2kKernel
528 ( UpperOrLower uplo, Orientation orientationOfA,
529 T alpha, const DistMatrix<T,STAR,MC>& A, const DistMatrix<T,STAR,MR>& B,
530 const DistMatrix<T,MC,STAR>& C, const DistMatrix<T,STAR,MR>& D,
531 T beta, DistMatrix<T>& E )
532 {
533 DEBUG_ONLY(
534 CallStackEntry cse("LocalTrr2kKernel");
535 CheckInput( A, B, C, D, E );
536 )
537 const Grid& g = E.Grid();
538
539 DistMatrix<T,STAR,MC> AL(g), AR(g);
540 DistMatrix<T,MC,STAR> CT(g), CB(g);
541 DistMatrix<T,STAR,MR> BL(g), BR(g),
542 DL(g), DR(g);
543 DistMatrix<T> ETL(g), ETR(g),
544 EBL(g), EBR(g);
545 DistMatrix<T> FTL(g), FBR(g);
546
547 const Int half = E.Height()/2;
548 ScaleTrapezoid( beta, uplo, E );
549 LockedPartitionRight( A, AL, AR, half );
550 LockedPartitionRight( B, BL, BR, half );
551 LockedPartitionDown( C, CT, CB, half );
552 LockedPartitionRight( D, DL, DR, half );
553 PartitionDownDiagonal
554 ( E, ETL, ETR,
555 EBL, EBR, half );
556
557 if( uplo == LOWER )
558 {
559 LocalGemm( orientationOfA, NORMAL, alpha, AR, BL, T(1), EBL );
560 LocalGemm( NORMAL, NORMAL, alpha, CB, DL, T(1), EBL );
561 }
562 else
563 {
564 LocalGemm( orientationOfA, NORMAL, alpha, AL, BR, T(1), ETR );
565 LocalGemm( NORMAL, NORMAL, alpha, CT, DR, T(1), ETR );
566 }
567
568 FTL.AlignWith( ETL );
569 LocalGemm( orientationOfA, NORMAL, alpha, AL, BL, FTL );
570 LocalGemm( NORMAL, NORMAL, alpha, CT, DL, T(1), FTL );
571 AxpyTriangle( uplo, T(1), FTL, ETL );
572
573 FBR.AlignWith( EBR );
574 LocalGemm( orientationOfA, NORMAL, alpha, AR, BR, FBR );
575 LocalGemm( NORMAL, NORMAL, alpha, CB, DR, T(1), FBR );
576 AxpyTriangle( uplo, T(1), FBR, EBR );
577 }
578
579 // E := alpha (A^{T/H} B + C D^{T/H}) + beta E
580 template<typename T>
581 inline void
LocalTrr2kKernel(UpperOrLower uplo,Orientation orientationOfA,Orientation orientationOfD,T alpha,const DistMatrix<T,STAR,MC> & A,const DistMatrix<T,STAR,MR> & B,const DistMatrix<T,MC,STAR> & C,const DistMatrix<T,MR,STAR> & D,T beta,DistMatrix<T> & E)582 LocalTrr2kKernel
583 ( UpperOrLower uplo, Orientation orientationOfA, Orientation orientationOfD,
584 T alpha, const DistMatrix<T,STAR,MC>& A, const DistMatrix<T,STAR,MR>& B,
585 const DistMatrix<T,MC,STAR>& C, const DistMatrix<T,MR,STAR>& D,
586 T beta, DistMatrix<T>& E )
587 {
588 DEBUG_ONLY(
589 CallStackEntry cse("LocalTrr2kKernel");
590 CheckInput( A, B, C, D, E );
591 )
592 const Grid& g = E.Grid();
593
594 DistMatrix<T,STAR,MC> AL(g), AR(g);
595 DistMatrix<T,STAR,MR> BL(g), BR(g);
596 DistMatrix<T,MC,STAR> CT(g), CB(g);
597 DistMatrix<T,MR,STAR> DT(g), DB(g);
598 DistMatrix<T> ETL(g), ETR(g),
599 EBL(g), EBR(g);
600 DistMatrix<T> FTL(g), FBR(g);
601
602 const Int half = E.Height()/2;
603 ScaleTrapezoid( beta, uplo, E );
604 LockedPartitionRight( A, AL, AR, half );
605 LockedPartitionRight( B, BL, BR, half );
606 LockedPartitionDown( C, CT, CB, half );
607 LockedPartitionDown( D, DT, DB, half );
608 PartitionDownDiagonal
609 ( E, ETL, ETR,
610 EBL, EBR, half );
611
612 if( uplo == LOWER )
613 {
614 LocalGemm( orientationOfA, NORMAL, alpha, AR, BL, T(1), EBL );
615 LocalGemm( NORMAL, orientationOfD, alpha, CB, DT, T(1), EBL );
616 }
617 else
618 {
619 LocalGemm( orientationOfA, NORMAL, alpha, AL, BR, T(1), ETR );
620 LocalGemm( NORMAL, orientationOfD, alpha, CT, DB, T(1), ETR );
621 }
622
623 FTL.AlignWith( ETL );
624 LocalGemm( orientationOfA, NORMAL, alpha, AL, BL, FTL );
625 LocalGemm( NORMAL, orientationOfD, alpha, CT, DT, T(1), FTL );
626 AxpyTriangle( uplo, T(1), FTL, ETL );
627
628 FBR.AlignWith( EBR );
629 LocalGemm( orientationOfA, NORMAL, alpha, AR, BR, FBR );
630 LocalGemm( NORMAL, orientationOfD, alpha, CB, DB, T(1), FBR );
631 AxpyTriangle( uplo, T(1), FBR, EBR );
632 }
633
634 // E := alpha (A^{T/H} B + C^{T/H} D) + beta E
635 template<typename T>
636 inline void
LocalTrr2kKernel(UpperOrLower uplo,Orientation orientationOfA,Orientation orientationOfC,T alpha,const DistMatrix<T,STAR,MC> & A,const DistMatrix<T,STAR,MR> & B,const DistMatrix<T,STAR,MC> & C,const DistMatrix<T,STAR,MR> & D,T beta,DistMatrix<T> & E)637 LocalTrr2kKernel
638 ( UpperOrLower uplo, Orientation orientationOfA, Orientation orientationOfC,
639 T alpha, const DistMatrix<T,STAR,MC>& A, const DistMatrix<T,STAR,MR>& B,
640 const DistMatrix<T,STAR,MC>& C, const DistMatrix<T,STAR,MR>& D,
641 T beta, DistMatrix<T>& E )
642 {
643 DEBUG_ONLY(
644 CallStackEntry cse("LocalTrr2kKernel");
645 CheckInput( A, B, C, D, E );
646 )
647 const Grid& g = E.Grid();
648
649 DistMatrix<T,STAR,MC> AL(g), AR(g),
650 CL(g), CR(g);
651 DistMatrix<T,STAR,MR> BL(g), BR(g),
652 DL(g), DR(g);
653 DistMatrix<T> ETL(g), ETR(g),
654 EBL(g), EBR(g);
655 DistMatrix<T> FTL(g), FBR(g);
656
657 const Int half = E.Height()/2;
658 ScaleTrapezoid( beta, uplo, E );
659 LockedPartitionRight( A, AL, AR, half );
660 LockedPartitionRight( B, BL, BR, half );
661 LockedPartitionRight( C, CL, CR, half );
662 LockedPartitionRight( D, DL, DR, half );
663 PartitionDownDiagonal
664 ( E, ETL, ETR,
665 EBL, EBR, half );
666
667 if( uplo == LOWER )
668 {
669 LocalGemm( orientationOfA, NORMAL, alpha, AR, BL, T(1), EBL );
670 LocalGemm( orientationOfC, NORMAL, alpha, CR, DL, T(1), EBL );
671 }
672 else
673 {
674 LocalGemm( orientationOfA, NORMAL, alpha, AL, BR, T(1), ETR );
675 LocalGemm( orientationOfC, NORMAL, alpha, CL, DR, T(1), ETR );
676 }
677
678 FTL.AlignWith( ETL );
679 LocalGemm( orientationOfA, NORMAL, alpha, AL, BL, FTL );
680 LocalGemm( orientationOfC, NORMAL, alpha, CL, DL, T(1), FTL );
681 AxpyTriangle( uplo, T(1), FTL, ETL );
682
683 FBR.AlignWith( EBR );
684 LocalGemm( orientationOfA, NORMAL, alpha, AR, BR, FBR );
685 LocalGemm( orientationOfC, NORMAL, alpha, CR, DR, T(1), FBR );
686 AxpyTriangle( uplo, T(1), FBR, EBR );
687 }
688
689 // E := alpha (A^{T/H} B + C^{T/H} D^{T/H}) + beta E
690 template<typename T>
691 inline void
LocalTrr2kKernel(UpperOrLower uplo,Orientation orientationOfA,Orientation orientationOfC,Orientation orientationOfD,T alpha,const DistMatrix<T,STAR,MC> & A,const DistMatrix<T,STAR,MR> & B,const DistMatrix<T,STAR,MC> & C,const DistMatrix<T,MR,STAR> & D,T beta,DistMatrix<T> & E)692 LocalTrr2kKernel
693 ( UpperOrLower uplo,
694 Orientation orientationOfA,
695 Orientation orientationOfC,
696 Orientation orientationOfD,
697 T alpha, const DistMatrix<T,STAR,MC>& A, const DistMatrix<T,STAR,MR>& B,
698 const DistMatrix<T,STAR,MC>& C, const DistMatrix<T,MR,STAR>& D,
699 T beta, DistMatrix<T>& E )
700 {
701 DEBUG_ONLY(
702 CallStackEntry cse("LocalTrr2kKernel");
703 CheckInput( A, B, C, D, E );
704 )
705 const Grid& g = E.Grid();
706
707 DistMatrix<T,STAR,MC> AL(g), AR(g),
708 CL(g), CR(g);
709 DistMatrix<T,STAR,MR> BL(g), BR(g);
710 DistMatrix<T,MR,STAR> DT(g), DB(g);
711 DistMatrix<T> ETL(g), ETR(g),
712 EBL(g), EBR(g);
713 DistMatrix<T> FTL(g), FBR(g);
714
715 const Int half = E.Height()/2;
716 ScaleTrapezoid( beta, uplo, E );
717 LockedPartitionRight( A, AL, AR, half );
718 LockedPartitionRight( B, BL, BR, half );
719 LockedPartitionRight( C, CL, CR, half );
720 LockedPartitionDown( D, DT, DB, half );
721 PartitionDownDiagonal
722 ( E, ETL, ETR,
723 EBL, EBR, half );
724
725 if( uplo == LOWER )
726 {
727 LocalGemm( orientationOfA, NORMAL, alpha, AR, BL, T(1), EBL );
728 LocalGemm( orientationOfC, orientationOfD, alpha, CR, DT, T(1), EBL );
729 }
730 else
731 {
732 LocalGemm( orientationOfA, NORMAL, alpha, AL, BR, T(1), ETR );
733 LocalGemm( orientationOfC, orientationOfD, alpha, CL, DB, T(1), ETR );
734 }
735
736 FTL.AlignWith( ETL );
737 LocalGemm( orientationOfA, NORMAL, alpha, AL, BL, FTL );
738 LocalGemm( orientationOfC, orientationOfD, alpha, CL, DT, T(1), FTL );
739 AxpyTriangle( uplo, T(1), FTL, ETL );
740
741 FBR.AlignWith( EBR );
742 LocalGemm( orientationOfA, NORMAL, alpha, AR, BR, FBR );
743 LocalGemm( orientationOfC, orientationOfD, alpha, CR, DB, T(1), FBR );
744 AxpyTriangle( uplo, T(1), FBR, EBR );
745 }
746
747 // E := alpha (A^{T/H} B^{T/H} + C D) + beta E
748 template<typename T>
749 inline void
LocalTrr2kKernel(UpperOrLower uplo,Orientation orientationOfA,Orientation orientationOfB,T alpha,const DistMatrix<T,STAR,MC> & A,const DistMatrix<T,MR,STAR> & B,const DistMatrix<T,MC,STAR> & C,const DistMatrix<T,STAR,MR> & D,T beta,DistMatrix<T> & E)750 LocalTrr2kKernel
751 ( UpperOrLower uplo, Orientation orientationOfA, Orientation orientationOfB,
752 T alpha, const DistMatrix<T,STAR,MC>& A, const DistMatrix<T,MR,STAR>& B,
753 const DistMatrix<T,MC,STAR>& C, const DistMatrix<T,STAR,MR>& D,
754 T beta, DistMatrix<T>& E )
755 {
756 DEBUG_ONLY(
757 CallStackEntry cse("LocalTrr2kKernel");
758 CheckInput( A, B, C, D, E );
759 )
760 const Grid& g = E.Grid();
761
762 DistMatrix<T,STAR,MC> AL(g), AR(g);
763 DistMatrix<T,MR,STAR> BT(g), BB(g);
764 DistMatrix<T,MC,STAR> CT(g), CB(g);
765 DistMatrix<T,STAR,MR> DL(g), DR(g);
766 DistMatrix<T> ETL(g), ETR(g),
767 EBL(g), EBR(g);
768 DistMatrix<T> FTL(g), FBR(g);
769
770 const Int half = E.Height()/2;
771 ScaleTrapezoid( beta, uplo, E );
772 LockedPartitionRight( A, AL, AR, half );
773 LockedPartitionDown( B, BT, BB, half );
774 LockedPartitionDown( C, CT, CB, half );
775 LockedPartitionRight( D, DL, DR, half );
776 PartitionDownDiagonal
777 ( E, ETL, ETR,
778 EBL, EBR, half );
779
780 if( uplo == LOWER )
781 {
782 LocalGemm( orientationOfA, orientationOfB, alpha, AR, BT, T(1), EBL );
783 LocalGemm( NORMAL, NORMAL, alpha, CB, DL, T(1), EBL );
784 }
785 else
786 {
787 LocalGemm( orientationOfA, orientationOfB, alpha, AL, BB, T(1), ETR );
788 LocalGemm( NORMAL, NORMAL, alpha, CT, DR, T(1), ETR );
789 }
790
791 FTL.AlignWith( ETL );
792 LocalGemm( orientationOfA, orientationOfB, alpha, AL, BT, FTL );
793 LocalGemm( NORMAL, NORMAL, alpha, CT, DL, T(1), FTL );
794 AxpyTriangle( uplo, T(1), FTL, ETL );
795
796 FBR.AlignWith( EBR );
797 LocalGemm( orientationOfA, orientationOfB, alpha, AR, BB, FBR );
798 LocalGemm( NORMAL, NORMAL, alpha, CB, DR, T(1), FBR );
799 AxpyTriangle( uplo, T(1), FBR, EBR );
800 }
801
802 // E := alpha (A^{T/H} B^{T/H} + C D^{T/H}) + beta C
803 template<typename T>
804 inline void
LocalTrr2kKernel(UpperOrLower uplo,Orientation orientationOfA,Orientation orientationOfB,Orientation orientationOfD,T alpha,const DistMatrix<T,STAR,MC> & A,const DistMatrix<T,MR,STAR> & B,const DistMatrix<T,MC,STAR> & C,const DistMatrix<T,MR,STAR> & D,T beta,DistMatrix<T> & E)805 LocalTrr2kKernel
806 ( UpperOrLower uplo,
807 Orientation orientationOfA,
808 Orientation orientationOfB,
809 Orientation orientationOfD,
810 T alpha, const DistMatrix<T,STAR,MC>& A, const DistMatrix<T,MR,STAR>& B,
811 const DistMatrix<T,MC,STAR>& C, const DistMatrix<T,MR,STAR>& D,
812 T beta, DistMatrix<T>& E )
813 {
814 DEBUG_ONLY(
815 CallStackEntry cse("LocalTrr2kKernel");
816 CheckInput( A, B, C, D, E );
817 )
818 const Grid& g = E.Grid();
819
820 DistMatrix<T,STAR,MC> AL(g), AR(g);
821 DistMatrix<T,MR,STAR> BT(g), BB(g);
822 DistMatrix<T,MC,STAR> CT(g), CB(g);
823 DistMatrix<T,MR,STAR> DT(g), DB(g);
824 DistMatrix<T> ETL(g), ETR(g),
825 EBL(g), EBR(g);
826 DistMatrix<T> FTL(g), FBR(g);
827
828 const Int half = E.Height()/2;
829 ScaleTrapezoid( beta, uplo, E );
830 LockedPartitionRight( A, AL, AR, half );
831 LockedPartitionDown( B, BT, BB, half );
832 LockedPartitionDown( C, CT, CB, half );
833 LockedPartitionDown( D, DT, DB, half );
834 PartitionDownDiagonal
835 ( E, ETL, ETR,
836 EBL, EBR, half );
837
838 if( uplo == LOWER )
839 {
840 LocalGemm( orientationOfA, orientationOfB, alpha, AR, BT, T(1), EBL );
841 LocalGemm( NORMAL, orientationOfD, alpha, CB, DT, T(1), EBL );
842 }
843 else
844 {
845 LocalGemm( orientationOfA, orientationOfB, alpha, AL, BB, T(1), ETR );
846 LocalGemm( NORMAL, orientationOfD, alpha, CT, DB, T(1), ETR );
847 }
848
849 FTL.AlignWith( ETL );
850 LocalGemm( orientationOfA, orientationOfB, alpha, AL, BT, FTL );
851 LocalGemm( NORMAL, orientationOfD, alpha, CT, DT, T(1), FTL );
852 AxpyTriangle( uplo, T(1), FTL, ETL );
853
854 FBR.AlignWith( EBR );
855 LocalGemm( orientationOfA, orientationOfB, alpha, AR, BB, FBR );
856 LocalGemm( NORMAL, orientationOfD, alpha, CB, DB, T(1), FBR );
857 AxpyTriangle( uplo, T(1), FBR, EBR );
858 }
859
860 // E := alpha (A^{T/H} B^{T/H} + C^{T/H} D) + beta E
861 template<typename T>
862 inline void
LocalTrr2kKernel(UpperOrLower uplo,Orientation orientationOfA,Orientation orientationOfB,Orientation orientationOfC,T alpha,const DistMatrix<T,STAR,MC> & A,const DistMatrix<T,MR,STAR> & B,const DistMatrix<T,STAR,MC> & C,const DistMatrix<T,STAR,MR> & D,T beta,DistMatrix<T> & E)863 LocalTrr2kKernel
864 ( UpperOrLower uplo,
865 Orientation orientationOfA,
866 Orientation orientationOfB,
867 Orientation orientationOfC,
868 T alpha, const DistMatrix<T,STAR,MC>& A, const DistMatrix<T,MR,STAR>& B,
869 const DistMatrix<T,STAR,MC>& C, const DistMatrix<T,STAR,MR>& D,
870 T beta, DistMatrix<T>& E )
871 {
872 DEBUG_ONLY(
873 CallStackEntry cse("LocalTrr2kKernel");
874 CheckInput( A, B, C, D, E );
875 )
876 const Grid& g = E.Grid();
877
878 DistMatrix<T,STAR,MC> AL(g), AR(g),
879 CL(g), CR(g);
880 DistMatrix<T,MR,STAR> BT(g), BB(g);
881 DistMatrix<T,STAR,MR> DL(g), DR(g);
882 DistMatrix<T> ETL(g), ETR(g),
883 EBL(g), EBR(g);
884 DistMatrix<T> FTL(g), FBR(g);
885
886 const Int half = E.Height()/2;
887 ScaleTrapezoid( beta, uplo, E );
888 LockedPartitionRight( A, AL, AR, half );
889 LockedPartitionDown( B, BT, BB, half );
890 LockedPartitionRight( C, CL, CR, half );
891 LockedPartitionRight( D, DL, DR, half );
892 PartitionDownDiagonal
893 ( E, ETL, ETR,
894 EBL, EBR, half );
895
896 if( uplo == LOWER )
897 {
898 LocalGemm( orientationOfA, orientationOfB, alpha, AR, BT, T(1), EBL );
899 LocalGemm( orientationOfC, NORMAL, alpha, CR, DL, T(1), EBL );
900 }
901 else
902 {
903 LocalGemm( orientationOfA, orientationOfB, alpha, AL, BB, T(1), ETR );
904 LocalGemm( orientationOfC, NORMAL, alpha, CL, DR, T(1), ETR );
905 }
906
907 FTL.AlignWith( ETL );
908 LocalGemm( orientationOfA, orientationOfB, alpha, AL, BT, FTL );
909 LocalGemm( orientationOfC, NORMAL, alpha, CL, DL, T(1), FTL );
910 AxpyTriangle( uplo, T(1), FTL, ETL );
911
912 FBR.AlignWith( EBR );
913 LocalGemm( orientationOfA, orientationOfB, alpha, AR, BB, FBR );
914 LocalGemm( orientationOfC, NORMAL, alpha, CR, DR, T(1), FBR );
915 AxpyTriangle( uplo, T(1), FBR, EBR );
916 }
917
918 // E := alpha (A^{T/H} B^{T/H} + C^{T/H} D^{T/H}) + beta C
919 template<typename T>
920 inline void
LocalTrr2kKernel(UpperOrLower uplo,Orientation orientationOfA,Orientation orientationOfB,Orientation orientationOfC,Orientation orientationOfD,T alpha,const DistMatrix<T,STAR,MC> & A,const DistMatrix<T,MR,STAR> & B,const DistMatrix<T,STAR,MC> & C,const DistMatrix<T,MR,STAR> & D,T beta,DistMatrix<T> & E)921 LocalTrr2kKernel
922 ( UpperOrLower uplo,
923 Orientation orientationOfA, Orientation orientationOfB,
924 Orientation orientationOfC, Orientation orientationOfD,
925 T alpha, const DistMatrix<T,STAR,MC>& A, const DistMatrix<T,MR,STAR>& B,
926 const DistMatrix<T,STAR,MC>& C, const DistMatrix<T,MR,STAR>& D,
927 T beta, DistMatrix<T>& E )
928 {
929 DEBUG_ONLY(
930 CallStackEntry cse("LocalTrr2kKernel");
931 CheckInput( A, B, C, D, E );
932 )
933 const Grid& g = E.Grid();
934
935 DistMatrix<T,STAR,MC> AL(g), AR(g),
936 CL(g), CR(g);
937 DistMatrix<T,MR,STAR> BT(g), DT(g),
938 BB(g), DB(g);
939 DistMatrix<T> ETL(g), ETR(g),
940 EBL(g), EBR(g);
941 DistMatrix<T> FTL(g), FBR(g);
942
943 const Int half = E.Height()/2;
944 ScaleTrapezoid( beta, uplo, E );
945 LockedPartitionRight( A, AL, AR, half );
946 LockedPartitionDown( B, BT, BB, half );
947 LockedPartitionRight( C, CL, CR, half );
948 LockedPartitionDown( D, DT, DB, half );
949 PartitionDownDiagonal
950 ( E, ETL, ETR,
951 EBL, EBR, half );
952
953 if( uplo == LOWER )
954 {
955 LocalGemm( orientationOfA, orientationOfB, alpha, AR, BT, T(1), EBL );
956 LocalGemm( orientationOfC, orientationOfD, alpha, CR, DT, T(1), EBL );
957 }
958 else
959 {
960 LocalGemm( orientationOfA, orientationOfB, alpha, AL, BB, T(1), ETR );
961 LocalGemm( orientationOfC, orientationOfD, alpha, CL, DB, T(1), ETR );
962 }
963
964 FTL.AlignWith( ETL );
965 LocalGemm( orientationOfA, orientationOfB, alpha, AL, BT, FTL );
966 LocalGemm( orientationOfC, orientationOfD, alpha, CL, DT, T(1), FTL );
967 AxpyTriangle( uplo, T(1), FTL, ETL );
968
969 FBR.AlignWith( EBR );
970 LocalGemm( orientationOfA, orientationOfB, alpha, AR, BB, FBR );
971 LocalGemm( orientationOfC, orientationOfD, alpha, CR, DB, T(1), FBR );
972 AxpyTriangle( uplo, T(1), FBR, EBR );
973 }
974
975 } // namespace trr2k
976
977 // E := alpha (A B + C D) + beta E
978 template<typename T>
LocalTrr2k(UpperOrLower uplo,T alpha,const DistMatrix<T,MC,STAR> & A,const DistMatrix<T,STAR,MR> & B,const DistMatrix<T,MC,STAR> & C,const DistMatrix<T,STAR,MR> & D,T beta,DistMatrix<T> & E)979 void LocalTrr2k
980 ( UpperOrLower uplo,
981 T alpha, const DistMatrix<T,MC,STAR>& A, const DistMatrix<T,STAR,MR>& B,
982 const DistMatrix<T,MC,STAR>& C, const DistMatrix<T,STAR,MR>& D,
983 T beta, DistMatrix<T>& E )
984 {
985 using namespace trr2k;
986 DEBUG_ONLY(
987 CallStackEntry cse("LocalTrr2k");
988 CheckInput( A, B, C, D, E );
989 )
990 const Grid& g = E.Grid();
991
992 if( E.Height() < g.Width()*LocalTrr2kBlocksize<T>() )
993 {
994 LocalTrr2kKernel( uplo, alpha, A, B, C, D, beta, E );
995 }
996 else
997 {
998 // Split E in four roughly equal pieces, perform a large gemm on corner
999 // and recurse on ETL and EBR.
1000 DistMatrix<T,MC,STAR> AT(g), CT(g),
1001 AB(g), CB(g);
1002 DistMatrix<T,STAR,MR> BL(g), BR(g),
1003 DL(g), DR(g);
1004 DistMatrix<T> ETL(g), ETR(g),
1005 EBL(g), EBR(g);
1006
1007 const Int half = E.Height() / 2;
1008 LockedPartitionDown( A, AT, AB, half );
1009 LockedPartitionRight( B, BL, BR, half );
1010 LockedPartitionDown( C, CT, CB, half );
1011 LockedPartitionRight( D, DL, DR, half );
1012 PartitionDownDiagonal
1013 ( E, ETL, ETR,
1014 EBL, EBR, half );
1015
1016 if( uplo == LOWER )
1017 {
1018 LocalGemm( NORMAL, NORMAL, alpha, AB, BL, beta, EBL );
1019 LocalGemm( NORMAL, NORMAL, alpha, CB, DL, T(1), EBL );
1020 }
1021 else
1022 {
1023 LocalGemm( NORMAL, NORMAL, alpha, AT, BR, beta, ETR );
1024 LocalGemm( NORMAL, NORMAL, alpha, CT, DR, T(1), ETR );
1025 }
1026
1027 // Recurse
1028 LocalTrr2k( uplo, alpha, AT, BL, CT, DL, beta, ETL );
1029 LocalTrr2k( uplo, alpha, AB, BR, CB, DR, beta, EBR );
1030 }
1031 }
1032
1033 // E := alpha (A B + C D^{T/H}) + beta E
1034 template<typename T>
LocalTrr2k(UpperOrLower uplo,Orientation orientationOfD,T alpha,const DistMatrix<T,MC,STAR> & A,const DistMatrix<T,STAR,MR> & B,const DistMatrix<T,MC,STAR> & C,const DistMatrix<T,MR,STAR> & D,T beta,DistMatrix<T> & E)1035 void LocalTrr2k
1036 ( UpperOrLower uplo, Orientation orientationOfD,
1037 T alpha, const DistMatrix<T,MC,STAR>& A, const DistMatrix<T,STAR,MR>& B,
1038 const DistMatrix<T,MC,STAR>& C, const DistMatrix<T,MR,STAR>& D,
1039 T beta, DistMatrix<T>& E )
1040 {
1041 using namespace trr2k;
1042 DEBUG_ONLY(
1043 CallStackEntry cse("LocalTrr2k");
1044 CheckInput( A, B, C, D, E );
1045 )
1046 const Grid& g = E.Grid();
1047
1048 if( E.Height() < g.Width()*LocalTrr2kBlocksize<T>() )
1049 {
1050 LocalTrr2kKernel( uplo, orientationOfD, alpha, A, B, C, D, beta, E );
1051 }
1052 else
1053 {
1054 // Split E in four roughly equal pieces, perform a large gemm on corner
1055 // and recurse on ETL and EBR.
1056 DistMatrix<T,MC,STAR> AT(g), CT(g),
1057 AB(g), CB(g);
1058 DistMatrix<T,STAR,MR> BL(g), BR(g);
1059 DistMatrix<T,MR,STAR> DT(g), DB(g);
1060 DistMatrix<T> ETL(g), ETR(g),
1061 EBL(g), EBR(g);
1062
1063 const Int half = E.Height() / 2;
1064 LockedPartitionDown( A, AT, AB, half );
1065 LockedPartitionRight( B, BL, BR, half );
1066 LockedPartitionDown( C, CT, CB, half );
1067 LockedPartitionDown( D, DT, DB, half );
1068 PartitionDownDiagonal
1069 ( E, ETL, ETR,
1070 EBL, EBR, half );
1071
1072 if( uplo == LOWER )
1073 {
1074 LocalGemm( NORMAL, NORMAL, alpha, AB, BL, T(1), EBL );
1075 LocalGemm( NORMAL, orientationOfD, alpha, CB, DT, beta, EBL );
1076 }
1077 else
1078 {
1079 LocalGemm( NORMAL, NORMAL, alpha, AT, BR, T(1), ETR );
1080 LocalGemm( NORMAL, orientationOfD, alpha, CT, DB, beta, ETR );
1081 }
1082
1083 // Recurse
1084 LocalTrr2k( uplo, orientationOfD, alpha, AT, BL, CT, DT, beta, ETL );
1085 LocalTrr2k( uplo, orientationOfD, alpha, AB, BR, CB, DB, beta, EBR );
1086 }
1087 }
1088
1089 // E := alpha (A B + C^{T/H} D) + beta E
1090 template<typename T>
LocalTrr2k(UpperOrLower uplo,Orientation orientationOfC,T alpha,const DistMatrix<T,MC,STAR> & A,const DistMatrix<T,STAR,MR> & B,const DistMatrix<T,STAR,MC> & C,const DistMatrix<T,STAR,MR> & D,T beta,DistMatrix<T> & E)1091 void LocalTrr2k
1092 ( UpperOrLower uplo, Orientation orientationOfC,
1093 T alpha, const DistMatrix<T,MC,STAR>& A, const DistMatrix<T,STAR,MR>& B,
1094 const DistMatrix<T,STAR,MC>& C, const DistMatrix<T,STAR,MR>& D,
1095 T beta, DistMatrix<T>& E )
1096 {
1097 using namespace trr2k;
1098 DEBUG_ONLY(
1099 CallStackEntry cse("LocalTrr2k");
1100 CheckInput( A, B, C, D, E );
1101 )
1102 const Grid& g = E.Grid();
1103
1104 if( E.Height() < g.Width()*LocalTrr2kBlocksize<T>() )
1105 {
1106 LocalTrr2kKernel( uplo, orientationOfC, alpha, A, B, C, D, beta, E );
1107 }
1108 else
1109 {
1110 // Split E in four roughly equal pieces, perform a large gemm on corner
1111 // and recurse on ETL and EBR.
1112 DistMatrix<T,MC,STAR> AT(g), AB(g);
1113 DistMatrix<T,STAR,MR> BL(g), BR(g),
1114 DL(g), DR(g);
1115 DistMatrix<T,STAR,MC> CL(g), CR(g);
1116 DistMatrix<T> ETL(g), ETR(g),
1117 EBL(g), EBR(g);
1118
1119 const Int half = E.Height() / 2;
1120 LockedPartitionDown( A, AT, AB, half );
1121 LockedPartitionRight( B, BL, BR, half );
1122 LockedPartitionRight( C, CL, CR, half );
1123 LockedPartitionRight( D, DL, DR, half );
1124 PartitionDownDiagonal
1125 ( E, ETL, ETR,
1126 EBL, EBR, half );
1127
1128 if( uplo == LOWER )
1129 {
1130 LocalGemm( NORMAL, NORMAL, alpha, AB, BL, beta, EBL );
1131 LocalGemm( orientationOfC, NORMAL, alpha, CR, DL, T(1), EBL );
1132 }
1133 else
1134 {
1135 LocalGemm( NORMAL, NORMAL, alpha, AT, BR, beta, ETR );
1136 LocalGemm( orientationOfC, NORMAL, alpha, CL, DR, T(1), ETR );
1137 }
1138
1139 // Recurse
1140 LocalTrr2k( uplo, orientationOfC, alpha, AT, BL, CL, DL, beta, ETL );
1141 LocalTrr2k( uplo, orientationOfC, alpha, AB, BR, CR, DR, beta, EBR );
1142 }
1143 }
1144
1145 // E := alpha (A B + C^{T/H} D^{T/H}) + beta E
1146 template<typename T>
LocalTrr2k(UpperOrLower uplo,Orientation orientationOfC,Orientation orientationOfD,T alpha,const DistMatrix<T,MC,STAR> & A,const DistMatrix<T,STAR,MR> & B,const DistMatrix<T,STAR,MC> & C,const DistMatrix<T,MR,STAR> & D,T beta,DistMatrix<T> & E)1147 void LocalTrr2k
1148 ( UpperOrLower uplo, Orientation orientationOfC, Orientation orientationOfD,
1149 T alpha, const DistMatrix<T,MC,STAR>& A, const DistMatrix<T,STAR,MR>& B,
1150 const DistMatrix<T,STAR,MC>& C, const DistMatrix<T,MR,STAR>& D,
1151 T beta, DistMatrix<T>& E )
1152 {
1153 using namespace trr2k;
1154 DEBUG_ONLY(
1155 CallStackEntry cse("LocalTrr2k");
1156 CheckInput( A, B, C, D, E );
1157 )
1158 const Grid& g = E.Grid();
1159
1160 if( E.Height() < g.Width()*LocalTrr2kBlocksize<T>() )
1161 {
1162 LocalTrr2kKernel
1163 ( uplo, orientationOfC, orientationOfD, alpha, A, B, C, D, beta, E );
1164 }
1165 else
1166 {
1167 // Split E in four roughly equal pieces, perform a large gemm on corner
1168 // and recurse on ETL and EBR.
1169 DistMatrix<T,MC,STAR> AT(g), AB(g);
1170 DistMatrix<T,STAR,MR> BL(g), BR(g);
1171 DistMatrix<T,STAR,MC> CL(g), CR(g);
1172 DistMatrix<T,MR,STAR> DT(g), DB(g);
1173 DistMatrix<T> ETL(g), ETR(g),
1174 EBL(g), EBR(g);
1175
1176 const Int half = E.Height() / 2;
1177 LockedPartitionDown( A, AT, AB, half );
1178 LockedPartitionRight( B, BL, BR, half );
1179 LockedPartitionRight( C, CL, CR, half );
1180 LockedPartitionDown( D, DT, DB, half );
1181 PartitionDownDiagonal
1182 ( E, ETL, ETR,
1183 EBL, EBR, half );
1184
1185 if( uplo == LOWER )
1186 {
1187 LocalGemm( NORMAL, NORMAL, alpha, AB, BL, beta, EBL );
1188 LocalGemm
1189 ( orientationOfC, orientationOfD, alpha, CR, DT, T(1), EBL );
1190 }
1191 else
1192 {
1193 LocalGemm( NORMAL, NORMAL, alpha, AT, BR, beta, ETR );
1194 LocalGemm
1195 ( orientationOfC, orientationOfD, alpha, CL, DB, T(1), ETR );
1196 }
1197
1198 // Recurse
1199 LocalTrr2k
1200 ( uplo, orientationOfC, orientationOfD,
1201 alpha, AT, BL, CL, DT, beta, ETL );
1202 LocalTrr2k
1203 ( uplo, orientationOfC, orientationOfD,
1204 alpha, AB, BR, CR, DB, beta, EBR );
1205 }
1206 }
1207
1208 // E := alpha (A B^{T/H} + C D) + beta E
1209 template<typename T>
LocalTrr2k(UpperOrLower uplo,Orientation orientationOfB,T alpha,const DistMatrix<T,MC,STAR> & A,const DistMatrix<T,MR,STAR> & B,const DistMatrix<T,MC,STAR> & C,const DistMatrix<T,STAR,MR> & D,T beta,DistMatrix<T> & E)1210 void LocalTrr2k
1211 ( UpperOrLower uplo, Orientation orientationOfB,
1212 T alpha, const DistMatrix<T,MC,STAR>& A, const DistMatrix<T,MR,STAR>& B,
1213 const DistMatrix<T,MC,STAR>& C, const DistMatrix<T,STAR,MR>& D,
1214 T beta, DistMatrix<T>& E )
1215 {
1216 using namespace trr2k;
1217 DEBUG_ONLY(
1218 CallStackEntry cse("LocalTrr2k");
1219 CheckInput( A, B, C, D, E );
1220 )
1221 const Grid& g = E.Grid();
1222
1223 if( E.Height() < g.Width()*LocalTrr2kBlocksize<T>() )
1224 {
1225 LocalTrr2kKernel( uplo, orientationOfB, alpha, A, B, C, D, beta, E );
1226 }
1227 else
1228 {
1229 // Split E in four roughly equal pieces, perform a large gemm on corner
1230 // and recurse on ETL and EBR.
1231 DistMatrix<T,MC,STAR> AT(g), CT(g),
1232 AB(g), CB(g);
1233 DistMatrix<T,MR,STAR> BT(g), BB(g);
1234 DistMatrix<T,STAR,MR> DL(g), DR(g);
1235 DistMatrix<T> ETL(g), ETR(g),
1236 EBL(g), EBR(g);
1237
1238 const Int half = E.Height() / 2;
1239 LockedPartitionDown( A, AT, AB, half );
1240 LockedPartitionDown( B, BT, BB, half );
1241 LockedPartitionDown( C, CT, CB, half );
1242 LockedPartitionRight( D, DL, DR, half );
1243 PartitionDownDiagonal
1244 ( E, ETL, ETR,
1245 EBL, EBR, half );
1246
1247 if( uplo == LOWER )
1248 {
1249 LocalGemm( NORMAL, orientationOfB, alpha, AB, BT, T(1), EBL );
1250 LocalGemm( NORMAL, NORMAL, alpha, CB, DL, beta, EBL );
1251 }
1252 else
1253 {
1254 LocalGemm( NORMAL, orientationOfB, alpha, AT, BB, T(1), ETR );
1255 LocalGemm( NORMAL, NORMAL, alpha, CT, DR, beta, ETR );
1256 }
1257
1258 // Recurse
1259 LocalTrr2k( uplo, orientationOfB, alpha, AT, BT, CT, DL, beta, ETL );
1260 LocalTrr2k( uplo, orientationOfB, alpha, AB, BB, CB, DR, beta, EBR );
1261 }
1262 }
1263
1264 // E := alpha (A B^{T/H} + C D^{T/H}) + beta E
1265 template<typename T>
LocalTrr2k(UpperOrLower uplo,Orientation orientationOfB,Orientation orientationOfD,T alpha,const DistMatrix<T,MC,STAR> & A,const DistMatrix<T,MR,STAR> & B,const DistMatrix<T,MC,STAR> & C,const DistMatrix<T,MR,STAR> & D,T beta,DistMatrix<T> & E)1266 void LocalTrr2k
1267 ( UpperOrLower uplo, Orientation orientationOfB, Orientation orientationOfD,
1268 T alpha, const DistMatrix<T,MC,STAR>& A, const DistMatrix<T,MR,STAR>& B,
1269 const DistMatrix<T,MC,STAR>& C, const DistMatrix<T,MR,STAR>& D,
1270 T beta, DistMatrix<T>& E )
1271 {
1272 using namespace trr2k;
1273 DEBUG_ONLY(
1274 CallStackEntry cse("LocalTrr2k");
1275 CheckInput( A, B, C, D, E );
1276 )
1277 const Grid& g = E.Grid();
1278
1279 if( E.Height() < g.Width()*LocalTrr2kBlocksize<T>() )
1280 {
1281 LocalTrr2kKernel
1282 ( uplo, orientationOfB, orientationOfD, alpha, A, B, C, D, beta, E );
1283 }
1284 else
1285 {
1286 // Split E in four roughly equal pieces, perform a large gemm on corner
1287 // and recurse on ETL and EBR.
1288 DistMatrix<T,MC,STAR> AT(g), CT(g),
1289 AB(g), CB(g);
1290 DistMatrix<T,MR,STAR> BT(g), DT(g),
1291 BB(g), DB(g);
1292 DistMatrix<T> ETL(g), ETR(g),
1293 EBL(g), EBR(g);
1294
1295 const Int half = E.Height() / 2;
1296 LockedPartitionDown( A, AT, AB, half );
1297 LockedPartitionDown( B, BT, BB, half );
1298 LockedPartitionDown( C, CT, CB, half );
1299 LockedPartitionDown( D, DT, DB, half );
1300 PartitionDownDiagonal
1301 ( E, ETL, ETR,
1302 EBL, EBR, half );
1303
1304 if( uplo == LOWER )
1305 {
1306 LocalGemm( NORMAL, orientationOfB, alpha, AB, BT, beta, EBL );
1307 LocalGemm( NORMAL, orientationOfD, alpha, CB, DT, T(1), EBL );
1308 }
1309 else
1310 {
1311 LocalGemm( NORMAL, orientationOfB, alpha, AT, BB, beta, ETR );
1312 LocalGemm( NORMAL, orientationOfD, alpha, CT, DB, T(1), ETR );
1313 }
1314
1315 // Recurse
1316 LocalTrr2k
1317 ( uplo, orientationOfB, orientationOfD,
1318 alpha, AT, BT, CT, DT, beta, ETL );
1319 LocalTrr2k
1320 ( uplo, orientationOfB, orientationOfD,
1321 alpha, AB, BB, CB, DB, beta, EBR );
1322 }
1323 }
1324
1325 // E := alpha (A B^{T/H} + C^{T/H} D) + beta E
1326 template<typename T>
LocalTrr2k(UpperOrLower uplo,Orientation orientationOfB,Orientation orientationOfC,T alpha,const DistMatrix<T,MC,STAR> & A,const DistMatrix<T,MR,STAR> & B,const DistMatrix<T,STAR,MC> & C,const DistMatrix<T,STAR,MR> & D,T beta,DistMatrix<T> & E)1327 void LocalTrr2k
1328 ( UpperOrLower uplo, Orientation orientationOfB, Orientation orientationOfC,
1329 T alpha, const DistMatrix<T,MC,STAR>& A, const DistMatrix<T,MR,STAR>& B,
1330 const DistMatrix<T,STAR,MC>& C, const DistMatrix<T,STAR,MR>& D,
1331 T beta, DistMatrix<T>& E )
1332 {
1333 using namespace trr2k;
1334 DEBUG_ONLY(
1335 CallStackEntry cse("LocalTrr2k");
1336 CheckInput( A, B, C, D, E );
1337 )
1338 const Grid& g = E.Grid();
1339
1340 if( E.Height() < g.Width()*LocalTrr2kBlocksize<T>() )
1341 {
1342 LocalTrr2kKernel
1343 ( uplo, orientationOfB, orientationOfC, alpha, A, B, C, D, beta, E );
1344 }
1345 else
1346 {
1347 // Split E in four roughly equal pieces, perform a large gemm on corner
1348 // and recurse on ETL and EBR.
1349 DistMatrix<T,MC,STAR> AT(g), AB(g);
1350 DistMatrix<T,MR,STAR> BT(g), BB(g);
1351 DistMatrix<T,STAR,MC> CL(g), CR(g);
1352 DistMatrix<T,STAR,MR> DL(g), DR(g);
1353 DistMatrix<T> ETL(g), ETR(g),
1354 EBL(g), EBR(g);
1355
1356 const Int half = E.Height() / 2;
1357 LockedPartitionDown( A, AT, AB, half );
1358 LockedPartitionDown( B, BT, BB, half );
1359 LockedPartitionRight( C, CL, CR, half );
1360 LockedPartitionRight( D, DL, DR, half );
1361 PartitionDownDiagonal
1362 ( E, ETL, ETR,
1363 EBL, EBR, half );
1364
1365 if( uplo == LOWER )
1366 {
1367 LocalGemm( NORMAL, orientationOfB, alpha, AB, BT, beta, EBL );
1368 LocalGemm( orientationOfC, NORMAL, alpha, CR, DL, T(1), EBL );
1369 }
1370 else
1371 {
1372 LocalGemm( NORMAL, orientationOfB, alpha, AT, BB, beta, ETR );
1373 LocalGemm( orientationOfC, NORMAL, alpha, CL, DR, T(1), ETR );
1374 }
1375
1376 // Recurse
1377 LocalTrr2k
1378 ( uplo, orientationOfB, orientationOfC,
1379 alpha, AT, BT, CL, DL, beta, ETL );
1380 LocalTrr2k
1381 ( uplo, orientationOfB, orientationOfC,
1382 alpha, AB, BB, CR, DR, beta, EBR );
1383 }
1384 }
1385
1386 // E := alpha (A B^{T/H} + C^{T/H} D^{T/H}) + beta E
1387 template<typename T>
LocalTrr2k(UpperOrLower uplo,Orientation orientationOfB,Orientation orientationOfC,Orientation orientationOfD,T alpha,const DistMatrix<T,MC,STAR> & A,const DistMatrix<T,MR,STAR> & B,const DistMatrix<T,STAR,MC> & C,const DistMatrix<T,MR,STAR> & D,T beta,DistMatrix<T> & E)1388 void LocalTrr2k
1389 ( UpperOrLower uplo,
1390 Orientation orientationOfB,
1391 Orientation orientationOfC,
1392 Orientation orientationOfD,
1393 T alpha, const DistMatrix<T,MC,STAR>& A, const DistMatrix<T,MR,STAR>& B,
1394 const DistMatrix<T,STAR,MC>& C, const DistMatrix<T,MR,STAR>& D,
1395 T beta, DistMatrix<T>& E )
1396 {
1397 using namespace trr2k;
1398 DEBUG_ONLY(
1399 CallStackEntry cse("LocalTrr2k");
1400 CheckInput( A, B, C, D, E );
1401 )
1402 const Grid& g = E.Grid();
1403
1404 if( E.Height() < g.Width()*LocalTrr2kBlocksize<T>() )
1405 {
1406 LocalTrr2kKernel
1407 ( uplo, orientationOfB, orientationOfC, orientationOfD,
1408 alpha, A, B, C, D, beta, E );
1409 }
1410 else
1411 {
1412 // Split E in four roughly equal pieces, perform a large gemm on corner
1413 // and recurse on ETL and EBR.
1414 DistMatrix<T,MC,STAR> AT(g), AB(g);
1415 DistMatrix<T,MR,STAR> BT(g), DT(g),
1416 BB(g), DB(g);
1417 DistMatrix<T,STAR,MC> CL(g), CR(g);
1418 DistMatrix<T> ETL(g), ETR(g),
1419 EBL(g), EBR(g);
1420
1421 const Int half = E.Height() / 2;
1422 LockedPartitionDown( A, AT, AB, half );
1423 LockedPartitionDown( B, BT, BB, half );
1424 LockedPartitionRight( C, CL, CR, half );
1425 LockedPartitionDown( D, DT, DB, half );
1426 PartitionDownDiagonal
1427 ( E, ETL, ETR,
1428 EBL, EBR, half );
1429
1430 if( uplo == LOWER )
1431 {
1432 LocalGemm( NORMAL, orientationOfB, alpha, AB, BT, beta, EBL );
1433 LocalGemm
1434 ( orientationOfC, orientationOfD, alpha, CR, DT, T(1), EBL );
1435 }
1436 else
1437 {
1438 LocalGemm( NORMAL, orientationOfB, alpha, AT, BB, beta, ETR );
1439 LocalGemm
1440 ( orientationOfC, orientationOfD, alpha, CL, DB, T(1), ETR );
1441 }
1442
1443 // Recurse
1444 LocalTrr2k
1445 ( uplo, orientationOfB, orientationOfC, orientationOfD,
1446 alpha, AT, BT, CL, DT, beta, ETL );
1447 LocalTrr2k
1448 ( uplo, orientationOfB, orientationOfC, orientationOfD,
1449 alpha, AB, BB, CR, DB, beta, EBR );
1450 }
1451 }
1452
1453 // E := alpha (A^{T/H} B + C D) + beta E
1454 template<typename T>
LocalTrr2k(UpperOrLower uplo,Orientation orientationOfA,T alpha,const DistMatrix<T,STAR,MC> & A,const DistMatrix<T,STAR,MR> & B,const DistMatrix<T,MC,STAR> & C,const DistMatrix<T,STAR,MR> & D,T beta,DistMatrix<T> & E)1455 void LocalTrr2k
1456 ( UpperOrLower uplo, Orientation orientationOfA,
1457 T alpha, const DistMatrix<T,STAR,MC>& A, const DistMatrix<T,STAR,MR>& B,
1458 const DistMatrix<T,MC,STAR>& C, const DistMatrix<T,STAR,MR>& D,
1459 T beta, DistMatrix<T>& E )
1460 {
1461 using namespace trr2k;
1462 DEBUG_ONLY(
1463 CallStackEntry cse("LocalTrr2k");
1464 CheckInput( A, B, C, D, E );
1465 )
1466 const Grid& g = E.Grid();
1467
1468 if( E.Height() < g.Width()*LocalTrr2kBlocksize<T>() )
1469 {
1470 LocalTrr2kKernel( uplo, orientationOfA, alpha, A, B, C, D, beta, E );
1471 }
1472 else
1473 {
1474 // Split E in four roughly equal pieces, perform a large gemm on corner
1475 // and recurse on ETL and EBR.
1476 DistMatrix<T,STAR,MC> AL(g), AR(g);
1477 DistMatrix<T,STAR,MR> BL(g), BR(g),
1478 DL(g), DR(g);
1479 DistMatrix<T,MC,STAR> CT(g), CB(g);
1480 DistMatrix<T> ETL(g), ETR(g),
1481 EBL(g), EBR(g);
1482
1483 const Int half = E.Height() / 2;
1484 LockedPartitionRight( A, AL, AR, half );
1485 LockedPartitionRight( B, BL, BR, half );
1486 LockedPartitionDown( C, CT, CB, half );
1487 LockedPartitionRight( D, DL, DR, half );
1488 PartitionDownDiagonal
1489 ( E, ETL, ETR,
1490 EBL, EBR, half );
1491
1492 if( uplo == LOWER )
1493 {
1494 LocalGemm( orientationOfA, NORMAL, alpha, AR, BL, beta, EBL );
1495 LocalGemm( NORMAL, NORMAL, alpha, CB, DL, T(1), EBL );
1496 }
1497 else
1498 {
1499 LocalGemm( orientationOfA, NORMAL, alpha, AL, BR, beta, ETR );
1500 LocalGemm( NORMAL, NORMAL, alpha, CT, DR, T(1), ETR );
1501 }
1502
1503 // Recurse
1504 LocalTrr2k( uplo, orientationOfA, alpha, AL, BL, CT, DL, beta, ETL );
1505 LocalTrr2k( uplo, orientationOfA, alpha, AR, BR, CB, DR, beta, EBR );
1506 }
1507 }
1508
1509 // E := alpha (A^{T/H} B + C D^{T/H}) + beta E
1510 template<typename T>
LocalTrr2k(UpperOrLower uplo,Orientation orientationOfA,Orientation orientationOfD,T alpha,const DistMatrix<T,STAR,MC> & A,const DistMatrix<T,STAR,MR> & B,const DistMatrix<T,MC,STAR> & C,const DistMatrix<T,MR,STAR> & D,T beta,DistMatrix<T> & E)1511 void LocalTrr2k
1512 ( UpperOrLower uplo, Orientation orientationOfA, Orientation orientationOfD,
1513 T alpha, const DistMatrix<T,STAR,MC>& A, const DistMatrix<T,STAR,MR>& B,
1514 const DistMatrix<T,MC,STAR>& C, const DistMatrix<T,MR,STAR>& D,
1515 T beta, DistMatrix<T>& E )
1516 {
1517 using namespace trr2k;
1518 DEBUG_ONLY(
1519 CallStackEntry cse("LocalTrr2k");
1520 CheckInput( A, B, C, D, E );
1521 )
1522 const Grid& g = E.Grid();
1523
1524 if( E.Height() < g.Width()*LocalTrr2kBlocksize<T>() )
1525 {
1526 LocalTrr2kKernel
1527 ( uplo, orientationOfA, orientationOfD, alpha, A, B, C, D, beta, E );
1528 }
1529 else
1530 {
1531 // Split E in four roughly equal pieces, perform a large gemm on corner
1532 // and recurse on ETL and EBR.
1533 DistMatrix<T,STAR,MC> AL(g), AR(g);
1534 DistMatrix<T,STAR,MR> BL(g), BR(g);
1535 DistMatrix<T,MC,STAR> CT(g), CB(g);
1536 DistMatrix<T,MR,STAR> DT(g), DB(g);
1537 DistMatrix<T> ETL(g), ETR(g),
1538 EBL(g), EBR(g);
1539
1540 const Int half = E.Height() / 2;
1541 LockedPartitionRight( A, AL, AR, half );
1542 LockedPartitionRight( B, BL, BR, half );
1543 LockedPartitionDown( C, CT, CB, half );
1544 LockedPartitionDown( D, DT, DB, half );
1545 PartitionDownDiagonal
1546 ( E, ETL, ETR,
1547 EBL, EBR, half );
1548
1549 if( uplo == LOWER )
1550 {
1551 LocalGemm( orientationOfA, NORMAL, alpha, AR, BL, beta, EBL );
1552 LocalGemm( NORMAL, orientationOfD, alpha, CB, DT, T(1), EBL );
1553 }
1554 else
1555 {
1556 LocalGemm( orientationOfA, NORMAL, alpha, AL, BR, beta, ETR );
1557 LocalGemm( NORMAL, orientationOfD, alpha, CT, DB, T(1), ETR );
1558 }
1559
1560 // Recurse
1561 LocalTrr2k
1562 ( uplo, orientationOfA, orientationOfD,
1563 alpha, AL, BL, CT, DT, beta, ETL );
1564 LocalTrr2k
1565 ( uplo, orientationOfA, orientationOfD,
1566 alpha, AR, BR, CB, DB, beta, EBR );
1567 }
1568 }
1569
1570 // E := alpha (A^{T/H} B + C^{T/H} D) + beta E
1571 template<typename T>
LocalTrr2k(UpperOrLower uplo,Orientation orientationOfA,Orientation orientationOfC,T alpha,const DistMatrix<T,STAR,MC> & A,const DistMatrix<T,STAR,MR> & B,const DistMatrix<T,STAR,MC> & C,const DistMatrix<T,STAR,MR> & D,T beta,DistMatrix<T> & E)1572 void LocalTrr2k
1573 ( UpperOrLower uplo, Orientation orientationOfA, Orientation orientationOfC,
1574 T alpha, const DistMatrix<T,STAR,MC>& A, const DistMatrix<T,STAR,MR>& B,
1575 const DistMatrix<T,STAR,MC>& C, const DistMatrix<T,STAR,MR>& D,
1576 T beta, DistMatrix<T>& E )
1577 {
1578 using namespace trr2k;
1579 DEBUG_ONLY(
1580 CallStackEntry cse("LocalTrr2k");
1581 CheckInput( A, B, C, D, E );
1582 )
1583 const Grid& g = E.Grid();
1584
1585 if( E.Height() < g.Width()*LocalTrr2kBlocksize<T>() )
1586 {
1587 LocalTrr2kKernel
1588 ( uplo, orientationOfA, orientationOfC, alpha, A, B, C, D, beta, E );
1589 }
1590 else
1591 {
1592 // Split E in four roughly equal pieces, perform a large gemm on corner
1593 // and recurse on ETL and EBR.
1594 DistMatrix<T,STAR,MC> AL(g), AR(g),
1595 CL(g), CR(g);
1596 DistMatrix<T,STAR,MR> BL(g), BR(g),
1597 DL(g), DR(g);
1598 DistMatrix<T> ETL(g), ETR(g),
1599 EBL(g), EBR(g);
1600
1601 const Int half = E.Height() / 2;
1602 LockedPartitionRight( A, AL, AR, half );
1603 LockedPartitionRight( B, BL, BR, half );
1604 LockedPartitionRight( C, CL, CR, half );
1605 LockedPartitionRight( D, DL, DR, half );
1606 PartitionDownDiagonal
1607 ( E, ETL, ETR,
1608 EBL, EBR, half );
1609
1610 if( uplo == LOWER )
1611 {
1612 LocalGemm( orientationOfA, NORMAL, alpha, AR, BL, beta, EBL );
1613 LocalGemm( orientationOfC, NORMAL, alpha, CR, DL, T(1), EBL );
1614 }
1615 else
1616 {
1617 LocalGemm( orientationOfA, NORMAL, alpha, AL, BR, beta, ETR );
1618 LocalGemm( orientationOfC, NORMAL, alpha, CL, DR, T(1), ETR );
1619 }
1620
1621 // Recurse
1622 LocalTrr2k
1623 ( uplo, orientationOfA, orientationOfC,
1624 alpha, AL, BL, CL, DL, beta, ETL );
1625 LocalTrr2k
1626 ( uplo, orientationOfA, orientationOfC,
1627 alpha, AR, BR, CR, DR, beta, EBR );
1628 }
1629 }
1630
1631 // E := alpha (A^{T/H} B + C^{T/H} D^{T/H}) + beta E
1632 template<typename T>
LocalTrr2k(UpperOrLower uplo,Orientation orientationOfA,Orientation orientationOfC,Orientation orientationOfD,T alpha,const DistMatrix<T,STAR,MC> & A,const DistMatrix<T,STAR,MR> & B,const DistMatrix<T,STAR,MC> & C,const DistMatrix<T,MR,STAR> & D,T beta,DistMatrix<T> & E)1633 void LocalTrr2k
1634 ( UpperOrLower uplo,
1635 Orientation orientationOfA,
1636 Orientation orientationOfC,
1637 Orientation orientationOfD,
1638 T alpha, const DistMatrix<T,STAR,MC>& A, const DistMatrix<T,STAR,MR>& B,
1639 const DistMatrix<T,STAR,MC>& C, const DistMatrix<T,MR,STAR>& D,
1640 T beta, DistMatrix<T>& E )
1641 {
1642 using namespace trr2k;
1643 DEBUG_ONLY(
1644 CallStackEntry cse("LocalTrr2k");
1645 CheckInput( A, B, C, D, E );
1646 )
1647 const Grid& g = E.Grid();
1648
1649 if( E.Height() < g.Width()*LocalTrr2kBlocksize<T>() )
1650 {
1651 LocalTrr2kKernel
1652 ( uplo, orientationOfA, orientationOfC, orientationOfD,
1653 alpha, A, B, C, D, beta, E );
1654 }
1655 else
1656 {
1657 // Split E in four roughly equal pieces, perform a large gemm on corner
1658 // and recurse on ETL and EBR.
1659 DistMatrix<T,STAR,MC> AL(g), AR(g),
1660 CL(g), CR(g);
1661 DistMatrix<T,STAR,MR> BL(g), BR(g);
1662 DistMatrix<T,MR,STAR> DT(g), DB(g);
1663 DistMatrix<T> ETL(g), ETR(g),
1664 EBL(g), EBR(g);
1665
1666 const Int half = E.Height() / 2;
1667 LockedPartitionRight( A, AL, AR, half );
1668 LockedPartitionRight( B, BL, BR, half );
1669 LockedPartitionRight( C, CL, CR, half );
1670 LockedPartitionDown( D, DT, DB, half );
1671 PartitionDownDiagonal
1672 ( E, ETL, ETR,
1673 EBL, EBR, half );
1674
1675 if( uplo == LOWER )
1676 {
1677 LocalGemm( orientationOfA, NORMAL, alpha, AR, BL, beta, EBL );
1678 LocalGemm
1679 ( orientationOfC, orientationOfD, alpha, CR, DT, T(1), EBL );
1680 }
1681 else
1682 {
1683 LocalGemm( orientationOfA, NORMAL, alpha, AL, BR, beta, ETR );
1684 LocalGemm
1685 ( orientationOfC, orientationOfD, alpha, CL, DB, T(1), ETR );
1686 }
1687
1688 // Recurse
1689 LocalTrr2k
1690 ( uplo, orientationOfA, orientationOfC, orientationOfD,
1691 alpha, AL, BL, CL, DT, beta, ETL );
1692 LocalTrr2k
1693 ( uplo, orientationOfA, orientationOfC, orientationOfD,
1694 alpha, AR, BR, CR, DB, beta, EBR );
1695 }
1696 }
1697
1698 // E := alpha (A^{T/H} B^{T/H} + C D) + beta E
1699 template<typename T>
LocalTrr2k(UpperOrLower uplo,Orientation orientationOfA,Orientation orientationOfB,T alpha,const DistMatrix<T,STAR,MC> & A,const DistMatrix<T,MR,STAR> & B,const DistMatrix<T,MC,STAR> & C,const DistMatrix<T,STAR,MR> & D,T beta,DistMatrix<T> & E)1700 void LocalTrr2k
1701 ( UpperOrLower uplo, Orientation orientationOfA, Orientation orientationOfB,
1702 T alpha, const DistMatrix<T,STAR,MC>& A, const DistMatrix<T,MR,STAR>& B,
1703 const DistMatrix<T,MC,STAR>& C, const DistMatrix<T,STAR,MR>& D,
1704 T beta, DistMatrix<T>& E )
1705 {
1706 using namespace trr2k;
1707 DEBUG_ONLY(
1708 CallStackEntry cse("LocalTrr2k");
1709 CheckInput( A, B, C, D, E );
1710 )
1711 const Grid& g = E.Grid();
1712
1713 if( E.Height() < g.Width()*LocalTrr2kBlocksize<T>() )
1714 {
1715 LocalTrr2kKernel
1716 ( uplo, orientationOfA, orientationOfB, alpha, A, B, C, D, beta, E );
1717 }
1718 else
1719 {
1720 // Split E in four roughly equal pieces, perform a large gemm on corner
1721 // and recurse on ETL and EBR.
1722 DistMatrix<T,STAR,MC> AL(g), AR(g);
1723 DistMatrix<T,MR,STAR> BT(g), BB(g);
1724 DistMatrix<T,MC,STAR> CT(g), CB(g);
1725 DistMatrix<T,STAR,MR> DL(g), DR(g);
1726 DistMatrix<T> ETL(g), ETR(g),
1727 EBL(g), EBR(g);
1728
1729 const Int half = E.Height() / 2;
1730 LockedPartitionRight( A, AL, AR, half );
1731 LockedPartitionDown( B, BT, BB, half );
1732 LockedPartitionDown( C, CT, CB, half );
1733 LockedPartitionRight( D, DL, DR, half );
1734 PartitionDownDiagonal
1735 ( E, ETL, ETR,
1736 EBL, EBR, half );
1737
1738 if( uplo == LOWER )
1739 {
1740 LocalGemm
1741 ( orientationOfA, orientationOfB, alpha, AR, BT, beta, EBL );
1742 LocalGemm( NORMAL, NORMAL, alpha, CB, DL, T(1), EBL );
1743 }
1744 else
1745 {
1746 LocalGemm
1747 ( orientationOfA, orientationOfB, alpha, AL, BB, beta, ETR );
1748 LocalGemm( NORMAL, NORMAL, alpha, CT, DR, T(1), ETR );
1749 }
1750
1751 // Recurse
1752 LocalTrr2k
1753 ( uplo, orientationOfA, orientationOfB,
1754 alpha, AL, BT, CT, DL, beta, ETL );
1755 LocalTrr2k
1756 ( uplo, orientationOfA, orientationOfB,
1757 alpha, AR, BB, CB, DR, beta, EBR );
1758 }
1759 }
1760
1761 // E := alpha (A^{T/H} B^{T/H} + C D^{T/H}) + beta E
1762 template<typename T>
LocalTrr2k(UpperOrLower uplo,Orientation orientationOfA,Orientation orientationOfB,Orientation orientationOfD,T alpha,const DistMatrix<T,STAR,MC> & A,const DistMatrix<T,MR,STAR> & B,const DistMatrix<T,MC,STAR> & C,const DistMatrix<T,MR,STAR> & D,T beta,DistMatrix<T> & E)1763 void LocalTrr2k
1764 ( UpperOrLower uplo,
1765 Orientation orientationOfA,
1766 Orientation orientationOfB,
1767 Orientation orientationOfD,
1768 T alpha, const DistMatrix<T,STAR,MC>& A, const DistMatrix<T,MR,STAR>& B,
1769 const DistMatrix<T,MC,STAR>& C, const DistMatrix<T,MR,STAR>& D,
1770 T beta, DistMatrix<T>& E )
1771 {
1772 using namespace trr2k;
1773 DEBUG_ONLY(
1774 CallStackEntry cse("LocalTrr2k");
1775 CheckInput( A, B, C, D, E );
1776 )
1777 const Grid& g = E.Grid();
1778
1779 if( E.Height() < g.Width()*LocalTrr2kBlocksize<T>() )
1780 {
1781 LocalTrr2kKernel
1782 ( uplo, orientationOfA, orientationOfB, orientationOfD,
1783 alpha, A, B, C, D, beta, E );
1784 }
1785 else
1786 {
1787 // Split E in four roughly equal pieces, perform a large gemm on corner
1788 // and recurse on ETL and EBR.
1789 DistMatrix<T,STAR,MC> AL(g), AR(g);
1790 DistMatrix<T,MR,STAR> BT(g), DT(g),
1791 BB(g), DB(g);
1792 DistMatrix<T,MC,STAR> CT(g), CB(g);
1793 DistMatrix<T> ETL(g), ETR(g),
1794 EBL(g), EBR(g);
1795
1796 const Int half = E.Height() / 2;
1797 LockedPartitionRight( A, AL, AR, half );
1798 LockedPartitionDown( B, BT, BB, half );
1799 LockedPartitionDown( C, CT, CB, half );
1800 LockedPartitionDown( D, DT, DB, half );
1801 PartitionDownDiagonal
1802 ( E, ETL, ETR,
1803 EBL, EBR, half );
1804
1805 if( uplo == LOWER )
1806 {
1807 LocalGemm
1808 ( orientationOfA, orientationOfB, alpha, AR, BT, beta, EBL );
1809 LocalGemm( NORMAL, orientationOfD, alpha, CB, DT, T(1), EBL );
1810 }
1811 else
1812 {
1813 LocalGemm
1814 ( orientationOfA, orientationOfB, alpha, AL, BB, beta, ETR );
1815 LocalGemm( NORMAL, orientationOfD, alpha, CT, DB, T(1), ETR );
1816 }
1817
1818 // Recurse
1819 LocalTrr2k
1820 ( uplo, orientationOfA, orientationOfB, orientationOfD,
1821 alpha, AL, BT, CT, DT, beta, ETL );
1822 LocalTrr2k
1823 ( uplo, orientationOfA, orientationOfB, orientationOfD,
1824 alpha, AR, BB, CB, DB, beta, EBR );
1825 }
1826 }
1827
1828 // E := alpha (A^{T/H} B^{T/H} + C^{T/H} D) + beta E
1829 template<typename T>
LocalTrr2k(UpperOrLower uplo,Orientation orientationOfA,Orientation orientationOfB,Orientation orientationOfC,T alpha,const DistMatrix<T,STAR,MC> & A,const DistMatrix<T,MR,STAR> & B,const DistMatrix<T,STAR,MC> & C,const DistMatrix<T,STAR,MR> & D,T beta,DistMatrix<T> & E)1830 void LocalTrr2k
1831 ( UpperOrLower uplo,
1832 Orientation orientationOfA,
1833 Orientation orientationOfB,
1834 Orientation orientationOfC,
1835 T alpha, const DistMatrix<T,STAR,MC>& A, const DistMatrix<T,MR,STAR>& B,
1836 const DistMatrix<T,STAR,MC>& C, const DistMatrix<T,STAR,MR>& D,
1837 T beta, DistMatrix<T>& E )
1838 {
1839 using namespace trr2k;
1840 DEBUG_ONLY(
1841 CallStackEntry cse("LocalTrr2k");
1842 CheckInput( A, B, C, D, E );
1843 )
1844 const Grid& g = E.Grid();
1845
1846 if( E.Height() < g.Width()*LocalTrr2kBlocksize<T>() )
1847 {
1848 LocalTrr2kKernel
1849 ( uplo, orientationOfA, orientationOfB, orientationOfC,
1850 alpha, A, B, C, D, beta, E );
1851 }
1852 else
1853 {
1854 // Split E in four roughly equal pieces, perform a large gemm on corner
1855 // and recurse on ETL and EBR.
1856 DistMatrix<T,STAR,MC> AL(g), AR(g),
1857 CL(g), CR(g);
1858 DistMatrix<T,MR,STAR> BT(g), BB(g);
1859 DistMatrix<T,STAR,MR> DL(g), DR(g);
1860 DistMatrix<T> ETL(g), ETR(g),
1861 EBL(g), EBR(g);
1862
1863 const Int half = E.Height() / 2;
1864 LockedPartitionRight( A, AL, AR, half );
1865 LockedPartitionDown( B, BT, BB, half );
1866 LockedPartitionRight( C, CL, CR, half );
1867 LockedPartitionRight( D, DL, DR, half );
1868 PartitionDownDiagonal
1869 ( E, ETL, ETR,
1870 EBL, EBR, half );
1871
1872 if( uplo == LOWER )
1873 {
1874 LocalGemm
1875 ( orientationOfA, orientationOfB, alpha, AR, BT, beta, EBL );
1876 LocalGemm( orientationOfC, NORMAL, alpha, CR, DL, T(1), EBL );
1877 }
1878 else
1879 {
1880 LocalGemm
1881 ( orientationOfA, orientationOfB, alpha, AL, BB, beta, ETR );
1882 LocalGemm( orientationOfC, NORMAL, alpha, CL, DR, T(1), ETR );
1883 }
1884
1885 // Recurse
1886 LocalTrr2k
1887 ( uplo, orientationOfA, orientationOfB, orientationOfC,
1888 alpha, AL, BT, CL, DL, beta, ETL );
1889 LocalTrr2k
1890 ( uplo, orientationOfA, orientationOfB, orientationOfC,
1891 alpha, AR, BB, CR, DR, beta, EBR );
1892 }
1893 }
1894
1895 // E := alpha (A^{T/H} B^{T/H} + C^{T/H} D^{T/H}) + beta E
1896 template<typename T>
LocalTrr2k(UpperOrLower uplo,Orientation orientationOfA,Orientation orientationOfB,Orientation orientationOfC,Orientation orientationOfD,T alpha,const DistMatrix<T,STAR,MC> & A,const DistMatrix<T,MR,STAR> & B,const DistMatrix<T,STAR,MC> & C,const DistMatrix<T,MR,STAR> & D,T beta,DistMatrix<T> & E)1897 void LocalTrr2k
1898 ( UpperOrLower uplo,
1899 Orientation orientationOfA,
1900 Orientation orientationOfB,
1901 Orientation orientationOfC,
1902 Orientation orientationOfD,
1903 T alpha, const DistMatrix<T,STAR,MC>& A, const DistMatrix<T,MR,STAR>& B,
1904 const DistMatrix<T,STAR,MC>& C, const DistMatrix<T,MR,STAR>& D,
1905 T beta, DistMatrix<T>& E )
1906 {
1907 using namespace trr2k;
1908 DEBUG_ONLY(
1909 CallStackEntry cse("LocalTrr2k");
1910 CheckInput( A, B, C, D, E );
1911 )
1912 const Grid& g = E.Grid();
1913
1914 if( E.Height() < g.Width()*LocalTrr2kBlocksize<T>() )
1915 {
1916 LocalTrr2kKernel
1917 ( uplo, orientationOfA, orientationOfB, orientationOfC, orientationOfD,
1918 alpha, A, B, C, D, beta, E );
1919 }
1920 else
1921 {
1922 // Split E in four roughly equal pieces, perform a large gemm on corner
1923 // and recurse on ETL and EBR.
1924 DistMatrix<T,STAR,MC> AL(g), AR(g),
1925 CL(g), CR(g);
1926 DistMatrix<T,MR,STAR> BT(g), DT(g),
1927 BB(g), DB(g);
1928 DistMatrix<T> ETL(g), ETR(g),
1929 EBL(g), EBR(g);
1930
1931 const Int half = E.Height() / 2;
1932 LockedPartitionRight( A, AL, AR, half );
1933 LockedPartitionDown( B, BT, BB, half );
1934 LockedPartitionRight( C, CL, CR, half );
1935 LockedPartitionDown( D, DT, DB, half );
1936 PartitionDownDiagonal
1937 ( E, ETL, ETR,
1938 EBL, EBR, half );
1939
1940 if( uplo == LOWER )
1941 {
1942 LocalGemm
1943 ( orientationOfA, orientationOfB, alpha, AR, BT, beta, EBL );
1944 LocalGemm
1945 ( orientationOfC, orientationOfD, alpha, CR, DT, T(1), EBL );
1946 }
1947 else
1948 {
1949 LocalGemm
1950 ( orientationOfA, orientationOfB, alpha, AL, BB, beta, ETR );
1951 LocalGemm
1952 ( orientationOfC, orientationOfD, alpha, CL, DB, T(1), ETR );
1953 }
1954
1955 // Recurse
1956 LocalTrr2k
1957 ( uplo,
1958 orientationOfA, orientationOfB, orientationOfC, orientationOfD,
1959 alpha, AL, BT, CL, DT, beta, ETL );
1960
1961 LocalTrr2k
1962 ( uplo,
1963 orientationOfA, orientationOfB, orientationOfC, orientationOfD,
1964 alpha, AR, BB, CR, DB, beta, EBR );
1965 }
1966 }
1967
1968 } // namespace elem
1969
1970 #endif // ifndef ELEM_TRR2K_LOCAL_HPP
1971