1 //------------------------------------------------------------------------------ 2 // GB_AxB_dot4_template: C+=A'*B via dot products, where C is dense 3 //------------------------------------------------------------------------------ 4 5 // SuiteSparse:GraphBLAS, Timothy A. Davis, (c) 2017-2021, All Rights Reserved. 6 // SPDX-License-Identifier: Apache-2.0 7 8 //------------------------------------------------------------------------------ 9 10 // C+=A'*B where C is a dense matrix and computed in-place. The monoid of the 11 // semiring matches the accum operator, and the type of C matches the ztype of 12 // accum. That is, no typecasting can be done with C. 13 14 // The PAIR operator as the multiplier provides important special cases. 15 16 { 17 18 //-------------------------------------------------------------------------- 19 // C += A'*B 20 //-------------------------------------------------------------------------- 21 22 int tid ; 23 #pragma omp parallel for num_threads(nthreads) schedule(dynamic,1) 24 for (tid = 0 ; tid < ntasks ; tid++) 25 { 26 27 //---------------------------------------------------------------------- 28 // get the task descriptor 29 //---------------------------------------------------------------------- 30 31 const int a_tid = tid / nbslice ; 32 const int b_tid = tid % nbslice ; 33 const int64_t kA_start = A_slice [a_tid] ; 34 const int64_t kA_end = A_slice [a_tid+1] ; 35 const int64_t kB_start = B_slice [b_tid] ; 36 const int64_t kB_end = B_slice [b_tid+1] ; 37 38 //---------------------------------------------------------------------- 39 // C+=A'*B via dot products 40 //---------------------------------------------------------------------- 41 42 for (int64_t kB = kB_start ; kB < kB_end ; kB++) 43 { 44 45 //------------------------------------------------------------------ 46 // get B(:,j) and C(:,j) 47 //------------------------------------------------------------------ 48 49 #if GB_B_IS_HYPER 50 const int64_t j = Bh [kB] ; 51 #else 52 const int64_t j = kB ; 53 #endif 54 55 const int64_t pC_start = j * cvlen ; 56 57 #if ( GB_B_IS_HYPER || GB_B_IS_SPARSE ) 58 // B is sparse or hyper 59 const int64_t pB_start = Bp [kB] ; 60 const int64_t pB_end = Bp [kB+1] ; 61 const int64_t bjnz = pB_end - pB_start ; 62 if (bjnz == 0) continue ; 63 #if ( GB_A_IS_HYPER || GB_A_IS_SPARSE ) 64 // Both A and B are sparse/hyper; get first & last in B(:,j) 65 const int64_t ib_first = Bi [pB_start] ; 66 const int64_t ib_last = Bi [pB_end-1] ; 67 #endif 68 #else 69 // B is bitmap or full 70 const int64_t pB_start = j * vlen ; 71 #endif 72 73 //------------------------------------------------------------------ 74 // C(:,j) += A'*B(:,j) where C is full 75 //------------------------------------------------------------------ 76 77 for (int64_t kA = kA_start ; kA < kA_end ; kA++) 78 { 79 80 //-------------------------------------------------------------- 81 // get A(:,i) 82 //-------------------------------------------------------------- 83 84 #if GB_A_IS_HYPER 85 const int64_t i = Ah [kA] ; 86 #else 87 const int64_t i = kA ; 88 #endif 89 90 #if ( GB_A_IS_HYPER || GB_A_IS_SPARSE ) 91 // A is sparse or hyper 92 int64_t pA = Ap [kA] ; 93 const int64_t pA_end = Ap [kA+1] ; 94 const int64_t ainz = pA_end - pA ; 95 if (ainz == 0) continue ; 96 #else 97 // A is bitmap or full 98 const int64_t pA = kA * vlen ; 99 #endif 100 101 //-------------------------------------------------------------- 102 // get C(i,j) 103 //-------------------------------------------------------------- 104 105 GB_CIJ_DECLARE (cij) ; // declare the cij scalar 106 int64_t pC = i + pC_start ; // C(i,j) is at Cx [pC] 107 bool cij_updated = false ; 108 109 //-------------------------------------------------------------- 110 // C(i,j) += A (:,i)*B(:,j): a single dot product 111 //-------------------------------------------------------------- 112 113 int64_t pB = pB_start ; 114 115 #if ( GB_A_IS_FULL && GB_B_IS_FULL ) 116 { 117 118 //---------------------------------------------------------- 119 // both A and B are full 120 //---------------------------------------------------------- 121 122 GB_GETC (cij, pC) ; // cij = Cx [pC] 123 #if GB_IS_PAIR_MULTIPLIER 124 { 125 #if GB_IS_ANY_MONOID 126 // ANY monoid: take the first entry found 127 GB_MULT (cij, ignore, ignore, 0, 0, 0) ; 128 #elif GB_IS_EQ_MONOID 129 // EQ_PAIR semiring 130 cij = (cij == 1) ; 131 #elif (GB_CTYPE_BITS > 0) 132 // PLUS, XOR monoids: A(:,i)'*B(:,j) is nnz(A(:,i)), 133 // for bool, 8-bit, 16-bit, or 32-bit integer 134 uint64_t t = ((uint64_t) cij) + vlen ; 135 cij = (GB_CTYPE) (t & GB_CTYPE_BITS) ; 136 #elif GB_IS_PLUS_FC32_MONOID 137 // PLUS monoid for float complex 138 cij = GxB_CMPLXF (crealf (cij) + (float) vlen, 0) ; 139 #elif GB_IS_PLUS_FC64_MONOID 140 // PLUS monoid for double complex 141 cij = GxB_CMPLX (creal (cij) + (double) vlen, 0) ; 142 #else 143 // PLUS monoid for float, double, or 64-bit integers 144 cij += (GB_CTYPE) vlen ; 145 #endif 146 } 147 #else 148 { 149 GB_PRAGMA_SIMD_DOT (cij) 150 for (int64_t k = 0 ; k < vlen ; k++) 151 { 152 GB_DOT_TERMINAL (cij) ; // break if terminal 153 // cij += A(k,i) * B(k,j) 154 GB_GETA (aki, Ax, pA+k) ; // aki = A(k,i) 155 GB_GETB (bkj, Bx, pB+k) ; // bkj = B(k,j) 156 // cij += aki * bkj 157 GB_MULTADD (cij, aki, bkj, i, k, j) ; 158 } 159 } 160 #endif 161 GB_DOT_ALWAYS_SAVE_CIJ ; 162 163 } 164 #elif ( GB_A_IS_FULL && GB_B_IS_BITMAP ) 165 { 166 167 //---------------------------------------------------------- 168 // A is full and B is bitmap 169 //---------------------------------------------------------- 170 171 for (int64_t k = 0 ; k < vlen ; k++) 172 { 173 if (Bb [pB+k]) 174 { 175 GB_DOT (k, pA+k, pB+k) ; 176 } 177 } 178 GB_DOT_SAVE_CIJ ; 179 180 } 181 #elif ( GB_A_IS_FULL && ( GB_B_IS_SPARSE || GB_B_IS_HYPER ) ) 182 { 183 184 //---------------------------------------------------------- 185 // A is full and B is sparse/hyper 186 //---------------------------------------------------------- 187 188 GB_GETC (cij, pC) ; // cij = Cx [pC] 189 #if GB_IS_PAIR_MULTIPLIER 190 { 191 #if GB_IS_ANY_MONOID 192 // ANY monoid: take the first entry found 193 // cij = 1, or CMPLX(1,0) for complex ANY 194 GB_MULT (cij, ignore, ignore, 0, 0, 0) ; 195 #elif GB_IS_EQ_MONOID 196 // EQ_PAIR semiring 197 cij = (cij == 1) ; 198 #elif (GB_CTYPE_BITS > 0) 199 // PLUS, XOR monoids: A(:,i)'*B(:,j) is nnz(A(:,i)), 200 // for bool, 8-bit, 16-bit, or 32-bit integer 201 uint64_t t = ((uint64_t) cij) + bjnz ; 202 cij = (GB_CTYPE) (t & GB_CTYPE_BITS) ; 203 #elif GB_IS_PLUS_FC32_MONOID 204 // PLUS monoid for float complex 205 cij = GxB_CMPLXF (crealf (cij) + (float) bjnz, 0) ; 206 #elif GB_IS_PLUS_FC64_MONOID 207 // PLUS monoid for double complex 208 cij = GxB_CMPLX (creal (cij) + (double) bjnz, 0) ; 209 #else 210 // PLUS monoid for float, double, or 64-bit integers 211 cij += (GB_CTYPE) bjnz ; 212 #endif 213 } 214 #else 215 { 216 GB_PRAGMA_SIMD_DOT (cij) 217 for (int64_t p = pB ; p < pB_end ; p++) 218 { 219 GB_DOT_TERMINAL (cij) ; // break if terminal 220 int64_t k = Bi [p] ; 221 // cij += A(k,i) * B(k,j) 222 GB_GETA (aki, Ax, pA+k) ; // aki = A(k,i) 223 GB_GETB (bkj, Bx, p ) ; // bkj = B(k,j) 224 GB_MULTADD (cij, aki, bkj, i, k, j) ; 225 } 226 } 227 #endif 228 GB_DOT_ALWAYS_SAVE_CIJ ; 229 230 } 231 #elif ( GB_A_IS_BITMAP && GB_B_IS_FULL ) 232 { 233 234 //---------------------------------------------------------- 235 // A is bitmap and B is full 236 //---------------------------------------------------------- 237 238 for (int64_t k = 0 ; k < vlen ; k++) 239 { 240 if (Ab [pA+k]) 241 { 242 GB_DOT (k, pA+k, pB+k) ; 243 } 244 } 245 GB_DOT_SAVE_CIJ ; 246 247 } 248 #elif ( GB_A_IS_BITMAP && GB_B_IS_BITMAP ) 249 { 250 251 //---------------------------------------------------------- 252 // both A and B are bitmap 253 //---------------------------------------------------------- 254 255 for (int64_t k = 0 ; k < vlen ; k++) 256 { 257 if (Ab [pA+k] && Bb [pB+k]) 258 { 259 GB_DOT (k, pA+k, pB+k) ; 260 } 261 } 262 GB_DOT_SAVE_CIJ ; 263 264 } 265 #elif ( GB_A_IS_BITMAP && ( GB_B_IS_SPARSE || GB_B_IS_HYPER ) ) 266 { 267 268 //---------------------------------------------------------- 269 // A is bitmap and B is sparse/hyper 270 //---------------------------------------------------------- 271 272 for (int64_t p = pB ; p < pB_end ; p++) 273 { 274 int64_t k = Bi [p] ; 275 if (Ab [pA+k]) 276 { 277 GB_DOT (k, pA+k, p) ; 278 } 279 } 280 GB_DOT_SAVE_CIJ ; 281 282 } 283 #elif ( (GB_A_IS_SPARSE || GB_A_IS_HYPER) && GB_B_IS_FULL ) 284 { 285 286 //---------------------------------------------------------- 287 // A is sparse/hyper and B is full 288 //---------------------------------------------------------- 289 290 GB_GETC (cij, pC) ; // cij = Cx [pC] 291 #if GB_IS_PAIR_MULTIPLIER 292 { 293 #if GB_IS_ANY_MONOID 294 // ANY monoid: take the first entry found 295 GB_MULT (cij, ignore, ignore, 0, 0, 0) ; 296 #elif GB_IS_EQ_MONOID 297 // EQ_PAIR semiring 298 cij = (cij == 1) ; 299 #elif (GB_CTYPE_BITS > 0) 300 // PLUS, XOR monoids: A(:,i)'*B(:,j) is nnz(A(:,i)), 301 // for bool, 8-bit, 16-bit, or 32-bit integer 302 uint64_t t = ((uint64_t) cij) + ainz ; 303 cij = (GB_CTYPE) (t & GB_CTYPE_BITS) ; 304 #elif GB_IS_PLUS_FC32_MONOID 305 // PLUS monoid for float complex 306 cij = GxB_CMPLXF (crealf (cij) + (float) ainz, 0) ; 307 #elif GB_IS_PLUS_FC64_MONOID 308 // PLUS monoid for double complex 309 cij = GxB_CMPLX (creal (cij) + (double) ainz, 0) ; 310 #else 311 // PLUS monoid for float, double, or 64-bit integers 312 cij += (GB_CTYPE) ainz ; 313 #endif 314 } 315 #else 316 { 317 GB_PRAGMA_SIMD_DOT (cij) 318 for (int64_t p = pA ; p < pA_end ; p++) 319 { 320 GB_DOT_TERMINAL (cij) ; // break if terminal 321 int64_t k = Ai [p] ; 322 // cij += A(k,i) * B(k,j) 323 GB_GETA (aki, Ax, p ) ; // aki = A(k,i) 324 GB_GETB (bkj, Bx, pB+k) ; // bkj = B(k,j) 325 GB_MULTADD (cij, aki, bkj, i, k, j) ; 326 } 327 } 328 #endif 329 GB_DOT_ALWAYS_SAVE_CIJ ; 330 331 } 332 #elif ( (GB_A_IS_SPARSE || GB_A_IS_HYPER) && GB_B_IS_BITMAP ) 333 { 334 335 //---------------------------------------------------------- 336 // A is sparse/hyper and B is bitmap 337 //---------------------------------------------------------- 338 339 for (int64_t p = pA ; p < pA_end ; p++) 340 { 341 int64_t k = Ai [p] ; 342 if (Bb [pB+k]) 343 { 344 GB_DOT (k, p, pB+k) ; 345 } 346 } 347 GB_DOT_SAVE_CIJ ; 348 349 } 350 #else 351 { 352 353 //---------------------------------------------------------- 354 // both A and B are sparse/hyper 355 //---------------------------------------------------------- 356 357 if (Ai [pA_end-1] < ib_first || ib_last < Ai [pA]) 358 { 359 360 //------------------------------------------------------ 361 // pattern of A(:,i) and B(:,j) don't overlap 362 //------------------------------------------------------ 363 364 } 365 else if (ainz > 8 * bjnz) 366 { 367 368 //------------------------------------------------------ 369 // B(:,j) is very sparse compared to A(:,i) 370 //------------------------------------------------------ 371 372 while (pA < pA_end && pB < pB_end) 373 { 374 int64_t ia = Ai [pA] ; 375 int64_t ib = Bi [pB] ; 376 if (ia < ib) 377 { 378 // A(ia,i) appears before B(ib,j) 379 // discard all entries A(ia:ib-1,i) 380 int64_t pleft = pA + 1 ; 381 int64_t pright = pA_end - 1 ; 382 GB_TRIM_BINARY_SEARCH (ib, Ai, pleft, pright) ; 383 ASSERT (pleft > pA) ; 384 pA = pleft ; 385 } 386 else if (ib < ia) 387 { 388 // B(ib,j) appears before A(ia,i) 389 pB++ ; 390 } 391 else // ia == ib == k 392 { 393 // A(k,i) and B(k,j) are next entries to merge 394 GB_DOT (ia, pA, pB) ; 395 pA++ ; 396 pB++ ; 397 } 398 } 399 GB_DOT_SAVE_CIJ ; 400 401 } 402 else if (bjnz > 8 * ainz) 403 { 404 405 //------------------------------------------------------ 406 // A(:,i) is very sparse compared to B(:,j) 407 //------------------------------------------------------ 408 409 while (pA < pA_end && pB < pB_end) 410 { 411 int64_t ia = Ai [pA] ; 412 int64_t ib = Bi [pB] ; 413 if (ia < ib) 414 { 415 // A(ia,i) appears before B(ib,j) 416 pA++ ; 417 } 418 else if (ib < ia) 419 { 420 // B(ib,j) appears before A(ia,i) 421 // discard all entries B(ib:ia-1,j) 422 int64_t pleft = pB + 1 ; 423 int64_t pright = pB_end - 1 ; 424 GB_TRIM_BINARY_SEARCH (ia, Bi, pleft, pright) ; 425 ASSERT (pleft > pB) ; 426 pB = pleft ; 427 } 428 else // ia == ib == k 429 { 430 // A(k,i) and B(k,j) are next entries to merge 431 GB_DOT (ia, pA, pB) ; 432 pA++ ; 433 pB++ ; 434 } 435 } 436 GB_DOT_SAVE_CIJ ; 437 438 } 439 else 440 { 441 442 //------------------------------------------------------ 443 // A(:,i) and B(:,j) have about the same sparsity 444 //------------------------------------------------------ 445 446 while (pA < pA_end && pB < pB_end) 447 { 448 int64_t ia = Ai [pA] ; 449 int64_t ib = Bi [pB] ; 450 if (ia < ib) 451 { 452 // A(ia,i) appears before B(ib,j) 453 pA++ ; 454 } 455 else if (ib < ia) 456 { 457 // B(ib,j) appears before A(ia,i) 458 pB++ ; 459 } 460 else // ia == ib == k 461 { 462 // A(k,i) and B(k,j) are the entries to merge 463 GB_DOT (ia, pA, pB) ; 464 pA++ ; 465 pB++ ; 466 } 467 } 468 GB_DOT_SAVE_CIJ ; 469 } 470 } 471 #endif 472 } 473 } 474 } 475 } 476 477