1 //------------------------------------------------------------------------------
2 // sparseDotProduct_merge_path.cu
3 //------------------------------------------------------------------------------
4
5 // The sparseDotProduct CUDA kernel produces the semi-ring dot product of two
6 // sparse vectors of types T1 and T2 and common index space size n, to a scalar
7 // odata of type T3. The vectors are sparse, with different numbers of non-zeros.
8 // ie. we want to produce dot(x,y) in the sense of the given semi-ring.
9
10 // This version uses a merge-path algorithm, when the sizes g_xnz and g_ynz are
11 // relatively close in size, but for any size of N.
12
13 // Both the grid and block are 1D, so blockDim.x is the # threads in a
14 // threadblock, and the # of threadblocks is grid.x
15
16 // Let b = blockIdx.x, and let s be blockDim.x. s= 32 with a variable number
17 // of active threads = min( min(g_xnz, g_ynz), 32)
18
19 // Thus, threadblock b owns a part of the index set spanned by g_xi and g_yi. Its job
20 // is to find the intersection of the index sets g_xi and g_yi, perform the semi-ring dot
21 // product on those items in the intersection, and finally reduce this data to a scalar,
22 // on exit write it to g_odata [b].
23
24 #include <limits>
25 #include <cooperative_groups.h>
26 #include "mySemiRing.h"
27
28 using namespace cooperative_groups;
29
30 template< typename T, int tile_sz>
reduce_sum(thread_block_tile<tile_sz> g,T val)31 __device__ T reduce_sum(thread_block_tile<tile_sz> g, T val)
32 {
33 // Each iteration halves the number of active threads
34 // Each thread adds its partial sum[i] to sum[lane+i]
35 for (int i = g.size() / 2; i > 0; i /= 2)
36 {
37 val = ADD( val, g.shfl_down(val,i) );
38 //if (g.thread_rank() ==0)
39 // printf("in reduce_sum i=%i val = %f\n", i, val);
40 }
41 return val; // note: only thread 0 will return full sum
42 }
43
44 #define INTMIN( A, B) ( (A) < (B) ) ? (A) : (B)
45 #define INTMAX( A, B) ( (A) > (B) ) ? (A) : (B)
46 #define intersects_per_thread 4
47
48 template< typename T1, typename T2, typename T3>
sparseDotProduct(unsigned int g_xnz,unsigned int * g_xi,T1 * g_xdata,unsigned int g_ynz,unsigned int * g_yi,T2 * g_ydata,T3 * g_odata)49 __global__ void sparseDotProduct
50 (
51 unsigned int g_xnz, // Number of non-zeros in x
52 unsigned int *g_xi, // Non-zero indices in x, size xnz
53 T1 *g_xdata, // array of size xnz, type T1
54 unsigned int g_ynz, // Number of non-zeros in y
55 unsigned int *g_yi, // Non-zero indices in y, size ynz
56 T2 *g_ydata, // array of size ynz, type T2
57 T3 *g_odata // array of size grid.x, type T3
58 )
59 {
60 // set thread ID
61 unsigned int tid_global = threadIdx.x+ blockDim.x* blockIdx.x;
62 unsigned int tid = threadIdx.x;
63
64 unsigned long int b = blockIdx.x ;
65
66 // total items to be inspected
67 unsigned int nxy = (g_xnz + g_ynz);
68
69 //largest possible number of intersections is the smaller nz
70 unsigned int n_intersect = INTMIN( g_xnz, g_ynz);
71
72 //we want more than one intersection per thread
73 unsigned int parts = (n_intersect+ intersects_per_thread -1)/ intersects_per_thread;
74
75 unsigned int work_per_thread = (nxy +parts -1)/parts;
76 unsigned int diag = INTMIN( work_per_thread*tid_global, nxy);
77 unsigned int diag_end = INTMIN( diag + work_per_thread, nxy);
78 //printf(" thd%d parts = %u wpt = %u diag, diag_end = %u,%u\n",tid, parts, work_per_thread, diag, diag_end);
79
80 unsigned int x_min = INTMAX( (int)(diag - g_ynz), 0);
81 unsigned int x_max = INTMIN( diag, g_xnz);
82
83 //printf("start thd%u x_min = %u x_max = %u\n", tid_global, x_min,x_max);
84 while ( x_min < x_max) { //binary search for correct diag break
85 unsigned int pivot = (x_min +x_max)/2;
86 if ( g_xi[pivot] < g_yi[ diag -pivot -1]) {
87 x_min = pivot +1;
88 }
89 else {
90 x_max = pivot;
91 }
92 }
93 int xcoord = x_min;
94 int ycoord = diag -x_min -1;
95 if (( diag > 0) &&(diag < (g_xnz+g_ynz)) && (g_xi[xcoord] == g_yi[ycoord]) ) {
96 diag--; //adjust for intersection incrementing both pointers
97 }
98 // two start points are known now
99 int x_start = xcoord;
100 int y_start = diag -xcoord;
101
102 //if (x_start != y_start)
103 // printf("start thd%u xs,ys = %i,%i\n", tid_global, x_start, y_start);
104
105 x_min = INTMAX( (int)(diag_end - g_ynz), 0);
106 x_max = INTMIN( diag_end, g_xnz);
107
108 while ( x_min < x_max) {
109 unsigned int pivot = (x_min +x_max)/2;
110 //printf("thd%u pre_sw piv=%u diag_e = %u xmin,xmax=%u,%u\n", tid_global, pivot, diag_end,x_min, x_max);
111 if ( g_xi[pivot] < g_yi[ diag_end -pivot -1]) {
112 x_min = pivot +1;
113 }
114 else {
115 x_max = pivot;
116 }
117 //printf("thd%u piv=%u xmin,xmax = %u,%u\n", tid_global, pivot, x_min, x_max);
118 }
119 xcoord = x_min;
120 ycoord = diag_end -x_min -1;
121 if ( (diag_end < (g_xnz+g_ynz)) && (g_xi[xcoord] == g_yi[ycoord]) ) {
122 diag--; //adjust for intersection incrementing both pointers
123 }
124 // two end points are known now
125 int x_end = xcoord;
126 int y_end = diag_end - xcoord;
127
128 /*
129 if (tid == 0 && b == 0) {
130 printf ("type1 is size %d\n", sizeof (T1)) ;
131 for (int k = 0 ; k < g_xnz ; k++) printf ("%4d: %g,", k, (T1) g_xdata [k]) ;
132 printf ("\n") ;
133 printf ("type2 is size %d\n", sizeof (T2)) ;
134 for (int k = 0 ; k < g_ynz ; k++) printf ("%4d: %g,", k, (T2) g_ydata [k]) ;
135 printf ("\n") ;
136 }
137 __syncthreads();
138 */
139
140 T3 sum = (T3) 0;
141 //printf(" thd%u has init value %f\n",tid, sum);
142
143 // nothing to do
144 if ( (x_start >= x_end) || (y_start >= y_end) ) { return ; }
145
146 //merge-path dot product
147 int k = x_start;
148 int l = y_start;
149 while ( k < x_end && l < y_end )
150 {
151 if ( g_xi[k] < g_yi[l] ) k += 1;
152 else if ( g_xi[k] > g_yi[l] ) l += 1;
153 else {
154 //printf(" thd%d ix at %u \n",tid_global,g_xi[k]);
155 //printf(" sum += %f * %f \n",tid,g_xdata[k],g_ydata[l]);
156 //sum = ADD( sum, MUL( g_xdata[k], g_ydata[l]));
157 MULADD( sum, g_xdata[k], g_ydata[l]);
158 //printf(" thd%u work value = %f\n",tid_global, sum);
159 k+= 1;
160 l+= 1;
161 }
162
163 }
164
165 __syncthreads ( ) ;
166 /*
167 if (1)
168 {
169 printf ("thd%u done with intersect and multiply, val = %f\n",tid_global, sum) ;
170 }
171 __syncthreads ( ) ;
172 */
173
174 //--------------------------------------------------------------------------
175 // reduce sum per-thread values to a single scalar
176 //--------------------------------------------------------------------------
177 // Using tile size fixed at compile time, we don't need shared memory
178 #define tile_sz 32
179 thread_block_tile<tile_sz> tile = tiled_partition<tile_sz>( this_thread_block());
180 T3 block_sum = reduce_sum<T3,tile_sz>(tile, sum);
181
182 // write result for this block to global mem
183 if (tid == 0)
184 {
185 printf ("final %d : %g\n", b, block_sum) ;
186 g_odata [b] = block_sum ;
187 }
188 }
189
190