1 //------------------------------------------------------------------------------
2 // GB_mex_kron: C<Mask> = accum(C,kron(A,B))
3 //------------------------------------------------------------------------------
4 
5 // SuiteSparse:GraphBLAS, Timothy A. Davis, (c) 2017-2021, All Rights Reserved.
6 // SPDX-License-Identifier: Apache-2.0
7 
8 //------------------------------------------------------------------------------
9 
10 #include "GB_mex.h"
11 
12 #define USAGE "C = GB_mex_kron (C, Mask, accum, mult, A, B, desc)"
13 
14 #define FREE_ALL                    \
15 {                                   \
16     GrB_Matrix_free_(&A) ;          \
17     GrB_Matrix_free_(&B) ;          \
18     GrB_Matrix_free_(&C) ;          \
19     GrB_Descriptor_free_(&desc) ;   \
20     GrB_Matrix_free_(&Mask) ;       \
21     GB_mx_put_global (true) ;       \
22 }
23 
mexFunction(int nargout,mxArray * pargout[],int nargin,const mxArray * pargin[])24 void mexFunction
25 (
26     int nargout,
27     mxArray *pargout [ ],
28     int nargin,
29     const mxArray *pargin [ ]
30 )
31 {
32 
33     bool malloc_debug = GB_mx_get_global (true) ;
34     GrB_Matrix A = NULL ;
35     GrB_Matrix B = NULL ;
36     GrB_Matrix C = NULL ;
37     GrB_Matrix Mask = NULL ;
38     GrB_Descriptor desc = NULL ;
39     GrB_BinaryOp mult = NULL ;
40 
41     // check inputs
42     if (nargout > 1 || nargin < 6 || nargin > 7)
43     {
44         mexErrMsgTxt ("Usage: " USAGE) ;
45     }
46 
47     // get C (make a deep copy)
48     #define GET_DEEP_COPY \
49         C = GB_mx_mxArray_to_Matrix (pargin [0], "C input", true, true) ;
50     #define FREE_DEEP_COPY GrB_Matrix_free_(&C) ;
51 
52     GET_DEEP_COPY ;
53     if (C == NULL)
54     {
55         FREE_ALL ;
56         mexErrMsgTxt ("C failed") ;
57     }
58 
59     // get Mask (shallow copy)
60     Mask = GB_mx_mxArray_to_Matrix (pargin [1], "Mask", false, false) ;
61     if (Mask == NULL && !mxIsEmpty (pargin [1]))
62     {
63         FREE_ALL ;
64         mexErrMsgTxt ("Mask failed") ;
65     }
66 
67     // get A (shallow copy)
68     A = GB_mx_mxArray_to_Matrix (pargin [4], "A input", false, true) ;
69     if (A == NULL)
70     {
71         FREE_ALL ;
72         mexErrMsgTxt ("A failed") ;
73     }
74 
75     // get B (shallow copy)
76     B = GB_mx_mxArray_to_Matrix (pargin [5], "B input", false, true) ;
77     if (B == NULL)
78     {
79         FREE_ALL ;
80         mexErrMsgTxt ("B failed") ;
81     }
82 
83     // get mult operator
84     bool user_complex = (Complex != GxB_FC64)
85         && (A->type == Complex || B->type == Complex) ;
86     if (!GB_mx_mxArray_to_BinaryOp (&mult, pargin [3], "mult",
87         C->type, user_complex) || mult == NULL)
88     {
89         FREE_ALL ;
90         mexErrMsgTxt ("mult failed") ;
91     }
92 
93     // get accum, if present
94     user_complex = (Complex != GxB_FC64)
95         && (C->type == Complex || mult->ztype == Complex) ;
96     GrB_BinaryOp accum ;
97     if (!GB_mx_mxArray_to_BinaryOp (&accum, pargin [2], "accum",
98         C->type, user_complex))
99     {
100         FREE_ALL ;
101         mexErrMsgTxt ("accum failed") ;
102     }
103 
104     // get desc
105     if (!GB_mx_mxArray_to_Descriptor (&desc, PARGIN (6), "desc"))
106     {
107         FREE_ALL ;
108         mexErrMsgTxt ("desc failed") ;
109     }
110 
111     // test all 3 variants: monoid, semiring, and binary op
112     if (mult == GrB_PLUS_FP64)
113     {
114         // C<Mask> = accum(C,kron(A,B)), monoid variant
115         METHOD (GrB_Matrix_kronecker_Monoid_ (C, Mask, accum, GrB_PLUS_MONOID_FP64,
116             A, B, desc)) ;
117     }
118     else if (mult == GrB_TIMES_FP64)
119     {
120         // C<Mask> = accum(C,kron(A,B)), semiring variant
121         METHOD (GrB_Matrix_kronecker_Semiring_ (C, Mask, accum, GrB_PLUS_TIMES_SEMIRING_FP64,
122             A, B, desc)) ;
123     }
124     else if (mult == GrB_TIMES_FP32)
125     {
126         // C<Mask> = accum(C,kron(A,B)), binary op variant (old name)
127         METHOD (GxB_kron (C, Mask, accum, mult, A, B, desc)) ;
128     }
129     else
130     {
131         // C<Mask> = accum(C,kron(A,B)), binary op variant (new name)
132         METHOD (GrB_Matrix_kronecker_BinaryOp_ (C, Mask, accum, mult, A, B, desc)) ;
133     }
134 
135     // return C to MATLAB as a struct and free the GraphBLAS C
136     pargout [0] = GB_mx_Matrix_to_mxArray (&C, "C output", true) ;
137 
138     FREE_ALL ;
139 }
140 
141