1# Copyright 2018 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""Utilities for working with tree-like container data structures.
16
17This module provides a small set of utility functions for working with tree-like
18data structures, such as nested tuples, lists, and dicts. We call these
19structures pytrees. They are trees in that they are defined recursively (any
20non-pytree is a pytree, i.e. a leaf, and any pytree of pytrees is a pytree) and
21can be operated on recursively (object identity equivalence is not preserved by
22mapping operations, and the structures cannot contain reference cycles).
23
24The set of Python types that are considered pytree nodes (e.g. that can be
25mapped over, rather than treated as leaves) is extensible. There is a single
26module-level registry of types, and class hierarchy is ignored. By registering a
27new pytree node type, that type in effect becomes transparent to the utility
28functions in this file.
29
30The primary purpose of this module is to enable the interoperability between
31user defined data structures and JAX transformations (e.g. `jit`). This is not
32meant to be a general purpose tree-like data structure handling library.
33
34See the `JAX pytrees note <pytrees.html>`_
35for examples.
36"""
37
38
39import functools
40import collections
41import operator as op
42from typing import Any, Callable, Optional, Sequence, Tuple, Type, TypeVar, overload
43
44from .lib import pytree
45
46from ._src.util import partial, safe_zip, unzip2
47
48from ._src import traceback_util
49traceback_util.register_exclusion(__file__)
50
51T = TypeVar("T")
52U = TypeVar("U")
53
54def tree_flatten(tree, is_leaf: Optional[Callable[[Any], bool]] = None):
55  """Flattens a pytree.
56
57  Args:
58    tree: a pytree to flatten.
59    is_leaf: an optionally specified function that will be called at each
60      flattening step. It should return a boolean, which indicates whether
61      the flattening should traverse the current object, or if it should be
62      stopped immediately, with the whole subtree being treated as a leaf.
63
64  Returns:
65    A pair where the first element is a list of leaf values and the second
66    element is a treedef representing the structure of the flattened tree.
67  """
68  # We skip the second argument in support of old jaxlibs
69  # TODO: Remove once 0.1.58 becomes the minimum supported jaxlib version
70  return pytree.flatten(tree) if is_leaf is None else pytree.flatten(tree, is_leaf)
71
72
73def tree_unflatten(treedef, leaves):
74  """Reconstructs a pytree from the treedef and the leaves.
75
76  The inverse of :func:`tree_flatten`.
77
78  Args:
79    treedef: the treedef to reconstruct
80    leaves: the list of leaves to use for reconstruction. The list must match
81      the leaves of the treedef.
82
83  Returns:
84    The reconstructed pytree, containing the ``leaves`` placed in the structure
85    described by ``treedef``.
86  """
87  return treedef.unflatten(leaves)
88
89def tree_leaves(tree):
90  """Gets the leaves of a pytree."""
91  return pytree.flatten(tree)[0]
92
93def tree_structure(tree):
94  """Gets the treedef for a pytree."""
95  return pytree.flatten(tree)[1]
96
97def treedef_tuple(treedefs):
98  """Makes a tuple treedef from a list of child treedefs."""
99  return pytree.tuple(list(treedefs))
100
101def treedef_children(treedef):
102  return treedef.children()
103
104def treedef_is_leaf(treedef):
105  return treedef.num_nodes == 1
106
107def all_leaves(iterable):
108  """Tests whether all elements in the given iterable are all leaves.
109
110  >>> tree = {"a": [1, 2, 3]}
111  >>> assert all_leaves(jax.tree_leaves(tree))
112  >>> assert not all_leaves([tree])
113
114  This function is useful in advanced cases, for example if a library allows
115  arbitrary map operations on a flat list of leaves it may want to check if
116  the result is still a flat list of leaves.
117
118  Args:
119    iterable: Iterable of leaves.
120
121  Returns:
122    A boolean indicating if all elements in the input are leaves.
123  """
124  return pytree.all_leaves(iterable)
125
126# The auxiliary is hashable, but because mypy has poor support for Hashable, we
127# annotate it as Any.
128def register_pytree_node(nodetype: Type[T],
129                         flatten_func: Callable[[T], Tuple[Sequence[Any], Any]],
130                         unflatten_func: Callable[[Any, Sequence[Any]], T]):
131  """Extends the set of types that are considered internal nodes in pytrees.
132
133  See `example usage <pytrees.html>`_.
134
135  Args:
136    nodetype: a Python type to treat as an internal pytree node.
137    flatten_func: a function to be used during flattening, taking a value of
138      type ``nodetype`` and returning a pair, with (1) an iterable for the
139      children to be flattened recursively, and (2) some hashable auxiliary
140      data to be stored in the treedef and to be passed to the
141      ``unflatten_func``.
142    unflatten_func: a function taking two arguments: the auxiliary data that was
143      returned by ``flatten_func`` and stored in the treedef, and the
144      unflattened children. The function should return an instance of
145      ``nodetype``.
146  """
147  pytree.register_node(nodetype, flatten_func, unflatten_func)
148  _registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func)
149
150def register_pytree_node_class(cls):
151  """Extends the set of types that are considered internal nodes in pytrees.
152
153  This function is a thin wrapper around ``register_pytree_node``, and provides
154  a class-oriented interface::
155
156    @register_pytree_node_class
157    class Special:
158      def __init__(self, x, y):
159        self.x = x
160        self.y = y
161      def tree_flatten(self):
162        return ((self.x, self.y), None)
163      @classmethod
164      def tree_unflatten(cls, aux_data, children):
165        return cls(*children)
166  """
167  register_pytree_node(cls, op.methodcaller('tree_flatten'), cls.tree_unflatten)
168  return cls
169
170def tree_map(f: Callable[[Any], Any], tree: Any) -> Any:
171  """Maps a function over a pytree to produce a new pytree.
172
173  Args:
174    f: unary function to be applied at each leaf.
175    tree: a pytree to be mapped over.
176
177  Returns:
178    A new pytree with the same structure as `tree` but with the value at each
179    leaf given by ``f(x)`` where ``x`` is the value at the corresponding leaf in
180    the input ``tree``.
181  """
182  leaves, treedef = pytree.flatten(tree)
183  return treedef.unflatten(map(f, leaves))
184
185def tree_multimap(f: Callable[..., Any], tree: Any, *rest: Any) -> Any:
186  """Maps a multi-input function over pytree args to produce a new pytree.
187
188  Args:
189    f: function that takes ``1 + len(rest)`` arguments, to be applied at the
190      corresponding leaves of the pytrees.
191    tree: a pytree to be mapped over, with each leaf providing the first
192      positional argument to ``f``.
193    *rest: a tuple of pytrees, each of which has the same structure as tree or
194      or has tree as a prefix.
195
196  Returns:
197    A new pytree with the same structure as ``tree`` but with the value at each
198    leaf given by ``f(x, *xs)`` where ``x`` is the value at the corresponding
199    leaf in ``tree`` and ``xs`` is the tuple of values at corresponding nodes in
200    ``rest``.
201  """
202  leaves, treedef = pytree.flatten(tree)
203  all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
204  return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
205
206# TODO(mattjj,phawkins): consider removing this function
207def _process_pytree(process_node, tree):
208  leaves, treedef = pytree.flatten(tree)
209  return treedef.walk(process_node, None, leaves), treedef
210
211def build_tree(treedef, xs):
212  return treedef.from_iterable_tree(xs)
213
214def tree_transpose(outer_treedef, inner_treedef, pytree_to_transpose):
215  flat, treedef = tree_flatten(pytree_to_transpose)
216  inner_size = inner_treedef.num_leaves
217  outer_size = outer_treedef.num_leaves
218  if treedef.num_leaves != (inner_size * outer_size):
219    expected_treedef = outer_treedef.compose(inner_treedef)
220    raise TypeError(f"Mismatch\n{treedef}\n != \n{expected_treedef}")
221  flat = iter(flat)
222  lol = [[next(flat) for _ in range(inner_size)] for __ in range(outer_size)]
223  transposed_lol = zip(*lol)
224  subtrees = map(partial(tree_unflatten, outer_treedef), transposed_lol)
225  return tree_unflatten(inner_treedef, subtrees)
226
227# TODO(mattjj): remove the Python-side registry when the C++-side registry is
228# sufficiently queryable that we can express _replace_nones. That may mean once
229# we have a flatten_one function.
230_RegistryEntry = collections.namedtuple("RegistryEntry", ["to_iter", "from_iter"])
231_registry = {
232    tuple: _RegistryEntry(lambda xs: (xs, None), lambda _, xs: tuple(xs)),
233    list: _RegistryEntry(lambda xs: (xs, None), lambda _, xs: list(xs)),
234    dict: _RegistryEntry(lambda xs: unzip2(sorted(xs.items()))[::-1],
235                         lambda keys, xs: dict(zip(keys, xs))),
236    type(None): _RegistryEntry(lambda z: ((), None), lambda _, xs: None),
237}
238def _replace_nones(sentinel, tree):
239  """Replaces ``None`` in ``tree`` with ``sentinel``."""
240  if tree is None:
241    return sentinel
242  else:
243    handler = _registry.get(type(tree))
244    if handler:
245      children, metadata = handler.to_iter(tree)
246      proc_children = [_replace_nones(sentinel, child) for child in children]
247      return handler.from_iter(metadata, proc_children)
248    elif isinstance(tree, tuple) and hasattr(tree, '_fields'):
249      # handle namedtuple as a special case, based on heuristic
250      children = iter(tree)
251      proc_children = [_replace_nones(sentinel, child) for child in children]
252      return type(tree)(*proc_children)
253    else:
254      return tree
255
256no_initializer = object()
257
258@overload
259def tree_reduce(function: Callable[[T, Any], T],
260                tree: Any) -> T:
261    ...
262
263@overload
264def tree_reduce(function: Callable[[T, Any], T],
265                tree: Any,
266                initializer: T) -> T:
267    ...
268
269def tree_reduce(function: Callable[[T, Any], T],
270                tree: Any,
271                initializer: Any = no_initializer) -> T:
272  if initializer is no_initializer:
273    return functools.reduce(function, tree_leaves(tree))
274  else:
275    return functools.reduce(function, tree_leaves(tree), initializer)
276
277def tree_all(tree):
278  return all(tree_leaves(tree))
279
280register_pytree_node(
281  collections.OrderedDict,
282  lambda x: (list(x.values()), list(x.keys())),
283  lambda keys, values: collections.OrderedDict(safe_zip(keys, values)))
284
285register_pytree_node(
286  collections.defaultdict,
287  lambda x: (tuple(x.values()), (x.default_factory, tuple(x.keys()))),
288  lambda s, values: collections.defaultdict(s[0], safe_zip(s[1], values)))
289
290
291class Partial(functools.partial):
292  """A version of functools.partial that works in pytrees.
293
294  Use it for partial function evaluation in a way that is compatible with JAX's
295  transformations, e.g., ``Partial(func, *args, **kwargs)``.
296
297  (You need to explicitly opt-in to this behavior because we didn't want to give
298  functools.partial different semantics than normal function closures.)
299  """
300
301register_pytree_node(
302    Partial,
303    lambda partial_: ((partial_.args, partial_.keywords), partial_.func),
304    lambda func, xs: Partial(func, *xs[0], **xs[1]),
305)
306