1import math 2 3import numpy as np 4 5from ..utils import log_errors 6from . import pickle 7from .serialize import dask_deserialize, dask_serialize 8 9 10def itemsize(dt): 11 """Itemsize of dtype 12 13 Try to return the itemsize of the base element, return 8 as a fallback 14 """ 15 result = dt.base.itemsize 16 if result > 255: 17 result = 8 18 return result 19 20 21@dask_serialize.register(np.ndarray) 22def serialize_numpy_ndarray(x, context=None): 23 if x.dtype.hasobject or (x.dtype.flags & np.core.multiarray.LIST_PICKLE): 24 header = {"pickle": True} 25 frames = [None] 26 buffer_callback = lambda f: frames.append(memoryview(f)) 27 frames[0] = pickle.dumps( 28 x, 29 buffer_callback=buffer_callback, 30 protocol=(context or {}).get("pickle-protocol", None), 31 ) 32 return header, frames 33 34 # We cannot blindly pickle the dtype as some may fail pickling, 35 # so we have a mixture of strategies. 36 if x.dtype.kind == "V": 37 # Preserving all the information works best when pickling 38 try: 39 # Only use stdlib pickle as cloudpickle is slow when failing 40 # (microseconds instead of nanoseconds) 41 dt = ( 42 1, 43 pickle.pickle.dumps( 44 x.dtype, protocol=(context or {}).get("pickle-protocol", None) 45 ), 46 ) 47 pickle.loads(dt[1]) # does it unpickle fine? 48 except Exception: 49 # dtype fails pickling => fall back on the descr if reasonable. 50 if x.dtype.type is not np.void or x.dtype.alignment != 1: 51 raise 52 else: 53 dt = (0, x.dtype.descr) 54 else: 55 dt = (0, x.dtype.str) 56 57 # Only serialize broadcastable data for arrays with zero strided axes 58 broadcast_to = None 59 if 0 in x.strides: 60 broadcast_to = x.shape 61 strides = x.strides 62 writeable = x.flags.writeable 63 x = x[tuple(slice(None) if s != 0 else slice(1) for s in strides)] 64 if not x.flags.c_contiguous and not x.flags.f_contiguous: 65 # Broadcasting can only be done with contiguous arrays 66 x = np.ascontiguousarray(x) 67 x = np.lib.stride_tricks.as_strided( 68 x, 69 strides=[j if i != 0 else i for i, j in zip(strides, x.strides)], 70 writeable=writeable, 71 ) 72 73 if not x.shape: 74 # 0d array 75 strides = x.strides 76 data = x.ravel() 77 elif x.flags.c_contiguous or x.flags.f_contiguous: 78 # Avoid a copy and respect order when unserializing 79 strides = x.strides 80 data = x.ravel(order="K") 81 else: 82 x = np.ascontiguousarray(x) 83 strides = x.strides 84 data = x.ravel() 85 86 if data.dtype.fields or data.dtype.itemsize > 8: 87 data = data.view("u%d" % math.gcd(x.dtype.itemsize, 8)) 88 89 try: 90 data = data.data 91 except ValueError: 92 # "ValueError: cannot include dtype 'M' in a buffer" 93 data = data.view("u%d" % math.gcd(x.dtype.itemsize, 8)).data 94 95 header = { 96 "dtype": dt, 97 "shape": x.shape, 98 "strides": strides, 99 "writeable": [x.flags.writeable], 100 } 101 102 if broadcast_to is not None: 103 header["broadcast_to"] = broadcast_to 104 105 frames = [data] 106 return header, frames 107 108 109@dask_deserialize.register(np.ndarray) 110def deserialize_numpy_ndarray(header, frames): 111 with log_errors(): 112 if header.get("pickle"): 113 return pickle.loads(frames[0], buffers=frames[1:]) 114 115 (frame,) = frames 116 (writeable,) = header["writeable"] 117 118 is_custom, dt = header["dtype"] 119 if is_custom: 120 dt = pickle.loads(dt) 121 else: 122 dt = np.dtype(dt) 123 124 if header.get("broadcast_to"): 125 shape = header["broadcast_to"] 126 else: 127 shape = header["shape"] 128 129 x = np.ndarray(shape, dtype=dt, buffer=frame, strides=header["strides"]) 130 if not writeable: 131 x.flags.writeable = False 132 else: 133 x = np.require(x, requirements=["W"]) 134 135 return x 136 137 138@dask_serialize.register(np.ma.core.MaskedConstant) 139def serialize_numpy_ma_masked(x): 140 return {}, [] 141 142 143@dask_deserialize.register(np.ma.core.MaskedConstant) 144def deserialize_numpy_ma_masked(header, frames): 145 return np.ma.masked 146 147 148@dask_serialize.register(np.ma.core.MaskedArray) 149def serialize_numpy_maskedarray(x, context=None): 150 data_header, frames = serialize_numpy_ndarray(x.data) 151 header = {"data-header": data_header, "nframes": len(frames)} 152 153 # Serialize mask if present 154 if x.mask is not np.ma.nomask: 155 mask_header, mask_frames = serialize_numpy_ndarray(x.mask) 156 header["mask-header"] = mask_header 157 frames += mask_frames 158 159 # Only a few dtypes have python equivalents msgpack can serialize 160 if isinstance(x.fill_value, (np.integer, np.floating, np.bool_)): 161 serialized_fill_value = (False, x.fill_value.item()) 162 else: 163 serialized_fill_value = ( 164 True, 165 pickle.dumps( 166 x.fill_value, protocol=(context or {}).get("pickle-protocol", None) 167 ), 168 ) 169 header["fill-value"] = serialized_fill_value 170 171 return header, frames 172 173 174@dask_deserialize.register(np.ma.core.MaskedArray) 175def deserialize_numpy_maskedarray(header, frames): 176 data_header = header["data-header"] 177 data_frames = frames[: header["nframes"]] 178 data = deserialize_numpy_ndarray(data_header, data_frames) 179 180 if "mask-header" in header: 181 mask_header = header["mask-header"] 182 mask_frames = frames[header["nframes"] :] 183 mask = deserialize_numpy_ndarray(mask_header, mask_frames) 184 else: 185 mask = np.ma.nomask 186 187 pickled_fv, fill_value = header["fill-value"] 188 if pickled_fv: 189 fill_value = pickle.loads(fill_value) 190 191 return np.ma.masked_array(data, mask=mask, fill_value=fill_value) 192