1 //------------------------------------------------------------------------------
2 // AxB_dot3_phase3_warpix.cu
3 //------------------------------------------------------------------------------
4
5 // This CUDA kernel produces the semi-ring product of two
6 // sparse matrices of types T_A and T_B and common index space size n, to a
7 // output matrix of type T_C. The matrices are sparse, with different numbers
8 // of non-zeros and different sparsity patterns.
9 // ie. we want to produce C = A'*B in the sense of the given semi-ring.
10
11 // This version uses a merge-path algorithm, when the sizes nnzA and nnzB are
12 // relatively close in size, neither is very spare nor dense, for any size of N.
13 // Handles arbitrary sparsity patterns with guaranteed load balance.
14
15 // Both the grid and block are 1D, so blockDim.x is the # threads in a
16 // threadblock, and the # of threadblocks is grid.x
17
18 // Let b = blockIdx.x, and let s be blockDim.x. s= 32 with a variable number
19 // of active threads = min( min(g_xnz, g_ynz), 32)
20
21 // Thus, threadblock b owns a part of the index set spanned by g_xi and g_yi. Its job
22 // is to find the intersection of the index sets g_xi and g_yi, perform the semi-ring dot
23 // product on those items in the intersection, and finally reduce this data to a scalar,
24 // on exit write it to g_odata [b].
25
26 // int64_t start <- start of vector pairs for this kernel
27 // int64_t end <- end of vector pairs for this kernel
28 // int64_t *Bucket <- array of pair indices for all kernels
29 // matrix<T_C> *C <- result matrix
30 // matrix<T_M> *M <- mask matrix
31 // matrix<T_A> *A <- input matrix A
32 // matrix<T_B> *B <- input matrix B
33 #define GB_KERNEL
34 #include <limits>
35 #include <cstdint>
36 #include "matrix.h"
37 #include <cooperative_groups.h>
38 #include "mySemiRing.h"
39
40 // Using tile size fixed at compile time, we don't need shared memory
41 #define tile_sz 32
42
43 using namespace cooperative_groups;
44
45 template< typename T, int warp_sz>
46 __device__ __inline__
GB_reduce_sum(thread_block_tile<warp_sz> g,T val)47 T GB_reduce_sum(thread_block_tile<warp_sz> g, T val)
48 {
49 // Each iteration halves the number of active threads
50 // Each thread adds its partial sum[i] to sum[lane+i]
51 for (int i = g.size() / 2; i > 0; i /= 2)
52 {
53 T next = g.shfl_down( val, i);
54 val = GB_ADD( val, next ) ;
55 }
56 return val;
57 }
58
59 template< typename T, int warp_sz>
60 __device__ __inline__
reduce_plus(thread_block_tile<warp_sz> g,T val)61 T reduce_plus(thread_block_tile<warp_sz> g, T val)
62 {
63 // Each iteration halves the number of active threads
64 // Each thread adds its partial sum[i] to sum[lane+i]
65 for (int i = g.size() / 2; i > 0; i /= 2)
66 {
67 val += g.shfl_down( val, i) ;
68 }
69 return val; // note: only thread 0 will return full sum and flag value
70 }
71
72 #define intersects_per_thread 8
73
74 template< typename T_C, typename T_A, typename T_B, typename T_X, typename T_Y, typename T_Z>
AxB_dot3_phase3_warpix(int64_t start,int64_t end,int64_t * __restrict__ Bucket,GrB_Matrix C,GrB_Matrix M,GrB_Matrix A,GrB_Matrix B,int sz)75 __global__ void AxB_dot3_phase3_warpix
76 (
77 int64_t start,
78 int64_t end,
79 int64_t *__restrict__ Bucket,
80 GrB_Matrix C,
81 GrB_Matrix M,
82 GrB_Matrix A,
83 GrB_Matrix B,
84 int sz
85 )
86 {
87
88 T_A *__restrict__ Ax = (T_A*)A->x;
89 T_B *__restrict__ Bx = (T_B*)B->x;
90 T_C *__restrict__ Cx = (T_C*)C->x;
91 int64_t *__restrict__ Ci = C->i;
92 int64_t *__restrict__ Mi = M->i;
93 int64_t *__restrict__ Mp = M->p;
94 int64_t *__restrict__ Ai = A->i;
95 int64_t *__restrict__ Bi = B->i;
96 int64_t *__restrict__ Ap = A->p;
97 int64_t *__restrict__ Bp = B->p;
98
99 int64_t mnvec = M->nvec;
100
101 // zombie count
102 int zc;
103
104 int64_t pair_id;
105
106 // set thread ID
107 int tid_global = threadIdx.x+ blockDim.x* blockIdx.x;
108 int tid = threadIdx.x;
109 int b = blockIdx.x ;
110
111 // total items to be inspected
112 int64_t nnzA = 0;
113 int64_t nnzB = 0;
114
115 thread_block_tile<tile_sz> tile = tiled_partition<tile_sz>( this_thread_block());
116
117 //int parts = gridDim.x; //Each warp is a part
118
119 //Find our part of the work bucket
120 int64_t pfirst, plast, kfirst, klast ;
121 GB_PARTITION (pfirst, plast, end-start, b, gridDim.x ) ;
122 /*
123 if( tid ==0 ) {
124 printf("block%d is alive, pf,pl=%ld,%ld \n", b, pfirst, plast);
125 }
126 __syncthreads();
127 */
128
129
130 __shared__ int64_t As[256];
131 __shared__ int64_t Bs[256];
132 __shared__ T_A Axs[256];
133 __shared__ T_B Bxs[256];
134
135 /*
136 int Bpl[9]; // local offsets into shared for multiple vectors of B
137 int shr_vec[8] ; //columns of B we see in this task
138
139 pair_id = Bucket[pfirst];
140 int64_t i = Mi[pair_id] ;
141 int vecs = 1 ;
142 int last_vec = i;
143 shr_vec[0] = i;
144 for (int id =1; id< plast-pfirst; id++)
145 {
146 pair_id = Bucket[pfirst+id];
147 i = Mi[pair_id];
148 if (i == last_vec) continue;
149 vecs++;
150 shr_vec[vecs] = i;
151 last_vec = i;
152 }
153 int all_loaded = 0;
154
155 Bpl[0] = 0;
156 for ( int k = 0; k < vecs; k++)
157 {
158 int64_t pA = Ap[ shr_vec[k] ];
159 int64_t pA_end = Ap[ shr_vec[k] +1];
160 nnzA = pA_end - pA;
161 Bpl[k+1] = Bpl[k] + nnzA;
162 for (int i = tid ; i < nnzA; i+= blockDim.x)
163 {
164 As[ Bpl[k] +i ] = Ai[ pA + i ] ;
165 }
166 __syncthreads();
167 }
168
169 //pre-load columns of B, which will be reused, to shared memory
170 //Due to loading a contigious block with stride 1 this is fast
171
172 all_loaded = (Bpl[vecs] < 256 );
173 if( tid == 0 ) {
174 printf("block%d loaded %d vals from B, vecs=%d, all_loaded=%d\n",
175 b, Bpl[vecs], vecs, all_loaded );
176 }
177 __syncthreads();
178
179
180 // reset counter
181 */
182 // Main loop over pairs
183 for (int id = start + pfirst; // loop on pairs
184 id < start+ plast;
185 id ++ )
186 {
187 int64_t pair_id = Bucket[id];
188
189 int64_t i = Mi[pair_id];
190 int64_t j = Ci[pair_id] >> 4;
191
192 int64_t pA = Ap[i];
193 int64_t pA_end = Ap[i+1];
194 nnzA = pA_end - pA;
195
196 int64_t pB = Bp[j];
197 int64_t pB_end = Bp[j+1];
198 nnzB = pB_end - pB;
199
200 zc = 0 ;
201 int j_last = -1 ;
202
203
204 // No search, this warp does all the work
205
206 int tx_start = pA;
207 int tx_end = pA_end;
208 int ty_start = pB;
209 int ty_end = pB_end;
210
211 for ( int i = tid; i < nnzA ; i+= blockDim.x)
212 {
213 As [i] = Ai[ pA + i];
214 Axs[i] = Ax[ pA + i];
215 }
216 __syncthreads();
217
218 if ( j != j_last) {
219 for ( int i = tid; i < nnzB ; i+= blockDim.x)
220 {
221 Bs [i] = Bi[ pB + i];
222 Bxs[i] = Bx[ pB + i];
223 }
224 __syncthreads();
225 j_last = j;
226 }
227
228
229 /*
230 if ( tid==0 ) {
231 //printf("block %d dot %lld i,j= %lld,%lld\n", blockIdx.x, pair_id, i, j);
232 printf("block%d dot %ld(i,j)=(%ld,%ld) xs,xe= %d,%d ys,ye = %d,%d \n",
233 b, pair_id, i, j, tx_start,tx_end, ty_start, ty_end);
234 //for(int a = 0; a < nnzA; a++) printf(" As[%d]:%ld ",a, As[j]);
235 }
236 tile.sync();
237 */
238
239
240
241 // Warp intersection: balanced by design, no idle threads.
242 // Each 32 thread warp will handle 32 comparisons per loop.
243 // Either A or B takes stride 4, other takes stride 8
244 // For this version A strides 4, B strides 8
245 T_A aki;
246 T_B bkj;
247 T_Z cij = GB_IDENTITY ;
248 int Astride = nnzA > nnzB ? 8 : 4;
249 int Ashift = nnzA > nnzB ? 3 : 2;
250 int Amask = nnzA > nnzB ? 7 : 3;
251 int Bstride = nnzB >= nnzA ? 8 : 4;
252 //printf(" Astride = %d, Bstride = %d\n", Astride, Bstride);
253
254 // TODO PLUS_PAIR_INT64, FP32, FP64: no need for cij_exists.
255 // just check if cij > 0
256
257 int cij_exists = 0 ;
258
259 //Warp intersection dot product
260 int bitty_row = tid & Amask ;
261 int bitty_col = tid >> Ashift ;
262
263 int k = tx_start + bitty_row ;
264 int l = ty_start + bitty_col ;
265
266 //Ai[k] = As[ k -pA ]; for lookup
267 //Bi[l] = Bs[ l -pB ];
268
269
270 int inc_k,inc_l;
271
272 int active = ( ( k < tx_end) && (l < ty_end ) );
273
274 /*
275 printf("block%d tid%d Ai,As=%ld,%ld Bi,Bs=%ld,%ld k,l =%d,%d active:%d\n",
276 b,tid, Ai[k], As[k -pA], Bi[l], Bs[l -pB],
277 k, l, active );
278 */
279
280
281 while ( tile.any(active) )
282 {
283 inc_k = 0;
284 inc_l = 0;
285 int kp = k-pA;
286 int lp = l-pB;
287 if ( active )
288 {
289 coalesced_group g = coalesced_threads();
290 if ( g.thread_rank() == g.size()-1)
291 {
292 inc_k = ( As[kp] <= Bs[lp] ) ;
293 inc_l = ( Bs[lp] <= As[kp] ) ;
294 // printf("block%d tid%d inc_k= %d inc_l = %d\n",b, tid, inc_k, inc_l );
295 }
296 //tile.sync();
297
298 if ( As [kp] == Bs [lp] )
299 {
300 //Axs[kp] = Ax[k];
301 //Bxs[lp] = Bx[l];
302
303 GB_GETA ( aki=(T_Z)Axs[kp] ) ;
304 GB_GETB ( bkj=(T_Z)Bxs[lp] ) ;
305 if (cij_exists)
306 {
307 T_Z t = GB_MULT( (T_Z) aki, (T_Z) bkj);
308 GB_ADD_F( cij, t ) ;
309 //printf("block%d thd%d ix at %ld(%ld) cij += %d * %d\n",b, tid, Ai[k], As[kp], aki, bkj);
310 }
311 else
312 {
313 cij_exists = 1 ;
314 cij = GB_MULT ( (T_Z) aki, (T_Z) bkj) ;
315 //printf(" thd%d ix at %ld(%ld) cij = %d * %d \n", tid, Ai[k], Ais[kp], aki, bkj);
316 }
317 }
318 // TODO check terminal condition
319 //printf(" block%u work value = %d, exists = %d\n", b, cij, cij_exists);
320 //printf("block%d tid%d k,l = %d,%d Ai,Bi = %ld,%ld \n", b, tid, k, l, Ai[k], Bi[l] );
321 }
322 //tile.sync();
323 //inc_k = tile.shfl_down( inc_k, 31-tid);
324 if( tile.any(inc_k) ) {
325 k =1+ tile.shfl_down(k,31-tid) + bitty_row ; // tid%Astride;
326 //Ais [k-pA] = As[k-pA];
327 //Axs [bitty_row] = Ax[k];
328 }
329 if( tile.any(inc_l) ) {
330 l =1+ tile.shfl_down(l,31-tid) + bitty_col ; // tid/Astride;
331 //Bis [l-pB] = Bs[l-pB];
332 //Bxs [bitty_col] = Bx[l];
333 }
334 active = ( ( k < tx_end) && (l < ty_end ) );
335 //printf("block%d tid = %d k = %d l= %d active=%d\n", b, tid, k, l,active);
336 }
337 tile.sync();
338
339 //--------------------------------------------------------------------------
340 // reduce sum per-thread values to a single scalar, get OR of flag
341 //--------------------------------------------------------------------------
342
343 // Do vote here for control.
344 cij_exists = tile.any( cij_exists);
345 tile.sync();
346
347 if (cij_exists)
348 {
349 cij = GB_reduce_sum<T_Z, tile_sz>( tile, cij );
350 }
351 tile.sync();
352
353
354 // Atomic write result for this block to global mem
355 if (tid == 0)
356 {
357 //printf ("final %d : %d exists = %d\n", b, cij, cij_exists) ;
358 if (cij_exists)
359 {
360 //printf("block%d i,j =%ld,%ld cij = %d\n",b, i, j, cij);
361 GB_PUTC( Cx[pair_id] = (T_C) cij ) ;
362 GB_PUTC ( Ci[pair_id] = i ) ;
363
364 }
365 else
366 {
367 //printf(" dot %d is a zombie\n", pair_id);
368 zc++;
369 GB_PUTC ( Ci[pair_id] = GB_FLIP (i) ) ;
370 }
371
372 //__syncthreads();
373
374
375 if( zc > 0)
376 {
377 //printf("warp %d zombie count = %d\n", blockIdx.x, zc);
378 atomicAdd( (unsigned long long int*)&(C->nzombies), (unsigned long long int)zc);
379 //printf("blk:%d Czombie = %lld\n",blockIdx.x,C->zombies);
380 }
381
382 }
383 tile.sync();
384 /*
385 */
386 }
387 }
388
389