1import logging
2import os
3import socket
4import subprocess
5import sys
6import uuid
7import zipfile
8
9from dask.utils import funcname, tmpfile
10
11logger = logging.getLogger(__name__)
12
13
14class SchedulerPlugin:
15    """Interface to extend the Scheduler
16
17    The scheduler operates by triggering and responding to events like
18    ``task_finished``, ``update_graph``, ``task_erred``, etc..
19
20    A plugin enables custom code to run at each of those same events.  The
21    scheduler will run the analogous methods on this class when each event is
22    triggered.  This runs user code within the scheduler thread that can
23    perform arbitrary operations in synchrony with the scheduler itself.
24
25    Plugins are often used for diagnostics and measurement, but have full
26    access to the scheduler and could in principle affect core scheduling.
27
28    To implement a plugin implement some of the methods of this class and add
29    the plugin to the scheduler with ``Scheduler.add_plugin(myplugin)``.
30
31    Examples
32    --------
33    >>> class Counter(SchedulerPlugin):
34    ...     def __init__(self):
35    ...         self.counter = 0
36    ...
37    ...     def transition(self, key, start, finish, *args, **kwargs):
38    ...         if start == 'processing' and finish == 'memory':
39    ...             self.counter += 1
40    ...
41    ...     def restart(self, scheduler):
42    ...         self.counter = 0
43
44    >>> plugin = Counter()
45    >>> scheduler.add_plugin(plugin)  # doctest: +SKIP
46    """
47
48    async def start(self, scheduler):
49        """Run when the scheduler starts up
50
51        This runs at the end of the Scheduler startup process
52        """
53        pass
54
55    async def close(self):
56        """Run when the scheduler closes down
57
58        This runs at the beginning of the Scheduler shutdown process, but after
59        workers have been asked to shut down gracefully
60        """
61        pass
62
63    def update_graph(self, scheduler, dsk=None, keys=None, restrictions=None, **kwargs):
64        """Run when a new graph / tasks enter the scheduler"""
65
66    def restart(self, scheduler, **kwargs):
67        """Run when the scheduler restarts itself"""
68
69    def transition(self, key, start, finish, *args, **kwargs):
70        """Run whenever a task changes state
71
72        Parameters
73        ----------
74        key : string
75        start : string
76            Start state of the transition.
77            One of released, waiting, processing, memory, error.
78        finish : string
79            Final state of the transition.
80        *args, **kwargs : More options passed when transitioning
81            This may include worker ID, compute time, etc.
82        """
83
84    def add_worker(self, scheduler=None, worker=None, **kwargs):
85        """Run when a new worker enters the cluster"""
86
87    def remove_worker(self, scheduler=None, worker=None, **kwargs):
88        """Run when a worker leaves the cluster"""
89
90    def add_client(self, scheduler=None, client=None, **kwargs):
91        """Run when a new client connects"""
92
93    def remove_client(self, scheduler=None, client=None, **kwargs):
94        """Run when a client disconnects"""
95
96
97class WorkerPlugin:
98    """Interface to extend the Worker
99
100    A worker plugin enables custom code to run at different stages of the Workers'
101    lifecycle: at setup, during task state transitions, when a task or dependency
102    is released, and at teardown.
103
104    A plugin enables custom code to run at each of step of a Workers's life. Whenever such
105    an event happens, the corresponding method on this class will be called. Note that the
106    user code always runs within the Worker's main thread.
107
108    To implement a plugin implement some of the methods of this class and register
109    the plugin to your client in order to have it attached to every existing and
110    future workers with ``Client.register_worker_plugin``.
111
112    Examples
113    --------
114    >>> class ErrorLogger(WorkerPlugin):
115    ...     def __init__(self, logger):
116    ...         self.logger = logger
117    ...
118    ...     def setup(self, worker):
119    ...         self.worker = worker
120    ...
121    ...     def transition(self, key, start, finish, *args, **kwargs):
122    ...         if finish == 'error':
123    ...             ts = self.worker.tasks[key]
124    ...             exc_info = (type(ts.exception), ts.exception, ts.traceback)
125    ...             self.logger.error(
126    ...                 "Error during computation of '%s'.", key,
127    ...                 exc_info=exc_info
128    ...             )
129
130    >>> plugin = ErrorLogger()
131    >>> client.register_worker_plugin(plugin)  # doctest: +SKIP
132    """
133
134    def setup(self, worker):
135        """
136        Run when the plugin is attached to a worker. This happens when the plugin is registered
137        and attached to existing workers, or when a worker is created after the plugin has been
138        registered.
139        """
140
141    def teardown(self, worker):
142        """Run when the worker to which the plugin is attached to is closed"""
143
144    def transition(self, key, start, finish, **kwargs):
145        """
146        Throughout the lifecycle of a task (see :doc:`Worker <worker>`), Workers are
147        instructed by the scheduler to compute certain tasks, resulting in transitions
148        in the state of each task. The Worker owning the task is then notified of this
149        state transition.
150
151        Whenever a task changes its state, this method will be called.
152
153        Parameters
154        ----------
155        key : string
156        start : string
157            Start state of the transition.
158            One of waiting, ready, executing, long-running, memory, error.
159        finish : string
160            Final state of the transition.
161        kwargs : More options passed when transitioning
162        """
163
164
165class NannyPlugin:
166    """Interface to extend the Nanny
167
168    A worker plugin enables custom code to run at different stages of the Workers'
169    lifecycle. A nanny plugin does the same thing, but benefits from being able
170    to run code before the worker is started, or to restart the worker if
171    necessary.
172
173    To implement a plugin implement some of the methods of this class and register
174    the plugin to your client in order to have it attached to every existing and
175    future nanny by passing ``nanny=True`` to
176    :meth:`Client.register_worker_plugin<distributed.Client.register_worker_plugin>`.
177
178    The ``restart`` attribute is used to control whether or not a running ``Worker``
179    needs to be restarted when registering the plugin.
180
181    See Also
182    --------
183    WorkerPlugin
184    SchedulerPlugin
185    """
186
187    restart = False
188
189    def setup(self, nanny):
190        """
191        Run when the plugin is attached to a nanny. This happens when the plugin is registered
192        and attached to existing nannies, or when a nanny is created after the plugin has been
193        registered.
194        """
195
196    def teardown(self, nanny):
197        """Run when the nanny to which the plugin is attached to is closed"""
198
199
200def _get_plugin_name(plugin) -> str:
201    """Return plugin name.
202
203    If plugin has no name attribute a random name is used.
204
205    """
206    if hasattr(plugin, "name"):
207        return plugin.name
208    else:
209        return funcname(type(plugin)) + "-" + str(uuid.uuid4())
210
211
212class PipInstall(WorkerPlugin):
213    """A Worker Plugin to pip install a set of packages
214
215    This accepts a set of packages to install on all workers.
216    You can also optionally ask for the worker to restart itself after
217    performing this installation.
218
219    .. note::
220
221       This will increase the time it takes to start up
222       each worker. If possible, we recommend including the
223       libraries in the worker environment or image. This is
224       primarily intended for experimentation and debugging.
225
226       Additional issues may arise if multiple workers share the same
227       file system. Each worker might try to install the packages
228       simultaneously.
229
230    Parameters
231    ----------
232    packages : List[str]
233        A list of strings to place after "pip install" command
234    pip_options : List[str]
235        Additional options to pass to pip.
236    restart : bool, default False
237        Whether or not to restart the worker after pip installing
238        Only functions if the worker has an attached nanny process
239
240    Examples
241    --------
242    >>> from dask.distributed import PipInstall
243    >>> plugin = PipInstall(packages=["scikit-learn"], pip_options=["--upgrade"])
244
245    >>> client.register_worker_plugin(plugin)
246    """
247
248    name = "pip"
249
250    def __init__(self, packages, pip_options=None, restart=False):
251        self.packages = packages
252        self.restart = restart
253        if pip_options is None:
254            pip_options = []
255        self.pip_options = pip_options
256
257    async def setup(self, worker):
258        from ..lock import Lock
259
260        async with Lock(socket.gethostname()):  # don't clobber one installation
261            logger.info("Pip installing the following packages: %s", self.packages)
262            proc = subprocess.Popen(
263                [sys.executable, "-m", "pip", "install"]
264                + self.pip_options
265                + self.packages,
266                stdout=subprocess.PIPE,
267                stderr=subprocess.PIPE,
268            )
269            stdout, stderr = proc.communicate()
270            returncode = proc.wait()
271
272            if returncode:
273                logger.error("Pip install failed with '%s'", stderr.decode().strip())
274                return
275
276            if self.restart and worker.nanny:
277                lines = stdout.strip().split(b"\n")
278                if not all(
279                    line.startswith(b"Requirement already satisfied") for line in lines
280                ):
281                    worker.loop.add_callback(
282                        worker.close_gracefully, restart=True
283                    )  # restart
284
285
286# Adapted from https://github.com/dask/distributed/issues/3560#issuecomment-596138522
287class UploadFile(WorkerPlugin):
288    """A WorkerPlugin to upload a local file to workers.
289
290    Parameters
291    ----------
292    filepath: str
293        A path to the file (.py, egg, or zip) to upload
294
295    Examples
296    --------
297    >>> from distributed.diagnostics.plugin import UploadFile
298
299    >>> client.register_worker_plugin(UploadFile("/path/to/file.py"))  # doctest: +SKIP
300    """
301
302    name = "upload_file"
303
304    def __init__(self, filepath):
305        """
306        Initialize the plugin by reading in the data from the given file.
307        """
308        self.filename = os.path.basename(filepath)
309        with open(filepath, "rb") as f:
310            self.data = f.read()
311
312    async def setup(self, worker):
313        response = await worker.upload_file(
314            comm=None, filename=self.filename, data=self.data, load=True
315        )
316        assert len(self.data) == response["nbytes"]
317
318
319class Environ(NannyPlugin):
320    restart = True
321
322    def __init__(self, environ={}):
323        self.environ = {k: str(v) for k, v in environ.items()}
324
325    async def setup(self, nanny):
326        nanny.env.update(self.environ)
327
328
329class UploadDirectory(NannyPlugin):
330    """A NannyPlugin to upload a local file to workers.
331
332    Parameters
333    ----------
334    path: str
335        A path to the directory to upload
336
337    Examples
338    --------
339    >>> from distributed.diagnostics.plugin import UploadDirectory
340    >>> client.register_worker_plugin(UploadDirectory("/path/to/directory"), nanny=True)  # doctest: +SKIP
341    """
342
343    def __init__(
344        self,
345        path,
346        restart=False,
347        update_path=False,
348        skip_words=(".git", ".github", ".pytest_cache", "tests", "docs"),
349        skip=(lambda fn: os.path.splitext(fn)[1] == ".pyc",),
350    ):
351        """
352        Initialize the plugin by reading in the data from the given file.
353        """
354        path = os.path.expanduser(path)
355        self.path = os.path.split(path)[-1]
356        self.restart = restart
357        self.update_path = update_path
358
359        self.name = "upload-directory-" + os.path.split(path)[-1]
360
361        with tmpfile(extension="zip") as fn:
362            with zipfile.ZipFile(fn, "w", zipfile.ZIP_DEFLATED) as z:
363                for root, dirs, files in os.walk(path):
364                    for file in files:
365                        filename = os.path.join(root, file)
366                        if any(predicate(filename) for predicate in skip):
367                            continue
368                        dirs = filename.split(os.sep)
369                        if any(word in dirs for word in skip_words):
370                            continue
371
372                        archive_name = os.path.relpath(
373                            os.path.join(root, file), os.path.join(path, "..")
374                        )
375                        z.write(filename, archive_name)
376
377            with open(fn, "rb") as f:
378                self.data = f.read()
379
380    async def setup(self, nanny):
381        fn = os.path.join(nanny.local_directory, f"tmp-{str(uuid.uuid4())}.zip")
382        with open(fn, "wb") as f:
383            f.write(self.data)
384
385        import zipfile
386
387        with zipfile.ZipFile(fn) as z:
388            z.extractall(path=nanny.local_directory)
389
390        if self.update_path:
391            path = os.path.join(nanny.local_directory, self.path)
392            if path not in sys.path:
393                sys.path.insert(0, path)
394
395        os.remove(fn)
396