1from __future__ import annotations
2
3import itertools
4
5import dask
6
7from ..utils import get_ip_interface
8from . import registry
9
10
11def parse_address(addr: str, strict: bool = False) -> tuple[str, str]:
12    """
13    Split address into its scheme and scheme-dependent location string.
14
15    >>> parse_address('tcp://127.0.0.1')
16    ('tcp', '127.0.0.1')
17
18    If strict is set to true the address must have a scheme.
19    """
20    if not isinstance(addr, str):
21        raise TypeError("expected str, got %r" % addr.__class__.__name__)
22    scheme, sep, loc = addr.rpartition("://")
23    if strict and not sep:
24        msg = (
25            "Invalid url scheme. "
26            "Must include protocol like tcp://localhost:8000. "
27            "Got %s" % addr
28        )
29        raise ValueError(msg)
30    if not sep:
31        scheme = dask.config.get("distributed.comm.default-scheme")
32    return scheme, loc
33
34
35def unparse_address(scheme: str, loc: str) -> str:
36    """
37    Undo parse_address().
38
39    >>> unparse_address('tcp', '127.0.0.1')
40    'tcp://127.0.0.1'
41    """
42    return f"{scheme}://{loc}"
43
44
45def normalize_address(addr: str) -> str:
46    """
47    Canonicalize address, adding a default scheme if necessary.
48
49    >>> normalize_address('tls://[::1]')
50    'tls://[::1]'
51    >>> normalize_address('[::1]')
52    'tcp://[::1]'
53    """
54    return unparse_address(*parse_address(addr))
55
56
57def parse_host_port(
58    address: str | tuple[str, int], default_port: str | int | None = None
59) -> tuple[str, int]:
60    """
61    Parse an endpoint address given in the form "host:port".
62    """
63    if isinstance(address, tuple):
64        return address
65
66    def _fail():
67        raise ValueError(
68            f"invalid address {address!r}; maybe: ipv6 needs brackets like [::1]"
69        )
70
71    def _default():
72        if default_port is None:
73            raise ValueError(f"missing port number in address {address!r}")
74        return default_port
75
76    if "://" in address:
77        _, address = address.split("://")
78    if address.startswith("["):
79        # IPv6 notation: '[addr]:port' or '[addr]'.
80        # The address may contain multiple colons.
81        host, sep, tail = address[1:].partition("]")
82        if not sep:
83            _fail()
84        if not tail:
85            port = _default()
86        else:
87            if not tail.startswith(":"):
88                _fail()
89            port = tail[1:]
90    else:
91        # Generic notation: 'addr:port' or 'addr'.
92        host, sep, port = address.rpartition(":")
93        if not sep:
94            host = port
95            port = _default()
96        elif ":" in host:
97            _fail()
98
99    return host, int(port)
100
101
102def unparse_host_port(host: str, port: int | None = None) -> str:
103    """
104    Undo parse_host_port().
105    """
106    if ":" in host and not host.startswith("["):
107        host = f"[{host}]"
108    if port is not None:
109        return f"{host}:{port}"
110    else:
111        return host
112
113
114def get_address_host_port(addr: str, strict: bool = False) -> tuple[str, int]:
115    """
116    Get a (host, port) tuple out of the given address.
117    For definition of strict check parse_address
118    ValueError is raised if the address scheme doesn't allow extracting
119    the requested information.
120
121    >>> get_address_host_port('tcp://1.2.3.4:80')
122    ('1.2.3.4', 80)
123    >>> get_address_host_port('tcp://[::1]:80')
124    ('::1', 80)
125    """
126    scheme, loc = parse_address(addr, strict=strict)
127    backend = registry.get_backend(scheme)
128    try:
129        return backend.get_address_host_port(loc)
130    except NotImplementedError:
131        raise ValueError(
132            f"don't know how to extract host and port for address {addr!r}"
133        )
134
135
136def get_address_host(addr: str) -> str:
137    """
138    Return a hostname / IP address identifying the machine this address
139    is located on.
140
141    In contrast to get_address_host_port(), this function should always
142    succeed for well-formed addresses.
143
144    >>> get_address_host('tcp://1.2.3.4:80')
145    '1.2.3.4'
146    """
147    scheme, loc = parse_address(addr)
148    backend = registry.get_backend(scheme)
149    return backend.get_address_host(loc)
150
151
152def get_local_address_for(addr: str) -> str:
153    """
154    Get a local listening address suitable for reaching *addr*.
155
156    For instance, trying to reach an external TCP address will return
157    a local TCP address that's routable to that external address.
158
159    >>> get_local_address_for('tcp://8.8.8.8:1234')
160    'tcp://192.168.1.68'
161    >>> get_local_address_for('tcp://127.0.0.1:1234')
162    'tcp://127.0.0.1'
163    """
164    scheme, loc = parse_address(addr)
165    backend = registry.get_backend(scheme)
166    return unparse_address(scheme, backend.get_local_address_for(loc))
167
168
169def resolve_address(addr: str) -> str:
170    """
171    Apply scheme-specific address resolution to *addr*, replacing
172    all symbolic references with concrete location specifiers.
173
174    In practice, this can mean hostnames are resolved to IP addresses.
175
176    >>> resolve_address('tcp://localhost:8786')
177    'tcp://127.0.0.1:8786'
178    """
179    scheme, loc = parse_address(addr)
180    backend = registry.get_backend(scheme)
181    return unparse_address(scheme, backend.resolve_address(loc))
182
183
184def uri_from_host_port(
185    host_arg: str | None, port_arg: str | None, default_port: int
186) -> str:
187    """
188    Process the *host* and *port* CLI options.
189    Return a URI.
190    """
191    # Much of distributed depends on a well-known IP being assigned to
192    # each entity (Worker, Scheduler, etc.), so avoid "universal" addresses
193    # like '' which would listen on all registered IPs and interfaces.
194    scheme, loc = parse_address(host_arg or "")
195
196    host, port = parse_host_port(
197        loc, port_arg if port_arg is not None else default_port
198    )
199
200    if port is None and port_arg is None:
201        port_arg = default_port
202
203    if port and port_arg and port != port_arg:
204        raise ValueError(
205            "port number given twice in options: "
206            "host %r and port %r" % (host_arg, port_arg)
207        )
208    if port is None and port_arg is not None:
209        port = port_arg
210    # Note `port = 0` means "choose a random port"
211    if port is None:
212        port = default_port
213    loc = unparse_host_port(host, port)
214    addr = unparse_address(scheme, loc)
215
216    return addr
217
218
219def addresses_from_user_args(
220    host=None,
221    port=None,
222    interface=None,
223    protocol=None,
224    peer=None,
225    security=None,
226    default_port=0,
227) -> list:
228    """Get a list of addresses if the inputs are lists
229
230    This is like ``address_from_user_args`` except that it also accepts lists
231    for some of the arguments.  If these arguments are lists then it will map
232    over them accordingly.
233
234    Examples
235    --------
236    >>> addresses_from_user_args(host="127.0.0.1", protocol=["inproc", "tcp"])
237    ["inproc://127.0.0.1:", "tcp://127.0.0.1:"]
238    """
239
240    def listify(obj):
241        if isinstance(obj, (tuple, list)):
242            return obj
243        else:
244            return itertools.repeat(obj)
245
246    if any(isinstance(x, (tuple, list)) for x in (host, port, interface, protocol)):
247        return [
248            address_from_user_args(
249                host=h,
250                port=p,
251                interface=i,
252                protocol=pr,
253                peer=peer,
254                security=security,
255                default_port=default_port,
256            )
257            for h, p, i, pr in zip(*map(listify, (host, port, interface, protocol)))
258        ]
259    else:
260        return [
261            address_from_user_args(
262                host, port, interface, protocol, peer, security, default_port
263            )
264        ]
265
266
267def address_from_user_args(
268    host=None,
269    port=None,
270    interface=None,
271    protocol=None,
272    peer=None,
273    security=None,
274    default_port=0,
275) -> str:
276    """Get an address to listen on from common user provided arguments"""
277
278    if security and security.require_encryption and not protocol:
279        protocol = "tls"
280
281    if protocol and protocol.rstrip("://") == "inplace":
282        if host or port or interface:
283            raise ValueError(
284                "Can not specify inproc protocol and host or port or interface"
285            )
286        else:
287            return "inproc://"
288
289    if interface:
290        if host:
291            raise ValueError("Can not specify both interface and host", interface, host)
292        else:
293            host = get_ip_interface(interface)
294
295    if protocol and host and "://" not in host:
296        host = protocol.rstrip("://") + "://" + host
297
298    if host or port:
299        addr = uri_from_host_port(host, port, default_port)
300    else:
301        addr = ""
302
303    if protocol:
304        addr = protocol.rstrip("://") + "://" + addr.split("://")[-1]
305
306    return addr
307