1 /******************************************************************************
2 * Copyright 1998-2019 Lawrence Livermore National Security, LLC and other
3 * HYPRE Project Developers. See the top-level COPYRIGHT file for details.
4 *
5 * SPDX-License-Identifier: (Apache-2.0 OR MIT)
6 ******************************************************************************/
7
8 /* ****************************************************************************
9 * -- SuperLU routine (version 1.1) --
10 * Univ. of California Berkeley, Xerox Palo Alto Research Center,
11 * and Lawrence Berkeley National Lab.
12 * ************************************************************************* */
13
14 #ifdef MLI_SUPERLU
15
16 #include <string.h>
17 #include "mli_solver_superlu.h"
18
19 /* ****************************************************************************
20 * constructor
21 * --------------------------------------------------------------------------*/
22
MLI_Solver_SuperLU(char * name)23 MLI_Solver_SuperLU::MLI_Solver_SuperLU(char *name) : MLI_Solver(name)
24 {
25 permR_ = NULL;
26 permC_ = NULL;
27 mliAmat_ = NULL;
28 factorized_ = 0;
29 }
30
31 /* ****************************************************************************
32 * destructor
33 * --------------------------------------------------------------------------*/
34
~MLI_Solver_SuperLU()35 MLI_Solver_SuperLU::~MLI_Solver_SuperLU()
36 {
37 if ( permR_ != NULL )
38 {
39 Destroy_SuperNode_Matrix(&superLU_Lmat);
40 Destroy_CompCol_Matrix(&superLU_Umat);
41 }
42 if ( permR_ != NULL ) delete [] permR_;
43 if ( permC_ != NULL ) delete [] permC_;
44 }
45
46 /* ****************************************************************************
47 * setup
48 * --------------------------------------------------------------------------*/
49
setup(MLI_Matrix * Amat)50 int MLI_Solver_SuperLU::setup( MLI_Matrix *Amat )
51 {
52 int globalNRows, localNRows, startRow, localNnz, globalNnz;
53 int *csrIA, *csrJA, *gcsrIA, *gcsrJA, *gcscJA, *gcscIA;
54 int nnz, row_num, irow, i, j, rowSize, *cols, *recvCntArray;
55 int *dispArray, itemp, *cntArray, icol, colNum, index;
56 int *etree, permcSpec, lwork, panel_size, relax, info, mypid, nprocs;
57 double *vals, *csrAA, *gcsrAA, *gcscAA, diagPivotThresh;
58 MPI_Comm mpiComm;
59 hypre_ParCSRMatrix *hypreA;
60 SuperMatrix AC;
61 superlu_options_t slu_options;
62 SuperLUStat_t slu_stat;
63 GlobalLU_t Glu;
64
65 /* ---------------------------------------------------------------
66 * fetch matrix
67 * -------------------------------------------------------------*/
68
69 mliAmat_ = Amat;
70 if ( strcmp( mliAmat_->getName(), "HYPRE_ParCSR" ) )
71 {
72 printf("MLI_Solver_SuperLU::setup ERROR - not HYPRE_ParCSR.\n");
73 exit(1);
74 }
75 hypreA = (hypre_ParCSRMatrix *) mliAmat_->getMatrix();
76
77 /* ---------------------------------------------------------------
78 * fetch matrix
79 * -------------------------------------------------------------*/
80
81 mpiComm = hypre_ParCSRMatrixComm( hypreA );
82 MPI_Comm_rank( mpiComm, &mypid );
83 MPI_Comm_size( mpiComm, &nprocs );
84 globalNRows = hypre_ParCSRMatrixGlobalNumRows( hypreA );
85 localNRows = hypre_ParCSRMatrixNumRows( hypreA );
86 startRow = hypre_ParCSRMatrixFirstRowIndex( hypreA );
87 localNnz = 0;
88 for ( irow = 0; irow < localNRows; irow++ )
89 {
90 row_num = startRow + irow;
91 hypre_ParCSRMatrixGetRow(hypreA, row_num, &rowSize, &cols, NULL);
92 localNnz += rowSize;
93 hypre_ParCSRMatrixRestoreRow(hypreA, row_num, &rowSize, &cols, NULL);
94 }
95 MPI_Allreduce(&localNnz, &globalNnz, 1, MPI_INT, MPI_SUM, mpiComm );
96 csrIA = new int[localNRows+1];
97 if ( localNnz > 0 ) csrJA = new int[localNnz];
98 else csrJA = NULL;
99 if ( localNnz > 0 ) csrAA = new double[localNnz];
100 else csrAA = NULL;
101 nnz = 0;
102 csrIA[0] = nnz;
103 for ( irow = 0; irow < localNRows; irow++ )
104 {
105 row_num = startRow + irow;
106 hypre_ParCSRMatrixGetRow(hypreA, row_num, &rowSize, &cols, &vals);
107 for ( i = 0; i < rowSize; i++ )
108 {
109 csrJA[nnz] = cols[i];
110 csrAA[nnz++] = vals[i];
111 }
112 hypre_ParCSRMatrixRestoreRow(hypreA, row_num, &rowSize, &cols, &vals);
113 csrIA[irow+1] = nnz;
114 }
115
116 /* ---------------------------------------------------------------
117 * collect the whole matrix
118 * -------------------------------------------------------------*/
119
120 gcsrIA = new int[globalNRows+1];
121 gcsrJA = new int[globalNnz];
122 gcsrAA = new double[globalNnz];
123 recvCntArray = new int[nprocs];
124 dispArray = new int[nprocs];
125
126 MPI_Allgather(&localNRows,1,MPI_INT,recvCntArray,1,MPI_INT,mpiComm);
127 dispArray[0] = 0;
128 for ( i = 1; i < nprocs; i++ )
129 dispArray[i] = dispArray[i-1] + recvCntArray[i-1];
130 csrIA[0] = csrIA[localNRows];
131 MPI_Allgatherv(csrIA, localNRows, MPI_INT, gcsrIA, recvCntArray,
132 dispArray, MPI_INT, mpiComm);
133 nnz = 0;
134 row_num = 0;
135 for ( i = 0; i < nprocs; i++ )
136 {
137 if ( recvCntArray[i] > 0 )
138 {
139 itemp = gcsrIA[row_num];
140 gcsrIA[row_num] = 0;
141 for ( j = 0; j < recvCntArray[i]; j++ )
142 gcsrIA[row_num+j] += nnz;
143 nnz += itemp;
144 row_num += recvCntArray[i];
145 }
146 }
147 gcsrIA[globalNRows] = nnz;
148
149 MPI_Allgather(&localNnz, 1, MPI_INT, recvCntArray, 1, MPI_INT, mpiComm);
150 dispArray[0] = 0;
151 for ( i = 1; i < nprocs; i++ )
152 dispArray[i] = dispArray[i-1] + recvCntArray[i-1];
153 MPI_Allgatherv(csrJA, localNnz, MPI_INT, gcsrJA, recvCntArray,
154 dispArray, MPI_INT, mpiComm);
155
156 MPI_Allgatherv(csrAA, localNnz, MPI_DOUBLE, gcsrAA, recvCntArray,
157 dispArray, MPI_DOUBLE, mpiComm);
158
159 delete [] recvCntArray;
160 delete [] dispArray;
161 delete [] csrIA;
162 if ( csrJA != NULL ) delete [] csrJA;
163 if ( csrAA != NULL ) delete [] csrAA;
164
165 /* ---------------------------------------------------------------
166 * conversion from CSR to CSC
167 * -------------------------------------------------------------*/
168
169 cntArray = new int[globalNRows];
170 for ( irow = 0; irow < globalNRows; irow++ ) cntArray[irow] = 0;
171 for ( irow = 0; irow < globalNRows; irow++ )
172 {
173 for ( i = gcsrIA[irow]; i < gcsrIA[irow+1]; i++ )
174 if ( gcsrJA[i] >= 0 && gcsrJA[i] < globalNRows )
175 cntArray[gcsrJA[i]]++;
176 else
177 {
178 printf("%d : MLI_Solver_SuperLU ERROR : gcsrJA %d %d = %d(%d)\n",
179 mypid, irow, i, gcsrJA[i], globalNRows);
180 exit(1);
181 }
182 }
183 gcscJA = hypre_TAlloc(int, (globalNRows+1) , HYPRE_MEMORY_HOST);
184 gcscIA = hypre_TAlloc(int, globalNnz , HYPRE_MEMORY_HOST);
185 gcscAA = hypre_TAlloc(double, globalNnz , HYPRE_MEMORY_HOST);
186 gcscJA[0] = 0;
187 nnz = 0;
188 for ( icol = 1; icol <= globalNRows; icol++ )
189 {
190 nnz += cntArray[icol-1];
191 gcscJA[icol] = nnz;
192 }
193 for ( irow = 0; irow < globalNRows; irow++ )
194 {
195 for ( i = gcsrIA[irow]; i < gcsrIA[irow+1]; i++ )
196 {
197 colNum = gcsrJA[i];
198 index = gcscJA[colNum]++;
199 gcscIA[index] = irow;
200 gcscAA[index] = gcsrAA[i];
201 }
202 }
203 gcscJA[0] = 0;
204 nnz = 0;
205 for ( icol = 1; icol <= globalNRows; icol++ )
206 {
207 nnz += cntArray[icol-1];
208 gcscJA[icol] = nnz;
209 }
210 delete [] cntArray;
211 delete [] gcsrIA;
212 delete [] gcsrJA;
213 delete [] gcsrAA;
214
215 /* ---------------------------------------------------------------
216 * make SuperMatrix
217 * -------------------------------------------------------------*/
218
219 dCreate_CompCol_Matrix(&superLU_Amat, globalNRows, globalNRows,
220 gcscJA[globalNRows], gcscAA, gcscIA, gcscJA,
221 SLU_NC, SLU_D, SLU_GE);
222 etree = new int[globalNRows];
223 permC_ = new int[globalNRows];
224 permR_ = new int[globalNRows];
225 permcSpec = 0;
226 get_perm_c(permcSpec, &superLU_Amat, permC_);
227 slu_options.Fact = DOFACT;
228 slu_options.SymmetricMode = NO;
229 sp_preorder(&slu_options, &superLU_Amat, permC_, etree, &AC);
230 diagPivotThresh = 1.0;
231 panel_size = sp_ienv(1);
232 relax = sp_ienv(2);
233 StatInit(&slu_stat);
234 lwork = 0;
235 slu_options.ColPerm = MY_PERMC;
236 slu_options.DiagPivotThresh = diagPivotThresh;
237
238 // dgstrf(&slu_options, &AC, dropTol, relax, panel_size,
239 // etree, NULL, lwork, permC_, permR_, &superLU_Lmat,
240 // &superLU_Umat, &slu_stat, &info);
241 dgstrf(&slu_options, &AC, relax, panel_size,
242 etree, NULL, lwork, permC_, permR_, &superLU_Lmat,
243 &superLU_Umat, &Glu, &slu_stat, &info);
244 Destroy_CompCol_Permuted(&AC);
245 Destroy_CompCol_Matrix(&superLU_Amat);
246 delete [] etree;
247 factorized_ = 1;
248 StatFree(&slu_stat);
249 return 0;
250 }
251
252 /* ****************************************************************************
253 * This subroutine calls the SuperLU subroutine to perform LU
254 * backward substitution
255 * --------------------------------------------------------------------------*/
256
solve(MLI_Vector * f_in,MLI_Vector * u_in)257 int MLI_Solver_SuperLU::solve( MLI_Vector *f_in, MLI_Vector *u_in )
258 {
259 int globalNRows, localNRows, startRow, *recvCntArray;
260 int i, irow, nprocs, *dispArray, info;
261 double *fGlobal;
262 hypre_ParVector *f, *u;
263 double *uData, *fData;
264 SuperMatrix B;
265 MPI_Comm mpiComm;
266 hypre_ParCSRMatrix *hypreA;
267 SuperLUStat_t slu_stat;
268 trans_t trans;
269
270 /* -------------------------------------------------------------
271 * check that the factorization has been called
272 * -----------------------------------------------------------*/
273
274 if ( ! factorized_ )
275 {
276 printf("MLI_Solver_SuperLU::Solve ERROR - not factorized yet.\n");
277 exit(1);
278 }
279
280 /* -------------------------------------------------------------
281 * fetch matrix and vector parameters
282 * -----------------------------------------------------------*/
283
284 hypreA = (hypre_ParCSRMatrix *) mliAmat_->getMatrix();
285 mpiComm = hypre_ParCSRMatrixComm( hypreA );
286 globalNRows = hypre_ParCSRMatrixGlobalNumRows( hypreA );
287 localNRows = hypre_ParCSRMatrixNumRows( hypreA );
288 startRow = hypre_ParCSRMatrixFirstRowIndex( hypreA );
289 u = (hypre_ParVector *) u_in->getVector();
290 uData = hypre_VectorData(hypre_ParVectorLocalVector(u));
291 f = (hypre_ParVector *) f_in->getVector();
292 fData = hypre_VectorData(hypre_ParVectorLocalVector(f));
293
294 /* -------------------------------------------------------------
295 * collect global vector and create a SuperLU dense matrix
296 * -----------------------------------------------------------*/
297
298 MPI_Comm_size( mpiComm, &nprocs );
299 recvCntArray = new int[nprocs];
300 dispArray = new int[nprocs];
301 fGlobal = new double[globalNRows];
302
303 MPI_Allgather(&localNRows,1,MPI_INT,recvCntArray,1,MPI_INT,mpiComm);
304 dispArray[0] = 0;
305 for ( i = 1; i < nprocs; i++ )
306 dispArray[i] = dispArray[i-1] + recvCntArray[i-1];
307 MPI_Allgatherv(fData, localNRows, MPI_DOUBLE, fGlobal, recvCntArray,
308 dispArray, MPI_DOUBLE, mpiComm);
309 dCreate_Dense_Matrix(&B, globalNRows,1,fGlobal,globalNRows,SLU_DN,
310 SLU_D,SLU_GE);
311
312 /* -------------------------------------------------------------
313 * solve the problem
314 * -----------------------------------------------------------*/
315
316 trans = NOTRANS;
317 StatInit(&slu_stat);
318 dgstrs (trans, &superLU_Lmat, &superLU_Umat, permC_, permR_, &B,
319 &slu_stat, &info);
320
321 /* -------------------------------------------------------------
322 * fetch the solution
323 * -----------------------------------------------------------*/
324
325 for ( irow = 0; irow < localNRows; irow++ )
326 uData[irow] = fGlobal[startRow+irow];
327
328 /* -------------------------------------------------------------
329 * clean up
330 * -----------------------------------------------------------*/
331
332 delete [] fGlobal;
333 delete [] recvCntArray;
334 delete [] dispArray;
335 Destroy_SuperMatrix_Store(&B);
336 StatFree(&slu_stat);
337
338 return info;
339 }
340
341 #endif
342
343