1 #define RADIX_BITS 4
2 #define RADIX_SIZE (1<<RADIX_BITS)
3 #define RADIX_MASK(n) ((RADIX_SIZE-1) << (n*RADIX_BITS))
4 #define RADIX_DIGITS(T) (bitsof(T)/RADIX_BITS)
5
6 // works when length on axis is within max allowed threads in block (1024)
k_topk_dense($dims $dstv $dstv_offset $dstv_strides $dsti $dsti_offset $dsti_strides ssize_t k,INPUT_TYPE * src,size_t src_offset,$src_strides size_t size)7 extern "C" __global__ void k_topk_dense(
8 $dims
9 // size_t dims_1, ssize_t dims_2, ... , dims_$${NDIM}
10 $dstv
11 // INPUT_TYPE *dstv
12 $dstv_offset
13 // size_t offset
14 $dstv_strides
15 // ssize_t dstv_strides_0, ssize_t dstv_strides_1, ... , dstv_strides_$${NDIM}
16 $dsti
17 // INDEX_TYPE *dsti
18 $dsti_offset
19 // size_t offset
20 $dsti_strides
21 // ssize_t dsti_strides_0, ssize_t dsti_strides_1, ... , dsti_strides_$${NDIM}
22 ssize_t k,
23 INPUT_TYPE* src,
24 size_t src_offset,
25 $src_strides
26 // ssize_t src_strides_0, ssize_t src_strides_1, ... , src_strides_$${NDIM}
27 size_t size) {
28 __shared__ int smem[32 * RADIX_SIZE];
29 __shared__ int k2;
30 const unsigned int idx = threadIdx.x;
31 bool is_topk= (idx < size);
32 bool is_topkth = is_topk;
33 size_t out_idx;
34
35 const unsigned char warp_id = idx / GA_WARP_SIZE;
36 // 0. get the slice for thread block to work on
37
38 size_t gid = blockIdx.x, gidx;
39 $set_slice
40 // $$set_slice expands into:
41 //for(int i=1; i<NDIM; i++) {
42 // gidx = gid % dims_$${i};
43 // gid /= dims_$${i};
44 // dsti = ptr_add(dsti, gidx*dsti_strides_$${i};
45 // dstv = ptr_add(dstv, gidx*dstv_strides_$${i};
46 // src = ptr_add(src, gidx*src_strides_$${i});
47 //}
48
49 // get input and its radix friendly form
50 const INPUT_TYPE xval = is_topk ? ptr_at(src, idx*src_strides_0) : theano_zero<INPUT_TYPE>();
51 radix_t x = RadixConfig<INPUT_TYPE>::convert(xval);
52
53 // resolve negative k
54 if (k<0) { x = ~x; k = -k; }
55 if (idx==0)
56 k2 = k;
57
58 // 1. filter is_topk and is_topkth using radix select
59
60 #pragma unroll
61 for (int i=bitsof(INPUT_TYPE)-RADIX_BITS; i>=0; i-=RADIX_BITS) {
62 const int digit = Bitfield<radix_t>::get(x, i, RADIX_BITS);
63 /*int digit = (x>>i) & (RADIX_SIZE-1);*/
64 // count within warp
65 #pragma unroll
66 for (int bin=0; bin<RADIX_SIZE; ++bin) {
67 bool vote = (bin == digit) && is_topkth;
68 unsigned int votes = __ballot(vote);
69 if (lane_id()==0)
70 smem[bin + RADIX_SIZE*warp_id] = __popc(votes);
71 }
72 local_barrier();
73 // sum counts across all warps
74 if (idx < RADIX_SIZE) {
75 int sum = smem[idx];
76 #pragma unroll
77 for(int w=RADIX_SIZE; w<blockDim.x*RADIX_SIZE / GA_WARP_SIZE; w+=RADIX_SIZE)
78 sum += smem[idx + w];
79 smem[idx] = sum;
80 }
81 local_barrier();
82
83 // find the bucket and update k2
84 // smem[:RADIX_SIZE:-1] = k2 - cumsum(smem[:RADIX_SIZE-1:-1])
85 if (idx == 0) {
86 int sum = k2;
87 #pragma unroll
88 for (int bin=RADIX_SIZE-1; bin>=0; --bin) {
89 sum -= smem[bin];
90 smem[bin] = sum;
91 k2 = (sum > 0) ? sum : k2;
92 }
93 smem[RADIX_SIZE] = 1;
94 }
95 local_barrier();
96
97 if (is_topkth) {
98 is_topk &= (smem[digit+1] > 0);
99 is_topkth &= (smem[digit] <= 0) && (smem[digit+1] > 0);
100 }
101 local_barrier();
102 }
103
104 // set k2 as number of exceeding values
105 if (idx==0) {
106 #pragma unroll
107 for (int bin=RADIX_SIZE-1; bin>=0; --bin) {
108 if (smem[bin] <= 0)
109 break;
110 k2 = smem[bin];
111 }
112 }
113 local_barrier();
114
115 // 2. find the index of output array, if exists
116
117 if (k2 != 0) {
118 // top_kth value may not be unique, so we need to
119 // perform binary cumsum on is_topkth to drop exceeding top-kth values
120 out_idx = binary_cumsum_exclusive(idx, warp_id, smem, is_topkth);
121 if ((out_idx >= k2) && is_topkth)
122 is_topk = false;
123 local_barrier();
124 }
125
126 // perform binary cumsum on is_topk to determine the indices to put result
127 out_idx = binary_cumsum_exclusive(idx, warp_id, smem, is_topk);
128
129 if (is_topk) {
130 #if WRITE_VALUE == 1
131 ptr_at(dstv, out_idx * dstv_strides_0) = xval;
132 #endif
133 #if WRITE_INDEX == 1
134 ptr_at(dsti, out_idx * dsti_strides_0) = (INDEX_TYPE)idx;
135 #endif
136 }
137 }
138