1 //------------------------------------------------------------------------------
2 // GB_mex_op: apply a built-in GraphBLAS operator to MATLAB arrays
3 //------------------------------------------------------------------------------
4
5 // SuiteSparse:GraphBLAS, Timothy A. Davis, (c) 2017-2021, All Rights Reserved.
6 // SPDX-License-Identifier: Apache-2.0
7
8 //------------------------------------------------------------------------------
9
10 // Usage:
11
12 // Z = GB_mex_op (op, X, Y)
13 // Z = GB_mex_op (op, X)
14
15 // Apply a built-in GraphBLAS operator or a user-defined Complex operator to
16 // one or two arrays X and Y of any MATLAB logical or numeric type. X and Y
17 // are first typecasted into the x and y operand types of the op. The output Z
18 // has the same type as the z type of the op.
19
20 #include "GB_mex.h"
21
22 #define USAGE "Z = GB_mex_op (opname, X, Y, cover)"
23
24 #define FREE_ALL \
25 { \
26 GB_mx_put_global (do_cover) ; \
27 }
28
mexFunction(int nargout,mxArray * pargout[],int nargin,const mxArray * pargin[])29 void mexFunction
30 (
31 int nargout,
32 mxArray *pargout [ ],
33 int nargin,
34 const mxArray *pargin [ ]
35 )
36 {
37
38 GB_void *X = NULL, *Y = NULL, *Z = NULL ;
39 GrB_Type X_type = NULL, Y_type = NULL ;
40 int64_t nrows = 0, ncols = 0, nx = 0, ny = 0, nrows2 = 0, ncols2 = 0 ;
41 size_t Y_size = 1 ;
42
43 bool do_cover = (nargin == 4) ;
44 bool malloc_debug = GB_mx_get_global (do_cover) ;
45
46 // if Y is char and cover present, treat as if nargin == 2
47 if (do_cover)
48 {
49 if (mxIsChar (pargin [2]))
50 {
51 nargin = 2 ;
52 }
53 }
54
55 //--------------------------------------------------------------------------
56 // check inputs
57 //--------------------------------------------------------------------------
58
59 if (nargout > 1 || nargin < 2 || nargin > 4)
60 {
61 mexErrMsgTxt ("Usage: " USAGE) ;
62 }
63
64 //--------------------------------------------------------------------------
65 // get op; default type is the same type as X
66 //--------------------------------------------------------------------------
67
68 GrB_UnaryOp op1 = NULL ;
69 GrB_BinaryOp op2 = NULL ;
70 GrB_Type op_ztype = NULL, op_xtype, op_ytype ;
71 size_t op_zsize, op_xsize, op_ysize ;
72 GrB_Type xtype = GB_mx_Type (pargin [1]) ;
73
74 // check for complex case
75 bool XisComplex = mxIsComplex (pargin [1]) ;
76 bool YisComplex = (nargin > 2) ? mxIsComplex (pargin [2]) : false ;
77 bool user_complex = (Complex != GxB_FC64) && (XisComplex || YisComplex) ;
78
79 if (nargin > 2)
80 {
81 // get a binary op
82 if (!GB_mx_mxArray_to_BinaryOp (&op2, pargin [0], "GB_mex_op",
83 xtype, user_complex) || op2 == NULL)
84 {
85 FREE_ALL ;
86 mexErrMsgTxt ("binary op missing") ;
87 }
88 op_ztype = op2->ztype ; op_zsize = op_ztype->size ;
89 op_xtype = op2->xtype ; op_xsize = op_xtype->size ;
90 op_ytype = op2->ytype ; op_ysize = op_ytype->size ;
91 ASSERT_BINARYOP_OK (op2, "binary op", GB0) ;
92 if (GB_OP_IS_POSITIONAL (op2))
93 {
94 FREE_ALL ;
95 mexErrMsgTxt ("binary positional op not supported") ;
96 }
97 }
98 else
99 {
100 // get a unary op
101 if (!GB_mx_mxArray_to_UnaryOp (&op1, pargin [0], "GB_mex_op",
102 xtype, user_complex) || op1 == NULL)
103 {
104 FREE_ALL ;
105 mexErrMsgTxt ("unary op missing") ;
106 }
107 op_ztype = op1->ztype ; op_zsize = op_ztype->size ;
108 op_xtype = op1->xtype ; op_xsize = op_xtype->size ;
109 op_ytype = NULL ; op_ysize = 1 ;
110 ASSERT_UNARYOP_OK (op1, "unary op", GB0) ;
111 if (GB_OP_IS_POSITIONAL (op1))
112 {
113 FREE_ALL ;
114 mexErrMsgTxt ("unary positional op not supported") ;
115 }
116 }
117
118 ASSERT_TYPE_OK (op_ztype, "Z type", GB0) ;
119
120 //--------------------------------------------------------------------------
121 // get X
122 //--------------------------------------------------------------------------
123
124 GB_mx_mxArray_to_array (pargin [1], &X, &nrows, &ncols, &X_type) ;
125 nx = nrows * ncols ;
126 if (X_type == NULL)
127 {
128 FREE_ALL ;
129 mexErrMsgTxt ("X must be numeric") ;
130 }
131 ASSERT_TYPE_OK (X_type, "X type", GB0) ;
132 size_t X_size = X_type->size ;
133
134 if (!GB_Type_compatible (op_xtype, X_type))
135 {
136 FREE_ALL ;
137 mexErrMsgTxt ("op xtype not compatible with X") ;
138 }
139
140 //--------------------------------------------------------------------------
141 // get Y
142 //--------------------------------------------------------------------------
143
144 if (nargin > 2)
145 {
146 GB_mx_mxArray_to_array (pargin [2], &Y, &nrows2, &ncols2, &Y_type) ;
147 ny = nrows2 * ncols2 ;
148 if (nrows2 != nrows || ncols2 != ncols)
149 {
150 FREE_ALL ;
151 mexErrMsgTxt ("X and Y must be the same size") ;
152 }
153 if (Y_type == NULL)
154 {
155 FREE_ALL ;
156 mexErrMsgTxt ("Y must be numeric") ;
157 }
158 ASSERT_TYPE_OK (Y_type, "Y type", GB0) ;
159 Y_size = Y_type->size ;
160
161 if (!GB_Type_compatible (op_ytype, Y_type))
162 {
163 FREE_ALL ;
164 mexErrMsgTxt ("op ytype not compatible with Y") ;
165 }
166 }
167
168 //--------------------------------------------------------------------------
169 // create Z of the same type as op_ztype
170 //--------------------------------------------------------------------------
171
172 pargout [0] = GB_mx_create_full (nrows, ncols, op_ztype) ;
173 Z = mxGetData (pargout [0]) ;
174
175 //--------------------------------------------------------------------------
176 // get scalar workspace
177 //--------------------------------------------------------------------------
178
179 char xwork [GB_VLA (op_xsize)] ;
180 char ywork [GB_VLA (op_ysize)] ;
181
182 GB_cast_function cast_X = GB_cast_factory (op_xtype->code, X_type->code) ;
183
184 //--------------------------------------------------------------------------
185 // do the op
186 //--------------------------------------------------------------------------
187
188 if (nargin > 2)
189 {
190 // Z = f (X,Y)
191 GxB_binary_function f_binary = op2->function ;
192
193 GB_cast_function cast_Y = GB_cast_factory (op_ytype->code,Y_type->code);
194 for (int64_t k = 0 ; k < nx ; k++)
195 {
196 cast_X (xwork, X +(k*X_size), X_size) ;
197 cast_Y (ywork, Y +(k*Y_size), Y_size) ;
198 // printf ("x: ") ; GB_code_check (op_xtype->code,xwork,3,NULL) ;
199 // printf ("\ny: ") ; GB_code_check (op_ytype->code,ywork,3,NULL) ;
200 f_binary (Z +(k*op_zsize), xwork, ywork) ;
201 // printf ("\nz: ") ; GB_code_check (op_ztype->code,
202 // Z +(k*op_zsize), 3, NULL) ; printf ("\n") ;
203 }
204
205 }
206 else
207 {
208 // Z = f (X)
209 GxB_unary_function f_unary = op1->function ;
210 for (int64_t k = 0 ; k < nx ; k++)
211 {
212 cast_X (xwork, X +(k*X_size), X_size) ;
213 f_unary (Z +(k*op_zsize), xwork) ;
214 }
215 }
216
217 //--------------------------------------------------------------------------
218 // free workspace and return to MATLAB
219 //--------------------------------------------------------------------------
220
221 FREE_ALL ;
222 }
223
224