1 /*
2    Copyright (c) 2009-2014, Jack Poulson
3    All rights reserved.
4 
5    This file is part of Elemental and is under the BSD 2-Clause License,
6    which can be found in the LICENSE file in the root directory, or at
7    http://opensource.org/licenses/BSD-2-Clause
8 */
9 #pragma once
10 #ifndef ELEM_MULTISHIFTQUASITRSM_LLN_HPP
11 #define ELEM_MULTISHIFTQUASITRSM_LLN_HPP
12 
13 #include ELEM_GEMM_INC
14 
15 namespace elem {
16 namespace msquasitrsm {
17 
18 // NOTE: The less stable blas::Givens is used instead of blas::Givens due to
19 //       the fact that the caching of an expensive-to-compute function of
20 //       machine constants is recomputed for every call of the latter to avoid
21 //       a thread safety issue.
22 
23 template<typename F>
24 inline void
LLNUnb(const Matrix<F> & L,const Matrix<F> & shifts,Matrix<F> & X)25 LLNUnb( const Matrix<F>& L, const Matrix<F>& shifts, Matrix<F>& X )
26 {
27     DEBUG_ONLY(CallStackEntry cse("msquasitrsm::LLNUnb"))
28     typedef Base<F> Real;
29     const Int m = X.Height();
30     const Int n = X.Width();
31 
32     const F* LBuf = L.LockedBuffer();
33           F* XBuf = X.Buffer();
34     const Int ldl = L.LDim();
35     const Int ldx = X.LDim();
36 
37     Int k=0;
38     while( k < m )
39     {
40         const bool in2x2 = ( k+1<m && LBuf[k+(k+1)*ldl] != F(0) );
41         if( in2x2 )
42         {
43             // Solve the 2x2 linear systems via 2x2 LQ decompositions produced
44             // by the Givens rotation
45             //    | L(k,k)-shift L(k,k+1) | | c -conj(s) | = | gamma11 0 |
46             //                              | s    c     |
47             // and by also forming the bottom two entries of the 2x2 resulting
48             // lower-triangular matrix, say gamma21 and gamma22
49             //
50             // Extract the constant part of the 2x2 diagonal block, D
51             const F delta12 = LBuf[ k   +(k+1)*ldl];
52             const F delta21 = LBuf[(k+1)+ k   *ldl];
53             for( Int j=0; j<n; ++j )
54             {
55                 const F delta11 = LBuf[ k   + k   *ldl] - shifts.Get(j,0);
56                 const F delta22 = LBuf[(k+1)+(k+1)*ldl] - shifts.Get(j,0);
57                 // Decompose D = L Q
58                 Real c; F s;
59                 const F gamma11 = blas::Givens( delta11, delta12, &c, &s );
60                 const F gamma21 =        c*delta21 + s*delta22;
61                 const F gamma22 = -Conj(s)*delta21 + c*delta22;
62 
63                 F* xBuf = &XBuf[j*ldx];
64 
65                 // Solve against L
66                 xBuf[k  ] /= gamma11;
67                 xBuf[k+1] -= gamma21*xBuf[k];
68                 xBuf[k+1] /= gamma22;
69 
70                 // Solve against Q
71                 const F chi1 = xBuf[k  ];
72                 const F chi2 = xBuf[k+1];
73                 xBuf[k  ] = c*chi1 - Conj(s)*chi2;
74                 xBuf[k+1] = s*chi1 +       c*chi2;
75 
76                 // Update x2 := x2 - L21 x1
77                 blas::Axpy
78                 ( m-(k+2), -xBuf[k  ],
79                   &LBuf[(k+2)+ k   *ldl], 1, &xBuf[k+2], 1 );
80                 blas::Axpy
81                 ( m-(k+2), -xBuf[k+1],
82                   &LBuf[(k+2)+(k+1)*ldl], 1, &xBuf[k+2], 1 );
83             }
84 
85             k += 2;
86         }
87         else
88         {
89             for( Int j=0; j<n; ++j )
90             {
91                 F* xBuf = &XBuf[j*ldx];
92                 xBuf[k] /= LBuf[k+k*ldl] - shifts.Get(j,0);
93                 blas::Axpy
94                 ( m-(k+1), -xBuf[k], &LBuf[(k+1)+k*ldl], 1, &xBuf[k+1], 1 );
95             }
96             k += 1;
97         }
98     }
99 }
100 
101 template<typename F>
102 inline void
LLN(const Matrix<F> & L,const Matrix<F> & shifts,Matrix<F> & X)103 LLN( const Matrix<F>& L, const Matrix<F>& shifts, Matrix<F>& X )
104 {
105     DEBUG_ONLY(CallStackEntry cse("msquasitrsm::LLN"))
106     const Int m = X.Height();
107     const Int n = X.Width();
108     const Int bsize = Blocksize();
109 
110     for( Int k=0; k<m; k+=bsize )
111     {
112         const Int nbProp = Min(bsize,m-k);
113         const bool in2x2 = ( k+nbProp<m && L.Get(k+nbProp-1,k+nbProp) != F(0) );
114         const Int nb = ( in2x2 ? nbProp+1 : nbProp );
115 
116         auto L11 = LockedViewRange( L, k,    k, k+nb, k+nb );
117         auto L21 = LockedViewRange( L, k+nb, k, m,    k+nb );
118 
119         auto X1 = ViewRange( X, k,    0, k+nb, n );
120         auto X2 = ViewRange( X, k+nb, 0, m,    n );
121 
122         LLNUnb( L11, shifts, X1 );
123         Gemm( NORMAL, NORMAL, F(-1), L21, X1, F(1), X2 );
124     }
125 }
126 
127 // For large numbers of RHS's, e.g., width(X) >> p
128 template<typename F>
129 inline void
LLNLarge(const DistMatrix<F> & L,const DistMatrix<F,VR,STAR> & shifts,DistMatrix<F> & X)130 LLNLarge
131 ( const DistMatrix<F>& L, const DistMatrix<F,VR,STAR>& shifts,
132   DistMatrix<F>& X )
133 {
134     DEBUG_ONLY(CallStackEntry cse("msquasitrsm::LLNLarge"))
135     const Int m = X.Height();
136     const Int n = X.Width();
137     const Int bsize = Blocksize();
138     const Grid& g = L.Grid();
139 
140     DistMatrix<F,STAR,STAR> L11_STAR_STAR(g);
141     DistMatrix<F,MC,  STAR> L21_MC_STAR(g);
142     DistMatrix<F,STAR,MR  > X1_STAR_MR(g);
143     DistMatrix<F,STAR,VR  > X1_STAR_VR(g);
144 
145     for( Int k=0; k<m; k+=bsize )
146     {
147         const Int nbProp = Min(bsize,m-k);
148         const bool in2x2 = ( k+nbProp<m && L.Get(k+nbProp-1,k+nbProp) != F(0) );
149         const Int nb = ( in2x2 ? nbProp+1 : nbProp );
150 
151         auto L11 = LockedViewRange( L, k,    k, k+nb, k+nb );
152         auto L21 = LockedViewRange( L, k+nb, k, m,    k+nb );
153 
154         auto X1 = ViewRange( X, k,    0, k+nb, n );
155         auto X2 = ViewRange( X, k+nb, 0, m,    n );
156 
157         L11_STAR_STAR = L11; // L11[* ,* ] <- L11[MC,MR]
158         X1_STAR_VR.AlignWith( shifts );
159         X1_STAR_VR    = X1;  // X1[* ,VR] <- X1[MC,MR]
160 
161         // X1[* ,VR] := L11^-1[* ,* ] X1[* ,VR]
162         LocalMultiShiftQuasiTrsm
163         ( LEFT, LOWER, NORMAL, F(1), L11_STAR_STAR, shifts, X1_STAR_VR );
164 
165         X1_STAR_MR.AlignWith( X2 );
166         X1_STAR_MR  = X1_STAR_VR; // X1[* ,MR]  <- X1[* ,VR]
167         X1          = X1_STAR_MR; // X1[MC,MR] <- X1[* ,MR]
168         L21_MC_STAR.AlignWith( X2 );
169         L21_MC_STAR = L21;        // L21[MC,* ] <- L21[MC,MR]
170 
171         // X2[MC,MR] -= L21[MC,* ] X1[* ,MR]
172         LocalGemm( NORMAL, NORMAL, F(-1), L21_MC_STAR, X1_STAR_MR, F(1), X2 );
173     }
174 }
175 
176 // For medium numbers of RHS's, e.g., width(X) ~= p
177 template<typename F,Dist shiftColDist,Dist shiftRowDist>
178 inline void
LLNMedium(const DistMatrix<F> & L,const DistMatrix<F,shiftColDist,shiftRowDist> & shifts,DistMatrix<F> & X)179 LLNMedium
180 ( const DistMatrix<F>& L,
181   const DistMatrix<F,shiftColDist,shiftRowDist>& shifts,
182         DistMatrix<F>& X )
183 {
184     DEBUG_ONLY(CallStackEntry cse("msquasitrsm::LLNMedium"))
185     const Int m = X.Height();
186     const Int n = X.Width();
187     const Int bsize = Blocksize();
188     const Grid& g = L.Grid();
189 
190     DistMatrix<F,STAR,STAR> L11_STAR_STAR(g);
191     DistMatrix<F,MC,  STAR> L21_MC_STAR(g);
192     DistMatrix<F,MR,  STAR> X1Trans_MR_STAR(g);
193 
194     DistMatrix<F,MR,  STAR> shifts_MR_STAR( shifts ),
195                             shifts_MR_STAR_Align(g);
196 
197     for( Int k=0; k<m; k+=bsize )
198     {
199         const Int nbProp = Min(bsize,m-k);
200         const bool in2x2 = ( k+nbProp<m && L.Get(k+nbProp-1,k+nbProp) != F(0) );
201         const Int nb = ( in2x2 ? nbProp+1 : nbProp );
202 
203         auto L11 = LockedViewRange( L, k,    k, k+nb, k+nb );
204         auto L21 = LockedViewRange( L, k+nb, k, m,    k+nb );
205 
206         auto X1 = ViewRange( X, k,    0, k+nb, n );
207         auto X2 = ViewRange( X, k+nb, 0, m,    n );
208 
209         L11_STAR_STAR = L11; // L11[* ,* ] <- L11[MC,MR]
210         X1Trans_MR_STAR.AlignWith( X2 );
211         X1.TransposeColAllGather( X1Trans_MR_STAR ); // X1[* ,MR] <- X1[MC,MR]
212 
213         // X1^T[MR,* ] := X1^T[MR,* ] L11^-T[* ,* ]
214         //              = (L11^-1[* ,* ] X1[* ,MR])^T
215         shifts_MR_STAR_Align.AlignWith( X1Trans_MR_STAR );
216         shifts_MR_STAR_Align = shifts_MR_STAR;
217         LocalMultiShiftQuasiTrsm
218         ( RIGHT, LOWER, TRANSPOSE,
219           F(1), L11_STAR_STAR, shifts_MR_STAR_Align, X1Trans_MR_STAR );
220 
221         X1.TransposeColFilterFrom( X1Trans_MR_STAR ); // X1[MC,MR] <- X1[* ,MR]
222         L21_MC_STAR.AlignWith( X2 );
223         L21_MC_STAR = L21;                   // L21[MC,* ] <- L21[MC,MR]
224 
225         // X2[MC,MR] -= L21[MC,* ] X1[* ,MR]
226         LocalGemm
227         ( NORMAL, TRANSPOSE, F(-1), L21_MC_STAR, X1Trans_MR_STAR, F(1), X2 );
228     }
229 }
230 
231 // For small numbers of RHS's, e.g., width(X) < p
232 template<typename F,Dist colDist,Dist shiftColDist,Dist shiftRowDist>
233 inline void
LLNSmall(const DistMatrix<F,colDist,STAR> & L,const DistMatrix<F,shiftColDist,shiftRowDist> & shifts,DistMatrix<F,colDist,STAR> & X)234 LLNSmall
235 ( const DistMatrix<F,     colDist,STAR        >& L,
236   const DistMatrix<F,shiftColDist,shiftRowDist>& shifts,
237         DistMatrix<F,     colDist,STAR        >& X )
238 {
239     DEBUG_ONLY(
240         CallStackEntry cse("msquasitrsm::LLNSmall");
241         if( L.ColAlign() != X.ColAlign() )
242             LogicError("L and X are assumed to be aligned");
243     )
244     const Int m = X.Height();
245     const Int n = X.Width();
246     const Int bsize = Blocksize();
247     const Grid& g = L.Grid();
248 
249     DistMatrix<F,STAR,STAR> L11_STAR_STAR(g), X1_STAR_STAR(g),
250                             shifts_STAR_STAR(shifts);
251 
252     for( Int k=0; k<m; k+=bsize )
253     {
254         const Int nbProp = Min(bsize,m-k);
255         const bool in2x2 = ( k+nbProp<m && L.Get(k+nbProp-1,k+nbProp) != F(0) );
256         const Int nb = ( in2x2 ? nbProp+1 : nbProp );
257 
258         auto L11 = LockedViewRange( L, k,    k, k+nb, k+nb );
259         auto L21 = LockedViewRange( L, k+nb, k, m,    k+nb );
260 
261         auto X1 = ViewRange( X, k,    0, k+nb, n );
262         auto X2 = ViewRange( X, k+nb, 0, m,    n );
263 
264         L11_STAR_STAR = L11; // L11[* ,* ] <- L11[VC,* ]
265         X1_STAR_STAR = X1;   // X1[* ,* ] <- X1[VC,* ]
266 
267         // X1[* ,* ] := (L11[* ,* ])^-1 X1[* ,* ]
268         LocalMultiShiftQuasiTrsm
269         ( LEFT, LOWER, NORMAL,
270           F(1), L11_STAR_STAR, shifts_STAR_STAR, X1_STAR_STAR );
271 
272         // X2[VC,* ] -= L21[VC,* ] X1[* ,* ]
273         LocalGemm( NORMAL, NORMAL, F(-1), L21, X1_STAR_STAR, F(1), X2 );
274     }
275 }
276 
277 } // namespace msquasitrsm
278 } // namespace elem
279 
280 #endif // ifndef ELEM_MULTISHIFTQUASITRSM_LLN_HPP
281