1 //------------------------------------------------------------------------------
2 // sparseDotProduct_merge_path.cu
3 //------------------------------------------------------------------------------
4 
5 // The sparseDotProduct CUDA kernel produces the semi-ring dot product of two
6 // sparse vectors of types T1 and T2 and common index space size n, to a scalar
7 // odata of type T3. The vectors are sparse, with different numbers of non-zeros.
8 // ie. we want to produce dot(x,y) in the sense of the given semi-ring.
9 
10 // This version uses a merge-path algorithm, when the sizes g_xnz and g_ynz are
11 // relatively close in size, but for any size of N.
12 
13 // Both the grid and block are 1D, so blockDim.x is the # threads in a
14 // threadblock, and the # of threadblocks is grid.x
15 
16 // Let b = blockIdx.x, and let s be blockDim.x. s= 32 with a variable number
17 // of active threads = min( min(g_xnz, g_ynz), 32)
18 
19 // Thus, threadblock b owns a part of the index set spanned by g_xi and g_yi.  Its job
20 // is to find the intersection of the index sets g_xi and g_yi, perform the semi-ring dot
21 // product on those items in the intersection, and finally reduce this data to a scalar,
22 // on exit write it to g_odata [b].
23 
24 #include <limits>
25 #include <cooperative_groups.h>
26 #include "mySemiRing.h"
27 
28 using namespace cooperative_groups;
29 
30 template< typename T, int tile_sz>
reduce_sum(thread_block_tile<tile_sz> g,T val)31 __device__ T reduce_sum(thread_block_tile<tile_sz> g, T val)
32 {
33     // Each iteration halves the number of active threads
34     // Each thread adds its partial sum[i] to sum[lane+i]
35     for (int i = g.size() / 2; i > 0; i /= 2)
36     {
37         val = ADD( val, g.shfl_down(val,i) );
38         //if (g.thread_rank() ==0)
39         //    printf("in reduce_sum i=%i val = %f\n", i, val);
40     }
41     return val; // note: only thread 0 will return full sum
42 }
43 
44 #define INTMIN( A, B) ( (A) < (B) ) ?  (A) : (B)
45 #define INTMAX( A, B) ( (A) > (B) ) ?  (A) : (B)
46 #define intersects_per_thread 4
47 
48 template< typename T1, typename T2, typename T3>
sparseDotProduct(unsigned int g_xnz,unsigned int * g_xi,T1 * g_xdata,unsigned int g_ynz,unsigned int * g_yi,T2 * g_ydata,T3 * g_odata)49 __global__ void sparseDotProduct
50 (
51     unsigned int g_xnz,       // Number of non-zeros in x
52     unsigned int *g_xi,       // Non-zero indices in x, size xnz
53     T1 *g_xdata,              // array of size xnz, type T1
54     unsigned int g_ynz,       // Number of non-zeros in y
55     unsigned int *g_yi,       // Non-zero indices in y, size ynz
56     T2 *g_ydata,              // array of size ynz, type T2
57     T3 *g_odata               // array of size grid.x, type T3
58 )
59 {
60     // set thread ID
61     unsigned int tid_global = threadIdx.x+ blockDim.x* blockIdx.x;
62     unsigned int tid = threadIdx.x;
63 
64     unsigned long int b = blockIdx.x ;
65 
66     // total items to be inspected
67     unsigned int nxy = (g_xnz + g_ynz);
68 
69     //largest possible number of intersections is the smaller nz
70     unsigned int n_intersect = INTMIN( g_xnz, g_ynz);
71 
72     //we want more than one intersection per thread
73     unsigned int parts = (n_intersect+ intersects_per_thread -1)/ intersects_per_thread;
74 
75     unsigned int work_per_thread = (nxy +parts -1)/parts;
76     unsigned int diag = INTMIN( work_per_thread*tid_global, nxy);
77     unsigned int diag_end = INTMIN( diag + work_per_thread, nxy);
78     //printf(" thd%d parts = %u wpt = %u diag, diag_end  = %u,%u\n",tid, parts, work_per_thread, diag, diag_end);
79 
80    unsigned int x_min = INTMAX( (int)(diag - g_ynz), 0);
81    unsigned int x_max = INTMIN( diag, g_xnz);
82 
83    //printf("start thd%u x_min = %u x_max = %u\n", tid_global, x_min,x_max);
84    while ( x_min < x_max) { //binary search for correct diag break
85       unsigned int pivot = (x_min +x_max)/2;
86       if ( g_xi[pivot] < g_yi[ diag -pivot -1]) {
87          x_min = pivot +1;
88       }
89       else {
90          x_max = pivot;
91       }
92    }
93    int xcoord = x_min;
94    int ycoord = diag -x_min -1;
95    if (( diag > 0) &&(diag < (g_xnz+g_ynz)) && (g_xi[xcoord] == g_yi[ycoord]) ) {
96        diag--; //adjust for intersection incrementing both pointers
97    }
98    // two start points are known now
99    int x_start = xcoord;
100    int y_start = diag -xcoord;
101 
102    //if (x_start != y_start)
103    //   printf("start thd%u  xs,ys = %i,%i\n", tid_global, x_start, y_start);
104 
105    x_min = INTMAX( (int)(diag_end - g_ynz), 0);
106    x_max = INTMIN( diag_end, g_xnz);
107 
108    while ( x_min < x_max) {
109       unsigned int pivot = (x_min +x_max)/2;
110       //printf("thd%u pre_sw piv=%u diag_e = %u  xmin,xmax=%u,%u\n", tid_global, pivot, diag_end,x_min, x_max);
111       if ( g_xi[pivot] < g_yi[ diag_end -pivot -1]) {
112          x_min = pivot +1;
113       }
114       else {
115          x_max = pivot;
116       }
117       //printf("thd%u piv=%u xmin,xmax = %u,%u\n", tid_global, pivot, x_min, x_max);
118    }
119    xcoord = x_min;
120    ycoord = diag_end -x_min -1;
121    if ( (diag_end < (g_xnz+g_ynz)) && (g_xi[xcoord] == g_yi[ycoord]) ) {
122        diag--; //adjust for intersection incrementing both pointers
123    }
124    // two end points are known now
125    int x_end = xcoord;
126    int y_end = diag_end - xcoord;
127 
128    /*
129    if (tid == 0 && b == 0) {
130         printf ("type1 is size %d\n", sizeof (T1)) ;
131         for (int k = 0 ; k < g_xnz ; k++) printf ("%4d: %g,", k, (T1) g_xdata [k]) ;
132         printf ("\n") ;
133         printf ("type2 is size %d\n", sizeof (T2)) ;
134         for (int k = 0 ; k < g_ynz ; k++) printf ("%4d: %g,", k, (T2) g_ydata [k]) ;
135         printf ("\n") ;
136     }
137     __syncthreads();
138     */
139 
140     T3 sum = (T3) 0;
141     //printf(" thd%u has init value %f\n",tid, sum);
142 
143     // nothing to do
144     if ( (x_start >= x_end) || (y_start >= y_end) ) { return ; }
145 
146     //merge-path dot product
147     int k = x_start;
148     int l = y_start;
149     while ( k < x_end && l < y_end )
150     {
151        if      ( g_xi[k] < g_yi[l] ) k += 1;
152        else if ( g_xi[k] > g_yi[l] ) l += 1;
153        else {
154           //printf("  thd%d ix at %u \n",tid_global,g_xi[k]);
155           //printf("   sum += %f * %f \n",tid,g_xdata[k],g_ydata[l]);
156           //sum = ADD( sum, MUL( g_xdata[k], g_ydata[l]));
157           MULADD( sum, g_xdata[k], g_ydata[l]);
158           //printf(" thd%u work value = %f\n",tid_global, sum);
159           k+= 1;
160           l+= 1;
161        }
162 
163     }
164 
165     __syncthreads ( ) ;
166     /*
167     if (1)
168     {
169         printf ("thd%u done with intersect and multiply, val = %f\n",tid_global, sum) ;
170     }
171     __syncthreads ( ) ;
172     */
173 
174     //--------------------------------------------------------------------------
175     // reduce sum per-thread values to a single scalar
176     //--------------------------------------------------------------------------
177     // Using tile size fixed at compile time, we don't need shared memory
178     #define tile_sz 32
179     thread_block_tile<tile_sz> tile = tiled_partition<tile_sz>( this_thread_block());
180     T3 block_sum = reduce_sum<T3,tile_sz>(tile, sum);
181 
182     // write result for this block to global mem
183     if (tid == 0)
184     {
185         printf ("final %d : %g\n", b,  block_sum) ;
186         g_odata [b] = block_sum ;
187     }
188 }
189 
190