1from __future__ import absolute_import 2 3from functools import partial 4 5from dask.callbacks import Callback 6 7from .auto import tqdm as tqdm_auto 8 9__author__ = {"github.com/": ["casperdcl"]} 10__all__ = ['TqdmCallback'] 11 12 13class TqdmCallback(Callback): 14 """Dask callback for task progress.""" 15 def __init__(self, start=None, pretask=None, tqdm_class=tqdm_auto, 16 **tqdm_kwargs): 17 """ 18 Parameters 19 ---------- 20 tqdm_class : optional 21 `tqdm` class to use for bars [default: `tqdm.auto.tqdm`]. 22 tqdm_kwargs : optional 23 Any other arguments used for all bars. 24 """ 25 super(TqdmCallback, self).__init__(start=start, pretask=pretask) 26 if tqdm_kwargs: 27 tqdm_class = partial(tqdm_class, **tqdm_kwargs) 28 self.tqdm_class = tqdm_class 29 30 def _start_state(self, _, state): 31 self.pbar = self.tqdm_class(total=sum( 32 len(state[k]) for k in ['ready', 'waiting', 'running', 'finished'])) 33 34 def _posttask(self, *_, **__): 35 self.pbar.update() 36 37 def _finish(self, *_, **__): 38 self.pbar.close() 39 40 def display(self): 41 """Displays in the current cell in Notebooks.""" 42 container = getattr(self.bar, 'container', None) 43 if container is None: 44 return 45 from .notebook import display 46 display(container) 47