1 /*
2 Copyright (c) 2009-2014, Jack Poulson
3 All rights reserved.
4
5 This file is part of Elemental and is under the BSD 2-Clause License,
6 which can be found in the LICENSE file in the root directory, or at
7 http://opensource.org/licenses/BSD-2-Clause
8 */
9 #pragma once
10 #ifndef ELEM_MULTISHIFTHESSSOLVE_HPP
11 #define ELEM_MULTISHIFTHESSSOLVE_HPP
12
13 // NOTE: These algorithms are adaptations and/or extensions of Alg. 2 from
14 // Greg Henry's "The shifted Hessenberg system solve computation".
15 // It is important to note that the Givens rotation definition in
16 // said paper is the adjoint of the LAPACK definition (as well as
17 // leaving out a conjugation necessary for the complex case).
18
19 namespace elem {
20 namespace mshs {
21
22 template<typename F>
23 inline void
LN(F alpha,const Matrix<F> & H,const Matrix<F> & shifts,Matrix<F> & X)24 LN( F alpha, const Matrix<F>& H, const Matrix<F>& shifts, Matrix<F>& X )
25 {
26 DEBUG_ONLY(CallStackEntry cse("mshs::LN"))
27 Scale( alpha, X );
28
29 const Int m = X.Height();
30 const Int n = X.Width();
31 if( m == 0 )
32 return;
33
34 // Initialize storage for Givens rotations
35 typedef Base<F> Real;
36 Matrix<Real> C(m,n);
37 Matrix<F> S(m,n);
38
39 // Initialize the workspace for shifted columns of H
40 Matrix<F> W(m,n);
41 for( Int j=0; j<n; ++j )
42 {
43 MemCopy( W.Buffer(0,j), H.LockedBuffer(), m );
44 W.Update( 0, j, -shifts.Get(j,0) );
45 }
46
47 // Simultaneously find the LQ factorization and solve against L
48 for( Int k=0; k<m-1; ++k )
49 {
50 auto hB = LockedViewRange( H, k+2, k+1, m, k+2 );
51 const F etakkp1 = H.Get(k,k+1);
52 const F etakp1kp1 = H.Get(k+1,k+1);
53 for( Int j=0; j<n; ++j )
54 {
55 // Find the Givens rotation needed to zero H(k,k+1),
56 // | c s | | H(k,k) | = | gamma |
57 // | -conj(s) c | | H(k,k+1) | | 0 |
58 Real c; F s;
59 lapack::Givens( W.Get(k,j), etakkp1, &c, &s );
60 C.Set( k, j, c );
61 S.Set( k, j, s );
62
63 // The new diagonal value of L
64 const F lambdakk = c*W.Get(k,j) + s*etakkp1;
65
66 // Divide our current entry of x by the diagonal value of L
67 X.Set( k, j, X.Get(k,j)/lambdakk );
68
69 // x(k+1:end) -= x(k) * L(k+1:end,k), where
70 // L(k+1:end,k) = c H(k+1:end,k) + s H(k+1:end,k+1). We express this
71 // more concisely as xB -= x(k) * ( c wB + s hB ).
72 // Note that we carefully handle updating the k+1'th entry since
73 // it is shift-dependent.
74 const F mu = shifts.Get( j, 0 );
75 const F xc = X.Get(k,j)*c;
76 const F xs = X.Get(k,j)*s;
77 X.Update( k+1, j, -xc*W.Get(k+1,j)-xs*(etakp1kp1-mu) );
78 blas::Axpy
79 ( m-(k+2), -xc, W.LockedBuffer(k+2,j), 1, X.Buffer(k+2,j), 1 );
80 blas::Axpy
81 ( m-(k+2), -xs, hB.LockedBuffer(), 1, X.Buffer(k+2,j), 1 );
82
83 // Change the working vector, wB, from representing a fully-updated
84 // portion of the k'th column of H from the end of the last
85 // to a fully-updated portion of the k+1'th column of this iteration
86 //
87 // w(k+1:end) := -conj(s) H(k+1:end,k) + c H(k+1:end,k+1)
88 W.Set( k+1, j, -Conj(s)*W.Get(k+1,j)+c*(etakp1kp1-mu) );
89 blas::Scal( m-(k+2), -Conj(s), W.Buffer(k+2,j), 1 );
90 blas::Axpy( m-(k+2), c, hB.LockedBuffer(), 1, W.Buffer(k+2,j), 1 );
91 }
92 }
93 // Divide x(end) by L(end,end)
94 for( Int j=0; j<n; ++j )
95 X.Set( m-1, j, X.Get(m-1,j)/W.Get(m-1,j) );
96
97 // Solve against Q
98 for( Int j=0; j<n; ++j )
99 {
100 F* x = X.Buffer(0,j);
101 const Real* c = C.LockedBuffer(0,j);
102 const F* s = S.LockedBuffer(0,j);
103 F tau0 = x[m-1];
104 for( Int k=m-2; k>=0; --k )
105 {
106 F tau1 = x[k];
107 x[k+1] = c[k] *tau0 + s[k]*tau1;
108 tau0 = -Conj(s[k])*tau0 + c[k]*tau1;
109 }
110 x[0] = tau0;
111 }
112 }
113
114 template<typename F>
115 inline void
UN(F alpha,const Matrix<F> & H,const Matrix<F> & shifts,Matrix<F> & X)116 UN( F alpha, const Matrix<F>& H, const Matrix<F>& shifts, Matrix<F>& X )
117 {
118 DEBUG_ONLY(CallStackEntry cse("mshs::UN"))
119 Scale( alpha, X );
120
121 const Int m = X.Height();
122 const Int n = X.Width();
123 if( m == 0 )
124 return;
125
126 // Initialize storage for Givens rotations
127 typedef Base<F> Real;
128 Matrix<Real> C(m,n);
129 Matrix<F> S(m,n);
130
131 // Initialize the workspace for shifted columns of H
132 Matrix<F> W(m,n);
133 for( Int j=0; j<n; ++j )
134 {
135 MemCopy( W.Buffer(0,j), H.LockedBuffer(0,m-1), m );
136 W.Update( m-1, j, -shifts.Get(j,0) );
137 }
138
139 // Simultaneously form the RQ factorization and solve against R
140 for( Int k=m-1; k>0; --k )
141 {
142 auto hT = LockedView( H, 0, k-1, k-1, 1 );
143 const F etakkm1 = H.Get(k,k-1);
144 const F etakm1km1 = H.Get(k-1,k-1);
145 for( Int j=0; j<n; ++j )
146 {
147 // Find the Givens rotation needed to zero H(k,k-1),
148 // | c s | | H(k,k) | = | gamma |
149 // | -conj(s) c | | H(k,k-1) | | 0 |
150 Real c; F s;
151 lapack::Givens( W.Get(k,j), etakkm1, &c, &s );
152 C.Set( k, j, c );
153 S.Set( k, j, s );
154
155 // The new diagonal value of R
156 const F rhokk = c*W.Get(k,j) + s*etakkm1;
157
158 // Divide our current entry of x by the diagonal value of R
159 X.Set( k, j, X.Get(k,j)/rhokk );
160
161 // x(0:k-1) -= x(k) * R(0:k-1,k), where
162 // R(0:k-1,k) = c H(0:k-1,k) + s H(0:k-1,k-1). We express this
163 // more concisely as xT -= x(k) * ( c wT + s hT ).
164 // Note that we carefully handle updating the k-1'th entry since
165 // it is shift-dependent.
166 const F mu = shifts.Get( j, 0 );
167 const F xc = X.Get(k,j)*c;
168 const F xs = X.Get(k,j)*s;
169 blas::Axpy( k-1, -xc, W.LockedBuffer(0,j), 1, X.Buffer(0,j), 1 );
170 blas::Axpy( k-1, -xs, hT.LockedBuffer(), 1, X.Buffer(0,j), 1 );
171 X.Update( k-1, j, -xc*W.Get(k-1,j)-xs*(etakm1km1-mu) );
172
173 // Change the working vector, wT, from representing a fully-updated
174 // portion of the k'th column of H from the end of the last
175 // to a fully-updated portion of the k-1'th column of this iteration
176 //
177 // w(0:k-1) := -conj(s) H(0:k-1,k) + c H(0:k-1,k-1)
178 blas::Scal( k-1, -Conj(s), W.Buffer(0,j), 1 );
179 blas::Axpy( k-1, c, hT.LockedBuffer(), 1, W.Buffer(0,j), 1 );
180 W.Set( k-1, j, -Conj(s)*W.Get(k-1,j)+c*(etakm1km1-mu) );
181 }
182 }
183 // Divide x(0) by R(0,0)
184 for( Int j=0; j<n; ++j )
185 X.Set( 0, j, X.Get(0,j)/W.Get(0,j) );
186
187 // Solve against Q
188 for( Int j=0; j<n; ++j )
189 {
190 F* x = X.Buffer(0,j);
191 const Real* c = C.LockedBuffer(0,j);
192 const F* s = S.LockedBuffer(0,j);
193 F tau0 = x[0];
194 for( Int k=1; k<m; ++k )
195 {
196 F tau1 = x[k];
197 x[k-1] = c[k] *tau0 + s[k]*tau1;
198 tau0 = -Conj(s[k])*tau0 + c[k]*tau1;
199 }
200 x[m-1] = tau0;
201 }
202 }
203
204 // NOTE: A [VC,* ] distribution might be most appropriate for the
205 // Hessenberg matrices since whole columns will need to be formed
206 // on every process and this distribution will keep the communication
207 // balanced.
208
209 template<typename F,Dist UH,Dist VH,Dist VX>
210 inline void
LN(F alpha,const DistMatrix<F,UH,VH> & H,const DistMatrix<F,VX,STAR> & shifts,DistMatrix<F,STAR,VX> & X)211 LN
212 ( F alpha, const DistMatrix<F,UH,VH>& H, const DistMatrix<F,VX,STAR>& shifts,
213 DistMatrix<F,STAR,VX>& X )
214 {
215 DEBUG_ONLY(
216 CallStackEntry cse("mshs::LN");
217 if( shifts.ColAlign() != X.RowAlign() )
218 LogicError("shifts and X are not aligned");
219 )
220 Scale( alpha, X );
221
222 const Int m = X.Height();
223 const Int nLoc = X.LocalWidth();
224 if( m == 0 )
225 return;
226
227 // Initialize storage for Givens rotations
228 typedef Base<F> Real;
229 Matrix<Real> C(m,nLoc);
230 Matrix<F> S(m,nLoc);
231
232 // Initialize the workspace for shifted columns of H
233 Matrix<F> W(m,nLoc);
234 {
235 auto h0 = LockedView( H, 0, 0, m, 1 );
236 DistMatrix<F,STAR,STAR> h0_STAR_STAR( h0 );
237 for( Int jLoc=0; jLoc<nLoc; ++jLoc )
238 {
239 MemCopy( W.Buffer(0,jLoc), h0_STAR_STAR.LockedBuffer(), m );
240 W.Update( 0, jLoc, -shifts.GetLocal(jLoc,0) );
241 }
242 }
243
244 // Simultaneously find the LQ factorization and solve against L
245 DistMatrix<F,STAR,STAR> hB_STAR_STAR( H.Grid() );
246 for( Int k=0; k<m-1; ++k )
247 {
248 auto hB = LockedViewRange( H, k+2, k+1, m, k+2 );
249 hB_STAR_STAR = hB;
250 const F etakkp1 = H.Get(k,k+1);
251 const F etakp1kp1 = H.Get(k+1,k+1);
252 for( Int jLoc=0; jLoc<nLoc; ++jLoc )
253 {
254 // Find the Givens rotation needed to zero H(k,k+1),
255 // | c s | | H(k,k) | = | gamma |
256 // | -conj(s) c | | H(k,k+1) | | 0 |
257 Real c; F s;
258 lapack::Givens( W.Get(k,jLoc), etakkp1, &c, &s );
259 C.Set( k, jLoc, c );
260 S.Set( k, jLoc, s );
261
262 // The new diagonal value of L
263 const F lambdakk = c*W.Get(k,jLoc) + s*etakkp1;
264
265 // Divide our current entry of x by the diagonal value of L
266 X.SetLocal( k, jLoc, X.GetLocal(k,jLoc)/lambdakk );
267
268 // x(k+1:end) -= x(k) * L(k+1:end,k), where
269 // L(k+1:end,k) = c H(k+1:end,k) + s H(k+1:end,k+1). We express this
270 // more concisely as xB -= x(k) * ( c wB + s hB ).
271 // Note that we carefully handle updating the k+1'th entry since
272 // it is shift-dependent.
273 const F mu = shifts.GetLocal( jLoc, 0 );
274 const F xc = X.GetLocal(k,jLoc)*c;
275 const F xs = X.GetLocal(k,jLoc)*s;
276 X.UpdateLocal( k+1, jLoc, -xc*W.Get(k+1,jLoc)-xs*(etakp1kp1-mu) );
277 blas::Axpy
278 ( m-(k+2), -xc, W.LockedBuffer(k+2,jLoc), 1,
279 X.Buffer(k+2,jLoc), 1 );
280 blas::Axpy
281 ( m-(k+2), -xs, hB_STAR_STAR.LockedBuffer(), 1,
282 X.Buffer(k+2,jLoc), 1 );
283
284 // Change the working vector, wB, from representing a fully-updated
285 // portion of the k'th column of H from the end of the last
286 // to a fully-updated portion of the k+1'th column of this iteration
287 //
288 // w(k+1:end) := -conj(s) H(k+1:end,k) + c H(k+1:end,k+1)
289 W.Set( k+1, jLoc, -Conj(s)*W.Get(k+1,jLoc)+c*(etakp1kp1-mu) );
290 blas::Scal( m-(k+2), -Conj(s), W.Buffer(k+2,jLoc), 1 );
291 blas::Axpy
292 ( m-(k+2), c, hB_STAR_STAR.LockedBuffer(), 1,
293 W.Buffer(k+2,jLoc), 1 );
294 }
295 }
296 // Divide x(end) by L(end,end)
297 for( Int jLoc=0; jLoc<nLoc; ++jLoc )
298 X.SetLocal( m-1, jLoc, X.GetLocal(m-1,jLoc)/W.Get(m-1,jLoc) );
299
300 // Solve against Q
301 for( Int jLoc=0; jLoc<nLoc; ++jLoc )
302 {
303 F* x = X.Buffer(0,jLoc);
304 const Real* c = C.LockedBuffer(0,jLoc);
305 const F* s = S.LockedBuffer(0,jLoc);
306 F tau0 = x[m-1];
307 for( Int k=m-2; k>=0; --k )
308 {
309 F tau1 = x[k];
310 x[k+1] = c[k] *tau0 + s[k]*tau1;
311 tau0 = -Conj(s[k])*tau0 + c[k]*tau1;
312 }
313 x[0] = tau0;
314 }
315 }
316
317 template<typename F,Dist UH,Dist VH,Dist VX>
318 inline void
UN(F alpha,const DistMatrix<F,UH,VH> & H,const DistMatrix<F,VX,STAR> & shifts,DistMatrix<F,STAR,VX> & X)319 UN
320 ( F alpha, const DistMatrix<F,UH,VH>& H, const DistMatrix<F,VX,STAR>& shifts,
321 DistMatrix<F,STAR,VX>& X )
322 {
323 DEBUG_ONLY(
324 CallStackEntry cse("mshs::UN");
325 if( shifts.ColAlign() != X.RowAlign() )
326 LogicError("shifts and X are not aligned");
327 )
328 Scale( alpha, X );
329
330 const Int m = X.Height();
331 const Int nLoc = X.LocalWidth();
332 if( m == 0 )
333 return;
334
335 // Initialize storage for Givens rotations
336 typedef Base<F> Real;
337 Matrix<Real> C(m,nLoc);
338 Matrix<F> S(m,nLoc);
339
340 // Initialize the workspace for shifted columns of H
341 Matrix<F> W(m,nLoc);
342 {
343 auto hLast = LockedView( H, 0, m-1, m, 1 );
344 DistMatrix<F,STAR,STAR> hLast_STAR_STAR( hLast );
345 for( Int jLoc=0; jLoc<nLoc; ++jLoc )
346 {
347 MemCopy( W.Buffer(0,jLoc), hLast_STAR_STAR.LockedBuffer(), m );
348 W.Update( m-1, jLoc, -shifts.GetLocal(jLoc,0) );
349 }
350 }
351
352 // Simultaneously form the RQ factorization and solve against R
353 DistMatrix<F,STAR,STAR> hT_STAR_STAR( H.Grid() );
354 for( Int k=m-1; k>0; --k )
355 {
356 auto hT = LockedView( H, 0, k-1, k-1, 1 );
357 hT_STAR_STAR = hT;
358 const F etakkm1 = H.Get(k,k-1);
359 const F etakm1km1 = H.Get(k-1,k-1);
360 for( Int jLoc=0; jLoc<nLoc; ++jLoc )
361 {
362 // Find the Givens rotation needed to zero H(k,k-1),
363 // | c s | | H(k,k) | = | gamma |
364 // | -conj(s) c | | H(k,k-1) | | 0 |
365 Real c; F s;
366 lapack::Givens( W.Get(k,jLoc), etakkm1, &c, &s );
367 C.Set( k, jLoc, c );
368 S.Set( k, jLoc, s );
369
370 // The new diagonal value of R
371 const F rhokk = c*W.Get(k,jLoc) + s*etakkm1;
372
373 // Divide our current entry of x by the diagonal value of R
374 X.SetLocal( k, jLoc, X.GetLocal(k,jLoc)/rhokk );
375
376 // x(0:k-1) -= x(k) * R(0:k-1,k), where
377 // R(0:k-1,k) = c H(0:k-1,k) + s H(0:k-1,k-1). We express this
378 // more concisely as xT -= x(k) * ( c wT + s hT ).
379 // Note that we carefully handle updating the k-1'th entry since
380 // it is shift-dependent.
381 const F mu = shifts.GetLocal( jLoc, 0 );
382 const F xc = X.GetLocal(k,jLoc)*c;
383 const F xs = X.GetLocal(k,jLoc)*s;
384 blas::Axpy
385 ( k-1, -xc, W.LockedBuffer(0,jLoc), 1, X.Buffer(0,jLoc), 1 );
386 blas::Axpy
387 ( k-1, -xs, hT_STAR_STAR.LockedBuffer(), 1, X.Buffer(0,jLoc), 1 );
388 X.UpdateLocal( k-1, jLoc, -xc*W.Get(k-1,jLoc)-xs*(etakm1km1-mu) );
389
390 // Change the working vector, wT, from representing a fully-updated
391 // portion of the k'th column of H from the end of the last
392 // to a fully-updated portion of the k-1'th column of this iteration
393 //
394 // w(0:k-1) := -conj(s) H(0:k-1,k) + c H(0:k-1,k-1)
395 blas::Scal( k-1, -Conj(s), W.Buffer(0,jLoc), 1 );
396 blas::Axpy( k-1, c, hT_STAR_STAR.LockedBuffer(), 1,
397 W.Buffer(0,jLoc), 1 );
398 W.Set( k-1, jLoc, -Conj(s)*W.Get(k-1,jLoc)+c*(etakm1km1-mu) );
399 }
400 }
401 for( Int jLoc=0; jLoc<nLoc; ++jLoc )
402 X.SetLocal( 0, jLoc, X.GetLocal(0,jLoc)/W.Get(0,jLoc) );
403
404 // Solve against Q
405 for( Int jLoc=0; jLoc<nLoc; ++jLoc )
406 {
407 F* x = X.Buffer(0,jLoc);
408 const Real* c = C.LockedBuffer(0,jLoc);
409 const F* s = S.LockedBuffer(0,jLoc);
410 F tau0 = x[0];
411 for( Int k=1; k<m; ++k )
412 {
413 F tau1 = x[k];
414 x[k-1] = c[k] *tau0 + s[k]*tau1;
415 tau0 = -Conj(s[k])*tau0 + c[k]*tau1;
416 }
417 x[m-1] = tau0;
418 }
419 }
420
421 // TODO: UT and LT
422
423 } // namespace mshs
424
425 template<typename F>
426 inline void
MultiShiftHessSolve(UpperOrLower uplo,Orientation orientation,F alpha,const Matrix<F> & H,const Matrix<F> & shifts,Matrix<F> & X)427 MultiShiftHessSolve
428 ( UpperOrLower uplo, Orientation orientation,
429 F alpha, const Matrix<F>& H, const Matrix<F>& shifts, Matrix<F>& X )
430 {
431 DEBUG_ONLY(CallStackEntry cse("MultiShiftHessSolve"))
432 if( uplo == LOWER )
433 {
434 if( orientation == NORMAL )
435 mshs::LN( alpha, H, shifts, X );
436 else
437 LogicError("This option is not yet supported");
438 }
439 else
440 {
441 if( orientation == NORMAL )
442 mshs::UN( alpha, H, shifts, X );
443 else
444 LogicError("This option is not yet supported");
445 }
446 }
447
448 template<typename F,Dist UH,Dist VH,Dist VX>
449 inline void
MultiShiftHessSolve(UpperOrLower uplo,Orientation orientation,F alpha,const DistMatrix<F,UH,VH> & H,const DistMatrix<F,VX,STAR> & shifts,DistMatrix<F,STAR,VX> & X)450 MultiShiftHessSolve
451 ( UpperOrLower uplo, Orientation orientation,
452 F alpha, const DistMatrix<F,UH,VH>& H, const DistMatrix<F,VX,STAR>& shifts,
453 DistMatrix<F,STAR,VX>& X )
454 {
455 DEBUG_ONLY(CallStackEntry cse("MultiShiftHessSolve"))
456 if( uplo == LOWER )
457 {
458 if( orientation == NORMAL )
459 mshs::LN( alpha, H, shifts, X );
460 else
461 LogicError("This option is not yet supported");
462 }
463 else
464 {
465 if( orientation == NORMAL )
466 mshs::UN( alpha, H, shifts, X );
467 else
468 LogicError("This option is not yet supported");
469 }
470 }
471
472 } // namespace elem
473
474 #endif // ifndef ELEM_MULTISHIFTHESSSOLVE_HPP
475