1 //------------------------------------------------------------------------------
2 // GB_mex_AdotB: compute C=spones(Mask).*(A'*B)
3 //------------------------------------------------------------------------------
4
5 // SuiteSparse:GraphBLAS, Timothy A. Davis, (c) 2017-2021, All Rights Reserved.
6 // SPDX-License-Identifier: Apache-2.0
7
8 //------------------------------------------------------------------------------
9
10 // Returns a plain MATLAB sparse matrix, not a struct. Only works in double
11 // and complex. Input matrices must be MATLAB sparse matrices, or GraphBLAS
12 // structs in CSC format.
13
14 #include "GB_mex.h"
15
16 #define USAGE "C = GB_mex_AdotB (A,B,Mask,flipxy)"
17
18 #define FREE_ALL \
19 { \
20 GrB_Matrix_free_(&A) ; \
21 GrB_Matrix_free_(&Aconj) ; \
22 GrB_Matrix_free_(&B) ; \
23 GrB_Matrix_free_(&C) ; \
24 GrB_Matrix_free_(&Mask) ; \
25 GrB_Monoid_free_(&add) ; \
26 GrB_Semiring_free_(&semiring) ; \
27 GB_mx_put_global (true) ; \
28 }
29
30 GrB_Matrix A = NULL, B = NULL, C = NULL, Aconj = NULL, Mask = NULL ;
31 GrB_Monoid add = NULL ;
32 GrB_Semiring semiring = NULL ;
33 GrB_Info adotb_complex (GB_Context Context) ;
34 GrB_Info adotb (GB_Context Context) ;
35 GrB_Index anrows, ancols, bnrows, bncols, mnrows, mncols ;
36 bool flipxy = false ;
37 struct GB_Matrix_opaque C_header ;
38
39 //------------------------------------------------------------------------------
40
adotb_complex(GB_Context Context)41 GrB_Info adotb_complex (GB_Context Context)
42 {
43 GrB_Info info = GrB_Matrix_new (&Aconj, Complex, anrows, ancols) ;
44 if (info != GrB_SUCCESS) return (info) ;
45 info = GrB_Matrix_apply_(Aconj, NULL, NULL, Complex_conj, A, NULL) ;
46 if (info != GrB_SUCCESS)
47 {
48 GrB_Matrix_free_(&Aconj) ;
49 return (info) ;
50 }
51
52 // force completion
53 info = GrB_Matrix_wait_(&Aconj) ;
54 if (info != GrB_SUCCESS)
55 {
56 GrB_Matrix_free_(&Aconj) ;
57 return (info) ;
58 }
59
60 bool mask_applied = false ;
61
62 GrB_Semiring semiring = Complex_plus_times ;
63
64 if (Mask != NULL)
65 {
66 // C<M> = A'*B using dot product method
67 info = GB_AxB_dot3 (C, Mask, false, Aconj, B, semiring, flipxy,
68 Context) ;
69 mask_applied = true ;
70 }
71 else
72 {
73 // C = A'*B using dot product method
74 mask_applied = false ; // no mask to apply
75 info = GB_AxB_dot2 (C, NULL, false, false, Aconj, B, semiring, flipxy,
76 Context) ;
77 }
78
79 GrB_Matrix_free_(&Aconj) ;
80 return (info) ;
81 }
82
83 //------------------------------------------------------------------------------
84
adotb(GB_Context Context)85 GrB_Info adotb (GB_Context Context)
86 {
87 // create the Semiring for regular z += x*y
88 GrB_Info info = GrB_Monoid_new_FP64_(&add, GrB_PLUS_FP64, (double) 0) ;
89 if (info != GrB_SUCCESS) return (info) ;
90 info = GrB_Semiring_new (&semiring, add, GrB_TIMES_FP64) ;
91 if (info != GrB_SUCCESS)
92 {
93 GrB_Monoid_free_(&add) ;
94 return (info) ;
95 }
96 // C = A'*B
97 bool mask_applied = false ;
98
99 if (Mask != NULL)
100 {
101 // C<M> = A'*B using dot product method
102 info = GB_AxB_dot3 (C, Mask, false, A, B,
103 semiring /* GxB_PLUS_TIMES_FP64 */,
104 flipxy, Context) ;
105 mask_applied = true ;
106 }
107 else
108 {
109 mask_applied = false ; // no mask to apply
110 info = GB_AxB_dot2 (C, NULL, false, false, A, B,
111 semiring /* GxB_PLUS_TIMES_FP64 */, flipxy, Context) ;
112 }
113
114 GrB_Monoid_free_(&add) ;
115 GrB_Semiring_free_(&semiring) ;
116 return (info) ;
117 }
118
119 //------------------------------------------------------------------------------
120
mexFunction(int nargout,mxArray * pargout[],int nargin,const mxArray * pargin[])121 void mexFunction
122 (
123 int nargout,
124 mxArray *pargout [ ],
125 int nargin,
126 const mxArray *pargin [ ]
127 )
128 {
129
130 bool malloc_debug = GB_mx_get_global (true) ;
131
132 GB_CONTEXT (USAGE) ;
133
134 // check inputs
135 if (nargout > 1 || nargin < 2 || nargin > 4)
136 {
137 mexErrMsgTxt ("Usage: " USAGE) ;
138 }
139
140 #define GET_DEEP_COPY ;
141 #define FREE_DEEP_COPY ;
142
143 GET_DEEP_COPY ;
144 // get A and B (shallow copies)
145 A = GB_mx_mxArray_to_Matrix (pargin [0], "A input", false, true) ;
146 B = GB_mx_mxArray_to_Matrix (pargin [1], "B input", false, true) ;
147 if (A == NULL)
148 {
149 FREE_ALL ;
150 mexErrMsgTxt ("A failed") ;
151 }
152 if (B == NULL)
153 {
154 FREE_ALL ;
155 mexErrMsgTxt ("B failed") ;
156 }
157
158 GrB_Matrix_nrows (&anrows, A) ;
159 GrB_Matrix_ncols (&ancols, A) ;
160
161 GrB_Matrix_nrows (&bnrows, B) ;
162 GrB_Matrix_ncols (&bncols, B) ;
163
164 if (!A->is_csc || !B->is_csc)
165 {
166 FREE_ALL ;
167 mexErrMsgTxt ("matrices must be CSC only") ;
168 }
169
170 // get Mask (shallow copy)
171 if (nargin > 2)
172 {
173 Mask = GB_mx_mxArray_to_Matrix (pargin [2], "Mask input", false, false);
174
175 GrB_Matrix_nrows (&mnrows, Mask) ;
176 GrB_Matrix_ncols (&mncols, Mask) ;
177
178 if (!Mask->is_csc)
179 {
180 FREE_ALL ;
181 mexErrMsgTxt ("matrices must be CSC only") ;
182 }
183
184 if (mnrows != ancols || mncols != bncols)
185 {
186 FREE_ALL ;
187 mexErrMsgTxt ("mask wrong dimension") ;
188 }
189 }
190
191 if (anrows != bnrows)
192 {
193 FREE_ALL ;
194 mexErrMsgTxt ("inner dimensions of A'*B do not match") ;
195 }
196
197 if (anrows == 0)
198 {
199 FREE_ALL ;
200 mexErrMsgTxt ("inner dimensions of A'*B must be > 0") ;
201 }
202
203 // get flipxy
204 GET_SCALAR (3, bool, flipxy, false) ;
205
206 struct GB_Matrix_opaque C_header ;
207 C = GB_clear_static_header (&C_header) ;
208
209 if (A->type == Complex)
210 {
211 // C = A'*B, complex case
212 METHOD (adotb_complex (Context)) ;
213 }
214 else
215 {
216 METHOD (adotb (Context)) ;
217 }
218
219 // return C to MATLAB
220 pargout [0] = GB_mx_Matrix_to_mxArray (&C, "C AdotB result", false) ;
221
222 FREE_ALL ;
223 }
224
225