1 //------------------------------------------------------------------------------
2 // GB_mex_mxm_flops: compute flops to do C=A*B, C<M>=A*B or C<!M>=A*B
3 //------------------------------------------------------------------------------
4
5 // SuiteSparse:GraphBLAS, Timothy A. Davis, (c) 2017-2021, All Rights Reserved.
6 // SPDX-License-Identifier: Apache-2.0
7
8 //------------------------------------------------------------------------------
9
10 #include "GB_mex.h"
11
12 #define USAGE "[bflops mwork] = GB_mex_mxm_flops (M, Mask_comp, A, B)"
13
14 #define FREE_ALL \
15 { \
16 GrB_Matrix_free_(&A) ; \
17 GrB_Matrix_free_(&B) ; \
18 GrB_Matrix_free_(&M) ; \
19 GB_mx_put_global (true) ; \
20 }
21
mexFunction(int nargout,mxArray * pargout[],int nargin,const mxArray * pargin[])22 void mexFunction
23 (
24 int nargout,
25 mxArray *pargout [ ],
26 int nargin,
27 const mxArray *pargin [ ]
28 )
29 {
30
31 bool malloc_debug = GB_mx_get_global (true) ;
32 GrB_Matrix A = NULL ;
33 GrB_Matrix B = NULL ;
34 GrB_Matrix M = NULL ;
35
36 // check inputs
37 GB_CONTEXT (USAGE) ;
38 if (nargout > 2 || nargin != 4)
39 {
40 mexErrMsgTxt ("Usage: " USAGE) ;
41 }
42
43 // get M (shallow copy)
44 M = GB_mx_mxArray_to_Matrix (pargin [0], "M", false, false) ;
45 if (M == NULL && !mxIsEmpty (pargin [0]))
46 {
47 FREE_ALL ;
48 mexErrMsgTxt ("M failed") ;
49 }
50
51 // get Mask_comp
52 bool GET_SCALAR (1, bool, Mask_comp, 0) ;
53
54 // get A (shallow copy)
55 A = GB_mx_mxArray_to_Matrix (pargin [2], "A", false, true) ;
56 if (A == NULL)
57 {
58 FREE_ALL ;
59 mexErrMsgTxt ("A failed") ;
60 }
61
62 // get B (shallow copy)
63 B = GB_mx_mxArray_to_Matrix (pargin [3], "B", false, true) ;
64 if (B == NULL)
65 {
66 FREE_ALL ;
67 mexErrMsgTxt ("B failed") ;
68 }
69
70 // allocate Bflops (note the calloc)
71 int64_t bnvec = B->nvec ;
72 size_t bfsize = (bnvec+1) * sizeof (int64_t) ;
73 int64_t *Bflops = mxMalloc (bfsize) ;
74 memset (Bflops, 0, bfsize) ;
75
76 // compute the flop count
77 int64_t Mwork = 0 ;
78
79 GB_AxB_saxpy3_flopcount (&Mwork, Bflops, M, Mask_comp, A, B, Context) ;
80
81 // return result to MATLAB
82 pargout [0] = mxCreateDoubleMatrix (1, bnvec+1, mxREAL) ;
83 double *Bflops_matlab = mxGetPr (pargout [0]) ;
84 for (int64_t kk = 0 ; kk <= bnvec ; kk++)
85 {
86 Bflops_matlab [kk] = (double) Bflops [kk] ;
87 }
88
89 pargout [1] = mxCreateDoubleScalar (Mwork) ;
90 mxFree (Bflops) ;
91 FREE_ALL ;
92 }
93
94