1 #include "cs_mex.h"
2 /* find an augmenting path starting at column j and extend the match if found */
3 static
augment(CS_INT k,cs_dl * A,CS_INT * jmatch,CS_INT * cheap,CS_INT * w,CS_INT j)4 CS_INT augment (CS_INT k, cs_dl *A, CS_INT *jmatch, CS_INT *cheap, CS_INT *w,
5 CS_INT j)
6 {
7 CS_INT found = 0, p, i = -1, *Ap = A->p, *Ai = A->i ;
8 /* --- Start depth-first-search at node j ------------------------------- */
9 w [j] = k ; /* mark j as visited for kth path */
10 for (p = cheap [j] ; p < Ap [j+1] && !found ; p++)
11 {
12 i = Ai [p] ; /* try a cheap assignment (i,j) */
13 found = (jmatch [i] == -1) ;
14 }
15 cheap [j] = p ; /* start here next time for j */
16 /* --- Depth-first-search of neighbors of j ----------------------------- */
17 for (p = Ap [j] ; p < Ap [j+1] && !found ; p++)
18 {
19 i = Ai [p] ; /* consider row i */
20 if (w [jmatch [i]] == k) continue ; /* skip col jmatch [i] if marked */
21 found = augment (k, A, jmatch, cheap, w, jmatch [i]) ;
22 }
23 if (found) jmatch [i] = j ; /* augment jmatch if path found */
24 return (found) ;
25 }
26
27 /* find a maximum transveral */
28 static
maxtrans(cs_dl * A)29 CS_INT *maxtrans (cs_dl *A) /* returns jmatch [0..m-1] */
30 {
31 CS_INT i, j, k, n, m, *Ap, *jmatch, *w, *cheap ;
32 if (!A) return (NULL) ; /* check inputs */
33 n = A->n ; m = A->m ; Ap = A->p ;
34 jmatch = cs_dl_malloc (m, sizeof (CS_INT)) ; /* allocate result */
35 w = cs_dl_malloc (2*n, sizeof (CS_INT)) ; /* allocate workspace */
36 if (!w || !jmatch) return (cs_dl_idone (jmatch, NULL, w, 0)) ;
37 cheap = w + n ;
38 for (j = 0 ; j < n ; j++) cheap [j] = Ap [j] ; /* for cheap assignment */
39 for (j = 0 ; j < n ; j++) w [j] = -1 ; /* all columns unflagged */
40 for (i = 0 ; i < m ; i++) jmatch [i] = -1 ; /* no rows matched yet */
41 for (k = 0 ; k < n ; k++) augment (k, A, jmatch, cheap, w, k) ;
42 return (cs_dl_idone (jmatch, NULL, w, 1)) ;
43 }
44
45 /* invert a maximum matching */
invmatch(CS_INT * jmatch,CS_INT m,CS_INT n)46 static CS_INT *invmatch (CS_INT *jmatch, CS_INT m, CS_INT n)
47 {
48 CS_INT i, j, *imatch ;
49 if (!jmatch) return (NULL) ;
50 imatch = cs_dl_malloc (n, sizeof (CS_INT)) ;
51 if (!imatch) return (NULL) ;
52 for (j = 0 ; j < n ; j++) imatch [j] = -1 ;
53 for (i = 0 ; i < m ; i++) if (jmatch [i] >= 0) imatch [jmatch [i]] = i ;
54 return (imatch) ;
55 }
56
mexFunction(int nargout,mxArray * pargout[],int nargin,const mxArray * pargin[])57 void mexFunction
58 (
59 int nargout,
60 mxArray *pargout [ ],
61 int nargin,
62 const mxArray *pargin [ ]
63 )
64 {
65 cs_dl *A, Amatrix ;
66 double *x ;
67 CS_INT i, m, n, *imatch, *jmatch ;
68
69 if (nargout > 1 || nargin != 1)
70 {
71 mexErrMsgTxt ("Usage: p = cr_maxtransr(A)") ;
72 }
73
74 /* get inputs */
75 A = cs_dl_mex_get_sparse (&Amatrix, 0, 0, pargin [0]) ;
76 m = A->m ;
77 n = A->n ;
78
79 jmatch = maxtrans (A) ;
80 imatch = invmatch (jmatch, m, n) ; /* imatch = inverse of jmatch */
81
82 pargout [0] = mxCreateDoubleMatrix (1, n, mxREAL) ;
83 x = mxGetPr (pargout [0]) ;
84 for (i = 0 ; i < n ; i++) x [i] = imatch [i] + 1 ;
85
86 cs_free (jmatch) ;
87 cs_free (imatch) ;
88 }
89