1 //------------------------------------------------------------------------------ 2 // GB_AxB_dot_generic: generic template for all dot-product methods 3 //------------------------------------------------------------------------------ 4 5 // SuiteSparse:GraphBLAS, Timothy A. Davis, (c) 2017-2021, All Rights Reserved. 6 // SPDX-License-Identifier: Apache-2.0 7 8 //------------------------------------------------------------------------------ 9 10 // This template serves all three dot product methods. The #including file 11 // defines GB_DOT2_GENERIC, GB_DOT3_GENERIC, or GB_DOT4_GENERIC. 12 13 { 14 15 //-------------------------------------------------------------------------- 16 // get operators, functions, workspace, contents of A, B, C 17 //-------------------------------------------------------------------------- 18 19 GxB_binary_function fmult = mult->function ; // NULL if positional 20 GxB_binary_function fadd = add->op->function ; 21 GB_Opcode opcode = mult->opcode ; 22 bool op_is_positional = GB_OPCODE_IS_POSITIONAL (opcode) ; 23 24 size_t csize = C->type->size ; 25 size_t asize = A_is_pattern ? 0 : A->type->size ; 26 size_t bsize = B_is_pattern ? 0 : B->type->size ; 27 28 size_t xsize = mult->xtype->size ; 29 size_t ysize = mult->ytype->size ; 30 31 // scalar workspace: because of typecasting, the x/y types need not 32 // be the same as the size of the A and B types. 33 // flipxy false: aki = (xtype) A(k,i) and bkj = (ytype) B(k,j) 34 // flipxy true: aki = (ytype) A(k,i) and bkj = (xtype) B(k,j) 35 size_t aki_size = flipxy ? ysize : xsize ; 36 size_t bkj_size = flipxy ? xsize : ysize ; 37 38 GB_void *restrict terminal = (GB_void *) add->terminal ; 39 40 GB_cast_function cast_A, cast_B ; 41 if (flipxy) 42 { 43 // A is typecasted to y, and B is typecasted to x 44 cast_A = A_is_pattern ? NULL : 45 GB_cast_factory (mult->ytype->code, A->type->code) ; 46 cast_B = B_is_pattern ? NULL : 47 GB_cast_factory (mult->xtype->code, B->type->code) ; 48 } 49 else 50 { 51 // A is typecasted to x, and B is typecasted to y 52 cast_A = A_is_pattern ? NULL : 53 GB_cast_factory (mult->xtype->code, A->type->code) ; 54 cast_B = B_is_pattern ? NULL : 55 GB_cast_factory (mult->ytype->code, B->type->code) ; 56 } 57 58 //-------------------------------------------------------------------------- 59 // C = A'*B via dot products, function pointers, and typecasting 60 //-------------------------------------------------------------------------- 61 62 #define GB_ATYPE GB_void 63 #define GB_BTYPE GB_void 64 #define GB_PHASE_2_OF_2 65 66 // no vectorization 67 #define GB_PRAGMA_SIMD_VECTORIZE ; 68 #define GB_PRAGMA_SIMD_DOT(cij) ; 69 70 if (op_is_positional) 71 { 72 73 //---------------------------------------------------------------------- 74 // generic semirings with positional multiply operators 75 //---------------------------------------------------------------------- 76 77 if (flipxy) 78 { 79 // flip a positional multiplicative operator 80 bool handled ; 81 opcode = GB_binop_flip (opcode, &handled) ; // for positional ops 82 ASSERT (handled) ; // all positional ops can be flipped 83 } 84 85 // aki = A(i,k), located in Ax [pA], value not used 86 #define GB_GETA(aki,Ax,pA) ; 87 88 // bkj = B(k,j), located in Bx [pB], value not used 89 #define GB_GETB(bkj,Bx,pB) ; 90 91 // define cij for each task 92 #define GB_CIJ_DECLARE(cij) GB_CTYPE cij 93 94 // address of Cx [p] 95 #define GB_CX(p) (&Cx [p]) 96 97 // cij = Cx [p] 98 #define GB_GETC(cij,p) cij = Cx [p] 99 100 // Cx [p] = cij 101 #define GB_PUTC(cij,p) Cx [p] = cij 102 103 // break if cij reaches the terminal value 104 #define GB_DOT_TERMINAL(cij) \ 105 if (is_terminal && cij == cij_terminal) \ 106 { \ 107 break ; \ 108 } 109 110 // C(i,j) += (A')(i,k) * B(k,j) 111 #define GB_MULTADD(cij, aki, bkj, i, k, j) \ 112 GB_CTYPE zwork ; \ 113 GB_MULT (zwork, aki, bkj, i, k, j) ; \ 114 fadd (&cij, &cij, &zwork) 115 116 int64_t offset = GB_positional_offset (opcode) ; 117 118 if (mult->ztype == GrB_INT64) 119 { 120 #define GB_CTYPE int64_t 121 int64_t cij_terminal = 0 ; 122 bool is_terminal = (terminal != NULL) ; 123 if (is_terminal) 124 { 125 memcpy (&cij_terminal, terminal, sizeof (int64_t)) ; 126 } 127 switch (opcode) 128 { 129 case GB_FIRSTI_opcode : // z = first_i(A'(i,k),y) == i 130 case GB_FIRSTI1_opcode : // z = first_i1(A'(i,k),y) == i+1 131 #undef GB_MULT 132 #define GB_MULT(t, aki, bkj, i, k, j) t = i + offset 133 #if defined ( GB_DOT2_GENERIC ) 134 #include "GB_AxB_dot2_meta.c" 135 #elif defined ( GB_DOT3_GENERIC ) 136 #include "GB_AxB_dot3_meta.c" 137 #else 138 #include "GB_AxB_dot4_meta.c" 139 #endif 140 break ; 141 case GB_FIRSTJ_opcode : // z = first_j(A'(i,k),y) == k 142 case GB_FIRSTJ1_opcode : // z = first_j1(A'(i,k),y) == k+1 143 case GB_SECONDI_opcode : // z = second_i(x,B(k,j)) == k 144 case GB_SECONDI1_opcode : // z = second_i1(x,B(k,j)) == k+1 145 #undef GB_MULT 146 #define GB_MULT(t, aki, bkj, i, k, j) t = k + offset 147 #if defined ( GB_DOT2_GENERIC ) 148 #include "GB_AxB_dot2_meta.c" 149 #elif defined ( GB_DOT3_GENERIC ) 150 #include "GB_AxB_dot3_meta.c" 151 #else 152 #include "GB_AxB_dot4_meta.c" 153 #endif 154 break ; 155 case GB_SECONDJ_opcode : // z = second_j(x,B(k,j)) == j 156 case GB_SECONDJ1_opcode : // z = second_j1(x,B(k,j)) == j+1 157 #undef GB_MULT 158 #define GB_MULT(t, aki, bkj, i, k, j) t = j + offset 159 #if defined ( GB_DOT2_GENERIC ) 160 #include "GB_AxB_dot2_meta.c" 161 #elif defined ( GB_DOT3_GENERIC ) 162 #include "GB_AxB_dot3_meta.c" 163 #else 164 #include "GB_AxB_dot4_meta.c" 165 #endif 166 break ; 167 default: ; 168 } 169 } 170 else 171 { 172 #undef GB_CTYPE 173 #define GB_CTYPE int32_t 174 int32_t cij_terminal = 0 ; 175 bool is_terminal = (terminal != NULL) ; 176 if (is_terminal) 177 { 178 memcpy (&cij_terminal, terminal, sizeof (int32_t)) ; 179 } 180 switch (opcode) 181 { 182 case GB_FIRSTI_opcode : // z = first_i(A'(i,k),y) == i 183 case GB_FIRSTI1_opcode : // z = first_i1(A'(i,k),y) == i+1 184 #undef GB_MULT 185 #define GB_MULT(t,aki,bkj,i,k,j) t = (int32_t) (i + offset) 186 #if defined ( GB_DOT2_GENERIC ) 187 #include "GB_AxB_dot2_meta.c" 188 #elif defined ( GB_DOT3_GENERIC ) 189 #include "GB_AxB_dot3_meta.c" 190 #else 191 #include "GB_AxB_dot4_meta.c" 192 #endif 193 break ; 194 case GB_FIRSTJ_opcode : // z = first_j(A'(i,k),y) == k 195 case GB_FIRSTJ1_opcode : // z = first_j1(A'(i,k),y) == k+1 196 case GB_SECONDI_opcode : // z = second_i(x,B(k,j)) == k 197 case GB_SECONDI1_opcode : // z = second_i1(x,B(k,j)) == k+1 198 #undef GB_MULT 199 #define GB_MULT(t,aki,bkj,i,k,j) t = (int32_t) (k + offset) 200 #if defined ( GB_DOT2_GENERIC ) 201 #include "GB_AxB_dot2_meta.c" 202 #elif defined ( GB_DOT3_GENERIC ) 203 #include "GB_AxB_dot3_meta.c" 204 #else 205 #include "GB_AxB_dot4_meta.c" 206 #endif 207 break ; 208 case GB_SECONDJ_opcode : // z = second_j(x,B(k,j)) == j 209 case GB_SECONDJ1_opcode : // z = second_j1(x,B(k,j)) == j+1 210 #undef GB_MULT 211 #define GB_MULT(t,aki,bkj,i,k,j) t = (int32_t) (j + offset) 212 #if defined ( GB_DOT2_GENERIC ) 213 #include "GB_AxB_dot2_meta.c" 214 #elif defined ( GB_DOT3_GENERIC ) 215 #include "GB_AxB_dot3_meta.c" 216 #else 217 #include "GB_AxB_dot4_meta.c" 218 #endif 219 break ; 220 default: ; 221 } 222 } 223 224 } 225 else 226 { 227 228 //---------------------------------------------------------------------- 229 // generic semirings with standard multiply operators 230 //---------------------------------------------------------------------- 231 232 // aki = A(k,i), located in Ax [pA] 233 #undef GB_GETA 234 #define GB_GETA(aki,Ax,pA) \ 235 GB_void aki [GB_VLA(aki_size)] ; \ 236 if (!A_is_pattern) cast_A (aki, Ax +((pA)*asize), asize) 237 238 // bkj = B(k,j), located in Bx [pB] 239 #undef GB_GETB 240 #define GB_GETB(bkj,Bx,pB) \ 241 GB_void bkj [GB_VLA(bkj_size)] ; \ 242 if (!B_is_pattern) cast_B (bkj, Bx +((pB)*bsize), bsize) 243 244 // define cij for each task 245 #undef GB_CIJ_DECLARE 246 #define GB_CIJ_DECLARE(cij) GB_void cij [GB_VLA(csize)] 247 248 // address of Cx [p] 249 #undef GB_CX 250 #define GB_CX(p) Cx +((p)*csize) 251 252 // cij = Cx [p] 253 #undef GB_GETC 254 #define GB_GETC(cij,p) memcpy (cij, GB_CX (p), csize) 255 256 // Cx [p] = cij 257 #undef GB_PUTC 258 #define GB_PUTC(cij,p) memcpy (GB_CX (p), cij, csize) 259 260 // break if cij reaches the terminal value 261 #undef GB_DOT_TERMINAL 262 #define GB_DOT_TERMINAL(cij) \ 263 if (terminal != NULL && memcmp (cij, terminal, csize) == 0) \ 264 { \ 265 break ; \ 266 } 267 268 // C(i,j) += (A')(i,k) * B(k,j) 269 #undef GB_MULTADD 270 #define GB_MULTADD(cij, aki, bkj, i, k, j) \ 271 GB_void zwork [GB_VLA(csize)] ; \ 272 GB_MULT (zwork, aki, bkj, i, k, j) ; \ 273 fadd (cij, cij, zwork) 274 275 #undef GB_CTYPE 276 #define GB_CTYPE GB_void 277 278 if (opcode == GB_FIRST_opcode || opcode == GB_SECOND_opcode) 279 { 280 // fmult is not used and can be NULL (for user-defined types) 281 if (flipxy) 282 { 283 // flip first and second 284 bool handled ; 285 opcode = GB_binop_flip (opcode, &handled) ; // for 1st and 2nd 286 ASSERT (handled) ; // FIRST and SECOND ops can be flipped 287 } 288 if (opcode == GB_FIRST_opcode) 289 { 290 // t = A(i,k) 291 ASSERT (B_is_pattern) ; 292 #undef GB_MULT 293 #define GB_MULT(t, aik, bkj, i, k, j) memcpy (t, aik, csize) 294 #if defined ( GB_DOT2_GENERIC ) 295 #include "GB_AxB_dot2_meta.c" 296 #elif defined ( GB_DOT3_GENERIC ) 297 #include "GB_AxB_dot3_meta.c" 298 #else 299 #include "GB_AxB_dot4_meta.c" 300 #endif 301 } 302 else // opcode == GB_SECOND_opcode 303 { 304 // t = B(i,k) 305 ASSERT (A_is_pattern) ; 306 #undef GB_MULT 307 #define GB_MULT(t, aik, bkj, i, k, j) memcpy (t, bkj, csize) 308 #if defined ( GB_DOT2_GENERIC ) 309 #include "GB_AxB_dot2_meta.c" 310 #elif defined ( GB_DOT3_GENERIC ) 311 #include "GB_AxB_dot3_meta.c" 312 #else 313 #include "GB_AxB_dot4_meta.c" 314 #endif 315 } 316 } 317 else 318 { 319 if (flipxy) 320 { 321 // t = B(k,j) * (A')(i,k) 322 #undef GB_MULT 323 #define GB_MULT(t, aki, bkj, i, k, j) fmult (t, bkj, aki) 324 #if defined ( GB_DOT2_GENERIC ) 325 #include "GB_AxB_dot2_meta.c" 326 #elif defined ( GB_DOT3_GENERIC ) 327 #include "GB_AxB_dot3_meta.c" 328 #else 329 #include "GB_AxB_dot4_meta.c" 330 #endif 331 } 332 else 333 { 334 // t = (A')(i,k) * B(k,j) 335 #undef GB_MULT 336 #define GB_MULT(t, aki, bkj, i, k, j) fmult (t, aki, bkj) 337 #if defined ( GB_DOT2_GENERIC ) 338 #include "GB_AxB_dot2_meta.c" 339 #elif defined ( GB_DOT3_GENERIC ) 340 #include "GB_AxB_dot3_meta.c" 341 #else 342 #include "GB_AxB_dot4_meta.c" 343 #endif 344 } 345 } 346 } 347 } 348 349