1const char* const templates_GB_jit_AxB_dot3_phase3_mp_cu = "templates/GB_jit_AxB_dot3_phase3_mp.cu\n"
2"//------------------------------------------------------------------------------\n"
3"// AxB_dot3_phase3_mp.cu \n"
4"//------------------------------------------------------------------------------\n"
5"\n"
6"// This CUDA kernel produces the semi-ring product of two\n"
7"// sparse matrices of types T_A and T_B and common index space size n, to a  \n"
8"// output matrix of type T_C. The matrices are sparse, with different numbers\n"
9"// of non-zeros and different sparsity patterns. \n"
10"// ie. we want to produce C = A'*B in the sense of the given semi-ring.\n"
11"\n"
12"// This version uses a merge-path algorithm, when the sizes nnzA and nnzB are \n"
13"// relatively close in size, neither is very spare nor dense, for any size of N.\n"
14"// Handles arbitrary sparsity patterns with guaranteed load balance.\n"
15"\n"
16"// Both the grid and block are 1D, so blockDim.x is the # threads in a\n"
17"// threadblock, and the # of threadblocks is grid.x\n"
18"\n"
19"// Let b = blockIdx.x, and let s be blockDim.x. s= 32 with a variable number\n"
20"// of active threads = min( min(g_xnz, g_ynz), 32) \n"
21"\n"
22"// Thus, threadblock b owns a part of the index set spanned by g_xi and g_yi.  Its job\n"
23"// is to find the intersection of the index sets g_xi and g_yi, perform the semi-ring dot\n"
24"// product on those items in the intersection, and finally reduce this data to a scalar, \n"
25"// on exit write it to g_odata [b].\n"
26"\n"
27"//  int64_t start          <- start of vector pairs for this kernel\n"
28"//  int64_t end            <- end of vector pairs for this kernel\n"
29"//  int64_t *Bucket        <- array of pair indices for all kernels \n"
30"//  matrix<T_C> *C         <- result matrix \n"
31"//  matrix<T_M> *M         <- mask matrix\n"
32"//  matrix<T_A> *A         <- input matrix A\n"
33"//  matrix<T_B> *B         <- input matrix B\n"
34"#include <limits>\n"
35"#include <cstdint>\n"
36"#include <cooperative_groups.h>\n"
37"#include \"mySemiRing.h\"\n"
38"#include \"matrix.h\"\n"
39"\n"
40"// Using tile size fixed at compile time, we don't need shared memory\n"
41"#define tile_sz 32 \n"
42"\n"
43"using namespace cooperative_groups;\n"
44"\n"
45"template< typename T, int warp_sz>\n"
46"__device__ __inline__ \n"
47"T GB_reduce_sum(thread_block_tile<warp_sz> g, T val)\n"
48"{\n"
49"    // Each iteration halves the number of active threads\n"
50"    // Each thread adds its partial sum[i] to sum[lane+i]\n"
51"    for (int i = g.size() / 2; i > 0; i /= 2)\n"
52"    {\n"
53"        T next = g.shfl_down( val, i);\n"
54"        val = GB_ADD( val, next ) ;\n"
55"    }\n"
56"    return val;\n"
57"}\n"
58"\n"
59"template< typename T, int warp_sz>\n"
60"__device__ __inline__ \n"
61"T reduce_plus(thread_block_tile<warp_sz> g, T val)\n"
62"{\n"
63"    // Each iteration halves the number of active threads\n"
64"    // Each thread adds its partial sum[i] to sum[lane+i]\n"
65"    for (int i = g.size() / 2; i > 0; i /= 2)\n"
66"    {\n"
67"        val += g.shfl_down( val, i) ;\n"
68"    }\n"
69"    return val; // note: only thread 0 will return full sum and flag value\n"
70"}\n"
71"\n"
72"#define intersects_per_thread 8\n"
73"\n"
74"template< typename T_C, typename T_A, typename T_B, typename T_X, typename T_Y, typename T_Z>  \n"
75"__global__ void AxB_dot3_phase3_mp\n"
76"(\n"
77"    int64_t start,\n"
78"    int64_t end,\n"
79"    int64_t *Bucket,\n"
80"    GrB_Matrix C,\n"
81"    GrB_Matrix M,\n"
82"    GrB_Matrix A,\n"
83"    GrB_Matrix B,\n"
84"    int sz\n"
85")\n"
86"{\n"
87"\n"
88"    T_A *Ax = (T_A*)A->x;\n"
89"    T_B *Bx = (T_B*)B->x;\n"
90"    T_C *Cx = (T_C*)C->x;\n"
91"    int64_t *Ci = C->i;\n"
92"    int64_t *Mi = M->i;\n"
93"    int64_t *Ai = A->i;\n"
94"    int64_t *Bi = B->i;\n"
95"    int64_t *Ap = A->p;\n"
96"    int64_t *Bp = B->p;\n"
97"\n"
98"\n"
99"    // zombie count\n"
100"    int zc = 0;\n"
101"\n"
102"    int64_t pair_id;\n"
103"\n"
104"    // set thread ID\n"
105"    int tid_global = threadIdx.x+ blockDim.x* blockIdx.x;\n"
106"    int tid = threadIdx.x;\n"
107"\n"
108"    int b = blockIdx.x ;\n"
109"\n"
110"    // total items to be inspected\n"
111"    int64_t nnzA = 0;\n"
112"    int64_t nnzB = 0;\n"
113"    int64_t n_intersect = 0;\n"
114"\n"
115"    thread_block_tile<tile_sz> tile = tiled_partition<tile_sz>( this_thread_block());\n"
116"\n"
117"    int parts = blockDim.x; //(n_intersect+ intersects_per_thread -1)/ intersects_per_thread; \n"
118"\n"
119"    // int has_zombies = 0 ;\n"
120"\n"
121"    // Main loop over pairs \n"
122"    for (pair_id = start+ blockIdx.x; //warp per pair \n"
123"         pair_id < end;  \n"
124"         pair_id += gridDim.x )\n"
125"    {\n"
126"\n"
127"         int64_t i = Mi[pair_id];\n"
128"         int64_t j = Ci[pair_id] >> 4;\n"
129"\n"
130"         int64_t xstart = Ap[i];\n"
131"         int64_t xend   = Ap[i+1];\n"
132"         nnzA = xend - xstart;\n"
133"\n"
134"         int64_t ystart = Bp[j]; \n"
135"         int64_t yend   = Bp[j+1]; \n"
136"         nnzB = yend - ystart;\n"
137"\n"
138"         n_intersect = GB_IMIN( xend -xstart, yend -ystart); \n"
139"    /* \n"
140"    if (threadIdx.x ==0 ) {\n"
141"      printf(\"block %d  doing dot %lld  i,j= %lld,%lld\\n\", blockIdx.x, pair_id, i, j);\n"
142"    }\n"
143"    */\n"
144"    //we want more than one intersection per thread\n"
145"    int64_t nxy = nnzA + nnzB;\n"
146"\n"
147"    int work_per_thread = (nxy +parts -1)/parts;\n"
148"    int diag = GB_IMIN( work_per_thread*tid, nxy);\n"
149"    int diag_end = GB_IMIN( diag + work_per_thread, nxy);\n"
150"    //printf(\" thd%d parts = %u wpt = %u diag, diag_end  = %u,%u\\n\",tid, parts, work_per_thread, diag, diag_end); \n"
151"\n"
152"    int x_min = GB_IMAX( (int)(diag - nnzB), 0);\n"
153"    int x_max = GB_IMIN( diag, nnzA);\n"
154"\n"
155"    //printf(\"start thd%u x_min = %u x_max = %u\\n\", tid_global, x_min,x_max);\n"
156"    while ( x_min < x_max) { //binary search for correct diag break\n"
157"      int pivot = (x_min +x_max)/2;\n"
158"      if ( Ai[pivot + xstart] < Bi[ diag -pivot -1 + ystart]) {\n"
159"         x_min = pivot +1;\n"
160"      }\n"
161"      else {\n"
162"         x_max = pivot;\n"
163"      }\n"
164"    }\n"
165"    int xcoord = x_min;\n"
166"    int ycoord = diag -x_min -1;\n"
167"    if (( diag > 0) &&(diag < (nnzA+nnzB)) && (Ai[xcoord+xstart] == Bi[ycoord+ystart]) ) { \n"
168"       diag--; //adjust for intersection incrementing both pointers \n"
169"    }\n"
170"    // two start points are known now\n"
171"    int tx_start = xcoord +xstart;\n"
172"    int ty_start = diag -xcoord +ystart; \n"
173"\n"
174"    //if (x_start != y_start)\n"
175"    //   printf(\"start thd%u  xs,ys = %i,%i\\n\", tid_global, x_start, y_start);\n"
176"\n"
177"    x_min = GB_IMAX( (int)(diag_end - nnzB), 0);\n"
178"    x_max = GB_IMIN( diag_end, nnzA);\n"
179"\n"
180"    while ( x_min < x_max) {\n"
181"       int pivot = (x_min +x_max)/2;\n"
182"       //printf(\"thd%u pre_sw piv=%u diag_e = %u  xmin,xmax=%u,%u\\n\", tid_global, pivot, diag_end,x_min, x_max);\n"
183"       if ( Ai[pivot+ xstart] < Bi[ diag_end -pivot -1 +ystart]) {\n"
184"          x_min = pivot +1;\n"
185"       }\n"
186"       else {\n"
187"          x_max = pivot;\n"
188"       }\n"
189"       //printf(\"thd%u piv=%u xmin,xmax = %u,%u\\n\", tid_global, pivot, x_min, x_max);\n"
190"    }\n"
191"    xcoord = x_min;\n"
192"    ycoord = diag_end -x_min -1;\n"
193"    if ( (diag_end < (nnzA +nnzB)) && (Ai[xcoord +xstart] == Bi[ycoord + ystart]) ) { \n"
194"        diag--; //adjust for intersection incrementing both pointers  \n"
195"    }\n"
196"    // two end points are known now\n"
197"    int tx_end = xcoord +xstart; \n"
198"    int ty_end = diag_end - xcoord + ystart; \n"
199"\n"
200"    T_A aki;\n"
201"    T_B bkj;\n"
202"    T_Z cij = GB_IDENTITY ;\n"
203"\n"
204"    // TODO PLUS_PAIR_INT64, FP32, FP64: no need for cij_exists.\n"
205"    // just check if cij > 0\n"
206"\n"
207"    int cij_exists  = 0 ;\n"
208"    //printf(\" thd%u has init value %f\\n\",tid, cij);\n"
209"\n"
210"    //merge-path dot product\n"
211"    int k = tx_start;\n"
212"    int l = ty_start;\n"
213"    while ( k < tx_end && l < ty_end )\n"
214"    {\n"
215"       if (Ai [k] == Bi [l])\n"
216"       {\n"
217"          GB_GETA ( aki=(T_Z)Ax[k] ) ;\n"
218"          GB_GETB ( bkj=(T_Z)Bx[l] ) ;\n"
219"          if (cij_exists)\n"
220"          {\n"
221"            T_Z t = GB_MULT( (T_Z)aki, (T_Z)bkj );\n"
222"            GB_ADD_F (cij, t ) ;\n"
223"          //printf(\"  thd%d ix at %lld   cij += %d * %d \\n\", tid_global, Ai[k], aki, bkj);\n"
224"          }\n"
225"          else\n"
226"          {\n"
227"            cij_exists = 1 ;\n"
228"            cij = GB_MULT ( (T_Z)aki, (T_Z)bkj ) ;\n"
229"          //printf(\"  thd%d ix at %lld   cij = %d * %d \\n\", tid_global, Ai[k], Ax[k], Bx[l]);\n"
230"          }\n"
231"          // TODO check terminal condition\n"
232"          k+= 1;\n"
233"          l+= 1;\n"
234"          //printf(\" block%u work value = %d, exists = %d\\n\", b, cij, cij_exists);\n"
235"       }\n"
236"       else\n"
237"       {\n"
238"            k += ( Ai[k] < Bi[l] ) ;\n"
239"            l += ( Ai[k] > Bi[l] ) ;\n"
240"       }\n"
241"    }\n"
242"\n"
243"    //tile.sync( ) ;\n"
244"    //--------------------------------------------------------------------------\n"
245"    // reduce sum per-thread values to a single scalar, get OR of flag\n"
246"    //--------------------------------------------------------------------------\n"
247"    /*\n"
248"    if (tid == 0)\n"
249"    {\n"
250"        printf (\"reduce %d : %d exists = %d\\n\", b,  cij, cij_exists) ;\n"
251"    }\n"
252"    __syncthreads();\n"
253"    */\n"
254"\n"
255"    // Do vote here for control.\n"
256"    cij_exists  = tile.any( cij_exists);\n"
257"    //tile.sync();\n"
258"\n"
259"    if (cij_exists)\n"
260"    {\n"
261"       cij = GB_reduce_sum<T_Z, tile_sz>( tile, cij );\n"
262"       \n"
263"    }\n"
264"    // else has_zombies = 1;\n"
265"\n"
266"\n"
267"    //__syncthreads();\n"
268"    //tile.sync( );\n"
269"    // write result for this block to global mem\n"
270"    if (tid == 0)\n"
271"    {\n"
272"        //printf (\"final %d : %d exists = %d\\n\", b,  cij, cij_exists) ;\n"
273"        if (cij_exists)\n"
274"        {\n"
275"           //printf(\" cij = %d\\n\", cij);\n"
276"           GB_PUTC ( Cx[pair_id]=(T_C)cij ) ;\n"
277"           GB_PUTC ( Ci[pair_id]=i ) ;\n"
278"        }\n"
279"        else\n"
280"        {\n"
281"           //printf(\" dot %d is a zombie\\n\", pair_id);\n"
282"           zc++;\n"
283"           GB_PUTC ( Ci[pair_id]=GB_FLIP (i) ) ;\n"
284"        }\n"
285"    }\n"
286"    //__syncthreads(); \n"
287"  }\n"
288"\n"
289"//--------------------------------------------------------------------------\n"
290"\n"
291"  if( tid ==0 && zc > 0)\n"
292"  {\n"
293"      //printf(\"warp %d zombie count = %d\\n\", blockIdx.x, zc);\n"
294"      atomicAdd( (unsigned long long int*)&(C->zombie_count), (unsigned long long int)zc);\n"
295"      //printf(\" Czombie = %lld\\n\",C->zombie_count);\n"
296"  }\n"
297"\n"
298"  //__syncthreads();\n"
299"\n"
300"}\n"
301"\n"
302;
303