1import logging 2import math 3import socket 4 5import dask 6from dask.sizeof import sizeof 7from dask.utils import parse_bytes 8 9from .. import protocol 10from ..utils import get_ip, get_ipv6, nbytes, offload 11 12logger = logging.getLogger(__name__) 13 14 15# Offload (de)serializing large frames to improve event loop responsiveness. 16OFFLOAD_THRESHOLD = dask.config.get("distributed.comm.offload") 17if isinstance(OFFLOAD_THRESHOLD, str): 18 OFFLOAD_THRESHOLD = parse_bytes(OFFLOAD_THRESHOLD) 19 20 21async def to_frames( 22 msg, 23 allow_offload=True, 24 **kwargs, 25): 26 """ 27 Serialize a message into a list of Distributed protocol frames. 28 Any kwargs are forwarded to protocol.dumps(). 29 """ 30 31 def _to_frames(): 32 try: 33 return list(protocol.dumps(msg, **kwargs)) 34 except Exception as e: 35 logger.info("Unserializable Message: %s", msg) 36 logger.exception(e) 37 raise 38 39 if OFFLOAD_THRESHOLD and allow_offload: 40 try: 41 msg_size = sizeof(msg) 42 except RecursionError: 43 msg_size = math.inf 44 else: 45 msg_size = 0 46 47 if allow_offload and OFFLOAD_THRESHOLD and msg_size > OFFLOAD_THRESHOLD: 48 return await offload(_to_frames) 49 else: 50 return _to_frames() 51 52 53async def from_frames(frames, deserialize=True, deserializers=None, allow_offload=True): 54 """ 55 Unserialize a list of Distributed protocol frames. 56 """ 57 size = False 58 59 def _from_frames(): 60 try: 61 return protocol.loads( 62 frames, deserialize=deserialize, deserializers=deserializers 63 ) 64 except EOFError: 65 if size > 1000: 66 datastr = "[too large to display]" 67 else: 68 datastr = frames 69 # Aid diagnosing 70 logger.error("truncated data stream (%d bytes): %s", size, datastr) 71 raise 72 73 if allow_offload and deserialize and OFFLOAD_THRESHOLD: 74 size = sum(map(nbytes, frames)) 75 if allow_offload and deserialize and OFFLOAD_THRESHOLD and size > OFFLOAD_THRESHOLD: 76 res = await offload(_from_frames) 77 else: 78 res = _from_frames() 79 80 return res 81 82 83def get_tcp_server_addresses(tcp_server): 84 """ 85 Get all bound addresses of a started Tornado TCPServer. 86 """ 87 sockets = list(tcp_server._sockets.values()) 88 if not sockets: 89 raise RuntimeError(f"TCP Server {tcp_server!r} not started yet?") 90 91 def _look_for_family(fam): 92 socks = [] 93 for sock in sockets: 94 if sock.family == fam: 95 socks.append(sock) 96 return socks 97 98 # If listening on both IPv4 and IPv6, prefer IPv4 as defective IPv6 99 # is common (e.g. Travis-CI). 100 socks = _look_for_family(socket.AF_INET) 101 if not socks: 102 socks = _look_for_family(socket.AF_INET6) 103 if not socks: 104 raise RuntimeError("No Internet socket found on TCPServer??") 105 106 return [sock.getsockname() for sock in socks] 107 108 109def get_tcp_server_address(tcp_server): 110 """ 111 Get the first bound address of a started Tornado TCPServer. 112 """ 113 return get_tcp_server_addresses(tcp_server)[0] 114 115 116def ensure_concrete_host(host, default_host=None): 117 """ 118 Ensure the given host string (or IP) denotes a concrete host, not a 119 wildcard listening address. 120 """ 121 if host in ("0.0.0.0", ""): 122 return default_host or get_ip() 123 elif host == "::": 124 return default_host or get_ipv6() 125 else: 126 return host 127