1import asyncio
2import logging
3import random
4from collections import defaultdict
5from functools import partial
6from itertools import cycle
8from tlz import concat, drop, groupby, merge
10import dask.config
11from dask.optimization import SubgraphCallable
12from dask.utils import parse_timedelta, stringify
14from .core import rpc
15from .utils import All
17logger = logging.getLogger(__name__)
20async def gather_from_workers(who_has, rpc, close=True, serializers=None, who=None):
21    """Gather data directly from peers
23    Parameters
24    ----------
25    who_has: dict
26        Dict mapping keys to sets of workers that may have that key
27    rpc: callable
29    Returns dict mapping key to value
31    See Also
32    --------
33    gather
34    _gather
35    """
36    from .worker import get_data_from_worker
38    bad_addresses = set()
39    missing_workers = set()
40    original_who_has = who_has
41    who_has = {k: set(v) for k, v in who_has.items()}
42    results = dict()
43    all_bad_keys = set()
45    while len(results) + len(all_bad_keys) < len(who_has):
46        d = defaultdict(list)
47        rev = dict()
48        bad_keys = set()
49        for key, addresses in who_has.items():
50            if key in results:
51                continue
52            try:
53                addr = random.choice(list(addresses - bad_addresses))
54                d[addr].append(key)
55                rev[key] = addr
56            except IndexError:
57                bad_keys.add(key)
58        if bad_keys:
59            all_bad_keys |= bad_keys
61        rpcs = {addr: rpc(addr) for addr in d}
62        try:
63            coroutines = {
64                address: asyncio.ensure_future(
65                    get_data_from_worker(
66                        rpc,
67                        keys,
68                        address,
69                        who=who,
70                        serializers=serializers,
71                        max_connections=False,
72                    )
73                )
74                for address, keys in d.items()
75            }
76            response = {}
77            for worker, c in coroutines.items():
78                try:
79                    r = await c
80                except OSError:
81                    missing_workers.add(worker)
82                except ValueError as e:
83                    logger.info(
84                        "Got an unexpected error while collecting from workers: %s", e
85                    )
86                    missing_workers.add(worker)
87                else:
88                    response.update(r["data"])
89        finally:
90            for r in rpcs.values():
91                await r.close_rpc()
93        bad_addresses |= {v for k, v in rev.items() if k not in response}
94        results.update(response)
96    bad_keys = {k: list(original_who_has[k]) for k in all_bad_keys}
97    return (results, bad_keys, list(missing_workers))
100class WrappedKey:
101    """Interface for a key in a dask graph.
103    Subclasses must have .key attribute that refers to a key in a dask graph.
105    Sometimes we want to associate metadata to keys in a dask graph.  For
106    example we might know that that key lives on a particular machine or can
107    only be accessed in a certain way.  Schedulers may have particular needs
108    that can only be addressed by additional metadata.
109    """
111    def __init__(self, key):
112        self.key = key
114    def __repr__(self):
115        return f"{type(self).__name__}('{self.key}')"
118_round_robin_counter = [0]
121async def scatter_to_workers(nthreads, data, rpc=rpc, report=True, serializers=None):
122    """Scatter data directly to workers
124    This distributes data in a round-robin fashion to a set of workers based on
125    how many cores they have.  nthreads should be a dictionary mapping worker
126    identities to numbers of cores.
128    See scatter for parameter docstring
129    """
130    assert isinstance(nthreads, dict)
131    assert isinstance(data, dict)
133    workers = list(concat([w] * nc for w, nc in nthreads.items()))
134    names, data = list(zip(*data.items()))
136    worker_iter = drop(_round_robin_counter[0] % len(workers), cycle(workers))
137    _round_robin_counter[0] += len(data)
139    L = list(zip(worker_iter, names, data))
140    d = groupby(0, L)
141    d = {worker: {key: value for _, key, value in v} for worker, v in d.items()}
143    rpcs = {addr: rpc(addr) for addr in d}
144    try:
145        out = await All(
146            [
147                rpcs[address].update_data(
148                    data=v, report=report, serializers=serializers
149                )
150                for address, v in d.items()
151            ]
152        )
153    finally:
154        for r in rpcs.values():
155            await r.close_rpc()
157    nbytes = merge(o["nbytes"] for o in out)
159    who_has = {k: [w for w, _, _ in v] for k, v in groupby(1, L).items()}
161    return (names, who_has, nbytes)
164collection_types = (tuple, list, set, frozenset)
167def unpack_remotedata(o, byte_keys=False, myset=None):
168    """Unpack WrappedKey objects from collection
170    Returns original collection and set of all found WrappedKey objects
172    Examples
173    --------
174    >>> rd = WrappedKey('mykey')
175    >>> unpack_remotedata(1)
176    (1, set())
177    >>> unpack_remotedata(())
178    ((), set())
179    >>> unpack_remotedata(rd)
180    ('mykey', {WrappedKey('mykey')})
181    >>> unpack_remotedata([1, rd])
182    ([1, 'mykey'], {WrappedKey('mykey')})
183    >>> unpack_remotedata({1: rd})
184    ({1: 'mykey'}, {WrappedKey('mykey')})
185    >>> unpack_remotedata({1: [rd]})
186    ({1: ['mykey']}, {WrappedKey('mykey')})
188    Use the ``byte_keys=True`` keyword to force string keys
190    >>> rd = WrappedKey(('x', 1))
191    >>> unpack_remotedata(rd, byte_keys=True)
192    ("('x', 1)", {WrappedKey('('x', 1)')})
193    """
194    if myset is None:
195        myset = set()
196        out = unpack_remotedata(o, byte_keys, myset)
197        return out, myset
199    typ = type(o)
201    if typ is tuple:
202        if not o:
203            return o
204        if type(o[0]) is SubgraphCallable:
205            sc = o[0]
206            futures = set()
207            dsk = {
208                k: unpack_remotedata(v, byte_keys, futures) for k, v in sc.dsk.items()
209            }
210            args = tuple(unpack_remotedata(i, byte_keys, futures) for i in o[1:])
211            if futures:
212                myset.update(futures)
213                futures = (
214                    tuple(stringify(f.key) for f in futures)
215                    if byte_keys
216                    else tuple(f.key for f in futures)
217                )
218                inkeys = sc.inkeys + futures
219                return (
220                    (SubgraphCallable(dsk, sc.outkey, inkeys, sc.name),)
221                    + args
222                    + futures
223                )
224            else:
225                return o
226        else:
227            return tuple(unpack_remotedata(item, byte_keys, myset) for item in o)
228    if typ in collection_types:
229        if not o:
230            return o
231        outs = [unpack_remotedata(item, byte_keys, myset) for item in o]
232        return typ(outs)
233    elif typ is dict:
234        if o:
235            return {k: unpack_remotedata(v, byte_keys, myset) for k, v in o.items()}
236        else:
237            return o
238    elif issubclass(typ, WrappedKey):  # TODO use type is Future
239        k = o.key
240        if byte_keys:
241            k = stringify(k)
242        myset.add(o)
243        return k
244    else:
245        return o
248def pack_data(o, d, key_types=object):
249    """Merge known data into tuple or dict
251    Parameters
252    ----------
253    o
254        core data structures containing literals and keys
255    d : dict
256        mapping of keys to data
258    Examples
259    --------
260    >>> data = {'x': 1}
261    >>> pack_data(('x', 'y'), data)
262    (1, 'y')
263    >>> pack_data({'a': 'x', 'b': 'y'}, data)  # doctest: +SKIP
264    {'a': 1, 'b': 'y'}
265    >>> pack_data({'a': ['x'], 'b': 'y'}, data)  # doctest: +SKIP
266    {'a': [1], 'b': 'y'}
267    """
268    typ = type(o)
269    try:
270        if isinstance(o, key_types) and o in d:
271            return d[o]
272    except TypeError:
273        pass
275    if typ in collection_types:
276        return typ([pack_data(x, d, key_types=key_types) for x in o])
277    elif typ is dict:
278        return {k: pack_data(v, d, key_types=key_types) for k, v in o.items()}
279    else:
280        return o
283def subs_multiple(o, d):
284    """Perform substitutions on a tasks
286    Parameters
287    ----------
288    o
289        Core data structures containing literals and keys
290    d : dict
291        Mapping of keys to values
293    Examples
294    --------
295    >>> dsk = {"a": (sum, ["x", 2])}
296    >>> data = {"x": 1}
297    >>> subs_multiple(dsk, data)  # doctest: +SKIP
298    {'a': (sum, [1, 2])}
300    """
301    typ = type(o)
302    if typ is tuple and o and callable(o[0]):  # istask(o)
303        return (o[0],) + tuple(subs_multiple(i, d) for i in o[1:])
304    elif typ is list:
305        return [subs_multiple(i, d) for i in o]
306    elif typ is dict:
307        return {k: subs_multiple(v, d) for (k, v) in o.items()}
308    else:
309        try:
310            return d.get(o, o)
311        except TypeError:
312            return o
315async def retry(
316    coro,
317    count,
318    delay_min,
319    delay_max,
320    jitter_fraction=0.1,
321    retry_on_exceptions=(EnvironmentError, IOError),
322    operation=None,
324    """
325    Return the result of ``await coro()``, re-trying in case of exceptions
327    The delay between attempts is ``delay_min * (2 ** i - 1)`` where ``i`` enumerates the attempt that just failed
328    (starting at 0), but never larger than ``delay_max``.
329    This yields no delay between the first and second attempt, then ``delay_min``, ``3 * delay_min``, etc.
330    (The reason to re-try with no delay is that in most cases this is sufficient and will thus recover faster
331    from a communication failure).
333    Parameters
334    ----------
335    coro
336        The coroutine function to call and await
337    count
338        The maximum number of re-tries before giving up. 0 means no re-try; must be >= 0.
339    delay_min
340        The base factor for the delay (in seconds); this is the first non-zero delay between re-tries.
341    delay_max
342        The maximum delay (in seconds) between consecutive re-tries (without jitter)
343    jitter_fraction
344        The maximum jitter to add to the delay, as fraction of the total delay. No jitter is added if this
345        value is <= 0.
346        Using a non-zero value here avoids "herd effects" of many operations re-tried at the same time
347    retry_on_exceptions
348        A tuple of exception classes to retry. Other exceptions are not caught and re-tried, but propagate immediately.
349    operation
350        A human-readable description of the operation attempted; used only for logging failures
352    Returns
353    -------
354    Any
355        Whatever `await `coro()` returned
356    """
357    # this loop is a no-op in case max_retries<=0
358    for i_try in range(count):
359        try:
360            return await coro()
361        except retry_on_exceptions as ex:
362            operation = operation or str(coro)
363            logger.info(
364                f"Retrying {operation} after exception in attempt {i_try}/{count}: {ex}"
365            )
366            delay = min(delay_min * (2 ** i_try - 1), delay_max)
367            if jitter_fraction > 0:
368                delay *= 1 + random.random() * jitter_fraction
369            await asyncio.sleep(delay)
370    return await coro()
373async def retry_operation(coro, *args, operation=None, **kwargs):
374    """
375    Retry an operation using the configuration values for the retry parameters
376    """
378    retry_count = dask.config.get("distributed.comm.retry.count")
379    retry_delay_min = parse_timedelta(
380        dask.config.get("distributed.comm.retry.delay.min"), default="s"
381    )
382    retry_delay_max = parse_timedelta(
383        dask.config.get("distributed.comm.retry.delay.max"), default="s"
384    )
385    return await retry(
386        partial(coro, *args, **kwargs),
387        count=retry_count,
388        delay_min=retry_delay_min,
389        delay_max=retry_delay_max,
390        operation=operation,
391    )