1 //------------------------------------------------------------------------------
2 // GB_mex_assign: C<Mask>(I,J) = accum (C (I,J), A)
3 //------------------------------------------------------------------------------
4
5 // SuiteSparse:GraphBLAS, Timothy A. Davis, (c) 2017-2021, All Rights Reserved.
6 // SPDX-License-Identifier: Apache-2.0
7
8 // This function is a wrapper for GrB_Matrix_assign, GrB_Matrix_assign_T
9 // GrB_Vector_assign, and GrB_Vector_assign_T. For these uses, the Mask must
10 // always be the same size as C.
11
12 // This mexFunction does not call GrB_Row_assign or GrB_Col_assign, since
13 // the Mask is a single row or column in these cases, and C is not modified
14 // outside that single row (for GrB_Row_assign) or column (for GrB_Col_assign).
15
16 // This function does the same thing as the MATLAB mimic GB_spec_assign.m.
17
18 //------------------------------------------------------------------------------
19
20 #include "GB_mex.h"
21
22 #define USAGE "C =GB_mex_assign (C, Mask, accum, A, I, J, desc) or (C, Work)"
23
24 #define FREE_ALL \
25 { \
26 GrB_Matrix_free_(&A) ; \
27 GrB_Matrix_free_(&Mask) ; \
28 GrB_Matrix_free_(&C) ; \
29 GrB_Descriptor_free_(&desc) ; \
30 GB_mx_put_global (true) ; \
31 }
32
33 #define GET_DEEP_COPY \
34 C = GB_mx_mxArray_to_Matrix (pargin [0], "C input", true, true) ;
35
36 #define FREE_DEEP_COPY GrB_Matrix_free_(&C) ;
37
38 GrB_Matrix C = NULL ;
39 GrB_Matrix Mask = NULL ;
40 GrB_Matrix A = NULL ;
41 GrB_Descriptor desc = NULL ;
42 GrB_BinaryOp accum = NULL ;
43 GrB_Index *I = NULL, ni = 0, I_range [3] ;
44 GrB_Index *J = NULL, nj = 0, J_range [3] ;
45 bool ignore ;
46 bool malloc_debug = false ;
47 GrB_Info info = GrB_SUCCESS ;
48 GrB_Info assign (void) ;
49
50 GrB_Info many_assign
51 (
52 int nwork,
53 int fA,
54 int fI,
55 int fJ,
56 int faccum,
57 int fMask,
58 int fdesc,
59 const mxArray *pargin [ ]
60 ) ;
61
62 //------------------------------------------------------------------------------
63 // assign: perform a single assignment
64 //------------------------------------------------------------------------------
65
66 #define OK(method) \
67 { \
68 info = method ; \
69 if (info != GrB_SUCCESS) \
70 { \
71 return (info) ; \
72 } \
73 }
74
assign()75 GrB_Info assign ( )
76 {
77 bool at = (desc != NULL && desc->in0 == GrB_TRAN) ;
78 GrB_Info info ;
79
80 ASSERT_MATRIX_OK (C, "C", GB0) ;
81 ASSERT_MATRIX_OK_OR_NULL (Mask, "Mask", GB0) ;
82 ASSERT_MATRIX_OK (A, "A", GB0) ;
83 ASSERT_BINARYOP_OK_OR_NULL (accum, "accum", GB0) ;
84 ASSERT_DESCRIPTOR_OK_OR_NULL (desc, "desc", GB0) ;
85
86 if (GB_NROWS (A) == 1 && GB_NCOLS (A) == 1 && GB_NNZ (A) == 1)
87 {
88 // scalar expansion to matrix or vector
89 GB_void *Ax = A->x ;
90
91 if (ni == 1 && nj == 1 && Mask == NULL && I != GrB_ALL && J != GrB_ALL
92 && GB_op_is_second (accum, C->type) && A->type->code < GB_FC64_code
93 && desc == NULL)
94 {
95 // test GrB_Matrix_setElement
96 #define ASSIGN(prefix,suffix,type) \
97 { \
98 type x = ((type *) Ax) [0] ; \
99 OK (prefix ## Matrix_setElement ## suffix \
100 (C, x, I [0], J [0])) ; \
101 } break ;
102
103 switch (A->type->code)
104 {
105 case GB_BOOL_code : ASSIGN (GrB_, _BOOL, bool) ;
106 case GB_INT8_code : ASSIGN (GrB_, _INT8, int8_t) ;
107 case GB_INT16_code : ASSIGN (GrB_, _INT16, int16_t) ;
108 case GB_INT32_code : ASSIGN (GrB_, _INT32, int32_t) ;
109 case GB_INT64_code : ASSIGN (GrB_, _INT64, int64_t) ;
110 case GB_UINT8_code : ASSIGN (GrB_, _UINT8, uint8_t) ;
111 case GB_UINT16_code : ASSIGN (GrB_, _UINT16, uint16_t) ;
112 case GB_UINT32_code : ASSIGN (GrB_, _UINT32, uint32_t) ;
113 case GB_UINT64_code : ASSIGN (GrB_, _UINT64, uint64_t) ;
114 case GB_FP32_code : ASSIGN (GrB_, _FP32, float) ;
115 case GB_FP64_code : ASSIGN (GrB_, _FP64, double) ;
116 case GB_FC32_code : ASSIGN (GxB_, _FC32, GxB_FC32_t) ;
117 case GB_FC64_code : ASSIGN (GxB_, _FC64, GxB_FC64_t) ;
118 case GB_UDT_code :
119 default:
120 FREE_ALL ;
121 mexErrMsgTxt ("unknown type: col setEl") ;
122 }
123
124 ASSERT_MATRIX_OK (C, "C after setElement", GB0) ;
125
126 }
127
128 if (C->vdim == 1)
129 {
130
131 // test GrB_Vector_assign_scalar functions
132 #undef ASSIGN
133 #define ASSIGN(prefix,suffix,type) \
134 { \
135 type x = ((type *) Ax) [0] ; \
136 OK (prefix ## Vector_assign ## suffix ((GrB_Vector) C, \
137 (GrB_Vector) Mask, accum, x, I, ni, desc)) ; \
138 } break ;
139
140 switch (A->type->code)
141 {
142 case GB_BOOL_code : ASSIGN (GrB_, _BOOL, bool) ;
143 case GB_INT8_code : ASSIGN (GrB_, _INT8, int8_t) ;
144 case GB_INT16_code : ASSIGN (GrB_, _INT16, int16_t) ;
145 case GB_INT32_code : ASSIGN (GrB_, _INT32, int32_t) ;
146 case GB_INT64_code : ASSIGN (GrB_, _INT64, int64_t) ;
147 case GB_UINT8_code : ASSIGN (GrB_, _UINT8, uint8_t) ;
148 case GB_UINT16_code : ASSIGN (GrB_, _UINT16, uint16_t) ;
149 case GB_UINT32_code : ASSIGN (GrB_, _UINT32, uint32_t) ;
150 case GB_UINT64_code : ASSIGN (GrB_, _UINT64, uint64_t) ;
151 case GB_FP32_code : ASSIGN (GrB_, _FP32, float) ;
152 case GB_FP64_code : ASSIGN (GrB_, _FP64, double) ;
153 case GB_FC32_code : ASSIGN (GxB_, _FC32, GxB_FC32_t) ;
154 case GB_FC64_code : ASSIGN (GxB_, _FC64, GxB_FC64_t) ;
155 case GB_UDT_code :
156 {
157 OK (GrB_Vector_assign_UDT ((GrB_Vector) C,
158 (GrB_Vector) Mask, accum, Ax, I, ni, desc)) ;
159 }
160 break ;
161 default:
162 FREE_ALL ;
163 mexErrMsgTxt ("unknown type: vec assign") ;
164 }
165
166 }
167 else
168 {
169
170 // test Matrix_assign_scalar functions
171 #undef ASSIGN
172 #define ASSIGN(prefix,suffix,type) \
173 { \
174 type x = ((type *) Ax) [0] ; \
175 OK (prefix ## Matrix_assign ## suffix (C, Mask, accum, \
176 x, I, ni, J, nj,desc)) ; \
177 } break ;
178
179 switch (A->type->code)
180 {
181 case GB_BOOL_code : ASSIGN (GrB_, _BOOL, bool) ;
182 case GB_INT8_code : ASSIGN (GrB_, _INT8, int8_t) ;
183 case GB_INT16_code : ASSIGN (GrB_, _INT16, int16_t) ;
184 case GB_INT32_code : ASSIGN (GrB_, _INT32, int32_t) ;
185 case GB_INT64_code : ASSIGN (GrB_, _INT64, int64_t) ;
186 case GB_UINT8_code : ASSIGN (GrB_, _UINT8, uint8_t) ;
187 case GB_UINT16_code : ASSIGN (GrB_, _UINT16, uint16_t) ;
188 case GB_UINT32_code : ASSIGN (GrB_, _UINT32, uint32_t) ;
189 case GB_UINT64_code : ASSIGN (GrB_, _UINT64, uint64_t) ;
190 case GB_FP32_code : ASSIGN (GrB_, _FP32, float) ;
191 case GB_FP64_code : ASSIGN (GrB_, _FP64, double) ;
192 case GB_FC32_code : ASSIGN (GxB_, _FC32, GxB_FC32_t) ;
193 case GB_FC64_code : ASSIGN (GxB_, _FC64, GxB_FC64_t) ;
194 case GB_UDT_code :
195 {
196 OK (GrB_Matrix_assign_UDT (C, Mask, accum,
197 Ax, I, ni, J, nj, desc)) ;
198 }
199 break ;
200
201 default:
202 FREE_ALL ;
203 mexErrMsgTxt ("unknown type: mtx assign") ;
204 }
205 }
206
207 }
208 else if (C->vdim == 1 && A->vdim == 1 &&
209 (Mask == NULL || Mask->vdim == 1) && !at)
210 {
211 // test GrB_Vector_assign
212 OK (GrB_Vector_assign_((GrB_Vector) C, (GrB_Vector) Mask, accum,
213 (GrB_Vector) A, I, ni, desc)) ;
214 }
215 else
216 {
217 // standard submatrix assignment
218 OK (GrB_Matrix_assign_(C, Mask, accum, A, I, ni, J, nj, desc)) ;
219 }
220
221 ASSERT_MATRIX_OK (C, "Final C before wait", GB0) ;
222 OK (GrB_Matrix_wait_(&C)) ;
223 return (info) ;
224 }
225
226 //------------------------------------------------------------------------------
227 // many_assign: do a sequence of assignments
228 //------------------------------------------------------------------------------
229
230 // The list of assignments is in a struct array
231
many_assign(int nwork,int fA,int fI,int fJ,int faccum,int fMask,int fdesc,const mxArray * pargin[])232 GrB_Info many_assign
233 (
234 int nwork,
235 int fA,
236 int fI,
237 int fJ,
238 int faccum,
239 int fMask,
240 int fdesc,
241 const mxArray *pargin [ ]
242 )
243 {
244 GrB_Info info = GrB_SUCCESS ;
245
246 for (int64_t k = 0 ; k < nwork ; k++)
247 {
248
249 //----------------------------------------------------------------------
250 // get the kth work to do
251 //----------------------------------------------------------------------
252
253 // each struct has fields A, I, J, and optionally Mask, accum, and desc
254
255 mxArray *p ;
256
257 // [ turn off malloc debugging
258 bool save = GB_Global_malloc_debug_get ( ) ;
259 GB_Global_malloc_debug_set (false) ;
260
261 // get Mask (shallow copy)
262 Mask = NULL ;
263 if (fMask >= 0)
264 {
265 p = mxGetFieldByNumber (pargin [1], k, fMask) ;
266 Mask = GB_mx_mxArray_to_Matrix (p, "Mask", false, false) ;
267 if (Mask == NULL && !mxIsEmpty (p))
268 {
269 FREE_ALL ;
270 mexErrMsgTxt ("Mask failed") ;
271 }
272 }
273
274 // get A (shallow copy)
275 p = mxGetFieldByNumber (pargin [1], k, fA) ;
276 A = GB_mx_mxArray_to_Matrix (p, "A", false, true) ;
277 if (A == NULL)
278 {
279 FREE_ALL ;
280 mexErrMsgTxt ("A failed") ;
281 }
282
283 // get accum, if present
284 accum = NULL ;
285 if (faccum >= 0)
286 {
287 p = mxGetFieldByNumber (pargin [1], k, faccum) ;
288 bool user_complex = (Complex != GxB_FC64)
289 && (C->type == Complex || A->type == Complex) ;
290 if (!GB_mx_mxArray_to_BinaryOp (&accum, p, "accum",
291 C->type, user_complex))
292 {
293 FREE_ALL ;
294 mexErrMsgTxt ("accum failed") ;
295 }
296 }
297
298 // get I
299 p = mxGetFieldByNumber (pargin [1], k, fI) ;
300 if (!GB_mx_mxArray_to_indices (&I, p, &ni, I_range, &ignore))
301 {
302 FREE_ALL ;
303 mexErrMsgTxt ("I failed") ;
304 }
305
306 // get J
307 p = mxGetFieldByNumber (pargin [1], k, fJ) ;
308 if (!GB_mx_mxArray_to_indices (&J, p, &nj, J_range, &ignore))
309 {
310 FREE_ALL ;
311 mexErrMsgTxt ("J failed") ;
312 }
313
314 // get desc
315 desc = NULL ;
316 if (fdesc > 0)
317 {
318 p = mxGetFieldByNumber (pargin [1], k, fdesc) ;
319 if (!GB_mx_mxArray_to_Descriptor (&desc, p, "desc"))
320 {
321 FREE_ALL ;
322 mexErrMsgTxt ("desc failed") ;
323 }
324 }
325
326 // restore malloc debugging to test the method
327 GB_Global_malloc_debug_set (save) ; // ]
328
329 //----------------------------------------------------------------------
330 // C<Mask>(I,J) = A
331 //----------------------------------------------------------------------
332
333 info = assign ( ) ;
334
335 GrB_Matrix_free_(&A) ;
336 GrB_Matrix_free_(&Mask) ;
337 GrB_Descriptor_free_(&desc) ;
338
339 if (info != GrB_SUCCESS)
340 {
341 return (info) ;
342 }
343 }
344
345 ASSERT_MATRIX_OK (C, "Final C before wait", GB0) ;
346 OK (GrB_Matrix_wait_(&C)) ;
347 return (info) ;
348 }
349
350 //------------------------------------------------------------------------------
351 // GB_mex_assign mexFunction
352 //------------------------------------------------------------------------------
353
mexFunction(int nargout,mxArray * pargout[],int nargin,const mxArray * pargin[])354 void mexFunction
355 (
356 int nargout,
357 mxArray *pargout [ ],
358 int nargin,
359 const mxArray *pargin [ ]
360 )
361 {
362
363 //--------------------------------------------------------------------------
364 // check inputs
365 //--------------------------------------------------------------------------
366
367 malloc_debug = GB_mx_get_global (true) ;
368 A = NULL ;
369 C = NULL ;
370 Mask = NULL ;
371 desc = NULL ;
372
373 if (nargout > 1 || ! (nargin == 2 || nargin == 6 || nargin == 7))
374 {
375 mexErrMsgTxt ("Usage: " USAGE) ;
376 }
377
378 //--------------------------------------------------------------------------
379 // get C (make a deep copy)
380 //--------------------------------------------------------------------------
381
382 GET_DEEP_COPY ;
383 if (C == NULL)
384 {
385 FREE_ALL ;
386 mexErrMsgTxt ("C failed") ;
387 }
388
389 if (nargin == 2)
390 {
391
392 //----------------------------------------------------------------------
393 // get a list of work to do: a struct array of length nwork
394 //----------------------------------------------------------------------
395
396 // each entry is a struct with fields:
397 // Mask, accum, A, I, J, desc
398
399 if (!mxIsStruct (pargin [1]))
400 {
401 FREE_ALL ;
402 mexErrMsgTxt ("2nd argument must be a struct") ;
403 }
404
405 int nwork = mxGetNumberOfElements (pargin [1]) ;
406 int nf = mxGetNumberOfFields (pargin [1]) ;
407 for (int f = 0 ; f < nf ; f++)
408 {
409 mxArray *p ;
410 for (int k = 0 ; k < nwork ; k++)
411 {
412 p = mxGetFieldByNumber (pargin [1], k, f) ;
413 }
414 }
415
416 int fA = mxGetFieldNumber (pargin [1], "A") ;
417 int fI = mxGetFieldNumber (pargin [1], "I") ;
418 int fJ = mxGetFieldNumber (pargin [1], "J") ;
419 int faccum = mxGetFieldNumber (pargin [1], "accum") ;
420 int fMask = mxGetFieldNumber (pargin [1], "Mask") ;
421 int fdesc = mxGetFieldNumber (pargin [1], "desc") ;
422
423 if (fA < 0 || fI < 0 || fJ < 0) mexErrMsgTxt ("A,I,J required") ;
424
425 METHOD (many_assign (nwork, fA, fI, fJ, faccum, fMask, fdesc, pargin)) ;
426
427 }
428 else
429 {
430
431 //----------------------------------------------------------------------
432 // C<Mask>(I,J) = A, with a single assignment
433 //----------------------------------------------------------------------
434
435 // get Mask (shallow copy)
436 Mask = GB_mx_mxArray_to_Matrix (pargin [1], "Mask", false, false) ;
437 if (Mask == NULL && !mxIsEmpty (pargin [1]))
438 {
439 FREE_ALL ;
440 mexErrMsgTxt ("Mask failed") ;
441 }
442
443 // get A (shallow copy)
444 A = GB_mx_mxArray_to_Matrix (pargin [3], "A", false, true) ;
445 if (A == NULL)
446 {
447 FREE_ALL ;
448 mexErrMsgTxt ("A failed") ;
449 }
450
451 // get accum, if present
452 bool user_complex = (Complex != GxB_FC64)
453 && (C->type == Complex || A->type == Complex) ;
454 accum = NULL ;
455 if (!GB_mx_mxArray_to_BinaryOp (&accum, pargin [2], "accum",
456 C->type, user_complex))
457 {
458 FREE_ALL ;
459 mexErrMsgTxt ("accum failed") ;
460 }
461
462 // get I
463 if (!GB_mx_mxArray_to_indices (&I, pargin [4], &ni, I_range, &ignore))
464 {
465 FREE_ALL ;
466 mexErrMsgTxt ("I failed") ;
467 }
468
469 // get J
470 if (!GB_mx_mxArray_to_indices (&J, pargin [5], &nj, J_range, &ignore))
471 {
472 FREE_ALL ;
473 mexErrMsgTxt ("J failed") ;
474 }
475
476 // get desc
477 if (!GB_mx_mxArray_to_Descriptor (&desc, PARGIN (6), "desc"))
478 {
479 FREE_ALL ;
480 mexErrMsgTxt ("desc failed") ;
481 }
482
483 // C<Mask>(I,J) = A
484
485 METHOD (assign ( )) ;
486 }
487
488 //--------------------------------------------------------------------------
489 // return C to MATLAB as a struct
490 //--------------------------------------------------------------------------
491
492 pargout [0] = GB_mx_Matrix_to_mxArray (&C, "C assign result", true) ;
493 FREE_ALL ;
494 }
495
496