1 /*
2 Copyright (c) 2009-2014, Jack Poulson
3 All rights reserved.
4
5 This file is part of Elemental and is under the BSD 2-Clause License,
6 which can be found in the LICENSE file in the root directory, or at
7 http://opensource.org/licenses/BSD-2-Clause
8 */
9 #pragma once
10 #ifndef ELEM_MULTISHIFTQUASITRSM_LLN_HPP
11 #define ELEM_MULTISHIFTQUASITRSM_LLN_HPP
12
13 #include ELEM_GEMM_INC
14
15 namespace elem {
16 namespace msquasitrsm {
17
18 // NOTE: The less stable blas::Givens is used instead of blas::Givens due to
19 // the fact that the caching of an expensive-to-compute function of
20 // machine constants is recomputed for every call of the latter to avoid
21 // a thread safety issue.
22
23 template<typename F>
24 inline void
LLNUnb(const Matrix<F> & L,const Matrix<F> & shifts,Matrix<F> & X)25 LLNUnb( const Matrix<F>& L, const Matrix<F>& shifts, Matrix<F>& X )
26 {
27 DEBUG_ONLY(CallStackEntry cse("msquasitrsm::LLNUnb"))
28 typedef Base<F> Real;
29 const Int m = X.Height();
30 const Int n = X.Width();
31
32 const F* LBuf = L.LockedBuffer();
33 F* XBuf = X.Buffer();
34 const Int ldl = L.LDim();
35 const Int ldx = X.LDim();
36
37 Int k=0;
38 while( k < m )
39 {
40 const bool in2x2 = ( k+1<m && LBuf[k+(k+1)*ldl] != F(0) );
41 if( in2x2 )
42 {
43 // Solve the 2x2 linear systems via 2x2 LQ decompositions produced
44 // by the Givens rotation
45 // | L(k,k)-shift L(k,k+1) | | c -conj(s) | = | gamma11 0 |
46 // | s c |
47 // and by also forming the bottom two entries of the 2x2 resulting
48 // lower-triangular matrix, say gamma21 and gamma22
49 //
50 // Extract the constant part of the 2x2 diagonal block, D
51 const F delta12 = LBuf[ k +(k+1)*ldl];
52 const F delta21 = LBuf[(k+1)+ k *ldl];
53 for( Int j=0; j<n; ++j )
54 {
55 const F delta11 = LBuf[ k + k *ldl] - shifts.Get(j,0);
56 const F delta22 = LBuf[(k+1)+(k+1)*ldl] - shifts.Get(j,0);
57 // Decompose D = L Q
58 Real c; F s;
59 const F gamma11 = blas::Givens( delta11, delta12, &c, &s );
60 const F gamma21 = c*delta21 + s*delta22;
61 const F gamma22 = -Conj(s)*delta21 + c*delta22;
62
63 F* xBuf = &XBuf[j*ldx];
64
65 // Solve against L
66 xBuf[k ] /= gamma11;
67 xBuf[k+1] -= gamma21*xBuf[k];
68 xBuf[k+1] /= gamma22;
69
70 // Solve against Q
71 const F chi1 = xBuf[k ];
72 const F chi2 = xBuf[k+1];
73 xBuf[k ] = c*chi1 - Conj(s)*chi2;
74 xBuf[k+1] = s*chi1 + c*chi2;
75
76 // Update x2 := x2 - L21 x1
77 blas::Axpy
78 ( m-(k+2), -xBuf[k ],
79 &LBuf[(k+2)+ k *ldl], 1, &xBuf[k+2], 1 );
80 blas::Axpy
81 ( m-(k+2), -xBuf[k+1],
82 &LBuf[(k+2)+(k+1)*ldl], 1, &xBuf[k+2], 1 );
83 }
84
85 k += 2;
86 }
87 else
88 {
89 for( Int j=0; j<n; ++j )
90 {
91 F* xBuf = &XBuf[j*ldx];
92 xBuf[k] /= LBuf[k+k*ldl] - shifts.Get(j,0);
93 blas::Axpy
94 ( m-(k+1), -xBuf[k], &LBuf[(k+1)+k*ldl], 1, &xBuf[k+1], 1 );
95 }
96 k += 1;
97 }
98 }
99 }
100
101 template<typename F>
102 inline void
LLN(const Matrix<F> & L,const Matrix<F> & shifts,Matrix<F> & X)103 LLN( const Matrix<F>& L, const Matrix<F>& shifts, Matrix<F>& X )
104 {
105 DEBUG_ONLY(CallStackEntry cse("msquasitrsm::LLN"))
106 const Int m = X.Height();
107 const Int n = X.Width();
108 const Int bsize = Blocksize();
109
110 for( Int k=0; k<m; k+=bsize )
111 {
112 const Int nbProp = Min(bsize,m-k);
113 const bool in2x2 = ( k+nbProp<m && L.Get(k+nbProp-1,k+nbProp) != F(0) );
114 const Int nb = ( in2x2 ? nbProp+1 : nbProp );
115
116 auto L11 = LockedViewRange( L, k, k, k+nb, k+nb );
117 auto L21 = LockedViewRange( L, k+nb, k, m, k+nb );
118
119 auto X1 = ViewRange( X, k, 0, k+nb, n );
120 auto X2 = ViewRange( X, k+nb, 0, m, n );
121
122 LLNUnb( L11, shifts, X1 );
123 Gemm( NORMAL, NORMAL, F(-1), L21, X1, F(1), X2 );
124 }
125 }
126
127 // For large numbers of RHS's, e.g., width(X) >> p
128 template<typename F>
129 inline void
LLNLarge(const DistMatrix<F> & L,const DistMatrix<F,VR,STAR> & shifts,DistMatrix<F> & X)130 LLNLarge
131 ( const DistMatrix<F>& L, const DistMatrix<F,VR,STAR>& shifts,
132 DistMatrix<F>& X )
133 {
134 DEBUG_ONLY(CallStackEntry cse("msquasitrsm::LLNLarge"))
135 const Int m = X.Height();
136 const Int n = X.Width();
137 const Int bsize = Blocksize();
138 const Grid& g = L.Grid();
139
140 DistMatrix<F,STAR,STAR> L11_STAR_STAR(g);
141 DistMatrix<F,MC, STAR> L21_MC_STAR(g);
142 DistMatrix<F,STAR,MR > X1_STAR_MR(g);
143 DistMatrix<F,STAR,VR > X1_STAR_VR(g);
144
145 for( Int k=0; k<m; k+=bsize )
146 {
147 const Int nbProp = Min(bsize,m-k);
148 const bool in2x2 = ( k+nbProp<m && L.Get(k+nbProp-1,k+nbProp) != F(0) );
149 const Int nb = ( in2x2 ? nbProp+1 : nbProp );
150
151 auto L11 = LockedViewRange( L, k, k, k+nb, k+nb );
152 auto L21 = LockedViewRange( L, k+nb, k, m, k+nb );
153
154 auto X1 = ViewRange( X, k, 0, k+nb, n );
155 auto X2 = ViewRange( X, k+nb, 0, m, n );
156
157 L11_STAR_STAR = L11; // L11[* ,* ] <- L11[MC,MR]
158 X1_STAR_VR.AlignWith( shifts );
159 X1_STAR_VR = X1; // X1[* ,VR] <- X1[MC,MR]
160
161 // X1[* ,VR] := L11^-1[* ,* ] X1[* ,VR]
162 LocalMultiShiftQuasiTrsm
163 ( LEFT, LOWER, NORMAL, F(1), L11_STAR_STAR, shifts, X1_STAR_VR );
164
165 X1_STAR_MR.AlignWith( X2 );
166 X1_STAR_MR = X1_STAR_VR; // X1[* ,MR] <- X1[* ,VR]
167 X1 = X1_STAR_MR; // X1[MC,MR] <- X1[* ,MR]
168 L21_MC_STAR.AlignWith( X2 );
169 L21_MC_STAR = L21; // L21[MC,* ] <- L21[MC,MR]
170
171 // X2[MC,MR] -= L21[MC,* ] X1[* ,MR]
172 LocalGemm( NORMAL, NORMAL, F(-1), L21_MC_STAR, X1_STAR_MR, F(1), X2 );
173 }
174 }
175
176 // For medium numbers of RHS's, e.g., width(X) ~= p
177 template<typename F,Dist shiftColDist,Dist shiftRowDist>
178 inline void
LLNMedium(const DistMatrix<F> & L,const DistMatrix<F,shiftColDist,shiftRowDist> & shifts,DistMatrix<F> & X)179 LLNMedium
180 ( const DistMatrix<F>& L,
181 const DistMatrix<F,shiftColDist,shiftRowDist>& shifts,
182 DistMatrix<F>& X )
183 {
184 DEBUG_ONLY(CallStackEntry cse("msquasitrsm::LLNMedium"))
185 const Int m = X.Height();
186 const Int n = X.Width();
187 const Int bsize = Blocksize();
188 const Grid& g = L.Grid();
189
190 DistMatrix<F,STAR,STAR> L11_STAR_STAR(g);
191 DistMatrix<F,MC, STAR> L21_MC_STAR(g);
192 DistMatrix<F,MR, STAR> X1Trans_MR_STAR(g);
193
194 DistMatrix<F,MR, STAR> shifts_MR_STAR( shifts ),
195 shifts_MR_STAR_Align(g);
196
197 for( Int k=0; k<m; k+=bsize )
198 {
199 const Int nbProp = Min(bsize,m-k);
200 const bool in2x2 = ( k+nbProp<m && L.Get(k+nbProp-1,k+nbProp) != F(0) );
201 const Int nb = ( in2x2 ? nbProp+1 : nbProp );
202
203 auto L11 = LockedViewRange( L, k, k, k+nb, k+nb );
204 auto L21 = LockedViewRange( L, k+nb, k, m, k+nb );
205
206 auto X1 = ViewRange( X, k, 0, k+nb, n );
207 auto X2 = ViewRange( X, k+nb, 0, m, n );
208
209 L11_STAR_STAR = L11; // L11[* ,* ] <- L11[MC,MR]
210 X1Trans_MR_STAR.AlignWith( X2 );
211 X1.TransposeColAllGather( X1Trans_MR_STAR ); // X1[* ,MR] <- X1[MC,MR]
212
213 // X1^T[MR,* ] := X1^T[MR,* ] L11^-T[* ,* ]
214 // = (L11^-1[* ,* ] X1[* ,MR])^T
215 shifts_MR_STAR_Align.AlignWith( X1Trans_MR_STAR );
216 shifts_MR_STAR_Align = shifts_MR_STAR;
217 LocalMultiShiftQuasiTrsm
218 ( RIGHT, LOWER, TRANSPOSE,
219 F(1), L11_STAR_STAR, shifts_MR_STAR_Align, X1Trans_MR_STAR );
220
221 X1.TransposeColFilterFrom( X1Trans_MR_STAR ); // X1[MC,MR] <- X1[* ,MR]
222 L21_MC_STAR.AlignWith( X2 );
223 L21_MC_STAR = L21; // L21[MC,* ] <- L21[MC,MR]
224
225 // X2[MC,MR] -= L21[MC,* ] X1[* ,MR]
226 LocalGemm
227 ( NORMAL, TRANSPOSE, F(-1), L21_MC_STAR, X1Trans_MR_STAR, F(1), X2 );
228 }
229 }
230
231 // For small numbers of RHS's, e.g., width(X) < p
232 template<typename F,Dist colDist,Dist shiftColDist,Dist shiftRowDist>
233 inline void
LLNSmall(const DistMatrix<F,colDist,STAR> & L,const DistMatrix<F,shiftColDist,shiftRowDist> & shifts,DistMatrix<F,colDist,STAR> & X)234 LLNSmall
235 ( const DistMatrix<F, colDist,STAR >& L,
236 const DistMatrix<F,shiftColDist,shiftRowDist>& shifts,
237 DistMatrix<F, colDist,STAR >& X )
238 {
239 DEBUG_ONLY(
240 CallStackEntry cse("msquasitrsm::LLNSmall");
241 if( L.ColAlign() != X.ColAlign() )
242 LogicError("L and X are assumed to be aligned");
243 )
244 const Int m = X.Height();
245 const Int n = X.Width();
246 const Int bsize = Blocksize();
247 const Grid& g = L.Grid();
248
249 DistMatrix<F,STAR,STAR> L11_STAR_STAR(g), X1_STAR_STAR(g),
250 shifts_STAR_STAR(shifts);
251
252 for( Int k=0; k<m; k+=bsize )
253 {
254 const Int nbProp = Min(bsize,m-k);
255 const bool in2x2 = ( k+nbProp<m && L.Get(k+nbProp-1,k+nbProp) != F(0) );
256 const Int nb = ( in2x2 ? nbProp+1 : nbProp );
257
258 auto L11 = LockedViewRange( L, k, k, k+nb, k+nb );
259 auto L21 = LockedViewRange( L, k+nb, k, m, k+nb );
260
261 auto X1 = ViewRange( X, k, 0, k+nb, n );
262 auto X2 = ViewRange( X, k+nb, 0, m, n );
263
264 L11_STAR_STAR = L11; // L11[* ,* ] <- L11[VC,* ]
265 X1_STAR_STAR = X1; // X1[* ,* ] <- X1[VC,* ]
266
267 // X1[* ,* ] := (L11[* ,* ])^-1 X1[* ,* ]
268 LocalMultiShiftQuasiTrsm
269 ( LEFT, LOWER, NORMAL,
270 F(1), L11_STAR_STAR, shifts_STAR_STAR, X1_STAR_STAR );
271
272 // X2[VC,* ] -= L21[VC,* ] X1[* ,* ]
273 LocalGemm( NORMAL, NORMAL, F(-1), L21, X1_STAR_STAR, F(1), X2 );
274 }
275 }
276
277 } // namespace msquasitrsm
278 } // namespace elem
279
280 #endif // ifndef ELEM_MULTISHIFTQUASITRSM_LLN_HPP
281