1 /*  findPivot.c  */
2 
3 #include "../Chv.h"
4 
5 #define MYDEBUG 0
6 
7 /*--------------------------------------------------------------------*/
8 static int findPivotSH ( Chv *chv, DV *workDV, double tau,
9    int ndelay, int *pirow, int *pjcol, int *pntest ) ;
10 static int findPivotN ( Chv *chv, DV *workDV, double tau,
11    int ndelay, int *pirow, int *pjcol, int *pntest ) ;
12 static int nonsym1x1 ( Chv *chv, int irow, int jcol, double tau,
13                        double rowmaxes[], double colmaxes[] ) ;
14 static int sym1x1 ( Chv *chv, int irow,
15                     double tau, double rowmaxes[] );
16 static int sym2x2 ( Chv *chv, int irow, int jcol,
17                     double tau, double rowmaxes[] ) ;
18 /*--------------------------------------------------------------------*/
19 /*
20    ------------------------------------------------------------------
21    purpose -- find and test a pivot
22 
23    workDV -- object that contains work vectors
24    tau    -- upper bound on magnitude of factor entries
25    ndelay -- number of delayed rows and columns on input
26    pirow  -- pointer to be filled with pivot row
27    pjcol  -- pointer to be filled with pivot column
28    pntest -- pointer to be incremented with the number of pivot tests
29 
30    return value -- size of pivot
31      0 --> pivot not found
32      1 --> 1x1 pivot in row *pirow and column *pjcol
33      2 --> 2x2 pivot in rows and columns *pirow and *pjcol,
34            symmetric front only
35 
36    created -- 98jan24, cca
37    ------------------------------------------------------------------
38 */
39 int
Chv_findPivot(Chv * chv,DV * workDV,double tau,int ndelay,int * pirow,int * pjcol,int * pntest)40 Chv_findPivot (
41    Chv     *chv,
42    DV       *workDV,
43    double   tau,
44    int      ndelay,
45    int      *pirow,
46    int      *pjcol,
47    int      *pntest
48 ) {
49 int   rc ;
50 /*
51    ---------------
52    check the input
53    ---------------
54 */
55 if (  chv == NULL || workDV == NULL || tau < 1.0 || ndelay < 0
56    || pirow == NULL || pjcol == NULL || pntest == NULL ) {
57    fprintf(stderr,
58            "\n fatal error in Chv_findPivot(%p,%p,%f,%d,%p,%p,%p)"
59            "\n bad input\n",
60            chv, workDV, tau, ndelay, pirow, pjcol, pntest) ;
61    exit(-1) ;
62 }
63 if ( !(CHV_IS_REAL(chv) || CHV_IS_COMPLEX(chv)) ) {
64    fprintf(stderr,
65            "\n fatal error in Chv_findPivot(%p,%p,%f,%d,%p,%p,%p)"
66            "\n bad type %d, must be SPOOLES_REAL or SPOOLES_COMPLEX\n",
67            chv, workDV, tau, ndelay, pirow, pjcol, pntest, chv->type) ;
68    exit(-1) ;
69 }
70 if ( !(CHV_IS_SYMMETRIC(chv) || CHV_IS_HERMITIAN(chv)
71         || CHV_IS_NONSYMMETRIC(chv)) ) {
72    fprintf(stderr,
73         "\n fatal error in Chv_findPivot(%p,%p,%f,%d,%p,%p,%p)"
74         "\n bad symflag %d"
75         "\n must be SPOOLES_SYMMETRIC, SPOOLES_HERMITIAN or CHV_NONSYMMETRIC\n",
76         chv, workDV, tau, ndelay, pirow, pjcol, pntest, chv->symflag) ;
77    exit(-1) ;
78 }
79 if ( CHV_IS_SYMMETRIC(chv) || CHV_IS_HERMITIAN(chv) ) {
80    rc = findPivotSH(chv, workDV, tau, ndelay, pirow, pjcol, pntest) ;
81 } else if ( CHV_IS_NONSYMMETRIC(chv) ) {
82    rc = findPivotN(chv, workDV, tau, ndelay, pirow, pjcol, pntest) ;
83 } else {
84    fprintf(stderr,
85            "\n fatal error in Chv_findPivot(%p,%p,%f,%d,%p,%p,%p)"
86            "\n bad symflag %d\n", chv, workDV, tau, ndelay, pirow,
87            pjcol, pntest, chv->symflag) ;
88    exit(-1) ;
89 }
90 return(rc) ; }
91 
92 /*--------------------------------------------------------------------*/
93 /*
94    --------------------------------------------------------------------
95    purpose -- find and test a pivot for a symmetric or hermitian matrix
96 
97    workDV -- object that contains work vectors
98    tau    -- upper bound on magnitude of factor entries
99    ndelay -- number of delayed rows and columns on input
100    pirow  -- pointer to be filled with pivot row
101    pjcol  -- pointer to be filled with pivot column
102    pntest -- pointer to be incremented with the number of pivot tests
103 
104    return value -- size of pivot
105      0 --> pivot not found
106      1 --> 1x1 pivot in row *pirow and column *pjcol
107      2 --> 2x2 pivot in rows and columns *pirow and *pjcol,
108            symmetric front only
109 
110    created -- 98apr18, cca
111    --------------------------------------------------------------------
112 */
113 static int
findPivotSH(Chv * chv,DV * workDV,double tau,int ndelay,int * pirow,int * pjcol,int * pntest)114 findPivotSH (
115    Chv     *chv,
116    DV       *workDV,
117    double   tau,
118    int      ndelay,
119    int      *pirow,
120    int      *pjcol,
121    int      *pntest
122 ) {
123 double   maxval ;
124 double   *rowmaxes ;
125 int      ii, irow, jrow, krow, ncand, nD,
126          ndouble, ntest, pivotsize, tag, untag ;
127 int      *rowids, *rowmark ;
128 
129 untag  = 0 ;
130 tag    = 1 ;
131 nD     = chv->nD ;
132 #if MYDEBUG > 0
133 fprintf(stdout,
134 "\n %% findPivotSH, id = %d, nD = %d, nL = %d, nU = %d, ndelay = %d",
135 chv->id, chv->nD, chv->nL, chv->nU, ndelay) ;
136 fflush(stdout) ;
137 #endif
138 *pirow = *pjcol = -1 ;
139 ntest  = *pntest ;
140 /*
141    ------------------------------------
142    symmetric front, set up work vectors
143    ------------------------------------
144 */
145 if ( sizeof(int) == sizeof(double) ) {
146    ndouble = 3*nD ;
147 } else if ( 2*sizeof(int) == sizeof(double) ) {
148    ndouble = 2*nD ;
149 }
150 DV_setSize(workDV, ndouble) ;
151 rowmaxes = DV_entries(workDV) ;
152 DVfill(nD, rowmaxes, 0.0) ;
153 rowmark  = (int *) (rowmaxes + nD) ;
154 rowids   = rowmark + nD ;
155 if ( ndelay > 0 ) {
156    IVfill(ndelay, rowmark, untag) ;
157    IVfill(nD - ndelay, rowmark + ndelay, tag) ;
158 } else {
159    IVfill(nD, rowmark, tag) ;
160 }
161 ncand = 0 ;
162 do {
163    pivotsize = 0 ;
164 /*
165    --------------------
166    find candidate pivot
167    --------------------
168 */
169    Chv_fastBunchParlettPivot(chv, rowmark, tag, &irow, &jrow) ;
170 #if MYDEBUG > 0
171    fprintf(stdout, "\n\n %% FBP: irow = %d, jrow = %d",
172            irow, jrow) ;
173    if ( irow != -1 ) {
174       double   imag, real ;
175       Chv_entry(chv, irow, irow, &real, &imag) ;
176       fprintf(stdout, "\n%%  entry(%d,%d) = %20.12e + %20.12e*i",
177               irow, irow, real, imag) ;
178       if ( jrow != irow ) {
179          Chv_entry(chv, irow, jrow, &real, &imag) ;
180          fprintf(stdout, "\n%%  entry(%d,%d) = %20.12e + %20.12e*i",
181                  irow, jrow, real, imag) ;
182          Chv_entry(chv, jrow, jrow, &real, &imag) ;
183          fprintf(stdout, "\n%%  entry(%d,%d) = %20.12e + %20.12e*i",
184                  jrow, jrow, real, imag) ;
185       }
186    }
187    fflush(stdout) ;
188 #endif
189    if ( irow == -1 ) {
190 /*
191       ----------------------------------------------
192       unable to find pivot, break out of search loop
193       ----------------------------------------------
194 */
195       pivotsize = 0 ; break ;
196    } else {
197 /*
198       -------------------------------
199       (irow,jrow) is a possible pivot
200       mark as visited and get row max
201       -------------------------------
202 */
203       Chv_maxabsInRow(chv, irow, &maxval) ;
204       rowmaxes[irow] = maxval ;
205       rowmark[irow]  = untag ;
206       if ( irow != jrow ) {
207          Chv_maxabsInRow(chv, jrow, &maxval) ;
208          rowmaxes[jrow] = maxval ;
209          rowmark[jrow] = untag ;
210       }
211       if ( irow == jrow ) {
212 /*
213          ------------------
214          test the 1x1 pivot
215          ------------------
216 */
217          pivotsize = sym1x1(chv, irow, tau, rowmaxes) ;
218          ntest++ ;
219 #if MYDEBUG > 0
220          fprintf(stdout,
221                  "\n\n %% pivotsize from sym1x1 = %d", pivotsize) ;
222 #endif
223          if ( pivotsize == 1 ) {
224             *pirow = irow ; *pjcol = jrow ;
225          } else {
226             for ( ii = 0 ; ii < ncand ; ii++ ) {
227 /*
228                ----------------------------------
229                test the 2x2 pivot (irow, krow)
230                where krow is a previous candidate
231                ----------------------------------
232 */
233                krow = rowids[ii] ;
234                pivotsize = sym2x2(chv, irow, krow, tau, rowmaxes) ;
235                ntest++ ;
236                if ( pivotsize == 2 ) {
237                   *pirow = irow ; *pjcol = krow ; break ;
238                }
239             }
240          }
241       } else {
242 /*
243          -------------------------------
244          test the 2x2 pivot (irow, jrow)
245          -------------------------------
246 */
247          pivotsize = sym2x2(chv, irow, jrow, tau, rowmaxes) ;
248          ntest++ ;
249          if ( pivotsize == 2 ) {
250             *pirow = irow ; *pjcol = jrow ;
251          } else {
252             for ( ii = 0 ; ii < ncand ; ii++ ) {
253                krow = rowids[ii] ;
254 /*
255                ----------------------------------
256                test the 2x2 pivot (irow, krow)
257                where krow is a previous candidate
258                ----------------------------------
259 */
260                pivotsize = sym2x2(chv, irow, krow, tau, rowmaxes) ;
261                ntest++ ;
262                if ( pivotsize == 2 ) {
263                   *pirow = irow ; *pjcol = krow ; break ;
264                }
265 /*
266                ----------------------------------
267                test the 2x2 pivot (jrow, krow)
268                where krow is a previous candidate
269                ----------------------------------
270 */
271                pivotsize = sym2x2(chv, jrow, krow, tau, rowmaxes) ;
272                ntest++ ;
273                if ( pivotsize == 2 ) {
274                   *pirow = jrow ; *pjcol = krow ; break ;
275                }
276             }
277          }
278       }
279       if ( pivotsize == 0 ) {
280 /*
281          ------------------------
282          add new candidate row(s)
283          ------------------------
284 */
285          rowids[ncand++] = irow ;
286          if ( irow != jrow ) {
287             rowids[ncand++] = jrow ;
288          }
289       }
290    }
291 } while ( pivotsize == 0 ) ;
292 *pntest = ntest ;
293 
294 return(pivotsize) ; }
295 
296 /*--------------------------------------------------------------------*/
297 /*
298    ------------------------------------------------------------------
299    purpose -- find and test a pivot for a nonsymmetric matrix
300 
301    workDV -- object that contains work vectors
302    tau    -- upper bound on magnitude of factor entries
303    ndelay -- number of delayed rows and columns on input
304    pirow  -- pointer to be filled with pivot row
305    pjcol  -- pointer to be filled with pivot column
306    pntest -- pointer to be incremented with the number of pivot tests
307 
308    return value -- size of pivot
309      0 --> pivot not found
310      1 --> 1x1 pivot in row *pirow and column *pjcol
311      2 --> 2x2 pivot in rows and columns *pirow and *pjcol,
312            symmetric front only
313 
314    created -- 98jan24, cca
315    ------------------------------------------------------------------
316 */
317 static int
findPivotN(Chv * chv,DV * workDV,double tau,int ndelay,int * pirow,int * pjcol,int * pntest)318 findPivotN (
319    Chv     *chv,
320    DV       *workDV,
321    double   tau,
322    int      ndelay,
323    int      *pirow,
324    int      *pjcol,
325    int      *pntest
326 ) {
327 double   maxval ;
328 double   *colmaxes, *rowmaxes ;
329 int      icol, ii, irow, jcol, jrow, ncand, nD,
330          ndouble, ntest, pivotsize, tag, untag ;
331 int      *colids, *colmark, *rowids, *rowmark ;
332 
333 untag  = 0 ;
334 tag    = 1 ;
335 nD     = chv->nD ;
336 #if MYDEBUG > 0
337 fprintf(stdout,
338 "\n %% Chv_findPivot, id = %d, nD = %d, nL = %d, nU = %d, ndelay = %d",
339 chv->id, chv->nD, chv->nL, chv->nU, ndelay) ;
340 fflush(stdout) ;
341 #endif
342 *pirow = *pjcol = -1 ;
343 ntest  = *pntest ;
344 /*
345    -------------------
346    set up work vectors
347    -------------------
348 */
349 if ( sizeof(int) == sizeof(double) ) {
350    ndouble = 6*nD ;
351 } else if ( 2*sizeof(int) == sizeof(double) ) {
352    ndouble = 4*nD ;
353 }
354 DV_setSize(workDV, ndouble) ;
355 rowmaxes = DV_entries(workDV) ;
356 colmaxes = rowmaxes + nD ;
357 DVfill(nD, rowmaxes, 0.0) ;
358 DVfill(nD, colmaxes, 0.0) ;
359 rowmark  = (int *) (colmaxes + nD) ;
360 colmark  = rowmark + nD ;
361 rowids   = colmark + nD ;
362 colids   = rowids + nD ;
363 if ( ndelay > 0 ) {
364    IVfill(ndelay, rowmark, untag) ;
365    IVfill(nD - ndelay, rowmark + ndelay, tag) ;
366    IVfill(ndelay, colmark, untag) ;
367    IVfill(nD - ndelay, colmark + ndelay, tag) ;
368 } else {
369    IVfill(nD, rowmark, tag) ;
370    IVfill(nD, colmark, tag) ;
371 }
372 ncand = 0 ;
373 do {
374    pivotsize = 0 ;
375 /*
376    --------------------
377    find candidate pivot
378    --------------------
379 */
380    Chv_quasimax(chv, rowmark, colmark, tag, &irow, &jcol) ;
381 #if MYDEBUG > 0
382    fprintf(stdout, "\n\n %% quasimax: irow = %d, jcol = %d",
383            irow, jcol) ;
384    if ( irow != -1 ) {
385       double   imag, real ;
386       Chv_entry(chv, irow, jcol, &real, &imag) ;
387       fprintf(stdout, "\n%%  entry(%d,%d) = %20.12e + %20.12e*i",
388               irow, jcol, real, imag) ;
389    }
390    fflush(stdout) ;
391 #endif
392    if ( irow == -1 ) {
393 /*
394       ----------------------------------------------
395       unable to find pivot, break out of search loop
396       ----------------------------------------------
397 */
398       break ;
399    } else {
400 /*
401       ------------------------------------------------------------
402       find the row max for row irow and column max for column jcol
403       ------------------------------------------------------------
404 */
405       Chv_maxabsInRow(chv, irow, &maxval) ;
406       rowmaxes[irow] = maxval ;
407       Chv_maxabsInColumn(chv, jcol, &maxval) ;
408       colmaxes[jcol] = maxval ;
409       rowmark[irow]  = untag ;
410       colmark[jcol]  = untag ;
411 /*
412       -------------------------------------
413       test the (irow,jcol) entry as a pivot
414       -------------------------------------
415 */
416       pivotsize = nonsym1x1(chv, irow, jcol, tau, rowmaxes, colmaxes) ;
417       ntest++ ;
418       if ( pivotsize == 1 ) {
419          *pirow = irow ; *pjcol = jcol ;
420       } else {
421 /*
422          ---------------------------------------
423          test the other matrix entries as pivots
424          ---------------------------------------
425 */
426          for ( ii = 0 ; ii < ncand ; ii++ ) {
427             jrow = rowids[ii] ;
428             icol = colids[ii] ;
429 /*
430             --------------------------
431             test the (irow,icol) entry
432             --------------------------
433 */
434             pivotsize = nonsym1x1(chv, irow, icol, tau,
435                                   rowmaxes, colmaxes) ;
436             ntest++ ;
437             if ( pivotsize == 1 ) {
438                *pirow = irow ; *pjcol = icol ; break ;
439             }
440 /*
441             --------------------------
442             test the (jrow,jcol) entry
443             --------------------------
444 */
445             pivotsize = nonsym1x1(chv, jrow, jcol, tau,
446                                   rowmaxes, colmaxes) ;
447             ntest++ ;
448             if ( pivotsize == 1 ) {
449                *pirow = jrow ; *pjcol = jcol ; break ;
450             }
451          }
452 /*
453          ----------------------------------------------
454          no pivots found, add irow to candidate row ids
455          and add jcol to candidate column ids
456          ----------------------------------------------
457 */
458          rowids[ncand] = irow ;
459          colids[ncand] = jcol ;
460          ncand++ ;
461       }
462    }
463 } while ( pivotsize == 0 ) ;
464 *pntest = ntest ;
465 
466 return(pivotsize) ; }
467 
468 /*--------------------------------------------------------------------*/
469 /*
470    ---------------------------------------------
471    return 1 if the nonsymmetric 1x1 pivot passes
472    return 0 otherwise
473 
474    created -- 98jan24, cca
475    ---------------------------------------------
476 */
477 static int
nonsym1x1(Chv * chv,int irow,int jcol,double tau,double rowmaxes[],double colmaxes[])478 nonsym1x1 (
479    Chv   *chv,
480    int    irow,
481    int    jcol,
482    double   tau,
483    double   rowmaxes[],
484    double   colmaxes[]
485 ) {
486 double   cutoff, magn ;
487 int      rc ;
488 
489 if ( CHV_IS_REAL(chv) ) {
490    double   value ;
491    Chv_realEntry(chv, irow, jcol, &value) ;
492    magn = fabs(value) ;
493 } else if ( CHV_IS_COMPLEX(chv) ) {
494    double   imag, real ;
495    Chv_complexEntry(chv, irow, jcol, &real, &imag) ;
496    magn = Zabs(real, imag) ;
497 }
498 cutoff = tau * magn ;
499 #if MYDEBUG > 0
500 fprintf(stdout, "\n %% magn = %12.4e, cutoff = %12.4e", magn, cutoff) ;
501 fprintf(stdout, "\n %% rowmaxes[%d] = %12.4e, colmaxes[%d] = %12.4e",
502         irow, rowmaxes[irow], jcol, colmaxes[jcol]) ;
503 #endif
504 if ( rowmaxes[irow] <= cutoff && colmaxes[jcol] <= cutoff ) {
505    rc = 1 ;
506 } else {
507    rc = 0 ;
508 }
509 return(rc) ; }
510 
511 /*--------------------------------------------------------------------*/
512 /*
513    ------------------------------------------
514    return 1 if the symmetric 1x1 pivot passes
515    return 0 otherwise
516 
517    created -- 98jan24, cca
518    ------------------------------------------
519 */
520 static int
sym1x1(Chv * chv,int irow,double tau,double rowmaxes[])521 sym1x1 (
522    Chv     *chv,
523    int      irow,
524    double   tau,
525    double   rowmaxes[]
526 ) {
527 double   cutoff ;
528 int      rc ;
529 
530 if ( CHV_IS_REAL(chv) ) {
531    double   value ;
532    Chv_realEntry(chv, irow, irow, &value) ;
533    cutoff = tau * fabs(value) ;
534 } else if ( CHV_IS_COMPLEX(chv) ) {
535    double   imag, real ;
536    Chv_complexEntry(chv, irow, irow, &real, &imag) ;
537    cutoff = tau * Zabs(real, imag) ;
538 }
539 #if MYDEBUG > 0
540 fprintf(stdout, "\n %% cutoff = %12.4e, rowmaxes[%d] = %12.4e",
541         cutoff, irow, rowmaxes[irow]) ;
542 #endif
543 if ( rowmaxes[irow] <= cutoff ) {
544    rc = 1 ;
545 } else {
546    rc = 0 ;
547 }
548 return(rc) ; }
549 
550 /*--------------------------------------------------------------------*/
551 /*
552    ------------------------------------------
553    return 2 if the symmetric 2x2 pivot passes
554    return 0 otherwise
555 
556    created -- 98jan24, cca
557    ------------------------------------------
558 */
559 static int
sym2x2(Chv * chv,int irow,int jcol,double tau,double rowmaxes[])560 sym2x2 (
561    Chv     *chv,
562    int      irow,
563    int      jcol,
564    double   tau,
565    double   rowmaxes[]
566 ) {
567 double   amag, bmag, cmag, denom, val1, val2 ;
568 int      rc ;
569 
570 if ( CHV_IS_REAL(chv) ) {
571    double  a, b, c ;
572 
573    Chv_realEntry(chv, irow, irow, &a) ;
574    Chv_realEntry(chv, irow, jcol, &b) ;
575    Chv_realEntry(chv, jcol, jcol, &c) ;
576    amag  = fabs(a) ;
577    bmag  = fabs(b) ;
578    cmag  = fabs(c) ;
579    denom = fabs(a*c - b*b) ;
580 } else if ( CHV_IS_COMPLEX(chv) ) {
581    double   aimag, areal, bimag, breal, cimag, creal, imag, real ;
582 
583    Chv_complexEntry(chv, irow, irow, &areal, &aimag) ;
584    Chv_complexEntry(chv, irow, jcol, &breal, &bimag) ;
585    Chv_complexEntry(chv, jcol, jcol, &creal, &cimag) ;
586    if ( CHV_IS_HERMITIAN(chv) ) {
587       amag  = fabs(areal) ;
588       bmag  = Zabs(breal, bimag) ;
589       cmag  = fabs(creal) ;
590       denom = fabs(areal*creal - breal*breal - bimag*bimag) ;
591    } else if ( CHV_IS_SYMMETRIC(chv) ) {
592       amag  = Zabs(areal, aimag) ;
593       bmag  = Zabs(breal, bimag) ;
594       cmag  = Zabs(creal, cimag) ;
595       real  = areal*creal - aimag*cimag - breal*breal + bimag*bimag ;
596       imag  = areal*cimag + aimag*creal - 2*breal*bimag ;
597       denom = Zabs(real, imag) ;
598    }
599 }
600 #if MYDEBUG > 0
601       fprintf(stdout,
602               "\n amag = %20.12e ; "
603               "\n bmag = %20.12e ; "
604               "\n cmag = %20.12e ; ", amag, bmag, cmag) ;
605 #endif
606 if ( denom == 0.0 ) {
607    return(0) ;
608 }
609 val1 = (cmag*rowmaxes[irow] + bmag*rowmaxes[jcol])/denom ;
610 val2 = (bmag*rowmaxes[irow] + amag*rowmaxes[jcol])/denom ;
611 #if MYDEBUG > 0
612 fprintf(stdout, "\n %% sym2x2"
613         "\n rowmax1 = %20.12e"
614         "\n rowmax2 = %20.12e"
615         "\n val1 = %20.12e"
616         "\n val2 = %20.12e"
617         "\n denom = %20.12e",
618         rowmaxes[irow], rowmaxes[jcol], val1, val2, denom) ;
619 #endif
620 if (  val1 <= tau && val2 <= tau ) {
621    rc = 2 ;
622 } else {
623    rc = 0 ;
624 }
625 return(rc) ; }
626 
627 /*--------------------------------------------------------------------*/
628