1 /*
2 Copyright (c) 2009-2014, Jack Poulson
3 All rights reserved.
4
5 Copyright (c) 2013, The University of Texas at Austin
6 All rights reserved.
7
8 This file is part of Elemental and is under the BSD 2-Clause License,
9 which can be found in the LICENSE file in the root directory, or at
10 http://opensource.org/licenses/BSD-2-Clause
11 */
12 #pragma once
13 #ifndef ELEM_TRMM_LLN_HPP
14 #define ELEM_TRMM_LLN_HPP
15
16 #include ELEM_AXPY_INC
17 #include ELEM_MAKETRIANGULAR_INC
18 #include ELEM_SCALE_INC
19 #include ELEM_SETDIAGONAL_INC
20 #include ELEM_TRANSPOSE_INC
21
22 #include ELEM_GEMM_INC
23
24 #include ELEM_ZEROS_INC
25
26 namespace elem {
27 namespace trmm {
28
29 template<typename T>
30 inline void
LocalAccumulateLLN(Orientation orientation,UnitOrNonUnit diag,T alpha,const DistMatrix<T,MC,MR> & L,const DistMatrix<T,STAR,MR> & XTrans,DistMatrix<T,MC,STAR> & Z)31 LocalAccumulateLLN
32 ( Orientation orientation, UnitOrNonUnit diag, T alpha,
33 const DistMatrix<T,MC, MR >& L,
34 const DistMatrix<T,STAR,MR >& XTrans,
35 DistMatrix<T,MC, STAR>& Z )
36 {
37 DEBUG_ONLY(
38 CallStackEntry cse("trmm::LocalAccumulateLLN");
39 if( L.Grid() != XTrans.Grid() ||
40 XTrans.Grid() != Z.Grid() )
41 LogicError("{L,X,Z} must be distributed over the same grid");
42 if( L.Height() != L.Width() ||
43 L.Height() != XTrans.Width() ||
44 L.Height() != Z.Height() ||
45 XTrans.Height() != Z.Width() )
46 LogicError
47 ("Nonconformal: \n",
48 " L ~ ",L.Height()," x ",L.Width(),"\n",
49 " X^H/T[* ,MR] ~ ",XTrans.Height()," x ",
50 XTrans.Width(),"\n",
51 " Z[MC,* ] ~ ",Z.Height()," x ",Z.Width());
52 if( XTrans.RowAlign() != L.RowAlign() ||
53 Z.ColAlign() != L.ColAlign() )
54 LogicError("Partial matrix distributions are misaligned");
55 )
56 const Int m = Z.Height();
57 const Int n = Z.Width();
58 const Int bsize = Blocksize();
59 const Grid& g = L.Grid();
60
61 DistMatrix<T> D11(g);
62 const Int ratio = Max( g.Height(), g.Width() );
63 for( Int k=0; k<m; k+=ratio*bsize )
64 {
65 const Int nb = Min(ratio*bsize,m-k);
66
67 auto L11 = LockedViewRange( L, k, k, k+nb, k+nb );
68 auto L21 = LockedViewRange( L, k+nb, k, m, k+nb );
69
70 auto X1Trans = LockedViewRange( XTrans, 0, k, n, k+nb );
71
72 auto Z1 = ViewRange( Z, k, 0, k+nb, n );
73 auto Z2 = ViewRange( Z, k+nb, 0, m, n );
74
75 D11.AlignWith( L11 );
76 D11 = L11;
77 MakeTriangular( LOWER, D11 );
78 if( diag == UNIT )
79 SetDiagonal( D11, T(1) );
80 LocalGemm( NORMAL, orientation, alpha, D11, X1Trans, T(1), Z1 );
81 LocalGemm( NORMAL, orientation, alpha, L21, X1Trans, T(1), Z2 );
82 }
83 }
84
85 template<typename T>
86 inline void
LLNA(UnitOrNonUnit diag,const DistMatrix<T> & L,DistMatrix<T> & X)87 LLNA( UnitOrNonUnit diag, const DistMatrix<T>& L, DistMatrix<T>& X )
88 {
89 DEBUG_ONLY(
90 CallStackEntry cse("trmm::LLNA");
91 if( L.Grid() != X.Grid() )
92 LogicError("L and X must be distributed over the same grid");
93 if( L.Height() != L.Width() || L.Width() != X.Height() )
94 LogicError
95 ("Nonconformal: \n"
96 " L ~ ",L.Height()," x ",L.Width(),"\n",
97 " X ~ ",X.Height()," x ",X.Width());
98 )
99 const Int m = X.Height();
100 const Int n = X.Width();
101 const Int bsize = Blocksize();
102 const Grid& g = L.Grid();
103
104 DistMatrix<T,VR, STAR> X1_VR_STAR(g);
105 DistMatrix<T,STAR,MR > X1Trans_STAR_MR(g);
106 DistMatrix<T,MC, STAR> Z1_MC_STAR(g);
107
108 X1_VR_STAR.AlignWith( L );
109 X1Trans_STAR_MR.AlignWith( L );
110 Z1_MC_STAR.AlignWith( L );
111
112 for( Int k=0; k<n; k+=bsize )
113 {
114 const Int nb = Min(bsize,n-k);
115
116 auto X1 = ViewRange( X, 0, k, m, k+nb );
117
118 X1_VR_STAR = X1;
119 X1_VR_STAR.TransposePartialColAllGather( X1Trans_STAR_MR );
120 Zeros( Z1_MC_STAR, m, nb );
121 LocalAccumulateLLN
122 ( TRANSPOSE, diag, T(1), L, X1Trans_STAR_MR, Z1_MC_STAR );
123 X1.RowSumScatterFrom( Z1_MC_STAR );
124 }
125 }
126
127 template<typename T>
128 inline void
LLNCOld(UnitOrNonUnit diag,const DistMatrix<T> & L,DistMatrix<T> & X)129 LLNCOld( UnitOrNonUnit diag, const DistMatrix<T>& L, DistMatrix<T>& X )
130 {
131 DEBUG_ONLY(
132 CallStackEntry cse("trmm::LLNCOld");
133 if( L.Grid() != X.Grid() )
134 LogicError("L and X must be distributed over the same grid");
135 if( L.Height() != L.Width() || L.Width() != X.Height() )
136 LogicError
137 ("Nonconformal: \n",
138 " L ~ ",L.Height()," x ",L.Width(),"\n",
139 " X ~ ",X.Height()," x ",X.Width());
140 )
141 const Int m = X.Height();
142 const Int n = X.Width();
143 const Int bsize = Blocksize();
144 const Grid& g = L.Grid();
145
146 DistMatrix<T,STAR,MC > L10_STAR_MC(g);
147 DistMatrix<T,STAR,STAR> L11_STAR_STAR(g);
148 DistMatrix<T,STAR,VR > X1_STAR_VR(g);
149 DistMatrix<T,MR, STAR> D1Trans_MR_STAR(g);
150 DistMatrix<T,MR, MC > D1Trans_MR_MC(g);
151 DistMatrix<T,MC, MR > D1(g);
152
153 const Int kLast = LastOffset( m, bsize );
154 for( Int k=kLast; k>=0; k-=bsize )
155 {
156 const Int nb = Min(bsize,m-k);
157
158 auto L10 = LockedViewRange( L, k, 0, k+nb, k );
159 auto L11 = LockedViewRange( L, k, k, k+nb, k+nb );
160
161 auto X0 = ViewRange( X, 0, 0, k, n );
162 auto X1 = ViewRange( X, k, 0, k+nb, n );
163
164 L11_STAR_STAR = L11;
165 X1_STAR_VR = X1;
166 LocalTrmm( LEFT, LOWER, NORMAL, diag, T(1), L11_STAR_STAR, X1_STAR_VR );
167 X1 = X1_STAR_VR;
168
169 L10_STAR_MC.AlignWith( X0 );
170 L10_STAR_MC = L10;
171 D1Trans_MR_STAR.AlignWith( X1 );
172 LocalGemm
173 ( TRANSPOSE, TRANSPOSE, T(1), X0, L10_STAR_MC, D1Trans_MR_STAR );
174 D1Trans_MR_MC.AlignWith( X1 );
175 D1Trans_MR_MC.RowSumScatterFrom( D1Trans_MR_STAR );
176 D1.AlignWith( X1 );
177 Zeros( D1, nb, n );
178 Transpose( D1Trans_MR_MC.Matrix(), D1.Matrix() );
179 Axpy( T(1), D1, X1 );
180 }
181 }
182
183 template<typename T>
184 inline void
LLNC(UnitOrNonUnit diag,const DistMatrix<T> & L,DistMatrix<T> & X)185 LLNC( UnitOrNonUnit diag, const DistMatrix<T>& L, DistMatrix<T>& X )
186 {
187 DEBUG_ONLY(
188 CallStackEntry cse("trmm::LLNC");
189 if( L.Grid() != X.Grid() )
190 LogicError("L and X must be distributed over the same grid");
191 if( L.Height() != L.Width() || L.Width() != X.Height() )
192 LogicError
193 ("Nonconformal: \n",
194 " L ~ ",L.Height()," x ",L.Width(),"\n",
195 " X ~ ",X.Height()," x ",X.Width());
196 )
197 const Int m = X.Height();
198 const Int n = X.Width();
199 const Int bsize = Blocksize();
200 const Grid& g = L.Grid();
201
202 DistMatrix<T,MC, STAR> L21_MC_STAR(g);
203 DistMatrix<T,STAR,STAR> L11_STAR_STAR(g);
204 DistMatrix<T,STAR,VR > X1_STAR_VR(g);
205 DistMatrix<T,MR, STAR> X1Trans_MR_STAR(g);
206
207 const Int kLast = LastOffset( m, bsize );
208 for( Int k=kLast; k>=0; k-=bsize )
209 {
210 const Int nb = Min(bsize,m-k);
211
212 auto L11 = LockedViewRange( L, k, k, k+nb, k+nb );
213 auto L21 = LockedViewRange( L, k+nb, k, m, k+nb );
214
215 auto X1 = ViewRange( X, k, 0, k+nb, n );
216 auto X2 = ViewRange( X, k+nb, 0, m, n );
217
218 L21_MC_STAR.AlignWith( X2 );
219 L21_MC_STAR = L21;
220 X1Trans_MR_STAR.AlignWith( X2 );
221 X1.TransposeColAllGather( X1Trans_MR_STAR );
222 LocalGemm
223 ( NORMAL, TRANSPOSE, T(1), L21_MC_STAR, X1Trans_MR_STAR, T(1), X2 );
224
225 L11_STAR_STAR = L11;
226 X1_STAR_VR.AlignWith( X1 );
227 X1_STAR_VR.TransposePartialRowFilterFrom( X1Trans_MR_STAR );
228 LocalTrmm( LEFT, LOWER, NORMAL, diag, T(1), L11_STAR_STAR, X1_STAR_VR );
229 X1 = X1_STAR_VR;
230 }
231 }
232
233 // Left Lower Normal (Non)Unit Trmm
234 // X := tril(L) X, or
235 // X := trilu(L) X
236 template<typename T>
237 inline void
LLN(UnitOrNonUnit diag,const DistMatrix<T> & L,DistMatrix<T> & X)238 LLN( UnitOrNonUnit diag, const DistMatrix<T>& L, DistMatrix<T>& X )
239 {
240 DEBUG_ONLY(CallStackEntry cse("trmm::LLN"))
241 // TODO: Come up with a better routing mechanism
242 if( L.Height() > 5*X.Width() )
243 LLNA( diag, L, X );
244 else
245 LLNC( diag, L, X );
246 }
247
248 } // namespace trmm
249 } // namespace elem
250
251 #endif // ifndef ELEM_TRMM_LLN_HPP
252