1from __future__ import division 2 3__copyright__ = "Copyright (C) 2011 Andreas Kloeckner" 4 5__license__ = """ 6Permission is hereby granted, free of charge, to any person obtaining a copy 7of this software and associated documentation files (the "Software"), to deal 8in the Software without restriction, including without limitation the rights 9to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 10copies of the Software, and to permit persons to whom the Software is 11furnished to do so, subject to the following conditions: 12 13The above copyright notice and this permission notice shall be included in 14all copies or substantial portions of the Software. 15 16THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN 22THE SOFTWARE. 23""" 24 25import numpy as np 26 27 28def f_contiguous_strides(itemsize, shape): 29 if shape: 30 strides = [itemsize] 31 for s in shape[:-1]: 32 strides.append(strides[-1]*s) 33 return tuple(strides) 34 else: 35 return () 36 37 38def c_contiguous_strides(itemsize, shape): 39 if shape: 40 strides = [itemsize] 41 for s in shape[:0:-1]: 42 strides.append(strides[-1]*s) 43 return tuple(strides[::-1]) 44 else: 45 return () 46 47 48def equal_strides(strides1, strides2, shape): 49 if len(strides1) != len(strides2) or len(strides2) != len(shape): 50 return False 51 52 for s, st1, st2 in zip(shape, strides1, strides2): 53 if s != 1 and st1 != st2: 54 return False 55 56 return True 57 58 59def is_f_contiguous_strides(strides, itemsize, shape): 60 return equal_strides(strides, f_contiguous_strides(itemsize, shape), shape) 61 62 63def is_c_contiguous_strides(strides, itemsize, shape): 64 return equal_strides(strides, c_contiguous_strides(itemsize, shape), shape) 65 66 67class ArrayFlags: 68 def __init__(self, ary): 69 self.f_contiguous = is_f_contiguous_strides( 70 ary.strides, ary.dtype.itemsize, ary.shape) 71 self.c_contiguous = is_c_contiguous_strides( 72 ary.strides, ary.dtype.itemsize, ary.shape) 73 self.forc = self.f_contiguous or self.c_contiguous 74 75 76def get_common_dtype(obj1, obj2, allow_double): 77 # Yes, numpy behaves differently depending on whether 78 # we're dealing with arrays or scalars. 79 80 zero1 = np.zeros(1, dtype=obj1.dtype) 81 82 try: 83 zero2 = np.zeros(1, dtype=obj2.dtype) 84 except AttributeError: 85 zero2 = obj2 86 87 result = (zero1 + zero2).dtype 88 89 if not allow_double: 90 if result == np.float64: 91 result = np.dtype(np.float32) 92 elif result == np.complex128: 93 result = np.dtype(np.complex64) 94 95 return result 96 97 98def bound(a): 99 high = a.bytes 100 low = a.bytes 101 102 for stri, shp in zip(a.strides, a.shape): 103 if stri < 0: 104 low += (stri)*(shp-1) 105 else: 106 high += (stri)*(shp-1) 107 return low, high 108 109 110def may_share_memory(a, b): 111 # When this is called with a an ndarray and b 112 # a sparse matrix, numpy.may_share_memory fails. 113 if a is b: 114 return True 115 if a.__class__ is b.__class__: 116 a_l, a_h = bound(a) 117 b_l, b_h = bound(b) 118 if b_l >= a_h or a_l >= b_h: 119 return False 120 return True 121 else: 122 return False 123 124 125# {{{ as_strided implementation 126 127try: 128 from numpy.lib.stride_tricks import as_strided as _as_strided 129 _test_dtype = np.dtype( 130 [("a", np.float64), ("b", np.float64)], align=True) 131 _test_result = _as_strided(np.zeros(10, dtype=_test_dtype)) 132 if _test_result.dtype != _test_dtype: 133 raise RuntimeError("numpy's as_strided is broken") 134 135 as_strided = _as_strided 136except: 137 # stolen from numpy to be compatible with older versions of numpy 138 class _DummyArray(object): 139 """ Dummy object that just exists to hang __array_interface__ dictionaries 140 and possibly keep alive a reference to a base array. 141 """ 142 def __init__(self, interface, base=None): 143 self.__array_interface__ = interface 144 self.base = base 145 146 def as_strided(x, shape=None, strides=None): 147 """ Make an ndarray from the given array with the given shape and strides. 148 """ 149 # work around Numpy bug 1873 (reported by Irwin Zaid) 150 # Since this is stolen from numpy, this implementation has the same bug. 151 # http://projects.scipy.org/numpy/ticket/1873 152 # == https://github.com/numpy/numpy/issues/2466 153 154 # Do not recreate the array if nothing need to be changed. 155 # This fixes a lot of errors on pypy since DummyArray hack does not 156 # currently (2014/May/17) on pypy. 157 158 if ((shape is None or x.shape == shape) and 159 (strides is None or x.strides == strides)): 160 return x 161 if not x.dtype.isbuiltin: 162 if shape is None: 163 shape = x.shape 164 strides = tuple(strides) 165 166 from pytools import product 167 if strides is not None and shape is not None \ 168 and product(shape) == product(x.shape) \ 169 and x.flags.forc: 170 # Workaround: If we're being asked to do what amounts to a 171 # contiguous reshape, at least do that. 172 173 if strides == f_contiguous_strides(x.dtype.itemsize, shape): 174 # **dict is a workaround for Python 2.5 syntax. 175 result = x.reshape(-1).reshape(*shape, **dict(order="F")) 176 assert result.strides == strides 177 return result 178 elif strides == c_contiguous_strides(x.dtype.itemsize, shape): 179 # **dict is a workaround for Python 2.5 syntax. 180 result = x.reshape(-1).reshape(*shape, **dict(order="C")) 181 assert result.strides == strides 182 return result 183 184 raise NotImplementedError( 185 "as_strided won't work on non-builtin arrays for now. " 186 "See https://github.com/numpy/numpy/issues/2466") 187 188 interface = dict(x.__array_interface__) 189 if shape is not None: 190 interface['shape'] = tuple(shape) 191 if strides is not None: 192 interface['strides'] = tuple(strides) 193 return np.asarray(_DummyArray(interface, base=x)) 194 195# }}} 196