1from __future__ import annotations 2 3import filecmp 4import inspect 5import logging 6import os 7import shutil 8import sys 9import urllib.request 10from collections.abc import Iterable 11from importlib import import_module 12from types import ModuleType 13from typing import cast 14 15import click 16 17from dask.utils import tmpfile 18 19from .core import Server 20from .utils import import_file 21 22logger = logging.getLogger(__name__) 23 24 25def validate_preload_argv(ctx, param, value): 26 """Click option callback providing validation of preload subcommand arguments.""" 27 if not value and not ctx.params.get("preload", None): 28 # No preload argv provided and no preload modules specified. 29 return value 30 31 if value and not ctx.params.get("preload", None): 32 # Report a usage error matching standard click error conventions. 33 unexpected_args = [v for v in value if v.startswith("-")] 34 for a in unexpected_args: 35 raise click.NoSuchOption(a) 36 raise click.UsageError( 37 "Got unexpected extra argument%s: (%s)" 38 % ("s" if len(value) > 1 else "", " ".join(value)) 39 ) 40 41 preload_modules = { 42 name: _import_module(name) 43 for name in ctx.params.get("preload") 44 if not is_webaddress(name) 45 } 46 47 preload_commands = [ 48 getattr(m, "dask_setup", None) 49 for m in preload_modules.values() 50 if isinstance(getattr(m, "dask_setup", None), click.Command) 51 ] 52 53 if len(preload_commands) > 1: 54 raise click.UsageError( 55 "Multiple --preload modules with click-configurable setup: %s" 56 % list(preload_modules.keys()) 57 ) 58 59 if value and not preload_commands: 60 raise click.UsageError( 61 "Unknown argument specified: %r Was click-configurable --preload target provided?" 62 ) 63 if not preload_commands: 64 return value 65 else: 66 preload_command = preload_commands[0] 67 68 ctx = click.Context(preload_command, allow_extra_args=False) 69 preload_command.parse_args(ctx, list(value)) 70 71 return value 72 73 74def is_webaddress(s: str) -> bool: 75 return any(s.startswith(prefix) for prefix in ("http://", "https://")) 76 77 78def _import_module(name, file_dir=None) -> ModuleType: 79 """Imports module and extract preload interface functions. 80 81 Import modules specified by name and extract 'dask_setup' 82 and 'dask_teardown' if present. 83 84 Parameters 85 ---------- 86 name : str 87 Module name, file path, or text of module or script 88 file_dir : string 89 Path of a directory where files should be copied 90 91 Returns 92 ------- 93 Nest dict of names to extracted module interface components if present 94 in imported module. 95 """ 96 if name.endswith(".py"): 97 # name is a file path 98 if file_dir is not None: 99 basename = os.path.basename(name) 100 copy_dst = os.path.join(file_dir, basename) 101 if os.path.exists(copy_dst): 102 if not filecmp.cmp(name, copy_dst): 103 logger.error("File name collision: %s", basename) 104 shutil.copy(name, copy_dst) 105 module = import_file(copy_dst)[0] 106 else: 107 module = import_file(name)[0] 108 109 elif " " not in name: 110 # name is a module name 111 if name not in sys.modules: 112 import_module(name) 113 module = sys.modules[name] 114 115 else: 116 # not a name, actually the text of the script 117 with tmpfile(extension=".py") as fn: 118 with open(fn, mode="w") as f: 119 f.write(name) 120 return _import_module(fn, file_dir=file_dir) 121 122 logger.info("Import preload module: %s", name) 123 return module 124 125 126def _download_module(url: str) -> ModuleType: 127 logger.info("Downloading preload at %s", url) 128 assert is_webaddress(url) 129 130 request = urllib.request.Request(url, method="GET") 131 response = urllib.request.urlopen(request) 132 source = response.read().decode() 133 134 compiled = compile(source, url, "exec") 135 module = ModuleType(url) 136 exec(compiled, module.__dict__) 137 return module 138 139 140class Preload: 141 """ 142 Manage state for setup/teardown of a preload module 143 144 Parameters 145 ---------- 146 dask_server: dask.distributed.Server 147 The Worker or Scheduler 148 name: str 149 module name, file name, or web address to load 150 argv: [str] 151 List of string arguments passed to click-configurable `dask_setup`. 152 file_dir: str 153 Path of a directory where files should be copied 154 """ 155 156 dask_server: Server 157 name: str 158 argv: list[str] 159 file_dir: str | None 160 module: ModuleType 161 162 def __init__( 163 self, dask_server: Server, name: str, argv: Iterable[str], file_dir: str | None 164 ): 165 self.dask_server = dask_server 166 self.name = name 167 self.argv = list(argv) 168 self.file_dir = file_dir 169 170 if is_webaddress(name): 171 self.module = _download_module(name) 172 else: 173 self.module = _import_module(name, file_dir) 174 175 async def start(self): 176 """Run when the server finishes its start method""" 177 dask_setup = getattr(self.module, "dask_setup", None) 178 179 if dask_setup: 180 if isinstance(dask_setup, click.Command): 181 context = dask_setup.make_context( 182 "dask_setup", self.argv, allow_extra_args=False 183 ) 184 result = dask_setup.callback( 185 self.dask_server, *context.args, **context.params 186 ) 187 if inspect.isawaitable(result): 188 await result 189 logger.info("Run preload setup click command: %s", self.name) 190 else: 191 future = dask_setup(self.dask_server) 192 if inspect.isawaitable(future): 193 await future 194 logger.info("Run preload setup function: %s", self.name) 195 196 async def teardown(self): 197 """Run when the server starts its close method""" 198 dask_teardown = getattr(self.module, "dask_teardown", None) 199 if dask_teardown: 200 future = dask_teardown(self.dask_server) 201 if inspect.isawaitable(future): 202 await future 203 204 205def process_preloads( 206 dask_server, 207 preload: str | list[str], 208 preload_argv: list[str] | list[list[str]], 209 *, 210 file_dir: str | None = None, 211) -> list[Preload]: 212 if isinstance(preload, str): 213 preload = [preload] 214 if preload_argv and isinstance(preload_argv[0], str): 215 preload_argv = [cast("list[str]", preload_argv)] * len(preload) 216 elif not preload_argv: 217 preload_argv = [cast("list[str]", [])] * len(preload) 218 if len(preload) != len(preload_argv): 219 raise ValueError( 220 "preload and preload_argv have mismatched lengths " 221 f"{len(preload)} != {len(preload_argv)}" 222 ) 223 224 return [ 225 Preload(dask_server, p, argv, file_dir) 226 for p, argv in zip(preload, preload_argv) 227 ] 228