1 /*
2    Copyright (c) 2009-2014, Jack Poulson
3    All rights reserved.
4 
5    This file is part of Elemental and is under the BSD 2-Clause License,
6    which can be found in the LICENSE file in the root directory, or at
7    http://opensource.org/licenses/BSD-2-Clause
8 */
9 #pragma once
10 #ifndef ELEM_QUASIDIAGONALSOLVE_HPP
11 #define ELEM_QUASIDIAGONALSOLVE_HPP
12 
13 #include "./Symmetric2x2Solve.hpp"
14 
15 namespace elem {
16 
17 template<typename F,typename FMain>
18 inline void
QuasiDiagonalSolve(LeftOrRight side,UpperOrLower uplo,const Matrix<FMain> & d,const Matrix<F> & dSub,Matrix<F> & X,bool conjugated=false)19 QuasiDiagonalSolve
20 ( LeftOrRight side, UpperOrLower uplo,
21   const Matrix<FMain>& d, const Matrix<F>& dSub,
22   Matrix<F>& X, bool conjugated=false )
23 {
24     DEBUG_ONLY(CallStackEntry cse("QuasiDiagonalSolve"))
25     const Int m = X.Height();
26     const Int n = X.Width();
27     Matrix<F> D( 2, 2 );
28     if( side == LEFT && uplo == LOWER )
29     {
30         Int i=0;
31         while( i < m )
32         {
33             Int nb;
34             if( i < m-1 && Abs(dSub.Get(i,0)) > 0 )
35                 nb = 2;
36             else
37                 nb = 1;
38 
39             if( nb == 1 )
40             {
41                 auto xRow = View( X, i, 0, nb, n );
42                 Scale( F(1)/d.Get(i,0), xRow );
43             }
44             else
45             {
46                 D.Set(0,0,d.Get(i,0));
47                 D.Set(1,1,d.Get(i+1,0));
48                 D.Set(1,0,dSub.Get(i,0));
49                 auto XRow = View( X, i, 0, nb, n );
50                 Symmetric2x2Solve( LEFT, LOWER, D, XRow, conjugated );
51             }
52 
53             i += nb;
54         }
55     }
56     else if( side == RIGHT && uplo == LOWER )
57     {
58         Int j=0;
59         while( j < n )
60         {
61             Int nb;
62             if( j < n-1 && Abs(dSub.Get(j,0)) > 0 )
63                 nb = 2;
64             else
65                 nb = 1;
66 
67             if( nb == 1 )
68             {
69                 auto xCol = View( X, 0, j, m, nb );
70                 Scale( F(1)/d.Get(j,0), xCol );
71             }
72             else
73             {
74                 D.Set(0,0,d.Get(j,0));
75                 D.Set(1,1,d.Get(j+1,0));
76                 D.Set(1,0,dSub.Get(j,0));
77                 auto XCol = View( X, 0, j, m, nb );
78                 Symmetric2x2Solve( RIGHT, LOWER, D, XCol, conjugated );
79             }
80 
81             j += nb;
82         }
83     }
84     else
85         LogicError("This option not yet supported");
86 }
87 
88 template<typename F,typename FMain,Dist U,Dist V>
89 inline void
LeftQuasiDiagonalSolve(UpperOrLower uplo,const DistMatrix<FMain,U,STAR> d,const DistMatrix<FMain,U,STAR> dPrev,const DistMatrix<FMain,U,STAR> dNext,const DistMatrix<FMain,U,STAR> dSub,const DistMatrix<FMain,U,STAR> dSubPrev,const DistMatrix<FMain,U,STAR> dSubNext,DistMatrix<F,U,V> & X,const DistMatrix<F,U,V> & XPrev,const DistMatrix<F,U,V> & XNext,bool conjugated=false)90 LeftQuasiDiagonalSolve
91 ( UpperOrLower uplo,
92   const DistMatrix<FMain,U,STAR> d,
93   const DistMatrix<FMain,U,STAR> dPrev,
94   const DistMatrix<FMain,U,STAR> dNext,
95   const DistMatrix<FMain,U,STAR> dSub,
96   const DistMatrix<FMain,U,STAR> dSubPrev,
97   const DistMatrix<FMain,U,STAR> dSubNext,
98         DistMatrix<F,U,V>& X,
99   const DistMatrix<F,U,V>& XPrev,
100   const DistMatrix<F,U,V>& XNext,
101   bool conjugated=false )
102 {
103     DEBUG_ONLY(CallStackEntry cse("LeftQuasiDiagonalSolve"))
104     if( uplo == UPPER )
105         LogicError("This option not yet supported");
106     const Int m = X.Height();
107     const Int mLocal = X.LocalHeight();
108     const Int nLocal = X.LocalWidth();
109     const Int colStride = X.ColStride();
110     DEBUG_ONLY(
111         const Int colAlignPrev = Mod(X.ColAlign()+1,colStride);
112         const Int colAlignNext = Mod(X.ColAlign()-1,colStride);
113         if( d.ColAlign() != X.ColAlign() || dSub.ColAlign() != X.ColAlign() )
114             LogicError("data is not properly aligned");
115         if( XPrev.ColAlign() != colAlignPrev ||
116             dPrev.ColAlign() != colAlignPrev ||
117             dSubPrev.ColAlign() != colAlignPrev )
118             LogicError("'previous' data is not properly aligned");
119         if( XNext.ColAlign() != colAlignNext ||
120             dNext.ColAlign() != colAlignNext ||
121             dSubNext.ColAlign() != colAlignNext )
122             LogicError("'next' data is not properly aligned");
123     )
124     const Int prevOff = ( XPrev.ColShift()==X.ColShift()-1 ? 0 : -1 );
125     const Int nextOff = ( XNext.ColShift()==X.ColShift()+1 ? 0 : +1 );
126     if( !X.Participating() )
127         return;
128 
129     // It is best to separate the case where colStride is 1
130     if( colStride == 1 )
131     {
132         QuasiDiagonalSolve
133         ( LEFT, uplo, d.LockedMatrix(), dSub.LockedMatrix(), X.Matrix(),
134           conjugated );
135         return;
136     }
137 
138     Matrix<F> D11( 2, 2 );
139     for( Int iLoc=0; iLoc<mLocal; ++iLoc )
140     {
141         const Int i = X.GlobalRow(iLoc);
142         const Int iLocPrev = iLoc + prevOff;
143         const Int iLocNext = iLoc + nextOff;
144 
145         auto x1Loc = View( X.Matrix(), iLoc, 0, 1, nLocal );
146 
147         if( i<m-1 && dSub.GetLocal(iLoc,0) != F(0) )
148         {
149             // Handle 2x2 starting at i
150             D11.Set( 0, 0, d.GetLocal(iLoc,0) );
151             D11.Set( 1, 1, dNext.GetLocal(iLocNext,0) );
152             D11.Set( 1, 0, dSub.GetLocal(iLoc,0) );
153 
154             auto x1NextLoc =
155                 LockedView( XNext.LockedMatrix(), iLocNext, 0, 1, nLocal );
156             FirstHalfOfSymmetric2x2Solve
157             ( LEFT, LOWER, D11, x1Loc, x1NextLoc, conjugated );
158         }
159         else if( i>0 && dSubPrev.GetLocal(iLocPrev,0) != F(0) )
160         {
161             // Handle 2x2 starting at i-1
162             D11.Set( 0, 0, dPrev.GetLocal(iLocPrev,0) );
163             D11.Set( 1, 1, d.GetLocal(iLoc,0) );
164             D11.Set( 1, 0, dSubPrev.GetLocal(iLocPrev,0) );
165 
166             auto x1PrevLoc =
167                 LockedView( XPrev.LockedMatrix(), iLocPrev, 0, 1, nLocal );
168             SecondHalfOfSymmetric2x2Solve
169             ( LEFT, LOWER, D11, x1PrevLoc, x1Loc, conjugated );
170         }
171         else
172         {
173             // Handle 1x1
174             Scale( F(1)/d.GetLocal(iLoc,0), x1Loc );
175         }
176     }
177 }
178 
179 template<typename F,typename FMain,Dist U,Dist V>
180 inline void
RightQuasiDiagonalSolve(UpperOrLower uplo,const DistMatrix<FMain,V,STAR> d,const DistMatrix<FMain,V,STAR> dPrev,const DistMatrix<FMain,V,STAR> dNext,const DistMatrix<FMain,V,STAR> dSub,const DistMatrix<FMain,V,STAR> dSubPrev,const DistMatrix<FMain,V,STAR> dSubNext,DistMatrix<F,U,V> & X,const DistMatrix<F,U,V> & XPrev,const DistMatrix<F,U,V> & XNext,bool conjugated=false)181 RightQuasiDiagonalSolve
182 ( UpperOrLower uplo,
183   const DistMatrix<FMain,V,STAR> d,
184   const DistMatrix<FMain,V,STAR> dPrev,
185   const DistMatrix<FMain,V,STAR> dNext,
186   const DistMatrix<FMain,V,STAR> dSub,
187   const DistMatrix<FMain,V,STAR> dSubPrev,
188   const DistMatrix<FMain,V,STAR> dSubNext,
189         DistMatrix<F,U,V>& X,
190   const DistMatrix<F,U,V>& XPrev,
191   const DistMatrix<F,U,V>& XNext,
192   bool conjugated=false )
193 {
194     DEBUG_ONLY(CallStackEntry cse("LeftQuasiDiagonalSolve"))
195     if( uplo == UPPER )
196         LogicError("This option not yet supported");
197     const Int n = X.Width();
198     const Int mLocal = X.LocalHeight();
199     const Int nLocal = X.LocalWidth();
200     const Int rowStride = X.RowStride();
201     DEBUG_ONLY(
202         const Int rowAlignPrev = Mod(X.RowAlign()+1,rowStride);
203         const Int rowAlignNext = Mod(X.RowAlign()-1,rowStride);
204         if( d.ColAlign() != X.RowAlign() || dSub.RowAlign() != X.RowAlign() )
205             LogicError("data is not properly aligned");
206         if( XPrev.RowAlign() != rowAlignPrev ||
207             dPrev.ColAlign() != rowAlignPrev ||
208             dSubPrev.ColAlign() != rowAlignPrev )
209             LogicError("'previous' data is not properly aligned");
210         if( XNext.RowAlign() != rowAlignNext ||
211             dNext.ColAlign() != rowAlignNext ||
212             dSubNext.ColAlign() != rowAlignNext )
213             LogicError("'next' data is not properly aligned");
214     )
215     const Int prevOff = ( XPrev.RowShift()==X.RowShift()-1 ? 0 : -1 );
216     const Int nextOff = ( XNext.RowShift()==X.RowShift()+1 ? 0 : +1 );
217     if( !X.Participating() )
218         return;
219 
220     // It is best to separate the case where rowStride is 1
221     if( rowStride == 1 )
222     {
223         QuasiDiagonalSolve
224         ( LEFT, uplo, d.LockedMatrix(), dSub.LockedMatrix(), X.Matrix(),
225           conjugated );
226         return;
227     }
228 
229     Matrix<F> D11( 2, 2 );
230     for( Int jLoc=0; jLoc<nLocal; ++jLoc )
231     {
232         const Int j = X.GlobalCol(jLoc);
233         const Int jLocPrev = jLoc + prevOff;
234         const Int jLocNext = jLoc + nextOff;
235 
236         auto x1Loc = View( X.Matrix(), 0, jLoc, mLocal, 1 );
237 
238         if( j<n-1 && dSub.GetLocal(jLoc,0) != F(0) )
239         {
240             // Handle 2x2 starting at j
241             D11.Set( 0, 0, d.GetLocal(jLoc,0) );
242             D11.Set( 1, 1, dNext.GetLocal(jLocNext,0) );
243             D11.Set( 1, 0, dSub.GetLocal(jLoc,0) );
244 
245             auto x1NextLoc =
246                 LockedView( XNext.LockedMatrix(), 0, jLocNext, mLocal, 1 );
247             FirstHalfOfSymmetric2x2Solve
248             ( RIGHT, LOWER, D11, x1Loc, x1NextLoc, conjugated );
249         }
250         else if( j>0 && dSubPrev.GetLocal(jLocPrev,0) != F(0) )
251         {
252             // Handle 2x2 starting at j-1
253             D11.Set( 0, 0, dPrev.GetLocal(jLocPrev,0) );
254             D11.Set( 1, 1, d.GetLocal(jLoc,0) );
255             D11.Set( 1, 0, dSubPrev.GetLocal(jLocPrev,0) );
256 
257             auto x1PrevLoc =
258                 LockedView( XPrev.LockedMatrix(), 0, jLocPrev, mLocal, 1 );
259             SecondHalfOfSymmetric2x2Solve
260             ( RIGHT, LOWER, D11, x1PrevLoc, x1Loc, conjugated );
261         }
262         else
263         {
264             // Handle 1x1
265             Scale( F(1)/d.GetLocal(jLoc,0), x1Loc );
266         }
267     }
268 }
269 
270 template<typename F,typename FMain,Dist U1,Dist V1,
271                                    Dist U2,Dist V2>
272 inline void
QuasiDiagonalSolve(LeftOrRight side,UpperOrLower uplo,const DistMatrix<FMain,U1,V1> & d,const DistMatrix<F,U1,V1> & dSub,DistMatrix<F,U2,V2> & X,bool conjugated=false)273 QuasiDiagonalSolve
274 ( LeftOrRight side, UpperOrLower uplo,
275   const DistMatrix<FMain,U1,V1>& d, const DistMatrix<F,U1,V1>& dSub,
276   DistMatrix<F,U2,V2>& X, bool conjugated=false )
277 {
278     DEBUG_ONLY(CallStackEntry cse("QuasiDiagonalSolve"))
279     const Grid& g = X.Grid();
280     const Int colAlign = X.ColAlign();
281     const Int rowAlign = X.RowAlign();
282     if( side == LEFT )
283     {
284         const Int colStride = X.ColStride();
285         DistMatrix<FMain,U2,STAR> d_U2_STAR(g);
286         DistMatrix<F,U2,STAR> dSub_U2_STAR(g);
287         d_U2_STAR.AlignWith( X );
288         dSub_U2_STAR.AlignWith( X );
289         d_U2_STAR = d;
290         dSub_U2_STAR = dSub;
291         if( colStride == 1 )
292         {
293             QuasiDiagonalSolve
294             ( side, uplo, d_U2_STAR.LockedMatrix(), dSub_U2_STAR.LockedMatrix(),
295               X.Matrix(), conjugated );
296             return;
297         }
298 
299         DistMatrix<FMain,U2,STAR> dPrev_U2_STAR(g), dNext_U2_STAR(g);
300         DistMatrix<F,U2,STAR> dSubPrev_U2_STAR(g), dSubNext_U2_STAR(g);
301         DistMatrix<F,U2,V2> XPrev(g), XNext(g);
302         const Int colAlignPrev = Mod(colAlign+1,colStride);
303         const Int colAlignNext = Mod(colAlign-1,colStride);
304         dPrev_U2_STAR.AlignCols( colAlignPrev );
305         dNext_U2_STAR.AlignCols( colAlignNext );
306         dSubPrev_U2_STAR.AlignCols( colAlignPrev );
307         dSubNext_U2_STAR.AlignCols( colAlignNext );
308         XPrev.Align( colAlignPrev, rowAlign );
309         XNext.Align( colAlignNext, rowAlign );
310         dPrev_U2_STAR = d;
311         dNext_U2_STAR = d;
312         dSubPrev_U2_STAR = dSub;
313         dSubNext_U2_STAR = dSub;
314         XPrev = X;
315         XNext = X;
316         LeftQuasiDiagonalSolve
317         ( uplo, d_U2_STAR, dPrev_U2_STAR, dNext_U2_STAR,
318           dSub_U2_STAR, dSubPrev_U2_STAR, dSubNext_U2_STAR,
319           X, XPrev, XNext, conjugated );
320     }
321     else
322     {
323         const Int rowStride = X.RowStride();
324         DistMatrix<FMain,V2,STAR> d_V2_STAR(g);
325         DistMatrix<F,V2,STAR> dSub_V2_STAR(g);
326         d_V2_STAR.AlignWith( X );
327         dSub_V2_STAR.AlignWith( X );
328         d_V2_STAR = d;
329         dSub_V2_STAR = dSub;
330         if( rowStride == 1 )
331         {
332             QuasiDiagonalSolve
333             ( side, uplo,
334               d_V2_STAR.LockedMatrix(), dSub_V2_STAR.LockedMatrix(),
335               X.Matrix(), conjugated );
336             return;
337         }
338 
339         DistMatrix<FMain,V2,STAR> dPrev_V2_STAR(g), dNext_V2_STAR(g);
340         DistMatrix<F,V2,STAR> dSubPrev_V2_STAR(g), dSubNext_V2_STAR(g);
341         DistMatrix<F,U2,V2> XPrev(g), XNext(g);
342         const Int rowAlignPrev = Mod(rowAlign+1,rowStride);
343         const Int rowAlignNext = Mod(rowAlign-1,rowStride);
344         dPrev_V2_STAR.AlignCols( rowAlignPrev );
345         dNext_V2_STAR.AlignCols( rowAlignNext );
346         dSubPrev_V2_STAR.AlignCols( rowAlignPrev );
347         dSubNext_V2_STAR.AlignCols( rowAlignNext );
348         XPrev.Align( colAlign, rowAlignPrev );
349         XNext.Align( colAlign, rowAlignNext );
350         dPrev_V2_STAR = d;
351         dNext_V2_STAR = d;
352         dSubPrev_V2_STAR = dSub;
353         dSubNext_V2_STAR = dSub;
354         XPrev = X;
355         XNext = X;
356         RightQuasiDiagonalSolve
357         ( uplo, d_V2_STAR, dPrev_V2_STAR, dNext_V2_STAR,
358           dSub_V2_STAR, dSubPrev_V2_STAR, dSubNext_V2_STAR,
359           X, XPrev, XNext, conjugated );
360     }
361 }
362 
363 } // namespace elem
364 
365 #endif // ifndef ELEM_QUASIDIAGONALSOLVE_HPP
366