1 //---------------------------------------------------------------------------//
2 // Copyright (c) 2013 Kyle Lutz <kyle.r.lutz@gmail.com>
3 //
4 // Distributed under the Boost Software License, Version 1.0
5 // See accompanying file LICENSE_1_0.txt or copy at
6 // http://www.boost.org/LICENSE_1_0.txt
7 //
8 // See http://boostorg.github.com/compute for more information.
9 //---------------------------------------------------------------------------//
10 
11 #ifndef BOOST_COMPUTE_ALGORITHM_DETAIL_RADIX_SORT_HPP
12 #define BOOST_COMPUTE_ALGORITHM_DETAIL_RADIX_SORT_HPP
13 
14 #include <iterator>
15 
16 #include <boost/assert.hpp>
17 #include <boost/type_traits/is_signed.hpp>
18 #include <boost/type_traits/is_floating_point.hpp>
19 
20 #include <boost/compute/kernel.hpp>
21 #include <boost/compute/program.hpp>
22 #include <boost/compute/command_queue.hpp>
23 #include <boost/compute/algorithm/exclusive_scan.hpp>
24 #include <boost/compute/container/vector.hpp>
25 #include <boost/compute/detail/iterator_range_size.hpp>
26 #include <boost/compute/detail/parameter_cache.hpp>
27 #include <boost/compute/type_traits/type_name.hpp>
28 #include <boost/compute/type_traits/is_fundamental.hpp>
29 #include <boost/compute/type_traits/is_vector_type.hpp>
30 #include <boost/compute/utility/program_cache.hpp>
31 
32 namespace boost {
33 namespace compute {
34 namespace detail {
35 
36 // meta-function returning true if type T is radix-sortable
37 template<class T>
38 struct is_radix_sortable :
39     boost::mpl::and_<
40         typename ::boost::compute::is_fundamental<T>::type,
41         typename boost::mpl::not_<typename is_vector_type<T>::type>::type
42     >
43 {
44 };
45 
46 template<size_t N>
47 struct radix_sort_value_type
48 {
49 };
50 
51 template<>
52 struct radix_sort_value_type<1>
53 {
54     typedef uchar_ type;
55 };
56 
57 template<>
58 struct radix_sort_value_type<2>
59 {
60     typedef ushort_ type;
61 };
62 
63 template<>
64 struct radix_sort_value_type<4>
65 {
66     typedef uint_ type;
67 };
68 
69 template<>
70 struct radix_sort_value_type<8>
71 {
72     typedef ulong_ type;
73 };
74 
75 template<typename T>
enable_double()76 inline const char* enable_double()
77 {
78     return " -DT2_double=0";
79 }
80 
81 template<>
enable_double()82 inline const char* enable_double<double>()
83 {
84     return " -DT2_double=1";
85 }
86 
87 const char radix_sort_source[] =
88 "#if T2_double\n"
89 "#pragma OPENCL EXTENSION cl_khr_fp64 : enable\n"
90 "#endif\n"
91 "#define K2_BITS (1 << K_BITS)\n"
92 "#define RADIX_MASK ((((T)(1)) << K_BITS) - 1)\n"
93 "#define SIGN_BIT ((sizeof(T) * CHAR_BIT) - 1)\n"
94 
95 "#if defined(ASC)\n" // asc order
96 
97 "inline uint radix(const T x, const uint low_bit)\n"
98 "{\n"
99 "#if defined(IS_FLOATING_POINT)\n"
100 "    const T mask = -(x >> SIGN_BIT) | (((T)(1)) << SIGN_BIT);\n"
101 "    return ((x ^ mask) >> low_bit) & RADIX_MASK;\n"
102 "#elif defined(IS_SIGNED)\n"
103 "    return ((x ^ (((T)(1)) << SIGN_BIT)) >> low_bit) & RADIX_MASK;\n"
104 "#else\n"
105 "    return (x >> low_bit) & RADIX_MASK;\n"
106 "#endif\n"
107 "}\n"
108 
109 "#else\n" // desc order
110 
111 // For signed types we just negate the x and for unsigned types we
112 // subtract the x from max value of its type ((T)(-1) is a max value
113 // of type T when T is an unsigned type).
114 "inline uint radix(const T x, const uint low_bit)\n"
115 "{\n"
116 "#if defined(IS_FLOATING_POINT)\n"
117 "    const T mask = -(x >> SIGN_BIT) | (((T)(1)) << SIGN_BIT);\n"
118 "    return (((-x) ^ mask) >> low_bit) & RADIX_MASK;\n"
119 "#elif defined(IS_SIGNED)\n"
120 "    return (((-x) ^ (((T)(1)) << SIGN_BIT)) >> low_bit) & RADIX_MASK;\n"
121 "#else\n"
122 "    return (((T)(-1) - x) >> low_bit) & RADIX_MASK;\n"
123 "#endif\n"
124 "}\n"
125 
126 "#endif\n" // #if defined(ASC)
127 
128 "__kernel void count(__global const T *input,\n"
129 "                    const uint input_offset,\n"
130 "                    const uint input_size,\n"
131 "                    __global uint *global_counts,\n"
132 "                    __global uint *global_offsets,\n"
133 "                    __local uint *local_counts,\n"
134 "                    const uint low_bit)\n"
135 "{\n"
136      // work-item parameters
137 "    const uint gid = get_global_id(0);\n"
138 "    const uint lid = get_local_id(0);\n"
139 
140      // zero local counts
141 "    if(lid < K2_BITS){\n"
142 "        local_counts[lid] = 0;\n"
143 "    }\n"
144 "    barrier(CLK_LOCAL_MEM_FENCE);\n"
145 
146      // reduce local counts
147 "    if(gid < input_size){\n"
148 "        T value = input[input_offset+gid];\n"
149 "        uint bucket = radix(value, low_bit);\n"
150 "        atomic_inc(local_counts + bucket);\n"
151 "    }\n"
152 "    barrier(CLK_LOCAL_MEM_FENCE);\n"
153 
154      // write block-relative offsets
155 "    if(lid < K2_BITS){\n"
156 "        global_counts[K2_BITS*get_group_id(0) + lid] = local_counts[lid];\n"
157 
158          // write global offsets
159 "        if(get_group_id(0) == (get_num_groups(0) - 1)){\n"
160 "            global_offsets[lid] = local_counts[lid];\n"
161 "        }\n"
162 "    }\n"
163 "}\n"
164 
165 "__kernel void scan(__global const uint *block_offsets,\n"
166 "                   __global uint *global_offsets,\n"
167 "                   const uint block_count)\n"
168 "{\n"
169 "    __global const uint *last_block_offsets =\n"
170 "        block_offsets + K2_BITS * (block_count - 1);\n"
171 
172      // calculate and scan global_offsets
173 "    uint sum = 0;\n"
174 "    for(uint i = 0; i < K2_BITS; i++){\n"
175 "        uint x = global_offsets[i] + last_block_offsets[i];\n"
176 "        global_offsets[i] = sum;\n"
177 "        sum += x;\n"
178 "    }\n"
179 "}\n"
180 
181 "__kernel void scatter(__global const T *input,\n"
182 "                      const uint input_offset,\n"
183 "                      const uint input_size,\n"
184 "                      const uint low_bit,\n"
185 "                      __global const uint *counts,\n"
186 "                      __global const uint *global_offsets,\n"
187 "#ifndef SORT_BY_KEY\n"
188 "                      __global T *output,\n"
189 "                      const uint output_offset)\n"
190 "#else\n"
191 "                      __global T *keys_output,\n"
192 "                      const uint keys_output_offset,\n"
193 "                      __global T2 *values_input,\n"
194 "                      const uint values_input_offset,\n"
195 "                      __global T2 *values_output,\n"
196 "                      const uint values_output_offset)\n"
197 "#endif\n"
198 "{\n"
199      // work-item parameters
200 "    const uint gid = get_global_id(0);\n"
201 "    const uint lid = get_local_id(0);\n"
202 
203      // copy input to local memory
204 "    T value;\n"
205 "    uint bucket;\n"
206 "    __local uint local_input[BLOCK_SIZE];\n"
207 "    if(gid < input_size){\n"
208 "        value = input[input_offset+gid];\n"
209 "        bucket = radix(value, low_bit);\n"
210 "        local_input[lid] = bucket;\n"
211 "    }\n"
212 
213      // copy block counts to local memory
214 "    __local uint local_counts[(1 << K_BITS)];\n"
215 "    if(lid < K2_BITS){\n"
216 "        local_counts[lid] = counts[get_group_id(0) * K2_BITS + lid];\n"
217 "    }\n"
218 
219      // wait until local memory is ready
220 "    barrier(CLK_LOCAL_MEM_FENCE);\n"
221 
222 "    if(gid >= input_size){\n"
223 "        return;\n"
224 "    }\n"
225 
226      // get global offset
227 "    uint offset = global_offsets[bucket] + local_counts[bucket];\n"
228 
229      // calculate local offset
230 "    uint local_offset = 0;\n"
231 "    for(uint i = 0; i < lid; i++){\n"
232 "        if(local_input[i] == bucket)\n"
233 "            local_offset++;\n"
234 "    }\n"
235 
236 "#ifndef SORT_BY_KEY\n"
237      // write value to output
238 "    output[output_offset + offset + local_offset] = value;\n"
239 "#else\n"
240      // write key and value if doing sort_by_key
241 "    keys_output[keys_output_offset+offset + local_offset] = value;\n"
242 "    values_output[values_output_offset+offset + local_offset] =\n"
243 "        values_input[values_input_offset+gid];\n"
244 "#endif\n"
245 "}\n";
246 
247 template<class T, class T2>
radix_sort_impl(const buffer_iterator<T> first,const buffer_iterator<T> last,const buffer_iterator<T2> values_first,const bool ascending,command_queue & queue)248 inline void radix_sort_impl(const buffer_iterator<T> first,
249                             const buffer_iterator<T> last,
250                             const buffer_iterator<T2> values_first,
251                             const bool ascending,
252                             command_queue &queue)
253 {
254 
255     typedef T value_type;
256     typedef typename radix_sort_value_type<sizeof(T)>::type sort_type;
257 
258     const device &device = queue.get_device();
259     const context &context = queue.get_context();
260 
261 
262     // if we have a valid values iterator then we are doing a
263     // sort by key and have to set up the values buffer
264     bool sort_by_key = (values_first.get_buffer().get() != 0);
265 
266     // load (or create) radix sort program
267     std::string cache_key =
268         std::string("__boost_radix_sort_") + type_name<value_type>();
269 
270     if(sort_by_key){
271         cache_key += std::string("_with_") + type_name<T2>();
272     }
273 
274     boost::shared_ptr<program_cache> cache =
275         program_cache::get_global_cache(context);
276     boost::shared_ptr<parameter_cache> parameters =
277         detail::parameter_cache::get_global_cache(device);
278 
279     // sort parameters
280     const uint_ k = parameters->get(cache_key, "k", 4);
281     const uint_ k2 = 1 << k;
282     const uint_ block_size = parameters->get(cache_key, "tpb", 128);
283 
284     // sort program compiler options
285     std::stringstream options;
286     options << "-DK_BITS=" << k;
287     options << " -DT=" << type_name<sort_type>();
288     options << " -DBLOCK_SIZE=" << block_size;
289 
290     if(boost::is_floating_point<value_type>::value){
291         options << " -DIS_FLOATING_POINT";
292     }
293 
294     if(boost::is_signed<value_type>::value){
295         options << " -DIS_SIGNED";
296     }
297 
298     if(sort_by_key){
299         options << " -DSORT_BY_KEY";
300         options << " -DT2=" << type_name<T2>();
301         options << enable_double<T2>();
302     }
303 
304     if(ascending){
305         options << " -DASC";
306     }
307 
308     // load radix sort program
309     program radix_sort_program = cache->get_or_build(
310         cache_key, options.str(), radix_sort_source, context
311     );
312 
313     kernel count_kernel(radix_sort_program, "count");
314     kernel scan_kernel(radix_sort_program, "scan");
315     kernel scatter_kernel(radix_sort_program, "scatter");
316 
317     size_t count = detail::iterator_range_size(first, last);
318 
319     uint_ block_count = static_cast<uint_>(count / block_size);
320     if(block_count * block_size != count){
321         block_count++;
322     }
323 
324     // setup temporary buffers
325     vector<value_type> output(count, context);
326     vector<T2> values_output(sort_by_key ? count : 0, context);
327     vector<uint_> offsets(k2, context);
328     vector<uint_> counts(block_count * k2, context);
329 
330     const buffer *input_buffer = &first.get_buffer();
331     uint_ input_offset = static_cast<uint_>(first.get_index());
332     const buffer *output_buffer = &output.get_buffer();
333     uint_ output_offset = 0;
334     const buffer *values_input_buffer = &values_first.get_buffer();
335     uint_ values_input_offset = static_cast<uint_>(values_first.get_index());
336     const buffer *values_output_buffer = &values_output.get_buffer();
337     uint_ values_output_offset = 0;
338 
339     for(uint_ i = 0; i < sizeof(sort_type) * CHAR_BIT / k; i++){
340         // write counts
341         count_kernel.set_arg(0, *input_buffer);
342         count_kernel.set_arg(1, input_offset);
343         count_kernel.set_arg(2, static_cast<uint_>(count));
344         count_kernel.set_arg(3, counts);
345         count_kernel.set_arg(4, offsets);
346         count_kernel.set_arg(5, block_size * sizeof(uint_), 0);
347         count_kernel.set_arg(6, i * k);
348         queue.enqueue_1d_range_kernel(count_kernel,
349                                       0,
350                                       block_count * block_size,
351                                       block_size);
352 
353         // scan counts
354         if(k == 1){
355             typedef uint2_ counter_type;
356             ::boost::compute::exclusive_scan(
357                 make_buffer_iterator<counter_type>(counts.get_buffer(), 0),
358                 make_buffer_iterator<counter_type>(counts.get_buffer(), counts.size() / 2),
359                 make_buffer_iterator<counter_type>(counts.get_buffer()),
360                 queue
361             );
362         }
363         else if(k == 2){
364             typedef uint4_ counter_type;
365             ::boost::compute::exclusive_scan(
366                 make_buffer_iterator<counter_type>(counts.get_buffer(), 0),
367                 make_buffer_iterator<counter_type>(counts.get_buffer(), counts.size() / 4),
368                 make_buffer_iterator<counter_type>(counts.get_buffer()),
369                 queue
370             );
371         }
372         else if(k == 4){
373             typedef uint16_ counter_type;
374             ::boost::compute::exclusive_scan(
375                 make_buffer_iterator<counter_type>(counts.get_buffer(), 0),
376                 make_buffer_iterator<counter_type>(counts.get_buffer(), counts.size() / 16),
377                 make_buffer_iterator<counter_type>(counts.get_buffer()),
378                 queue
379             );
380         }
381         else {
382             BOOST_ASSERT(false && "unknown k");
383             break;
384         }
385 
386         // scan global offsets
387         scan_kernel.set_arg(0, counts);
388         scan_kernel.set_arg(1, offsets);
389         scan_kernel.set_arg(2, block_count);
390         queue.enqueue_task(scan_kernel);
391 
392         // scatter values
393         scatter_kernel.set_arg(0, *input_buffer);
394         scatter_kernel.set_arg(1, input_offset);
395         scatter_kernel.set_arg(2, static_cast<uint_>(count));
396         scatter_kernel.set_arg(3, i * k);
397         scatter_kernel.set_arg(4, counts);
398         scatter_kernel.set_arg(5, offsets);
399         scatter_kernel.set_arg(6, *output_buffer);
400         scatter_kernel.set_arg(7, output_offset);
401         if(sort_by_key){
402             scatter_kernel.set_arg(8, *values_input_buffer);
403             scatter_kernel.set_arg(9, values_input_offset);
404             scatter_kernel.set_arg(10, *values_output_buffer);
405             scatter_kernel.set_arg(11, values_output_offset);
406         }
407         queue.enqueue_1d_range_kernel(scatter_kernel,
408                                       0,
409                                       block_count * block_size,
410                                       block_size);
411 
412         // swap buffers
413         std::swap(input_buffer, output_buffer);
414         std::swap(values_input_buffer, values_output_buffer);
415         std::swap(input_offset, output_offset);
416         std::swap(values_input_offset, values_output_offset);
417     }
418 }
419 
420 template<class Iterator>
radix_sort(Iterator first,Iterator last,command_queue & queue)421 inline void radix_sort(Iterator first,
422                        Iterator last,
423                        command_queue &queue)
424 {
425     radix_sort_impl(first, last, buffer_iterator<int>(), true, queue);
426 }
427 
428 template<class KeyIterator, class ValueIterator>
radix_sort_by_key(KeyIterator keys_first,KeyIterator keys_last,ValueIterator values_first,command_queue & queue)429 inline void radix_sort_by_key(KeyIterator keys_first,
430                               KeyIterator keys_last,
431                               ValueIterator values_first,
432                               command_queue &queue)
433 {
434     radix_sort_impl(keys_first, keys_last, values_first, true, queue);
435 }
436 
437 template<class Iterator>
radix_sort(Iterator first,Iterator last,const bool ascending,command_queue & queue)438 inline void radix_sort(Iterator first,
439                        Iterator last,
440                        const bool ascending,
441                        command_queue &queue)
442 {
443     radix_sort_impl(first, last, buffer_iterator<int>(), ascending, queue);
444 }
445 
446 template<class KeyIterator, class ValueIterator>
radix_sort_by_key(KeyIterator keys_first,KeyIterator keys_last,ValueIterator values_first,const bool ascending,command_queue & queue)447 inline void radix_sort_by_key(KeyIterator keys_first,
448                               KeyIterator keys_last,
449                               ValueIterator values_first,
450                               const bool ascending,
451                               command_queue &queue)
452 {
453     radix_sort_impl(keys_first, keys_last, values_first, ascending, queue);
454 }
455 
456 
457 } // end detail namespace
458 } // end compute namespace
459 } // end boost namespace
460 
461 #endif // BOOST_COMPUTE_ALGORITHM_DETAIL_RADIX_SORT_HPP
462