1 //------------------------------------------------------------------------------
2 // GB_mex_AdotB: compute C=spones(Mask).*(A'*B)
3 //------------------------------------------------------------------------------
4 
5 // SuiteSparse:GraphBLAS, Timothy A. Davis, (c) 2017-2021, All Rights Reserved.
6 // SPDX-License-Identifier: Apache-2.0
7 
8 //------------------------------------------------------------------------------
9 
10 // Returns a plain MATLAB sparse matrix, not a struct.  Only works in double
11 // and complex.  Input matrices must be MATLAB sparse matrices, or GraphBLAS
12 // structs in CSC format.
13 
14 #include "GB_mex.h"
15 
16 #define USAGE "C = GB_mex_AdotB (A,B,Mask,flipxy)"
17 
18 #define FREE_ALL                            \
19 {                                           \
20     GrB_Matrix_free_(&A) ;                  \
21     GrB_Matrix_free_(&Aconj) ;              \
22     GrB_Matrix_free_(&B) ;                  \
23     GrB_Matrix_free_(&C) ;                  \
24     GrB_Matrix_free_(&Mask) ;               \
25     GrB_Monoid_free_(&add) ;                \
26     GrB_Semiring_free_(&semiring) ;         \
27     GB_mx_put_global (true) ;               \
28 }
29 
30 GrB_Matrix A = NULL, B = NULL, C = NULL, Aconj = NULL, Mask = NULL ;
31 GrB_Monoid add = NULL ;
32 GrB_Semiring semiring = NULL ;
33 GrB_Info adotb_complex (GB_Context Context) ;
34 GrB_Info adotb (GB_Context Context) ;
35 GrB_Index anrows, ancols, bnrows, bncols, mnrows, mncols ;
36 bool flipxy = false ;
37 struct GB_Matrix_opaque C_header ;
38 
39 //------------------------------------------------------------------------------
40 
adotb_complex(GB_Context Context)41 GrB_Info adotb_complex (GB_Context Context)
42 {
43     GrB_Info info = GrB_Matrix_new (&Aconj, Complex, anrows, ancols) ;
44     if (info != GrB_SUCCESS) return (info) ;
45     info = GrB_Matrix_apply_(Aconj, NULL, NULL, Complex_conj, A, NULL) ;
46     if (info != GrB_SUCCESS)
47     {
48         GrB_Matrix_free_(&Aconj) ;
49         return (info) ;
50     }
51 
52     // force completion
53     info = GrB_Matrix_wait_(&Aconj) ;
54     if (info != GrB_SUCCESS)
55     {
56         GrB_Matrix_free_(&Aconj) ;
57         return (info) ;
58     }
59 
60     bool mask_applied = false ;
61 
62     GrB_Semiring semiring = Complex_plus_times ;
63 
64     if (Mask != NULL)
65     {
66         // C<M> = A'*B using dot product method
67         info = GB_AxB_dot3 (C, Mask, false, Aconj, B, semiring, flipxy,
68             Context) ;
69         mask_applied = true ;
70     }
71     else
72     {
73         // C = A'*B using dot product method
74         mask_applied = false ;  // no mask to apply
75         info = GB_AxB_dot2 (C, NULL, false, false, Aconj, B, semiring, flipxy,
76             Context) ;
77     }
78 
79     GrB_Matrix_free_(&Aconj) ;
80     return (info) ;
81 }
82 
83 //------------------------------------------------------------------------------
84 
adotb(GB_Context Context)85 GrB_Info adotb (GB_Context Context)
86 {
87     // create the Semiring for regular z += x*y
88     GrB_Info info = GrB_Monoid_new_FP64_(&add, GrB_PLUS_FP64, (double) 0) ;
89     if (info != GrB_SUCCESS) return (info) ;
90     info = GrB_Semiring_new (&semiring, add, GrB_TIMES_FP64) ;
91     if (info != GrB_SUCCESS)
92     {
93         GrB_Monoid_free_(&add) ;
94         return (info) ;
95     }
96     // C = A'*B
97     bool mask_applied = false ;
98 
99     if (Mask != NULL)
100     {
101         // C<M> = A'*B using dot product method
102         info = GB_AxB_dot3 (C, Mask, false, A, B,
103             semiring /* GxB_PLUS_TIMES_FP64 */,
104             flipxy, Context) ;
105         mask_applied = true ;
106     }
107     else
108     {
109         mask_applied = false ;  // no mask to apply
110         info = GB_AxB_dot2 (C, NULL, false, false, A, B,
111             semiring /* GxB_PLUS_TIMES_FP64 */, flipxy, Context) ;
112     }
113 
114     GrB_Monoid_free_(&add) ;
115     GrB_Semiring_free_(&semiring) ;
116     return (info) ;
117 }
118 
119 //------------------------------------------------------------------------------
120 
mexFunction(int nargout,mxArray * pargout[],int nargin,const mxArray * pargin[])121 void mexFunction
122 (
123     int nargout,
124     mxArray *pargout [ ],
125     int nargin,
126     const mxArray *pargin [ ]
127 )
128 {
129 
130     bool malloc_debug = GB_mx_get_global (true) ;
131 
132     GB_CONTEXT (USAGE) ;
133 
134     // check inputs
135     if (nargout > 1 || nargin < 2 || nargin > 4)
136     {
137         mexErrMsgTxt ("Usage: " USAGE) ;
138     }
139 
140     #define GET_DEEP_COPY ;
141     #define FREE_DEEP_COPY ;
142 
143     GET_DEEP_COPY ;
144     // get A and B (shallow copies)
145     A = GB_mx_mxArray_to_Matrix (pargin [0], "A input", false, true) ;
146     B = GB_mx_mxArray_to_Matrix (pargin [1], "B input", false, true) ;
147     if (A == NULL)
148     {
149         FREE_ALL ;
150         mexErrMsgTxt ("A failed") ;
151     }
152     if (B == NULL)
153     {
154         FREE_ALL ;
155         mexErrMsgTxt ("B failed") ;
156     }
157 
158     GrB_Matrix_nrows (&anrows, A) ;
159     GrB_Matrix_ncols (&ancols, A) ;
160 
161     GrB_Matrix_nrows (&bnrows, B) ;
162     GrB_Matrix_ncols (&bncols, B) ;
163 
164     if (!A->is_csc || !B->is_csc)
165     {
166         FREE_ALL ;
167         mexErrMsgTxt ("matrices must be CSC only") ;
168     }
169 
170     // get Mask (shallow copy)
171     if (nargin > 2)
172     {
173         Mask = GB_mx_mxArray_to_Matrix (pargin [2], "Mask input", false, false);
174 
175         GrB_Matrix_nrows (&mnrows, Mask) ;
176         GrB_Matrix_ncols (&mncols, Mask) ;
177 
178         if (!Mask->is_csc)
179         {
180             FREE_ALL ;
181             mexErrMsgTxt ("matrices must be CSC only") ;
182         }
183 
184         if (mnrows != ancols || mncols != bncols)
185         {
186             FREE_ALL ;
187             mexErrMsgTxt ("mask wrong dimension") ;
188         }
189     }
190 
191     if (anrows != bnrows)
192     {
193         FREE_ALL ;
194         mexErrMsgTxt ("inner dimensions of A'*B do not match") ;
195     }
196 
197     if (anrows == 0)
198     {
199         FREE_ALL ;
200         mexErrMsgTxt ("inner dimensions of A'*B must be > 0") ;
201     }
202 
203     // get flipxy
204     GET_SCALAR (3, bool, flipxy, false) ;
205 
206     struct GB_Matrix_opaque C_header ;
207     C = GB_clear_static_header (&C_header) ;
208 
209     if (A->type == Complex)
210     {
211         // C = A'*B, complex case
212         METHOD (adotb_complex (Context)) ;
213     }
214     else
215     {
216         METHOD (adotb (Context)) ;
217     }
218 
219     // return C to MATLAB
220     pargout [0] = GB_mx_Matrix_to_mxArray (&C, "C AdotB result", false) ;
221 
222     FREE_ALL ;
223 }
224 
225