1 //------------------------------------------------------------------------------
2 // GB_mex_AxB: compute C=A*B, A'*B, A*B', or A'*B'
3 //------------------------------------------------------------------------------
4 
5 // SuiteSparse:GraphBLAS, Timothy A. Davis, (c) 2017-2021, All Rights Reserved.
6 // SPDX-License-Identifier: Apache-2.0
7 
8 //------------------------------------------------------------------------------
9 
10 // This is for testing only.  See GrB_mxm instead.  Returns a plain MATLAB
11 // matrix, in double.
12 
13 #include "GB_mex.h"
14 
15 #define USAGE "C = GB_mex_AxB (A, B, atranspose, btranspose, axb_method)"
16 
17 #define FREE_ALL                                \
18 {                                               \
19     GrB_Matrix_free_(&A) ;                      \
20     GrB_Matrix_free_(&Aconj) ;                  \
21     GrB_Matrix_free_(&B) ;                      \
22     GrB_Matrix_free_(&Bconj) ;                  \
23     GrB_Matrix_free_(&C) ;                      \
24     GrB_Matrix_free_(&Mask) ;                   \
25     GrB_Monoid_free_(&add) ;                    \
26     GrB_Semiring_free_(&semiring) ;             \
27     GB_mx_put_global (true) ;                   \
28 }
29 
30 //------------------------------------------------------------------------------
31 
32 GrB_Info info ;
33 bool malloc_debug = false ;
34 bool ignore = false, ignore1 = false, ignore2 = false ;
35 bool atranspose = false ;
36 bool btranspose = false ;
37 GrB_Matrix A = NULL, B = NULL, C = NULL, Aconj = NULL, Bconj = NULL,
38     Mask = NULL ;
39 GrB_Monoid add = NULL ;
40 GrB_Semiring semiring = NULL ;
41 int64_t anrows = 0 ;
42 int64_t ancols = 0 ;
43 int64_t bnrows = 0 ;
44 int64_t bncols = 0 ;
45 struct GB_Matrix_opaque C_header ;
46 
47 GrB_Desc_Value AxB_method = GxB_DEFAULT ;
48 
49 GrB_Info axb (GB_Context Context) ;
50 GrB_Info axb_complex (GB_Context Context) ;
51 
52 //------------------------------------------------------------------------------
53 
axb(GB_Context Context)54 GrB_Info axb (GB_Context Context)
55 {
56 
57     // create the Semiring for regular z += x*y
58     info = GrB_Monoid_new_FP64_(&add, GrB_PLUS_FP64, (double) 0) ;
59     if (info != GrB_SUCCESS) return (info) ;
60 
61     info = GrB_Semiring_new (&semiring, add, GrB_TIMES_FP64) ;
62     if (info != GrB_SUCCESS)
63     {
64         GrB_Monoid_free_(&add) ;
65         return (info) ;
66     }
67 
68     struct GB_Matrix_opaque MT_header ;
69     GrB_Matrix MT = GB_clear_static_header (&MT_header) ;
70 
71     // C = A*B, A'*B, A*B', or A'*B'
72     info = GB_AxB_meta (C, NULL,
73         false,      // C_replace
74         true,       // CSC
75         MT,         // no MT returned
76         &ignore1,   // M_transposed will be false
77         NULL,       // no Mask
78         false,      // mask not complemented
79         false,      // mask not structural
80         NULL,       // no accum
81         A,
82         B,
83         semiring,   // GrB_PLUS_TIMES_FP64
84         atranspose,
85         btranspose,
86         false,      // flipxy
87         &ignore,    // mask_applied
88         &ignore2,   // done_in_place
89         AxB_method,
90         true,       // do the sort
91         Context) ;
92 
93     GrB_Monoid_free_(&add) ;
94     GrB_Semiring_free_(&semiring) ;
95 
96     return (info) ;
97 }
98 
99 //------------------------------------------------------------------------------
100 
axb_complex(GB_Context Context)101 GrB_Info axb_complex (GB_Context Context)
102 {
103 
104     // C = A*B, complex case
105 
106     Aconj = NULL ;
107     Bconj = NULL ;
108 
109     if (atranspose)
110     {
111         // Aconj = A
112         info = GrB_Matrix_new (&Aconj, Complex, A->vlen, A->vdim) ;
113         if (info != GrB_SUCCESS) return (info) ;
114         info = GrB_Matrix_apply_(Aconj, NULL, NULL, Complex_conj, A, NULL) ;
115         if (info != GrB_SUCCESS)
116         {
117             GrB_Matrix_free_(&Aconj) ;
118             return (info) ;
119         }
120     }
121 
122     if (btranspose)
123     {
124         // Bconj = B
125         info = GrB_Matrix_new (&Bconj, Complex, B->vlen, B->vdim) ;
126         if (info != GrB_SUCCESS)
127         {
128             GrB_Matrix_free_(&Aconj) ;
129             return (info) ;
130         }
131 
132         info = GrB_Matrix_apply_(Bconj, NULL, NULL, Complex_conj, B, NULL) ;
133         if (info != GrB_SUCCESS)
134         {
135             GrB_Matrix_free_(&Bconj) ;
136             GrB_Matrix_free_(&Aconj) ;
137             return (info) ;
138         }
139 
140     }
141 
142     // force completion
143     if (Aconj != NULL)
144     {
145         info = GrB_Matrix_wait_(&Aconj) ;
146         if (info != GrB_SUCCESS)
147         {
148             GrB_Matrix_free_(&Aconj) ;
149             GrB_Matrix_free_(&Bconj) ;
150             return (info) ;
151         }
152     }
153 
154     if (Bconj != NULL)
155     {
156         info = GrB_Matrix_wait_(&Bconj) ;
157         if (info != GrB_SUCCESS)
158         {
159             GrB_Matrix_free_(&Aconj) ;
160             GrB_Matrix_free_(&Bconj) ;
161             return (info) ;
162         }
163     }
164 
165     struct GB_Matrix_opaque MT_header ;
166     GrB_Matrix MT = GB_clear_static_header (&MT_header) ;
167 
168     info = GB_AxB_meta (C, NULL,
169         false,      // C_replace
170         true,       // CSC
171         MT,         // no MT returned
172         &ignore1,   // M_transposed will be false
173         NULL,       // no Mask
174         false,      // mask not complemented
175         false,      // mask not structural
176         NULL,       // no accum
177         (atranspose) ? Aconj : A,
178         (btranspose) ? Bconj : B,
179         Complex_plus_times,
180         atranspose,
181         btranspose,
182         false,      // flipxy
183         &ignore,    // mask_applied
184         &ignore2,   // done_in_place
185         AxB_method,
186         true,       // do the sort
187         Context) ;
188 
189     GrB_Matrix_free_(&Bconj) ;
190     GrB_Matrix_free_(&Aconj) ;
191 
192     return (info) ;
193 }
194 
195 //------------------------------------------------------------------------------
196 
mexFunction(int nargout,mxArray * pargout[],int nargin,const mxArray * pargin[])197 void mexFunction
198 (
199     int nargout,
200     mxArray *pargout [ ],
201     int nargin,
202     const mxArray *pargin [ ]
203 )
204 {
205 
206     info = GrB_SUCCESS ;
207     malloc_debug = GB_mx_get_global (true) ;
208     ignore = false ;
209     ignore1 = false ;
210     ignore2 = false ;
211     A = NULL ;
212     B = NULL ;
213     C = NULL ;
214     Aconj = NULL ;
215     Bconj = NULL ;
216     Mask = NULL ;
217     add = NULL ;
218     semiring = NULL ;
219 
220     GB_CONTEXT (USAGE) ;
221 
222     // check inputs
223     if (nargout > 1 || nargin < 2 || nargin > 5)
224     {
225         mexErrMsgTxt ("Usage: " USAGE) ;
226     }
227 
228     #define GET_DEEP_COPY ;
229     #define FREE_DEEP_COPY ;
230 
231     // get A and B
232     A = GB_mx_mxArray_to_Matrix (pargin [0], "A", false, true) ;
233     B = GB_mx_mxArray_to_Matrix (pargin [1], "B", false, true) ;
234     if (A == NULL || B == NULL)
235     {
236         FREE_ALL ;
237         mexErrMsgTxt ("failed") ;
238     }
239 
240     if (!A->is_csc || !B->is_csc)
241     {
242         mexErrMsgTxt ("A and B must be in CSC format") ;
243     }
244 
245     // get the atranspose option
246     GET_SCALAR (2, bool, atranspose, false) ;
247 
248     // get the btranspose option
249     GET_SCALAR (3, bool, btranspose, false) ;
250 
251     // get the axb_method
252     // 0 or not present: default
253     // 1001: Gustavson
254     // 1003: dot
255     // 1004: hash
256     // 1005: saxpy
257     GET_SCALAR (4, GrB_Desc_Value, AxB_method, GxB_DEFAULT) ;
258 
259     if (! ((AxB_method == GxB_DEFAULT) ||
260         (AxB_method == GxB_AxB_GUSTAVSON) ||
261         (AxB_method == GxB_AxB_HASH) ||
262         (AxB_method == GxB_AxB_DOT)))
263     {
264         mexErrMsgTxt ("unknown method") ;
265     }
266 
267     // determine the dimensions
268     anrows = (atranspose) ? GB_NCOLS (A) : GB_NROWS (A) ;
269     ancols = (atranspose) ? GB_NROWS (A) : GB_NCOLS (A) ;
270     bnrows = (btranspose) ? GB_NCOLS (B) : GB_NROWS (B) ;
271     bncols = (btranspose) ? GB_NROWS (B) : GB_NCOLS (B) ;
272     if (ancols != bnrows)
273     {
274         FREE_ALL ;
275         mexErrMsgTxt ("invalid dimensions") ;
276     }
277 
278     C = GB_clear_static_header (&C_header) ;
279 
280     if (A->type == Complex)
281     {
282         METHOD (axb_complex (Context)) ;
283     }
284     else
285     {
286         METHOD (axb (Context)) ;
287     }
288 
289     // return C to MATLAB
290     pargout [0] = GB_mx_Matrix_to_mxArray (&C, "C AxB result", false) ;
291 
292     FREE_ALL ;
293 }
294 
295