1import math
2
3import numpy as np
4
5from ..utils import log_errors
6from . import pickle
7from .serialize import dask_deserialize, dask_serialize
8
9
10def itemsize(dt):
11    """Itemsize of dtype
12
13    Try to return the itemsize of the base element, return 8 as a fallback
14    """
15    result = dt.base.itemsize
16    if result > 255:
17        result = 8
18    return result
19
20
21@dask_serialize.register(np.ndarray)
22def serialize_numpy_ndarray(x, context=None):
23    if x.dtype.hasobject or (x.dtype.flags & np.core.multiarray.LIST_PICKLE):
24        header = {"pickle": True}
25        frames = [None]
26        buffer_callback = lambda f: frames.append(memoryview(f))
27        frames[0] = pickle.dumps(
28            x,
29            buffer_callback=buffer_callback,
30            protocol=(context or {}).get("pickle-protocol", None),
31        )
32        return header, frames
33
34    # We cannot blindly pickle the dtype as some may fail pickling,
35    # so we have a mixture of strategies.
36    if x.dtype.kind == "V":
37        # Preserving all the information works best when pickling
38        try:
39            # Only use stdlib pickle as cloudpickle is slow when failing
40            # (microseconds instead of nanoseconds)
41            dt = (
42                1,
43                pickle.pickle.dumps(
44                    x.dtype, protocol=(context or {}).get("pickle-protocol", None)
45                ),
46            )
47            pickle.loads(dt[1])  # does it unpickle fine?
48        except Exception:
49            # dtype fails pickling => fall back on the descr if reasonable.
50            if x.dtype.type is not np.void or x.dtype.alignment != 1:
51                raise
52            else:
53                dt = (0, x.dtype.descr)
54    else:
55        dt = (0, x.dtype.str)
56
57    # Only serialize broadcastable data for arrays with zero strided axes
58    broadcast_to = None
59    if 0 in x.strides:
60        broadcast_to = x.shape
61        strides = x.strides
62        writeable = x.flags.writeable
63        x = x[tuple(slice(None) if s != 0 else slice(1) for s in strides)]
64        if not x.flags.c_contiguous and not x.flags.f_contiguous:
65            # Broadcasting can only be done with contiguous arrays
66            x = np.ascontiguousarray(x)
67            x = np.lib.stride_tricks.as_strided(
68                x,
69                strides=[j if i != 0 else i for i, j in zip(strides, x.strides)],
70                writeable=writeable,
71            )
72
73    if not x.shape:
74        # 0d array
75        strides = x.strides
76        data = x.ravel()
77    elif x.flags.c_contiguous or x.flags.f_contiguous:
78        # Avoid a copy and respect order when unserializing
79        strides = x.strides
80        data = x.ravel(order="K")
81    else:
82        x = np.ascontiguousarray(x)
83        strides = x.strides
84        data = x.ravel()
85
86    if data.dtype.fields or data.dtype.itemsize > 8:
87        data = data.view("u%d" % math.gcd(x.dtype.itemsize, 8))
88
89    try:
90        data = data.data
91    except ValueError:
92        # "ValueError: cannot include dtype 'M' in a buffer"
93        data = data.view("u%d" % math.gcd(x.dtype.itemsize, 8)).data
94
95    header = {
96        "dtype": dt,
97        "shape": x.shape,
98        "strides": strides,
99        "writeable": [x.flags.writeable],
100    }
101
102    if broadcast_to is not None:
103        header["broadcast_to"] = broadcast_to
104
105    frames = [data]
106    return header, frames
107
108
109@dask_deserialize.register(np.ndarray)
110def deserialize_numpy_ndarray(header, frames):
111    with log_errors():
112        if header.get("pickle"):
113            return pickle.loads(frames[0], buffers=frames[1:])
114
115        (frame,) = frames
116        (writeable,) = header["writeable"]
117
118        is_custom, dt = header["dtype"]
119        if is_custom:
120            dt = pickle.loads(dt)
121        else:
122            dt = np.dtype(dt)
123
124        if header.get("broadcast_to"):
125            shape = header["broadcast_to"]
126        else:
127            shape = header["shape"]
128
129        x = np.ndarray(shape, dtype=dt, buffer=frame, strides=header["strides"])
130        if not writeable:
131            x.flags.writeable = False
132        else:
133            x = np.require(x, requirements=["W"])
134
135        return x
136
137
138@dask_serialize.register(np.ma.core.MaskedConstant)
139def serialize_numpy_ma_masked(x):
140    return {}, []
141
142
143@dask_deserialize.register(np.ma.core.MaskedConstant)
144def deserialize_numpy_ma_masked(header, frames):
145    return np.ma.masked
146
147
148@dask_serialize.register(np.ma.core.MaskedArray)
149def serialize_numpy_maskedarray(x, context=None):
150    data_header, frames = serialize_numpy_ndarray(x.data)
151    header = {"data-header": data_header, "nframes": len(frames)}
152
153    # Serialize mask if present
154    if x.mask is not np.ma.nomask:
155        mask_header, mask_frames = serialize_numpy_ndarray(x.mask)
156        header["mask-header"] = mask_header
157        frames += mask_frames
158
159    # Only a few dtypes have python equivalents msgpack can serialize
160    if isinstance(x.fill_value, (np.integer, np.floating, np.bool_)):
161        serialized_fill_value = (False, x.fill_value.item())
162    else:
163        serialized_fill_value = (
164            True,
165            pickle.dumps(
166                x.fill_value, protocol=(context or {}).get("pickle-protocol", None)
167            ),
168        )
169    header["fill-value"] = serialized_fill_value
170
171    return header, frames
172
173
174@dask_deserialize.register(np.ma.core.MaskedArray)
175def deserialize_numpy_maskedarray(header, frames):
176    data_header = header["data-header"]
177    data_frames = frames[: header["nframes"]]
178    data = deserialize_numpy_ndarray(data_header, data_frames)
179
180    if "mask-header" in header:
181        mask_header = header["mask-header"]
182        mask_frames = frames[header["nframes"] :]
183        mask = deserialize_numpy_ndarray(mask_header, mask_frames)
184    else:
185        mask = np.ma.nomask
186
187    pickled_fv, fill_value = header["fill-value"]
188    if pickled_fv:
189        fill_value = pickle.loads(fill_value)
190
191    return np.ma.masked_array(data, mask=mask, fill_value=fill_value)
192