1 //------------------------------------------------------------------------------
2 // GB_AxB_colscale: C = A*D where D is diagonal
3 //------------------------------------------------------------------------------
4 
5 // SuiteSparse:GraphBLAS, Timothy A. Davis, (c) 2017-2021, All Rights Reserved.
6 // SPDX-License-Identifier: Apache-2.0
7 
8 //------------------------------------------------------------------------------
9 
10 #include "GB_mxm.h"
11 #include "GB_binop.h"
12 #include "GB_apply.h"
13 #include "GB_ek_slice.h"
14 #ifndef GBCOMPACT
15 #include "GB_binop__include.h"
16 #endif
17 
18 #define GB_FREE_WORK                        \
19 {                                           \
20     GB_WERK_POP (A_ek_slicing, int64_t) ;   \
21 }
22 
23 #define GB_FREE_ALL                 \
24 {                                   \
25     GB_FREE_WORK ;                  \
26     GB_phbix_free (C) ;           \
27 }
28 
GB_AxB_colscale(GrB_Matrix C,const GrB_Matrix A,const GrB_Matrix D,const GrB_Semiring semiring,const bool flipxy,GB_Context Context)29 GrB_Info GB_AxB_colscale            // C = A*D, column scale with diagonal D
30 (
31     GrB_Matrix C,                   // output matrix, static header
32     const GrB_Matrix A,             // input matrix
33     const GrB_Matrix D,             // diagonal input matrix
34     const GrB_Semiring semiring,    // semiring that defines C=A*D
35     const bool flipxy,              // if true, do z=fmult(b,a) vs fmult(a,b)
36     GB_Context Context
37 )
38 {
39 
40     //--------------------------------------------------------------------------
41     // check inputs
42     //--------------------------------------------------------------------------
43 
44     GrB_Info info ;
45     ASSERT (C != NULL && C->static_header) ;
46     ASSERT_MATRIX_OK (A, "A for colscale A*D", GB0) ;
47     ASSERT_MATRIX_OK (D, "D for colscale A*D", GB0) ;
48     ASSERT (!GB_ZOMBIES (A)) ;
49     ASSERT (GB_JUMBLED_OK (A)) ;
50     ASSERT (!GB_PENDING (A)) ;
51     ASSERT (!GB_ZOMBIES (D)) ;
52     ASSERT (!GB_JUMBLED (D)) ;
53     ASSERT (!GB_PENDING (D)) ;
54     ASSERT_SEMIRING_OK (semiring, "semiring for numeric A*D", GB0) ;
55     ASSERT (A->vdim == D->vlen) ;
56     ASSERT (GB_is_diagonal (D, Context)) ;
57 
58     ASSERT (!GB_IS_BITMAP (A)) ;        // TODO: ok for now
59     ASSERT (!GB_IS_BITMAP (D)) ;
60     ASSERT (!GB_IS_FULL (D)) ;
61     GB_WERK_DECLARE (A_ek_slicing, int64_t) ;
62 
63     GBURBLE ("(%s=%s*%s) ",
64         GB_sparsity_char_matrix (A),    // C has the sparsity structure of A
65         GB_sparsity_char_matrix (A),
66         GB_sparsity_char_matrix (D)) ;
67 
68     //--------------------------------------------------------------------------
69     // get the semiring operators
70     //--------------------------------------------------------------------------
71 
72     GrB_BinaryOp mult = semiring->multiply ;
73     ASSERT (mult->ztype == semiring->add->op->ztype) ;
74     GB_Opcode opcode = mult->opcode ;
75     // GB_reduce_to_vector does not use GB_AxB_colscale:
76     ASSERT (!(mult->function == NULL &&
77         (opcode == GB_FIRST_opcode || opcode == GB_SECOND_opcode))) ;
78 
79     //--------------------------------------------------------------------------
80     // copy the pattern of A into C
81     //--------------------------------------------------------------------------
82 
83     // allocate C->x but do not initialize it
84     GB_OK (GB_dup2 (&C, A, false, mult->ztype, Context)) ; // static header
85 
86     //--------------------------------------------------------------------------
87     // apply a positional operator: convert C=A*D to C=op(A)
88     //--------------------------------------------------------------------------
89 
90     if (GB_OPCODE_IS_POSITIONAL (opcode))
91     {
92         if (flipxy)
93         {
94             // the multiplicative operator is fmult(y,x), so flip the opcode
95             bool handled ;
96             opcode = GB_binop_flip (opcode, &handled) ; // for positional ops
97             ASSERT (handled) ;      // all positional ops can be flipped
98         }
99         // determine unary operator to compute C=A*D
100         GrB_UnaryOp op1 = NULL ;
101         if (mult->ztype == GrB_INT64)
102         {
103             switch (opcode)
104             {
105                 // first_op(A,D) becomes position_op(A)
106                 case GB_FIRSTI_opcode   : op1 = GxB_POSITIONI_INT64  ;
107                     break ;
108                 case GB_FIRSTJ_opcode   : op1 = GxB_POSITIONJ_INT64  ;
109                     break ;
110                 case GB_FIRSTI1_opcode  : op1 = GxB_POSITIONI1_INT64 ;
111                     break ;
112                 case GB_FIRSTJ1_opcode  : op1 = GxB_POSITIONJ1_INT64 ;
113                     break ;
114                 // second_op(A,D) becomes position_j(A)
115                 case GB_SECONDI_opcode  :
116                 case GB_SECONDJ_opcode  : op1 = GxB_POSITIONJ_INT64  ;
117                     break ;
118                 case GB_SECONDI1_opcode :
119                 case GB_SECONDJ1_opcode : op1 = GxB_POSITIONJ1_INT64 ;
120                     break ;
121                 default:  ;
122             }
123         }
124         else
125         {
126             switch (opcode)
127             {
128                 // first_op(A,D) becomes position_op(A)
129                 case GB_FIRSTI_opcode   : op1 = GxB_POSITIONI_INT32  ;
130                     break ;
131                 case GB_FIRSTJ_opcode   : op1 = GxB_POSITIONJ_INT32  ;
132                     break ;
133                 case GB_FIRSTI1_opcode  : op1 = GxB_POSITIONI1_INT32 ;
134                     break ;
135                 case GB_FIRSTJ1_opcode  : op1 = GxB_POSITIONJ1_INT32 ;
136                     break ;
137                 // second_op(A,D) becomes position_j(A)
138                 case GB_SECONDI_opcode  :
139                 case GB_SECONDJ_opcode  : op1 = GxB_POSITIONJ_INT32  ;
140                     break ;
141                 case GB_SECONDI1_opcode :
142                 case GB_SECONDJ1_opcode : op1 = GxB_POSITIONJ1_INT32 ;
143                     break ;
144                 default:  ;
145             }
146         }
147         GB_OK (GB_apply_op ((GB_void *) (C->x), op1,    // positional unary op
148             NULL, NULL, false, A, Context)) ;
149         ASSERT_MATRIX_OK (C, "colscale positional: C = A*D output", GB0) ;
150         return (GrB_SUCCESS) ;
151     }
152 
153     //--------------------------------------------------------------------------
154     // determine the number of threads to use
155     //--------------------------------------------------------------------------
156 
157     GB_GET_NTHREADS_MAX (nthreads_max, chunk, Context) ;
158 
159     //--------------------------------------------------------------------------
160     // slice the entries for each task
161     //--------------------------------------------------------------------------
162 
163     int A_nthreads, A_ntasks ;
164     GB_SLICE_MATRIX (A, 32, chunk) ;
165 
166     //--------------------------------------------------------------------------
167     // determine if the values are accessed
168     //--------------------------------------------------------------------------
169 
170     bool op_is_first  = (opcode == GB_FIRST_opcode) ;
171     bool op_is_second = (opcode == GB_SECOND_opcode) ;
172     bool op_is_pair   = (opcode == GB_PAIR_opcode) ;
173     bool A_is_pattern = false ;
174     bool D_is_pattern = false ;
175 
176     if (flipxy)
177     {
178         // z = fmult (b,a) will be computed
179         A_is_pattern = op_is_first  || op_is_pair ;
180         D_is_pattern = op_is_second || op_is_pair ;
181         ASSERT (GB_IMPLIES (!A_is_pattern,
182             GB_Type_compatible (A->type, mult->ytype))) ;
183         ASSERT (GB_IMPLIES (!D_is_pattern,
184             GB_Type_compatible (D->type, mult->xtype))) ;
185     }
186     else
187     {
188         // z = fmult (a,b) will be computed
189         A_is_pattern = op_is_second || op_is_pair ;
190         D_is_pattern = op_is_first  || op_is_pair ;
191         ASSERT (GB_IMPLIES (!A_is_pattern,
192             GB_Type_compatible (A->type, mult->xtype))) ;
193         ASSERT (GB_IMPLIES (!D_is_pattern,
194             GB_Type_compatible (D->type, mult->ytype))) ;
195     }
196 
197     //--------------------------------------------------------------------------
198     // C = A*D, column scale, via built-in binary operators
199     //--------------------------------------------------------------------------
200 
201     bool done = false ;
202 
203     #ifndef GBCOMPACT
204 
205         //----------------------------------------------------------------------
206         // define the worker for the switch factory
207         //----------------------------------------------------------------------
208 
209         #define GB_AxD(mult,xname) GB (_AxD_ ## mult ## xname)
210 
211         #define GB_BINOP_WORKER(mult,xname)                                  \
212         {                                                                    \
213             info = GB_AxD(mult,xname) (C, A, A_is_pattern, D, D_is_pattern,  \
214                 A_ek_slicing, A_ntasks, A_nthreads) ;                        \
215             done = (info != GrB_NO_VALUE) ;                                  \
216         }                                                                    \
217         break ;
218 
219         //----------------------------------------------------------------------
220         // launch the switch factory
221         //----------------------------------------------------------------------
222 
223         GB_Type_code xcode, ycode, zcode ;
224         if (GB_binop_builtin (A->type, A_is_pattern, D->type, D_is_pattern,
225             mult, flipxy, &opcode, &xcode, &ycode, &zcode))
226         {
227             // C=A*D, colscale with built-in operator
228             #define GB_BINOP_IS_SEMIRING_MULTIPLIER
229             #include "GB_binop_factory.c"
230             #undef  GB_BINOP_IS_SEMIRING_MULTIPLIER
231         }
232 
233     #endif
234 
235     //--------------------------------------------------------------------------
236     // C = A*D, column scale, with typecasting or user-defined operator
237     //--------------------------------------------------------------------------
238 
239     if (!done)
240     {
241 
242         //----------------------------------------------------------------------
243         // get operators, functions, workspace, contents of A, D, and C
244         //----------------------------------------------------------------------
245 
246         GB_BURBLE_MATRIX (C, "(generic C=A*D colscale) ") ;
247 
248         GxB_binary_function fmult = mult->function ;
249 
250         size_t csize = C->type->size ;
251         size_t asize = A_is_pattern ? 0 : A->type->size ;
252         size_t dsize = D_is_pattern ? 0 : D->type->size ;
253 
254         size_t xsize = mult->xtype->size ;
255         size_t ysize = mult->ytype->size ;
256 
257         // scalar workspace: because of typecasting, the x/y types need not
258         // be the same as the size of the A and D types.
259         // flipxy false: aij = (xtype) A(i,j) and djj = (ytype) D(j,j)
260         // flipxy true:  aij = (ytype) A(i,j) and djj = (xtype) D(j,j)
261         size_t aij_size = flipxy ? ysize : xsize ;
262         size_t djj_size = flipxy ? xsize : ysize ;
263 
264         GB_void *restrict Cx = (GB_void *) C->x ;
265 
266         GB_cast_function cast_A, cast_D ;
267         if (flipxy)
268         {
269             // A is typecasted to y, and D is typecasted to x
270             cast_A = A_is_pattern ? NULL :
271                      GB_cast_factory (mult->ytype->code, A->type->code) ;
272             cast_D = D_is_pattern ? NULL :
273                      GB_cast_factory (mult->xtype->code, D->type->code) ;
274         }
275         else
276         {
277             // A is typecasted to x, and D is typecasted to y
278             cast_A = A_is_pattern ? NULL :
279                      GB_cast_factory (mult->xtype->code, A->type->code) ;
280             cast_D = D_is_pattern ? NULL :
281                      GB_cast_factory (mult->ytype->code, D->type->code) ;
282         }
283 
284         //----------------------------------------------------------------------
285         // C = A*D via function pointers, and typecasting
286         //----------------------------------------------------------------------
287 
288         // aij = A(i,j), located in Ax [pA]
289         #define GB_GETA(aij,Ax,pA)                                          \
290             GB_void aij [GB_VLA(aij_size)] ;                                \
291             if (!A_is_pattern) cast_A (aij, Ax +((pA)*asize), asize) ;
292 
293         // dji = D(j,j), located in Dx [j]
294         #define GB_GETB(djj,Dx,j)                                           \
295             GB_void djj [GB_VLA(djj_size)] ;                                \
296             if (!D_is_pattern) cast_D (djj, Dx +((j)*dsize), dsize) ;
297 
298         // address of Cx [p]
299         #define GB_CX(p) Cx +((p)*csize)
300 
301         #define GB_ATYPE GB_void
302         #define GB_BTYPE GB_void
303         #define GB_CTYPE GB_void
304 
305         // no vectorization
306         #define GB_PRAGMA_SIMD_VECTORIZE ;
307 
308         if (flipxy)
309         {
310             #define GB_BINOP(z,x,y,i,j) fmult (z,y,x)
311             #include "GB_AxB_colscale_meta.c"
312             #undef GB_BINOP
313         }
314         else
315         {
316             #define GB_BINOP(z,x,y,i,j) fmult (z,x,y)
317             #include "GB_AxB_colscale_meta.c"
318             #undef GB_BINOP
319         }
320     }
321 
322     //--------------------------------------------------------------------------
323     // free workspace and return result
324     //--------------------------------------------------------------------------
325 
326     ASSERT_MATRIX_OK (C, "colscale: C = A*D output", GB0) ;
327     GB_FREE_WORK ;
328     return (GrB_SUCCESS) ;
329 }
330 
331