1from __future__ import annotations
2
3import filecmp
4import inspect
5import logging
6import os
7import shutil
8import sys
9import urllib.request
10from collections.abc import Iterable
11from importlib import import_module
12from types import ModuleType
13from typing import cast
14
15import click
16
17from dask.utils import tmpfile
18
19from .core import Server
20from .utils import import_file
21
22logger = logging.getLogger(__name__)
23
24
25def validate_preload_argv(ctx, param, value):
26    """Click option callback providing validation of preload subcommand arguments."""
27    if not value and not ctx.params.get("preload", None):
28        # No preload argv provided and no preload modules specified.
29        return value
30
31    if value and not ctx.params.get("preload", None):
32        # Report a usage error matching standard click error conventions.
33        unexpected_args = [v for v in value if v.startswith("-")]
34        for a in unexpected_args:
35            raise click.NoSuchOption(a)
36        raise click.UsageError(
37            "Got unexpected extra argument%s: (%s)"
38            % ("s" if len(value) > 1 else "", " ".join(value))
39        )
40
41    preload_modules = {
42        name: _import_module(name)
43        for name in ctx.params.get("preload")
44        if not is_webaddress(name)
45    }
46
47    preload_commands = [
48        getattr(m, "dask_setup", None)
49        for m in preload_modules.values()
50        if isinstance(getattr(m, "dask_setup", None), click.Command)
51    ]
52
53    if len(preload_commands) > 1:
54        raise click.UsageError(
55            "Multiple --preload modules with click-configurable setup: %s"
56            % list(preload_modules.keys())
57        )
58
59    if value and not preload_commands:
60        raise click.UsageError(
61            "Unknown argument specified: %r Was click-configurable --preload target provided?"
62        )
63    if not preload_commands:
64        return value
65    else:
66        preload_command = preload_commands[0]
67
68    ctx = click.Context(preload_command, allow_extra_args=False)
69    preload_command.parse_args(ctx, list(value))
70
71    return value
72
73
74def is_webaddress(s: str) -> bool:
75    return any(s.startswith(prefix) for prefix in ("http://", "https://"))
76
77
78def _import_module(name, file_dir=None) -> ModuleType:
79    """Imports module and extract preload interface functions.
80
81    Import modules specified by name and extract 'dask_setup'
82    and 'dask_teardown' if present.
83
84    Parameters
85    ----------
86    name : str
87        Module name, file path, or text of module or script
88    file_dir : string
89        Path of a directory where files should be copied
90
91    Returns
92    -------
93    Nest dict of names to extracted module interface components if present
94    in imported module.
95    """
96    if name.endswith(".py"):
97        # name is a file path
98        if file_dir is not None:
99            basename = os.path.basename(name)
100            copy_dst = os.path.join(file_dir, basename)
101            if os.path.exists(copy_dst):
102                if not filecmp.cmp(name, copy_dst):
103                    logger.error("File name collision: %s", basename)
104            shutil.copy(name, copy_dst)
105            module = import_file(copy_dst)[0]
106        else:
107            module = import_file(name)[0]
108
109    elif " " not in name:
110        # name is a module name
111        if name not in sys.modules:
112            import_module(name)
113        module = sys.modules[name]
114
115    else:
116        # not a name, actually the text of the script
117        with tmpfile(extension=".py") as fn:
118            with open(fn, mode="w") as f:
119                f.write(name)
120            return _import_module(fn, file_dir=file_dir)
121
122    logger.info("Import preload module: %s", name)
123    return module
124
125
126def _download_module(url: str) -> ModuleType:
127    logger.info("Downloading preload at %s", url)
128    assert is_webaddress(url)
129
130    request = urllib.request.Request(url, method="GET")
131    response = urllib.request.urlopen(request)
132    source = response.read().decode()
133
134    compiled = compile(source, url, "exec")
135    module = ModuleType(url)
136    exec(compiled, module.__dict__)
137    return module
138
139
140class Preload:
141    """
142    Manage state for setup/teardown of a preload module
143
144    Parameters
145    ----------
146    dask_server: dask.distributed.Server
147        The Worker or Scheduler
148    name: str
149        module name, file name, or web address to load
150    argv: [str]
151        List of string arguments passed to click-configurable `dask_setup`.
152    file_dir: str
153        Path of a directory where files should be copied
154    """
155
156    dask_server: Server
157    name: str
158    argv: list[str]
159    file_dir: str | None
160    module: ModuleType
161
162    def __init__(
163        self, dask_server: Server, name: str, argv: Iterable[str], file_dir: str | None
164    ):
165        self.dask_server = dask_server
166        self.name = name
167        self.argv = list(argv)
168        self.file_dir = file_dir
169
170        if is_webaddress(name):
171            self.module = _download_module(name)
172        else:
173            self.module = _import_module(name, file_dir)
174
175    async def start(self):
176        """Run when the server finishes its start method"""
177        dask_setup = getattr(self.module, "dask_setup", None)
178
179        if dask_setup:
180            if isinstance(dask_setup, click.Command):
181                context = dask_setup.make_context(
182                    "dask_setup", self.argv, allow_extra_args=False
183                )
184                result = dask_setup.callback(
185                    self.dask_server, *context.args, **context.params
186                )
187                if inspect.isawaitable(result):
188                    await result
189                logger.info("Run preload setup click command: %s", self.name)
190            else:
191                future = dask_setup(self.dask_server)
192                if inspect.isawaitable(future):
193                    await future
194                logger.info("Run preload setup function: %s", self.name)
195
196    async def teardown(self):
197        """Run when the server starts its close method"""
198        dask_teardown = getattr(self.module, "dask_teardown", None)
199        if dask_teardown:
200            future = dask_teardown(self.dask_server)
201            if inspect.isawaitable(future):
202                await future
203
204
205def process_preloads(
206    dask_server,
207    preload: str | list[str],
208    preload_argv: list[str] | list[list[str]],
209    *,
210    file_dir: str | None = None,
211) -> list[Preload]:
212    if isinstance(preload, str):
213        preload = [preload]
214    if preload_argv and isinstance(preload_argv[0], str):
215        preload_argv = [cast("list[str]", preload_argv)] * len(preload)
216    elif not preload_argv:
217        preload_argv = [cast("list[str]", [])] * len(preload)
218    if len(preload) != len(preload_argv):
219        raise ValueError(
220            "preload and preload_argv have mismatched lengths "
221            f"{len(preload)} != {len(preload_argv)}"
222        )
223
224    return [
225        Preload(dask_server, p, argv, file_dir)
226        for p, argv in zip(preload, preload_argv)
227    ]
228