1""" 2Thin wrappers around common functions. 3 4Subpackages contain potentially unstable extensions. 5""" 6import sys 7from functools import wraps 8 9from ..auto import tqdm as tqdm_auto 10from ..std import tqdm 11from ..utils import ObjectWrapper 12 13__author__ = {"github.com/": ["casperdcl"]} 14__all__ = ['tenumerate', 'tzip', 'tmap'] 15 16 17class DummyTqdmFile(ObjectWrapper): 18 """Dummy file-like that will write to tqdm""" 19 20 def __init__(self, wrapped): 21 super(DummyTqdmFile, self).__init__(wrapped) 22 self._buf = [] 23 24 def write(self, x, nolock=False): 25 nl = b"\n" if isinstance(x, bytes) else "\n" 26 pre, sep, post = x.rpartition(nl) 27 if sep: 28 blank = type(nl)() 29 tqdm.write(blank.join(self._buf + [pre, sep]), 30 end=blank, file=self._wrapped, nolock=nolock) 31 self._buf = [post] 32 else: 33 self._buf.append(x) 34 35 def __del__(self): 36 if self._buf: 37 blank = type(self._buf[0])() 38 try: 39 tqdm.write(blank.join(self._buf), end=blank, file=self._wrapped) 40 except (OSError, ValueError): 41 pass 42 43 44def builtin_iterable(func): 45 """Wraps `func()` output in a `list()` in py2""" 46 if sys.version_info[:1] < (3,): 47 @wraps(func) 48 def inner(*args, **kwargs): 49 return list(func(*args, **kwargs)) 50 return inner 51 return func 52 53 54def tenumerate(iterable, start=0, total=None, tqdm_class=tqdm_auto, **tqdm_kwargs): 55 """ 56 Equivalent of `numpy.ndenumerate` or builtin `enumerate`. 57 58 Parameters 59 ---------- 60 tqdm_class : [default: tqdm.auto.tqdm]. 61 """ 62 try: 63 import numpy as np 64 except ImportError: 65 pass 66 else: 67 if isinstance(iterable, np.ndarray): 68 return tqdm_class(np.ndenumerate(iterable), total=total or iterable.size, 69 **tqdm_kwargs) 70 return enumerate(tqdm_class(iterable, total=total, **tqdm_kwargs), start) 71 72 73@builtin_iterable 74def tzip(iter1, *iter2plus, **tqdm_kwargs): 75 """ 76 Equivalent of builtin `zip`. 77 78 Parameters 79 ---------- 80 tqdm_class : [default: tqdm.auto.tqdm]. 81 """ 82 kwargs = tqdm_kwargs.copy() 83 tqdm_class = kwargs.pop("tqdm_class", tqdm_auto) 84 for i in zip(tqdm_class(iter1, **kwargs), *iter2plus): 85 yield i 86 87 88@builtin_iterable 89def tmap(function, *sequences, **tqdm_kwargs): 90 """ 91 Equivalent of builtin `map`. 92 93 Parameters 94 ---------- 95 tqdm_class : [default: tqdm.auto.tqdm]. 96 """ 97 for i in tzip(*sequences, **tqdm_kwargs): 98 yield function(*i) 99