1from __future__ import annotations
2
3import ctypes
4import struct
5from collections.abc import Sequence
6
7import dask
8
9from ..utils import nbytes
10
11BIG_BYTES_SHARD_SIZE = dask.utils.parse_bytes(dask.config.get("distributed.comm.shard"))
12
13
14msgpack_opts = {
15    ("max_%s_len" % x): 2 ** 31 - 1 for x in ["str", "bin", "array", "map", "ext"]
16}
17msgpack_opts["strict_map_key"] = False
18msgpack_opts["raw"] = False
19
20
21def frame_split_size(frame, n=BIG_BYTES_SHARD_SIZE) -> list:
22    """
23    Split a frame into a list of frames of maximum size
24
25    This helps us to avoid passing around very large bytestrings.
26
27    Examples
28    --------
29    >>> frame_split_size([b'12345', b'678'], n=3)  # doctest: +SKIP
30    [b'123', b'45', b'678']
31    """
32    n = n or BIG_BYTES_SHARD_SIZE
33    frame = memoryview(frame)
34
35    if frame.nbytes <= n:
36        return [frame]
37
38    nitems = frame.nbytes // frame.itemsize
39    items_per_shard = n // frame.itemsize
40
41    return [frame[i : i + items_per_shard] for i in range(0, nitems, items_per_shard)]
42
43
44def pack_frames_prelude(frames):
45    nframes = len(frames)
46    nbytes_frames = map(nbytes, frames)
47    return struct.pack(f"Q{nframes}Q", nframes, *nbytes_frames)
48
49
50def pack_frames(frames):
51    """Pack frames into a byte-like object
52
53    This prepends length information to the front of the bytes-like object
54
55    See Also
56    --------
57    unpack_frames
58    """
59    return b"".join([pack_frames_prelude(frames), *frames])
60
61
62def unpack_frames(b):
63    """Unpack bytes into a sequence of frames
64
65    This assumes that length information is at the front of the bytestring,
66    as performed by pack_frames
67
68    See Also
69    --------
70    pack_frames
71    """
72    b = memoryview(b)
73
74    fmt = "Q"
75    fmt_size = struct.calcsize(fmt)
76
77    (n_frames,) = struct.unpack_from(fmt, b)
78    lengths = struct.unpack_from(f"{n_frames}{fmt}", b, fmt_size)
79
80    frames = []
81    start = fmt_size * (1 + n_frames)
82    for length in lengths:
83        end = start + length
84        frames.append(b[start:end])
85        start = end
86
87    return frames
88
89
90def merge_memoryviews(mvs: Sequence[memoryview]) -> memoryview:
91    """
92    Zero-copy "concatenate" a sequence of contiguous memoryviews.
93
94    Returns a new memoryview which slices into the underlying buffer
95    to extract out the portion equivalent to all of ``mvs`` being concatenated.
96
97    All the memoryviews must:
98    * Share the same underlying buffer (``.obj``)
99    * When merged, cover a continuous portion of that buffer with no gaps
100    * Have the same strides
101    * Be 1-dimensional
102    * Have the same format
103    * Be contiguous
104
105    Raises ValueError if these conditions are not met.
106    """
107    if not mvs:
108        return memoryview(bytearray())
109    if len(mvs) == 1:
110        return mvs[0]
111
112    first = mvs[0]
113    if not isinstance(first, memoryview):
114        raise TypeError(f"Expected memoryview; got {type(first)}")
115    obj = first.obj
116    format = first.format
117
118    first_start_addr = 0
119    nbytes = 0
120    for i, mv in enumerate(mvs):
121        if not isinstance(mv, memoryview):
122            raise TypeError(f"{i}: expected memoryview; got {type(mv)}")
123
124        if mv.nbytes == 0:
125            continue
126
127        if mv.obj is not obj:
128            raise ValueError(
129                f"{i}: memoryview has different buffer: {mv.obj!r} vs {obj!r}"
130            )
131        if not mv.contiguous:
132            raise ValueError(f"{i}: memoryview non-contiguous")
133        if mv.ndim != 1:
134            raise ValueError(f"{i}: memoryview has {mv.ndim} dimensions, not 1")
135        if mv.format != format:
136            raise ValueError(f"{i}: inconsistent format: {mv.format} vs {format}")
137
138        start_addr = address_of_memoryview(mv)
139        if first_start_addr == 0:
140            first_start_addr = start_addr
141        else:
142            expected_addr = first_start_addr + nbytes
143            if start_addr != expected_addr:
144                raise ValueError(
145                    f"memoryview {i} does not start where the previous ends. "
146                    f"Expected {expected_addr:x}, starts {start_addr - expected_addr} byte(s) away."
147                )
148        nbytes += mv.nbytes
149
150    if nbytes == 0:
151        # all memoryviews were zero-length
152        assert len(first) == 0
153        return first
154
155    assert first_start_addr != 0, "Underlying buffer is null pointer?!"
156
157    base_mv = memoryview(obj).cast("B")
158    base_start_addr = address_of_memoryview(base_mv)
159    start_index = first_start_addr - base_start_addr
160
161    return base_mv[start_index : start_index + nbytes].cast(format)
162
163
164one_byte_carr = ctypes.c_byte * 1
165# ^ length and type don't matter, just use it to get the address of the first byte
166
167
168def address_of_memoryview(mv: memoryview) -> int:
169    """
170    Get the pointer to the first byte of a memoryview's data.
171
172    If the memoryview is read-only, NumPy must be installed.
173    """
174    # NOTE: this method relies on pointer arithmetic to figure out
175    # where each memoryview starts within the underlying buffer.
176    # There's no direct API to get the address of a memoryview,
177    # so we use a trick through ctypes and the buffer protocol:
178    # https://mattgwwalker.wordpress.com/2020/10/15/address-of-a-buffer-in-python/
179    try:
180        carr = one_byte_carr.from_buffer(mv)
181    except TypeError:
182        # `mv` is read-only. `from_buffer` requires the buffer to be writeable.
183        # See https://bugs.python.org/issue11427 for discussion.
184        # This typically comes from `deserialize_bytes`, where `mv.obj` is an
185        # immutable bytestring.
186        pass
187    else:
188        return ctypes.addressof(carr)
189
190    try:
191        import numpy as np
192    except ImportError:
193        raise ValueError(
194            f"Cannot get address of read-only memoryview {mv} since NumPy is not installed."
195        )
196
197    # NumPy doesn't mind read-only buffers. We could just use this method
198    # for all cases, but it's nice to use the pure-Python method for the common
199    # case of writeable buffers (created by TCP comms, for example).
200    return np.asarray(mv).__array_interface__["data"][0]
201