1 /******************************************************************************
2  * Copyright 1998-2019 Lawrence Livermore National Security, LLC and other
3  * HYPRE Project Developers. See the top-level COPYRIGHT file for details.
4  *
5  * SPDX-License-Identifier: (Apache-2.0 OR MIT)
6  ******************************************************************************/
7 
8 /* ****************************************************************************
9  * -- SuperLU routine (version 1.1) --
10  * Univ. of California Berkeley, Xerox Palo Alto Research Center,
11  * and Lawrence Berkeley National Lab.
12  * ************************************************************************* */
13 
14 #ifdef MLI_SUPERLU
15 
16 #include <string.h>
17 #include "mli_solver_superlu.h"
18 
19 /* ****************************************************************************
20  * constructor
21  * --------------------------------------------------------------------------*/
22 
MLI_Solver_SuperLU(char * name)23 MLI_Solver_SuperLU::MLI_Solver_SuperLU(char *name) : MLI_Solver(name)
24 {
25    permR_      = NULL;
26    permC_      = NULL;
27    mliAmat_    = NULL;
28    factorized_ = 0;
29 }
30 
31 /* ****************************************************************************
32  * destructor
33  * --------------------------------------------------------------------------*/
34 
~MLI_Solver_SuperLU()35 MLI_Solver_SuperLU::~MLI_Solver_SuperLU()
36 {
37    if ( permR_ != NULL )
38    {
39       Destroy_SuperNode_Matrix(&superLU_Lmat);
40       Destroy_CompCol_Matrix(&superLU_Umat);
41    }
42    if ( permR_ != NULL ) delete [] permR_;
43    if ( permC_ != NULL ) delete [] permC_;
44 }
45 
46 /* ****************************************************************************
47  * setup
48  * --------------------------------------------------------------------------*/
49 
setup(MLI_Matrix * Amat)50 int MLI_Solver_SuperLU::setup( MLI_Matrix *Amat )
51 {
52    int      globalNRows, localNRows, startRow, localNnz, globalNnz;
53    int      *csrIA, *csrJA, *gcsrIA, *gcsrJA, *gcscJA, *gcscIA;
54    int      nnz, row_num, irow, i, j, rowSize, *cols, *recvCntArray;
55    int      *dispArray, itemp, *cntArray, icol, colNum, index;
56    int      *etree, permcSpec, lwork, panel_size, relax, info, mypid, nprocs;
57    double   *vals, *csrAA, *gcsrAA, *gcscAA, diagPivotThresh;
58    MPI_Comm mpiComm;
59    hypre_ParCSRMatrix *hypreA;
60    SuperMatrix        AC;
61    superlu_options_t  slu_options;
62    SuperLUStat_t      slu_stat;
63    GlobalLU_t         Glu;
64 
65    /* ---------------------------------------------------------------
66     * fetch matrix
67     * -------------------------------------------------------------*/
68 
69    mliAmat_ = Amat;
70    if ( strcmp( mliAmat_->getName(), "HYPRE_ParCSR" ) )
71    {
72       printf("MLI_Solver_SuperLU::setup ERROR - not HYPRE_ParCSR.\n");
73       exit(1);
74    }
75    hypreA = (hypre_ParCSRMatrix *) mliAmat_->getMatrix();
76 
77    /* ---------------------------------------------------------------
78     * fetch matrix
79     * -------------------------------------------------------------*/
80 
81    mpiComm     = hypre_ParCSRMatrixComm( hypreA );
82    MPI_Comm_rank( mpiComm, &mypid );
83    MPI_Comm_size( mpiComm, &nprocs );
84    globalNRows = hypre_ParCSRMatrixGlobalNumRows( hypreA );
85    localNRows  = hypre_ParCSRMatrixNumRows( hypreA );
86    startRow    = hypre_ParCSRMatrixFirstRowIndex( hypreA );
87    localNnz    = 0;
88    for ( irow = 0; irow < localNRows; irow++ )
89    {
90       row_num = startRow + irow;
91       hypre_ParCSRMatrixGetRow(hypreA, row_num, &rowSize, &cols, NULL);
92       localNnz += rowSize;
93       hypre_ParCSRMatrixRestoreRow(hypreA, row_num, &rowSize, &cols, NULL);
94    }
95    MPI_Allreduce(&localNnz, &globalNnz, 1, MPI_INT, MPI_SUM, mpiComm );
96    csrIA    = new int[localNRows+1];
97    if ( localNnz > 0 ) csrJA = new int[localNnz];
98    else                csrJA = NULL;
99    if ( localNnz > 0 ) csrAA = new double[localNnz];
100    else                csrAA = NULL;
101    nnz      = 0;
102    csrIA[0] = nnz;
103    for ( irow = 0; irow < localNRows; irow++ )
104    {
105       row_num = startRow + irow;
106       hypre_ParCSRMatrixGetRow(hypreA, row_num, &rowSize, &cols, &vals);
107       for ( i = 0; i < rowSize; i++ )
108       {
109          csrJA[nnz] = cols[i];
110          csrAA[nnz++] = vals[i];
111       }
112       hypre_ParCSRMatrixRestoreRow(hypreA, row_num, &rowSize, &cols, &vals);
113       csrIA[irow+1] = nnz;
114    }
115 
116    /* ---------------------------------------------------------------
117     * collect the whole matrix
118     * -------------------------------------------------------------*/
119 
120    gcsrIA = new int[globalNRows+1];
121    gcsrJA = new int[globalNnz];
122    gcsrAA = new double[globalNnz];
123    recvCntArray = new int[nprocs];
124    dispArray    = new int[nprocs];
125 
126    MPI_Allgather(&localNRows,1,MPI_INT,recvCntArray,1,MPI_INT,mpiComm);
127    dispArray[0] = 0;
128    for ( i = 1; i < nprocs; i++ )
129        dispArray[i] = dispArray[i-1] + recvCntArray[i-1];
130    csrIA[0] = csrIA[localNRows];
131    MPI_Allgatherv(csrIA, localNRows, MPI_INT, gcsrIA, recvCntArray,
132                   dispArray, MPI_INT, mpiComm);
133    nnz = 0;
134    row_num = 0;
135    for ( i = 0; i < nprocs; i++ )
136    {
137       if ( recvCntArray[i] > 0 )
138       {
139          itemp = gcsrIA[row_num];
140          gcsrIA[row_num] = 0;
141          for ( j = 0; j < recvCntArray[i]; j++ )
142             gcsrIA[row_num+j] += nnz;
143          nnz += itemp;
144          row_num += recvCntArray[i];
145       }
146    }
147    gcsrIA[globalNRows] = nnz;
148 
149    MPI_Allgather(&localNnz, 1, MPI_INT, recvCntArray, 1, MPI_INT, mpiComm);
150    dispArray[0] = 0;
151    for ( i = 1; i < nprocs; i++ )
152       dispArray[i] = dispArray[i-1] + recvCntArray[i-1];
153    MPI_Allgatherv(csrJA, localNnz, MPI_INT, gcsrJA, recvCntArray,
154                   dispArray, MPI_INT, mpiComm);
155 
156    MPI_Allgatherv(csrAA, localNnz, MPI_DOUBLE, gcsrAA, recvCntArray,
157                   dispArray, MPI_DOUBLE, mpiComm);
158 
159    delete [] recvCntArray;
160    delete [] dispArray;
161    delete [] csrIA;
162    if ( csrJA != NULL ) delete [] csrJA;
163    if ( csrAA != NULL ) delete [] csrAA;
164 
165    /* ---------------------------------------------------------------
166     * conversion from CSR to CSC
167     * -------------------------------------------------------------*/
168 
169    cntArray = new int[globalNRows];
170    for ( irow = 0; irow < globalNRows; irow++ ) cntArray[irow] = 0;
171    for ( irow = 0; irow < globalNRows; irow++ )
172    {
173       for ( i = gcsrIA[irow]; i < gcsrIA[irow+1]; i++ )
174          if ( gcsrJA[i] >= 0 && gcsrJA[i] < globalNRows )
175             cntArray[gcsrJA[i]]++;
176          else
177          {
178             printf("%d : MLI_Solver_SuperLU ERROR : gcsrJA %d %d = %d(%d)\n",
179                    mypid, irow, i, gcsrJA[i], globalNRows);
180             exit(1);
181          }
182    }
183    gcscJA = hypre_TAlloc(int,  (globalNRows+1) , HYPRE_MEMORY_HOST);
184    gcscIA = hypre_TAlloc(int,  globalNnz , HYPRE_MEMORY_HOST);
185    gcscAA = hypre_TAlloc(double,  globalNnz , HYPRE_MEMORY_HOST);
186    gcscJA[0] = 0;
187    nnz = 0;
188    for ( icol = 1; icol <= globalNRows; icol++ )
189    {
190       nnz += cntArray[icol-1];
191       gcscJA[icol] = nnz;
192    }
193    for ( irow = 0; irow < globalNRows; irow++ )
194    {
195       for ( i = gcsrIA[irow]; i < gcsrIA[irow+1]; i++ )
196       {
197          colNum = gcsrJA[i];
198          index   = gcscJA[colNum]++;
199          gcscIA[index] = irow;
200          gcscAA[index] = gcsrAA[i];
201       }
202    }
203    gcscJA[0] = 0;
204    nnz = 0;
205    for ( icol = 1; icol <= globalNRows; icol++ )
206    {
207       nnz += cntArray[icol-1];
208       gcscJA[icol] = nnz;
209    }
210    delete [] cntArray;
211    delete [] gcsrIA;
212    delete [] gcsrJA;
213    delete [] gcsrAA;
214 
215    /* ---------------------------------------------------------------
216     * make SuperMatrix
217     * -------------------------------------------------------------*/
218 
219    dCreate_CompCol_Matrix(&superLU_Amat, globalNRows, globalNRows,
220                           gcscJA[globalNRows], gcscAA, gcscIA, gcscJA,
221                           SLU_NC, SLU_D, SLU_GE);
222    etree   = new int[globalNRows];
223    permC_  = new int[globalNRows];
224    permR_  = new int[globalNRows];
225    permcSpec = 0;
226    get_perm_c(permcSpec, &superLU_Amat, permC_);
227    slu_options.Fact = DOFACT;
228    slu_options.SymmetricMode = NO;
229    sp_preorder(&slu_options, &superLU_Amat, permC_, etree, &AC);
230    diagPivotThresh = 1.0;
231    panel_size = sp_ienv(1);
232    relax = sp_ienv(2);
233    StatInit(&slu_stat);
234    lwork = 0;
235    slu_options.ColPerm = MY_PERMC;
236    slu_options.DiagPivotThresh = diagPivotThresh;
237 
238 //   dgstrf(&slu_options, &AC, dropTol, relax, panel_size,
239 //          etree, NULL, lwork, permC_, permR_, &superLU_Lmat,
240 //          &superLU_Umat, &slu_stat, &info);
241    dgstrf(&slu_options, &AC, relax, panel_size,
242           etree, NULL, lwork, permC_, permR_, &superLU_Lmat,
243           &superLU_Umat, &Glu, &slu_stat, &info);
244    Destroy_CompCol_Permuted(&AC);
245    Destroy_CompCol_Matrix(&superLU_Amat);
246    delete [] etree;
247    factorized_ = 1;
248    StatFree(&slu_stat);
249    return 0;
250 }
251 
252 /* ****************************************************************************
253  * This subroutine calls the SuperLU subroutine to perform LU
254  * backward substitution
255  * --------------------------------------------------------------------------*/
256 
solve(MLI_Vector * f_in,MLI_Vector * u_in)257 int MLI_Solver_SuperLU::solve( MLI_Vector *f_in, MLI_Vector *u_in )
258 {
259    int             globalNRows, localNRows, startRow, *recvCntArray;
260    int             i, irow, nprocs, *dispArray, info;
261    double          *fGlobal;
262    hypre_ParVector *f, *u;
263    double          *uData, *fData;
264    SuperMatrix     B;
265    MPI_Comm        mpiComm;
266    hypre_ParCSRMatrix *hypreA;
267    SuperLUStat_t      slu_stat;
268    trans_t            trans;
269 
270    /* -------------------------------------------------------------
271     * check that the factorization has been called
272     * -----------------------------------------------------------*/
273 
274    if ( ! factorized_ )
275    {
276       printf("MLI_Solver_SuperLU::Solve ERROR - not factorized yet.\n");
277       exit(1);
278    }
279 
280    /* -------------------------------------------------------------
281     * fetch matrix and vector parameters
282     * -----------------------------------------------------------*/
283 
284    hypreA      = (hypre_ParCSRMatrix *) mliAmat_->getMatrix();
285    mpiComm     = hypre_ParCSRMatrixComm( hypreA );
286    globalNRows = hypre_ParCSRMatrixGlobalNumRows( hypreA );
287    localNRows  = hypre_ParCSRMatrixNumRows( hypreA );
288    startRow    = hypre_ParCSRMatrixFirstRowIndex( hypreA );
289    u           = (hypre_ParVector *) u_in->getVector();
290    uData       = hypre_VectorData(hypre_ParVectorLocalVector(u));
291    f           = (hypre_ParVector *) f_in->getVector();
292    fData       = hypre_VectorData(hypre_ParVectorLocalVector(f));
293 
294    /* -------------------------------------------------------------
295     * collect global vector and create a SuperLU dense matrix
296     * -----------------------------------------------------------*/
297 
298    MPI_Comm_size( mpiComm, &nprocs );
299    recvCntArray = new int[nprocs];
300    dispArray    = new int[nprocs];
301    fGlobal      = new double[globalNRows];
302 
303    MPI_Allgather(&localNRows,1,MPI_INT,recvCntArray,1,MPI_INT,mpiComm);
304    dispArray[0] = 0;
305    for ( i = 1; i < nprocs; i++ )
306        dispArray[i] = dispArray[i-1] + recvCntArray[i-1];
307    MPI_Allgatherv(fData, localNRows, MPI_DOUBLE, fGlobal, recvCntArray,
308                   dispArray, MPI_DOUBLE, mpiComm);
309    dCreate_Dense_Matrix(&B, globalNRows,1,fGlobal,globalNRows,SLU_DN,
310                         SLU_D,SLU_GE);
311 
312    /* -------------------------------------------------------------
313     * solve the problem
314     * -----------------------------------------------------------*/
315 
316    trans = NOTRANS;
317    StatInit(&slu_stat);
318    dgstrs (trans, &superLU_Lmat, &superLU_Umat, permC_, permR_, &B,
319            &slu_stat, &info);
320 
321    /* -------------------------------------------------------------
322     * fetch the solution
323     * -----------------------------------------------------------*/
324 
325    for ( irow = 0; irow < localNRows; irow++ )
326       uData[irow] = fGlobal[startRow+irow];
327 
328    /* -------------------------------------------------------------
329     * clean up
330     * -----------------------------------------------------------*/
331 
332    delete [] fGlobal;
333    delete [] recvCntArray;
334    delete [] dispArray;
335    Destroy_SuperMatrix_Store(&B);
336    StatFree(&slu_stat);
337 
338    return info;
339 }
340 
341 #endif
342 
343