1 //------------------------------------------------------------------------------
2 // GB_mex_band: C = tril (triu (A,lo), hi), or with A'
3 //------------------------------------------------------------------------------
4
5 // SuiteSparse:GraphBLAS, Timothy A. Davis, (c) 2017-2021, All Rights Reserved.
6 // SPDX-License-Identifier: Apache-2.0
7
8 //------------------------------------------------------------------------------
9
10 // Apply a select operator to a matrix
11
12 #include "GB_mex.h"
13
14 #define USAGE "C = GB_mex_band (A, lo, hi, atranspose)"
15
16 #define FREE_ALL \
17 { \
18 GxB_Scalar_free_(&Thunk) ; \
19 GrB_Matrix_free_(&C) ; \
20 GrB_Matrix_free_(&A) ; \
21 GxB_Scalar_free_(&Thunk_type) ; \
22 GxB_SelectOp_free_(&op) ; \
23 GrB_Descriptor_free_(&desc) ; \
24 GB_mx_put_global (true) ; \
25 }
26
27 #define OK(method) \
28 { \
29 info = method ; \
30 if (info != GrB_SUCCESS) \
31 { \
32 FREE_ALL ; \
33 mexErrMsgTxt ("GraphBLAS failed") ; \
34 } \
35 }
36
37 typedef struct
38 {
39 int64_t lo ;
40 int64_t hi ;
41 } LoHi_type ;
42
43 bool LoHi_band (GrB_Index i, GrB_Index j,
44 /* x is unused: */ const void *x, const LoHi_type *thunk) ;
45
LoHi_band(GrB_Index i,GrB_Index j,const void * x,const LoHi_type * thunk)46 bool LoHi_band (GrB_Index i, GrB_Index j,
47 /* x is unused: */ const void *x, const LoHi_type *thunk)
48 {
49 int64_t i2 = (int64_t) i ;
50 int64_t j2 = (int64_t) j ;
51 // printf ("i %lld j %lld lo %lld hi %lld\n", i2, j2, thunk->lo, thunk->hi) ;
52 // printf (" j-i %lld\n", j2-i2) ;
53 return ((thunk->lo <= (j2-i2)) && ((j2-i2) <= thunk->hi)) ;
54 }
55
mexFunction(int nargout,mxArray * pargout[],int nargin,const mxArray * pargin[])56 void mexFunction
57 (
58 int nargout,
59 mxArray *pargout [ ],
60 int nargin,
61 const mxArray *pargin [ ]
62 )
63 {
64
65 bool malloc_debug = GB_mx_get_global (true) ;
66 GrB_Matrix C = NULL ;
67 GrB_Matrix A = NULL ;
68 GxB_SelectOp op = NULL ;
69 GrB_Info info ;
70 GrB_Descriptor desc = NULL ;
71 GxB_Scalar Thunk = NULL ;
72 GrB_Type Thunk_type = NULL ;
73
74 #define GET_DEEP_COPY ;
75 #define FREE_DEEP_COPY ;
76
77 // check inputs
78 if (nargout > 1 || nargin < 3 || nargin > 4)
79 {
80 mexErrMsgTxt ("Usage: " USAGE) ;
81 }
82
83 // get A (shallow copy)
84 A = GB_mx_mxArray_to_Matrix (pargin [0], "A input", false, true) ;
85 if (A == NULL)
86 {
87 FREE_ALL ;
88 mexErrMsgTxt ("A failed") ;
89 }
90
91 // create the Thunk
92 LoHi_type bandwidth ;
93 OK (GrB_Type_new (&Thunk_type, sizeof (LoHi_type))) ;
94
95 // get lo and hi
96 bandwidth.lo = (int64_t) mxGetScalar (pargin [1]) ;
97 bandwidth.hi = (int64_t) mxGetScalar (pargin [2]) ;
98
99 OK (GxB_Scalar_new (&Thunk, Thunk_type)) ;
100 OK (GxB_Scalar_setElement_UDT (Thunk, (void *) &bandwidth)) ;
101 OK (GxB_Scalar_wait_(&Thunk)) ;
102
103 // get atranspose
104 bool atranspose = false ;
105 if (nargin > 3) atranspose = (bool) mxGetScalar (pargin [3]) ;
106 if (atranspose)
107 {
108 OK (GrB_Descriptor_new (&desc)) ;
109 OK (GxB_Desc_set (desc, GrB_INP0, GrB_TRAN)) ;
110 }
111
112 GB_MEX_TIC ;
113
114 // create operator
115 // use the user-defined operator, from the LoHi_band function
116 METHOD (GxB_SelectOp_new (&op, (GxB_select_function) LoHi_band,
117 NULL, Thunk_type)) ;
118
119 GrB_Index nrows, ncols ;
120 GrB_Matrix_nrows (&nrows, A) ;
121 GrB_Matrix_ncols (&ncols, A) ;
122 if (bandwidth.lo == 0 && bandwidth.hi == 0 && nrows == 10 && ncols == 10)
123 {
124 GxB_SelectOp_fprint_ (op, 3, NULL) ;
125 }
126
127 // create result matrix C
128 if (atranspose)
129 {
130 OK (GrB_Matrix_new (&C, GrB_FP64, A->vdim, A->vlen)) ;
131 }
132 else
133 {
134 OK (GrB_Matrix_new (&C, GrB_FP64, A->vlen, A->vdim)) ;
135 }
136
137 // C<Mask> = accum(C,op(A))
138 if (GB_NCOLS (C) == 1 && !atranspose)
139 {
140 // this is just to test the Vector version
141 OK (GxB_Vector_select_((GrB_Vector) C, NULL, NULL, op, (GrB_Vector) A,
142 Thunk, NULL)) ;
143 }
144 else
145 {
146 OK (GxB_Matrix_select_(C, NULL, NULL, op, A, Thunk, desc)) ;
147 }
148
149 GB_MEX_TOC ;
150
151 // return C to MATLAB as a sparse matrix and free the GraphBLAS C
152 pargout [0] = GB_mx_Matrix_to_mxArray (&C, "C output", false) ;
153
154 FREE_ALL ;
155 }
156
157