1 //------------------------------------------------------------------------------
2 // GB_mex_assign: C<Mask>(I,J) = accum (C (I,J), A)
3 //------------------------------------------------------------------------------
4 
5 // SuiteSparse:GraphBLAS, Timothy A. Davis, (c) 2017-2021, All Rights Reserved.
6 // SPDX-License-Identifier: Apache-2.0
7 
8 // This function is a wrapper for GrB_Matrix_assign, GrB_Matrix_assign_T
9 // GrB_Vector_assign, and GrB_Vector_assign_T.  For these uses, the Mask must
10 // always be the same size as C.
11 
12 // This mexFunction does not call GrB_Row_assign or GrB_Col_assign, since
13 // the Mask is a single row or column in these cases, and C is not modified
14 // outside that single row (for GrB_Row_assign) or column (for GrB_Col_assign).
15 
16 // This function does the same thing as the MATLAB mimic GB_spec_assign.m.
17 
18 //------------------------------------------------------------------------------
19 
20 #include "GB_mex.h"
21 
22 #define USAGE "C =GB_mex_assign (C, Mask, accum, A, I, J, desc) or (C, Work)"
23 
24 #define FREE_ALL                        \
25 {                                       \
26     GrB_Matrix_free_(&A) ;              \
27     GrB_Matrix_free_(&Mask) ;           \
28     GrB_Matrix_free_(&C) ;              \
29     GrB_Descriptor_free_(&desc) ;       \
30     GB_mx_put_global (true) ;           \
31 }
32 
33 #define GET_DEEP_COPY \
34     C = GB_mx_mxArray_to_Matrix (pargin [0], "C input", true, true) ;
35 
36 #define FREE_DEEP_COPY GrB_Matrix_free_(&C) ;
37 
38 GrB_Matrix C = NULL ;
39 GrB_Matrix Mask = NULL ;
40 GrB_Matrix A = NULL ;
41 GrB_Descriptor desc = NULL ;
42 GrB_BinaryOp accum = NULL ;
43 GrB_Index *I = NULL, ni = 0, I_range [3] ;
44 GrB_Index *J = NULL, nj = 0, J_range [3] ;
45 bool ignore ;
46 bool malloc_debug = false ;
47 GrB_Info info = GrB_SUCCESS ;
48 GrB_Info assign (void) ;
49 
50 GrB_Info many_assign
51 (
52     int nwork,
53     int fA,
54     int fI,
55     int fJ,
56     int faccum,
57     int fMask,
58     int fdesc,
59     const mxArray *pargin [ ]
60 ) ;
61 
62 //------------------------------------------------------------------------------
63 // assign: perform a single assignment
64 //------------------------------------------------------------------------------
65 
66 #define OK(method)                      \
67 {                                       \
68     info = method ;                     \
69     if (info != GrB_SUCCESS)            \
70     {                                   \
71         return (info) ;                 \
72     }                                   \
73 }
74 
assign()75 GrB_Info assign ( )
76 {
77     bool at = (desc != NULL && desc->in0 == GrB_TRAN) ;
78     GrB_Info info ;
79 
80     ASSERT_MATRIX_OK (C, "C", GB0) ;
81     ASSERT_MATRIX_OK_OR_NULL (Mask, "Mask", GB0) ;
82     ASSERT_MATRIX_OK (A, "A", GB0) ;
83     ASSERT_BINARYOP_OK_OR_NULL (accum, "accum", GB0) ;
84     ASSERT_DESCRIPTOR_OK_OR_NULL (desc, "desc", GB0) ;
85 
86     if (GB_NROWS (A) == 1 && GB_NCOLS (A) == 1 && GB_NNZ (A) == 1)
87     {
88         // scalar expansion to matrix or vector
89         GB_void *Ax = A->x ;
90 
91         if (ni == 1 && nj == 1 && Mask == NULL && I != GrB_ALL && J != GrB_ALL
92             && GB_op_is_second (accum, C->type) && A->type->code < GB_FC64_code
93             && desc == NULL)
94         {
95             // test GrB_Matrix_setElement
96             #define ASSIGN(prefix,suffix,type)                          \
97             {                                                           \
98                 type x = ((type *) Ax) [0] ;                            \
99                 OK (prefix ## Matrix_setElement ## suffix               \
100                     (C, x, I [0], J [0])) ;                             \
101             } break ;
102 
103             switch (A->type->code)
104             {
105                 case GB_BOOL_code   : ASSIGN (GrB_, _BOOL,   bool) ;
106                 case GB_INT8_code   : ASSIGN (GrB_, _INT8,   int8_t) ;
107                 case GB_INT16_code  : ASSIGN (GrB_, _INT16,  int16_t) ;
108                 case GB_INT32_code  : ASSIGN (GrB_, _INT32,  int32_t) ;
109                 case GB_INT64_code  : ASSIGN (GrB_, _INT64,  int64_t) ;
110                 case GB_UINT8_code  : ASSIGN (GrB_, _UINT8,  uint8_t) ;
111                 case GB_UINT16_code : ASSIGN (GrB_, _UINT16, uint16_t) ;
112                 case GB_UINT32_code : ASSIGN (GrB_, _UINT32, uint32_t) ;
113                 case GB_UINT64_code : ASSIGN (GrB_, _UINT64, uint64_t) ;
114                 case GB_FP32_code   : ASSIGN (GrB_, _FP32,   float) ;
115                 case GB_FP64_code   : ASSIGN (GrB_, _FP64,   double) ;
116                 case GB_FC32_code   : ASSIGN (GxB_, _FC32,   GxB_FC32_t) ;
117                 case GB_FC64_code   : ASSIGN (GxB_, _FC64,   GxB_FC64_t) ;
118                 case GB_UDT_code    :
119                 default:
120                     FREE_ALL ;
121                     mexErrMsgTxt ("unknown type: col setEl") ;
122             }
123 
124             ASSERT_MATRIX_OK (C, "C after setElement", GB0) ;
125 
126         }
127 
128         if (C->vdim == 1)
129         {
130 
131             // test GrB_Vector_assign_scalar functions
132             #undef  ASSIGN
133             #define ASSIGN(prefix,suffix,type)                          \
134             {                                                           \
135                 type x = ((type *) Ax) [0] ;                            \
136                 OK (prefix ## Vector_assign ## suffix ((GrB_Vector) C,  \
137                     (GrB_Vector) Mask, accum, x, I, ni, desc)) ;        \
138             } break ;
139 
140             switch (A->type->code)
141             {
142                 case GB_BOOL_code   : ASSIGN (GrB_, _BOOL,   bool) ;
143                 case GB_INT8_code   : ASSIGN (GrB_, _INT8,   int8_t) ;
144                 case GB_INT16_code  : ASSIGN (GrB_, _INT16,  int16_t) ;
145                 case GB_INT32_code  : ASSIGN (GrB_, _INT32,  int32_t) ;
146                 case GB_INT64_code  : ASSIGN (GrB_, _INT64,  int64_t) ;
147                 case GB_UINT8_code  : ASSIGN (GrB_, _UINT8,  uint8_t) ;
148                 case GB_UINT16_code : ASSIGN (GrB_, _UINT16, uint16_t) ;
149                 case GB_UINT32_code : ASSIGN (GrB_, _UINT32, uint32_t) ;
150                 case GB_UINT64_code : ASSIGN (GrB_, _UINT64, uint64_t) ;
151                 case GB_FP32_code   : ASSIGN (GrB_, _FP32,   float) ;
152                 case GB_FP64_code   : ASSIGN (GrB_, _FP64,   double) ;
153                 case GB_FC32_code   : ASSIGN (GxB_, _FC32,   GxB_FC32_t) ;
154                 case GB_FC64_code   : ASSIGN (GxB_, _FC64,   GxB_FC64_t) ;
155                 case GB_UDT_code    :
156                 {
157                     OK (GrB_Vector_assign_UDT ((GrB_Vector) C,
158                         (GrB_Vector) Mask, accum, Ax, I, ni, desc)) ;
159                 }
160                 break ;
161                 default:
162                     FREE_ALL ;
163                     mexErrMsgTxt ("unknown type: vec assign") ;
164             }
165 
166         }
167         else
168         {
169 
170             // test Matrix_assign_scalar functions
171             #undef  ASSIGN
172             #define ASSIGN(prefix,suffix,type)                          \
173             {                                                           \
174                 type x = ((type *) Ax) [0] ;                            \
175                 OK (prefix ## Matrix_assign ## suffix (C, Mask, accum,  \
176                     x, I, ni, J, nj,desc)) ;                            \
177             } break ;
178 
179             switch (A->type->code)
180             {
181                 case GB_BOOL_code   : ASSIGN (GrB_, _BOOL,   bool) ;
182                 case GB_INT8_code   : ASSIGN (GrB_, _INT8,   int8_t) ;
183                 case GB_INT16_code  : ASSIGN (GrB_, _INT16,  int16_t) ;
184                 case GB_INT32_code  : ASSIGN (GrB_, _INT32,  int32_t) ;
185                 case GB_INT64_code  : ASSIGN (GrB_, _INT64,  int64_t) ;
186                 case GB_UINT8_code  : ASSIGN (GrB_, _UINT8,  uint8_t) ;
187                 case GB_UINT16_code : ASSIGN (GrB_, _UINT16, uint16_t) ;
188                 case GB_UINT32_code : ASSIGN (GrB_, _UINT32, uint32_t) ;
189                 case GB_UINT64_code : ASSIGN (GrB_, _UINT64, uint64_t) ;
190                 case GB_FP32_code   : ASSIGN (GrB_, _FP32,   float) ;
191                 case GB_FP64_code   : ASSIGN (GrB_, _FP64,   double) ;
192                 case GB_FC32_code   : ASSIGN (GxB_, _FC32,   GxB_FC32_t) ;
193                 case GB_FC64_code   : ASSIGN (GxB_, _FC64,   GxB_FC64_t) ;
194                 case GB_UDT_code    :
195                 {
196                     OK (GrB_Matrix_assign_UDT (C, Mask, accum,
197                         Ax, I, ni, J, nj, desc)) ;
198                 }
199                 break ;
200 
201                 default:
202                     FREE_ALL ;
203                     mexErrMsgTxt ("unknown type: mtx assign") ;
204             }
205         }
206 
207     }
208     else if (C->vdim == 1 && A->vdim == 1 &&
209         (Mask == NULL || Mask->vdim == 1) && !at)
210     {
211         // test GrB_Vector_assign
212         OK (GrB_Vector_assign_((GrB_Vector) C, (GrB_Vector) Mask, accum,
213             (GrB_Vector) A, I, ni, desc)) ;
214     }
215     else
216     {
217         // standard submatrix assignment
218         OK (GrB_Matrix_assign_(C, Mask, accum, A, I, ni, J, nj, desc)) ;
219     }
220 
221     ASSERT_MATRIX_OK (C, "Final C before wait", GB0) ;
222     OK (GrB_Matrix_wait_(&C)) ;
223     return (info) ;
224 }
225 
226 //------------------------------------------------------------------------------
227 // many_assign: do a sequence of assignments
228 //------------------------------------------------------------------------------
229 
230 // The list of assignments is in a struct array
231 
many_assign(int nwork,int fA,int fI,int fJ,int faccum,int fMask,int fdesc,const mxArray * pargin[])232 GrB_Info many_assign
233 (
234     int nwork,
235     int fA,
236     int fI,
237     int fJ,
238     int faccum,
239     int fMask,
240     int fdesc,
241     const mxArray *pargin [ ]
242 )
243 {
244     GrB_Info info = GrB_SUCCESS ;
245 
246     for (int64_t k = 0 ; k < nwork ; k++)
247     {
248 
249         //----------------------------------------------------------------------
250         // get the kth work to do
251         //----------------------------------------------------------------------
252 
253         // each struct has fields A, I, J, and optionally Mask, accum, and desc
254 
255         mxArray *p ;
256 
257         // [ turn off malloc debugging
258         bool save = GB_Global_malloc_debug_get ( ) ;
259         GB_Global_malloc_debug_set (false) ;
260 
261         // get Mask (shallow copy)
262         Mask = NULL ;
263         if (fMask >= 0)
264         {
265             p = mxGetFieldByNumber (pargin [1], k, fMask) ;
266             Mask = GB_mx_mxArray_to_Matrix (p, "Mask", false, false) ;
267             if (Mask == NULL && !mxIsEmpty (p))
268             {
269                 FREE_ALL ;
270                 mexErrMsgTxt ("Mask failed") ;
271             }
272         }
273 
274         // get A (shallow copy)
275         p = mxGetFieldByNumber (pargin [1], k, fA) ;
276         A = GB_mx_mxArray_to_Matrix (p, "A", false, true) ;
277         if (A == NULL)
278         {
279             FREE_ALL ;
280             mexErrMsgTxt ("A failed") ;
281         }
282 
283         // get accum, if present
284         accum = NULL ;
285         if (faccum >= 0)
286         {
287             p = mxGetFieldByNumber (pargin [1], k, faccum) ;
288             bool user_complex = (Complex != GxB_FC64)
289                 && (C->type == Complex || A->type == Complex) ;
290             if (!GB_mx_mxArray_to_BinaryOp (&accum, p, "accum",
291                 C->type, user_complex))
292             {
293                 FREE_ALL ;
294                 mexErrMsgTxt ("accum failed") ;
295             }
296         }
297 
298         // get I
299         p = mxGetFieldByNumber (pargin [1], k, fI) ;
300         if (!GB_mx_mxArray_to_indices (&I, p, &ni, I_range, &ignore))
301         {
302             FREE_ALL ;
303             mexErrMsgTxt ("I failed") ;
304         }
305 
306         // get J
307         p = mxGetFieldByNumber (pargin [1], k, fJ) ;
308         if (!GB_mx_mxArray_to_indices (&J, p, &nj, J_range, &ignore))
309         {
310             FREE_ALL ;
311             mexErrMsgTxt ("J failed") ;
312         }
313 
314         // get desc
315         desc = NULL ;
316         if (fdesc > 0)
317         {
318             p = mxGetFieldByNumber (pargin [1], k, fdesc) ;
319             if (!GB_mx_mxArray_to_Descriptor (&desc, p, "desc"))
320             {
321                 FREE_ALL ;
322                 mexErrMsgTxt ("desc failed") ;
323             }
324         }
325 
326         // restore malloc debugging to test the method
327         GB_Global_malloc_debug_set (save) ; // ]
328 
329         //----------------------------------------------------------------------
330         // C<Mask>(I,J) = A
331         //----------------------------------------------------------------------
332 
333         info = assign ( ) ;
334 
335         GrB_Matrix_free_(&A) ;
336         GrB_Matrix_free_(&Mask) ;
337         GrB_Descriptor_free_(&desc) ;
338 
339         if (info != GrB_SUCCESS)
340         {
341             return (info) ;
342         }
343     }
344 
345     ASSERT_MATRIX_OK (C, "Final C before wait", GB0) ;
346     OK (GrB_Matrix_wait_(&C)) ;
347     return (info) ;
348 }
349 
350 //------------------------------------------------------------------------------
351 // GB_mex_assign mexFunction
352 //------------------------------------------------------------------------------
353 
mexFunction(int nargout,mxArray * pargout[],int nargin,const mxArray * pargin[])354 void mexFunction
355 (
356     int nargout,
357     mxArray *pargout [ ],
358     int nargin,
359     const mxArray *pargin [ ]
360 )
361 {
362 
363     //--------------------------------------------------------------------------
364     // check inputs
365     //--------------------------------------------------------------------------
366 
367     malloc_debug = GB_mx_get_global (true) ;
368     A = NULL ;
369     C = NULL ;
370     Mask = NULL ;
371     desc = NULL ;
372 
373     if (nargout > 1 || ! (nargin == 2 || nargin == 6 || nargin == 7))
374     {
375         mexErrMsgTxt ("Usage: " USAGE) ;
376     }
377 
378     //--------------------------------------------------------------------------
379     // get C (make a deep copy)
380     //--------------------------------------------------------------------------
381 
382     GET_DEEP_COPY ;
383     if (C == NULL)
384     {
385         FREE_ALL ;
386         mexErrMsgTxt ("C failed") ;
387     }
388 
389     if (nargin == 2)
390     {
391 
392         //----------------------------------------------------------------------
393         // get a list of work to do: a struct array of length nwork
394         //----------------------------------------------------------------------
395 
396         // each entry is a struct with fields:
397         // Mask, accum, A, I, J, desc
398 
399         if (!mxIsStruct (pargin [1]))
400         {
401             FREE_ALL ;
402             mexErrMsgTxt ("2nd argument must be a struct") ;
403         }
404 
405         int nwork = mxGetNumberOfElements (pargin [1]) ;
406         int nf = mxGetNumberOfFields (pargin [1]) ;
407         for (int f = 0 ; f < nf ; f++)
408         {
409             mxArray *p ;
410             for (int k = 0 ; k < nwork ; k++)
411             {
412                 p = mxGetFieldByNumber (pargin [1], k, f) ;
413             }
414         }
415 
416         int fA = mxGetFieldNumber (pargin [1], "A") ;
417         int fI = mxGetFieldNumber (pargin [1], "I") ;
418         int fJ = mxGetFieldNumber (pargin [1], "J") ;
419         int faccum = mxGetFieldNumber (pargin [1], "accum") ;
420         int fMask = mxGetFieldNumber (pargin [1], "Mask") ;
421         int fdesc = mxGetFieldNumber (pargin [1], "desc") ;
422 
423         if (fA < 0 || fI < 0 || fJ < 0) mexErrMsgTxt ("A,I,J required") ;
424 
425         METHOD (many_assign (nwork, fA, fI, fJ, faccum, fMask, fdesc, pargin)) ;
426 
427     }
428     else
429     {
430 
431         //----------------------------------------------------------------------
432         // C<Mask>(I,J) = A, with a single assignment
433         //----------------------------------------------------------------------
434 
435         // get Mask (shallow copy)
436         Mask = GB_mx_mxArray_to_Matrix (pargin [1], "Mask", false, false) ;
437         if (Mask == NULL && !mxIsEmpty (pargin [1]))
438         {
439             FREE_ALL ;
440             mexErrMsgTxt ("Mask failed") ;
441         }
442 
443         // get A (shallow copy)
444         A = GB_mx_mxArray_to_Matrix (pargin [3], "A", false, true) ;
445         if (A == NULL)
446         {
447             FREE_ALL ;
448             mexErrMsgTxt ("A failed") ;
449         }
450 
451         // get accum, if present
452         bool user_complex = (Complex != GxB_FC64)
453             && (C->type == Complex || A->type == Complex) ;
454         accum = NULL ;
455         if (!GB_mx_mxArray_to_BinaryOp (&accum, pargin [2], "accum",
456             C->type, user_complex))
457         {
458             FREE_ALL ;
459             mexErrMsgTxt ("accum failed") ;
460         }
461 
462         // get I
463         if (!GB_mx_mxArray_to_indices (&I, pargin [4], &ni, I_range, &ignore))
464         {
465             FREE_ALL ;
466             mexErrMsgTxt ("I failed") ;
467         }
468 
469         // get J
470         if (!GB_mx_mxArray_to_indices (&J, pargin [5], &nj, J_range, &ignore))
471         {
472             FREE_ALL ;
473             mexErrMsgTxt ("J failed") ;
474         }
475 
476         // get desc
477         if (!GB_mx_mxArray_to_Descriptor (&desc, PARGIN (6), "desc"))
478         {
479             FREE_ALL ;
480             mexErrMsgTxt ("desc failed") ;
481         }
482 
483         // C<Mask>(I,J) = A
484 
485         METHOD (assign ( )) ;
486     }
487 
488     //--------------------------------------------------------------------------
489     // return C to MATLAB as a struct
490     //--------------------------------------------------------------------------
491 
492     pargout [0] = GB_mx_Matrix_to_mxArray (&C, "C assign result", true) ;
493     FREE_ALL ;
494 }
495 
496