1from __future__ import annotations 2 3import asyncio 4import contextvars 5import functools 6import importlib 7import inspect 8import json 9import logging 10import multiprocessing 11import os 12import pkgutil 13import re 14import socket 15import sys 16import tempfile 17import threading 18import warnings 19import weakref 20import xml.etree.ElementTree 21from asyncio import TimeoutError 22from collections import OrderedDict, UserDict, deque 23from concurrent.futures import CancelledError, ThreadPoolExecutor # noqa: F401 24from contextlib import contextmanager, suppress 25from hashlib import md5 26from importlib.util import cache_from_source 27from time import sleep 28from typing import TYPE_CHECKING 29from typing import Any as AnyType 30from typing import ClassVar, Container, Sequence, overload 31 32import click 33import tblib.pickling_support 34 35if TYPE_CHECKING: 36 from typing_extensions import Protocol 37 38try: 39 import resource 40except ImportError: 41 resource = None # type: ignore 42 43import tlz as toolz 44from tornado import gen 45from tornado.ioloop import IOLoop 46 47import dask 48from dask import istask 49from dask.utils import parse_timedelta as _parse_timedelta 50from dask.widgets import get_template 51 52try: 53 from tornado.ioloop import PollIOLoop # type: ignore 54except ImportError: 55 PollIOLoop = None # dropped in tornado 6.0 56 57from .compatibility import PYPY, WINDOWS 58from .metrics import time 59 60try: 61 from dask.context import thread_state 62except ImportError: 63 thread_state = threading.local() 64 65# For some reason this is required in python >= 3.9 66if WINDOWS: 67 import multiprocessing.popen_spawn_win32 68else: 69 import multiprocessing.popen_spawn_posix 70 71logger = _logger = logging.getLogger(__name__) 72 73 74no_default = "__no_default__" 75 76 77def _initialize_mp_context(): 78 if WINDOWS or PYPY: 79 return multiprocessing 80 else: 81 method = dask.config.get("distributed.worker.multiprocessing-method") 82 ctx = multiprocessing.get_context(method) 83 # Makes the test suite much faster 84 preload = ["distributed"] 85 if "pkg_resources" in sys.modules: 86 preload.append("pkg_resources") 87 88 from .versions import optional_packages, required_packages 89 90 for pkg, _ in required_packages + optional_packages: 91 try: 92 importlib.import_module(pkg) 93 except ImportError: 94 pass 95 else: 96 preload.append(pkg) 97 ctx.set_forkserver_preload(preload) 98 return ctx 99 100 101mp_context = _initialize_mp_context() 102 103 104def has_arg(func, argname): 105 """ 106 Whether the function takes an argument with the given name. 107 """ 108 while True: 109 try: 110 if argname in inspect.getfullargspec(func).args: 111 return True 112 except TypeError: 113 break 114 try: 115 # For Tornado coroutines and other decorated functions 116 func = func.__wrapped__ 117 except AttributeError: 118 break 119 return False 120 121 122def get_fileno_limit(): 123 """ 124 Get the maximum number of open files per process. 125 """ 126 if resource is not None: 127 return resource.getrlimit(resource.RLIMIT_NOFILE)[0] 128 else: 129 # Default ceiling for Windows when using the CRT, though it 130 # is settable using _setmaxstdio(). 131 return 512 132 133 134@toolz.memoize 135def _get_ip(host, port, family): 136 # By using a UDP socket, we don't actually try to connect but 137 # simply select the local address through which *host* is reachable. 138 sock = socket.socket(family, socket.SOCK_DGRAM) 139 try: 140 sock.connect((host, port)) 141 ip = sock.getsockname()[0] 142 return ip 143 except OSError as e: 144 warnings.warn( 145 "Couldn't detect a suitable IP address for " 146 "reaching %r, defaulting to hostname: %s" % (host, e), 147 RuntimeWarning, 148 ) 149 addr_info = socket.getaddrinfo( 150 socket.gethostname(), port, family, socket.SOCK_DGRAM, socket.IPPROTO_UDP 151 )[0] 152 return addr_info[4][0] 153 finally: 154 sock.close() 155 156 157def get_ip(host="8.8.8.8", port=80): 158 """ 159 Get the local IP address through which the *host* is reachable. 160 161 *host* defaults to a well-known Internet host (one of Google's public 162 DNS servers). 163 """ 164 return _get_ip(host, port, family=socket.AF_INET) 165 166 167def get_ipv6(host="2001:4860:4860::8888", port=80): 168 """ 169 The same as get_ip(), but for IPv6. 170 """ 171 return _get_ip(host, port, family=socket.AF_INET6) 172 173 174def get_ip_interface(ifname): 175 """ 176 Get the local IPv4 address of a network interface. 177 178 KeyError is raised if the interface doesn't exist. 179 ValueError is raised if the interface does no have an IPv4 address 180 associated with it. 181 """ 182 import psutil 183 184 net_if_addrs = psutil.net_if_addrs() 185 186 if ifname not in net_if_addrs: 187 allowed_ifnames = list(net_if_addrs.keys()) 188 raise ValueError( 189 "{!r} is not a valid network interface. " 190 "Valid network interfaces are: {}".format(ifname, allowed_ifnames) 191 ) 192 193 for info in net_if_addrs[ifname]: 194 if info.family == socket.AF_INET: 195 return info.address 196 raise ValueError(f"interface {ifname!r} doesn't have an IPv4 address") 197 198 199async def All(args, quiet_exceptions=()): 200 """Wait on many tasks at the same time 201 202 Err once any of the tasks err. 203 204 See https://github.com/tornadoweb/tornado/issues/1546 205 206 Parameters 207 ---------- 208 args: futures to wait for 209 quiet_exceptions: tuple, Exception 210 Exception types to avoid logging if they fail 211 """ 212 tasks = gen.WaitIterator(*map(asyncio.ensure_future, args)) 213 results = [None for _ in args] 214 while not tasks.done(): 215 try: 216 result = await tasks.next() 217 except Exception: 218 219 @gen.coroutine 220 def quiet(): 221 """Watch unfinished tasks 222 223 Otherwise if they err they get logged in a way that is hard to 224 control. They need some other task to watch them so that they 225 are not orphaned 226 """ 227 for task in list(tasks._unfinished): 228 try: 229 yield task 230 except quiet_exceptions: 231 pass 232 233 quiet() 234 raise 235 results[tasks.current_index] = result 236 return results 237 238 239async def Any(args, quiet_exceptions=()): 240 """Wait on many tasks at the same time and return when any is finished 241 242 Err once any of the tasks err. 243 244 Parameters 245 ---------- 246 args: futures to wait for 247 quiet_exceptions: tuple, Exception 248 Exception types to avoid logging if they fail 249 """ 250 tasks = gen.WaitIterator(*map(asyncio.ensure_future, args)) 251 results = [None for _ in args] 252 while not tasks.done(): 253 try: 254 result = await tasks.next() 255 except Exception: 256 257 @gen.coroutine 258 def quiet(): 259 """Watch unfinished tasks 260 261 Otherwise if they err they get logged in a way that is hard to 262 control. They need some other task to watch them so that they 263 are not orphaned 264 """ 265 for task in list(tasks._unfinished): 266 try: 267 yield task 268 except quiet_exceptions: 269 pass 270 271 quiet() 272 raise 273 274 results[tasks.current_index] = result 275 break 276 return results 277 278 279def sync(loop, func, *args, callback_timeout=None, **kwargs): 280 """ 281 Run coroutine in loop running in separate thread. 282 """ 283 callback_timeout = _parse_timedelta(callback_timeout, "s") 284 # Tornado's PollIOLoop doesn't raise when using closed, do it ourselves 285 if PollIOLoop and ( 286 (isinstance(loop, PollIOLoop) and getattr(loop, "_closing", False)) 287 or (hasattr(loop, "asyncio_loop") and loop.asyncio_loop._closed) 288 ): 289 raise RuntimeError("IOLoop is closed") 290 try: 291 if loop.asyncio_loop.is_closed(): # tornado 6 292 raise RuntimeError("IOLoop is closed") 293 except AttributeError: 294 pass 295 296 e = threading.Event() 297 main_tid = threading.get_ident() 298 result = [None] 299 error = [False] 300 301 @gen.coroutine 302 def f(): 303 # We flag the thread state asynchronous, which will make sync() call 304 # within `func` use async semantic. In order to support concurrent 305 # calls to sync(), `asynchronous` is used as a ref counter. 306 thread_state.asynchronous = getattr(thread_state, "asynchronous", 0) 307 thread_state.asynchronous += 1 308 try: 309 if main_tid == threading.get_ident(): 310 raise RuntimeError("sync() called from thread of running loop") 311 yield gen.moment 312 future = func(*args, **kwargs) 313 if callback_timeout is not None: 314 future = asyncio.wait_for(future, callback_timeout) 315 result[0] = yield future 316 except Exception: 317 error[0] = sys.exc_info() 318 finally: 319 assert thread_state.asynchronous > 0 320 thread_state.asynchronous -= 1 321 e.set() 322 323 loop.add_callback(f) 324 if callback_timeout is not None: 325 if not e.wait(callback_timeout): 326 raise TimeoutError(f"timed out after {callback_timeout} s.") 327 else: 328 while not e.is_set(): 329 e.wait(10) 330 if error[0]: 331 typ, exc, tb = error[0] 332 raise exc.with_traceback(tb) 333 else: 334 return result[0] 335 336 337class LoopRunner: 338 """ 339 A helper to start and stop an IO loop in a controlled way. 340 Several loop runners can associate safely to the same IO loop. 341 342 Parameters 343 ---------- 344 loop: IOLoop (optional) 345 If given, this loop will be re-used, otherwise an appropriate one 346 will be looked up or created. 347 asynchronous: boolean (optional, default False) 348 If false (the default), the loop is meant to run in a separate 349 thread and will be started if necessary. 350 If true, the loop is meant to run in the thread this 351 object is instantiated from, and will not be started automatically. 352 """ 353 354 # All loops currently associated to loop runners 355 _all_loops: ClassVar[ 356 weakref.WeakKeyDictionary[IOLoop, tuple[int, LoopRunner | None]] 357 ] = weakref.WeakKeyDictionary() 358 _lock = threading.Lock() 359 360 def __init__(self, loop=None, asynchronous=False): 361 current = IOLoop.current() 362 if loop is None: 363 if asynchronous: 364 self._loop = current 365 else: 366 # We're expecting the loop to run in another thread, 367 # avoid re-using this thread's assigned loop 368 self._loop = IOLoop() 369 else: 370 self._loop = loop 371 self._asynchronous = asynchronous 372 self._loop_thread = None 373 self._started = False 374 with self._lock: 375 self._all_loops.setdefault(self._loop, (0, None)) 376 377 def start(self): 378 """ 379 Start the IO loop if required. The loop is run in a dedicated 380 thread. 381 382 If the loop is already running, this method does nothing. 383 """ 384 with self._lock: 385 self._start_unlocked() 386 387 def _start_unlocked(self): 388 assert not self._started 389 390 count, real_runner = self._all_loops[self._loop] 391 if self._asynchronous or real_runner is not None or count > 0: 392 self._all_loops[self._loop] = count + 1, real_runner 393 self._started = True 394 return 395 396 assert self._loop_thread is None 397 assert count == 0 398 399 loop_evt = threading.Event() 400 done_evt = threading.Event() 401 in_thread = [None] 402 start_exc = [None] 403 404 def loop_cb(): 405 in_thread[0] = threading.current_thread() 406 loop_evt.set() 407 408 def run_loop(loop=self._loop): 409 loop.add_callback(loop_cb) 410 # run loop forever if it's not running already 411 try: 412 if ( 413 getattr(loop, "asyncio_loop", None) is None 414 or not loop.asyncio_loop.is_running() 415 ): 416 loop.start() 417 except Exception as e: 418 start_exc[0] = e 419 finally: 420 done_evt.set() 421 422 thread = threading.Thread(target=run_loop, name="IO loop") 423 thread.daemon = True 424 thread.start() 425 426 loop_evt.wait(timeout=10) 427 self._started = True 428 429 actual_thread = in_thread[0] 430 if actual_thread is not thread: 431 # Loop already running in other thread (user-launched) 432 done_evt.wait(5) 433 if start_exc[0] is not None and not isinstance(start_exc[0], RuntimeError): 434 if not isinstance( 435 start_exc[0], Exception 436 ): # track down infrequent error 437 raise TypeError( 438 f"not an exception: {start_exc[0]!r}", 439 ) 440 raise start_exc[0] 441 self._all_loops[self._loop] = count + 1, None 442 else: 443 assert start_exc[0] is None, start_exc 444 self._loop_thread = thread 445 self._all_loops[self._loop] = count + 1, self 446 447 def stop(self, timeout=10): 448 """ 449 Stop and close the loop if it was created by us. 450 Otherwise, just mark this object "stopped". 451 """ 452 with self._lock: 453 self._stop_unlocked(timeout) 454 455 def _stop_unlocked(self, timeout): 456 if not self._started: 457 return 458 459 self._started = False 460 461 count, real_runner = self._all_loops[self._loop] 462 if count > 1: 463 self._all_loops[self._loop] = count - 1, real_runner 464 else: 465 assert count == 1 466 del self._all_loops[self._loop] 467 if real_runner is not None: 468 real_runner._real_stop(timeout) 469 470 def _real_stop(self, timeout): 471 assert self._loop_thread is not None 472 if self._loop_thread is not None: 473 try: 474 self._loop.add_callback(self._loop.stop) 475 self._loop_thread.join(timeout=timeout) 476 with suppress(KeyError): # IOLoop can be missing 477 self._loop.close() 478 finally: 479 self._loop_thread = None 480 481 def is_started(self): 482 """ 483 Return True between start() and stop() calls, False otherwise. 484 """ 485 return self._started 486 487 def run_sync(self, func, *args, **kwargs): 488 """ 489 Convenience helper: start the loop if needed, 490 run sync(func, *args, **kwargs), then stop the loop again. 491 """ 492 if self._started: 493 return sync(self.loop, func, *args, **kwargs) 494 else: 495 self.start() 496 try: 497 return sync(self.loop, func, *args, **kwargs) 498 finally: 499 self.stop() 500 501 @property 502 def loop(self): 503 return self._loop 504 505 506@contextmanager 507def set_thread_state(**kwargs): 508 old = {} 509 for k in kwargs: 510 try: 511 old[k] = getattr(thread_state, k) 512 except AttributeError: 513 pass 514 for k, v in kwargs.items(): 515 setattr(thread_state, k, v) 516 try: 517 yield 518 finally: 519 for k in kwargs: 520 try: 521 v = old[k] 522 except KeyError: 523 delattr(thread_state, k) 524 else: 525 setattr(thread_state, k, v) 526 527 528@contextmanager 529def tmp_text(filename, text): 530 fn = os.path.join(tempfile.gettempdir(), filename) 531 with open(fn, "w") as f: 532 f.write(text) 533 534 try: 535 yield fn 536 finally: 537 if os.path.exists(fn): 538 os.remove(fn) 539 540 541def is_kernel(): 542 """Determine if we're running within an IPython kernel 543 544 >>> is_kernel() 545 False 546 """ 547 # http://stackoverflow.com/questions/34091701/determine-if-were-in-an-ipython-notebook-session 548 if "IPython" not in sys.modules: # IPython hasn't been imported 549 return False 550 from IPython import get_ipython 551 552 # check for `kernel` attribute on the IPython instance 553 return getattr(get_ipython(), "kernel", None) is not None 554 555 556hex_pattern = re.compile("[a-f]+") 557 558 559@functools.lru_cache(100000) 560def key_split(s): 561 """ 562 >>> key_split('x') 563 'x' 564 >>> key_split('x-1') 565 'x' 566 >>> key_split('x-1-2-3') 567 'x' 568 >>> key_split(('x-2', 1)) 569 'x' 570 >>> key_split("('x-2', 1)") 571 'x' 572 >>> key_split("('x', 1)") 573 'x' 574 >>> key_split('hello-world-1') 575 'hello-world' 576 >>> key_split(b'hello-world-1') 577 'hello-world' 578 >>> key_split('ae05086432ca935f6eba409a8ecd4896') 579 'data' 580 >>> key_split('<module.submodule.myclass object at 0xdaf372') 581 'myclass' 582 >>> key_split(None) 583 'Other' 584 >>> key_split('x-abcdefab') # ignores hex 585 'x' 586 """ 587 if type(s) is bytes: 588 s = s.decode() 589 if type(s) is tuple: 590 s = s[0] 591 try: 592 words = s.split("-") 593 if not words[0][0].isalpha(): 594 result = words[0].split(",")[0].strip("'(\"") 595 else: 596 result = words[0] 597 for word in words[1:]: 598 if word.isalpha() and not ( 599 len(word) == 8 and hex_pattern.match(word) is not None 600 ): 601 result += "-" + word 602 else: 603 break 604 if len(result) == 32 and re.match(r"[a-f0-9]{32}", result): 605 return "data" 606 else: 607 if result[0] == "<": 608 result = result.strip("<>").split()[0].split(".")[-1] 609 return result 610 except Exception: 611 return "Other" 612 613 614def key_split_group(x) -> str: 615 """A more fine-grained version of key_split 616 617 >>> key_split_group(('x-2', 1)) 618 'x-2' 619 >>> key_split_group("('x-2', 1)") 620 'x-2' 621 >>> key_split_group('ae05086432ca935f6eba409a8ecd4896') 622 'data' 623 >>> key_split_group('<module.submodule.myclass object at 0xdaf372') 624 'myclass' 625 >>> key_split_group('x') 626 'x' 627 >>> key_split_group('x-1') 628 'x' 629 """ 630 typ = type(x) 631 if typ is tuple: 632 return x[0] 633 elif typ is str: 634 if x[0] == "(": 635 return x.split(",", 1)[0].strip("()\"'") 636 elif len(x) == 32 and re.match(r"[a-f0-9]{32}", x): 637 return "data" 638 elif x[0] == "<": 639 return x.strip("<>").split()[0].split(".")[-1] 640 else: 641 return key_split(x) 642 elif typ is bytes: 643 return key_split_group(x.decode()) 644 else: 645 return "Other" 646 647 648@contextmanager 649def log_errors(pdb=False): 650 from .comm import CommClosedError 651 652 try: 653 yield 654 except (CommClosedError, gen.Return): 655 raise 656 except Exception as e: 657 try: 658 logger.exception(e) 659 except TypeError: # logger becomes None during process cleanup 660 pass 661 if pdb: 662 import pdb 663 664 pdb.set_trace() 665 raise 666 667 668def silence_logging(level, root="distributed"): 669 """ 670 Change all StreamHandlers for the given logger to the given level 671 """ 672 if isinstance(level, str): 673 level = getattr(logging, level.upper()) 674 675 old = None 676 logger = logging.getLogger(root) 677 for handler in logger.handlers: 678 if isinstance(handler, logging.StreamHandler): 679 old = handler.level 680 handler.setLevel(level) 681 682 return old 683 684 685@toolz.memoize 686def ensure_ip(hostname): 687 """Ensure that address is an IP address 688 689 Examples 690 -------- 691 >>> ensure_ip('localhost') 692 '127.0.0.1' 693 >>> ensure_ip('') # Maps as localhost for binding e.g. 'tcp://:8811' 694 '127.0.0.1' 695 >>> ensure_ip('123.123.123.123') # pass through IP addresses 696 '123.123.123.123' 697 """ 698 if not hostname: 699 hostname = "localhost" 700 701 # Prefer IPv4 over IPv6, for compatibility 702 families = [socket.AF_INET, socket.AF_INET6] 703 for fam in families: 704 try: 705 results = socket.getaddrinfo( 706 hostname, 1234, fam, socket.SOCK_STREAM # dummy port number 707 ) 708 except socket.gaierror as e: 709 exc = e 710 else: 711 return results[0][4][0] 712 713 raise exc 714 715 716tblib.pickling_support.install() 717 718 719def get_traceback(): 720 exc_type, exc_value, exc_traceback = sys.exc_info() 721 bad = [ 722 os.path.join("distributed", "worker"), 723 os.path.join("distributed", "scheduler"), 724 os.path.join("tornado", "gen.py"), 725 os.path.join("concurrent", "futures"), 726 ] 727 while exc_traceback and any( 728 b in exc_traceback.tb_frame.f_code.co_filename for b in bad 729 ): 730 exc_traceback = exc_traceback.tb_next 731 return exc_traceback 732 733 734def truncate_exception(e, n=10000): 735 """Truncate exception to be about a certain length""" 736 if len(str(e)) > n: 737 try: 738 return type(e)("Long error message", str(e)[:n]) 739 except Exception: 740 return Exception("Long error message", type(e), str(e)[:n]) 741 else: 742 return e 743 744 745def validate_key(k): 746 """Validate a key as received on a stream.""" 747 typ = type(k) 748 if typ is not str and typ is not bytes: 749 raise TypeError(f"Unexpected key type {typ} (value: {k!r})") 750 751 752def _maybe_complex(task): 753 """Possibly contains a nested task""" 754 return ( 755 istask(task) 756 or type(task) is list 757 and any(map(_maybe_complex, task)) 758 or type(task) is dict 759 and any(map(_maybe_complex, task.values())) 760 ) 761 762 763def seek_delimiter(file, delimiter, blocksize): 764 """Seek current file to next byte after a delimiter bytestring 765 766 This seeks the file to the next byte following the delimiter. It does 767 not return anything. Use ``file.tell()`` to see location afterwards. 768 769 Parameters 770 ---------- 771 file: a file 772 delimiter: bytes 773 a delimiter like ``b'\n'`` or message sentinel 774 blocksize: int 775 Number of bytes to read from the file at once. 776 """ 777 778 if file.tell() == 0: 779 return 780 781 last = b"" 782 while True: 783 current = file.read(blocksize) 784 if not current: 785 return 786 full = last + current 787 try: 788 i = full.index(delimiter) 789 file.seek(file.tell() - (len(full) - i) + len(delimiter)) 790 return 791 except ValueError: 792 pass 793 last = full[-len(delimiter) :] 794 795 796def read_block(f, offset, length, delimiter=None): 797 """Read a block of bytes from a file 798 799 Parameters 800 ---------- 801 f: file 802 File-like object supporting seek, read, tell, etc.. 803 offset: int 804 Byte offset to start read 805 length: int 806 Number of bytes to read 807 delimiter: bytes (optional) 808 Ensure reading starts and stops at delimiter bytestring 809 810 If using the ``delimiter=`` keyword argument we ensure that the read 811 starts and stops at delimiter boundaries that follow the locations 812 ``offset`` and ``offset + length``. If ``offset`` is zero then we 813 start at zero. The bytestring returned WILL include the 814 terminating delimiter string. 815 816 Examples 817 -------- 818 819 >>> from io import BytesIO # doctest: +SKIP 820 >>> f = BytesIO(b'Alice, 100\\nBob, 200\\nCharlie, 300') # doctest: +SKIP 821 >>> read_block(f, 0, 13) # doctest: +SKIP 822 b'Alice, 100\\nBo' 823 824 >>> read_block(f, 0, 13, delimiter=b'\\n') # doctest: +SKIP 825 b'Alice, 100\\nBob, 200\\n' 826 827 >>> read_block(f, 10, 10, delimiter=b'\\n') # doctest: +SKIP 828 b'Bob, 200\\nCharlie, 300' 829 """ 830 if delimiter: 831 f.seek(offset) 832 seek_delimiter(f, delimiter, 2 ** 16) 833 start = f.tell() 834 length -= start - offset 835 836 f.seek(start + length) 837 seek_delimiter(f, delimiter, 2 ** 16) 838 end = f.tell() 839 840 offset = start 841 length = end - start 842 843 f.seek(offset) 844 bytes = f.read(length) 845 return bytes 846 847 848def ensure_bytes(s): 849 """Attempt to turn `s` into bytes. 850 851 Parameters 852 ---------- 853 s : Any 854 The object to be converted. Will correctly handled 855 856 * str 857 * bytes 858 * objects implementing the buffer protocol (memoryview, ndarray, etc.) 859 860 Returns 861 ------- 862 b : bytes 863 864 Raises 865 ------ 866 TypeError 867 When `s` cannot be converted 868 869 Examples 870 -------- 871 >>> ensure_bytes('123') 872 b'123' 873 >>> ensure_bytes(b'123') 874 b'123' 875 """ 876 if isinstance(s, bytes): 877 return s 878 elif hasattr(s, "encode"): 879 return s.encode() 880 else: 881 try: 882 return bytes(s) 883 except Exception as e: 884 raise TypeError( 885 "Object %s is neither a bytes object nor has an encode method" % s 886 ) from e 887 888 889def open_port(host=""): 890 """Return a probably-open port 891 892 There is a chance that this port will be taken by the operating system soon 893 after returning from this function. 894 """ 895 # http://stackoverflow.com/questions/2838244/get-open-tcp-port-in-python 896 s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) 897 s.bind((host, 0)) 898 s.listen(1) 899 port = s.getsockname()[1] 900 s.close() 901 return port 902 903 904def import_file(path): 905 """Loads modules for a file (.py, .zip, .egg)""" 906 directory, filename = os.path.split(path) 907 name, ext = os.path.splitext(filename) 908 names_to_import = [] 909 tmp_python_path = None 910 911 if ext in (".py",): # , '.pyc'): 912 if directory not in sys.path: 913 tmp_python_path = directory 914 names_to_import.append(name) 915 if ext == ".py": # Ensure that no pyc file will be reused 916 cache_file = cache_from_source(path) 917 with suppress(OSError): 918 os.remove(cache_file) 919 if ext in (".egg", ".zip", ".pyz"): 920 if path not in sys.path: 921 sys.path.insert(0, path) 922 names = (mod_info.name for mod_info in pkgutil.iter_modules([path])) 923 names_to_import.extend(names) 924 925 loaded = [] 926 if not names_to_import: 927 logger.warning("Found nothing to import from %s", filename) 928 else: 929 importlib.invalidate_caches() 930 if tmp_python_path is not None: 931 sys.path.insert(0, tmp_python_path) 932 try: 933 for name in names_to_import: 934 logger.info("Reload module %s from %s file", name, ext) 935 loaded.append(importlib.reload(importlib.import_module(name))) 936 finally: 937 if tmp_python_path is not None: 938 sys.path.remove(tmp_python_path) 939 return loaded 940 941 942def asciitable(columns, rows): 943 """Formats an ascii table for given columns and rows. 944 945 Parameters 946 ---------- 947 columns : list 948 The column names 949 rows : list of tuples 950 The rows in the table. Each tuple must be the same length as 951 ``columns``. 952 """ 953 rows = [tuple(str(i) for i in r) for r in rows] 954 columns = tuple(str(i) for i in columns) 955 widths = tuple(max(max(map(len, x)), len(c)) for x, c in zip(zip(*rows), columns)) 956 row_template = ("|" + (" %%-%ds |" * len(columns))) % widths 957 header = row_template % tuple(columns) 958 bar = "+%s+" % "+".join("-" * (w + 2) for w in widths) 959 data = "\n".join(row_template % r for r in rows) 960 return "\n".join([bar, header, bar, data, bar]) 961 962 963def nbytes(frame, _bytes_like=(bytes, bytearray)): 964 """Number of bytes of a frame or memoryview""" 965 if isinstance(frame, _bytes_like): 966 return len(frame) 967 else: 968 try: 969 return frame.nbytes 970 except AttributeError: 971 return len(frame) 972 973 974def json_load_robust(fn, load=json.load): 975 """Reads a JSON file from disk that may be being written as we read""" 976 while not os.path.exists(fn): 977 sleep(0.01) 978 for i in range(10): 979 try: 980 with open(fn) as f: 981 cfg = load(f) 982 if cfg: 983 return cfg 984 except (ValueError, KeyError): # race with writing process 985 pass 986 sleep(0.1) 987 988 989class DequeHandler(logging.Handler): 990 """A logging.Handler that records records into a deque""" 991 992 _instances: ClassVar[weakref.WeakSet[DequeHandler]] = weakref.WeakSet() 993 994 def __init__(self, *args, n=10000, **kwargs): 995 self.deque = deque(maxlen=n) 996 super().__init__(*args, **kwargs) 997 self._instances.add(self) 998 999 def emit(self, record): 1000 self.deque.append(record) 1001 1002 def clear(self): 1003 """ 1004 Clear internal storage. 1005 """ 1006 self.deque.clear() 1007 1008 @classmethod 1009 def clear_all_instances(cls): 1010 """ 1011 Clear the internal storage of all live DequeHandlers. 1012 """ 1013 for inst in list(cls._instances): 1014 inst.clear() 1015 1016 1017def reset_logger_locks(): 1018 """Python 2's logger's locks don't survive a fork event 1019 1020 https://github.com/dask/distributed/issues/1491 1021 """ 1022 for name in logging.Logger.manager.loggerDict.keys(): 1023 for handler in logging.getLogger(name).handlers: 1024 handler.createLock() 1025 1026 1027is_server_extension = False 1028 1029if "notebook" in sys.modules: 1030 import traitlets 1031 from notebook.notebookapp import NotebookApp 1032 1033 is_server_extension = traitlets.config.Application.initialized() and isinstance( 1034 traitlets.config.Application.instance(), NotebookApp 1035 ) 1036 1037if not is_server_extension: 1038 is_kernel_and_no_running_loop = False 1039 1040 if is_kernel(): 1041 try: 1042 asyncio.get_running_loop() 1043 except RuntimeError: 1044 is_kernel_and_no_running_loop = True 1045 1046 if not is_kernel_and_no_running_loop: 1047 1048 # TODO: Use tornado's AnyThreadEventLoopPolicy, instead of class below, 1049 # once tornado > 6.0.3 is available. 1050 if WINDOWS: 1051 # WindowsProactorEventLoopPolicy is not compatible with tornado 6 1052 # fallback to the pre-3.8 default of Selector 1053 # https://github.com/tornadoweb/tornado/issues/2608 1054 BaseEventLoopPolicy = asyncio.WindowsSelectorEventLoopPolicy # type: ignore 1055 else: 1056 BaseEventLoopPolicy = asyncio.DefaultEventLoopPolicy 1057 1058 class AnyThreadEventLoopPolicy(BaseEventLoopPolicy): # type: ignore 1059 def get_event_loop(self): 1060 try: 1061 return super().get_event_loop() 1062 except (RuntimeError, AssertionError): 1063 loop = self.new_event_loop() 1064 self.set_event_loop(loop) 1065 return loop 1066 1067 asyncio.set_event_loop_policy(AnyThreadEventLoopPolicy()) 1068 1069 1070@functools.lru_cache(1000) 1071def has_keyword(func, keyword): 1072 return keyword in inspect.signature(func).parameters 1073 1074 1075@functools.lru_cache(1000) 1076def command_has_keyword(cmd, k): 1077 if cmd is not None: 1078 if isinstance(cmd, str): 1079 try: 1080 from importlib import import_module 1081 1082 cmd = import_module(cmd) 1083 except ImportError: 1084 raise ImportError("Module for command %s is not available" % cmd) 1085 1086 if isinstance(getattr(cmd, "main"), click.core.Command): 1087 cmd = cmd.main 1088 if isinstance(cmd, click.core.Command): 1089 cmd_params = { 1090 p.human_readable_name 1091 for p in cmd.params 1092 if isinstance(p, click.core.Option) 1093 } 1094 return k in cmd_params 1095 1096 return False 1097 1098 1099# from bokeh.palettes import viridis 1100# palette = viridis(18) 1101palette = [ 1102 "#440154", 1103 "#471669", 1104 "#472A79", 1105 "#433C84", 1106 "#3C4D8A", 1107 "#355D8C", 1108 "#2E6C8E", 1109 "#287A8E", 1110 "#23898D", 1111 "#1E978A", 1112 "#20A585", 1113 "#2EB27C", 1114 "#45BF6F", 1115 "#64CB5D", 1116 "#88D547", 1117 "#AFDC2E", 1118 "#D7E219", 1119 "#FDE724", 1120] 1121 1122 1123@toolz.memoize 1124def color_of(x, palette=palette): 1125 h = md5(str(x).encode()) 1126 n = int(h.hexdigest()[:8], 16) 1127 return palette[n % len(palette)] 1128 1129 1130def _iscoroutinefunction(f): 1131 # Python < 3.8 does not support determining if `partial` objects wrap async funcs 1132 if sys.version_info < (3, 8): 1133 while isinstance(f, functools.partial): 1134 f = f.func 1135 return inspect.iscoroutinefunction(f) or gen.is_coroutine_function(f) 1136 1137 1138@functools.lru_cache(None) 1139def _iscoroutinefunction_cached(f): 1140 return _iscoroutinefunction(f) 1141 1142 1143def iscoroutinefunction(f): 1144 # Attempt to use lru_cache version and fall back to non-cached version if needed 1145 try: 1146 return _iscoroutinefunction_cached(f) 1147 except TypeError: # unhashable type 1148 return _iscoroutinefunction(f) 1149 1150 1151@contextmanager 1152def warn_on_duration(duration, msg): 1153 start = time() 1154 yield 1155 stop = time() 1156 if stop - start > _parse_timedelta(duration): 1157 warnings.warn(msg, stacklevel=2) 1158 1159 1160def format_dashboard_link(host, port): 1161 template = dask.config.get("distributed.dashboard.link") 1162 if dask.config.get("distributed.scheduler.dashboard.tls.cert"): 1163 scheme = "https" 1164 else: 1165 scheme = "http" 1166 return template.format( 1167 **toolz.merge(os.environ, dict(scheme=scheme, host=host, port=port)) 1168 ) 1169 1170 1171def parse_ports(port): 1172 """Parse input port information into list of ports 1173 1174 Parameters 1175 ---------- 1176 port : int, str, None 1177 Input port or ports. Can be an integer like 8787, a string for a 1178 single port like "8787", a string for a sequential range of ports like 1179 "8000:8200", or None. 1180 1181 Returns 1182 ------- 1183 ports : list 1184 List of ports 1185 1186 Examples 1187 -------- 1188 A single port can be specified using an integer: 1189 1190 >>> parse_ports(8787) 1191 [8787] 1192 1193 or a string: 1194 1195 >>> parse_ports("8787") 1196 [8787] 1197 1198 A sequential range of ports can be specified by a string which indicates 1199 the first and last ports which should be included in the sequence of ports: 1200 1201 >>> parse_ports("8787:8790") 1202 [8787, 8788, 8789, 8790] 1203 1204 An input of ``None`` is also valid and can be used to indicate that no port 1205 has been specified: 1206 1207 >>> parse_ports(None) 1208 [None] 1209 1210 """ 1211 if isinstance(port, str) and ":" not in port: 1212 port = int(port) 1213 1214 if isinstance(port, (int, type(None))): 1215 ports = [port] 1216 else: 1217 port_start, port_stop = map(int, port.split(":")) 1218 if port_stop <= port_start: 1219 raise ValueError( 1220 "When specifying a range of ports like port_start:port_stop, " 1221 "port_stop must be greater than port_start, but got " 1222 f"port_start={port_start} and port_stop={port_stop}" 1223 ) 1224 ports = list(range(port_start, port_stop + 1)) 1225 1226 return ports 1227 1228 1229is_coroutine_function = iscoroutinefunction 1230 1231 1232class Log(str): 1233 """A container for newline-delimited string of log entries""" 1234 1235 def _repr_html_(self): 1236 return get_template("log.html.j2").render(log=self) 1237 1238 1239class Logs(dict): 1240 """A container for a dict mapping names to strings of log entries""" 1241 1242 def _repr_html_(self): 1243 return get_template("logs.html.j2").render(logs=self) 1244 1245 1246def cli_keywords(d: dict, cls=None, cmd=None): 1247 """Convert a kwargs dictionary into a list of CLI keywords 1248 1249 Parameters 1250 ---------- 1251 d : dict 1252 The keywords to convert 1253 cls : callable 1254 The callable that consumes these terms to check them for validity 1255 cmd : string or object 1256 A string with the name of a module, or the module containing a 1257 click-generated command with a "main" function, or the function itself. 1258 It may be used to parse a module's custom arguments (i.e., arguments that 1259 are not part of Worker class), such as nprocs from dask-worker CLI or 1260 enable_nvlink from dask-cuda-worker CLI. 1261 1262 Examples 1263 -------- 1264 >>> cli_keywords({"x": 123, "save_file": "foo.txt"}) 1265 ['--x', '123', '--save-file', 'foo.txt'] 1266 1267 >>> from dask.distributed import Worker 1268 >>> cli_keywords({"x": 123}, Worker) 1269 Traceback (most recent call last): 1270 ... 1271 ValueError: Class distributed.worker.Worker does not support keyword x 1272 """ 1273 from dask.utils import typename 1274 1275 if cls or cmd: 1276 for k in d: 1277 if not has_keyword(cls, k) and not command_has_keyword(cmd, k): 1278 if cls and cmd: 1279 raise ValueError( 1280 "Neither class %s or module %s support keyword %s" 1281 % (typename(cls), typename(cmd), k) 1282 ) 1283 elif cls: 1284 raise ValueError( 1285 f"Class {typename(cls)} does not support keyword {k}" 1286 ) 1287 else: 1288 raise ValueError( 1289 f"Module {typename(cmd)} does not support keyword {k}" 1290 ) 1291 1292 def convert_value(v): 1293 out = str(v) 1294 if " " in out and "'" not in out and '"' not in out: 1295 out = '"' + out + '"' 1296 return out 1297 1298 return sum( 1299 (["--" + k.replace("_", "-"), convert_value(v)] for k, v in d.items()), [] 1300 ) 1301 1302 1303def is_valid_xml(text): 1304 return xml.etree.ElementTree.fromstring(text) is not None 1305 1306 1307_offload_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="Dask-Offload") 1308weakref.finalize(_offload_executor, _offload_executor.shutdown) 1309 1310 1311def import_term(name: str): 1312 """Return the fully qualified term 1313 1314 Examples 1315 -------- 1316 >>> import_term("math.sin") # doctest: +SKIP 1317 <function math.sin(x, /)> 1318 """ 1319 try: 1320 module_name, attr_name = name.rsplit(".", 1) 1321 except ValueError: 1322 return importlib.import_module(name) 1323 1324 module = importlib.import_module(module_name) 1325 return getattr(module, attr_name) 1326 1327 1328async def offload(fn, *args, **kwargs): 1329 loop = asyncio.get_event_loop() 1330 # Retain context vars while deserializing; see https://bugs.python.org/issue34014 1331 context = contextvars.copy_context() 1332 return await loop.run_in_executor( 1333 _offload_executor, lambda: context.run(fn, *args, **kwargs) 1334 ) 1335 1336 1337class EmptyContext: 1338 def __enter__(self): 1339 pass 1340 1341 def __exit__(self, *args): 1342 pass 1343 1344 async def __aenter__(self): 1345 pass 1346 1347 async def __aexit__(self, *args): 1348 pass 1349 1350 1351empty_context = EmptyContext() 1352 1353 1354class LRU(UserDict): 1355 """Limited size mapping, evicting the least recently looked-up key when full""" 1356 1357 def __init__(self, maxsize): 1358 super().__init__() 1359 self.data = OrderedDict() 1360 self.maxsize = maxsize 1361 1362 def __getitem__(self, key): 1363 value = super().__getitem__(key) 1364 self.data.move_to_end(key) 1365 return value 1366 1367 def __setitem__(self, key, value): 1368 if len(self) >= self.maxsize: 1369 self.data.popitem(last=False) 1370 super().__setitem__(key, value) 1371 1372 1373def clean_dashboard_address(addrs: AnyType, default_listen_ip: str = "") -> list[dict]: 1374 """ 1375 Examples 1376 -------- 1377 >>> clean_dashboard_address(8787) 1378 [{'address': '', 'port': 8787}] 1379 >>> clean_dashboard_address(":8787") 1380 [{'address': '', 'port': 8787}] 1381 >>> clean_dashboard_address("8787") 1382 [{'address': '', 'port': 8787}] 1383 >>> clean_dashboard_address("8787") 1384 [{'address': '', 'port': 8787}] 1385 >>> clean_dashboard_address("foo:8787") 1386 [{'address': 'foo', 'port': 8787}] 1387 >>> clean_dashboard_address([8787, 8887]) 1388 [{'address': '', 'port': 8787}, {'address': '', 'port': 8887}] 1389 >>> clean_dashboard_address(":8787,:8887") 1390 [{'address': '', 'port': 8787}, {'address': '', 'port': 8887}] 1391 """ 1392 1393 if default_listen_ip == "0.0.0.0": 1394 default_listen_ip = "" # for IPV6 1395 1396 if isinstance(addrs, str): 1397 addrs = addrs.split(",") 1398 if not isinstance(addrs, list): 1399 addrs = [addrs] 1400 1401 addresses = [] 1402 for addr in addrs: 1403 try: 1404 addr = int(addr) 1405 except (TypeError, ValueError): 1406 pass 1407 1408 if isinstance(addr, str): 1409 addr = addr.split(":") 1410 1411 if isinstance(addr, (tuple, list)): 1412 if len(addr) == 2: 1413 host, port = (addr[0], int(addr[1])) 1414 elif len(addr) == 1: 1415 [host], port = addr, 0 1416 else: 1417 raise ValueError(addr) 1418 elif isinstance(addr, int): 1419 host = default_listen_ip 1420 port = addr 1421 1422 addresses.append({"address": host, "port": port}) 1423 return addresses 1424 1425 1426_deprecations = { 1427 "deserialize_for_cli": "dask.config.deserialize", 1428 "serialize_for_cli": "dask.config.serialize", 1429 "format_bytes": "dask.utils.format_bytes", 1430 "format_time": "dask.utils.format_time", 1431 "funcname": "dask.utils.funcname", 1432 "parse_bytes": "dask.utils.parse_bytes", 1433 "parse_timedelta": "dask.utils.parse_timedelta", 1434 "typename": "dask.utils.typename", 1435 "tmpfile": "dask.utils.tmpfile", 1436} 1437 1438 1439def __getattr__(name): 1440 if name in _deprecations: 1441 use_instead = _deprecations[name] 1442 1443 warnings.warn( 1444 f"{name} is deprecated and will be removed in a future release. " 1445 f"Please use {use_instead} instead.", 1446 category=FutureWarning, 1447 stacklevel=2, 1448 ) 1449 return import_term(use_instead) 1450 else: 1451 raise AttributeError(f"module {__name__} has no attribute {name}") 1452 1453 1454if TYPE_CHECKING: 1455 1456 class SupportsToDict(Protocol): 1457 def _to_dict( 1458 self, *, exclude: Container[str] | None = None, **kwargs 1459 ) -> dict[str, AnyType]: 1460 ... 1461 1462 1463@overload 1464def recursive_to_dict( 1465 obj: SupportsToDict, exclude: Container[str] = None, seen: set[AnyType] = None 1466) -> dict[str, AnyType]: 1467 ... 1468 1469 1470@overload 1471def recursive_to_dict( 1472 obj: Sequence, exclude: Container[str] = None, seen: set[AnyType] = None 1473) -> Sequence: 1474 ... 1475 1476 1477@overload 1478def recursive_to_dict( 1479 obj: dict, exclude: Container[str] = None, seen: set[AnyType] = None 1480) -> dict: 1481 ... 1482 1483 1484@overload 1485def recursive_to_dict( 1486 obj: None, exclude: Container[str] = None, seen: set[AnyType] = None 1487) -> None: 1488 ... 1489 1490 1491def recursive_to_dict(obj, exclude=None, seen=None): 1492 """ 1493 This is for debugging purposes only and calls ``_to_dict`` methods on ``obj`` or 1494 it's elements recursively, if available. The output of this function is 1495 intended to be json serializable. 1496 1497 Parameters 1498 ---------- 1499 exclude: 1500 A list of attribute names to be excluded from the dump. 1501 This will be forwarded to the objects ``_to_dict`` methods and these methods 1502 are required to ensure this. 1503 seen: 1504 Used internally to avoid infinite recursion. If an object has already 1505 been encountered, it's representation will be generated instead of its 1506 ``_to_dict``. This is necessary since we have multiple cyclic referencing 1507 data structures. 1508 """ 1509 if obj is None: 1510 return None 1511 if isinstance(obj, str): 1512 return obj 1513 if seen is None: 1514 seen = set() 1515 if id(obj) in seen: 1516 return repr(obj) 1517 seen.add(id(obj)) 1518 if isinstance(obj, type): 1519 return repr(obj) 1520 if hasattr(obj, "_to_dict"): 1521 return obj._to_dict(exclude=exclude) 1522 if isinstance(obj, (deque, set)): 1523 obj = tuple(obj) 1524 if isinstance(obj, (list, tuple)): 1525 return tuple( 1526 recursive_to_dict( 1527 el, 1528 exclude=exclude, 1529 seen=seen, 1530 ) 1531 for el in obj 1532 ) 1533 elif isinstance(obj, dict): 1534 res = {} 1535 for k, v in obj.items(): 1536 k = recursive_to_dict(k, exclude=exclude, seen=seen) 1537 try: 1538 hash(k) 1539 except TypeError: 1540 k = str(k) 1541 v = recursive_to_dict(v, exclude=exclude, seen=seen) 1542 res[k] = v 1543 return res 1544 else: 1545 return repr(obj) 1546