1import itertools
2import os
3from itertools import product
4from typing import (
5    Any,
6    Hashable,
7    Iterable,
8    List,
9    Mapping,
10    Optional,
11    Sequence,
12    Set,
13    Tuple,
14    Union,
15)
16
17import tlz as toolz
18
19from .base import clone_key, get_name_from_key, tokenize
20from .compatibility import prod
21from .core import flatten, keys_in_tasks, reverse_dict
22from .delayed import unpack_collections
23from .highlevelgraph import HighLevelGraph, Layer
24from .optimization import SubgraphCallable, fuse
25from .utils import (
26    _deprecated,
27    apply,
28    ensure_dict,
29    homogeneous_deepmap,
30    stringify,
31    stringify_collection_keys,
32)
33
34
35class BlockwiseDep:
36    """Blockwise-IO argument
37
38    This is the base class for indexable Blockwise-IO arguments.
39    When constructing a ``Blockwise`` Layer, one or more of the
40    collection tuples passed in with ``indices`` may contain a
41    ``BlockwiseDep`` instance (in place of a "real" collection name).
42    This allows a new collection to be created (via IO) within a
43    ``Blockwise`` layer.
44
45    All ``BlockwiseDep`` instances must define a ``numblocks``
46    attribute to speficy the number of blocks/partitions the
47    object can support along each dimension. The object should
48    also define a ``produces_tasks`` attribute to specify if
49    any nested tasks will be passed to the Blockwise function.
50
51    See Also
52    --------
53    dask.blockwise.Blockwise
54    dask.blockwise.BlockwiseDepDict
55    """
56
57    numblocks: Tuple[int, ...]
58    produces_tasks: bool
59
60    def __getitem__(self, idx: Tuple[int, ...]) -> Any:
61        """Return Blockwise-function arguments for a specific index"""
62        raise NotImplementedError(
63            "Must define `__getitem__` for `BlockwiseDep` subclass."
64        )
65
66    def get(self, idx: Tuple[int, ...], default) -> Any:
67        """BlockwiseDep ``__getitem__`` Wrapper"""
68        try:
69            return self.__getitem__(idx)
70        except KeyError:
71            return default
72
73    def __dask_distributed_pack__(
74        self, required_indices: Optional[List[Tuple[int, ...]]] = None
75    ):
76        """Client-side serialization for ``BlockwiseDep`` objects.
77
78        Should return a ``state`` dictionary, with msgpack-serializable
79        values, that can be used to initialize a new ``BlockwiseDep`` object
80        on a scheduler process.
81        """
82        raise NotImplementedError(
83            "Must define `__dask_distributed_pack__` for `BlockwiseDep` subclass."
84        )
85
86    @classmethod
87    def __dask_distributed_unpack__(cls, state):
88        """Scheduler-side deserialization for ``BlockwiseDep`` objects.
89
90        Should use an input ``state`` dictionary to initialize a new
91        ``BlockwiseDep`` object.
92        """
93        raise NotImplementedError(
94            "Must define `__dask_distributed_unpack__` for `BlockwiseDep` subclass."
95        )
96
97    def __repr__(self) -> str:
98        return f"<{type(self).__name__} {self.numblocks}>"
99
100
101class BlockwiseDepDict(BlockwiseDep):
102    """Dictionary-based Blockwise-IO argument
103
104    This is a dictionary-backed instance of ``BlockwiseDep``.
105    The purpose of this class is to simplify the construction
106    of IO-based Blockwise Layers with block/partition-dependent
107    function arguments that are difficult to calculate at
108    graph-materialization time.
109
110    Examples
111    --------
112
113    Specify an IO-based function for the Blockwise Layer. Note
114    that the function will be passed a single input object when
115    the task is executed (e.g. a single ``tuple`` or ``dict``):
116
117    >>> import pandas as pd
118    >>> func = lambda x: pd.read_csv(**x)
119
120    Use ``BlockwiseDepDict`` to define the input argument to
121    ``func`` for each block/partition:
122
123    >>> dep = BlockwiseDepDict(
124    ...     mapping={
125    ...         (0,) : {
126    ...             "filepath_or_buffer": "data.csv",
127    ...             "skiprows": 1,
128    ...             "nrows": 2,
129    ...             "names": ["a", "b"],
130    ...         },
131    ...         (1,) : {
132    ...             "filepath_or_buffer": "data.csv",
133    ...             "skiprows": 3,
134    ...             "nrows": 2,
135    ...             "names": ["a", "b"],
136    ...         },
137    ...     }
138    ... )
139
140    Construct a Blockwise Layer with ``dep`` speficied
141    in the ``indices`` list:
142
143    >>> layer = Blockwise(
144    ...     output="collection-name",
145    ...     output_indices="i",
146    ...     dsk={"collection-name": (func, '_0')},
147    ...     indices=[(dep, "i")],
148    ...     numblocks={},
149    ... )
150
151    See Also
152    --------
153    dask.blockwise.Blockwise
154    dask.blockwise.BlockwiseDep
155    """
156
157    def __init__(
158        self,
159        mapping: dict,
160        numblocks: Optional[Tuple[int, ...]] = None,
161        produces_tasks: bool = False,
162    ):
163        self.mapping = mapping
164        self.produces_tasks = produces_tasks
165
166        # By default, assume 1D shape
167        self.numblocks = numblocks or (len(mapping),)
168
169    def __getitem__(self, idx: Tuple[int, ...]) -> Any:
170        return self.mapping[idx]
171
172    def __dask_distributed_pack__(
173        self, required_indices: Optional[List[Tuple[int, ...]]] = None
174    ):
175        from distributed.protocol import to_serialize
176
177        if required_indices is None:
178            required_indices = self.mapping.keys()
179
180        return {
181            "mapping": {k: to_serialize(self.mapping[k]) for k in required_indices},
182            "numblocks": self.numblocks,
183            "produces_tasks": self.produces_tasks,
184        }
185
186    @classmethod
187    def __dask_distributed_unpack__(cls, state):
188        return cls(**state)
189
190
191def subs(task, substitution):
192    """Create a new task with the values substituted
193
194    This is like dask.core.subs, but takes a dict of many substitutions to
195    perform simultaneously.  It is not as concerned with micro performance.
196    """
197    if isinstance(task, dict):
198        return {k: subs(v, substitution) for k, v in task.items()}
199    if type(task) in (tuple, list, set):
200        return type(task)([subs(x, substitution) for x in task])
201    try:
202        return substitution[task]
203    except (KeyError, TypeError):
204        return task
205
206
207def index_subs(ind, substitution):
208    """A simple subs function that works both on tuples and strings"""
209    if ind is None:
210        return ind
211    else:
212        return tuple(substitution.get(c, c) for c in ind)
213
214
215_BLOCKWISE_DEFAULT_PREFIX = "__dask_blockwise__"
216
217
218def blockwise_token(i, prefix=_BLOCKWISE_DEFAULT_PREFIX):
219    return prefix + "%d" % i
220
221
222def blockwise(
223    func,
224    output,
225    output_indices,
226    *arrind_pairs,
227    numblocks=None,
228    concatenate=None,
229    new_axes=None,
230    dependencies=(),
231    **kwargs,
232):
233    """Create a Blockwise symbolic mutable mapping
234
235    This is like the ``make_blockwise_graph`` function, but rather than construct a
236    dict, it returns a symbolic Blockwise object.
237
238    See Also
239    --------
240    make_blockwise_graph
241    Blockwise
242    """
243    new_axes = new_axes or {}
244
245    arrind_pairs = list(arrind_pairs)
246
247    # Transform indices to canonical elements
248    # We use terms like _0, and _1 rather than provided index elements
249    unique_indices = {
250        i for ii in arrind_pairs[1::2] if ii is not None for i in ii
251    } | set(output_indices)
252    sub = {k: blockwise_token(i, ".") for i, k in enumerate(sorted(unique_indices))}
253    output_indices = index_subs(tuple(output_indices), sub)
254    a_pairs_list = []
255    for a in arrind_pairs[1::2]:
256        if a is not None:
257            val = tuple(a)
258        else:
259            val = a
260        a_pairs_list.append(index_subs(val, sub))
261
262    arrind_pairs[1::2] = a_pairs_list
263    new_axes = {index_subs((k,), sub)[0]: v for k, v in new_axes.items()}
264
265    # Unpack dask values in non-array arguments
266    inputs = []
267    inputs_indices = []
268    for name, index in toolz.partition(2, arrind_pairs):
269        inputs.append(name)
270        inputs_indices.append(index)
271
272    # Unpack delayed objects in kwargs
273    new_keys = {n for c in dependencies for n in c.__dask_layers__()}
274    if kwargs:
275        # replace keys in kwargs with _0 tokens
276        new_tokens = tuple(
277            blockwise_token(i) for i in range(len(inputs), len(inputs) + len(new_keys))
278        )
279        sub = dict(zip(new_keys, new_tokens))
280        inputs.extend(new_keys)
281        inputs_indices.extend((None,) * len(new_keys))
282        kwargs = subs(kwargs, sub)
283
284    indices = [(k, v) for k, v in zip(inputs, inputs_indices)]
285    keys = map(blockwise_token, range(len(inputs)))
286
287    # Construct local graph
288    if not kwargs:
289        subgraph = {output: (func,) + tuple(keys)}
290    else:
291        _keys = list(keys)
292        if new_keys:
293            _keys = _keys[: -len(new_keys)]
294        kwargs2 = (dict, list(map(list, kwargs.items())))
295        subgraph = {output: (apply, func, _keys, kwargs2)}
296
297    # Construct final output
298    subgraph = Blockwise(
299        output,
300        output_indices,
301        subgraph,
302        indices,
303        numblocks=numblocks,
304        concatenate=concatenate,
305        new_axes=new_axes,
306    )
307    return subgraph
308
309
310class Blockwise(Layer):
311    """Tensor Operation
312
313    This is a lazily constructed mapping for tensor operation graphs.
314    This defines a dictionary using an operation and an indexing pattern.
315    It is built for many operations like elementwise, transpose, tensordot, and
316    so on.  We choose to keep these as symbolic mappings rather than raw
317    dictionaries because we are able to fuse them during optimization,
318    sometimes resulting in much lower overhead.
319
320    Parameters
321    ----------
322    output: str
323        The name of the output collection.  Used in keynames
324    output_indices: tuple
325        The output indices, like ``('i', 'j', 'k')`` used to determine the
326        structure of the block computations
327    dsk: dict
328        A small graph to apply per-output-block.  May include keys from the
329        input indices.
330    indices: Tuple[Tuple[str, Optional[Tuple[str, ...]]], ...]
331        An ordered mapping from input key name, like ``'x'``
332        to input indices, like ``('i', 'j')``
333        Or includes literals, which have ``None`` for an index value.
334        In place of input-key names, the first tuple element may also be a
335        ``BlockwiseDep`` object.
336    numblocks: Mapping[key, Sequence[int]]
337        Number of blocks along each dimension for each input
338    concatenate: bool
339        Whether or not to pass contracted dimensions as a list of inputs or a
340        single input to the block function
341    new_axes: Mapping
342        New index dimensions that may have been created and their size,
343        e.g. ``{'j': 2, 'k': 3}``
344    output_blocks: Set[Tuple[int, ...]]
345        Specify a specific set of required output blocks. Since the graph
346        will only contain the necessary tasks to generate these outputs,
347        this kwarg can be used to "cull" the abstract layer (without needing
348        to materialize the low-level graph).
349    annotations: dict (optional)
350        Layer annotations
351    io_deps: Dict[str, BlockwiseDep] (optional)
352        Dictionary containing the mapping between "place-holder" collection
353        keys and ``BlockwiseDep``-based objects.
354        **WARNING**: This argument should only be used internally (for culling,
355        fusion and cloning of existing Blockwise layers). Explicit use of this
356        argument will be deprecated in the future.
357
358    See Also
359    --------
360    dask.blockwise.blockwise
361    dask.array.blockwise
362    """
363
364    output: str
365    output_indices: Tuple[str, ...]
366    dsk: Mapping[str, tuple]
367    indices: Tuple[Tuple[str, Optional[Tuple[str, ...]]], ...]
368    numblocks: Mapping[str, Sequence[int]]
369    concatenate: Optional[bool]
370    new_axes: Mapping[str, int]
371    output_blocks: Optional[Set[Tuple[int, ...]]]
372
373    def __init__(
374        self,
375        output: str,
376        output_indices: Iterable[str],
377        dsk: Mapping[str, tuple],
378        indices: Iterable[Tuple[Union[str, BlockwiseDep], Optional[Iterable[str]]]],
379        numblocks: Mapping[str, Sequence[int]],
380        concatenate: bool = None,
381        new_axes: Mapping[str, int] = None,
382        output_blocks: Set[Tuple[int, ...]] = None,
383        annotations: Mapping[str, Any] = None,
384        io_deps: Optional[Mapping[str, BlockwiseDep]] = None,
385    ):
386        super().__init__(annotations=annotations)
387        self.output = output
388        self.output_indices = tuple(output_indices)
389        self.output_blocks = output_blocks
390        self.dsk = dsk
391
392        # Remove `BlockwiseDep` arguments from input indices
393        # and add them to `self.io_deps`.
394        # TODO: Remove `io_deps` and handle indexable objects
395        # in `self.indices` throughout `Blockwise`.
396        self.indices = []
397        self.numblocks = numblocks
398        self.io_deps = io_deps or {}
399        for dep, ind in indices:
400            name = dep
401            if isinstance(dep, BlockwiseDep):
402                name = tokenize(dep)
403                self.io_deps[name] = dep
404                self.numblocks[name] = dep.numblocks
405            self.indices.append((name, tuple(ind) if ind is not None else ind))
406        self.indices = tuple(self.indices)
407
408        # optimize_blockwise won't merge where `concatenate` doesn't match, so
409        # enforce a canonical value if there are no axes for reduction.
410        output_indices_set = set(self.output_indices)
411        if concatenate is not None and all(
412            i in output_indices_set
413            for name, ind in self.indices
414            if ind is not None
415            for i in ind
416        ):
417            concatenate = None
418        self.concatenate = concatenate
419        self.new_axes = new_axes or {}
420
421    @property
422    def dims(self):
423        """Returns a dictionary mapping between each index specified in
424        `self.indices` and the number of output blocks for that indice.
425        """
426        if not hasattr(self, "_dims"):
427            self._dims = _make_dims(self.indices, self.numblocks, self.new_axes)
428        return self._dims
429
430    def __repr__(self):
431        return f"Blockwise<{self.indices} -> {self.output}>"
432
433    @property
434    def _dict(self):
435        if hasattr(self, "_cached_dict"):
436            return self._cached_dict["dsk"]
437        else:
438            keys = tuple(map(blockwise_token, range(len(self.indices))))
439            dsk, _ = fuse(self.dsk, [self.output])
440            func = SubgraphCallable(dsk, self.output, keys)
441
442            dsk = make_blockwise_graph(
443                func,
444                self.output,
445                self.output_indices,
446                *list(toolz.concat(self.indices)),
447                new_axes=self.new_axes,
448                numblocks=self.numblocks,
449                concatenate=self.concatenate,
450                output_blocks=self.output_blocks,
451                dims=self.dims,
452                io_deps=self.io_deps,
453            )
454
455            self._cached_dict = {"dsk": dsk}
456        return self._cached_dict["dsk"]
457
458    def get_output_keys(self):
459        if self.output_blocks:
460            # Culling has already generated a list of output blocks
461            return {(self.output, *p) for p in self.output_blocks}
462
463        # Return all possible output keys (no culling)
464        return {
465            (self.output, *p)
466            for p in itertools.product(
467                *[range(self.dims[i]) for i in self.output_indices]
468            )
469        }
470
471    def __getitem__(self, key):
472        return self._dict[key]
473
474    def __iter__(self):
475        return iter(self._dict)
476
477    def __len__(self) -> int:
478        # same method as `get_output_keys`, without manifesting the keys themselves
479        return (
480            len(self.output_blocks)
481            if self.output_blocks
482            else prod(self.dims[i] for i in self.output_indices)
483        )
484
485    def is_materialized(self):
486        return hasattr(self, "_cached_dict")
487
488    def __dask_distributed_pack__(
489        self, all_hlg_keys, known_key_dependencies, client, client_keys
490    ):
491        from distributed.protocol import to_serialize
492        from distributed.utils import CancelledError
493        from distributed.utils_comm import unpack_remotedata
494        from distributed.worker import dumps_function
495
496        keys = tuple(map(blockwise_token, range(len(self.indices))))
497        dsk, _ = fuse(self.dsk, [self.output])
498
499        # Embed literals in `dsk`
500        keys2 = []
501        indices2 = []
502        global_dependencies = set()
503        for key, (val, index) in zip(keys, self.indices):
504            if index is None:
505                try:
506                    val_is_a_key = val in all_hlg_keys
507                except TypeError:  # not hashable
508                    val_is_a_key = False
509                if val_is_a_key:
510                    keys2.append(key)
511                    indices2.append((val, index))
512                    global_dependencies.add(stringify(val))
513                else:
514                    dsk[key] = val  # Literal
515            else:
516                keys2.append(key)
517                indices2.append((val, index))
518
519        dsk = (SubgraphCallable(dsk, self.output, tuple(keys2)),)
520        dsk, dsk_unpacked_futures = unpack_remotedata(dsk, byte_keys=True)
521
522        # Handle `io_deps` serialization. Assume each element
523        # is a `BlockwiseDep`-based object.
524        packed_io_deps = {}
525        inline_tasks = False
526        for name, blockwise_dep in self.io_deps.items():
527            packed_io_deps[name] = {
528                "__module__": blockwise_dep.__module__,
529                "__name__": type(blockwise_dep).__name__,
530                # TODO: Pass a `required_indices` list to __pack__
531                "state": blockwise_dep.__dask_distributed_pack__(),
532            }
533            inline_tasks = inline_tasks or blockwise_dep.produces_tasks
534
535        # Dump (pickle + cache) the function here if we know `make_blockwise_graph`
536        # will NOT be producing "nested" tasks (via `__dask_distributed_unpack__`).
537        #
538        # If `make_blockwise_graph` DOES need to produce nested tasks later on, it
539        # will need to call `to_serialize` on the entire task.  That will be a
540        # problem if the function was already pickled here. Therefore, we want to
541        # call `to_serialize` on the function if we know there will be nested tasks.
542        #
543        # We know there will be nested tasks if either:
544        #   (1) `concatenate=True`   # Check `self.concatenate`
545        #   (2) `inline_tasks=True`  # Check `BlockwiseDep.produces_tasks`
546        #
547        # We do not call `to_serialize` in ALL cases, because that code path does
548        # not cache the function on the scheduler or worker (or warn if there are
549        # large objects being passed into the graph).  However, in the future,
550        # single-pass serialization improvements should allow us to remove this
551        # special logic altogether.
552        func = (
553            to_serialize(dsk[0])
554            if (self.concatenate or inline_tasks)
555            else dumps_function(dsk[0])
556        )
557        func_future_args = dsk[1:]
558
559        indices = list(toolz.concat(indices2))
560        indices, indices_unpacked_futures = unpack_remotedata(indices, byte_keys=True)
561
562        # Check the legality of the unpacked futures
563        for future in itertools.chain(dsk_unpacked_futures, indices_unpacked_futures):
564            if future.client is not client:
565                raise ValueError(
566                    "Inputs contain futures that were created by another client."
567                )
568            if stringify(future.key) not in client.futures:
569                raise CancelledError(stringify(future.key))
570
571        # All blockwise tasks will depend on the futures in `indices`
572        global_dependencies |= {stringify(f.key) for f in indices_unpacked_futures}
573
574        return {
575            "output": self.output,
576            "output_indices": self.output_indices,
577            "func": func,
578            "func_future_args": func_future_args,
579            "global_dependencies": global_dependencies,
580            "indices": indices,
581            "is_list": [isinstance(x, list) for x in indices],
582            "numblocks": self.numblocks,
583            "concatenate": self.concatenate,
584            "new_axes": self.new_axes,
585            "output_blocks": self.output_blocks,
586            "dims": self.dims,
587            "io_deps": packed_io_deps,
588        }
589
590    @classmethod
591    def __dask_distributed_unpack__(cls, state, dsk, dependencies):
592        from distributed.protocol.serialize import import_allowed_module
593
594        # Make sure we convert list items back from tuples in `indices`.
595        # The msgpack serialization will have converted lists into
596        # tuples, and tuples may be stringified during graph
597        # materialization (bad if the item was not a key).
598        indices = [
599            list(ind) if is_list else ind
600            for ind, is_list in zip(state["indices"], state["is_list"])
601        ]
602
603        # Unpack io_deps state
604        io_deps = {}
605        for replace_name, packed_dep in state["io_deps"].items():
606            mod = import_allowed_module(packed_dep["__module__"])
607            dep_cls = getattr(mod, packed_dep["__name__"])
608            io_deps[replace_name] = dep_cls.__dask_distributed_unpack__(
609                packed_dep["state"]
610            )
611
612        layer_dsk, layer_deps = make_blockwise_graph(
613            state["func"],
614            state["output"],
615            state["output_indices"],
616            *indices,
617            new_axes=state["new_axes"],
618            numblocks=state["numblocks"],
619            concatenate=state["concatenate"],
620            output_blocks=state["output_blocks"],
621            dims=state["dims"],
622            return_key_deps=True,
623            deserializing=True,
624            func_future_args=state["func_future_args"],
625            io_deps=io_deps,
626        )
627        g_deps = state["global_dependencies"]
628
629        # Stringify layer graph and dependencies
630        layer_dsk = {
631            stringify(k): stringify_collection_keys(v) for k, v in layer_dsk.items()
632        }
633        deps = {
634            stringify(k): {stringify(d) for d in v} | g_deps
635            for k, v in layer_deps.items()
636        }
637        return {"dsk": layer_dsk, "deps": deps}
638
639    def _cull_dependencies(self, all_hlg_keys, output_blocks):
640        """Determine the necessary dependencies to produce `output_blocks`.
641
642        This method does not require graph materialization.
643        """
644
645        # Check `concatenate` option
646        concatenate = None
647        if self.concatenate is True:
648            from dask.array.core import concatenate_axes as concatenate
649
650        # Generate coordinate map
651        (coord_maps, concat_axes, dummies) = _get_coord_mapping(
652            self.dims,
653            self.output,
654            self.output_indices,
655            self.numblocks,
656            self.indices,
657            concatenate,
658        )
659
660        # Gather constant dependencies (for all output keys)
661        const_deps = set()
662        for (arg, ind) in self.indices:
663            if ind is None:
664                try:
665                    if arg in all_hlg_keys:
666                        const_deps.add(arg)
667                except TypeError:
668                    pass  # unhashable
669
670        # Get dependencies for each output block
671        key_deps = {}
672        for out_coords in output_blocks:
673            deps = set()
674            coords = out_coords + dummies
675            for cmap, axes, (arg, ind) in zip(coord_maps, concat_axes, self.indices):
676                if ind is not None and arg not in self.io_deps:
677                    arg_coords = tuple(coords[c] for c in cmap)
678                    if axes:
679                        tups = lol_product((arg,), arg_coords)
680                        deps.update(flatten(tups))
681                        if concatenate:
682                            tups = (concatenate, tups, axes)
683                    else:
684                        tups = (arg,) + arg_coords
685                        deps.add(tups)
686            key_deps[(self.output,) + out_coords] = deps | const_deps
687
688        return key_deps
689
690    def _cull(self, output_blocks):
691        return Blockwise(
692            self.output,
693            self.output_indices,
694            self.dsk,
695            self.indices,
696            self.numblocks,
697            concatenate=self.concatenate,
698            new_axes=self.new_axes,
699            output_blocks=output_blocks,
700            annotations=self.annotations,
701            io_deps=self.io_deps,
702        )
703
704    def cull(
705        self, keys: set, all_hlg_keys: Iterable
706    ) -> Tuple[Layer, Mapping[Hashable, set]]:
707        # Culling is simple for Blockwise layers.  We can just
708        # collect a set of required output blocks (tuples), and
709        # only construct graph for these blocks in `make_blockwise_graph`
710
711        output_blocks = set()
712        for key in keys:
713            if key[0] == self.output:
714                output_blocks.add(key[1:])
715        culled_deps = self._cull_dependencies(all_hlg_keys, output_blocks)
716        out_size_iter = (self.dims[i] for i in self.output_indices)
717        if prod(out_size_iter) != len(culled_deps):
718            culled_layer = self._cull(output_blocks)
719            return culled_layer, culled_deps
720        else:
721            return self, culled_deps
722
723    def clone(
724        self,
725        keys: set,
726        seed: Hashable,
727        bind_to: Hashable = None,
728    ) -> Tuple[Layer, bool]:
729        names = {get_name_from_key(k) for k in keys}
730        # We assume that 'keys' will contain either all or none of the output keys of
731        # each of the layers, because clone/bind are always invoked at collection level.
732        # Asserting this is very expensive, so we only check it during unit tests.
733        if "PYTEST_CURRENT_TEST" in os.environ:
734            assert not self.get_output_keys() - keys
735            for name, nb in self.numblocks.items():
736                if name in names:
737                    for block in product(*(list(range(nbi)) for nbi in nb)):
738                        assert (name, *block) in keys
739
740        is_leaf = True
741
742        indices = []
743        for k, idxv in self.indices:
744            if k in names:
745                is_leaf = False
746                k = clone_key(k, seed)
747            indices.append((k, idxv))
748
749        numblocks = {}
750        for k, nbv in self.numblocks.items():
751            if k in names:
752                is_leaf = False
753                k = clone_key(k, seed)
754            numblocks[k] = nbv
755
756        dsk = {clone_key(k, seed): v for k, v in self.dsk.items()}
757
758        if bind_to is not None and is_leaf:
759            from .graph_manipulation import chunks
760
761            # It's always a Delayed generated by dask.graph_manipulation.checkpoint;
762            # the layer name always matches the key
763            assert isinstance(bind_to, str)
764            dsk = {k: (chunks.bind, v, f"_{len(indices)}") for k, v in dsk.items()}
765            indices.append((bind_to, None))
766
767        return (
768            Blockwise(
769                output=clone_key(self.output, seed),
770                output_indices=self.output_indices,
771                dsk=dsk,
772                indices=indices,
773                numblocks=numblocks,
774                concatenate=self.concatenate,
775                new_axes=self.new_axes,
776                output_blocks=self.output_blocks,
777                annotations=self.annotations,
778                io_deps=self.io_deps,
779            ),
780            (bind_to is not None and is_leaf),
781        )
782
783
784def _get_coord_mapping(
785    dims,
786    output,
787    out_indices,
788    numblocks,
789    argpairs,
790    concatenate,
791):
792    """Calculate coordinate mapping for graph construction.
793
794    This function handles the high-level logic behind Blockwise graph
795    construction. The output is a tuple containing: The mapping between
796    input and output block coordinates (`coord_maps`), the axes along
797    which to concatenate for each input (`concat_axes`), and the dummy
798    indices needed for broadcasting (`dummies`).
799
800    Used by `make_blockwise_graph` and `Blockwise._cull_dependencies`.
801
802    Parameters
803    ----------
804    dims : dict
805        Mapping between each index specified in `argpairs` and
806        the number of output blocks for that index. Corresponds
807        to the Blockwise `dims` attribute.
808    output : str
809        Corresponds to the Blockwise `output` attribute.
810    out_indices : tuple
811        Corresponds to the Blockwise `output_indices` attribute.
812    numblocks : dict
813        Corresponds to the Blockwise `numblocks` attribute.
814    argpairs : tuple
815        Corresponds to the Blockwise `indices` attribute.
816    concatenate : bool
817        Corresponds to the Blockwise `concatenate` attribute.
818    """
819
820    block_names = set()
821    all_indices = set()
822    for name, ind in argpairs:
823        if ind is not None:
824            block_names.add(name)
825            for x in ind:
826                all_indices.add(x)
827    assert set(numblocks) == block_names
828
829    dummy_indices = all_indices - set(out_indices)
830
831    # For each position in the output space, we'll construct a
832    # "coordinate set" that consists of
833    # - the output indices
834    # - the dummy indices
835    # - the dummy indices, with indices replaced by zeros (for broadcasting), we
836    #   are careful to only emit a single dummy zero when concatenate=True to not
837    #   concatenate the same array with itself several times.
838    # - a 0 to assist with broadcasting.
839
840    index_pos, zero_pos = {}, {}
841    for i, ind in enumerate(out_indices):
842        index_pos[ind] = i
843        zero_pos[ind] = -1
844
845    _dummies_list = []
846    for i, ind in enumerate(dummy_indices):
847        index_pos[ind] = 2 * i + len(out_indices)
848        zero_pos[ind] = 2 * i + 1 + len(out_indices)
849        reps = 1 if concatenate else dims[ind]
850        _dummies_list.append([list(range(dims[ind])), [0] * reps])
851
852    # ([0, 1, 2], [0, 0, 0], ...)  For a dummy index of dimension 3
853    dummies = tuple(itertools.chain.from_iterable(_dummies_list))
854    dummies += (0,)
855
856    # For each coordinate position in each input, gives the position in
857    # the coordinate set.
858    coord_maps = []
859
860    # Axes along which to concatenate, for each input
861    concat_axes = []
862    for arg, ind in argpairs:
863        if ind is not None:
864            coord_maps.append(
865                [
866                    zero_pos[i] if nb == 1 else index_pos[i]
867                    for i, nb in zip(ind, numblocks[arg])
868                ]
869            )
870            concat_axes.append([n for n, i in enumerate(ind) if i in dummy_indices])
871        else:
872            coord_maps.append(None)
873            concat_axes.append(None)
874
875    return coord_maps, concat_axes, dummies
876
877
878def make_blockwise_graph(
879    func,
880    output,
881    out_indices,
882    *arrind_pairs,
883    numblocks=None,
884    concatenate=None,
885    new_axes=None,
886    output_blocks=None,
887    dims=None,
888    deserializing=False,
889    func_future_args=None,
890    return_key_deps=False,
891    io_deps=None,
892    **kwargs,
893):
894    """Tensor operation
895
896    Applies a function, ``func``, across blocks from many different input
897    collections.  We arrange the pattern with which those blocks interact with
898    sets of matching indices.  E.g.::
899
900        make_blockwise_graph(func, 'z', 'i', 'x', 'i', 'y', 'i')
901
902    yield an embarrassingly parallel communication pattern and is read as
903
904        $$ z_i = func(x_i, y_i) $$
905
906    More complex patterns may emerge, including multiple indices::
907
908        make_blockwise_graph(func, 'z', 'ij', 'x', 'ij', 'y', 'ji')
909
910        $$ z_{ij} = func(x_{ij}, y_{ji}) $$
911
912    Indices missing in the output but present in the inputs results in many
913    inputs being sent to one function (see examples).
914
915    Examples
916    --------
917    Simple embarrassing map operation
918
919    >>> inc = lambda x: x + 1
920    >>> make_blockwise_graph(inc, 'z', 'ij', 'x', 'ij', numblocks={'x': (2, 2)})  # doctest: +SKIP
921    {('z', 0, 0): (inc, ('x', 0, 0)),
922     ('z', 0, 1): (inc, ('x', 0, 1)),
923     ('z', 1, 0): (inc, ('x', 1, 0)),
924     ('z', 1, 1): (inc, ('x', 1, 1))}
925
926    Simple operation on two datasets
927
928    >>> add = lambda x, y: x + y
929    >>> make_blockwise_graph(add, 'z', 'ij', 'x', 'ij', 'y', 'ij', numblocks={'x': (2, 2),
930    ...                                                      'y': (2, 2)})  # doctest: +SKIP
931    {('z', 0, 0): (add, ('x', 0, 0), ('y', 0, 0)),
932     ('z', 0, 1): (add, ('x', 0, 1), ('y', 0, 1)),
933     ('z', 1, 0): (add, ('x', 1, 0), ('y', 1, 0)),
934     ('z', 1, 1): (add, ('x', 1, 1), ('y', 1, 1))}
935
936    Operation that flips one of the datasets
937
938    >>> addT = lambda x, y: x + y.T  # Transpose each chunk
939    >>> #                                        z_ij ~ x_ij y_ji
940    >>> #               ..         ..         .. notice swap
941    >>> make_blockwise_graph(addT, 'z', 'ij', 'x', 'ij', 'y', 'ji', numblocks={'x': (2, 2),
942    ...                                                       'y': (2, 2)})  # doctest: +SKIP
943    {('z', 0, 0): (add, ('x', 0, 0), ('y', 0, 0)),
944     ('z', 0, 1): (add, ('x', 0, 1), ('y', 1, 0)),
945     ('z', 1, 0): (add, ('x', 1, 0), ('y', 0, 1)),
946     ('z', 1, 1): (add, ('x', 1, 1), ('y', 1, 1))}
947
948    Dot product with contraction over ``j`` index.  Yields list arguments
949
950    >>> make_blockwise_graph(dotmany, 'z', 'ik', 'x', 'ij', 'y', 'jk', numblocks={'x': (2, 2),
951    ...                                                          'y': (2, 2)})  # doctest: +SKIP
952    {('z', 0, 0): (dotmany, [('x', 0, 0), ('x', 0, 1)],
953                            [('y', 0, 0), ('y', 1, 0)]),
954     ('z', 0, 1): (dotmany, [('x', 0, 0), ('x', 0, 1)],
955                            [('y', 0, 1), ('y', 1, 1)]),
956     ('z', 1, 0): (dotmany, [('x', 1, 0), ('x', 1, 1)],
957                            [('y', 0, 0), ('y', 1, 0)]),
958     ('z', 1, 1): (dotmany, [('x', 1, 0), ('x', 1, 1)],
959                            [('y', 0, 1), ('y', 1, 1)])}
960
961    Pass ``concatenate=True`` to concatenate arrays ahead of time
962
963    >>> make_blockwise_graph(f, 'z', 'i', 'x', 'ij', 'y', 'ij', concatenate=True,
964    ...     numblocks={'x': (2, 2), 'y': (2, 2,)})  # doctest: +SKIP
965    {('z', 0): (f, (concatenate_axes, [('x', 0, 0), ('x', 0, 1)], (1,)),
966                   (concatenate_axes, [('y', 0, 0), ('y', 0, 1)], (1,)))
967     ('z', 1): (f, (concatenate_axes, [('x', 1, 0), ('x', 1, 1)], (1,)),
968                   (concatenate_axes, [('y', 1, 0), ('y', 1, 1)], (1,)))}
969
970    Supports Broadcasting rules
971
972    >>> make_blockwise_graph(add, 'z', 'ij', 'x', 'ij', 'y', 'ij', numblocks={'x': (1, 2),
973    ...                                                      'y': (2, 2)})  # doctest: +SKIP
974    {('z', 0, 0): (add, ('x', 0, 0), ('y', 0, 0)),
975     ('z', 0, 1): (add, ('x', 0, 1), ('y', 0, 1)),
976     ('z', 1, 0): (add, ('x', 0, 0), ('y', 1, 0)),
977     ('z', 1, 1): (add, ('x', 0, 1), ('y', 1, 1))}
978
979    Support keyword arguments with apply
980
981    >>> def f(a, b=0): return a + b
982    >>> make_blockwise_graph(f, 'z', 'i', 'x', 'i', numblocks={'x': (2,)}, b=10)  # doctest: +SKIP
983    {('z', 0): (apply, f, [('x', 0)], {'b': 10}),
984     ('z', 1): (apply, f, [('x', 1)], {'b': 10})}
985
986    Include literals by indexing with ``None``
987
988    >>> make_blockwise_graph(add, 'z', 'i', 'x', 'i', 100, None,  numblocks={'x': (2,)})  # doctest: +SKIP
989    {('z', 0): (add, ('x', 0), 100),
990     ('z', 1): (add, ('x', 1), 100)}
991
992    See Also
993    --------
994    dask.array.blockwise
995    dask.blockwise.blockwise
996    """
997
998    if numblocks is None:
999        raise ValueError("Missing required numblocks argument.")
1000    new_axes = new_axes or {}
1001    io_deps = io_deps or {}
1002    argpairs = list(toolz.partition(2, arrind_pairs))
1003
1004    if return_key_deps:
1005        key_deps = {}
1006
1007    if deserializing:
1008        from distributed.protocol.serialize import to_serialize
1009
1010    if concatenate is True:
1011        from dask.array.core import concatenate_axes as concatenate
1012
1013    # Dictionary mapping {i: 3, j: 4, ...} for i, j, ... the dimensions
1014    dims = dims or _make_dims(argpairs, numblocks, new_axes)
1015
1016    # Generate the abstract "plan" before constructing
1017    # the actual graph
1018    (coord_maps, concat_axes, dummies) = _get_coord_mapping(
1019        dims,
1020        output,
1021        out_indices,
1022        numblocks,
1023        argpairs,
1024        concatenate,
1025    )
1026
1027    # Unpack delayed objects in kwargs
1028    dsk2 = {}
1029    if kwargs:
1030        task, dsk2 = unpack_collections(kwargs)
1031        if dsk2:
1032            kwargs2 = task
1033        else:
1034            kwargs2 = kwargs
1035
1036    # Apply Culling.
1037    # Only need to construct the specified set of output blocks
1038    output_blocks = output_blocks or itertools.product(
1039        *[range(dims[i]) for i in out_indices]
1040    )
1041
1042    dsk = {}
1043    # Create argument lists
1044    for out_coords in output_blocks:
1045        deps = set()
1046        coords = out_coords + dummies
1047        args = []
1048        for cmap, axes, (arg, ind) in zip(coord_maps, concat_axes, argpairs):
1049            if ind is None:
1050                if deserializing:
1051                    args.append(stringify_collection_keys(arg))
1052                else:
1053                    args.append(arg)
1054            else:
1055                arg_coords = tuple(coords[c] for c in cmap)
1056                if axes:
1057                    tups = lol_product((arg,), arg_coords)
1058                    if arg not in io_deps:
1059                        deps.update(flatten(tups))
1060
1061                    if concatenate:
1062                        tups = (concatenate, tups, axes)
1063                else:
1064                    tups = (arg,) + arg_coords
1065                    if arg not in io_deps:
1066                        deps.add(tups)
1067                # Replace "place-holder" IO keys with "real" args
1068                if arg in io_deps:
1069                    # We don't want to stringify keys for args
1070                    # we are replacing here
1071                    idx = tups[1:]
1072                    args.append(io_deps[arg].get(idx, idx))
1073                elif deserializing:
1074                    args.append(stringify_collection_keys(tups))
1075                else:
1076                    args.append(tups)
1077        out_key = (output,) + out_coords
1078
1079        if deserializing:
1080            deps.update(func_future_args)
1081            args += list(func_future_args)
1082
1083        if deserializing and isinstance(func, bytes):
1084            # Construct a function/args/kwargs dict if we
1085            # do not have a nested task (i.e. concatenate=False).
1086            # TODO: Avoid using the iterate_collection-version
1087            # of to_serialize if we know that are no embeded
1088            # Serialized/Serialize objects in args and/or kwargs.
1089            if kwargs:
1090                dsk[out_key] = {
1091                    "function": func,
1092                    "args": to_serialize(args),
1093                    "kwargs": to_serialize(kwargs2),
1094                }
1095            else:
1096                dsk[out_key] = {"function": func, "args": to_serialize(args)}
1097        else:
1098            if kwargs:
1099                val = (apply, func, args, kwargs2)
1100            else:
1101                args.insert(0, func)
1102                val = tuple(args)
1103            # May still need to serialize (if concatenate=True)
1104            dsk[out_key] = to_serialize(val) if deserializing else val
1105
1106        if return_key_deps:
1107            key_deps[out_key] = deps
1108
1109    if dsk2:
1110        dsk.update(ensure_dict(dsk2))
1111
1112    if return_key_deps:
1113        return dsk, key_deps
1114    else:
1115        return dsk
1116
1117
1118def lol_product(head, values):
1119    """List of list of tuple keys, similar to `itertools.product`.
1120
1121    Parameters
1122    ----------
1123    head : tuple
1124        Prefix prepended to all results.
1125    values : sequence
1126        Mix of singletons and lists. Each list is substituted with every
1127        possible value and introduces another level of list in the output.
1128
1129    Examples
1130    --------
1131    >>> lol_product(('x',), (1, 2, 3))
1132    ('x', 1, 2, 3)
1133    >>> lol_product(('x',), (1, [2, 3], 4, [5, 6]))  # doctest: +NORMALIZE_WHITESPACE
1134    [[('x', 1, 2, 4, 5), ('x', 1, 2, 4, 6)],
1135     [('x', 1, 3, 4, 5), ('x', 1, 3, 4, 6)]]
1136    """
1137    if not values:
1138        return head
1139    elif isinstance(values[0], list):
1140        return [lol_product(head + (x,), values[1:]) for x in values[0]]
1141    else:
1142        return lol_product(head + (values[0],), values[1:])
1143
1144
1145def lol_tuples(head, ind, values, dummies):
1146    """List of list of tuple keys
1147
1148    Parameters
1149    ----------
1150    head : tuple
1151        The known tuple so far
1152    ind : Iterable
1153        An iterable of indices not yet covered
1154    values : dict
1155        Known values for non-dummy indices
1156    dummies : dict
1157        Ranges of values for dummy indices
1158
1159    Examples
1160    --------
1161    >>> lol_tuples(('x',), 'ij', {'i': 1, 'j': 0}, {})
1162    ('x', 1, 0)
1163
1164    >>> lol_tuples(('x',), 'ij', {'i': 1}, {'j': range(3)})
1165    [('x', 1, 0), ('x', 1, 1), ('x', 1, 2)]
1166
1167    >>> lol_tuples(('x',), 'ijk', {'i': 1}, {'j': [0, 1, 2], 'k': [0, 1]}) # doctest: +NORMALIZE_WHITESPACE
1168    [[('x', 1, 0, 0), ('x', 1, 0, 1)],
1169     [('x', 1, 1, 0), ('x', 1, 1, 1)],
1170     [('x', 1, 2, 0), ('x', 1, 2, 1)]]
1171    """
1172    if not ind:
1173        return head
1174    if ind[0] not in dummies:
1175        return lol_tuples(head + (values[ind[0]],), ind[1:], values, dummies)
1176    else:
1177        return [
1178            lol_tuples(head + (v,), ind[1:], values, dummies) for v in dummies[ind[0]]
1179        ]
1180
1181
1182def optimize_blockwise(graph, keys=()):
1183    """High level optimization of stacked Blockwise layers
1184
1185    For operations that have multiple Blockwise operations one after the other, like
1186    ``x.T + 123`` we can fuse these into a single Blockwise operation.  This happens
1187    before any actual tasks are generated, and so can reduce overhead.
1188
1189    This finds groups of Blockwise operations that can be safely fused, and then
1190    passes them to ``rewrite_blockwise`` for rewriting.
1191
1192    Parameters
1193    ----------
1194    graph : HighLevelGraph
1195    keys : Iterable
1196        The keys of all outputs of all collections.
1197        Used to make sure that we don't fuse a layer needed by an output
1198
1199    Returns
1200    -------
1201    HighLevelGraph
1202
1203    See Also
1204    --------
1205    rewrite_blockwise
1206    """
1207    out = _optimize_blockwise(graph, keys=keys)
1208    while out.dependencies != graph.dependencies:
1209        graph = out
1210        out = _optimize_blockwise(graph, keys=keys)
1211    return out
1212
1213
1214def _optimize_blockwise(full_graph, keys=()):
1215    keep = {k[0] if type(k) is tuple else k for k in keys}
1216    layers = full_graph.layers
1217    dependents = reverse_dict(full_graph.dependencies)
1218    roots = {k for k in full_graph.layers if not dependents.get(k)}
1219    stack = list(roots)
1220
1221    out = {}
1222    dependencies = {}
1223    seen = set()
1224    io_names = set()
1225
1226    while stack:
1227        layer = stack.pop()
1228        if layer in seen or layer not in layers:
1229            continue
1230        seen.add(layer)
1231
1232        # Outer loop walks through possible output Blockwise layers
1233        if isinstance(layers[layer], Blockwise):
1234            blockwise_layers = {layer}
1235            deps = set(blockwise_layers)
1236            io_names |= layers[layer].io_deps.keys()
1237            while deps:  # we gather as many sub-layers as we can
1238                dep = deps.pop()
1239
1240                if dep not in layers:
1241                    stack.append(dep)
1242                    continue
1243                if not isinstance(layers[dep], Blockwise):
1244                    stack.append(dep)
1245                    continue
1246                if dep != layer and dep in keep:
1247                    stack.append(dep)
1248                    continue
1249                if layers[dep].concatenate != layers[layer].concatenate:
1250                    stack.append(dep)
1251                    continue
1252                if (
1253                    sum(k == dep for k, ind in layers[layer].indices if ind is not None)
1254                    > 1
1255                ):
1256                    stack.append(dep)
1257                    continue
1258                if (
1259                    blockwise_layers
1260                    and layers[next(iter(blockwise_layers))].annotations
1261                    != layers[dep].annotations
1262                ):
1263                    stack.append(dep)
1264                    continue
1265
1266                # passed everything, proceed
1267                blockwise_layers.add(dep)
1268
1269                # traverse further to this child's children
1270                for d in full_graph.dependencies.get(dep, ()):
1271                    # Don't allow reductions to proceed
1272                    output_indices = set(layers[dep].output_indices)
1273                    input_indices = {
1274                        i for _, ind in layers[dep].indices if ind for i in ind
1275                    }
1276
1277                    if len(dependents[d]) <= 1 and output_indices.issuperset(
1278                        input_indices
1279                    ):
1280                        deps.add(d)
1281                    else:
1282                        stack.append(d)
1283
1284            # Merge these Blockwise layers into one
1285            new_layer = rewrite_blockwise([layers[l] for l in blockwise_layers])
1286            out[layer] = new_layer
1287
1288            new_deps = set()
1289            for k, v in new_layer.indices:
1290                if v is None:
1291                    new_deps |= keys_in_tasks(full_graph.dependencies, [k])
1292                elif k not in io_names:
1293                    new_deps.add(k)
1294            dependencies[layer] = new_deps
1295        else:
1296            out[layer] = layers[layer]
1297            dependencies[layer] = full_graph.dependencies.get(layer, set())
1298            stack.extend(full_graph.dependencies.get(layer, ()))
1299
1300    return HighLevelGraph(out, dependencies)
1301
1302
1303def rewrite_blockwise(inputs):
1304    """Rewrite a stack of Blockwise expressions into a single blockwise expression
1305
1306    Given a set of Blockwise layers, combine them into a single layer.  The provided
1307    layers are expected to fit well together.  That job is handled by
1308    ``optimize_blockwise``
1309
1310    Parameters
1311    ----------
1312    inputs : List[Blockwise]
1313
1314    Returns
1315    -------
1316    blockwise: Blockwise
1317
1318    See Also
1319    --------
1320    optimize_blockwise
1321    """
1322    if len(inputs) == 1:
1323        # Fast path: if there's only one input we can just use it as-is.
1324        return inputs[0]
1325
1326    inputs = {inp.output: inp for inp in inputs}
1327    dependencies = {
1328        inp.output: {d for d, v in inp.indices if v is not None and d in inputs}
1329        for inp in inputs.values()
1330    }
1331    dependents = reverse_dict(dependencies)
1332
1333    new_index_iter = (
1334        c + (str(d) if d else "")  # A, B, ... A1, B1, ...
1335        for d in itertools.count()
1336        for c in "ABCDEFGHIJKLMNOPQRSTUVWXYZ"
1337    )
1338
1339    [root] = [k for k, v in dependents.items() if not v]
1340
1341    # Our final results.  These will change during fusion below
1342    indices = list(inputs[root].indices)
1343    new_axes = inputs[root].new_axes
1344    concatenate = inputs[root].concatenate
1345    dsk = dict(inputs[root].dsk)
1346
1347    changed = True
1348    while changed:
1349        changed = False
1350        for i, (dep, ind) in enumerate(indices):
1351            if ind is None:
1352                continue
1353            if dep not in inputs:
1354                continue
1355
1356            changed = True
1357
1358            # Replace _n with dep name in existing tasks
1359            # (inc, _0) -> (inc, 'b')
1360            dsk = {k: subs(v, {blockwise_token(i): dep}) for k, v in dsk.items()}
1361
1362            # Remove current input from input indices
1363            # [('a', 'i'), ('b', 'i')] -> [('a', 'i')]
1364            _, current_dep_indices = indices.pop(i)
1365            sub = {
1366                blockwise_token(i): blockwise_token(i - 1)
1367                for i in range(i + 1, len(indices) + 1)
1368            }
1369            dsk = subs(dsk, sub)
1370
1371            # Change new input_indices to match give index from current computation
1372            # [('c', j')] -> [('c', 'i')]
1373            new_indices = inputs[dep].indices
1374            sub = dict(zip(inputs[dep].output_indices, current_dep_indices))
1375            contracted = {
1376                x
1377                for _, j in new_indices
1378                if j is not None
1379                for x in j
1380                if x not in inputs[dep].output_indices
1381            }
1382            extra = dict(zip(contracted, new_index_iter))
1383            sub.update(extra)
1384            new_indices = [(x, index_subs(j, sub)) for x, j in new_indices]
1385
1386            # Update new_axes
1387            for k, v in inputs[dep].new_axes.items():
1388                new_axes[sub[k]] = v
1389
1390            # Bump new inputs up in list
1391            sub = {}
1392            # Map from (id(key), inds or None) -> index in indices. Used to deduplicate indices.
1393            index_map = {(id(k), inds): n for n, (k, inds) in enumerate(indices)}
1394            for ii, index in enumerate(new_indices):
1395                id_key = (id(index[0]), index[1])
1396                if id_key in index_map:  # use old inputs if available
1397                    sub[blockwise_token(ii)] = blockwise_token(index_map[id_key])
1398                else:
1399                    index_map[id_key] = len(indices)
1400                    sub[blockwise_token(ii)] = blockwise_token(len(indices))
1401                    indices.append(index)
1402            new_dsk = subs(inputs[dep].dsk, sub)
1403
1404            # indices.extend(new_indices)
1405            dsk.update(new_dsk)
1406
1407    # De-duplicate indices like [(a, ij), (b, i), (a, ij)] -> [(a, ij), (b, i)]
1408    # Make sure that we map everything else appropriately as we remove inputs
1409    new_indices = []
1410    seen = {}
1411    sub = {}  # like {_0: _0, _1: _0, _2: _1}
1412    for i, x in enumerate(indices):
1413        if x[1] is not None and x in seen:
1414            sub[i] = seen[x]
1415        else:
1416            if x[1] is not None:
1417                seen[x] = len(new_indices)
1418            sub[i] = len(new_indices)
1419            new_indices.append(x)
1420
1421    sub = {blockwise_token(k): blockwise_token(v) for k, v in sub.items()}
1422    dsk = {k: subs(v, sub) for k, v in dsk.items() if k not in sub.keys()}
1423
1424    indices_check = {k for k, v in indices if v is not None}
1425    numblocks = toolz.merge([inp.numblocks for inp in inputs.values()])
1426    numblocks = {k: v for k, v in numblocks.items() if v is None or k in indices_check}
1427
1428    # Update IO-dependency information
1429    io_deps = {}
1430    for v in inputs.values():
1431        io_deps.update(v.io_deps)
1432
1433    return Blockwise(
1434        root,
1435        inputs[root].output_indices,
1436        dsk,
1437        new_indices,
1438        numblocks=numblocks,
1439        new_axes=new_axes,
1440        concatenate=concatenate,
1441        annotations=inputs[root].annotations,
1442        io_deps=io_deps,
1443    )
1444
1445
1446@_deprecated()
1447def zero_broadcast_dimensions(lol, nblocks):
1448    """
1449    >>> lol = [('x', 1, 0), ('x', 1, 1), ('x', 1, 2)]
1450    >>> nblocks = (4, 1, 2)  # note singleton dimension in second place
1451    >>> lol = [[('x', 1, 0, 0), ('x', 1, 0, 1)],
1452    ...        [('x', 1, 1, 0), ('x', 1, 1, 1)],
1453    ...        [('x', 1, 2, 0), ('x', 1, 2, 1)]]
1454
1455    >>> zero_broadcast_dimensions(lol, nblocks)  # doctest: +SKIP
1456    [[('x', 1, 0, 0), ('x', 1, 0, 1)],
1457     [('x', 1, 0, 0), ('x', 1, 0, 1)],
1458     [('x', 1, 0, 0), ('x', 1, 0, 1)]]
1459
1460    See Also
1461    --------
1462    lol_tuples
1463    """
1464    f = lambda t: (t[0],) + tuple(0 if d == 1 else i for i, d in zip(t[1:], nblocks))
1465    return homogeneous_deepmap(f, lol)
1466
1467
1468def broadcast_dimensions(argpairs, numblocks, sentinels=(1, (1,)), consolidate=None):
1469    """Find block dimensions from arguments
1470
1471    Parameters
1472    ----------
1473    argpairs : iterable
1474        name, ijk index pairs
1475    numblocks : dict
1476        maps {name: number of blocks}
1477    sentinels : iterable (optional)
1478        values for singleton dimensions
1479    consolidate : func (optional)
1480        use this to reduce each set of common blocks into a smaller set
1481
1482    Examples
1483    --------
1484    >>> argpairs = [('x', 'ij'), ('y', 'ji')]
1485    >>> numblocks = {'x': (2, 3), 'y': (3, 2)}
1486    >>> broadcast_dimensions(argpairs, numblocks)
1487    {'i': 2, 'j': 3}
1488
1489    Supports numpy broadcasting rules
1490
1491    >>> argpairs = [('x', 'ij'), ('y', 'ij')]
1492    >>> numblocks = {'x': (2, 1), 'y': (1, 3)}
1493    >>> broadcast_dimensions(argpairs, numblocks)
1494    {'i': 2, 'j': 3}
1495
1496    Works in other contexts too
1497
1498    >>> argpairs = [('x', 'ij'), ('y', 'ij')]
1499    >>> d = {'x': ('Hello', 1), 'y': (1, (2, 3))}
1500    >>> broadcast_dimensions(argpairs, d)
1501    {'i': 'Hello', 'j': (2, 3)}
1502    """
1503    # List like [('i', 2), ('j', 1), ('i', 1), ('j', 2)]
1504    argpairs2 = [(a, ind) for a, ind in argpairs if ind is not None]
1505    L = toolz.concat(
1506        [
1507            zip(inds, dims)
1508            for (x, inds), (x, dims) in toolz.join(
1509                toolz.first, argpairs2, toolz.first, numblocks.items()
1510            )
1511        ]
1512    )
1513
1514    g = toolz.groupby(0, L)
1515    g = {k: {d for i, d in v} for k, v in g.items()}
1516
1517    g2 = {k: v - set(sentinels) if len(v) > 1 else v for k, v in g.items()}
1518
1519    if consolidate:
1520        return toolz.valmap(consolidate, g2)
1521
1522    if g2 and not set(map(len, g2.values())) == {1}:
1523        raise ValueError("Shapes do not align %s" % g)
1524
1525    return toolz.valmap(toolz.first, g2)
1526
1527
1528def _make_dims(indices, numblocks, new_axes):
1529    """Returns a dictionary mapping between each index specified in
1530    `indices` and the number of output blocks for that indice.
1531    """
1532    dims = broadcast_dimensions(indices, numblocks)
1533    for k, v in new_axes.items():
1534        dims[k] = len(v) if isinstance(v, tuple) else 1
1535    return dims
1536
1537
1538def fuse_roots(graph: HighLevelGraph, keys: list):
1539    """
1540    Fuse nearby layers if they don't have dependencies
1541
1542    Often Blockwise sections of the graph fill out all of the computation
1543    except for the initial data access or data loading layers::
1544
1545      Large Blockwise Layer
1546        |       |       |
1547        X       Y       Z
1548
1549    This can be troublesome because X, Y, and Z tasks may be executed on
1550    different machines, and then require communication to move around.
1551
1552    This optimization identifies this situation, lowers all of the graphs to
1553    concrete dicts, and then calls ``fuse`` on them, with a width equal to the
1554    number of layers like X, Y, and Z.
1555
1556    This is currently used within array and dataframe optimizations.
1557
1558    Parameters
1559    ----------
1560    graph : HighLevelGraph
1561        The full graph of the computation
1562    keys : list
1563        The output keys of the computation, to be passed on to fuse
1564
1565    See Also
1566    --------
1567    Blockwise
1568    fuse
1569    """
1570    layers = ensure_dict(graph.layers, copy=True)
1571    dependencies = ensure_dict(graph.dependencies, copy=True)
1572    dependents = reverse_dict(dependencies)
1573
1574    for name, layer in graph.layers.items():
1575        deps = graph.dependencies[name]
1576        if (
1577            isinstance(layer, Blockwise)
1578            and len(deps) > 1
1579            and not any(dependencies[dep] for dep in deps)  # no need to fuse if 0 or 1
1580            and all(len(dependents[dep]) == 1 for dep in deps)
1581            and all(layer.annotations == graph.layers[dep].annotations for dep in deps)
1582        ):
1583            new = toolz.merge(layer, *[layers[dep] for dep in deps])
1584            new, _ = fuse(new, keys, ave_width=len(deps))
1585
1586            for dep in deps:
1587                del layers[dep]
1588                del dependencies[dep]
1589
1590            layers[name] = new
1591            dependencies[name] = set()
1592
1593    return HighLevelGraph(layers, dependencies)
1594