1import html
2import logging
3import sys
4import weakref
5from contextlib import suppress
6from timeit import default_timer
7
8from tlz import valmap
9from tornado.ioloop import IOLoop
10
11import dask
12
13from ..client import default_client, futures_of
14from ..core import CommClosedError, clean_exception, coerce_to_address, connect
15from ..protocol.pickle import dumps
16from ..utils import LoopRunner, is_kernel, key_split
17from .progress import MultiProgress, Progress, format_time
18
19logger = logging.getLogger(__name__)
20
21
22def get_scheduler(scheduler):
23    if scheduler is None:
24        return default_client().scheduler.address
25    return coerce_to_address(scheduler)
26
27
28class ProgressBar:
29    def __init__(self, keys, scheduler=None, interval="100ms", complete=True):
30        self.scheduler = get_scheduler(scheduler)
31
32        self.client = None
33        for key in keys:
34            if hasattr(key, "client"):
35                self.client = weakref.ref(key.client)
36                break
37
38        self.keys = {k.key if hasattr(k, "key") else k for k in keys}
39        self.interval = dask.utils.parse_timedelta(interval, default="s")
40        self.complete = complete
41        self._start_time = default_timer()
42
43    @property
44    def elapsed(self):
45        return default_timer() - self._start_time
46
47    async def listen(self):
48        complete = self.complete
49        keys = self.keys
50
51        async def setup(scheduler):
52            p = Progress(keys, scheduler, complete=complete)
53            await p.setup()
54            return p
55
56        def function(scheduler, p):
57            result = {
58                "all": len(p.all_keys),
59                "remaining": len(p.keys),
60                "status": p.status,
61            }
62            if p.status == "error":
63                result.update(p.extra)
64            return result
65
66        self.comm = await connect(
67            self.scheduler, **(self.client().connection_args if self.client else {})
68        )
69        logger.debug("Progressbar Connected to scheduler")
70
71        await self.comm.write(
72            {
73                "op": "feed",
74                "setup": dumps(setup),
75                "function": dumps(function),
76                "interval": self.interval,
77            },
78            serializers=self.client()._serializers if self.client else None,
79        )
80
81        while True:
82            try:
83                response = await self.comm.read(
84                    deserializers=self.client()._deserializers if self.client else None
85                )
86            except CommClosedError:
87                break
88            self._last_response = response
89            self.status = response["status"]
90            self._draw_bar(**response)
91            if response["status"] in ("error", "finished"):
92                await self.comm.close()
93                self._draw_stop(**response)
94                break
95
96        logger.debug("Progressbar disconnected from scheduler")
97
98    def _draw_stop(self, **kwargs):
99        pass
100
101    def __del__(self):
102        with suppress(AttributeError):
103            self.comm.abort()
104
105
106class TextProgressBar(ProgressBar):
107    def __init__(
108        self,
109        keys,
110        scheduler=None,
111        interval="100ms",
112        width=40,
113        loop=None,
114        complete=True,
115        start=True,
116        **kwargs,
117    ):
118        super().__init__(keys, scheduler, interval, complete)
119        self.width = width
120        self.loop = loop or IOLoop()
121
122        if start:
123            loop_runner = LoopRunner(self.loop)
124            loop_runner.run_sync(self.listen)
125
126    def _draw_bar(self, remaining, all, **kwargs):
127        frac = (1 - remaining / all) if all else 1.0
128        bar = "#" * int(self.width * frac)
129        percent = int(100 * frac)
130        elapsed = format_time(self.elapsed)
131        msg = "\r[{0:<{1}}] | {2}% Completed | {3}".format(
132            bar, self.width, percent, elapsed
133        )
134        with suppress(ValueError):
135            sys.stdout.write(msg)
136            sys.stdout.flush()
137
138    def _draw_stop(self, **kwargs):
139        sys.stdout.write("\r")
140        sys.stdout.flush()
141
142
143class ProgressWidget(ProgressBar):
144    """ProgressBar that uses an IPython ProgressBar widget for the notebook
145
146    See Also
147    --------
148    progress: User function
149    TextProgressBar: Text version suitable for the console
150    """
151
152    def __init__(
153        self,
154        keys,
155        scheduler=None,
156        interval="100ms",
157        complete=False,
158        loop=None,
159        **kwargs,
160    ):
161        super().__init__(keys, scheduler, interval, complete)
162
163        from ipywidgets import HTML, FloatProgress, HBox, VBox
164
165        self.elapsed_time = HTML("")
166        self.bar = FloatProgress(min=0, max=1, description="")
167        self.bar_text = HTML("")
168
169        self.bar_widget = HBox([self.bar_text, self.bar])
170        self.widget = VBox([self.elapsed_time, self.bar_widget])
171
172    def _ipython_display_(self, **kwargs):
173        IOLoop.current().add_callback(self.listen)
174        return self.widget._ipython_display_(**kwargs)
175
176    def _draw_stop(self, remaining, status, exception=None, **kwargs):
177        if status == "error":
178            _, exception, _ = clean_exception(exception)
179            self.bar.bar_style = "danger"
180            self.elapsed_time.value = (
181                '<div style="padding: 0px 10px 5px 10px"><b>Exception</b> '
182                "<tt>"
183                + repr(exception)
184                + "</tt>:"
185                + format_time(self.elapsed)
186                + " "
187                + "</div>"
188            )
189        elif not remaining:
190            self.bar.bar_style = "success"
191            self.elapsed_time.value = (
192                '<div style="padding: 0px 10px 5px 10px"><b>Finished:</b> '
193                + format_time(self.elapsed)
194                + "</div>"
195            )
196
197    def _draw_bar(self, remaining, all, **kwargs):
198        ndone = all - remaining
199        self.elapsed_time.value = (
200            '<div style="padding: 0px 10px 5px 10px"><b>Computing:</b> '
201            + format_time(self.elapsed)
202            + "</div>"
203        )
204        self.bar.value = ndone / all if all else 1.0
205        self.bar_text.value = (
206            '<div style="padding: 0px 10px 0px 10px; text-align:right;">%d / %d</div>'
207            % (ndone, all)
208        )
209
210
211class MultiProgressBar:
212    def __init__(
213        self,
214        keys,
215        scheduler=None,
216        func=key_split,
217        interval="100ms",
218        complete=False,
219        **kwargs,
220    ):
221        self.scheduler = get_scheduler(scheduler)
222
223        self.client = None
224        for key in keys:
225            if hasattr(key, "client"):
226                self.client = weakref.ref(key.client)
227                break
228
229        self.keys = {k.key if hasattr(k, "key") else k for k in keys}
230        self.func = func
231        self.interval = interval
232        self.complete = complete
233        self._start_time = default_timer()
234
235    @property
236    def elapsed(self):
237        return default_timer() - self._start_time
238
239    async def listen(self):
240        complete = self.complete
241        keys = self.keys
242        func = self.func
243
244        async def setup(scheduler):
245            p = MultiProgress(keys, scheduler, complete=complete, func=func)
246            await p.setup()
247            return p
248
249        def function(scheduler, p):
250            result = {
251                "all": valmap(len, p.all_keys),
252                "remaining": valmap(len, p.keys),
253                "status": p.status,
254            }
255            if p.status == "error":
256                result.update(p.extra)
257            return result
258
259        self.comm = await connect(
260            self.scheduler, **(self.client().connection_args if self.client else {})
261        )
262        logger.debug("Progressbar Connected to scheduler")
263
264        await self.comm.write(
265            {
266                "op": "feed",
267                "setup": dumps(setup),
268                "function": dumps(function),
269                "interval": self.interval,
270            }
271        )
272
273        while True:
274            response = await self.comm.read(
275                deserializers=self.client()._deserializers if self.client else None
276            )
277            self._last_response = response
278            self.status = response["status"]
279            self._draw_bar(**response)
280            if response["status"] in ("error", "finished"):
281                await self.comm.close()
282                self._draw_stop(**response)
283                break
284        logger.debug("Progressbar disconnected from scheduler")
285
286    def _draw_stop(self, **kwargs):
287        pass
288
289    def __del__(self):
290        with suppress(AttributeError):
291            self.comm.abort()
292
293
294class MultiProgressWidget(MultiProgressBar):
295    """Multiple progress bar Widget suitable for the notebook
296
297    Displays multiple progress bars for a computation, split on computation
298    type.
299
300    See Also
301    --------
302    progress: User-level function <--- use this
303    MultiProgress: Non-visualization component that contains most logic
304    ProgressWidget: Single progress bar widget
305    """
306
307    def __init__(
308        self,
309        keys,
310        scheduler=None,
311        minimum=0,
312        interval=0.1,
313        func=key_split,
314        complete=False,
315        **kwargs,
316    ):
317        super().__init__(keys, scheduler, func, interval, complete)
318        from ipywidgets import VBox
319
320        self.widget = VBox([])
321
322    def make_widget(self, all):
323        from ipywidgets import HTML, FloatProgress, HBox, VBox
324
325        self.elapsed_time = HTML("")
326        self.bars = {key: FloatProgress(min=0, max=1, description="") for key in all}
327        self.bar_texts = {key: HTML("") for key in all}
328        self.bar_labels = {
329            key: HTML(
330                '<div style="padding: 0px 10px 0px 10px;'
331                " text-align:left; word-wrap: "
332                'break-word;">'
333                + html.escape(key.decode() if isinstance(key, bytes) else key)
334                + "</div>"
335            )
336            for key in all
337        }
338
339        def keyfunc(kv):
340            """Order keys by most numerous, then by string name"""
341            return kv[::-1]
342
343        key_order = [k for k, v in sorted(all.items(), key=keyfunc, reverse=True)]
344
345        self.bar_widgets = VBox(
346            [
347                HBox([self.bar_texts[key], self.bars[key], self.bar_labels[key]])
348                for key in key_order
349            ]
350        )
351        self.widget.children = (self.elapsed_time, self.bar_widgets)
352
353    def _ipython_display_(self, **kwargs):
354        IOLoop.current().add_callback(self.listen)
355        return self.widget._ipython_display_(**kwargs)
356
357    def _draw_stop(self, remaining, status, exception=None, key=None, **kwargs):
358        for k, v in remaining.items():
359            if not v:
360                self.bars[k].bar_style = "success"
361            else:
362                self.bars[k].bar_style = "danger"
363
364        if status == "error":
365            _, exception, _ = clean_exception(exception)
366            # self.bars[self.func(key)].bar_style = 'danger'  # TODO
367            self.elapsed_time.value = (
368                '<div style="padding: 0px 10px 5px 10px"><b>Exception</b> '
369                + "<tt>"
370                + repr(exception)
371                + "</tt>:"
372                + format_time(self.elapsed)
373                + " "
374                + "</div>"
375            )
376        else:
377            self.elapsed_time.value = (
378                '<div style="padding: 0px 10px 5px 10px"><b>Finished:</b> '
379                + format_time(self.elapsed)
380                + "</div>"
381            )
382
383    def _draw_bar(self, remaining, all, status, **kwargs):
384        if self.keys and not self.widget.children:
385            self.make_widget(all)
386        for k, ntasks in all.items():
387            ndone = ntasks - remaining[k]
388            self.elapsed_time.value = (
389                '<div style="padding: 0px 10px 5px 10px"><b>Computing:</b> '
390                + format_time(self.elapsed)
391                + "</div>"
392            )
393            self.bars[k].value = ndone / ntasks if ntasks else 1.0
394            self.bar_texts[k].value = (
395                '<div style="padding: 0px 10px 0px 10px; text-align: right">%d / %d</div>'
396                % (ndone, ntasks)
397            )
398
399
400def progress(*futures, notebook=None, multi=True, complete=True, **kwargs):
401    """Track progress of futures
402
403    This operates differently in the notebook and the console
404
405    *  Notebook:  This returns immediately, leaving an IPython widget on screen
406    *  Console:  This blocks until the computation completes
407
408    Parameters
409    ----------
410    futures : Futures
411        A list of futures or keys to track
412    notebook : bool (optional)
413        Running in the notebook or not (defaults to guess)
414    multi : bool (optional)
415        Track different functions independently (defaults to True)
416    complete : bool (optional)
417        Track all keys (True) or only keys that have not yet run (False)
418        (defaults to True)
419
420    Notes
421    -----
422    In the notebook, the output of `progress` must be the last statement
423    in the cell. Typically, this means calling `progress` at the end of a
424    cell.
425
426    Examples
427    --------
428    >>> progress(futures)  # doctest: +SKIP
429    [########################################] | 100% Completed |  1.7s
430    """
431    futures = futures_of(futures)
432    if not isinstance(futures, (set, list)):
433        futures = [futures]
434    if notebook is None:
435        notebook = is_kernel()  # often but not always correct assumption
436    if notebook:
437        if multi:
438            bar = MultiProgressWidget(futures, complete=complete, **kwargs)
439        else:
440            bar = ProgressWidget(futures, complete=complete, **kwargs)
441        return bar
442    else:
443        TextProgressBar(futures, complete=complete, **kwargs)
444