1 //------------------------------------------------------------------------------
2 // GB_mex_mxm_generic: C<Mask> = accum(C,A*B)
3 //------------------------------------------------------------------------------
4 
5 // SuiteSparse:GraphBLAS, Timothy A. Davis, (c) 2017-2021, All Rights Reserved.
6 // SPDX-License-Identifier: Apache-2.0
7 
8 //------------------------------------------------------------------------------
9 
10 #include "GB_mex.h"
11 
12 #define USAGE "C = GB_mex_mxm_generic (C, Mask, accum, semiring, A, B, desc)"
13 
14 #define FREE_ALL                                    \
15 {                                                   \
16     GrB_Matrix_free_(&A) ;                          \
17     GrB_Matrix_free_(&B) ;                          \
18     GrB_Matrix_free_(&C) ;                          \
19     GrB_Matrix_free_(&Mask) ;                       \
20     GrB_Monoid_free_(&myplus_monoid) ;              \
21     GrB_BinaryOp_free_(&myplus) ;                   \
22     if (semiring != Complex_plus_times)             \
23     {                                               \
24         GrB_Semiring_free_(&semiring) ;             \
25     }                                               \
26     GrB_Descriptor_free_(&desc) ;                   \
27     GB_mx_put_global (true) ;                       \
28 }
29 
30 void My_Plus_int64 (void *z, const void *x, const void *y) ;
31 void My_Plus_int32 (void *z, const void *x, const void *y) ;
32 void My_Plus_fp64  (void *z, const void *x, const void *y) ;
33 
My_Plus_int64(void * z,const void * x,const void * y)34 void My_Plus_int64 (void *z, const void *x, const void *y)
35 {
36     int64_t a = (*((int64_t *) x)) ;
37     int64_t b = (*((int64_t *) y)) ;
38     int64_t c = a + b ;
39     (*((int64_t *) z)) = c ;
40 }
41 
My_Plus_int32(void * z,const void * x,const void * y)42 void My_Plus_int32 (void *z, const void *x, const void *y)
43 {
44     int32_t a = (*((int32_t *) x)) ;
45     int32_t b = (*((int32_t *) y)) ;
46     int32_t c = a + b ;
47     (*((int32_t *) z)) = c ;
48 }
49 
My_Plus_fp64(void * z,const void * x,const void * y)50 void My_Plus_fp64  (void *z, const void *x, const void *y)
51 {
52     double a = (*((double *) x)) ;
53     double b = (*((double *) y)) ;
54     double c = a + b ;
55     (*((double *) z)) = c ;
56 }
57 
mexFunction(int nargout,mxArray * pargout[],int nargin,const mxArray * pargin[])58 void mexFunction
59 (
60     int nargout,
61     mxArray *pargout [ ],
62     int nargin,
63     const mxArray *pargin [ ]
64 )
65 {
66 
67     bool malloc_debug = GB_mx_get_global (true) ;
68     GrB_Matrix A = NULL ;
69     GrB_Matrix B = NULL ;
70     GrB_Matrix C = NULL ;
71     GrB_Matrix Mask = NULL ;
72     GrB_Semiring semiring = NULL ;
73     GrB_Descriptor desc = NULL ;
74     GrB_BinaryOp myplus = NULL ;
75     GrB_Monoid   myplus_monoid = NULL ;
76 
77     // check inputs
78     if (nargout > 1 || nargin < 6 || nargin > 7)
79     {
80         mexErrMsgTxt ("Usage: " USAGE) ;
81     }
82 
83     // get C (make a deep copy)
84     #define GET_DEEP_COPY \
85     C = GB_mx_mxArray_to_Matrix (pargin [0], "C input", true, true) ;
86     #define FREE_DEEP_COPY GrB_Matrix_free_(&C) ;
87     GET_DEEP_COPY ;
88     if (C == NULL)
89     {
90         FREE_ALL ;
91         mexErrMsgTxt ("C failed") ;
92     }
93 
94     // get Mask (shallow copy)
95     Mask = GB_mx_mxArray_to_Matrix (pargin [1], "Mask", false, false) ;
96     if (Mask == NULL && !mxIsEmpty (pargin [1]))
97     {
98         FREE_ALL ;
99         mexErrMsgTxt ("Mask failed") ;
100     }
101 
102     // get A (shallow copy)
103     A = GB_mx_mxArray_to_Matrix (pargin [4], "A input", false, true) ;
104     if (A == NULL)
105     {
106         FREE_ALL ;
107         mexErrMsgTxt ("A failed") ;
108     }
109 
110     // get B (shallow copy)
111     B = GB_mx_mxArray_to_Matrix (pargin [5], "B input", false, true) ;
112     if (B == NULL)
113     {
114         FREE_ALL ;
115         mexErrMsgTxt ("B failed") ;
116     }
117 
118     bool user_complex = (Complex != GxB_FC64) && (C->type == Complex) ;
119 
120     // get semiring
121     if (!GB_mx_mxArray_to_Semiring (&semiring, pargin [3], "semiring",
122         C->type, user_complex))
123     {
124         FREE_ALL ;
125         mexErrMsgTxt ("semiring failed") ;
126     }
127 
128     if (semiring != NULL && semiring->add == GrB_PLUS_MONOID_INT64)
129     {
130         // replace the semiring with a user-defined monoid
131         GrB_BinaryOp mult = semiring->multiply ;
132         GrB_Monoid_free_(&(semiring->add)) ;
133         GrB_Semiring_free_(&semiring) ;
134         GrB_BinaryOp_new (&myplus, My_Plus_int64,
135             GrB_INT64, GrB_INT64, GrB_INT64) ;
136         // add a spurious terminal value
137         GxB_Monoid_terminal_new_INT64 (&myplus_monoid, myplus,
138             (int64_t) 0, (int64_t) -111) ;
139         GrB_Semiring_new (&semiring, myplus_monoid, mult) ;
140     }
141     else if (semiring != NULL && semiring->add == GrB_PLUS_MONOID_INT32)
142     {
143         // replace the semiring with a user-defined monoid
144         GrB_BinaryOp mult = semiring->multiply ;
145         GrB_Monoid_free_(&(semiring->add)) ;
146         GrB_Semiring_free_(&semiring) ;
147         GrB_BinaryOp_new (&myplus, My_Plus_int32,
148             GrB_INT32, GrB_INT32, GrB_INT32) ;
149         // add a spurious terminal value
150         GxB_Monoid_terminal_new_INT32 (&myplus_monoid, myplus,
151             (int32_t) 0, (int32_t) -111) ;
152         GrB_Semiring_new (&semiring, myplus_monoid, mult) ;
153     }
154     else if (semiring != NULL && semiring->add == GrB_PLUS_MONOID_FP64)
155     {
156         // replace the semiring with a user-defined monoid
157         GrB_BinaryOp mult = semiring->multiply ;
158         GrB_Monoid_free_(&(semiring->add)) ;
159         GrB_Semiring_free_(&semiring) ;
160         GrB_BinaryOp_new (&myplus, My_Plus_fp64,
161             GrB_FP64, GrB_FP64, GrB_FP64) ;
162         GrB_Monoid_new_FP64 (&myplus_monoid, myplus, (double) 0) ;
163         GrB_Semiring_new (&semiring, myplus_monoid, mult) ;
164     }
165 
166     // get accum, if present
167     GrB_BinaryOp accum ;
168     if (!GB_mx_mxArray_to_BinaryOp (&accum, pargin [2], "accum",
169         C->type, user_complex))
170     {
171         FREE_ALL ;
172         mexErrMsgTxt ("accum failed") ;
173     }
174 
175     // get desc
176     if (!GB_mx_mxArray_to_Descriptor (&desc, PARGIN (6), "desc"))
177     {
178         FREE_ALL ;
179         mexErrMsgTxt ("desc failed") ;
180     }
181 
182     // C<Mask> = accum(C,A*B)
183     METHOD (GrB_mxm (C, Mask, accum, semiring, A, B, desc)) ;
184 
185     // return C to MATLAB as a struct and free the GraphBLAS C
186     pargout [0] = GB_mx_Matrix_to_mxArray (&C, "C output from GrB_mxm", true) ;
187 
188     FREE_ALL ;
189 }
190 
191