1 //------------------------------------------------------------------------------
2 // GB_AxB_dot_generic: generic template for all dot-product methods
3 //------------------------------------------------------------------------------
4 
5 // SuiteSparse:GraphBLAS, Timothy A. Davis, (c) 2017-2021, All Rights Reserved.
6 // SPDX-License-Identifier: Apache-2.0
7 
8 //------------------------------------------------------------------------------
9 
10 // This template serves all three dot product methods.  The #including file
11 // defines GB_DOT2_GENERIC, GB_DOT3_GENERIC, or GB_DOT4_GENERIC.
12 
13 {
14 
15     //--------------------------------------------------------------------------
16     // get operators, functions, workspace, contents of A, B, C
17     //--------------------------------------------------------------------------
18 
19     GxB_binary_function fmult = mult->function ;    // NULL if positional
20     GxB_binary_function fadd  = add->op->function ;
21     GB_Opcode opcode = mult->opcode ;
22     bool op_is_positional = GB_OPCODE_IS_POSITIONAL (opcode) ;
23 
24     size_t csize = C->type->size ;
25     size_t asize = A_is_pattern ? 0 : A->type->size ;
26     size_t bsize = B_is_pattern ? 0 : B->type->size ;
27 
28     size_t xsize = mult->xtype->size ;
29     size_t ysize = mult->ytype->size ;
30 
31     // scalar workspace: because of typecasting, the x/y types need not
32     // be the same as the size of the A and B types.
33     // flipxy false: aki = (xtype) A(k,i) and bkj = (ytype) B(k,j)
34     // flipxy true:  aki = (ytype) A(k,i) and bkj = (xtype) B(k,j)
35     size_t aki_size = flipxy ? ysize : xsize ;
36     size_t bkj_size = flipxy ? xsize : ysize ;
37 
38     GB_void *restrict terminal = (GB_void *) add->terminal ;
39 
40     GB_cast_function cast_A, cast_B ;
41     if (flipxy)
42     {
43         // A is typecasted to y, and B is typecasted to x
44         cast_A = A_is_pattern ? NULL :
45                  GB_cast_factory (mult->ytype->code, A->type->code) ;
46         cast_B = B_is_pattern ? NULL :
47                  GB_cast_factory (mult->xtype->code, B->type->code) ;
48     }
49     else
50     {
51         // A is typecasted to x, and B is typecasted to y
52         cast_A = A_is_pattern ? NULL :
53                  GB_cast_factory (mult->xtype->code, A->type->code) ;
54         cast_B = B_is_pattern ? NULL :
55                  GB_cast_factory (mult->ytype->code, B->type->code) ;
56     }
57 
58     //--------------------------------------------------------------------------
59     // C = A'*B via dot products, function pointers, and typecasting
60     //--------------------------------------------------------------------------
61 
62     #define GB_ATYPE GB_void
63     #define GB_BTYPE GB_void
64     #define GB_PHASE_2_OF_2
65 
66     // no vectorization
67     #define GB_PRAGMA_SIMD_VECTORIZE ;
68     #define GB_PRAGMA_SIMD_DOT(cij) ;
69 
70     if (op_is_positional)
71     {
72 
73         //----------------------------------------------------------------------
74         // generic semirings with positional multiply operators
75         //----------------------------------------------------------------------
76 
77         if (flipxy)
78         {
79             // flip a positional multiplicative operator
80             bool handled ;
81             opcode = GB_binop_flip (opcode, &handled) ; // for positional ops
82             ASSERT (handled) ;      // all positional ops can be flipped
83         }
84 
85         // aki = A(i,k), located in Ax [pA], value not used
86         #define GB_GETA(aki,Ax,pA) ;
87 
88         // bkj = B(k,j), located in Bx [pB], value not used
89         #define GB_GETB(bkj,Bx,pB) ;
90 
91         // define cij for each task
92         #define GB_CIJ_DECLARE(cij) GB_CTYPE cij
93 
94         // address of Cx [p]
95         #define GB_CX(p) (&Cx [p])
96 
97         // cij = Cx [p]
98         #define GB_GETC(cij,p) cij = Cx [p]
99 
100         // Cx [p] = cij
101         #define GB_PUTC(cij,p) Cx [p] = cij
102 
103         // break if cij reaches the terminal value
104         #define GB_DOT_TERMINAL(cij)                                    \
105             if (is_terminal && cij == cij_terminal)                     \
106             {                                                           \
107                 break ;                                                 \
108             }
109 
110         // C(i,j) += (A')(i,k) * B(k,j)
111         #define GB_MULTADD(cij, aki, bkj, i, k, j)                      \
112             GB_CTYPE zwork ;                                            \
113             GB_MULT (zwork, aki, bkj, i, k, j) ;                        \
114             fadd (&cij, &cij, &zwork)
115 
116         int64_t offset = GB_positional_offset (opcode) ;
117 
118         if (mult->ztype == GrB_INT64)
119         {
120             #define GB_CTYPE int64_t
121             int64_t cij_terminal = 0 ;
122             bool is_terminal = (terminal != NULL) ;
123             if (is_terminal)
124             {
125                 memcpy (&cij_terminal, terminal, sizeof (int64_t)) ;
126             }
127             switch (opcode)
128             {
129                 case GB_FIRSTI_opcode   :   // z = first_i(A'(i,k),y) == i
130                 case GB_FIRSTI1_opcode  :   // z = first_i1(A'(i,k),y) == i+1
131                     #undef  GB_MULT
132                     #define GB_MULT(t, aki, bkj, i, k, j) t = i + offset
133                     #if defined ( GB_DOT2_GENERIC )
134                     #include "GB_AxB_dot2_meta.c"
135                     #elif defined ( GB_DOT3_GENERIC )
136                     #include "GB_AxB_dot3_meta.c"
137                     #else
138                     #include "GB_AxB_dot4_meta.c"
139                     #endif
140                     break ;
141                 case GB_FIRSTJ_opcode   :   // z = first_j(A'(i,k),y) == k
142                 case GB_FIRSTJ1_opcode  :   // z = first_j1(A'(i,k),y) == k+1
143                 case GB_SECONDI_opcode  :   // z = second_i(x,B(k,j)) == k
144                 case GB_SECONDI1_opcode :   // z = second_i1(x,B(k,j)) == k+1
145                     #undef  GB_MULT
146                     #define GB_MULT(t, aki, bkj, i, k, j) t = k + offset
147                     #if defined ( GB_DOT2_GENERIC )
148                     #include "GB_AxB_dot2_meta.c"
149                     #elif defined ( GB_DOT3_GENERIC )
150                     #include "GB_AxB_dot3_meta.c"
151                     #else
152                     #include "GB_AxB_dot4_meta.c"
153                     #endif
154                     break ;
155                 case GB_SECONDJ_opcode  :   // z = second_j(x,B(k,j)) == j
156                 case GB_SECONDJ1_opcode :   // z = second_j1(x,B(k,j)) == j+1
157                     #undef  GB_MULT
158                     #define GB_MULT(t, aki, bkj, i, k, j) t = j + offset
159                     #if defined ( GB_DOT2_GENERIC )
160                     #include "GB_AxB_dot2_meta.c"
161                     #elif defined ( GB_DOT3_GENERIC )
162                     #include "GB_AxB_dot3_meta.c"
163                     #else
164                     #include "GB_AxB_dot4_meta.c"
165                     #endif
166                     break ;
167                 default: ;
168             }
169         }
170         else
171         {
172             #undef  GB_CTYPE
173             #define GB_CTYPE int32_t
174             int32_t cij_terminal = 0 ;
175             bool is_terminal = (terminal != NULL) ;
176             if (is_terminal)
177             {
178                 memcpy (&cij_terminal, terminal, sizeof (int32_t)) ;
179             }
180             switch (opcode)
181             {
182                 case GB_FIRSTI_opcode   :   // z = first_i(A'(i,k),y) == i
183                 case GB_FIRSTI1_opcode  :   // z = first_i1(A'(i,k),y) == i+1
184                     #undef  GB_MULT
185                     #define GB_MULT(t,aki,bkj,i,k,j) t = (int32_t) (i + offset)
186                     #if defined ( GB_DOT2_GENERIC )
187                     #include "GB_AxB_dot2_meta.c"
188                     #elif defined ( GB_DOT3_GENERIC )
189                     #include "GB_AxB_dot3_meta.c"
190                     #else
191                     #include "GB_AxB_dot4_meta.c"
192                     #endif
193                     break ;
194                 case GB_FIRSTJ_opcode   :   // z = first_j(A'(i,k),y) == k
195                 case GB_FIRSTJ1_opcode  :   // z = first_j1(A'(i,k),y) == k+1
196                 case GB_SECONDI_opcode  :   // z = second_i(x,B(k,j)) == k
197                 case GB_SECONDI1_opcode :   // z = second_i1(x,B(k,j)) == k+1
198                     #undef  GB_MULT
199                     #define GB_MULT(t,aki,bkj,i,k,j) t = (int32_t) (k + offset)
200                     #if defined ( GB_DOT2_GENERIC )
201                     #include "GB_AxB_dot2_meta.c"
202                     #elif defined ( GB_DOT3_GENERIC )
203                     #include "GB_AxB_dot3_meta.c"
204                     #else
205                     #include "GB_AxB_dot4_meta.c"
206                     #endif
207                     break ;
208                 case GB_SECONDJ_opcode  :   // z = second_j(x,B(k,j)) == j
209                 case GB_SECONDJ1_opcode :   // z = second_j1(x,B(k,j)) == j+1
210                     #undef  GB_MULT
211                     #define GB_MULT(t,aki,bkj,i,k,j) t = (int32_t) (j + offset)
212                     #if defined ( GB_DOT2_GENERIC )
213                     #include "GB_AxB_dot2_meta.c"
214                     #elif defined ( GB_DOT3_GENERIC )
215                     #include "GB_AxB_dot3_meta.c"
216                     #else
217                     #include "GB_AxB_dot4_meta.c"
218                     #endif
219                     break ;
220                 default: ;
221             }
222         }
223 
224     }
225     else
226     {
227 
228         //----------------------------------------------------------------------
229         // generic semirings with standard multiply operators
230         //----------------------------------------------------------------------
231 
232         // aki = A(k,i), located in Ax [pA]
233         #undef  GB_GETA
234         #define GB_GETA(aki,Ax,pA)                                      \
235             GB_void aki [GB_VLA(aki_size)] ;                            \
236             if (!A_is_pattern) cast_A (aki, Ax +((pA)*asize), asize)
237 
238         // bkj = B(k,j), located in Bx [pB]
239         #undef  GB_GETB
240         #define GB_GETB(bkj,Bx,pB)                                      \
241             GB_void bkj [GB_VLA(bkj_size)] ;                            \
242             if (!B_is_pattern) cast_B (bkj, Bx +((pB)*bsize), bsize)
243 
244         // define cij for each task
245         #undef  GB_CIJ_DECLARE
246         #define GB_CIJ_DECLARE(cij) GB_void cij [GB_VLA(csize)]
247 
248         // address of Cx [p]
249         #undef  GB_CX
250         #define GB_CX(p) Cx +((p)*csize)
251 
252         // cij = Cx [p]
253         #undef  GB_GETC
254         #define GB_GETC(cij,p) memcpy (cij, GB_CX (p), csize)
255 
256         // Cx [p] = cij
257         #undef  GB_PUTC
258         #define GB_PUTC(cij,p) memcpy (GB_CX (p), cij, csize)
259 
260         // break if cij reaches the terminal value
261         #undef  GB_DOT_TERMINAL
262         #define GB_DOT_TERMINAL(cij)                                    \
263             if (terminal != NULL && memcmp (cij, terminal, csize) == 0) \
264             {                                                           \
265                 break ;                                                 \
266             }
267 
268         // C(i,j) += (A')(i,k) * B(k,j)
269         #undef  GB_MULTADD
270         #define GB_MULTADD(cij, aki, bkj, i, k, j)                      \
271             GB_void zwork [GB_VLA(csize)] ;                             \
272             GB_MULT (zwork, aki, bkj, i, k, j) ;                        \
273             fadd (cij, cij, zwork)
274 
275         #undef  GB_CTYPE
276         #define GB_CTYPE GB_void
277 
278         if (opcode == GB_FIRST_opcode || opcode == GB_SECOND_opcode)
279         {
280             // fmult is not used and can be NULL (for user-defined types)
281             if (flipxy)
282             {
283                 // flip first and second
284                 bool handled ;
285                 opcode = GB_binop_flip (opcode, &handled) ; // for 1st and 2nd
286                 ASSERT (handled) ;      // FIRST and SECOND ops can be flipped
287             }
288             if (opcode == GB_FIRST_opcode)
289             {
290                 // t = A(i,k)
291                 ASSERT (B_is_pattern) ;
292                 #undef  GB_MULT
293                 #define GB_MULT(t, aik, bkj, i, k, j) memcpy (t, aik, csize)
294                 #if defined ( GB_DOT2_GENERIC )
295                 #include "GB_AxB_dot2_meta.c"
296                 #elif defined ( GB_DOT3_GENERIC )
297                 #include "GB_AxB_dot3_meta.c"
298                 #else
299                 #include "GB_AxB_dot4_meta.c"
300                 #endif
301             }
302             else // opcode == GB_SECOND_opcode
303             {
304                 // t = B(i,k)
305                 ASSERT (A_is_pattern) ;
306                 #undef  GB_MULT
307                 #define GB_MULT(t, aik, bkj, i, k, j) memcpy (t, bkj, csize)
308                 #if defined ( GB_DOT2_GENERIC )
309                 #include "GB_AxB_dot2_meta.c"
310                 #elif defined ( GB_DOT3_GENERIC )
311                 #include "GB_AxB_dot3_meta.c"
312                 #else
313                 #include "GB_AxB_dot4_meta.c"
314                 #endif
315             }
316         }
317         else
318         {
319             if (flipxy)
320             {
321                 // t = B(k,j) * (A')(i,k)
322                 #undef  GB_MULT
323                 #define GB_MULT(t, aki, bkj, i, k, j) fmult (t, bkj, aki)
324                 #if defined ( GB_DOT2_GENERIC )
325                 #include "GB_AxB_dot2_meta.c"
326                 #elif defined ( GB_DOT3_GENERIC )
327                 #include "GB_AxB_dot3_meta.c"
328                 #else
329                 #include "GB_AxB_dot4_meta.c"
330                 #endif
331             }
332             else
333             {
334                 // t = (A')(i,k) * B(k,j)
335                 #undef  GB_MULT
336                 #define GB_MULT(t, aki, bkj, i, k, j) fmult (t, aki, bkj)
337                 #if defined ( GB_DOT2_GENERIC )
338                 #include "GB_AxB_dot2_meta.c"
339                 #elif defined ( GB_DOT3_GENERIC )
340                 #include "GB_AxB_dot3_meta.c"
341                 #else
342                 #include "GB_AxB_dot4_meta.c"
343                 #endif
344             }
345         }
346     }
347 }
348 
349