1 /*! \file
2 Copyright (c) 2003, The Regents of the University of California, through
3 Lawrence Berkeley National Laboratory (subject to receipt of any required
4 approvals from U.S. Dept. of Energy)
5
6 All rights reserved.
7
8 The source code is distributed under BSD license, see the file License.txt
9 at the top-level directory.
10 */
11 /*
12 * -- SuperLU routine (version 5.2) --
13 * Univ. of California Berkeley, Xerox Palo Alto Research Center,
14 * and Lawrence Berkeley National Lab.
15 * June 30, 2009
16 *
17 * Modified:
18 * September 25, 2011, compatible with 64-bit integer in R2006b
19 * January 17, 2016, compatible with SuperLU_5.0 interface
20 */
21 #include <stdio.h>
22 #include "mex.h"
23 #include "matrix.h"
24
25 #include "slu_ddefs.h"
26
27 #define MatlabMatrix mxArray
28
29
30 /* Aliases for input and output arguments */
31 #define A_in prhs[0]
32 #define Pc_in prhs[1]
33 #define L_out plhs[0]
34 #define U_out plhs[1]
35 #define Pr_out plhs[2]
36 #define Pc_out plhs[3]
37
38 void LUextract(SuperMatrix *, SuperMatrix *, double *, mwIndex *, mwIndex *,
39 double *, mwIndex *, mwIndex *, int *, int*);
40
41 #define verbose (SPUMONI>0)
42 #define babble (SPUMONI>1)
43 #define burble (SPUMONI>2)
44
mexFunction(int nlhs,MatlabMatrix * plhs[],int nrhs,const MatlabMatrix * prhs[])45 void mexFunction(
46 int nlhs, /* number of expected outputs */
47 MatlabMatrix *plhs[], /* matrix pointer array returning outputs */
48 int nrhs, /* number of inputs */
49 const MatlabMatrix *prhs[] /* matrix pointer array for inputs */
50 )
51 {
52 int SPUMONI; /* ... as should the sparse monitor flag */
53 double FlopsInSuperLU; /* ... as should the flop counter */
54 extern flops_t LUFactFlops(SuperLUStat_t *);
55
56 /* Arguments to C dgstrf(). */
57 superlu_options_t options;
58 SuperMatrix Ac; /* Matrix postmultiplied by Pc */
59 SuperMatrix L, U;
60 GlobalLU_t Glu; /* Not needed on return. */
61 int panel_size, relax;
62 int *etree, *perm_r, *perm_c;
63 SuperLUStat_t stat;
64
65 /* other local variables */
66 SuperMatrix A;
67 int m, n, nnz;
68 double *val;
69 int *rowind;
70 int *colptr;
71 mwSize *perm_c_64;
72 mwSize *rowind_64;
73 mwSize *colptr_64;
74 double thresh = 1.0; /* diagonal pivoting threshold */
75 int info;
76 MatlabMatrix *X, *Y; /* args to calls back to Matlab */
77 int i, mexerr;
78 double *dp;
79 double *Lval, *Uval;
80 int *Lrow, *Urow;
81 int *Lcol, *Ucol;
82 mwIndex *Lrow_64, *Lcol_64, *Urow_64, *Ucol_64;
83 int nnzL, nnzU, snnzL, snnzU;
84
85 /* Check number of arguments passed from Matlab. */
86 if (nrhs != 2) {
87 mexErrMsgTxt("SUPERLU requires 2 input arguments.");
88 } else if (nlhs != 4) {
89 mexErrMsgTxt("SUPERLU requires 4 output arguments.");
90 }
91
92 /* Read the Sparse Monitor Flag */
93 X = mxCreateString("spumoni");
94 mexerr = mexCallMATLAB(1, &Y, 1, &X, "sparsfun");
95 SPUMONI = mxGetScalar(Y);
96 mxDestroyArray(Y);
97 mxDestroyArray(X);
98
99 m = mxGetM(A_in);
100 n = mxGetN(A_in);
101 val = mxGetPr(A_in);
102 rowind_64 = mxGetIr(A_in);
103 colptr_64 = mxGetJc(A_in);
104 perm_c_64 = mxGetIr(Pc_in);
105 nnz = colptr_64[n];
106 if ( verbose ) mexPrintf("m = %d, n = %d, nnz = %d\n", m, n, nnz);
107
108 etree = (int *) mxCalloc(n, sizeof(int));
109 perm_r = (int *) mxCalloc(m, sizeof(int));
110 perm_c = (int *) mxMalloc(n * sizeof(int));
111 rowind = (int *) mxMalloc(nnz * sizeof(int));
112 colptr = (int *) mxMalloc((n+1) * sizeof(int));
113
114 for (i = 0; i < n; ++i) {
115 perm_c[i] = perm_c_64[i];
116 colptr[i] = colptr_64[i];
117 /*printf("perm_c[%d] %d\n", i, perm_c[i]);*/
118 }
119 colptr[n] = colptr_64[n];
120 for (i = 0; i < nnz; ++i) rowind[i] = rowind_64[i];
121
122 dCreate_CompCol_Matrix(&A, m, n, nnz, val, rowind, colptr,
123 SLU_NC, SLU_D, SLU_GE);
124 panel_size = sp_ienv(1);
125 relax = sp_ienv(2);
126 thresh = 1.0;
127 FlopsInSuperLU = 0;
128
129 set_default_options(&options);
130 StatInit(&stat);
131
132 if ( verbose ) mexPrintf("Apply column perm to A and compute etree...\n");
133 sp_preorder(&options, &A, perm_c, etree, &Ac);
134
135 if ( verbose ) {
136 mexPrintf("LU factorization...\n");
137 mexPrintf("\tpanel_size %d, relax %d, diag_pivot_thresh %.2g\n",
138 panel_size, relax, thresh);
139 }
140
141 dgstrf(&options, &Ac, relax, panel_size, etree,
142 NULL, 0, perm_c, perm_r, &L, &U, &Glu, &stat, &info);
143
144 if ( verbose ) mexPrintf("INFO from dgstrf %d\n", info);
145
146 #if 0 /* FLOPS is not available in the new Matlab. */
147 /* Tell Matlab how many flops we did. */
148 FlopsInSuperLU += LUFactFlops(&stat);
149 if (verbose) mexPrintf("SUPERLU flops: %.f\n", FlopsInSuperLU);
150 mexerr = mexCallMATLAB(1, &X, 0, NULL, "flops");
151 *(mxGetPr(X)) += FlopsInSuperLU;
152 mexerr = mexCallMATLAB(1, &Y, 1, &X, "flops");
153 mxDestroyArray(Y);
154 mxDestroyArray(X);
155 #endif
156
157 /* Construct output arguments for Matlab. */
158 if ( info >= 0 && info <= n ) {
159 Pr_out = mxCreateDoubleMatrix(m, 1, mxREAL); /* output row perm */
160 dp = mxGetPr(Pr_out);
161 for (i = 0; i < m; *dp++ = (double) perm_r[i++]+1);
162
163 Pc_out = mxCreateDoubleMatrix(n, 1, mxREAL); /* output col perm */
164 dp = mxGetPr(Pc_out);
165 for (i = 0; i < m; ++i) dp[i] = (double) perm_c[i]+1;
166
167 /* Now for L and U */
168 nnzL = ((SCformat*)L.Store)->nnz; /* count diagonals */
169 nnzU = ((NCformat*)U.Store)->nnz;
170 L_out = mxCreateSparse(m, n, nnzL, mxREAL);
171 Lval = mxGetPr(L_out);
172 Lrow_64 = mxGetIr(L_out);
173 Lcol_64 = mxGetJc(L_out);
174 U_out = mxCreateSparse(m, n, nnzU, mxREAL);
175 Uval = mxGetPr(U_out);
176 Urow_64 = mxGetIr(U_out);
177 Ucol_64 = mxGetJc(U_out);
178
179 LUextract(&L, &U, Lval, Lrow_64, Lcol_64, Uval, Urow_64, Ucol_64,
180 &snnzL, &snnzU);
181
182 if ( babble )
183 for (i = 0; i <= n; ++i) printf("Lcol_64[%d] %d\n", i, Lcol_64[i]);
184
185 printf("nnzL = %d, nnzU = %d\n", nnzL, nnzU);
186 if ( babble ) {
187 for (i=0; i < nnzL; ++i)
188 mexPrintf("Lrow_64[%d] %d\n", i, Lrow_64[i]);
189 for (i = 0; i < snnzU; ++i)
190 mexPrintf("Urow_64[%d] = %d\n", i, Urow_64[i]);
191 }
192
193 Destroy_CompCol_Permuted(&Ac);
194 Destroy_SuperNode_Matrix(&L);
195 Destroy_CompCol_Matrix(&U);
196
197 if (verbose) mexPrintf("factor nonzeros: %d unsqueezed, %d squeezed.\n",
198 nnzL + nnzU, snnzL + snnzU);
199 } else {
200 mexErrMsgTxt("Error returned from C dgstrf().");
201 }
202
203 mxFree(etree);
204 mxFree(perm_r);
205 mxFree(perm_c);
206 mxFree(rowind);
207 mxFree(colptr);
208
209 StatFree(&stat);
210
211 return;
212 }
213
214 void
LUextract(SuperMatrix * L,SuperMatrix * U,double * Lval,mwIndex * Lrow,mwIndex * Lcol,double * Uval,mwIndex * Urow,mwIndex * Ucol,int * snnzL,int * snnzU)215 LUextract(SuperMatrix *L, SuperMatrix *U, double *Lval, mwIndex *Lrow,
216 mwIndex *Lcol, double *Uval, mwIndex *Urow, mwIndex *Ucol,
217 int *snnzL, int *snnzU)
218 {
219 int i, j, k;
220 int upper;
221 int fsupc, istart, nsupr;
222 int lastl = 0, lastu = 0;
223 SCformat *Lstore;
224 NCformat *Ustore;
225 double *SNptr;
226
227 Lstore = L->Store;
228 Ustore = U->Store;
229 Lcol[0] = 0;
230 Ucol[0] = 0;
231
232 /* for each supernode */
233 for (k = 0; k <= Lstore->nsuper; ++k) {
234
235 fsupc = L_FST_SUPC(k);
236 istart = L_SUB_START(fsupc);
237 nsupr = L_SUB_START(fsupc+1) - istart;
238 upper = 1;
239
240 /* for each column in the supernode */
241 for (j = fsupc; j < L_FST_SUPC(k+1); ++j) {
242 SNptr = &((double*)Lstore->nzval)[L_NZ_START(j)];
243
244 /* Extract U */
245 for (i = U_NZ_START(j); i < U_NZ_START(j+1); ++i) {
246 Uval[lastu] = ((double*)Ustore->nzval)[i];
247 /* Matlab doesn't like explicit zero. */
248 if (Uval[lastu] != 0.0) Urow[lastu++] = (mwIndex) U_SUB(i);
249 }
250 for (i = 0; i < upper; ++i) { /* upper triangle in the supernode */
251 Uval[lastu] = SNptr[i];
252 /* Matlab doesn't like explicit zero. */
253 if (Uval[lastu] != 0.0) Urow[lastu++] = (mwIndex)L_SUB(istart+i);
254 }
255 Ucol[j+1] = lastu;
256
257 /* Extract L */
258 Lval[lastl] = 1.0; /* unit diagonal */
259 Lrow[lastl++] = L_SUB(istart + upper - 1);
260 for (i = upper; i < nsupr; ++i) {
261 Lval[lastl] = SNptr[i];
262 /* Matlab doesn't like explicit zero. */
263 if (Lval[lastl] != 0.0) Lrow[lastl++] = (mwIndex)L_SUB(istart+i);
264 }
265 Lcol[j+1] = lastl;
266
267 ++upper;
268
269 } /* for j ... */
270
271 } /* for k ... */
272
273 *snnzL = lastl;
274 *snnzU = lastu;
275 }
276