1 #ifndef VEXCL_REDUCE_BY_KEY_HPP
2 #define VEXCL_REDUCE_BY_KEY_HPP
3
4 /*
5 The MIT License
6
7 Copyright (c) 2012-2018 Denis Demidov <dennis.demidov@gmail.com>
8
9 Permission is hereby granted, free of charge, to any person obtaining a copy
10 of this software and associated documentation files (the "Software"), to deal
11 in the Software without restriction, including without limitation the rights
12 to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13 copies of the Software, and to permit persons to whom the Software is
14 furnished to do so, subject to the following conditions:
15
16 The above copyright notice and this permission notice shall be included in
17 all copies or substantial portions of the Software.
18
19 THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20 IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21 FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22 AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23 LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24 OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
25 THE SOFTWARE.
26 */
27
28 /**
29 * \file vexcl/reduce_by_key.hpp
30 * \author Denis Demidov <dennis.demidov@gmail.com>
31 * \brief Reduce by key algortihm.
32
33 Adopted from Bolt code, see <https://github.com/HSA-Libraries/Bolt>.
34 The original code came with the following copyright notice:
35
36 \verbatim
37 Copyright 2012 - 2013 Advanced Micro Devices, Inc.
38
39 Licensed under the Apache License, Version 2.0 (the "License");
40 you may not use this file except in compliance with the License.
41 You may obtain a copy of the License at
42
43 http://www.apache.org/licenses/LICENSE-2.0
44
45 Unless required by applicable law or agreed to in writing, software
46 distributed under the License is distributed on an "AS IS" BASIS,
47 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
48 See the License for the specific language governing permissions and
49 limitations under the License.
50 \endverbatim
51 */
52
53 #include <string>
54
55 #include <vexcl/vector.hpp>
56 #include <vexcl/scan.hpp>
57 #include <vexcl/detail/fusion.hpp>
58 #include <vexcl/function.hpp>
59
60 namespace vex {
61 namespace detail {
62 namespace rbk {
63
64 //---------------------------------------------------------------------------
65 template <typename T, class Comp>
offset_calculation(const backend::command_queue & queue)66 backend::kernel offset_calculation(const backend::command_queue &queue) {
67 static detail::kernel_cache cache;
68
69 auto kernel = cache.find(queue);
70
71 if (kernel == cache.end()) {
72 backend::source_generator src(queue);
73
74 Comp::define(src, "comp");
75
76 src.begin_kernel("offset_calculation");
77 src.begin_kernel_parameters();
78 src.template parameter< size_t >("n");
79
80 boost::mpl::for_each<T>(pointer_param<global_ptr, true>(src, "keys"));
81
82 src.template parameter< global_ptr<int> >("offsets");
83 src.end_kernel_parameters();
84
85 src.new_line().grid_stride_loop().open("{");
86 src.new_line()
87 << "if (idx > 0)"
88 << " offsets[idx] = !comp(";
89 for(int p = 0; p < boost::mpl::size<T>::value; ++p)
90 src << (p ? ", " : "") << "keys" << p << "[idx - 1]";
91 for(int p = 0; p < boost::mpl::size<T>::value; ++p)
92 src << ", keys" << p << "[idx]";
93 src << ");";
94 src.new_line() << "else offsets[idx] = 0;";
95 src.close("}");
96 src.end_kernel();
97
98 kernel = cache.insert(queue, backend::kernel(
99 queue, src.str(), "offset_calculation"));
100 }
101
102 return kernel->second;
103 }
104
105 //---------------------------------------------------------------------------
106 template <int NT, typename T, class Oper>
block_scan_by_key(const backend::command_queue & queue)107 backend::kernel block_scan_by_key(const backend::command_queue &queue) {
108 static detail::kernel_cache cache;
109
110 auto kernel = cache.find(queue);
111
112 if (kernel == cache.end()) {
113 backend::source_generator src(queue);
114
115 Oper::define(src, "oper");
116
117 src.begin_kernel("block_scan_by_key");
118 src.begin_kernel_parameters();
119 src.template parameter< size_t >("n");
120 src.template parameter< global_ptr<const int> >("keys");
121 src.template parameter< global_ptr<const T> >("vals");
122 src.template parameter< global_ptr<T> >("output");
123 src.template parameter< global_ptr<int> >("key_buf");
124 src.template parameter< global_ptr<T> >("val_buf");
125 src.end_kernel_parameters();
126
127 src.new_line() << "size_t l_id = " << src.local_id(0) << ";";
128 src.new_line() << "size_t g_id = " << src.global_id(0) << ";";
129 src.new_line() << "size_t block = " << src.group_id(0) << ";";
130
131 src.new_line() << "struct Shared";
132 src.open("{");
133 src.new_line() << "int keys[" << NT << "];";
134 src.new_line() << type_name<T>() << " vals[" << NT << "];";
135 src.close("};");
136
137 src.smem_static_var("struct Shared", "shared");
138
139 src.new_line() << "int key;";
140 src.new_line() << type_name<T>() << " val;";
141
142 src.new_line() << "if (g_id < n)";
143 src.open("{");
144 src.new_line() << "key = keys[g_id];";
145 src.new_line() << "val = vals[g_id];";
146 src.new_line() << "shared.keys[l_id] = key;";
147 src.new_line() << "shared.vals[l_id] = val;";
148 src.close("}");
149
150 // Computes a scan within a workgroup updates vals in lds but not keys
151 src.new_line() << type_name<T>() << " sum = val;";
152 src.new_line() << "for(size_t offset = 1; offset < " << NT << "; offset *= 2)";
153 src.open("{");
154 src.new_line().barrier();
155 src.new_line() << "if (l_id >= offset && shared.keys[l_id - offset] == key)";
156 src.open("{");
157 src.new_line() << "sum = oper(sum, shared.vals[l_id - offset]);";
158 src.close("}");
159 src.new_line().barrier();
160 src.new_line() << "shared.vals[l_id] = sum;";
161 src.close("}");
162 src.new_line().barrier();
163
164 src.new_line() << "if (g_id >= n) return;";
165
166 // Each work item writes out its calculated scan result, relative to the
167 // beginning of each work group
168 src.new_line() << "int key2 = -1;";
169 src.new_line() << "if (g_id < n - 1) key2 = keys[g_id + 1];";
170 src.new_line() << "if (key != key2) output[g_id] = sum;";
171
172 src.new_line() << "if (l_id == 0)";
173 src.open("{");
174 src.new_line() << "key_buf[block] = shared.keys[" << NT - 1 << "];";
175 src.new_line() << "val_buf[block] = shared.vals[" << NT - 1 << "];";
176 src.close("}");
177
178 src.end_kernel();
179
180 kernel = cache.insert(queue, backend::kernel(
181 queue, src.str(), "block_scan_by_key"));
182 }
183
184 return kernel->second;
185 }
186
187 //---------------------------------------------------------------------------
188 template <int NT, typename T, class Oper>
block_inclusive_scan_by_key(const backend::command_queue & queue)189 backend::kernel block_inclusive_scan_by_key(const backend::command_queue &queue)
190 {
191 static detail::kernel_cache cache;
192
193 auto kernel = cache.find(queue);
194
195 if (kernel == cache.end()) {
196 backend::source_generator src(queue);
197
198 Oper::define(src, "oper");
199
200 src.begin_kernel("block_inclusive_scan_by_key");
201 src.begin_kernel_parameters();
202 src.template parameter< size_t >("n");
203 src.template parameter< global_ptr<const int> >("key_sum");
204 src.template parameter< global_ptr<const T> >("pre_sum");
205 src.template parameter< global_ptr<T> >("post_sum");
206 src.template parameter< cl_uint >("work_per_thread");
207 src.end_kernel_parameters();
208
209 src.new_line() << "size_t l_id = " << src.local_id(0) << ";";
210 src.new_line() << "size_t g_id = " << src.global_id(0) << ";";
211 src.new_line() << "size_t map_id = g_id * work_per_thread;";
212
213 src.new_line() << "struct Shared";
214 src.open("{");
215 src.new_line() << "int keys[" << NT << "];";
216 src.new_line() << type_name<T>() << " vals[" << NT << "];";
217 src.close("};");
218
219 src.smem_static_var("struct Shared", "shared");
220
221 src.new_line() << "uint offset;";
222 src.new_line() << "int key;";
223 src.new_line() << type_name<T>() << " work_sum;";
224
225 src.new_line() << "if (map_id < n)";
226 src.open("{");
227 src.new_line() << "int prev_key;";
228
229 // accumulate zeroth value manually
230 src.new_line() << "offset = 0;";
231 src.new_line() << "key = key_sum[map_id];";
232 src.new_line() << "work_sum = pre_sum[map_id];";
233
234 src.new_line() << "post_sum[map_id] = work_sum;";
235
236 // Serial accumulation
237 src.new_line() << "for( offset = offset + 1; offset < work_per_thread; ++offset )";
238 src.open("{");
239 src.new_line() << "prev_key = key;";
240 src.new_line() << "key = key_sum[ map_id + offset ];";
241
242 src.new_line() << "if ( map_id + offset < n )";
243 src.open("{");
244 src.new_line() << type_name<T>() << " y = pre_sum[ map_id + offset ];";
245
246 src.new_line() << "if ( key == prev_key ) work_sum = oper( work_sum, y );";
247 src.new_line() << "else work_sum = y;";
248
249 src.new_line() << "post_sum[ map_id + offset ] = work_sum;";
250 src.close("}");
251 src.close("}");
252 src.close("}");
253 src.new_line().barrier();
254
255 // load LDS with register sums
256 src.new_line() << "shared.vals[ l_id ] = work_sum;";
257 src.new_line() << "shared.keys[ l_id ] = key;";
258
259 // scan in lds
260 src.new_line() << type_name<T>() << " scan_sum = work_sum;";
261
262 src.new_line() << "for( offset = 1; offset < " << NT << "; offset *= 2 )";
263 src.open("{");
264 src.new_line().barrier();
265
266 src.new_line() << "if (map_id < n)";
267 src.open("{");
268 src.new_line() << "if (l_id >= offset)";
269 src.open("{");
270 src.new_line() << "int key1 = shared.keys[ l_id ];";
271 src.new_line() << "int key2 = shared.keys[ l_id - offset ];";
272
273 src.new_line() << "if ( key1 == key2 ) scan_sum = oper( scan_sum, shared.vals[ l_id - offset ] );";
274 src.new_line() << "else scan_sum = shared.vals[ l_id ];";
275 src.close("}");
276
277 src.close("}");
278 src.new_line().barrier();
279
280 src.new_line() << "shared.vals[ l_id ] = scan_sum;";
281 src.close("}");
282
283 src.new_line().barrier();
284
285 // write final scan from pre-scan and lds scan
286 src.new_line() << "for( offset = 0; offset < work_per_thread; ++offset )";
287 src.open("{");
288 src.new_line().barrier(true);
289
290 src.new_line() << "if (map_id < n && l_id > 0)";
291 src.open("{");
292 src.new_line() << type_name<T>() << " y = post_sum[ map_id + offset ];";
293 src.new_line() << "int key1 = key_sum [ map_id + offset ];";
294 src.new_line() << "int key2 = shared.keys[ l_id - 1 ];";
295
296 src.new_line() << "if ( key1 == key2 ) y = oper( y, shared.vals[l_id - 1] );";
297
298 src.new_line() << "post_sum[ map_id + offset ] = y;";
299 src.close("}");
300 src.close("}");
301
302 src.end_kernel();
303
304 kernel = cache.insert(queue, backend::kernel(
305 queue, src.str(), "block_inclusive_scan_by_key"));
306 }
307
308 return kernel->second;
309 }
310
311 //---------------------------------------------------------------------------
312 template <typename T, class Oper>
block_sum_by_key(const backend::command_queue & queue)313 backend::kernel block_sum_by_key(const backend::command_queue &queue) {
314 static detail::kernel_cache cache;
315
316 auto kernel = cache.find(queue);
317
318 if (kernel == cache.end()) {
319 backend::source_generator src(queue);
320
321 Oper::define(src, "oper");
322
323 src.begin_kernel("block_sum_by_key");
324 src.begin_kernel_parameters();
325 src.template parameter< size_t >("n");
326 src.template parameter< global_ptr<const int> >("key_sum");
327 src.template parameter< global_ptr<const T> >("post_sum");
328 src.template parameter< global_ptr<const int> >("keys");
329 src.template parameter< global_ptr<T> >("output");
330 src.end_kernel_parameters();
331
332 src.new_line() << "size_t g_id = " << src.global_id(0) << ";";
333 src.new_line() << "size_t block = " << src.group_id(0) << ";";
334
335 src.new_line() << "if (g_id >= n) return;";
336
337 // accumulate prefix
338 src.new_line() << "int key2 = keys[ g_id ];";
339 src.new_line() << "int key1 = (block > 0 ) ? key_sum[ block - 1 ] : key2 - 1;";
340 src.new_line() << "int key3 = (g_id < n - 1) ? keys [ g_id + 1 ] : key2 - 1;";
341
342 src.new_line() << "if (block > 0 && key1 == key2 && key2 != key3)";
343 src.open("{");
344 src.new_line() << type_name<T>() << " scan_result = output [ g_id ];";
345 src.new_line() << type_name<T>() << " post_block_sum = post_sum[ block - 1 ];";
346 src.new_line() << "output[ g_id ] = oper( scan_result, post_block_sum );";
347 src.close("}");
348
349 src.end_kernel();
350
351 kernel = cache.insert(queue, backend::kernel(
352 queue, src.str(), "block_sum_by_key"));
353 }
354
355 return kernel->second;
356 }
357
358 //---------------------------------------------------------------------------
359 template <typename K, typename V>
key_value_mapping(const backend::command_queue & queue)360 backend::kernel key_value_mapping(const backend::command_queue &queue) {
361 static detail::kernel_cache cache;
362
363 auto kernel = cache.find(queue);
364
365 if (kernel == cache.end()) {
366 backend::source_generator src(queue);
367
368 src.begin_kernel("key_value_mapping");
369 src.begin_kernel_parameters();
370 src.template parameter< size_t >("n");
371
372 boost::mpl::for_each<K>(pointer_param<global_ptr, true>(src, "ikeys"));
373 boost::mpl::for_each<K>(pointer_param<global_ptr >(src, "okeys"));
374
375 src.template parameter< global_ptr<V> >("ovals");
376 src.template parameter< global_ptr<int> >("offset");
377 src.template parameter< global_ptr<const V> >("ivals");
378 src.end_kernel_parameters();
379
380 src.new_line().grid_stride_loop().open("{");
381
382 src.new_line() << "int num_sections = offset[n - 1] + 1;";
383
384 src.new_line() << "int off = offset[idx];";
385 src.new_line() << "if (idx < (n - 1) && off != offset[idx + 1])";
386 src.open("{");
387 for(int p = 0; p < boost::mpl::size<K>::value; ++p)
388 src.new_line() << "okeys" << p << "[off] = ikeys" << p << "[idx];";
389 src.new_line() << "ovals[off] = ivals[idx];";
390 src.close("}");
391
392 src.new_line() << "if (idx == (n - 1))";
393 src.open("{");
394 for(int p = 0; p < boost::mpl::size<K>::value; ++p)
395 src.new_line() << "okeys" << p << "[num_sections - 1] = ikeys" << p << "[idx];";
396 src.new_line() << "ovals[num_sections - 1] = ivals[idx];";
397 src.close("}");
398
399 src.close("}");
400
401 src.end_kernel();
402
403 kernel = cache.insert(queue, backend::kernel(
404 queue, src.str(), "key_value_mapping"));
405 }
406
407 return kernel->second;
408 }
409
410 struct do_vex_resize {
411 const std::vector<backend::command_queue> &q;
412 size_t n;
413
do_vex_resizevex::detail::rbk::do_vex_resize414 do_vex_resize(const std::vector<backend::command_queue> &q, size_t n)
415 : q(q), n(n) {}
416
417 template <class V>
operator ()vex::detail::rbk::do_vex_resize418 void operator()(V &v) const {
419 v.resize(q, n);
420 }
421 };
422
423 struct do_push_arg {
424 backend::kernel &k;
425
do_push_argvex::detail::rbk::do_push_arg426 do_push_arg(backend::kernel &k) : k(k) {}
427
428 template <class T>
operator ()vex::detail::rbk::do_push_arg429 void operator()(const T &t) const {
430 k.push_arg( t(0) );
431 }
432 };
433
434 template <typename IKTuple, typename OKTuple, typename V, class Comp, class Oper>
reduce_by_key_sink(IKTuple && ikeys,vector<V> const & ivals,OKTuple && okeys,vector<V> & ovals,Comp,Oper)435 int reduce_by_key_sink(
436 IKTuple &&ikeys, vector<V> const &ivals,
437 OKTuple &&okeys, vector<V> &ovals,
438 Comp, Oper
439 )
440 {
441 namespace fusion = boost::fusion;
442 typedef typename extract_value_types<IKTuple>::type K;
443
444 static_assert(
445 std::is_same<K, typename extract_value_types<OKTuple>::type>::value,
446 "Incompatible input and output key types");
447
448 precondition(
449 fusion::at_c<0>(ikeys).nparts() == 1 && ivals.nparts() == 1,
450 "reduce_by_key is only supported for single device contexts"
451 );
452
453 precondition(fusion::at_c<0>(ikeys).size() == ivals.size(),
454 "keys and values should have same size"
455 );
456
457 const auto &queue = fusion::at_c<0>(ikeys).queue_list();
458 backend::select_context(queue[0]);
459
460 const int NT_cpu = 1;
461 const int NT_gpu = 256;
462 const int NT = is_cpu(queue[0]) ? NT_cpu : NT_gpu;
463
464 size_t count = fusion::at_c<0>(ikeys).size();
465 size_t num_blocks = (count + NT - 1) / NT;
466 size_t scan_buf_size = alignup(num_blocks, NT);
467
468 backend::device_vector<int> key_sum (queue[0], scan_buf_size);
469 backend::device_vector<V> pre_sum (queue[0], scan_buf_size);
470 backend::device_vector<V> post_sum (queue[0], scan_buf_size);
471 backend::device_vector<V> offset_val(queue[0], count);
472 backend::device_vector<int> offset (queue[0], count);
473
474 /***** Kernel 0 *****/
475 auto krn0 = offset_calculation<K, Comp>(queue[0]);
476
477 krn0.push_arg(count);
478 boost::fusion::for_each(ikeys, do_push_arg(krn0));
479 krn0.push_arg(offset);
480
481 krn0(queue[0]);
482
483 VEX_FUNCTION(int, plus, (int, x)(int, y), return x + y;);
484 scan(queue[0], offset, offset, 0, false, plus);
485
486 /***** Kernel 1 *****/
487 auto krn1 = is_cpu(queue[0]) ?
488 block_scan_by_key<NT_cpu, V, Oper>(queue[0]) :
489 block_scan_by_key<NT_gpu, V, Oper>(queue[0]);
490
491 krn1.push_arg(count);
492 krn1.push_arg(offset);
493 krn1.push_arg(ivals(0));
494 krn1.push_arg(offset_val);
495 krn1.push_arg(key_sum);
496 krn1.push_arg(pre_sum);
497
498 krn1.config(num_blocks, NT);
499 krn1(queue[0]);
500
501 /***** Kernel 2 *****/
502 uint work_per_thread = std::max<uint>(1U, static_cast<uint>(scan_buf_size / NT));
503
504 auto krn2 = is_cpu(queue[0]) ?
505 block_inclusive_scan_by_key<NT_cpu, V, Oper>(queue[0]) :
506 block_inclusive_scan_by_key<NT_gpu, V, Oper>(queue[0]);
507
508 krn2.push_arg(num_blocks);
509 krn2.push_arg(key_sum);
510 krn2.push_arg(pre_sum);
511 krn2.push_arg(post_sum);
512 krn2.push_arg(work_per_thread);
513
514 krn2.config(1, NT);
515 krn2(queue[0]);
516
517 /***** Kernel 3 *****/
518 auto krn3 = block_sum_by_key<V, Oper>(queue[0]);
519
520 krn3.push_arg(count);
521 krn3.push_arg(key_sum);
522 krn3.push_arg(post_sum);
523 krn3.push_arg(offset);
524 krn3.push_arg(offset_val);
525
526 krn3.config(num_blocks, NT);
527 krn3(queue[0]);
528
529 /***** resize okeys and ovals *****/
530 int out_elements = 0;
531 offset.read(queue[0], count - 1, 1, &out_elements, true);
532 ++out_elements;
533
534 boost::fusion::for_each(okeys, do_vex_resize(queue, out_elements));
535 ovals.resize(ivals.queue_list(), out_elements);
536
537 /***** Kernel 4 *****/
538 auto krn4 = key_value_mapping<K, V>(queue[0]);
539
540 krn4.push_arg(count);
541 boost::fusion::for_each(ikeys, do_push_arg(krn4));
542 boost::fusion::for_each(okeys, do_push_arg(krn4));
543 krn4.push_arg(ovals(0));
544 krn4.push_arg(offset);
545 krn4.push_arg(offset_val);
546
547 krn4(queue[0]);
548
549 return out_elements;
550 }
551
552 } // namespace rbk
553 } // namespace detail
554
555 /// Reduce by key algorithm.
556 template <typename IKeys, typename OKeys, typename V, class Comp, class Oper>
reduce_by_key(IKeys && ikeys,vector<V> const & ivals,OKeys && okeys,vector<V> & ovals,Comp comp,Oper oper)557 int reduce_by_key(
558 IKeys &&ikeys, vector<V> const &ivals,
559 OKeys &&okeys, vector<V> &ovals,
560 Comp comp, Oper oper
561 )
562 {
563 return detail::rbk::reduce_by_key_sink(
564 detail::forward_as_sequence(ikeys), ivals,
565 detail::forward_as_sequence(okeys), ovals,
566 comp, oper);
567 }
568
569 /// Reduce by key algorithm.
570 template <typename K, typename V>
reduce_by_key(vector<K> const & ikeys,vector<V> const & ivals,vector<K> & okeys,vector<V> & ovals)571 int reduce_by_key(
572 vector<K> const &ikeys,
573 vector<V> const &ivals,
574 vector<K> &okeys,
575 vector<V> &ovals
576 )
577 {
578 VEX_FUNCTION(bool, equal, (K, x)(K, y), return x == y;);
579 VEX_FUNCTION(V, plus, (V, x)(V, y), return x + y;);
580 return reduce_by_key(ikeys, ivals, okeys, ovals, equal, plus);
581 }
582
583 }
584
585 #endif
586