1 // =============================================================================
2 // === spqr_rsolve =============================================================
3 // =============================================================================
4 
5 // Solve X = E*(R\B) or X=R\B using the QR factorization from spqr_1factor
6 // (including the singleton-row R and the multifrontal R).
7 
8 /*
9    Let A by an m-by-n matrix.
10 
11    If just A was factorized, then the QR factorization in QRsym and QRnum
12    contains the global R factor:
13 
14         R22
15          0
16 
17    where n1cols = 0, n1rows = 0, A is m-by-n, R22 is (QRnum->rank1)-by-n.
18    R22 is the multifrontal part of R.
19 
20    If [A Binput] was factorized and no singletons were removed prior to
21    factorization, then the QR factorization in QRsym and QRnum contains the
22    global R factor:
23 
24         R22 C2
25          0  C3
26 
27    where R22 is (QRnum->rank1)-by-n and where A is m-by-n and [C2 ; C2] is
28    m-by-bncols.
29 
30    If [A Binput] was factorized and singletons were removed prior to
31    factorization, then the QR factorization of the [S2 B2] matrix was computed
32    (the global R):
33 
34        R11 R12 C1
35         0  R22 C2
36         0   0  C3
37 
38     where R11 has n1cols columns and n1rows rows.  [R11 R12] is the singleton-
39     row part of R.  The R1 = [R11 R12] matrix (not in QRsym and QRnum) contains
40     the singleton rows.  The QR factorization in QRsym and QRnum contains only
41     the R factor:
42 
43         R22 C2
44          0  C3
45 
46     To solve with R22, the columns of R22 must be shifted by n1cols and then
47     permuted via Q1fill to get a column index in the range n1cols to n-1, where
48     A is m-by-n.  However, the QR factorization (and Q1fill) also contain column
49     indices >= n.  If a column index in the QR factorization is >= n, then it
50     refers to a column of C and is ignored in this solve phase.
51 
52     The row indices of R22 must also be shifted, by n1rows.
53 
54     Note that in all cases, the number of rows of R22 is given by QRnum->rank1.
55 */
56 
57 #include "spqr.hpp"
58 
spqr_rsolve(SuiteSparseQR_factorization<Entry> * QR,int use_Q1fill,Long nrhs,Long ldb,Entry * B,Entry * X,Entry ** Rcolp,Long * Rlive,Entry * W,cholmod_common * cc)59 template <typename Entry> void spqr_rsolve
60 (
61     // inputs
62     SuiteSparseQR_factorization <Entry> *QR,
63     int use_Q1fill,         // if TRUE, do X=E*(R\B), otherwise do X=R\B
64 
65     Long nrhs,              // number of columns of B
66     Long ldb,               // leading dimension of B
67     Entry *B,               // size m-by-nrhs with leading dimesion ldb
68 
69     // output
70     Entry *X,               // size n-by-nrhs with leading dimension n
71 
72     // workspace
73     Entry **Rcolp,          // size QRnum->maxfrank
74     Long *Rlive,            // size QRnum->maxfrank
75     Entry *W,               // size QRnum->maxfrank * nrhs
76 
77     cholmod_common *cc
78 )
79 {
80     spqr_symbolic *QRsym ;
81     spqr_numeric <Entry> *QRnum ;
82     Long n1rows, n1cols, n ;
83     Long *Q1fill, *R1p, *R1j ;
84     Entry *R1x ;
85 
86     Entry xi ;
87     Entry **Rblock, *R, *W1, *B1, *X1 ;
88     Long *Rp, *Rj, *Super, *HStair, *Hm, *Stair ;
89     char *Rdead ;
90     Long nf, // m,
91         rank, j, f, col1, col2, fp, pr, fn, rm, k, i, row1, row2, ii,
92         keepH, fm, h, t, live, kk ;
93 
94     // -------------------------------------------------------------------------
95     // get the contents of the QR object
96     // -------------------------------------------------------------------------
97 
98     QRsym = QR->QRsym ;
99     QRnum = QR->QRnum ;
100     n1rows = QR->n1rows ;
101     n1cols = QR->n1cols ;
102     n = QR->nacols ;
103     Q1fill = use_Q1fill ? QR->Q1fill : NULL ;
104     R1p = QR->R1p ;
105     R1j = QR->R1j ;
106     R1x = QR->R1x ;
107 
108     keepH = QRnum->keepH ;
109     PR (("rsolve keepH %ld\n", keepH)) ;
110     nf = QRsym->nf ;
111     // m = QRsym->m ;
112     Rblock = QRnum->Rblock ;
113     Rp = QRsym->Rp ;
114     Rj = QRsym->Rj ;
115     Super = QRsym->Super ;
116     Rdead = QRnum->Rdead ;
117     rank = QR->rank ;   // R22 is R(n1rows:rank-1,n1cols:n-1) of
118                         // the global R.
119     HStair = QRnum->HStair ;
120     Hm = QRnum->Hm ;
121 
122     // -------------------------------------------------------------------------
123     // X = 0
124     // -------------------------------------------------------------------------
125 
126     X1 = X ;
127     for (kk = 0 ; kk < nrhs ; kk++)
128     {
129         for (i = 0 ; i < n ; i++)
130         {
131             X1 [i] = 0 ;
132         }
133         X1 += n ;
134     }
135 
136     // =========================================================================
137     // === solve with the multifrontal rows of R ===============================
138     // =========================================================================
139 
140     Stair = NULL ;
141     fm = 0 ;
142     h = 0 ;
143     t = 0 ;
144 
145     // start with row2 = QR-num->rank + n1rows, the last row of the combined R
146     // factor of [A Binput]
147 
148     row2 = QRnum->rank + n1rows ;
149     for (f = nf-1 ; f >= 0 ; f--)
150     {
151 
152         // ---------------------------------------------------------------------
153         // get the R block for front F
154         // ---------------------------------------------------------------------
155 
156         R = Rblock [f] ;
157         col1 = Super [f] ;                  // first pivot column in front F
158         col2 = Super [f+1] ;                // col2-1 is last pivot col
159         fp = col2 - col1 ;                  // number of pivots in front F
160         pr = Rp [f] ;                       // pointer to row indices for F
161         fn = Rp [f+1] - pr ;                // # of columns in front F
162 
163         if (keepH)
164         {
165             Stair = HStair + pr ;           // staircase of front F
166             fm = Hm [f] ;                   // # of rows in front F
167             h = 0 ;                         // H vector starts in row h
168         }
169 
170         // ---------------------------------------------------------------------
171         // find the live pivot columns in this R or RH block
172         // ---------------------------------------------------------------------
173 
174         rm = 0 ;                            // number of rows in R block
175         for (k = 0 ; k < fp ; k++)
176         {
177             j = col1 + k ;
178             ASSERT (Rj [pr + k] == j) ;
179             if (keepH)
180             {
181                 t = Stair [k] ;             // length of R+H vector
182                 ASSERT (t >= 0 && t <= fm) ;
183                 if (t == 0)
184                 {
185                     live = FALSE ;          // column k is dead
186                     t = rm ;                // dead col, R only, no H
187                     h = rm ;
188                 }
189                 else
190                 {
191                     live = (rm < fm) ;      // k is live, unless we hit the wall
192                     h = rm + 1 ;            // H vector starts in row h
193                 }
194                 ASSERT (t >= h) ;
195             }
196             else
197             {
198                 live = (!Rdead [j])  ;
199             }
200 
201             if (live)
202             {
203                 // R (rm,k) is a "diagonal"; rm and k are local indices.
204                 // Keep track of a pointer to the first entry R(0,k)
205                 Rcolp [rm] = R ;
206                 Rlive [rm] = j ;
207                 rm++ ;
208             }
209             else
210             {
211                 // compute the basic solution; dead columns are zero
212                 ii = Q1fill ? Q1fill [j+n1cols] : j+n1cols ;
213                 if (ii < n)
214                 {
215                     for (kk = 0 ; kk < nrhs ; kk++)
216                     {
217                         // X (ii,kk) = 0; note this is stride n
218                         X [INDEX (ii,kk,n)] = 0 ;
219                     }
220                 }
221             }
222 
223             // advance to the next column of R in the R block
224             R += rm + (keepH ? (t-h) : 0) ;
225         }
226 
227         // There are rm rows in this R block, corresponding to the rm live
228         // columns in the range col1:col2-1.  The list of live global column
229         // indices is given in Rlive [0:rm-1].  Pointers to the numerical
230         // entries for each of these columns in this R block are given in
231         // Rcolp [0:rm-1].   The rm rows in this R block correspond to
232         // row1:row2-1 of R and b.
233 
234         row1 = row2 - rm ;
235 
236         // ---------------------------------------------------------------------
237         // get the right-hand sides for these rm equations
238         // ---------------------------------------------------------------------
239 
240         // W = B (row1:row2-1,:)
241         ASSERT (rm <= QRnum->maxfrank) ;
242         W1 = W ;
243         B1 = B ;
244         for (kk = 0 ; kk < nrhs ; kk++)
245         {
246             for (i = 0 ; i < rm ; i++)
247             {
248                 ii = row1 + i ;
249                 ASSERT (ii >= n1rows) ;
250                 W1 [i] = (ii < rank) ? B1 [ii] : 0 ;
251             }
252             W1 += rm ;
253             B1 += ldb ;
254         }
255 
256         // ---------------------------------------------------------------------
257         // solve with the rectangular part of R (W = W - R2*x2)
258         // ---------------------------------------------------------------------
259 
260         for ( ; k < fn ; k++)
261         {
262             j = Rj [pr + k] ;
263             ASSERT (j >= col2 && j < QRsym->n) ;
264             ii = Q1fill ? Q1fill [j+n1cols] : j+n1cols ;
265             ASSERT ((ii < n) == (j+n1cols < n)) ;
266             // break if past the last column of A in QR of [A Binput]
267             if (ii >= n) break ;
268 
269             if (!Rdead [j])
270             {
271                 // global column j is live
272                 W1 = W ;
273                 for (kk = 0 ; kk < nrhs ; kk++)
274                 {
275                     xi = X [INDEX (ii,kk,n)] ;        // xi = X (ii,kk)
276                     if (xi != (Entry) 0)
277                     {
278                         FLOP_COUNT (2*rm) ;
279                         for (i = 0 ; i < rm ; i++)
280                         {
281                             W1 [i] -= R [i] * xi ;
282                         }
283                     }
284                     W1 += rm ;
285                 }
286             }
287 
288             // go to the next column of R
289             R += rm ;
290             if (keepH)
291             {
292                 t = Stair [k] ;             // length of R+H vector
293                 ASSERT (t >= 0 && t <= fm) ;
294                 h = MIN (h+1, fm) ;         // H vector starts in row h
295                 ASSERT (t >= h) ;
296                 R += (t-h) ;
297             }
298         }
299 
300         // ---------------------------------------------------------------------
301         // solve with the squeezed upper triangular part of R
302         // ---------------------------------------------------------------------
303 
304         for (k = rm-1 ; k >= 0 ; k--)
305         {
306             R = Rcolp [k] ;                 // kth live pivot column
307             j = Rlive [k] ;                 // is jth global column
308             ii = Q1fill ? Q1fill [j+n1cols] : j+n1cols ;
309             ASSERT ((ii < n) == (j+n1cols < n)) ;
310             if (ii < n)
311             {
312                 W1 = W ;
313                 for (kk = 0 ; kk < nrhs ; kk++)
314                 {
315                     // divide by the "diagonal"
316                     // xi = W1 [k] / R [k] ;
317                     xi = spqr_divide (W1 [k], R [k], cc) ;
318                     FLOP_COUNT (1) ;
319                     X [INDEX(ii,kk,n)] = xi ;
320                     if (xi != (Entry) 0)
321                     {
322                         FLOP_COUNT (2*k) ;
323                         for (i = 0 ; i < k ; i++)
324                         {
325                             W1 [i] -= R [i] * xi ;
326                         }
327                     }
328                     W1 += rm ;
329                 }
330             }
331         }
332 
333         // ---------------------------------------------------------------------
334         // prepare for the R block for front f-1
335         // ---------------------------------------------------------------------
336 
337         row2 = row1 ;
338     }
339     ASSERT (row2 == n1rows) ;
340 
341     // =========================================================================
342     // === solve with the singleton rows of R ==================================
343     // =========================================================================
344 
345     FLOP_COUNT ((n1rows <= 0) ? 0 :
346         nrhs * (n1rows + (2 * (R1p [n1rows] - n1rows)))) ;
347 
348     for (kk = 0 ; kk < nrhs ; kk++)
349     {
350         for (i = n1rows-1 ; i >= 0 ; i--)
351         {
352             // get the right-hand side for this ith singleton row
353             Entry x = B [i] ;
354             // solve with the "off-diagonal" entries, x = x-R(i,:)*x2
355             for (Long p = R1p [i] + 1 ; p < R1p [i+1] ; p++)
356             {
357                 Long jnew = R1j [p] ;
358                 ASSERT (jnew >= i && jnew < n) ;
359                 Long jold = Q1fill ? Q1fill [jnew] : jnew ;
360                 ASSERT (jold >= 0 && jold < n) ;
361                 x -= R1x [p] * X [jold] ;
362             }
363             // divide by the "diagonal" (the singleton entry itself)
364             Long p = R1p [i] ;
365             Long jnew = R1j [p] ;
366             Long jold = Q1fill ? Q1fill [jnew] : jnew ;
367             ASSERT (jold >= 0 && jold < n) ;
368             // X [jold] = x / R1x [p] ; using cc->complex_divide
369             X [jold] = spqr_divide (x, R1x [p], cc) ;
370         }
371         B += ldb ;
372         X += n ;
373     }
374 }
375 
376 // =============================================================================
377 
378 template void spqr_rsolve <double>
379 (
380     // inputs
381     SuiteSparseQR_factorization <double> *QR,
382     int use_Q1fill,
383 
384     Long nrhs,              // number of columns of B
385     Long ldb,               // leading dimension of B
386     double *B,              // size m-by-nrhs with leading dimesion ldb
387 
388     // output
389     double *X,              // size n-by-nrhs with leading dimension n
390 
391     // workspace
392     double **Rcolp,
393     Long *Rlive,
394     double *W,
395 
396     cholmod_common *cc
397 ) ;
398 
399 
400 template void spqr_rsolve <Complex>
401 (
402     // inputs
403     SuiteSparseQR_factorization <Complex> *QR,
404     int use_Q1fill,
405 
406     Long nrhs,              // number of columns of B
407     Long ldb,               // leading dimension of B
408     Complex *B,             // size m-by-nrhs with leading dimesion ldb
409 
410     // output
411     Complex *X,             // size n-by-nrhs with leading dimension n
412 
413     // workspace
414     Complex **Rcolp,
415     Long *Rlive,
416     Complex *W,
417 
418     cholmod_common *cc
419 ) ;
420 
421