1 //------------------------------------------------------------------------------
2 // GB_mex_kron: C<Mask> = accum(C,kron(A,B))
3 //------------------------------------------------------------------------------
4
5 // SuiteSparse:GraphBLAS, Timothy A. Davis, (c) 2017-2021, All Rights Reserved.
6 // SPDX-License-Identifier: Apache-2.0
7
8 //------------------------------------------------------------------------------
9
10 #include "GB_mex.h"
11
12 #define USAGE "C = GB_mex_kron (C, Mask, accum, mult, A, B, desc)"
13
14 #define FREE_ALL \
15 { \
16 GrB_Matrix_free_(&A) ; \
17 GrB_Matrix_free_(&B) ; \
18 GrB_Matrix_free_(&C) ; \
19 GrB_Descriptor_free_(&desc) ; \
20 GrB_Matrix_free_(&Mask) ; \
21 GB_mx_put_global (true) ; \
22 }
23
mexFunction(int nargout,mxArray * pargout[],int nargin,const mxArray * pargin[])24 void mexFunction
25 (
26 int nargout,
27 mxArray *pargout [ ],
28 int nargin,
29 const mxArray *pargin [ ]
30 )
31 {
32
33 bool malloc_debug = GB_mx_get_global (true) ;
34 GrB_Matrix A = NULL ;
35 GrB_Matrix B = NULL ;
36 GrB_Matrix C = NULL ;
37 GrB_Matrix Mask = NULL ;
38 GrB_Descriptor desc = NULL ;
39 GrB_BinaryOp mult = NULL ;
40
41 // check inputs
42 if (nargout > 1 || nargin < 6 || nargin > 7)
43 {
44 mexErrMsgTxt ("Usage: " USAGE) ;
45 }
46
47 // get C (make a deep copy)
48 #define GET_DEEP_COPY \
49 C = GB_mx_mxArray_to_Matrix (pargin [0], "C input", true, true) ;
50 #define FREE_DEEP_COPY GrB_Matrix_free_(&C) ;
51
52 GET_DEEP_COPY ;
53 if (C == NULL)
54 {
55 FREE_ALL ;
56 mexErrMsgTxt ("C failed") ;
57 }
58
59 // get Mask (shallow copy)
60 Mask = GB_mx_mxArray_to_Matrix (pargin [1], "Mask", false, false) ;
61 if (Mask == NULL && !mxIsEmpty (pargin [1]))
62 {
63 FREE_ALL ;
64 mexErrMsgTxt ("Mask failed") ;
65 }
66
67 // get A (shallow copy)
68 A = GB_mx_mxArray_to_Matrix (pargin [4], "A input", false, true) ;
69 if (A == NULL)
70 {
71 FREE_ALL ;
72 mexErrMsgTxt ("A failed") ;
73 }
74
75 // get B (shallow copy)
76 B = GB_mx_mxArray_to_Matrix (pargin [5], "B input", false, true) ;
77 if (B == NULL)
78 {
79 FREE_ALL ;
80 mexErrMsgTxt ("B failed") ;
81 }
82
83 // get mult operator
84 bool user_complex = (Complex != GxB_FC64)
85 && (A->type == Complex || B->type == Complex) ;
86 if (!GB_mx_mxArray_to_BinaryOp (&mult, pargin [3], "mult",
87 C->type, user_complex) || mult == NULL)
88 {
89 FREE_ALL ;
90 mexErrMsgTxt ("mult failed") ;
91 }
92
93 // get accum, if present
94 user_complex = (Complex != GxB_FC64)
95 && (C->type == Complex || mult->ztype == Complex) ;
96 GrB_BinaryOp accum ;
97 if (!GB_mx_mxArray_to_BinaryOp (&accum, pargin [2], "accum",
98 C->type, user_complex))
99 {
100 FREE_ALL ;
101 mexErrMsgTxt ("accum failed") ;
102 }
103
104 // get desc
105 if (!GB_mx_mxArray_to_Descriptor (&desc, PARGIN (6), "desc"))
106 {
107 FREE_ALL ;
108 mexErrMsgTxt ("desc failed") ;
109 }
110
111 // test all 3 variants: monoid, semiring, and binary op
112 if (mult == GrB_PLUS_FP64)
113 {
114 // C<Mask> = accum(C,kron(A,B)), monoid variant
115 METHOD (GrB_Matrix_kronecker_Monoid_ (C, Mask, accum, GrB_PLUS_MONOID_FP64,
116 A, B, desc)) ;
117 }
118 else if (mult == GrB_TIMES_FP64)
119 {
120 // C<Mask> = accum(C,kron(A,B)), semiring variant
121 METHOD (GrB_Matrix_kronecker_Semiring_ (C, Mask, accum, GrB_PLUS_TIMES_SEMIRING_FP64,
122 A, B, desc)) ;
123 }
124 else if (mult == GrB_TIMES_FP32)
125 {
126 // C<Mask> = accum(C,kron(A,B)), binary op variant (old name)
127 METHOD (GxB_kron (C, Mask, accum, mult, A, B, desc)) ;
128 }
129 else
130 {
131 // C<Mask> = accum(C,kron(A,B)), binary op variant (new name)
132 METHOD (GrB_Matrix_kronecker_BinaryOp_ (C, Mask, accum, mult, A, B, desc)) ;
133 }
134
135 // return C to MATLAB as a struct and free the GraphBLAS C
136 pargout [0] = GB_mx_Matrix_to_mxArray (&C, "C output", true) ;
137
138 FREE_ALL ;
139 }
140
141