1"""Scan primitive.""" 2 3from __future__ import division 4from __future__ import absolute_import 5from six.moves import range 6from six.moves import zip 7 8__copyright__ = """Copyright 2011-2012 Andreas Kloeckner \ 9 Copyright 2017 Hao Gao""" 10 11__license__ = """ 12Permission is hereby granted, free of charge, to any person 13obtaining a copy of this software and associated documentation 14files (the "Software"), to deal in the Software without 15restriction, including without limitation the rights to use, 16copy, modify, merge, publish, distribute, sublicense, and/or sell 17copies of the Software, and to permit persons to whom the 18Software is furnished to do so, subject to the following 19conditions: 20 21The above copyright notice and this permission notice shall be 22included in all copies or substantial portions of the Software. 23 24THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, 25EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES 26OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND 27NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT 28HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, 29WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING 30FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR 31OTHER DEALINGS IN THE SOFTWARE. 32""" 33 34import numpy as np 35import pyopencl as cl 36import pyopencl.array # noqa 37from pyopencl.scan import ScanTemplate 38from pyopencl.tools import dtype_to_ctype 39from pytools import memoize, memoize_method, Record 40from mako.template import Template 41 42 43# {{{ copy_if 44 45_copy_if_template = ScanTemplate( 46 arguments="item_t *ary, item_t *out, scan_t *count", 47 input_expr="(%(predicate)s) ? 1 : 0", 48 scan_expr="a+b", neutral="0", 49 output_statement=""" 50 if (prev_item != item) out[item-1] = ary[i]; 51 if (i+1 == N) *count = item; 52 """, 53 template_processor="printf") 54 55 56def extract_extra_args_types_values(extra_args): 57 from pyopencl.tools import VectorArg, ScalarArg 58 59 extra_args_types = [] 60 extra_args_values = [] 61 for name, val in extra_args: 62 if isinstance(val, cl.array.Array): 63 extra_args_types.append(VectorArg(val.dtype, name, with_offset=False)) 64 extra_args_values.append(val) 65 elif isinstance(val, np.generic): 66 extra_args_types.append(ScalarArg(val.dtype, name)) 67 extra_args_values.append(val) 68 else: 69 raise RuntimeError("argument '%d' not understood" % name) 70 71 return tuple(extra_args_types), extra_args_values 72 73 74def copy_if(ary, predicate, extra_args=[], preamble="", queue=None, wait_for=None): 75 """Copy the elements of *ary* satisfying *predicate* to an output array. 76 77 :arg predicate: a C expression evaluating to a `bool`, represented as a string. 78 The value to test is available as `ary[i]`, and if the expression evaluates 79 to `true`, then this value ends up in the output. 80 :arg extra_args: |scan_extra_args| 81 :arg preamble: |preamble| 82 :arg wait_for: |explain-waitfor| 83 :returns: a tuple *(out, count, event)* where *out* is the output array, *count* 84 is an on-device scalar (fetch to host with `count.get()`) indicating 85 how many elements satisfied *predicate*, and *event* is a 86 :class:`pyopencl.Event` for dependency management. *out* is allocated 87 to the same length as *ary*, but only the first *count* entries carry 88 meaning. 89 90 .. versionadded:: 2013.1 91 """ 92 if len(ary) > np.iinfo(np.int32).max: 93 scan_dtype = np.int64 94 else: 95 scan_dtype = np.int32 96 97 extra_args_types, extra_args_values = extract_extra_args_types_values(extra_args) 98 99 knl = _copy_if_template.build(ary.context, 100 type_aliases=(("scan_t", scan_dtype), ("item_t", ary.dtype)), 101 var_values=(("predicate", predicate),), 102 more_preamble=preamble, more_arguments=extra_args_types) 103 out = cl.array.empty_like(ary) 104 count = ary._new_with_changes(data=None, offset=0, 105 shape=(), strides=(), dtype=scan_dtype) 106 107 # **dict is a Py2.5 workaround 108 evt = knl(ary, out, count, *extra_args_values, 109 **dict(queue=queue, wait_for=wait_for)) 110 111 return out, count, evt 112 113# }}} 114 115 116# {{{ remove_if 117 118def remove_if(ary, predicate, extra_args=[], preamble="", queue=None, wait_for=None): 119 """Copy the elements of *ary* not satisfying *predicate* to an output array. 120 121 :arg predicate: a C expression evaluating to a `bool`, represented as a string. 122 The value to test is available as `ary[i]`, and if the expression evaluates 123 to `false`, then this value ends up in the output. 124 :arg extra_args: |scan_extra_args| 125 :arg preamble: |preamble| 126 :arg wait_for: |explain-waitfor| 127 :returns: a tuple *(out, count, event)* where *out* is the output array, *count* 128 is an on-device scalar (fetch to host with `count.get()`) indicating 129 how many elements did not satisfy *predicate*, and *event* is a 130 :class:`pyopencl.Event` for dependency management. 131 132 .. versionadded:: 2013.1 133 """ 134 return copy_if(ary, "!(%s)" % predicate, extra_args=extra_args, 135 preamble=preamble, queue=queue, wait_for=wait_for) 136 137# }}} 138 139 140# {{{ partition 141 142_partition_template = ScanTemplate( 143 arguments=( 144 "item_t *ary, item_t *out_true, item_t *out_false, " 145 "scan_t *count_true"), 146 input_expr="(%(predicate)s) ? 1 : 0", 147 scan_expr="a+b", neutral="0", 148 output_statement="""//CL// 149 if (prev_item != item) 150 out_true[item-1] = ary[i]; 151 else 152 out_false[i-item] = ary[i]; 153 if (i+1 == N) *count_true = item; 154 """, 155 template_processor="printf") 156 157 158def partition(ary, predicate, extra_args=[], preamble="", queue=None, wait_for=None): 159 """Copy the elements of *ary* into one of two arrays depending on whether 160 they satisfy *predicate*. 161 162 :arg predicate: a C expression evaluating to a `bool`, represented as a string. 163 The value to test is available as `ary[i]`. 164 :arg extra_args: |scan_extra_args| 165 :arg preamble: |preamble| 166 :arg wait_for: |explain-waitfor| 167 :returns: a tuple *(out_true, out_false, count, event)* where *count* 168 is an on-device scalar (fetch to host with `count.get()`) indicating 169 how many elements satisfied the predicate, and *event* is a 170 :class:`pyopencl.Event` for dependency management. 171 172 .. versionadded:: 2013.1 173 """ 174 if len(ary) > np.iinfo(np.uint32).max: 175 scan_dtype = np.uint64 176 else: 177 scan_dtype = np.uint32 178 179 extra_args_types, extra_args_values = extract_extra_args_types_values(extra_args) 180 181 knl = _partition_template.build( 182 ary.context, 183 type_aliases=(("item_t", ary.dtype), ("scan_t", scan_dtype)), 184 var_values=(("predicate", predicate),), 185 more_preamble=preamble, more_arguments=extra_args_types) 186 187 out_true = cl.array.empty_like(ary) 188 out_false = cl.array.empty_like(ary) 189 count = ary._new_with_changes(data=None, offset=0, 190 shape=(), strides=(), dtype=scan_dtype) 191 192 # **dict is a Py2.5 workaround 193 evt = knl(ary, out_true, out_false, count, *extra_args_values, 194 **dict(queue=queue, wait_for=wait_for)) 195 196 return out_true, out_false, count, evt 197 198# }}} 199 200 201# {{{ unique 202 203_unique_template = ScanTemplate( 204 arguments="item_t *ary, item_t *out, scan_t *count_unique", 205 input_fetch_exprs=[ 206 ("ary_im1", "ary", -1), 207 ("ary_i", "ary", 0), 208 ], 209 input_expr="(i == 0) || (IS_EQUAL_EXPR(ary_im1, ary_i) ? 0 : 1)", 210 scan_expr="a+b", neutral="0", 211 output_statement=""" 212 if (prev_item != item) out[item-1] = ary[i]; 213 if (i+1 == N) *count_unique = item; 214 """, 215 preamble="#define IS_EQUAL_EXPR(a, b) %(macro_is_equal_expr)s\n", 216 template_processor="printf") 217 218 219def unique(ary, is_equal_expr="a == b", extra_args=[], preamble="", 220 queue=None, wait_for=None): 221 """Copy the elements of *ary* into the output if *is_equal_expr*, applied to the 222 array element and its predecessor, yields false. 223 224 Works like the UNIX command :program:`uniq`, with a potentially custom 225 comparison. This operation is often used on sorted sequences. 226 227 :arg is_equal_expr: a C expression evaluating to a `bool`, 228 represented as a string. The elements being compared are 229 available as `a` and `b`. If this expression yields `false`, the 230 two are considered distinct. 231 :arg extra_args: |scan_extra_args| 232 :arg preamble: |preamble| 233 :arg wait_for: |explain-waitfor| 234 :returns: a tuple *(out, count, event)* where *out* is the output array, *count* 235 is an on-device scalar (fetch to host with `count.get()`) indicating 236 how many elements satisfied the predicate, and *event* is a 237 :class:`pyopencl.Event` for dependency management. 238 239 .. versionadded:: 2013.1 240 """ 241 242 if len(ary) > np.iinfo(np.uint32).max: 243 scan_dtype = np.uint64 244 else: 245 scan_dtype = np.uint32 246 247 extra_args_types, extra_args_values = extract_extra_args_types_values(extra_args) 248 249 knl = _unique_template.build( 250 ary.context, 251 type_aliases=(("item_t", ary.dtype), ("scan_t", scan_dtype)), 252 var_values=(("macro_is_equal_expr", is_equal_expr),), 253 more_preamble=preamble, more_arguments=extra_args_types) 254 255 out = cl.array.empty_like(ary) 256 count = ary._new_with_changes(data=None, offset=0, 257 shape=(), strides=(), dtype=scan_dtype) 258 259 # **dict is a Py2.5 workaround 260 evt = knl(ary, out, count, *extra_args_values, 261 **dict(queue=queue, wait_for=wait_for)) 262 263 return out, count, evt 264 265# }}} 266 267 268# {{{ radix_sort 269 270def to_bin(n): 271 # Py 2.5 has no built-in bin() 272 digs = [] 273 while n: 274 digs.append(str(n % 2)) 275 n >>= 1 276 277 return ''.join(digs[::-1]) 278 279 280def _padded_bin(i, l): 281 s = to_bin(i) 282 while len(s) < l: 283 s = '0' + s 284 return s 285 286 287@memoize 288def _make_sort_scan_type(device, bits, index_dtype): 289 name = "pyopencl_sort_scan_%s_%dbits_t" % ( 290 index_dtype.type.__name__, bits) 291 292 fields = [] 293 for mnr in range(2**bits): 294 fields.append(('c%s' % _padded_bin(mnr, bits), index_dtype)) 295 296 dtype = np.dtype(fields) 297 298 from pyopencl.tools import get_or_register_dtype, match_dtype_to_c_struct 299 dtype, c_decl = match_dtype_to_c_struct(device, name, dtype) 300 301 dtype = get_or_register_dtype(name, dtype) 302 return name, dtype, c_decl 303 304 305# {{{ types, helpers preamble 306 307RADIX_SORT_PREAMBLE_TPL = Template(r"""//CL// 308 typedef ${scan_ctype} scan_t; 309 typedef ${key_ctype} key_t; 310 typedef ${index_ctype} index_t; 311 312 // #define DEBUG 313 #ifdef DEBUG 314 #define dbg_printf(ARGS) printf ARGS 315 #else 316 #define dbg_printf(ARGS) /* */ 317 #endif 318 319 index_t get_count(scan_t s, int mnr) 320 { 321 return ${get_count_branch("")}; 322 } 323 324 #define BIN_NR(key_arg) ((key_arg >> base_bit) & ${2**bits - 1}) 325 326""", strict_undefined=True) 327 328# }}} 329 330# {{{ scan helpers 331 332RADIX_SORT_SCAN_PREAMBLE_TPL = Template(r"""//CL// 333 scan_t scan_t_neutral() 334 { 335 scan_t result; 336 %for mnr in range(2**bits): 337 result.c${padded_bin(mnr, bits)} = 0; 338 %endfor 339 return result; 340 } 341 342 // considers bits (base_bit+bits-1, ..., base_bit) 343 scan_t scan_t_from_value( 344 key_t key, 345 int base_bit, 346 int i 347 ) 348 { 349 // extract relevant bit range 350 key_t bin_nr = BIN_NR(key); 351 352 dbg_printf(("i: %d key:%d bin_nr:%d\n", i, key, bin_nr)); 353 354 scan_t result; 355 %for mnr in range(2**bits): 356 result.c${padded_bin(mnr, bits)} = (bin_nr == ${mnr}); 357 %endfor 358 359 return result; 360 } 361 362 scan_t scan_t_add(scan_t a, scan_t b, bool across_seg_boundary) 363 { 364 %for mnr in range(2**bits): 365 <% field = "c"+padded_bin(mnr, bits) %> 366 b.${field} = a.${field} + b.${field}; 367 %endfor 368 369 return b; 370 } 371""", strict_undefined=True) 372 373RADIX_SORT_OUTPUT_STMT_TPL = Template(r"""//CL// 374 { 375 key_t key = ${key_expr}; 376 key_t my_bin_nr = BIN_NR(key); 377 378 index_t previous_bins_size = 0; 379 %for mnr in range(2**bits): 380 previous_bins_size += 381 (my_bin_nr > ${mnr}) 382 ? last_item.c${padded_bin(mnr, bits)} 383 : 0; 384 %endfor 385 386 index_t tgt_idx = 387 previous_bins_size 388 + get_count(item, my_bin_nr) - 1; 389 390 %for arg_name in sort_arg_names: 391 sorted_${arg_name}[tgt_idx] = ${arg_name}[i]; 392 %endfor 393 } 394""", strict_undefined=True) 395 396# }}} 397 398 399# {{{ driver 400 401# import hoisted here to be used as a default argument in the constructor 402from pyopencl.scan import GenericScanKernel 403 404 405class RadixSort(object): 406 """Provides a general `radix sort <https://en.wikipedia.org/wiki/Radix_sort>`_ 407 on the compute device. 408 409 .. seealso:: :class:`pyopencl.algorithm.BitonicSort` 410 411 .. versionadded:: 2013.1 412 """ 413 def __init__(self, context, arguments, key_expr, sort_arg_names, 414 bits_at_a_time=2, index_dtype=np.int32, key_dtype=np.uint32, 415 scan_kernel=GenericScanKernel, options=[]): 416 """ 417 :arg arguments: A string of comma-separated C argument declarations. 418 If *arguments* is specified, then *input_expr* must also be 419 specified. All types used here must be known to PyOpenCL. 420 (see :func:`pyopencl.tools.get_or_register_dtype`). 421 :arg key_expr: An integer-valued C expression returning the 422 key based on which the sort is performed. The array index 423 for which the key is to be computed is available as `i`. 424 The expression may refer to any of the *arguments*. 425 :arg sort_arg_names: A list of argument names whose corresponding 426 array arguments will be sorted according to *key_expr*. 427 """ 428 429 # {{{ arg processing 430 431 from pyopencl.tools import parse_arg_list 432 self.arguments = parse_arg_list(arguments) 433 del arguments 434 435 self.sort_arg_names = sort_arg_names 436 self.bits = int(bits_at_a_time) 437 self.index_dtype = np.dtype(index_dtype) 438 self.key_dtype = np.dtype(key_dtype) 439 440 self.options = options 441 442 # }}} 443 444 # {{{ kernel creation 445 446 scan_ctype, scan_dtype, scan_t_cdecl = \ 447 _make_sort_scan_type(context.devices[0], self.bits, self.index_dtype) 448 449 from pyopencl.tools import VectorArg, ScalarArg 450 scan_arguments = ( 451 list(self.arguments) 452 + [VectorArg(arg.dtype, "sorted_"+arg.name) for arg in self.arguments 453 if arg.name in sort_arg_names] 454 + [ScalarArg(np.int32, "base_bit")]) 455 456 def get_count_branch(known_bits): 457 if len(known_bits) == self.bits: 458 return "s.c%s" % known_bits 459 460 boundary_mnr = known_bits + "1" + (self.bits-len(known_bits)-1)*"0" 461 462 return ("((mnr < %s) ? %s : %s)" % ( 463 int(boundary_mnr, 2), 464 get_count_branch(known_bits+"0"), 465 get_count_branch(known_bits+"1"))) 466 467 codegen_args = dict( 468 bits=self.bits, 469 key_ctype=dtype_to_ctype(self.key_dtype), 470 key_expr=key_expr, 471 index_ctype=dtype_to_ctype(self.index_dtype), 472 index_type_max=np.iinfo(self.index_dtype).max, 473 padded_bin=_padded_bin, 474 scan_ctype=scan_ctype, 475 sort_arg_names=sort_arg_names, 476 get_count_branch=get_count_branch, 477 ) 478 479 preamble = scan_t_cdecl+RADIX_SORT_PREAMBLE_TPL.render(**codegen_args) 480 scan_preamble = preamble \ 481 + RADIX_SORT_SCAN_PREAMBLE_TPL.render(**codegen_args) 482 483 self.scan_kernel = scan_kernel( 484 context, scan_dtype, 485 arguments=scan_arguments, 486 input_expr="scan_t_from_value(%s, base_bit, i)" % key_expr, 487 scan_expr="scan_t_add(a, b, across_seg_boundary)", 488 neutral="scan_t_neutral()", 489 output_statement=RADIX_SORT_OUTPUT_STMT_TPL.render(**codegen_args), 490 preamble=scan_preamble, options=self.options) 491 492 for i, arg in enumerate(self.arguments): 493 if isinstance(arg, VectorArg): 494 self.first_array_arg_idx = i 495 496 # }}} 497 498 def __call__(self, *args, **kwargs): 499 """Run the radix sort. In addition to *args* which must match the 500 *arguments* specification on the constructor, the following 501 keyword arguments are supported: 502 503 :arg key_bits: specify how many bits (starting from least-significant) 504 there are in the key. 505 :arg allocator: See the *allocator* argument of :func:`pyopencl.array.empty`. 506 :arg queue: A :class:`pyopencl.CommandQueue`, defaulting to the 507 one from the first argument array. 508 :arg wait_for: |explain-waitfor| 509 :returns: A tuple ``(sorted, event)``. *sorted* consists of sorted 510 copies of the arrays named in *sorted_args*, in the order of that 511 list. *event* is a :class:`pyopencl.Event` for dependency management. 512 """ 513 514 wait_for = kwargs.pop("wait_for", None) 515 516 # {{{ run control 517 518 key_bits = kwargs.pop("key_bits", None) 519 if key_bits is None: 520 key_bits = int(np.iinfo(self.key_dtype).bits) 521 522 n = len(args[self.first_array_arg_idx]) 523 524 allocator = kwargs.pop("allocator", None) 525 if allocator is None: 526 allocator = args[self.first_array_arg_idx].allocator 527 528 queue = kwargs.pop("queue", None) 529 if queue is None: 530 queue = args[self.first_array_arg_idx].queue 531 532 args = list(args) 533 534 base_bit = 0 535 while base_bit < key_bits: 536 sorted_args = [ 537 cl.array.empty(queue, n, arg_descr.dtype, allocator=allocator) 538 for arg_descr in self.arguments 539 if arg_descr.name in self.sort_arg_names] 540 541 scan_args = args + sorted_args + [base_bit] 542 543 last_evt = self.scan_kernel(*scan_args, 544 **dict(queue=queue, wait_for=wait_for)) 545 wait_for = [last_evt] 546 547 # substitute sorted 548 for i, arg_descr in enumerate(self.arguments): 549 if arg_descr.name in self.sort_arg_names: 550 args[i] = sorted_args[self.sort_arg_names.index(arg_descr.name)] 551 552 base_bit += self.bits 553 554 return [arg_val 555 for arg_descr, arg_val in zip(self.arguments, args) 556 if arg_descr.name in self.sort_arg_names], last_evt 557 558 # }}} 559 560# }}} 561 562# }}} 563 564 565# {{{ generic parallel list builder 566 567# {{{ kernel template 568 569_LIST_BUILDER_TEMPLATE = Template("""//CL// 570% if double_support: 571 #if __OPENCL_C_VERSION__ < 120 572 #pragma OPENCL EXTENSION cl_khr_fp64: enable 573 #endif 574 #define PYOPENCL_DEFINE_CDOUBLE 575% endif 576 577#include <pyopencl-complex.h> 578 579${preamble} 580 581// {{{ declare helper macros for user interface 582 583typedef ${index_type} index_type; 584 585%if is_count_stage: 586 #define PLB_COUNT_STAGE 587 588 %for name, dtype in list_names_and_dtypes: 589 %if name in count_sharing: 590 #define APPEND_${name}(value) { /* nothing */ } 591 %else: 592 #define APPEND_${name}(value) { ++(*plb_loc_${name}_count); } 593 %endif 594 %endfor 595%else: 596 #define PLB_WRITE_STAGE 597 598 %for name, dtype in list_names_and_dtypes: 599 %if name in count_sharing: 600 #define APPEND_${name}(value) \ 601 { plb_${name}_list[(*plb_${count_sharing[name]}_index) - 1] \ 602 = value; } 603 %else: 604 #define APPEND_${name}(value) \ 605 { plb_${name}_list[(*plb_${name}_index)++] = value; } 606 %endif 607 %endfor 608%endif 609 610#define LIST_ARG_DECL ${user_list_arg_decl} 611#define LIST_ARGS ${user_list_args} 612#define USER_ARG_DECL ${user_arg_decl} 613#define USER_ARGS ${user_args} 614 615// }}} 616 617${generate_template} 618 619// {{{ kernel entry point 620 621__kernel 622%if do_not_vectorize: 623__attribute__((reqd_work_group_size(1, 1, 1))) 624%endif 625void ${kernel_name}(${kernel_list_arg_decl} USER_ARG_DECL index_type n) 626 627{ 628 %if not do_not_vectorize: 629 int lid = get_local_id(0); 630 index_type gsize = get_global_size(0); 631 index_type work_group_start = get_local_size(0)*get_group_id(0); 632 for (index_type i = work_group_start + lid; i < n; i += gsize) 633 %else: 634 const int chunk_size = 128; 635 index_type chunk_base = get_global_id(0)*chunk_size; 636 index_type gsize = get_global_size(0); 637 for (; chunk_base < n; chunk_base += gsize*chunk_size) 638 for (index_type i = chunk_base; i < min(n, chunk_base+chunk_size); ++i) 639 %endif 640 { 641 %if is_count_stage: 642 %for name, dtype in list_names_and_dtypes: 643 %if name not in count_sharing: 644 index_type plb_loc_${name}_count = 0; 645 %endif 646 %endfor 647 %else: 648 %for name, dtype in list_names_and_dtypes: 649 %if name not in count_sharing: 650 index_type plb_${name}_index; 651 if (plb_${name}_start_index) 652 %if name in eliminate_empty_output_lists: 653 plb_${name}_index = 654 plb_${name}_start_index[ 655 ${name}_compressed_indices[i] 656 ]; 657 %else: 658 plb_${name}_index = plb_${name}_start_index[i]; 659 %endif 660 else 661 plb_${name}_index = 0; 662 %endif 663 %endfor 664 %endif 665 666 generate(${kernel_list_arg_values} USER_ARGS i); 667 668 %if is_count_stage: 669 %for name, dtype in list_names_and_dtypes: 670 %if name not in count_sharing: 671 if (plb_${name}_count) 672 plb_${name}_count[i] = plb_loc_${name}_count; 673 %endif 674 %endfor 675 %endif 676 } 677} 678 679// }}} 680 681""", strict_undefined=True) 682 683# }}} 684 685 686def _get_arg_decl(arg_list): 687 result = "" 688 for arg in arg_list: 689 result += arg.declarator() + ", " 690 691 return result 692 693 694def _get_arg_list(arg_list, prefix=""): 695 result = "" 696 for arg in arg_list: 697 result += prefix + arg.name + ", " 698 699 return result 700 701 702class BuiltList(Record): 703 pass 704 705 706class ListOfListsBuilder: 707 """Generates and executes code to produce a large number of variable-size 708 lists, simply. 709 710 .. note:: This functionality is provided as a preview. Its interface 711 is subject to change until this notice is removed. 712 713 .. versionadded:: 2013.1 714 715 Here's a usage example:: 716 717 from pyopencl.algorithm import ListOfListsBuilder 718 builder = ListOfListsBuilder(context, [("mylist", np.int32)], \"\"\" 719 void generate(LIST_ARG_DECL USER_ARG_DECL index_type i) 720 { 721 int count = i % 4; 722 for (int j = 0; j < count; ++j) 723 { 724 APPEND_mylist(count); 725 } 726 } 727 \"\"\", arg_decls=[]) 728 729 result, event = builder(queue, 2000) 730 731 inf = result["mylist"] 732 assert inf.count == 3000 733 assert (inf.list.get()[-6:] == [1, 2, 2, 3, 3, 3]).all() 734 735 The function `generate` above is called once for each "input object". 736 Each input object can then generate zero or more list entries. 737 The number of these input objects is given to :meth:`__call__` as *n_objects*. 738 List entries are generated by calls to `APPEND_<list name>(value)`. 739 Multiple lists may be generated at once. 740 741 .. automethod:: __init__ 742 .. automethod:: __call__ 743 """ 744 def __init__(self, context, list_names_and_dtypes, generate_template, 745 arg_decls, count_sharing=None, devices=None, 746 name_prefix="plb_build_list", options=[], preamble="", 747 debug=False, complex_kernel=False, 748 eliminate_empty_output_lists=[]): 749 """ 750 :arg context: A :class:`pyopencl.Context`. 751 :arg list_names_and_dtypes: a list of `(name, dtype)` tuples 752 indicating the lists to be built. 753 :arg generate_template: a snippet of C as described below 754 :arg arg_decls: A string of comma-separated C argument declarations. 755 :arg count_sharing: A mapping consisting of `(child, mother)` 756 indicating that `mother` and `child` will always have the 757 same number of indices, and the `APPEND` to `mother` 758 will always happen *before* the `APPEND` to the child. 759 :arg name_prefix: the name prefix to use for the compiled kernels 760 :arg options: OpenCL compilation options for kernels using 761 *generate_template*. 762 :arg complex_kernel: If `True`, prevents vectorization on CPUs. 763 :arg eliminate_empty_output_lists: A Python list of list names 764 for which the empty output lists are eliminated. 765 766 *generate_template* may use the following C macros/identifiers: 767 768 * `index_type`: expands to C identifier for the index type used 769 for the calculation 770 * `USER_ARG_DECL`: expands to the C declarator for `arg_decls` 771 * `USER_ARGS`: a list of C argument values corresponding to 772 `user_arg_decl` 773 * `LIST_ARG_DECL`: expands to a C argument list representing the 774 data for the output lists. These are escaped prefixed with 775 `"plg_"` so as to not interfere with user-provided names. 776 * `LIST_ARGS`: a list of C argument values corresponding to 777 `LIST_ARG_DECL` 778 * `APPEND_name(entry)`: inserts `entry` into the list `name`. 779 *entry* must be a valid C expression of the correct type. 780 781 All argument-list related macros have a trailing comma included 782 if they are non-empty. 783 784 *generate_template* must supply a function: 785 786 .. code-block:: c 787 788 void generate(USER_ARG_DECL LIST_ARG_DECL index_type i) 789 { 790 APPEND_mylist(5); 791 } 792 793 Internally, the `kernel_template` is expanded (at least) twice. Once, 794 for a 'counting' stage where the size of all the lists is determined, 795 and a second time, for a 'generation' stage where the lists are 796 actually filled. A `generate` function that has side effects beyond 797 calling `append` is therefore ill-formed. 798 799 .. versionchanged:: 2018.1 800 801 Change *eliminate_empty_output_lists* argument type from `bool` to 802 `list`. 803 """ 804 805 if devices is None: 806 devices = context.devices 807 808 if count_sharing is None: 809 count_sharing = {} 810 811 self.context = context 812 self.devices = devices 813 814 self.list_names_and_dtypes = list_names_and_dtypes 815 self.generate_template = generate_template 816 817 from pyopencl.tools import parse_arg_list 818 self.arg_decls = parse_arg_list(arg_decls) 819 820 self.count_sharing = count_sharing 821 822 self.name_prefix = name_prefix 823 self.preamble = preamble 824 self.options = options 825 826 self.debug = debug 827 828 self.complex_kernel = complex_kernel 829 830 if eliminate_empty_output_lists is True: 831 eliminate_empty_output_lists = \ 832 [name for name, _ in self.list_names_and_dtypes] 833 834 if eliminate_empty_output_lists is False: 835 eliminate_empty_output_lists = [] 836 837 self.eliminate_empty_output_lists = eliminate_empty_output_lists 838 for list_name in self.eliminate_empty_output_lists: 839 if not any(list_name == name for name, _ in self.list_names_and_dtypes): 840 raise ValueError( 841 "invalid list name '%s' in eliminate_empty_output_lists" 842 % list_name) 843 844 # {{{ kernel generators 845 846 @memoize_method 847 def get_scan_kernel(self, index_dtype): 848 from pyopencl.scan import GenericScanKernel 849 return GenericScanKernel( 850 self.context, index_dtype, 851 arguments="__global %s *ary" % dtype_to_ctype(index_dtype), 852 input_expr="ary[i]", 853 scan_expr="a+b", neutral="0", 854 output_statement="ary[i+1] = item;", 855 devices=self.devices) 856 857 @memoize_method 858 def get_compress_kernel(self, index_dtype): 859 arguments = """ 860 __global ${index_t} *count, 861 __global ${index_t} *compressed_counts, 862 __global ${index_t} *nonempty_indices, 863 __global ${index_t} *compressed_indices, 864 __global ${index_t} *num_non_empty_list 865 """ 866 from sys import version_info 867 if version_info > (3, 0): 868 arguments = Template(arguments) 869 else: 870 arguments = Template(arguments, disable_unicode=True) 871 872 from pyopencl.scan import GenericScanKernel 873 return GenericScanKernel( 874 self.context, index_dtype, 875 arguments=arguments.render(index_t=dtype_to_ctype(index_dtype)), 876 input_expr="count[i] == 0 ? 0 : 1", 877 scan_expr="a+b", neutral="0", 878 output_statement=""" 879 if (i + 1 < N) compressed_indices[i + 1] = item; 880 if (prev_item != item) { 881 nonempty_indices[item - 1] = i; 882 compressed_counts[item - 1] = count[i]; 883 } 884 if (i + 1 == N) *num_non_empty_list = item; 885 """, 886 devices=self.devices) 887 888 def do_not_vectorize(self): 889 from pytools import any 890 return (self.complex_kernel 891 and any(dev.type & cl.device_type.CPU 892 for dev in self.context.devices)) 893 894 @memoize_method 895 def get_count_kernel(self, index_dtype): 896 index_ctype = dtype_to_ctype(index_dtype) 897 from pyopencl.tools import VectorArg, OtherArg 898 kernel_list_args = [ 899 VectorArg(index_dtype, "plb_%s_count" % name) 900 for name, dtype in self.list_names_and_dtypes 901 if name not in self.count_sharing] 902 903 user_list_args = [] 904 for name, dtype in self.list_names_and_dtypes: 905 if name in self.count_sharing: 906 continue 907 908 name = "plb_loc_%s_count" % name 909 user_list_args.append(OtherArg("%s *%s" % ( 910 index_ctype, name), name)) 911 912 kernel_name = self.name_prefix+"_count" 913 914 from pyopencl.characterize import has_double_support 915 src = _LIST_BUILDER_TEMPLATE.render( 916 is_count_stage=True, 917 kernel_name=kernel_name, 918 double_support=all(has_double_support(dev) for dev in 919 self.context.devices), 920 debug=self.debug, 921 do_not_vectorize=self.do_not_vectorize(), 922 eliminate_empty_output_lists=self.eliminate_empty_output_lists, 923 924 kernel_list_arg_decl=_get_arg_decl(kernel_list_args), 925 kernel_list_arg_values=_get_arg_list(user_list_args, prefix="&"), 926 user_list_arg_decl=_get_arg_decl(user_list_args), 927 user_list_args=_get_arg_list(user_list_args), 928 user_arg_decl=_get_arg_decl(self.arg_decls), 929 user_args=_get_arg_list(self.arg_decls), 930 931 list_names_and_dtypes=self.list_names_and_dtypes, 932 count_sharing=self.count_sharing, 933 name_prefix=self.name_prefix, 934 generate_template=self.generate_template, 935 preamble=self.preamble, 936 937 index_type=index_ctype, 938 ) 939 940 src = str(src) 941 942 prg = cl.Program(self.context, src).build(self.options) 943 knl = getattr(prg, kernel_name) 944 945 from pyopencl.tools import get_arg_list_scalar_arg_dtypes 946 knl.set_scalar_arg_dtypes(get_arg_list_scalar_arg_dtypes( 947 kernel_list_args+self.arg_decls) + [index_dtype]) 948 949 return knl 950 951 @memoize_method 952 def get_write_kernel(self, index_dtype): 953 index_ctype = dtype_to_ctype(index_dtype) 954 from pyopencl.tools import VectorArg, OtherArg 955 kernel_list_args = [] 956 kernel_list_arg_values = "" 957 user_list_args = [] 958 959 for name, dtype in self.list_names_and_dtypes: 960 list_name = "plb_%s_list" % name 961 list_arg = VectorArg(dtype, list_name) 962 963 kernel_list_args.append(list_arg) 964 user_list_args.append(list_arg) 965 966 if name in self.count_sharing: 967 kernel_list_arg_values += "%s, " % list_name 968 continue 969 970 kernel_list_args.append( 971 VectorArg(index_dtype, "plb_%s_start_index" % name)) 972 973 if name in self.eliminate_empty_output_lists: 974 kernel_list_args.append( 975 VectorArg(index_dtype, "%s_compressed_indices" % name)) 976 977 index_name = "plb_%s_index" % name 978 user_list_args.append(OtherArg("%s *%s" % ( 979 index_ctype, index_name), index_name)) 980 981 kernel_list_arg_values += "%s, &%s, " % (list_name, index_name) 982 983 kernel_name = self.name_prefix+"_write" 984 985 from pyopencl.characterize import has_double_support 986 src = _LIST_BUILDER_TEMPLATE.render( 987 is_count_stage=False, 988 kernel_name=kernel_name, 989 double_support=all(has_double_support(dev) for dev in 990 self.context.devices), 991 debug=self.debug, 992 do_not_vectorize=self.do_not_vectorize(), 993 eliminate_empty_output_lists=self.eliminate_empty_output_lists, 994 995 kernel_list_arg_decl=_get_arg_decl(kernel_list_args), 996 kernel_list_arg_values=kernel_list_arg_values, 997 user_list_arg_decl=_get_arg_decl(user_list_args), 998 user_list_args=_get_arg_list(user_list_args), 999 user_arg_decl=_get_arg_decl(self.arg_decls), 1000 user_args=_get_arg_list(self.arg_decls), 1001 1002 list_names_and_dtypes=self.list_names_and_dtypes, 1003 count_sharing=self.count_sharing, 1004 name_prefix=self.name_prefix, 1005 generate_template=self.generate_template, 1006 preamble=self.preamble, 1007 1008 index_type=index_ctype, 1009 ) 1010 1011 src = str(src) 1012 1013 prg = cl.Program(self.context, src).build(self.options) 1014 knl = getattr(prg, kernel_name) 1015 1016 from pyopencl.tools import get_arg_list_scalar_arg_dtypes 1017 knl.set_scalar_arg_dtypes(get_arg_list_scalar_arg_dtypes( 1018 kernel_list_args+self.arg_decls) + [index_dtype]) 1019 1020 return knl 1021 1022 # }}} 1023 1024 # {{{ driver 1025 1026 def __call__(self, queue, n_objects, *args, **kwargs): 1027 """ 1028 :arg args: arguments corresponding to arg_decls in the constructor. 1029 :class:`pyopencl.array.Array` are not allowed directly and should 1030 be passed as their :attr:`pyopencl.array.Array.data` attribute instead. 1031 :arg allocator: optionally, the allocator to use to allocate new 1032 arrays. 1033 :arg omit_lists: An iterable of list names that should *not* be built 1034 with this invocation. The kernel code may *not* call ``APPEND_name`` 1035 for these omitted lists. If it does, undefined behavior will result. 1036 The returned *lists* dictionary will not contain an entry for names 1037 in *omit_lists*. 1038 :arg wait_for: |explain-waitfor| 1039 :returns: a tuple ``(lists, event)``, where 1040 *lists* a mapping from (built) list names to objects which 1041 have attributes 1042 1043 * ``count`` for the total number of entries in all lists combined 1044 * ``lists`` for the array containing all lists. 1045 * ``starts`` for the array of starting indices in `lists`. 1046 `starts` is built so that it has n+1 entries, so that 1047 the *i*'th entry is the start of the *i*'th list, and the 1048 *i*'th entry is the index one past the *i*'th list's end, 1049 even for the last list. 1050 1051 This implies that all lists are contiguous. 1052 1053 If the list name is specified in *eliminate_empty_output_lists* 1054 constructor argument, *lists* has two additional attributes 1055 ``num_nonempty_lists`` and ``nonempty_indices`` 1056 1057 * ``num_nonempty_lists`` for the number of nonempty lists. 1058 * ``nonempty_indices`` for the index of nonempty list in input objects. 1059 1060 In this case, `starts` has `num_nonempty_lists` + 1 entries. The *i*'s 1061 entry is the start of the *i*'th nonempty list, which is generated by 1062 the object with index *nonempty_indices[i]*. 1063 1064 *event* is a :class:`pyopencl.Event` for dependency management. 1065 1066 .. versionchanged:: 2016.2 1067 1068 Added omit_lists. 1069 """ 1070 if n_objects >= int(np.iinfo(np.int32).max): 1071 index_dtype = np.int64 1072 else: 1073 index_dtype = np.int32 1074 index_dtype = np.dtype(index_dtype) 1075 1076 allocator = kwargs.pop("allocator", None) 1077 omit_lists = kwargs.pop("omit_lists", []) 1078 wait_for = kwargs.pop("wait_for", None) 1079 if kwargs: 1080 raise TypeError("invalid keyword arguments: '%s'" % ", ".join(kwargs)) 1081 1082 for oml in omit_lists: 1083 if not any(oml == name for name, _ in self.list_names_and_dtypes): 1084 raise ValueError("invalid list name '%s' in omit_lists") 1085 1086 result = {} 1087 count_list_args = [] 1088 1089 if wait_for is None: 1090 wait_for = [] 1091 1092 count_kernel = self.get_count_kernel(index_dtype) 1093 write_kernel = self.get_write_kernel(index_dtype) 1094 scan_kernel = self.get_scan_kernel(index_dtype) 1095 if self.eliminate_empty_output_lists: 1096 compress_kernel = self.get_compress_kernel(index_dtype) 1097 1098 # {{{ allocate memory for counts 1099 1100 for name, dtype in self.list_names_and_dtypes: 1101 if name in self.count_sharing: 1102 continue 1103 if name in omit_lists: 1104 count_list_args.append(None) 1105 continue 1106 1107 counts = cl.array.empty(queue, 1108 (n_objects + 1), index_dtype, allocator=allocator) 1109 counts[-1] = 0 1110 wait_for = wait_for + counts.events 1111 1112 # The scan will turn the "counts" array into the "starts" array 1113 # in-place. 1114 if name in self.eliminate_empty_output_lists: 1115 result[name] = BuiltList(count=None, starts=counts, lists=None, 1116 num_nonempty_lists=None, 1117 nonempty_indices=None) 1118 else: 1119 result[name] = BuiltList(count=None, starts=counts, lists=None) 1120 count_list_args.append(counts.data) 1121 1122 # }}} 1123 1124 if self.debug: 1125 gsize = (1,) 1126 lsize = (1,) 1127 elif self.do_not_vectorize(): 1128 gsize = (4*queue.device.max_compute_units,) 1129 lsize = (1,) 1130 else: 1131 from pyopencl.array import splay 1132 gsize, lsize = splay(queue, n_objects) 1133 1134 count_event = count_kernel(queue, gsize, lsize, 1135 *(tuple(count_list_args) + args + (n_objects,)), 1136 **dict(wait_for=wait_for)) 1137 1138 compress_events = {} 1139 for name, dtype in self.list_names_and_dtypes: 1140 if name in omit_lists: 1141 continue 1142 if name in self.count_sharing: 1143 continue 1144 if name not in self.eliminate_empty_output_lists: 1145 continue 1146 1147 compressed_counts = cl.array.empty( 1148 queue, (n_objects + 1,), index_dtype, allocator=allocator) 1149 info_record = result[name] 1150 info_record.nonempty_indices = cl.array.empty( 1151 queue, (n_objects + 1,), index_dtype, allocator=allocator) 1152 info_record.num_nonempty_lists = cl.array.empty( 1153 queue, (1,), index_dtype, allocator=allocator) 1154 info_record.compressed_indices = cl.array.empty( 1155 queue, (n_objects + 1,), index_dtype, allocator=allocator) 1156 info_record.compressed_indices[0] = 0 1157 compress_events[name] = compress_kernel( 1158 info_record.starts, 1159 compressed_counts, 1160 info_record.nonempty_indices, 1161 info_record.compressed_indices, 1162 info_record.num_nonempty_lists, 1163 wait_for=[count_event] + info_record.compressed_indices.events) 1164 1165 info_record.starts = compressed_counts 1166 1167 # {{{ run scans 1168 1169 scan_events = [] 1170 1171 for name, dtype in self.list_names_and_dtypes: 1172 if name in self.count_sharing: 1173 continue 1174 if name in omit_lists: 1175 continue 1176 1177 info_record = result[name] 1178 if name in self.eliminate_empty_output_lists: 1179 compress_events[name].wait() 1180 num_nonempty_lists = info_record.num_nonempty_lists.get()[0] 1181 info_record.num_nonempty_lists = num_nonempty_lists 1182 info_record.starts = info_record.starts[:num_nonempty_lists + 1] 1183 info_record.nonempty_indices = \ 1184 info_record.nonempty_indices[:num_nonempty_lists] 1185 info_record.starts[-1] = 0 1186 1187 starts_ary = info_record.starts 1188 if name in self.eliminate_empty_output_lists: 1189 evt = scan_kernel( 1190 starts_ary, 1191 size=info_record.num_nonempty_lists, 1192 wait_for=starts_ary.events) 1193 else: 1194 evt = scan_kernel(starts_ary, wait_for=[count_event], 1195 size=n_objects) 1196 1197 starts_ary.setitem(0, 0, queue=queue, wait_for=[evt]) 1198 scan_events.extend(starts_ary.events) 1199 1200 # retrieve count 1201 info_record.count = int(starts_ary[-1].get()) 1202 1203 # }}} 1204 1205 # {{{ deal with count-sharing lists, allocate memory for lists 1206 1207 write_list_args = [] 1208 for name, dtype in self.list_names_and_dtypes: 1209 if name in omit_lists: 1210 write_list_args.append(None) 1211 if name not in self.count_sharing: 1212 write_list_args.append(None) 1213 if name in self.eliminate_empty_output_lists: 1214 write_list_args.append(None) 1215 continue 1216 1217 if name in self.count_sharing: 1218 sharing_from = self.count_sharing[name] 1219 1220 info_record = result[name] = BuiltList( 1221 count=result[sharing_from].count, 1222 starts=result[sharing_from].starts, 1223 ) 1224 1225 else: 1226 info_record = result[name] 1227 1228 info_record.lists = cl.array.empty(queue, 1229 info_record.count, dtype, allocator=allocator) 1230 write_list_args.append(info_record.lists.data) 1231 1232 if name not in self.count_sharing: 1233 write_list_args.append(info_record.starts.data) 1234 1235 if name in self.eliminate_empty_output_lists: 1236 write_list_args.append(info_record.compressed_indices.data) 1237 1238 # }}} 1239 1240 evt = write_kernel(queue, gsize, lsize, 1241 *(tuple(write_list_args) + args + (n_objects,)), 1242 **dict(wait_for=scan_events)) 1243 1244 return result, evt 1245 1246 # }}} 1247 1248# }}} 1249 1250 1251# {{{ key-value sorting 1252 1253class _KernelInfo(Record): 1254 pass 1255 1256 1257def _make_cl_int_literal(value, dtype): 1258 iinfo = np.iinfo(dtype) 1259 result = str(int(value)) 1260 if dtype.itemsize == 8: 1261 result += "l" 1262 if int(iinfo.min) < 0: 1263 result += "u" 1264 1265 return result 1266 1267 1268class KeyValueSorter(object): 1269 """Given arrays *values* and *keys* of equal length 1270 and a number *nkeys* of keys, returns a tuple `(starts, 1271 lists)`, as follows: *values* and *keys* are sorted 1272 by *keys*, and the sorted *values* is returned as 1273 *lists*. Then for each index *i* in `range(nkeys)`, 1274 *starts[i]* is written to indicating where the 1275 group of *values* belonging to the key with index 1276 *i* begins. It implicitly ends at *starts[i+1]*. 1277 1278 `starts` is built so that it has `nkeys+1` entries, so that 1279 the *i*'th entry is the start of the *i*'th list, and the 1280 *i*'th entry is the index one past the *i*'th list's end, 1281 even for the last list. 1282 1283 This implies that all lists are contiguous. 1284 1285 .. note:: This functionality is provided as a preview. Its 1286 interface is subject to change until this notice is removed. 1287 1288 .. versionadded:: 2013.1 1289 """ 1290 1291 def __init__(self, context): 1292 self.context = context 1293 1294 @memoize_method 1295 def get_kernels(self, key_dtype, value_dtype, starts_dtype): 1296 from pyopencl.algorithm import RadixSort 1297 from pyopencl.tools import VectorArg, ScalarArg 1298 1299 by_target_sorter = RadixSort( 1300 self.context, [ 1301 VectorArg(value_dtype, "values"), 1302 VectorArg(key_dtype, "keys"), 1303 ], 1304 key_expr="keys[i]", 1305 sort_arg_names=["values", "keys"]) 1306 1307 from pyopencl.elementwise import ElementwiseTemplate 1308 start_finder = ElementwiseTemplate( 1309 arguments="""//CL// 1310 starts_t *key_group_starts, 1311 key_t *keys_sorted_by_key, 1312 """, 1313 1314 operation=r"""//CL// 1315 key_t my_key = keys_sorted_by_key[i]; 1316 1317 if (i == 0 || my_key != keys_sorted_by_key[i-1]) 1318 key_group_starts[my_key] = i; 1319 """, 1320 name="find_starts").build(self.context, 1321 type_aliases=( 1322 ("key_t", starts_dtype), 1323 ("starts_t", starts_dtype), 1324 ), 1325 var_values=()) 1326 1327 from pyopencl.scan import GenericScanKernel 1328 bound_propagation_scan = GenericScanKernel( 1329 self.context, starts_dtype, 1330 arguments=[ 1331 VectorArg(starts_dtype, "starts"), 1332 # starts has length n+1 1333 ScalarArg(key_dtype, "nkeys"), 1334 ], 1335 input_expr="starts[nkeys-i]", 1336 scan_expr="min(a, b)", 1337 neutral=_make_cl_int_literal( 1338 np.iinfo(starts_dtype).max, starts_dtype), 1339 output_statement="starts[nkeys-i] = item;") 1340 1341 return _KernelInfo( 1342 by_target_sorter=by_target_sorter, 1343 start_finder=start_finder, 1344 bound_propagation_scan=bound_propagation_scan) 1345 1346 def __call__(self, queue, keys, values, nkeys, 1347 starts_dtype, allocator=None, wait_for=None): 1348 if allocator is None: 1349 allocator = values.allocator 1350 1351 knl_info = self.get_kernels(keys.dtype, values.dtype, 1352 starts_dtype) 1353 1354 (values_sorted_by_key, keys_sorted_by_key), evt = knl_info.by_target_sorter( 1355 values, keys, queue=queue, wait_for=wait_for) 1356 1357 starts = (cl.array.empty(queue, (nkeys+1), starts_dtype, allocator=allocator) 1358 .fill(len(values_sorted_by_key), wait_for=[evt])) 1359 evt, = starts.events 1360 1361 evt = knl_info.start_finder(starts, keys_sorted_by_key, 1362 range=slice(len(keys_sorted_by_key)), 1363 wait_for=[evt]) 1364 1365 evt = knl_info.bound_propagation_scan(starts, nkeys, 1366 queue=queue, wait_for=[evt]) 1367 1368 return starts, values_sorted_by_key, evt 1369 1370# }}} 1371 1372# vim: filetype=pyopencl:fdm=marker 1373