1import asyncio 2import logging 3import random 4from collections import defaultdict 5from functools import partial 6from itertools import cycle 7 8from tlz import concat, drop, groupby, merge 9 10import dask.config 11from dask.optimization import SubgraphCallable 12from dask.utils import parse_timedelta, stringify 13 14from .core import rpc 15from .utils import All 16 17logger = logging.getLogger(__name__) 18 19 20async def gather_from_workers(who_has, rpc, close=True, serializers=None, who=None): 21 """Gather data directly from peers 22 23 Parameters 24 ---------- 25 who_has: dict 26 Dict mapping keys to sets of workers that may have that key 27 rpc: callable 28 29 Returns dict mapping key to value 30 31 See Also 32 -------- 33 gather 34 _gather 35 """ 36 from .worker import get_data_from_worker 37 38 bad_addresses = set() 39 missing_workers = set() 40 original_who_has = who_has 41 who_has = {k: set(v) for k, v in who_has.items()} 42 results = dict() 43 all_bad_keys = set() 44 45 while len(results) + len(all_bad_keys) < len(who_has): 46 d = defaultdict(list) 47 rev = dict() 48 bad_keys = set() 49 for key, addresses in who_has.items(): 50 if key in results: 51 continue 52 try: 53 addr = random.choice(list(addresses - bad_addresses)) 54 d[addr].append(key) 55 rev[key] = addr 56 except IndexError: 57 bad_keys.add(key) 58 if bad_keys: 59 all_bad_keys |= bad_keys 60 61 rpcs = {addr: rpc(addr) for addr in d} 62 try: 63 coroutines = { 64 address: asyncio.ensure_future( 65 get_data_from_worker( 66 rpc, 67 keys, 68 address, 69 who=who, 70 serializers=serializers, 71 max_connections=False, 72 ) 73 ) 74 for address, keys in d.items() 75 } 76 response = {} 77 for worker, c in coroutines.items(): 78 try: 79 r = await c 80 except OSError: 81 missing_workers.add(worker) 82 except ValueError as e: 83 logger.info( 84 "Got an unexpected error while collecting from workers: %s", e 85 ) 86 missing_workers.add(worker) 87 else: 88 response.update(r["data"]) 89 finally: 90 for r in rpcs.values(): 91 await r.close_rpc() 92 93 bad_addresses |= {v for k, v in rev.items() if k not in response} 94 results.update(response) 95 96 bad_keys = {k: list(original_who_has[k]) for k in all_bad_keys} 97 return (results, bad_keys, list(missing_workers)) 98 99 100class WrappedKey: 101 """Interface for a key in a dask graph. 102 103 Subclasses must have .key attribute that refers to a key in a dask graph. 104 105 Sometimes we want to associate metadata to keys in a dask graph. For 106 example we might know that that key lives on a particular machine or can 107 only be accessed in a certain way. Schedulers may have particular needs 108 that can only be addressed by additional metadata. 109 """ 110 111 def __init__(self, key): 112 self.key = key 113 114 def __repr__(self): 115 return f"{type(self).__name__}('{self.key}')" 116 117 118_round_robin_counter = [0] 119 120 121async def scatter_to_workers(nthreads, data, rpc=rpc, report=True, serializers=None): 122 """Scatter data directly to workers 123 124 This distributes data in a round-robin fashion to a set of workers based on 125 how many cores they have. nthreads should be a dictionary mapping worker 126 identities to numbers of cores. 127 128 See scatter for parameter docstring 129 """ 130 assert isinstance(nthreads, dict) 131 assert isinstance(data, dict) 132 133 workers = list(concat([w] * nc for w, nc in nthreads.items())) 134 names, data = list(zip(*data.items())) 135 136 worker_iter = drop(_round_robin_counter[0] % len(workers), cycle(workers)) 137 _round_robin_counter[0] += len(data) 138 139 L = list(zip(worker_iter, names, data)) 140 d = groupby(0, L) 141 d = {worker: {key: value for _, key, value in v} for worker, v in d.items()} 142 143 rpcs = {addr: rpc(addr) for addr in d} 144 try: 145 out = await All( 146 [ 147 rpcs[address].update_data( 148 data=v, report=report, serializers=serializers 149 ) 150 for address, v in d.items() 151 ] 152 ) 153 finally: 154 for r in rpcs.values(): 155 await r.close_rpc() 156 157 nbytes = merge(o["nbytes"] for o in out) 158 159 who_has = {k: [w for w, _, _ in v] for k, v in groupby(1, L).items()} 160 161 return (names, who_has, nbytes) 162 163 164collection_types = (tuple, list, set, frozenset) 165 166 167def unpack_remotedata(o, byte_keys=False, myset=None): 168 """Unpack WrappedKey objects from collection 169 170 Returns original collection and set of all found WrappedKey objects 171 172 Examples 173 -------- 174 >>> rd = WrappedKey('mykey') 175 >>> unpack_remotedata(1) 176 (1, set()) 177 >>> unpack_remotedata(()) 178 ((), set()) 179 >>> unpack_remotedata(rd) 180 ('mykey', {WrappedKey('mykey')}) 181 >>> unpack_remotedata([1, rd]) 182 ([1, 'mykey'], {WrappedKey('mykey')}) 183 >>> unpack_remotedata({1: rd}) 184 ({1: 'mykey'}, {WrappedKey('mykey')}) 185 >>> unpack_remotedata({1: [rd]}) 186 ({1: ['mykey']}, {WrappedKey('mykey')}) 187 188 Use the ``byte_keys=True`` keyword to force string keys 189 190 >>> rd = WrappedKey(('x', 1)) 191 >>> unpack_remotedata(rd, byte_keys=True) 192 ("('x', 1)", {WrappedKey('('x', 1)')}) 193 """ 194 if myset is None: 195 myset = set() 196 out = unpack_remotedata(o, byte_keys, myset) 197 return out, myset 198 199 typ = type(o) 200 201 if typ is tuple: 202 if not o: 203 return o 204 if type(o[0]) is SubgraphCallable: 205 sc = o[0] 206 futures = set() 207 dsk = { 208 k: unpack_remotedata(v, byte_keys, futures) for k, v in sc.dsk.items() 209 } 210 args = tuple(unpack_remotedata(i, byte_keys, futures) for i in o[1:]) 211 if futures: 212 myset.update(futures) 213 futures = ( 214 tuple(stringify(f.key) for f in futures) 215 if byte_keys 216 else tuple(f.key for f in futures) 217 ) 218 inkeys = sc.inkeys + futures 219 return ( 220 (SubgraphCallable(dsk, sc.outkey, inkeys, sc.name),) 221 + args 222 + futures 223 ) 224 else: 225 return o 226 else: 227 return tuple(unpack_remotedata(item, byte_keys, myset) for item in o) 228 if typ in collection_types: 229 if not o: 230 return o 231 outs = [unpack_remotedata(item, byte_keys, myset) for item in o] 232 return typ(outs) 233 elif typ is dict: 234 if o: 235 return {k: unpack_remotedata(v, byte_keys, myset) for k, v in o.items()} 236 else: 237 return o 238 elif issubclass(typ, WrappedKey): # TODO use type is Future 239 k = o.key 240 if byte_keys: 241 k = stringify(k) 242 myset.add(o) 243 return k 244 else: 245 return o 246 247 248def pack_data(o, d, key_types=object): 249 """Merge known data into tuple or dict 250 251 Parameters 252 ---------- 253 o 254 core data structures containing literals and keys 255 d : dict 256 mapping of keys to data 257 258 Examples 259 -------- 260 >>> data = {'x': 1} 261 >>> pack_data(('x', 'y'), data) 262 (1, 'y') 263 >>> pack_data({'a': 'x', 'b': 'y'}, data) # doctest: +SKIP 264 {'a': 1, 'b': 'y'} 265 >>> pack_data({'a': ['x'], 'b': 'y'}, data) # doctest: +SKIP 266 {'a': [1], 'b': 'y'} 267 """ 268 typ = type(o) 269 try: 270 if isinstance(o, key_types) and o in d: 271 return d[o] 272 except TypeError: 273 pass 274 275 if typ in collection_types: 276 return typ([pack_data(x, d, key_types=key_types) for x in o]) 277 elif typ is dict: 278 return {k: pack_data(v, d, key_types=key_types) for k, v in o.items()} 279 else: 280 return o 281 282 283def subs_multiple(o, d): 284 """Perform substitutions on a tasks 285 286 Parameters 287 ---------- 288 o 289 Core data structures containing literals and keys 290 d : dict 291 Mapping of keys to values 292 293 Examples 294 -------- 295 >>> dsk = {"a": (sum, ["x", 2])} 296 >>> data = {"x": 1} 297 >>> subs_multiple(dsk, data) # doctest: +SKIP 298 {'a': (sum, [1, 2])} 299 300 """ 301 typ = type(o) 302 if typ is tuple and o and callable(o[0]): # istask(o) 303 return (o[0],) + tuple(subs_multiple(i, d) for i in o[1:]) 304 elif typ is list: 305 return [subs_multiple(i, d) for i in o] 306 elif typ is dict: 307 return {k: subs_multiple(v, d) for (k, v) in o.items()} 308 else: 309 try: 310 return d.get(o, o) 311 except TypeError: 312 return o 313 314 315async def retry( 316 coro, 317 count, 318 delay_min, 319 delay_max, 320 jitter_fraction=0.1, 321 retry_on_exceptions=(EnvironmentError, IOError), 322 operation=None, 323): 324 """ 325 Return the result of ``await coro()``, re-trying in case of exceptions 326 327 The delay between attempts is ``delay_min * (2 ** i - 1)`` where ``i`` enumerates the attempt that just failed 328 (starting at 0), but never larger than ``delay_max``. 329 This yields no delay between the first and second attempt, then ``delay_min``, ``3 * delay_min``, etc. 330 (The reason to re-try with no delay is that in most cases this is sufficient and will thus recover faster 331 from a communication failure). 332 333 Parameters 334 ---------- 335 coro 336 The coroutine function to call and await 337 count 338 The maximum number of re-tries before giving up. 0 means no re-try; must be >= 0. 339 delay_min 340 The base factor for the delay (in seconds); this is the first non-zero delay between re-tries. 341 delay_max 342 The maximum delay (in seconds) between consecutive re-tries (without jitter) 343 jitter_fraction 344 The maximum jitter to add to the delay, as fraction of the total delay. No jitter is added if this 345 value is <= 0. 346 Using a non-zero value here avoids "herd effects" of many operations re-tried at the same time 347 retry_on_exceptions 348 A tuple of exception classes to retry. Other exceptions are not caught and re-tried, but propagate immediately. 349 operation 350 A human-readable description of the operation attempted; used only for logging failures 351 352 Returns 353 ------- 354 Any 355 Whatever `await `coro()` returned 356 """ 357 # this loop is a no-op in case max_retries<=0 358 for i_try in range(count): 359 try: 360 return await coro() 361 except retry_on_exceptions as ex: 362 operation = operation or str(coro) 363 logger.info( 364 f"Retrying {operation} after exception in attempt {i_try}/{count}: {ex}" 365 ) 366 delay = min(delay_min * (2 ** i_try - 1), delay_max) 367 if jitter_fraction > 0: 368 delay *= 1 + random.random() * jitter_fraction 369 await asyncio.sleep(delay) 370 return await coro() 371 372 373async def retry_operation(coro, *args, operation=None, **kwargs): 374 """ 375 Retry an operation using the configuration values for the retry parameters 376 """ 377 378 retry_count = dask.config.get("distributed.comm.retry.count") 379 retry_delay_min = parse_timedelta( 380 dask.config.get("distributed.comm.retry.delay.min"), default="s" 381 ) 382 retry_delay_max = parse_timedelta( 383 dask.config.get("distributed.comm.retry.delay.max"), default="s" 384 ) 385 return await retry( 386 partial(coro, *args, **kwargs), 387 count=retry_count, 388 delay_min=retry_delay_min, 389 delay_max=retry_delay_max, 390 operation=operation, 391 ) 392