1from __future__ import division
2
3__copyright__ = "Copyright (C) 2011 Andreas Kloeckner"
4
5__license__ = """
6Permission is hereby granted, free of charge, to any person obtaining a copy
7of this software and associated documentation files (the "Software"), to deal
8in the Software without restriction, including without limitation the rights
9to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10copies of the Software, and to permit persons to whom the Software is
11furnished to do so, subject to the following conditions:
12
13The above copyright notice and this permission notice shall be included in
14all copies or substantial portions of the Software.
15
16THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
22THE SOFTWARE.
23"""
24
25import numpy as np
26
27
28def f_contiguous_strides(itemsize, shape):
29    if shape:
30        strides = [itemsize]
31        for s in shape[:-1]:
32            strides.append(strides[-1]*s)
33        return tuple(strides)
34    else:
35        return ()
36
37
38def c_contiguous_strides(itemsize, shape):
39    if shape:
40        strides = [itemsize]
41        for s in shape[:0:-1]:
42            strides.append(strides[-1]*s)
43        return tuple(strides[::-1])
44    else:
45        return ()
46
47
48def equal_strides(strides1, strides2, shape):
49    if len(strides1) != len(strides2) or len(strides2) != len(shape):
50        return False
51
52    for s, st1, st2 in zip(shape, strides1, strides2):
53        if s != 1 and st1 != st2:
54            return False
55
56    return True
57
58
59def is_f_contiguous_strides(strides, itemsize, shape):
60    return equal_strides(strides, f_contiguous_strides(itemsize, shape), shape)
61
62
63def is_c_contiguous_strides(strides, itemsize, shape):
64    return equal_strides(strides, c_contiguous_strides(itemsize, shape), shape)
65
66
67class ArrayFlags:
68    def __init__(self, ary):
69        self.f_contiguous = is_f_contiguous_strides(
70            ary.strides, ary.dtype.itemsize, ary.shape)
71        self.c_contiguous = is_c_contiguous_strides(
72            ary.strides, ary.dtype.itemsize, ary.shape)
73        self.forc = self.f_contiguous or self.c_contiguous
74
75
76def get_common_dtype(obj1, obj2, allow_double):
77    # Yes, numpy behaves differently depending on whether
78    # we're dealing with arrays or scalars.
79
80    zero1 = np.zeros(1, dtype=obj1.dtype)
81
82    try:
83        zero2 = np.zeros(1, dtype=obj2.dtype)
84    except AttributeError:
85        zero2 = obj2
86
87    result = (zero1 + zero2).dtype
88
89    if not allow_double:
90        if result == np.float64:
91            result = np.dtype(np.float32)
92        elif result == np.complex128:
93            result = np.dtype(np.complex64)
94
95    return result
96
97
98def bound(a):
99    high = a.bytes
100    low = a.bytes
101
102    for stri, shp in zip(a.strides, a.shape):
103        if stri < 0:
104            low += (stri)*(shp-1)
105        else:
106            high += (stri)*(shp-1)
107    return low, high
108
109
110def may_share_memory(a, b):
111    # When this is called with a an ndarray and b
112    # a sparse matrix, numpy.may_share_memory fails.
113    if a is b:
114        return True
115    if a.__class__ is b.__class__:
116        a_l, a_h = bound(a)
117        b_l, b_h = bound(b)
118        if b_l >= a_h or a_l >= b_h:
119            return False
120        return True
121    else:
122        return False
123
124
125# {{{ as_strided implementation
126
127try:
128    from numpy.lib.stride_tricks import as_strided as _as_strided
129    _test_dtype = np.dtype(
130            [("a", np.float64), ("b", np.float64)], align=True)
131    _test_result = _as_strided(np.zeros(10, dtype=_test_dtype))
132    if _test_result.dtype != _test_dtype:
133        raise RuntimeError("numpy's as_strided is broken")
134
135    as_strided = _as_strided
136except:
137    # stolen from numpy to be compatible with older versions of numpy
138    class _DummyArray(object):
139        """ Dummy object that just exists to hang __array_interface__ dictionaries
140        and possibly keep alive a reference to a base array.
141        """
142        def __init__(self, interface, base=None):
143            self.__array_interface__ = interface
144            self.base = base
145
146    def as_strided(x, shape=None, strides=None):
147        """ Make an ndarray from the given array with the given shape and strides.
148        """
149        # work around Numpy bug 1873 (reported by Irwin Zaid)
150        # Since this is stolen from numpy, this implementation has the same bug.
151        # http://projects.scipy.org/numpy/ticket/1873
152        # == https://github.com/numpy/numpy/issues/2466
153
154        # Do not recreate the array if nothing need to be changed.
155        # This fixes a lot of errors on pypy since DummyArray hack does not
156        # currently (2014/May/17) on pypy.
157
158        if ((shape is None or x.shape == shape) and
159            (strides is None or x.strides == strides)):
160            return x
161        if not x.dtype.isbuiltin:
162            if shape is None:
163                shape = x.shape
164            strides = tuple(strides)
165
166            from pytools import product
167            if strides is not None and shape is not None \
168                    and product(shape) == product(x.shape) \
169                    and x.flags.forc:
170                # Workaround: If we're being asked to do what amounts to a
171                # contiguous reshape, at least do that.
172
173                if strides == f_contiguous_strides(x.dtype.itemsize, shape):
174                    # **dict is a workaround for Python 2.5 syntax.
175                    result = x.reshape(-1).reshape(*shape, **dict(order="F"))
176                    assert result.strides == strides
177                    return result
178                elif strides == c_contiguous_strides(x.dtype.itemsize, shape):
179                    # **dict is a workaround for Python 2.5 syntax.
180                    result = x.reshape(-1).reshape(*shape, **dict(order="C"))
181                    assert result.strides == strides
182                    return result
183
184            raise NotImplementedError(
185                    "as_strided won't work on non-builtin arrays for now. "
186                    "See https://github.com/numpy/numpy/issues/2466")
187
188        interface = dict(x.__array_interface__)
189        if shape is not None:
190            interface['shape'] = tuple(shape)
191        if strides is not None:
192            interface['strides'] = tuple(strides)
193        return np.asarray(_DummyArray(interface, base=x))
194
195# }}}
196