1 //------------------------------------------------------------------------------
2 // GB_AxB_saxpy3_flopcount: compute flops for GB_AxB_saxpy3
3 //------------------------------------------------------------------------------
4
5 // SuiteSparse:GraphBLAS, Timothy A. Davis, (c) 2017-2021, All Rights Reserved.
6 // SPDX-License-Identifier: Apache-2.0
7
8 //------------------------------------------------------------------------------
9
10 // On input, A, B, and M (optional) are matrices for C=A*B, C<M>=A*B, or
11 // C<!M>=A*B. The flop count for each B(:,j) is computed, and returned as a
12 // cumulative sum. This function is CSR/CSC agnostic, but for simplicity of
13 // this description, assume A and B are both CSC matrices, so that ncols(A) ==
14 // nrows(B). For both CSR and CSC, A->vdim == B->vlen holds. A and/or B may
15 // be hypersparse, in any combination.
16
17 // Bflops has size (B->nvec)+1, for both standard and hypersparse B. Let
18 // n=B->vdim be the column dimension of B (that is, B is m-by-n).
19
20 // If B is a standard CSC matrix then Bflops has size n+1 == B->nvec+1, and on
21 // output, Bflops [j] is the # of flops required to compute C (:, 0:j-1). B->h
22 // is NULL, and is implicitly the vector 0:(n-1).
23
24 // If B is hypersparse, then let Bh = B->h. Its size is B->nvec, and j = Bh
25 // [kk] is the (kk)th column in the data structure for B. C will also be
26 // hypersparse, and only C(:,Bh) will be computed (C may have fewer non-empty
27 // columns than B). On output, Bflops [kk] is the number of needed flops to
28 // compute C (:, Bh [0:kk-1]).
29
30 // In both cases, Bflops [0] = 0, and Bflops [B->nvec] = total number of flops.
31 // The size of Bflops is B->nvec+1 so that it has the same size as B->p. The
32 // first entry of B->p and Bflops are both zero. This allows B to be sliced
33 // either by # of entries in B (by slicing B->p) or by the flop count required
34 // (by slicing Bflops).
35
36 // This algorithm does not look at the values of M, A, or B, just their
37 // patterns. The flop count of C=A*B, C<M>=A*B, or C<!M>=A*B is computed for a
38 // saxpy-based method; the work for A'*B for the dot product method is not
39 // computed.
40
41 // The algorithm scans all nonzeros in B. It only scans at most the min and
42 // max (first and last) row indices in A and M (if M is present). If A and M
43 // are not hypersparse, the time taken is O(nnz(B)+n). If all matrices are
44 // hypersparse, the time is O(nnz(B)*log(h)) where h = max # of vectors present
45 // in A and M. In pseudo-MATLAB, and assuming B is in standard (not
46 // hypersparse) form:
47
48 /*
49 [m n] = size (B) ;
50 Bflops = zeros (1,n+1) ; % (set to zero in the caller)
51 Mwork = 0 ;
52 for each column j in B:
53 if (B (:,j) is empty) continue ;
54 mjnz = nnz (M (:,j))
55 if (M is present, not complemented, and M (:,j) is empty) continue ;
56 Bflops (j) = mjnz if M present and not dense, to scatter M(:,j)
57 Mwork += mjnz
58 for each k where B (k,j) is nonzero:
59 aknz = nnz (A (:,k))
60 if (aknz == 0) continue ;
61 % numerical phase will compute: C(:,j)<#M(:,j)> += A(:,k)*B(k,j)
62 % where #M is no mask, M, or !M. This typically takes aknz flops,
63 % or with a binary search if nnz(M(:,j)) << nnz(A(:,k)).
64 Bflops (j) += aknz
65 end
66 end
67 */
68
69 #include "GB_mxm.h"
70 #include "GB_ek_slice.h"
71 #include "GB_bracket.h"
72 #include "GB_AxB_saxpy3.h"
73
74 #define GB_FREE_ALL \
75 { \
76 GB_WERK_POP (Work, int64_t) ; \
77 GB_WERK_POP (B_ek_slicing, int64_t) ; \
78 }
79
80 GB_PUBLIC // accessed by the MATLAB tests in GraphBLAS/Test only
GB_AxB_saxpy3_flopcount(int64_t * Mwork,int64_t * Bflops,const GrB_Matrix M,const bool Mask_comp,const GrB_Matrix A,const GrB_Matrix B,GB_Context Context)81 GrB_Info GB_AxB_saxpy3_flopcount
82 (
83 int64_t *Mwork, // amount of work to handle the mask M
84 int64_t *Bflops, // size B->nvec+1
85 const GrB_Matrix M, // optional mask matrix
86 const bool Mask_comp, // if true, mask is complemented
87 const GrB_Matrix A,
88 const GrB_Matrix B,
89 GB_Context Context
90 )
91 {
92
93 //--------------------------------------------------------------------------
94 // check inputs
95 //--------------------------------------------------------------------------
96
97 ASSERT_MATRIX_OK_OR_NULL (M, "M for flop count A*B", GB0) ;
98 ASSERT (!GB_ZOMBIES (M)) ;
99 ASSERT (GB_JUMBLED_OK (M)) ;
100 ASSERT (!GB_PENDING (M)) ;
101
102 ASSERT_MATRIX_OK (A, "A for flop count A*B", GB0) ;
103 ASSERT (!GB_ZOMBIES (A)) ;
104 ASSERT (GB_JUMBLED_OK (A)) ;
105 ASSERT (!GB_PENDING (A)) ;
106
107 ASSERT_MATRIX_OK (B, "B for flop count A*B", GB0) ;
108 ASSERT (!GB_ZOMBIES (B)) ;
109 ASSERT (GB_JUMBLED_OK (B)) ;
110 ASSERT (!GB_PENDING (B)) ;
111
112 ASSERT (A->vdim == B->vlen) ;
113 ASSERT (Bflops != NULL) ;
114 ASSERT (Mwork != NULL) ;
115
116 //--------------------------------------------------------------------------
117 // determine the number of threads to use
118 //--------------------------------------------------------------------------
119
120 int64_t bnvec = B->nvec ;
121
122 GB_GET_NTHREADS_MAX (nthreads_max, chunk, Context) ;
123
124 // clear Bflops
125 GB_memset (Bflops, 0, (bnvec+1) * sizeof (int64_t), nthreads_max) ;
126
127 //--------------------------------------------------------------------------
128 // get the mask, if present: any sparsity structure
129 //--------------------------------------------------------------------------
130
131 bool mask_is_M = (M != NULL && !Mask_comp) ;
132 const int64_t *restrict Mp = NULL ;
133 const int64_t *restrict Mh = NULL ;
134 int64_t mnvec = 0 ;
135 int64_t mvlen = 0 ;
136 bool M_is_hyper = GB_IS_HYPERSPARSE (M) ;
137 bool M_is_dense = false ;
138 if (M != NULL)
139 {
140 Mh = M->h ;
141 Mp = M->p ;
142 mnvec = M->nvec ;
143 mvlen = M->vlen ;
144 M_is_dense = GB_is_packed (M) ;
145 }
146
147 //--------------------------------------------------------------------------
148 // get A and B: any sparsity structure
149 //--------------------------------------------------------------------------
150
151 const int64_t *restrict Ap = A->p ;
152 const int64_t *restrict Ah = A->h ;
153 const int64_t anvec = A->nvec ;
154 const int64_t avlen = A->vlen ;
155 const bool A_is_hyper = GB_IS_HYPERSPARSE (A) ;
156
157 const int64_t *restrict Bp = B->p ;
158 const int64_t *restrict Bh = B->h ;
159 const int8_t *restrict Bb = B->b ;
160 const int64_t *restrict Bi = B->i ;
161 const bool B_is_hyper = GB_IS_HYPERSPARSE (B) ;
162 const bool B_is_bitmap = GB_IS_BITMAP (B) ;
163 const bool B_is_sparse_or_hyper = B_is_hyper || GB_IS_SPARSE (B) ;
164 const int64_t bvlen = B->vlen ;
165 const bool B_jumbled = B->jumbled ;
166
167 //--------------------------------------------------------------------------
168 // declare workspace
169 //--------------------------------------------------------------------------
170
171 GB_WERK_DECLARE (Work, int64_t) ;
172 GB_WERK_DECLARE (B_ek_slicing, int64_t) ;
173 int64_t *restrict Wfirst = NULL ;
174 int64_t *restrict Wlast = NULL ;
175
176 //--------------------------------------------------------------------------
177 // construct the parallel tasks
178 //--------------------------------------------------------------------------
179
180 int B_ntasks, B_nthreads ;
181 GB_SLICE_MATRIX (B, 64, chunk) ;
182
183 //--------------------------------------------------------------------------
184 // allocate workspace
185 //--------------------------------------------------------------------------
186
187 GB_WERK_PUSH (Work, 2*B_ntasks, int64_t) ;
188 if (Work == NULL)
189 {
190 // out of memory
191 GB_FREE_ALL ;
192 return (GrB_OUT_OF_MEMORY) ;
193 }
194 Wfirst = Work ;
195 Wlast = Work + B_ntasks ;
196
197 //--------------------------------------------------------------------------
198 // compute flop counts for C=A*B, C<M>=A*B, or C<!M>=A*B
199 //--------------------------------------------------------------------------
200
201 int64_t total_Mwork = 0 ;
202 int taskid ;
203 #pragma omp parallel for num_threads(B_nthreads) schedule(dynamic,1) \
204 reduction(+:total_Mwork)
205 for (taskid = 0 ; taskid < B_ntasks ; taskid++)
206 {
207
208 //----------------------------------------------------------------------
209 // get the task descriptor
210 //----------------------------------------------------------------------
211
212 int64_t kfirst = kfirst_Bslice [taskid] ;
213 int64_t klast = klast_Bslice [taskid] ;
214 Wfirst [taskid] = 0 ;
215 Wlast [taskid] = 0 ;
216 int64_t mpleft = 0 ; // for GB_lookup of the mask M
217 int64_t task_Mwork = 0 ;
218
219 //----------------------------------------------------------------------
220 // count flops for vectors kfirst to klast of B
221 //----------------------------------------------------------------------
222
223 for (int64_t kk = kfirst ; kk <= klast ; kk++)
224 {
225
226 // nnz (B (:,j)), for all tasks
227 int64_t bjnz = (Bp == NULL) ? bvlen : (Bp [kk+1] - Bp [kk]) ;
228 // C(:,j) is empty if the entire vector B(:,j) is empty
229 if (bjnz == 0) continue ;
230
231 //------------------------------------------------------------------
232 // find the part of B(:,j) to be computed by this task
233 //------------------------------------------------------------------
234
235 int64_t pB, pB_end ;
236 GB_get_pA (&pB, &pB_end, taskid, kk,
237 kfirst, klast, pstart_Bslice, Bp, bvlen) ;
238 int64_t my_bjnz = pB_end - pB ;
239 int64_t j = GBH (Bh, kk) ;
240
241 //------------------------------------------------------------------
242 // see if M(:,j) is present and non-empty
243 //------------------------------------------------------------------
244
245 // if M(:,j) is full, bitmap, or dense, do not add mjnz to bjflops
246 // or task_MWork.
247
248 int64_t bjflops = (B_is_bitmap) ? my_bjnz : 0 ;
249 int64_t mjnz = 0 ;
250 if (M != NULL && !M_is_dense)
251 {
252 int64_t mpright = mnvec - 1 ;
253 int64_t pM, pM_end ;
254 GB_lookup (M_is_hyper, Mh, Mp, mvlen, &mpleft, mpright, j,
255 &pM, &pM_end) ;
256 mjnz = pM_end - pM ;
257 // If M not complemented: C(:,j) is empty if M(:,j) is empty.
258 if (mjnz == 0 && !Mask_comp) continue ;
259 if (mjnz > 0)
260 {
261 // M(:,j) not empty
262 if (pB == GBP (Bp, kk, bvlen))
263 {
264 // this task owns the top part of B(:,j), so it can
265 // account for the work to access M(:,j), without the
266 // work being duplicated by other tasks working on
267 // B(:,j)
268 bjflops = mjnz ;
269 // keep track of total work spent examining the mask.
270 // If any B(:,j) is empty, M(:,j) can be ignored. So
271 // total_Mwork will be <= nnz (M).
272 task_Mwork += mjnz ;
273 }
274 }
275 }
276 int64_t mjnz_much = 64 * mjnz ;
277
278 //------------------------------------------------------------------
279 // trim Ah on right
280 //------------------------------------------------------------------
281
282 // Ah [0..A->nvec-1] holds the set of non-empty vectors of A, but
283 // only vectors k corresponding to nonzero entries B(k,j) are
284 // accessed for this vector B(:,j). If nnz (B(:,j)) > 2, prune the
285 // search space on the right, so the remaining calls to GB_lookup
286 // will only need to search Ah [pleft...pright-1]. pright does not
287 // change. pleft is advanced as B(:,j) is traversed, since the
288 // indices in B(:,j) are sorted in ascending order.
289
290 int64_t pleft = 0 ;
291 int64_t pright = anvec-1 ;
292 if (A_is_hyper && B_is_sparse_or_hyper && my_bjnz > 2 && !B_jumbled)
293 {
294 // trim Ah [0..pright] to remove any entries past last B(:,j)
295 int64_t ilast = Bi [pB_end-1] ;
296 GB_bracket_right (ilast, Ah, 0, &pright) ;
297 }
298
299 //------------------------------------------------------------------
300 // count the flops to compute C(:,j)<#M(:,j)> = A*B(:,j)
301 //------------------------------------------------------------------
302
303 // where #M is either not present, M, or !M
304
305 for ( ; pB < pB_end ; pB++)
306 {
307 // get B(k,j)
308 int64_t k = GBI (Bi, pB, bvlen) ;
309 if (!GBB (Bb, pB)) continue ;
310
311 // B(k,j) is nonzero
312
313 // find A(:,k), reusing pleft if B is not jumbled
314 if (B_jumbled)
315 {
316 pleft = 0 ;
317 }
318 int64_t pA, pA_end ;
319 GB_lookup (A_is_hyper, Ah, Ap, avlen, &pleft, pright, k,
320 &pA, &pA_end) ;
321
322 // skip if A(:,k) empty
323 int64_t aknz = pA_end - pA ;
324 if (aknz == 0) continue ;
325
326 double bkjflops ;
327
328 // skip if intersection of A(:,k) and M(:,j) is empty
329 // and mask is not complemented (C<M>=A*B)
330 if (mask_is_M)
331 {
332 // A(:,k) is non-empty; get first and last index of A(:,k)
333 if (aknz > 256 && mjnz_much < aknz && mjnz < mvlen &&
334 aknz < avlen && !(A->jumbled))
335 {
336 // scan M(:j), and do binary search for A(i,j)
337 bkjflops = mjnz * (1 + 4 * log2 ((double) aknz)) ;
338 }
339 else
340 {
341 // scan A(:k), and lookup M(i,j)
342 bkjflops = aknz ;
343 }
344 }
345 else
346 {
347 // A(:,k)*B(k,j) requires aknz flops
348 bkjflops = aknz ;
349 }
350
351 // increment by flops for the single entry B(k,j)
352 // C(:,j)<#M(:,j)> += A(:,k)*B(k,j).
353 bjflops += bkjflops ;
354 }
355
356 //------------------------------------------------------------------
357 // log the flops for B(:,j)
358 //------------------------------------------------------------------
359
360 if (kk == kfirst)
361 {
362 Wfirst [taskid] = bjflops ;
363 }
364 else if (kk == klast)
365 {
366 Wlast [taskid] = bjflops ;
367 }
368 else
369 {
370 Bflops [kk] = bjflops ;
371 }
372 }
373
374 // compute the total work to access the mask, which is <= nnz (M)
375 total_Mwork += task_Mwork ;
376 }
377
378 //--------------------------------------------------------------------------
379 // reduce the first and last vector of each slice
380 //--------------------------------------------------------------------------
381
382 // See also Template/GB_select_phase1.c
383
384 int64_t kprior = -1 ;
385
386 for (int taskid = 0 ; taskid < B_ntasks ; taskid++)
387 {
388
389 //----------------------------------------------------------------------
390 // sum up the partial flops that taskid computed for kfirst
391 //----------------------------------------------------------------------
392
393 int64_t kfirst = kfirst_Bslice [taskid] ;
394 int64_t klast = klast_Bslice [taskid] ;
395
396 if (kfirst <= klast)
397 {
398 int64_t pB = pstart_Bslice [taskid] ;
399 int64_t pB_end = GBP (Bp, kfirst+1, bvlen) ;
400 pB_end = GB_IMIN (pB_end, pstart_Bslice [taskid+1]) ;
401 if (pB < pB_end)
402 {
403 if (kprior < kfirst)
404 {
405 // This task is the first one that did work on
406 // B(:,kfirst), so use it to start the reduction.
407 Bflops [kfirst] = Wfirst [taskid] ;
408 }
409 else
410 {
411 // subsequent task for B(:,kfirst)
412 Bflops [kfirst] += Wfirst [taskid] ;
413 }
414 kprior = kfirst ;
415 }
416 }
417
418 //----------------------------------------------------------------------
419 // sum up the partial flops that taskid computed for klast
420 //----------------------------------------------------------------------
421
422 if (kfirst < klast)
423 {
424 int64_t pB = GBP (Bp, klast, bvlen) ;
425 int64_t pB_end = pstart_Bslice [taskid+1] ;
426 if (pB < pB_end)
427 {
428 /* if */ ASSERT (kprior < klast) ;
429 {
430 // This task is the first one that did work on
431 // B(:,klast), so use it to start the reduction.
432 Bflops [klast] = Wlast [taskid] ;
433 }
434 /*
435 else
436 {
437 // If kfirst < klast and B(:,klast) is not empty,
438 // then this task is always the first one to do
439 // work on B(:,klast), so this case is never used.
440 ASSERT (GB_DEAD_CODE) ;
441 // subsequent task to work on B(:,klast)
442 Bflops [klast] += Wlast [taskid] ;
443 }
444 */
445 kprior = klast ;
446 }
447 }
448 }
449
450 //--------------------------------------------------------------------------
451 // cumulative sum of Bflops
452 //--------------------------------------------------------------------------
453
454 // Bflops = cumsum ([0 Bflops]) ;
455 ASSERT (Bflops [bnvec] == 0) ;
456 GB_cumsum (Bflops, bnvec, NULL, B_nthreads, Context) ;
457 // Bflops [bnvec] is now the total flop count, including the time to
458 // compute A*B and to handle the mask. total_Mwork is part of this total
459 // flop count, but is also returned separtely.
460
461 //--------------------------------------------------------------------------
462 // free workspace and return result
463 //--------------------------------------------------------------------------
464
465 GB_FREE_ALL ;
466 (*Mwork) = total_Mwork ;
467 return (GrB_SUCCESS) ;
468 }
469
470