1 /*
2    Copyright (c) 2009-2014, Jack Poulson
3    All rights reserved.
4 
5    This file is part of Elemental and is under the BSD 2-Clause License,
6    which can be found in the LICENSE file in the root directory, or at
7    http://opensource.org/licenses/BSD-2-Clause
8 */
9 #pragma once
10 #ifndef ELEM_MULTISHIFTHESSSOLVE_HPP
11 #define ELEM_MULTISHIFTHESSSOLVE_HPP
12 
13 // NOTE: These algorithms are adaptations and/or extensions of Alg. 2 from
14 //       Greg Henry's "The shifted Hessenberg system solve computation".
15 //       It is important to note that the Givens rotation definition in
16 //       said paper is the adjoint of the LAPACK definition (as well as
17 //       leaving out a conjugation necessary for the complex case).
18 
19 namespace elem {
20 namespace mshs {
21 
22 template<typename F>
23 inline void
LN(F alpha,const Matrix<F> & H,const Matrix<F> & shifts,Matrix<F> & X)24 LN( F alpha, const Matrix<F>& H, const Matrix<F>& shifts, Matrix<F>& X )
25 {
26     DEBUG_ONLY(CallStackEntry cse("mshs::LN"))
27     Scale( alpha, X );
28 
29     const Int m = X.Height();
30     const Int n = X.Width();
31     if( m == 0 )
32         return;
33 
34     // Initialize storage for Givens rotations
35     typedef Base<F> Real;
36     Matrix<Real> C(m,n);
37     Matrix<F> S(m,n);
38 
39     // Initialize the workspace for shifted columns of H
40     Matrix<F> W(m,n);
41     for( Int j=0; j<n; ++j )
42     {
43         MemCopy( W.Buffer(0,j), H.LockedBuffer(), m );
44         W.Update( 0, j, -shifts.Get(j,0) );
45     }
46 
47     // Simultaneously find the LQ factorization and solve against L
48     for( Int k=0; k<m-1; ++k )
49     {
50         auto hB = LockedViewRange( H, k+2, k+1, m, k+2 );
51         const F etakkp1 = H.Get(k,k+1);
52         const F etakp1kp1 = H.Get(k+1,k+1);
53         for( Int j=0; j<n; ++j )
54         {
55             // Find the Givens rotation needed to zero H(k,k+1),
56             //   | c        s | | H(k,k)   | = | gamma |
57             //   | -conj(s) c | | H(k,k+1) |   | 0     |
58             Real c; F s;
59             lapack::Givens( W.Get(k,j), etakkp1, &c, &s );
60             C.Set( k, j, c );
61             S.Set( k, j, s );
62 
63             // The new diagonal value of L
64             const F lambdakk = c*W.Get(k,j) + s*etakkp1;
65 
66             // Divide our current entry of x by the diagonal value of L
67             X.Set( k, j, X.Get(k,j)/lambdakk );
68 
69             // x(k+1:end) -= x(k) * L(k+1:end,k), where
70             // L(k+1:end,k) = c H(k+1:end,k) + s H(k+1:end,k+1). We express this
71             // more concisely as xB -= x(k) * ( c wB + s hB ).
72             // Note that we carefully handle updating the k+1'th entry since
73             // it is shift-dependent.
74             const F mu = shifts.Get( j, 0 );
75             const F xc = X.Get(k,j)*c;
76             const F xs  = X.Get(k,j)*s;
77             X.Update( k+1, j, -xc*W.Get(k+1,j)-xs*(etakp1kp1-mu) );
78             blas::Axpy
79             ( m-(k+2), -xc, W.LockedBuffer(k+2,j), 1, X.Buffer(k+2,j), 1 );
80             blas::Axpy
81             ( m-(k+2), -xs, hB.LockedBuffer(),     1, X.Buffer(k+2,j), 1 );
82 
83             // Change the working vector, wB, from representing a fully-updated
84             // portion of the k'th column of H from the end of the last
85             // to a fully-updated portion of the k+1'th column of this iteration
86             //
87             // w(k+1:end) := -conj(s) H(k+1:end,k) + c H(k+1:end,k+1)
88             W.Set( k+1, j, -Conj(s)*W.Get(k+1,j)+c*(etakp1kp1-mu) );
89             blas::Scal( m-(k+2), -Conj(s), W.Buffer(k+2,j), 1 );
90             blas::Axpy( m-(k+2), c, hB.LockedBuffer(), 1, W.Buffer(k+2,j), 1 );
91         }
92     }
93     // Divide x(end) by L(end,end)
94     for( Int j=0; j<n; ++j )
95         X.Set( m-1, j, X.Get(m-1,j)/W.Get(m-1,j) );
96 
97     // Solve against Q
98     for( Int j=0; j<n; ++j )
99     {
100         F* x = X.Buffer(0,j);
101         const Real* c = C.LockedBuffer(0,j);
102         const F*    s = S.LockedBuffer(0,j);
103         F tau0 = x[m-1];
104         for( Int k=m-2; k>=0; --k )
105         {
106             F tau1 = x[k];
107             x[k+1] =       c[k] *tau0 + s[k]*tau1;
108             tau0   = -Conj(s[k])*tau0 + c[k]*tau1;
109         }
110         x[0] = tau0;
111     }
112 }
113 
114 template<typename F>
115 inline void
UN(F alpha,const Matrix<F> & H,const Matrix<F> & shifts,Matrix<F> & X)116 UN( F alpha, const Matrix<F>& H, const Matrix<F>& shifts, Matrix<F>& X )
117 {
118     DEBUG_ONLY(CallStackEntry cse("mshs::UN"))
119     Scale( alpha, X );
120 
121     const Int m = X.Height();
122     const Int n = X.Width();
123     if( m == 0 )
124         return;
125 
126     // Initialize storage for Givens rotations
127     typedef Base<F> Real;
128     Matrix<Real> C(m,n);
129     Matrix<F> S(m,n);
130 
131     // Initialize the workspace for shifted columns of H
132     Matrix<F> W(m,n);
133     for( Int j=0; j<n; ++j )
134     {
135         MemCopy( W.Buffer(0,j), H.LockedBuffer(0,m-1), m );
136         W.Update( m-1, j, -shifts.Get(j,0) );
137     }
138 
139     // Simultaneously form the RQ factorization and solve against R
140     for( Int k=m-1; k>0; --k )
141     {
142         auto hT = LockedView( H, 0, k-1, k-1, 1 );
143         const F etakkm1 = H.Get(k,k-1);
144         const F etakm1km1 = H.Get(k-1,k-1);
145         for( Int j=0; j<n; ++j )
146         {
147             // Find the Givens rotation needed to zero H(k,k-1),
148             //   | c        s | | H(k,k)   | = | gamma |
149             //   | -conj(s) c | | H(k,k-1) |   | 0     |
150             Real c; F s;
151             lapack::Givens( W.Get(k,j), etakkm1, &c, &s );
152             C.Set( k, j, c );
153             S.Set( k, j, s );
154 
155             // The new diagonal value of R
156             const F rhokk = c*W.Get(k,j) + s*etakkm1;
157 
158             // Divide our current entry of x by the diagonal value of R
159             X.Set( k, j, X.Get(k,j)/rhokk );
160 
161             // x(0:k-1) -= x(k) * R(0:k-1,k), where
162             // R(0:k-1,k) = c H(0:k-1,k) + s H(0:k-1,k-1). We express this
163             // more concisely as xT -= x(k) * ( c wT + s hT ).
164             // Note that we carefully handle updating the k-1'th entry since
165             // it is shift-dependent.
166             const F mu = shifts.Get( j, 0 );
167             const F xc = X.Get(k,j)*c;
168             const F xs  = X.Get(k,j)*s;
169             blas::Axpy( k-1, -xc, W.LockedBuffer(0,j), 1, X.Buffer(0,j), 1 );
170             blas::Axpy( k-1, -xs, hT.LockedBuffer(),   1, X.Buffer(0,j), 1 );
171             X.Update( k-1, j, -xc*W.Get(k-1,j)-xs*(etakm1km1-mu) );
172 
173             // Change the working vector, wT, from representing a fully-updated
174             // portion of the k'th column of H from the end of the last
175             // to a fully-updated portion of the k-1'th column of this iteration
176             //
177             // w(0:k-1) := -conj(s) H(0:k-1,k) + c H(0:k-1,k-1)
178             blas::Scal( k-1, -Conj(s), W.Buffer(0,j), 1 );
179             blas::Axpy( k-1, c, hT.LockedBuffer(), 1, W.Buffer(0,j), 1 );
180             W.Set( k-1, j, -Conj(s)*W.Get(k-1,j)+c*(etakm1km1-mu) );
181         }
182     }
183     // Divide x(0) by R(0,0)
184     for( Int j=0; j<n; ++j )
185         X.Set( 0, j, X.Get(0,j)/W.Get(0,j) );
186 
187     // Solve against Q
188     for( Int j=0; j<n; ++j )
189     {
190         F* x = X.Buffer(0,j);
191         const Real* c = C.LockedBuffer(0,j);
192         const F*    s = S.LockedBuffer(0,j);
193         F tau0 = x[0];
194         for( Int k=1; k<m; ++k )
195         {
196             F tau1 = x[k];
197             x[k-1] =       c[k] *tau0 + s[k]*tau1;
198             tau0   = -Conj(s[k])*tau0 + c[k]*tau1;
199         }
200         x[m-1] = tau0;
201     }
202 }
203 
204 // NOTE: A [VC,* ] distribution might be most appropriate for the
205 //       Hessenberg matrices since whole columns will need to be formed
206 //       on every process and this distribution will keep the communication
207 //       balanced.
208 
209 template<typename F,Dist UH,Dist VH,Dist VX>
210 inline void
LN(F alpha,const DistMatrix<F,UH,VH> & H,const DistMatrix<F,VX,STAR> & shifts,DistMatrix<F,STAR,VX> & X)211 LN
212 ( F alpha, const DistMatrix<F,UH,VH>& H, const DistMatrix<F,VX,STAR>& shifts,
213   DistMatrix<F,STAR,VX>& X )
214 {
215     DEBUG_ONLY(
216         CallStackEntry cse("mshs::LN");
217         if( shifts.ColAlign() != X.RowAlign() )
218             LogicError("shifts and X are not aligned");
219     )
220     Scale( alpha, X );
221 
222     const Int m = X.Height();
223     const Int nLoc = X.LocalWidth();
224     if( m == 0 )
225         return;
226 
227     // Initialize storage for Givens rotations
228     typedef Base<F> Real;
229     Matrix<Real> C(m,nLoc);
230     Matrix<F> S(m,nLoc);
231 
232     // Initialize the workspace for shifted columns of H
233     Matrix<F> W(m,nLoc);
234     {
235         auto h0 = LockedView( H, 0, 0, m, 1 );
236         DistMatrix<F,STAR,STAR> h0_STAR_STAR( h0 );
237         for( Int jLoc=0; jLoc<nLoc; ++jLoc )
238         {
239             MemCopy( W.Buffer(0,jLoc), h0_STAR_STAR.LockedBuffer(), m );
240             W.Update( 0, jLoc, -shifts.GetLocal(jLoc,0) );
241         }
242     }
243 
244     // Simultaneously find the LQ factorization and solve against L
245     DistMatrix<F,STAR,STAR> hB_STAR_STAR( H.Grid() );
246     for( Int k=0; k<m-1; ++k )
247     {
248         auto hB = LockedViewRange( H, k+2, k+1, m, k+2 );
249         hB_STAR_STAR = hB;
250         const F etakkp1 = H.Get(k,k+1);
251         const F etakp1kp1 = H.Get(k+1,k+1);
252         for( Int jLoc=0; jLoc<nLoc; ++jLoc )
253         {
254             // Find the Givens rotation needed to zero H(k,k+1),
255             //   | c        s | | H(k,k)   | = | gamma |
256             //   | -conj(s) c | | H(k,k+1) |   | 0     |
257             Real c; F s;
258             lapack::Givens( W.Get(k,jLoc), etakkp1, &c, &s );
259             C.Set( k, jLoc, c );
260             S.Set( k, jLoc, s );
261 
262             // The new diagonal value of L
263             const F lambdakk = c*W.Get(k,jLoc) + s*etakkp1;
264 
265             // Divide our current entry of x by the diagonal value of L
266             X.SetLocal( k, jLoc, X.GetLocal(k,jLoc)/lambdakk );
267 
268             // x(k+1:end) -= x(k) * L(k+1:end,k), where
269             // L(k+1:end,k) = c H(k+1:end,k) + s H(k+1:end,k+1). We express this
270             // more concisely as xB -= x(k) * ( c wB + s hB ).
271             // Note that we carefully handle updating the k+1'th entry since
272             // it is shift-dependent.
273             const F mu = shifts.GetLocal( jLoc, 0 );
274             const F xc = X.GetLocal(k,jLoc)*c;
275             const F xs  = X.GetLocal(k,jLoc)*s;
276             X.UpdateLocal( k+1, jLoc, -xc*W.Get(k+1,jLoc)-xs*(etakp1kp1-mu) );
277             blas::Axpy
278             ( m-(k+2), -xc, W.LockedBuffer(k+2,jLoc), 1,
279                             X.Buffer(k+2,jLoc),       1 );
280             blas::Axpy
281             ( m-(k+2), -xs, hB_STAR_STAR.LockedBuffer(), 1,
282                             X.Buffer(k+2,jLoc),          1 );
283 
284             // Change the working vector, wB, from representing a fully-updated
285             // portion of the k'th column of H from the end of the last
286             // to a fully-updated portion of the k+1'th column of this iteration
287             //
288             // w(k+1:end) := -conj(s) H(k+1:end,k) + c H(k+1:end,k+1)
289             W.Set( k+1, jLoc, -Conj(s)*W.Get(k+1,jLoc)+c*(etakp1kp1-mu) );
290             blas::Scal( m-(k+2), -Conj(s), W.Buffer(k+2,jLoc), 1 );
291             blas::Axpy
292             ( m-(k+2), c, hB_STAR_STAR.LockedBuffer(), 1,
293                           W.Buffer(k+2,jLoc),          1 );
294         }
295     }
296     // Divide x(end) by L(end,end)
297     for( Int jLoc=0; jLoc<nLoc; ++jLoc )
298         X.SetLocal( m-1, jLoc, X.GetLocal(m-1,jLoc)/W.Get(m-1,jLoc) );
299 
300     // Solve against Q
301     for( Int jLoc=0; jLoc<nLoc; ++jLoc )
302     {
303         F* x = X.Buffer(0,jLoc);
304         const Real* c = C.LockedBuffer(0,jLoc);
305         const F*    s = S.LockedBuffer(0,jLoc);
306         F tau0 = x[m-1];
307         for( Int k=m-2; k>=0; --k )
308         {
309             F tau1 = x[k];
310             x[k+1] =       c[k] *tau0 + s[k]*tau1;
311             tau0   = -Conj(s[k])*tau0 + c[k]*tau1;
312         }
313         x[0] = tau0;
314     }
315 }
316 
317 template<typename F,Dist UH,Dist VH,Dist VX>
318 inline void
UN(F alpha,const DistMatrix<F,UH,VH> & H,const DistMatrix<F,VX,STAR> & shifts,DistMatrix<F,STAR,VX> & X)319 UN
320 ( F alpha, const DistMatrix<F,UH,VH>& H, const DistMatrix<F,VX,STAR>& shifts,
321   DistMatrix<F,STAR,VX>& X )
322 {
323     DEBUG_ONLY(
324         CallStackEntry cse("mshs::UN");
325         if( shifts.ColAlign() != X.RowAlign() )
326             LogicError("shifts and X are not aligned");
327     )
328     Scale( alpha, X );
329 
330     const Int m = X.Height();
331     const Int nLoc = X.LocalWidth();
332     if( m == 0 )
333         return;
334 
335     // Initialize storage for Givens rotations
336     typedef Base<F> Real;
337     Matrix<Real> C(m,nLoc);
338     Matrix<F> S(m,nLoc);
339 
340     // Initialize the workspace for shifted columns of H
341     Matrix<F> W(m,nLoc);
342     {
343         auto hLast = LockedView( H, 0, m-1, m, 1 );
344         DistMatrix<F,STAR,STAR> hLast_STAR_STAR( hLast );
345         for( Int jLoc=0; jLoc<nLoc; ++jLoc )
346         {
347             MemCopy( W.Buffer(0,jLoc), hLast_STAR_STAR.LockedBuffer(), m );
348             W.Update( m-1, jLoc, -shifts.GetLocal(jLoc,0) );
349         }
350     }
351 
352     // Simultaneously form the RQ factorization and solve against R
353     DistMatrix<F,STAR,STAR> hT_STAR_STAR( H.Grid() );
354     for( Int k=m-1; k>0; --k )
355     {
356         auto hT = LockedView( H, 0, k-1, k-1, 1 );
357         hT_STAR_STAR = hT;
358         const F etakkm1 = H.Get(k,k-1);
359         const F etakm1km1 = H.Get(k-1,k-1);
360         for( Int jLoc=0; jLoc<nLoc; ++jLoc )
361         {
362             // Find the Givens rotation needed to zero H(k,k-1),
363             //   | c        s | | H(k,k)   | = | gamma |
364             //   | -conj(s) c | | H(k,k-1) |   | 0     |
365             Real c; F s;
366             lapack::Givens( W.Get(k,jLoc), etakkm1, &c, &s );
367             C.Set( k, jLoc, c );
368             S.Set( k, jLoc, s );
369 
370             // The new diagonal value of R
371             const F rhokk = c*W.Get(k,jLoc) + s*etakkm1;
372 
373             // Divide our current entry of x by the diagonal value of R
374             X.SetLocal( k, jLoc, X.GetLocal(k,jLoc)/rhokk );
375 
376             // x(0:k-1) -= x(k) * R(0:k-1,k), where
377             // R(0:k-1,k) = c H(0:k-1,k) + s H(0:k-1,k-1). We express this
378             // more concisely as xT -= x(k) * ( c wT + s hT ).
379             // Note that we carefully handle updating the k-1'th entry since
380             // it is shift-dependent.
381             const F mu = shifts.GetLocal( jLoc, 0 );
382             const F xc = X.GetLocal(k,jLoc)*c;
383             const F xs  = X.GetLocal(k,jLoc)*s;
384             blas::Axpy
385             ( k-1, -xc, W.LockedBuffer(0,jLoc),      1, X.Buffer(0,jLoc), 1 );
386             blas::Axpy
387             ( k-1, -xs, hT_STAR_STAR.LockedBuffer(), 1, X.Buffer(0,jLoc), 1 );
388             X.UpdateLocal( k-1, jLoc, -xc*W.Get(k-1,jLoc)-xs*(etakm1km1-mu) );
389 
390             // Change the working vector, wT, from representing a fully-updated
391             // portion of the k'th column of H from the end of the last
392             // to a fully-updated portion of the k-1'th column of this iteration
393             //
394             // w(0:k-1) := -conj(s) H(0:k-1,k) + c H(0:k-1,k-1)
395             blas::Scal( k-1, -Conj(s), W.Buffer(0,jLoc), 1 );
396             blas::Axpy( k-1, c, hT_STAR_STAR.LockedBuffer(), 1,
397                                 W.Buffer(0,jLoc),            1 );
398             W.Set( k-1, jLoc, -Conj(s)*W.Get(k-1,jLoc)+c*(etakm1km1-mu) );
399         }
400     }
401     for( Int jLoc=0; jLoc<nLoc; ++jLoc )
402         X.SetLocal( 0, jLoc, X.GetLocal(0,jLoc)/W.Get(0,jLoc) );
403 
404     // Solve against Q
405     for( Int jLoc=0; jLoc<nLoc; ++jLoc )
406     {
407         F* x = X.Buffer(0,jLoc);
408         const Real* c = C.LockedBuffer(0,jLoc);
409         const F*    s = S.LockedBuffer(0,jLoc);
410         F tau0 = x[0];
411         for( Int k=1; k<m; ++k )
412         {
413             F tau1 = x[k];
414             x[k-1] =       c[k] *tau0 + s[k]*tau1;
415             tau0   = -Conj(s[k])*tau0 + c[k]*tau1;
416         }
417         x[m-1] = tau0;
418     }
419 }
420 
421 // TODO: UT and LT
422 
423 } // namespace mshs
424 
425 template<typename F>
426 inline void
MultiShiftHessSolve(UpperOrLower uplo,Orientation orientation,F alpha,const Matrix<F> & H,const Matrix<F> & shifts,Matrix<F> & X)427 MultiShiftHessSolve
428 ( UpperOrLower uplo, Orientation orientation,
429   F alpha, const Matrix<F>& H, const Matrix<F>& shifts, Matrix<F>& X )
430 {
431     DEBUG_ONLY(CallStackEntry cse("MultiShiftHessSolve"))
432     if( uplo == LOWER )
433     {
434         if( orientation == NORMAL )
435             mshs::LN( alpha, H, shifts, X );
436         else
437             LogicError("This option is not yet supported");
438     }
439     else
440     {
441         if( orientation == NORMAL )
442             mshs::UN( alpha, H, shifts, X );
443         else
444             LogicError("This option is not yet supported");
445     }
446 }
447 
448 template<typename F,Dist UH,Dist VH,Dist VX>
449 inline void
MultiShiftHessSolve(UpperOrLower uplo,Orientation orientation,F alpha,const DistMatrix<F,UH,VH> & H,const DistMatrix<F,VX,STAR> & shifts,DistMatrix<F,STAR,VX> & X)450 MultiShiftHessSolve
451 ( UpperOrLower uplo, Orientation orientation,
452   F alpha, const DistMatrix<F,UH,VH>& H, const DistMatrix<F,VX,STAR>& shifts,
453   DistMatrix<F,STAR,VX>& X )
454 {
455     DEBUG_ONLY(CallStackEntry cse("MultiShiftHessSolve"))
456     if( uplo == LOWER )
457     {
458         if( orientation == NORMAL )
459             mshs::LN( alpha, H, shifts, X );
460         else
461             LogicError("This option is not yet supported");
462     }
463     else
464     {
465         if( orientation == NORMAL )
466             mshs::UN( alpha, H, shifts, X );
467         else
468             LogicError("This option is not yet supported");
469     }
470 }
471 
472 } // namespace elem
473 
474 #endif // ifndef ELEM_MULTISHIFTHESSSOLVE_HPP
475