1 //------------------------------------------------------------------------------
2 // GB_mex_band: C = tril (triu (A,lo), hi), or with A'
3 //------------------------------------------------------------------------------
4 
5 // SuiteSparse:GraphBLAS, Timothy A. Davis, (c) 2017-2021, All Rights Reserved.
6 // SPDX-License-Identifier: Apache-2.0
7 
8 //------------------------------------------------------------------------------
9 
10 // Apply a select operator to a matrix
11 
12 #include "GB_mex.h"
13 
14 #define USAGE "C = GB_mex_band (A, lo, hi, atranspose)"
15 
16 #define FREE_ALL                        \
17 {                                       \
18     GxB_Scalar_free_(&Thunk) ;          \
19     GrB_Matrix_free_(&C) ;              \
20     GrB_Matrix_free_(&A) ;              \
21     GxB_Scalar_free_(&Thunk_type) ;     \
22     GxB_SelectOp_free_(&op) ;           \
23     GrB_Descriptor_free_(&desc) ;       \
24     GB_mx_put_global (true) ;           \
25 }
26 
27 #define OK(method)                                      \
28 {                                                       \
29     info = method ;                                     \
30     if (info != GrB_SUCCESS)                            \
31     {                                                   \
32         FREE_ALL ;                                      \
33         mexErrMsgTxt ("GraphBLAS failed") ;             \
34     }                                                   \
35 }
36 
37 typedef struct
38 {
39     int64_t lo ;
40     int64_t hi ;
41 } LoHi_type ;
42 
43 bool LoHi_band (GrB_Index i, GrB_Index j,
44     /* x is unused: */ const void *x, const LoHi_type *thunk) ;
45 
LoHi_band(GrB_Index i,GrB_Index j,const void * x,const LoHi_type * thunk)46 bool LoHi_band (GrB_Index i, GrB_Index j,
47     /* x is unused: */ const void *x, const LoHi_type *thunk)
48 {
49     int64_t i2 = (int64_t) i ;
50     int64_t j2 = (int64_t) j ;
51 //  printf ("i %lld j %lld lo %lld hi %lld\n", i2, j2, thunk->lo, thunk->hi) ;
52 //  printf ("   j-i %lld\n", j2-i2) ;
53     return ((thunk->lo <= (j2-i2)) && ((j2-i2) <= thunk->hi)) ;
54 }
55 
mexFunction(int nargout,mxArray * pargout[],int nargin,const mxArray * pargin[])56 void mexFunction
57 (
58     int nargout,
59     mxArray *pargout [ ],
60     int nargin,
61     const mxArray *pargin [ ]
62 )
63 {
64 
65     bool malloc_debug = GB_mx_get_global (true) ;
66     GrB_Matrix C = NULL ;
67     GrB_Matrix A = NULL ;
68     GxB_SelectOp op = NULL ;
69     GrB_Info info ;
70     GrB_Descriptor desc = NULL ;
71     GxB_Scalar Thunk = NULL ;
72     GrB_Type Thunk_type = NULL ;
73 
74     #define GET_DEEP_COPY ;
75     #define FREE_DEEP_COPY ;
76 
77     // check inputs
78     if (nargout > 1 || nargin < 3 || nargin > 4)
79     {
80         mexErrMsgTxt ("Usage: " USAGE) ;
81     }
82 
83     // get A (shallow copy)
84     A = GB_mx_mxArray_to_Matrix (pargin [0], "A input", false, true) ;
85     if (A == NULL)
86     {
87         FREE_ALL ;
88         mexErrMsgTxt ("A failed") ;
89     }
90 
91     // create the Thunk
92     LoHi_type bandwidth  ;
93     OK (GrB_Type_new (&Thunk_type, sizeof (LoHi_type))) ;
94 
95     // get lo and hi
96     bandwidth.lo = (int64_t) mxGetScalar (pargin [1]) ;
97     bandwidth.hi = (int64_t) mxGetScalar (pargin [2]) ;
98 
99     OK (GxB_Scalar_new (&Thunk, Thunk_type)) ;
100     OK (GxB_Scalar_setElement_UDT (Thunk, (void *) &bandwidth)) ;
101     OK (GxB_Scalar_wait_(&Thunk)) ;
102 
103     // get atranspose
104     bool atranspose = false ;
105     if (nargin > 3) atranspose = (bool) mxGetScalar (pargin [3]) ;
106     if (atranspose)
107     {
108         OK (GrB_Descriptor_new (&desc)) ;
109         OK (GxB_Desc_set (desc, GrB_INP0, GrB_TRAN)) ;
110     }
111 
112     GB_MEX_TIC ;
113 
114     // create operator
115     // use the user-defined operator, from the LoHi_band function
116     METHOD (GxB_SelectOp_new (&op, (GxB_select_function) LoHi_band,
117         NULL, Thunk_type)) ;
118 
119     GrB_Index nrows, ncols ;
120     GrB_Matrix_nrows (&nrows, A) ;
121     GrB_Matrix_ncols (&ncols, A) ;
122     if (bandwidth.lo == 0 && bandwidth.hi == 0 && nrows == 10 && ncols == 10)
123     {
124         GxB_SelectOp_fprint_ (op, 3, NULL) ;
125     }
126 
127     // create result matrix C
128     if (atranspose)
129     {
130         OK (GrB_Matrix_new (&C, GrB_FP64, A->vdim, A->vlen)) ;
131     }
132     else
133     {
134         OK (GrB_Matrix_new (&C, GrB_FP64, A->vlen, A->vdim)) ;
135     }
136 
137     // C<Mask> = accum(C,op(A))
138     if (GB_NCOLS (C) == 1 && !atranspose)
139     {
140         // this is just to test the Vector version
141         OK (GxB_Vector_select_((GrB_Vector) C, NULL, NULL, op, (GrB_Vector) A,
142             Thunk, NULL)) ;
143     }
144     else
145     {
146         OK (GxB_Matrix_select_(C, NULL, NULL, op, A, Thunk, desc)) ;
147     }
148 
149     GB_MEX_TOC ;
150 
151     // return C to MATLAB as a sparse matrix and free the GraphBLAS C
152     pargout [0] = GB_mx_Matrix_to_mxArray (&C, "C output", false) ;
153 
154     FREE_ALL ;
155 }
156 
157