1from __future__ import annotations 2 3import ctypes 4import struct 5from collections.abc import Sequence 6 7import dask 8 9from ..utils import nbytes 10 11BIG_BYTES_SHARD_SIZE = dask.utils.parse_bytes(dask.config.get("distributed.comm.shard")) 12 13 14msgpack_opts = { 15 ("max_%s_len" % x): 2 ** 31 - 1 for x in ["str", "bin", "array", "map", "ext"] 16} 17msgpack_opts["strict_map_key"] = False 18msgpack_opts["raw"] = False 19 20 21def frame_split_size(frame, n=BIG_BYTES_SHARD_SIZE) -> list: 22 """ 23 Split a frame into a list of frames of maximum size 24 25 This helps us to avoid passing around very large bytestrings. 26 27 Examples 28 -------- 29 >>> frame_split_size([b'12345', b'678'], n=3) # doctest: +SKIP 30 [b'123', b'45', b'678'] 31 """ 32 n = n or BIG_BYTES_SHARD_SIZE 33 frame = memoryview(frame) 34 35 if frame.nbytes <= n: 36 return [frame] 37 38 nitems = frame.nbytes // frame.itemsize 39 items_per_shard = n // frame.itemsize 40 41 return [frame[i : i + items_per_shard] for i in range(0, nitems, items_per_shard)] 42 43 44def pack_frames_prelude(frames): 45 nframes = len(frames) 46 nbytes_frames = map(nbytes, frames) 47 return struct.pack(f"Q{nframes}Q", nframes, *nbytes_frames) 48 49 50def pack_frames(frames): 51 """Pack frames into a byte-like object 52 53 This prepends length information to the front of the bytes-like object 54 55 See Also 56 -------- 57 unpack_frames 58 """ 59 return b"".join([pack_frames_prelude(frames), *frames]) 60 61 62def unpack_frames(b): 63 """Unpack bytes into a sequence of frames 64 65 This assumes that length information is at the front of the bytestring, 66 as performed by pack_frames 67 68 See Also 69 -------- 70 pack_frames 71 """ 72 b = memoryview(b) 73 74 fmt = "Q" 75 fmt_size = struct.calcsize(fmt) 76 77 (n_frames,) = struct.unpack_from(fmt, b) 78 lengths = struct.unpack_from(f"{n_frames}{fmt}", b, fmt_size) 79 80 frames = [] 81 start = fmt_size * (1 + n_frames) 82 for length in lengths: 83 end = start + length 84 frames.append(b[start:end]) 85 start = end 86 87 return frames 88 89 90def merge_memoryviews(mvs: Sequence[memoryview]) -> memoryview: 91 """ 92 Zero-copy "concatenate" a sequence of contiguous memoryviews. 93 94 Returns a new memoryview which slices into the underlying buffer 95 to extract out the portion equivalent to all of ``mvs`` being concatenated. 96 97 All the memoryviews must: 98 * Share the same underlying buffer (``.obj``) 99 * When merged, cover a continuous portion of that buffer with no gaps 100 * Have the same strides 101 * Be 1-dimensional 102 * Have the same format 103 * Be contiguous 104 105 Raises ValueError if these conditions are not met. 106 """ 107 if not mvs: 108 return memoryview(bytearray()) 109 if len(mvs) == 1: 110 return mvs[0] 111 112 first = mvs[0] 113 if not isinstance(first, memoryview): 114 raise TypeError(f"Expected memoryview; got {type(first)}") 115 obj = first.obj 116 format = first.format 117 118 first_start_addr = 0 119 nbytes = 0 120 for i, mv in enumerate(mvs): 121 if not isinstance(mv, memoryview): 122 raise TypeError(f"{i}: expected memoryview; got {type(mv)}") 123 124 if mv.nbytes == 0: 125 continue 126 127 if mv.obj is not obj: 128 raise ValueError( 129 f"{i}: memoryview has different buffer: {mv.obj!r} vs {obj!r}" 130 ) 131 if not mv.contiguous: 132 raise ValueError(f"{i}: memoryview non-contiguous") 133 if mv.ndim != 1: 134 raise ValueError(f"{i}: memoryview has {mv.ndim} dimensions, not 1") 135 if mv.format != format: 136 raise ValueError(f"{i}: inconsistent format: {mv.format} vs {format}") 137 138 start_addr = address_of_memoryview(mv) 139 if first_start_addr == 0: 140 first_start_addr = start_addr 141 else: 142 expected_addr = first_start_addr + nbytes 143 if start_addr != expected_addr: 144 raise ValueError( 145 f"memoryview {i} does not start where the previous ends. " 146 f"Expected {expected_addr:x}, starts {start_addr - expected_addr} byte(s) away." 147 ) 148 nbytes += mv.nbytes 149 150 if nbytes == 0: 151 # all memoryviews were zero-length 152 assert len(first) == 0 153 return first 154 155 assert first_start_addr != 0, "Underlying buffer is null pointer?!" 156 157 base_mv = memoryview(obj).cast("B") 158 base_start_addr = address_of_memoryview(base_mv) 159 start_index = first_start_addr - base_start_addr 160 161 return base_mv[start_index : start_index + nbytes].cast(format) 162 163 164one_byte_carr = ctypes.c_byte * 1 165# ^ length and type don't matter, just use it to get the address of the first byte 166 167 168def address_of_memoryview(mv: memoryview) -> int: 169 """ 170 Get the pointer to the first byte of a memoryview's data. 171 172 If the memoryview is read-only, NumPy must be installed. 173 """ 174 # NOTE: this method relies on pointer arithmetic to figure out 175 # where each memoryview starts within the underlying buffer. 176 # There's no direct API to get the address of a memoryview, 177 # so we use a trick through ctypes and the buffer protocol: 178 # https://mattgwwalker.wordpress.com/2020/10/15/address-of-a-buffer-in-python/ 179 try: 180 carr = one_byte_carr.from_buffer(mv) 181 except TypeError: 182 # `mv` is read-only. `from_buffer` requires the buffer to be writeable. 183 # See https://bugs.python.org/issue11427 for discussion. 184 # This typically comes from `deserialize_bytes`, where `mv.obj` is an 185 # immutable bytestring. 186 pass 187 else: 188 return ctypes.addressof(carr) 189 190 try: 191 import numpy as np 192 except ImportError: 193 raise ValueError( 194 f"Cannot get address of read-only memoryview {mv} since NumPy is not installed." 195 ) 196 197 # NumPy doesn't mind read-only buffers. We could just use this method 198 # for all cases, but it's nice to use the pure-Python method for the common 199 # case of writeable buffers (created by TCP comms, for example). 200 return np.asarray(mv).__array_interface__["data"][0] 201