1 // ============================================================================= 2 // === GPUQREngine/Include/Kernel/Apply/block_apply_chunk.cu =================== 3 // ============================================================================= 4 5 //------------------------------------------------------------------------------ 6 // block_apply_chunk macro 7 //------------------------------------------------------------------------------ 8 9 // A = A - V*T'*V'*A, for a single chunk of N columns of A, starting at column 10 // j1 and ending at j1+N-1. 11 // 12 // This function uses fixed thread geometry and loop unrolling, which requires 13 // the geometry to be known at compile time for best efficiency. It is then 14 // #include'd by the block_apply_x function (block_apply.cu). The following 15 // terms are #define'd by each specific version: 16 // 17 // ROW_PANELSIZE # of row tiles in V and A 18 // COL_PANELSIZE # of column tiles in C and A 19 // CBITTYROWS # of rows in the C bitty block 20 // CBITTYCOLS # of cols in the C bitty block 21 // ABITTYROWS # of rows in the A bitty block 22 // ABITTYCOLS # of cols in the A bitty block 23 // 24 // The C bitty must cannot be larger than the A bitty block, since additional 25 // registers are used to buffer the A matrix while the C bitty block is being 26 // computed. These buffer registers are not used while computing with the A 27 // bitty block, so for some variants of this kernel, they can be overlapped 28 // with the A bitty block. 29 // 30 // The ROW_PANELSIZE, COL_PANELSIZE, ROW_EDGE_CASE, and COL_EDGE_CASE are 31 // #define'd by the parent file(s) that include this file. The *_EDGE_CASE 32 // macros are then #undefined here. The bitty block dimensions are defined 33 // below. This file is #include'd into block_apply.cu. It is not a standalone 34 // function. 35 36 { 37 38 //-------------------------------------------------------------------------- 39 // bitty block sizes 40 //-------------------------------------------------------------------------- 41 42 #if (ROW_PANELSIZE == 3) 43 44 #if (COL_PANELSIZE == 2) 45 46 //------------------------------------------------------------------ 47 // 3-by-2 block apply 48 //------------------------------------------------------------------ 49 50 // V is 3-by-1, C is 1-by-2, A is 3-by-2 (in # tiles) 51 // 256 threads, each does a 4-by-2 block of C = T'*V'*A 52 #define CBITTYROWS 4 53 #define CBITTYCOLS 2 54 // 384 threads, each does a 4-by-4 block of A = A-V*C 55 #define ABITTYROWS 4 56 #define ABITTYCOLS 4 57 58 #else 59 60 //------------------------------------------------------------------ 61 // 3-by-1 block apply 62 //------------------------------------------------------------------ 63 64 // V is 3-by-1, C is 1-by-1, A is 3-by-1 (in # tiles) 65 // 256 threads, each does a 2-by-2 block of C = T'*V'*A 66 #define CBITTYROWS 2 67 #define CBITTYCOLS 2 68 // 384 threads, each does a 2-by-4 block of A = A-V*C 69 #define ABITTYROWS 2 70 #define ABITTYCOLS 4 71 72 #endif 73 74 #elif (ROW_PANELSIZE == 2) 75 76 #if (COL_PANELSIZE == 2) 77 78 //------------------------------------------------------------------ 79 // block_apply_2_by_2 80 //------------------------------------------------------------------ 81 82 // V is 2-by-1, C is 1-by-2, A is 2-by-2 (in # tiles) 83 // 256 threads, each does a 4-by-2 block of C = T'*V'*A 84 #define CBITTYROWS 4 85 #define CBITTYCOLS 2 86 // 256 threads, each does a 4-by-4 block of A = A-V*C 87 #define ABITTYROWS 4 88 #define ABITTYCOLS 4 89 90 #else 91 92 //------------------------------------------------------------------ 93 // block_apply_2_by_1 94 //------------------------------------------------------------------ 95 96 // V is 2-by-1, C is 1-by-1, A is 2-by-1 (in # tiles) 97 // 256 threads, each does a 2-by-2 block of C = T'*V'*A 98 #define CBITTYROWS 2 99 #define CBITTYCOLS 2 100 // 256 threads, each does a 2-by-4 block of A = A-V*C 101 #define ABITTYROWS 2 102 #define ABITTYCOLS 4 103 104 105 #endif 106 107 #else 108 109 #if (COL_PANELSIZE == 2) 110 111 //------------------------------------------------------------------ 112 // block_apply_1_by_2 113 //------------------------------------------------------------------ 114 115 // V is 1-by-1, C is 1-by-2, A is 1-by-2 (in # tiles) 116 // 256 threads, each does a 4-by-2 block of C = T'*V'*A 117 #define CBITTYROWS 2 118 #define CBITTYCOLS 4 119 // 256 threads, each does a 4-by-2 block of A = A-V*C 120 #define ABITTYROWS 2 121 #define ABITTYCOLS 4 122 123 #else 124 125 //------------------------------------------------------------------ 126 // block_apply_1_by_1 127 //------------------------------------------------------------------ 128 129 // V is 1-by-1, C is 1-by-1, A is 1-by-1 (in # tiles) 130 // 256 threads, each does a 2-by-2 block of C = T'*V'*A 131 #define CBITTYROWS 2 132 #define CBITTYCOLS 2 133 // 256 threads, each does a 2-by-2 block of A = A-V*C 134 #define ABITTYROWS 2 135 #define ABITTYCOLS 2 136 137 #endif 138 139 #endif 140 141 //-------------------------------------------------------------------------- 142 // matrix sizes and thread geometry 143 //-------------------------------------------------------------------------- 144 145 // For each outer iteration, C is M-by-N, V is (K+1)-by-M (with an extra 146 // row for T), and A is K-by-N. 147 #define K (ROW_PANELSIZE * M) 148 #define N (COL_PANELSIZE * M) 149 150 // threads to use for C=T'*(V'*A) 151 #define CTHREADS ((M * N) / (CBITTYROWS * CBITTYCOLS)) 152 153 // threads to use for A=A-V*C 154 #define ATHREADS ((K * N) / (ABITTYROWS * ABITTYCOLS)) 155 156 //-------------------------------------------------------------------------- 157 // bitty blocks for the computation 158 //-------------------------------------------------------------------------- 159 160 // Each thread owns a bitty block of C for C=T'*V'*A. The top left entry 161 // owned by a thread is C(ic,jc). Thread 0 does C(0,0), thread 1 does 162 // C(1,0) ... 163 #define ic (threadIdx.x % (M/CBITTYROWS)) 164 #define jc (threadIdx.x / (M/CBITTYROWS)) 165 #define MYCBITTYROW(ii) (ii * (M/CBITTYROWS) + ic) 166 #define MYCBITTYCOL(jj) (jj * (N/CBITTYCOLS) + jc) 167 168 // Each thread owns a bitty block of A for A=A-V*C, with top left entry 169 // A(ia,ja). Thread 0 does A(0,0), thread 1 does A(0,1), thread 2 does 170 // A(0,2), ... so that global memory loads/stores are coallesced across a 171 // warp. 172 #define ia (threadIdx.x / (N/ABITTYCOLS)) 173 #define ja (threadIdx.x % (N/ABITTYCOLS)) 174 #define MYABITTYROW(ii) (ii * (K/ABITTYROWS) + ia) 175 #define MYABITTYCOL(jj) (jj * (N/ABITTYCOLS) + ja) 176 177 //-------------------------------------------------------------------------- 178 // loading the A matrix 179 //-------------------------------------------------------------------------- 180 181 // Each thread loads a set of entries of A defined by iaload and jaload. 182 // The first entry loaded by a thread is A(iaload,jaload), and then it 183 // loads entries every ACHUNKSIZE rows after that (in the same column 184 // jaload). 185 #define iaload (threadIdx.x / N) 186 #define jaload (threadIdx.x % N) 187 #define ACHUNKSIZE (NUMTHREADS / N) 188 #define NACHUNKS CEIL (HALFTILE*N, NUMTHREADS) 189 190 int fjload = j1 + jaload ; 191 192 //-------------------------------------------------------------------------- 193 // register allocation 194 //-------------------------------------------------------------------------- 195 196 // C bitty block is no larger than the A bitty block, in both dimensions. 197 double rbit [ABITTYROWS][ABITTYCOLS] ; 198 double rrow [ABITTYROWS] ; 199 double rcol [ABITTYCOLS] ; 200 201 #if (CBITTYCOLS == ABITTYCOLS) 202 // the A bitty block is too small to hold the A buffer 203 double abuffer [NACHUNKS] ; 204 #define rbitA(i) abuffer [i] 205 #else 206 // use the last column of the A bitty block for the A buffer 207 #define rbitA(i) (rbit [i][ABITTYCOLS-1]) 208 #endif 209 210 //-------------------------------------------------------------------------- 211 // edge case 212 //-------------------------------------------------------------------------- 213 214 #ifdef ROW_EDGE_CASE 215 // check if a row is inside the front. 216 #define INSIDE_ROW(test) (test) 217 #else 218 // the row is guaranteed to reside inside the frontal matrix. 219 #define INSIDE_ROW(test) (1) 220 #endif 221 222 #ifdef COL_EDGE_CASE 223 // check if a column is inside the front. 224 #define INSIDE_COL(test) (test) 225 #else 226 // the column is guaranteed to reside inside the frontal matrix. 227 #define INSIDE_COL(test) (1) 228 #endif 229 230 bool aloader = INSIDE_COL (fjload < fn) ; 231 232 //-------------------------------------------------------------------------- 233 // C = V'*A, where V is now in shared, and A is loaded from global 234 //-------------------------------------------------------------------------- 235 236 // prefetch the first halftile of A from global to register 237 #pragma unroll 238 for (int ii = 0 ; ii < NACHUNKS ; ii++) 239 { 240 rbitA (ii) = 0 ; 241 } 242 #pragma unroll 243 for (int ii = 0 ; ii < NACHUNKS ; ii++) 244 { 245 int i = ii * ACHUNKSIZE + iaload ; 246 if (ii < NACHUNKS-1 || i < HALFTILE) 247 { 248 int fi = IFRONT (0, i) ; 249 if (aloader && INSIDE_ROW (fi < fm)) 250 { 251 rbitA (ii) = glF [fi * fn + fjload] ; 252 } 253 } 254 } 255 256 // The X=V*C computation in the prior iteration reads shC, but the same 257 // space is used to load A from the frontal matrix in this iteration. 258 __syncthreads ( ) ; 259 260 // clear the C bitty block 261 #pragma unroll 262 for (int ii = 0 ; ii < CBITTYROWS ; ii++) 263 { 264 #pragma unroll 265 for (int jj = 0 ; jj < CBITTYCOLS ; jj++) 266 { 267 rbit [ii][jj] = 0 ; 268 } 269 } 270 271 // C=V'*A for the first tile of V, which is lower triangular 272 #define FIRST_TILE 273 #include "cevta_tile.cu" 274 275 // Subsequent tiles of V are square. Result is in C bitty block. 276 for (int t = 1 ; t < ROW_PANELSIZE ; t++) 277 { 278 #include "cevta_tile.cu" 279 } 280 281 //-------------------------------------------------------------------------- 282 // write result of C=V'*A into shared, and clear the C bitty block 283 //-------------------------------------------------------------------------- 284 285 if (CTHREADS == NUMTHREADS || threadIdx.x < CTHREADS) 286 { 287 #pragma unroll 288 for (int ii = 0 ; ii < CBITTYROWS ; ii++) 289 { 290 int i = MYCBITTYROW (ii) ; 291 #pragma unroll 292 for (int jj = 0 ; jj < CBITTYCOLS ; jj++) 293 { 294 int j = MYCBITTYCOL (jj) ; 295 shC [i][j] = rbit [ii][jj] ; 296 rbit [ii][jj] = 0 ; 297 } 298 } 299 } 300 301 // make sure all of shC is available to all threads 302 __syncthreads ( ) ; 303 304 //-------------------------------------------------------------------------- 305 // C = triu(T)'*C, leaving the result in the C bitty block 306 //-------------------------------------------------------------------------- 307 308 if (CTHREADS == NUMTHREADS || threadIdx.x < CTHREADS) 309 { 310 #pragma unroll 311 for (int i = 0 ; i < M ; i++) 312 { 313 #pragma unroll 314 for (int ii = 0 ; ii < CBITTYROWS ; ii++) 315 { 316 int j = MYCBITTYROW (ii) ; 317 if (i <= j) 318 { 319 rrow [ii] = ST (i,j) ; 320 } 321 } 322 #pragma unroll 323 for (int jj = 0 ; jj < CBITTYCOLS ; jj++) 324 { 325 int j = MYCBITTYCOL (jj) ; 326 rcol [jj] = shC [i][j] ; 327 } 328 #pragma unroll 329 for (int ii = 0 ; ii < CBITTYROWS ; ii++) 330 { 331 int j = MYCBITTYROW (ii) ; 332 if (i <= j) 333 { 334 #pragma unroll 335 for (int jj = 0 ; jj < CBITTYCOLS ; jj++) 336 { 337 rbit [ii][jj] += rrow [ii] * rcol [jj] ; 338 } 339 } 340 } 341 } 342 } 343 344 // We need syncthreads here because of the write-after-read hazard. Each 345 // thread reads the old C, above, and then C is modified below with the new 346 // C, where newC = triu(T)'*oldC. 347 __syncthreads ( ) ; 348 349 //-------------------------------------------------------------------------- 350 // write the result of C = T'*C to shared memory 351 //-------------------------------------------------------------------------- 352 353 if (CTHREADS == NUMTHREADS || threadIdx.x < CTHREADS) 354 { 355 #pragma unroll 356 for (int ii = 0 ; ii < CBITTYROWS ; ii++) 357 { 358 int i = MYCBITTYROW (ii) ; 359 #pragma unroll 360 for (int jj = 0 ; jj < CBITTYCOLS ; jj++) 361 { 362 int j = MYCBITTYCOL (jj) ; 363 shC [i][j] = rbit [ii][jj] ; 364 } 365 } 366 } 367 368 // All threads come here. We need a syncthreads because 369 // shC has been written above and must be read below in A=A-V*C. 370 __syncthreads ( ) ; 371 372 //-------------------------------------------------------------------------- 373 // A = A - V*C 374 //-------------------------------------------------------------------------- 375 376 if (ATHREADS == NUMTHREADS || threadIdx.x < ATHREADS) 377 { 378 379 //---------------------------------------------------------------------- 380 // clear the A bitty block 381 //---------------------------------------------------------------------- 382 383 #pragma unroll 384 for (int ii = 0 ; ii < ABITTYROWS ; ii++) 385 { 386 #pragma unroll 387 for (int jj = 0 ; jj < ABITTYCOLS ; jj++) 388 { 389 rbit [ii][jj] = 0 ; 390 } 391 } 392 393 //---------------------------------------------------------------------- 394 // X = tril(V)*C, store result into register (rbit) 395 //---------------------------------------------------------------------- 396 397 #pragma unroll 398 for (int p = 0 ; p < M ; p++) 399 { 400 #pragma unroll 401 for (int ii = 0 ; ii < ABITTYROWS ; ii++) 402 { 403 int i = MYABITTYROW (ii) ; 404 if (i >= p) 405 { 406 rrow [ii] = shV [1+i][p] ; 407 } 408 } 409 #pragma unroll 410 for (int jj = 0 ; jj < ABITTYCOLS ; jj++) 411 { 412 int j = MYABITTYCOL (jj) ; 413 rcol [jj] = shC [p][j] ; 414 } 415 #pragma unroll 416 for (int ii = 0 ; ii < ABITTYROWS ; ii++) 417 { 418 int i = MYABITTYROW (ii) ; 419 if (i >= p) 420 { 421 #pragma unroll 422 for (int jj = 0 ; jj < ABITTYCOLS ; jj++) 423 { 424 rbit [ii][jj] += rrow [ii] * rcol [jj] ; 425 } 426 } 427 } 428 } 429 430 //---------------------------------------------------------------------- 431 // A = A - X, which finalizes the computation A = A - V*(T'*(V'*A)) 432 //---------------------------------------------------------------------- 433 434 #if (COL_PANELSIZE == 2) 435 436 #pragma unroll 437 for (int ii = 0 ; ii < ABITTYROWS ; ii++) 438 { 439 int i = MYABITTYROW (ii) ; 440 int fi = IFRONT (i / M, i % M) ; 441 #pragma unroll 442 for (int jj = 0 ; jj < ABITTYCOLS ; jj++) 443 { 444 int fj = j1 + MYABITTYCOL (jj) ; 445 if (INSIDE_ROW (fi < fm) && INSIDE_COL (fj < fn)) 446 { 447 glF [fi * fn + fj] -= rbit [ii][jj] ; 448 } 449 } 450 } 451 452 #else 453 454 #pragma unroll 455 for (int ii = 0 ; ii < ABITTYROWS ; ii++) 456 { 457 int i = MYABITTYROW (ii) ; 458 int fi = IFRONT (i / M, i % M) ; 459 #pragma unroll 460 for (int jj = 0 ; jj < ABITTYCOLS ; jj++) 461 { 462 int fj = j1 + MYABITTYCOL (jj) ; 463 if (INSIDE_ROW (fi < fm) && INSIDE_COL (fj < fn)) 464 { 465 shV[i][MYABITTYCOL(jj)] = glF[fi*fn+fj] - rbit[ii][jj]; 466 } 467 else 468 { 469 shV[i][MYABITTYCOL(jj)] = 0.0; 470 } 471 } 472 } 473 474 #endif 475 } 476 477 //-------------------------------------------------------------------------- 478 // sync 479 //-------------------------------------------------------------------------- 480 481 // The X=V*C computation in this iteration reads shC, but the same space is 482 // used to load A from the frontal matrix in C=V'*A in the next iteration. 483 // This final sync also ensures that all threads finish the block_apply 484 // at the same time. Thus, no syncthreads is needed at the start of a 485 // subsequent function (the pipelined apply+factorize, for example). 486 487 __syncthreads ( ) ; 488 } 489 490 //------------------------------------------------------------------------------ 491 // undef's 492 //------------------------------------------------------------------------------ 493 494 // The following #define's appear above. Note that FIRST_TILE is not #undef'd 495 // since that is done by cevta_tile.cu. 496 #undef CBITTYROWS 497 #undef CBITTYCOLS 498 #undef ABITTYROWS 499 #undef ABITTYCOLS 500 501 #undef K 502 #undef N 503 504 #undef CTHREADS 505 #undef ATHREADS 506 507 #undef ic 508 #undef jc 509 #undef MYCBITTYROW 510 #undef MYCBITTYCOL 511 512 #undef ia 513 #undef ja 514 #undef MYABITTYROW 515 #undef MYABITTYCOL 516 517 #undef iaload 518 #undef jaload 519 #undef ACHUNKSIZE 520 #undef NACHUNKS 521 522 #undef rbitA 523 #undef INSIDE_ROW 524 #undef INSIDE_COL 525 526 // Defined in the parent file that includes this one. Note that ROW_PANELSIZE 527 // is not #undef'd, since that is done in the parent. 528 #undef ROW_EDGE_CASE 529 #undef COL_EDGE_CASE 530