1from __future__ import annotations 2 3import itertools 4 5import dask 6 7from ..utils import get_ip_interface 8from . import registry 9 10 11def parse_address(addr: str, strict: bool = False) -> tuple[str, str]: 12 """ 13 Split address into its scheme and scheme-dependent location string. 14 15 >>> parse_address('tcp://127.0.0.1') 16 ('tcp', '127.0.0.1') 17 18 If strict is set to true the address must have a scheme. 19 """ 20 if not isinstance(addr, str): 21 raise TypeError("expected str, got %r" % addr.__class__.__name__) 22 scheme, sep, loc = addr.rpartition("://") 23 if strict and not sep: 24 msg = ( 25 "Invalid url scheme. " 26 "Must include protocol like tcp://localhost:8000. " 27 "Got %s" % addr 28 ) 29 raise ValueError(msg) 30 if not sep: 31 scheme = dask.config.get("distributed.comm.default-scheme") 32 return scheme, loc 33 34 35def unparse_address(scheme: str, loc: str) -> str: 36 """ 37 Undo parse_address(). 38 39 >>> unparse_address('tcp', '127.0.0.1') 40 'tcp://127.0.0.1' 41 """ 42 return f"{scheme}://{loc}" 43 44 45def normalize_address(addr: str) -> str: 46 """ 47 Canonicalize address, adding a default scheme if necessary. 48 49 >>> normalize_address('tls://[::1]') 50 'tls://[::1]' 51 >>> normalize_address('[::1]') 52 'tcp://[::1]' 53 """ 54 return unparse_address(*parse_address(addr)) 55 56 57def parse_host_port( 58 address: str | tuple[str, int], default_port: str | int | None = None 59) -> tuple[str, int]: 60 """ 61 Parse an endpoint address given in the form "host:port". 62 """ 63 if isinstance(address, tuple): 64 return address 65 66 def _fail(): 67 raise ValueError( 68 f"invalid address {address!r}; maybe: ipv6 needs brackets like [::1]" 69 ) 70 71 def _default(): 72 if default_port is None: 73 raise ValueError(f"missing port number in address {address!r}") 74 return default_port 75 76 if "://" in address: 77 _, address = address.split("://") 78 if address.startswith("["): 79 # IPv6 notation: '[addr]:port' or '[addr]'. 80 # The address may contain multiple colons. 81 host, sep, tail = address[1:].partition("]") 82 if not sep: 83 _fail() 84 if not tail: 85 port = _default() 86 else: 87 if not tail.startswith(":"): 88 _fail() 89 port = tail[1:] 90 else: 91 # Generic notation: 'addr:port' or 'addr'. 92 host, sep, port = address.rpartition(":") 93 if not sep: 94 host = port 95 port = _default() 96 elif ":" in host: 97 _fail() 98 99 return host, int(port) 100 101 102def unparse_host_port(host: str, port: int | None = None) -> str: 103 """ 104 Undo parse_host_port(). 105 """ 106 if ":" in host and not host.startswith("["): 107 host = f"[{host}]" 108 if port is not None: 109 return f"{host}:{port}" 110 else: 111 return host 112 113 114def get_address_host_port(addr: str, strict: bool = False) -> tuple[str, int]: 115 """ 116 Get a (host, port) tuple out of the given address. 117 For definition of strict check parse_address 118 ValueError is raised if the address scheme doesn't allow extracting 119 the requested information. 120 121 >>> get_address_host_port('tcp://1.2.3.4:80') 122 ('1.2.3.4', 80) 123 >>> get_address_host_port('tcp://[::1]:80') 124 ('::1', 80) 125 """ 126 scheme, loc = parse_address(addr, strict=strict) 127 backend = registry.get_backend(scheme) 128 try: 129 return backend.get_address_host_port(loc) 130 except NotImplementedError: 131 raise ValueError( 132 f"don't know how to extract host and port for address {addr!r}" 133 ) 134 135 136def get_address_host(addr: str) -> str: 137 """ 138 Return a hostname / IP address identifying the machine this address 139 is located on. 140 141 In contrast to get_address_host_port(), this function should always 142 succeed for well-formed addresses. 143 144 >>> get_address_host('tcp://1.2.3.4:80') 145 '1.2.3.4' 146 """ 147 scheme, loc = parse_address(addr) 148 backend = registry.get_backend(scheme) 149 return backend.get_address_host(loc) 150 151 152def get_local_address_for(addr: str) -> str: 153 """ 154 Get a local listening address suitable for reaching *addr*. 155 156 For instance, trying to reach an external TCP address will return 157 a local TCP address that's routable to that external address. 158 159 >>> get_local_address_for('tcp://8.8.8.8:1234') 160 'tcp://192.168.1.68' 161 >>> get_local_address_for('tcp://127.0.0.1:1234') 162 'tcp://127.0.0.1' 163 """ 164 scheme, loc = parse_address(addr) 165 backend = registry.get_backend(scheme) 166 return unparse_address(scheme, backend.get_local_address_for(loc)) 167 168 169def resolve_address(addr: str) -> str: 170 """ 171 Apply scheme-specific address resolution to *addr*, replacing 172 all symbolic references with concrete location specifiers. 173 174 In practice, this can mean hostnames are resolved to IP addresses. 175 176 >>> resolve_address('tcp://localhost:8786') 177 'tcp://127.0.0.1:8786' 178 """ 179 scheme, loc = parse_address(addr) 180 backend = registry.get_backend(scheme) 181 return unparse_address(scheme, backend.resolve_address(loc)) 182 183 184def uri_from_host_port( 185 host_arg: str | None, port_arg: str | None, default_port: int 186) -> str: 187 """ 188 Process the *host* and *port* CLI options. 189 Return a URI. 190 """ 191 # Much of distributed depends on a well-known IP being assigned to 192 # each entity (Worker, Scheduler, etc.), so avoid "universal" addresses 193 # like '' which would listen on all registered IPs and interfaces. 194 scheme, loc = parse_address(host_arg or "") 195 196 host, port = parse_host_port( 197 loc, port_arg if port_arg is not None else default_port 198 ) 199 200 if port is None and port_arg is None: 201 port_arg = default_port 202 203 if port and port_arg and port != port_arg: 204 raise ValueError( 205 "port number given twice in options: " 206 "host %r and port %r" % (host_arg, port_arg) 207 ) 208 if port is None and port_arg is not None: 209 port = port_arg 210 # Note `port = 0` means "choose a random port" 211 if port is None: 212 port = default_port 213 loc = unparse_host_port(host, port) 214 addr = unparse_address(scheme, loc) 215 216 return addr 217 218 219def addresses_from_user_args( 220 host=None, 221 port=None, 222 interface=None, 223 protocol=None, 224 peer=None, 225 security=None, 226 default_port=0, 227) -> list: 228 """Get a list of addresses if the inputs are lists 229 230 This is like ``address_from_user_args`` except that it also accepts lists 231 for some of the arguments. If these arguments are lists then it will map 232 over them accordingly. 233 234 Examples 235 -------- 236 >>> addresses_from_user_args(host="127.0.0.1", protocol=["inproc", "tcp"]) 237 ["inproc://127.0.0.1:", "tcp://127.0.0.1:"] 238 """ 239 240 def listify(obj): 241 if isinstance(obj, (tuple, list)): 242 return obj 243 else: 244 return itertools.repeat(obj) 245 246 if any(isinstance(x, (tuple, list)) for x in (host, port, interface, protocol)): 247 return [ 248 address_from_user_args( 249 host=h, 250 port=p, 251 interface=i, 252 protocol=pr, 253 peer=peer, 254 security=security, 255 default_port=default_port, 256 ) 257 for h, p, i, pr in zip(*map(listify, (host, port, interface, protocol))) 258 ] 259 else: 260 return [ 261 address_from_user_args( 262 host, port, interface, protocol, peer, security, default_port 263 ) 264 ] 265 266 267def address_from_user_args( 268 host=None, 269 port=None, 270 interface=None, 271 protocol=None, 272 peer=None, 273 security=None, 274 default_port=0, 275) -> str: 276 """Get an address to listen on from common user provided arguments""" 277 278 if security and security.require_encryption and not protocol: 279 protocol = "tls" 280 281 if protocol and protocol.rstrip("://") == "inplace": 282 if host or port or interface: 283 raise ValueError( 284 "Can not specify inproc protocol and host or port or interface" 285 ) 286 else: 287 return "inproc://" 288 289 if interface: 290 if host: 291 raise ValueError("Can not specify both interface and host", interface, host) 292 else: 293 host = get_ip_interface(interface) 294 295 if protocol and host and "://" not in host: 296 host = protocol.rstrip("://") + "://" + host 297 298 if host or port: 299 addr = uri_from_host_port(host, port, default_port) 300 else: 301 addr = "" 302 303 if protocol: 304 addr = protocol.rstrip("://") + "://" + addr.split("://")[-1] 305 306 return addr 307