1 //------------------------------------------------------------------------------
2 // GB_AxB_saxpy3_flopcount:  compute flops for GB_AxB_saxpy3
3 //------------------------------------------------------------------------------
4 
5 // SuiteSparse:GraphBLAS, Timothy A. Davis, (c) 2017-2021, All Rights Reserved.
6 // SPDX-License-Identifier: Apache-2.0
7 
8 //------------------------------------------------------------------------------
9 
10 // On input, A, B, and M (optional) are matrices for C=A*B, C<M>=A*B, or
11 // C<!M>=A*B.  The flop count for each B(:,j) is computed, and returned as a
12 // cumulative sum.  This function is CSR/CSC agnostic, but for simplicity of
13 // this description, assume A and B are both CSC matrices, so that ncols(A) ==
14 // nrows(B).  For both CSR and CSC, A->vdim == B->vlen holds.  A and/or B may
15 // be hypersparse, in any combination.
16 
17 // Bflops has size (B->nvec)+1, for both standard and hypersparse B.  Let
18 // n=B->vdim be the column dimension of B (that is, B is m-by-n).
19 
20 // If B is a standard CSC matrix then Bflops has size n+1 == B->nvec+1, and on
21 // output, Bflops [j] is the # of flops required to compute C (:, 0:j-1).  B->h
22 // is NULL, and is implicitly the vector 0:(n-1).
23 
24 // If B is hypersparse, then let Bh = B->h.  Its size is B->nvec, and j = Bh
25 // [kk] is the (kk)th column in the data structure for B.  C will also be
26 // hypersparse, and only C(:,Bh) will be computed (C may have fewer non-empty
27 // columns than B).  On output, Bflops [kk] is the number of needed flops to
28 // compute C (:, Bh [0:kk-1]).
29 
30 // In both cases, Bflops [0] = 0, and Bflops [B->nvec] = total number of flops.
31 // The size of Bflops is B->nvec+1 so that it has the same size as B->p.  The
32 // first entry of B->p and Bflops are both zero.  This allows B to be sliced
33 // either by # of entries in B (by slicing B->p) or by the flop count required
34 // (by slicing Bflops).
35 
36 // This algorithm does not look at the values of M, A, or B, just their
37 // patterns.  The flop count of C=A*B, C<M>=A*B, or C<!M>=A*B is computed for a
38 // saxpy-based method; the work for A'*B for the dot product method is not
39 // computed.
40 
41 // The algorithm scans all nonzeros in B.  It only scans at most the min and
42 // max (first and last) row indices in A and M (if M is present).  If A and M
43 // are not hypersparse, the time taken is O(nnz(B)+n).  If all matrices are
44 // hypersparse, the time is O(nnz(B)*log(h)) where h = max # of vectors present
45 // in A and M.  In pseudo-MATLAB, and assuming B is in standard (not
46 // hypersparse) form:
47 
48 /*
49     [m n] = size (B) ;
50     Bflops = zeros (1,n+1) ;        % (set to zero in the caller)
51     Mwork = 0 ;
52     for each column j in B:
53         if (B (:,j) is empty) continue ;
54         mjnz = nnz (M (:,j))
55         if (M is present, not complemented, and M (:,j) is empty) continue ;
56         Bflops (j) = mjnz if M present and not dense, to scatter M(:,j)
57         Mwork += mjnz
58         for each k where B (k,j) is nonzero:
59             aknz = nnz (A (:,k))
60             if (aknz == 0) continue ;
61             % numerical phase will compute: C(:,j)<#M(:,j)> += A(:,k)*B(k,j)
62             % where #M is no mask, M, or !M.  This typically takes aknz flops,
63             % or with a binary search if nnz(M(:,j)) << nnz(A(:,k)).
64             Bflops (j) += aknz
65         end
66     end
67 */
68 
69 #include "GB_mxm.h"
70 #include "GB_ek_slice.h"
71 #include "GB_bracket.h"
72 #include "GB_AxB_saxpy3.h"
73 
74 #define GB_FREE_ALL                         \
75 {                                           \
76     GB_WERK_POP (Work, int64_t) ;           \
77     GB_WERK_POP (B_ek_slicing, int64_t) ;   \
78 }
79 
80 GB_PUBLIC   // accessed by the MATLAB tests in GraphBLAS/Test only
GB_AxB_saxpy3_flopcount(int64_t * Mwork,int64_t * Bflops,const GrB_Matrix M,const bool Mask_comp,const GrB_Matrix A,const GrB_Matrix B,GB_Context Context)81 GrB_Info GB_AxB_saxpy3_flopcount
82 (
83     int64_t *Mwork,             // amount of work to handle the mask M
84     int64_t *Bflops,            // size B->nvec+1
85     const GrB_Matrix M,         // optional mask matrix
86     const bool Mask_comp,       // if true, mask is complemented
87     const GrB_Matrix A,
88     const GrB_Matrix B,
89     GB_Context Context
90 )
91 {
92 
93     //--------------------------------------------------------------------------
94     // check inputs
95     //--------------------------------------------------------------------------
96 
97     ASSERT_MATRIX_OK_OR_NULL (M, "M for flop count A*B", GB0) ;
98     ASSERT (!GB_ZOMBIES (M)) ;
99     ASSERT (GB_JUMBLED_OK (M)) ;
100     ASSERT (!GB_PENDING (M)) ;
101 
102     ASSERT_MATRIX_OK (A, "A for flop count A*B", GB0) ;
103     ASSERT (!GB_ZOMBIES (A)) ;
104     ASSERT (GB_JUMBLED_OK (A)) ;
105     ASSERT (!GB_PENDING (A)) ;
106 
107     ASSERT_MATRIX_OK (B, "B for flop count A*B", GB0) ;
108     ASSERT (!GB_ZOMBIES (B)) ;
109     ASSERT (GB_JUMBLED_OK (B)) ;
110     ASSERT (!GB_PENDING (B)) ;
111 
112     ASSERT (A->vdim == B->vlen) ;
113     ASSERT (Bflops != NULL) ;
114     ASSERT (Mwork != NULL) ;
115 
116     //--------------------------------------------------------------------------
117     // determine the number of threads to use
118     //--------------------------------------------------------------------------
119 
120     int64_t bnvec = B->nvec ;
121 
122     GB_GET_NTHREADS_MAX (nthreads_max, chunk, Context) ;
123 
124     // clear Bflops
125     GB_memset (Bflops, 0, (bnvec+1) * sizeof (int64_t), nthreads_max) ;
126 
127     //--------------------------------------------------------------------------
128     // get the mask, if present: any sparsity structure
129     //--------------------------------------------------------------------------
130 
131     bool mask_is_M = (M != NULL && !Mask_comp) ;
132     const int64_t *restrict Mp = NULL ;
133     const int64_t *restrict Mh = NULL ;
134     int64_t mnvec = 0 ;
135     int64_t mvlen = 0 ;
136     bool M_is_hyper = GB_IS_HYPERSPARSE (M) ;
137     bool M_is_dense = false ;
138     if (M != NULL)
139     {
140         Mh = M->h ;
141         Mp = M->p ;
142         mnvec = M->nvec ;
143         mvlen = M->vlen ;
144         M_is_dense = GB_is_packed (M) ;
145     }
146 
147     //--------------------------------------------------------------------------
148     // get A and B: any sparsity structure
149     //--------------------------------------------------------------------------
150 
151     const int64_t *restrict Ap = A->p ;
152     const int64_t *restrict Ah = A->h ;
153     const int64_t anvec = A->nvec ;
154     const int64_t avlen = A->vlen ;
155     const bool A_is_hyper = GB_IS_HYPERSPARSE (A) ;
156 
157     const int64_t *restrict Bp = B->p ;
158     const int64_t *restrict Bh = B->h ;
159     const int8_t  *restrict Bb = B->b ;
160     const int64_t *restrict Bi = B->i ;
161     const bool B_is_hyper = GB_IS_HYPERSPARSE (B) ;
162     const bool B_is_bitmap = GB_IS_BITMAP (B) ;
163     const bool B_is_sparse_or_hyper = B_is_hyper || GB_IS_SPARSE (B) ;
164     const int64_t bvlen = B->vlen ;
165     const bool B_jumbled = B->jumbled ;
166 
167     //--------------------------------------------------------------------------
168     // declare workspace
169     //--------------------------------------------------------------------------
170 
171     GB_WERK_DECLARE (Work, int64_t) ;
172     GB_WERK_DECLARE (B_ek_slicing, int64_t) ;
173     int64_t *restrict Wfirst = NULL ;
174     int64_t *restrict Wlast  = NULL ;
175 
176     //--------------------------------------------------------------------------
177     // construct the parallel tasks
178     //--------------------------------------------------------------------------
179 
180     int B_ntasks, B_nthreads ;
181     GB_SLICE_MATRIX (B, 64, chunk) ;
182 
183     //--------------------------------------------------------------------------
184     // allocate workspace
185     //--------------------------------------------------------------------------
186 
187     GB_WERK_PUSH (Work, 2*B_ntasks, int64_t) ;
188     if (Work == NULL)
189     {
190         // out of memory
191         GB_FREE_ALL ;
192         return (GrB_OUT_OF_MEMORY) ;
193     }
194     Wfirst = Work ;
195     Wlast  = Work + B_ntasks ;
196 
197     //--------------------------------------------------------------------------
198     // compute flop counts for C=A*B, C<M>=A*B, or C<!M>=A*B
199     //--------------------------------------------------------------------------
200 
201     int64_t total_Mwork = 0 ;
202     int taskid ;
203     #pragma omp parallel for num_threads(B_nthreads) schedule(dynamic,1) \
204         reduction(+:total_Mwork)
205     for (taskid = 0 ; taskid < B_ntasks ; taskid++)
206     {
207 
208         //----------------------------------------------------------------------
209         // get the task descriptor
210         //----------------------------------------------------------------------
211 
212         int64_t kfirst = kfirst_Bslice [taskid] ;
213         int64_t klast  = klast_Bslice  [taskid] ;
214         Wfirst [taskid] = 0 ;
215         Wlast  [taskid] = 0 ;
216         int64_t mpleft = 0 ;     // for GB_lookup of the mask M
217         int64_t task_Mwork = 0 ;
218 
219         //----------------------------------------------------------------------
220         // count flops for vectors kfirst to klast of B
221         //----------------------------------------------------------------------
222 
223         for (int64_t kk = kfirst ; kk <= klast ; kk++)
224         {
225 
226             // nnz (B (:,j)), for all tasks
227             int64_t bjnz = (Bp == NULL) ? bvlen : (Bp [kk+1] - Bp [kk]) ;
228             // C(:,j) is empty if the entire vector B(:,j) is empty
229             if (bjnz == 0) continue ;
230 
231             //------------------------------------------------------------------
232             // find the part of B(:,j) to be computed by this task
233             //------------------------------------------------------------------
234 
235             int64_t pB, pB_end ;
236             GB_get_pA (&pB, &pB_end, taskid, kk,
237                 kfirst, klast, pstart_Bslice, Bp, bvlen) ;
238             int64_t my_bjnz = pB_end - pB ;
239             int64_t j = GBH (Bh, kk) ;
240 
241             //------------------------------------------------------------------
242             // see if M(:,j) is present and non-empty
243             //------------------------------------------------------------------
244 
245             // if M(:,j) is full, bitmap, or dense, do not add mjnz to bjflops
246             // or task_MWork.
247 
248             int64_t bjflops = (B_is_bitmap) ? my_bjnz : 0 ;
249             int64_t mjnz = 0 ;
250             if (M != NULL && !M_is_dense)
251             {
252                 int64_t mpright = mnvec - 1 ;
253                 int64_t pM, pM_end ;
254                 GB_lookup (M_is_hyper, Mh, Mp, mvlen, &mpleft, mpright, j,
255                     &pM, &pM_end) ;
256                 mjnz = pM_end - pM ;
257                 // If M not complemented: C(:,j) is empty if M(:,j) is empty.
258                 if (mjnz == 0 && !Mask_comp) continue ;
259                 if (mjnz > 0)
260                 {
261                     // M(:,j) not empty
262                     if (pB == GBP (Bp, kk, bvlen))
263                     {
264                         // this task owns the top part of B(:,j), so it can
265                         // account for the work to access M(:,j), without the
266                         // work being duplicated by other tasks working on
267                         // B(:,j)
268                         bjflops = mjnz ;
269                         // keep track of total work spent examining the mask.
270                         // If any B(:,j) is empty, M(:,j) can be ignored.  So
271                         // total_Mwork will be <= nnz (M).
272                         task_Mwork += mjnz ;
273                     }
274                 }
275             }
276             int64_t mjnz_much = 64 * mjnz ;
277 
278             //------------------------------------------------------------------
279             // trim Ah on right
280             //------------------------------------------------------------------
281 
282             // Ah [0..A->nvec-1] holds the set of non-empty vectors of A, but
283             // only vectors k corresponding to nonzero entries B(k,j) are
284             // accessed for this vector B(:,j).  If nnz (B(:,j)) > 2, prune the
285             // search space on the right, so the remaining calls to GB_lookup
286             // will only need to search Ah [pleft...pright-1].  pright does not
287             // change.  pleft is advanced as B(:,j) is traversed, since the
288             // indices in B(:,j) are sorted in ascending order.
289 
290             int64_t pleft = 0 ;
291             int64_t pright = anvec-1 ;
292             if (A_is_hyper && B_is_sparse_or_hyper && my_bjnz > 2 && !B_jumbled)
293             {
294                 // trim Ah [0..pright] to remove any entries past last B(:,j)
295                 int64_t ilast = Bi [pB_end-1] ;
296                 GB_bracket_right (ilast, Ah, 0, &pright) ;
297             }
298 
299             //------------------------------------------------------------------
300             // count the flops to compute C(:,j)<#M(:,j)> = A*B(:,j)
301             //------------------------------------------------------------------
302 
303             // where #M is either not present, M, or !M
304 
305             for ( ; pB < pB_end ; pB++)
306             {
307                 // get B(k,j)
308                 int64_t k = GBI (Bi, pB, bvlen) ;
309                 if (!GBB (Bb, pB)) continue ;
310 
311                 // B(k,j) is nonzero
312 
313                 // find A(:,k), reusing pleft if B is not jumbled
314                 if (B_jumbled)
315                 {
316                     pleft = 0 ;
317                 }
318                 int64_t pA, pA_end ;
319                 GB_lookup (A_is_hyper, Ah, Ap, avlen, &pleft, pright, k,
320                     &pA, &pA_end) ;
321 
322                 // skip if A(:,k) empty
323                 int64_t aknz = pA_end - pA ;
324                 if (aknz == 0) continue ;
325 
326                 double bkjflops ;
327 
328                 // skip if intersection of A(:,k) and M(:,j) is empty
329                 // and mask is not complemented (C<M>=A*B)
330                 if (mask_is_M)
331                 {
332                     // A(:,k) is non-empty; get first and last index of A(:,k)
333                     if (aknz > 256 && mjnz_much < aknz && mjnz < mvlen &&
334                         aknz < avlen && !(A->jumbled))
335                     {
336                         // scan M(:j), and do binary search for A(i,j)
337                         bkjflops = mjnz * (1 + 4 * log2 ((double) aknz)) ;
338                     }
339                     else
340                     {
341                         // scan A(:k), and lookup M(i,j)
342                         bkjflops = aknz ;
343                     }
344                 }
345                 else
346                 {
347                     // A(:,k)*B(k,j) requires aknz flops
348                     bkjflops = aknz ;
349                 }
350 
351                 // increment by flops for the single entry B(k,j)
352                 // C(:,j)<#M(:,j)> += A(:,k)*B(k,j).
353                 bjflops += bkjflops ;
354             }
355 
356             //------------------------------------------------------------------
357             // log the flops for B(:,j)
358             //------------------------------------------------------------------
359 
360             if (kk == kfirst)
361             {
362                 Wfirst [taskid] = bjflops ;
363             }
364             else if (kk == klast)
365             {
366                 Wlast [taskid] = bjflops ;
367             }
368             else
369             {
370                 Bflops [kk] = bjflops ;
371             }
372         }
373 
374         // compute the total work to access the mask, which is <= nnz (M)
375         total_Mwork += task_Mwork ;
376     }
377 
378     //--------------------------------------------------------------------------
379     // reduce the first and last vector of each slice
380     //--------------------------------------------------------------------------
381 
382     // See also Template/GB_select_phase1.c
383 
384     int64_t kprior = -1 ;
385 
386     for (int taskid = 0 ; taskid < B_ntasks ; taskid++)
387     {
388 
389         //----------------------------------------------------------------------
390         // sum up the partial flops that taskid computed for kfirst
391         //----------------------------------------------------------------------
392 
393         int64_t kfirst = kfirst_Bslice [taskid] ;
394         int64_t klast  = klast_Bslice  [taskid] ;
395 
396         if (kfirst <= klast)
397         {
398             int64_t pB = pstart_Bslice [taskid] ;
399             int64_t pB_end = GBP (Bp, kfirst+1, bvlen) ;
400             pB_end = GB_IMIN (pB_end, pstart_Bslice [taskid+1]) ;
401             if (pB < pB_end)
402             {
403                 if (kprior < kfirst)
404                 {
405                     // This task is the first one that did work on
406                     // B(:,kfirst), so use it to start the reduction.
407                     Bflops [kfirst] = Wfirst [taskid] ;
408                 }
409                 else
410                 {
411                     // subsequent task for B(:,kfirst)
412                     Bflops [kfirst] += Wfirst [taskid] ;
413                 }
414                 kprior = kfirst ;
415             }
416         }
417 
418         //----------------------------------------------------------------------
419         // sum up the partial flops that taskid computed for klast
420         //----------------------------------------------------------------------
421 
422         if (kfirst < klast)
423         {
424             int64_t pB = GBP (Bp, klast, bvlen) ;
425             int64_t pB_end = pstart_Bslice [taskid+1] ;
426             if (pB < pB_end)
427             {
428                 /* if */ ASSERT (kprior < klast) ;
429                 {
430                     // This task is the first one that did work on
431                     // B(:,klast), so use it to start the reduction.
432                     Bflops [klast] = Wlast [taskid] ;
433                 }
434                 /*
435                 else
436                 {
437                     // If kfirst < klast and B(:,klast) is not empty,
438                     // then this task is always the first one to do
439                     // work on B(:,klast), so this case is never used.
440                     ASSERT (GB_DEAD_CODE) ;
441                     // subsequent task to work on B(:,klast)
442                     Bflops [klast] += Wlast [taskid] ;
443                 }
444                 */
445                 kprior = klast ;
446             }
447         }
448     }
449 
450     //--------------------------------------------------------------------------
451     // cumulative sum of Bflops
452     //--------------------------------------------------------------------------
453 
454     // Bflops = cumsum ([0 Bflops]) ;
455     ASSERT (Bflops [bnvec] == 0) ;
456     GB_cumsum (Bflops, bnvec, NULL, B_nthreads, Context) ;
457     // Bflops [bnvec] is now the total flop count, including the time to
458     // compute A*B and to handle the mask.  total_Mwork is part of this total
459     // flop count, but is also returned separtely.
460 
461     //--------------------------------------------------------------------------
462     // free workspace and return result
463     //--------------------------------------------------------------------------
464 
465     GB_FREE_ALL ;
466     (*Mwork) = total_Mwork ;
467     return (GrB_SUCCESS) ;
468 }
469 
470