1import asyncio
2import logging
3import random
4from collections import defaultdict
5from functools import partial
6from itertools import cycle
7
8from tlz import concat, drop, groupby, merge
9
10import dask.config
11from dask.optimization import SubgraphCallable
12from dask.utils import parse_timedelta, stringify
13
14from .core import rpc
15from .utils import All
16
17logger = logging.getLogger(__name__)
18
19
20async def gather_from_workers(who_has, rpc, close=True, serializers=None, who=None):
21    """Gather data directly from peers
22
23    Parameters
24    ----------
25    who_has: dict
26        Dict mapping keys to sets of workers that may have that key
27    rpc: callable
28
29    Returns dict mapping key to value
30
31    See Also
32    --------
33    gather
34    _gather
35    """
36    from .worker import get_data_from_worker
37
38    bad_addresses = set()
39    missing_workers = set()
40    original_who_has = who_has
41    who_has = {k: set(v) for k, v in who_has.items()}
42    results = dict()
43    all_bad_keys = set()
44
45    while len(results) + len(all_bad_keys) < len(who_has):
46        d = defaultdict(list)
47        rev = dict()
48        bad_keys = set()
49        for key, addresses in who_has.items():
50            if key in results:
51                continue
52            try:
53                addr = random.choice(list(addresses - bad_addresses))
54                d[addr].append(key)
55                rev[key] = addr
56            except IndexError:
57                bad_keys.add(key)
58        if bad_keys:
59            all_bad_keys |= bad_keys
60
61        rpcs = {addr: rpc(addr) for addr in d}
62        try:
63            coroutines = {
64                address: asyncio.ensure_future(
65                    get_data_from_worker(
66                        rpc,
67                        keys,
68                        address,
69                        who=who,
70                        serializers=serializers,
71                        max_connections=False,
72                    )
73                )
74                for address, keys in d.items()
75            }
76            response = {}
77            for worker, c in coroutines.items():
78                try:
79                    r = await c
80                except OSError:
81                    missing_workers.add(worker)
82                except ValueError as e:
83                    logger.info(
84                        "Got an unexpected error while collecting from workers: %s", e
85                    )
86                    missing_workers.add(worker)
87                else:
88                    response.update(r["data"])
89        finally:
90            for r in rpcs.values():
91                await r.close_rpc()
92
93        bad_addresses |= {v for k, v in rev.items() if k not in response}
94        results.update(response)
95
96    bad_keys = {k: list(original_who_has[k]) for k in all_bad_keys}
97    return (results, bad_keys, list(missing_workers))
98
99
100class WrappedKey:
101    """Interface for a key in a dask graph.
102
103    Subclasses must have .key attribute that refers to a key in a dask graph.
104
105    Sometimes we want to associate metadata to keys in a dask graph.  For
106    example we might know that that key lives on a particular machine or can
107    only be accessed in a certain way.  Schedulers may have particular needs
108    that can only be addressed by additional metadata.
109    """
110
111    def __init__(self, key):
112        self.key = key
113
114    def __repr__(self):
115        return f"{type(self).__name__}('{self.key}')"
116
117
118_round_robin_counter = [0]
119
120
121async def scatter_to_workers(nthreads, data, rpc=rpc, report=True, serializers=None):
122    """Scatter data directly to workers
123
124    This distributes data in a round-robin fashion to a set of workers based on
125    how many cores they have.  nthreads should be a dictionary mapping worker
126    identities to numbers of cores.
127
128    See scatter for parameter docstring
129    """
130    assert isinstance(nthreads, dict)
131    assert isinstance(data, dict)
132
133    workers = list(concat([w] * nc for w, nc in nthreads.items()))
134    names, data = list(zip(*data.items()))
135
136    worker_iter = drop(_round_robin_counter[0] % len(workers), cycle(workers))
137    _round_robin_counter[0] += len(data)
138
139    L = list(zip(worker_iter, names, data))
140    d = groupby(0, L)
141    d = {worker: {key: value for _, key, value in v} for worker, v in d.items()}
142
143    rpcs = {addr: rpc(addr) for addr in d}
144    try:
145        out = await All(
146            [
147                rpcs[address].update_data(
148                    data=v, report=report, serializers=serializers
149                )
150                for address, v in d.items()
151            ]
152        )
153    finally:
154        for r in rpcs.values():
155            await r.close_rpc()
156
157    nbytes = merge(o["nbytes"] for o in out)
158
159    who_has = {k: [w for w, _, _ in v] for k, v in groupby(1, L).items()}
160
161    return (names, who_has, nbytes)
162
163
164collection_types = (tuple, list, set, frozenset)
165
166
167def unpack_remotedata(o, byte_keys=False, myset=None):
168    """Unpack WrappedKey objects from collection
169
170    Returns original collection and set of all found WrappedKey objects
171
172    Examples
173    --------
174    >>> rd = WrappedKey('mykey')
175    >>> unpack_remotedata(1)
176    (1, set())
177    >>> unpack_remotedata(())
178    ((), set())
179    >>> unpack_remotedata(rd)
180    ('mykey', {WrappedKey('mykey')})
181    >>> unpack_remotedata([1, rd])
182    ([1, 'mykey'], {WrappedKey('mykey')})
183    >>> unpack_remotedata({1: rd})
184    ({1: 'mykey'}, {WrappedKey('mykey')})
185    >>> unpack_remotedata({1: [rd]})
186    ({1: ['mykey']}, {WrappedKey('mykey')})
187
188    Use the ``byte_keys=True`` keyword to force string keys
189
190    >>> rd = WrappedKey(('x', 1))
191    >>> unpack_remotedata(rd, byte_keys=True)
192    ("('x', 1)", {WrappedKey('('x', 1)')})
193    """
194    if myset is None:
195        myset = set()
196        out = unpack_remotedata(o, byte_keys, myset)
197        return out, myset
198
199    typ = type(o)
200
201    if typ is tuple:
202        if not o:
203            return o
204        if type(o[0]) is SubgraphCallable:
205            sc = o[0]
206            futures = set()
207            dsk = {
208                k: unpack_remotedata(v, byte_keys, futures) for k, v in sc.dsk.items()
209            }
210            args = tuple(unpack_remotedata(i, byte_keys, futures) for i in o[1:])
211            if futures:
212                myset.update(futures)
213                futures = (
214                    tuple(stringify(f.key) for f in futures)
215                    if byte_keys
216                    else tuple(f.key for f in futures)
217                )
218                inkeys = sc.inkeys + futures
219                return (
220                    (SubgraphCallable(dsk, sc.outkey, inkeys, sc.name),)
221                    + args
222                    + futures
223                )
224            else:
225                return o
226        else:
227            return tuple(unpack_remotedata(item, byte_keys, myset) for item in o)
228    if typ in collection_types:
229        if not o:
230            return o
231        outs = [unpack_remotedata(item, byte_keys, myset) for item in o]
232        return typ(outs)
233    elif typ is dict:
234        if o:
235            return {k: unpack_remotedata(v, byte_keys, myset) for k, v in o.items()}
236        else:
237            return o
238    elif issubclass(typ, WrappedKey):  # TODO use type is Future
239        k = o.key
240        if byte_keys:
241            k = stringify(k)
242        myset.add(o)
243        return k
244    else:
245        return o
246
247
248def pack_data(o, d, key_types=object):
249    """Merge known data into tuple or dict
250
251    Parameters
252    ----------
253    o
254        core data structures containing literals and keys
255    d : dict
256        mapping of keys to data
257
258    Examples
259    --------
260    >>> data = {'x': 1}
261    >>> pack_data(('x', 'y'), data)
262    (1, 'y')
263    >>> pack_data({'a': 'x', 'b': 'y'}, data)  # doctest: +SKIP
264    {'a': 1, 'b': 'y'}
265    >>> pack_data({'a': ['x'], 'b': 'y'}, data)  # doctest: +SKIP
266    {'a': [1], 'b': 'y'}
267    """
268    typ = type(o)
269    try:
270        if isinstance(o, key_types) and o in d:
271            return d[o]
272    except TypeError:
273        pass
274
275    if typ in collection_types:
276        return typ([pack_data(x, d, key_types=key_types) for x in o])
277    elif typ is dict:
278        return {k: pack_data(v, d, key_types=key_types) for k, v in o.items()}
279    else:
280        return o
281
282
283def subs_multiple(o, d):
284    """Perform substitutions on a tasks
285
286    Parameters
287    ----------
288    o
289        Core data structures containing literals and keys
290    d : dict
291        Mapping of keys to values
292
293    Examples
294    --------
295    >>> dsk = {"a": (sum, ["x", 2])}
296    >>> data = {"x": 1}
297    >>> subs_multiple(dsk, data)  # doctest: +SKIP
298    {'a': (sum, [1, 2])}
299
300    """
301    typ = type(o)
302    if typ is tuple and o and callable(o[0]):  # istask(o)
303        return (o[0],) + tuple(subs_multiple(i, d) for i in o[1:])
304    elif typ is list:
305        return [subs_multiple(i, d) for i in o]
306    elif typ is dict:
307        return {k: subs_multiple(v, d) for (k, v) in o.items()}
308    else:
309        try:
310            return d.get(o, o)
311        except TypeError:
312            return o
313
314
315async def retry(
316    coro,
317    count,
318    delay_min,
319    delay_max,
320    jitter_fraction=0.1,
321    retry_on_exceptions=(EnvironmentError, IOError),
322    operation=None,
323):
324    """
325    Return the result of ``await coro()``, re-trying in case of exceptions
326
327    The delay between attempts is ``delay_min * (2 ** i - 1)`` where ``i`` enumerates the attempt that just failed
328    (starting at 0), but never larger than ``delay_max``.
329    This yields no delay between the first and second attempt, then ``delay_min``, ``3 * delay_min``, etc.
330    (The reason to re-try with no delay is that in most cases this is sufficient and will thus recover faster
331    from a communication failure).
332
333    Parameters
334    ----------
335    coro
336        The coroutine function to call and await
337    count
338        The maximum number of re-tries before giving up. 0 means no re-try; must be >= 0.
339    delay_min
340        The base factor for the delay (in seconds); this is the first non-zero delay between re-tries.
341    delay_max
342        The maximum delay (in seconds) between consecutive re-tries (without jitter)
343    jitter_fraction
344        The maximum jitter to add to the delay, as fraction of the total delay. No jitter is added if this
345        value is <= 0.
346        Using a non-zero value here avoids "herd effects" of many operations re-tried at the same time
347    retry_on_exceptions
348        A tuple of exception classes to retry. Other exceptions are not caught and re-tried, but propagate immediately.
349    operation
350        A human-readable description of the operation attempted; used only for logging failures
351
352    Returns
353    -------
354    Any
355        Whatever `await `coro()` returned
356    """
357    # this loop is a no-op in case max_retries<=0
358    for i_try in range(count):
359        try:
360            return await coro()
361        except retry_on_exceptions as ex:
362            operation = operation or str(coro)
363            logger.info(
364                f"Retrying {operation} after exception in attempt {i_try}/{count}: {ex}"
365            )
366            delay = min(delay_min * (2 ** i_try - 1), delay_max)
367            if jitter_fraction > 0:
368                delay *= 1 + random.random() * jitter_fraction
369            await asyncio.sleep(delay)
370    return await coro()
371
372
373async def retry_operation(coro, *args, operation=None, **kwargs):
374    """
375    Retry an operation using the configuration values for the retry parameters
376    """
377
378    retry_count = dask.config.get("distributed.comm.retry.count")
379    retry_delay_min = parse_timedelta(
380        dask.config.get("distributed.comm.retry.delay.min"), default="s"
381    )
382    retry_delay_max = parse_timedelta(
383        dask.config.get("distributed.comm.retry.delay.max"), default="s"
384    )
385    return await retry(
386        partial(coro, *args, **kwargs),
387        count=retry_count,
388        delay_min=retry_delay_min,
389        delay_max=retry_delay_max,
390        operation=operation,
391    )
392