1 //------------------------------------------------------------------------------
2 // GB_AxB_dot4: compute C+=A'*B in-place
3 //------------------------------------------------------------------------------
4 
5 // SuiteSparse:GraphBLAS, Timothy A. Davis, (c) 2017-2021, All Rights Reserved.
6 // SPDX-License-Identifier: Apache-2.0
7 
8 //------------------------------------------------------------------------------
9 
10 // GB_AxB_dot4 does its computation in a single phase, computing its result in
11 // the input matrix C, which is already dense.  The mask M is not handled by
12 // this function.
13 
14 #include "GB_mxm.h"
15 #include "GB_binop.h"
16 #include "GB_unused.h"
17 #ifndef GBCOMPACT
18 #include "GB_AxB__include.h"
19 #endif
20 
21 #define GB_FREE_WORK                    \
22 {                                       \
23     GB_WERK_POP (B_slice, int64_t) ;    \
24     GB_WERK_POP (A_slice, int64_t) ;    \
25 }
26 
GB_AxB_dot4(GrB_Matrix C,const GrB_Matrix A,const GrB_Matrix B,const GrB_Semiring semiring,const bool flipxy,GB_Context Context)27 GrB_Info GB_AxB_dot4                // C+=A'*B, dot product method
28 (
29     GrB_Matrix C,                   // input/output matrix, must be dense
30     const GrB_Matrix A,             // input matrix
31     const GrB_Matrix B,             // input matrix
32     const GrB_Semiring semiring,    // semiring that defines C+=A*B
33     const bool flipxy,              // if true, do z=fmult(b,a) vs fmult(a,b)
34     GB_Context Context
35 )
36 {
37 
38     //--------------------------------------------------------------------------
39     // check inputs
40     //--------------------------------------------------------------------------
41 
42     GrB_Info info ;
43     ASSERT_MATRIX_OK (C, "C for dot in-place += A'*B", GB0) ;
44     ASSERT_MATRIX_OK (A, "A for dot in-place += A'*B", GB0) ;
45     ASSERT_MATRIX_OK (B, "B for dot in-place += A'*B", GB0) ;
46     ASSERT (GB_is_dense (C)) ;
47     ASSERT (!GB_ZOMBIES (C)) ;
48     ASSERT (!GB_JUMBLED (C)) ;
49     ASSERT (!GB_PENDING (C)) ;
50     ASSERT (!GB_ZOMBIES (A)) ;
51     ASSERT (!GB_JUMBLED (A)) ;
52     ASSERT (!GB_PENDING (A)) ;
53     ASSERT (!GB_ZOMBIES (B)) ;
54     ASSERT (!GB_JUMBLED (B)) ;
55     ASSERT (!GB_PENDING (B)) ;
56 
57     ASSERT (!GB_IS_BITMAP (C)) ;
58 
59     ASSERT_SEMIRING_OK (semiring, "semiring for in-place += A'*B", GB0) ;
60     ASSERT (A->vlen == B->vlen) ;
61 
62     GB_WERK_DECLARE (A_slice, int64_t) ;
63     GB_WERK_DECLARE (B_slice, int64_t) ;
64 
65     GBURBLE ("(%s+=%s'*%s) ",
66         GB_sparsity_char_matrix (C),
67         GB_sparsity_char_matrix (A),
68         GB_sparsity_char_matrix (B)) ;
69 
70     //--------------------------------------------------------------------------
71     // determine the number of threads to use
72     //--------------------------------------------------------------------------
73 
74     int64_t anz = GB_NNZ_HELD (A) ;
75     int64_t bnz = GB_NNZ_HELD (B) ;
76     GB_GET_NTHREADS_MAX (nthreads_max, chunk, Context) ;
77     int nthreads = GB_nthreads (anz + bnz, chunk, nthreads_max) ;
78 
79     //--------------------------------------------------------------------------
80     // get the semiring operators
81     //--------------------------------------------------------------------------
82 
83     GrB_BinaryOp mult = semiring->multiply ;
84     GrB_Monoid add = semiring->add ;
85     ASSERT (mult->ztype == add->op->ztype) ;
86     ASSERT (C->type     == add->op->ztype) ;
87 
88     bool op_is_first  = mult->opcode == GB_FIRST_opcode ;
89     bool op_is_second = mult->opcode == GB_SECOND_opcode ;
90     bool op_is_pair   = mult->opcode == GB_PAIR_opcode ;
91     bool A_is_pattern = false ;
92     bool B_is_pattern = false ;
93 
94     if (flipxy)
95     {
96         // z = fmult (b,a) will be computed
97         A_is_pattern = op_is_first  || op_is_pair ;
98         B_is_pattern = op_is_second || op_is_pair ;
99         ASSERT (GB_IMPLIES (!A_is_pattern,
100             GB_Type_compatible (A->type, mult->ytype))) ;
101         ASSERT (GB_IMPLIES (!B_is_pattern,
102             GB_Type_compatible (B->type, mult->xtype))) ;
103     }
104     else
105     {
106         // z = fmult (a,b) will be computed
107         A_is_pattern = op_is_second || op_is_pair ;
108         B_is_pattern = op_is_first  || op_is_pair ;
109         ASSERT (GB_IMPLIES (!A_is_pattern,
110             GB_Type_compatible (A->type, mult->xtype))) ;
111         ASSERT (GB_IMPLIES (!B_is_pattern,
112             GB_Type_compatible (B->type, mult->ytype))) ;
113     }
114 
115     //--------------------------------------------------------------------------
116     // slice A and B
117     //--------------------------------------------------------------------------
118 
119     // A and B can have any sparsity: full, sparse, or hypersparse.
120     // C is always full.
121 
122     int64_t anvec = A->nvec ;
123     int64_t vlen  = A->vlen ;
124     int64_t bnvec = B->nvec ;
125 
126     int naslice = (nthreads == 1) ? 1 : (16 * nthreads) ;
127     int nbslice = (nthreads == 1) ? 1 : (16 * nthreads) ;
128 
129     naslice = GB_IMIN (naslice, anvec) ;
130     nbslice = GB_IMIN (nbslice, bnvec) ;
131 
132     GB_WERK_PUSH (A_slice, naslice + 1, int64_t) ;
133     GB_WERK_PUSH (B_slice, nbslice + 1, int64_t) ;
134     if (A_slice == NULL || B_slice == NULL)
135     {
136         // out of memory
137         GB_FREE_WORK ;
138         return (GrB_OUT_OF_MEMORY) ;
139     }
140     GB_pslice (A_slice, A->p, anvec, naslice, false) ;
141     GB_pslice (B_slice, B->p, bnvec, nbslice, false) ;
142 
143     //--------------------------------------------------------------------------
144     // C += A'*B, computing each entry with a dot product, via builtin semiring
145     //--------------------------------------------------------------------------
146 
147     bool done = false ;
148 
149     #ifndef GBCOMPACT
150 
151         //----------------------------------------------------------------------
152         // define the worker for the switch factory
153         //----------------------------------------------------------------------
154 
155         #define GB_Adot4B(add,mult,xname) GB (_Adot4B_ ## add ## mult ## xname)
156 
157         #define GB_AxB_WORKER(add,mult,xname)           \
158         {                                               \
159             info = GB_Adot4B (add,mult,xname) (C,       \
160                 A, A_is_pattern, A_slice, naslice,      \
161                 B, B_is_pattern, B_slice, nbslice,      \
162                 nthreads) ;                             \
163             done = (info != GrB_NO_VALUE) ;             \
164         }                                               \
165         break ;
166 
167         //----------------------------------------------------------------------
168         // launch the switch factory
169         //----------------------------------------------------------------------
170 
171         GB_Opcode mult_opcode, add_opcode ;
172         GB_Type_code xcode, ycode, zcode ;
173 
174         if (GB_AxB_semiring_builtin (A, A_is_pattern, B, B_is_pattern, semiring,
175             flipxy, &mult_opcode, &add_opcode, &xcode, &ycode, &zcode))
176         {
177             #include "GB_AxB_factory.c"
178         }
179 
180     #endif
181 
182     //--------------------------------------------------------------------------
183     // C += A'*B, computing each entry with a dot product, with typecasting
184     //--------------------------------------------------------------------------
185 
186     if (!done)
187     {
188         #define GB_DOT4_GENERIC
189         GB_BURBLE_MATRIX (C, "(generic C+=A'*B) ") ;
190         #include "GB_AxB_dot_generic.c"
191     }
192 
193     //--------------------------------------------------------------------------
194     // free workspace and return result
195     //--------------------------------------------------------------------------
196 
197     GB_FREE_WORK ;
198     ASSERT_MATRIX_OK (C, "dot4: C += A'*B output", GB0) ;
199     return (GrB_SUCCESS) ;
200 }
201 
202