1 //------------------------------------------------------------------------------
2 // GB_AxB_dot4: compute C+=A'*B in-place
3 //------------------------------------------------------------------------------
4
5 // SuiteSparse:GraphBLAS, Timothy A. Davis, (c) 2017-2021, All Rights Reserved.
6 // SPDX-License-Identifier: Apache-2.0
7
8 //------------------------------------------------------------------------------
9
10 // GB_AxB_dot4 does its computation in a single phase, computing its result in
11 // the input matrix C, which is already dense. The mask M is not handled by
12 // this function.
13
14 #include "GB_mxm.h"
15 #include "GB_binop.h"
16 #include "GB_unused.h"
17 #ifndef GBCOMPACT
18 #include "GB_AxB__include.h"
19 #endif
20
21 #define GB_FREE_WORK \
22 { \
23 GB_WERK_POP (B_slice, int64_t) ; \
24 GB_WERK_POP (A_slice, int64_t) ; \
25 }
26
GB_AxB_dot4(GrB_Matrix C,const GrB_Matrix A,const GrB_Matrix B,const GrB_Semiring semiring,const bool flipxy,GB_Context Context)27 GrB_Info GB_AxB_dot4 // C+=A'*B, dot product method
28 (
29 GrB_Matrix C, // input/output matrix, must be dense
30 const GrB_Matrix A, // input matrix
31 const GrB_Matrix B, // input matrix
32 const GrB_Semiring semiring, // semiring that defines C+=A*B
33 const bool flipxy, // if true, do z=fmult(b,a) vs fmult(a,b)
34 GB_Context Context
35 )
36 {
37
38 //--------------------------------------------------------------------------
39 // check inputs
40 //--------------------------------------------------------------------------
41
42 GrB_Info info ;
43 ASSERT_MATRIX_OK (C, "C for dot in-place += A'*B", GB0) ;
44 ASSERT_MATRIX_OK (A, "A for dot in-place += A'*B", GB0) ;
45 ASSERT_MATRIX_OK (B, "B for dot in-place += A'*B", GB0) ;
46 ASSERT (GB_is_dense (C)) ;
47 ASSERT (!GB_ZOMBIES (C)) ;
48 ASSERT (!GB_JUMBLED (C)) ;
49 ASSERT (!GB_PENDING (C)) ;
50 ASSERT (!GB_ZOMBIES (A)) ;
51 ASSERT (!GB_JUMBLED (A)) ;
52 ASSERT (!GB_PENDING (A)) ;
53 ASSERT (!GB_ZOMBIES (B)) ;
54 ASSERT (!GB_JUMBLED (B)) ;
55 ASSERT (!GB_PENDING (B)) ;
56
57 ASSERT (!GB_IS_BITMAP (C)) ;
58
59 ASSERT_SEMIRING_OK (semiring, "semiring for in-place += A'*B", GB0) ;
60 ASSERT (A->vlen == B->vlen) ;
61
62 GB_WERK_DECLARE (A_slice, int64_t) ;
63 GB_WERK_DECLARE (B_slice, int64_t) ;
64
65 GBURBLE ("(%s+=%s'*%s) ",
66 GB_sparsity_char_matrix (C),
67 GB_sparsity_char_matrix (A),
68 GB_sparsity_char_matrix (B)) ;
69
70 //--------------------------------------------------------------------------
71 // determine the number of threads to use
72 //--------------------------------------------------------------------------
73
74 int64_t anz = GB_NNZ_HELD (A) ;
75 int64_t bnz = GB_NNZ_HELD (B) ;
76 GB_GET_NTHREADS_MAX (nthreads_max, chunk, Context) ;
77 int nthreads = GB_nthreads (anz + bnz, chunk, nthreads_max) ;
78
79 //--------------------------------------------------------------------------
80 // get the semiring operators
81 //--------------------------------------------------------------------------
82
83 GrB_BinaryOp mult = semiring->multiply ;
84 GrB_Monoid add = semiring->add ;
85 ASSERT (mult->ztype == add->op->ztype) ;
86 ASSERT (C->type == add->op->ztype) ;
87
88 bool op_is_first = mult->opcode == GB_FIRST_opcode ;
89 bool op_is_second = mult->opcode == GB_SECOND_opcode ;
90 bool op_is_pair = mult->opcode == GB_PAIR_opcode ;
91 bool A_is_pattern = false ;
92 bool B_is_pattern = false ;
93
94 if (flipxy)
95 {
96 // z = fmult (b,a) will be computed
97 A_is_pattern = op_is_first || op_is_pair ;
98 B_is_pattern = op_is_second || op_is_pair ;
99 ASSERT (GB_IMPLIES (!A_is_pattern,
100 GB_Type_compatible (A->type, mult->ytype))) ;
101 ASSERT (GB_IMPLIES (!B_is_pattern,
102 GB_Type_compatible (B->type, mult->xtype))) ;
103 }
104 else
105 {
106 // z = fmult (a,b) will be computed
107 A_is_pattern = op_is_second || op_is_pair ;
108 B_is_pattern = op_is_first || op_is_pair ;
109 ASSERT (GB_IMPLIES (!A_is_pattern,
110 GB_Type_compatible (A->type, mult->xtype))) ;
111 ASSERT (GB_IMPLIES (!B_is_pattern,
112 GB_Type_compatible (B->type, mult->ytype))) ;
113 }
114
115 //--------------------------------------------------------------------------
116 // slice A and B
117 //--------------------------------------------------------------------------
118
119 // A and B can have any sparsity: full, sparse, or hypersparse.
120 // C is always full.
121
122 int64_t anvec = A->nvec ;
123 int64_t vlen = A->vlen ;
124 int64_t bnvec = B->nvec ;
125
126 int naslice = (nthreads == 1) ? 1 : (16 * nthreads) ;
127 int nbslice = (nthreads == 1) ? 1 : (16 * nthreads) ;
128
129 naslice = GB_IMIN (naslice, anvec) ;
130 nbslice = GB_IMIN (nbslice, bnvec) ;
131
132 GB_WERK_PUSH (A_slice, naslice + 1, int64_t) ;
133 GB_WERK_PUSH (B_slice, nbslice + 1, int64_t) ;
134 if (A_slice == NULL || B_slice == NULL)
135 {
136 // out of memory
137 GB_FREE_WORK ;
138 return (GrB_OUT_OF_MEMORY) ;
139 }
140 GB_pslice (A_slice, A->p, anvec, naslice, false) ;
141 GB_pslice (B_slice, B->p, bnvec, nbslice, false) ;
142
143 //--------------------------------------------------------------------------
144 // C += A'*B, computing each entry with a dot product, via builtin semiring
145 //--------------------------------------------------------------------------
146
147 bool done = false ;
148
149 #ifndef GBCOMPACT
150
151 //----------------------------------------------------------------------
152 // define the worker for the switch factory
153 //----------------------------------------------------------------------
154
155 #define GB_Adot4B(add,mult,xname) GB (_Adot4B_ ## add ## mult ## xname)
156
157 #define GB_AxB_WORKER(add,mult,xname) \
158 { \
159 info = GB_Adot4B (add,mult,xname) (C, \
160 A, A_is_pattern, A_slice, naslice, \
161 B, B_is_pattern, B_slice, nbslice, \
162 nthreads) ; \
163 done = (info != GrB_NO_VALUE) ; \
164 } \
165 break ;
166
167 //----------------------------------------------------------------------
168 // launch the switch factory
169 //----------------------------------------------------------------------
170
171 GB_Opcode mult_opcode, add_opcode ;
172 GB_Type_code xcode, ycode, zcode ;
173
174 if (GB_AxB_semiring_builtin (A, A_is_pattern, B, B_is_pattern, semiring,
175 flipxy, &mult_opcode, &add_opcode, &xcode, &ycode, &zcode))
176 {
177 #include "GB_AxB_factory.c"
178 }
179
180 #endif
181
182 //--------------------------------------------------------------------------
183 // C += A'*B, computing each entry with a dot product, with typecasting
184 //--------------------------------------------------------------------------
185
186 if (!done)
187 {
188 #define GB_DOT4_GENERIC
189 GB_BURBLE_MATRIX (C, "(generic C+=A'*B) ") ;
190 #include "GB_AxB_dot_generic.c"
191 }
192
193 //--------------------------------------------------------------------------
194 // free workspace and return result
195 //--------------------------------------------------------------------------
196
197 GB_FREE_WORK ;
198 ASSERT_MATRIX_OK (C, "dot4: C += A'*B output", GB0) ;
199 return (GrB_SUCCESS) ;
200 }
201
202