1 //------------------------------------------------------------------------------
2 // GB_mex_select: C<M> = accum(C,select(A,k)) or select(A',k)
3 //------------------------------------------------------------------------------
4 
5 // SuiteSparse:GraphBLAS, Timothy A. Davis, (c) 2017-2021, All Rights Reserved.
6 // SPDX-License-Identifier: Apache-2.0
7 
8 //------------------------------------------------------------------------------
9 
10 // Apply a select operator to a matrix
11 
12 #include "GB_mex.h"
13 
14 #define USAGE "C = GB_mex_select (C, M, accum, op, A, Thunk, desc, test)"
15 
16 #define FREE_ALL                        \
17 {                                       \
18     GxB_Scalar_free_(&Thunk) ;          \
19     GrB_Matrix_free_(&C) ;              \
20     GrB_Matrix_free_(&M) ;              \
21     GrB_Matrix_free_(&A) ;              \
22     GxB_SelectOp_free_(&isnanop) ;      \
23     GrB_Descriptor_free_(&desc) ;       \
24     GB_mx_put_global (true) ;           \
25 }
26 
27 bool isnan64 (GrB_Index i, GrB_Index j, const void *x, const void *b) ;
28 
isnan64(GrB_Index i,GrB_Index j,const void * x,const void * b)29 bool isnan64 (GrB_Index i, GrB_Index j, const void *x, const void *b)
30 {
31     double aij = * ((double *) x) ;
32     return (isnan (aij)) ;
33 }
34 
mexFunction(int nargout,mxArray * pargout[],int nargin,const mxArray * pargin[])35 void mexFunction
36 (
37     int nargout,
38     mxArray *pargout [ ],
39     int nargin,
40     const mxArray *pargin [ ]
41 )
42 {
43 
44     bool malloc_debug = GB_mx_get_global (true) ;
45     GrB_Matrix C = NULL ;
46     GrB_Matrix M = NULL ;
47     GrB_Matrix A = NULL ;
48     GrB_Descriptor desc = NULL ;
49     GxB_Scalar Thunk = NULL ;
50     GxB_SelectOp isnanop = NULL ;
51 
52     // check inputs
53     if (nargout > 1 || nargin < 5 || nargin > 8)
54     {
55         mexErrMsgTxt ("Usage: " USAGE) ;
56     }
57 
58     // get C (make a deep copy)
59     #define GET_DEEP_COPY \
60     C = GB_mx_mxArray_to_Matrix (pargin [0], "C input", true, true) ;   \
61     if (nargin > 7 && C != NULL) C->nvec_nonempty = -1 ;
62     #define FREE_DEEP_COPY GrB_Matrix_free_(&C) ;
63     GET_DEEP_COPY ;
64     if (C == NULL)
65     {
66         FREE_ALL ;
67         mexErrMsgTxt ("C failed") ;
68     }
69 
70     // get M (shallow copy)
71     M = GB_mx_mxArray_to_Matrix (pargin [1], "M", false, false) ;
72     if (M == NULL && !mxIsEmpty (pargin [1]))
73     {
74         FREE_ALL ;
75         mexErrMsgTxt ("M failed") ;
76     }
77 
78     // get A (shallow copy)
79     A = GB_mx_mxArray_to_Matrix (pargin [4], "A input", false, true) ;
80     if (A == NULL)
81     {
82         FREE_ALL ;
83         mexErrMsgTxt ("A failed") ;
84     }
85 
86     // get accum, if present
87     bool user_complex = (Complex != GxB_FC64)
88         && (C->type == Complex || A->type == Complex) ;
89     GrB_BinaryOp accum ;
90     if (!GB_mx_mxArray_to_BinaryOp (&accum, pargin [2], "accum",
91         C->type, user_complex))
92     {
93         FREE_ALL ;
94         mexErrMsgTxt ("accum failed") ;
95     }
96 
97     // get select operator; must be present
98     GxB_SelectOp op ;
99     if (!GB_mx_mxArray_to_SelectOp (&op, pargin [3], "op"))
100     {
101         FREE_ALL ;
102         mexErrMsgTxt ("SelectOp failed") ;
103     }
104 
105     if (op == NULL)
106     {
107         // user-defined isnan operator, with no Thunk
108         GxB_SelectOp_new (&isnanop, isnan64, GrB_FP64, NULL) ;
109         op = isnanop ;
110     }
111     else if (nargin > 5)
112     {
113         // get Thunk (shallow copy)
114         if (mxIsSparse (pargin [5]))
115         {
116             Thunk = (GxB_Scalar) GB_mx_mxArray_to_Matrix (pargin [5],
117                 "Thunk input", false, false) ;
118             if (Thunk == NULL)
119             {
120                 FREE_ALL ;
121                 mexErrMsgTxt ("Thunk failed") ;
122             }
123             if (!GB_SCALAR_OK (Thunk))
124             {
125                 FREE_ALL ;
126                 mexErrMsgTxt ("Thunk not a valid scalar") ;
127             }
128         }
129         else
130         {
131             // get k
132             GrB_Type thunk_type = GB_mx_Type (pargin [5]) ;
133             GxB_Scalar_new (&Thunk, thunk_type) ;
134             if (thunk_type == GrB_BOOL)
135             {
136                 bool *p = mxGetData (pargin [5]) ;
137                 GxB_Scalar_setElement_BOOL_(Thunk, *p) ;
138             }
139             else if (thunk_type == GrB_INT8)
140             {
141                 int8_t *p = mxGetInt8s (pargin [5]) ;
142                 GxB_Scalar_setElement_INT8_(Thunk, *p) ;
143             }
144             else if (thunk_type == GrB_INT16)
145             {
146                 int16_t *p = mxGetInt16s (pargin [5]) ;
147                 GxB_Scalar_setElement_INT16_(Thunk, *p) ;
148             }
149             else if (thunk_type == GrB_INT32)
150             {
151                 int32_t *p = mxGetInt32s (pargin [5]) ;
152                 GxB_Scalar_setElement_INT32_(Thunk, *p) ;
153             }
154             else if (thunk_type == GrB_INT64)
155             {
156                 int64_t *p = mxGetInt64s (pargin [5]) ;
157                 GxB_Scalar_setElement_INT64_(Thunk, *p) ;
158             }
159             else if (thunk_type == GrB_UINT8)
160             {
161                 uint8_t *p = mxGetUint8s (pargin [5]) ;
162                 GxB_Scalar_setElement_UINT8_(Thunk, *p) ;
163             }
164             else if (thunk_type == GrB_UINT16)
165             {
166                 uint16_t *p = mxGetUint16s (pargin [5]) ;
167                 GxB_Scalar_setElement_UINT16_(Thunk, *p) ;
168             }
169             else if (thunk_type == GrB_UINT32)
170             {
171                 uint32_t *p = mxGetUint32s (pargin [5]) ;
172                 GxB_Scalar_setElement_UINT32_(Thunk, *p) ;
173             }
174             else if (thunk_type == GrB_UINT64)
175             {
176                 uint64_t *p = mxGetUint64s (pargin [5]) ;
177                 GxB_Scalar_setElement_UINT64_(Thunk, *p) ;
178             }
179             else if (thunk_type == GrB_FP32)
180             {
181                 float *p = mxGetSingles (pargin [5]) ;
182                 GxB_Scalar_setElement_FP32_(Thunk, *p) ;
183             }
184             else if (thunk_type == GrB_FP64)
185             {
186                 double *p = mxGetDoubles (pargin [5]) ;
187                 GxB_Scalar_setElement_FP64_(Thunk, *p) ;
188             }
189             else if (thunk_type == GxB_FC32)
190             {
191                 GxB_FC32_t *p = mxGetComplexSingles (pargin [5]) ;
192                 GxB_Scalar_setElement_FC32_(Thunk, *p) ;
193             }
194             else if (thunk_type == GxB_FC64)
195             {
196                 GxB_FC64_t *p = mxGetComplexDoubles (pargin [5]) ;
197                 GxB_Scalar_setElement_FC64_(Thunk, *p) ;
198             }
199             else if (thunk_type == Complex)
200             {
201                 GxB_FC64_t *p = mxGetComplexDoubles (pargin [5]) ;
202                 GxB_Scalar_setElement_UDT (Thunk, p) ;
203             }
204             else
205             {
206                 mexErrMsgTxt ("unknown thunk type") ;
207             }
208             GxB_Scalar_wait_(&Thunk) ;
209         }
210     }
211 
212     // get desc
213     if (!GB_mx_mxArray_to_Descriptor (&desc, PARGIN (6), "desc"))
214     {
215         FREE_ALL ;
216         mexErrMsgTxt ("desc failed") ;
217     }
218 
219     // just for testing
220     if (nargin > 7)
221     {
222         if (M != NULL) M->nvec_nonempty = -1 ;
223         A->nvec_nonempty = -1 ;
224         C->nvec_nonempty = -1 ;
225     }
226 
227     // C<M> = accum(C,op(A))
228     if (C->vdim == 1 && (desc == NULL || desc->in0 == GxB_DEFAULT))
229     {
230         // this is just to test the Vector version
231         METHOD (GxB_Vector_select_((GrB_Vector) C, (GrB_Vector) M, accum, op,
232             (GrB_Vector) A, Thunk, desc)) ; // C
233     }
234     else
235     {
236         METHOD (GxB_Matrix_select_(C, M, accum, op, A, Thunk, desc)) ; // C
237     }
238 
239     // return C to MATLAB as a struct and free the GraphBLAS C
240     pargout [0] = GB_mx_Matrix_to_mxArray (&C, "C output", true) ;
241 
242     FREE_ALL ;
243 }
244 
245