1 /* -*-c-*-  */
2 /*
3  * _superlu module
4  *
5  * Python interface to SuperLU decompositions.
6  */
7 
8 /* Copyright 1999 Travis Oliphant
9  *
10  * Permission to copy and modified this file is granted under
11  * the revised BSD license. No warranty is expressed or IMPLIED
12  */
13 
14 #include <Python.h>
15 
16 #define PY_ARRAY_UNIQUE_SYMBOL _scipy_sparse_superlu_ARRAY_API
17 #include <numpy/ndarrayobject.h>
18 
19 #include "_superluobject.h"
20 
21 
22 /*
23  * NULL-safe deconstruction functions
24  */
XDestroy_SuperMatrix_Store(SuperMatrix * A)25 void XDestroy_SuperMatrix_Store(SuperMatrix * A)
26 {
27     Destroy_SuperMatrix_Store(A);	/* safe as-is */
28     A->Store = NULL;
29 }
30 
XDestroy_SuperNode_Matrix(SuperMatrix * A)31 void XDestroy_SuperNode_Matrix(SuperMatrix * A)
32 {
33     if (A->Store) {
34 	Destroy_SuperNode_Matrix(A);
35     }
36     A->Store = NULL;
37 }
38 
XDestroy_CompCol_Matrix(SuperMatrix * A)39 void XDestroy_CompCol_Matrix(SuperMatrix * A)
40 {
41     if (A->Store) {
42 	Destroy_CompCol_Matrix(A);
43     }
44     A->Store = NULL;
45 }
46 
XDestroy_CompCol_Permuted(SuperMatrix * A)47 void XDestroy_CompCol_Permuted(SuperMatrix * A)
48 {
49     if (A->Store) {
50 	Destroy_CompCol_Permuted(A);
51     }
52     A->Store = NULL;
53 }
54 
XStatFree(SuperLUStat_t * stat)55 void XStatFree(SuperLUStat_t * stat)
56 {
57     if (stat->ops) {
58 	StatFree(stat);
59     }
60     stat->ops = NULL;
61 }
62 
63 
64 /*
65  * Data-type dependent implementations for Xgssv and Xgstrf;
66  *
67  * These have to included from separate files because of SuperLU include
68  * structure.
69  */
70 
Py_gssv(PyObject * self,PyObject * args,PyObject * kwdict)71 static PyObject *Py_gssv(PyObject * self, PyObject * args,
72 			 PyObject * kwdict)
73 {
74     volatile PyObject *Py_B = NULL;
75     volatile PyArrayObject *Py_X = NULL;
76     volatile PyArrayObject *nzvals = NULL;
77     volatile PyArrayObject *colind = NULL, *rowptr = NULL;
78     volatile int N, nnz;
79     volatile int info;
80     volatile int csc = 0;
81     volatile int *perm_r = NULL, *perm_c = NULL;
82     volatile SuperMatrix A = { 0 }, B = { 0 }, L = { 0 }, U = { 0 };
83     volatile superlu_options_t options = { 0 };
84     volatile SuperLUStat_t stat = { 0 };
85     volatile PyObject *option_dict = NULL;
86     volatile int type;
87     volatile jmp_buf *jmpbuf_ptr;
88     SLU_BEGIN_THREADS_DEF;
89 
90     static char *kwlist[] = {
91         "N", "nnz", "nzvals", "colind", "rowptr", "B", "csc",
92         "options", NULL
93     };
94 
95     /* Get input arguments */
96     if (!PyArg_ParseTupleAndKeywords(args, kwdict, "iiO!O!O!O|iO", kwlist,
97 				     &N, &nnz, &PyArray_Type, &nzvals,
98 				     &PyArray_Type, &colind, &PyArray_Type,
99 				     &rowptr, &Py_B, &csc, &option_dict)) {
100 	return NULL;
101     }
102 
103     if (!_CHECK_INTEGER(colind) || !_CHECK_INTEGER(rowptr)) {
104 	PyErr_SetString(PyExc_TypeError,
105 			"colind and rowptr must be of type cint");
106 	return NULL;
107     }
108 
109     type = PyArray_TYPE((PyArrayObject*)nzvals);
110     if (!CHECK_SLU_TYPE(type)) {
111 	PyErr_SetString(PyExc_TypeError,
112 			"nzvals is not of a type supported by SuperLU");
113 	return NULL;
114     }
115 
116     if (!set_superlu_options_from_dict((superlu_options_t*)&options, 0,
117                                        (PyObject*)option_dict, NULL, NULL)) {
118 	return NULL;
119     }
120 
121     /* Create Space for output */
122     Py_X = (PyArrayObject*)PyArray_FROMANY(
123         (PyObject*)Py_B, type, 1, 2,
124         NPY_ARRAY_F_CONTIGUOUS | NPY_ARRAY_ENSURECOPY);
125     if (Py_X == NULL)
126 	return NULL;
127 
128     if (PyArray_DIM((PyArrayObject*)Py_X, 0) != N) {
129         PyErr_SetString(PyExc_ValueError,
130                         "b array has invalid shape");
131         Py_DECREF(Py_X);
132         return NULL;
133     }
134 
135     if (csc) {
136 	if (NCFormat_from_spMatrix((SuperMatrix*)&A, N, N, nnz,
137                                    (PyArrayObject *)nzvals, (PyArrayObject *)colind,
138                                    (PyArrayObject *)rowptr, type)) {
139 	    Py_DECREF(Py_X);
140 	    return NULL;
141 	}
142     }
143     else {
144 	if (NRFormat_from_spMatrix((SuperMatrix*)&A, N, N, nnz, (PyArrayObject *)nzvals,
145                                    (PyArrayObject *)colind, (PyArrayObject *)rowptr,
146 				   type)) {
147 	    Py_DECREF(Py_X);
148 	    return NULL;
149 	}
150     }
151 
152     if (DenseSuper_from_Numeric((SuperMatrix*)&B, (PyObject*)Py_X)) {
153 	Destroy_SuperMatrix_Store((SuperMatrix*)&A);
154 	Py_DECREF(Py_X);
155 	return NULL;
156     }
157 
158     /* B and Py_X  share same data now but Py_X "owns" it */
159 
160     /* Setup options */
161 
162     jmpbuf_ptr = (volatile jmp_buf *)superlu_python_jmpbuf();
163     SLU_BEGIN_THREADS;
164     if (setjmp(*(jmp_buf*)jmpbuf_ptr)) {
165         SLU_END_THREADS;
166 	goto fail;
167     }
168     else {
169 	perm_c = intMalloc(N);
170 	perm_r = intMalloc(N);
171 	StatInit((SuperLUStat_t*)&stat);
172 
173 	/* Compute direct inverse of sparse Matrix */
174 	gssv(type, (superlu_options_t*)&options, (SuperMatrix*)&A, (int*)perm_c, (int*)perm_r,
175              (SuperMatrix*)&L, (SuperMatrix*)&U, (SuperMatrix*)&B, (SuperLUStat_t*)&stat,
176              (int*)&info);
177         SLU_END_THREADS;
178     }
179 
180     SUPERLU_FREE((void*)perm_r);
181     SUPERLU_FREE((void*)perm_c);
182     Destroy_SuperMatrix_Store((SuperMatrix*)&A);	/* holds just a pointer to the data */
183     Destroy_SuperMatrix_Store((SuperMatrix*)&B);
184     Destroy_SuperNode_Matrix((SuperMatrix*)&L);
185     Destroy_CompCol_Matrix((SuperMatrix*)&U);
186     StatFree((SuperLUStat_t*)&stat);
187 
188     return Py_BuildValue("Ni", Py_X, info);
189 
190   fail:
191     SUPERLU_FREE((void*)perm_r);
192     SUPERLU_FREE((void*)perm_c);
193     XDestroy_SuperMatrix_Store((SuperMatrix*)&A);	/* holds just a pointer to the data */
194     XDestroy_SuperMatrix_Store((SuperMatrix*)&B);
195     XDestroy_SuperNode_Matrix((SuperMatrix*)&L);
196     XDestroy_CompCol_Matrix((SuperMatrix*)&U);
197     XStatFree((SuperLUStat_t*)&stat);
198     Py_XDECREF(Py_X);
199     return NULL;
200 }
201 
Py_gstrf(PyObject * self,PyObject * args,PyObject * keywds)202 static PyObject *Py_gstrf(PyObject * self, PyObject * args,
203 			  PyObject * keywds)
204 {
205     /* default value for SuperLU parameters */
206     int N, nnz;
207     PyArrayObject *rowind, *colptr, *nzvals;
208     SuperMatrix A = { 0 };
209     PyObject *result;
210     PyObject *py_csc_construct_func = NULL;
211     PyObject *option_dict = NULL;
212     int type;
213     int ilu = 0;
214 
215     static char *kwlist[] = { "N", "nnz", "nzvals", "colind", "rowptr",
216         "csc_construct_func", "options", "ilu",
217 	NULL
218     };
219 
220     int res =
221 	PyArg_ParseTupleAndKeywords(args, keywds, "iiO!O!O!O|Oi", kwlist,
222 				    &N, &nnz,
223 				    &PyArray_Type, &nzvals,
224 				    &PyArray_Type, &rowind,
225 				    &PyArray_Type, &colptr,
226                                     &py_csc_construct_func,
227 				    &option_dict,
228 				    &ilu);
229 
230     if (!res)
231 	return NULL;
232 
233     if (!_CHECK_INTEGER(colptr) || !_CHECK_INTEGER(rowind)) {
234 	PyErr_SetString(PyExc_TypeError,
235 			"rowind and colptr must be of type cint");
236 	return NULL;
237     }
238 
239     type = PyArray_TYPE((PyArrayObject*)nzvals);
240     if (!CHECK_SLU_TYPE(type)) {
241 	PyErr_SetString(PyExc_TypeError,
242 			"nzvals is not of a type supported by SuperLU");
243 	return NULL;
244     }
245 
246     if (NCFormat_from_spMatrix(&A, N, N, nnz, nzvals, rowind, colptr,
247 			       type)) {
248 	goto fail;
249     }
250 
251     result = newSuperLUObject(&A, option_dict, type, ilu, py_csc_construct_func);
252     if (result == NULL) {
253 	goto fail;
254     }
255 
256     /* arrays of input matrix will not be freed */
257     Destroy_SuperMatrix_Store(&A);
258     return result;
259 
260   fail:
261     /* arrays of input matrix will not be freed */
262     XDestroy_SuperMatrix_Store(&A);
263     return NULL;
264 }
265 
266 static char gssv_doc[] =
267     "Direct inversion of sparse matrix.\n\nX = gssv(A,B) solves A*X = B for X.";
268 
269 static char gstrf_doc[] = "gstrf(A, ...)\n\
270 \n\
271 performs a factorization of the sparse matrix A=*(N,nnz,nzvals,rowind,colptr) and \n\
272 returns a factored_lu object.\n\
273 \n\
274 arguments\n\
275 ---------\n\
276 \n\
277 Matrix to be factorized is represented as N,nnz,nzvals,rowind,colptr\n\
278   as separate arguments.  This is compressed sparse column representation.\n\
279 \n\
280 N         number of rows and columns \n\
281 nnz       number of non-zero elements\n\
282 nzvals    non-zero values \n\
283 rowind    row-index for this column (same size as nzvals)\n\
284 colptr    index into rowind for first non-zero value in this column\n\
285           size is (N+1).  Last value should be nnz. \n\
286 \n\
287 additional keyword arguments:\n\
288 -----------------------------\n\
289 options             specifies additional options for SuperLU\n\
290                     (same keys and values as in superlu_options_t C structure,\n\
291                     and additionally 'Relax' and 'PanelSize')\n\
292 \n\
293 ilu                 whether to perform an incomplete LU decomposition\n\
294                     (default: false)\n\
295 ";
296 
297 
298 /*
299  * Main SuperLU module
300  */
301 
302 static PyMethodDef SuperLU_Methods[] = {
303     {"gssv", (PyCFunction) Py_gssv, METH_VARARGS | METH_KEYWORDS,
304      gssv_doc},
305     {"gstrf", (PyCFunction) Py_gstrf, METH_VARARGS | METH_KEYWORDS,
306      gstrf_doc},
307     {NULL, NULL}
308 };
309 
310 static struct PyModuleDef moduledef = {
311     PyModuleDef_HEAD_INIT,
312     "_superlu",
313     NULL,
314     -1,
315     SuperLU_Methods,
316     NULL,
317     NULL,
318     NULL,
319     NULL
320 };
321 
PyInit__superlu(void)322 PyObject *PyInit__superlu(void)
323 {
324     PyObject *m, *d;
325 
326     import_array();
327 
328     if (PyType_Ready(&SuperLUType) < 0) {
329         return NULL;
330     }
331 
332     if (PyType_Ready(&SuperLUGlobalType) < 0) {
333     	return NULL;
334     }
335 
336     m = PyModule_Create(&moduledef);
337     d = PyModule_GetDict(m);
338 
339     Py_INCREF(&PyArrayFlags_Type);
340     PyDict_SetItemString(d, "SuperLU",
341 			 (PyObject *) &SuperLUType);
342 
343     if (PyErr_Occurred())
344 	Py_FatalError("can't initialize module _superlu");
345 
346     return m;
347 }
348