1 #include "cs_mex.h"
2 /* find an augmenting path starting at column j and extend the match if found */
3 static
augment(CS_INT k,cs_dl * A,CS_INT * jmatch,CS_INT * cheap,CS_INT * w,CS_INT j)4 CS_INT augment (CS_INT k, cs_dl *A, CS_INT *jmatch, CS_INT *cheap, CS_INT *w,
5     CS_INT j)
6 {
7     CS_INT found = 0, p, i = -1, *Ap = A->p, *Ai = A->i ;
8     /* --- Start depth-first-search at node j ------------------------------- */
9     w [j] = k ;                             /* mark j as visited for kth path */
10     for (p = cheap [j] ; p < Ap [j+1] && !found ; p++)
11     {
12         i = Ai [p] ;                        /* try a cheap assignment (i,j) */
13         found = (jmatch [i] == -1) ;
14     }
15     cheap [j] = p ;                         /* start here next time for j */
16     /* --- Depth-first-search of neighbors of j ----------------------------- */
17     for (p = Ap [j] ; p < Ap [j+1] && !found ; p++)
18     {
19         i = Ai [p] ;                        /* consider row i */
20         if (w [jmatch [i]] == k) continue ; /* skip col jmatch [i] if marked */
21         found = augment (k, A, jmatch, cheap, w, jmatch [i]) ;
22     }
23     if (found) jmatch [i] = j ;             /* augment jmatch if path found */
24     return (found) ;
25 }
26 
27 /* find a maximum transveral */
28 static
maxtrans(cs_dl * A)29 CS_INT *maxtrans (cs_dl *A)   /* returns jmatch [0..m-1] */
30 {
31     CS_INT i, j, k, n, m, *Ap, *jmatch, *w, *cheap ;
32     if (!A) return (NULL) ;                         /* check inputs */
33     n = A->n ; m = A->m ; Ap = A->p ;
34     jmatch = cs_dl_malloc (m, sizeof (CS_INT)) ;    /* allocate result */
35     w = cs_dl_malloc (2*n, sizeof (CS_INT)) ;       /* allocate workspace */
36     if (!w || !jmatch) return (cs_dl_idone (jmatch, NULL, w, 0)) ;
37     cheap = w + n ;
38     for (j = 0 ; j < n ; j++) cheap [j] = Ap [j] ;  /* for cheap assignment */
39     for (j = 0 ; j < n ; j++) w [j] = -1 ;          /* all columns unflagged */
40     for (i = 0 ; i < m ; i++) jmatch [i] = -1 ;     /* no rows matched yet */
41     for (k = 0 ; k < n ; k++) augment (k, A, jmatch, cheap, w, k) ;
42     return (cs_dl_idone (jmatch, NULL, w, 1)) ;
43 }
44 
45 /* invert a maximum matching */
invmatch(CS_INT * jmatch,CS_INT m,CS_INT n)46 static CS_INT *invmatch (CS_INT *jmatch, CS_INT m, CS_INT n)
47 {
48     CS_INT i, j, *imatch ;
49     if (!jmatch) return (NULL) ;
50     imatch = cs_dl_malloc (n, sizeof (CS_INT)) ;
51     if (!imatch) return (NULL) ;
52     for (j = 0 ; j < n ; j++) imatch [j] = -1 ;
53     for (i = 0 ; i < m ; i++) if (jmatch [i] >= 0) imatch [jmatch [i]] = i ;
54     return (imatch) ;
55 }
56 
mexFunction(int nargout,mxArray * pargout[],int nargin,const mxArray * pargin[])57 void mexFunction
58 (
59     int nargout,
60     mxArray *pargout [ ],
61     int nargin,
62     const mxArray *pargin [ ]
63 )
64 {
65     cs_dl *A, Amatrix ;
66     double *x ;
67     CS_INT i, m, n, *imatch, *jmatch ;
68 
69     if (nargout > 1 || nargin != 1)
70     {
71         mexErrMsgTxt ("Usage: p = cr_maxtransr(A)") ;
72     }
73 
74     /* get inputs */
75     A = cs_dl_mex_get_sparse (&Amatrix, 0, 0, pargin [0]) ;
76     m = A->m ;
77     n = A->n ;
78 
79     jmatch = maxtrans (A) ;
80     imatch = invmatch (jmatch, m, n) ;  /* imatch = inverse of jmatch */
81 
82     pargout [0] = mxCreateDoubleMatrix (1, n, mxREAL) ;
83     x = mxGetPr (pargout [0]) ;
84     for (i = 0 ; i < n ; i++) x [i] = imatch [i] + 1 ;
85 
86     cs_free (jmatch) ;
87     cs_free (imatch) ;
88 }
89