1 //------------------------------------------------------------------------------
2 // AxB_dot3_phase3_warpix.cu
3 //------------------------------------------------------------------------------
4 
5 // This CUDA kernel produces the semi-ring product of two
6 // sparse matrices of types T_A and T_B and common index space size n, to a
7 // output matrix of type T_C. The matrices are sparse, with different numbers
8 // of non-zeros and different sparsity patterns.
9 // ie. we want to produce C = A'*B in the sense of the given semi-ring.
10 
11 // This version uses a merge-path algorithm, when the sizes nnzA and nnzB are
12 // relatively close in size, neither is very spare nor dense, for any size of N.
13 // Handles arbitrary sparsity patterns with guaranteed load balance.
14 
15 // Both the grid and block are 1D, so blockDim.x is the # threads in a
16 // threadblock, and the # of threadblocks is grid.x
17 
18 // Let b = blockIdx.x, and let s be blockDim.x. s= 32 with a variable number
19 // of active threads = min( min(g_xnz, g_ynz), 32)
20 
21 // Thus, threadblock b owns a part of the index set spanned by g_xi and g_yi.  Its job
22 // is to find the intersection of the index sets g_xi and g_yi, perform the semi-ring dot
23 // product on those items in the intersection, and finally reduce this data to a scalar,
24 // on exit write it to g_odata [b].
25 
26 //  int64_t start          <- start of vector pairs for this kernel
27 //  int64_t end            <- end of vector pairs for this kernel
28 //  int64_t *Bucket        <- array of pair indices for all kernels
29 //  matrix<T_C> *C         <- result matrix
30 //  matrix<T_M> *M         <- mask matrix
31 //  matrix<T_A> *A         <- input matrix A
32 //  matrix<T_B> *B         <- input matrix B
33 #define GB_KERNEL
34 #include <limits>
35 #include <cstdint>
36 #include "matrix.h"
37 #include <cooperative_groups.h>
38 #include "mySemiRing.h"
39 
40 // Using tile size fixed at compile time, we don't need shared memory
41 #define tile_sz 32
42 
43 using namespace cooperative_groups;
44 
45 template< typename T, int warp_sz>
46 __device__ __inline__
GB_reduce_sum(thread_block_tile<warp_sz> g,T val)47 T GB_reduce_sum(thread_block_tile<warp_sz> g, T val)
48 {
49     // Each iteration halves the number of active threads
50     // Each thread adds its partial sum[i] to sum[lane+i]
51     for (int i = g.size() / 2; i > 0; i /= 2)
52     {
53         T next = g.shfl_down( val, i);
54         val = GB_ADD( val, next ) ;
55     }
56     return val;
57 }
58 
59 template< typename T, int warp_sz>
60 __device__ __inline__
reduce_plus(thread_block_tile<warp_sz> g,T val)61 T reduce_plus(thread_block_tile<warp_sz> g, T val)
62 {
63     // Each iteration halves the number of active threads
64     // Each thread adds its partial sum[i] to sum[lane+i]
65     for (int i = g.size() / 2; i > 0; i /= 2)
66     {
67         val += g.shfl_down( val, i) ;
68     }
69     return val; // note: only thread 0 will return full sum and flag value
70 }
71 
72 #define intersects_per_thread 8
73 
74 template< typename T_C, typename T_A, typename T_B, typename T_X, typename T_Y, typename T_Z>
AxB_dot3_phase3_warpix(int64_t start,int64_t end,int64_t * __restrict__ Bucket,GrB_Matrix C,GrB_Matrix M,GrB_Matrix A,GrB_Matrix B,int sz)75 __global__ void AxB_dot3_phase3_warpix
76 (
77     int64_t start,
78     int64_t end,
79     int64_t *__restrict__ Bucket,
80     GrB_Matrix C,
81     GrB_Matrix M,
82     GrB_Matrix A,
83     GrB_Matrix B,
84     int sz
85 )
86 {
87 
88     T_A *__restrict__ Ax = (T_A*)A->x;
89     T_B *__restrict__ Bx = (T_B*)B->x;
90     T_C *__restrict__ Cx = (T_C*)C->x;
91     int64_t *__restrict__ Ci = C->i;
92     int64_t *__restrict__ Mi = M->i;
93     int64_t *__restrict__ Mp = M->p;
94     int64_t *__restrict__ Ai = A->i;
95     int64_t *__restrict__ Bi = B->i;
96     int64_t *__restrict__ Ap = A->p;
97     int64_t *__restrict__ Bp = B->p;
98 
99     int64_t mnvec = M->nvec;
100 
101     // zombie count
102     int zc;
103 
104     int64_t pair_id;
105 
106     // set thread ID
107     int tid_global = threadIdx.x+ blockDim.x* blockIdx.x;
108     int tid = threadIdx.x;
109     int b = blockIdx.x ;
110 
111     // total items to be inspected
112     int64_t nnzA = 0;
113     int64_t nnzB = 0;
114 
115     thread_block_tile<tile_sz> tile = tiled_partition<tile_sz>( this_thread_block());
116 
117     //int parts = gridDim.x; //Each warp is a part
118 
119     //Find our part of the work bucket
120     int64_t pfirst, plast, kfirst, klast ;
121     GB_PARTITION (pfirst, plast, end-start, b, gridDim.x ) ;
122     /*
123     if( tid ==0 ) {
124        printf("block%d is alive, pf,pl=%ld,%ld \n", b, pfirst, plast);
125     }
126     __syncthreads();
127     */
128 
129 
130     __shared__ int64_t As[256];
131     __shared__ int64_t Bs[256];
132     __shared__ T_A Axs[256];
133     __shared__ T_B Bxs[256];
134 
135    /*
136     int Bpl[9]; // local offsets into shared for multiple vectors of B
137     int shr_vec[8] ; //columns of B we see in this task
138 
139     pair_id = Bucket[pfirst];
140     int64_t i = Mi[pair_id] ;
141     int vecs = 1 ;
142     int last_vec = i;
143     shr_vec[0] = i;
144     for (int id =1; id< plast-pfirst; id++)
145     {
146          pair_id = Bucket[pfirst+id];
147          i = Mi[pair_id];
148          if (i == last_vec) continue;
149          vecs++;
150          shr_vec[vecs] = i;
151          last_vec = i;
152     }
153     int all_loaded = 0;
154 
155     Bpl[0] = 0;
156     for ( int k = 0; k < vecs; k++)
157     {
158         int64_t pA       = Ap[ shr_vec[k] ];
159         int64_t pA_end   = Ap[ shr_vec[k] +1];
160         nnzA = pA_end - pA;
161         Bpl[k+1] = Bpl[k] + nnzA;
162         for (int i = tid ; i < nnzA; i+= blockDim.x)
163         {
164            As[ Bpl[k] +i ] = Ai[ pA + i ] ;
165         }
166         __syncthreads();
167     }
168 
169     //pre-load columns of B, which will be reused, to shared memory
170     //Due to loading a contigious block with stride 1 this is fast
171 
172     all_loaded = (Bpl[vecs] < 256 );
173     if( tid == 0 ) {
174        printf("block%d loaded %d vals from B, vecs=%d, all_loaded=%d\n",
175                  b, Bpl[vecs], vecs, all_loaded );
176     }
177     __syncthreads();
178 
179 
180     // reset counter
181     */
182     // Main loop over pairs
183     for (int id = start + pfirst; // loop on pairs
184          id < start+ plast;
185          id ++ )
186     {
187          int64_t pair_id = Bucket[id];
188 
189          int64_t i = Mi[pair_id];
190          int64_t j = Ci[pair_id] >> 4;
191 
192          int64_t pA       = Ap[i];
193          int64_t pA_end   = Ap[i+1];
194          nnzA = pA_end - pA;
195 
196          int64_t pB       = Bp[j];
197          int64_t pB_end   = Bp[j+1];
198          nnzB = pB_end - pB;
199 
200          zc = 0 ;
201          int j_last = -1 ;
202 
203 
204     // No search, this warp does all the work
205 
206     int tx_start = pA;
207     int tx_end   = pA_end;
208     int ty_start = pB;
209     int ty_end   = pB_end;
210 
211     for ( int i = tid; i < nnzA ; i+= blockDim.x)
212     {
213        As [i] = Ai[ pA + i];
214        Axs[i] = Ax[ pA + i];
215     }
216     __syncthreads();
217 
218     if ( j != j_last) {
219         for ( int i = tid; i < nnzB ; i+= blockDim.x)
220         {
221            Bs [i] = Bi[ pB + i];
222            Bxs[i] = Bx[ pB + i];
223         }
224         __syncthreads();
225         j_last = j;
226     }
227 
228 
229     /*
230     if ( tid==0 ) {
231       //printf("block %d dot %lld i,j= %lld,%lld\n", blockIdx.x, pair_id, i, j);
232       printf("block%d dot %ld(i,j)=(%ld,%ld) xs,xe= %d,%d ys,ye = %d,%d \n",
233                b, pair_id, i, j, tx_start,tx_end, ty_start, ty_end);
234       //for(int a = 0; a < nnzA; a++) printf(" As[%d]:%ld ",a, As[j]);
235     }
236     tile.sync();
237     */
238 
239 
240 
241     // Warp intersection: balanced by design, no idle threads.
242     // Each 32 thread warp will handle 32 comparisons per loop.
243     // Either A or B takes stride 4, other takes stride 8
244     // For this version A strides 4, B strides 8
245     T_A aki;
246     T_B bkj;
247     T_Z cij = GB_IDENTITY ;
248     int Astride = nnzA > nnzB ? 8 : 4;
249     int Ashift  = nnzA > nnzB ? 3 : 2;
250     int Amask   = nnzA > nnzB ? 7 : 3;
251     int Bstride = nnzB >= nnzA ? 8 : 4;
252     //printf(" Astride = %d, Bstride = %d\n", Astride, Bstride);
253 
254     // TODO PLUS_PAIR_INT64, FP32, FP64: no need for cij_exists.
255     // just check if cij > 0
256 
257     int cij_exists  = 0 ;
258 
259     //Warp intersection dot product
260     int bitty_row = tid &  Amask ;
261     int bitty_col = tid >> Ashift ;
262 
263     int k = tx_start + bitty_row ;
264     int l = ty_start + bitty_col ;
265 
266     //Ai[k] = As[ k -pA ];  for lookup
267     //Bi[l] = Bs[ l -pB ];
268 
269 
270     int inc_k,inc_l;
271 
272     int active = ( ( k < tx_end) && (l < ty_end ) );
273 
274     /*
275     printf("block%d tid%d  Ai,As=%ld,%ld Bi,Bs=%ld,%ld  k,l =%d,%d active:%d\n",
276                     b,tid, Ai[k], As[k -pA], Bi[l], Bs[l -pB],
277                     k, l,  active );
278     */
279 
280 
281     while ( tile.any(active) )
282     {
283        inc_k = 0;
284        inc_l = 0;
285        int kp = k-pA;
286        int lp = l-pB;
287        if ( active )
288        {
289           coalesced_group g = coalesced_threads();
290           if ( g.thread_rank() == g.size()-1)
291           {
292              inc_k = ( As[kp] <= Bs[lp] ) ;
293              inc_l = ( Bs[lp] <= As[kp] ) ;
294              // printf("block%d tid%d inc_k= %d inc_l = %d\n",b, tid, inc_k, inc_l );
295           }
296           //tile.sync();
297 
298           if ( As [kp] == Bs [lp] )
299           {
300               //Axs[kp] = Ax[k];
301               //Bxs[lp] = Bx[l];
302 
303               GB_GETA ( aki=(T_Z)Axs[kp] ) ;
304               GB_GETB ( bkj=(T_Z)Bxs[lp] ) ;
305               if (cij_exists)
306               {
307                 T_Z t = GB_MULT( (T_Z) aki, (T_Z) bkj);
308                 GB_ADD_F( cij, t ) ;
309                 //printf("block%d  thd%d ix at %ld(%ld)  cij += %d * %d\n",b, tid, Ai[k], As[kp], aki, bkj);
310               }
311               else
312               {
313                 cij_exists = 1 ;
314                 cij = GB_MULT ( (T_Z) aki, (T_Z) bkj) ;
315                 //printf("  thd%d ix at %ld(%ld)  cij = %d * %d \n", tid, Ai[k], Ais[kp], aki, bkj);
316               }
317           }
318           // TODO check terminal condition
319           //printf(" block%u work value = %d, exists = %d\n", b, cij, cij_exists);
320           //printf("block%d tid%d k,l = %d,%d Ai,Bi = %ld,%ld \n", b, tid, k, l, Ai[k], Bi[l] );
321        }
322        //tile.sync();
323        //inc_k = tile.shfl_down( inc_k, 31-tid);
324        if( tile.any(inc_k) ) {
325           k =1+ tile.shfl_down(k,31-tid) + bitty_row ; // tid%Astride;
326           //Ais [k-pA] = As[k-pA];
327           //Axs [bitty_row] = Ax[k];
328        }
329        if( tile.any(inc_l) ) {
330           l =1+ tile.shfl_down(l,31-tid) + bitty_col ; // tid/Astride;
331           //Bis [l-pB] = Bs[l-pB];
332           //Bxs [bitty_col] = Bx[l];
333        }
334        active = ( ( k < tx_end) && (l < ty_end ) );
335        //printf("block%d tid = %d k = %d l= %d active=%d\n", b, tid, k, l,active);
336     }
337     tile.sync();
338 
339     //--------------------------------------------------------------------------
340     // reduce sum per-thread values to a single scalar, get OR of flag
341     //--------------------------------------------------------------------------
342 
343     // Do vote here for control.
344     cij_exists  = tile.any( cij_exists);
345     tile.sync();
346 
347     if (cij_exists)
348     {
349        cij = GB_reduce_sum<T_Z, tile_sz>( tile, cij );
350     }
351     tile.sync();
352 
353 
354     // Atomic write result for this block to global mem
355     if (tid == 0)
356     {
357         //printf ("final %d : %d exists = %d\n", b,  cij, cij_exists) ;
358         if (cij_exists)
359         {
360            //printf("block%d i,j =%ld,%ld cij = %d\n",b, i, j, cij);
361            GB_PUTC( Cx[pair_id] = (T_C) cij ) ;
362            GB_PUTC ( Ci[pair_id] = i ) ;
363 
364         }
365         else
366         {
367             //printf(" dot %d is a zombie\n", pair_id);
368             zc++;
369             GB_PUTC ( Ci[pair_id] = GB_FLIP (i) ) ;
370         }
371 
372     //__syncthreads();
373 
374 
375        if( zc > 0)
376        {
377           //printf("warp %d zombie count = %d\n", blockIdx.x, zc);
378           atomicAdd( (unsigned long long int*)&(C->nzombies), (unsigned long long int)zc);
379           //printf("blk:%d Czombie = %lld\n",blockIdx.x,C->zombies);
380        }
381 
382     }
383     tile.sync();
384     /*
385     */
386   }
387 }
388 
389