1# Copyright 2018 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15"""Utilities for working with tree-like container data structures. 16 17This module provides a small set of utility functions for working with tree-like 18data structures, such as nested tuples, lists, and dicts. We call these 19structures pytrees. They are trees in that they are defined recursively (any 20non-pytree is a pytree, i.e. a leaf, and any pytree of pytrees is a pytree) and 21can be operated on recursively (object identity equivalence is not preserved by 22mapping operations, and the structures cannot contain reference cycles). 23 24The set of Python types that are considered pytree nodes (e.g. that can be 25mapped over, rather than treated as leaves) is extensible. There is a single 26module-level registry of types, and class hierarchy is ignored. By registering a 27new pytree node type, that type in effect becomes transparent to the utility 28functions in this file. 29 30The primary purpose of this module is to enable the interoperability between 31user defined data structures and JAX transformations (e.g. `jit`). This is not 32meant to be a general purpose tree-like data structure handling library. 33 34See the `JAX pytrees note <pytrees.html>`_ 35for examples. 36""" 37 38 39import functools 40import collections 41import operator as op 42from typing import Any, Callable, Optional, Sequence, Tuple, Type, TypeVar, overload 43 44from .lib import pytree 45 46from ._src.util import partial, safe_zip, unzip2 47 48from ._src import traceback_util 49traceback_util.register_exclusion(__file__) 50 51T = TypeVar("T") 52U = TypeVar("U") 53 54def tree_flatten(tree, is_leaf: Optional[Callable[[Any], bool]] = None): 55 """Flattens a pytree. 56 57 Args: 58 tree: a pytree to flatten. 59 is_leaf: an optionally specified function that will be called at each 60 flattening step. It should return a boolean, which indicates whether 61 the flattening should traverse the current object, or if it should be 62 stopped immediately, with the whole subtree being treated as a leaf. 63 64 Returns: 65 A pair where the first element is a list of leaf values and the second 66 element is a treedef representing the structure of the flattened tree. 67 """ 68 # We skip the second argument in support of old jaxlibs 69 # TODO: Remove once 0.1.58 becomes the minimum supported jaxlib version 70 return pytree.flatten(tree) if is_leaf is None else pytree.flatten(tree, is_leaf) 71 72 73def tree_unflatten(treedef, leaves): 74 """Reconstructs a pytree from the treedef and the leaves. 75 76 The inverse of :func:`tree_flatten`. 77 78 Args: 79 treedef: the treedef to reconstruct 80 leaves: the list of leaves to use for reconstruction. The list must match 81 the leaves of the treedef. 82 83 Returns: 84 The reconstructed pytree, containing the ``leaves`` placed in the structure 85 described by ``treedef``. 86 """ 87 return treedef.unflatten(leaves) 88 89def tree_leaves(tree): 90 """Gets the leaves of a pytree.""" 91 return pytree.flatten(tree)[0] 92 93def tree_structure(tree): 94 """Gets the treedef for a pytree.""" 95 return pytree.flatten(tree)[1] 96 97def treedef_tuple(treedefs): 98 """Makes a tuple treedef from a list of child treedefs.""" 99 return pytree.tuple(list(treedefs)) 100 101def treedef_children(treedef): 102 return treedef.children() 103 104def treedef_is_leaf(treedef): 105 return treedef.num_nodes == 1 106 107def all_leaves(iterable): 108 """Tests whether all elements in the given iterable are all leaves. 109 110 >>> tree = {"a": [1, 2, 3]} 111 >>> assert all_leaves(jax.tree_leaves(tree)) 112 >>> assert not all_leaves([tree]) 113 114 This function is useful in advanced cases, for example if a library allows 115 arbitrary map operations on a flat list of leaves it may want to check if 116 the result is still a flat list of leaves. 117 118 Args: 119 iterable: Iterable of leaves. 120 121 Returns: 122 A boolean indicating if all elements in the input are leaves. 123 """ 124 return pytree.all_leaves(iterable) 125 126# The auxiliary is hashable, but because mypy has poor support for Hashable, we 127# annotate it as Any. 128def register_pytree_node(nodetype: Type[T], 129 flatten_func: Callable[[T], Tuple[Sequence[Any], Any]], 130 unflatten_func: Callable[[Any, Sequence[Any]], T]): 131 """Extends the set of types that are considered internal nodes in pytrees. 132 133 See `example usage <pytrees.html>`_. 134 135 Args: 136 nodetype: a Python type to treat as an internal pytree node. 137 flatten_func: a function to be used during flattening, taking a value of 138 type ``nodetype`` and returning a pair, with (1) an iterable for the 139 children to be flattened recursively, and (2) some hashable auxiliary 140 data to be stored in the treedef and to be passed to the 141 ``unflatten_func``. 142 unflatten_func: a function taking two arguments: the auxiliary data that was 143 returned by ``flatten_func`` and stored in the treedef, and the 144 unflattened children. The function should return an instance of 145 ``nodetype``. 146 """ 147 pytree.register_node(nodetype, flatten_func, unflatten_func) 148 _registry[nodetype] = _RegistryEntry(flatten_func, unflatten_func) 149 150def register_pytree_node_class(cls): 151 """Extends the set of types that are considered internal nodes in pytrees. 152 153 This function is a thin wrapper around ``register_pytree_node``, and provides 154 a class-oriented interface:: 155 156 @register_pytree_node_class 157 class Special: 158 def __init__(self, x, y): 159 self.x = x 160 self.y = y 161 def tree_flatten(self): 162 return ((self.x, self.y), None) 163 @classmethod 164 def tree_unflatten(cls, aux_data, children): 165 return cls(*children) 166 """ 167 register_pytree_node(cls, op.methodcaller('tree_flatten'), cls.tree_unflatten) 168 return cls 169 170def tree_map(f: Callable[[Any], Any], tree: Any) -> Any: 171 """Maps a function over a pytree to produce a new pytree. 172 173 Args: 174 f: unary function to be applied at each leaf. 175 tree: a pytree to be mapped over. 176 177 Returns: 178 A new pytree with the same structure as `tree` but with the value at each 179 leaf given by ``f(x)`` where ``x`` is the value at the corresponding leaf in 180 the input ``tree``. 181 """ 182 leaves, treedef = pytree.flatten(tree) 183 return treedef.unflatten(map(f, leaves)) 184 185def tree_multimap(f: Callable[..., Any], tree: Any, *rest: Any) -> Any: 186 """Maps a multi-input function over pytree args to produce a new pytree. 187 188 Args: 189 f: function that takes ``1 + len(rest)`` arguments, to be applied at the 190 corresponding leaves of the pytrees. 191 tree: a pytree to be mapped over, with each leaf providing the first 192 positional argument to ``f``. 193 *rest: a tuple of pytrees, each of which has the same structure as tree or 194 or has tree as a prefix. 195 196 Returns: 197 A new pytree with the same structure as ``tree`` but with the value at each 198 leaf given by ``f(x, *xs)`` where ``x`` is the value at the corresponding 199 leaf in ``tree`` and ``xs`` is the tuple of values at corresponding nodes in 200 ``rest``. 201 """ 202 leaves, treedef = pytree.flatten(tree) 203 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest] 204 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves)) 205 206# TODO(mattjj,phawkins): consider removing this function 207def _process_pytree(process_node, tree): 208 leaves, treedef = pytree.flatten(tree) 209 return treedef.walk(process_node, None, leaves), treedef 210 211def build_tree(treedef, xs): 212 return treedef.from_iterable_tree(xs) 213 214def tree_transpose(outer_treedef, inner_treedef, pytree_to_transpose): 215 flat, treedef = tree_flatten(pytree_to_transpose) 216 inner_size = inner_treedef.num_leaves 217 outer_size = outer_treedef.num_leaves 218 if treedef.num_leaves != (inner_size * outer_size): 219 expected_treedef = outer_treedef.compose(inner_treedef) 220 raise TypeError(f"Mismatch\n{treedef}\n != \n{expected_treedef}") 221 flat = iter(flat) 222 lol = [[next(flat) for _ in range(inner_size)] for __ in range(outer_size)] 223 transposed_lol = zip(*lol) 224 subtrees = map(partial(tree_unflatten, outer_treedef), transposed_lol) 225 return tree_unflatten(inner_treedef, subtrees) 226 227# TODO(mattjj): remove the Python-side registry when the C++-side registry is 228# sufficiently queryable that we can express _replace_nones. That may mean once 229# we have a flatten_one function. 230_RegistryEntry = collections.namedtuple("RegistryEntry", ["to_iter", "from_iter"]) 231_registry = { 232 tuple: _RegistryEntry(lambda xs: (xs, None), lambda _, xs: tuple(xs)), 233 list: _RegistryEntry(lambda xs: (xs, None), lambda _, xs: list(xs)), 234 dict: _RegistryEntry(lambda xs: unzip2(sorted(xs.items()))[::-1], 235 lambda keys, xs: dict(zip(keys, xs))), 236 type(None): _RegistryEntry(lambda z: ((), None), lambda _, xs: None), 237} 238def _replace_nones(sentinel, tree): 239 """Replaces ``None`` in ``tree`` with ``sentinel``.""" 240 if tree is None: 241 return sentinel 242 else: 243 handler = _registry.get(type(tree)) 244 if handler: 245 children, metadata = handler.to_iter(tree) 246 proc_children = [_replace_nones(sentinel, child) for child in children] 247 return handler.from_iter(metadata, proc_children) 248 elif isinstance(tree, tuple) and hasattr(tree, '_fields'): 249 # handle namedtuple as a special case, based on heuristic 250 children = iter(tree) 251 proc_children = [_replace_nones(sentinel, child) for child in children] 252 return type(tree)(*proc_children) 253 else: 254 return tree 255 256no_initializer = object() 257 258@overload 259def tree_reduce(function: Callable[[T, Any], T], 260 tree: Any) -> T: 261 ... 262 263@overload 264def tree_reduce(function: Callable[[T, Any], T], 265 tree: Any, 266 initializer: T) -> T: 267 ... 268 269def tree_reduce(function: Callable[[T, Any], T], 270 tree: Any, 271 initializer: Any = no_initializer) -> T: 272 if initializer is no_initializer: 273 return functools.reduce(function, tree_leaves(tree)) 274 else: 275 return functools.reduce(function, tree_leaves(tree), initializer) 276 277def tree_all(tree): 278 return all(tree_leaves(tree)) 279 280register_pytree_node( 281 collections.OrderedDict, 282 lambda x: (list(x.values()), list(x.keys())), 283 lambda keys, values: collections.OrderedDict(safe_zip(keys, values))) 284 285register_pytree_node( 286 collections.defaultdict, 287 lambda x: (tuple(x.values()), (x.default_factory, tuple(x.keys()))), 288 lambda s, values: collections.defaultdict(s[0], safe_zip(s[1], values))) 289 290 291class Partial(functools.partial): 292 """A version of functools.partial that works in pytrees. 293 294 Use it for partial function evaluation in a way that is compatible with JAX's 295 transformations, e.g., ``Partial(func, *args, **kwargs)``. 296 297 (You need to explicitly opt-in to this behavior because we didn't want to give 298 functools.partial different semantics than normal function closures.) 299 """ 300 301register_pytree_node( 302 Partial, 303 lambda partial_: ((partial_.args, partial_.keywords), partial_.func), 304 lambda func, xs: Partial(func, *xs[0], **xs[1]), 305) 306