1 //------------------------------------------------------------------------------
2 // GB_mex_mxm_generic: C<Mask> = accum(C,A*B)
3 //------------------------------------------------------------------------------
4
5 // SuiteSparse:GraphBLAS, Timothy A. Davis, (c) 2017-2021, All Rights Reserved.
6 // SPDX-License-Identifier: Apache-2.0
7
8 //------------------------------------------------------------------------------
9
10 #include "GB_mex.h"
11
12 #define USAGE "C = GB_mex_mxm_generic (C, Mask, accum, semiring, A, B, desc)"
13
14 #define FREE_ALL \
15 { \
16 GrB_Matrix_free_(&A) ; \
17 GrB_Matrix_free_(&B) ; \
18 GrB_Matrix_free_(&C) ; \
19 GrB_Matrix_free_(&Mask) ; \
20 GrB_Monoid_free_(&myplus_monoid) ; \
21 GrB_BinaryOp_free_(&myplus) ; \
22 if (semiring != Complex_plus_times) \
23 { \
24 GrB_Semiring_free_(&semiring) ; \
25 } \
26 GrB_Descriptor_free_(&desc) ; \
27 GB_mx_put_global (true) ; \
28 }
29
30 void My_Plus_int64 (void *z, const void *x, const void *y) ;
31 void My_Plus_int32 (void *z, const void *x, const void *y) ;
32 void My_Plus_fp64 (void *z, const void *x, const void *y) ;
33
My_Plus_int64(void * z,const void * x,const void * y)34 void My_Plus_int64 (void *z, const void *x, const void *y)
35 {
36 int64_t a = (*((int64_t *) x)) ;
37 int64_t b = (*((int64_t *) y)) ;
38 int64_t c = a + b ;
39 (*((int64_t *) z)) = c ;
40 }
41
My_Plus_int32(void * z,const void * x,const void * y)42 void My_Plus_int32 (void *z, const void *x, const void *y)
43 {
44 int32_t a = (*((int32_t *) x)) ;
45 int32_t b = (*((int32_t *) y)) ;
46 int32_t c = a + b ;
47 (*((int32_t *) z)) = c ;
48 }
49
My_Plus_fp64(void * z,const void * x,const void * y)50 void My_Plus_fp64 (void *z, const void *x, const void *y)
51 {
52 double a = (*((double *) x)) ;
53 double b = (*((double *) y)) ;
54 double c = a + b ;
55 (*((double *) z)) = c ;
56 }
57
mexFunction(int nargout,mxArray * pargout[],int nargin,const mxArray * pargin[])58 void mexFunction
59 (
60 int nargout,
61 mxArray *pargout [ ],
62 int nargin,
63 const mxArray *pargin [ ]
64 )
65 {
66
67 bool malloc_debug = GB_mx_get_global (true) ;
68 GrB_Matrix A = NULL ;
69 GrB_Matrix B = NULL ;
70 GrB_Matrix C = NULL ;
71 GrB_Matrix Mask = NULL ;
72 GrB_Semiring semiring = NULL ;
73 GrB_Descriptor desc = NULL ;
74 GrB_BinaryOp myplus = NULL ;
75 GrB_Monoid myplus_monoid = NULL ;
76
77 // check inputs
78 if (nargout > 1 || nargin < 6 || nargin > 7)
79 {
80 mexErrMsgTxt ("Usage: " USAGE) ;
81 }
82
83 // get C (make a deep copy)
84 #define GET_DEEP_COPY \
85 C = GB_mx_mxArray_to_Matrix (pargin [0], "C input", true, true) ;
86 #define FREE_DEEP_COPY GrB_Matrix_free_(&C) ;
87 GET_DEEP_COPY ;
88 if (C == NULL)
89 {
90 FREE_ALL ;
91 mexErrMsgTxt ("C failed") ;
92 }
93
94 // get Mask (shallow copy)
95 Mask = GB_mx_mxArray_to_Matrix (pargin [1], "Mask", false, false) ;
96 if (Mask == NULL && !mxIsEmpty (pargin [1]))
97 {
98 FREE_ALL ;
99 mexErrMsgTxt ("Mask failed") ;
100 }
101
102 // get A (shallow copy)
103 A = GB_mx_mxArray_to_Matrix (pargin [4], "A input", false, true) ;
104 if (A == NULL)
105 {
106 FREE_ALL ;
107 mexErrMsgTxt ("A failed") ;
108 }
109
110 // get B (shallow copy)
111 B = GB_mx_mxArray_to_Matrix (pargin [5], "B input", false, true) ;
112 if (B == NULL)
113 {
114 FREE_ALL ;
115 mexErrMsgTxt ("B failed") ;
116 }
117
118 bool user_complex = (Complex != GxB_FC64) && (C->type == Complex) ;
119
120 // get semiring
121 if (!GB_mx_mxArray_to_Semiring (&semiring, pargin [3], "semiring",
122 C->type, user_complex))
123 {
124 FREE_ALL ;
125 mexErrMsgTxt ("semiring failed") ;
126 }
127
128 if (semiring != NULL && semiring->add == GrB_PLUS_MONOID_INT64)
129 {
130 // replace the semiring with a user-defined monoid
131 GrB_BinaryOp mult = semiring->multiply ;
132 GrB_Monoid_free_(&(semiring->add)) ;
133 GrB_Semiring_free_(&semiring) ;
134 GrB_BinaryOp_new (&myplus, My_Plus_int64,
135 GrB_INT64, GrB_INT64, GrB_INT64) ;
136 // add a spurious terminal value
137 GxB_Monoid_terminal_new_INT64 (&myplus_monoid, myplus,
138 (int64_t) 0, (int64_t) -111) ;
139 GrB_Semiring_new (&semiring, myplus_monoid, mult) ;
140 }
141 else if (semiring != NULL && semiring->add == GrB_PLUS_MONOID_INT32)
142 {
143 // replace the semiring with a user-defined monoid
144 GrB_BinaryOp mult = semiring->multiply ;
145 GrB_Monoid_free_(&(semiring->add)) ;
146 GrB_Semiring_free_(&semiring) ;
147 GrB_BinaryOp_new (&myplus, My_Plus_int32,
148 GrB_INT32, GrB_INT32, GrB_INT32) ;
149 // add a spurious terminal value
150 GxB_Monoid_terminal_new_INT32 (&myplus_monoid, myplus,
151 (int32_t) 0, (int32_t) -111) ;
152 GrB_Semiring_new (&semiring, myplus_monoid, mult) ;
153 }
154 else if (semiring != NULL && semiring->add == GrB_PLUS_MONOID_FP64)
155 {
156 // replace the semiring with a user-defined monoid
157 GrB_BinaryOp mult = semiring->multiply ;
158 GrB_Monoid_free_(&(semiring->add)) ;
159 GrB_Semiring_free_(&semiring) ;
160 GrB_BinaryOp_new (&myplus, My_Plus_fp64,
161 GrB_FP64, GrB_FP64, GrB_FP64) ;
162 GrB_Monoid_new_FP64 (&myplus_monoid, myplus, (double) 0) ;
163 GrB_Semiring_new (&semiring, myplus_monoid, mult) ;
164 }
165
166 // get accum, if present
167 GrB_BinaryOp accum ;
168 if (!GB_mx_mxArray_to_BinaryOp (&accum, pargin [2], "accum",
169 C->type, user_complex))
170 {
171 FREE_ALL ;
172 mexErrMsgTxt ("accum failed") ;
173 }
174
175 // get desc
176 if (!GB_mx_mxArray_to_Descriptor (&desc, PARGIN (6), "desc"))
177 {
178 FREE_ALL ;
179 mexErrMsgTxt ("desc failed") ;
180 }
181
182 // C<Mask> = accum(C,A*B)
183 METHOD (GrB_mxm (C, Mask, accum, semiring, A, B, desc)) ;
184
185 // return C to MATLAB as a struct and free the GraphBLAS C
186 pargout [0] = GB_mx_Matrix_to_mxArray (&C, "C output from GrB_mxm", true) ;
187
188 FREE_ALL ;
189 }
190
191