1 //------------------------------------------------------------------------------
2 // GB_mex_apply1: C<Mask> = accum(C,op(x,A)) or op(x,A')
3 //------------------------------------------------------------------------------
4
5 // SuiteSparse:GraphBLAS, Timothy A. Davis, (c) 2017-2021, All Rights Reserved.
6 // SPDX-License-Identifier: Apache-2.0
7
8 //------------------------------------------------------------------------------
9
10 // Apply a binary operator to a matrix or vector, binding x to a scalar,
11
12 #include "GB_mex.h"
13
14 #define USAGE "C = GB_mex_apply1 (C, Mask, accum, op, how, x, A, desc)"
15
16 // if how == 0: use the GxB_Scalar and GxB_Matrix/Vector_apply_BinaryOp1st
17 // if how == 1: use the C scalar and GrB_Matrix/Vector_apply_BinaryOp1st_T
18
19 #define FREE_ALL \
20 { \
21 GrB_Matrix_free_(&C) ; \
22 GrB_Matrix_free_(&Mask) ; \
23 GrB_Matrix_free_(&S) ; \
24 GrB_Matrix_free_(&A) ; \
25 GrB_Descriptor_free_(&desc) ; \
26 GB_mx_put_global (true) ; \
27 }
28
29 GrB_Matrix C = NULL, S = NULL ;
30 GxB_Scalar scalar = NULL ;
31 GrB_Matrix Mask = NULL ;
32 GrB_Matrix A = NULL ;
33 GrB_Descriptor desc = NULL ;
34 GrB_BinaryOp accum = NULL ;
35 GrB_BinaryOp op = NULL ;
36 GrB_Info apply1 (bool is_matrix) ;
37 int how = 0 ;
38
39 //------------------------------------------------------------------------------
40
apply1(bool is_matrix)41 GrB_Info apply1 (bool is_matrix)
42 {
43 GrB_Info info ;
44 GrB_Type stype ;
45 GxB_Scalar_type (&stype, scalar) ;
46
47 if (is_matrix && how == 1)
48 {
49 if (stype == GrB_BOOL)
50 {
51 bool x = *((bool *) (scalar->x)) ;
52 info = GrB_Matrix_apply_BinaryOp1st_BOOL_
53 (C, Mask, accum, op, x, A, desc) ;
54 }
55 else if (stype == GrB_INT8)
56 {
57 int8_t x = *((int8_t *) (scalar->x)) ;
58 info = GrB_Matrix_apply_BinaryOp1st_INT8_
59 (C, Mask, accum, op, x, A, desc) ;
60 }
61 else if (stype == GrB_INT16)
62 {
63 int16_t x = *((int16_t *) (scalar->x)) ;
64 info = GrB_Matrix_apply_BinaryOp1st_INT16_
65 (C, Mask, accum, op, x, A, desc) ;
66 }
67 else if (stype == GrB_INT32)
68 {
69 int32_t x = *((int32_t *) (scalar->x)) ;
70 info = GrB_Matrix_apply_BinaryOp1st_INT32_
71 (C, Mask, accum, op, x, A, desc) ;
72 }
73 else if (stype == GrB_INT64)
74 {
75 int64_t x = *((int64_t *) (scalar->x)) ;
76 info = GrB_Matrix_apply_BinaryOp1st_INT64_
77 (C, Mask, accum, op, x, A, desc) ;
78 }
79 else if (stype == GrB_UINT8)
80 {
81 uint8_t x = *((uint8_t *) (scalar->x)) ;
82 info = GrB_Matrix_apply_BinaryOp1st_UINT8_
83 (C, Mask, accum, op, x, A, desc) ;
84 }
85 else if (stype == GrB_UINT16)
86 {
87 uint16_t x = *((uint16_t *) (scalar->x)) ;
88 info = GrB_Matrix_apply_BinaryOp1st_UINT16_
89 (C, Mask, accum, op, x, A, desc) ;
90 }
91 else if (stype == GrB_UINT32)
92 {
93 uint32_t x = *((uint32_t *) (scalar->x)) ;
94 info = GrB_Matrix_apply_BinaryOp1st_UINT32_
95 (C, Mask, accum, op, x, A, desc) ;
96 }
97 else if (stype == GrB_UINT64)
98 {
99 uint64_t x = *((uint64_t *) (scalar->x)) ;
100 info = GrB_Matrix_apply_BinaryOp1st_UINT64_
101 (C, Mask, accum, op, x, A, desc) ;
102 }
103 else if (stype == GrB_FP32)
104 {
105 float x = *((float *) (scalar->x)) ;
106 info = GrB_Matrix_apply_BinaryOp1st_FP32_
107 (C, Mask, accum, op, x, A, desc) ;
108 }
109 else if (stype == GrB_FP64)
110 {
111 double x = *((double *) (scalar->x)) ;
112 info = GrB_Matrix_apply_BinaryOp1st_FP64_
113 (C, Mask, accum, op, x, A, desc) ;
114 }
115 else if (stype == GxB_FC32)
116 {
117 GxB_FC32_t x = *((GxB_FC32_t *) (scalar->x)) ;
118 info = GxB_Matrix_apply_BinaryOp1st_FC32_
119 (C, Mask, accum, op, x, A, desc) ;
120 }
121 else if (stype == GxB_FC64)
122 {
123 GxB_FC64_t x = *((GxB_FC64_t *) (scalar->x)) ;
124 info = GxB_Matrix_apply_BinaryOp1st_FC64_
125 (C, Mask, accum, op, x, A, desc) ;
126 }
127 }
128 else if (is_matrix && how == 0)
129 {
130 info = GxB_Matrix_apply_BinaryOp1st_
131 (C, Mask, accum, op, scalar, A, desc) ;
132 }
133 else if (!is_matrix && how == 1)
134 {
135 GrB_Vector w = (GrB_Vector) C ;
136 GrB_Vector m = (GrB_Vector) Mask ;
137 GrB_Vector a = (GrB_Vector) A ;
138 if (stype == GrB_BOOL)
139 {
140 bool x = *((bool *) (scalar->x)) ;
141 info = GrB_Vector_apply_BinaryOp1st_BOOL_
142 (w, Mask, accum, op, x, A, desc) ;
143 }
144 else if (stype == GrB_INT8)
145 {
146 int8_t x = *((int8_t *) (scalar->x)) ;
147 info = GrB_Vector_apply_BinaryOp1st_INT8_
148 (w, Mask, accum, op, x, A, desc) ;
149 }
150 else if (stype == GrB_INT16)
151 {
152 int16_t x = *((int16_t *) (scalar->x)) ;
153 info = GrB_Vector_apply_BinaryOp1st_INT16_
154 (w, Mask, accum, op, x, A, desc) ;
155 }
156 else if (stype == GrB_INT32)
157 {
158 int32_t x = *((int32_t *) (scalar->x)) ;
159 info = GrB_Vector_apply_BinaryOp1st_INT32_
160 (w, Mask, accum, op, x, A, desc) ;
161 }
162 else if (stype == GrB_INT64)
163 {
164 int64_t x = *((int64_t *) (scalar->x)) ;
165 info = GrB_Vector_apply_BinaryOp1st_INT64_
166 (w, Mask, accum, op, x, A, desc) ;
167 }
168 else if (stype == GrB_UINT8)
169 {
170 uint8_t x = *((uint8_t *) (scalar->x)) ;
171 info = GrB_Vector_apply_BinaryOp1st_UINT8_
172 (w, Mask, accum, op, x, A, desc) ;
173 }
174 else if (stype == GrB_UINT16)
175 {
176 uint16_t x = *((uint16_t *) (scalar->x)) ;
177 info = GrB_Vector_apply_BinaryOp1st_UINT16_
178 (w, Mask, accum, op, x, A, desc) ;
179 }
180 else if (stype == GrB_UINT32)
181 {
182 uint32_t x = *((uint32_t *) (scalar->x)) ;
183 info = GrB_Vector_apply_BinaryOp1st_UINT32_
184 (w, Mask, accum, op, x, A, desc) ;
185 }
186 else if (stype == GrB_UINT64)
187 {
188 uint64_t x = *((uint64_t *) (scalar->x)) ;
189 info = GrB_Vector_apply_BinaryOp1st_UINT64_
190 (w, Mask, accum, op, x, A, desc) ;
191 }
192 else if (stype == GrB_FP32)
193 {
194 float x = *((float *) (scalar->x)) ;
195 info = GrB_Vector_apply_BinaryOp1st_FP32_
196 (w, Mask, accum, op, x, A, desc) ;
197 }
198 else if (stype == GrB_FP64)
199 {
200 double x = *((double *) (scalar->x)) ;
201 info = GrB_Vector_apply_BinaryOp1st_FP64_
202 (w, Mask, accum, op, x, A, desc) ;
203 }
204 else if (stype == GxB_FC32)
205 {
206 GxB_FC32_t x = *((GxB_FC32_t *) (scalar->x)) ;
207 info = GxB_Vector_apply_BinaryOp1st_FC32_
208 (w, Mask, accum, op, x, A, desc) ;
209 }
210 else if (stype == GxB_FC64)
211 {
212 GxB_FC64_t x = *((GxB_FC64_t *) (scalar->x)) ;
213 info = GxB_Vector_apply_BinaryOp1st_FC64_
214 (w, Mask, accum, op, x, A, desc) ;
215 }
216 }
217 else if (!is_matrix && how == 0)
218 {
219 GrB_Vector w = (GrB_Vector) C ;
220 GrB_Vector m = (GrB_Vector) Mask ;
221 GrB_Vector a = (GrB_Vector) A ;
222 info = GxB_Vector_apply_BinaryOp1st_
223 (w, m, accum, op, scalar, a, desc) ;
224 }
225
226 return (info) ;
227 }
228
229 //------------------------------------------------------------------------------
230
mexFunction(int nargout,mxArray * pargout[],int nargin,const mxArray * pargin[])231 void mexFunction
232 (
233 int nargout,
234 mxArray *pargout [ ],
235 int nargin,
236 const mxArray *pargin [ ]
237 )
238 {
239
240 bool malloc_debug = GB_mx_get_global (true) ;
241
242 // check inputs
243 if (nargout > 1 || nargin < 7 || nargin > 8)
244 {
245 mexErrMsgTxt ("Usage: " USAGE) ;
246 }
247
248 // get C (make a deep copy)
249 #define GET_DEEP_COPY \
250 C = GB_mx_mxArray_to_Matrix (pargin [0], "C input", true, true) ;
251 #define FREE_DEEP_COPY GrB_Matrix_free_(&C) ;
252 GET_DEEP_COPY ;
253 if (C == NULL)
254 {
255 FREE_ALL ;
256 mexErrMsgTxt ("C failed") ;
257 }
258
259 // get Mask (shallow copy)
260 Mask = GB_mx_mxArray_to_Matrix (pargin [1], "Mask", false, false) ;
261 if (Mask == NULL && !mxIsEmpty (pargin [1]))
262 {
263 FREE_ALL ;
264 mexErrMsgTxt ("Mask failed") ;
265 }
266
267 // get how. 0: use GxB_Scalar, 1: use bare C scalar
268 GET_SCALAR (4, int, how, 0) ;
269
270 // get scalar (shallow copy)
271 S = GB_mx_mxArray_to_Matrix (pargin [5], "scalar input", false, true) ;
272 if (S == NULL || S->magic != GB_MAGIC)
273 {
274 FREE_ALL ;
275 mexErrMsgTxt ("scalar failed") ;
276 }
277 GrB_Index snrows, sncols, snvals ;
278 GrB_Matrix_nrows (&snrows, S) ;
279 GrB_Matrix_ncols (&sncols, S) ;
280 GrB_Matrix_nvals (&snvals, S) ;
281 GxB_Format_Value fmt ;
282 GxB_Matrix_Option_get_(S, GxB_FORMAT, &fmt) ;
283 if (snrows != 1 || sncols != 1 || snvals != 1 || fmt != GxB_BY_COL)
284 {
285 FREE_ALL ;
286 mexErrMsgTxt ("scalar failed") ;
287 }
288 scalar = (GxB_Scalar) S ;
289 GrB_Info info = GxB_Scalar_fprint (scalar, "scalar", GxB_SILENT, NULL) ;
290 if (info != GrB_SUCCESS)
291 {
292 FREE_ALL ;
293 mexErrMsgTxt ("scalar failed") ;
294 }
295
296 // get A (shallow copy)
297 A = GB_mx_mxArray_to_Matrix (pargin [6], "A input", false, true) ;
298 if (A == NULL || A->magic != GB_MAGIC)
299 {
300 FREE_ALL ;
301 mexErrMsgTxt ("A failed") ;
302 }
303
304 // get accum, if present
305 bool user_complex = (Complex != GxB_FC64)
306 && (C->type == Complex || A->type == Complex) ;
307 if (!GB_mx_mxArray_to_BinaryOp (&accum, pargin [2], "accum",
308 C->type, user_complex))
309 {
310 FREE_ALL ;
311 mexErrMsgTxt ("accum failed") ;
312 }
313
314 // get op
315 if (!GB_mx_mxArray_to_BinaryOp (&op, pargin [3], "op",
316 A->type, user_complex) || op == NULL)
317 {
318 FREE_ALL ;
319 mexErrMsgTxt ("UnaryOp failed") ;
320 }
321
322 // get desc
323 if (!GB_mx_mxArray_to_Descriptor (&desc, PARGIN (7), "desc"))
324 {
325 FREE_ALL ;
326 mexErrMsgTxt ("desc failed") ;
327 }
328
329 // C<Mask> = accum(C,op(x,A))
330 if (GB_NCOLS (C) == 1 && (desc == NULL || desc->in1 == GxB_DEFAULT))
331 {
332 // this is just to test the Vector version
333 METHOD (apply1 (false)) ;
334 }
335 else
336 {
337 METHOD (apply1 (true)) ;
338 }
339
340 // return C to MATLAB as a struct and free the GraphBLAS C
341 pargout [0] = GB_mx_Matrix_to_mxArray (&C, "C output", true) ;
342
343 FREE_ALL ;
344 }
345
346