1 #ifndef VEXCL_REDUCE_BY_KEY_HPP
2 #define VEXCL_REDUCE_BY_KEY_HPP
3 
4 /*
5 The MIT License
6 
7 Copyright (c) 2012-2018 Denis Demidov <dennis.demidov@gmail.com>
8 
9 Permission is hereby granted, free of charge, to any person obtaining a copy
10 of this software and associated documentation files (the "Software"), to deal
11 in the Software without restriction, including without limitation the rights
12 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13 copies of the Software, and to permit persons to whom the Software is
14 furnished to do so, subject to the following conditions:
15 
16 The above copyright notice and this permission notice shall be included in
17 all copies or substantial portions of the Software.
18 
19 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
22 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
25 THE SOFTWARE.
26 */
27 
28 /**
29  * \file   vexcl/reduce_by_key.hpp
30  * \author Denis Demidov <dennis.demidov@gmail.com>
31  * \brief  Reduce by key algortihm.
32 
33 Adopted from Bolt code, see <https://github.com/HSA-Libraries/Bolt>.
34 The original code came with the following copyright notice:
35 
36 \verbatim
37 Copyright 2012 - 2013 Advanced Micro Devices, Inc.
38 
39 Licensed under the Apache License, Version 2.0 (the "License");
40 you may not use this file except in compliance with the License.
41 You may obtain a copy of the License at
42 
43     http://www.apache.org/licenses/LICENSE-2.0
44 
45 Unless required by applicable law or agreed to in writing, software
46 distributed under the License is distributed on an "AS IS" BASIS,
47 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
48 See the License for the specific language governing permissions and
49 limitations under the License.
50 \endverbatim
51 */
52 
53 #include <string>
54 
55 #include <vexcl/vector.hpp>
56 #include <vexcl/scan.hpp>
57 #include <vexcl/detail/fusion.hpp>
58 #include <vexcl/function.hpp>
59 
60 namespace vex {
61 namespace detail {
62 namespace rbk {
63 
64 //---------------------------------------------------------------------------
65 template <typename T, class Comp>
offset_calculation(const backend::command_queue & queue)66 backend::kernel offset_calculation(const backend::command_queue &queue) {
67     static detail::kernel_cache cache;
68 
69     auto kernel = cache.find(queue);
70 
71     if (kernel == cache.end()) {
72         backend::source_generator src(queue);
73 
74         Comp::define(src, "comp");
75 
76         src.begin_kernel("offset_calculation");
77         src.begin_kernel_parameters();
78         src.template parameter< size_t >("n");
79 
80         boost::mpl::for_each<T>(pointer_param<global_ptr, true>(src, "keys"));
81 
82         src.template parameter< global_ptr<int> >("offsets");
83         src.end_kernel_parameters();
84 
85         src.new_line().grid_stride_loop().open("{");
86         src.new_line()
87             << "if (idx > 0)"
88             << " offsets[idx] = !comp(";
89         for(int p = 0; p < boost::mpl::size<T>::value; ++p)
90             src << (p ? ", " : "") << "keys" << p << "[idx - 1]";
91         for(int p = 0; p < boost::mpl::size<T>::value; ++p)
92             src << ", keys" << p << "[idx]";
93         src << ");";
94         src.new_line() << "else offsets[idx] = 0;";
95         src.close("}");
96         src.end_kernel();
97 
98         kernel = cache.insert(queue, backend::kernel(
99                     queue, src.str(), "offset_calculation"));
100     }
101 
102     return kernel->second;
103 }
104 
105 //---------------------------------------------------------------------------
106 template <int NT, typename T, class Oper>
block_scan_by_key(const backend::command_queue & queue)107 backend::kernel block_scan_by_key(const backend::command_queue &queue) {
108     static detail::kernel_cache cache;
109 
110     auto kernel = cache.find(queue);
111 
112     if (kernel == cache.end()) {
113         backend::source_generator src(queue);
114 
115         Oper::define(src, "oper");
116 
117         src.begin_kernel("block_scan_by_key");
118         src.begin_kernel_parameters();
119         src.template parameter< size_t                >("n");
120         src.template parameter< global_ptr<const int> >("keys");
121         src.template parameter< global_ptr<const T>   >("vals");
122         src.template parameter< global_ptr<T>         >("output");
123         src.template parameter< global_ptr<int>       >("key_buf");
124         src.template parameter< global_ptr<T>         >("val_buf");
125         src.end_kernel_parameters();
126 
127         src.new_line() << "size_t l_id  = " << src.local_id(0)   << ";";
128         src.new_line() << "size_t g_id  = " << src.global_id(0)  << ";";
129         src.new_line() << "size_t block = " << src.group_id(0)   << ";";
130 
131         src.new_line() << "struct Shared";
132         src.open("{");
133             src.new_line() << "int keys[" << NT << "];";
134             src.new_line() << type_name<T>() << " vals[" << NT << "];";
135         src.close("};");
136 
137         src.smem_static_var("struct Shared", "shared");
138 
139         src.new_line() << "int key;";
140         src.new_line() << type_name<T>() << " val;";
141 
142         src.new_line() << "if (g_id < n)";
143         src.open("{");
144         src.new_line() << "key = keys[g_id];";
145         src.new_line() << "val = vals[g_id];";
146         src.new_line() << "shared.keys[l_id] = key;";
147         src.new_line() << "shared.vals[l_id] = val;";
148         src.close("}");
149 
150         // Computes a scan within a workgroup updates vals in lds but not keys
151         src.new_line() << type_name<T>() << " sum = val;";
152         src.new_line() << "for(size_t offset = 1; offset < " << NT << "; offset *= 2)";
153         src.open("{");
154         src.new_line().barrier();
155         src.new_line() << "if (l_id >= offset && shared.keys[l_id - offset] == key)";
156         src.open("{");
157         src.new_line() << "sum = oper(sum, shared.vals[l_id - offset]);";
158         src.close("}");
159         src.new_line().barrier();
160         src.new_line() << "shared.vals[l_id] = sum;";
161         src.close("}");
162         src.new_line().barrier();
163 
164         src.new_line() << "if (g_id >= n) return;";
165 
166         // Each work item writes out its calculated scan result, relative to the
167         // beginning of each work group
168         src.new_line() << "int key2 = -1;";
169         src.new_line() << "if (g_id < n - 1) key2 = keys[g_id + 1];";
170         src.new_line() << "if (key != key2) output[g_id] = sum;";
171 
172         src.new_line() << "if (l_id == 0)";
173         src.open("{");
174         src.new_line() << "key_buf[block] = shared.keys[" << NT - 1 << "];";
175         src.new_line() << "val_buf[block] = shared.vals[" << NT - 1 << "];";
176         src.close("}");
177 
178         src.end_kernel();
179 
180         kernel = cache.insert(queue, backend::kernel(
181                     queue, src.str(), "block_scan_by_key"));
182     }
183 
184     return kernel->second;
185 }
186 
187 //---------------------------------------------------------------------------
188 template <int NT, typename T, class Oper>
block_inclusive_scan_by_key(const backend::command_queue & queue)189 backend::kernel block_inclusive_scan_by_key(const backend::command_queue &queue)
190 {
191     static detail::kernel_cache cache;
192 
193     auto kernel = cache.find(queue);
194 
195     if (kernel == cache.end()) {
196         backend::source_generator src(queue);
197 
198         Oper::define(src, "oper");
199 
200         src.begin_kernel("block_inclusive_scan_by_key");
201         src.begin_kernel_parameters();
202         src.template parameter< size_t                >("n");
203         src.template parameter< global_ptr<const int> >("key_sum");
204         src.template parameter< global_ptr<const T>   >("pre_sum");
205         src.template parameter< global_ptr<T>         >("post_sum");
206         src.template parameter< cl_uint               >("work_per_thread");
207         src.end_kernel_parameters();
208 
209         src.new_line() << "size_t l_id   = " << src.local_id(0)   << ";";
210         src.new_line() << "size_t g_id   = " << src.global_id(0)  << ";";
211         src.new_line() << "size_t map_id = g_id * work_per_thread;";
212 
213         src.new_line() << "struct Shared";
214         src.open("{");
215             src.new_line() << "int keys[" << NT << "];";
216             src.new_line() << type_name<T>() << " vals[" << NT << "];";
217         src.close("};");
218 
219         src.smem_static_var("struct Shared", "shared");
220 
221         src.new_line() << "uint offset;";
222         src.new_line() << "int  key;";
223         src.new_line() << type_name<T>() << " work_sum;";
224 
225         src.new_line() << "if (map_id < n)";
226         src.open("{");
227         src.new_line() << "int prev_key;";
228 
229         // accumulate zeroth value manually
230         src.new_line() << "offset   = 0;";
231         src.new_line() << "key      = key_sum[map_id];";
232         src.new_line() << "work_sum = pre_sum[map_id];";
233 
234         src.new_line() << "post_sum[map_id] = work_sum;";
235 
236         //  Serial accumulation
237         src.new_line() << "for( offset = offset + 1; offset < work_per_thread; ++offset )";
238         src.open("{");
239         src.new_line() << "prev_key = key;";
240         src.new_line() << "key      = key_sum[ map_id + offset ];";
241 
242         src.new_line() << "if ( map_id + offset < n )";
243         src.open("{");
244         src.new_line() << type_name<T>() << " y = pre_sum[ map_id + offset ];";
245 
246         src.new_line() << "if ( key == prev_key ) work_sum = oper( work_sum, y );";
247         src.new_line() << "else work_sum = y;";
248 
249         src.new_line() << "post_sum[ map_id + offset ] = work_sum;";
250         src.close("}");
251         src.close("}");
252         src.close("}");
253         src.new_line().barrier();
254 
255         // load LDS with register sums
256         src.new_line() << "shared.vals[ l_id ] = work_sum;";
257         src.new_line() << "shared.keys[ l_id ] = key;";
258 
259         // scan in lds
260         src.new_line() << type_name<T>() << " scan_sum = work_sum;";
261 
262         src.new_line() << "for( offset = 1; offset < " << NT << "; offset *= 2 )";
263         src.open("{");
264         src.new_line().barrier();
265 
266         src.new_line() << "if (map_id < n)";
267         src.open("{");
268         src.new_line() << "if (l_id >= offset)";
269         src.open("{");
270         src.new_line() << "int key1 = shared.keys[ l_id ];";
271         src.new_line() << "int key2 = shared.keys[ l_id - offset ];";
272 
273         src.new_line() << "if ( key1 == key2 ) scan_sum = oper( scan_sum, shared.vals[ l_id - offset ] );";
274         src.new_line() << "else scan_sum = shared.vals[ l_id ];";
275         src.close("}");
276 
277         src.close("}");
278         src.new_line().barrier();
279 
280         src.new_line() << "shared.vals[ l_id ] = scan_sum;";
281         src.close("}");
282 
283         src.new_line().barrier();
284 
285         // write final scan from pre-scan and lds scan
286         src.new_line() << "for( offset = 0; offset < work_per_thread; ++offset )";
287         src.open("{");
288         src.new_line().barrier(true);
289 
290         src.new_line() << "if (map_id < n && l_id > 0)";
291         src.open("{");
292         src.new_line() << type_name<T>() << " y = post_sum[ map_id + offset ];";
293         src.new_line() << "int key1 = key_sum    [ map_id + offset ];";
294         src.new_line() << "int key2 = shared.keys[ l_id - 1 ];";
295 
296         src.new_line() << "if ( key1 == key2 ) y = oper( y, shared.vals[l_id - 1] );";
297 
298         src.new_line() << "post_sum[ map_id + offset ] = y;";
299         src.close("}");
300         src.close("}");
301 
302         src.end_kernel();
303 
304         kernel = cache.insert(queue, backend::kernel(
305                     queue, src.str(), "block_inclusive_scan_by_key"));
306     }
307 
308     return kernel->second;
309 }
310 
311 //---------------------------------------------------------------------------
312 template <typename T, class Oper>
block_sum_by_key(const backend::command_queue & queue)313 backend::kernel block_sum_by_key(const backend::command_queue &queue) {
314     static detail::kernel_cache cache;
315 
316     auto kernel = cache.find(queue);
317 
318     if (kernel == cache.end()) {
319         backend::source_generator src(queue);
320 
321         Oper::define(src, "oper");
322 
323         src.begin_kernel("block_sum_by_key");
324         src.begin_kernel_parameters();
325         src.template parameter< size_t                >("n");
326         src.template parameter< global_ptr<const int> >("key_sum");
327         src.template parameter< global_ptr<const T>   >("post_sum");
328         src.template parameter< global_ptr<const int> >("keys");
329         src.template parameter< global_ptr<T>         >("output");
330         src.end_kernel_parameters();
331 
332         src.new_line() << "size_t g_id  = " << src.global_id(0)  << ";";
333         src.new_line() << "size_t block = " << src.group_id(0)   << ";";
334 
335         src.new_line() << "if (g_id >= n) return;";
336 
337         // accumulate prefix
338         src.new_line() << "int key2 = keys[ g_id ];";
339         src.new_line() << "int key1 = (block > 0    ) ? key_sum[ block - 1 ] : key2 - 1;";
340         src.new_line() << "int key3 = (g_id  < n - 1) ? keys   [ g_id  + 1 ] : key2 - 1;";
341 
342         src.new_line() << "if (block > 0 && key1 == key2 && key2 != key3)";
343         src.open("{");
344         src.new_line() << type_name<T>() << " scan_result    = output  [ g_id      ];";
345         src.new_line() << type_name<T>() << " post_block_sum = post_sum[ block - 1 ];";
346         src.new_line() << "output[ g_id ] = oper( scan_result, post_block_sum );";
347         src.close("}");
348 
349         src.end_kernel();
350 
351         kernel = cache.insert(queue, backend::kernel(
352                     queue, src.str(), "block_sum_by_key"));
353     }
354 
355     return kernel->second;
356 }
357 
358 //---------------------------------------------------------------------------
359 template <typename K, typename V>
key_value_mapping(const backend::command_queue & queue)360 backend::kernel key_value_mapping(const backend::command_queue &queue) {
361     static detail::kernel_cache cache;
362 
363     auto kernel = cache.find(queue);
364 
365     if (kernel == cache.end()) {
366         backend::source_generator src(queue);
367 
368         src.begin_kernel("key_value_mapping");
369         src.begin_kernel_parameters();
370         src.template parameter< size_t >("n");
371 
372         boost::mpl::for_each<K>(pointer_param<global_ptr, true>(src, "ikeys"));
373         boost::mpl::for_each<K>(pointer_param<global_ptr      >(src, "okeys"));
374 
375         src.template parameter< global_ptr<V>       >("ovals");
376         src.template parameter< global_ptr<int>     >("offset");
377         src.template parameter< global_ptr<const V> >("ivals");
378         src.end_kernel_parameters();
379 
380         src.new_line().grid_stride_loop().open("{");
381 
382         src.new_line() << "int num_sections = offset[n - 1] + 1;";
383 
384         src.new_line() << "int off = offset[idx];";
385         src.new_line() << "if (idx < (n - 1) && off != offset[idx + 1])";
386         src.open("{");
387         for(int p = 0; p < boost::mpl::size<K>::value; ++p)
388             src.new_line() << "okeys" << p << "[off] = ikeys" << p << "[idx];";
389         src.new_line() << "ovals[off] = ivals[idx];";
390         src.close("}");
391 
392         src.new_line() << "if (idx == (n - 1))";
393         src.open("{");
394         for(int p = 0; p < boost::mpl::size<K>::value; ++p)
395             src.new_line() << "okeys" << p << "[num_sections - 1] = ikeys" << p << "[idx];";
396         src.new_line() << "ovals[num_sections - 1] = ivals[idx];";
397         src.close("}");
398 
399         src.close("}");
400 
401         src.end_kernel();
402 
403         kernel = cache.insert(queue, backend::kernel(
404                     queue, src.str(), "key_value_mapping"));
405     }
406 
407     return kernel->second;
408 }
409 
410 struct do_vex_resize {
411     const std::vector<backend::command_queue> &q;
412     size_t n;
413 
do_vex_resizevex::detail::rbk::do_vex_resize414     do_vex_resize(const std::vector<backend::command_queue> &q, size_t n)
415         : q(q), n(n) {}
416 
417     template <class V>
operator ()vex::detail::rbk::do_vex_resize418     void operator()(V &v) const {
419         v.resize(q, n);
420     }
421 };
422 
423 struct do_push_arg {
424     backend::kernel &k;
425 
do_push_argvex::detail::rbk::do_push_arg426     do_push_arg(backend::kernel &k) : k(k) {}
427 
428     template <class T>
operator ()vex::detail::rbk::do_push_arg429     void operator()(const T &t) const {
430         k.push_arg( t(0) );
431     }
432 };
433 
434 template <typename IKTuple, typename OKTuple, typename V, class Comp, class Oper>
reduce_by_key_sink(IKTuple && ikeys,vector<V> const & ivals,OKTuple && okeys,vector<V> & ovals,Comp,Oper)435 int reduce_by_key_sink(
436         IKTuple &&ikeys, vector<V> const &ivals,
437         OKTuple &&okeys, vector<V>       &ovals,
438         Comp, Oper
439         )
440 {
441     namespace fusion = boost::fusion;
442     typedef typename extract_value_types<IKTuple>::type K;
443 
444     static_assert(
445             std::is_same<K, typename extract_value_types<OKTuple>::type>::value,
446             "Incompatible input and output key types");
447 
448     precondition(
449             fusion::at_c<0>(ikeys).nparts() == 1 && ivals.nparts() == 1,
450             "reduce_by_key is only supported for single device contexts"
451             );
452 
453     precondition(fusion::at_c<0>(ikeys).size() == ivals.size(),
454             "keys and values should have same size"
455             );
456 
457     const auto &queue = fusion::at_c<0>(ikeys).queue_list();
458     backend::select_context(queue[0]);
459 
460     const int NT_cpu = 1;
461     const int NT_gpu = 256;
462     const int NT = is_cpu(queue[0]) ? NT_cpu : NT_gpu;
463 
464     size_t count         = fusion::at_c<0>(ikeys).size();
465     size_t num_blocks    = (count + NT - 1) / NT;
466     size_t scan_buf_size = alignup(num_blocks, NT);
467 
468     backend::device_vector<int> key_sum   (queue[0], scan_buf_size);
469     backend::device_vector<V>   pre_sum   (queue[0], scan_buf_size);
470     backend::device_vector<V>   post_sum  (queue[0], scan_buf_size);
471     backend::device_vector<V>   offset_val(queue[0], count);
472     backend::device_vector<int> offset    (queue[0], count);
473 
474     /***** Kernel 0 *****/
475     auto krn0 = offset_calculation<K, Comp>(queue[0]);
476 
477     krn0.push_arg(count);
478     boost::fusion::for_each(ikeys, do_push_arg(krn0));
479     krn0.push_arg(offset);
480 
481     krn0(queue[0]);
482 
483     VEX_FUNCTION(int, plus, (int, x)(int, y), return x + y;);
484     scan(queue[0], offset, offset, 0, false, plus);
485 
486     /***** Kernel 1 *****/
487     auto krn1 = is_cpu(queue[0]) ?
488         block_scan_by_key<NT_cpu, V, Oper>(queue[0]) :
489         block_scan_by_key<NT_gpu, V, Oper>(queue[0]);
490 
491     krn1.push_arg(count);
492     krn1.push_arg(offset);
493     krn1.push_arg(ivals(0));
494     krn1.push_arg(offset_val);
495     krn1.push_arg(key_sum);
496     krn1.push_arg(pre_sum);
497 
498     krn1.config(num_blocks, NT);
499     krn1(queue[0]);
500 
501     /***** Kernel 2 *****/
502     uint work_per_thread = std::max<uint>(1U, static_cast<uint>(scan_buf_size / NT));
503 
504     auto krn2 = is_cpu(queue[0]) ?
505         block_inclusive_scan_by_key<NT_cpu, V, Oper>(queue[0]) :
506         block_inclusive_scan_by_key<NT_gpu, V, Oper>(queue[0]);
507 
508     krn2.push_arg(num_blocks);
509     krn2.push_arg(key_sum);
510     krn2.push_arg(pre_sum);
511     krn2.push_arg(post_sum);
512     krn2.push_arg(work_per_thread);
513 
514     krn2.config(1, NT);
515     krn2(queue[0]);
516 
517     /***** Kernel 3 *****/
518     auto krn3 = block_sum_by_key<V, Oper>(queue[0]);
519 
520     krn3.push_arg(count);
521     krn3.push_arg(key_sum);
522     krn3.push_arg(post_sum);
523     krn3.push_arg(offset);
524     krn3.push_arg(offset_val);
525 
526     krn3.config(num_blocks, NT);
527     krn3(queue[0]);
528 
529     /***** resize okeys and ovals *****/
530     int out_elements = 0;
531     offset.read(queue[0], count - 1, 1, &out_elements, true);
532     ++out_elements;
533 
534     boost::fusion::for_each(okeys, do_vex_resize(queue, out_elements));
535     ovals.resize(ivals.queue_list(), out_elements);
536 
537     /***** Kernel 4 *****/
538     auto krn4 = key_value_mapping<K, V>(queue[0]);
539 
540     krn4.push_arg(count);
541     boost::fusion::for_each(ikeys, do_push_arg(krn4));
542     boost::fusion::for_each(okeys, do_push_arg(krn4));
543     krn4.push_arg(ovals(0));
544     krn4.push_arg(offset);
545     krn4.push_arg(offset_val);
546 
547     krn4(queue[0]);
548 
549     return out_elements;
550 }
551 
552 } // namespace rbk
553 } // namespace detail
554 
555 /// Reduce by key algorithm.
556 template <typename IKeys, typename OKeys, typename V, class Comp, class Oper>
reduce_by_key(IKeys && ikeys,vector<V> const & ivals,OKeys && okeys,vector<V> & ovals,Comp comp,Oper oper)557 int reduce_by_key(
558         IKeys &&ikeys, vector<V> const &ivals,
559         OKeys &&okeys, vector<V>       &ovals,
560         Comp comp, Oper oper
561         )
562 {
563     return detail::rbk::reduce_by_key_sink(
564             detail::forward_as_sequence(ikeys), ivals,
565             detail::forward_as_sequence(okeys), ovals,
566             comp, oper);
567 }
568 
569 /// Reduce by key algorithm.
570 template <typename K, typename V>
reduce_by_key(vector<K> const & ikeys,vector<V> const & ivals,vector<K> & okeys,vector<V> & ovals)571 int reduce_by_key(
572         vector<K> const &ikeys,
573         vector<V> const &ivals,
574         vector<K>       &okeys,
575         vector<V>       &ovals
576         )
577 {
578     VEX_FUNCTION(bool, equal, (K, x)(K, y), return x == y;);
579     VEX_FUNCTION(V, plus, (V, x)(V, y), return x + y;);
580     return reduce_by_key(ikeys, ivals, okeys, ovals, equal, plus);
581 }
582 
583 }
584 
585 #endif
586