1 /*
2    Copyright (c) 2009-2014, Jack Poulson
3    All rights reserved.
4 
5    This file is part of Elemental and is under the BSD 2-Clause License,
6    which can be found in the LICENSE file in the root directory, or at
7    http://opensource.org/licenses/BSD-2-Clause
8 */
9 #ifndef ELEM_TRR2K_LOCAL_HPP
10 #define ELEM_TRR2K_LOCAL_HPP
11 
12 #include ELEM_AXPYTRIANGLE_INC
13 #include ELEM_SCALETRAPEZOID_INC
14 #include ELEM_GEMM_INC
15 
16 namespace elem {
17 
18 namespace trr2k {
19 
20 #ifndef ELEM_RELEASE
21 
EnsureSame(const Grid & gA,const Grid & gB,const Grid & gC,const Grid & gD,const Grid & gE)22 void EnsureSame
23 ( const Grid& gA, const Grid& gB, const Grid& gC,
24   const Grid& gD, const Grid& gE )
25 {
26     if( gA != gB || gB != gC || gC != gD || gD != gE )
27         LogicError("Grids must be the same");
28 }
29 
30 template<typename T>
EnsureConformal(const DistMatrix<T,MC,STAR> & A,const DistMatrix<T> & E,std::string name)31 void EnsureConformal
32 ( const DistMatrix<T,MC,STAR>& A, const DistMatrix<T>& E, std::string name )
33 {
34     if( A.Height() != E.Height() || A.ColAlign() != E.ColAlign() )
35         LogicError(name," not conformal with E");
36 }
37 
38 template<typename T>
EnsureConformal(const DistMatrix<T,STAR,MC> & A,const DistMatrix<T> & E,std::string name)39 void EnsureConformal
40 ( const DistMatrix<T,STAR,MC>& A, const DistMatrix<T>& E, std::string name )
41 {
42     if( A.Width() != E.Height() || A.RowAlign() != E.ColAlign() )
43         LogicError(name," not conformal with E");
44 }
45 
46 template<typename T>
EnsureConformal(const DistMatrix<T,MR,STAR> & A,const DistMatrix<T> & E,std::string name)47 void EnsureConformal
48 ( const DistMatrix<T,MR,STAR>& A, const DistMatrix<T>& E, std::string name )
49 {
50     if( A.Height() != E.Width() || A.ColAlign() != E.RowAlign() )
51         LogicError(name," not conformal with E");
52 }
53 
54 template<typename T>
EnsureConformal(const DistMatrix<T,STAR,MR> & A,const DistMatrix<T> & E,std::string name)55 void EnsureConformal
56 ( const DistMatrix<T,STAR,MR>& A, const DistMatrix<T>& E, std::string name )
57 {
58     if( A.Width() != E.Width() || A.RowAlign() != E.RowAlign() )
59         LogicError(name," not conformal with E");
60 }
61 
62 template<typename T,Distribution UA,Distribution VA,
63                     Distribution UB,Distribution VB,
64                     Distribution UC,Distribution VC,
65                     Distribution UD,Distribution VD>
CheckInput(const DistMatrix<T,UA,VA> & A,const DistMatrix<T,UB,VB> & B,const DistMatrix<T,UC,VC> & C,const DistMatrix<T,UD,VD> & D,const DistMatrix<T> & E)66 void CheckInput
67 ( const DistMatrix<T,UA,VA>& A, const DistMatrix<T,UB,VB>& B,
68   const DistMatrix<T,UC,VC>& C, const DistMatrix<T,UD,VD>& D,
69   const DistMatrix<T>& E )
70 {
71     EnsureSame( A.Grid(), B.Grid(), C.Grid(), D.Grid(), E.Grid() );
72     EnsureConformal( A, E, "A" );
73     EnsureConformal( B, E, "B" );
74     EnsureConformal( C, E, "C" );
75     EnsureConformal( D, E, "D" );
76 }
77 
78 #endif // ifndef ELEM_RELEASE
79 
80 // E := alpha (A B + C D) + beta E
81 template<typename T>
82 inline void
LocalTrr2kKernel(UpperOrLower uplo,T alpha,const DistMatrix<T,MC,STAR> & A,const DistMatrix<T,STAR,MR> & B,const DistMatrix<T,MC,STAR> & C,const DistMatrix<T,STAR,MR> & D,T beta,DistMatrix<T> & E)83 LocalTrr2kKernel
84 ( UpperOrLower uplo,
85   T alpha, const DistMatrix<T,MC,STAR>& A, const DistMatrix<T,STAR,MR>& B,
86            const DistMatrix<T,MC,STAR>& C, const DistMatrix<T,STAR,MR>& D,
87   T beta,        DistMatrix<T>& E )
88 {
89     DEBUG_ONLY(
90         CallStackEntry cse("LocalTrr2kKernel");
91         CheckInput( A, B, C, D, E );
92     )
93     const Grid& g = E.Grid();
94 
95     DistMatrix<T,MC,STAR> AT(g),  CT(g),
96                           AB(g),  CB(g);
97     DistMatrix<T,STAR,MR> BL(g), BR(g),
98                           DL(g), DR(g);
99     DistMatrix<T> ETL(g), ETR(g),
100                   EBL(g), EBR(g);
101     DistMatrix<T> FTL(g), FBR(g);
102 
103     const Int half = E.Height()/2;
104     ScaleTrapezoid( beta, uplo, E );
105     LockedPartitionDown( A, AT, AB, half );
106     LockedPartitionRight( B, BL, BR, half );
107     LockedPartitionDown( C, CT, CB, half );
108     LockedPartitionRight( D, DL, DR, half );
109     PartitionDownDiagonal
110     ( E, ETL, ETR,
111          EBL, EBR, half );
112 
113     if( uplo == LOWER )
114     {
115         LocalGemm( NORMAL, NORMAL, alpha, AB, BL, T(1), EBL );
116         LocalGemm( NORMAL, NORMAL, alpha, CB, DL, T(1), EBL );
117     }
118     else
119     {
120         LocalGemm( NORMAL, NORMAL, alpha, AT, BR, T(1), ETR );
121         LocalGemm( NORMAL, NORMAL, alpha, CT, DR, T(1), ETR );
122     }
123 
124     FTL.AlignWith( ETL );
125     LocalGemm( NORMAL, NORMAL, alpha, AT, BL, FTL );
126     LocalGemm( NORMAL, NORMAL, alpha, CT, DL, T(1), FTL );
127     AxpyTriangle( uplo, T(1), FTL, ETL );
128 
129     FBR.AlignWith( EBR );
130     LocalGemm( NORMAL, NORMAL, alpha, AB, BR, FBR );
131     LocalGemm( NORMAL, NORMAL, alpha, CB, DR, T(1), FBR );
132     AxpyTriangle( uplo, T(1), FBR, EBR );
133 }
134 
135 // E := alpha (A B + C D^{T/H}) + beta C
136 template<typename T>
137 inline void
LocalTrr2kKernel(UpperOrLower uplo,Orientation orientationOfD,T alpha,const DistMatrix<T,MC,STAR> & A,const DistMatrix<T,STAR,MR> & B,const DistMatrix<T,MC,STAR> & C,const DistMatrix<T,MR,STAR> & D,T beta,DistMatrix<T> & E)138 LocalTrr2kKernel
139 ( UpperOrLower uplo, Orientation orientationOfD,
140   T alpha, const DistMatrix<T,MC,STAR>& A, const DistMatrix<T,STAR,MR>& B,
141            const DistMatrix<T,MC,STAR>& C, const DistMatrix<T,MR,STAR>& D,
142   T beta,        DistMatrix<T>& E )
143 {
144     DEBUG_ONLY(
145         CallStackEntry cse("LocalTrr2kKernel");
146         CheckInput( A, B, C, D, E );
147     )
148     const Grid& g = E.Grid();
149 
150     DistMatrix<T,MC,STAR> AT(g),  CT(g),
151                           AB(g),  CB(g);
152     DistMatrix<T,MR,STAR> DT(g),
153                           DB(g);
154     DistMatrix<T,STAR,MR> BL(g), BR(g);
155     DistMatrix<T> ETL(g), ETR(g),
156                   EBL(g), EBR(g);
157     DistMatrix<T> FTL(g), FBR(g);
158 
159     const Int half = E.Height()/2;
160     ScaleTrapezoid( beta, uplo, E );
161     LockedPartitionDown( A, AT, AB, half );
162     LockedPartitionRight( B, BL, BR, half );
163     LockedPartitionDown( C, CT, CB, half );
164     LockedPartitionDown( D, DT, DB, half );
165     PartitionDownDiagonal
166     ( E, ETL, ETR,
167          EBL, EBR, half );
168 
169     if( uplo == LOWER )
170     {
171         LocalGemm( NORMAL, NORMAL, alpha, AB, BL, T(1), EBL );
172         LocalGemm( NORMAL, orientationOfD, alpha, CB, DT, T(1), EBL );
173     }
174     else
175     {
176         LocalGemm( NORMAL, NORMAL, alpha, AT, BR, T(1), ETR );
177         LocalGemm( NORMAL, orientationOfD, alpha, CT, DB, T(1), ETR );
178     }
179 
180     FTL.AlignWith( ETL );
181     LocalGemm( NORMAL, NORMAL, alpha, AT, BL, FTL );
182     LocalGemm( NORMAL, orientationOfD, alpha, CT, DT, T(1), FTL );
183     AxpyTriangle( uplo, T(1), FTL, ETL );
184 
185     FBR.AlignWith( EBR );
186     LocalGemm( NORMAL, NORMAL, alpha, AB, BR, FBR );
187     LocalGemm( NORMAL, orientationOfD, alpha, CB, DB, T(1), FBR );
188     AxpyTriangle( uplo, T(1), FBR, EBR );
189 }
190 
191 // E := alpha (A B + C^{T/H} D) + beta E
192 template<typename T>
193 inline void
LocalTrr2kKernel(UpperOrLower uplo,Orientation orientationOfC,T alpha,const DistMatrix<T,MC,STAR> & A,const DistMatrix<T,STAR,MR> & B,const DistMatrix<T,STAR,MC> & C,const DistMatrix<T,STAR,MR> & D,T beta,DistMatrix<T> & E)194 LocalTrr2kKernel
195 ( UpperOrLower uplo, Orientation orientationOfC,
196   T alpha, const DistMatrix<T,MC,STAR>& A, const DistMatrix<T,STAR,MR>& B,
197            const DistMatrix<T,STAR,MC>& C, const DistMatrix<T,STAR,MR>& D,
198   T beta,        DistMatrix<T>& E )
199 {
200     DEBUG_ONLY(
201         CallStackEntry cse("LocalTrr2kKernel");
202         CheckInput( A, B, C, D, E );
203     )
204     const Grid& g = E.Grid();
205 
206     DistMatrix<T,MC,STAR> AT(g), AB(g);
207     DistMatrix<T,STAR,MC> CL(g), CR(g);
208     DistMatrix<T,STAR,MR> BL(g), BR(g),
209                           DL(g), DR(g);
210     DistMatrix<T> ETL(g), ETR(g),
211                   EBL(g), EBR(g);
212     DistMatrix<T> FTL(g), FBR(g);
213 
214     const Int half = E.Height()/2;
215     ScaleTrapezoid( beta, uplo, E );
216     LockedPartitionDown( A, AT, AB, half );
217     LockedPartitionRight( B, BL, BR, half );
218     LockedPartitionRight( C, CL, CR, half );
219     LockedPartitionRight( D, DL, DR, half );
220     PartitionDownDiagonal
221     ( E, ETL, ETR,
222          EBL, EBR, half );
223 
224     if( uplo == LOWER )
225     {
226         LocalGemm( NORMAL, NORMAL, alpha, AB, BL, T(1), EBL );
227         LocalGemm( orientationOfC, NORMAL, alpha, CR, DL, T(1), EBL );
228     }
229     else
230     {
231         LocalGemm( NORMAL, NORMAL, alpha, AT, BR, T(1), ETR );
232         LocalGemm( orientationOfC, NORMAL, alpha, CL, DR, T(1), ETR );
233     }
234 
235     FTL.AlignWith( ETL );
236     LocalGemm( NORMAL, NORMAL, alpha, AT, BL, FTL );
237     LocalGemm( orientationOfC, NORMAL, alpha, CL, DL, T(1), FTL );
238     AxpyTriangle( uplo, T(1), FTL, ETL );
239 
240     FBR.AlignWith( EBR );
241     LocalGemm( NORMAL, NORMAL, alpha, AB, BR, FBR );
242     LocalGemm( orientationOfC, NORMAL, alpha, CR, DR, T(1), FBR );
243     AxpyTriangle( uplo, T(1), FBR, EBR );
244 }
245 
246 // E := alpha (A B + C^{T/H} D^{T/H}) + beta E
247 template<typename T>
248 inline void
LocalTrr2kKernel(UpperOrLower uplo,Orientation orientationOfC,Orientation orientationOfD,T alpha,const DistMatrix<T,MC,STAR> & A,const DistMatrix<T,STAR,MR> & B,const DistMatrix<T,STAR,MC> & C,const DistMatrix<T,MR,STAR> & D,T beta,DistMatrix<T> & E)249 LocalTrr2kKernel
250 ( UpperOrLower uplo, Orientation orientationOfC, Orientation orientationOfD,
251   T alpha, const DistMatrix<T,MC,STAR>& A, const DistMatrix<T,STAR,MR>& B,
252            const DistMatrix<T,STAR,MC>& C, const DistMatrix<T,MR,STAR>& D,
253   T beta,        DistMatrix<T>& E )
254 {
255     DEBUG_ONLY(
256         CallStackEntry cse("LocalTrr2kKernel");
257         CheckInput( A, B, C, D, E );
258     )
259     const Grid& g = E.Grid();
260 
261     DistMatrix<T,MC,STAR> AT(g), AB(g);
262     DistMatrix<T,STAR,MR> BL(g), BR(g);
263     DistMatrix<T,STAR,MC> CL(g), CR(g);
264     DistMatrix<T,MR,STAR> DT(g), DB(g);
265     DistMatrix<T> ETL(g), ETR(g),
266                   EBL(g), EBR(g);
267     DistMatrix<T> FTL(g), FBR(g);
268 
269     const Int half = E.Height()/2;
270     ScaleTrapezoid( beta, uplo, E );
271     LockedPartitionDown( A, AT, AB, half );
272     LockedPartitionRight( B, BL, BR, half );
273     LockedPartitionRight( C, CL, CR, half );
274     LockedPartitionDown( D, DT, DB, half );
275     PartitionDownDiagonal
276     ( E, ETL, ETR,
277          EBL, EBR, half );
278 
279     if( uplo == LOWER )
280     {
281         LocalGemm( NORMAL, NORMAL, alpha, AB, BL, T(1), EBL );
282         LocalGemm( orientationOfC, orientationOfD, alpha, CR, DT, T(1), EBL );
283     }
284     else
285     {
286         LocalGemm( NORMAL, NORMAL, alpha, AT, BR, T(1), ETR );
287         LocalGemm( orientationOfC, orientationOfD, alpha, CL, DB, T(1), ETR );
288     }
289 
290     FTL.AlignWith( ETL );
291     LocalGemm( NORMAL, NORMAL, alpha, AT, BL, FTL );
292     LocalGemm( orientationOfC, orientationOfD, alpha, CL, DT, T(1), FTL );
293     AxpyTriangle( uplo, T(1), FTL, ETL );
294 
295     FBR.AlignWith( EBR );
296     LocalGemm( NORMAL, NORMAL, alpha, AB, BR, FBR );
297     LocalGemm( orientationOfC, orientationOfD, alpha, CR, DB, T(1), FBR );
298     AxpyTriangle( uplo, T(1), FBR, EBR );
299 }
300 
301 // E := alpha (A B^{T/H} + C D) + beta C
302 template<typename T>
303 inline void
LocalTrr2kKernel(UpperOrLower uplo,Orientation orientationOfB,T alpha,const DistMatrix<T,MC,STAR> & A,const DistMatrix<T,MR,STAR> & B,const DistMatrix<T,MC,STAR> & C,const DistMatrix<T,STAR,MR> & D,T beta,DistMatrix<T> & E)304 LocalTrr2kKernel
305 ( UpperOrLower uplo, Orientation orientationOfB,
306   T alpha, const DistMatrix<T,MC,STAR>& A, const DistMatrix<T,MR,STAR>& B,
307            const DistMatrix<T,MC,STAR>& C, const DistMatrix<T,STAR,MR>& D,
308   T beta,        DistMatrix<T>& E )
309 {
310     DEBUG_ONLY(
311         CallStackEntry cse("LocalTrr2kKernel");
312         CheckInput( A, B, C, D, E );
313     )
314     const Grid& g = E.Grid();
315 
316     DistMatrix<T,MC,STAR> AT(g),  CT(g),
317                           AB(g),  CB(g);
318     DistMatrix<T,MR,STAR> BT(g), BB(g);
319     DistMatrix<T,STAR,MR> DL(g), DR(g);
320     DistMatrix<T> ETL(g), ETR(g),
321                   EBL(g), EBR(g);
322     DistMatrix<T> FTL(g), FBR(g);
323 
324     const Int half = E.Height()/2;
325     ScaleTrapezoid( beta, uplo, E );
326     LockedPartitionDown( A, AT, AB, half );
327     LockedPartitionDown( B, BT, BB, half );
328     LockedPartitionDown( C, CT, CB, half );
329     LockedPartitionRight( D, DL, DR, half );
330     PartitionDownDiagonal
331     ( E, ETL, ETR,
332          EBL, EBR, half );
333 
334     if( uplo == LOWER )
335     {
336         LocalGemm( NORMAL, orientationOfB, alpha, AB, BT, T(1), EBL );
337         LocalGemm( NORMAL, NORMAL, alpha, CB, DL, T(1), EBL );
338     }
339     else
340     {
341         LocalGemm( NORMAL, orientationOfB, alpha, AT, BB, T(1), ETR );
342         LocalGemm( NORMAL, NORMAL, alpha, CT, DR, T(1), ETR );
343     }
344 
345     FTL.AlignWith( ETL );
346     LocalGemm( NORMAL, orientationOfB, alpha, AT, BT, FTL );
347     LocalGemm( NORMAL, NORMAL, alpha, CT, DL, T(1), FTL );
348     AxpyTriangle( uplo, T(1), FTL, ETL );
349 
350     FBR.AlignWith( EBR );
351     LocalGemm( NORMAL, orientationOfB, alpha, AB, BB, FBR );
352     LocalGemm( NORMAL, NORMAL, alpha, CB, DR, T(1), FBR );
353     AxpyTriangle( uplo, T(1), FBR, EBR );
354 }
355 
356 // E := alpha (A B^{T/H} + C D^{T/H}) + beta C
357 template<typename T>
358 inline void
LocalTrr2kKernel(UpperOrLower uplo,Orientation orientationOfB,Orientation orientationOfD,T alpha,const DistMatrix<T,MC,STAR> & A,const DistMatrix<T,MR,STAR> & B,const DistMatrix<T,MC,STAR> & C,const DistMatrix<T,MR,STAR> & D,T beta,DistMatrix<T> & E)359 LocalTrr2kKernel
360 ( UpperOrLower uplo, Orientation orientationOfB, Orientation orientationOfD,
361   T alpha, const DistMatrix<T,MC,STAR>& A, const DistMatrix<T,MR,STAR>& B,
362            const DistMatrix<T,MC,STAR>& C, const DistMatrix<T,MR,STAR>& D,
363   T beta,        DistMatrix<T>& E )
364 {
365     DEBUG_ONLY(
366         CallStackEntry cse("LocalTrr2kKernel");
367         CheckInput( A, B, C, D, E );
368     )
369     const Grid& g = E.Grid();
370 
371     DistMatrix<T,MC,STAR> AT(g),  CT(g),
372                           AB(g),  CB(g);
373     DistMatrix<T,MR,STAR> BT(g),  DT(g),
374                           BB(g),  DB(g);
375     DistMatrix<T> ETL(g), ETR(g),
376                   EBL(g), EBR(g);
377     DistMatrix<T> FTL(g), FBR(g);
378 
379     const Int half = E.Height()/2;
380     ScaleTrapezoid( beta, uplo, E );
381     LockedPartitionDown( A, AT, AB, half );
382     LockedPartitionDown( B, BT, BB, half );
383     LockedPartitionDown( C, CT, CB, half );
384     LockedPartitionDown( D, DT, DB, half );
385     PartitionDownDiagonal
386     ( E, ETL, ETR,
387          EBL, EBR, half );
388 
389     if( uplo == LOWER )
390     {
391         LocalGemm( NORMAL, orientationOfB, alpha, AB, BT, T(1), EBL );
392         LocalGemm( NORMAL, orientationOfD, alpha, CB, DT, T(1), EBL );
393     }
394     else
395     {
396         LocalGemm( NORMAL, orientationOfB, alpha, AT, BB, T(1), ETR );
397         LocalGemm( NORMAL, orientationOfD, alpha, CT, DB, T(1), ETR );
398     }
399 
400     FTL.AlignWith( ETL );
401     LocalGemm( NORMAL, orientationOfB, alpha, AT, BT, FTL );
402     LocalGemm( NORMAL, orientationOfD, alpha, CT, DT, T(1), FTL );
403     AxpyTriangle( uplo, T(1), FTL, ETL );
404 
405     FBR.AlignWith( EBR );
406     LocalGemm( NORMAL, orientationOfB, alpha, AB, BB, FBR );
407     LocalGemm( NORMAL, orientationOfD, alpha, CB, DB, T(1), FBR );
408     AxpyTriangle( uplo, T(1), FBR, EBR );
409 }
410 
411 // E := alpha (A B^{T/H} + C^{T/H} D) + beta E
412 template<typename T>
413 inline void
LocalTrr2kKernel(UpperOrLower uplo,Orientation orientationOfB,Orientation orientationOfC,T alpha,const DistMatrix<T,MC,STAR> & A,const DistMatrix<T,MR,STAR> & B,const DistMatrix<T,STAR,MC> & C,const DistMatrix<T,STAR,MR> & D,T beta,DistMatrix<T> & E)414 LocalTrr2kKernel
415 ( UpperOrLower uplo, Orientation orientationOfB, Orientation orientationOfC,
416   T alpha, const DistMatrix<T,MC,STAR>& A, const DistMatrix<T,MR,STAR>& B,
417            const DistMatrix<T,STAR,MC>& C, const DistMatrix<T,STAR,MR>& D,
418   T beta,        DistMatrix<T>& E )
419 {
420     DEBUG_ONLY(
421         CallStackEntry cse("LocalTrr2kKernel");
422         CheckInput( A, B, C, D, E );
423     )
424     const Grid& g = E.Grid();
425 
426     DistMatrix<T,MC,STAR> AT(g), AB(g);
427     DistMatrix<T,MR,STAR> BT(g), BB(g);
428     DistMatrix<T,STAR,MC> CL(g), CR(g);
429     DistMatrix<T,STAR,MR> DL(g), DR(g);
430     DistMatrix<T> ETL(g), ETR(g),
431                   EBL(g), EBR(g);
432     DistMatrix<T> FTL(g), FBR(g);
433 
434     const Int half = E.Height()/2;
435     ScaleTrapezoid( beta, uplo, E );
436     LockedPartitionDown( A, AT, AB, half );
437     LockedPartitionDown( B, BT, BB, half );
438     LockedPartitionRight( C, CL, CR, half );
439     LockedPartitionRight( D, DL, DR, half );
440     PartitionDownDiagonal
441     ( E, ETL, ETR,
442          EBL, EBR, half );
443 
444     if( uplo == LOWER )
445     {
446         LocalGemm( NORMAL, orientationOfB, alpha, AB, BT, T(1), EBL );
447         LocalGemm( orientationOfC, NORMAL, alpha, CR, DL, T(1), EBL );
448     }
449     else
450     {
451         LocalGemm( NORMAL, orientationOfB, alpha, AT, BB, T(1), ETR );
452         LocalGemm( orientationOfC, NORMAL, alpha, CL, DR, T(1), ETR );
453     }
454 
455     FTL.AlignWith( ETL );
456     LocalGemm( NORMAL, orientationOfB, alpha, AT, BT, FTL );
457     LocalGemm( orientationOfC, NORMAL, alpha, CL, DL, T(1), FTL );
458     AxpyTriangle( uplo, T(1), FTL, ETL );
459 
460     FBR.AlignWith( EBR );
461     LocalGemm( NORMAL, orientationOfB, alpha, AB, BB, FBR );
462     LocalGemm( orientationOfC, NORMAL, alpha, CR, DR, T(1), FBR );
463     AxpyTriangle( uplo, T(1), FBR, EBR );
464 }
465 
466 // E := alpha (A B^{T/H} + C^{T/H} D^{T/H}) + beta C
467 template<typename T>
468 inline void
LocalTrr2kKernel(UpperOrLower uplo,Orientation orientationOfB,Orientation orientationOfC,Orientation orientationOfD,T alpha,const DistMatrix<T,MC,STAR> & A,const DistMatrix<T,MR,STAR> & B,const DistMatrix<T,STAR,MC> & C,const DistMatrix<T,MR,STAR> & D,T beta,DistMatrix<T> & E)469 LocalTrr2kKernel
470 ( UpperOrLower uplo,
471   Orientation orientationOfB,
472   Orientation orientationOfC,
473   Orientation orientationOfD,
474   T alpha, const DistMatrix<T,MC,STAR>& A, const DistMatrix<T,MR,STAR>& B,
475            const DistMatrix<T,STAR,MC>& C, const DistMatrix<T,MR,STAR>& D,
476   T beta,        DistMatrix<T>& E )
477 {
478     DEBUG_ONLY(
479         CallStackEntry cse("LocalTrr2kKernel");
480         CheckInput( A, B, C, D, E );
481     )
482     const Grid& g = E.Grid();
483 
484     DistMatrix<T,MC,STAR> AT(g), AB(g);
485     DistMatrix<T,MR,STAR> BT(g), BB(g);
486     DistMatrix<T,STAR,MC> CL(g), CR(g);
487     DistMatrix<T,MR,STAR> DT(g), DB(g);
488     DistMatrix<T> ETL(g), ETR(g),
489                   EBL(g), EBR(g);
490     DistMatrix<T> FTL(g), FBR(g);
491 
492     const Int half = E.Height()/2;
493     ScaleTrapezoid( beta, uplo, E );
494     LockedPartitionDown( A, AT, AB, half );
495     LockedPartitionDown( B, BT, BB, half );
496     LockedPartitionRight( C, CL, CR, half );
497     LockedPartitionDown( D, DT, DB, half );
498     PartitionDownDiagonal
499     ( E, ETL, ETR,
500          EBL, EBR, half );
501 
502     if( uplo == LOWER )
503     {
504         LocalGemm( NORMAL, orientationOfB, alpha, AB, BT, T(1), EBL );
505         LocalGemm( orientationOfC, orientationOfD, alpha, CR, DT, T(1), EBL );
506     }
507     else
508     {
509         LocalGemm( NORMAL, orientationOfB, alpha, AT, BB, T(1), ETR );
510         LocalGemm( orientationOfC, orientationOfD, alpha, CL, DB, T(1), ETR );
511     }
512 
513     FTL.AlignWith( ETL );
514     LocalGemm( NORMAL, orientationOfB, alpha, AT, BT, FTL );
515     LocalGemm( orientationOfC, orientationOfD, alpha, CL, DT, T(1), FTL );
516     AxpyTriangle( uplo, T(1), FTL, ETL );
517 
518     FBR.AlignWith( EBR );
519     LocalGemm( NORMAL, orientationOfB, alpha, AB, BB, FBR );
520     LocalGemm( orientationOfC, orientationOfD, alpha, CR, DB, T(1), FBR );
521     AxpyTriangle( uplo, T(1), FBR, EBR );
522 }
523 
524 // E := alpha (A^{T/H} B + C D) + beta E
525 template<typename T>
526 inline void
LocalTrr2kKernel(UpperOrLower uplo,Orientation orientationOfA,T alpha,const DistMatrix<T,STAR,MC> & A,const DistMatrix<T,STAR,MR> & B,const DistMatrix<T,MC,STAR> & C,const DistMatrix<T,STAR,MR> & D,T beta,DistMatrix<T> & E)527 LocalTrr2kKernel
528 ( UpperOrLower uplo, Orientation orientationOfA,
529   T alpha, const DistMatrix<T,STAR,MC>& A, const DistMatrix<T,STAR,MR>& B,
530            const DistMatrix<T,MC,STAR>& C, const DistMatrix<T,STAR,MR>& D,
531   T beta,        DistMatrix<T>& E )
532 {
533     DEBUG_ONLY(
534         CallStackEntry cse("LocalTrr2kKernel");
535         CheckInput( A, B, C, D, E );
536     )
537     const Grid& g = E.Grid();
538 
539     DistMatrix<T,STAR,MC> AL(g), AR(g);
540     DistMatrix<T,MC,STAR> CT(g), CB(g);
541     DistMatrix<T,STAR,MR> BL(g), BR(g),
542                           DL(g), DR(g);
543     DistMatrix<T> ETL(g), ETR(g),
544                   EBL(g), EBR(g);
545     DistMatrix<T> FTL(g), FBR(g);
546 
547     const Int half = E.Height()/2;
548     ScaleTrapezoid( beta, uplo, E );
549     LockedPartitionRight( A, AL, AR, half );
550     LockedPartitionRight( B, BL, BR, half );
551     LockedPartitionDown( C, CT, CB, half );
552     LockedPartitionRight( D, DL, DR, half );
553     PartitionDownDiagonal
554     ( E, ETL, ETR,
555          EBL, EBR, half );
556 
557     if( uplo == LOWER )
558     {
559         LocalGemm( orientationOfA, NORMAL, alpha, AR, BL, T(1), EBL );
560         LocalGemm( NORMAL, NORMAL, alpha, CB, DL, T(1), EBL );
561     }
562     else
563     {
564         LocalGemm( orientationOfA, NORMAL, alpha, AL, BR, T(1), ETR );
565         LocalGemm( NORMAL, NORMAL, alpha, CT, DR, T(1), ETR );
566     }
567 
568     FTL.AlignWith( ETL );
569     LocalGemm( orientationOfA, NORMAL, alpha, AL, BL, FTL );
570     LocalGemm( NORMAL, NORMAL, alpha, CT, DL, T(1), FTL );
571     AxpyTriangle( uplo, T(1), FTL, ETL );
572 
573     FBR.AlignWith( EBR );
574     LocalGemm( orientationOfA, NORMAL, alpha, AR, BR, FBR );
575     LocalGemm( NORMAL, NORMAL, alpha, CB, DR, T(1), FBR );
576     AxpyTriangle( uplo, T(1), FBR, EBR );
577 }
578 
579 // E := alpha (A^{T/H} B + C D^{T/H}) + beta E
580 template<typename T>
581 inline void
LocalTrr2kKernel(UpperOrLower uplo,Orientation orientationOfA,Orientation orientationOfD,T alpha,const DistMatrix<T,STAR,MC> & A,const DistMatrix<T,STAR,MR> & B,const DistMatrix<T,MC,STAR> & C,const DistMatrix<T,MR,STAR> & D,T beta,DistMatrix<T> & E)582 LocalTrr2kKernel
583 ( UpperOrLower uplo, Orientation orientationOfA, Orientation orientationOfD,
584   T alpha, const DistMatrix<T,STAR,MC>& A, const DistMatrix<T,STAR,MR>& B,
585            const DistMatrix<T,MC,STAR>& C, const DistMatrix<T,MR,STAR>& D,
586   T beta,        DistMatrix<T>& E )
587 {
588     DEBUG_ONLY(
589         CallStackEntry cse("LocalTrr2kKernel");
590         CheckInput( A, B, C, D, E );
591     )
592     const Grid& g = E.Grid();
593 
594     DistMatrix<T,STAR,MC> AL(g), AR(g);
595     DistMatrix<T,STAR,MR> BL(g), BR(g);
596     DistMatrix<T,MC,STAR> CT(g), CB(g);
597     DistMatrix<T,MR,STAR> DT(g), DB(g);
598     DistMatrix<T> ETL(g), ETR(g),
599                   EBL(g), EBR(g);
600     DistMatrix<T> FTL(g), FBR(g);
601 
602     const Int half = E.Height()/2;
603     ScaleTrapezoid( beta, uplo, E );
604     LockedPartitionRight( A, AL, AR, half );
605     LockedPartitionRight( B, BL, BR, half );
606     LockedPartitionDown( C, CT, CB, half );
607     LockedPartitionDown( D, DT, DB, half );
608     PartitionDownDiagonal
609     ( E, ETL, ETR,
610          EBL, EBR, half );
611 
612     if( uplo == LOWER )
613     {
614         LocalGemm( orientationOfA, NORMAL, alpha, AR, BL, T(1), EBL );
615         LocalGemm( NORMAL, orientationOfD, alpha, CB, DT, T(1), EBL );
616     }
617     else
618     {
619         LocalGemm( orientationOfA, NORMAL, alpha, AL, BR, T(1), ETR );
620         LocalGemm( NORMAL, orientationOfD, alpha, CT, DB, T(1), ETR );
621     }
622 
623     FTL.AlignWith( ETL );
624     LocalGemm( orientationOfA, NORMAL, alpha, AL, BL, FTL );
625     LocalGemm( NORMAL, orientationOfD, alpha, CT, DT, T(1), FTL );
626     AxpyTriangle( uplo, T(1), FTL, ETL );
627 
628     FBR.AlignWith( EBR );
629     LocalGemm( orientationOfA, NORMAL, alpha, AR, BR, FBR );
630     LocalGemm( NORMAL, orientationOfD, alpha, CB, DB, T(1), FBR );
631     AxpyTriangle( uplo, T(1), FBR, EBR );
632 }
633 
634 // E := alpha (A^{T/H} B + C^{T/H} D) + beta E
635 template<typename T>
636 inline void
LocalTrr2kKernel(UpperOrLower uplo,Orientation orientationOfA,Orientation orientationOfC,T alpha,const DistMatrix<T,STAR,MC> & A,const DistMatrix<T,STAR,MR> & B,const DistMatrix<T,STAR,MC> & C,const DistMatrix<T,STAR,MR> & D,T beta,DistMatrix<T> & E)637 LocalTrr2kKernel
638 ( UpperOrLower uplo, Orientation orientationOfA, Orientation orientationOfC,
639   T alpha, const DistMatrix<T,STAR,MC>& A, const DistMatrix<T,STAR,MR>& B,
640            const DistMatrix<T,STAR,MC>& C, const DistMatrix<T,STAR,MR>& D,
641   T beta,        DistMatrix<T>& E )
642 {
643     DEBUG_ONLY(
644         CallStackEntry cse("LocalTrr2kKernel");
645         CheckInput( A, B, C, D, E );
646     )
647     const Grid& g = E.Grid();
648 
649     DistMatrix<T,STAR,MC> AL(g), AR(g),
650                           CL(g), CR(g);
651     DistMatrix<T,STAR,MR> BL(g), BR(g),
652                           DL(g), DR(g);
653     DistMatrix<T> ETL(g), ETR(g),
654                   EBL(g), EBR(g);
655     DistMatrix<T> FTL(g), FBR(g);
656 
657     const Int half = E.Height()/2;
658     ScaleTrapezoid( beta, uplo, E );
659     LockedPartitionRight( A, AL, AR, half );
660     LockedPartitionRight( B, BL, BR, half );
661     LockedPartitionRight( C, CL, CR, half );
662     LockedPartitionRight( D, DL, DR, half );
663     PartitionDownDiagonal
664     ( E, ETL, ETR,
665          EBL, EBR, half );
666 
667     if( uplo == LOWER )
668     {
669         LocalGemm( orientationOfA, NORMAL, alpha, AR, BL, T(1), EBL );
670         LocalGemm( orientationOfC, NORMAL, alpha, CR, DL, T(1), EBL );
671     }
672     else
673     {
674         LocalGemm( orientationOfA, NORMAL, alpha, AL, BR, T(1), ETR );
675         LocalGemm( orientationOfC, NORMAL, alpha, CL, DR, T(1), ETR );
676     }
677 
678     FTL.AlignWith( ETL );
679     LocalGemm( orientationOfA, NORMAL, alpha, AL, BL, FTL );
680     LocalGemm( orientationOfC, NORMAL, alpha, CL, DL, T(1), FTL );
681     AxpyTriangle( uplo, T(1), FTL, ETL );
682 
683     FBR.AlignWith( EBR );
684     LocalGemm( orientationOfA, NORMAL, alpha, AR, BR, FBR );
685     LocalGemm( orientationOfC, NORMAL, alpha, CR, DR, T(1), FBR );
686     AxpyTriangle( uplo, T(1), FBR, EBR );
687 }
688 
689 // E := alpha (A^{T/H} B + C^{T/H} D^{T/H}) + beta E
690 template<typename T>
691 inline void
LocalTrr2kKernel(UpperOrLower uplo,Orientation orientationOfA,Orientation orientationOfC,Orientation orientationOfD,T alpha,const DistMatrix<T,STAR,MC> & A,const DistMatrix<T,STAR,MR> & B,const DistMatrix<T,STAR,MC> & C,const DistMatrix<T,MR,STAR> & D,T beta,DistMatrix<T> & E)692 LocalTrr2kKernel
693 ( UpperOrLower uplo,
694   Orientation orientationOfA,
695   Orientation orientationOfC,
696   Orientation orientationOfD,
697   T alpha, const DistMatrix<T,STAR,MC>& A, const DistMatrix<T,STAR,MR>& B,
698            const DistMatrix<T,STAR,MC>& C, const DistMatrix<T,MR,STAR>& D,
699   T beta,        DistMatrix<T>& E )
700 {
701     DEBUG_ONLY(
702         CallStackEntry cse("LocalTrr2kKernel");
703         CheckInput( A, B, C, D, E );
704     )
705     const Grid& g = E.Grid();
706 
707     DistMatrix<T,STAR,MC> AL(g), AR(g),
708                           CL(g), CR(g);
709     DistMatrix<T,STAR,MR> BL(g), BR(g);
710     DistMatrix<T,MR,STAR> DT(g), DB(g);
711     DistMatrix<T> ETL(g), ETR(g),
712                   EBL(g), EBR(g);
713     DistMatrix<T> FTL(g), FBR(g);
714 
715     const Int half = E.Height()/2;
716     ScaleTrapezoid( beta, uplo, E );
717     LockedPartitionRight( A, AL, AR, half );
718     LockedPartitionRight( B, BL, BR, half );
719     LockedPartitionRight( C, CL, CR, half );
720     LockedPartitionDown( D, DT, DB, half );
721     PartitionDownDiagonal
722     ( E, ETL, ETR,
723          EBL, EBR, half );
724 
725     if( uplo == LOWER )
726     {
727         LocalGemm( orientationOfA, NORMAL, alpha, AR, BL, T(1), EBL );
728         LocalGemm( orientationOfC, orientationOfD, alpha, CR, DT, T(1), EBL );
729     }
730     else
731     {
732         LocalGemm( orientationOfA, NORMAL, alpha, AL, BR, T(1), ETR );
733         LocalGemm( orientationOfC, orientationOfD, alpha, CL, DB, T(1), ETR );
734     }
735 
736     FTL.AlignWith( ETL );
737     LocalGemm( orientationOfA, NORMAL, alpha, AL, BL, FTL );
738     LocalGemm( orientationOfC, orientationOfD, alpha, CL, DT, T(1), FTL );
739     AxpyTriangle( uplo, T(1), FTL, ETL );
740 
741     FBR.AlignWith( EBR );
742     LocalGemm( orientationOfA, NORMAL, alpha, AR, BR, FBR );
743     LocalGemm( orientationOfC, orientationOfD, alpha, CR, DB, T(1), FBR );
744     AxpyTriangle( uplo, T(1), FBR, EBR );
745 }
746 
747 // E := alpha (A^{T/H} B^{T/H} + C D) + beta E
748 template<typename T>
749 inline void
LocalTrr2kKernel(UpperOrLower uplo,Orientation orientationOfA,Orientation orientationOfB,T alpha,const DistMatrix<T,STAR,MC> & A,const DistMatrix<T,MR,STAR> & B,const DistMatrix<T,MC,STAR> & C,const DistMatrix<T,STAR,MR> & D,T beta,DistMatrix<T> & E)750 LocalTrr2kKernel
751 ( UpperOrLower uplo, Orientation orientationOfA, Orientation orientationOfB,
752   T alpha, const DistMatrix<T,STAR,MC>& A, const DistMatrix<T,MR,STAR>& B,
753            const DistMatrix<T,MC,STAR>& C, const DistMatrix<T,STAR,MR>& D,
754   T beta,        DistMatrix<T>& E )
755 {
756     DEBUG_ONLY(
757         CallStackEntry cse("LocalTrr2kKernel");
758         CheckInput( A, B, C, D, E );
759     )
760     const Grid& g = E.Grid();
761 
762     DistMatrix<T,STAR,MC> AL(g), AR(g);
763     DistMatrix<T,MR,STAR> BT(g), BB(g);
764     DistMatrix<T,MC,STAR> CT(g), CB(g);
765     DistMatrix<T,STAR,MR> DL(g), DR(g);
766     DistMatrix<T> ETL(g), ETR(g),
767                   EBL(g), EBR(g);
768     DistMatrix<T> FTL(g), FBR(g);
769 
770     const Int half = E.Height()/2;
771     ScaleTrapezoid( beta, uplo, E );
772     LockedPartitionRight( A, AL, AR, half );
773     LockedPartitionDown( B, BT, BB, half );
774     LockedPartitionDown( C, CT, CB, half );
775     LockedPartitionRight( D, DL, DR, half );
776     PartitionDownDiagonal
777     ( E, ETL, ETR,
778          EBL, EBR, half );
779 
780     if( uplo == LOWER )
781     {
782         LocalGemm( orientationOfA, orientationOfB, alpha, AR, BT, T(1), EBL );
783         LocalGemm( NORMAL, NORMAL, alpha, CB, DL, T(1), EBL );
784     }
785     else
786     {
787         LocalGemm( orientationOfA, orientationOfB, alpha, AL, BB, T(1), ETR );
788         LocalGemm( NORMAL, NORMAL, alpha, CT, DR, T(1), ETR );
789     }
790 
791     FTL.AlignWith( ETL );
792     LocalGemm( orientationOfA, orientationOfB, alpha, AL, BT, FTL );
793     LocalGemm( NORMAL, NORMAL, alpha, CT, DL, T(1), FTL );
794     AxpyTriangle( uplo, T(1), FTL, ETL );
795 
796     FBR.AlignWith( EBR );
797     LocalGemm( orientationOfA, orientationOfB, alpha, AR, BB, FBR );
798     LocalGemm( NORMAL, NORMAL, alpha, CB, DR, T(1), FBR );
799     AxpyTriangle( uplo, T(1), FBR, EBR );
800 }
801 
802 // E := alpha (A^{T/H} B^{T/H} + C D^{T/H}) + beta C
803 template<typename T>
804 inline void
LocalTrr2kKernel(UpperOrLower uplo,Orientation orientationOfA,Orientation orientationOfB,Orientation orientationOfD,T alpha,const DistMatrix<T,STAR,MC> & A,const DistMatrix<T,MR,STAR> & B,const DistMatrix<T,MC,STAR> & C,const DistMatrix<T,MR,STAR> & D,T beta,DistMatrix<T> & E)805 LocalTrr2kKernel
806 ( UpperOrLower uplo,
807   Orientation orientationOfA,
808   Orientation orientationOfB,
809   Orientation orientationOfD,
810   T alpha, const DistMatrix<T,STAR,MC>& A, const DistMatrix<T,MR,STAR>& B,
811            const DistMatrix<T,MC,STAR>& C, const DistMatrix<T,MR,STAR>& D,
812   T beta,        DistMatrix<T>& E )
813 {
814     DEBUG_ONLY(
815         CallStackEntry cse("LocalTrr2kKernel");
816         CheckInput( A, B, C, D, E );
817     )
818     const Grid& g = E.Grid();
819 
820     DistMatrix<T,STAR,MC> AL(g), AR(g);
821     DistMatrix<T,MR,STAR> BT(g), BB(g);
822     DistMatrix<T,MC,STAR> CT(g), CB(g);
823     DistMatrix<T,MR,STAR> DT(g), DB(g);
824     DistMatrix<T> ETL(g), ETR(g),
825                   EBL(g), EBR(g);
826     DistMatrix<T> FTL(g), FBR(g);
827 
828     const Int half = E.Height()/2;
829     ScaleTrapezoid( beta, uplo, E );
830     LockedPartitionRight( A, AL, AR, half );
831     LockedPartitionDown( B, BT, BB, half );
832     LockedPartitionDown( C, CT, CB, half );
833     LockedPartitionDown( D, DT, DB, half );
834     PartitionDownDiagonal
835     ( E, ETL, ETR,
836          EBL, EBR, half );
837 
838     if( uplo == LOWER )
839     {
840         LocalGemm( orientationOfA, orientationOfB, alpha, AR, BT, T(1), EBL );
841         LocalGemm( NORMAL, orientationOfD, alpha, CB, DT, T(1), EBL );
842     }
843     else
844     {
845         LocalGemm( orientationOfA, orientationOfB, alpha, AL, BB, T(1), ETR );
846         LocalGemm( NORMAL, orientationOfD, alpha, CT, DB, T(1), ETR );
847     }
848 
849     FTL.AlignWith( ETL );
850     LocalGemm( orientationOfA, orientationOfB, alpha, AL, BT, FTL );
851     LocalGemm( NORMAL, orientationOfD, alpha, CT, DT, T(1), FTL );
852     AxpyTriangle( uplo, T(1), FTL, ETL );
853 
854     FBR.AlignWith( EBR );
855     LocalGemm( orientationOfA, orientationOfB, alpha, AR, BB, FBR );
856     LocalGemm( NORMAL, orientationOfD, alpha, CB, DB, T(1), FBR );
857     AxpyTriangle( uplo, T(1), FBR, EBR );
858 }
859 
860 // E := alpha (A^{T/H} B^{T/H} + C^{T/H} D) + beta E
861 template<typename T>
862 inline void
LocalTrr2kKernel(UpperOrLower uplo,Orientation orientationOfA,Orientation orientationOfB,Orientation orientationOfC,T alpha,const DistMatrix<T,STAR,MC> & A,const DistMatrix<T,MR,STAR> & B,const DistMatrix<T,STAR,MC> & C,const DistMatrix<T,STAR,MR> & D,T beta,DistMatrix<T> & E)863 LocalTrr2kKernel
864 ( UpperOrLower uplo,
865   Orientation orientationOfA,
866   Orientation orientationOfB,
867   Orientation orientationOfC,
868   T alpha, const DistMatrix<T,STAR,MC>& A, const DistMatrix<T,MR,STAR>& B,
869            const DistMatrix<T,STAR,MC>& C, const DistMatrix<T,STAR,MR>& D,
870   T beta,        DistMatrix<T>& E )
871 {
872     DEBUG_ONLY(
873         CallStackEntry cse("LocalTrr2kKernel");
874         CheckInput( A, B, C, D, E );
875     )
876     const Grid& g = E.Grid();
877 
878     DistMatrix<T,STAR,MC> AL(g), AR(g),
879                           CL(g), CR(g);
880     DistMatrix<T,MR,STAR> BT(g), BB(g);
881     DistMatrix<T,STAR,MR> DL(g), DR(g);
882     DistMatrix<T> ETL(g), ETR(g),
883                   EBL(g), EBR(g);
884     DistMatrix<T> FTL(g), FBR(g);
885 
886     const Int half = E.Height()/2;
887     ScaleTrapezoid( beta, uplo, E );
888     LockedPartitionRight( A, AL, AR, half );
889     LockedPartitionDown( B, BT, BB, half );
890     LockedPartitionRight( C, CL, CR, half );
891     LockedPartitionRight( D, DL, DR, half );
892     PartitionDownDiagonal
893     ( E, ETL, ETR,
894          EBL, EBR, half );
895 
896     if( uplo == LOWER )
897     {
898         LocalGemm( orientationOfA, orientationOfB, alpha, AR, BT, T(1), EBL );
899         LocalGemm( orientationOfC, NORMAL, alpha, CR, DL, T(1), EBL );
900     }
901     else
902     {
903         LocalGemm( orientationOfA, orientationOfB, alpha, AL, BB, T(1), ETR );
904         LocalGemm( orientationOfC, NORMAL, alpha, CL, DR, T(1), ETR );
905     }
906 
907     FTL.AlignWith( ETL );
908     LocalGemm( orientationOfA, orientationOfB, alpha, AL, BT, FTL );
909     LocalGemm( orientationOfC, NORMAL, alpha, CL, DL, T(1), FTL );
910     AxpyTriangle( uplo, T(1), FTL, ETL );
911 
912     FBR.AlignWith( EBR );
913     LocalGemm( orientationOfA, orientationOfB, alpha, AR, BB, FBR );
914     LocalGemm( orientationOfC, NORMAL, alpha, CR, DR, T(1), FBR );
915     AxpyTriangle( uplo, T(1), FBR, EBR );
916 }
917 
918 // E := alpha (A^{T/H} B^{T/H} + C^{T/H} D^{T/H}) + beta C
919 template<typename T>
920 inline void
LocalTrr2kKernel(UpperOrLower uplo,Orientation orientationOfA,Orientation orientationOfB,Orientation orientationOfC,Orientation orientationOfD,T alpha,const DistMatrix<T,STAR,MC> & A,const DistMatrix<T,MR,STAR> & B,const DistMatrix<T,STAR,MC> & C,const DistMatrix<T,MR,STAR> & D,T beta,DistMatrix<T> & E)921 LocalTrr2kKernel
922 ( UpperOrLower uplo,
923   Orientation orientationOfA, Orientation orientationOfB,
924   Orientation orientationOfC, Orientation orientationOfD,
925   T alpha, const DistMatrix<T,STAR,MC>& A, const DistMatrix<T,MR,STAR>& B,
926            const DistMatrix<T,STAR,MC>& C, const DistMatrix<T,MR,STAR>& D,
927   T beta,        DistMatrix<T>& E )
928 {
929     DEBUG_ONLY(
930         CallStackEntry cse("LocalTrr2kKernel");
931         CheckInput( A, B, C, D, E );
932     )
933     const Grid& g = E.Grid();
934 
935     DistMatrix<T,STAR,MC> AL(g), AR(g),
936                           CL(g), CR(g);
937     DistMatrix<T,MR,STAR> BT(g),  DT(g),
938                           BB(g),  DB(g);
939     DistMatrix<T> ETL(g), ETR(g),
940                   EBL(g), EBR(g);
941     DistMatrix<T> FTL(g), FBR(g);
942 
943     const Int half = E.Height()/2;
944     ScaleTrapezoid( beta, uplo, E );
945     LockedPartitionRight( A, AL, AR, half );
946     LockedPartitionDown( B, BT, BB, half );
947     LockedPartitionRight( C, CL, CR, half );
948     LockedPartitionDown( D, DT, DB, half );
949     PartitionDownDiagonal
950     ( E, ETL, ETR,
951          EBL, EBR, half );
952 
953     if( uplo == LOWER )
954     {
955         LocalGemm( orientationOfA, orientationOfB, alpha, AR, BT, T(1), EBL );
956         LocalGemm( orientationOfC, orientationOfD, alpha, CR, DT, T(1), EBL );
957     }
958     else
959     {
960         LocalGemm( orientationOfA, orientationOfB, alpha, AL, BB, T(1), ETR );
961         LocalGemm( orientationOfC, orientationOfD, alpha, CL, DB, T(1), ETR );
962     }
963 
964     FTL.AlignWith( ETL );
965     LocalGemm( orientationOfA, orientationOfB, alpha, AL, BT, FTL );
966     LocalGemm( orientationOfC, orientationOfD, alpha, CL, DT, T(1), FTL );
967     AxpyTriangle( uplo, T(1), FTL, ETL );
968 
969     FBR.AlignWith( EBR );
970     LocalGemm( orientationOfA, orientationOfB, alpha, AR, BB, FBR );
971     LocalGemm( orientationOfC, orientationOfD, alpha, CR, DB, T(1), FBR );
972     AxpyTriangle( uplo, T(1), FBR, EBR );
973 }
974 
975 } // namespace trr2k
976 
977 // E := alpha (A B + C D) + beta E
978 template<typename T>
LocalTrr2k(UpperOrLower uplo,T alpha,const DistMatrix<T,MC,STAR> & A,const DistMatrix<T,STAR,MR> & B,const DistMatrix<T,MC,STAR> & C,const DistMatrix<T,STAR,MR> & D,T beta,DistMatrix<T> & E)979 void LocalTrr2k
980 ( UpperOrLower uplo,
981   T alpha, const DistMatrix<T,MC,STAR>& A, const DistMatrix<T,STAR,MR>& B,
982            const DistMatrix<T,MC,STAR>& C, const DistMatrix<T,STAR,MR>& D,
983   T beta,        DistMatrix<T>& E )
984 {
985     using namespace trr2k;
986     DEBUG_ONLY(
987         CallStackEntry cse("LocalTrr2k");
988         CheckInput( A, B, C, D, E );
989     )
990     const Grid& g = E.Grid();
991 
992     if( E.Height() < g.Width()*LocalTrr2kBlocksize<T>() )
993     {
994         LocalTrr2kKernel( uplo, alpha, A, B, C, D, beta, E );
995     }
996     else
997     {
998         // Split E in four roughly equal pieces, perform a large gemm on corner
999         // and recurse on ETL and EBR.
1000         DistMatrix<T,MC,STAR> AT(g),  CT(g),
1001                               AB(g),  CB(g);
1002         DistMatrix<T,STAR,MR> BL(g), BR(g),
1003                               DL(g), DR(g);
1004         DistMatrix<T> ETL(g), ETR(g),
1005                       EBL(g), EBR(g);
1006 
1007         const Int half = E.Height() / 2;
1008         LockedPartitionDown( A, AT, AB, half );
1009         LockedPartitionRight( B, BL, BR, half );
1010         LockedPartitionDown( C, CT, CB, half );
1011         LockedPartitionRight( D, DL, DR, half );
1012         PartitionDownDiagonal
1013         ( E, ETL, ETR,
1014              EBL, EBR, half );
1015 
1016         if( uplo == LOWER )
1017         {
1018             LocalGemm( NORMAL, NORMAL, alpha, AB, BL, beta, EBL );
1019             LocalGemm( NORMAL, NORMAL, alpha, CB, DL, T(1), EBL );
1020         }
1021         else
1022         {
1023             LocalGemm( NORMAL, NORMAL, alpha, AT, BR, beta, ETR );
1024             LocalGemm( NORMAL, NORMAL, alpha, CT, DR, T(1), ETR );
1025         }
1026 
1027         // Recurse
1028         LocalTrr2k( uplo, alpha, AT, BL, CT, DL, beta, ETL );
1029         LocalTrr2k( uplo, alpha, AB, BR, CB, DR, beta, EBR );
1030     }
1031 }
1032 
1033 // E := alpha (A B + C D^{T/H}) + beta E
1034 template<typename T>
LocalTrr2k(UpperOrLower uplo,Orientation orientationOfD,T alpha,const DistMatrix<T,MC,STAR> & A,const DistMatrix<T,STAR,MR> & B,const DistMatrix<T,MC,STAR> & C,const DistMatrix<T,MR,STAR> & D,T beta,DistMatrix<T> & E)1035 void LocalTrr2k
1036 ( UpperOrLower uplo, Orientation orientationOfD,
1037   T alpha, const DistMatrix<T,MC,STAR>& A, const DistMatrix<T,STAR,MR>& B,
1038            const DistMatrix<T,MC,STAR>& C, const DistMatrix<T,MR,STAR>& D,
1039   T beta,        DistMatrix<T>& E  )
1040 {
1041     using namespace trr2k;
1042     DEBUG_ONLY(
1043         CallStackEntry cse("LocalTrr2k");
1044         CheckInput( A, B, C, D, E );
1045     )
1046     const Grid& g = E.Grid();
1047 
1048     if( E.Height() < g.Width()*LocalTrr2kBlocksize<T>() )
1049     {
1050         LocalTrr2kKernel( uplo, orientationOfD, alpha, A, B, C, D, beta, E );
1051     }
1052     else
1053     {
1054         // Split E in four roughly equal pieces, perform a large gemm on corner
1055         // and recurse on ETL and EBR.
1056         DistMatrix<T,MC,STAR> AT(g),  CT(g),
1057                               AB(g),  CB(g);
1058         DistMatrix<T,STAR,MR> BL(g), BR(g);
1059         DistMatrix<T,MR,STAR> DT(g), DB(g);
1060         DistMatrix<T> ETL(g), ETR(g),
1061                       EBL(g), EBR(g);
1062 
1063         const Int half = E.Height() / 2;
1064         LockedPartitionDown( A, AT, AB, half );
1065         LockedPartitionRight( B, BL, BR, half );
1066         LockedPartitionDown( C, CT, CB, half );
1067         LockedPartitionDown( D, DT, DB, half );
1068         PartitionDownDiagonal
1069         ( E, ETL, ETR,
1070              EBL, EBR, half );
1071 
1072         if( uplo == LOWER )
1073         {
1074             LocalGemm( NORMAL, NORMAL, alpha, AB, BL, T(1), EBL );
1075             LocalGemm( NORMAL, orientationOfD, alpha, CB, DT, beta, EBL );
1076         }
1077         else
1078         {
1079             LocalGemm( NORMAL, NORMAL, alpha, AT, BR, T(1), ETR );
1080             LocalGemm( NORMAL, orientationOfD, alpha, CT, DB, beta, ETR );
1081         }
1082 
1083         // Recurse
1084         LocalTrr2k( uplo, orientationOfD, alpha, AT, BL, CT, DT, beta, ETL );
1085         LocalTrr2k( uplo, orientationOfD, alpha, AB, BR, CB, DB, beta, EBR );
1086     }
1087 }
1088 
1089 // E := alpha (A B + C^{T/H} D) + beta E
1090 template<typename T>
LocalTrr2k(UpperOrLower uplo,Orientation orientationOfC,T alpha,const DistMatrix<T,MC,STAR> & A,const DistMatrix<T,STAR,MR> & B,const DistMatrix<T,STAR,MC> & C,const DistMatrix<T,STAR,MR> & D,T beta,DistMatrix<T> & E)1091 void LocalTrr2k
1092 ( UpperOrLower uplo, Orientation orientationOfC,
1093   T alpha, const DistMatrix<T,MC,STAR>& A, const DistMatrix<T,STAR,MR>& B,
1094            const DistMatrix<T,STAR,MC>& C, const DistMatrix<T,STAR,MR>& D,
1095   T beta,        DistMatrix<T>& E  )
1096 {
1097     using namespace trr2k;
1098     DEBUG_ONLY(
1099         CallStackEntry cse("LocalTrr2k");
1100         CheckInput( A, B, C, D, E );
1101     )
1102     const Grid& g = E.Grid();
1103 
1104     if( E.Height() < g.Width()*LocalTrr2kBlocksize<T>() )
1105     {
1106         LocalTrr2kKernel( uplo, orientationOfC, alpha, A, B, C, D, beta, E );
1107     }
1108     else
1109     {
1110         // Split E in four roughly equal pieces, perform a large gemm on corner
1111         // and recurse on ETL and EBR.
1112         DistMatrix<T,MC,STAR> AT(g), AB(g);
1113         DistMatrix<T,STAR,MR> BL(g), BR(g),
1114                               DL(g), DR(g);
1115         DistMatrix<T,STAR,MC> CL(g), CR(g);
1116         DistMatrix<T> ETL(g), ETR(g),
1117                       EBL(g), EBR(g);
1118 
1119         const Int half = E.Height() / 2;
1120         LockedPartitionDown( A, AT, AB, half );
1121         LockedPartitionRight( B, BL, BR, half );
1122         LockedPartitionRight( C, CL, CR, half );
1123         LockedPartitionRight( D, DL, DR, half );
1124         PartitionDownDiagonal
1125         ( E, ETL, ETR,
1126              EBL, EBR, half );
1127 
1128         if( uplo == LOWER )
1129         {
1130             LocalGemm( NORMAL, NORMAL, alpha, AB, BL, beta, EBL );
1131             LocalGemm( orientationOfC, NORMAL, alpha, CR, DL, T(1), EBL );
1132         }
1133         else
1134         {
1135             LocalGemm( NORMAL, NORMAL, alpha, AT, BR, beta, ETR );
1136             LocalGemm( orientationOfC, NORMAL, alpha, CL, DR, T(1), ETR );
1137         }
1138 
1139         // Recurse
1140         LocalTrr2k( uplo, orientationOfC, alpha, AT, BL, CL, DL, beta, ETL );
1141         LocalTrr2k( uplo, orientationOfC, alpha, AB, BR, CR, DR, beta, EBR );
1142     }
1143 }
1144 
1145 // E := alpha (A B + C^{T/H} D^{T/H}) + beta E
1146 template<typename T>
LocalTrr2k(UpperOrLower uplo,Orientation orientationOfC,Orientation orientationOfD,T alpha,const DistMatrix<T,MC,STAR> & A,const DistMatrix<T,STAR,MR> & B,const DistMatrix<T,STAR,MC> & C,const DistMatrix<T,MR,STAR> & D,T beta,DistMatrix<T> & E)1147 void LocalTrr2k
1148 ( UpperOrLower uplo, Orientation orientationOfC, Orientation orientationOfD,
1149   T alpha, const DistMatrix<T,MC,STAR>& A, const DistMatrix<T,STAR,MR>& B,
1150            const DistMatrix<T,STAR,MC>& C, const DistMatrix<T,MR,STAR>& D,
1151   T beta,        DistMatrix<T>& E  )
1152 {
1153     using namespace trr2k;
1154     DEBUG_ONLY(
1155         CallStackEntry cse("LocalTrr2k");
1156         CheckInput( A, B, C, D, E );
1157     )
1158     const Grid& g = E.Grid();
1159 
1160     if( E.Height() < g.Width()*LocalTrr2kBlocksize<T>() )
1161     {
1162         LocalTrr2kKernel
1163         ( uplo, orientationOfC, orientationOfD, alpha, A, B, C, D, beta, E );
1164     }
1165     else
1166     {
1167         // Split E in four roughly equal pieces, perform a large gemm on corner
1168         // and recurse on ETL and EBR.
1169         DistMatrix<T,MC,STAR> AT(g), AB(g);
1170         DistMatrix<T,STAR,MR> BL(g), BR(g);
1171         DistMatrix<T,STAR,MC> CL(g), CR(g);
1172         DistMatrix<T,MR,STAR> DT(g), DB(g);
1173         DistMatrix<T> ETL(g), ETR(g),
1174                       EBL(g), EBR(g);
1175 
1176         const Int half = E.Height() / 2;
1177         LockedPartitionDown( A, AT, AB, half );
1178         LockedPartitionRight( B, BL, BR, half );
1179         LockedPartitionRight( C, CL, CR, half );
1180         LockedPartitionDown( D, DT, DB, half );
1181         PartitionDownDiagonal
1182         ( E, ETL, ETR,
1183              EBL, EBR, half );
1184 
1185         if( uplo == LOWER )
1186         {
1187             LocalGemm( NORMAL, NORMAL, alpha, AB, BL, beta, EBL );
1188             LocalGemm
1189             ( orientationOfC, orientationOfD, alpha, CR, DT, T(1), EBL );
1190         }
1191         else
1192         {
1193             LocalGemm( NORMAL, NORMAL, alpha, AT, BR, beta, ETR );
1194             LocalGemm
1195             ( orientationOfC, orientationOfD, alpha, CL, DB, T(1), ETR );
1196         }
1197 
1198         // Recurse
1199         LocalTrr2k
1200         ( uplo, orientationOfC, orientationOfD,
1201           alpha, AT, BL, CL, DT, beta, ETL );
1202         LocalTrr2k
1203         ( uplo, orientationOfC, orientationOfD,
1204           alpha, AB, BR, CR, DB, beta, EBR );
1205     }
1206 }
1207 
1208 // E := alpha (A B^{T/H} + C D) + beta E
1209 template<typename T>
LocalTrr2k(UpperOrLower uplo,Orientation orientationOfB,T alpha,const DistMatrix<T,MC,STAR> & A,const DistMatrix<T,MR,STAR> & B,const DistMatrix<T,MC,STAR> & C,const DistMatrix<T,STAR,MR> & D,T beta,DistMatrix<T> & E)1210 void LocalTrr2k
1211 ( UpperOrLower uplo, Orientation orientationOfB,
1212   T alpha, const DistMatrix<T,MC,STAR>& A, const DistMatrix<T,MR,STAR>& B,
1213            const DistMatrix<T,MC,STAR>& C, const DistMatrix<T,STAR,MR>& D,
1214   T beta,        DistMatrix<T>& E  )
1215 {
1216     using namespace trr2k;
1217     DEBUG_ONLY(
1218         CallStackEntry cse("LocalTrr2k");
1219         CheckInput( A, B, C, D, E );
1220     )
1221     const Grid& g = E.Grid();
1222 
1223     if( E.Height() < g.Width()*LocalTrr2kBlocksize<T>() )
1224     {
1225         LocalTrr2kKernel( uplo, orientationOfB, alpha, A, B, C, D, beta, E );
1226     }
1227     else
1228     {
1229         // Split E in four roughly equal pieces, perform a large gemm on corner
1230         // and recurse on ETL and EBR.
1231         DistMatrix<T,MC,STAR> AT(g),  CT(g),
1232                               AB(g),  CB(g);
1233         DistMatrix<T,MR,STAR> BT(g), BB(g);
1234         DistMatrix<T,STAR,MR> DL(g), DR(g);
1235         DistMatrix<T> ETL(g), ETR(g),
1236                       EBL(g), EBR(g);
1237 
1238         const Int half = E.Height() / 2;
1239         LockedPartitionDown( A, AT, AB, half );
1240         LockedPartitionDown( B, BT, BB, half );
1241         LockedPartitionDown( C, CT, CB, half );
1242         LockedPartitionRight( D, DL, DR, half );
1243         PartitionDownDiagonal
1244         ( E, ETL, ETR,
1245              EBL, EBR, half );
1246 
1247         if( uplo == LOWER )
1248         {
1249             LocalGemm( NORMAL, orientationOfB, alpha, AB, BT, T(1), EBL );
1250             LocalGemm( NORMAL, NORMAL, alpha, CB, DL, beta, EBL );
1251         }
1252         else
1253         {
1254             LocalGemm( NORMAL, orientationOfB, alpha, AT, BB, T(1), ETR );
1255             LocalGemm( NORMAL, NORMAL, alpha, CT, DR, beta, ETR );
1256         }
1257 
1258         // Recurse
1259         LocalTrr2k( uplo, orientationOfB, alpha, AT, BT, CT, DL, beta, ETL );
1260         LocalTrr2k( uplo, orientationOfB, alpha, AB, BB, CB, DR, beta, EBR );
1261     }
1262 }
1263 
1264 // E := alpha (A B^{T/H} + C D^{T/H}) + beta E
1265 template<typename T>
LocalTrr2k(UpperOrLower uplo,Orientation orientationOfB,Orientation orientationOfD,T alpha,const DistMatrix<T,MC,STAR> & A,const DistMatrix<T,MR,STAR> & B,const DistMatrix<T,MC,STAR> & C,const DistMatrix<T,MR,STAR> & D,T beta,DistMatrix<T> & E)1266 void LocalTrr2k
1267 ( UpperOrLower uplo, Orientation orientationOfB, Orientation orientationOfD,
1268   T alpha, const DistMatrix<T,MC,STAR>& A, const DistMatrix<T,MR,STAR>& B,
1269            const DistMatrix<T,MC,STAR>& C, const DistMatrix<T,MR,STAR>& D,
1270   T beta,        DistMatrix<T>& E )
1271 {
1272     using namespace trr2k;
1273     DEBUG_ONLY(
1274         CallStackEntry cse("LocalTrr2k");
1275         CheckInput( A, B, C, D, E );
1276     )
1277     const Grid& g = E.Grid();
1278 
1279     if( E.Height() < g.Width()*LocalTrr2kBlocksize<T>() )
1280     {
1281         LocalTrr2kKernel
1282         ( uplo, orientationOfB, orientationOfD, alpha, A, B, C, D, beta, E );
1283     }
1284     else
1285     {
1286         // Split E in four roughly equal pieces, perform a large gemm on corner
1287         // and recurse on ETL and EBR.
1288         DistMatrix<T,MC,STAR> AT(g),  CT(g),
1289                               AB(g),  CB(g);
1290         DistMatrix<T,MR,STAR> BT(g),  DT(g),
1291                               BB(g),  DB(g);
1292         DistMatrix<T> ETL(g), ETR(g),
1293                       EBL(g), EBR(g);
1294 
1295         const Int half = E.Height() / 2;
1296         LockedPartitionDown( A, AT, AB, half );
1297         LockedPartitionDown( B, BT, BB, half );
1298         LockedPartitionDown( C, CT, CB, half );
1299         LockedPartitionDown( D, DT, DB, half );
1300         PartitionDownDiagonal
1301         ( E, ETL, ETR,
1302              EBL, EBR, half );
1303 
1304         if( uplo == LOWER )
1305         {
1306             LocalGemm( NORMAL, orientationOfB, alpha, AB, BT, beta, EBL );
1307             LocalGemm( NORMAL, orientationOfD, alpha, CB, DT, T(1), EBL );
1308         }
1309         else
1310         {
1311             LocalGemm( NORMAL, orientationOfB, alpha, AT, BB, beta, ETR );
1312             LocalGemm( NORMAL, orientationOfD, alpha, CT, DB, T(1), ETR );
1313         }
1314 
1315         // Recurse
1316         LocalTrr2k
1317         ( uplo, orientationOfB, orientationOfD,
1318           alpha, AT, BT, CT, DT, beta, ETL );
1319         LocalTrr2k
1320         ( uplo, orientationOfB, orientationOfD,
1321           alpha, AB, BB, CB, DB, beta, EBR );
1322     }
1323 }
1324 
1325 // E := alpha (A B^{T/H} + C^{T/H} D) + beta E
1326 template<typename T>
LocalTrr2k(UpperOrLower uplo,Orientation orientationOfB,Orientation orientationOfC,T alpha,const DistMatrix<T,MC,STAR> & A,const DistMatrix<T,MR,STAR> & B,const DistMatrix<T,STAR,MC> & C,const DistMatrix<T,STAR,MR> & D,T beta,DistMatrix<T> & E)1327 void LocalTrr2k
1328 ( UpperOrLower uplo, Orientation orientationOfB, Orientation orientationOfC,
1329   T alpha, const DistMatrix<T,MC,STAR>& A, const DistMatrix<T,MR,STAR>& B,
1330            const DistMatrix<T,STAR,MC>& C, const DistMatrix<T,STAR,MR>& D,
1331   T beta,        DistMatrix<T>& E  )
1332 {
1333     using namespace trr2k;
1334     DEBUG_ONLY(
1335         CallStackEntry cse("LocalTrr2k");
1336         CheckInput( A, B, C, D, E );
1337     )
1338     const Grid& g = E.Grid();
1339 
1340     if( E.Height() < g.Width()*LocalTrr2kBlocksize<T>() )
1341     {
1342         LocalTrr2kKernel
1343         ( uplo, orientationOfB, orientationOfC, alpha, A, B, C, D, beta, E );
1344     }
1345     else
1346     {
1347         // Split E in four roughly equal pieces, perform a large gemm on corner
1348         // and recurse on ETL and EBR.
1349         DistMatrix<T,MC,STAR> AT(g), AB(g);
1350         DistMatrix<T,MR,STAR> BT(g), BB(g);
1351         DistMatrix<T,STAR,MC> CL(g), CR(g);
1352         DistMatrix<T,STAR,MR> DL(g), DR(g);
1353         DistMatrix<T> ETL(g), ETR(g),
1354                       EBL(g), EBR(g);
1355 
1356         const Int half = E.Height() / 2;
1357         LockedPartitionDown( A, AT, AB, half );
1358         LockedPartitionDown( B, BT, BB, half );
1359         LockedPartitionRight( C, CL, CR, half );
1360         LockedPartitionRight( D, DL, DR, half );
1361         PartitionDownDiagonal
1362         ( E, ETL, ETR,
1363              EBL, EBR, half );
1364 
1365         if( uplo == LOWER )
1366         {
1367             LocalGemm( NORMAL, orientationOfB, alpha, AB, BT, beta, EBL );
1368             LocalGemm( orientationOfC, NORMAL, alpha, CR, DL, T(1), EBL );
1369         }
1370         else
1371         {
1372             LocalGemm( NORMAL, orientationOfB, alpha, AT, BB, beta, ETR );
1373             LocalGemm( orientationOfC, NORMAL, alpha, CL, DR, T(1), ETR );
1374         }
1375 
1376         // Recurse
1377         LocalTrr2k
1378         ( uplo, orientationOfB, orientationOfC,
1379           alpha, AT, BT, CL, DL, beta, ETL );
1380         LocalTrr2k
1381         ( uplo, orientationOfB, orientationOfC,
1382           alpha, AB, BB, CR, DR, beta, EBR );
1383     }
1384 }
1385 
1386 // E := alpha (A B^{T/H} + C^{T/H} D^{T/H}) + beta E
1387 template<typename T>
LocalTrr2k(UpperOrLower uplo,Orientation orientationOfB,Orientation orientationOfC,Orientation orientationOfD,T alpha,const DistMatrix<T,MC,STAR> & A,const DistMatrix<T,MR,STAR> & B,const DistMatrix<T,STAR,MC> & C,const DistMatrix<T,MR,STAR> & D,T beta,DistMatrix<T> & E)1388 void LocalTrr2k
1389 ( UpperOrLower uplo,
1390   Orientation orientationOfB,
1391   Orientation orientationOfC,
1392   Orientation orientationOfD,
1393   T alpha, const DistMatrix<T,MC,STAR>& A, const DistMatrix<T,MR,STAR>& B,
1394            const DistMatrix<T,STAR,MC>& C, const DistMatrix<T,MR,STAR>& D,
1395   T beta,        DistMatrix<T>& E  )
1396 {
1397     using namespace trr2k;
1398     DEBUG_ONLY(
1399         CallStackEntry cse("LocalTrr2k");
1400         CheckInput( A, B, C, D, E );
1401     )
1402     const Grid& g = E.Grid();
1403 
1404     if( E.Height() < g.Width()*LocalTrr2kBlocksize<T>() )
1405     {
1406         LocalTrr2kKernel
1407         ( uplo, orientationOfB, orientationOfC, orientationOfD,
1408           alpha, A, B, C, D, beta, E );
1409     }
1410     else
1411     {
1412         // Split E in four roughly equal pieces, perform a large gemm on corner
1413         // and recurse on ETL and EBR.
1414         DistMatrix<T,MC,STAR> AT(g), AB(g);
1415         DistMatrix<T,MR,STAR> BT(g),  DT(g),
1416                               BB(g),  DB(g);
1417         DistMatrix<T,STAR,MC> CL(g), CR(g);
1418         DistMatrix<T> ETL(g), ETR(g),
1419                       EBL(g), EBR(g);
1420 
1421         const Int half = E.Height() / 2;
1422         LockedPartitionDown( A, AT, AB, half );
1423         LockedPartitionDown( B, BT, BB, half );
1424         LockedPartitionRight( C, CL, CR, half );
1425         LockedPartitionDown( D, DT, DB, half );
1426         PartitionDownDiagonal
1427         ( E, ETL, ETR,
1428              EBL, EBR, half );
1429 
1430         if( uplo == LOWER )
1431         {
1432             LocalGemm( NORMAL, orientationOfB, alpha, AB, BT, beta, EBL );
1433             LocalGemm
1434             ( orientationOfC, orientationOfD, alpha, CR, DT, T(1), EBL );
1435         }
1436         else
1437         {
1438             LocalGemm( NORMAL, orientationOfB, alpha, AT, BB, beta, ETR );
1439             LocalGemm
1440             ( orientationOfC, orientationOfD, alpha, CL, DB, T(1), ETR );
1441         }
1442 
1443         // Recurse
1444         LocalTrr2k
1445         ( uplo, orientationOfB, orientationOfC, orientationOfD,
1446           alpha, AT, BT, CL, DT, beta, ETL );
1447         LocalTrr2k
1448         ( uplo, orientationOfB, orientationOfC, orientationOfD,
1449           alpha, AB, BB, CR, DB, beta, EBR );
1450     }
1451 }
1452 
1453 // E := alpha (A^{T/H} B + C D) + beta E
1454 template<typename T>
LocalTrr2k(UpperOrLower uplo,Orientation orientationOfA,T alpha,const DistMatrix<T,STAR,MC> & A,const DistMatrix<T,STAR,MR> & B,const DistMatrix<T,MC,STAR> & C,const DistMatrix<T,STAR,MR> & D,T beta,DistMatrix<T> & E)1455 void LocalTrr2k
1456 ( UpperOrLower uplo, Orientation orientationOfA,
1457   T alpha, const DistMatrix<T,STAR,MC>& A, const DistMatrix<T,STAR,MR>& B,
1458            const DistMatrix<T,MC,STAR>& C, const DistMatrix<T,STAR,MR>& D,
1459   T beta,        DistMatrix<T>& E  )
1460 {
1461     using namespace trr2k;
1462     DEBUG_ONLY(
1463         CallStackEntry cse("LocalTrr2k");
1464         CheckInput( A, B, C, D, E );
1465     )
1466     const Grid& g = E.Grid();
1467 
1468     if( E.Height() < g.Width()*LocalTrr2kBlocksize<T>() )
1469     {
1470         LocalTrr2kKernel( uplo, orientationOfA, alpha, A, B, C, D, beta, E );
1471     }
1472     else
1473     {
1474         // Split E in four roughly equal pieces, perform a large gemm on corner
1475         // and recurse on ETL and EBR.
1476         DistMatrix<T,STAR,MC> AL(g), AR(g);
1477         DistMatrix<T,STAR,MR> BL(g), BR(g),
1478                               DL(g), DR(g);
1479         DistMatrix<T,MC,STAR> CT(g), CB(g);
1480         DistMatrix<T> ETL(g), ETR(g),
1481                       EBL(g), EBR(g);
1482 
1483         const Int half = E.Height() / 2;
1484         LockedPartitionRight( A, AL, AR, half );
1485         LockedPartitionRight( B, BL, BR, half );
1486         LockedPartitionDown( C, CT, CB, half );
1487         LockedPartitionRight( D, DL, DR, half );
1488         PartitionDownDiagonal
1489         ( E, ETL, ETR,
1490              EBL, EBR, half );
1491 
1492         if( uplo == LOWER )
1493         {
1494             LocalGemm( orientationOfA, NORMAL, alpha, AR, BL, beta, EBL );
1495             LocalGemm( NORMAL, NORMAL, alpha, CB, DL, T(1), EBL );
1496         }
1497         else
1498         {
1499             LocalGemm( orientationOfA, NORMAL, alpha, AL, BR, beta, ETR );
1500             LocalGemm( NORMAL, NORMAL, alpha, CT, DR, T(1), ETR );
1501         }
1502 
1503         // Recurse
1504         LocalTrr2k( uplo, orientationOfA, alpha, AL, BL, CT, DL, beta, ETL );
1505         LocalTrr2k( uplo, orientationOfA, alpha, AR, BR, CB, DR, beta, EBR );
1506     }
1507 }
1508 
1509 // E := alpha (A^{T/H} B + C D^{T/H}) + beta E
1510 template<typename T>
LocalTrr2k(UpperOrLower uplo,Orientation orientationOfA,Orientation orientationOfD,T alpha,const DistMatrix<T,STAR,MC> & A,const DistMatrix<T,STAR,MR> & B,const DistMatrix<T,MC,STAR> & C,const DistMatrix<T,MR,STAR> & D,T beta,DistMatrix<T> & E)1511 void LocalTrr2k
1512 ( UpperOrLower uplo, Orientation orientationOfA, Orientation orientationOfD,
1513   T alpha, const DistMatrix<T,STAR,MC>& A, const DistMatrix<T,STAR,MR>& B,
1514            const DistMatrix<T,MC,STAR>& C, const DistMatrix<T,MR,STAR>& D,
1515   T beta,        DistMatrix<T>& E  )
1516 {
1517     using namespace trr2k;
1518     DEBUG_ONLY(
1519         CallStackEntry cse("LocalTrr2k");
1520         CheckInput( A, B, C, D, E );
1521     )
1522     const Grid& g = E.Grid();
1523 
1524     if( E.Height() < g.Width()*LocalTrr2kBlocksize<T>() )
1525     {
1526         LocalTrr2kKernel
1527         ( uplo, orientationOfA, orientationOfD, alpha, A, B, C, D, beta, E );
1528     }
1529     else
1530     {
1531         // Split E in four roughly equal pieces, perform a large gemm on corner
1532         // and recurse on ETL and EBR.
1533         DistMatrix<T,STAR,MC> AL(g), AR(g);
1534         DistMatrix<T,STAR,MR> BL(g), BR(g);
1535         DistMatrix<T,MC,STAR> CT(g), CB(g);
1536         DistMatrix<T,MR,STAR> DT(g), DB(g);
1537         DistMatrix<T> ETL(g), ETR(g),
1538                       EBL(g), EBR(g);
1539 
1540         const Int half = E.Height() / 2;
1541         LockedPartitionRight( A, AL, AR, half );
1542         LockedPartitionRight( B, BL, BR, half );
1543         LockedPartitionDown( C, CT, CB, half );
1544         LockedPartitionDown( D, DT, DB, half );
1545         PartitionDownDiagonal
1546         ( E, ETL, ETR,
1547              EBL, EBR, half );
1548 
1549         if( uplo == LOWER )
1550         {
1551             LocalGemm( orientationOfA, NORMAL, alpha, AR, BL, beta, EBL );
1552             LocalGemm( NORMAL, orientationOfD, alpha, CB, DT, T(1), EBL );
1553         }
1554         else
1555         {
1556             LocalGemm( orientationOfA, NORMAL, alpha, AL, BR, beta, ETR );
1557             LocalGemm( NORMAL, orientationOfD, alpha, CT, DB, T(1), ETR );
1558         }
1559 
1560         // Recurse
1561         LocalTrr2k
1562         ( uplo, orientationOfA, orientationOfD,
1563           alpha, AL, BL, CT, DT, beta, ETL );
1564         LocalTrr2k
1565         ( uplo, orientationOfA, orientationOfD,
1566           alpha, AR, BR, CB, DB, beta, EBR );
1567     }
1568 }
1569 
1570 // E := alpha (A^{T/H} B + C^{T/H} D) + beta E
1571 template<typename T>
LocalTrr2k(UpperOrLower uplo,Orientation orientationOfA,Orientation orientationOfC,T alpha,const DistMatrix<T,STAR,MC> & A,const DistMatrix<T,STAR,MR> & B,const DistMatrix<T,STAR,MC> & C,const DistMatrix<T,STAR,MR> & D,T beta,DistMatrix<T> & E)1572 void LocalTrr2k
1573 ( UpperOrLower uplo, Orientation orientationOfA, Orientation orientationOfC,
1574   T alpha, const DistMatrix<T,STAR,MC>& A, const DistMatrix<T,STAR,MR>& B,
1575            const DistMatrix<T,STAR,MC>& C, const DistMatrix<T,STAR,MR>& D,
1576   T beta,        DistMatrix<T>& E  )
1577 {
1578     using namespace trr2k;
1579     DEBUG_ONLY(
1580         CallStackEntry cse("LocalTrr2k");
1581         CheckInput( A, B, C, D, E );
1582     )
1583     const Grid& g = E.Grid();
1584 
1585     if( E.Height() < g.Width()*LocalTrr2kBlocksize<T>() )
1586     {
1587         LocalTrr2kKernel
1588         ( uplo, orientationOfA, orientationOfC, alpha, A, B, C, D, beta, E );
1589     }
1590     else
1591     {
1592         // Split E in four roughly equal pieces, perform a large gemm on corner
1593         // and recurse on ETL and EBR.
1594         DistMatrix<T,STAR,MC> AL(g), AR(g),
1595                               CL(g), CR(g);
1596         DistMatrix<T,STAR,MR> BL(g), BR(g),
1597                               DL(g), DR(g);
1598         DistMatrix<T> ETL(g), ETR(g),
1599                       EBL(g), EBR(g);
1600 
1601         const Int half = E.Height() / 2;
1602         LockedPartitionRight( A, AL, AR, half );
1603         LockedPartitionRight( B, BL, BR, half );
1604         LockedPartitionRight( C, CL, CR, half );
1605         LockedPartitionRight( D, DL, DR, half );
1606         PartitionDownDiagonal
1607         ( E, ETL, ETR,
1608              EBL, EBR, half );
1609 
1610         if( uplo == LOWER )
1611         {
1612             LocalGemm( orientationOfA, NORMAL, alpha, AR, BL, beta, EBL );
1613             LocalGemm( orientationOfC, NORMAL, alpha, CR, DL, T(1), EBL );
1614         }
1615         else
1616         {
1617             LocalGemm( orientationOfA, NORMAL, alpha, AL, BR, beta, ETR );
1618             LocalGemm( orientationOfC, NORMAL, alpha, CL, DR, T(1), ETR );
1619         }
1620 
1621         // Recurse
1622         LocalTrr2k
1623         ( uplo, orientationOfA, orientationOfC,
1624           alpha, AL, BL, CL, DL, beta, ETL );
1625         LocalTrr2k
1626         ( uplo, orientationOfA, orientationOfC,
1627           alpha, AR, BR, CR, DR, beta, EBR );
1628     }
1629 }
1630 
1631 // E := alpha (A^{T/H} B + C^{T/H} D^{T/H}) + beta E
1632 template<typename T>
LocalTrr2k(UpperOrLower uplo,Orientation orientationOfA,Orientation orientationOfC,Orientation orientationOfD,T alpha,const DistMatrix<T,STAR,MC> & A,const DistMatrix<T,STAR,MR> & B,const DistMatrix<T,STAR,MC> & C,const DistMatrix<T,MR,STAR> & D,T beta,DistMatrix<T> & E)1633 void LocalTrr2k
1634 ( UpperOrLower uplo,
1635   Orientation orientationOfA,
1636   Orientation orientationOfC,
1637   Orientation orientationOfD,
1638   T alpha, const DistMatrix<T,STAR,MC>& A, const DistMatrix<T,STAR,MR>& B,
1639            const DistMatrix<T,STAR,MC>& C, const DistMatrix<T,MR,STAR>& D,
1640   T beta,        DistMatrix<T>& E  )
1641 {
1642     using namespace trr2k;
1643     DEBUG_ONLY(
1644         CallStackEntry cse("LocalTrr2k");
1645         CheckInput( A, B, C, D, E );
1646     )
1647     const Grid& g = E.Grid();
1648 
1649     if( E.Height() < g.Width()*LocalTrr2kBlocksize<T>() )
1650     {
1651         LocalTrr2kKernel
1652         ( uplo, orientationOfA, orientationOfC, orientationOfD,
1653           alpha, A, B, C, D, beta, E );
1654     }
1655     else
1656     {
1657         // Split E in four roughly equal pieces, perform a large gemm on corner
1658         // and recurse on ETL and EBR.
1659         DistMatrix<T,STAR,MC> AL(g), AR(g),
1660                               CL(g), CR(g);
1661         DistMatrix<T,STAR,MR> BL(g), BR(g);
1662         DistMatrix<T,MR,STAR> DT(g), DB(g);
1663         DistMatrix<T> ETL(g), ETR(g),
1664                       EBL(g), EBR(g);
1665 
1666         const Int half = E.Height() / 2;
1667         LockedPartitionRight( A, AL, AR, half );
1668         LockedPartitionRight( B, BL, BR, half );
1669         LockedPartitionRight( C, CL, CR, half );
1670         LockedPartitionDown( D, DT, DB, half );
1671         PartitionDownDiagonal
1672         ( E, ETL, ETR,
1673              EBL, EBR, half );
1674 
1675         if( uplo == LOWER )
1676         {
1677             LocalGemm( orientationOfA, NORMAL, alpha, AR, BL, beta, EBL );
1678             LocalGemm
1679             ( orientationOfC, orientationOfD, alpha, CR, DT, T(1), EBL );
1680         }
1681         else
1682         {
1683             LocalGemm( orientationOfA, NORMAL, alpha, AL, BR, beta, ETR );
1684             LocalGemm
1685             ( orientationOfC, orientationOfD, alpha, CL, DB, T(1), ETR );
1686         }
1687 
1688         // Recurse
1689         LocalTrr2k
1690         ( uplo, orientationOfA, orientationOfC, orientationOfD,
1691           alpha, AL, BL, CL, DT, beta, ETL );
1692         LocalTrr2k
1693         ( uplo, orientationOfA, orientationOfC, orientationOfD,
1694           alpha, AR, BR, CR, DB, beta, EBR );
1695     }
1696 }
1697 
1698 // E := alpha (A^{T/H} B^{T/H} + C D) + beta E
1699 template<typename T>
LocalTrr2k(UpperOrLower uplo,Orientation orientationOfA,Orientation orientationOfB,T alpha,const DistMatrix<T,STAR,MC> & A,const DistMatrix<T,MR,STAR> & B,const DistMatrix<T,MC,STAR> & C,const DistMatrix<T,STAR,MR> & D,T beta,DistMatrix<T> & E)1700 void LocalTrr2k
1701 ( UpperOrLower uplo, Orientation orientationOfA, Orientation orientationOfB,
1702   T alpha, const DistMatrix<T,STAR,MC>& A, const DistMatrix<T,MR,STAR>& B,
1703            const DistMatrix<T,MC,STAR>& C, const DistMatrix<T,STAR,MR>& D,
1704   T beta,        DistMatrix<T>& E  )
1705 {
1706     using namespace trr2k;
1707     DEBUG_ONLY(
1708         CallStackEntry cse("LocalTrr2k");
1709         CheckInput( A, B, C, D, E );
1710     )
1711     const Grid& g = E.Grid();
1712 
1713     if( E.Height() < g.Width()*LocalTrr2kBlocksize<T>() )
1714     {
1715         LocalTrr2kKernel
1716         ( uplo, orientationOfA, orientationOfB, alpha, A, B, C, D, beta, E );
1717     }
1718     else
1719     {
1720         // Split E in four roughly equal pieces, perform a large gemm on corner
1721         // and recurse on ETL and EBR.
1722         DistMatrix<T,STAR,MC> AL(g), AR(g);
1723         DistMatrix<T,MR,STAR> BT(g), BB(g);
1724         DistMatrix<T,MC,STAR> CT(g), CB(g);
1725         DistMatrix<T,STAR,MR> DL(g), DR(g);
1726         DistMatrix<T> ETL(g), ETR(g),
1727                       EBL(g), EBR(g);
1728 
1729         const Int half = E.Height() / 2;
1730         LockedPartitionRight( A, AL, AR, half );
1731         LockedPartitionDown( B, BT, BB, half );
1732         LockedPartitionDown( C, CT, CB, half );
1733         LockedPartitionRight( D, DL, DR, half );
1734         PartitionDownDiagonal
1735         ( E, ETL, ETR,
1736              EBL, EBR, half );
1737 
1738         if( uplo == LOWER )
1739         {
1740             LocalGemm
1741             ( orientationOfA, orientationOfB, alpha, AR, BT, beta, EBL );
1742             LocalGemm( NORMAL, NORMAL, alpha, CB, DL, T(1), EBL );
1743         }
1744         else
1745         {
1746             LocalGemm
1747             ( orientationOfA, orientationOfB, alpha, AL, BB, beta, ETR );
1748             LocalGemm( NORMAL, NORMAL, alpha, CT, DR, T(1), ETR );
1749         }
1750 
1751         // Recurse
1752         LocalTrr2k
1753         ( uplo, orientationOfA, orientationOfB,
1754           alpha, AL, BT, CT, DL, beta, ETL );
1755         LocalTrr2k
1756         ( uplo, orientationOfA, orientationOfB,
1757           alpha, AR, BB, CB, DR, beta, EBR );
1758     }
1759 }
1760 
1761 // E := alpha (A^{T/H} B^{T/H} + C D^{T/H}) + beta E
1762 template<typename T>
LocalTrr2k(UpperOrLower uplo,Orientation orientationOfA,Orientation orientationOfB,Orientation orientationOfD,T alpha,const DistMatrix<T,STAR,MC> & A,const DistMatrix<T,MR,STAR> & B,const DistMatrix<T,MC,STAR> & C,const DistMatrix<T,MR,STAR> & D,T beta,DistMatrix<T> & E)1763 void LocalTrr2k
1764 ( UpperOrLower uplo,
1765   Orientation orientationOfA,
1766   Orientation orientationOfB,
1767   Orientation orientationOfD,
1768   T alpha, const DistMatrix<T,STAR,MC>& A, const DistMatrix<T,MR,STAR>& B,
1769            const DistMatrix<T,MC,STAR>& C, const DistMatrix<T,MR,STAR>& D,
1770   T beta,        DistMatrix<T>& E  )
1771 {
1772     using namespace trr2k;
1773     DEBUG_ONLY(
1774         CallStackEntry cse("LocalTrr2k");
1775         CheckInput( A, B, C, D, E );
1776     )
1777     const Grid& g = E.Grid();
1778 
1779     if( E.Height() < g.Width()*LocalTrr2kBlocksize<T>() )
1780     {
1781         LocalTrr2kKernel
1782         ( uplo, orientationOfA, orientationOfB, orientationOfD,
1783           alpha, A, B, C, D, beta, E );
1784     }
1785     else
1786     {
1787         // Split E in four roughly equal pieces, perform a large gemm on corner
1788         // and recurse on ETL and EBR.
1789         DistMatrix<T,STAR,MC> AL(g), AR(g);
1790         DistMatrix<T,MR,STAR> BT(g),  DT(g),
1791                               BB(g),  DB(g);
1792         DistMatrix<T,MC,STAR> CT(g), CB(g);
1793         DistMatrix<T> ETL(g), ETR(g),
1794                       EBL(g), EBR(g);
1795 
1796         const Int half = E.Height() / 2;
1797         LockedPartitionRight( A, AL, AR, half );
1798         LockedPartitionDown( B, BT, BB, half );
1799         LockedPartitionDown( C, CT, CB, half );
1800         LockedPartitionDown( D, DT, DB, half );
1801         PartitionDownDiagonal
1802         ( E, ETL, ETR,
1803              EBL, EBR, half );
1804 
1805         if( uplo == LOWER )
1806         {
1807             LocalGemm
1808             ( orientationOfA, orientationOfB, alpha, AR, BT, beta, EBL );
1809             LocalGemm( NORMAL, orientationOfD, alpha, CB, DT, T(1), EBL );
1810         }
1811         else
1812         {
1813             LocalGemm
1814             ( orientationOfA, orientationOfB, alpha, AL, BB, beta, ETR );
1815             LocalGemm( NORMAL, orientationOfD, alpha, CT, DB, T(1), ETR );
1816         }
1817 
1818         // Recurse
1819         LocalTrr2k
1820         ( uplo, orientationOfA, orientationOfB, orientationOfD,
1821           alpha, AL, BT, CT, DT, beta, ETL );
1822         LocalTrr2k
1823         ( uplo, orientationOfA, orientationOfB, orientationOfD,
1824           alpha, AR, BB, CB, DB, beta, EBR );
1825     }
1826 }
1827 
1828 // E := alpha (A^{T/H} B^{T/H} + C^{T/H} D) + beta E
1829 template<typename T>
LocalTrr2k(UpperOrLower uplo,Orientation orientationOfA,Orientation orientationOfB,Orientation orientationOfC,T alpha,const DistMatrix<T,STAR,MC> & A,const DistMatrix<T,MR,STAR> & B,const DistMatrix<T,STAR,MC> & C,const DistMatrix<T,STAR,MR> & D,T beta,DistMatrix<T> & E)1830 void LocalTrr2k
1831 ( UpperOrLower uplo,
1832   Orientation orientationOfA,
1833   Orientation orientationOfB,
1834   Orientation orientationOfC,
1835   T alpha, const DistMatrix<T,STAR,MC>& A, const DistMatrix<T,MR,STAR>& B,
1836            const DistMatrix<T,STAR,MC>& C, const DistMatrix<T,STAR,MR>& D,
1837   T beta,        DistMatrix<T>& E  )
1838 {
1839     using namespace trr2k;
1840     DEBUG_ONLY(
1841         CallStackEntry cse("LocalTrr2k");
1842         CheckInput( A, B, C, D, E );
1843     )
1844     const Grid& g = E.Grid();
1845 
1846     if( E.Height() < g.Width()*LocalTrr2kBlocksize<T>() )
1847     {
1848         LocalTrr2kKernel
1849         ( uplo, orientationOfA, orientationOfB, orientationOfC,
1850           alpha, A, B, C, D, beta, E );
1851     }
1852     else
1853     {
1854         // Split E in four roughly equal pieces, perform a large gemm on corner
1855         // and recurse on ETL and EBR.
1856         DistMatrix<T,STAR,MC> AL(g), AR(g),
1857                               CL(g), CR(g);
1858         DistMatrix<T,MR,STAR> BT(g), BB(g);
1859         DistMatrix<T,STAR,MR> DL(g), DR(g);
1860         DistMatrix<T> ETL(g), ETR(g),
1861                       EBL(g), EBR(g);
1862 
1863         const Int half = E.Height() / 2;
1864         LockedPartitionRight( A, AL, AR, half );
1865         LockedPartitionDown( B, BT, BB, half );
1866         LockedPartitionRight( C, CL, CR, half );
1867         LockedPartitionRight( D, DL, DR, half );
1868         PartitionDownDiagonal
1869         ( E, ETL, ETR,
1870              EBL, EBR, half );
1871 
1872         if( uplo == LOWER )
1873         {
1874             LocalGemm
1875             ( orientationOfA, orientationOfB, alpha, AR, BT, beta, EBL );
1876             LocalGemm( orientationOfC, NORMAL, alpha, CR, DL, T(1), EBL );
1877         }
1878         else
1879         {
1880             LocalGemm
1881             ( orientationOfA, orientationOfB, alpha, AL, BB, beta, ETR );
1882             LocalGemm( orientationOfC, NORMAL, alpha, CL, DR, T(1), ETR );
1883         }
1884 
1885         // Recurse
1886         LocalTrr2k
1887         ( uplo, orientationOfA, orientationOfB, orientationOfC,
1888           alpha, AL, BT, CL, DL, beta, ETL );
1889         LocalTrr2k
1890         ( uplo, orientationOfA, orientationOfB, orientationOfC,
1891           alpha, AR, BB, CR, DR, beta, EBR );
1892     }
1893 }
1894 
1895 // E := alpha (A^{T/H} B^{T/H} + C^{T/H} D^{T/H}) + beta E
1896 template<typename T>
LocalTrr2k(UpperOrLower uplo,Orientation orientationOfA,Orientation orientationOfB,Orientation orientationOfC,Orientation orientationOfD,T alpha,const DistMatrix<T,STAR,MC> & A,const DistMatrix<T,MR,STAR> & B,const DistMatrix<T,STAR,MC> & C,const DistMatrix<T,MR,STAR> & D,T beta,DistMatrix<T> & E)1897 void LocalTrr2k
1898 ( UpperOrLower uplo,
1899   Orientation orientationOfA,
1900   Orientation orientationOfB,
1901   Orientation orientationOfC,
1902   Orientation orientationOfD,
1903   T alpha, const DistMatrix<T,STAR,MC>& A, const DistMatrix<T,MR,STAR>& B,
1904            const DistMatrix<T,STAR,MC>& C, const DistMatrix<T,MR,STAR>& D,
1905   T beta,        DistMatrix<T>& E  )
1906 {
1907     using namespace trr2k;
1908     DEBUG_ONLY(
1909         CallStackEntry cse("LocalTrr2k");
1910         CheckInput( A, B, C, D, E );
1911     )
1912     const Grid& g = E.Grid();
1913 
1914     if( E.Height() < g.Width()*LocalTrr2kBlocksize<T>() )
1915     {
1916         LocalTrr2kKernel
1917         ( uplo, orientationOfA, orientationOfB, orientationOfC, orientationOfD,
1918           alpha, A, B, C, D, beta, E );
1919     }
1920     else
1921     {
1922         // Split E in four roughly equal pieces, perform a large gemm on corner
1923         // and recurse on ETL and EBR.
1924         DistMatrix<T,STAR,MC> AL(g), AR(g),
1925                               CL(g), CR(g);
1926         DistMatrix<T,MR,STAR> BT(g),  DT(g),
1927                               BB(g),  DB(g);
1928         DistMatrix<T> ETL(g), ETR(g),
1929                       EBL(g), EBR(g);
1930 
1931         const Int half = E.Height() / 2;
1932         LockedPartitionRight( A, AL, AR, half );
1933         LockedPartitionDown( B, BT, BB, half );
1934         LockedPartitionRight( C, CL, CR, half );
1935         LockedPartitionDown( D, DT, DB, half );
1936         PartitionDownDiagonal
1937         ( E, ETL, ETR,
1938              EBL, EBR, half );
1939 
1940         if( uplo == LOWER )
1941         {
1942             LocalGemm
1943             ( orientationOfA, orientationOfB, alpha, AR, BT, beta, EBL );
1944             LocalGemm
1945             ( orientationOfC, orientationOfD, alpha, CR, DT, T(1), EBL );
1946         }
1947         else
1948         {
1949             LocalGemm
1950             ( orientationOfA, orientationOfB, alpha, AL, BB, beta, ETR );
1951             LocalGemm
1952             ( orientationOfC, orientationOfD, alpha, CL, DB, T(1), ETR );
1953         }
1954 
1955         // Recurse
1956         LocalTrr2k
1957         ( uplo,
1958           orientationOfA, orientationOfB, orientationOfC, orientationOfD,
1959           alpha, AL, BT, CL, DT, beta, ETL );
1960 
1961         LocalTrr2k
1962         ( uplo,
1963           orientationOfA, orientationOfB, orientationOfC, orientationOfD,
1964           alpha, AR, BB, CR, DB, beta, EBR );
1965     }
1966 }
1967 
1968 } // namespace elem
1969 
1970 #endif // ifndef ELEM_TRR2K_LOCAL_HPP
1971