1 // =============================================================================
2 // === GPUQREngine/Include/Kernel/Apply/block_apply_chunk.cu ===================
3 // =============================================================================
4 
5 //------------------------------------------------------------------------------
6 // block_apply_chunk macro
7 //------------------------------------------------------------------------------
8 
9 // A = A - V*T'*V'*A, for a single chunk of N columns of A, starting at column
10 // j1 and ending at j1+N-1.
11 //
12 // This function uses fixed thread geometry and loop unrolling, which requires
13 // the geometry to be known at compile time for best efficiency.  It is then
14 // #include'd by the block_apply_x function (block_apply.cu).  The following
15 // terms are #define'd by each specific version:
16 //
17 //      ROW_PANELSIZE    # of row tiles in V and A
18 //      COL_PANELSIZE    # of column tiles in C and A
19 //      CBITTYROWS       # of rows in the C bitty block
20 //      CBITTYCOLS       # of cols in the C bitty block
21 //      ABITTYROWS       # of rows in the A bitty block
22 //      ABITTYCOLS       # of cols in the A bitty block
23 //
24 // The C bitty must cannot be larger than the A bitty block, since additional
25 // registers are used to buffer the A matrix while the C bitty block is being
26 // computed.  These buffer registers are not used while computing with the A
27 // bitty block, so for some variants of this kernel, they can be overlapped
28 // with the A bitty block.
29 //
30 // The ROW_PANELSIZE, COL_PANELSIZE, ROW_EDGE_CASE, and COL_EDGE_CASE are
31 // #define'd by the parent file(s) that include this file.  The *_EDGE_CASE
32 // macros are then #undefined here.  The bitty block dimensions are defined
33 // below.  This file is #include'd into block_apply.cu.  It is not a standalone
34 // function.
35 
36 {
37 
38     //--------------------------------------------------------------------------
39     // bitty block sizes
40     //--------------------------------------------------------------------------
41 
42     #if (ROW_PANELSIZE == 3)
43 
44         #if (COL_PANELSIZE == 2)
45 
46             //------------------------------------------------------------------
47             // 3-by-2 block apply
48             //------------------------------------------------------------------
49 
50             // V is 3-by-1, C is 1-by-2, A is 3-by-2 (in # tiles)
51             // 256 threads, each does a 4-by-2 block of C = T'*V'*A
52             #define CBITTYROWS      4
53             #define CBITTYCOLS      2
54             // 384 threads, each does a 4-by-4 block of A = A-V*C
55             #define ABITTYROWS      4
56             #define ABITTYCOLS      4
57 
58         #else
59 
60             //------------------------------------------------------------------
61             // 3-by-1 block apply
62             //------------------------------------------------------------------
63 
64             // V is 3-by-1, C is 1-by-1, A is 3-by-1 (in # tiles)
65             // 256 threads, each does a 2-by-2 block of C = T'*V'*A
66             #define CBITTYROWS      2
67             #define CBITTYCOLS      2
68             // 384 threads, each does a 2-by-4 block of A = A-V*C
69             #define ABITTYROWS      2
70             #define ABITTYCOLS      4
71 
72         #endif
73 
74     #elif (ROW_PANELSIZE == 2)
75 
76         #if (COL_PANELSIZE == 2)
77 
78             //------------------------------------------------------------------
79             // block_apply_2_by_2
80             //------------------------------------------------------------------
81 
82             // V is 2-by-1, C is 1-by-2, A is 2-by-2 (in # tiles)
83             // 256 threads, each does a 4-by-2 block of C = T'*V'*A
84             #define CBITTYROWS      4
85             #define CBITTYCOLS      2
86             // 256 threads, each does a 4-by-4 block of A = A-V*C
87             #define ABITTYROWS      4
88             #define ABITTYCOLS      4
89 
90         #else
91 
92             //------------------------------------------------------------------
93             // block_apply_2_by_1
94             //------------------------------------------------------------------
95 
96             // V is 2-by-1, C is 1-by-1, A is 2-by-1 (in # tiles)
97             // 256 threads, each does a 2-by-2 block of C = T'*V'*A
98             #define CBITTYROWS      2
99             #define CBITTYCOLS      2
100             // 256 threads, each does a 2-by-4 block of A = A-V*C
101             #define ABITTYROWS      2
102             #define ABITTYCOLS      4
103 
104 
105         #endif
106 
107     #else
108 
109         #if (COL_PANELSIZE == 2)
110 
111             //------------------------------------------------------------------
112             // block_apply_1_by_2
113             //------------------------------------------------------------------
114 
115             // V is 1-by-1, C is 1-by-2, A is 1-by-2 (in # tiles)
116             // 256 threads, each does a 4-by-2 block of C = T'*V'*A
117             #define CBITTYROWS      2
118             #define CBITTYCOLS      4
119             // 256 threads, each does a 4-by-2 block of A = A-V*C
120             #define ABITTYROWS      2
121             #define ABITTYCOLS      4
122 
123         #else
124 
125             //------------------------------------------------------------------
126             // block_apply_1_by_1
127             //------------------------------------------------------------------
128 
129             // V is 1-by-1, C is 1-by-1, A is 1-by-1 (in # tiles)
130             // 256 threads, each does a 2-by-2 block of C = T'*V'*A
131             #define CBITTYROWS      2
132             #define CBITTYCOLS      2
133             // 256 threads, each does a 2-by-2 block of A = A-V*C
134             #define ABITTYROWS      2
135             #define ABITTYCOLS      2
136 
137         #endif
138 
139     #endif
140 
141     //--------------------------------------------------------------------------
142     // matrix sizes and thread geometry
143     //--------------------------------------------------------------------------
144 
145     // For each outer iteration, C is M-by-N, V is (K+1)-by-M (with an extra
146     // row for T), and A is K-by-N.
147     #define K           (ROW_PANELSIZE * M)
148     #define N           (COL_PANELSIZE * M)
149 
150     // threads to use for C=T'*(V'*A)
151     #define CTHREADS    ((M * N) / (CBITTYROWS * CBITTYCOLS))
152 
153     // threads to use for A=A-V*C
154     #define ATHREADS    ((K * N) / (ABITTYROWS * ABITTYCOLS))
155 
156     //--------------------------------------------------------------------------
157     // bitty blocks for the computation
158     //--------------------------------------------------------------------------
159 
160     // Each thread owns a bitty block of C for C=T'*V'*A.  The top left entry
161     // owned by a thread is C(ic,jc).  Thread 0 does C(0,0), thread 1 does
162     // C(1,0) ...
163     #define ic          (threadIdx.x % (M/CBITTYROWS))
164     #define jc          (threadIdx.x / (M/CBITTYROWS))
165     #define MYCBITTYROW(ii) (ii * (M/CBITTYROWS) + ic)
166     #define MYCBITTYCOL(jj) (jj * (N/CBITTYCOLS) + jc)
167 
168     // Each thread owns a bitty block of A for A=A-V*C, with top left entry
169     // A(ia,ja).  Thread 0 does A(0,0), thread 1 does A(0,1), thread 2 does
170     // A(0,2), ... so that global memory loads/stores are coallesced across a
171     // warp.
172     #define ia          (threadIdx.x / (N/ABITTYCOLS))
173     #define ja          (threadIdx.x % (N/ABITTYCOLS))
174     #define MYABITTYROW(ii) (ii * (K/ABITTYROWS) + ia)
175     #define MYABITTYCOL(jj) (jj * (N/ABITTYCOLS) + ja)
176 
177     //--------------------------------------------------------------------------
178     // loading the A matrix
179     //--------------------------------------------------------------------------
180 
181     // Each thread loads a set of entries of A defined by iaload and jaload.
182     // The first entry loaded by a thread is A(iaload,jaload), and then it
183     // loads entries every ACHUNKSIZE rows after that (in the same column
184     // jaload).
185     #define iaload      (threadIdx.x / N)
186     #define jaload      (threadIdx.x % N)
187     #define ACHUNKSIZE  (NUMTHREADS / N)
188     #define NACHUNKS    CEIL (HALFTILE*N, NUMTHREADS)
189 
190     int fjload = j1 + jaload ;
191 
192     //--------------------------------------------------------------------------
193     // register allocation
194     //--------------------------------------------------------------------------
195 
196     // C bitty block is no larger than the A bitty block, in both dimensions.
197     double rbit [ABITTYROWS][ABITTYCOLS] ;
198     double rrow [ABITTYROWS] ;
199     double rcol [ABITTYCOLS] ;
200 
201     #if (CBITTYCOLS == ABITTYCOLS)
202         // the A bitty block is too small to hold the A buffer
203         double abuffer [NACHUNKS] ;
204         #define rbitA(i) abuffer [i]
205     #else
206         // use the last column of the A bitty block for the A buffer
207         #define rbitA(i) (rbit [i][ABITTYCOLS-1])
208     #endif
209 
210     //--------------------------------------------------------------------------
211     // edge case
212     //--------------------------------------------------------------------------
213 
214     #ifdef ROW_EDGE_CASE
215         // check if a row is inside the front.
216         #define INSIDE_ROW(test) (test)
217     #else
218         // the row is guaranteed to reside inside the frontal matrix.
219         #define INSIDE_ROW(test) (1)
220     #endif
221 
222     #ifdef COL_EDGE_CASE
223         // check if a column is inside the front.
224         #define INSIDE_COL(test) (test)
225     #else
226         // the column is guaranteed to reside inside the frontal matrix.
227         #define INSIDE_COL(test) (1)
228     #endif
229 
230     bool aloader = INSIDE_COL (fjload < fn) ;
231 
232     //--------------------------------------------------------------------------
233     // C = V'*A, where V is now in shared, and A is loaded from global
234     //--------------------------------------------------------------------------
235 
236     // prefetch the first halftile of A from global to register
237     #pragma unroll
238     for (int ii = 0 ; ii < NACHUNKS ; ii++)
239     {
240         rbitA (ii) = 0 ;
241     }
242     #pragma unroll
243     for (int ii = 0 ; ii < NACHUNKS ; ii++)
244     {
245         int i = ii * ACHUNKSIZE + iaload ;
246         if (ii < NACHUNKS-1 || i < HALFTILE)
247         {
248             int fi = IFRONT (0, i) ;
249             if (aloader && INSIDE_ROW (fi < fm))
250             {
251                 rbitA (ii) = glF [fi * fn + fjload] ;
252             }
253         }
254     }
255 
256     // The X=V*C computation in the prior iteration reads shC, but the same
257     // space is used to load A from the frontal matrix in this iteration.
258     __syncthreads ( ) ;
259 
260     // clear the C bitty block
261     #pragma unroll
262     for (int ii = 0 ; ii < CBITTYROWS ; ii++)
263     {
264         #pragma unroll
265         for (int jj = 0 ; jj < CBITTYCOLS ; jj++)
266         {
267             rbit [ii][jj] = 0 ;
268         }
269     }
270 
271     // C=V'*A for the first tile of V, which is lower triangular
272     #define FIRST_TILE
273     #include "cevta_tile.cu"
274 
275     // Subsequent tiles of V are square.  Result is in C bitty block.
276     for (int t = 1 ; t < ROW_PANELSIZE ; t++)
277     {
278         #include "cevta_tile.cu"
279     }
280 
281     //--------------------------------------------------------------------------
282     // write result of C=V'*A into shared, and clear the C bitty block
283     //--------------------------------------------------------------------------
284 
285     if (CTHREADS == NUMTHREADS || threadIdx.x < CTHREADS)
286     {
287         #pragma unroll
288         for (int ii = 0 ; ii < CBITTYROWS ; ii++)
289         {
290             int i = MYCBITTYROW (ii) ;
291             #pragma unroll
292             for (int jj = 0 ; jj < CBITTYCOLS ; jj++)
293             {
294                 int j = MYCBITTYCOL (jj) ;
295                 shC [i][j] = rbit [ii][jj] ;
296                 rbit [ii][jj] = 0 ;
297             }
298         }
299     }
300 
301     // make sure all of shC is available to all threads
302     __syncthreads ( ) ;
303 
304     //--------------------------------------------------------------------------
305     // C = triu(T)'*C, leaving the result in the C bitty block
306     //--------------------------------------------------------------------------
307 
308     if (CTHREADS == NUMTHREADS || threadIdx.x < CTHREADS)
309     {
310         #pragma unroll
311         for (int i = 0 ; i < M ; i++)
312         {
313             #pragma unroll
314             for (int ii = 0 ; ii < CBITTYROWS ; ii++)
315             {
316                 int j = MYCBITTYROW (ii) ;
317                 if (i <= j)
318                 {
319                     rrow [ii] = ST (i,j) ;
320                 }
321             }
322             #pragma unroll
323             for (int jj = 0 ; jj < CBITTYCOLS ; jj++)
324             {
325                 int j = MYCBITTYCOL (jj) ;
326                 rcol [jj] = shC [i][j] ;
327             }
328             #pragma unroll
329             for (int ii = 0 ; ii < CBITTYROWS ; ii++)
330             {
331                 int j = MYCBITTYROW (ii) ;
332                 if (i <= j)
333                 {
334                     #pragma unroll
335                     for (int jj = 0 ; jj < CBITTYCOLS ; jj++)
336                     {
337                         rbit [ii][jj] += rrow [ii] * rcol [jj] ;
338                     }
339                 }
340             }
341         }
342     }
343 
344     // We need syncthreads here because of the write-after-read hazard.  Each
345     // thread reads the old C, above, and then C is modified below with the new
346     // C, where newC = triu(T)'*oldC.
347     __syncthreads ( ) ;
348 
349     //--------------------------------------------------------------------------
350     // write the result of C = T'*C to shared memory
351     //--------------------------------------------------------------------------
352 
353     if (CTHREADS == NUMTHREADS || threadIdx.x < CTHREADS)
354     {
355         #pragma unroll
356         for (int ii = 0 ; ii < CBITTYROWS ; ii++)
357         {
358             int i = MYCBITTYROW (ii) ;
359             #pragma unroll
360             for (int jj = 0 ; jj < CBITTYCOLS ; jj++)
361             {
362                 int j = MYCBITTYCOL (jj) ;
363                 shC [i][j] = rbit [ii][jj] ;
364             }
365         }
366     }
367 
368     // All threads come here.  We need a syncthreads because
369     // shC has been written above and must be read below in A=A-V*C.
370     __syncthreads ( ) ;
371 
372     //--------------------------------------------------------------------------
373     // A = A - V*C
374     //--------------------------------------------------------------------------
375 
376     if (ATHREADS == NUMTHREADS || threadIdx.x < ATHREADS)
377     {
378 
379         //----------------------------------------------------------------------
380         // clear the A bitty block
381         //----------------------------------------------------------------------
382 
383         #pragma unroll
384         for (int ii = 0 ; ii < ABITTYROWS ; ii++)
385         {
386             #pragma unroll
387             for (int jj = 0 ; jj < ABITTYCOLS ; jj++)
388             {
389                 rbit [ii][jj] = 0 ;
390             }
391         }
392 
393         //----------------------------------------------------------------------
394         // X = tril(V)*C, store result into register (rbit)
395         //----------------------------------------------------------------------
396 
397         #pragma unroll
398         for (int p = 0 ; p < M ; p++)
399         {
400             #pragma unroll
401             for (int ii = 0 ; ii < ABITTYROWS ; ii++)
402             {
403                 int i = MYABITTYROW (ii) ;
404                 if (i >= p)
405                 {
406                     rrow [ii] = shV [1+i][p] ;
407                 }
408             }
409             #pragma unroll
410             for (int jj = 0 ; jj < ABITTYCOLS ; jj++)
411             {
412                 int j = MYABITTYCOL (jj) ;
413                 rcol [jj] = shC [p][j] ;
414             }
415             #pragma unroll
416             for (int ii = 0 ; ii < ABITTYROWS ; ii++)
417             {
418                 int i = MYABITTYROW (ii) ;
419                 if (i >= p)
420                 {
421                     #pragma unroll
422                     for (int jj = 0 ; jj < ABITTYCOLS ; jj++)
423                     {
424                         rbit [ii][jj] += rrow [ii] * rcol [jj] ;
425                     }
426                 }
427             }
428         }
429 
430         //----------------------------------------------------------------------
431         // A = A - X, which finalizes the computation A = A - V*(T'*(V'*A))
432         //----------------------------------------------------------------------
433 
434         #if (COL_PANELSIZE == 2)
435 
436             #pragma unroll
437             for (int ii = 0 ; ii < ABITTYROWS ; ii++)
438             {
439                 int i = MYABITTYROW (ii) ;
440                 int fi = IFRONT (i / M, i % M) ;
441                 #pragma unroll
442                 for (int jj = 0 ; jj < ABITTYCOLS ; jj++)
443                 {
444                     int fj = j1 + MYABITTYCOL (jj) ;
445                     if (INSIDE_ROW (fi < fm) && INSIDE_COL (fj < fn))
446                     {
447                         glF [fi * fn + fj] -= rbit [ii][jj] ;
448                     }
449                 }
450             }
451 
452         #else
453 
454             #pragma unroll
455             for (int ii = 0 ; ii < ABITTYROWS ; ii++)
456             {
457                 int i = MYABITTYROW (ii) ;
458                 int fi = IFRONT (i / M, i % M) ;
459                 #pragma unroll
460                 for (int jj = 0 ; jj < ABITTYCOLS ; jj++)
461                 {
462                     int fj = j1 + MYABITTYCOL (jj) ;
463                     if (INSIDE_ROW (fi < fm) && INSIDE_COL (fj < fn))
464                     {
465                         shV[i][MYABITTYCOL(jj)] = glF[fi*fn+fj] - rbit[ii][jj];
466                     }
467                     else
468                     {
469                         shV[i][MYABITTYCOL(jj)] = 0.0;
470                     }
471                 }
472             }
473 
474         #endif
475     }
476 
477     //--------------------------------------------------------------------------
478     // sync
479     //--------------------------------------------------------------------------
480 
481     // The X=V*C computation in this iteration reads shC, but the same space is
482     // used to load A from the frontal matrix in C=V'*A in the next iteration.
483     // This final sync also ensures that all threads finish the block_apply
484     // at the same time.  Thus, no syncthreads is needed at the start of a
485     // subsequent function (the pipelined apply+factorize, for example).
486 
487     __syncthreads ( ) ;
488 }
489 
490 //------------------------------------------------------------------------------
491 // undef's
492 //------------------------------------------------------------------------------
493 
494 // The following #define's appear above.  Note that FIRST_TILE is not #undef'd
495 // since that is done by cevta_tile.cu.
496 #undef CBITTYROWS
497 #undef CBITTYCOLS
498 #undef ABITTYROWS
499 #undef ABITTYCOLS
500 
501 #undef K
502 #undef N
503 
504 #undef CTHREADS
505 #undef ATHREADS
506 
507 #undef ic
508 #undef jc
509 #undef MYCBITTYROW
510 #undef MYCBITTYCOL
511 
512 #undef ia
513 #undef ja
514 #undef MYABITTYROW
515 #undef MYABITTYCOL
516 
517 #undef iaload
518 #undef jaload
519 #undef ACHUNKSIZE
520 #undef NACHUNKS
521 
522 #undef rbitA
523 #undef INSIDE_ROW
524 #undef INSIDE_COL
525 
526 // Defined in the parent file that includes this one.  Note that ROW_PANELSIZE
527 // is not #undef'd, since that is done in the parent.
528 #undef ROW_EDGE_CASE
529 #undef COL_EDGE_CASE
530