1const char* const templates_GB_jit_AxB_dot3_phase3_mp_cu = "templates/GB_jit_AxB_dot3_phase3_mp.cu\n" 2"//------------------------------------------------------------------------------\n" 3"// AxB_dot3_phase3_mp.cu \n" 4"//------------------------------------------------------------------------------\n" 5"\n" 6"// This CUDA kernel produces the semi-ring product of two\n" 7"// sparse matrices of types T_A and T_B and common index space size n, to a \n" 8"// output matrix of type T_C. The matrices are sparse, with different numbers\n" 9"// of non-zeros and different sparsity patterns. \n" 10"// ie. we want to produce C = A'*B in the sense of the given semi-ring.\n" 11"\n" 12"// This version uses a merge-path algorithm, when the sizes nnzA and nnzB are \n" 13"// relatively close in size, neither is very spare nor dense, for any size of N.\n" 14"// Handles arbitrary sparsity patterns with guaranteed load balance.\n" 15"\n" 16"// Both the grid and block are 1D, so blockDim.x is the # threads in a\n" 17"// threadblock, and the # of threadblocks is grid.x\n" 18"\n" 19"// Let b = blockIdx.x, and let s be blockDim.x. s= 32 with a variable number\n" 20"// of active threads = min( min(g_xnz, g_ynz), 32) \n" 21"\n" 22"// Thus, threadblock b owns a part of the index set spanned by g_xi and g_yi. Its job\n" 23"// is to find the intersection of the index sets g_xi and g_yi, perform the semi-ring dot\n" 24"// product on those items in the intersection, and finally reduce this data to a scalar, \n" 25"// on exit write it to g_odata [b].\n" 26"\n" 27"// int64_t start <- start of vector pairs for this kernel\n" 28"// int64_t end <- end of vector pairs for this kernel\n" 29"// int64_t *Bucket <- array of pair indices for all kernels \n" 30"// matrix<T_C> *C <- result matrix \n" 31"// matrix<T_M> *M <- mask matrix\n" 32"// matrix<T_A> *A <- input matrix A\n" 33"// matrix<T_B> *B <- input matrix B\n" 34"#include <limits>\n" 35"#include <cstdint>\n" 36"#include <cooperative_groups.h>\n" 37"#include \"mySemiRing.h\"\n" 38"#include \"matrix.h\"\n" 39"\n" 40"// Using tile size fixed at compile time, we don't need shared memory\n" 41"#define tile_sz 32 \n" 42"\n" 43"using namespace cooperative_groups;\n" 44"\n" 45"template< typename T, int warp_sz>\n" 46"__device__ __inline__ \n" 47"T GB_reduce_sum(thread_block_tile<warp_sz> g, T val)\n" 48"{\n" 49" // Each iteration halves the number of active threads\n" 50" // Each thread adds its partial sum[i] to sum[lane+i]\n" 51" for (int i = g.size() / 2; i > 0; i /= 2)\n" 52" {\n" 53" T next = g.shfl_down( val, i);\n" 54" val = GB_ADD( val, next ) ;\n" 55" }\n" 56" return val;\n" 57"}\n" 58"\n" 59"template< typename T, int warp_sz>\n" 60"__device__ __inline__ \n" 61"T reduce_plus(thread_block_tile<warp_sz> g, T val)\n" 62"{\n" 63" // Each iteration halves the number of active threads\n" 64" // Each thread adds its partial sum[i] to sum[lane+i]\n" 65" for (int i = g.size() / 2; i > 0; i /= 2)\n" 66" {\n" 67" val += g.shfl_down( val, i) ;\n" 68" }\n" 69" return val; // note: only thread 0 will return full sum and flag value\n" 70"}\n" 71"\n" 72"#define intersects_per_thread 8\n" 73"\n" 74"template< typename T_C, typename T_A, typename T_B, typename T_X, typename T_Y, typename T_Z> \n" 75"__global__ void AxB_dot3_phase3_mp\n" 76"(\n" 77" int64_t start,\n" 78" int64_t end,\n" 79" int64_t *Bucket,\n" 80" GrB_Matrix C,\n" 81" GrB_Matrix M,\n" 82" GrB_Matrix A,\n" 83" GrB_Matrix B,\n" 84" int sz\n" 85")\n" 86"{\n" 87"\n" 88" T_A *Ax = (T_A*)A->x;\n" 89" T_B *Bx = (T_B*)B->x;\n" 90" T_C *Cx = (T_C*)C->x;\n" 91" int64_t *Ci = C->i;\n" 92" int64_t *Mi = M->i;\n" 93" int64_t *Ai = A->i;\n" 94" int64_t *Bi = B->i;\n" 95" int64_t *Ap = A->p;\n" 96" int64_t *Bp = B->p;\n" 97"\n" 98"\n" 99" // zombie count\n" 100" int zc = 0;\n" 101"\n" 102" int64_t pair_id;\n" 103"\n" 104" // set thread ID\n" 105" int tid_global = threadIdx.x+ blockDim.x* blockIdx.x;\n" 106" int tid = threadIdx.x;\n" 107"\n" 108" int b = blockIdx.x ;\n" 109"\n" 110" // total items to be inspected\n" 111" int64_t nnzA = 0;\n" 112" int64_t nnzB = 0;\n" 113" int64_t n_intersect = 0;\n" 114"\n" 115" thread_block_tile<tile_sz> tile = tiled_partition<tile_sz>( this_thread_block());\n" 116"\n" 117" int parts = blockDim.x; //(n_intersect+ intersects_per_thread -1)/ intersects_per_thread; \n" 118"\n" 119" // int has_zombies = 0 ;\n" 120"\n" 121" // Main loop over pairs \n" 122" for (pair_id = start+ blockIdx.x; //warp per pair \n" 123" pair_id < end; \n" 124" pair_id += gridDim.x )\n" 125" {\n" 126"\n" 127" int64_t i = Mi[pair_id];\n" 128" int64_t j = Ci[pair_id] >> 4;\n" 129"\n" 130" int64_t xstart = Ap[i];\n" 131" int64_t xend = Ap[i+1];\n" 132" nnzA = xend - xstart;\n" 133"\n" 134" int64_t ystart = Bp[j]; \n" 135" int64_t yend = Bp[j+1]; \n" 136" nnzB = yend - ystart;\n" 137"\n" 138" n_intersect = GB_IMIN( xend -xstart, yend -ystart); \n" 139" /* \n" 140" if (threadIdx.x ==0 ) {\n" 141" printf(\"block %d doing dot %lld i,j= %lld,%lld\\n\", blockIdx.x, pair_id, i, j);\n" 142" }\n" 143" */\n" 144" //we want more than one intersection per thread\n" 145" int64_t nxy = nnzA + nnzB;\n" 146"\n" 147" int work_per_thread = (nxy +parts -1)/parts;\n" 148" int diag = GB_IMIN( work_per_thread*tid, nxy);\n" 149" int diag_end = GB_IMIN( diag + work_per_thread, nxy);\n" 150" //printf(\" thd%d parts = %u wpt = %u diag, diag_end = %u,%u\\n\",tid, parts, work_per_thread, diag, diag_end); \n" 151"\n" 152" int x_min = GB_IMAX( (int)(diag - nnzB), 0);\n" 153" int x_max = GB_IMIN( diag, nnzA);\n" 154"\n" 155" //printf(\"start thd%u x_min = %u x_max = %u\\n\", tid_global, x_min,x_max);\n" 156" while ( x_min < x_max) { //binary search for correct diag break\n" 157" int pivot = (x_min +x_max)/2;\n" 158" if ( Ai[pivot + xstart] < Bi[ diag -pivot -1 + ystart]) {\n" 159" x_min = pivot +1;\n" 160" }\n" 161" else {\n" 162" x_max = pivot;\n" 163" }\n" 164" }\n" 165" int xcoord = x_min;\n" 166" int ycoord = diag -x_min -1;\n" 167" if (( diag > 0) &&(diag < (nnzA+nnzB)) && (Ai[xcoord+xstart] == Bi[ycoord+ystart]) ) { \n" 168" diag--; //adjust for intersection incrementing both pointers \n" 169" }\n" 170" // two start points are known now\n" 171" int tx_start = xcoord +xstart;\n" 172" int ty_start = diag -xcoord +ystart; \n" 173"\n" 174" //if (x_start != y_start)\n" 175" // printf(\"start thd%u xs,ys = %i,%i\\n\", tid_global, x_start, y_start);\n" 176"\n" 177" x_min = GB_IMAX( (int)(diag_end - nnzB), 0);\n" 178" x_max = GB_IMIN( diag_end, nnzA);\n" 179"\n" 180" while ( x_min < x_max) {\n" 181" int pivot = (x_min +x_max)/2;\n" 182" //printf(\"thd%u pre_sw piv=%u diag_e = %u xmin,xmax=%u,%u\\n\", tid_global, pivot, diag_end,x_min, x_max);\n" 183" if ( Ai[pivot+ xstart] < Bi[ diag_end -pivot -1 +ystart]) {\n" 184" x_min = pivot +1;\n" 185" }\n" 186" else {\n" 187" x_max = pivot;\n" 188" }\n" 189" //printf(\"thd%u piv=%u xmin,xmax = %u,%u\\n\", tid_global, pivot, x_min, x_max);\n" 190" }\n" 191" xcoord = x_min;\n" 192" ycoord = diag_end -x_min -1;\n" 193" if ( (diag_end < (nnzA +nnzB)) && (Ai[xcoord +xstart] == Bi[ycoord + ystart]) ) { \n" 194" diag--; //adjust for intersection incrementing both pointers \n" 195" }\n" 196" // two end points are known now\n" 197" int tx_end = xcoord +xstart; \n" 198" int ty_end = diag_end - xcoord + ystart; \n" 199"\n" 200" T_A aki;\n" 201" T_B bkj;\n" 202" T_Z cij = GB_IDENTITY ;\n" 203"\n" 204" // TODO PLUS_PAIR_INT64, FP32, FP64: no need for cij_exists.\n" 205" // just check if cij > 0\n" 206"\n" 207" int cij_exists = 0 ;\n" 208" //printf(\" thd%u has init value %f\\n\",tid, cij);\n" 209"\n" 210" //merge-path dot product\n" 211" int k = tx_start;\n" 212" int l = ty_start;\n" 213" while ( k < tx_end && l < ty_end )\n" 214" {\n" 215" if (Ai [k] == Bi [l])\n" 216" {\n" 217" GB_GETA ( aki=(T_Z)Ax[k] ) ;\n" 218" GB_GETB ( bkj=(T_Z)Bx[l] ) ;\n" 219" if (cij_exists)\n" 220" {\n" 221" T_Z t = GB_MULT( (T_Z)aki, (T_Z)bkj );\n" 222" GB_ADD_F (cij, t ) ;\n" 223" //printf(\" thd%d ix at %lld cij += %d * %d \\n\", tid_global, Ai[k], aki, bkj);\n" 224" }\n" 225" else\n" 226" {\n" 227" cij_exists = 1 ;\n" 228" cij = GB_MULT ( (T_Z)aki, (T_Z)bkj ) ;\n" 229" //printf(\" thd%d ix at %lld cij = %d * %d \\n\", tid_global, Ai[k], Ax[k], Bx[l]);\n" 230" }\n" 231" // TODO check terminal condition\n" 232" k+= 1;\n" 233" l+= 1;\n" 234" //printf(\" block%u work value = %d, exists = %d\\n\", b, cij, cij_exists);\n" 235" }\n" 236" else\n" 237" {\n" 238" k += ( Ai[k] < Bi[l] ) ;\n" 239" l += ( Ai[k] > Bi[l] ) ;\n" 240" }\n" 241" }\n" 242"\n" 243" //tile.sync( ) ;\n" 244" //--------------------------------------------------------------------------\n" 245" // reduce sum per-thread values to a single scalar, get OR of flag\n" 246" //--------------------------------------------------------------------------\n" 247" /*\n" 248" if (tid == 0)\n" 249" {\n" 250" printf (\"reduce %d : %d exists = %d\\n\", b, cij, cij_exists) ;\n" 251" }\n" 252" __syncthreads();\n" 253" */\n" 254"\n" 255" // Do vote here for control.\n" 256" cij_exists = tile.any( cij_exists);\n" 257" //tile.sync();\n" 258"\n" 259" if (cij_exists)\n" 260" {\n" 261" cij = GB_reduce_sum<T_Z, tile_sz>( tile, cij );\n" 262" \n" 263" }\n" 264" // else has_zombies = 1;\n" 265"\n" 266"\n" 267" //__syncthreads();\n" 268" //tile.sync( );\n" 269" // write result for this block to global mem\n" 270" if (tid == 0)\n" 271" {\n" 272" //printf (\"final %d : %d exists = %d\\n\", b, cij, cij_exists) ;\n" 273" if (cij_exists)\n" 274" {\n" 275" //printf(\" cij = %d\\n\", cij);\n" 276" GB_PUTC ( Cx[pair_id]=(T_C)cij ) ;\n" 277" GB_PUTC ( Ci[pair_id]=i ) ;\n" 278" }\n" 279" else\n" 280" {\n" 281" //printf(\" dot %d is a zombie\\n\", pair_id);\n" 282" zc++;\n" 283" GB_PUTC ( Ci[pair_id]=GB_FLIP (i) ) ;\n" 284" }\n" 285" }\n" 286" //__syncthreads(); \n" 287" }\n" 288"\n" 289"//--------------------------------------------------------------------------\n" 290"\n" 291" if( tid ==0 && zc > 0)\n" 292" {\n" 293" //printf(\"warp %d zombie count = %d\\n\", blockIdx.x, zc);\n" 294" atomicAdd( (unsigned long long int*)&(C->zombie_count), (unsigned long long int)zc);\n" 295" //printf(\" Czombie = %lld\\n\",C->zombie_count);\n" 296" }\n" 297"\n" 298" //__syncthreads();\n" 299"\n" 300"}\n" 301"\n" 302; 303