1"""Scan primitive."""
2
3from __future__ import division
4from __future__ import absolute_import
5from six.moves import range
6from six.moves import zip
7
8__copyright__ = """Copyright 2011-2012 Andreas Kloeckner \
9                   Copyright 2017 Hao Gao"""
10
11__license__ = """
12Permission is hereby granted, free of charge, to any person
13obtaining a copy of this software and associated documentation
14files (the "Software"), to deal in the Software without
15restriction, including without limitation the rights to use,
16copy, modify, merge, publish, distribute, sublicense, and/or sell
17copies of the Software, and to permit persons to whom the
18Software is furnished to do so, subject to the following
19conditions:
20
21The above copyright notice and this permission notice shall be
22included in all copies or substantial portions of the Software.
23
24THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
25EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
26OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
27NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
28HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
29WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
30FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
31OTHER DEALINGS IN THE SOFTWARE.
32"""
33
34import numpy as np
35import pyopencl as cl
36import pyopencl.array  # noqa
37from pyopencl.scan import ScanTemplate
38from pyopencl.tools import dtype_to_ctype
39from pytools import memoize, memoize_method, Record
40from mako.template import Template
41
42
43# {{{ copy_if
44
45_copy_if_template = ScanTemplate(
46        arguments="item_t *ary, item_t *out, scan_t *count",
47        input_expr="(%(predicate)s) ? 1 : 0",
48        scan_expr="a+b", neutral="0",
49        output_statement="""
50            if (prev_item != item) out[item-1] = ary[i];
51            if (i+1 == N) *count = item;
52            """,
53        template_processor="printf")
54
55
56def extract_extra_args_types_values(extra_args):
57    from pyopencl.tools import VectorArg, ScalarArg
58
59    extra_args_types = []
60    extra_args_values = []
61    for name, val in extra_args:
62        if isinstance(val, cl.array.Array):
63            extra_args_types.append(VectorArg(val.dtype, name, with_offset=False))
64            extra_args_values.append(val)
65        elif isinstance(val, np.generic):
66            extra_args_types.append(ScalarArg(val.dtype, name))
67            extra_args_values.append(val)
68        else:
69            raise RuntimeError("argument '%d' not understood" % name)
70
71    return tuple(extra_args_types), extra_args_values
72
73
74def copy_if(ary, predicate, extra_args=[], preamble="", queue=None, wait_for=None):
75    """Copy the elements of *ary* satisfying *predicate* to an output array.
76
77    :arg predicate: a C expression evaluating to a `bool`, represented as a string.
78        The value to test is available as `ary[i]`, and if the expression evaluates
79        to `true`, then this value ends up in the output.
80    :arg extra_args: |scan_extra_args|
81    :arg preamble: |preamble|
82    :arg wait_for: |explain-waitfor|
83    :returns: a tuple *(out, count, event)* where *out* is the output array, *count*
84        is an on-device scalar (fetch to host with `count.get()`) indicating
85        how many elements satisfied *predicate*, and *event* is a
86        :class:`pyopencl.Event` for dependency management. *out* is allocated
87        to the same length as *ary*, but only the first *count* entries carry
88        meaning.
89
90    .. versionadded:: 2013.1
91    """
92    if len(ary) > np.iinfo(np.int32).max:
93        scan_dtype = np.int64
94    else:
95        scan_dtype = np.int32
96
97    extra_args_types, extra_args_values = extract_extra_args_types_values(extra_args)
98
99    knl = _copy_if_template.build(ary.context,
100            type_aliases=(("scan_t", scan_dtype), ("item_t", ary.dtype)),
101            var_values=(("predicate", predicate),),
102            more_preamble=preamble, more_arguments=extra_args_types)
103    out = cl.array.empty_like(ary)
104    count = ary._new_with_changes(data=None, offset=0,
105            shape=(), strides=(), dtype=scan_dtype)
106
107    # **dict is a Py2.5 workaround
108    evt = knl(ary, out, count, *extra_args_values,
109            **dict(queue=queue, wait_for=wait_for))
110
111    return out, count, evt
112
113# }}}
114
115
116# {{{ remove_if
117
118def remove_if(ary, predicate, extra_args=[], preamble="", queue=None, wait_for=None):
119    """Copy the elements of *ary* not satisfying *predicate* to an output array.
120
121    :arg predicate: a C expression evaluating to a `bool`, represented as a string.
122        The value to test is available as `ary[i]`, and if the expression evaluates
123        to `false`, then this value ends up in the output.
124    :arg extra_args: |scan_extra_args|
125    :arg preamble: |preamble|
126    :arg wait_for: |explain-waitfor|
127    :returns: a tuple *(out, count, event)* where *out* is the output array, *count*
128        is an on-device scalar (fetch to host with `count.get()`) indicating
129        how many elements did not satisfy *predicate*, and *event* is a
130        :class:`pyopencl.Event` for dependency management.
131
132    .. versionadded:: 2013.1
133    """
134    return copy_if(ary, "!(%s)" % predicate, extra_args=extra_args,
135            preamble=preamble, queue=queue, wait_for=wait_for)
136
137# }}}
138
139
140# {{{ partition
141
142_partition_template = ScanTemplate(
143        arguments=(
144            "item_t *ary, item_t *out_true, item_t *out_false, "
145            "scan_t *count_true"),
146        input_expr="(%(predicate)s) ? 1 : 0",
147        scan_expr="a+b", neutral="0",
148        output_statement="""//CL//
149                if (prev_item != item)
150                    out_true[item-1] = ary[i];
151                else
152                    out_false[i-item] = ary[i];
153                if (i+1 == N) *count_true = item;
154                """,
155        template_processor="printf")
156
157
158def partition(ary, predicate, extra_args=[], preamble="", queue=None, wait_for=None):
159    """Copy the elements of *ary* into one of two arrays depending on whether
160    they satisfy *predicate*.
161
162    :arg predicate: a C expression evaluating to a `bool`, represented as a string.
163        The value to test is available as `ary[i]`.
164    :arg extra_args: |scan_extra_args|
165    :arg preamble: |preamble|
166    :arg wait_for: |explain-waitfor|
167    :returns: a tuple *(out_true, out_false, count, event)* where *count*
168        is an on-device scalar (fetch to host with `count.get()`) indicating
169        how many elements satisfied the predicate, and *event* is a
170        :class:`pyopencl.Event` for dependency management.
171
172    .. versionadded:: 2013.1
173    """
174    if len(ary) > np.iinfo(np.uint32).max:
175        scan_dtype = np.uint64
176    else:
177        scan_dtype = np.uint32
178
179    extra_args_types, extra_args_values = extract_extra_args_types_values(extra_args)
180
181    knl = _partition_template.build(
182            ary.context,
183            type_aliases=(("item_t", ary.dtype), ("scan_t", scan_dtype)),
184            var_values=(("predicate", predicate),),
185            more_preamble=preamble, more_arguments=extra_args_types)
186
187    out_true = cl.array.empty_like(ary)
188    out_false = cl.array.empty_like(ary)
189    count = ary._new_with_changes(data=None, offset=0,
190            shape=(), strides=(), dtype=scan_dtype)
191
192    # **dict is a Py2.5 workaround
193    evt = knl(ary, out_true, out_false, count, *extra_args_values,
194            **dict(queue=queue, wait_for=wait_for))
195
196    return out_true, out_false, count, evt
197
198# }}}
199
200
201# {{{ unique
202
203_unique_template = ScanTemplate(
204        arguments="item_t *ary, item_t *out, scan_t *count_unique",
205        input_fetch_exprs=[
206            ("ary_im1", "ary", -1),
207            ("ary_i", "ary", 0),
208            ],
209        input_expr="(i == 0) || (IS_EQUAL_EXPR(ary_im1, ary_i) ? 0 : 1)",
210        scan_expr="a+b", neutral="0",
211        output_statement="""
212                if (prev_item != item) out[item-1] = ary[i];
213                if (i+1 == N) *count_unique = item;
214                """,
215        preamble="#define IS_EQUAL_EXPR(a, b) %(macro_is_equal_expr)s\n",
216        template_processor="printf")
217
218
219def unique(ary, is_equal_expr="a == b", extra_args=[], preamble="",
220        queue=None, wait_for=None):
221    """Copy the elements of *ary* into the output if *is_equal_expr*, applied to the
222    array element and its predecessor, yields false.
223
224    Works like the UNIX command :program:`uniq`, with a potentially custom
225    comparison.  This operation is often used on sorted sequences.
226
227    :arg is_equal_expr: a C expression evaluating to a `bool`,
228        represented as a string.  The elements being compared are
229        available as `a` and `b`. If this expression yields `false`, the
230        two are considered distinct.
231    :arg extra_args: |scan_extra_args|
232    :arg preamble: |preamble|
233    :arg wait_for: |explain-waitfor|
234    :returns: a tuple *(out, count, event)* where *out* is the output array, *count*
235        is an on-device scalar (fetch to host with `count.get()`) indicating
236        how many elements satisfied the predicate, and *event* is a
237        :class:`pyopencl.Event` for dependency management.
238
239    .. versionadded:: 2013.1
240    """
241
242    if len(ary) > np.iinfo(np.uint32).max:
243        scan_dtype = np.uint64
244    else:
245        scan_dtype = np.uint32
246
247    extra_args_types, extra_args_values = extract_extra_args_types_values(extra_args)
248
249    knl = _unique_template.build(
250            ary.context,
251            type_aliases=(("item_t", ary.dtype), ("scan_t", scan_dtype)),
252            var_values=(("macro_is_equal_expr", is_equal_expr),),
253            more_preamble=preamble, more_arguments=extra_args_types)
254
255    out = cl.array.empty_like(ary)
256    count = ary._new_with_changes(data=None, offset=0,
257            shape=(), strides=(), dtype=scan_dtype)
258
259    # **dict is a Py2.5 workaround
260    evt = knl(ary, out, count, *extra_args_values,
261            **dict(queue=queue, wait_for=wait_for))
262
263    return out, count, evt
264
265# }}}
266
267
268# {{{ radix_sort
269
270def to_bin(n):
271    # Py 2.5 has no built-in bin()
272    digs = []
273    while n:
274        digs.append(str(n % 2))
275        n >>= 1
276
277    return ''.join(digs[::-1])
278
279
280def _padded_bin(i, l):
281    s = to_bin(i)
282    while len(s) < l:
283        s = '0' + s
284    return s
285
286
287@memoize
288def _make_sort_scan_type(device, bits, index_dtype):
289    name = "pyopencl_sort_scan_%s_%dbits_t" % (
290            index_dtype.type.__name__, bits)
291
292    fields = []
293    for mnr in range(2**bits):
294        fields.append(('c%s' % _padded_bin(mnr, bits), index_dtype))
295
296    dtype = np.dtype(fields)
297
298    from pyopencl.tools import get_or_register_dtype, match_dtype_to_c_struct
299    dtype, c_decl = match_dtype_to_c_struct(device, name, dtype)
300
301    dtype = get_or_register_dtype(name, dtype)
302    return name, dtype, c_decl
303
304
305# {{{ types, helpers preamble
306
307RADIX_SORT_PREAMBLE_TPL = Template(r"""//CL//
308    typedef ${scan_ctype} scan_t;
309    typedef ${key_ctype} key_t;
310    typedef ${index_ctype} index_t;
311
312    // #define DEBUG
313    #ifdef DEBUG
314        #define dbg_printf(ARGS) printf ARGS
315    #else
316        #define dbg_printf(ARGS) /* */
317    #endif
318
319    index_t get_count(scan_t s, int mnr)
320    {
321        return ${get_count_branch("")};
322    }
323
324    #define BIN_NR(key_arg) ((key_arg >> base_bit) & ${2**bits - 1})
325
326""", strict_undefined=True)
327
328# }}}
329
330# {{{ scan helpers
331
332RADIX_SORT_SCAN_PREAMBLE_TPL = Template(r"""//CL//
333    scan_t scan_t_neutral()
334    {
335        scan_t result;
336        %for mnr in range(2**bits):
337            result.c${padded_bin(mnr, bits)} = 0;
338        %endfor
339        return result;
340    }
341
342    // considers bits (base_bit+bits-1, ..., base_bit)
343    scan_t scan_t_from_value(
344        key_t key,
345        int base_bit,
346        int i
347    )
348    {
349        // extract relevant bit range
350        key_t bin_nr = BIN_NR(key);
351
352        dbg_printf(("i: %d key:%d bin_nr:%d\n", i, key, bin_nr));
353
354        scan_t result;
355        %for mnr in range(2**bits):
356            result.c${padded_bin(mnr, bits)} = (bin_nr == ${mnr});
357        %endfor
358
359        return result;
360    }
361
362    scan_t scan_t_add(scan_t a, scan_t b, bool across_seg_boundary)
363    {
364        %for mnr in range(2**bits):
365            <% field = "c"+padded_bin(mnr, bits) %>
366            b.${field} = a.${field} + b.${field};
367        %endfor
368
369        return b;
370    }
371""", strict_undefined=True)
372
373RADIX_SORT_OUTPUT_STMT_TPL = Template(r"""//CL//
374    {
375        key_t key = ${key_expr};
376        key_t my_bin_nr = BIN_NR(key);
377
378        index_t previous_bins_size = 0;
379        %for mnr in range(2**bits):
380            previous_bins_size +=
381                (my_bin_nr > ${mnr})
382                    ? last_item.c${padded_bin(mnr, bits)}
383                    : 0;
384        %endfor
385
386        index_t tgt_idx =
387            previous_bins_size
388            + get_count(item, my_bin_nr) - 1;
389
390        %for arg_name in sort_arg_names:
391            sorted_${arg_name}[tgt_idx] = ${arg_name}[i];
392        %endfor
393    }
394""", strict_undefined=True)
395
396# }}}
397
398
399# {{{ driver
400
401# import hoisted here to be used as a default argument in the constructor
402from pyopencl.scan import GenericScanKernel
403
404
405class RadixSort(object):
406    """Provides a general `radix sort <https://en.wikipedia.org/wiki/Radix_sort>`_
407    on the compute device.
408
409    .. seealso:: :class:`pyopencl.algorithm.BitonicSort`
410
411    .. versionadded:: 2013.1
412    """
413    def __init__(self, context, arguments, key_expr, sort_arg_names,
414            bits_at_a_time=2, index_dtype=np.int32, key_dtype=np.uint32,
415            scan_kernel=GenericScanKernel, options=[]):
416        """
417        :arg arguments: A string of comma-separated C argument declarations.
418            If *arguments* is specified, then *input_expr* must also be
419            specified. All types used here must be known to PyOpenCL.
420            (see :func:`pyopencl.tools.get_or_register_dtype`).
421        :arg key_expr: An integer-valued C expression returning the
422            key based on which the sort is performed. The array index
423            for which the key is to be computed is available as `i`.
424            The expression may refer to any of the *arguments*.
425        :arg sort_arg_names: A list of argument names whose corresponding
426            array arguments will be sorted according to *key_expr*.
427        """
428
429        # {{{ arg processing
430
431        from pyopencl.tools import parse_arg_list
432        self.arguments = parse_arg_list(arguments)
433        del arguments
434
435        self.sort_arg_names = sort_arg_names
436        self.bits = int(bits_at_a_time)
437        self.index_dtype = np.dtype(index_dtype)
438        self.key_dtype = np.dtype(key_dtype)
439
440        self.options = options
441
442        # }}}
443
444        # {{{ kernel creation
445
446        scan_ctype, scan_dtype, scan_t_cdecl = \
447                _make_sort_scan_type(context.devices[0], self.bits, self.index_dtype)
448
449        from pyopencl.tools import VectorArg, ScalarArg
450        scan_arguments = (
451                list(self.arguments)
452                + [VectorArg(arg.dtype, "sorted_"+arg.name) for arg in self.arguments
453                    if arg.name in sort_arg_names]
454                + [ScalarArg(np.int32, "base_bit")])
455
456        def get_count_branch(known_bits):
457            if len(known_bits) == self.bits:
458                return "s.c%s" % known_bits
459
460            boundary_mnr = known_bits + "1" + (self.bits-len(known_bits)-1)*"0"
461
462            return ("((mnr < %s) ? %s : %s)" % (
463                int(boundary_mnr, 2),
464                get_count_branch(known_bits+"0"),
465                get_count_branch(known_bits+"1")))
466
467        codegen_args = dict(
468                bits=self.bits,
469                key_ctype=dtype_to_ctype(self.key_dtype),
470                key_expr=key_expr,
471                index_ctype=dtype_to_ctype(self.index_dtype),
472                index_type_max=np.iinfo(self.index_dtype).max,
473                padded_bin=_padded_bin,
474                scan_ctype=scan_ctype,
475                sort_arg_names=sort_arg_names,
476                get_count_branch=get_count_branch,
477                )
478
479        preamble = scan_t_cdecl+RADIX_SORT_PREAMBLE_TPL.render(**codegen_args)
480        scan_preamble = preamble \
481                + RADIX_SORT_SCAN_PREAMBLE_TPL.render(**codegen_args)
482
483        self.scan_kernel = scan_kernel(
484                context, scan_dtype,
485                arguments=scan_arguments,
486                input_expr="scan_t_from_value(%s, base_bit, i)" % key_expr,
487                scan_expr="scan_t_add(a, b, across_seg_boundary)",
488                neutral="scan_t_neutral()",
489                output_statement=RADIX_SORT_OUTPUT_STMT_TPL.render(**codegen_args),
490                preamble=scan_preamble, options=self.options)
491
492        for i, arg in enumerate(self.arguments):
493            if isinstance(arg, VectorArg):
494                self.first_array_arg_idx = i
495
496        # }}}
497
498    def __call__(self, *args, **kwargs):
499        """Run the radix sort. In addition to *args* which must match the
500        *arguments* specification on the constructor, the following
501        keyword arguments are supported:
502
503        :arg key_bits: specify how many bits (starting from least-significant)
504            there are in the key.
505        :arg allocator: See the *allocator* argument of :func:`pyopencl.array.empty`.
506        :arg queue: A :class:`pyopencl.CommandQueue`, defaulting to the
507            one from the first argument array.
508        :arg wait_for: |explain-waitfor|
509        :returns: A tuple ``(sorted, event)``. *sorted* consists of sorted
510            copies of the arrays named in *sorted_args*, in the order of that
511            list. *event* is a :class:`pyopencl.Event` for dependency management.
512        """
513
514        wait_for = kwargs.pop("wait_for", None)
515
516        # {{{ run control
517
518        key_bits = kwargs.pop("key_bits", None)
519        if key_bits is None:
520            key_bits = int(np.iinfo(self.key_dtype).bits)
521
522        n = len(args[self.first_array_arg_idx])
523
524        allocator = kwargs.pop("allocator", None)
525        if allocator is None:
526            allocator = args[self.first_array_arg_idx].allocator
527
528        queue = kwargs.pop("queue", None)
529        if queue is None:
530            queue = args[self.first_array_arg_idx].queue
531
532        args = list(args)
533
534        base_bit = 0
535        while base_bit < key_bits:
536            sorted_args = [
537                    cl.array.empty(queue, n, arg_descr.dtype, allocator=allocator)
538                    for arg_descr in self.arguments
539                    if arg_descr.name in self.sort_arg_names]
540
541            scan_args = args + sorted_args + [base_bit]
542
543            last_evt = self.scan_kernel(*scan_args,
544                    **dict(queue=queue, wait_for=wait_for))
545            wait_for = [last_evt]
546
547            # substitute sorted
548            for i, arg_descr in enumerate(self.arguments):
549                if arg_descr.name in self.sort_arg_names:
550                    args[i] = sorted_args[self.sort_arg_names.index(arg_descr.name)]
551
552            base_bit += self.bits
553
554        return [arg_val
555                for arg_descr, arg_val in zip(self.arguments, args)
556                if arg_descr.name in self.sort_arg_names], last_evt
557
558        # }}}
559
560# }}}
561
562# }}}
563
564
565# {{{ generic parallel list builder
566
567# {{{ kernel template
568
569_LIST_BUILDER_TEMPLATE = Template("""//CL//
570% if double_support:
571    #if __OPENCL_C_VERSION__ < 120
572    #pragma OPENCL EXTENSION cl_khr_fp64: enable
573    #endif
574    #define PYOPENCL_DEFINE_CDOUBLE
575% endif
576
577#include <pyopencl-complex.h>
578
579${preamble}
580
581// {{{ declare helper macros for user interface
582
583typedef ${index_type} index_type;
584
585%if is_count_stage:
586    #define PLB_COUNT_STAGE
587
588    %for name, dtype in list_names_and_dtypes:
589        %if name in count_sharing:
590            #define APPEND_${name}(value) { /* nothing */ }
591        %else:
592            #define APPEND_${name}(value) { ++(*plb_loc_${name}_count); }
593        %endif
594    %endfor
595%else:
596    #define PLB_WRITE_STAGE
597
598    %for name, dtype in list_names_and_dtypes:
599        %if name in count_sharing:
600            #define APPEND_${name}(value) \
601                { plb_${name}_list[(*plb_${count_sharing[name]}_index) - 1] \
602                    = value; }
603        %else:
604            #define APPEND_${name}(value) \
605                { plb_${name}_list[(*plb_${name}_index)++] = value; }
606        %endif
607    %endfor
608%endif
609
610#define LIST_ARG_DECL ${user_list_arg_decl}
611#define LIST_ARGS ${user_list_args}
612#define USER_ARG_DECL ${user_arg_decl}
613#define USER_ARGS ${user_args}
614
615// }}}
616
617${generate_template}
618
619// {{{ kernel entry point
620
621__kernel
622%if do_not_vectorize:
623__attribute__((reqd_work_group_size(1, 1, 1)))
624%endif
625void ${kernel_name}(${kernel_list_arg_decl} USER_ARG_DECL index_type n)
626
627{
628    %if not do_not_vectorize:
629        int lid = get_local_id(0);
630        index_type gsize = get_global_size(0);
631        index_type work_group_start = get_local_size(0)*get_group_id(0);
632        for (index_type i = work_group_start + lid; i < n; i += gsize)
633    %else:
634        const int chunk_size = 128;
635        index_type chunk_base = get_global_id(0)*chunk_size;
636        index_type gsize = get_global_size(0);
637        for (; chunk_base < n; chunk_base += gsize*chunk_size)
638        for (index_type i = chunk_base; i < min(n, chunk_base+chunk_size); ++i)
639    %endif
640    {
641        %if is_count_stage:
642            %for name, dtype in list_names_and_dtypes:
643                %if name not in count_sharing:
644                    index_type plb_loc_${name}_count = 0;
645                %endif
646            %endfor
647        %else:
648            %for name, dtype in list_names_and_dtypes:
649                %if name not in count_sharing:
650                    index_type plb_${name}_index;
651                    if (plb_${name}_start_index)
652                        %if name in eliminate_empty_output_lists:
653                            plb_${name}_index =
654                                plb_${name}_start_index[
655                                    ${name}_compressed_indices[i]
656                                ];
657                        %else:
658                            plb_${name}_index = plb_${name}_start_index[i];
659                        %endif
660                    else
661                        plb_${name}_index = 0;
662                %endif
663            %endfor
664        %endif
665
666        generate(${kernel_list_arg_values} USER_ARGS i);
667
668        %if is_count_stage:
669            %for name, dtype in list_names_and_dtypes:
670                %if name not in count_sharing:
671                    if (plb_${name}_count)
672                        plb_${name}_count[i] = plb_loc_${name}_count;
673                %endif
674            %endfor
675        %endif
676    }
677}
678
679// }}}
680
681""", strict_undefined=True)
682
683# }}}
684
685
686def _get_arg_decl(arg_list):
687    result = ""
688    for arg in arg_list:
689        result += arg.declarator() + ", "
690
691    return result
692
693
694def _get_arg_list(arg_list, prefix=""):
695    result = ""
696    for arg in arg_list:
697        result += prefix + arg.name + ", "
698
699    return result
700
701
702class BuiltList(Record):
703    pass
704
705
706class ListOfListsBuilder:
707    """Generates and executes code to produce a large number of variable-size
708    lists, simply.
709
710    .. note:: This functionality is provided as a preview. Its interface
711        is subject to change until this notice is removed.
712
713    .. versionadded:: 2013.1
714
715    Here's a usage example::
716
717        from pyopencl.algorithm import ListOfListsBuilder
718        builder = ListOfListsBuilder(context, [("mylist", np.int32)], \"\"\"
719                void generate(LIST_ARG_DECL USER_ARG_DECL index_type i)
720                {
721                    int count = i % 4;
722                    for (int j = 0; j < count; ++j)
723                    {
724                        APPEND_mylist(count);
725                    }
726                }
727                \"\"\", arg_decls=[])
728
729        result, event = builder(queue, 2000)
730
731        inf = result["mylist"]
732        assert inf.count == 3000
733        assert (inf.list.get()[-6:] == [1, 2, 2, 3, 3, 3]).all()
734
735    The function `generate` above is called once for each "input object".
736    Each input object can then generate zero or more list entries.
737    The number of these input objects is given to :meth:`__call__` as *n_objects*.
738    List entries are generated by calls to `APPEND_<list name>(value)`.
739    Multiple lists may be generated at once.
740
741    .. automethod:: __init__
742    .. automethod:: __call__
743    """
744    def __init__(self, context, list_names_and_dtypes, generate_template,
745            arg_decls, count_sharing=None, devices=None,
746            name_prefix="plb_build_list", options=[], preamble="",
747            debug=False, complex_kernel=False,
748            eliminate_empty_output_lists=[]):
749        """
750        :arg context: A :class:`pyopencl.Context`.
751        :arg list_names_and_dtypes: a list of `(name, dtype)` tuples
752            indicating the lists to be built.
753        :arg generate_template: a snippet of C as described below
754        :arg arg_decls: A string of comma-separated C argument declarations.
755        :arg count_sharing: A mapping consisting of `(child, mother)`
756            indicating that `mother` and `child` will always have the
757            same number of indices, and the `APPEND` to `mother`
758            will always happen *before* the `APPEND` to the child.
759        :arg name_prefix: the name prefix to use for the compiled kernels
760        :arg options: OpenCL compilation options for kernels using
761            *generate_template*.
762        :arg complex_kernel: If `True`, prevents vectorization on CPUs.
763        :arg eliminate_empty_output_lists: A Python list of list names
764            for which the empty output lists are eliminated.
765
766        *generate_template* may use the following C macros/identifiers:
767
768        * `index_type`: expands to C identifier for the index type used
769          for the calculation
770        * `USER_ARG_DECL`: expands to the C declarator for `arg_decls`
771        * `USER_ARGS`: a list of C argument values corresponding to
772          `user_arg_decl`
773        * `LIST_ARG_DECL`: expands to a C argument list representing the
774          data for the output lists. These are escaped prefixed with
775          `"plg_"` so as to not interfere with user-provided names.
776        * `LIST_ARGS`: a list of C argument values corresponding to
777          `LIST_ARG_DECL`
778        * `APPEND_name(entry)`: inserts `entry` into the list `name`.
779          *entry* must be a valid C expression of the correct type.
780
781        All argument-list related macros have a trailing comma included
782        if they are non-empty.
783
784        *generate_template* must supply a function:
785
786        .. code-block:: c
787
788            void generate(USER_ARG_DECL LIST_ARG_DECL index_type i)
789            {
790                APPEND_mylist(5);
791            }
792
793        Internally, the `kernel_template` is expanded (at least) twice. Once,
794        for a 'counting' stage where the size of all the lists is determined,
795        and a second time, for a 'generation' stage where the lists are
796        actually filled. A `generate` function that has side effects beyond
797        calling `append` is therefore ill-formed.
798
799        .. versionchanged:: 2018.1
800
801            Change *eliminate_empty_output_lists* argument type from `bool` to
802            `list`.
803        """
804
805        if devices is None:
806            devices = context.devices
807
808        if count_sharing is None:
809            count_sharing = {}
810
811        self.context = context
812        self.devices = devices
813
814        self.list_names_and_dtypes = list_names_and_dtypes
815        self.generate_template = generate_template
816
817        from pyopencl.tools import parse_arg_list
818        self.arg_decls = parse_arg_list(arg_decls)
819
820        self.count_sharing = count_sharing
821
822        self.name_prefix = name_prefix
823        self.preamble = preamble
824        self.options = options
825
826        self.debug = debug
827
828        self.complex_kernel = complex_kernel
829
830        if eliminate_empty_output_lists is True:
831            eliminate_empty_output_lists = \
832                    [name for name, _ in self.list_names_and_dtypes]
833
834        if eliminate_empty_output_lists is False:
835            eliminate_empty_output_lists = []
836
837        self.eliminate_empty_output_lists = eliminate_empty_output_lists
838        for list_name in self.eliminate_empty_output_lists:
839            if not any(list_name == name for name, _ in self.list_names_and_dtypes):
840                raise ValueError(
841                    "invalid list name '%s' in eliminate_empty_output_lists"
842                    % list_name)
843
844    # {{{ kernel generators
845
846    @memoize_method
847    def get_scan_kernel(self, index_dtype):
848        from pyopencl.scan import GenericScanKernel
849        return GenericScanKernel(
850                self.context, index_dtype,
851                arguments="__global %s *ary" % dtype_to_ctype(index_dtype),
852                input_expr="ary[i]",
853                scan_expr="a+b", neutral="0",
854                output_statement="ary[i+1] = item;",
855                devices=self.devices)
856
857    @memoize_method
858    def get_compress_kernel(self, index_dtype):
859        arguments = """
860            __global ${index_t} *count,
861            __global ${index_t} *compressed_counts,
862            __global ${index_t} *nonempty_indices,
863            __global ${index_t} *compressed_indices,
864            __global ${index_t} *num_non_empty_list
865        """
866        from sys import version_info
867        if version_info > (3, 0):
868            arguments = Template(arguments)
869        else:
870            arguments = Template(arguments, disable_unicode=True)
871
872        from pyopencl.scan import GenericScanKernel
873        return GenericScanKernel(
874                self.context, index_dtype,
875                arguments=arguments.render(index_t=dtype_to_ctype(index_dtype)),
876                input_expr="count[i] == 0 ? 0 : 1",
877                scan_expr="a+b", neutral="0",
878                output_statement="""
879                    if (i + 1 < N) compressed_indices[i + 1] = item;
880                    if (prev_item != item) {
881                        nonempty_indices[item - 1] = i;
882                        compressed_counts[item - 1] = count[i];
883                    }
884                    if (i + 1 == N) *num_non_empty_list = item;
885                    """,
886                devices=self.devices)
887
888    def do_not_vectorize(self):
889        from pytools import any
890        return (self.complex_kernel
891                and any(dev.type & cl.device_type.CPU
892                    for dev in self.context.devices))
893
894    @memoize_method
895    def get_count_kernel(self, index_dtype):
896        index_ctype = dtype_to_ctype(index_dtype)
897        from pyopencl.tools import VectorArg, OtherArg
898        kernel_list_args = [
899                VectorArg(index_dtype, "plb_%s_count" % name)
900                for name, dtype in self.list_names_and_dtypes
901                if name not in self.count_sharing]
902
903        user_list_args = []
904        for name, dtype in self.list_names_and_dtypes:
905            if name in self.count_sharing:
906                continue
907
908            name = "plb_loc_%s_count" % name
909            user_list_args.append(OtherArg("%s *%s" % (
910                index_ctype, name), name))
911
912        kernel_name = self.name_prefix+"_count"
913
914        from pyopencl.characterize import has_double_support
915        src = _LIST_BUILDER_TEMPLATE.render(
916                is_count_stage=True,
917                kernel_name=kernel_name,
918                double_support=all(has_double_support(dev) for dev in
919                    self.context.devices),
920                debug=self.debug,
921                do_not_vectorize=self.do_not_vectorize(),
922                eliminate_empty_output_lists=self.eliminate_empty_output_lists,
923
924                kernel_list_arg_decl=_get_arg_decl(kernel_list_args),
925                kernel_list_arg_values=_get_arg_list(user_list_args, prefix="&"),
926                user_list_arg_decl=_get_arg_decl(user_list_args),
927                user_list_args=_get_arg_list(user_list_args),
928                user_arg_decl=_get_arg_decl(self.arg_decls),
929                user_args=_get_arg_list(self.arg_decls),
930
931                list_names_and_dtypes=self.list_names_and_dtypes,
932                count_sharing=self.count_sharing,
933                name_prefix=self.name_prefix,
934                generate_template=self.generate_template,
935                preamble=self.preamble,
936
937                index_type=index_ctype,
938                )
939
940        src = str(src)
941
942        prg = cl.Program(self.context, src).build(self.options)
943        knl = getattr(prg, kernel_name)
944
945        from pyopencl.tools import get_arg_list_scalar_arg_dtypes
946        knl.set_scalar_arg_dtypes(get_arg_list_scalar_arg_dtypes(
947            kernel_list_args+self.arg_decls) + [index_dtype])
948
949        return knl
950
951    @memoize_method
952    def get_write_kernel(self, index_dtype):
953        index_ctype = dtype_to_ctype(index_dtype)
954        from pyopencl.tools import VectorArg, OtherArg
955        kernel_list_args = []
956        kernel_list_arg_values = ""
957        user_list_args = []
958
959        for name, dtype in self.list_names_and_dtypes:
960            list_name = "plb_%s_list" % name
961            list_arg = VectorArg(dtype, list_name)
962
963            kernel_list_args.append(list_arg)
964            user_list_args.append(list_arg)
965
966            if name in self.count_sharing:
967                kernel_list_arg_values += "%s, " % list_name
968                continue
969
970            kernel_list_args.append(
971                    VectorArg(index_dtype, "plb_%s_start_index" % name))
972
973            if name in self.eliminate_empty_output_lists:
974                kernel_list_args.append(
975                    VectorArg(index_dtype, "%s_compressed_indices" % name))
976
977            index_name = "plb_%s_index" % name
978            user_list_args.append(OtherArg("%s *%s" % (
979                index_ctype, index_name), index_name))
980
981            kernel_list_arg_values += "%s, &%s, " % (list_name, index_name)
982
983        kernel_name = self.name_prefix+"_write"
984
985        from pyopencl.characterize import has_double_support
986        src = _LIST_BUILDER_TEMPLATE.render(
987                is_count_stage=False,
988                kernel_name=kernel_name,
989                double_support=all(has_double_support(dev) for dev in
990                    self.context.devices),
991                debug=self.debug,
992                do_not_vectorize=self.do_not_vectorize(),
993                eliminate_empty_output_lists=self.eliminate_empty_output_lists,
994
995                kernel_list_arg_decl=_get_arg_decl(kernel_list_args),
996                kernel_list_arg_values=kernel_list_arg_values,
997                user_list_arg_decl=_get_arg_decl(user_list_args),
998                user_list_args=_get_arg_list(user_list_args),
999                user_arg_decl=_get_arg_decl(self.arg_decls),
1000                user_args=_get_arg_list(self.arg_decls),
1001
1002                list_names_and_dtypes=self.list_names_and_dtypes,
1003                count_sharing=self.count_sharing,
1004                name_prefix=self.name_prefix,
1005                generate_template=self.generate_template,
1006                preamble=self.preamble,
1007
1008                index_type=index_ctype,
1009                )
1010
1011        src = str(src)
1012
1013        prg = cl.Program(self.context, src).build(self.options)
1014        knl = getattr(prg, kernel_name)
1015
1016        from pyopencl.tools import get_arg_list_scalar_arg_dtypes
1017        knl.set_scalar_arg_dtypes(get_arg_list_scalar_arg_dtypes(
1018            kernel_list_args+self.arg_decls) + [index_dtype])
1019
1020        return knl
1021
1022    # }}}
1023
1024    # {{{ driver
1025
1026    def __call__(self, queue, n_objects, *args, **kwargs):
1027        """
1028        :arg args: arguments corresponding to arg_decls in the constructor.
1029            :class:`pyopencl.array.Array` are not allowed directly and should
1030            be passed as their :attr:`pyopencl.array.Array.data` attribute instead.
1031        :arg allocator: optionally, the allocator to use to allocate new
1032            arrays.
1033        :arg omit_lists: An iterable of list names that should *not* be built
1034            with this invocation. The kernel code may *not* call ``APPEND_name``
1035            for these omitted lists. If it does, undefined behavior will result.
1036            The returned *lists* dictionary will not contain an entry for names
1037            in *omit_lists*.
1038        :arg wait_for: |explain-waitfor|
1039        :returns: a tuple ``(lists, event)``, where
1040            *lists* a mapping from (built) list names to objects which
1041            have attributes
1042
1043            * ``count`` for the total number of entries in all lists combined
1044            * ``lists`` for the array containing all lists.
1045            * ``starts`` for the array of starting indices in `lists`.
1046              `starts` is built so that it has n+1 entries, so that
1047              the *i*'th entry is the start of the *i*'th list, and the
1048              *i*'th entry is the index one past the *i*'th list's end,
1049              even for the last list.
1050
1051              This implies that all lists are contiguous.
1052
1053            If the list name is specified in *eliminate_empty_output_lists*
1054            constructor argument, *lists* has two additional attributes
1055            ``num_nonempty_lists`` and ``nonempty_indices``
1056
1057            * ``num_nonempty_lists`` for the number of nonempty lists.
1058            * ``nonempty_indices`` for the index of nonempty list in input objects.
1059
1060            In this case, `starts` has `num_nonempty_lists` + 1 entries. The *i*'s
1061            entry is the start of the *i*'th nonempty list, which is generated by
1062            the object with index *nonempty_indices[i]*.
1063
1064            *event* is a :class:`pyopencl.Event` for dependency management.
1065
1066        .. versionchanged:: 2016.2
1067
1068            Added omit_lists.
1069        """
1070        if n_objects >= int(np.iinfo(np.int32).max):
1071            index_dtype = np.int64
1072        else:
1073            index_dtype = np.int32
1074        index_dtype = np.dtype(index_dtype)
1075
1076        allocator = kwargs.pop("allocator", None)
1077        omit_lists = kwargs.pop("omit_lists", [])
1078        wait_for = kwargs.pop("wait_for", None)
1079        if kwargs:
1080            raise TypeError("invalid keyword arguments: '%s'" % ", ".join(kwargs))
1081
1082        for oml in omit_lists:
1083            if not any(oml == name for name, _ in self.list_names_and_dtypes):
1084                raise ValueError("invalid list name '%s' in omit_lists")
1085
1086        result = {}
1087        count_list_args = []
1088
1089        if wait_for is None:
1090            wait_for = []
1091
1092        count_kernel = self.get_count_kernel(index_dtype)
1093        write_kernel = self.get_write_kernel(index_dtype)
1094        scan_kernel = self.get_scan_kernel(index_dtype)
1095        if self.eliminate_empty_output_lists:
1096            compress_kernel = self.get_compress_kernel(index_dtype)
1097
1098        # {{{ allocate memory for counts
1099
1100        for name, dtype in self.list_names_and_dtypes:
1101            if name in self.count_sharing:
1102                continue
1103            if name in omit_lists:
1104                count_list_args.append(None)
1105                continue
1106
1107            counts = cl.array.empty(queue,
1108                    (n_objects + 1), index_dtype, allocator=allocator)
1109            counts[-1] = 0
1110            wait_for = wait_for + counts.events
1111
1112            # The scan will turn the "counts" array into the "starts" array
1113            # in-place.
1114            if name in self.eliminate_empty_output_lists:
1115                result[name] = BuiltList(count=None, starts=counts, lists=None,
1116                                         num_nonempty_lists=None,
1117                                         nonempty_indices=None)
1118            else:
1119                result[name] = BuiltList(count=None, starts=counts, lists=None)
1120            count_list_args.append(counts.data)
1121
1122        # }}}
1123
1124        if self.debug:
1125            gsize = (1,)
1126            lsize = (1,)
1127        elif self.do_not_vectorize():
1128            gsize = (4*queue.device.max_compute_units,)
1129            lsize = (1,)
1130        else:
1131            from pyopencl.array import splay
1132            gsize, lsize = splay(queue, n_objects)
1133
1134        count_event = count_kernel(queue, gsize, lsize,
1135                *(tuple(count_list_args) + args + (n_objects,)),
1136                **dict(wait_for=wait_for))
1137
1138        compress_events = {}
1139        for name, dtype in self.list_names_and_dtypes:
1140            if name in omit_lists:
1141                continue
1142            if name in self.count_sharing:
1143                continue
1144            if name not in self.eliminate_empty_output_lists:
1145                continue
1146
1147            compressed_counts = cl.array.empty(
1148                queue, (n_objects + 1,), index_dtype, allocator=allocator)
1149            info_record = result[name]
1150            info_record.nonempty_indices = cl.array.empty(
1151                queue, (n_objects + 1,), index_dtype, allocator=allocator)
1152            info_record.num_nonempty_lists = cl.array.empty(
1153                queue, (1,), index_dtype, allocator=allocator)
1154            info_record.compressed_indices = cl.array.empty(
1155                queue, (n_objects + 1,), index_dtype, allocator=allocator)
1156            info_record.compressed_indices[0] = 0
1157            compress_events[name] = compress_kernel(
1158                info_record.starts,
1159                compressed_counts,
1160                info_record.nonempty_indices,
1161                info_record.compressed_indices,
1162                info_record.num_nonempty_lists,
1163                wait_for=[count_event] + info_record.compressed_indices.events)
1164
1165            info_record.starts = compressed_counts
1166
1167        # {{{ run scans
1168
1169        scan_events = []
1170
1171        for name, dtype in self.list_names_and_dtypes:
1172            if name in self.count_sharing:
1173                continue
1174            if name in omit_lists:
1175                continue
1176
1177            info_record = result[name]
1178            if name in self.eliminate_empty_output_lists:
1179                compress_events[name].wait()
1180                num_nonempty_lists = info_record.num_nonempty_lists.get()[0]
1181                info_record.num_nonempty_lists = num_nonempty_lists
1182                info_record.starts = info_record.starts[:num_nonempty_lists + 1]
1183                info_record.nonempty_indices = \
1184                    info_record.nonempty_indices[:num_nonempty_lists]
1185                info_record.starts[-1] = 0
1186
1187            starts_ary = info_record.starts
1188            if name in self.eliminate_empty_output_lists:
1189                evt = scan_kernel(
1190                        starts_ary,
1191                        size=info_record.num_nonempty_lists,
1192                        wait_for=starts_ary.events)
1193            else:
1194                evt = scan_kernel(starts_ary, wait_for=[count_event],
1195                        size=n_objects)
1196
1197            starts_ary.setitem(0, 0, queue=queue, wait_for=[evt])
1198            scan_events.extend(starts_ary.events)
1199
1200            # retrieve count
1201            info_record.count = int(starts_ary[-1].get())
1202
1203        # }}}
1204
1205        # {{{ deal with count-sharing lists, allocate memory for lists
1206
1207        write_list_args = []
1208        for name, dtype in self.list_names_and_dtypes:
1209            if name in omit_lists:
1210                write_list_args.append(None)
1211                if name not in self.count_sharing:
1212                    write_list_args.append(None)
1213                if name in self.eliminate_empty_output_lists:
1214                    write_list_args.append(None)
1215                continue
1216
1217            if name in self.count_sharing:
1218                sharing_from = self.count_sharing[name]
1219
1220                info_record = result[name] = BuiltList(
1221                        count=result[sharing_from].count,
1222                        starts=result[sharing_from].starts,
1223                        )
1224
1225            else:
1226                info_record = result[name]
1227
1228            info_record.lists = cl.array.empty(queue,
1229                    info_record.count, dtype, allocator=allocator)
1230            write_list_args.append(info_record.lists.data)
1231
1232            if name not in self.count_sharing:
1233                write_list_args.append(info_record.starts.data)
1234
1235            if name in self.eliminate_empty_output_lists:
1236                write_list_args.append(info_record.compressed_indices.data)
1237
1238        # }}}
1239
1240        evt = write_kernel(queue, gsize, lsize,
1241                *(tuple(write_list_args) + args + (n_objects,)),
1242                **dict(wait_for=scan_events))
1243
1244        return result, evt
1245
1246    # }}}
1247
1248# }}}
1249
1250
1251# {{{ key-value sorting
1252
1253class _KernelInfo(Record):
1254    pass
1255
1256
1257def _make_cl_int_literal(value, dtype):
1258    iinfo = np.iinfo(dtype)
1259    result = str(int(value))
1260    if dtype.itemsize == 8:
1261        result += "l"
1262    if int(iinfo.min) < 0:
1263        result += "u"
1264
1265    return result
1266
1267
1268class KeyValueSorter(object):
1269    """Given arrays *values* and *keys* of equal length
1270    and a number *nkeys* of keys, returns a tuple `(starts,
1271    lists)`, as follows: *values* and *keys* are sorted
1272    by *keys*, and the sorted *values* is returned as
1273    *lists*. Then for each index *i* in `range(nkeys)`,
1274    *starts[i]* is written to indicating where the
1275    group of *values* belonging to the key with index
1276    *i* begins. It implicitly ends at *starts[i+1]*.
1277
1278    `starts` is built so that it has `nkeys+1` entries, so that
1279    the *i*'th entry is the start of the *i*'th list, and the
1280    *i*'th entry is the index one past the *i*'th list's end,
1281    even for the last list.
1282
1283    This implies that all lists are contiguous.
1284
1285    .. note:: This functionality is provided as a preview. Its
1286        interface is subject to change until this notice is removed.
1287
1288    .. versionadded:: 2013.1
1289    """
1290
1291    def __init__(self, context):
1292        self.context = context
1293
1294    @memoize_method
1295    def get_kernels(self, key_dtype, value_dtype, starts_dtype):
1296        from pyopencl.algorithm import RadixSort
1297        from pyopencl.tools import VectorArg, ScalarArg
1298
1299        by_target_sorter = RadixSort(
1300                self.context, [
1301                    VectorArg(value_dtype, "values"),
1302                    VectorArg(key_dtype, "keys"),
1303                    ],
1304                key_expr="keys[i]",
1305                sort_arg_names=["values", "keys"])
1306
1307        from pyopencl.elementwise import ElementwiseTemplate
1308        start_finder = ElementwiseTemplate(
1309                arguments="""//CL//
1310                starts_t *key_group_starts,
1311                key_t *keys_sorted_by_key,
1312                """,
1313
1314                operation=r"""//CL//
1315                key_t my_key = keys_sorted_by_key[i];
1316
1317                if (i == 0 || my_key != keys_sorted_by_key[i-1])
1318                    key_group_starts[my_key] = i;
1319                """,
1320                name="find_starts").build(self.context,
1321                        type_aliases=(
1322                            ("key_t", starts_dtype),
1323                            ("starts_t", starts_dtype),
1324                            ),
1325                        var_values=())
1326
1327        from pyopencl.scan import GenericScanKernel
1328        bound_propagation_scan = GenericScanKernel(
1329                self.context, starts_dtype,
1330                arguments=[
1331                    VectorArg(starts_dtype, "starts"),
1332                    # starts has length n+1
1333                    ScalarArg(key_dtype, "nkeys"),
1334                    ],
1335                input_expr="starts[nkeys-i]",
1336                scan_expr="min(a, b)",
1337                neutral=_make_cl_int_literal(
1338                    np.iinfo(starts_dtype).max, starts_dtype),
1339                output_statement="starts[nkeys-i] = item;")
1340
1341        return _KernelInfo(
1342                by_target_sorter=by_target_sorter,
1343                start_finder=start_finder,
1344                bound_propagation_scan=bound_propagation_scan)
1345
1346    def __call__(self, queue, keys, values, nkeys,
1347            starts_dtype, allocator=None, wait_for=None):
1348        if allocator is None:
1349            allocator = values.allocator
1350
1351        knl_info = self.get_kernels(keys.dtype, values.dtype,
1352                starts_dtype)
1353
1354        (values_sorted_by_key, keys_sorted_by_key), evt = knl_info.by_target_sorter(
1355                values, keys, queue=queue, wait_for=wait_for)
1356
1357        starts = (cl.array.empty(queue, (nkeys+1), starts_dtype, allocator=allocator)
1358                .fill(len(values_sorted_by_key), wait_for=[evt]))
1359        evt, = starts.events
1360
1361        evt = knl_info.start_finder(starts, keys_sorted_by_key,
1362                range=slice(len(keys_sorted_by_key)),
1363                wait_for=[evt])
1364
1365        evt = knl_info.bound_propagation_scan(starts, nkeys,
1366                queue=queue, wait_for=[evt])
1367
1368        return starts, values_sorted_by_key, evt
1369
1370# }}}
1371
1372# vim: filetype=pyopencl:fdm=marker
1373