1 #include "cs_mex.h"
2 /* check MATLAB input argument */
cs_mex_check(csi nel,csi m,csi n,csi square,csi sparse,csi values,const mxArray * A)3 void cs_mex_check (csi nel, csi m, csi n, csi square, csi sparse, csi values,
4 const mxArray *A)
5 {
6 csi nnel, mm = mxGetM (A), nn = mxGetN (A) ;
7 if (values)
8 {
9 if (mxIsComplex (A))
10 {
11 mexErrMsgTxt ("matrix must be real; try CXSparse instead") ;
12 }
13 }
14 if (sparse && !mxIsSparse (A)) mexErrMsgTxt ("matrix must be sparse") ;
15 if (!sparse)
16 {
17 if (mxIsSparse (A)) mexErrMsgTxt ("matrix must be full") ;
18 if (values && !mxIsDouble (A)) mexErrMsgTxt ("matrix must be double") ;
19 }
20 if (nel)
21 {
22 /* check number of elements */
23 nnel = mxGetNumberOfElements (A) ;
24 if (m >= 0 && n >= 0 && m*n != nnel) mexErrMsgTxt ("wrong length") ;
25 }
26 else
27 {
28 /* check row and/or column dimensions */
29 if (m >= 0 && m != mm) mexErrMsgTxt ("wrong dimension") ;
30 if (n >= 0 && n != nn) mexErrMsgTxt ("wrong dimension") ;
31 }
32 if (square && mm != nn) mexErrMsgTxt ("matrix must be square") ;
33 }
34
35 /* get a MATLAB sparse matrix and convert to cs */
cs_mex_get_sparse(cs * A,csi square,csi values,const mxArray * Amatlab)36 cs *cs_mex_get_sparse (cs *A, csi square, csi values, const mxArray *Amatlab)
37 {
38 cs_mex_check (0, -1, -1, square, 1, values, Amatlab) ;
39 A->m = mxGetM (Amatlab) ;
40 A->n = mxGetN (Amatlab) ;
41 A->p = (csi *) mxGetJc (Amatlab) ;
42 A->i = (csi *) mxGetIr (Amatlab) ;
43 A->x = values ? mxGetPr (Amatlab) : NULL ;
44 A->nzmax = mxGetNzmax (Amatlab) ;
45 A->nz = -1 ; /* denotes a compressed-col matrix, instead of triplet */
46 return (A) ;
47 }
48
49 /* return a sparse matrix to MATLAB */
cs_mex_put_sparse(cs ** Ahandle)50 mxArray *cs_mex_put_sparse (cs **Ahandle)
51 {
52 cs *A ;
53 mxArray *Amatlab ;
54 if (!Ahandle || !CS_CSC ((*Ahandle))) mexErrMsgTxt ("invalid sparse matrix") ;
55 A = *Ahandle ;
56 Amatlab = mxCreateSparse (0, 0, 0, mxREAL) ;
57 mxSetM (Amatlab, A->m) ;
58 mxSetN (Amatlab, A->n) ;
59 mxSetNzmax (Amatlab, A->nzmax) ;
60 cs_free (mxGetJc (Amatlab)) ;
61 cs_free (mxGetIr (Amatlab)) ;
62 mxSetJc (Amatlab, (mwIndex *) A->p) ; /* assign A->p pointer to MATLAB A */
63 mxSetIr (Amatlab, (mwIndex *) A->i) ;
64 cs_free (mxGetPr (Amatlab)) ;
65 if (A->x == NULL)
66 {
67 /* A is a pattern only matrix; return all 1's to MATLAB */
68 csi i, nz ;
69 nz = A->p [A->n] ;
70 A->x = cs_malloc (CS_MAX (nz,1), sizeof (double)) ;
71 for (i = 0 ; i < nz ; i++)
72 {
73 A->x [i] = 1 ;
74 }
75 }
76 mxSetPr (Amatlab, A->x) ;
77 mexMakeMemoryPersistent (A->p) ; /* ensure MATLAB does not free A->p */
78 mexMakeMemoryPersistent (A->i) ;
79 mexMakeMemoryPersistent (A->x) ;
80 cs_free (A) ; /* frees A struct only, not A->p, etc */
81 *Ahandle = NULL ;
82 return (Amatlab) ;
83 }
84
85 /* get a MATLAB dense column vector */
cs_mex_get_double(csi n,const mxArray * X)86 double *cs_mex_get_double (csi n, const mxArray *X)
87 {
88 cs_mex_check (0, n, 1, 0, 0, 1, X) ;
89 return (mxGetPr (X)) ;
90 }
91
92 /* return a double vector to MATLAB */
cs_mex_put_double(csi n,const double * b,mxArray ** X)93 double *cs_mex_put_double (csi n, const double *b, mxArray **X)
94 {
95 double *x ;
96 csi k ;
97 *X = mxCreateDoubleMatrix (n, 1, mxREAL) ; /* create x */
98 x = mxGetPr (*X) ;
99 for (k = 0 ; k < n ; k++) x [k] = b [k] ; /* copy x = b */
100 return (x) ;
101 }
102
103 /* get a MATLAB flint array and convert to csi */
cs_mex_get_int(csi n,const mxArray * Imatlab,csi * imax,csi lo)104 csi *cs_mex_get_int (csi n, const mxArray *Imatlab, csi *imax, csi lo)
105 {
106 double *p ;
107 csi i, k, *C = cs_malloc (n, sizeof (csi)) ;
108 cs_mex_check (1, n, 1, 0, 0, 1, Imatlab) ;
109 p = mxGetPr (Imatlab) ;
110 *imax = 0 ;
111 for (k = 0 ; k < n ; k++)
112 {
113 i = p [k] ;
114 C [k] = i - 1 ;
115 if (i < lo) mexErrMsgTxt ("index out of bounds") ;
116 *imax = CS_MAX (*imax, i) ;
117 }
118 return (C) ;
119 }
120
121 /* return an csi array to MATLAB as a flint row vector */
cs_mex_put_int(csi * p,csi n,csi offset,csi do_free)122 mxArray *cs_mex_put_int (csi *p, csi n, csi offset, csi do_free)
123 {
124 mxArray *X = mxCreateDoubleMatrix (1, n, mxREAL) ;
125 double *x = mxGetPr (X) ;
126 csi k ;
127 for (k = 0 ; k < n ; k++) x [k] = (p ? p [k] : k) + offset ;
128 if (do_free) cs_free (p) ;
129 return (X) ;
130 }
131