1 /*
2 Copyright (c) 2009-2014, Jack Poulson
3 All rights reserved.
4
5 This file is part of Elemental and is under the BSD 2-Clause License,
6 which can be found in the LICENSE file in the root directory, or at
7 http://opensource.org/licenses/BSD-2-Clause
8 */
9 #pragma once
10 #ifndef ELEM_QUASIDIAGONALSOLVE_HPP
11 #define ELEM_QUASIDIAGONALSOLVE_HPP
12
13 #include "./Symmetric2x2Solve.hpp"
14
15 namespace elem {
16
17 template<typename F,typename FMain>
18 inline void
QuasiDiagonalSolve(LeftOrRight side,UpperOrLower uplo,const Matrix<FMain> & d,const Matrix<F> & dSub,Matrix<F> & X,bool conjugated=false)19 QuasiDiagonalSolve
20 ( LeftOrRight side, UpperOrLower uplo,
21 const Matrix<FMain>& d, const Matrix<F>& dSub,
22 Matrix<F>& X, bool conjugated=false )
23 {
24 DEBUG_ONLY(CallStackEntry cse("QuasiDiagonalSolve"))
25 const Int m = X.Height();
26 const Int n = X.Width();
27 Matrix<F> D( 2, 2 );
28 if( side == LEFT && uplo == LOWER )
29 {
30 Int i=0;
31 while( i < m )
32 {
33 Int nb;
34 if( i < m-1 && Abs(dSub.Get(i,0)) > 0 )
35 nb = 2;
36 else
37 nb = 1;
38
39 if( nb == 1 )
40 {
41 auto xRow = View( X, i, 0, nb, n );
42 Scale( F(1)/d.Get(i,0), xRow );
43 }
44 else
45 {
46 D.Set(0,0,d.Get(i,0));
47 D.Set(1,1,d.Get(i+1,0));
48 D.Set(1,0,dSub.Get(i,0));
49 auto XRow = View( X, i, 0, nb, n );
50 Symmetric2x2Solve( LEFT, LOWER, D, XRow, conjugated );
51 }
52
53 i += nb;
54 }
55 }
56 else if( side == RIGHT && uplo == LOWER )
57 {
58 Int j=0;
59 while( j < n )
60 {
61 Int nb;
62 if( j < n-1 && Abs(dSub.Get(j,0)) > 0 )
63 nb = 2;
64 else
65 nb = 1;
66
67 if( nb == 1 )
68 {
69 auto xCol = View( X, 0, j, m, nb );
70 Scale( F(1)/d.Get(j,0), xCol );
71 }
72 else
73 {
74 D.Set(0,0,d.Get(j,0));
75 D.Set(1,1,d.Get(j+1,0));
76 D.Set(1,0,dSub.Get(j,0));
77 auto XCol = View( X, 0, j, m, nb );
78 Symmetric2x2Solve( RIGHT, LOWER, D, XCol, conjugated );
79 }
80
81 j += nb;
82 }
83 }
84 else
85 LogicError("This option not yet supported");
86 }
87
88 template<typename F,typename FMain,Dist U,Dist V>
89 inline void
LeftQuasiDiagonalSolve(UpperOrLower uplo,const DistMatrix<FMain,U,STAR> d,const DistMatrix<FMain,U,STAR> dPrev,const DistMatrix<FMain,U,STAR> dNext,const DistMatrix<FMain,U,STAR> dSub,const DistMatrix<FMain,U,STAR> dSubPrev,const DistMatrix<FMain,U,STAR> dSubNext,DistMatrix<F,U,V> & X,const DistMatrix<F,U,V> & XPrev,const DistMatrix<F,U,V> & XNext,bool conjugated=false)90 LeftQuasiDiagonalSolve
91 ( UpperOrLower uplo,
92 const DistMatrix<FMain,U,STAR> d,
93 const DistMatrix<FMain,U,STAR> dPrev,
94 const DistMatrix<FMain,U,STAR> dNext,
95 const DistMatrix<FMain,U,STAR> dSub,
96 const DistMatrix<FMain,U,STAR> dSubPrev,
97 const DistMatrix<FMain,U,STAR> dSubNext,
98 DistMatrix<F,U,V>& X,
99 const DistMatrix<F,U,V>& XPrev,
100 const DistMatrix<F,U,V>& XNext,
101 bool conjugated=false )
102 {
103 DEBUG_ONLY(CallStackEntry cse("LeftQuasiDiagonalSolve"))
104 if( uplo == UPPER )
105 LogicError("This option not yet supported");
106 const Int m = X.Height();
107 const Int mLocal = X.LocalHeight();
108 const Int nLocal = X.LocalWidth();
109 const Int colStride = X.ColStride();
110 DEBUG_ONLY(
111 const Int colAlignPrev = Mod(X.ColAlign()+1,colStride);
112 const Int colAlignNext = Mod(X.ColAlign()-1,colStride);
113 if( d.ColAlign() != X.ColAlign() || dSub.ColAlign() != X.ColAlign() )
114 LogicError("data is not properly aligned");
115 if( XPrev.ColAlign() != colAlignPrev ||
116 dPrev.ColAlign() != colAlignPrev ||
117 dSubPrev.ColAlign() != colAlignPrev )
118 LogicError("'previous' data is not properly aligned");
119 if( XNext.ColAlign() != colAlignNext ||
120 dNext.ColAlign() != colAlignNext ||
121 dSubNext.ColAlign() != colAlignNext )
122 LogicError("'next' data is not properly aligned");
123 )
124 const Int prevOff = ( XPrev.ColShift()==X.ColShift()-1 ? 0 : -1 );
125 const Int nextOff = ( XNext.ColShift()==X.ColShift()+1 ? 0 : +1 );
126 if( !X.Participating() )
127 return;
128
129 // It is best to separate the case where colStride is 1
130 if( colStride == 1 )
131 {
132 QuasiDiagonalSolve
133 ( LEFT, uplo, d.LockedMatrix(), dSub.LockedMatrix(), X.Matrix(),
134 conjugated );
135 return;
136 }
137
138 Matrix<F> D11( 2, 2 );
139 for( Int iLoc=0; iLoc<mLocal; ++iLoc )
140 {
141 const Int i = X.GlobalRow(iLoc);
142 const Int iLocPrev = iLoc + prevOff;
143 const Int iLocNext = iLoc + nextOff;
144
145 auto x1Loc = View( X.Matrix(), iLoc, 0, 1, nLocal );
146
147 if( i<m-1 && dSub.GetLocal(iLoc,0) != F(0) )
148 {
149 // Handle 2x2 starting at i
150 D11.Set( 0, 0, d.GetLocal(iLoc,0) );
151 D11.Set( 1, 1, dNext.GetLocal(iLocNext,0) );
152 D11.Set( 1, 0, dSub.GetLocal(iLoc,0) );
153
154 auto x1NextLoc =
155 LockedView( XNext.LockedMatrix(), iLocNext, 0, 1, nLocal );
156 FirstHalfOfSymmetric2x2Solve
157 ( LEFT, LOWER, D11, x1Loc, x1NextLoc, conjugated );
158 }
159 else if( i>0 && dSubPrev.GetLocal(iLocPrev,0) != F(0) )
160 {
161 // Handle 2x2 starting at i-1
162 D11.Set( 0, 0, dPrev.GetLocal(iLocPrev,0) );
163 D11.Set( 1, 1, d.GetLocal(iLoc,0) );
164 D11.Set( 1, 0, dSubPrev.GetLocal(iLocPrev,0) );
165
166 auto x1PrevLoc =
167 LockedView( XPrev.LockedMatrix(), iLocPrev, 0, 1, nLocal );
168 SecondHalfOfSymmetric2x2Solve
169 ( LEFT, LOWER, D11, x1PrevLoc, x1Loc, conjugated );
170 }
171 else
172 {
173 // Handle 1x1
174 Scale( F(1)/d.GetLocal(iLoc,0), x1Loc );
175 }
176 }
177 }
178
179 template<typename F,typename FMain,Dist U,Dist V>
180 inline void
RightQuasiDiagonalSolve(UpperOrLower uplo,const DistMatrix<FMain,V,STAR> d,const DistMatrix<FMain,V,STAR> dPrev,const DistMatrix<FMain,V,STAR> dNext,const DistMatrix<FMain,V,STAR> dSub,const DistMatrix<FMain,V,STAR> dSubPrev,const DistMatrix<FMain,V,STAR> dSubNext,DistMatrix<F,U,V> & X,const DistMatrix<F,U,V> & XPrev,const DistMatrix<F,U,V> & XNext,bool conjugated=false)181 RightQuasiDiagonalSolve
182 ( UpperOrLower uplo,
183 const DistMatrix<FMain,V,STAR> d,
184 const DistMatrix<FMain,V,STAR> dPrev,
185 const DistMatrix<FMain,V,STAR> dNext,
186 const DistMatrix<FMain,V,STAR> dSub,
187 const DistMatrix<FMain,V,STAR> dSubPrev,
188 const DistMatrix<FMain,V,STAR> dSubNext,
189 DistMatrix<F,U,V>& X,
190 const DistMatrix<F,U,V>& XPrev,
191 const DistMatrix<F,U,V>& XNext,
192 bool conjugated=false )
193 {
194 DEBUG_ONLY(CallStackEntry cse("LeftQuasiDiagonalSolve"))
195 if( uplo == UPPER )
196 LogicError("This option not yet supported");
197 const Int n = X.Width();
198 const Int mLocal = X.LocalHeight();
199 const Int nLocal = X.LocalWidth();
200 const Int rowStride = X.RowStride();
201 DEBUG_ONLY(
202 const Int rowAlignPrev = Mod(X.RowAlign()+1,rowStride);
203 const Int rowAlignNext = Mod(X.RowAlign()-1,rowStride);
204 if( d.ColAlign() != X.RowAlign() || dSub.RowAlign() != X.RowAlign() )
205 LogicError("data is not properly aligned");
206 if( XPrev.RowAlign() != rowAlignPrev ||
207 dPrev.ColAlign() != rowAlignPrev ||
208 dSubPrev.ColAlign() != rowAlignPrev )
209 LogicError("'previous' data is not properly aligned");
210 if( XNext.RowAlign() != rowAlignNext ||
211 dNext.ColAlign() != rowAlignNext ||
212 dSubNext.ColAlign() != rowAlignNext )
213 LogicError("'next' data is not properly aligned");
214 )
215 const Int prevOff = ( XPrev.RowShift()==X.RowShift()-1 ? 0 : -1 );
216 const Int nextOff = ( XNext.RowShift()==X.RowShift()+1 ? 0 : +1 );
217 if( !X.Participating() )
218 return;
219
220 // It is best to separate the case where rowStride is 1
221 if( rowStride == 1 )
222 {
223 QuasiDiagonalSolve
224 ( LEFT, uplo, d.LockedMatrix(), dSub.LockedMatrix(), X.Matrix(),
225 conjugated );
226 return;
227 }
228
229 Matrix<F> D11( 2, 2 );
230 for( Int jLoc=0; jLoc<nLocal; ++jLoc )
231 {
232 const Int j = X.GlobalCol(jLoc);
233 const Int jLocPrev = jLoc + prevOff;
234 const Int jLocNext = jLoc + nextOff;
235
236 auto x1Loc = View( X.Matrix(), 0, jLoc, mLocal, 1 );
237
238 if( j<n-1 && dSub.GetLocal(jLoc,0) != F(0) )
239 {
240 // Handle 2x2 starting at j
241 D11.Set( 0, 0, d.GetLocal(jLoc,0) );
242 D11.Set( 1, 1, dNext.GetLocal(jLocNext,0) );
243 D11.Set( 1, 0, dSub.GetLocal(jLoc,0) );
244
245 auto x1NextLoc =
246 LockedView( XNext.LockedMatrix(), 0, jLocNext, mLocal, 1 );
247 FirstHalfOfSymmetric2x2Solve
248 ( RIGHT, LOWER, D11, x1Loc, x1NextLoc, conjugated );
249 }
250 else if( j>0 && dSubPrev.GetLocal(jLocPrev,0) != F(0) )
251 {
252 // Handle 2x2 starting at j-1
253 D11.Set( 0, 0, dPrev.GetLocal(jLocPrev,0) );
254 D11.Set( 1, 1, d.GetLocal(jLoc,0) );
255 D11.Set( 1, 0, dSubPrev.GetLocal(jLocPrev,0) );
256
257 auto x1PrevLoc =
258 LockedView( XPrev.LockedMatrix(), 0, jLocPrev, mLocal, 1 );
259 SecondHalfOfSymmetric2x2Solve
260 ( RIGHT, LOWER, D11, x1PrevLoc, x1Loc, conjugated );
261 }
262 else
263 {
264 // Handle 1x1
265 Scale( F(1)/d.GetLocal(jLoc,0), x1Loc );
266 }
267 }
268 }
269
270 template<typename F,typename FMain,Dist U1,Dist V1,
271 Dist U2,Dist V2>
272 inline void
QuasiDiagonalSolve(LeftOrRight side,UpperOrLower uplo,const DistMatrix<FMain,U1,V1> & d,const DistMatrix<F,U1,V1> & dSub,DistMatrix<F,U2,V2> & X,bool conjugated=false)273 QuasiDiagonalSolve
274 ( LeftOrRight side, UpperOrLower uplo,
275 const DistMatrix<FMain,U1,V1>& d, const DistMatrix<F,U1,V1>& dSub,
276 DistMatrix<F,U2,V2>& X, bool conjugated=false )
277 {
278 DEBUG_ONLY(CallStackEntry cse("QuasiDiagonalSolve"))
279 const Grid& g = X.Grid();
280 const Int colAlign = X.ColAlign();
281 const Int rowAlign = X.RowAlign();
282 if( side == LEFT )
283 {
284 const Int colStride = X.ColStride();
285 DistMatrix<FMain,U2,STAR> d_U2_STAR(g);
286 DistMatrix<F,U2,STAR> dSub_U2_STAR(g);
287 d_U2_STAR.AlignWith( X );
288 dSub_U2_STAR.AlignWith( X );
289 d_U2_STAR = d;
290 dSub_U2_STAR = dSub;
291 if( colStride == 1 )
292 {
293 QuasiDiagonalSolve
294 ( side, uplo, d_U2_STAR.LockedMatrix(), dSub_U2_STAR.LockedMatrix(),
295 X.Matrix(), conjugated );
296 return;
297 }
298
299 DistMatrix<FMain,U2,STAR> dPrev_U2_STAR(g), dNext_U2_STAR(g);
300 DistMatrix<F,U2,STAR> dSubPrev_U2_STAR(g), dSubNext_U2_STAR(g);
301 DistMatrix<F,U2,V2> XPrev(g), XNext(g);
302 const Int colAlignPrev = Mod(colAlign+1,colStride);
303 const Int colAlignNext = Mod(colAlign-1,colStride);
304 dPrev_U2_STAR.AlignCols( colAlignPrev );
305 dNext_U2_STAR.AlignCols( colAlignNext );
306 dSubPrev_U2_STAR.AlignCols( colAlignPrev );
307 dSubNext_U2_STAR.AlignCols( colAlignNext );
308 XPrev.Align( colAlignPrev, rowAlign );
309 XNext.Align( colAlignNext, rowAlign );
310 dPrev_U2_STAR = d;
311 dNext_U2_STAR = d;
312 dSubPrev_U2_STAR = dSub;
313 dSubNext_U2_STAR = dSub;
314 XPrev = X;
315 XNext = X;
316 LeftQuasiDiagonalSolve
317 ( uplo, d_U2_STAR, dPrev_U2_STAR, dNext_U2_STAR,
318 dSub_U2_STAR, dSubPrev_U2_STAR, dSubNext_U2_STAR,
319 X, XPrev, XNext, conjugated );
320 }
321 else
322 {
323 const Int rowStride = X.RowStride();
324 DistMatrix<FMain,V2,STAR> d_V2_STAR(g);
325 DistMatrix<F,V2,STAR> dSub_V2_STAR(g);
326 d_V2_STAR.AlignWith( X );
327 dSub_V2_STAR.AlignWith( X );
328 d_V2_STAR = d;
329 dSub_V2_STAR = dSub;
330 if( rowStride == 1 )
331 {
332 QuasiDiagonalSolve
333 ( side, uplo,
334 d_V2_STAR.LockedMatrix(), dSub_V2_STAR.LockedMatrix(),
335 X.Matrix(), conjugated );
336 return;
337 }
338
339 DistMatrix<FMain,V2,STAR> dPrev_V2_STAR(g), dNext_V2_STAR(g);
340 DistMatrix<F,V2,STAR> dSubPrev_V2_STAR(g), dSubNext_V2_STAR(g);
341 DistMatrix<F,U2,V2> XPrev(g), XNext(g);
342 const Int rowAlignPrev = Mod(rowAlign+1,rowStride);
343 const Int rowAlignNext = Mod(rowAlign-1,rowStride);
344 dPrev_V2_STAR.AlignCols( rowAlignPrev );
345 dNext_V2_STAR.AlignCols( rowAlignNext );
346 dSubPrev_V2_STAR.AlignCols( rowAlignPrev );
347 dSubNext_V2_STAR.AlignCols( rowAlignNext );
348 XPrev.Align( colAlign, rowAlignPrev );
349 XNext.Align( colAlign, rowAlignNext );
350 dPrev_V2_STAR = d;
351 dNext_V2_STAR = d;
352 dSubPrev_V2_STAR = dSub;
353 dSubNext_V2_STAR = dSub;
354 XPrev = X;
355 XNext = X;
356 RightQuasiDiagonalSolve
357 ( uplo, d_V2_STAR, dPrev_V2_STAR, dNext_V2_STAR,
358 dSub_V2_STAR, dSubPrev_V2_STAR, dSubNext_V2_STAR,
359 X, XPrev, XNext, conjugated );
360 }
361 }
362
363 } // namespace elem
364
365 #endif // ifndef ELEM_QUASIDIAGONALSOLVE_HPP
366