1 /******************************************************************************
2  * Copyright 1998-2019 Lawrence Livermore National Security, LLC and other
3  * HYPRE Project Developers. See the top-level COPYRIGHT file for details.
4  *
5  * SPDX-License-Identifier: (Apache-2.0 OR MIT)
6  ******************************************************************************/
7 
8 /******************************************************************************
9  *
10  * HYPRE_LSI_DSuperLU interface
11  *
12  *****************************************************************************/
13 
14 #include <stdlib.h>
15 #include <stdio.h>
16 #include <math.h>
17 
18 #include "utilities/_hypre_utilities.h"
19 #include "HYPRE.h"
20 #include "IJ_mv/HYPRE_IJ_mv.h"
21 #include "parcsr_mv/_hypre_parcsr_mv.h"
22 #include "parcsr_ls/HYPRE_parcsr_ls.h"
23 
24 /*---------------------------------------------------------------------------
25  * Distributed SUPERLU include files
26  *-------------------------------------------------------------------------*/
27 
28 #ifdef HYPRE_USING_DSUPERLU
29 #include "superlu_ddefs.h"
30 
31 typedef struct HYPRE_LSI_DSuperLU_Struct
32 {
33    MPI_Comm           comm_;
34    HYPRE_ParCSRMatrix Amat_;
35    superlu_dist_options_t  options_;
36    SuperMatrix        sluAmat_;
37    ScalePermstruct_t  ScalePermstruct_;
38    SuperLUStat_t      stat_;
39    LUstruct_t         LUstruct_;
40    SOLVEstruct_t      SOLVEstruct_;
41    int                globalNRows_;
42    int                localNRows_;
43    int                startRow_;
44    int                outputLevel_;
45    double             *berr_;
46    gridinfo_t         sluGrid_;
47    int                setupFlag_;
48 }
49 HYPRE_LSI_DSuperLU;
50 
51 int HYPRE_LSI_DSuperLUGenMatrix(HYPRE_Solver solver);
52 
53 /***************************************************************************
54  * HYPRE_LSI_DSuperLUCreate - Return a DSuperLU object "solver".
55  *--------------------------------------------------------------------------*/
56 
HYPRE_LSI_DSuperLUCreate(MPI_Comm comm,HYPRE_Solver * solver)57 int HYPRE_LSI_DSuperLUCreate( MPI_Comm comm, HYPRE_Solver *solver )
58 {
59    HYPRE_LSI_DSuperLU *sluPtr;
60    sluPtr = hypre_TAlloc(HYPRE_LSI_DSuperLU, 1, HYPRE_MEMORY_HOST);
61    hypre_assert ( sluPtr != NULL );
62    sluPtr->comm_        = comm;
63    sluPtr->Amat_        = NULL;
64    sluPtr->localNRows_  = 0;
65    sluPtr->globalNRows_ = 0;
66    sluPtr->startRow_    = 0;
67    sluPtr->outputLevel_ = 0;
68    sluPtr->setupFlag_   = 0;
69    sluPtr->berr_ = hypre_TAlloc(double, 1, HYPRE_MEMORY_HOST);
70    *solver = (HYPRE_Solver) sluPtr;
71    return 0;
72 }
73 
74 /***************************************************************************
75  * HYPRE_LSI_DSuperLUDestroy - Destroy a DSuperLU object.
76  *--------------------------------------------------------------------------*/
77 
HYPRE_LSI_DSuperLUDestroy(HYPRE_Solver solver)78 int HYPRE_LSI_DSuperLUDestroy( HYPRE_Solver solver )
79 {
80    HYPRE_LSI_DSuperLU *sluPtr;
81    sluPtr = (HYPRE_LSI_DSuperLU *) solver;
82    sluPtr->Amat_ = NULL;
83    if (sluPtr->setupFlag_ == 1)
84    {
85       PStatFree(&(sluPtr->stat_));
86       Destroy_CompRowLoc_Matrix_dist(&(sluPtr->sluAmat_));
87       ScalePermstructFree(&(sluPtr->ScalePermstruct_));
88       Destroy_LU(sluPtr->globalNRows_, &(sluPtr->sluGrid_), &(sluPtr->LUstruct_));
89       LUstructFree(&(sluPtr->LUstruct_));
90       if (sluPtr->options_.SolveInitialized)
91          dSolveFinalize(&(sluPtr->options_), &(sluPtr->SOLVEstruct_));
92       superlu_gridexit(&(sluPtr->sluGrid_));
93    }
94    hypre_TFree(sluPtr->berr_, HYPRE_MEMORY_HOST);
95    hypre_TFree(sluPtr, HYPRE_MEMORY_HOST);
96    return 0;
97 }
98 
99 /***************************************************************************
100  * HYPRE_LSI_DSuperLUSetOutputLevel - Set debug level
101  *--------------------------------------------------------------------------*/
102 
HYPRE_LSI_DSuperLUSetOutputLevel(HYPRE_Solver solver,int level)103 int HYPRE_LSI_DSuperLUSetOutputLevel(HYPRE_Solver solver, int level)
104 {
105    HYPRE_LSI_DSuperLU *sluPtr = (HYPRE_LSI_DSuperLU *) solver;
106    sluPtr->outputLevel_ = level;
107    return 0;
108 }
109 
110 /***************************************************************************
111  * HYPRE_LSI_DSuperLUSetup - Set up function for LSI_DSuperLU.
112  *--------------------------------------------------------------------------*/
113 
HYPRE_LSI_DSuperLUSetup(HYPRE_Solver solver,HYPRE_ParCSRMatrix A_csr,HYPRE_ParVector b,HYPRE_ParVector x)114 int HYPRE_LSI_DSuperLUSetup(HYPRE_Solver solver, HYPRE_ParCSRMatrix A_csr,
115                             HYPRE_ParVector b, HYPRE_ParVector x )
116 {
117    int                nprocs, mypid, nprow, npcol, info, iZero=0;
118    HYPRE_LSI_DSuperLU *sluPtr = (HYPRE_LSI_DSuperLU *) solver;
119    MPI_Comm           mpiComm;
120 
121    /* ---------------------------------------------------------------- */
122    /* get machine information                                          */
123    /* ---------------------------------------------------------------- */
124 
125    mpiComm = sluPtr->comm_;
126    MPI_Comm_size(mpiComm, &nprocs);
127    MPI_Comm_rank(mpiComm, &mypid);
128 
129    /* ---------------------------------------------------------------- */
130    /* compute grid information                                         */
131    /* ---------------------------------------------------------------- */
132 
133    nprow = sluPtr->sluGrid_.nprow = 1;
134    npcol = sluPtr->sluGrid_.npcol = nprocs;
135    superlu_gridinit(mpiComm, nprow, npcol, &(sluPtr->sluGrid_));
136    if (mypid != sluPtr->sluGrid_.iam)
137    {
138       printf("DSuperLU ERROR: mismatched mypid and SuperLU iam.\n");
139       exit(1);
140    }
141 
142    /* ---------------------------------------------------------------- */
143    /* get whole matrix and compose SuperLU matrix                      */
144    /* ---------------------------------------------------------------- */
145 
146    sluPtr->Amat_ = A_csr;
147    HYPRE_LSI_DSuperLUGenMatrix(solver);
148 
149    /* ---------------------------------------------------------------- */
150    /* set solver options                                               */
151    /* ---------------------------------------------------------------- */
152 
153    set_default_options_dist(&(sluPtr->options_));
154    /* options->Fact              = DOFACT (SamePattern,FACTORED}
155       options->Equil             = YES (NO, ROW, COL, BOTH)
156                                    (YES not robust)
157       options->ParSymbFact       = NO;
158       options->ColPerm           = MMD_AT_PLUS_A (NATURAL, MMD_ATA,
159                                    METIS_AT_PLUS_A, PARMETIS, MY_PERMC}
160                                    (MMD_AT_PLUS_A the fastest, a factor
161                                     of 3+ better than MMD_ATA, which in
162                                     turn is 25% better than NATURAL)
163       options->RowPerm           = LargeDiag (NOROWPERM, MY_PERMR)
164       options->ReplaceTinyPivot  = YES (NO)
165       options->IterRefine        = DOUBLE (NOREFINE, SINGLE, EXTRA)
166                                    (EXTRA not supported, DOUBLE more
167                                     accurate)
168       options->Trans             = NOTRANS (TRANS, CONJ)
169       options->SolveInitialized  = NO;
170       options->RefineInitialized = NO;
171       options->PrintStat         = YES;
172    */
173    sluPtr->options_.Fact = DOFACT;
174    sluPtr->options_.Equil = YES;
175    sluPtr->options_.IterRefine = SLU_DOUBLE;
176    sluPtr->options_.ColPerm = MMD_AT_PLUS_A;
177    sluPtr->options_.DiagPivotThresh = 1.0;
178    sluPtr->options_.ReplaceTinyPivot = NO;
179    if (sluPtr->outputLevel_ < 2) sluPtr->options_.PrintStat = NO;
180    ScalePermstructInit(sluPtr->globalNRows_, sluPtr->globalNRows_,
181                        &(sluPtr->ScalePermstruct_));
182 //   LUstructInit(sluPtr->globalNRows_, sluPtr->globalNRows_,
183 //                &(sluPtr->LUstruct_));
184    LUstructInit(sluPtr->globalNRows_, &(sluPtr->LUstruct_));
185    sluPtr->berr_[0] = 0.0;
186    PStatInit(&(sluPtr->stat_));
187    pdgssvx(&(sluPtr->options_), &(sluPtr->sluAmat_),
188            &(sluPtr->ScalePermstruct_), NULL, sluPtr->localNRows_, iZero,
189            &(sluPtr->sluGrid_), &(sluPtr->LUstruct_),
190            &(sluPtr->SOLVEstruct_), sluPtr->berr_, &(sluPtr->stat_), &info);
191    sluPtr->options_.Fact = FACTORED;
192    if (sluPtr->outputLevel_ >= 2)
193       PStatPrint(&(sluPtr->options_),&(sluPtr->stat_),&(sluPtr->sluGrid_));
194 
195    sluPtr->setupFlag_ = 1;
196 
197    if (mypid == 0 && sluPtr->outputLevel_ >=2)
198    {
199       printf("DSuperLUSetup: diagScale = %d\n",
200              sluPtr->ScalePermstruct_.DiagScale);
201       printf("DSuperLUSetup: berr = %e\n", sluPtr->berr_[0]);
202       printf("DSuperLUSetup: info = %d\n", info);
203    }
204    return 0;
205 }
206 
207 /***************************************************************************
208  * HYPRE_LSI_DSuperLUSolve - Solve function for DSuperLU.
209  *--------------------------------------------------------------------------*/
210 
HYPRE_LSI_DSuperLUSolve(HYPRE_Solver solver,HYPRE_ParCSRMatrix A,HYPRE_ParVector b,HYPRE_ParVector x)211 int HYPRE_LSI_DSuperLUSolve( HYPRE_Solver solver, HYPRE_ParCSRMatrix A,
212                              HYPRE_ParVector b, HYPRE_ParVector x )
213 {
214    int                localNRows, irow, iOne=1, info, mypid;
215    double             *rhs, *soln;
216    HYPRE_LSI_DSuperLU *sluPtr = (HYPRE_LSI_DSuperLU *) solver;
217 
218    /* ---------------------------------------------------------------- */
219    /* get machine, matrix, and vector information                      */
220    /* ---------------------------------------------------------------- */
221 
222    MPI_Comm_rank(sluPtr->comm_, &mypid);
223    localNRows  = sluPtr->localNRows_;
224    rhs  = hypre_VectorData(hypre_ParVectorLocalVector((hypre_ParVector *) b));
225    soln = hypre_VectorData(hypre_ParVectorLocalVector((hypre_ParVector *) x));
226    for (irow = 0; irow < localNRows; irow++) soln[irow] = rhs[irow];
227 
228    /* ---------------------------------------------------------------- */
229    /* solve                                                            */
230    /* ---------------------------------------------------------------- */
231 
232    pdgssvx(&(sluPtr->options_), &(sluPtr->sluAmat_),
233            &(sluPtr->ScalePermstruct_), soln, localNRows, iOne,
234            &(sluPtr->sluGrid_), &(sluPtr->LUstruct_),
235            &(sluPtr->SOLVEstruct_), sluPtr->berr_, &(sluPtr->stat_), &info);
236 
237    /* ---------------------------------------------------------------- */
238    /* diagnostics message                                              */
239    /* ---------------------------------------------------------------- */
240 
241    if (mypid == 0 && sluPtr->outputLevel_ >=2)
242    {
243       printf("DSuperLUSolve: info = %d\n", info);
244       printf("DSuperLUSolve: diagScale = %d\n",
245              sluPtr->ScalePermstruct_.DiagScale);
246    }
247    return 0;
248 }
249 
250 /****************************************************************************
251  * Create SuperLU matrix in CSR
252  *--------------------------------------------------------------------------*/
253 
HYPRE_LSI_DSuperLUGenMatrix(HYPRE_Solver solver)254 int HYPRE_LSI_DSuperLUGenMatrix(HYPRE_Solver solver)
255 {
256    int        nprocs, mypid, *csrIA, *csrJA, *procNRows, localNNZ;
257    int        startRow, localNRows, rowSize, *colInd, irow, jcol;
258    double     *csrAA, *colVal;
259    HYPRE_LSI_DSuperLU *sluPtr = (HYPRE_LSI_DSuperLU *) solver;
260    HYPRE_ParCSRMatrix Amat;
261    MPI_Comm   mpiComm;
262 
263    /* ---------------------------------------------------------------- */
264    /* fetch parallel machine parameters                                */
265    /* ---------------------------------------------------------------- */
266 
267    mpiComm = sluPtr->comm_;
268    MPI_Comm_rank(mpiComm, &mypid);
269    MPI_Comm_size(mpiComm, &nprocs);
270 
271    /* ---------------------------------------------------------------- */
272    /* fetch matrix information                                         */
273    /* ---------------------------------------------------------------- */
274 
275    Amat = sluPtr->Amat_;
276    HYPRE_ParCSRMatrixGetRowPartitioning(Amat, &procNRows);
277    startRow = procNRows[mypid];
278    sluPtr->startRow_ = startRow;
279    localNNZ = 0;
280    for (irow = startRow; irow < procNRows[mypid+1]; irow++)
281    {
282       HYPRE_ParCSRMatrixGetRow(Amat,irow,&rowSize,&colInd,&colVal);
283       localNNZ += rowSize;
284       HYPRE_ParCSRMatrixRestoreRow(Amat,irow,&rowSize,&colInd,&colVal);
285    }
286    localNRows = procNRows[mypid+1] - procNRows[mypid];
287    sluPtr->localNRows_ = localNRows;
288    sluPtr->globalNRows_ = procNRows[nprocs];
289    csrIA = (int *) intMalloc_dist(localNRows+1);
290    csrJA = (int *) intMalloc_dist(localNNZ);
291    csrAA = (double *) doubleMalloc_dist(localNNZ);
292    localNNZ = 0;
293 
294    csrIA[0] = localNNZ;
295    for (irow = startRow; irow < procNRows[mypid+1]; irow++)
296    {
297       HYPRE_ParCSRMatrixGetRow(Amat,irow,&rowSize,&colInd,&colVal);
298       for ( jcol = 0; jcol < rowSize; jcol++ )
299       {
300          csrJA[localNNZ] = colInd[jcol];
301          csrAA[localNNZ++] = colVal[jcol];
302       }
303       csrIA[irow-startRow+1] = localNNZ;
304       HYPRE_ParCSRMatrixRestoreRow(Amat,irow,&rowSize,&colInd,&colVal);
305    }
306    /*for (irow = startRow; irow < procNRows[mypid+1]; irow++)
307     *   qsort1(csrJA, csrAA, csrIA[irow-startRow], csrIA[irow-startRow+1]-1);
308     */
309 
310    /* ---------------------------------------------------------------- */
311    /* create SuperLU matrix                                            */
312    /* ---------------------------------------------------------------- */
313 
314    dCreate_CompRowLoc_Matrix_dist(&(sluPtr->sluAmat_), sluPtr->globalNRows_,
315             sluPtr->globalNRows_, localNNZ, localNRows, startRow, csrAA,
316             csrJA, csrIA, SLU_NR_loc, SLU_D, SLU_GE);
317    hypre_TFree(procNRows, HYPRE_MEMORY_HOST);
318    return 0;
319 }
320 #else
321    int bogus;
322 #endif
323 
324