1import logging
2import math
3import socket
4
5import dask
6from dask.sizeof import sizeof
7from dask.utils import parse_bytes
8
9from .. import protocol
10from ..utils import get_ip, get_ipv6, nbytes, offload
11
12logger = logging.getLogger(__name__)
13
14
15# Offload (de)serializing large frames to improve event loop responsiveness.
16OFFLOAD_THRESHOLD = dask.config.get("distributed.comm.offload")
17if isinstance(OFFLOAD_THRESHOLD, str):
18    OFFLOAD_THRESHOLD = parse_bytes(OFFLOAD_THRESHOLD)
19
20
21async def to_frames(
22    msg,
23    allow_offload=True,
24    **kwargs,
25):
26    """
27    Serialize a message into a list of Distributed protocol frames.
28    Any kwargs are forwarded to protocol.dumps().
29    """
30
31    def _to_frames():
32        try:
33            return list(protocol.dumps(msg, **kwargs))
34        except Exception as e:
35            logger.info("Unserializable Message: %s", msg)
36            logger.exception(e)
37            raise
38
39    if OFFLOAD_THRESHOLD and allow_offload:
40        try:
41            msg_size = sizeof(msg)
42        except RecursionError:
43            msg_size = math.inf
44    else:
45        msg_size = 0
46
47    if allow_offload and OFFLOAD_THRESHOLD and msg_size > OFFLOAD_THRESHOLD:
48        return await offload(_to_frames)
49    else:
50        return _to_frames()
51
52
53async def from_frames(frames, deserialize=True, deserializers=None, allow_offload=True):
54    """
55    Unserialize a list of Distributed protocol frames.
56    """
57    size = False
58
59    def _from_frames():
60        try:
61            return protocol.loads(
62                frames, deserialize=deserialize, deserializers=deserializers
63            )
64        except EOFError:
65            if size > 1000:
66                datastr = "[too large to display]"
67            else:
68                datastr = frames
69            # Aid diagnosing
70            logger.error("truncated data stream (%d bytes): %s", size, datastr)
71            raise
72
73    if allow_offload and deserialize and OFFLOAD_THRESHOLD:
74        size = sum(map(nbytes, frames))
75    if allow_offload and deserialize and OFFLOAD_THRESHOLD and size > OFFLOAD_THRESHOLD:
76        res = await offload(_from_frames)
77    else:
78        res = _from_frames()
79
80    return res
81
82
83def get_tcp_server_addresses(tcp_server):
84    """
85    Get all bound addresses of a started Tornado TCPServer.
86    """
87    sockets = list(tcp_server._sockets.values())
88    if not sockets:
89        raise RuntimeError(f"TCP Server {tcp_server!r} not started yet?")
90
91    def _look_for_family(fam):
92        socks = []
93        for sock in sockets:
94            if sock.family == fam:
95                socks.append(sock)
96        return socks
97
98    # If listening on both IPv4 and IPv6, prefer IPv4 as defective IPv6
99    # is common (e.g. Travis-CI).
100    socks = _look_for_family(socket.AF_INET)
101    if not socks:
102        socks = _look_for_family(socket.AF_INET6)
103    if not socks:
104        raise RuntimeError("No Internet socket found on TCPServer??")
105
106    return [sock.getsockname() for sock in socks]
107
108
109def get_tcp_server_address(tcp_server):
110    """
111    Get the first bound address of a started Tornado TCPServer.
112    """
113    return get_tcp_server_addresses(tcp_server)[0]
114
115
116def ensure_concrete_host(host, default_host=None):
117    """
118    Ensure the given host string (or IP) denotes a concrete host, not a
119    wildcard listening address.
120    """
121    if host in ("0.0.0.0", ""):
122        return default_host or get_ip()
123    elif host == "::":
124        return default_host or get_ipv6()
125    else:
126        return host
127