1from __future__ import annotations
2
3import asyncio
4import contextvars
5import functools
6import importlib
7import inspect
8import json
9import logging
10import multiprocessing
11import os
12import pkgutil
13import re
14import socket
15import sys
16import tempfile
17import threading
18import warnings
19import weakref
20import xml.etree.ElementTree
21from asyncio import TimeoutError
22from collections import OrderedDict, UserDict, deque
23from concurrent.futures import CancelledError, ThreadPoolExecutor  # noqa: F401
24from contextlib import contextmanager, suppress
25from hashlib import md5
26from importlib.util import cache_from_source
27from time import sleep
28from typing import TYPE_CHECKING
29from typing import Any as AnyType
30from typing import ClassVar, Container, Sequence, overload
31
32import click
33import tblib.pickling_support
34
35if TYPE_CHECKING:
36    from typing_extensions import Protocol
37
38try:
39    import resource
40except ImportError:
41    resource = None  # type: ignore
42
43import tlz as toolz
44from tornado import gen
45from tornado.ioloop import IOLoop
46
47import dask
48from dask import istask
49from dask.utils import parse_timedelta as _parse_timedelta
50from dask.widgets import get_template
51
52try:
53    from tornado.ioloop import PollIOLoop  # type: ignore
54except ImportError:
55    PollIOLoop = None  # dropped in tornado 6.0
56
57from .compatibility import PYPY, WINDOWS
58from .metrics import time
59
60try:
61    from dask.context import thread_state
62except ImportError:
63    thread_state = threading.local()
64
65# For some reason this is required in python >= 3.9
66if WINDOWS:
67    import multiprocessing.popen_spawn_win32
68else:
69    import multiprocessing.popen_spawn_posix
70
71logger = _logger = logging.getLogger(__name__)
72
73
74no_default = "__no_default__"
75
76
77def _initialize_mp_context():
78    if WINDOWS or PYPY:
79        return multiprocessing
80    else:
81        method = dask.config.get("distributed.worker.multiprocessing-method")
82        ctx = multiprocessing.get_context(method)
83        # Makes the test suite much faster
84        preload = ["distributed"]
85        if "pkg_resources" in sys.modules:
86            preload.append("pkg_resources")
87
88        from .versions import optional_packages, required_packages
89
90        for pkg, _ in required_packages + optional_packages:
91            try:
92                importlib.import_module(pkg)
93            except ImportError:
94                pass
95            else:
96                preload.append(pkg)
97        ctx.set_forkserver_preload(preload)
98        return ctx
99
100
101mp_context = _initialize_mp_context()
102
103
104def has_arg(func, argname):
105    """
106    Whether the function takes an argument with the given name.
107    """
108    while True:
109        try:
110            if argname in inspect.getfullargspec(func).args:
111                return True
112        except TypeError:
113            break
114        try:
115            # For Tornado coroutines and other decorated functions
116            func = func.__wrapped__
117        except AttributeError:
118            break
119    return False
120
121
122def get_fileno_limit():
123    """
124    Get the maximum number of open files per process.
125    """
126    if resource is not None:
127        return resource.getrlimit(resource.RLIMIT_NOFILE)[0]
128    else:
129        # Default ceiling for Windows when using the CRT, though it
130        # is settable using _setmaxstdio().
131        return 512
132
133
134@toolz.memoize
135def _get_ip(host, port, family):
136    # By using a UDP socket, we don't actually try to connect but
137    # simply select the local address through which *host* is reachable.
138    sock = socket.socket(family, socket.SOCK_DGRAM)
139    try:
140        sock.connect((host, port))
141        ip = sock.getsockname()[0]
142        return ip
143    except OSError as e:
144        warnings.warn(
145            "Couldn't detect a suitable IP address for "
146            "reaching %r, defaulting to hostname: %s" % (host, e),
147            RuntimeWarning,
148        )
149        addr_info = socket.getaddrinfo(
150            socket.gethostname(), port, family, socket.SOCK_DGRAM, socket.IPPROTO_UDP
151        )[0]
152        return addr_info[4][0]
153    finally:
154        sock.close()
155
156
157def get_ip(host="8.8.8.8", port=80):
158    """
159    Get the local IP address through which the *host* is reachable.
160
161    *host* defaults to a well-known Internet host (one of Google's public
162    DNS servers).
163    """
164    return _get_ip(host, port, family=socket.AF_INET)
165
166
167def get_ipv6(host="2001:4860:4860::8888", port=80):
168    """
169    The same as get_ip(), but for IPv6.
170    """
171    return _get_ip(host, port, family=socket.AF_INET6)
172
173
174def get_ip_interface(ifname):
175    """
176    Get the local IPv4 address of a network interface.
177
178    KeyError is raised if the interface doesn't exist.
179    ValueError is raised if the interface does no have an IPv4 address
180    associated with it.
181    """
182    import psutil
183
184    net_if_addrs = psutil.net_if_addrs()
185
186    if ifname not in net_if_addrs:
187        allowed_ifnames = list(net_if_addrs.keys())
188        raise ValueError(
189            "{!r} is not a valid network interface. "
190            "Valid network interfaces are: {}".format(ifname, allowed_ifnames)
191        )
192
193    for info in net_if_addrs[ifname]:
194        if info.family == socket.AF_INET:
195            return info.address
196    raise ValueError(f"interface {ifname!r} doesn't have an IPv4 address")
197
198
199async def All(args, quiet_exceptions=()):
200    """Wait on many tasks at the same time
201
202    Err once any of the tasks err.
203
204    See https://github.com/tornadoweb/tornado/issues/1546
205
206    Parameters
207    ----------
208    args: futures to wait for
209    quiet_exceptions: tuple, Exception
210        Exception types to avoid logging if they fail
211    """
212    tasks = gen.WaitIterator(*map(asyncio.ensure_future, args))
213    results = [None for _ in args]
214    while not tasks.done():
215        try:
216            result = await tasks.next()
217        except Exception:
218
219            @gen.coroutine
220            def quiet():
221                """Watch unfinished tasks
222
223                Otherwise if they err they get logged in a way that is hard to
224                control.  They need some other task to watch them so that they
225                are not orphaned
226                """
227                for task in list(tasks._unfinished):
228                    try:
229                        yield task
230                    except quiet_exceptions:
231                        pass
232
233            quiet()
234            raise
235        results[tasks.current_index] = result
236    return results
237
238
239async def Any(args, quiet_exceptions=()):
240    """Wait on many tasks at the same time and return when any is finished
241
242    Err once any of the tasks err.
243
244    Parameters
245    ----------
246    args: futures to wait for
247    quiet_exceptions: tuple, Exception
248        Exception types to avoid logging if they fail
249    """
250    tasks = gen.WaitIterator(*map(asyncio.ensure_future, args))
251    results = [None for _ in args]
252    while not tasks.done():
253        try:
254            result = await tasks.next()
255        except Exception:
256
257            @gen.coroutine
258            def quiet():
259                """Watch unfinished tasks
260
261                Otherwise if they err they get logged in a way that is hard to
262                control.  They need some other task to watch them so that they
263                are not orphaned
264                """
265                for task in list(tasks._unfinished):
266                    try:
267                        yield task
268                    except quiet_exceptions:
269                        pass
270
271            quiet()
272            raise
273
274        results[tasks.current_index] = result
275        break
276    return results
277
278
279def sync(loop, func, *args, callback_timeout=None, **kwargs):
280    """
281    Run coroutine in loop running in separate thread.
282    """
283    callback_timeout = _parse_timedelta(callback_timeout, "s")
284    # Tornado's PollIOLoop doesn't raise when using closed, do it ourselves
285    if PollIOLoop and (
286        (isinstance(loop, PollIOLoop) and getattr(loop, "_closing", False))
287        or (hasattr(loop, "asyncio_loop") and loop.asyncio_loop._closed)
288    ):
289        raise RuntimeError("IOLoop is closed")
290    try:
291        if loop.asyncio_loop.is_closed():  # tornado 6
292            raise RuntimeError("IOLoop is closed")
293    except AttributeError:
294        pass
295
296    e = threading.Event()
297    main_tid = threading.get_ident()
298    result = [None]
299    error = [False]
300
301    @gen.coroutine
302    def f():
303        # We flag the thread state asynchronous, which will make sync() call
304        # within `func` use async semantic. In order to support concurrent
305        # calls to sync(), `asynchronous` is used as a ref counter.
306        thread_state.asynchronous = getattr(thread_state, "asynchronous", 0)
307        thread_state.asynchronous += 1
308        try:
309            if main_tid == threading.get_ident():
310                raise RuntimeError("sync() called from thread of running loop")
311            yield gen.moment
312            future = func(*args, **kwargs)
313            if callback_timeout is not None:
314                future = asyncio.wait_for(future, callback_timeout)
315            result[0] = yield future
316        except Exception:
317            error[0] = sys.exc_info()
318        finally:
319            assert thread_state.asynchronous > 0
320            thread_state.asynchronous -= 1
321            e.set()
322
323    loop.add_callback(f)
324    if callback_timeout is not None:
325        if not e.wait(callback_timeout):
326            raise TimeoutError(f"timed out after {callback_timeout} s.")
327    else:
328        while not e.is_set():
329            e.wait(10)
330    if error[0]:
331        typ, exc, tb = error[0]
332        raise exc.with_traceback(tb)
333    else:
334        return result[0]
335
336
337class LoopRunner:
338    """
339    A helper to start and stop an IO loop in a controlled way.
340    Several loop runners can associate safely to the same IO loop.
341
342    Parameters
343    ----------
344    loop: IOLoop (optional)
345        If given, this loop will be re-used, otherwise an appropriate one
346        will be looked up or created.
347    asynchronous: boolean (optional, default False)
348        If false (the default), the loop is meant to run in a separate
349        thread and will be started if necessary.
350        If true, the loop is meant to run in the thread this
351        object is instantiated from, and will not be started automatically.
352    """
353
354    # All loops currently associated to loop runners
355    _all_loops: ClassVar[
356        weakref.WeakKeyDictionary[IOLoop, tuple[int, LoopRunner | None]]
357    ] = weakref.WeakKeyDictionary()
358    _lock = threading.Lock()
359
360    def __init__(self, loop=None, asynchronous=False):
361        current = IOLoop.current()
362        if loop is None:
363            if asynchronous:
364                self._loop = current
365            else:
366                # We're expecting the loop to run in another thread,
367                # avoid re-using this thread's assigned loop
368                self._loop = IOLoop()
369        else:
370            self._loop = loop
371        self._asynchronous = asynchronous
372        self._loop_thread = None
373        self._started = False
374        with self._lock:
375            self._all_loops.setdefault(self._loop, (0, None))
376
377    def start(self):
378        """
379        Start the IO loop if required.  The loop is run in a dedicated
380        thread.
381
382        If the loop is already running, this method does nothing.
383        """
384        with self._lock:
385            self._start_unlocked()
386
387    def _start_unlocked(self):
388        assert not self._started
389
390        count, real_runner = self._all_loops[self._loop]
391        if self._asynchronous or real_runner is not None or count > 0:
392            self._all_loops[self._loop] = count + 1, real_runner
393            self._started = True
394            return
395
396        assert self._loop_thread is None
397        assert count == 0
398
399        loop_evt = threading.Event()
400        done_evt = threading.Event()
401        in_thread = [None]
402        start_exc = [None]
403
404        def loop_cb():
405            in_thread[0] = threading.current_thread()
406            loop_evt.set()
407
408        def run_loop(loop=self._loop):
409            loop.add_callback(loop_cb)
410            # run loop forever if it's not running already
411            try:
412                if (
413                    getattr(loop, "asyncio_loop", None) is None
414                    or not loop.asyncio_loop.is_running()
415                ):
416                    loop.start()
417            except Exception as e:
418                start_exc[0] = e
419            finally:
420                done_evt.set()
421
422        thread = threading.Thread(target=run_loop, name="IO loop")
423        thread.daemon = True
424        thread.start()
425
426        loop_evt.wait(timeout=10)
427        self._started = True
428
429        actual_thread = in_thread[0]
430        if actual_thread is not thread:
431            # Loop already running in other thread (user-launched)
432            done_evt.wait(5)
433            if start_exc[0] is not None and not isinstance(start_exc[0], RuntimeError):
434                if not isinstance(
435                    start_exc[0], Exception
436                ):  # track down infrequent error
437                    raise TypeError(
438                        f"not an exception: {start_exc[0]!r}",
439                    )
440                raise start_exc[0]
441            self._all_loops[self._loop] = count + 1, None
442        else:
443            assert start_exc[0] is None, start_exc
444            self._loop_thread = thread
445            self._all_loops[self._loop] = count + 1, self
446
447    def stop(self, timeout=10):
448        """
449        Stop and close the loop if it was created by us.
450        Otherwise, just mark this object "stopped".
451        """
452        with self._lock:
453            self._stop_unlocked(timeout)
454
455    def _stop_unlocked(self, timeout):
456        if not self._started:
457            return
458
459        self._started = False
460
461        count, real_runner = self._all_loops[self._loop]
462        if count > 1:
463            self._all_loops[self._loop] = count - 1, real_runner
464        else:
465            assert count == 1
466            del self._all_loops[self._loop]
467            if real_runner is not None:
468                real_runner._real_stop(timeout)
469
470    def _real_stop(self, timeout):
471        assert self._loop_thread is not None
472        if self._loop_thread is not None:
473            try:
474                self._loop.add_callback(self._loop.stop)
475                self._loop_thread.join(timeout=timeout)
476                with suppress(KeyError):  # IOLoop can be missing
477                    self._loop.close()
478            finally:
479                self._loop_thread = None
480
481    def is_started(self):
482        """
483        Return True between start() and stop() calls, False otherwise.
484        """
485        return self._started
486
487    def run_sync(self, func, *args, **kwargs):
488        """
489        Convenience helper: start the loop if needed,
490        run sync(func, *args, **kwargs), then stop the loop again.
491        """
492        if self._started:
493            return sync(self.loop, func, *args, **kwargs)
494        else:
495            self.start()
496            try:
497                return sync(self.loop, func, *args, **kwargs)
498            finally:
499                self.stop()
500
501    @property
502    def loop(self):
503        return self._loop
504
505
506@contextmanager
507def set_thread_state(**kwargs):
508    old = {}
509    for k in kwargs:
510        try:
511            old[k] = getattr(thread_state, k)
512        except AttributeError:
513            pass
514    for k, v in kwargs.items():
515        setattr(thread_state, k, v)
516    try:
517        yield
518    finally:
519        for k in kwargs:
520            try:
521                v = old[k]
522            except KeyError:
523                delattr(thread_state, k)
524            else:
525                setattr(thread_state, k, v)
526
527
528@contextmanager
529def tmp_text(filename, text):
530    fn = os.path.join(tempfile.gettempdir(), filename)
531    with open(fn, "w") as f:
532        f.write(text)
533
534    try:
535        yield fn
536    finally:
537        if os.path.exists(fn):
538            os.remove(fn)
539
540
541def is_kernel():
542    """Determine if we're running within an IPython kernel
543
544    >>> is_kernel()
545    False
546    """
547    # http://stackoverflow.com/questions/34091701/determine-if-were-in-an-ipython-notebook-session
548    if "IPython" not in sys.modules:  # IPython hasn't been imported
549        return False
550    from IPython import get_ipython
551
552    # check for `kernel` attribute on the IPython instance
553    return getattr(get_ipython(), "kernel", None) is not None
554
555
556hex_pattern = re.compile("[a-f]+")
557
558
559@functools.lru_cache(100000)
560def key_split(s):
561    """
562    >>> key_split('x')
563    'x'
564    >>> key_split('x-1')
565    'x'
566    >>> key_split('x-1-2-3')
567    'x'
568    >>> key_split(('x-2', 1))
569    'x'
570    >>> key_split("('x-2', 1)")
571    'x'
572    >>> key_split("('x', 1)")
573    'x'
574    >>> key_split('hello-world-1')
575    'hello-world'
576    >>> key_split(b'hello-world-1')
577    'hello-world'
578    >>> key_split('ae05086432ca935f6eba409a8ecd4896')
579    'data'
580    >>> key_split('<module.submodule.myclass object at 0xdaf372')
581    'myclass'
582    >>> key_split(None)
583    'Other'
584    >>> key_split('x-abcdefab')  # ignores hex
585    'x'
586    """
587    if type(s) is bytes:
588        s = s.decode()
589    if type(s) is tuple:
590        s = s[0]
591    try:
592        words = s.split("-")
593        if not words[0][0].isalpha():
594            result = words[0].split(",")[0].strip("'(\"")
595        else:
596            result = words[0]
597        for word in words[1:]:
598            if word.isalpha() and not (
599                len(word) == 8 and hex_pattern.match(word) is not None
600            ):
601                result += "-" + word
602            else:
603                break
604        if len(result) == 32 and re.match(r"[a-f0-9]{32}", result):
605            return "data"
606        else:
607            if result[0] == "<":
608                result = result.strip("<>").split()[0].split(".")[-1]
609            return result
610    except Exception:
611        return "Other"
612
613
614def key_split_group(x) -> str:
615    """A more fine-grained version of key_split
616
617    >>> key_split_group(('x-2', 1))
618    'x-2'
619    >>> key_split_group("('x-2', 1)")
620    'x-2'
621    >>> key_split_group('ae05086432ca935f6eba409a8ecd4896')
622    'data'
623    >>> key_split_group('<module.submodule.myclass object at 0xdaf372')
624    'myclass'
625    >>> key_split_group('x')
626    'x'
627    >>> key_split_group('x-1')
628    'x'
629    """
630    typ = type(x)
631    if typ is tuple:
632        return x[0]
633    elif typ is str:
634        if x[0] == "(":
635            return x.split(",", 1)[0].strip("()\"'")
636        elif len(x) == 32 and re.match(r"[a-f0-9]{32}", x):
637            return "data"
638        elif x[0] == "<":
639            return x.strip("<>").split()[0].split(".")[-1]
640        else:
641            return key_split(x)
642    elif typ is bytes:
643        return key_split_group(x.decode())
644    else:
645        return "Other"
646
647
648@contextmanager
649def log_errors(pdb=False):
650    from .comm import CommClosedError
651
652    try:
653        yield
654    except (CommClosedError, gen.Return):
655        raise
656    except Exception as e:
657        try:
658            logger.exception(e)
659        except TypeError:  # logger becomes None during process cleanup
660            pass
661        if pdb:
662            import pdb
663
664            pdb.set_trace()
665        raise
666
667
668def silence_logging(level, root="distributed"):
669    """
670    Change all StreamHandlers for the given logger to the given level
671    """
672    if isinstance(level, str):
673        level = getattr(logging, level.upper())
674
675    old = None
676    logger = logging.getLogger(root)
677    for handler in logger.handlers:
678        if isinstance(handler, logging.StreamHandler):
679            old = handler.level
680            handler.setLevel(level)
681
682    return old
683
684
685@toolz.memoize
686def ensure_ip(hostname):
687    """Ensure that address is an IP address
688
689    Examples
690    --------
691    >>> ensure_ip('localhost')
692    '127.0.0.1'
693    >>> ensure_ip('')  # Maps as localhost for binding e.g. 'tcp://:8811'
694    '127.0.0.1'
695    >>> ensure_ip('123.123.123.123')  # pass through IP addresses
696    '123.123.123.123'
697    """
698    if not hostname:
699        hostname = "localhost"
700
701    # Prefer IPv4 over IPv6, for compatibility
702    families = [socket.AF_INET, socket.AF_INET6]
703    for fam in families:
704        try:
705            results = socket.getaddrinfo(
706                hostname, 1234, fam, socket.SOCK_STREAM  # dummy port number
707            )
708        except socket.gaierror as e:
709            exc = e
710        else:
711            return results[0][4][0]
712
713    raise exc
714
715
716tblib.pickling_support.install()
717
718
719def get_traceback():
720    exc_type, exc_value, exc_traceback = sys.exc_info()
721    bad = [
722        os.path.join("distributed", "worker"),
723        os.path.join("distributed", "scheduler"),
724        os.path.join("tornado", "gen.py"),
725        os.path.join("concurrent", "futures"),
726    ]
727    while exc_traceback and any(
728        b in exc_traceback.tb_frame.f_code.co_filename for b in bad
729    ):
730        exc_traceback = exc_traceback.tb_next
731    return exc_traceback
732
733
734def truncate_exception(e, n=10000):
735    """Truncate exception to be about a certain length"""
736    if len(str(e)) > n:
737        try:
738            return type(e)("Long error message", str(e)[:n])
739        except Exception:
740            return Exception("Long error message", type(e), str(e)[:n])
741    else:
742        return e
743
744
745def validate_key(k):
746    """Validate a key as received on a stream."""
747    typ = type(k)
748    if typ is not str and typ is not bytes:
749        raise TypeError(f"Unexpected key type {typ} (value: {k!r})")
750
751
752def _maybe_complex(task):
753    """Possibly contains a nested task"""
754    return (
755        istask(task)
756        or type(task) is list
757        and any(map(_maybe_complex, task))
758        or type(task) is dict
759        and any(map(_maybe_complex, task.values()))
760    )
761
762
763def seek_delimiter(file, delimiter, blocksize):
764    """Seek current file to next byte after a delimiter bytestring
765
766    This seeks the file to the next byte following the delimiter.  It does
767    not return anything.  Use ``file.tell()`` to see location afterwards.
768
769    Parameters
770    ----------
771    file: a file
772    delimiter: bytes
773        a delimiter like ``b'\n'`` or message sentinel
774    blocksize: int
775        Number of bytes to read from the file at once.
776    """
777
778    if file.tell() == 0:
779        return
780
781    last = b""
782    while True:
783        current = file.read(blocksize)
784        if not current:
785            return
786        full = last + current
787        try:
788            i = full.index(delimiter)
789            file.seek(file.tell() - (len(full) - i) + len(delimiter))
790            return
791        except ValueError:
792            pass
793        last = full[-len(delimiter) :]
794
795
796def read_block(f, offset, length, delimiter=None):
797    """Read a block of bytes from a file
798
799    Parameters
800    ----------
801    f: file
802        File-like object supporting seek, read, tell, etc..
803    offset: int
804        Byte offset to start read
805    length: int
806        Number of bytes to read
807    delimiter: bytes (optional)
808        Ensure reading starts and stops at delimiter bytestring
809
810    If using the ``delimiter=`` keyword argument we ensure that the read
811    starts and stops at delimiter boundaries that follow the locations
812    ``offset`` and ``offset + length``.  If ``offset`` is zero then we
813    start at zero.  The bytestring returned WILL include the
814    terminating delimiter string.
815
816    Examples
817    --------
818
819    >>> from io import BytesIO  # doctest: +SKIP
820    >>> f = BytesIO(b'Alice, 100\\nBob, 200\\nCharlie, 300')  # doctest: +SKIP
821    >>> read_block(f, 0, 13)  # doctest: +SKIP
822    b'Alice, 100\\nBo'
823
824    >>> read_block(f, 0, 13, delimiter=b'\\n')  # doctest: +SKIP
825    b'Alice, 100\\nBob, 200\\n'
826
827    >>> read_block(f, 10, 10, delimiter=b'\\n')  # doctest: +SKIP
828    b'Bob, 200\\nCharlie, 300'
829    """
830    if delimiter:
831        f.seek(offset)
832        seek_delimiter(f, delimiter, 2 ** 16)
833        start = f.tell()
834        length -= start - offset
835
836        f.seek(start + length)
837        seek_delimiter(f, delimiter, 2 ** 16)
838        end = f.tell()
839
840        offset = start
841        length = end - start
842
843    f.seek(offset)
844    bytes = f.read(length)
845    return bytes
846
847
848def ensure_bytes(s):
849    """Attempt to turn `s` into bytes.
850
851    Parameters
852    ----------
853    s : Any
854        The object to be converted. Will correctly handled
855
856        * str
857        * bytes
858        * objects implementing the buffer protocol (memoryview, ndarray, etc.)
859
860    Returns
861    -------
862    b : bytes
863
864    Raises
865    ------
866    TypeError
867        When `s` cannot be converted
868
869    Examples
870    --------
871    >>> ensure_bytes('123')
872    b'123'
873    >>> ensure_bytes(b'123')
874    b'123'
875    """
876    if isinstance(s, bytes):
877        return s
878    elif hasattr(s, "encode"):
879        return s.encode()
880    else:
881        try:
882            return bytes(s)
883        except Exception as e:
884            raise TypeError(
885                "Object %s is neither a bytes object nor has an encode method" % s
886            ) from e
887
888
889def open_port(host=""):
890    """Return a probably-open port
891
892    There is a chance that this port will be taken by the operating system soon
893    after returning from this function.
894    """
895    # http://stackoverflow.com/questions/2838244/get-open-tcp-port-in-python
896    s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
897    s.bind((host, 0))
898    s.listen(1)
899    port = s.getsockname()[1]
900    s.close()
901    return port
902
903
904def import_file(path):
905    """Loads modules for a file (.py, .zip, .egg)"""
906    directory, filename = os.path.split(path)
907    name, ext = os.path.splitext(filename)
908    names_to_import = []
909    tmp_python_path = None
910
911    if ext in (".py",):  # , '.pyc'):
912        if directory not in sys.path:
913            tmp_python_path = directory
914        names_to_import.append(name)
915    if ext == ".py":  # Ensure that no pyc file will be reused
916        cache_file = cache_from_source(path)
917        with suppress(OSError):
918            os.remove(cache_file)
919    if ext in (".egg", ".zip", ".pyz"):
920        if path not in sys.path:
921            sys.path.insert(0, path)
922        names = (mod_info.name for mod_info in pkgutil.iter_modules([path]))
923        names_to_import.extend(names)
924
925    loaded = []
926    if not names_to_import:
927        logger.warning("Found nothing to import from %s", filename)
928    else:
929        importlib.invalidate_caches()
930        if tmp_python_path is not None:
931            sys.path.insert(0, tmp_python_path)
932        try:
933            for name in names_to_import:
934                logger.info("Reload module %s from %s file", name, ext)
935                loaded.append(importlib.reload(importlib.import_module(name)))
936        finally:
937            if tmp_python_path is not None:
938                sys.path.remove(tmp_python_path)
939    return loaded
940
941
942def asciitable(columns, rows):
943    """Formats an ascii table for given columns and rows.
944
945    Parameters
946    ----------
947    columns : list
948        The column names
949    rows : list of tuples
950        The rows in the table. Each tuple must be the same length as
951        ``columns``.
952    """
953    rows = [tuple(str(i) for i in r) for r in rows]
954    columns = tuple(str(i) for i in columns)
955    widths = tuple(max(max(map(len, x)), len(c)) for x, c in zip(zip(*rows), columns))
956    row_template = ("|" + (" %%-%ds |" * len(columns))) % widths
957    header = row_template % tuple(columns)
958    bar = "+%s+" % "+".join("-" * (w + 2) for w in widths)
959    data = "\n".join(row_template % r for r in rows)
960    return "\n".join([bar, header, bar, data, bar])
961
962
963def nbytes(frame, _bytes_like=(bytes, bytearray)):
964    """Number of bytes of a frame or memoryview"""
965    if isinstance(frame, _bytes_like):
966        return len(frame)
967    else:
968        try:
969            return frame.nbytes
970        except AttributeError:
971            return len(frame)
972
973
974def json_load_robust(fn, load=json.load):
975    """Reads a JSON file from disk that may be being written as we read"""
976    while not os.path.exists(fn):
977        sleep(0.01)
978    for i in range(10):
979        try:
980            with open(fn) as f:
981                cfg = load(f)
982            if cfg:
983                return cfg
984        except (ValueError, KeyError):  # race with writing process
985            pass
986        sleep(0.1)
987
988
989class DequeHandler(logging.Handler):
990    """A logging.Handler that records records into a deque"""
991
992    _instances: ClassVar[weakref.WeakSet[DequeHandler]] = weakref.WeakSet()
993
994    def __init__(self, *args, n=10000, **kwargs):
995        self.deque = deque(maxlen=n)
996        super().__init__(*args, **kwargs)
997        self._instances.add(self)
998
999    def emit(self, record):
1000        self.deque.append(record)
1001
1002    def clear(self):
1003        """
1004        Clear internal storage.
1005        """
1006        self.deque.clear()
1007
1008    @classmethod
1009    def clear_all_instances(cls):
1010        """
1011        Clear the internal storage of all live DequeHandlers.
1012        """
1013        for inst in list(cls._instances):
1014            inst.clear()
1015
1016
1017def reset_logger_locks():
1018    """Python 2's logger's locks don't survive a fork event
1019
1020    https://github.com/dask/distributed/issues/1491
1021    """
1022    for name in logging.Logger.manager.loggerDict.keys():
1023        for handler in logging.getLogger(name).handlers:
1024            handler.createLock()
1025
1026
1027is_server_extension = False
1028
1029if "notebook" in sys.modules:
1030    import traitlets
1031    from notebook.notebookapp import NotebookApp
1032
1033    is_server_extension = traitlets.config.Application.initialized() and isinstance(
1034        traitlets.config.Application.instance(), NotebookApp
1035    )
1036
1037if not is_server_extension:
1038    is_kernel_and_no_running_loop = False
1039
1040    if is_kernel():
1041        try:
1042            asyncio.get_running_loop()
1043        except RuntimeError:
1044            is_kernel_and_no_running_loop = True
1045
1046    if not is_kernel_and_no_running_loop:
1047
1048        # TODO: Use tornado's AnyThreadEventLoopPolicy, instead of class below,
1049        # once tornado > 6.0.3 is available.
1050        if WINDOWS:
1051            # WindowsProactorEventLoopPolicy is not compatible with tornado 6
1052            # fallback to the pre-3.8 default of Selector
1053            # https://github.com/tornadoweb/tornado/issues/2608
1054            BaseEventLoopPolicy = asyncio.WindowsSelectorEventLoopPolicy  # type: ignore
1055        else:
1056            BaseEventLoopPolicy = asyncio.DefaultEventLoopPolicy
1057
1058        class AnyThreadEventLoopPolicy(BaseEventLoopPolicy):  # type: ignore
1059            def get_event_loop(self):
1060                try:
1061                    return super().get_event_loop()
1062                except (RuntimeError, AssertionError):
1063                    loop = self.new_event_loop()
1064                    self.set_event_loop(loop)
1065                    return loop
1066
1067        asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy())
1068
1069
1070@functools.lru_cache(1000)
1071def has_keyword(func, keyword):
1072    return keyword in inspect.signature(func).parameters
1073
1074
1075@functools.lru_cache(1000)
1076def command_has_keyword(cmd, k):
1077    if cmd is not None:
1078        if isinstance(cmd, str):
1079            try:
1080                from importlib import import_module
1081
1082                cmd = import_module(cmd)
1083            except ImportError:
1084                raise ImportError("Module for command %s is not available" % cmd)
1085
1086        if isinstance(getattr(cmd, "main"), click.core.Command):
1087            cmd = cmd.main
1088        if isinstance(cmd, click.core.Command):
1089            cmd_params = {
1090                p.human_readable_name
1091                for p in cmd.params
1092                if isinstance(p, click.core.Option)
1093            }
1094            return k in cmd_params
1095
1096    return False
1097
1098
1099# from bokeh.palettes import viridis
1100# palette = viridis(18)
1101palette = [
1102    "#440154",
1103    "#471669",
1104    "#472A79",
1105    "#433C84",
1106    "#3C4D8A",
1107    "#355D8C",
1108    "#2E6C8E",
1109    "#287A8E",
1110    "#23898D",
1111    "#1E978A",
1112    "#20A585",
1113    "#2EB27C",
1114    "#45BF6F",
1115    "#64CB5D",
1116    "#88D547",
1117    "#AFDC2E",
1118    "#D7E219",
1119    "#FDE724",
1120]
1121
1122
1123@toolz.memoize
1124def color_of(x, palette=palette):
1125    h = md5(str(x).encode())
1126    n = int(h.hexdigest()[:8], 16)
1127    return palette[n % len(palette)]
1128
1129
1130def _iscoroutinefunction(f):
1131    # Python < 3.8 does not support determining if `partial` objects wrap async funcs
1132    if sys.version_info < (3, 8):
1133        while isinstance(f, functools.partial):
1134            f = f.func
1135    return inspect.iscoroutinefunction(f) or gen.is_coroutine_function(f)
1136
1137
1138@functools.lru_cache(None)
1139def _iscoroutinefunction_cached(f):
1140    return _iscoroutinefunction(f)
1141
1142
1143def iscoroutinefunction(f):
1144    # Attempt to use lru_cache version and fall back to non-cached version if needed
1145    try:
1146        return _iscoroutinefunction_cached(f)
1147    except TypeError:  # unhashable type
1148        return _iscoroutinefunction(f)
1149
1150
1151@contextmanager
1152def warn_on_duration(duration, msg):
1153    start = time()
1154    yield
1155    stop = time()
1156    if stop - start > _parse_timedelta(duration):
1157        warnings.warn(msg, stacklevel=2)
1158
1159
1160def format_dashboard_link(host, port):
1161    template = dask.config.get("distributed.dashboard.link")
1162    if dask.config.get("distributed.scheduler.dashboard.tls.cert"):
1163        scheme = "https"
1164    else:
1165        scheme = "http"
1166    return template.format(
1167        **toolz.merge(os.environ, dict(scheme=scheme, host=host, port=port))
1168    )
1169
1170
1171def parse_ports(port):
1172    """Parse input port information into list of ports
1173
1174    Parameters
1175    ----------
1176    port : int, str, None
1177        Input port or ports. Can be an integer like 8787, a string for a
1178        single port like "8787", a string for a sequential range of ports like
1179        "8000:8200", or None.
1180
1181    Returns
1182    -------
1183    ports : list
1184        List of ports
1185
1186    Examples
1187    --------
1188    A single port can be specified using an integer:
1189
1190    >>> parse_ports(8787)
1191    [8787]
1192
1193    or a string:
1194
1195    >>> parse_ports("8787")
1196    [8787]
1197
1198    A sequential range of ports can be specified by a string which indicates
1199    the first and last ports which should be included in the sequence of ports:
1200
1201    >>> parse_ports("8787:8790")
1202    [8787, 8788, 8789, 8790]
1203
1204    An input of ``None`` is also valid and can be used to indicate that no port
1205    has been specified:
1206
1207    >>> parse_ports(None)
1208    [None]
1209
1210    """
1211    if isinstance(port, str) and ":" not in port:
1212        port = int(port)
1213
1214    if isinstance(port, (int, type(None))):
1215        ports = [port]
1216    else:
1217        port_start, port_stop = map(int, port.split(":"))
1218        if port_stop <= port_start:
1219            raise ValueError(
1220                "When specifying a range of ports like port_start:port_stop, "
1221                "port_stop must be greater than port_start, but got "
1222                f"port_start={port_start} and port_stop={port_stop}"
1223            )
1224        ports = list(range(port_start, port_stop + 1))
1225
1226    return ports
1227
1228
1229is_coroutine_function = iscoroutinefunction
1230
1231
1232class Log(str):
1233    """A container for newline-delimited string of log entries"""
1234
1235    def _repr_html_(self):
1236        return get_template("log.html.j2").render(log=self)
1237
1238
1239class Logs(dict):
1240    """A container for a dict mapping names to strings of log entries"""
1241
1242    def _repr_html_(self):
1243        return get_template("logs.html.j2").render(logs=self)
1244
1245
1246def cli_keywords(d: dict, cls=None, cmd=None):
1247    """Convert a kwargs dictionary into a list of CLI keywords
1248
1249    Parameters
1250    ----------
1251    d : dict
1252        The keywords to convert
1253    cls : callable
1254        The callable that consumes these terms to check them for validity
1255    cmd : string or object
1256        A string with the name of a module, or the module containing a
1257        click-generated command with a "main" function, or the function itself.
1258        It may be used to parse a module's custom arguments (i.e., arguments that
1259        are not part of Worker class), such as nprocs from dask-worker CLI or
1260        enable_nvlink from dask-cuda-worker CLI.
1261
1262    Examples
1263    --------
1264    >>> cli_keywords({"x": 123, "save_file": "foo.txt"})
1265    ['--x', '123', '--save-file', 'foo.txt']
1266
1267    >>> from dask.distributed import Worker
1268    >>> cli_keywords({"x": 123}, Worker)
1269    Traceback (most recent call last):
1270    ...
1271    ValueError: Class distributed.worker.Worker does not support keyword x
1272    """
1273    from dask.utils import typename
1274
1275    if cls or cmd:
1276        for k in d:
1277            if not has_keyword(cls, k) and not command_has_keyword(cmd, k):
1278                if cls and cmd:
1279                    raise ValueError(
1280                        "Neither class %s or module %s support keyword %s"
1281                        % (typename(cls), typename(cmd), k)
1282                    )
1283                elif cls:
1284                    raise ValueError(
1285                        f"Class {typename(cls)} does not support keyword {k}"
1286                    )
1287                else:
1288                    raise ValueError(
1289                        f"Module {typename(cmd)} does not support keyword {k}"
1290                    )
1291
1292    def convert_value(v):
1293        out = str(v)
1294        if " " in out and "'" not in out and '"' not in out:
1295            out = '"' + out + '"'
1296        return out
1297
1298    return sum(
1299        (["--" + k.replace("_", "-"), convert_value(v)] for k, v in d.items()), []
1300    )
1301
1302
1303def is_valid_xml(text):
1304    return xml.etree.ElementTree.fromstring(text) is not None
1305
1306
1307_offload_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="Dask-Offload")
1308weakref.finalize(_offload_executor, _offload_executor.shutdown)
1309
1310
1311def import_term(name: str):
1312    """Return the fully qualified term
1313
1314    Examples
1315    --------
1316    >>> import_term("math.sin") # doctest: +SKIP
1317    <function math.sin(x, /)>
1318    """
1319    try:
1320        module_name, attr_name = name.rsplit(".", 1)
1321    except ValueError:
1322        return importlib.import_module(name)
1323
1324    module = importlib.import_module(module_name)
1325    return getattr(module, attr_name)
1326
1327
1328async def offload(fn, *args, **kwargs):
1329    loop = asyncio.get_event_loop()
1330    # Retain context vars while deserializing; see https://bugs.python.org/issue34014
1331    context = contextvars.copy_context()
1332    return await loop.run_in_executor(
1333        _offload_executor, lambda: context.run(fn, *args, **kwargs)
1334    )
1335
1336
1337class EmptyContext:
1338    def __enter__(self):
1339        pass
1340
1341    def __exit__(self, *args):
1342        pass
1343
1344    async def __aenter__(self):
1345        pass
1346
1347    async def __aexit__(self, *args):
1348        pass
1349
1350
1351empty_context = EmptyContext()
1352
1353
1354class LRU(UserDict):
1355    """Limited size mapping, evicting the least recently looked-up key when full"""
1356
1357    def __init__(self, maxsize):
1358        super().__init__()
1359        self.data = OrderedDict()
1360        self.maxsize = maxsize
1361
1362    def __getitem__(self, key):
1363        value = super().__getitem__(key)
1364        self.data.move_to_end(key)
1365        return value
1366
1367    def __setitem__(self, key, value):
1368        if len(self) >= self.maxsize:
1369            self.data.popitem(last=False)
1370        super().__setitem__(key, value)
1371
1372
1373def clean_dashboard_address(addrs: AnyType, default_listen_ip: str = "") -> list[dict]:
1374    """
1375    Examples
1376    --------
1377    >>> clean_dashboard_address(8787)
1378    [{'address': '', 'port': 8787}]
1379    >>> clean_dashboard_address(":8787")
1380    [{'address': '', 'port': 8787}]
1381    >>> clean_dashboard_address("8787")
1382    [{'address': '', 'port': 8787}]
1383    >>> clean_dashboard_address("8787")
1384    [{'address': '', 'port': 8787}]
1385    >>> clean_dashboard_address("foo:8787")
1386    [{'address': 'foo', 'port': 8787}]
1387    >>> clean_dashboard_address([8787, 8887])
1388    [{'address': '', 'port': 8787}, {'address': '', 'port': 8887}]
1389    >>> clean_dashboard_address(":8787,:8887")
1390    [{'address': '', 'port': 8787}, {'address': '', 'port': 8887}]
1391    """
1392
1393    if default_listen_ip == "0.0.0.0":
1394        default_listen_ip = ""  # for IPV6
1395
1396    if isinstance(addrs, str):
1397        addrs = addrs.split(",")
1398    if not isinstance(addrs, list):
1399        addrs = [addrs]
1400
1401    addresses = []
1402    for addr in addrs:
1403        try:
1404            addr = int(addr)
1405        except (TypeError, ValueError):
1406            pass
1407
1408        if isinstance(addr, str):
1409            addr = addr.split(":")
1410
1411        if isinstance(addr, (tuple, list)):
1412            if len(addr) == 2:
1413                host, port = (addr[0], int(addr[1]))
1414            elif len(addr) == 1:
1415                [host], port = addr, 0
1416            else:
1417                raise ValueError(addr)
1418        elif isinstance(addr, int):
1419            host = default_listen_ip
1420            port = addr
1421
1422        addresses.append({"address": host, "port": port})
1423    return addresses
1424
1425
1426_deprecations = {
1427    "deserialize_for_cli": "dask.config.deserialize",
1428    "serialize_for_cli": "dask.config.serialize",
1429    "format_bytes": "dask.utils.format_bytes",
1430    "format_time": "dask.utils.format_time",
1431    "funcname": "dask.utils.funcname",
1432    "parse_bytes": "dask.utils.parse_bytes",
1433    "parse_timedelta": "dask.utils.parse_timedelta",
1434    "typename": "dask.utils.typename",
1435    "tmpfile": "dask.utils.tmpfile",
1436}
1437
1438
1439def __getattr__(name):
1440    if name in _deprecations:
1441        use_instead = _deprecations[name]
1442
1443        warnings.warn(
1444            f"{name} is deprecated and will be removed in a future release. "
1445            f"Please use {use_instead} instead.",
1446            category=FutureWarning,
1447            stacklevel=2,
1448        )
1449        return import_term(use_instead)
1450    else:
1451        raise AttributeError(f"module {__name__} has no attribute {name}")
1452
1453
1454if TYPE_CHECKING:
1455
1456    class SupportsToDict(Protocol):
1457        def _to_dict(
1458            self, *, exclude: Container[str] | None = None, **kwargs
1459        ) -> dict[str, AnyType]:
1460            ...
1461
1462
1463@overload
1464def recursive_to_dict(
1465    obj: SupportsToDict, exclude: Container[str] = None, seen: set[AnyType] = None
1466) -> dict[str, AnyType]:
1467    ...
1468
1469
1470@overload
1471def recursive_to_dict(
1472    obj: Sequence, exclude: Container[str] = None, seen: set[AnyType] = None
1473) -> Sequence:
1474    ...
1475
1476
1477@overload
1478def recursive_to_dict(
1479    obj: dict, exclude: Container[str] = None, seen: set[AnyType] = None
1480) -> dict:
1481    ...
1482
1483
1484@overload
1485def recursive_to_dict(
1486    obj: None, exclude: Container[str] = None, seen: set[AnyType] = None
1487) -> None:
1488    ...
1489
1490
1491def recursive_to_dict(obj, exclude=None, seen=None):
1492    """
1493    This is for debugging purposes only and calls ``_to_dict`` methods on ``obj`` or
1494    it's elements recursively, if available. The output of this function is
1495    intended to be json serializable.
1496
1497    Parameters
1498    ----------
1499    exclude:
1500        A list of attribute names to be excluded from the dump.
1501        This will be forwarded to the objects ``_to_dict`` methods and these methods
1502        are required to ensure this.
1503    seen:
1504        Used internally to avoid infinite recursion. If an object has already
1505        been encountered, it's representation will be generated instead of its
1506        ``_to_dict``. This is necessary since we have multiple cyclic referencing
1507        data structures.
1508    """
1509    if obj is None:
1510        return None
1511    if isinstance(obj, str):
1512        return obj
1513    if seen is None:
1514        seen = set()
1515    if id(obj) in seen:
1516        return repr(obj)
1517    seen.add(id(obj))
1518    if isinstance(obj, type):
1519        return repr(obj)
1520    if hasattr(obj, "_to_dict"):
1521        return obj._to_dict(exclude=exclude)
1522    if isinstance(obj, (deque, set)):
1523        obj = tuple(obj)
1524    if isinstance(obj, (list, tuple)):
1525        return tuple(
1526            recursive_to_dict(
1527                el,
1528                exclude=exclude,
1529                seen=seen,
1530            )
1531            for el in obj
1532        )
1533    elif isinstance(obj, dict):
1534        res = {}
1535        for k, v in obj.items():
1536            k = recursive_to_dict(k, exclude=exclude, seen=seen)
1537            try:
1538                hash(k)
1539            except TypeError:
1540                k = str(k)
1541            v = recursive_to_dict(v, exclude=exclude, seen=seen)
1542            res[k] = v
1543        return res
1544    else:
1545        return repr(obj)
1546