1 /*
2    Copyright (c) 2009-2014, Jack Poulson
3    All rights reserved.
4 
5    Copyright (c) 2013, The University of Texas at Austin
6    All rights reserved.
7 
8    This file is part of Elemental and is under the BSD 2-Clause License,
9    which can be found in the LICENSE file in the root directory, or at
10    http://opensource.org/licenses/BSD-2-Clause
11 */
12 #pragma once
13 #ifndef ELEM_TRMM_LLN_HPP
14 #define ELEM_TRMM_LLN_HPP
15 
16 #include ELEM_AXPY_INC
17 #include ELEM_MAKETRIANGULAR_INC
18 #include ELEM_SCALE_INC
19 #include ELEM_SETDIAGONAL_INC
20 #include ELEM_TRANSPOSE_INC
21 
22 #include ELEM_GEMM_INC
23 
24 #include ELEM_ZEROS_INC
25 
26 namespace elem {
27 namespace trmm {
28 
29 template<typename T>
30 inline void
LocalAccumulateLLN(Orientation orientation,UnitOrNonUnit diag,T alpha,const DistMatrix<T,MC,MR> & L,const DistMatrix<T,STAR,MR> & XTrans,DistMatrix<T,MC,STAR> & Z)31 LocalAccumulateLLN
32 ( Orientation orientation, UnitOrNonUnit diag, T alpha,
33   const DistMatrix<T,MC,  MR  >& L,
34   const DistMatrix<T,STAR,MR  >& XTrans,
35         DistMatrix<T,MC,  STAR>& Z )
36 {
37     DEBUG_ONLY(
38         CallStackEntry cse("trmm::LocalAccumulateLLN");
39         if( L.Grid() != XTrans.Grid() ||
40             XTrans.Grid() != Z.Grid() )
41             LogicError("{L,X,Z} must be distributed over the same grid");
42         if( L.Height() != L.Width() ||
43             L.Height() != XTrans.Width() ||
44             L.Height() != Z.Height() ||
45             XTrans.Height() != Z.Width() )
46             LogicError
47             ("Nonconformal: \n",
48              "  L ~ ",L.Height()," x ",L.Width(),"\n",
49              "  X^H/T[* ,MR] ~ ",XTrans.Height()," x ",
50                                  XTrans.Width(),"\n",
51              "  Z[MC,* ] ~ ",Z.Height()," x ",Z.Width());
52         if( XTrans.RowAlign() != L.RowAlign() ||
53             Z.ColAlign() != L.ColAlign() )
54             LogicError("Partial matrix distributions are misaligned");
55     )
56     const Int m = Z.Height();
57     const Int n = Z.Width();
58     const Int bsize = Blocksize();
59     const Grid& g = L.Grid();
60 
61     DistMatrix<T> D11(g);
62     const Int ratio = Max( g.Height(), g.Width() );
63     for( Int k=0; k<m; k+=ratio*bsize )
64     {
65         const Int nb = Min(ratio*bsize,m-k);
66 
67         auto L11 = LockedViewRange( L, k,    k, k+nb, k+nb );
68         auto L21 = LockedViewRange( L, k+nb, k, m,    k+nb );
69 
70         auto X1Trans = LockedViewRange( XTrans, 0, k, n, k+nb );
71 
72         auto Z1 = ViewRange( Z, k,    0, k+nb, n );
73         auto Z2 = ViewRange( Z, k+nb, 0, m,    n );
74 
75         D11.AlignWith( L11 );
76         D11 = L11;
77         MakeTriangular( LOWER, D11 );
78         if( diag == UNIT )
79             SetDiagonal( D11, T(1) );
80         LocalGemm( NORMAL, orientation, alpha, D11, X1Trans, T(1), Z1 );
81         LocalGemm( NORMAL, orientation, alpha, L21, X1Trans, T(1), Z2 );
82     }
83 }
84 
85 template<typename T>
86 inline void
LLNA(UnitOrNonUnit diag,const DistMatrix<T> & L,DistMatrix<T> & X)87 LLNA( UnitOrNonUnit diag, const DistMatrix<T>& L, DistMatrix<T>& X )
88 {
89     DEBUG_ONLY(
90         CallStackEntry cse("trmm::LLNA");
91         if( L.Grid() != X.Grid() )
92             LogicError("L and X must be distributed over the same grid");
93         if( L.Height() != L.Width() || L.Width() != X.Height() )
94             LogicError
95             ("Nonconformal: \n"
96              "  L ~ ",L.Height()," x ",L.Width(),"\n",
97              "  X ~ ",X.Height()," x ",X.Width());
98     )
99     const Int m = X.Height();
100     const Int n = X.Width();
101     const Int bsize = Blocksize();
102     const Grid& g = L.Grid();
103 
104     DistMatrix<T,VR,  STAR> X1_VR_STAR(g);
105     DistMatrix<T,STAR,MR  > X1Trans_STAR_MR(g);
106     DistMatrix<T,MC,  STAR> Z1_MC_STAR(g);
107 
108     X1_VR_STAR.AlignWith( L );
109     X1Trans_STAR_MR.AlignWith( L );
110     Z1_MC_STAR.AlignWith( L );
111 
112     for( Int k=0; k<n; k+=bsize )
113     {
114         const Int nb = Min(bsize,n-k);
115 
116         auto X1 = ViewRange( X, 0, k, m, k+nb );
117 
118         X1_VR_STAR = X1;
119         X1_VR_STAR.TransposePartialColAllGather( X1Trans_STAR_MR );
120         Zeros( Z1_MC_STAR, m, nb );
121         LocalAccumulateLLN
122         ( TRANSPOSE, diag, T(1), L, X1Trans_STAR_MR, Z1_MC_STAR );
123         X1.RowSumScatterFrom( Z1_MC_STAR );
124     }
125 }
126 
127 template<typename T>
128 inline void
LLNCOld(UnitOrNonUnit diag,const DistMatrix<T> & L,DistMatrix<T> & X)129 LLNCOld( UnitOrNonUnit diag, const DistMatrix<T>& L, DistMatrix<T>& X )
130 {
131     DEBUG_ONLY(
132         CallStackEntry cse("trmm::LLNCOld");
133         if( L.Grid() != X.Grid() )
134             LogicError("L and X must be distributed over the same grid");
135         if( L.Height() != L.Width() || L.Width() != X.Height() )
136             LogicError
137             ("Nonconformal: \n",
138              "  L ~ ",L.Height()," x ",L.Width(),"\n",
139              "  X ~ ",X.Height()," x ",X.Width());
140     )
141     const Int m = X.Height();
142     const Int n = X.Width();
143     const Int bsize = Blocksize();
144     const Grid& g = L.Grid();
145 
146     DistMatrix<T,STAR,MC  > L10_STAR_MC(g);
147     DistMatrix<T,STAR,STAR> L11_STAR_STAR(g);
148     DistMatrix<T,STAR,VR  > X1_STAR_VR(g);
149     DistMatrix<T,MR,  STAR> D1Trans_MR_STAR(g);
150     DistMatrix<T,MR,  MC  > D1Trans_MR_MC(g);
151     DistMatrix<T,MC,  MR  > D1(g);
152 
153     const Int kLast = LastOffset( m, bsize );
154     for( Int k=kLast; k>=0; k-=bsize )
155     {
156         const Int nb = Min(bsize,m-k);
157 
158         auto L10 = LockedViewRange( L, k, 0, k+nb, k    );
159         auto L11 = LockedViewRange( L, k, k, k+nb, k+nb );
160 
161         auto X0 = ViewRange( X, 0, 0, k,    n );
162         auto X1 = ViewRange( X, k, 0, k+nb, n );
163 
164         L11_STAR_STAR = L11;
165         X1_STAR_VR = X1;
166         LocalTrmm( LEFT, LOWER, NORMAL, diag, T(1), L11_STAR_STAR, X1_STAR_VR );
167         X1 = X1_STAR_VR;
168 
169         L10_STAR_MC.AlignWith( X0 );
170         L10_STAR_MC = L10;
171         D1Trans_MR_STAR.AlignWith( X1 );
172         LocalGemm
173         ( TRANSPOSE, TRANSPOSE, T(1), X0, L10_STAR_MC, D1Trans_MR_STAR );
174         D1Trans_MR_MC.AlignWith( X1 );
175         D1Trans_MR_MC.RowSumScatterFrom( D1Trans_MR_STAR );
176         D1.AlignWith( X1 );
177         Zeros( D1, nb, n );
178         Transpose( D1Trans_MR_MC.Matrix(), D1.Matrix() );
179         Axpy( T(1), D1, X1 );
180     }
181 }
182 
183 template<typename T>
184 inline void
LLNC(UnitOrNonUnit diag,const DistMatrix<T> & L,DistMatrix<T> & X)185 LLNC( UnitOrNonUnit diag, const DistMatrix<T>& L, DistMatrix<T>& X )
186 {
187     DEBUG_ONLY(
188         CallStackEntry cse("trmm::LLNC");
189         if( L.Grid() != X.Grid() )
190             LogicError("L and X must be distributed over the same grid");
191         if( L.Height() != L.Width() || L.Width() != X.Height() )
192             LogicError
193             ("Nonconformal: \n",
194              "  L ~ ",L.Height()," x ",L.Width(),"\n",
195              "  X ~ ",X.Height()," x ",X.Width());
196     )
197     const Int m = X.Height();
198     const Int n = X.Width();
199     const Int bsize = Blocksize();
200     const Grid& g = L.Grid();
201 
202     DistMatrix<T,MC,  STAR> L21_MC_STAR(g);
203     DistMatrix<T,STAR,STAR> L11_STAR_STAR(g);
204     DistMatrix<T,STAR,VR  > X1_STAR_VR(g);
205     DistMatrix<T,MR,  STAR> X1Trans_MR_STAR(g);
206 
207     const Int kLast = LastOffset( m, bsize );
208     for( Int k=kLast; k>=0; k-=bsize )
209     {
210         const Int nb = Min(bsize,m-k);
211 
212         auto L11 = LockedViewRange( L, k,    k, k+nb, k+nb );
213         auto L21 = LockedViewRange( L, k+nb, k, m,    k+nb );
214 
215         auto X1 = ViewRange( X, k,    0, k+nb, n );
216         auto X2 = ViewRange( X, k+nb, 0, m,    n );
217 
218         L21_MC_STAR.AlignWith( X2 );
219         L21_MC_STAR = L21;
220         X1Trans_MR_STAR.AlignWith( X2 );
221         X1.TransposeColAllGather( X1Trans_MR_STAR );
222         LocalGemm
223         ( NORMAL, TRANSPOSE, T(1), L21_MC_STAR, X1Trans_MR_STAR, T(1), X2 );
224 
225         L11_STAR_STAR = L11;
226         X1_STAR_VR.AlignWith( X1 );
227         X1_STAR_VR.TransposePartialRowFilterFrom( X1Trans_MR_STAR );
228         LocalTrmm( LEFT, LOWER, NORMAL, diag, T(1), L11_STAR_STAR, X1_STAR_VR );
229         X1 = X1_STAR_VR;
230     }
231 }
232 
233 // Left Lower Normal (Non)Unit Trmm
234 //   X := tril(L)  X, or
235 //   X := trilu(L) X
236 template<typename T>
237 inline void
LLN(UnitOrNonUnit diag,const DistMatrix<T> & L,DistMatrix<T> & X)238 LLN( UnitOrNonUnit diag, const DistMatrix<T>& L, DistMatrix<T>& X )
239 {
240     DEBUG_ONLY(CallStackEntry cse("trmm::LLN"))
241     // TODO: Come up with a better routing mechanism
242     if( L.Height() > 5*X.Width() )
243         LLNA( diag, L, X );
244     else
245         LLNC( diag, L, X );
246 }
247 
248 } // namespace trmm
249 } // namespace elem
250 
251 #endif // ifndef ELEM_TRMM_LLN_HPP
252