1import itertools 2import os 3from itertools import product 4from typing import ( 5 Any, 6 Hashable, 7 Iterable, 8 List, 9 Mapping, 10 Optional, 11 Sequence, 12 Set, 13 Tuple, 14 Union, 15) 16 17import tlz as toolz 18 19from .base import clone_key, get_name_from_key, tokenize 20from .compatibility import prod 21from .core import flatten, keys_in_tasks, reverse_dict 22from .delayed import unpack_collections 23from .highlevelgraph import HighLevelGraph, Layer 24from .optimization import SubgraphCallable, fuse 25from .utils import ( 26 _deprecated, 27 apply, 28 ensure_dict, 29 homogeneous_deepmap, 30 stringify, 31 stringify_collection_keys, 32) 33 34 35class BlockwiseDep: 36 """Blockwise-IO argument 37 38 This is the base class for indexable Blockwise-IO arguments. 39 When constructing a ``Blockwise`` Layer, one or more of the 40 collection tuples passed in with ``indices`` may contain a 41 ``BlockwiseDep`` instance (in place of a "real" collection name). 42 This allows a new collection to be created (via IO) within a 43 ``Blockwise`` layer. 44 45 All ``BlockwiseDep`` instances must define a ``numblocks`` 46 attribute to speficy the number of blocks/partitions the 47 object can support along each dimension. The object should 48 also define a ``produces_tasks`` attribute to specify if 49 any nested tasks will be passed to the Blockwise function. 50 51 See Also 52 -------- 53 dask.blockwise.Blockwise 54 dask.blockwise.BlockwiseDepDict 55 """ 56 57 numblocks: Tuple[int, ...] 58 produces_tasks: bool 59 60 def __getitem__(self, idx: Tuple[int, ...]) -> Any: 61 """Return Blockwise-function arguments for a specific index""" 62 raise NotImplementedError( 63 "Must define `__getitem__` for `BlockwiseDep` subclass." 64 ) 65 66 def get(self, idx: Tuple[int, ...], default) -> Any: 67 """BlockwiseDep ``__getitem__`` Wrapper""" 68 try: 69 return self.__getitem__(idx) 70 except KeyError: 71 return default 72 73 def __dask_distributed_pack__( 74 self, required_indices: Optional[List[Tuple[int, ...]]] = None 75 ): 76 """Client-side serialization for ``BlockwiseDep`` objects. 77 78 Should return a ``state`` dictionary, with msgpack-serializable 79 values, that can be used to initialize a new ``BlockwiseDep`` object 80 on a scheduler process. 81 """ 82 raise NotImplementedError( 83 "Must define `__dask_distributed_pack__` for `BlockwiseDep` subclass." 84 ) 85 86 @classmethod 87 def __dask_distributed_unpack__(cls, state): 88 """Scheduler-side deserialization for ``BlockwiseDep`` objects. 89 90 Should use an input ``state`` dictionary to initialize a new 91 ``BlockwiseDep`` object. 92 """ 93 raise NotImplementedError( 94 "Must define `__dask_distributed_unpack__` for `BlockwiseDep` subclass." 95 ) 96 97 def __repr__(self) -> str: 98 return f"<{type(self).__name__} {self.numblocks}>" 99 100 101class BlockwiseDepDict(BlockwiseDep): 102 """Dictionary-based Blockwise-IO argument 103 104 This is a dictionary-backed instance of ``BlockwiseDep``. 105 The purpose of this class is to simplify the construction 106 of IO-based Blockwise Layers with block/partition-dependent 107 function arguments that are difficult to calculate at 108 graph-materialization time. 109 110 Examples 111 -------- 112 113 Specify an IO-based function for the Blockwise Layer. Note 114 that the function will be passed a single input object when 115 the task is executed (e.g. a single ``tuple`` or ``dict``): 116 117 >>> import pandas as pd 118 >>> func = lambda x: pd.read_csv(**x) 119 120 Use ``BlockwiseDepDict`` to define the input argument to 121 ``func`` for each block/partition: 122 123 >>> dep = BlockwiseDepDict( 124 ... mapping={ 125 ... (0,) : { 126 ... "filepath_or_buffer": "data.csv", 127 ... "skiprows": 1, 128 ... "nrows": 2, 129 ... "names": ["a", "b"], 130 ... }, 131 ... (1,) : { 132 ... "filepath_or_buffer": "data.csv", 133 ... "skiprows": 3, 134 ... "nrows": 2, 135 ... "names": ["a", "b"], 136 ... }, 137 ... } 138 ... ) 139 140 Construct a Blockwise Layer with ``dep`` speficied 141 in the ``indices`` list: 142 143 >>> layer = Blockwise( 144 ... output="collection-name", 145 ... output_indices="i", 146 ... dsk={"collection-name": (func, '_0')}, 147 ... indices=[(dep, "i")], 148 ... numblocks={}, 149 ... ) 150 151 See Also 152 -------- 153 dask.blockwise.Blockwise 154 dask.blockwise.BlockwiseDep 155 """ 156 157 def __init__( 158 self, 159 mapping: dict, 160 numblocks: Optional[Tuple[int, ...]] = None, 161 produces_tasks: bool = False, 162 ): 163 self.mapping = mapping 164 self.produces_tasks = produces_tasks 165 166 # By default, assume 1D shape 167 self.numblocks = numblocks or (len(mapping),) 168 169 def __getitem__(self, idx: Tuple[int, ...]) -> Any: 170 return self.mapping[idx] 171 172 def __dask_distributed_pack__( 173 self, required_indices: Optional[List[Tuple[int, ...]]] = None 174 ): 175 from distributed.protocol import to_serialize 176 177 if required_indices is None: 178 required_indices = self.mapping.keys() 179 180 return { 181 "mapping": {k: to_serialize(self.mapping[k]) for k in required_indices}, 182 "numblocks": self.numblocks, 183 "produces_tasks": self.produces_tasks, 184 } 185 186 @classmethod 187 def __dask_distributed_unpack__(cls, state): 188 return cls(**state) 189 190 191def subs(task, substitution): 192 """Create a new task with the values substituted 193 194 This is like dask.core.subs, but takes a dict of many substitutions to 195 perform simultaneously. It is not as concerned with micro performance. 196 """ 197 if isinstance(task, dict): 198 return {k: subs(v, substitution) for k, v in task.items()} 199 if type(task) in (tuple, list, set): 200 return type(task)([subs(x, substitution) for x in task]) 201 try: 202 return substitution[task] 203 except (KeyError, TypeError): 204 return task 205 206 207def index_subs(ind, substitution): 208 """A simple subs function that works both on tuples and strings""" 209 if ind is None: 210 return ind 211 else: 212 return tuple(substitution.get(c, c) for c in ind) 213 214 215_BLOCKWISE_DEFAULT_PREFIX = "__dask_blockwise__" 216 217 218def blockwise_token(i, prefix=_BLOCKWISE_DEFAULT_PREFIX): 219 return prefix + "%d" % i 220 221 222def blockwise( 223 func, 224 output, 225 output_indices, 226 *arrind_pairs, 227 numblocks=None, 228 concatenate=None, 229 new_axes=None, 230 dependencies=(), 231 **kwargs, 232): 233 """Create a Blockwise symbolic mutable mapping 234 235 This is like the ``make_blockwise_graph`` function, but rather than construct a 236 dict, it returns a symbolic Blockwise object. 237 238 See Also 239 -------- 240 make_blockwise_graph 241 Blockwise 242 """ 243 new_axes = new_axes or {} 244 245 arrind_pairs = list(arrind_pairs) 246 247 # Transform indices to canonical elements 248 # We use terms like _0, and _1 rather than provided index elements 249 unique_indices = { 250 i for ii in arrind_pairs[1::2] if ii is not None for i in ii 251 } | set(output_indices) 252 sub = {k: blockwise_token(i, ".") for i, k in enumerate(sorted(unique_indices))} 253 output_indices = index_subs(tuple(output_indices), sub) 254 a_pairs_list = [] 255 for a in arrind_pairs[1::2]: 256 if a is not None: 257 val = tuple(a) 258 else: 259 val = a 260 a_pairs_list.append(index_subs(val, sub)) 261 262 arrind_pairs[1::2] = a_pairs_list 263 new_axes = {index_subs((k,), sub)[0]: v for k, v in new_axes.items()} 264 265 # Unpack dask values in non-array arguments 266 inputs = [] 267 inputs_indices = [] 268 for name, index in toolz.partition(2, arrind_pairs): 269 inputs.append(name) 270 inputs_indices.append(index) 271 272 # Unpack delayed objects in kwargs 273 new_keys = {n for c in dependencies for n in c.__dask_layers__()} 274 if kwargs: 275 # replace keys in kwargs with _0 tokens 276 new_tokens = tuple( 277 blockwise_token(i) for i in range(len(inputs), len(inputs) + len(new_keys)) 278 ) 279 sub = dict(zip(new_keys, new_tokens)) 280 inputs.extend(new_keys) 281 inputs_indices.extend((None,) * len(new_keys)) 282 kwargs = subs(kwargs, sub) 283 284 indices = [(k, v) for k, v in zip(inputs, inputs_indices)] 285 keys = map(blockwise_token, range(len(inputs))) 286 287 # Construct local graph 288 if not kwargs: 289 subgraph = {output: (func,) + tuple(keys)} 290 else: 291 _keys = list(keys) 292 if new_keys: 293 _keys = _keys[: -len(new_keys)] 294 kwargs2 = (dict, list(map(list, kwargs.items()))) 295 subgraph = {output: (apply, func, _keys, kwargs2)} 296 297 # Construct final output 298 subgraph = Blockwise( 299 output, 300 output_indices, 301 subgraph, 302 indices, 303 numblocks=numblocks, 304 concatenate=concatenate, 305 new_axes=new_axes, 306 ) 307 return subgraph 308 309 310class Blockwise(Layer): 311 """Tensor Operation 312 313 This is a lazily constructed mapping for tensor operation graphs. 314 This defines a dictionary using an operation and an indexing pattern. 315 It is built for many operations like elementwise, transpose, tensordot, and 316 so on. We choose to keep these as symbolic mappings rather than raw 317 dictionaries because we are able to fuse them during optimization, 318 sometimes resulting in much lower overhead. 319 320 Parameters 321 ---------- 322 output: str 323 The name of the output collection. Used in keynames 324 output_indices: tuple 325 The output indices, like ``('i', 'j', 'k')`` used to determine the 326 structure of the block computations 327 dsk: dict 328 A small graph to apply per-output-block. May include keys from the 329 input indices. 330 indices: Tuple[Tuple[str, Optional[Tuple[str, ...]]], ...] 331 An ordered mapping from input key name, like ``'x'`` 332 to input indices, like ``('i', 'j')`` 333 Or includes literals, which have ``None`` for an index value. 334 In place of input-key names, the first tuple element may also be a 335 ``BlockwiseDep`` object. 336 numblocks: Mapping[key, Sequence[int]] 337 Number of blocks along each dimension for each input 338 concatenate: bool 339 Whether or not to pass contracted dimensions as a list of inputs or a 340 single input to the block function 341 new_axes: Mapping 342 New index dimensions that may have been created and their size, 343 e.g. ``{'j': 2, 'k': 3}`` 344 output_blocks: Set[Tuple[int, ...]] 345 Specify a specific set of required output blocks. Since the graph 346 will only contain the necessary tasks to generate these outputs, 347 this kwarg can be used to "cull" the abstract layer (without needing 348 to materialize the low-level graph). 349 annotations: dict (optional) 350 Layer annotations 351 io_deps: Dict[str, BlockwiseDep] (optional) 352 Dictionary containing the mapping between "place-holder" collection 353 keys and ``BlockwiseDep``-based objects. 354 **WARNING**: This argument should only be used internally (for culling, 355 fusion and cloning of existing Blockwise layers). Explicit use of this 356 argument will be deprecated in the future. 357 358 See Also 359 -------- 360 dask.blockwise.blockwise 361 dask.array.blockwise 362 """ 363 364 output: str 365 output_indices: Tuple[str, ...] 366 dsk: Mapping[str, tuple] 367 indices: Tuple[Tuple[str, Optional[Tuple[str, ...]]], ...] 368 numblocks: Mapping[str, Sequence[int]] 369 concatenate: Optional[bool] 370 new_axes: Mapping[str, int] 371 output_blocks: Optional[Set[Tuple[int, ...]]] 372 373 def __init__( 374 self, 375 output: str, 376 output_indices: Iterable[str], 377 dsk: Mapping[str, tuple], 378 indices: Iterable[Tuple[Union[str, BlockwiseDep], Optional[Iterable[str]]]], 379 numblocks: Mapping[str, Sequence[int]], 380 concatenate: bool = None, 381 new_axes: Mapping[str, int] = None, 382 output_blocks: Set[Tuple[int, ...]] = None, 383 annotations: Mapping[str, Any] = None, 384 io_deps: Optional[Mapping[str, BlockwiseDep]] = None, 385 ): 386 super().__init__(annotations=annotations) 387 self.output = output 388 self.output_indices = tuple(output_indices) 389 self.output_blocks = output_blocks 390 self.dsk = dsk 391 392 # Remove `BlockwiseDep` arguments from input indices 393 # and add them to `self.io_deps`. 394 # TODO: Remove `io_deps` and handle indexable objects 395 # in `self.indices` throughout `Blockwise`. 396 self.indices = [] 397 self.numblocks = numblocks 398 self.io_deps = io_deps or {} 399 for dep, ind in indices: 400 name = dep 401 if isinstance(dep, BlockwiseDep): 402 name = tokenize(dep) 403 self.io_deps[name] = dep 404 self.numblocks[name] = dep.numblocks 405 self.indices.append((name, tuple(ind) if ind is not None else ind)) 406 self.indices = tuple(self.indices) 407 408 # optimize_blockwise won't merge where `concatenate` doesn't match, so 409 # enforce a canonical value if there are no axes for reduction. 410 output_indices_set = set(self.output_indices) 411 if concatenate is not None and all( 412 i in output_indices_set 413 for name, ind in self.indices 414 if ind is not None 415 for i in ind 416 ): 417 concatenate = None 418 self.concatenate = concatenate 419 self.new_axes = new_axes or {} 420 421 @property 422 def dims(self): 423 """Returns a dictionary mapping between each index specified in 424 `self.indices` and the number of output blocks for that indice. 425 """ 426 if not hasattr(self, "_dims"): 427 self._dims = _make_dims(self.indices, self.numblocks, self.new_axes) 428 return self._dims 429 430 def __repr__(self): 431 return f"Blockwise<{self.indices} -> {self.output}>" 432 433 @property 434 def _dict(self): 435 if hasattr(self, "_cached_dict"): 436 return self._cached_dict["dsk"] 437 else: 438 keys = tuple(map(blockwise_token, range(len(self.indices)))) 439 dsk, _ = fuse(self.dsk, [self.output]) 440 func = SubgraphCallable(dsk, self.output, keys) 441 442 dsk = make_blockwise_graph( 443 func, 444 self.output, 445 self.output_indices, 446 *list(toolz.concat(self.indices)), 447 new_axes=self.new_axes, 448 numblocks=self.numblocks, 449 concatenate=self.concatenate, 450 output_blocks=self.output_blocks, 451 dims=self.dims, 452 io_deps=self.io_deps, 453 ) 454 455 self._cached_dict = {"dsk": dsk} 456 return self._cached_dict["dsk"] 457 458 def get_output_keys(self): 459 if self.output_blocks: 460 # Culling has already generated a list of output blocks 461 return {(self.output, *p) for p in self.output_blocks} 462 463 # Return all possible output keys (no culling) 464 return { 465 (self.output, *p) 466 for p in itertools.product( 467 *[range(self.dims[i]) for i in self.output_indices] 468 ) 469 } 470 471 def __getitem__(self, key): 472 return self._dict[key] 473 474 def __iter__(self): 475 return iter(self._dict) 476 477 def __len__(self) -> int: 478 # same method as `get_output_keys`, without manifesting the keys themselves 479 return ( 480 len(self.output_blocks) 481 if self.output_blocks 482 else prod(self.dims[i] for i in self.output_indices) 483 ) 484 485 def is_materialized(self): 486 return hasattr(self, "_cached_dict") 487 488 def __dask_distributed_pack__( 489 self, all_hlg_keys, known_key_dependencies, client, client_keys 490 ): 491 from distributed.protocol import to_serialize 492 from distributed.utils import CancelledError 493 from distributed.utils_comm import unpack_remotedata 494 from distributed.worker import dumps_function 495 496 keys = tuple(map(blockwise_token, range(len(self.indices)))) 497 dsk, _ = fuse(self.dsk, [self.output]) 498 499 # Embed literals in `dsk` 500 keys2 = [] 501 indices2 = [] 502 global_dependencies = set() 503 for key, (val, index) in zip(keys, self.indices): 504 if index is None: 505 try: 506 val_is_a_key = val in all_hlg_keys 507 except TypeError: # not hashable 508 val_is_a_key = False 509 if val_is_a_key: 510 keys2.append(key) 511 indices2.append((val, index)) 512 global_dependencies.add(stringify(val)) 513 else: 514 dsk[key] = val # Literal 515 else: 516 keys2.append(key) 517 indices2.append((val, index)) 518 519 dsk = (SubgraphCallable(dsk, self.output, tuple(keys2)),) 520 dsk, dsk_unpacked_futures = unpack_remotedata(dsk, byte_keys=True) 521 522 # Handle `io_deps` serialization. Assume each element 523 # is a `BlockwiseDep`-based object. 524 packed_io_deps = {} 525 inline_tasks = False 526 for name, blockwise_dep in self.io_deps.items(): 527 packed_io_deps[name] = { 528 "__module__": blockwise_dep.__module__, 529 "__name__": type(blockwise_dep).__name__, 530 # TODO: Pass a `required_indices` list to __pack__ 531 "state": blockwise_dep.__dask_distributed_pack__(), 532 } 533 inline_tasks = inline_tasks or blockwise_dep.produces_tasks 534 535 # Dump (pickle + cache) the function here if we know `make_blockwise_graph` 536 # will NOT be producing "nested" tasks (via `__dask_distributed_unpack__`). 537 # 538 # If `make_blockwise_graph` DOES need to produce nested tasks later on, it 539 # will need to call `to_serialize` on the entire task. That will be a 540 # problem if the function was already pickled here. Therefore, we want to 541 # call `to_serialize` on the function if we know there will be nested tasks. 542 # 543 # We know there will be nested tasks if either: 544 # (1) `concatenate=True` # Check `self.concatenate` 545 # (2) `inline_tasks=True` # Check `BlockwiseDep.produces_tasks` 546 # 547 # We do not call `to_serialize` in ALL cases, because that code path does 548 # not cache the function on the scheduler or worker (or warn if there are 549 # large objects being passed into the graph). However, in the future, 550 # single-pass serialization improvements should allow us to remove this 551 # special logic altogether. 552 func = ( 553 to_serialize(dsk[0]) 554 if (self.concatenate or inline_tasks) 555 else dumps_function(dsk[0]) 556 ) 557 func_future_args = dsk[1:] 558 559 indices = list(toolz.concat(indices2)) 560 indices, indices_unpacked_futures = unpack_remotedata(indices, byte_keys=True) 561 562 # Check the legality of the unpacked futures 563 for future in itertools.chain(dsk_unpacked_futures, indices_unpacked_futures): 564 if future.client is not client: 565 raise ValueError( 566 "Inputs contain futures that were created by another client." 567 ) 568 if stringify(future.key) not in client.futures: 569 raise CancelledError(stringify(future.key)) 570 571 # All blockwise tasks will depend on the futures in `indices` 572 global_dependencies |= {stringify(f.key) for f in indices_unpacked_futures} 573 574 return { 575 "output": self.output, 576 "output_indices": self.output_indices, 577 "func": func, 578 "func_future_args": func_future_args, 579 "global_dependencies": global_dependencies, 580 "indices": indices, 581 "is_list": [isinstance(x, list) for x in indices], 582 "numblocks": self.numblocks, 583 "concatenate": self.concatenate, 584 "new_axes": self.new_axes, 585 "output_blocks": self.output_blocks, 586 "dims": self.dims, 587 "io_deps": packed_io_deps, 588 } 589 590 @classmethod 591 def __dask_distributed_unpack__(cls, state, dsk, dependencies): 592 from distributed.protocol.serialize import import_allowed_module 593 594 # Make sure we convert list items back from tuples in `indices`. 595 # The msgpack serialization will have converted lists into 596 # tuples, and tuples may be stringified during graph 597 # materialization (bad if the item was not a key). 598 indices = [ 599 list(ind) if is_list else ind 600 for ind, is_list in zip(state["indices"], state["is_list"]) 601 ] 602 603 # Unpack io_deps state 604 io_deps = {} 605 for replace_name, packed_dep in state["io_deps"].items(): 606 mod = import_allowed_module(packed_dep["__module__"]) 607 dep_cls = getattr(mod, packed_dep["__name__"]) 608 io_deps[replace_name] = dep_cls.__dask_distributed_unpack__( 609 packed_dep["state"] 610 ) 611 612 layer_dsk, layer_deps = make_blockwise_graph( 613 state["func"], 614 state["output"], 615 state["output_indices"], 616 *indices, 617 new_axes=state["new_axes"], 618 numblocks=state["numblocks"], 619 concatenate=state["concatenate"], 620 output_blocks=state["output_blocks"], 621 dims=state["dims"], 622 return_key_deps=True, 623 deserializing=True, 624 func_future_args=state["func_future_args"], 625 io_deps=io_deps, 626 ) 627 g_deps = state["global_dependencies"] 628 629 # Stringify layer graph and dependencies 630 layer_dsk = { 631 stringify(k): stringify_collection_keys(v) for k, v in layer_dsk.items() 632 } 633 deps = { 634 stringify(k): {stringify(d) for d in v} | g_deps 635 for k, v in layer_deps.items() 636 } 637 return {"dsk": layer_dsk, "deps": deps} 638 639 def _cull_dependencies(self, all_hlg_keys, output_blocks): 640 """Determine the necessary dependencies to produce `output_blocks`. 641 642 This method does not require graph materialization. 643 """ 644 645 # Check `concatenate` option 646 concatenate = None 647 if self.concatenate is True: 648 from dask.array.core import concatenate_axes as concatenate 649 650 # Generate coordinate map 651 (coord_maps, concat_axes, dummies) = _get_coord_mapping( 652 self.dims, 653 self.output, 654 self.output_indices, 655 self.numblocks, 656 self.indices, 657 concatenate, 658 ) 659 660 # Gather constant dependencies (for all output keys) 661 const_deps = set() 662 for (arg, ind) in self.indices: 663 if ind is None: 664 try: 665 if arg in all_hlg_keys: 666 const_deps.add(arg) 667 except TypeError: 668 pass # unhashable 669 670 # Get dependencies for each output block 671 key_deps = {} 672 for out_coords in output_blocks: 673 deps = set() 674 coords = out_coords + dummies 675 for cmap, axes, (arg, ind) in zip(coord_maps, concat_axes, self.indices): 676 if ind is not None and arg not in self.io_deps: 677 arg_coords = tuple(coords[c] for c in cmap) 678 if axes: 679 tups = lol_product((arg,), arg_coords) 680 deps.update(flatten(tups)) 681 if concatenate: 682 tups = (concatenate, tups, axes) 683 else: 684 tups = (arg,) + arg_coords 685 deps.add(tups) 686 key_deps[(self.output,) + out_coords] = deps | const_deps 687 688 return key_deps 689 690 def _cull(self, output_blocks): 691 return Blockwise( 692 self.output, 693 self.output_indices, 694 self.dsk, 695 self.indices, 696 self.numblocks, 697 concatenate=self.concatenate, 698 new_axes=self.new_axes, 699 output_blocks=output_blocks, 700 annotations=self.annotations, 701 io_deps=self.io_deps, 702 ) 703 704 def cull( 705 self, keys: set, all_hlg_keys: Iterable 706 ) -> Tuple[Layer, Mapping[Hashable, set]]: 707 # Culling is simple for Blockwise layers. We can just 708 # collect a set of required output blocks (tuples), and 709 # only construct graph for these blocks in `make_blockwise_graph` 710 711 output_blocks = set() 712 for key in keys: 713 if key[0] == self.output: 714 output_blocks.add(key[1:]) 715 culled_deps = self._cull_dependencies(all_hlg_keys, output_blocks) 716 out_size_iter = (self.dims[i] for i in self.output_indices) 717 if prod(out_size_iter) != len(culled_deps): 718 culled_layer = self._cull(output_blocks) 719 return culled_layer, culled_deps 720 else: 721 return self, culled_deps 722 723 def clone( 724 self, 725 keys: set, 726 seed: Hashable, 727 bind_to: Hashable = None, 728 ) -> Tuple[Layer, bool]: 729 names = {get_name_from_key(k) for k in keys} 730 # We assume that 'keys' will contain either all or none of the output keys of 731 # each of the layers, because clone/bind are always invoked at collection level. 732 # Asserting this is very expensive, so we only check it during unit tests. 733 if "PYTEST_CURRENT_TEST" in os.environ: 734 assert not self.get_output_keys() - keys 735 for name, nb in self.numblocks.items(): 736 if name in names: 737 for block in product(*(list(range(nbi)) for nbi in nb)): 738 assert (name, *block) in keys 739 740 is_leaf = True 741 742 indices = [] 743 for k, idxv in self.indices: 744 if k in names: 745 is_leaf = False 746 k = clone_key(k, seed) 747 indices.append((k, idxv)) 748 749 numblocks = {} 750 for k, nbv in self.numblocks.items(): 751 if k in names: 752 is_leaf = False 753 k = clone_key(k, seed) 754 numblocks[k] = nbv 755 756 dsk = {clone_key(k, seed): v for k, v in self.dsk.items()} 757 758 if bind_to is not None and is_leaf: 759 from .graph_manipulation import chunks 760 761 # It's always a Delayed generated by dask.graph_manipulation.checkpoint; 762 # the layer name always matches the key 763 assert isinstance(bind_to, str) 764 dsk = {k: (chunks.bind, v, f"_{len(indices)}") for k, v in dsk.items()} 765 indices.append((bind_to, None)) 766 767 return ( 768 Blockwise( 769 output=clone_key(self.output, seed), 770 output_indices=self.output_indices, 771 dsk=dsk, 772 indices=indices, 773 numblocks=numblocks, 774 concatenate=self.concatenate, 775 new_axes=self.new_axes, 776 output_blocks=self.output_blocks, 777 annotations=self.annotations, 778 io_deps=self.io_deps, 779 ), 780 (bind_to is not None and is_leaf), 781 ) 782 783 784def _get_coord_mapping( 785 dims, 786 output, 787 out_indices, 788 numblocks, 789 argpairs, 790 concatenate, 791): 792 """Calculate coordinate mapping for graph construction. 793 794 This function handles the high-level logic behind Blockwise graph 795 construction. The output is a tuple containing: The mapping between 796 input and output block coordinates (`coord_maps`), the axes along 797 which to concatenate for each input (`concat_axes`), and the dummy 798 indices needed for broadcasting (`dummies`). 799 800 Used by `make_blockwise_graph` and `Blockwise._cull_dependencies`. 801 802 Parameters 803 ---------- 804 dims : dict 805 Mapping between each index specified in `argpairs` and 806 the number of output blocks for that index. Corresponds 807 to the Blockwise `dims` attribute. 808 output : str 809 Corresponds to the Blockwise `output` attribute. 810 out_indices : tuple 811 Corresponds to the Blockwise `output_indices` attribute. 812 numblocks : dict 813 Corresponds to the Blockwise `numblocks` attribute. 814 argpairs : tuple 815 Corresponds to the Blockwise `indices` attribute. 816 concatenate : bool 817 Corresponds to the Blockwise `concatenate` attribute. 818 """ 819 820 block_names = set() 821 all_indices = set() 822 for name, ind in argpairs: 823 if ind is not None: 824 block_names.add(name) 825 for x in ind: 826 all_indices.add(x) 827 assert set(numblocks) == block_names 828 829 dummy_indices = all_indices - set(out_indices) 830 831 # For each position in the output space, we'll construct a 832 # "coordinate set" that consists of 833 # - the output indices 834 # - the dummy indices 835 # - the dummy indices, with indices replaced by zeros (for broadcasting), we 836 # are careful to only emit a single dummy zero when concatenate=True to not 837 # concatenate the same array with itself several times. 838 # - a 0 to assist with broadcasting. 839 840 index_pos, zero_pos = {}, {} 841 for i, ind in enumerate(out_indices): 842 index_pos[ind] = i 843 zero_pos[ind] = -1 844 845 _dummies_list = [] 846 for i, ind in enumerate(dummy_indices): 847 index_pos[ind] = 2 * i + len(out_indices) 848 zero_pos[ind] = 2 * i + 1 + len(out_indices) 849 reps = 1 if concatenate else dims[ind] 850 _dummies_list.append([list(range(dims[ind])), [0] * reps]) 851 852 # ([0, 1, 2], [0, 0, 0], ...) For a dummy index of dimension 3 853 dummies = tuple(itertools.chain.from_iterable(_dummies_list)) 854 dummies += (0,) 855 856 # For each coordinate position in each input, gives the position in 857 # the coordinate set. 858 coord_maps = [] 859 860 # Axes along which to concatenate, for each input 861 concat_axes = [] 862 for arg, ind in argpairs: 863 if ind is not None: 864 coord_maps.append( 865 [ 866 zero_pos[i] if nb == 1 else index_pos[i] 867 for i, nb in zip(ind, numblocks[arg]) 868 ] 869 ) 870 concat_axes.append([n for n, i in enumerate(ind) if i in dummy_indices]) 871 else: 872 coord_maps.append(None) 873 concat_axes.append(None) 874 875 return coord_maps, concat_axes, dummies 876 877 878def make_blockwise_graph( 879 func, 880 output, 881 out_indices, 882 *arrind_pairs, 883 numblocks=None, 884 concatenate=None, 885 new_axes=None, 886 output_blocks=None, 887 dims=None, 888 deserializing=False, 889 func_future_args=None, 890 return_key_deps=False, 891 io_deps=None, 892 **kwargs, 893): 894 """Tensor operation 895 896 Applies a function, ``func``, across blocks from many different input 897 collections. We arrange the pattern with which those blocks interact with 898 sets of matching indices. E.g.:: 899 900 make_blockwise_graph(func, 'z', 'i', 'x', 'i', 'y', 'i') 901 902 yield an embarrassingly parallel communication pattern and is read as 903 904 $$ z_i = func(x_i, y_i) $$ 905 906 More complex patterns may emerge, including multiple indices:: 907 908 make_blockwise_graph(func, 'z', 'ij', 'x', 'ij', 'y', 'ji') 909 910 $$ z_{ij} = func(x_{ij}, y_{ji}) $$ 911 912 Indices missing in the output but present in the inputs results in many 913 inputs being sent to one function (see examples). 914 915 Examples 916 -------- 917 Simple embarrassing map operation 918 919 >>> inc = lambda x: x + 1 920 >>> make_blockwise_graph(inc, 'z', 'ij', 'x', 'ij', numblocks={'x': (2, 2)}) # doctest: +SKIP 921 {('z', 0, 0): (inc, ('x', 0, 0)), 922 ('z', 0, 1): (inc, ('x', 0, 1)), 923 ('z', 1, 0): (inc, ('x', 1, 0)), 924 ('z', 1, 1): (inc, ('x', 1, 1))} 925 926 Simple operation on two datasets 927 928 >>> add = lambda x, y: x + y 929 >>> make_blockwise_graph(add, 'z', 'ij', 'x', 'ij', 'y', 'ij', numblocks={'x': (2, 2), 930 ... 'y': (2, 2)}) # doctest: +SKIP 931 {('z', 0, 0): (add, ('x', 0, 0), ('y', 0, 0)), 932 ('z', 0, 1): (add, ('x', 0, 1), ('y', 0, 1)), 933 ('z', 1, 0): (add, ('x', 1, 0), ('y', 1, 0)), 934 ('z', 1, 1): (add, ('x', 1, 1), ('y', 1, 1))} 935 936 Operation that flips one of the datasets 937 938 >>> addT = lambda x, y: x + y.T # Transpose each chunk 939 >>> # z_ij ~ x_ij y_ji 940 >>> # .. .. .. notice swap 941 >>> make_blockwise_graph(addT, 'z', 'ij', 'x', 'ij', 'y', 'ji', numblocks={'x': (2, 2), 942 ... 'y': (2, 2)}) # doctest: +SKIP 943 {('z', 0, 0): (add, ('x', 0, 0), ('y', 0, 0)), 944 ('z', 0, 1): (add, ('x', 0, 1), ('y', 1, 0)), 945 ('z', 1, 0): (add, ('x', 1, 0), ('y', 0, 1)), 946 ('z', 1, 1): (add, ('x', 1, 1), ('y', 1, 1))} 947 948 Dot product with contraction over ``j`` index. Yields list arguments 949 950 >>> make_blockwise_graph(dotmany, 'z', 'ik', 'x', 'ij', 'y', 'jk', numblocks={'x': (2, 2), 951 ... 'y': (2, 2)}) # doctest: +SKIP 952 {('z', 0, 0): (dotmany, [('x', 0, 0), ('x', 0, 1)], 953 [('y', 0, 0), ('y', 1, 0)]), 954 ('z', 0, 1): (dotmany, [('x', 0, 0), ('x', 0, 1)], 955 [('y', 0, 1), ('y', 1, 1)]), 956 ('z', 1, 0): (dotmany, [('x', 1, 0), ('x', 1, 1)], 957 [('y', 0, 0), ('y', 1, 0)]), 958 ('z', 1, 1): (dotmany, [('x', 1, 0), ('x', 1, 1)], 959 [('y', 0, 1), ('y', 1, 1)])} 960 961 Pass ``concatenate=True`` to concatenate arrays ahead of time 962 963 >>> make_blockwise_graph(f, 'z', 'i', 'x', 'ij', 'y', 'ij', concatenate=True, 964 ... numblocks={'x': (2, 2), 'y': (2, 2,)}) # doctest: +SKIP 965 {('z', 0): (f, (concatenate_axes, [('x', 0, 0), ('x', 0, 1)], (1,)), 966 (concatenate_axes, [('y', 0, 0), ('y', 0, 1)], (1,))) 967 ('z', 1): (f, (concatenate_axes, [('x', 1, 0), ('x', 1, 1)], (1,)), 968 (concatenate_axes, [('y', 1, 0), ('y', 1, 1)], (1,)))} 969 970 Supports Broadcasting rules 971 972 >>> make_blockwise_graph(add, 'z', 'ij', 'x', 'ij', 'y', 'ij', numblocks={'x': (1, 2), 973 ... 'y': (2, 2)}) # doctest: +SKIP 974 {('z', 0, 0): (add, ('x', 0, 0), ('y', 0, 0)), 975 ('z', 0, 1): (add, ('x', 0, 1), ('y', 0, 1)), 976 ('z', 1, 0): (add, ('x', 0, 0), ('y', 1, 0)), 977 ('z', 1, 1): (add, ('x', 0, 1), ('y', 1, 1))} 978 979 Support keyword arguments with apply 980 981 >>> def f(a, b=0): return a + b 982 >>> make_blockwise_graph(f, 'z', 'i', 'x', 'i', numblocks={'x': (2,)}, b=10) # doctest: +SKIP 983 {('z', 0): (apply, f, [('x', 0)], {'b': 10}), 984 ('z', 1): (apply, f, [('x', 1)], {'b': 10})} 985 986 Include literals by indexing with ``None`` 987 988 >>> make_blockwise_graph(add, 'z', 'i', 'x', 'i', 100, None, numblocks={'x': (2,)}) # doctest: +SKIP 989 {('z', 0): (add, ('x', 0), 100), 990 ('z', 1): (add, ('x', 1), 100)} 991 992 See Also 993 -------- 994 dask.array.blockwise 995 dask.blockwise.blockwise 996 """ 997 998 if numblocks is None: 999 raise ValueError("Missing required numblocks argument.") 1000 new_axes = new_axes or {} 1001 io_deps = io_deps or {} 1002 argpairs = list(toolz.partition(2, arrind_pairs)) 1003 1004 if return_key_deps: 1005 key_deps = {} 1006 1007 if deserializing: 1008 from distributed.protocol.serialize import to_serialize 1009 1010 if concatenate is True: 1011 from dask.array.core import concatenate_axes as concatenate 1012 1013 # Dictionary mapping {i: 3, j: 4, ...} for i, j, ... the dimensions 1014 dims = dims or _make_dims(argpairs, numblocks, new_axes) 1015 1016 # Generate the abstract "plan" before constructing 1017 # the actual graph 1018 (coord_maps, concat_axes, dummies) = _get_coord_mapping( 1019 dims, 1020 output, 1021 out_indices, 1022 numblocks, 1023 argpairs, 1024 concatenate, 1025 ) 1026 1027 # Unpack delayed objects in kwargs 1028 dsk2 = {} 1029 if kwargs: 1030 task, dsk2 = unpack_collections(kwargs) 1031 if dsk2: 1032 kwargs2 = task 1033 else: 1034 kwargs2 = kwargs 1035 1036 # Apply Culling. 1037 # Only need to construct the specified set of output blocks 1038 output_blocks = output_blocks or itertools.product( 1039 *[range(dims[i]) for i in out_indices] 1040 ) 1041 1042 dsk = {} 1043 # Create argument lists 1044 for out_coords in output_blocks: 1045 deps = set() 1046 coords = out_coords + dummies 1047 args = [] 1048 for cmap, axes, (arg, ind) in zip(coord_maps, concat_axes, argpairs): 1049 if ind is None: 1050 if deserializing: 1051 args.append(stringify_collection_keys(arg)) 1052 else: 1053 args.append(arg) 1054 else: 1055 arg_coords = tuple(coords[c] for c in cmap) 1056 if axes: 1057 tups = lol_product((arg,), arg_coords) 1058 if arg not in io_deps: 1059 deps.update(flatten(tups)) 1060 1061 if concatenate: 1062 tups = (concatenate, tups, axes) 1063 else: 1064 tups = (arg,) + arg_coords 1065 if arg not in io_deps: 1066 deps.add(tups) 1067 # Replace "place-holder" IO keys with "real" args 1068 if arg in io_deps: 1069 # We don't want to stringify keys for args 1070 # we are replacing here 1071 idx = tups[1:] 1072 args.append(io_deps[arg].get(idx, idx)) 1073 elif deserializing: 1074 args.append(stringify_collection_keys(tups)) 1075 else: 1076 args.append(tups) 1077 out_key = (output,) + out_coords 1078 1079 if deserializing: 1080 deps.update(func_future_args) 1081 args += list(func_future_args) 1082 1083 if deserializing and isinstance(func, bytes): 1084 # Construct a function/args/kwargs dict if we 1085 # do not have a nested task (i.e. concatenate=False). 1086 # TODO: Avoid using the iterate_collection-version 1087 # of to_serialize if we know that are no embeded 1088 # Serialized/Serialize objects in args and/or kwargs. 1089 if kwargs: 1090 dsk[out_key] = { 1091 "function": func, 1092 "args": to_serialize(args), 1093 "kwargs": to_serialize(kwargs2), 1094 } 1095 else: 1096 dsk[out_key] = {"function": func, "args": to_serialize(args)} 1097 else: 1098 if kwargs: 1099 val = (apply, func, args, kwargs2) 1100 else: 1101 args.insert(0, func) 1102 val = tuple(args) 1103 # May still need to serialize (if concatenate=True) 1104 dsk[out_key] = to_serialize(val) if deserializing else val 1105 1106 if return_key_deps: 1107 key_deps[out_key] = deps 1108 1109 if dsk2: 1110 dsk.update(ensure_dict(dsk2)) 1111 1112 if return_key_deps: 1113 return dsk, key_deps 1114 else: 1115 return dsk 1116 1117 1118def lol_product(head, values): 1119 """List of list of tuple keys, similar to `itertools.product`. 1120 1121 Parameters 1122 ---------- 1123 head : tuple 1124 Prefix prepended to all results. 1125 values : sequence 1126 Mix of singletons and lists. Each list is substituted with every 1127 possible value and introduces another level of list in the output. 1128 1129 Examples 1130 -------- 1131 >>> lol_product(('x',), (1, 2, 3)) 1132 ('x', 1, 2, 3) 1133 >>> lol_product(('x',), (1, [2, 3], 4, [5, 6])) # doctest: +NORMALIZE_WHITESPACE 1134 [[('x', 1, 2, 4, 5), ('x', 1, 2, 4, 6)], 1135 [('x', 1, 3, 4, 5), ('x', 1, 3, 4, 6)]] 1136 """ 1137 if not values: 1138 return head 1139 elif isinstance(values[0], list): 1140 return [lol_product(head + (x,), values[1:]) for x in values[0]] 1141 else: 1142 return lol_product(head + (values[0],), values[1:]) 1143 1144 1145def lol_tuples(head, ind, values, dummies): 1146 """List of list of tuple keys 1147 1148 Parameters 1149 ---------- 1150 head : tuple 1151 The known tuple so far 1152 ind : Iterable 1153 An iterable of indices not yet covered 1154 values : dict 1155 Known values for non-dummy indices 1156 dummies : dict 1157 Ranges of values for dummy indices 1158 1159 Examples 1160 -------- 1161 >>> lol_tuples(('x',), 'ij', {'i': 1, 'j': 0}, {}) 1162 ('x', 1, 0) 1163 1164 >>> lol_tuples(('x',), 'ij', {'i': 1}, {'j': range(3)}) 1165 [('x', 1, 0), ('x', 1, 1), ('x', 1, 2)] 1166 1167 >>> lol_tuples(('x',), 'ijk', {'i': 1}, {'j': [0, 1, 2], 'k': [0, 1]}) # doctest: +NORMALIZE_WHITESPACE 1168 [[('x', 1, 0, 0), ('x', 1, 0, 1)], 1169 [('x', 1, 1, 0), ('x', 1, 1, 1)], 1170 [('x', 1, 2, 0), ('x', 1, 2, 1)]] 1171 """ 1172 if not ind: 1173 return head 1174 if ind[0] not in dummies: 1175 return lol_tuples(head + (values[ind[0]],), ind[1:], values, dummies) 1176 else: 1177 return [ 1178 lol_tuples(head + (v,), ind[1:], values, dummies) for v in dummies[ind[0]] 1179 ] 1180 1181 1182def optimize_blockwise(graph, keys=()): 1183 """High level optimization of stacked Blockwise layers 1184 1185 For operations that have multiple Blockwise operations one after the other, like 1186 ``x.T + 123`` we can fuse these into a single Blockwise operation. This happens 1187 before any actual tasks are generated, and so can reduce overhead. 1188 1189 This finds groups of Blockwise operations that can be safely fused, and then 1190 passes them to ``rewrite_blockwise`` for rewriting. 1191 1192 Parameters 1193 ---------- 1194 graph : HighLevelGraph 1195 keys : Iterable 1196 The keys of all outputs of all collections. 1197 Used to make sure that we don't fuse a layer needed by an output 1198 1199 Returns 1200 ------- 1201 HighLevelGraph 1202 1203 See Also 1204 -------- 1205 rewrite_blockwise 1206 """ 1207 out = _optimize_blockwise(graph, keys=keys) 1208 while out.dependencies != graph.dependencies: 1209 graph = out 1210 out = _optimize_blockwise(graph, keys=keys) 1211 return out 1212 1213 1214def _optimize_blockwise(full_graph, keys=()): 1215 keep = {k[0] if type(k) is tuple else k for k in keys} 1216 layers = full_graph.layers 1217 dependents = reverse_dict(full_graph.dependencies) 1218 roots = {k for k in full_graph.layers if not dependents.get(k)} 1219 stack = list(roots) 1220 1221 out = {} 1222 dependencies = {} 1223 seen = set() 1224 io_names = set() 1225 1226 while stack: 1227 layer = stack.pop() 1228 if layer in seen or layer not in layers: 1229 continue 1230 seen.add(layer) 1231 1232 # Outer loop walks through possible output Blockwise layers 1233 if isinstance(layers[layer], Blockwise): 1234 blockwise_layers = {layer} 1235 deps = set(blockwise_layers) 1236 io_names |= layers[layer].io_deps.keys() 1237 while deps: # we gather as many sub-layers as we can 1238 dep = deps.pop() 1239 1240 if dep not in layers: 1241 stack.append(dep) 1242 continue 1243 if not isinstance(layers[dep], Blockwise): 1244 stack.append(dep) 1245 continue 1246 if dep != layer and dep in keep: 1247 stack.append(dep) 1248 continue 1249 if layers[dep].concatenate != layers[layer].concatenate: 1250 stack.append(dep) 1251 continue 1252 if ( 1253 sum(k == dep for k, ind in layers[layer].indices if ind is not None) 1254 > 1 1255 ): 1256 stack.append(dep) 1257 continue 1258 if ( 1259 blockwise_layers 1260 and layers[next(iter(blockwise_layers))].annotations 1261 != layers[dep].annotations 1262 ): 1263 stack.append(dep) 1264 continue 1265 1266 # passed everything, proceed 1267 blockwise_layers.add(dep) 1268 1269 # traverse further to this child's children 1270 for d in full_graph.dependencies.get(dep, ()): 1271 # Don't allow reductions to proceed 1272 output_indices = set(layers[dep].output_indices) 1273 input_indices = { 1274 i for _, ind in layers[dep].indices if ind for i in ind 1275 } 1276 1277 if len(dependents[d]) <= 1 and output_indices.issuperset( 1278 input_indices 1279 ): 1280 deps.add(d) 1281 else: 1282 stack.append(d) 1283 1284 # Merge these Blockwise layers into one 1285 new_layer = rewrite_blockwise([layers[l] for l in blockwise_layers]) 1286 out[layer] = new_layer 1287 1288 new_deps = set() 1289 for k, v in new_layer.indices: 1290 if v is None: 1291 new_deps |= keys_in_tasks(full_graph.dependencies, [k]) 1292 elif k not in io_names: 1293 new_deps.add(k) 1294 dependencies[layer] = new_deps 1295 else: 1296 out[layer] = layers[layer] 1297 dependencies[layer] = full_graph.dependencies.get(layer, set()) 1298 stack.extend(full_graph.dependencies.get(layer, ())) 1299 1300 return HighLevelGraph(out, dependencies) 1301 1302 1303def rewrite_blockwise(inputs): 1304 """Rewrite a stack of Blockwise expressions into a single blockwise expression 1305 1306 Given a set of Blockwise layers, combine them into a single layer. The provided 1307 layers are expected to fit well together. That job is handled by 1308 ``optimize_blockwise`` 1309 1310 Parameters 1311 ---------- 1312 inputs : List[Blockwise] 1313 1314 Returns 1315 ------- 1316 blockwise: Blockwise 1317 1318 See Also 1319 -------- 1320 optimize_blockwise 1321 """ 1322 if len(inputs) == 1: 1323 # Fast path: if there's only one input we can just use it as-is. 1324 return inputs[0] 1325 1326 inputs = {inp.output: inp for inp in inputs} 1327 dependencies = { 1328 inp.output: {d for d, v in inp.indices if v is not None and d in inputs} 1329 for inp in inputs.values() 1330 } 1331 dependents = reverse_dict(dependencies) 1332 1333 new_index_iter = ( 1334 c + (str(d) if d else "") # A, B, ... A1, B1, ... 1335 for d in itertools.count() 1336 for c in "ABCDEFGHIJKLMNOPQRSTUVWXYZ" 1337 ) 1338 1339 [root] = [k for k, v in dependents.items() if not v] 1340 1341 # Our final results. These will change during fusion below 1342 indices = list(inputs[root].indices) 1343 new_axes = inputs[root].new_axes 1344 concatenate = inputs[root].concatenate 1345 dsk = dict(inputs[root].dsk) 1346 1347 changed = True 1348 while changed: 1349 changed = False 1350 for i, (dep, ind) in enumerate(indices): 1351 if ind is None: 1352 continue 1353 if dep not in inputs: 1354 continue 1355 1356 changed = True 1357 1358 # Replace _n with dep name in existing tasks 1359 # (inc, _0) -> (inc, 'b') 1360 dsk = {k: subs(v, {blockwise_token(i): dep}) for k, v in dsk.items()} 1361 1362 # Remove current input from input indices 1363 # [('a', 'i'), ('b', 'i')] -> [('a', 'i')] 1364 _, current_dep_indices = indices.pop(i) 1365 sub = { 1366 blockwise_token(i): blockwise_token(i - 1) 1367 for i in range(i + 1, len(indices) + 1) 1368 } 1369 dsk = subs(dsk, sub) 1370 1371 # Change new input_indices to match give index from current computation 1372 # [('c', j')] -> [('c', 'i')] 1373 new_indices = inputs[dep].indices 1374 sub = dict(zip(inputs[dep].output_indices, current_dep_indices)) 1375 contracted = { 1376 x 1377 for _, j in new_indices 1378 if j is not None 1379 for x in j 1380 if x not in inputs[dep].output_indices 1381 } 1382 extra = dict(zip(contracted, new_index_iter)) 1383 sub.update(extra) 1384 new_indices = [(x, index_subs(j, sub)) for x, j in new_indices] 1385 1386 # Update new_axes 1387 for k, v in inputs[dep].new_axes.items(): 1388 new_axes[sub[k]] = v 1389 1390 # Bump new inputs up in list 1391 sub = {} 1392 # Map from (id(key), inds or None) -> index in indices. Used to deduplicate indices. 1393 index_map = {(id(k), inds): n for n, (k, inds) in enumerate(indices)} 1394 for ii, index in enumerate(new_indices): 1395 id_key = (id(index[0]), index[1]) 1396 if id_key in index_map: # use old inputs if available 1397 sub[blockwise_token(ii)] = blockwise_token(index_map[id_key]) 1398 else: 1399 index_map[id_key] = len(indices) 1400 sub[blockwise_token(ii)] = blockwise_token(len(indices)) 1401 indices.append(index) 1402 new_dsk = subs(inputs[dep].dsk, sub) 1403 1404 # indices.extend(new_indices) 1405 dsk.update(new_dsk) 1406 1407 # De-duplicate indices like [(a, ij), (b, i), (a, ij)] -> [(a, ij), (b, i)] 1408 # Make sure that we map everything else appropriately as we remove inputs 1409 new_indices = [] 1410 seen = {} 1411 sub = {} # like {_0: _0, _1: _0, _2: _1} 1412 for i, x in enumerate(indices): 1413 if x[1] is not None and x in seen: 1414 sub[i] = seen[x] 1415 else: 1416 if x[1] is not None: 1417 seen[x] = len(new_indices) 1418 sub[i] = len(new_indices) 1419 new_indices.append(x) 1420 1421 sub = {blockwise_token(k): blockwise_token(v) for k, v in sub.items()} 1422 dsk = {k: subs(v, sub) for k, v in dsk.items() if k not in sub.keys()} 1423 1424 indices_check = {k for k, v in indices if v is not None} 1425 numblocks = toolz.merge([inp.numblocks for inp in inputs.values()]) 1426 numblocks = {k: v for k, v in numblocks.items() if v is None or k in indices_check} 1427 1428 # Update IO-dependency information 1429 io_deps = {} 1430 for v in inputs.values(): 1431 io_deps.update(v.io_deps) 1432 1433 return Blockwise( 1434 root, 1435 inputs[root].output_indices, 1436 dsk, 1437 new_indices, 1438 numblocks=numblocks, 1439 new_axes=new_axes, 1440 concatenate=concatenate, 1441 annotations=inputs[root].annotations, 1442 io_deps=io_deps, 1443 ) 1444 1445 1446@_deprecated() 1447def zero_broadcast_dimensions(lol, nblocks): 1448 """ 1449 >>> lol = [('x', 1, 0), ('x', 1, 1), ('x', 1, 2)] 1450 >>> nblocks = (4, 1, 2) # note singleton dimension in second place 1451 >>> lol = [[('x', 1, 0, 0), ('x', 1, 0, 1)], 1452 ... [('x', 1, 1, 0), ('x', 1, 1, 1)], 1453 ... [('x', 1, 2, 0), ('x', 1, 2, 1)]] 1454 1455 >>> zero_broadcast_dimensions(lol, nblocks) # doctest: +SKIP 1456 [[('x', 1, 0, 0), ('x', 1, 0, 1)], 1457 [('x', 1, 0, 0), ('x', 1, 0, 1)], 1458 [('x', 1, 0, 0), ('x', 1, 0, 1)]] 1459 1460 See Also 1461 -------- 1462 lol_tuples 1463 """ 1464 f = lambda t: (t[0],) + tuple(0 if d == 1 else i for i, d in zip(t[1:], nblocks)) 1465 return homogeneous_deepmap(f, lol) 1466 1467 1468def broadcast_dimensions(argpairs, numblocks, sentinels=(1, (1,)), consolidate=None): 1469 """Find block dimensions from arguments 1470 1471 Parameters 1472 ---------- 1473 argpairs : iterable 1474 name, ijk index pairs 1475 numblocks : dict 1476 maps {name: number of blocks} 1477 sentinels : iterable (optional) 1478 values for singleton dimensions 1479 consolidate : func (optional) 1480 use this to reduce each set of common blocks into a smaller set 1481 1482 Examples 1483 -------- 1484 >>> argpairs = [('x', 'ij'), ('y', 'ji')] 1485 >>> numblocks = {'x': (2, 3), 'y': (3, 2)} 1486 >>> broadcast_dimensions(argpairs, numblocks) 1487 {'i': 2, 'j': 3} 1488 1489 Supports numpy broadcasting rules 1490 1491 >>> argpairs = [('x', 'ij'), ('y', 'ij')] 1492 >>> numblocks = {'x': (2, 1), 'y': (1, 3)} 1493 >>> broadcast_dimensions(argpairs, numblocks) 1494 {'i': 2, 'j': 3} 1495 1496 Works in other contexts too 1497 1498 >>> argpairs = [('x', 'ij'), ('y', 'ij')] 1499 >>> d = {'x': ('Hello', 1), 'y': (1, (2, 3))} 1500 >>> broadcast_dimensions(argpairs, d) 1501 {'i': 'Hello', 'j': (2, 3)} 1502 """ 1503 # List like [('i', 2), ('j', 1), ('i', 1), ('j', 2)] 1504 argpairs2 = [(a, ind) for a, ind in argpairs if ind is not None] 1505 L = toolz.concat( 1506 [ 1507 zip(inds, dims) 1508 for (x, inds), (x, dims) in toolz.join( 1509 toolz.first, argpairs2, toolz.first, numblocks.items() 1510 ) 1511 ] 1512 ) 1513 1514 g = toolz.groupby(0, L) 1515 g = {k: {d for i, d in v} for k, v in g.items()} 1516 1517 g2 = {k: v - set(sentinels) if len(v) > 1 else v for k, v in g.items()} 1518 1519 if consolidate: 1520 return toolz.valmap(consolidate, g2) 1521 1522 if g2 and not set(map(len, g2.values())) == {1}: 1523 raise ValueError("Shapes do not align %s" % g) 1524 1525 return toolz.valmap(toolz.first, g2) 1526 1527 1528def _make_dims(indices, numblocks, new_axes): 1529 """Returns a dictionary mapping between each index specified in 1530 `indices` and the number of output blocks for that indice. 1531 """ 1532 dims = broadcast_dimensions(indices, numblocks) 1533 for k, v in new_axes.items(): 1534 dims[k] = len(v) if isinstance(v, tuple) else 1 1535 return dims 1536 1537 1538def fuse_roots(graph: HighLevelGraph, keys: list): 1539 """ 1540 Fuse nearby layers if they don't have dependencies 1541 1542 Often Blockwise sections of the graph fill out all of the computation 1543 except for the initial data access or data loading layers:: 1544 1545 Large Blockwise Layer 1546 | | | 1547 X Y Z 1548 1549 This can be troublesome because X, Y, and Z tasks may be executed on 1550 different machines, and then require communication to move around. 1551 1552 This optimization identifies this situation, lowers all of the graphs to 1553 concrete dicts, and then calls ``fuse`` on them, with a width equal to the 1554 number of layers like X, Y, and Z. 1555 1556 This is currently used within array and dataframe optimizations. 1557 1558 Parameters 1559 ---------- 1560 graph : HighLevelGraph 1561 The full graph of the computation 1562 keys : list 1563 The output keys of the computation, to be passed on to fuse 1564 1565 See Also 1566 -------- 1567 Blockwise 1568 fuse 1569 """ 1570 layers = ensure_dict(graph.layers, copy=True) 1571 dependencies = ensure_dict(graph.dependencies, copy=True) 1572 dependents = reverse_dict(dependencies) 1573 1574 for name, layer in graph.layers.items(): 1575 deps = graph.dependencies[name] 1576 if ( 1577 isinstance(layer, Blockwise) 1578 and len(deps) > 1 1579 and not any(dependencies[dep] for dep in deps) # no need to fuse if 0 or 1 1580 and all(len(dependents[dep]) == 1 for dep in deps) 1581 and all(layer.annotations == graph.layers[dep].annotations for dep in deps) 1582 ): 1583 new = toolz.merge(layer, *[layers[dep] for dep in deps]) 1584 new, _ = fuse(new, keys, ave_width=len(deps)) 1585 1586 for dep in deps: 1587 del layers[dep] 1588 del dependencies[dep] 1589 1590 layers[name] = new 1591 dependencies[name] = set() 1592 1593 return HighLevelGraph(layers, dependencies) 1594