1"""CL device arrays."""
2
3from __future__ import division, absolute_import
4
5__copyright__ = "Copyright (C) 2009 Andreas Kloeckner"
6
7__license__ = """
8Permission is hereby granted, free of charge, to any person
9obtaining a copy of this software and associated documentation
10files (the "Software"), to deal in the Software without
11restriction, including without limitation the rights to use,
12copy, modify, merge, publish, distribute, sublicense, and/or sell
13copies of the Software, and to permit persons to whom the
14Software is furnished to do so, subject to the following
15conditions:
16
17The above copyright notice and this permission notice shall be
18included in all copies or substantial portions of the Software.
19
20THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
21EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
22OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
23NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
24HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
25WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
26FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
27OTHER DEALINGS IN THE SOFTWARE.
28"""
29
30import six
31from six.moves import range, reduce
32
33import numpy as np
34import pyopencl.elementwise as elementwise
35import pyopencl as cl
36from pytools import memoize_method
37from pyopencl.compyte.array import (
38        as_strided as _as_strided,
39        f_contiguous_strides as _f_contiguous_strides,
40        c_contiguous_strides as _c_contiguous_strides,
41        equal_strides as _equal_strides,
42        ArrayFlags as _ArrayFlags,
43        get_common_dtype as _get_common_dtype_base)
44from pyopencl.characterize import has_double_support
45from pyopencl import cltypes
46
47
48def _get_common_dtype(obj1, obj2, queue):
49    return _get_common_dtype_base(obj1, obj2,
50                                  has_double_support(queue.device))
51
52
53# Work around PyPy not currently supporting the object dtype.
54# (Yes, it doesn't even support checking!)
55# (as of May 27, 2014 on PyPy 2.3)
56try:
57    np.dtype(object)
58
59    def _dtype_is_object(t):
60        return t == object
61except Exception:
62    def _dtype_is_object(t):
63        return False
64
65
66class VecLookupWarner(object):
67    def __getattr__(self, name):
68        from warnings import warn
69        warn("pyopencl.array.vec is deprecated. "
70             "Please use pyopencl.cltypes for OpenCL vector and scalar types",
71             DeprecationWarning, 2)
72
73        if name == "types":
74            name = "vec_types"
75        elif name == "type_to_scalar_and_count":
76            name = "vec_type_to_scalar_and_count"
77
78        return getattr(cltypes, name)
79
80
81vec = VecLookupWarner()
82
83# {{{ helper functionality
84
85
86def splay(queue, n, kernel_specific_max_wg_size=None):
87    dev = queue.device
88    max_work_items = _builtin_min(128, dev.max_work_group_size)
89
90    if kernel_specific_max_wg_size is not None:
91        from six.moves.builtins import min
92        max_work_items = min(max_work_items, kernel_specific_max_wg_size)
93
94    min_work_items = _builtin_min(32, max_work_items)
95    max_groups = dev.max_compute_units * 4 * 8
96    # 4 to overfill the device
97    # 8 is an Nvidia constant--that's how many
98    # groups fit onto one compute device
99
100    if n < min_work_items:
101        group_count = 1
102        work_items_per_group = min_work_items
103    elif n < (max_groups * min_work_items):
104        group_count = (n + min_work_items - 1) // min_work_items
105        work_items_per_group = min_work_items
106    elif n < (max_groups * max_work_items):
107        group_count = max_groups
108        grp = (n + min_work_items - 1) // min_work_items
109        work_items_per_group = (
110                (grp + max_groups - 1) // max_groups) * min_work_items
111    else:
112        group_count = max_groups
113        work_items_per_group = max_work_items
114
115    #print "n:%d gc:%d wipg:%d" % (n, group_count, work_items_per_group)
116    return (group_count*work_items_per_group,), (work_items_per_group,)
117
118
119def elwise_kernel_runner(kernel_getter):
120    """Take a kernel getter of the same signature as the kernel
121    and return a function that invokes that kernel.
122
123    Assumes that the zeroth entry in *args* is an :class:`Array`.
124    """
125
126    def kernel_runner(*args, **kwargs):
127        repr_ary = args[0]
128        queue = kwargs.pop("queue", None) or repr_ary.queue
129        wait_for = kwargs.pop("wait_for", None)
130
131        # wait_for must be a copy, because we modify it in-place below
132        if wait_for is None:
133            wait_for = []
134        else:
135            wait_for = list(wait_for)
136
137        knl = kernel_getter(*args, **kwargs)
138
139        gs, ls = repr_ary.get_sizes(queue,
140                knl.get_work_group_info(
141                    cl.kernel_work_group_info.WORK_GROUP_SIZE,
142                    queue.device))
143
144        assert isinstance(repr_ary, Array)
145
146        actual_args = []
147        for arg in args:
148            if isinstance(arg, Array):
149                if not arg.flags.forc:
150                    raise RuntimeError("only contiguous arrays may "
151                            "be used as arguments to this operation")
152                actual_args.append(arg.base_data)
153                actual_args.append(arg.offset)
154                wait_for.extend(arg.events)
155            else:
156                actual_args.append(arg)
157        actual_args.append(repr_ary.size)
158
159        return knl(queue, gs, ls, *actual_args, **dict(wait_for=wait_for))
160
161    try:
162        from functools import update_wrapper
163    except ImportError:
164        return kernel_runner
165    else:
166        return update_wrapper(kernel_runner, kernel_getter)
167
168
169class DefaultAllocator(cl.tools.DeferredAllocator):
170    def __init__(self, *args, **kwargs):
171        from warnings import warn
172        warn("pyopencl.array.DefaultAllocator is deprecated. "
173                "It will be continue to exist throughout the 2013.x "
174                "versions of PyOpenCL.",
175                DeprecationWarning, 2)
176        cl.tools.DeferredAllocator.__init__(self, *args, **kwargs)
177
178
179def _make_strides(itemsize, shape, order):
180    if order in "fF":
181        return _f_contiguous_strides(itemsize, shape)
182    elif order in "cC":
183        return _c_contiguous_strides(itemsize, shape)
184    else:
185        raise ValueError("invalid order: %s" % order)
186
187# }}}
188
189
190# {{{ array class
191
192class ArrayHasOffsetError(ValueError):
193    """
194    .. versionadded:: 2013.1
195    """
196
197    def __init__(self, val="The operation you are attempting does not yet "
198                "support arrays that start at an offset from the beginning "
199                "of their buffer."):
200        ValueError.__init__(self, val)
201
202
203class _copy_queue:  # noqa
204    pass
205
206
207class Array(object):
208    """A :class:`numpy.ndarray` work-alike that stores its data and performs
209    its computations on the compute device.  *shape* and *dtype* work exactly
210    as in :mod:`numpy`.  Arithmetic methods in :class:`Array` support the
211    broadcasting of scalars. (e.g. `array+5`)
212
213    *cq* must be a :class:`pyopencl.CommandQueue` or a :class:`pyopencl.Context`.
214
215    If it is a queue, *cq* specifies the queue in which the array carries out
216    its computations by default. If a default queue (and thereby overloaded
217    operators and many other niceties) are not desired, pass a
218    :class:`Context`.
219
220    *allocator* may be `None` or a callable that, upon being called with an
221    argument of the number of bytes to be allocated, returns an
222    :class:`pyopencl.Buffer` object. (A :class:`pyopencl.tools.MemoryPool`
223    instance is one useful example of an object to pass here.)
224
225    .. versionchanged:: 2011.1
226        Renamed *context* to *cqa*, made it general-purpose.
227
228        All arguments beyond *order* should be considered keyword-only.
229
230    .. versionchanged:: 2015.2
231        Renamed *context* to *cq*, disallowed passing allocators through it.
232
233    .. attribute :: data
234
235        The :class:`pyopencl.MemoryObject` instance created for the memory that
236        backs this :class:`Array`.
237
238        .. versionchanged:: 2013.1
239
240            If a non-zero :attr:`offset` has been specified for this array,
241            this will fail with :exc:`ArrayHasOffsetError`.
242
243    .. attribute :: base_data
244
245        The :class:`pyopencl.MemoryObject` instance created for the memory that
246        backs this :class:`Array`. Unlike :attr:`data`, the base address of
247        *base_data* is allowed to be different from the beginning of the array.
248        The actual beginning is the base address of *base_data* plus
249        :attr:`offset` bytes.
250
251        Unlike :attr:`data`, retrieving :attr:`base_data` always succeeds.
252
253        .. versionadded:: 2013.1
254
255    .. attribute :: offset
256
257        See :attr:`base_data`.
258
259        .. versionadded:: 2013.1
260
261    .. attribute :: shape
262
263        The tuple of lengths of each dimension in the array.
264
265    .. attribute :: ndim
266
267        The number of dimensions in :attr:`shape`.
268
269    .. attribute :: dtype
270
271        The :class:`numpy.dtype` of the items in the GPU array.
272
273    .. attribute :: size
274
275        The number of meaningful entries in the array. Can also be computed by
276        multiplying up the numbers in :attr:`shape`.
277
278    .. attribute :: nbytes
279
280        The size of the entire array in bytes. Computed as :attr:`size` times
281        ``dtype.itemsize``.
282
283    .. attribute :: strides
284
285        Tuple of bytes to step in each dimension when traversing an array.
286
287    .. attribute :: flags
288
289        Return an object with attributes `c_contiguous`, `f_contiguous` and
290        `forc`, which may be used to query contiguity properties in analogy to
291        :attr:`numpy.ndarray.flags`.
292
293    .. rubric:: Methods
294
295    .. automethod :: with_queue
296
297    .. automethod :: __len__
298    .. automethod :: reshape
299    .. automethod :: ravel
300    .. automethod :: view
301    .. automethod :: squeeze
302    .. automethod :: transpose
303    .. attribute :: T
304    .. automethod :: set
305    .. automethod :: get
306    .. automethod :: copy
307
308    .. automethod :: __str__
309    .. automethod :: __repr__
310
311    .. automethod :: mul_add
312    .. automethod :: __add__
313    .. automethod :: __sub__
314    .. automethod :: __iadd__
315    .. automethod :: __isub__
316    .. automethod :: __neg__
317    .. automethod :: __mul__
318    .. automethod :: __div__
319    .. automethod :: __rdiv__
320    .. automethod :: __pow__
321
322    .. automethod :: __and__
323    .. automethod :: __xor__
324    .. automethod :: __or__
325    .. automethod :: __iand__
326    .. automethod :: __ixor__
327    .. automethod :: __ior__
328
329    .. automethod :: __abs__
330    .. automethod :: __invert__
331
332    .. UNDOC reverse()
333
334    .. automethod :: fill
335
336    .. automethod :: astype
337
338    .. autoattribute :: real
339    .. autoattribute :: imag
340    .. automethod :: conj
341
342    .. automethod :: __getitem__
343    .. automethod :: __setitem__
344
345    .. automethod :: setitem
346
347    .. automethod :: map_to_host
348
349    .. rubric:: Comparisons, conditionals, any, all
350
351    .. versionadded:: 2013.2
352
353    Boolean arrays are stored as :class:`numpy.int8` because ``bool``
354    has an unspecified size in the OpenCL spec.
355
356    .. automethod :: __nonzero__
357
358        Only works for device scalars. (i.e. "arrays" with ``shape == ()``.)
359
360    .. automethod :: any
361    .. automethod :: all
362
363    .. automethod :: __eq__
364    .. automethod :: __ne__
365    .. automethod :: __lt__
366    .. automethod :: __le__
367    .. automethod :: __gt__
368    .. automethod :: __ge__
369
370    .. rubric:: Event management
371
372    If an array is used from within an out-of-order queue, it needs to take
373    care of its own operation ordering. The facilities in this section make
374    this possible.
375
376    .. versionadded:: 2014.1.1
377
378    .. attribute:: events
379
380        A list of :class:`pyopencl.Event` instances that the current content of
381        this array depends on. User code may read, but should never modify this
382        list directly. To update this list, instead use the following methods.
383
384    .. automethod:: add_event
385    .. automethod:: finish
386    """
387
388    __array_priority__ = 100
389
390    def __init__(self, cq, shape, dtype, order="C", allocator=None,
391            data=None, offset=0, strides=None, events=None):
392        # {{{ backward compatibility
393
394        if isinstance(cq, cl.CommandQueue):
395            queue = cq
396            context = queue.context
397
398        elif isinstance(cq, cl.Context):
399            context = cq
400            queue = None
401
402        else:
403            raise TypeError("cq may be a queue or a context, not '%s'"
404                    % type(cq))
405
406        if allocator is not None:
407            # "is" would be wrong because two Python objects are allowed
408            # to hold handles to the same context.
409
410            # FIXME It would be nice to check this. But it would require
411            # changing the allocator interface. Trust the user for now.
412
413            #assert allocator.context == context
414            pass
415
416        # Queue-less arrays do have a purpose in life.
417        # They don't do very much, but at least they don't run kernels
418        # in random queues.
419        #
420        # See also :meth:`with_queue`.
421
422        del cq
423
424        # }}}
425
426        # invariant here: allocator, queue set
427
428        # {{{ determine shape and strides
429        dtype = np.dtype(dtype)
430
431        try:
432            s = 1
433            for dim in shape:
434                s *= dim
435        except TypeError:
436            import sys
437            if sys.version_info >= (3,):
438                admissible_types = (int, np.integer)
439            else:
440                admissible_types = (np.integer,) + six.integer_types
441
442            if not isinstance(shape, admissible_types):
443                raise TypeError("shape must either be iterable or "
444                        "castable to an integer")
445            s = shape
446            shape = (shape,)
447
448        if isinstance(s, np.integer):
449            # bombs if s is a Python integer
450            s = np.asscalar(s)
451
452        if strides is None:
453            strides = _make_strides(dtype.itemsize, shape, order)
454
455        else:
456            # FIXME: We should possibly perform some plausibility
457            # checking on 'strides' here.
458
459            strides = tuple(strides)
460
461        # }}}
462
463        if _dtype_is_object(dtype):
464            raise TypeError("object arrays on the compute device are not allowed")
465
466        assert isinstance(shape, tuple)
467        assert isinstance(strides, tuple)
468
469        self.queue = queue
470        self.shape = shape
471        self.dtype = dtype
472        self.strides = strides
473        if events is None:
474            self.events = []
475        else:
476            self.events = events
477
478        self.size = s
479        alloc_nbytes = self.nbytes = self.dtype.itemsize * self.size
480
481        self.allocator = allocator
482
483        if data is None:
484            if alloc_nbytes <= 0:
485                if alloc_nbytes == 0:
486                    # Work around CL not allowing zero-sized buffers.
487                    alloc_nbytes = 1
488
489                else:
490                    raise ValueError("cannot allocate CL buffer with "
491                            "negative size")
492
493            if allocator is None:
494                if context is None and queue is not None:
495                    context = queue.context
496
497                self.base_data = cl.Buffer(
498                        context, cl.mem_flags.READ_WRITE, alloc_nbytes)
499            else:
500                self.base_data = self.allocator(alloc_nbytes)
501        else:
502            self.base_data = data
503
504        self.offset = offset
505        self.context = context
506
507    @property
508    def ndim(self):
509        return len(self.shape)
510
511    @property
512    def data(self):
513        if self.offset:
514            raise ArrayHasOffsetError()
515        else:
516            return self.base_data
517
518    @property
519    @memoize_method
520    def flags(self):
521        return _ArrayFlags(self)
522
523    def _new_with_changes(self, data, offset, shape=None, dtype=None,
524            strides=None, queue=_copy_queue, allocator=None):
525        """
526        :arg data: *None* means allocate a new array.
527        """
528        if shape is None:
529            shape = self.shape
530        if dtype is None:
531            dtype = self.dtype
532        if strides is None:
533            strides = self.strides
534        if queue is _copy_queue:
535            queue = self.queue
536        if allocator is None:
537            allocator = self.allocator
538
539        # If we're allocating new data, then there's not likely to be
540        # a data dependency. Otherwise, the two arrays should probably
541        # share the same events list.
542
543        if data is None:
544            events = None
545        else:
546            events = self.events
547
548        if queue is not None:
549            return Array(queue, shape, dtype, allocator=allocator,
550                    strides=strides, data=data, offset=offset,
551                    events=events)
552        else:
553            return Array(self.context, shape, dtype,
554                    strides=strides, data=data, offset=offset,
555                    events=events, allocator=allocator)
556
557    def with_queue(self, queue):
558        """Return a copy of *self* with the default queue set to *queue*.
559
560        *None* is allowed as a value for *queue*.
561
562        .. versionadded:: 2013.1
563        """
564
565        if queue is not None:
566            assert queue.context == self.context
567
568        return self._new_with_changes(self.base_data, self.offset,
569                queue=queue)
570
571    #@memoize_method FIXME: reenable
572    def get_sizes(self, queue, kernel_specific_max_wg_size=None):
573        if not self.flags.forc:
574            raise NotImplementedError("cannot operate on non-contiguous array")
575        return splay(queue, self.size,
576                kernel_specific_max_wg_size=kernel_specific_max_wg_size)
577
578    def set(self, ary, queue=None, async_=None, **kwargs):
579        """Transfer the contents the :class:`numpy.ndarray` object *ary*
580        onto the device.
581
582        *ary* must have the same dtype and size (not necessarily shape) as
583        *self*.
584
585        .. versionchanged:: 2017.2.1
586
587            Python 3.7 makes ``async`` a reserved keyword. On older Pythons,
588            we will continue to  accept *async* as a parameter, however this
589            should be considered deprecated. *async_* is the new, official
590            spelling.
591        """
592
593        # {{{ handle 'async' deprecation
594
595        async_arg = kwargs.pop("async", None)
596        if async_arg is not None:
597            if async_ is not None:
598                raise TypeError("may not specify both 'async' and 'async_'")
599            async_ = async_arg
600
601        if async_ is None:
602            async_ = False
603
604        if kwargs:
605            raise TypeError("extra keyword arguments specified: %s"
606                    % ", ".join(kwargs))
607
608        # }}}
609
610        assert ary.size == self.size
611        assert ary.dtype == self.dtype
612
613        if not ary.flags.forc:
614            raise RuntimeError("cannot set from non-contiguous array")
615
616        if not _equal_strides(ary.strides, self.strides, self.shape):
617            from warnings import warn
618            warn("Setting array from one with different "
619                    "strides/storage order. This will cease to work "
620                    "in 2013.x.",
621                    stacklevel=2)
622
623        if self.size:
624            event1 = cl.enqueue_copy(queue or self.queue, self.base_data, ary,
625                    device_offset=self.offset,
626                    is_blocking=not async_)
627            self.add_event(event1)
628
629    def get(self, queue=None, ary=None, async_=None, **kwargs):
630        """Transfer the contents of *self* into *ary* or a newly allocated
631        :mod:`numpy.ndarray`. If *ary* is given, it must have the same
632        shape and dtype.
633
634        .. versionchanged:: 2015.2
635
636            *ary* with different shape was deprecated.
637
638        .. versionchanged:: 2017.2.1
639
640            Python 3.7 makes ``async`` a reserved keyword. On older Pythons,
641            we will continue to  accept *async* as a parameter, however this
642            should be considered deprecated. *async_* is the new, official
643            spelling.
644        """
645
646        # {{{ handle 'async' deprecation
647
648        async_arg = kwargs.pop("async", None)
649        if async_arg is not None:
650            if async_ is not None:
651                raise TypeError("may not specify both 'async' and 'async_'")
652            async_ = async_arg
653
654        if async_ is None:
655            async_ = False
656
657        if kwargs:
658            raise TypeError("extra keyword arguments specified: %s"
659                    % ", ".join(kwargs))
660
661        # }}}
662
663        if ary is None:
664            ary = np.empty(self.shape, self.dtype)
665
666            if self.strides != ary.strides:
667                ary = _as_strided(ary, strides=self.strides)
668        else:
669            if ary.size != self.size:
670                raise TypeError("'ary' has non-matching size")
671            if ary.dtype != self.dtype:
672                raise TypeError("'ary' has non-matching type")
673
674            if self.shape != ary.shape:
675                from warnings import warn
676                warn("get() between arrays of different shape is deprecated "
677                        "and will be removed in PyCUDA 2017.x",
678                        DeprecationWarning, stacklevel=2)
679
680        assert self.flags.forc, "Array in get() must be contiguous"
681
682        queue = queue or self.queue
683        if queue is None:
684            raise ValueError("Cannot copy array to host. "
685                    "Array has no queue. Use "
686                    "'new_array = array.with_queue(queue)' "
687                    "to associate one.")
688
689        if self.size:
690            cl.enqueue_copy(queue, ary, self.base_data,
691                    device_offset=self.offset,
692                    wait_for=self.events, is_blocking=not async_)
693
694        return ary
695
696    def copy(self, queue=_copy_queue):
697        """
698        :arg queue: The :class:`CommandQueue` for the returned array.
699
700        .. versionchanged:: 2017.1.2
701            Updates the queue of the returned array.
702
703        .. versionadded:: 2013.1
704        """
705
706        if queue is _copy_queue:
707            queue = self.queue
708
709        result = self._new_like_me(queue=queue)
710
711        # result.queue won't be the same as queue if queue is None.
712        # We force them to be the same here.
713        if result.queue is not queue:
714            result = result.with_queue(queue)
715
716        if self.nbytes:
717            event1 = cl.enqueue_copy(queue or self.queue,
718                    result.base_data, self.base_data,
719                    src_offset=self.offset, byte_count=self.nbytes,
720                    wait_for=self.events)
721            result.add_event(event1)
722
723        return result
724
725    def __str__(self):
726        return str(self.get())
727
728    def __repr__(self):
729        return repr(self.get())
730
731    def safely_stringify_for_pudb(self):
732        return "cl.Array %s %s" % (self.dtype, self.shape)
733
734    def __hash__(self):
735        raise TypeError("pyopencl arrays are not hashable.")
736
737    # {{{ kernel invocation wrappers
738
739    @staticmethod
740    @elwise_kernel_runner
741    def _axpbyz(out, afac, a, bfac, b, queue=None):
742        """Compute ``out = selffac * self + otherfac*other``,
743        where *other* is an array."""
744        assert out.shape == a.shape
745        assert out.shape == b.shape
746
747        return elementwise.get_axpbyz_kernel(
748                out.context, a.dtype, b.dtype, out.dtype)
749
750    @staticmethod
751    @elwise_kernel_runner
752    def _axpbz(out, a, x, b, queue=None):
753        """Compute ``z = a * x + b``, where *b* is a scalar."""
754        a = np.array(a)
755        b = np.array(b)
756        assert out.shape == x.shape
757        return elementwise.get_axpbz_kernel(out.context,
758                a.dtype, x.dtype, b.dtype, out.dtype)
759
760    @staticmethod
761    @elwise_kernel_runner
762    def _elwise_multiply(out, a, b, queue=None):
763        assert out.shape == a.shape
764        assert out.shape == b.shape
765        return elementwise.get_multiply_kernel(
766                a.context, a.dtype, b.dtype, out.dtype)
767
768    @staticmethod
769    @elwise_kernel_runner
770    def _rdiv_scalar(out, ary, other, queue=None):
771        other = np.array(other)
772        assert out.shape == ary.shape
773        return elementwise.get_rdivide_elwise_kernel(
774                out.context, ary.dtype, other.dtype, out.dtype)
775
776    @staticmethod
777    @elwise_kernel_runner
778    def _div(out, self, other, queue=None):
779        """Divides an array by another array."""
780
781        assert self.shape == other.shape
782
783        return elementwise.get_divide_kernel(self.context,
784                self.dtype, other.dtype, out.dtype)
785
786    @staticmethod
787    @elwise_kernel_runner
788    def _fill(result, scalar):
789        return elementwise.get_fill_kernel(result.context, result.dtype)
790
791    @staticmethod
792    @elwise_kernel_runner
793    def _abs(result, arg):
794        if arg.dtype.kind == "c":
795            from pyopencl.elementwise import complex_dtype_to_name
796            fname = "%s_abs" % complex_dtype_to_name(arg.dtype)
797        elif arg.dtype.kind == "f":
798            fname = "fabs"
799        elif arg.dtype.kind in ["u", "i"]:
800            fname = "abs"
801        else:
802            raise TypeError("unsupported dtype in _abs()")
803
804        return elementwise.get_unary_func_kernel(
805                arg.context, fname, arg.dtype, out_dtype=result.dtype)
806
807    @staticmethod
808    @elwise_kernel_runner
809    def _real(result, arg):
810        from pyopencl.elementwise import complex_dtype_to_name
811        fname = "%s_real" % complex_dtype_to_name(arg.dtype)
812        return elementwise.get_unary_func_kernel(
813                arg.context, fname, arg.dtype, out_dtype=result.dtype)
814
815    @staticmethod
816    @elwise_kernel_runner
817    def _imag(result, arg):
818        from pyopencl.elementwise import complex_dtype_to_name
819        fname = "%s_imag" % complex_dtype_to_name(arg.dtype)
820        return elementwise.get_unary_func_kernel(
821                arg.context, fname, arg.dtype, out_dtype=result.dtype)
822
823    @staticmethod
824    @elwise_kernel_runner
825    def _conj(result, arg):
826        from pyopencl.elementwise import complex_dtype_to_name
827        fname = "%s_conj" % complex_dtype_to_name(arg.dtype)
828        return elementwise.get_unary_func_kernel(
829                arg.context, fname, arg.dtype, out_dtype=result.dtype)
830
831    @staticmethod
832    @elwise_kernel_runner
833    def _pow_scalar(result, ary, exponent):
834        exponent = np.array(exponent)
835        return elementwise.get_pow_kernel(result.context,
836                ary.dtype, exponent.dtype, result.dtype,
837                is_base_array=True, is_exp_array=False)
838
839    @staticmethod
840    @elwise_kernel_runner
841    def _rpow_scalar(result, base, exponent):
842        base = np.array(base)
843        return elementwise.get_pow_kernel(result.context,
844                base.dtype, exponent.dtype, result.dtype,
845                is_base_array=False, is_exp_array=True)
846
847    @staticmethod
848    @elwise_kernel_runner
849    def _pow_array(result, base, exponent):
850        return elementwise.get_pow_kernel(
851                result.context, base.dtype, exponent.dtype, result.dtype,
852                is_base_array=True, is_exp_array=True)
853
854    @staticmethod
855    @elwise_kernel_runner
856    def _reverse(result, ary):
857        return elementwise.get_reverse_kernel(result.context, ary.dtype)
858
859    @staticmethod
860    @elwise_kernel_runner
861    def _copy(dest, src):
862        return elementwise.get_copy_kernel(
863                dest.context, dest.dtype, src.dtype)
864
865    def _new_like_me(self, dtype=None, queue=None):
866        strides = None
867        if dtype is None:
868            dtype = self.dtype
869
870        if dtype == self.dtype:
871            strides = self.strides
872
873        queue = queue or self.queue
874        if queue is not None:
875            return self.__class__(queue, self.shape, dtype,
876                    allocator=self.allocator, strides=strides)
877        else:
878            return self.__class__(self.context, self.shape, dtype,
879                    strides=strides, allocator=self.allocator)
880
881    @staticmethod
882    @elwise_kernel_runner
883    def _scalar_binop(out, a, b, queue=None, op=None):
884        return elementwise.get_array_scalar_binop_kernel(
885                out.context, op, out.dtype, a.dtype,
886                np.array(b).dtype)
887
888    @staticmethod
889    @elwise_kernel_runner
890    def _array_binop(out, a, b, queue=None, op=None):
891        if a.shape != b.shape:
892            raise ValueError("shapes of binop arguments do not match")
893        return elementwise.get_array_binop_kernel(
894                out.context, op, out.dtype, a.dtype, b.dtype)
895
896    @staticmethod
897    @elwise_kernel_runner
898    def _unop(out, a, queue=None, op=None):
899        if out.shape != a.shape:
900            raise ValueError("shapes of arguments do not match")
901        return elementwise.get_unop_kernel(
902                out.context, op, a.dtype, out.dtype)
903
904    # }}}
905
906    # {{{ operators
907
908    def mul_add(self, selffac, other, otherfac, queue=None):
909        """Return `selffac * self + otherfac*other`.
910        """
911        result = self._new_like_me(
912                _get_common_dtype(self, other, queue or self.queue))
913        result.add_event(
914                self._axpbyz(result, selffac, self, otherfac, other))
915        return result
916
917    def __add__(self, other):
918        """Add an array with an array or an array with a scalar."""
919
920        if isinstance(other, Array):
921            # add another vector
922            result = self._new_like_me(
923                    _get_common_dtype(self, other, self.queue))
924
925            result.add_event(
926                    self._axpbyz(result,
927                        self.dtype.type(1), self,
928                        other.dtype.type(1), other))
929
930            return result
931        else:
932            # add a scalar
933            if other == 0:
934                return self.copy()
935            else:
936                common_dtype = _get_common_dtype(self, other, self.queue)
937                result = self._new_like_me(common_dtype)
938                result.add_event(
939                        self._axpbz(result, self.dtype.type(1),
940                            self, common_dtype.type(other)))
941                return result
942
943    __radd__ = __add__
944
945    def __sub__(self, other):
946        """Substract an array from an array or a scalar from an array."""
947
948        if isinstance(other, Array):
949            result = self._new_like_me(
950                    _get_common_dtype(self, other, self.queue))
951            result.add_event(
952                    self._axpbyz(result,
953                        self.dtype.type(1), self,
954                        other.dtype.type(-1), other))
955
956            return result
957        else:
958            # subtract a scalar
959            if other == 0:
960                return self.copy()
961            else:
962                result = self._new_like_me(
963                        _get_common_dtype(self, other, self.queue))
964                result.add_event(
965                        self._axpbz(result, self.dtype.type(1), self, -other))
966                return result
967
968    def __rsub__(self, other):
969        """Substracts an array by a scalar or an array::
970
971           x = n - self
972        """
973        common_dtype = _get_common_dtype(self, other, self.queue)
974        # other must be a scalar
975        result = self._new_like_me(common_dtype)
976        result.add_event(
977                self._axpbz(result, self.dtype.type(-1), self,
978                    common_dtype.type(other)))
979        return result
980
981    def __iadd__(self, other):
982        if isinstance(other, Array):
983            self.add_event(
984                    self._axpbyz(self,
985                        self.dtype.type(1), self,
986                        other.dtype.type(1), other))
987            return self
988        else:
989            self.add_event(
990                    self._axpbz(self, self.dtype.type(1), self, other))
991            return self
992
993    def __isub__(self, other):
994        if isinstance(other, Array):
995            self.add_event(
996                    self._axpbyz(self, self.dtype.type(1), self,
997                        other.dtype.type(-1), other))
998            return self
999        else:
1000            self._axpbz(self, self.dtype.type(1), self, -other)
1001            return self
1002
1003    def __neg__(self):
1004        result = self._new_like_me()
1005        result.add_event(self._axpbz(result, -1, self, 0))
1006        return result
1007
1008    def __mul__(self, other):
1009        if isinstance(other, Array):
1010            result = self._new_like_me(
1011                    _get_common_dtype(self, other, self.queue))
1012            result.add_event(
1013                    self._elwise_multiply(result, self, other))
1014            return result
1015        else:
1016            common_dtype = _get_common_dtype(self, other, self.queue)
1017            result = self._new_like_me(common_dtype)
1018            result.add_event(
1019                    self._axpbz(result,
1020                        common_dtype.type(other), self, self.dtype.type(0)))
1021            return result
1022
1023    def __rmul__(self, scalar):
1024        common_dtype = _get_common_dtype(self, scalar, self.queue)
1025        result = self._new_like_me(common_dtype)
1026        result.add_event(
1027                self._axpbz(result,
1028                    common_dtype.type(scalar), self, self.dtype.type(0)))
1029        return result
1030
1031    def __imul__(self, other):
1032        if isinstance(other, Array):
1033            self.add_event(
1034                    self._elwise_multiply(self, self, other))
1035        else:
1036            # scalar
1037            self.add_event(
1038                    self._axpbz(self, other, self, self.dtype.type(0)))
1039
1040        return self
1041
1042    def __div__(self, other):
1043        """Divides an array by an array or a scalar, i.e. ``self / other``.
1044        """
1045        if isinstance(other, Array):
1046            result = self._new_like_me(
1047                    _get_common_dtype(self, other, self.queue))
1048            result.add_event(self._div(result, self, other))
1049        else:
1050            if other == 1:
1051                return self.copy()
1052            else:
1053                # create a new array for the result
1054                common_dtype = _get_common_dtype(self, other, self.queue)
1055                result = self._new_like_me(common_dtype)
1056                result.add_event(
1057                        self._axpbz(result,
1058                            common_dtype.type(1/other), self, self.dtype.type(0)))
1059
1060        return result
1061
1062    __truediv__ = __div__
1063
1064    def __rdiv__(self, other):
1065        """Divides an array by a scalar or an array, i.e. ``other / self``.
1066        """
1067
1068        if isinstance(other, Array):
1069            result = self._new_like_me(
1070                    _get_common_dtype(self, other, self.queue))
1071            result.add_event(other._div(result, self))
1072        else:
1073            # create a new array for the result
1074            common_dtype = _get_common_dtype(self, other, self.queue)
1075            result = self._new_like_me(common_dtype)
1076            result.add_event(
1077                    self._rdiv_scalar(result, self, common_dtype.type(other)))
1078
1079        return result
1080
1081    __rtruediv__ = __rdiv__
1082
1083    def __and__(self, other):
1084        common_dtype = _get_common_dtype(self, other, self.queue)
1085
1086        if not np.issubdtype(common_dtype, np.integer):
1087            raise TypeError("Integral types only")
1088
1089        if isinstance(other, Array):
1090            result = self._new_like_me(common_dtype)
1091            result.add_event(self._array_binop(result, self, other, op="&"))
1092        else:
1093            # create a new array for the result
1094            result = self._new_like_me(common_dtype)
1095            result.add_event(
1096                    self._scalar_binop(result, self, other, op="&"))
1097
1098        return result
1099
1100    __rand__ = __and__  # commutes
1101
1102    def __or__(self, other):
1103        common_dtype = _get_common_dtype(self, other, self.queue)
1104
1105        if not np.issubdtype(common_dtype, np.integer):
1106            raise TypeError("Integral types only")
1107
1108        if isinstance(other, Array):
1109            result = self._new_like_me(common_dtype)
1110            result.add_event(self._array_binop(result, self, other, op="|"))
1111        else:
1112            # create a new array for the result
1113            result = self._new_like_me(common_dtype)
1114            result.add_event(
1115                    self._scalar_binop(result, self, other, op="|"))
1116
1117        return result
1118
1119    __ror__ = __or__  # commutes
1120
1121    def __xor__(self, other):
1122        common_dtype = _get_common_dtype(self, other, self.queue)
1123
1124        if not np.issubdtype(common_dtype, np.integer):
1125            raise TypeError("Integral types only")
1126
1127        if isinstance(other, Array):
1128            result = self._new_like_me(common_dtype)
1129            result.add_event(self._array_binop(result, self, other, op="^"))
1130        else:
1131            # create a new array for the result
1132            result = self._new_like_me(common_dtype)
1133            result.add_event(
1134                    self._scalar_binop(result, self, other, op="^"))
1135
1136        return result
1137
1138    __rxor__ = __xor__  # commutes
1139
1140    def __iand__(self, other):
1141        common_dtype = _get_common_dtype(self, other, self.queue)
1142
1143        if not np.issubdtype(common_dtype, np.integer):
1144            raise TypeError("Integral types only")
1145
1146        if isinstance(other, Array):
1147            self.add_event(self._array_binop(self, self, other, op="&"))
1148        else:
1149            self.add_event(
1150                    self._scalar_binop(self, self, other, op="&"))
1151
1152        return self
1153
1154    def __ior__(self, other):
1155        common_dtype = _get_common_dtype(self, other, self.queue)
1156
1157        if not np.issubdtype(common_dtype, np.integer):
1158            raise TypeError("Integral types only")
1159
1160        if isinstance(other, Array):
1161            self.add_event(self._array_binop(self, self, other, op="|"))
1162        else:
1163            self.add_event(
1164                    self._scalar_binop(self, self, other, op="|"))
1165
1166        return self
1167
1168    def __ixor__(self, other):
1169        common_dtype = _get_common_dtype(self, other, self.queue)
1170
1171        if not np.issubdtype(common_dtype, np.integer):
1172            raise TypeError("Integral types only")
1173
1174        if isinstance(other, Array):
1175            self.add_event(self._array_binop(self, self, other, op="^"))
1176        else:
1177            self.add_event(
1178                    self._scalar_binop(self, self, other, op="^"))
1179
1180        return self
1181
1182    def _zero_fill(self, queue=None, wait_for=None):
1183        queue = queue or self.queue
1184
1185        if (
1186                queue._get_cl_version() >= (1, 2)
1187                and cl.get_cl_header_version() >= (1, 2)):
1188
1189            self.add_event(
1190                    cl.enqueue_fill_buffer(queue, self.base_data, np.int8(0),
1191                        self.offset, self.nbytes, wait_for=wait_for))
1192        else:
1193            zero = np.zeros((), self.dtype)
1194            self.fill(zero, queue=queue)
1195
1196    def fill(self, value, queue=None, wait_for=None):
1197        """Fill the array with *scalar*.
1198
1199        :returns: *self*.
1200        """
1201
1202        self.add_event(
1203                self._fill(self, value, queue=queue, wait_for=wait_for))
1204
1205        return self
1206
1207    def __len__(self):
1208        """Returns the size of the leading dimension of *self*."""
1209        if len(self.shape):
1210            return self.shape[0]
1211        else:
1212            return TypeError("scalar has no len()")
1213
1214    def __abs__(self):
1215        """Return a `Array` of the absolute values of the elements
1216        of *self*.
1217        """
1218
1219        result = self._new_like_me(self.dtype.type(0).real.dtype)
1220        result.add_event(self._abs(result, self))
1221        return result
1222
1223    def __pow__(self, other):
1224        """Exponentiation by a scalar or elementwise by another
1225        :class:`Array`.
1226        """
1227
1228        if isinstance(other, Array):
1229            assert self.shape == other.shape
1230
1231            result = self._new_like_me(
1232                    _get_common_dtype(self, other, self.queue))
1233            result.add_event(
1234                    self._pow_array(result, self, other))
1235        else:
1236            result = self._new_like_me(
1237                    _get_common_dtype(self, other, self.queue))
1238            result.add_event(self._pow_scalar(result, self, other))
1239
1240        return result
1241
1242    def __rpow__(self, other):
1243        # other must be a scalar
1244        common_dtype = _get_common_dtype(self, other, self.queue)
1245        result = self._new_like_me(common_dtype)
1246        result.add_event(
1247                self._rpow_scalar(result, common_dtype.type(other), self))
1248        return result
1249
1250    def __invert__(self):
1251        if not np.issubdtype(self.dtype, np.integer):
1252            raise TypeError("Integral types only")
1253
1254        result = self._new_like_me()
1255        result.add_event(self._unop(result, self, op="~"))
1256
1257        return result
1258
1259    # }}}
1260
1261    def reverse(self, queue=None):
1262        """Return this array in reversed order. The array is treated
1263        as one-dimensional.
1264        """
1265
1266        result = self._new_like_me()
1267        result.add_event(
1268                self._reverse(result, self))
1269        return result
1270
1271    def astype(self, dtype, queue=None):
1272        """Return a copy of *self*, cast to *dtype*."""
1273        if dtype == self.dtype:
1274            return self.copy()
1275
1276        result = self._new_like_me(dtype=dtype)
1277        result.add_event(self._copy(result, self, queue=queue))
1278        return result
1279
1280    # {{{ rich comparisons, any, all
1281
1282    def __nonzero__(self):
1283        if self.shape == ():
1284            return bool(self.get())
1285        else:
1286            raise ValueError("The truth value of an array with "
1287                    "more than one element is ambiguous. Use a.any() or a.all()")
1288
1289    __bool__ = __nonzero__
1290
1291    def any(self, queue=None, wait_for=None):
1292        from pyopencl.reduction import get_any_kernel
1293        krnl = get_any_kernel(self.context, self.dtype)
1294        if wait_for is None:
1295            wait_for = []
1296        result, event1 = krnl(self, queue=queue,
1297               wait_for=wait_for + self.events, return_event=True)
1298        result.add_event(event1)
1299        return result
1300
1301    def all(self, queue=None, wait_for=None):
1302        from pyopencl.reduction import get_all_kernel
1303        krnl = get_all_kernel(self.context, self.dtype)
1304        if wait_for is None:
1305            wait_for = []
1306        result, event1 = krnl(self, queue=queue,
1307               wait_for=wait_for + self.events, return_event=True)
1308        result.add_event(event1)
1309        return result
1310
1311    @staticmethod
1312    @elwise_kernel_runner
1313    def _scalar_comparison(out, a, b, queue=None, op=None):
1314        return elementwise.get_array_scalar_comparison_kernel(
1315                out.context, op, a.dtype)
1316
1317    @staticmethod
1318    @elwise_kernel_runner
1319    def _array_comparison(out, a, b, queue=None, op=None):
1320        if a.shape != b.shape:
1321            raise ValueError("shapes of comparison arguments do not match")
1322        return elementwise.get_array_comparison_kernel(
1323                out.context, op, a.dtype, b.dtype)
1324
1325    def __eq__(self, other):
1326        if isinstance(other, Array):
1327            result = self._new_like_me(np.int8)
1328            result.add_event(
1329                    self._array_comparison(result, self, other, op="=="))
1330            return result
1331        else:
1332            result = self._new_like_me(np.int8)
1333            result.add_event(
1334                    self._scalar_comparison(result, self, other, op="=="))
1335            return result
1336
1337    def __ne__(self, other):
1338        if isinstance(other, Array):
1339            result = self._new_like_me(np.int8)
1340            result.add_event(
1341                    self._array_comparison(result, self, other, op="!="))
1342            return result
1343        else:
1344            result = self._new_like_me(np.int8)
1345            result.add_event(
1346                    self._scalar_comparison(result, self, other, op="!="))
1347            return result
1348
1349    def __le__(self, other):
1350        if isinstance(other, Array):
1351            result = self._new_like_me(np.int8)
1352            result.add_event(
1353                    self._array_comparison(result, self, other, op="<="))
1354            return result
1355        else:
1356            result = self._new_like_me(np.int8)
1357            self._scalar_comparison(result, self, other, op="<=")
1358            return result
1359
1360    def __ge__(self, other):
1361        if isinstance(other, Array):
1362            result = self._new_like_me(np.int8)
1363            result.add_event(
1364                    self._array_comparison(result, self, other, op=">="))
1365            return result
1366        else:
1367            result = self._new_like_me(np.int8)
1368            result.add_event(
1369                    self._scalar_comparison(result, self, other, op=">="))
1370            return result
1371
1372    def __lt__(self, other):
1373        if isinstance(other, Array):
1374            result = self._new_like_me(np.int8)
1375            result.add_event(
1376                    self._array_comparison(result, self, other, op="<"))
1377            return result
1378        else:
1379            result = self._new_like_me(np.int8)
1380            result.add_event(
1381                    self._scalar_comparison(result, self, other, op="<"))
1382            return result
1383
1384    def __gt__(self, other):
1385        if isinstance(other, Array):
1386            result = self._new_like_me(np.int8)
1387            result.add_event(
1388                    self._array_comparison(result, self, other, op=">"))
1389            return result
1390        else:
1391            result = self._new_like_me(np.int8)
1392            result.add_event(
1393                    self._scalar_comparison(result, self, other, op=">"))
1394            return result
1395
1396    # }}}
1397
1398    # {{{ complex-valued business
1399
1400    def real(self):
1401        if self.dtype.kind == "c":
1402            result = self._new_like_me(self.dtype.type(0).real.dtype)
1403            result.add_event(
1404                    self._real(result, self))
1405            return result
1406        else:
1407            return self
1408    real = property(real, doc=".. versionadded:: 2012.1")
1409
1410    def imag(self):
1411        if self.dtype.kind == "c":
1412            result = self._new_like_me(self.dtype.type(0).real.dtype)
1413            result.add_event(
1414                    self._imag(result, self))
1415            return result
1416        else:
1417            return zeros_like(self)
1418    imag = property(imag, doc=".. versionadded:: 2012.1")
1419
1420    def conj(self):
1421        """.. versionadded:: 2012.1"""
1422        if self.dtype.kind == "c":
1423            result = self._new_like_me()
1424            result.add_event(self._conj(result, self))
1425            return result
1426        else:
1427            return self
1428
1429    # }}}
1430
1431    # {{{ event management
1432
1433    def add_event(self, evt):
1434        """Add *evt* to :attr:`events`. If :attr:`events` is too long, this method
1435        may implicitly wait for a subset of :attr:`events` and clear them from the
1436        list.
1437        """
1438        n_wait = 4
1439
1440        self.events.append(evt)
1441
1442        if len(self.events) > 3*n_wait:
1443            wait_events = self.events[:n_wait]
1444            cl.wait_for_events(wait_events)
1445            del self.events[:n_wait]
1446
1447    def finish(self):
1448        """Wait for the entire contents of :attr:`events`, clear it."""
1449
1450        if self.events:
1451            cl.wait_for_events(self.events)
1452            del self.events[:]
1453
1454    # }}}
1455
1456    # {{{ views
1457
1458    def reshape(self, *shape, **kwargs):
1459        """Returns an array containing the same data with a new shape."""
1460
1461        order = kwargs.pop("order", "C")
1462        if kwargs:
1463            raise TypeError("unexpected keyword arguments: %s"
1464                    % list(kwargs.keys()))
1465
1466        if order not in "CF":
1467            raise ValueError("order must be either 'C' or 'F'")
1468
1469        # TODO: add more error-checking, perhaps
1470
1471        # FIXME: The following is overly conservative. As long as we don't change
1472        # our memory footprint, we're good.
1473
1474        # if not self.flags.forc:
1475        #     raise RuntimeError("only contiguous arrays may "
1476        #             "be used as arguments to this operation")
1477
1478        if isinstance(shape[0], tuple) or isinstance(shape[0], list):
1479            shape = tuple(shape[0])
1480
1481        if -1 in shape:
1482            shape = list(shape)
1483            idx = shape.index(-1)
1484            size = -reduce(lambda x, y: x * y, shape, 1)
1485            shape[idx] = self.size // size
1486            if any(s < 0 for s in shape):
1487                raise ValueError("can only specify one unknown dimension")
1488            shape = tuple(shape)
1489
1490        if shape == self.shape:
1491            return self._new_with_changes(
1492                    data=self.base_data, offset=self.offset, shape=shape,
1493                    strides=self.strides)
1494
1495        import operator
1496        size = reduce(operator.mul, shape, 1)
1497        if size != self.size:
1498            raise ValueError("total size of new array must be unchanged")
1499
1500        # {{{ determine reshaped strides
1501
1502        # copied and translated from
1503        # https://github.com/numpy/numpy/blob/4083883228d61a3b571dec640185b5a5d983bf59/numpy/core/src/multiarray/shape.c  # noqa
1504
1505        newdims = shape
1506        newnd = len(newdims)
1507
1508        # Remove axes with dimension 1 from the old array. They have no effect
1509        # but would need special cases since their strides do not matter.
1510
1511        olddims = []
1512        oldstrides = []
1513        for oi in range(len(self.shape)):
1514            s = self.shape[oi]
1515            if s != 1:
1516                olddims.append(s)
1517                oldstrides.append(self.strides[oi])
1518
1519        oldnd = len(olddims)
1520
1521        newstrides = [-1]*len(newdims)
1522
1523        # oi to oj and ni to nj give the axis ranges currently worked with
1524        oi = 0
1525        oj = 1
1526        ni = 0
1527        nj = 1
1528        while ni < newnd and oi < oldnd:
1529            np = newdims[ni]
1530            op = olddims[oi]
1531
1532            while np != op:
1533                if np < op:
1534                    # Misses trailing 1s, these are handled later
1535                    np *= newdims[nj]
1536                    nj += 1
1537                else:
1538                    op *= olddims[oj]
1539                    oj += 1
1540
1541            # Check whether the original axes can be combined
1542            for ok in range(oi, oj-1):
1543                if order == "F":
1544                    if oldstrides[ok+1] != olddims[ok]*oldstrides[ok]:
1545                        raise ValueError("cannot reshape without copy")
1546                else:
1547                    # C order
1548                    if (oldstrides[ok] != olddims[ok+1]*oldstrides[ok+1]):
1549                        raise ValueError("cannot reshape without copy")
1550
1551            # Calculate new strides for all axes currently worked with
1552            if order == "F":
1553                newstrides[ni] = oldstrides[oi]
1554                for nk in range(ni+1, nj):
1555                    newstrides[nk] = newstrides[nk - 1]*newdims[nk - 1]
1556            else:
1557                # C order
1558                newstrides[nj - 1] = oldstrides[oj - 1]
1559                for nk in range(nj-1, ni, -1):
1560                    newstrides[nk - 1] = newstrides[nk]*newdims[nk]
1561
1562            ni = nj
1563            nj += 1
1564
1565            oi = oj
1566            oj += 1
1567
1568        # Set strides corresponding to trailing 1s of the new shape.
1569        if ni >= 1:
1570            last_stride = newstrides[ni - 1]
1571        else:
1572            last_stride = self.dtype.itemsize
1573
1574        if order == "F":
1575            last_stride *= newdims[ni - 1]
1576
1577        for nk in range(ni, len(shape)):
1578            newstrides[nk] = last_stride
1579
1580        # }}}
1581
1582        return self._new_with_changes(
1583                data=self.base_data, offset=self.offset, shape=shape,
1584                strides=tuple(newstrides))
1585
1586    def ravel(self):
1587        """Returns flattened array containing the same data."""
1588        return self.reshape(self.size)
1589
1590    def view(self, dtype=None):
1591        """Returns view of array with the same data. If *dtype* is different
1592        from current dtype, the actual bytes of memory will be reinterpreted.
1593        """
1594
1595        if dtype is None:
1596            dtype = self.dtype
1597
1598        old_itemsize = self.dtype.itemsize
1599        itemsize = np.dtype(dtype).itemsize
1600
1601        from pytools import argmin2
1602        min_stride_axis = argmin2(
1603                (axis, abs(stride))
1604                for axis, stride in enumerate(self.strides))
1605
1606        if self.shape[min_stride_axis] * old_itemsize % itemsize != 0:
1607            raise ValueError("new type not compatible with array")
1608
1609        new_shape = (
1610                self.shape[:min_stride_axis]
1611                + (self.shape[min_stride_axis] * old_itemsize // itemsize,)
1612                + self.shape[min_stride_axis+1:])
1613        new_strides = (
1614                self.strides[:min_stride_axis]
1615                + (self.strides[min_stride_axis] * itemsize // old_itemsize,)
1616                + self.strides[min_stride_axis+1:])
1617
1618        return self._new_with_changes(
1619                self.base_data, self.offset,
1620                shape=new_shape, dtype=dtype,
1621                strides=new_strides)
1622
1623    def squeeze(self):
1624        """Returns a view of the array with dimensions of
1625        length 1 removed.
1626
1627        .. versionadded:: 2015.2
1628        """
1629        new_shape = tuple([dim for dim in self.shape if dim > 1])
1630        new_strides = tuple([self.strides[i]
1631            for i, dim in enumerate(self.shape) if dim > 1])
1632
1633        return self._new_with_changes(
1634                self.base_data, self.offset,
1635                shape=new_shape, strides=new_strides)
1636
1637    def transpose(self, axes=None):
1638        """Permute the dimensions of an array.
1639
1640        :arg axes: list of ints, optional.
1641            By default, reverse the dimensions, otherwise permute the axes
1642            according to the values given.
1643
1644        :returns: :class:`Array` A view of the array with its axes permuted.
1645
1646        .. versionadded:: 2015.2
1647        """
1648
1649        if axes is None:
1650            axes = range(self.ndim-1, -1, -1)
1651
1652        if len(axes) != len(self.shape):
1653            raise ValueError("axes don't match array")
1654
1655        new_shape = [self.shape[axes[i]] for i in range(len(axes))]
1656        new_strides = [self.strides[axes[i]] for i in range(len(axes))]
1657
1658        return self._new_with_changes(
1659                self.base_data, self.offset,
1660                shape=tuple(new_shape),
1661                strides=tuple(new_strides))
1662
1663    @property
1664    def T(self):  # noqa
1665        """
1666        .. versionadded:: 2015.2
1667        """
1668        return self.transpose()
1669
1670    # }}}
1671
1672    def map_to_host(self, queue=None, flags=None, is_blocking=True, wait_for=None):
1673        """If *is_blocking*, return a :class:`numpy.ndarray` corresponding to the
1674        same memory as *self*.
1675
1676        If *is_blocking* is not true, return a tuple ``(ary, evt)``, where
1677        *ary* is the above-mentioned array.
1678
1679        The host array is obtained using :func:`pyopencl.enqueue_map_buffer`.
1680        See there for further details.
1681
1682        :arg flags: A combination of :class:`pyopencl.map_flags`.
1683            Defaults to read-write.
1684
1685        .. versionadded :: 2013.2
1686        """
1687
1688        if flags is None:
1689            flags = cl.map_flags.READ | cl.map_flags.WRITE
1690        if wait_for is None:
1691            wait_for = []
1692
1693        ary, evt = cl.enqueue_map_buffer(
1694                queue or self.queue, self.base_data, flags, self.offset,
1695                self.shape, self.dtype, strides=self.strides,
1696                wait_for=wait_for + self.events, is_blocking=is_blocking)
1697
1698        if is_blocking:
1699            return ary
1700        else:
1701            return ary, evt
1702
1703    # {{{ getitem/setitem
1704
1705    def __getitem__(self, index):
1706        """
1707        .. versionadded:: 2013.1
1708        """
1709
1710        if isinstance(index, Array):
1711            if index.dtype.kind != "i":
1712                raise TypeError(
1713                        "fancy indexing is only allowed with integers")
1714            if len(index.shape) != 1:
1715                raise NotImplementedError(
1716                        "multidimensional fancy indexing is not supported")
1717            if len(self.shape) != 1:
1718                raise NotImplementedError(
1719                        "fancy indexing into a multi-d array is not supported")
1720
1721            return take(self, index)
1722
1723        if not isinstance(index, tuple):
1724            index = (index,)
1725
1726        new_shape = []
1727        new_offset = self.offset
1728        new_strides = []
1729
1730        seen_ellipsis = False
1731
1732        index_axis = 0
1733        array_axis = 0
1734        while index_axis < len(index):
1735            index_entry = index[index_axis]
1736
1737            if array_axis > len(self.shape):
1738                raise IndexError("too many axes in index")
1739
1740            if isinstance(index_entry, slice):
1741                start, stop, idx_stride = index_entry.indices(
1742                        self.shape[array_axis])
1743
1744                array_stride = self.strides[array_axis]
1745
1746                new_shape.append((abs(stop-start)-1)//abs(idx_stride)+1)
1747                new_strides.append(idx_stride*array_stride)
1748                new_offset += array_stride*start
1749
1750                index_axis += 1
1751                array_axis += 1
1752
1753            elif isinstance(index_entry, (int, np.integer)):
1754                array_shape = self.shape[array_axis]
1755                if index_entry < 0:
1756                    index_entry += array_shape
1757
1758                if not (0 <= index_entry < array_shape):
1759                    raise IndexError(
1760                            "subindex in axis %d out of range" % index_axis)
1761
1762                new_offset += self.strides[array_axis]*index_entry
1763
1764                index_axis += 1
1765                array_axis += 1
1766
1767            elif index_entry is Ellipsis:
1768                index_axis += 1
1769
1770                remaining_index_count = len(index) - index_axis
1771                new_array_axis = len(self.shape) - remaining_index_count
1772                if new_array_axis < array_axis:
1773                    raise IndexError("invalid use of ellipsis in index")
1774                while array_axis < new_array_axis:
1775                    new_shape.append(self.shape[array_axis])
1776                    new_strides.append(self.strides[array_axis])
1777                    array_axis += 1
1778
1779                if seen_ellipsis:
1780                    raise IndexError(
1781                            "more than one ellipsis not allowed in index")
1782                seen_ellipsis = True
1783
1784            elif index_entry is np.newaxis:
1785                new_shape.append(1)
1786                new_strides.append(0)
1787                index_axis += 1
1788
1789            else:
1790                raise IndexError("invalid subindex in axis %d" % index_axis)
1791
1792        while array_axis < len(self.shape):
1793            new_shape.append(self.shape[array_axis])
1794            new_strides.append(self.strides[array_axis])
1795
1796            array_axis += 1
1797
1798        return self._new_with_changes(
1799                self.base_data, offset=new_offset,
1800                shape=tuple(new_shape),
1801                strides=tuple(new_strides))
1802
1803    def setitem(self, subscript, value, queue=None, wait_for=None):
1804        """Like :meth:`__setitem__`, but with the ability to specify
1805        a *queue* and *wait_for*.
1806
1807        .. versionadded:: 2013.1
1808
1809        .. versionchanged:: 2013.2
1810
1811            Added *wait_for*.
1812        """
1813
1814        queue = queue or self.queue or value.queue
1815        if wait_for is None:
1816            wait_for = []
1817        wait_for = wait_for + self.events
1818
1819        if isinstance(subscript, Array):
1820            if subscript.dtype.kind != "i":
1821                raise TypeError(
1822                        "fancy indexing is only allowed with integers")
1823            if len(subscript.shape) != 1:
1824                raise NotImplementedError(
1825                        "multidimensional fancy indexing is not supported")
1826            if len(self.shape) != 1:
1827                raise NotImplementedError(
1828                        "fancy indexing into a multi-d array is not supported")
1829
1830            multi_put([value], subscript, out=[self], queue=queue,
1831                    wait_for=wait_for)
1832            return
1833
1834        subarray = self[subscript]
1835
1836        if isinstance(value, np.ndarray):
1837            if subarray.shape == value.shape and subarray.strides == value.strides:
1838                self.add_event(
1839                        cl.enqueue_copy(queue, subarray.base_data,
1840                            value, device_offset=subarray.offset, wait_for=wait_for))
1841                return
1842            else:
1843                value = to_device(queue, value, self.allocator)
1844
1845        if isinstance(value, Array):
1846            if len(subarray.shape) != len(value.shape):
1847                raise NotImplementedError("broadcasting is not "
1848                        "supported in __setitem__")
1849            if subarray.shape != value.shape:
1850                raise ValueError("cannot assign between arrays of "
1851                        "differing shapes")
1852            if subarray.strides != value.strides:
1853                raise ValueError("cannot assign between arrays of "
1854                        "differing strides")
1855
1856            self.add_event(
1857                    self._copy(subarray, value, queue=queue, wait_for=wait_for))
1858
1859        else:
1860            # Let's assume it's a scalar
1861            subarray.fill(value, queue=queue, wait_for=wait_for)
1862
1863    def __setitem__(self, subscript, value):
1864        """Set the slice of *self* identified *subscript* to *value*.
1865
1866        *value* is allowed to be:
1867
1868        * A :class:`Array` of the same :attr:`shape` and (for now) :attr:`strides`,
1869          but with potentially different :attr:`dtype`.
1870        * A :class:`numpy.ndarray` of the same :attr:`shape` and (for now)
1871          :attr:`strides`, but with potentially different :attr:`dtype`.
1872        * A scalar.
1873
1874        Non-scalar broadcasting is not currently supported.
1875
1876        .. versionadded:: 2013.1
1877        """
1878        self.setitem(subscript, value)
1879
1880    # }}}
1881
1882# }}}
1883
1884
1885# {{{ creation helpers
1886
1887def as_strided(ary, shape=None, strides=None):
1888    """Make an :class:`Array` from the given array with the given
1889    shape and strides.
1890    """
1891
1892    # undocumented for the moment
1893
1894    if shape is None:
1895        shape = ary.shape
1896    if strides is None:
1897        strides = ary.strides
1898
1899    return Array(ary.queue, shape, ary.dtype, allocator=ary.allocator,
1900            data=ary.data, strides=strides)
1901
1902
1903class _same_as_transfer(object):  # noqa
1904    pass
1905
1906
1907def to_device(queue, ary, allocator=None, async_=None,
1908        array_queue=_same_as_transfer, **kwargs):
1909    """Return a :class:`Array` that is an exact copy of the
1910    :class:`numpy.ndarray` instance *ary*.
1911
1912    :arg array_queue: The :class:`CommandQueue` which will
1913        be stored in the resulting array. Useful
1914        to make sure there is no implicit queue associated
1915        with the array by passing *None*.
1916
1917    See :class:`Array` for the meaning of *allocator*.
1918
1919    .. versionchanged:: 2015.2
1920        *array_queue* argument was added.
1921
1922    .. versionchanged:: 2017.2.1
1923
1924        Python 3.7 makes ``async`` a reserved keyword. On older Pythons,
1925        we will continue to  accept *async* as a parameter, however this
1926        should be considered deprecated. *async_* is the new, official
1927        spelling.
1928    """
1929
1930    # {{{ handle 'async' deprecation
1931
1932    async_arg = kwargs.pop("async", None)
1933    if async_arg is not None:
1934        if async_ is not None:
1935            raise TypeError("may not specify both 'async' and 'async_'")
1936        async_ = async_arg
1937
1938    if async_ is None:
1939        async_ = False
1940
1941    if kwargs:
1942        raise TypeError("extra keyword arguments specified: %s"
1943                % ", ".join(kwargs))
1944
1945    # }}}
1946
1947    if _dtype_is_object(ary.dtype):
1948        raise RuntimeError("to_device does not work on object arrays.")
1949
1950    if array_queue is _same_as_transfer:
1951        first_arg = queue
1952    else:
1953        first_arg = queue.context
1954
1955    result = Array(first_arg, ary.shape, ary.dtype,
1956                    allocator=allocator, strides=ary.strides)
1957    result.set(ary, async_=async_, queue=queue)
1958    return result
1959
1960
1961empty = Array
1962
1963
1964def zeros(queue, shape, dtype, order="C", allocator=None):
1965    """Same as :func:`empty`, but the :class:`Array` is zero-initialized before
1966    being returned.
1967
1968    .. versionchanged:: 2011.1
1969        *context* argument was deprecated.
1970    """
1971
1972    result = Array(queue, shape, dtype,
1973            order=order, allocator=allocator)
1974    result._zero_fill()
1975    return result
1976
1977
1978def empty_like(ary, queue=_copy_queue, allocator=None):
1979    """Make a new, uninitialized :class:`Array` having the same properties
1980    as *other_ary*.
1981    """
1982
1983    return ary._new_with_changes(data=None, offset=0, queue=queue,
1984            allocator=allocator)
1985
1986
1987def zeros_like(ary):
1988    """Make a new, zero-initialized :class:`Array` having the same properties
1989    as *other_ary*.
1990    """
1991
1992    result = empty_like(ary)
1993    result._zero_fill()
1994    return result
1995
1996
1997@elwise_kernel_runner
1998def _arange_knl(result, start, step):
1999    return elementwise.get_arange_kernel(
2000            result.context, result.dtype)
2001
2002
2003def arange(queue, *args, **kwargs):
2004    """Create a :class:`Array` filled with numbers spaced `step` apart,
2005    starting from `start` and ending at `stop`.
2006
2007    For floating point arguments, the length of the result is
2008    `ceil((stop - start)/step)`.  This rule may result in the last
2009    element of the result being greater than `stop`.
2010
2011    *dtype*, if not specified, is taken as the largest common type
2012    of *start*, *stop* and *step*.
2013
2014    .. versionchanged:: 2011.1
2015        *context* argument was deprecated.
2016
2017    .. versionchanged:: 2011.2
2018        *allocator* keyword argument was added.
2019    """
2020
2021    # argument processing -----------------------------------------------------
2022
2023    # Yuck. Thanks, numpy developers. ;)
2024    from pytools import Record
2025
2026    class Info(Record):
2027        pass
2028
2029    explicit_dtype = False
2030
2031    inf = Info()
2032    inf.start = None
2033    inf.stop = None
2034    inf.step = None
2035    inf.dtype = None
2036    inf.allocator = None
2037    inf.wait_for = []
2038
2039    if isinstance(args[-1], np.dtype):
2040        inf.dtype = args[-1]
2041        args = args[:-1]
2042        explicit_dtype = True
2043
2044    argc = len(args)
2045    if argc == 0:
2046        raise ValueError("stop argument required")
2047    elif argc == 1:
2048        inf.stop = args[0]
2049    elif argc == 2:
2050        inf.start = args[0]
2051        inf.stop = args[1]
2052    elif argc == 3:
2053        inf.start = args[0]
2054        inf.stop = args[1]
2055        inf.step = args[2]
2056    else:
2057        raise ValueError("too many arguments")
2058
2059    admissible_names = ["start", "stop", "step", "dtype", "allocator"]
2060    for k, v in six.iteritems(kwargs):
2061        if k in admissible_names:
2062            if getattr(inf, k) is None:
2063                setattr(inf, k, v)
2064                if k == "dtype":
2065                    explicit_dtype = True
2066            else:
2067                raise ValueError(
2068                        "may not specify '%s' by position and keyword" % k)
2069        else:
2070            raise ValueError("unexpected keyword argument '%s'" % k)
2071
2072    if inf.start is None:
2073        inf.start = 0
2074    if inf.step is None:
2075        inf.step = 1
2076    if inf.dtype is None:
2077        inf.dtype = np.array([inf.start, inf.stop, inf.step]).dtype
2078
2079    # actual functionality ----------------------------------------------------
2080    dtype = np.dtype(inf.dtype)
2081    start = dtype.type(inf.start)
2082    step = dtype.type(inf.step)
2083    stop = dtype.type(inf.stop)
2084    wait_for = inf.wait_for
2085
2086    if not explicit_dtype:
2087        raise TypeError("arange requires a dtype argument")
2088
2089    from math import ceil
2090    size = int(ceil((stop-start)/step))
2091
2092    result = Array(queue, (size,), dtype, allocator=inf.allocator)
2093    result.add_event(
2094            _arange_knl(result, start, step, queue=queue, wait_for=wait_for))
2095    return result
2096
2097# }}}
2098
2099
2100# {{{ take/put/concatenate/diff
2101
2102@elwise_kernel_runner
2103def _take(result, ary, indices):
2104    return elementwise.get_take_kernel(
2105            result.context, result.dtype, indices.dtype)
2106
2107
2108def take(a, indices, out=None, queue=None, wait_for=None):
2109    """Return the :class:`Array` ``[a[indices[0]], ..., a[indices[n]]]``.
2110    For the moment, *a* must be a type that can be bound to a texture.
2111    """
2112
2113    queue = queue or a.queue
2114    if out is None:
2115        out = Array(queue, indices.shape, a.dtype, allocator=a.allocator)
2116
2117    assert len(indices.shape) == 1
2118    out.add_event(
2119            _take(out, a, indices, queue=queue, wait_for=wait_for))
2120    return out
2121
2122
2123def multi_take(arrays, indices, out=None, queue=None):
2124    if not len(arrays):
2125        return []
2126
2127    assert len(indices.shape) == 1
2128
2129    from pytools import single_valued
2130    a_dtype = single_valued(a.dtype for a in arrays)
2131    a_allocator = arrays[0].dtype
2132    context = indices.context
2133    queue = queue or indices.queue
2134
2135    vec_count = len(arrays)
2136
2137    if out is None:
2138        out = [Array(context, queue, indices.shape, a_dtype,
2139            allocator=a_allocator)
2140                for i in range(vec_count)]
2141    else:
2142        if len(out) != len(arrays):
2143            raise ValueError("out and arrays must have the same length")
2144
2145    chunk_size = _builtin_min(vec_count, 10)
2146
2147    def make_func_for_chunk_size(chunk_size):
2148        knl = elementwise.get_take_kernel(
2149                indices.context, a_dtype, indices.dtype,
2150                vec_count=chunk_size)
2151        knl.set_block_shape(*indices._block)
2152        return knl
2153
2154    knl = make_func_for_chunk_size(chunk_size)
2155
2156    for start_i in range(0, len(arrays), chunk_size):
2157        chunk_slice = slice(start_i, start_i+chunk_size)
2158
2159        if start_i + chunk_size > vec_count:
2160            knl = make_func_for_chunk_size(vec_count-start_i)
2161
2162        gs, ls = indices.get_sizes(queue,
2163                knl.get_work_group_info(
2164                    cl.kernel_work_group_info.WORK_GROUP_SIZE,
2165                    queue.device))
2166
2167        wait_for_this = (indices.events
2168            + _builtin_sum((i.events for i in arrays[chunk_slice]), [])
2169            + _builtin_sum((o.events for o in out[chunk_slice]), []))
2170        evt = knl(queue, gs, ls,
2171                indices.data,
2172                *([o.data for o in out[chunk_slice]]
2173                    + [i.data for i in arrays[chunk_slice]]
2174                    + [indices.size]), wait_for=wait_for_this)
2175        for o in out[chunk_slice]:
2176            o.add_event(evt)
2177
2178    return out
2179
2180
2181def multi_take_put(arrays, dest_indices, src_indices, dest_shape=None,
2182        out=None, queue=None, src_offsets=None):
2183    if not len(arrays):
2184        return []
2185
2186    from pytools import single_valued
2187    a_dtype = single_valued(a.dtype for a in arrays)
2188    a_allocator = arrays[0].allocator
2189    context = src_indices.context
2190    queue = queue or src_indices.queue
2191
2192    vec_count = len(arrays)
2193
2194    if out is None:
2195        out = [Array(queue, dest_shape, a_dtype, allocator=a_allocator)
2196                for i in range(vec_count)]
2197    else:
2198        if a_dtype != single_valued(o.dtype for o in out):
2199            raise TypeError("arrays and out must have the same dtype")
2200        if len(out) != vec_count:
2201            raise ValueError("out and arrays must have the same length")
2202
2203    if src_indices.dtype != dest_indices.dtype:
2204        raise TypeError(
2205                "src_indices and dest_indices must have the same dtype")
2206
2207    if len(src_indices.shape) != 1:
2208        raise ValueError("src_indices must be 1D")
2209
2210    if src_indices.shape != dest_indices.shape:
2211        raise ValueError(
2212                "src_indices and dest_indices must have the same shape")
2213
2214    if src_offsets is None:
2215        src_offsets_list = []
2216    else:
2217        src_offsets_list = src_offsets
2218        if len(src_offsets) != vec_count:
2219            raise ValueError(
2220                    "src_indices and src_offsets must have the same length")
2221
2222    max_chunk_size = 10
2223
2224    chunk_size = _builtin_min(vec_count, max_chunk_size)
2225
2226    def make_func_for_chunk_size(chunk_size):
2227        return elementwise.get_take_put_kernel(context,
2228                a_dtype, src_indices.dtype,
2229                with_offsets=src_offsets is not None,
2230                vec_count=chunk_size)
2231
2232    knl = make_func_for_chunk_size(chunk_size)
2233
2234    for start_i in range(0, len(arrays), chunk_size):
2235        chunk_slice = slice(start_i, start_i+chunk_size)
2236
2237        if start_i + chunk_size > vec_count:
2238            knl = make_func_for_chunk_size(vec_count-start_i)
2239
2240        gs, ls = src_indices.get_sizes(queue,
2241                knl.get_work_group_info(
2242                    cl.kernel_work_group_info.WORK_GROUP_SIZE,
2243                    queue.device))
2244
2245        from pytools import flatten
2246        wait_for_this = (dest_indices.events + src_indices.events
2247            + _builtin_sum((i.events for i in arrays[chunk_slice]), [])
2248            + _builtin_sum((o.events for o in out[chunk_slice]), []))
2249        evt = knl(queue, gs, ls,
2250                *([o.data for o in out[chunk_slice]]
2251                    + [dest_indices.base_data,
2252                        dest_indices.offset,
2253                        src_indices.base_data,
2254                        src_indices.offset]
2255                    + list(flatten(
2256                        (i.base_data, i.offset)
2257                        for i in arrays[chunk_slice]))
2258                    + src_offsets_list[chunk_slice]
2259                    + [src_indices.size]), wait_for=wait_for_this)
2260        for o in out[chunk_slice]:
2261            o.add_event(evt)
2262
2263    return out
2264
2265
2266def multi_put(arrays, dest_indices, dest_shape=None, out=None, queue=None,
2267        wait_for=None):
2268    if not len(arrays):
2269        return []
2270
2271    from pytools import single_valued
2272    a_dtype = single_valued(a.dtype for a in arrays)
2273    a_allocator = arrays[0].allocator
2274    context = dest_indices.context
2275    queue = queue or dest_indices.queue
2276    if wait_for is None:
2277        wait_for = []
2278    wait_for = wait_for + dest_indices.events
2279
2280    vec_count = len(arrays)
2281
2282    if out is None:
2283        out = [Array(queue, dest_shape, a_dtype,
2284            allocator=a_allocator, queue=queue)
2285            for i in range(vec_count)]
2286    else:
2287        if a_dtype != single_valued(o.dtype for o in out):
2288            raise TypeError("arrays and out must have the same dtype")
2289        if len(out) != vec_count:
2290            raise ValueError("out and arrays must have the same length")
2291
2292    if len(dest_indices.shape) != 1:
2293        raise ValueError("dest_indices must be 1D")
2294
2295    chunk_size = _builtin_min(vec_count, 10)
2296
2297    # array of bools to specify whether the array of same index in this chunk
2298    # will be filled with a single value.
2299    use_fill = np.ndarray((chunk_size,), dtype=np.uint8)
2300    array_lengths = np.ndarray((chunk_size,), dtype=np.int64)
2301
2302    def make_func_for_chunk_size(chunk_size):
2303        knl = elementwise.get_put_kernel(
2304                context, a_dtype, dest_indices.dtype,
2305                vec_count=chunk_size)
2306        return knl
2307
2308    knl = make_func_for_chunk_size(chunk_size)
2309
2310    for start_i in range(0, len(arrays), chunk_size):
2311        chunk_slice = slice(start_i, start_i+chunk_size)
2312        for fill_idx, ary in enumerate(arrays[chunk_slice]):
2313            # If there is only one value in the values array for this src array
2314            # in the chunk then fill every index in `dest_idx` array with it.
2315            use_fill[fill_idx] = 1 if ary.size == 1 else 0
2316            array_lengths[fill_idx] = len(ary)
2317        # Copy the populated `use_fill` array to a buffer on the device.
2318        use_fill_cla = to_device(queue, use_fill)
2319        array_lengths_cla = to_device(queue, array_lengths)
2320
2321        if start_i + chunk_size > vec_count:
2322            knl = make_func_for_chunk_size(vec_count-start_i)
2323
2324        gs, ls = dest_indices.get_sizes(queue,
2325                knl.get_work_group_info(
2326                    cl.kernel_work_group_info.WORK_GROUP_SIZE,
2327                    queue.device))
2328
2329        from pytools import flatten
2330        wait_for_this = (wait_for
2331            + _builtin_sum((i.events for i in arrays[chunk_slice]), [])
2332            + _builtin_sum((o.events for o in out[chunk_slice]), []))
2333        evt = knl(queue, gs, ls,
2334                *(
2335                    list(flatten(
2336                        (o.base_data, o.offset)
2337                        for o in out[chunk_slice]))
2338                    + [dest_indices.base_data, dest_indices.offset]
2339                    + list(flatten(
2340                        (i.base_data, i.offset)
2341                        for i in arrays[chunk_slice]))
2342                    + [use_fill_cla.base_data, use_fill_cla.offset]
2343                    + [array_lengths_cla.base_data, array_lengths_cla.offset]
2344                    + [dest_indices.size]),
2345                **dict(wait_for=wait_for_this))
2346
2347        for o in out[chunk_slice]:
2348            o.add_event(evt)
2349
2350    return out
2351
2352
2353def concatenate(arrays, axis=0, queue=None, allocator=None):
2354    """
2355    .. versionadded:: 2013.1
2356    """
2357    # {{{ find properties of result array
2358
2359    shape = None
2360
2361    for i_ary, ary in enumerate(arrays):
2362        queue = queue or ary.queue
2363        allocator = allocator or ary.allocator
2364
2365        if shape is None:
2366            # first array
2367            shape = list(ary.shape)
2368        else:
2369            if len(ary.shape) != len(shape):
2370                raise ValueError("%d'th array has different number of axes "
2371                        "(shold have %d, has %d)"
2372                        % (i_ary, len(ary.shape), len(shape)))
2373
2374            ary_shape_list = list(ary.shape)
2375            if (ary_shape_list[:axis] != shape[:axis]
2376                    or ary_shape_list[axis+1:] != shape[axis+1:]):
2377                raise ValueError("%d'th array has residual not matching "
2378                        "other arrays" % i_ary)
2379
2380            shape[axis] += ary.shape[axis]
2381
2382    # }}}
2383
2384    shape = tuple(shape)
2385    dtype = np.find_common_type([ary.dtype for ary in arrays], [])
2386    result = empty(queue, shape, dtype, allocator=allocator)
2387
2388    full_slice = (slice(None),) * len(shape)
2389
2390    base_idx = 0
2391    for ary in arrays:
2392        my_len = ary.shape[axis]
2393        result.setitem(
2394                full_slice[:axis]
2395                + (slice(base_idx, base_idx+my_len),)
2396                + full_slice[axis+1:],
2397                ary)
2398
2399        base_idx += my_len
2400
2401    return result
2402
2403
2404@elwise_kernel_runner
2405def _diff(result, array):
2406    return elementwise.get_diff_kernel(array.context, array.dtype)
2407
2408
2409def diff(array, queue=None, allocator=None):
2410    """
2411    .. versionadded:: 2013.2
2412    """
2413
2414    if len(array.shape) != 1:
2415        raise ValueError("multi-D arrays are not supported")
2416
2417    n, = array.shape
2418
2419    queue = queue or array.queue
2420    allocator = allocator or array.allocator
2421
2422    result = empty(queue, (n-1,), array.dtype, allocator=allocator)
2423    event1 = _diff(result, array, queue=queue)
2424    result.add_event(event1)
2425    return result
2426
2427
2428def hstack(arrays, queue=None):
2429    from pyopencl.array import empty
2430
2431    if len(arrays) == 0:
2432        return empty(queue, (), dtype=np.float64)
2433
2434    if queue is None:
2435        for ary in arrays:
2436            if ary.queue is not None:
2437                queue = ary.queue
2438                break
2439
2440    from pytools import all_equal, single_valued
2441    if not all_equal(len(ary.shape) for ary in arrays):
2442        raise ValueError("arguments must all have the same number of axes")
2443
2444    lead_shape = single_valued(ary.shape[:-1] for ary in arrays)
2445
2446    w = _builtin_sum([ary.shape[-1] for ary in arrays])
2447    result = empty(queue, lead_shape+(w,), arrays[0].dtype)
2448    index = 0
2449    for ary in arrays:
2450        result[..., index:index+ary.shape[-1]] = ary
2451        index += ary.shape[-1]
2452
2453    return result
2454
2455# }}}
2456
2457
2458# {{{ shape manipulation
2459
2460def transpose(a, axes=None):
2461    """Permute the dimensions of an array.
2462
2463    :arg a: :class:`Array`
2464    :arg axes: list of ints, optional.
2465        By default, reverse the dimensions, otherwise permute the axes
2466        according to the values given.
2467
2468    :returns: :class:`Array` A view of the array with its axes permuted.
2469    """
2470    return a.transpose(axes)
2471
2472
2473def reshape(a, shape):
2474    """Gives a new shape to an array without changing its data.
2475
2476    .. versionadded:: 2015.2
2477    """
2478
2479    return a.reshape(shape)
2480
2481# }}}
2482
2483
2484# {{{ conditionals
2485
2486@elwise_kernel_runner
2487def _if_positive(result, criterion, then_, else_):
2488    return elementwise.get_if_positive_kernel(
2489            result.context, criterion.dtype, then_.dtype)
2490
2491
2492def if_positive(criterion, then_, else_, out=None, queue=None):
2493    """Return an array like *then_*, which, for the element at index *i*,
2494    contains *then_[i]* if *criterion[i]>0*, else *else_[i]*.
2495    """
2496
2497    if not (criterion.shape == then_.shape == else_.shape):
2498        raise ValueError("shapes do not match")
2499
2500    if not (then_.dtype == else_.dtype):
2501        raise ValueError("dtypes do not match")
2502
2503    if out is None:
2504        out = empty_like(then_)
2505    event1 = _if_positive(out, criterion, then_, else_, queue=queue)
2506    out.add_event(event1)
2507    return out
2508
2509
2510def maximum(a, b, out=None, queue=None):
2511    """Return the elementwise maximum of *a* and *b*."""
2512
2513    # silly, but functional
2514    return if_positive(a.mul_add(1, b, -1, queue=queue), a, b,
2515            queue=queue, out=out)
2516
2517
2518def minimum(a, b, out=None, queue=None):
2519    """Return the elementwise minimum of *a* and *b*."""
2520    # silly, but functional
2521    return if_positive(a.mul_add(1, b, -1, queue=queue), b, a,
2522            queue=queue, out=out)
2523
2524# }}}
2525
2526
2527# {{{ reductions
2528_builtin_sum = sum
2529_builtin_min = min
2530_builtin_max = max
2531
2532
2533def sum(a, dtype=None, queue=None, slice=None):
2534    """
2535    .. versionadded:: 2011.1
2536    """
2537    from pyopencl.reduction import get_sum_kernel
2538    krnl = get_sum_kernel(a.context, dtype, a.dtype)
2539    result, event1 = krnl(a, queue=queue, slice=slice, wait_for=a.events,
2540            return_event=True)
2541    result.add_event(event1)
2542    return result
2543
2544
2545def dot(a, b, dtype=None, queue=None, slice=None):
2546    """
2547    .. versionadded:: 2011.1
2548    """
2549    from pyopencl.reduction import get_dot_kernel
2550    krnl = get_dot_kernel(a.context, dtype, a.dtype, b.dtype)
2551    result, event1 = krnl(a, b, queue=queue, slice=slice,
2552            wait_for=a.events + b.events, return_event=True)
2553    result.add_event(event1)
2554    return result
2555
2556
2557def vdot(a, b, dtype=None, queue=None, slice=None):
2558    """Like :func:`numpy.vdot`.
2559
2560    .. versionadded:: 2013.1
2561    """
2562    from pyopencl.reduction import get_dot_kernel
2563    krnl = get_dot_kernel(a.context, dtype, a.dtype, b.dtype,
2564            conjugate_first=True)
2565    result, event1 = krnl(a, b, queue=queue, slice=slice,
2566            wait_for=a.events + b.events, return_event=True)
2567    result.add_event(event1)
2568    return result
2569
2570
2571def subset_dot(subset, a, b, dtype=None, queue=None, slice=None):
2572    """
2573    .. versionadded:: 2011.1
2574    """
2575    from pyopencl.reduction import get_subset_dot_kernel
2576    krnl = get_subset_dot_kernel(
2577            a.context, dtype, subset.dtype, a.dtype, b.dtype)
2578    result, event1 = krnl(subset, a, b, queue=queue, slice=slice,
2579            wait_for=subset.events + a.events + b.events, return_event=True)
2580    result.add_event(event1)
2581    return result
2582
2583
2584def _make_minmax_kernel(what):
2585    def f(a, queue=None):
2586        from pyopencl.reduction import get_minmax_kernel
2587        krnl = get_minmax_kernel(a.context, what, a.dtype)
2588        result, event1 = krnl(a, queue=queue, wait_for=a.events,
2589                return_event=True)
2590        result.add_event(event1)
2591        return result
2592
2593    return f
2594
2595
2596min = _make_minmax_kernel("min")
2597min.__doc__ = """
2598    .. versionadded:: 2011.1
2599    """
2600
2601max = _make_minmax_kernel("max")
2602max.__doc__ = """
2603    .. versionadded:: 2011.1
2604    """
2605
2606
2607def _make_subset_minmax_kernel(what):
2608    def f(subset, a, queue=None, slice=None):
2609        from pyopencl.reduction import get_subset_minmax_kernel
2610        krnl = get_subset_minmax_kernel(a.context, what, a.dtype, subset.dtype)
2611        result, event1 = krnl(subset, a,  queue=queue, slice=slice,
2612                wait_for=a.events + subset.events, return_event=True)
2613        result.add_event(event1)
2614        return result
2615    return f
2616
2617
2618subset_min = _make_subset_minmax_kernel("min")
2619subset_min.__doc__ = """.. versionadded:: 2011.1"""
2620subset_max = _make_subset_minmax_kernel("max")
2621subset_max.__doc__ = """.. versionadded:: 2011.1"""
2622
2623# }}}
2624
2625
2626# {{{ scans
2627
2628def cumsum(a, output_dtype=None, queue=None,
2629        wait_for=None, return_event=False):
2630    # undocumented for now
2631
2632    """
2633    .. versionadded:: 2013.1
2634    """
2635
2636    if output_dtype is None:
2637        output_dtype = a.dtype
2638    if wait_for is None:
2639        wait_for = []
2640
2641    result = a._new_like_me(output_dtype)
2642
2643    from pyopencl.scan import get_cumsum_kernel
2644    krnl = get_cumsum_kernel(a.context, a.dtype, output_dtype)
2645    evt = krnl(a, result, queue=queue, wait_for=wait_for + a.events)
2646    result.add_event(evt)
2647
2648    if return_event:
2649        return evt, result
2650    else:
2651        return result
2652
2653# }}}
2654
2655# vim: foldmethod=marker
2656