1 /*
2    Copyright (c) 2009-2014, Jack Poulson
3    All rights reserved.
4 
5    Copyright (c) 2013, The University of Texas at Austin
6    All rights reserved.
7 
8    This file is part of Elemental and is under the BSD 2-Clause License,
9    which can be found in the LICENSE file in the root directory, or at
10    http://opensource.org/licenses/BSD-2-Clause
11 */
12 #pragma once
13 #ifndef ELEM_TRMM_RLN_HPP
14 #define ELEM_TRMM_RLN_HPP
15 
16 #include ELEM_AXPY_INC
17 #include ELEM_MAKETRIANGULAR_INC
18 #include ELEM_SCALE_INC
19 #include ELEM_SETDIAGONAL_INC
20 #include ELEM_TRANSPOSE_INC
21 
22 #include ELEM_GEMM_INC
23 
24 #include ELEM_ZEROS_INC
25 
26 namespace elem {
27 namespace trmm {
28 
29 template<typename T>
30 inline void
LocalAccumulateRLN(Orientation orientation,UnitOrNonUnit diag,T alpha,const DistMatrix<T,MC,MR> & L,const DistMatrix<T,STAR,MC> & X,DistMatrix<T,MR,STAR> & ZTrans)31 LocalAccumulateRLN
32 ( Orientation orientation, UnitOrNonUnit diag, T alpha,
33   const DistMatrix<T,MC,  MR  >& L,
34   const DistMatrix<T,STAR,MC  >& X,
35         DistMatrix<T,MR,  STAR>& ZTrans )
36 {
37     DEBUG_ONLY(
38         CallStackEntry cse("trmm::LocalAccumulateRLN");
39         if( L.Grid() != X.Grid() ||
40             X.Grid() != ZTrans.Grid() )
41             LogicError("{L,X,Z} must be distributed over the same grid");
42         if( L.Height() != L.Width() ||
43             L.Height() != X.Width() ||
44             L.Height() != ZTrans.Height() )
45             LogicError
46             ("Nonconformal:\n",
47              DimsString(L,"L"),"\n",
48              DimsString(X,"X[* ,MC]"),"\n",
49              DimsString(ZTrans,"Z'[MR,* ]"));
50         if( X.RowAlign() != L.ColAlign() ||
51             ZTrans.ColAlign() != L.RowAlign() )
52             LogicError("Partial matrix distributions are misaligned");
53     )
54     const Int m = ZTrans.Height();
55     const Int n = ZTrans.Width();
56     const Int bsize = Blocksize();
57     const Grid& g = L.Grid();
58 
59     DistMatrix<T> D11(g);
60 
61     const Int ratio = Max( g.Height(), g.Width() );
62     for( Int k=0; k<m; k+=ratio*bsize )
63     {
64         const Int nb = Min(ratio*bsize,m-k);
65 
66         auto L11 = LockedViewRange( L, k,    k, k+nb, k+nb );
67         auto L21 = LockedViewRange( L, k+nb, k, m,    k+nb );
68 
69         auto X1 = LockedViewRange( X, 0, k,    n, k+nb );
70         auto X2 = LockedViewRange( X, 0, k+nb, n, m    );
71 
72         auto Z1Trans = ViewRange( ZTrans, k, 0, k+nb, n );
73 
74         D11.AlignWith( L11 );
75         D11 = L11;
76         MakeTriangular( LOWER, D11 );
77         if( diag == UNIT )
78             SetDiagonal( D11, T(1) );
79         LocalGemm( orientation, orientation, alpha, D11, X1, T(1), Z1Trans );
80         LocalGemm( orientation, orientation, alpha, L21, X2, T(1), Z1Trans );
81     }
82 }
83 
84 template<typename T>
85 inline void
RLNA(UnitOrNonUnit diag,const DistMatrix<T> & L,DistMatrix<T> & X)86 RLNA( UnitOrNonUnit diag, const DistMatrix<T>& L, DistMatrix<T>& X )
87 {
88     DEBUG_ONLY(
89         CallStackEntry cse("trmm::RLNA");
90         if( L.Grid() != X.Grid() )
91             LogicError("{L,X} must be distributed over the same grid");
92     )
93     const Int m = X.Height();
94     const Int n = X.Width();
95     const Int bsize = Blocksize();
96     const Grid& g = L.Grid();
97 
98     DistMatrix<T,STAR,VC  > X1_STAR_VC(g);
99     DistMatrix<T,STAR,MC  > X1_STAR_MC(g);
100     DistMatrix<T,MR,  STAR> Z1Trans_MR_STAR(g);
101     DistMatrix<T,MR,  MC  > Z1Trans_MR_MC(g);
102 
103     X1_STAR_VC.AlignWith( L );
104     X1_STAR_MC.AlignWith( L );
105     Z1Trans_MR_STAR.AlignWith( L );
106 
107     for( Int k=0; k<m; k+=bsize )
108     {
109         const Int nb = Min(bsize,m-k);
110 
111         auto X1 = ViewRange( X, k, 0, k+nb, n );
112 
113         X1_STAR_VC = X1;
114         X1_STAR_MC = X1_STAR_VC;
115 
116         Zeros( Z1Trans_MR_STAR, n, nb );
117         LocalAccumulateRLN
118         ( TRANSPOSE, diag, T(1), L, X1_STAR_MC, Z1Trans_MR_STAR );
119 
120         Z1Trans_MR_MC.AlignWith( X1 );
121         Z1Trans_MR_MC.RowSumScatterFrom( Z1Trans_MR_STAR );
122         Transpose( Z1Trans_MR_MC.Matrix(), X1.Matrix() );
123     }
124 }
125 
126 template<typename T>
127 inline void
RLNCOld(UnitOrNonUnit diag,const DistMatrix<T> & L,DistMatrix<T> & X)128 RLNCOld( UnitOrNonUnit diag, const DistMatrix<T>& L, DistMatrix<T>& X )
129 {
130     DEBUG_ONLY(
131         CallStackEntry cse("trmm::RLNCOld");
132         if( L.Grid() != X.Grid() )
133             LogicError
134             ("L and X must be distributed over the same grid");
135         if( L.Height() != L.Width() || X.Width() != L.Height() )
136             LogicError
137             ("Nonconformal:\n",DimsString(L,"L"),"\n",DimsString(X,"X"));
138     )
139     const Int m = X.Height();
140     const Int n = X.Width();
141     const Int bsize = Blocksize();
142     const Grid& g = L.Grid();
143 
144     DistMatrix<T,STAR,STAR> L11_STAR_STAR(g);
145     DistMatrix<T,MR,  STAR> L21_MR_STAR(g);
146     DistMatrix<T,VC,  STAR> X1_VC_STAR(g);
147     DistMatrix<T,MC,  STAR> D1_MC_STAR(g);
148 
149     for( Int k=0; k<n; k+=bsize )
150     {
151         const Int nb = Min(bsize,n-k);
152 
153         auto L11 = LockedViewRange( L, k,    k, k+nb, k+nb );
154         auto L21 = LockedViewRange( L, k+nb, k, n,    k+nb );
155 
156         auto X1 = ViewRange( X, 0, k,    m, k+nb );
157         auto X2 = ViewRange( X, 0, k+nb, m, n    );
158 
159         X1_VC_STAR = X1;
160         L11_STAR_STAR = L11;
161         LocalTrmm
162         ( RIGHT, LOWER, NORMAL, diag, T(1), L11_STAR_STAR, X1_VC_STAR );
163         X1 = X1_VC_STAR;
164 
165         L21_MR_STAR.AlignWith( X2 );
166         L21_MR_STAR = L21;
167         D1_MC_STAR.AlignWith( X1 );
168         LocalGemm( NORMAL, NORMAL, T(1), X2, L21_MR_STAR, D1_MC_STAR );
169         X1.RowSumScatterUpdate( T(1), D1_MC_STAR );
170     }
171 }
172 
173 template<typename T>
174 inline void
RLNC(UnitOrNonUnit diag,const DistMatrix<T> & L,DistMatrix<T> & X)175 RLNC( UnitOrNonUnit diag, const DistMatrix<T>& L, DistMatrix<T>& X )
176 {
177     DEBUG_ONLY(
178         CallStackEntry cse("trmm::RLNC");
179         if( L.Grid() != X.Grid() )
180             LogicError("L and X must be distributed over the same grid");
181         if( L.Height() != L.Width() || X.Width() != L.Height() )
182             LogicError
183             ("Nonconformal:\n",DimsString(L,"L"),"\n",DimsString(X,"X"));
184     )
185     const Int m = X.Height();
186     const Int n = X.Width();
187     const Int bsize = Blocksize();
188     const Grid& g = L.Grid();
189 
190     DistMatrix<T,STAR,STAR> L11_STAR_STAR(g);
191     DistMatrix<T,MR,  STAR> L10Trans_MR_STAR(g);
192     DistMatrix<T,VC,  STAR> X1_VC_STAR(g);
193     DistMatrix<T,MC,  STAR> X1_MC_STAR(g);
194 
195     for( Int k=0; k<n; k+=bsize )
196     {
197         const Int nb = Min(bsize,n-k);
198 
199         auto L10 = LockedViewRange( L, k, 0, k+nb, k    );
200         auto L11 = LockedViewRange( L, k, k, k+nb, k+nb );
201 
202         auto X0 = ViewRange( X, 0, 0, m, k    );
203         auto X1 = ViewRange( X, 0, k, m, k+nb );
204 
205         X1_MC_STAR.AlignWith( X0 );
206         X1_MC_STAR = X1;
207         L10Trans_MR_STAR.AlignWith( X0 );
208         L10.TransposeColAllGather( L10Trans_MR_STAR );
209         LocalGemm
210         ( NORMAL, TRANSPOSE, T(1), X1_MC_STAR, L10Trans_MR_STAR, T(1), X0 );
211 
212         L11_STAR_STAR = L11;
213         X1_VC_STAR.AlignWith( X1 );
214         X1_VC_STAR = X1_MC_STAR;
215         LocalTrmm
216         ( RIGHT, LOWER, NORMAL, diag, T(1), L11_STAR_STAR, X1_VC_STAR );
217         X1 = X1_VC_STAR;
218     }
219 }
220 
221 // Right Lower Normal (Non)Unit Trmm
222 //   X := X tril(L), and
223 //   X := X trilu(L)
224 template<typename T>
225 inline void
RLN(UnitOrNonUnit diag,const DistMatrix<T> & L,DistMatrix<T> & X)226 RLN( UnitOrNonUnit diag, const DistMatrix<T>& L, DistMatrix<T>& X )
227 {
228     DEBUG_ONLY(CallStackEntry cse("trmm::RLN"))
229     // TODO: Come up with a better routing mechanism
230     if( L.Height() > 5*X.Height() )
231         RLNA( diag, L, X );
232     else
233         RLNC( diag, L, X );
234 }
235 
236 } // namespace trmm
237 } // namespace elem
238 
239 #endif // ifndef ELEM_TRMM_RLN_HPP
240