1 //------------------------------------------------------------------------------
2 // GB_mex_mxm_flops: compute flops to do C=A*B, C<M>=A*B or C<!M>=A*B
3 //------------------------------------------------------------------------------
4 
5 // SuiteSparse:GraphBLAS, Timothy A. Davis, (c) 2017-2021, All Rights Reserved.
6 // SPDX-License-Identifier: Apache-2.0
7 
8 //------------------------------------------------------------------------------
9 
10 #include "GB_mex.h"
11 
12 #define USAGE "[bflops mwork] = GB_mex_mxm_flops (M, Mask_comp, A, B)"
13 
14 #define FREE_ALL                            \
15 {                                           \
16     GrB_Matrix_free_(&A) ;                  \
17     GrB_Matrix_free_(&B) ;                  \
18     GrB_Matrix_free_(&M) ;                  \
19     GB_mx_put_global (true) ;               \
20 }
21 
mexFunction(int nargout,mxArray * pargout[],int nargin,const mxArray * pargin[])22 void mexFunction
23 (
24     int nargout,
25     mxArray *pargout [ ],
26     int nargin,
27     const mxArray *pargin [ ]
28 )
29 {
30 
31     bool malloc_debug = GB_mx_get_global (true) ;
32     GrB_Matrix A = NULL ;
33     GrB_Matrix B = NULL ;
34     GrB_Matrix M = NULL ;
35 
36     // check inputs
37     GB_CONTEXT (USAGE) ;
38     if (nargout > 2 || nargin != 4)
39     {
40         mexErrMsgTxt ("Usage: " USAGE) ;
41     }
42 
43     // get M (shallow copy)
44     M = GB_mx_mxArray_to_Matrix (pargin [0], "M", false, false) ;
45     if (M == NULL && !mxIsEmpty (pargin [0]))
46     {
47         FREE_ALL ;
48         mexErrMsgTxt ("M failed") ;
49     }
50 
51     // get Mask_comp
52     bool GET_SCALAR (1, bool, Mask_comp, 0) ;
53 
54     // get A (shallow copy)
55     A = GB_mx_mxArray_to_Matrix (pargin [2], "A", false, true) ;
56     if (A == NULL)
57     {
58         FREE_ALL ;
59         mexErrMsgTxt ("A failed") ;
60     }
61 
62     // get B (shallow copy)
63     B = GB_mx_mxArray_to_Matrix (pargin [3], "B", false, true) ;
64     if (B == NULL)
65     {
66         FREE_ALL ;
67         mexErrMsgTxt ("B failed") ;
68     }
69 
70     // allocate Bflops (note the calloc)
71     int64_t bnvec = B->nvec ;
72     size_t bfsize = (bnvec+1) * sizeof (int64_t) ;
73     int64_t *Bflops = mxMalloc (bfsize) ;
74     memset (Bflops, 0, bfsize) ;
75 
76     // compute the flop count
77     int64_t Mwork = 0 ;
78 
79     GB_AxB_saxpy3_flopcount (&Mwork, Bflops, M, Mask_comp, A, B, Context) ;
80 
81     // return result to MATLAB
82     pargout [0] = mxCreateDoubleMatrix (1, bnvec+1, mxREAL) ;
83     double *Bflops_matlab = mxGetPr (pargout [0]) ;
84     for (int64_t kk = 0 ; kk <= bnvec ; kk++)
85     {
86         Bflops_matlab [kk] = (double) Bflops [kk] ;
87     }
88 
89     pargout [1] = mxCreateDoubleScalar (Mwork) ;
90     mxFree (Bflops) ;
91     FREE_ALL ;
92 }
93 
94