1 #define RADIX_BITS 4
2 #define RADIX_SIZE      (1<<RADIX_BITS)
3 #define RADIX_MASK(n)   ((RADIX_SIZE-1) << (n*RADIX_BITS))
4 #define RADIX_DIGITS(T) (bitsof(T)/RADIX_BITS)
5 
6 // works when length on axis is within max allowed threads in block (1024)
k_topk_dense($dims $dstv $dstv_offset $dstv_strides $dsti $dsti_offset $dsti_strides ssize_t k,INPUT_TYPE * src,size_t src_offset,$src_strides size_t size)7 extern "C" __global__ void k_topk_dense(
8         $dims
9         // size_t dims_1, ssize_t dims_2, ... , dims_$${NDIM}
10         $dstv
11         // INPUT_TYPE *dstv
12         $dstv_offset
13         // size_t offset
14         $dstv_strides
15         // ssize_t dstv_strides_0, ssize_t dstv_strides_1, ... , dstv_strides_$${NDIM}
16         $dsti
17         // INDEX_TYPE *dsti
18         $dsti_offset
19         // size_t offset
20         $dsti_strides
21         // ssize_t dsti_strides_0, ssize_t dsti_strides_1, ... , dsti_strides_$${NDIM}
22         ssize_t k,
23         INPUT_TYPE* src,
24 	size_t src_offset,
25         $src_strides
26         // ssize_t src_strides_0, ssize_t src_strides_1, ... , src_strides_$${NDIM}
27         size_t size) {
28     __shared__ int smem[32 * RADIX_SIZE];
29     __shared__ int k2;
30     const unsigned int idx = threadIdx.x;
31     bool is_topk= (idx < size);
32     bool is_topkth = is_topk;
33     size_t out_idx;
34 
35     const unsigned char warp_id = idx / GA_WARP_SIZE;
36     // 0. get the slice for thread block to work on
37 
38     size_t gid = blockIdx.x, gidx;
39     $set_slice
40     // $$set_slice expands into:
41     //for(int i=1; i<NDIM; i++) {
42         // gidx = gid % dims_$${i};
43         // gid /= dims_$${i};
44         // dsti = ptr_add(dsti, gidx*dsti_strides_$${i};
45         // dstv = ptr_add(dstv, gidx*dstv_strides_$${i};
46         // src = ptr_add(src, gidx*src_strides_$${i});
47     //}
48 
49     // get input and its radix friendly form
50     const INPUT_TYPE xval = is_topk ? ptr_at(src, idx*src_strides_0) : theano_zero<INPUT_TYPE>();
51     radix_t x = RadixConfig<INPUT_TYPE>::convert(xval);
52 
53     // resolve negative k
54     if (k<0) { x = ~x; k = -k; }
55     if (idx==0)
56         k2 = k;
57 
58     // 1. filter is_topk and is_topkth using radix select
59 
60     #pragma unroll
61     for (int i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) {
62         const int digit = Bitfield<radix_t>::get(x, i, RADIX_BITS);
63         /*int digit = (x>>i) & (RADIX_SIZE-1);*/
64         // count within warp
65         #pragma unroll
66         for (int bin=0; bin<RADIX_SIZE; ++bin) {
67             bool vote = (bin == digit) && is_topkth;
68             unsigned int votes = __ballot(vote);
69             if (lane_id()==0)
70                 smem[bin + RADIX_SIZE*warp_id] = __popc(votes);
71         }
72         local_barrier();
73         // sum counts across all warps
74         if (idx < RADIX_SIZE) {
75             int sum = smem[idx];
76             #pragma unroll
77             for(int w=RADIX_SIZE; w<blockDim.x*RADIX_SIZE / GA_WARP_SIZE; w+=RADIX_SIZE)
78                 sum += smem[idx + w];
79             smem[idx] = sum;
80         }
81         local_barrier();
82 
83         // find the bucket and update k2
84         // smem[:RADIX_SIZE:-1] = k2 - cumsum(smem[:RADIX_SIZE-1:-1])
85         if (idx == 0) {
86             int sum = k2;
87             #pragma unroll
88             for (int bin=RADIX_SIZE-1; bin>=0; --bin) {
89                 sum -= smem[bin];
90                 smem[bin] = sum;
91                 k2 = (sum > 0) ? sum : k2;
92             }
93             smem[RADIX_SIZE] = 1;
94         }
95         local_barrier();
96 
97         if (is_topkth) {
98             is_topk &= (smem[digit+1] > 0);
99             is_topkth &= (smem[digit] <= 0) && (smem[digit+1] > 0);
100         }
101         local_barrier();
102     }
103 
104     // set k2 as number of exceeding values
105     if (idx==0) {
106         #pragma unroll
107         for (int bin=RADIX_SIZE-1; bin>=0; --bin) {
108             if (smem[bin] <= 0)
109                 break;
110             k2 = smem[bin];
111         }
112     }
113     local_barrier();
114 
115     // 2. find the index of output array, if exists
116 
117     if (k2 != 0) {
118         // top_kth value may not be unique, so we need to
119         // perform binary cumsum on is_topkth to drop exceeding top-kth values
120         out_idx = binary_cumsum_exclusive(idx, warp_id, smem, is_topkth);
121         if ((out_idx >= k2) && is_topkth)
122             is_topk = false;
123         local_barrier();
124     }
125 
126     // perform binary cumsum on is_topk to determine the indices to put result
127     out_idx = binary_cumsum_exclusive(idx, warp_id, smem, is_topk);
128 
129     if (is_topk) {
130 #if WRITE_VALUE == 1
131         ptr_at(dstv, out_idx * dstv_strides_0) = xval;
132 #endif
133 #if WRITE_INDEX == 1
134         ptr_at(dsti, out_idx * dsti_strides_0) = (INDEX_TYPE)idx;
135 #endif
136     }
137 }
138