1 //------------------------------------------------------------------------------
2 // GB_mex_apply1: C<Mask> = accum(C,op(x,A)) or op(x,A')
3 //------------------------------------------------------------------------------
4 
5 // SuiteSparse:GraphBLAS, Timothy A. Davis, (c) 2017-2021, All Rights Reserved.
6 // SPDX-License-Identifier: Apache-2.0
7 
8 //------------------------------------------------------------------------------
9 
10 // Apply a binary operator to a matrix or vector, binding x to a scalar,
11 
12 #include "GB_mex.h"
13 
14 #define USAGE "C = GB_mex_apply1 (C, Mask, accum, op, how, x, A, desc)"
15 
16 // if how == 0: use the GxB_Scalar and GxB_Matrix/Vector_apply_BinaryOp1st
17 // if how == 1: use the C scalar   and GrB_Matrix/Vector_apply_BinaryOp1st_T
18 
19 #define FREE_ALL                        \
20 {                                       \
21     GrB_Matrix_free_(&C) ;              \
22     GrB_Matrix_free_(&Mask) ;           \
23     GrB_Matrix_free_(&S) ;              \
24     GrB_Matrix_free_(&A) ;              \
25     GrB_Descriptor_free_(&desc) ;       \
26     GB_mx_put_global (true) ;           \
27 }
28 
29 GrB_Matrix C = NULL, S = NULL ;
30 GxB_Scalar scalar = NULL ;
31 GrB_Matrix Mask = NULL ;
32 GrB_Matrix A = NULL ;
33 GrB_Descriptor desc = NULL ;
34 GrB_BinaryOp accum = NULL ;
35 GrB_BinaryOp op = NULL ;
36 GrB_Info apply1 (bool is_matrix) ;
37 int how = 0 ;
38 
39 //------------------------------------------------------------------------------
40 
apply1(bool is_matrix)41 GrB_Info apply1 (bool is_matrix)
42 {
43     GrB_Info info ;
44     GrB_Type stype ;
45     GxB_Scalar_type (&stype, scalar) ;
46 
47     if (is_matrix && how == 1)
48     {
49         if (stype == GrB_BOOL)
50         {
51             bool x = *((bool *) (scalar->x)) ;
52             info = GrB_Matrix_apply_BinaryOp1st_BOOL_
53                 (C, Mask, accum, op, x, A, desc) ;
54         }
55         else if (stype == GrB_INT8)
56         {
57             int8_t x = *((int8_t *) (scalar->x)) ;
58             info = GrB_Matrix_apply_BinaryOp1st_INT8_
59                 (C, Mask, accum, op, x, A, desc) ;
60         }
61         else if (stype == GrB_INT16)
62         {
63             int16_t x = *((int16_t *) (scalar->x)) ;
64             info = GrB_Matrix_apply_BinaryOp1st_INT16_
65                 (C, Mask, accum, op, x, A, desc) ;
66         }
67         else if (stype == GrB_INT32)
68         {
69             int32_t x = *((int32_t *) (scalar->x)) ;
70             info = GrB_Matrix_apply_BinaryOp1st_INT32_
71                 (C, Mask, accum, op, x, A, desc) ;
72         }
73         else if (stype == GrB_INT64)
74         {
75             int64_t x = *((int64_t *) (scalar->x)) ;
76             info = GrB_Matrix_apply_BinaryOp1st_INT64_
77                 (C, Mask, accum, op, x, A, desc) ;
78         }
79         else if (stype == GrB_UINT8)
80         {
81             uint8_t x = *((uint8_t *) (scalar->x)) ;
82             info = GrB_Matrix_apply_BinaryOp1st_UINT8_
83                 (C, Mask, accum, op, x, A, desc) ;
84         }
85         else if (stype == GrB_UINT16)
86         {
87             uint16_t x = *((uint16_t *) (scalar->x)) ;
88             info = GrB_Matrix_apply_BinaryOp1st_UINT16_
89                 (C, Mask, accum, op, x, A, desc) ;
90         }
91         else if (stype == GrB_UINT32)
92         {
93             uint32_t x = *((uint32_t *) (scalar->x)) ;
94             info = GrB_Matrix_apply_BinaryOp1st_UINT32_
95                 (C, Mask, accum, op, x, A, desc) ;
96         }
97         else if (stype == GrB_UINT64)
98         {
99             uint64_t x = *((uint64_t *) (scalar->x)) ;
100             info = GrB_Matrix_apply_BinaryOp1st_UINT64_
101                 (C, Mask, accum, op, x, A, desc) ;
102         }
103         else if (stype == GrB_FP32)
104         {
105             float x = *((float *) (scalar->x)) ;
106             info = GrB_Matrix_apply_BinaryOp1st_FP32_
107                 (C, Mask, accum, op, x, A, desc) ;
108         }
109         else if (stype == GrB_FP64)
110         {
111             double x = *((double *) (scalar->x)) ;
112             info = GrB_Matrix_apply_BinaryOp1st_FP64_
113                 (C, Mask, accum, op, x, A, desc) ;
114         }
115         else if (stype == GxB_FC32)
116         {
117             GxB_FC32_t x = *((GxB_FC32_t *) (scalar->x)) ;
118             info = GxB_Matrix_apply_BinaryOp1st_FC32_
119                 (C, Mask, accum, op, x, A, desc) ;
120         }
121         else if (stype == GxB_FC64)
122         {
123             GxB_FC64_t x = *((GxB_FC64_t *) (scalar->x)) ;
124             info = GxB_Matrix_apply_BinaryOp1st_FC64_
125                 (C, Mask, accum, op, x, A, desc) ;
126         }
127     }
128     else if (is_matrix && how == 0)
129     {
130         info = GxB_Matrix_apply_BinaryOp1st_
131             (C, Mask, accum, op, scalar, A, desc) ;
132     }
133     else if (!is_matrix && how == 1)
134     {
135         GrB_Vector w = (GrB_Vector) C ;
136         GrB_Vector m = (GrB_Vector) Mask ;
137         GrB_Vector a = (GrB_Vector) A ;
138         if (stype == GrB_BOOL)
139         {
140             bool x = *((bool *) (scalar->x)) ;
141             info = GrB_Vector_apply_BinaryOp1st_BOOL_
142                 (w, Mask, accum, op, x, A, desc) ;
143         }
144         else if (stype == GrB_INT8)
145         {
146             int8_t x = *((int8_t *) (scalar->x)) ;
147             info = GrB_Vector_apply_BinaryOp1st_INT8_
148                 (w, Mask, accum, op, x, A, desc) ;
149         }
150         else if (stype == GrB_INT16)
151         {
152             int16_t x = *((int16_t *) (scalar->x)) ;
153             info = GrB_Vector_apply_BinaryOp1st_INT16_
154                 (w, Mask, accum, op, x, A, desc) ;
155         }
156         else if (stype == GrB_INT32)
157         {
158             int32_t x = *((int32_t *) (scalar->x)) ;
159             info = GrB_Vector_apply_BinaryOp1st_INT32_
160                 (w, Mask, accum, op, x, A, desc) ;
161         }
162         else if (stype == GrB_INT64)
163         {
164             int64_t x = *((int64_t *) (scalar->x)) ;
165             info = GrB_Vector_apply_BinaryOp1st_INT64_
166                 (w, Mask, accum, op, x, A, desc) ;
167         }
168         else if (stype == GrB_UINT8)
169         {
170             uint8_t x = *((uint8_t *) (scalar->x)) ;
171             info = GrB_Vector_apply_BinaryOp1st_UINT8_
172                 (w, Mask, accum, op, x, A, desc) ;
173         }
174         else if (stype == GrB_UINT16)
175         {
176             uint16_t x = *((uint16_t *) (scalar->x)) ;
177             info = GrB_Vector_apply_BinaryOp1st_UINT16_
178                 (w, Mask, accum, op, x, A, desc) ;
179         }
180         else if (stype == GrB_UINT32)
181         {
182             uint32_t x = *((uint32_t *) (scalar->x)) ;
183             info = GrB_Vector_apply_BinaryOp1st_UINT32_
184                 (w, Mask, accum, op, x, A, desc) ;
185         }
186         else if (stype == GrB_UINT64)
187         {
188             uint64_t x = *((uint64_t *) (scalar->x)) ;
189             info = GrB_Vector_apply_BinaryOp1st_UINT64_
190                 (w, Mask, accum, op, x, A, desc) ;
191         }
192         else if (stype == GrB_FP32)
193         {
194             float x = *((float *) (scalar->x)) ;
195             info = GrB_Vector_apply_BinaryOp1st_FP32_
196                 (w, Mask, accum, op, x, A, desc) ;
197         }
198         else if (stype == GrB_FP64)
199         {
200             double x = *((double *) (scalar->x)) ;
201             info = GrB_Vector_apply_BinaryOp1st_FP64_
202                 (w, Mask, accum, op, x, A, desc) ;
203         }
204         else if (stype == GxB_FC32)
205         {
206             GxB_FC32_t x = *((GxB_FC32_t *) (scalar->x)) ;
207             info = GxB_Vector_apply_BinaryOp1st_FC32_
208                 (w, Mask, accum, op, x, A, desc) ;
209         }
210         else if (stype == GxB_FC64)
211         {
212             GxB_FC64_t x = *((GxB_FC64_t *) (scalar->x)) ;
213             info = GxB_Vector_apply_BinaryOp1st_FC64_
214                 (w, Mask, accum, op, x, A, desc) ;
215         }
216     }
217     else if (!is_matrix && how == 0)
218     {
219         GrB_Vector w = (GrB_Vector) C ;
220         GrB_Vector m = (GrB_Vector) Mask ;
221         GrB_Vector a = (GrB_Vector) A ;
222         info = GxB_Vector_apply_BinaryOp1st_
223             (w, m, accum, op, scalar, a, desc) ;
224     }
225 
226     return (info) ;
227 }
228 
229 //------------------------------------------------------------------------------
230 
mexFunction(int nargout,mxArray * pargout[],int nargin,const mxArray * pargin[])231 void mexFunction
232 (
233     int nargout,
234     mxArray *pargout [ ],
235     int nargin,
236     const mxArray *pargin [ ]
237 )
238 {
239 
240     bool malloc_debug = GB_mx_get_global (true) ;
241 
242     // check inputs
243     if (nargout > 1 || nargin < 7 || nargin > 8)
244     {
245         mexErrMsgTxt ("Usage: " USAGE) ;
246     }
247 
248     // get C (make a deep copy)
249     #define GET_DEEP_COPY \
250     C = GB_mx_mxArray_to_Matrix (pargin [0], "C input", true, true) ;
251     #define FREE_DEEP_COPY GrB_Matrix_free_(&C) ;
252     GET_DEEP_COPY ;
253     if (C == NULL)
254     {
255         FREE_ALL ;
256         mexErrMsgTxt ("C failed") ;
257     }
258 
259     // get Mask (shallow copy)
260     Mask = GB_mx_mxArray_to_Matrix (pargin [1], "Mask", false, false) ;
261     if (Mask == NULL && !mxIsEmpty (pargin [1]))
262     {
263         FREE_ALL ;
264         mexErrMsgTxt ("Mask failed") ;
265     }
266 
267     // get how.  0: use GxB_Scalar, 1: use bare C scalar
268     GET_SCALAR (4, int, how, 0) ;
269 
270     // get scalar (shallow copy)
271     S = GB_mx_mxArray_to_Matrix (pargin [5], "scalar input", false, true) ;
272     if (S == NULL || S->magic != GB_MAGIC)
273     {
274         FREE_ALL ;
275         mexErrMsgTxt ("scalar failed") ;
276     }
277     GrB_Index snrows, sncols, snvals ;
278     GrB_Matrix_nrows (&snrows, S) ;
279     GrB_Matrix_ncols (&sncols, S) ;
280     GrB_Matrix_nvals (&snvals, S) ;
281     GxB_Format_Value fmt ;
282     GxB_Matrix_Option_get_(S, GxB_FORMAT, &fmt) ;
283     if (snrows != 1 || sncols != 1 || snvals != 1 || fmt != GxB_BY_COL)
284     {
285         FREE_ALL ;
286         mexErrMsgTxt ("scalar failed") ;
287     }
288     scalar = (GxB_Scalar) S ;
289     GrB_Info info = GxB_Scalar_fprint (scalar, "scalar", GxB_SILENT, NULL) ;
290     if (info != GrB_SUCCESS)
291     {
292         FREE_ALL ;
293         mexErrMsgTxt ("scalar failed") ;
294     }
295 
296     // get A (shallow copy)
297     A = GB_mx_mxArray_to_Matrix (pargin [6], "A input", false, true) ;
298     if (A == NULL || A->magic != GB_MAGIC)
299     {
300         FREE_ALL ;
301         mexErrMsgTxt ("A failed") ;
302     }
303 
304     // get accum, if present
305     bool user_complex = (Complex != GxB_FC64)
306         && (C->type == Complex || A->type == Complex) ;
307     if (!GB_mx_mxArray_to_BinaryOp (&accum, pargin [2], "accum",
308         C->type, user_complex))
309     {
310         FREE_ALL ;
311         mexErrMsgTxt ("accum failed") ;
312     }
313 
314     // get op
315     if (!GB_mx_mxArray_to_BinaryOp (&op, pargin [3], "op",
316         A->type, user_complex) || op == NULL)
317     {
318         FREE_ALL ;
319         mexErrMsgTxt ("UnaryOp failed") ;
320     }
321 
322     // get desc
323     if (!GB_mx_mxArray_to_Descriptor (&desc, PARGIN (7), "desc"))
324     {
325         FREE_ALL ;
326         mexErrMsgTxt ("desc failed") ;
327     }
328 
329     // C<Mask> = accum(C,op(x,A))
330     if (GB_NCOLS (C) == 1 && (desc == NULL || desc->in1 == GxB_DEFAULT))
331     {
332         // this is just to test the Vector version
333         METHOD (apply1 (false)) ;
334     }
335     else
336     {
337         METHOD (apply1 (true)) ;
338     }
339 
340     // return C to MATLAB as a struct and free the GraphBLAS C
341     pargout [0] = GB_mx_Matrix_to_mxArray (&C, "C output", true) ;
342 
343     FREE_ALL ;
344 }
345 
346