1 //------------------------------------------------------------------------------
2 // GB_AxB_colscale: C = A*D where D is diagonal
3 //------------------------------------------------------------------------------
4
5 // SuiteSparse:GraphBLAS, Timothy A. Davis, (c) 2017-2021, All Rights Reserved.
6 // SPDX-License-Identifier: Apache-2.0
7
8 //------------------------------------------------------------------------------
9
10 #include "GB_mxm.h"
11 #include "GB_binop.h"
12 #include "GB_apply.h"
13 #include "GB_ek_slice.h"
14 #ifndef GBCOMPACT
15 #include "GB_binop__include.h"
16 #endif
17
18 #define GB_FREE_WORK \
19 { \
20 GB_WERK_POP (A_ek_slicing, int64_t) ; \
21 }
22
23 #define GB_FREE_ALL \
24 { \
25 GB_FREE_WORK ; \
26 GB_phbix_free (C) ; \
27 }
28
GB_AxB_colscale(GrB_Matrix C,const GrB_Matrix A,const GrB_Matrix D,const GrB_Semiring semiring,const bool flipxy,GB_Context Context)29 GrB_Info GB_AxB_colscale // C = A*D, column scale with diagonal D
30 (
31 GrB_Matrix C, // output matrix, static header
32 const GrB_Matrix A, // input matrix
33 const GrB_Matrix D, // diagonal input matrix
34 const GrB_Semiring semiring, // semiring that defines C=A*D
35 const bool flipxy, // if true, do z=fmult(b,a) vs fmult(a,b)
36 GB_Context Context
37 )
38 {
39
40 //--------------------------------------------------------------------------
41 // check inputs
42 //--------------------------------------------------------------------------
43
44 GrB_Info info ;
45 ASSERT (C != NULL && C->static_header) ;
46 ASSERT_MATRIX_OK (A, "A for colscale A*D", GB0) ;
47 ASSERT_MATRIX_OK (D, "D for colscale A*D", GB0) ;
48 ASSERT (!GB_ZOMBIES (A)) ;
49 ASSERT (GB_JUMBLED_OK (A)) ;
50 ASSERT (!GB_PENDING (A)) ;
51 ASSERT (!GB_ZOMBIES (D)) ;
52 ASSERT (!GB_JUMBLED (D)) ;
53 ASSERT (!GB_PENDING (D)) ;
54 ASSERT_SEMIRING_OK (semiring, "semiring for numeric A*D", GB0) ;
55 ASSERT (A->vdim == D->vlen) ;
56 ASSERT (GB_is_diagonal (D, Context)) ;
57
58 ASSERT (!GB_IS_BITMAP (A)) ; // TODO: ok for now
59 ASSERT (!GB_IS_BITMAP (D)) ;
60 ASSERT (!GB_IS_FULL (D)) ;
61 GB_WERK_DECLARE (A_ek_slicing, int64_t) ;
62
63 GBURBLE ("(%s=%s*%s) ",
64 GB_sparsity_char_matrix (A), // C has the sparsity structure of A
65 GB_sparsity_char_matrix (A),
66 GB_sparsity_char_matrix (D)) ;
67
68 //--------------------------------------------------------------------------
69 // get the semiring operators
70 //--------------------------------------------------------------------------
71
72 GrB_BinaryOp mult = semiring->multiply ;
73 ASSERT (mult->ztype == semiring->add->op->ztype) ;
74 GB_Opcode opcode = mult->opcode ;
75 // GB_reduce_to_vector does not use GB_AxB_colscale:
76 ASSERT (!(mult->function == NULL &&
77 (opcode == GB_FIRST_opcode || opcode == GB_SECOND_opcode))) ;
78
79 //--------------------------------------------------------------------------
80 // copy the pattern of A into C
81 //--------------------------------------------------------------------------
82
83 // allocate C->x but do not initialize it
84 GB_OK (GB_dup2 (&C, A, false, mult->ztype, Context)) ; // static header
85
86 //--------------------------------------------------------------------------
87 // apply a positional operator: convert C=A*D to C=op(A)
88 //--------------------------------------------------------------------------
89
90 if (GB_OPCODE_IS_POSITIONAL (opcode))
91 {
92 if (flipxy)
93 {
94 // the multiplicative operator is fmult(y,x), so flip the opcode
95 bool handled ;
96 opcode = GB_binop_flip (opcode, &handled) ; // for positional ops
97 ASSERT (handled) ; // all positional ops can be flipped
98 }
99 // determine unary operator to compute C=A*D
100 GrB_UnaryOp op1 = NULL ;
101 if (mult->ztype == GrB_INT64)
102 {
103 switch (opcode)
104 {
105 // first_op(A,D) becomes position_op(A)
106 case GB_FIRSTI_opcode : op1 = GxB_POSITIONI_INT64 ;
107 break ;
108 case GB_FIRSTJ_opcode : op1 = GxB_POSITIONJ_INT64 ;
109 break ;
110 case GB_FIRSTI1_opcode : op1 = GxB_POSITIONI1_INT64 ;
111 break ;
112 case GB_FIRSTJ1_opcode : op1 = GxB_POSITIONJ1_INT64 ;
113 break ;
114 // second_op(A,D) becomes position_j(A)
115 case GB_SECONDI_opcode :
116 case GB_SECONDJ_opcode : op1 = GxB_POSITIONJ_INT64 ;
117 break ;
118 case GB_SECONDI1_opcode :
119 case GB_SECONDJ1_opcode : op1 = GxB_POSITIONJ1_INT64 ;
120 break ;
121 default: ;
122 }
123 }
124 else
125 {
126 switch (opcode)
127 {
128 // first_op(A,D) becomes position_op(A)
129 case GB_FIRSTI_opcode : op1 = GxB_POSITIONI_INT32 ;
130 break ;
131 case GB_FIRSTJ_opcode : op1 = GxB_POSITIONJ_INT32 ;
132 break ;
133 case GB_FIRSTI1_opcode : op1 = GxB_POSITIONI1_INT32 ;
134 break ;
135 case GB_FIRSTJ1_opcode : op1 = GxB_POSITIONJ1_INT32 ;
136 break ;
137 // second_op(A,D) becomes position_j(A)
138 case GB_SECONDI_opcode :
139 case GB_SECONDJ_opcode : op1 = GxB_POSITIONJ_INT32 ;
140 break ;
141 case GB_SECONDI1_opcode :
142 case GB_SECONDJ1_opcode : op1 = GxB_POSITIONJ1_INT32 ;
143 break ;
144 default: ;
145 }
146 }
147 GB_OK (GB_apply_op ((GB_void *) (C->x), op1, // positional unary op
148 NULL, NULL, false, A, Context)) ;
149 ASSERT_MATRIX_OK (C, "colscale positional: C = A*D output", GB0) ;
150 return (GrB_SUCCESS) ;
151 }
152
153 //--------------------------------------------------------------------------
154 // determine the number of threads to use
155 //--------------------------------------------------------------------------
156
157 GB_GET_NTHREADS_MAX (nthreads_max, chunk, Context) ;
158
159 //--------------------------------------------------------------------------
160 // slice the entries for each task
161 //--------------------------------------------------------------------------
162
163 int A_nthreads, A_ntasks ;
164 GB_SLICE_MATRIX (A, 32, chunk) ;
165
166 //--------------------------------------------------------------------------
167 // determine if the values are accessed
168 //--------------------------------------------------------------------------
169
170 bool op_is_first = (opcode == GB_FIRST_opcode) ;
171 bool op_is_second = (opcode == GB_SECOND_opcode) ;
172 bool op_is_pair = (opcode == GB_PAIR_opcode) ;
173 bool A_is_pattern = false ;
174 bool D_is_pattern = false ;
175
176 if (flipxy)
177 {
178 // z = fmult (b,a) will be computed
179 A_is_pattern = op_is_first || op_is_pair ;
180 D_is_pattern = op_is_second || op_is_pair ;
181 ASSERT (GB_IMPLIES (!A_is_pattern,
182 GB_Type_compatible (A->type, mult->ytype))) ;
183 ASSERT (GB_IMPLIES (!D_is_pattern,
184 GB_Type_compatible (D->type, mult->xtype))) ;
185 }
186 else
187 {
188 // z = fmult (a,b) will be computed
189 A_is_pattern = op_is_second || op_is_pair ;
190 D_is_pattern = op_is_first || op_is_pair ;
191 ASSERT (GB_IMPLIES (!A_is_pattern,
192 GB_Type_compatible (A->type, mult->xtype))) ;
193 ASSERT (GB_IMPLIES (!D_is_pattern,
194 GB_Type_compatible (D->type, mult->ytype))) ;
195 }
196
197 //--------------------------------------------------------------------------
198 // C = A*D, column scale, via built-in binary operators
199 //--------------------------------------------------------------------------
200
201 bool done = false ;
202
203 #ifndef GBCOMPACT
204
205 //----------------------------------------------------------------------
206 // define the worker for the switch factory
207 //----------------------------------------------------------------------
208
209 #define GB_AxD(mult,xname) GB (_AxD_ ## mult ## xname)
210
211 #define GB_BINOP_WORKER(mult,xname) \
212 { \
213 info = GB_AxD(mult,xname) (C, A, A_is_pattern, D, D_is_pattern, \
214 A_ek_slicing, A_ntasks, A_nthreads) ; \
215 done = (info != GrB_NO_VALUE) ; \
216 } \
217 break ;
218
219 //----------------------------------------------------------------------
220 // launch the switch factory
221 //----------------------------------------------------------------------
222
223 GB_Type_code xcode, ycode, zcode ;
224 if (GB_binop_builtin (A->type, A_is_pattern, D->type, D_is_pattern,
225 mult, flipxy, &opcode, &xcode, &ycode, &zcode))
226 {
227 // C=A*D, colscale with built-in operator
228 #define GB_BINOP_IS_SEMIRING_MULTIPLIER
229 #include "GB_binop_factory.c"
230 #undef GB_BINOP_IS_SEMIRING_MULTIPLIER
231 }
232
233 #endif
234
235 //--------------------------------------------------------------------------
236 // C = A*D, column scale, with typecasting or user-defined operator
237 //--------------------------------------------------------------------------
238
239 if (!done)
240 {
241
242 //----------------------------------------------------------------------
243 // get operators, functions, workspace, contents of A, D, and C
244 //----------------------------------------------------------------------
245
246 GB_BURBLE_MATRIX (C, "(generic C=A*D colscale) ") ;
247
248 GxB_binary_function fmult = mult->function ;
249
250 size_t csize = C->type->size ;
251 size_t asize = A_is_pattern ? 0 : A->type->size ;
252 size_t dsize = D_is_pattern ? 0 : D->type->size ;
253
254 size_t xsize = mult->xtype->size ;
255 size_t ysize = mult->ytype->size ;
256
257 // scalar workspace: because of typecasting, the x/y types need not
258 // be the same as the size of the A and D types.
259 // flipxy false: aij = (xtype) A(i,j) and djj = (ytype) D(j,j)
260 // flipxy true: aij = (ytype) A(i,j) and djj = (xtype) D(j,j)
261 size_t aij_size = flipxy ? ysize : xsize ;
262 size_t djj_size = flipxy ? xsize : ysize ;
263
264 GB_void *restrict Cx = (GB_void *) C->x ;
265
266 GB_cast_function cast_A, cast_D ;
267 if (flipxy)
268 {
269 // A is typecasted to y, and D is typecasted to x
270 cast_A = A_is_pattern ? NULL :
271 GB_cast_factory (mult->ytype->code, A->type->code) ;
272 cast_D = D_is_pattern ? NULL :
273 GB_cast_factory (mult->xtype->code, D->type->code) ;
274 }
275 else
276 {
277 // A is typecasted to x, and D is typecasted to y
278 cast_A = A_is_pattern ? NULL :
279 GB_cast_factory (mult->xtype->code, A->type->code) ;
280 cast_D = D_is_pattern ? NULL :
281 GB_cast_factory (mult->ytype->code, D->type->code) ;
282 }
283
284 //----------------------------------------------------------------------
285 // C = A*D via function pointers, and typecasting
286 //----------------------------------------------------------------------
287
288 // aij = A(i,j), located in Ax [pA]
289 #define GB_GETA(aij,Ax,pA) \
290 GB_void aij [GB_VLA(aij_size)] ; \
291 if (!A_is_pattern) cast_A (aij, Ax +((pA)*asize), asize) ;
292
293 // dji = D(j,j), located in Dx [j]
294 #define GB_GETB(djj,Dx,j) \
295 GB_void djj [GB_VLA(djj_size)] ; \
296 if (!D_is_pattern) cast_D (djj, Dx +((j)*dsize), dsize) ;
297
298 // address of Cx [p]
299 #define GB_CX(p) Cx +((p)*csize)
300
301 #define GB_ATYPE GB_void
302 #define GB_BTYPE GB_void
303 #define GB_CTYPE GB_void
304
305 // no vectorization
306 #define GB_PRAGMA_SIMD_VECTORIZE ;
307
308 if (flipxy)
309 {
310 #define GB_BINOP(z,x,y,i,j) fmult (z,y,x)
311 #include "GB_AxB_colscale_meta.c"
312 #undef GB_BINOP
313 }
314 else
315 {
316 #define GB_BINOP(z,x,y,i,j) fmult (z,x,y)
317 #include "GB_AxB_colscale_meta.c"
318 #undef GB_BINOP
319 }
320 }
321
322 //--------------------------------------------------------------------------
323 // free workspace and return result
324 //--------------------------------------------------------------------------
325
326 ASSERT_MATRIX_OK (C, "colscale: C = A*D output", GB0) ;
327 GB_FREE_WORK ;
328 return (GrB_SUCCESS) ;
329 }
330
331