1# Copyright 2018 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15 16import functools 17import itertools as it 18import operator 19import types 20from typing import Any, Callable 21 22import numpy as np 23 24import jax 25from jax.config import FLAGS 26 27partial = functools.partial 28 29 30def safe_zip(*args): 31 n = len(args[0]) 32 for arg in args[1:]: 33 assert len(arg) == n, 'length mismatch: {}'.format(list(map(len, args))) 34 return list(zip(*args)) 35 36def safe_map(f, *args): 37 args = list(map(list, args)) 38 n = len(args[0]) 39 for arg in args[1:]: 40 assert len(arg) == n, 'length mismatch: {}'.format(list(map(len, args))) 41 return list(map(f, *args)) 42 43def unzip2(xys): 44 xs = [] 45 ys = [] 46 for x, y in xys: 47 xs.append(x) 48 ys.append(y) 49 return tuple(xs), tuple(ys) 50 51def unzip3(xyzs): 52 xs = [] 53 ys = [] 54 zs = [] 55 for x, y, z in xyzs: 56 xs.append(x) 57 ys.append(y) 58 zs.append(z) 59 return tuple(xs), tuple(ys), tuple(zs) 60 61def unzip4(wxyzs): 62 ws = [] 63 xs = [] 64 ys = [] 65 zs = [] 66 for w, x, y, z in wxyzs: 67 ws.append(w) 68 xs.append(x) 69 ys.append(y) 70 zs.append(z) 71 return tuple(ws), tuple(xs), tuple(ys), tuple(zs) 72 73def subvals(lst, replace): 74 lst = list(lst) 75 for i, v in replace: 76 lst[i] = v 77 return tuple(lst) 78 79def split_list(args, ns): 80 assert type(ns) is list 81 args = list(args) 82 lists = [] 83 for n in ns: 84 lists.append(args[:n]) 85 args = args[n:] 86 lists.append(args) 87 return lists 88 89def split_dict(dct, names): 90 dct = dict(dct) 91 lst = [dct.pop(name) for name in names] 92 assert not dct 93 return lst 94 95def concatenate(xs): 96 return list(it.chain.from_iterable(xs)) 97 98class partialmethod(functools.partial): 99 def __get__(self, instance, owner): 100 if instance is None: 101 return self 102 else: 103 return partial(self.func, instance, 104 *(self.args or ()), **(self.keywords or {})) 105 106def curry(f): 107 """Curries arguments of f, returning a function on any remaining arguments. 108 109 For example: 110 >>> f = lambda x, y, z, w: x * y + z * w 111 >>> f(2,3,4,5) 112 26 113 >>> curry(f)(2)(3, 4, 5) 114 26 115 >>> curry(f)(2, 3)(4, 5) 116 26 117 >>> curry(f)(2, 3, 4, 5)() 118 26 119 """ 120 return partial(partial, f) 121 122def toposort(end_nodes): 123 if not end_nodes: return [] 124 end_nodes = _remove_duplicates(end_nodes) 125 126 child_counts = {} 127 stack = list(end_nodes) 128 while stack: 129 node = stack.pop() 130 if id(node) in child_counts: 131 child_counts[id(node)] += 1 132 else: 133 child_counts[id(node)] = 1 134 stack.extend(node.parents) 135 for node in end_nodes: 136 child_counts[id(node)] -= 1 137 138 sorted_nodes = [] 139 childless_nodes = [node for node in end_nodes if child_counts[id(node)] == 0] 140 assert childless_nodes 141 while childless_nodes: 142 node = childless_nodes.pop() 143 sorted_nodes.append(node) 144 for parent in node.parents: 145 if child_counts[id(parent)] == 1: 146 childless_nodes.append(parent) 147 else: 148 child_counts[id(parent)] -= 1 149 150 check_toposort(sorted_nodes[::-1]) 151 return sorted_nodes[::-1] 152 153def check_toposort(nodes): 154 visited = set() 155 for node in nodes: 156 assert all(id(parent) in visited for parent in node.parents) 157 visited.add(id(node)) 158 159def _remove_duplicates(node_list): 160 seen = set() 161 out = [] 162 for n in node_list: 163 if id(n) not in seen: 164 seen.add(id(n)) 165 out.append(n) 166 return out 167 168def split_merge(predicate, xs): 169 sides = list(map(predicate, xs)) 170 lhs = [x for x, s in zip(xs, sides) if s] 171 rhs = [x for x, s in zip(xs, sides) if not s] 172 def merge(new_lhs, new_rhs): 173 out = [] 174 for s in sides: 175 if s: 176 out.append(new_lhs[0]) 177 new_lhs = new_lhs[1:] 178 else: 179 out.append(new_rhs[0]) 180 new_rhs = new_rhs[1:] 181 assert not new_rhs 182 assert not new_lhs 183 return out 184 185 return lhs, rhs, merge 186 187def cache(max_size=4096): 188 def wrap(f): 189 @functools.lru_cache(max_size) 190 def cached(_, *args, **kwargs): 191 return f(*args, **kwargs) 192 193 @functools.wraps(f) 194 def wrapper(*args, **kwargs): 195 if jax.core.debug_state.check_leaks: 196 return f(*args, **kwargs) 197 else: 198 return cached(bool(FLAGS.jax_enable_x64), *args, **kwargs) 199 200 wrapper.cache_clear = cached.cache_clear 201 wrapper.cache_info = cached.cache_info 202 return wrapper 203 return wrap 204 205def memoize(f): 206 @functools.lru_cache(None) 207 def memoized(_, *args, **kwargs): 208 return f(*args, **kwargs) 209 210 @functools.wraps(f) 211 def wrapper(*args, **kwargs): 212 return memoized(bool(FLAGS.jax_enable_x64), *args, **kwargs) 213 214 wrapper.cache_clear = memoized.cache_clear 215 wrapper.cache_info = memoized.cache_info 216 return wrapper 217 218def prod(xs): 219 out = 1 220 for x in xs: 221 out *= x 222 return out 223 224class WrapHashably(object): 225 __slots__ = ["val"] 226 227 def __init__(self, val): 228 self.val = val 229 230 def __hash__(self): 231 return id(self.val) 232 233 def __eq__(self, other): 234 return self.val is other.val 235 236class Hashable(object): 237 __slots__ = ["val"] 238 239 def __init__(self, val): 240 self.val = val 241 242 def __hash__(self): 243 return hash(self.val) 244 245 def __eq__(self, other): 246 return self.val == other.val 247 248def get_module_functions(module): 249 """Finds functions in module. 250 Args: 251 module: A Python module. 252 Returns: 253 module_fns: A dict of names mapped to functions, builtins or ufuncs in `module`. 254 """ 255 module_fns = {} 256 for key in dir(module): 257 # Omitting module level __getattr__, __dir__ which was added in Python 3.7 258 # https://www.python.org/dev/peps/pep-0562/ 259 if key in ('__getattr__', '__dir__'): 260 continue 261 attr = getattr(module, key) 262 if isinstance( 263 attr, (types.BuiltinFunctionType, types.FunctionType, np.ufunc)): 264 module_fns[key] = attr 265 return module_fns 266 267def wrap_name(name, transform_name): 268 return transform_name + '(' + name + ')' 269 270def extend_name_stack(stack, name=''): 271 return stack + name + '/' 272 273def canonicalize_axis(axis, num_dims) -> int: 274 """Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims).""" 275 axis = operator.index(axis) 276 if not -num_dims <= axis < num_dims: 277 raise ValueError( 278 "axis {} is out of bounds for array of dimension {}".format( 279 axis, num_dims)) 280 if axis < 0: 281 axis = axis + num_dims 282 return axis 283 284def moveaxis(x, src, dst): 285 if src == dst: 286 return x 287 src = canonicalize_axis(src, x.ndim) 288 dst = canonicalize_axis(dst, x.ndim) 289 perm = [i for i in range(np.ndim(x)) if i != src] 290 perm.insert(dst, src) 291 return x.transpose(perm) 292 293def ceil_of_ratio(x, y): 294 return -(-x // y) 295 296@curry 297def wraps(wrapped, fun, namestr="{fun}", docstr="{doc}", **kwargs): 298 try: 299 fun.__name__ = namestr.format(fun=get_name(wrapped)) 300 fun.__module__ = get_module(wrapped) 301 fun.__doc__ = docstr.format(fun=get_name(wrapped), doc=get_doc(wrapped), **kwargs) 302 fun.__wrapped__ = wrapped 303 finally: 304 return fun 305 306def get_name(fun): return getattr(fun, "__name__", "<unnamed function>") 307def get_module(fun): return getattr(fun, "__module__", "<unknown module>") 308def get_doc(fun): return getattr(fun, "__doc__", "") 309 310# NOTE: Ideally we would annotate both the argument and return type as NoReturn 311# but it seems like pytype doesn't support that... 312def assert_unreachable(x): 313 raise AssertionError(f"Unhandled case: {type(x).__name__}") 314 315def tuple_insert(t, idx, val): 316 assert 0 <= idx <= len(t), (idx, len(t)) 317 return t[:idx] + (val,) + t[idx:] 318 319def tuple_delete(t, idx): 320 assert 0 <= idx < len(t), (idx, len(t)) 321 return t[:idx] + t[idx + 1:] 322 323# TODO(mattjj): replace with dataclass when Python 2 support is removed 324def taggedtuple(name, fields) -> Callable[..., Any]: 325 """Lightweight version of namedtuple where equality depends on the type.""" 326 def __new__(cls, *xs): 327 return tuple.__new__(cls, (cls,) + xs) 328 def __repr__(self): 329 return '{}{}'.format(name, tuple.__str__(self[1:])) 330 class_namespace = {'__new__' : __new__, '__repr__': __repr__} 331 for i, f in enumerate(fields): 332 class_namespace[f] = property(operator.itemgetter(i+1)) # type: ignore 333 return type(name, (tuple,), class_namespace) 334 335class HashableFunction: 336 """Decouples function equality and hash from its identity. 337 338 Local lambdas and functiond defs are reallocated on each function call, making 339 the functions created on different calls compare as unequal. This breaks our 340 caching logic, which should really only care about comparing the semantics and 341 not actual identity. 342 343 This class makes it possible to compare different functions based on their 344 semantics. The parts that are taken into account are: the bytecode of 345 the wrapped function (which is cached by the CPython interpreter and is stable 346 across the invocations of the surrounding function), and `closure` which should 347 contain all values in scope that affect the function semantics. In particular 348 `closure` should contain all elements of the function closure, or it should be 349 possible to derive the relevant elements of the true function closure based 350 solely on the contents of the `closure` argument (e.g. in case some closed-over 351 values are not hashable, but are entirely determined by hashable locals). 352 """ 353 354 def __init__(self, f, closure): 355 self.f = f 356 self.closure = closure 357 358 def __eq__(self, other): 359 return (type(other) is HashableFunction and 360 self.f.__code__ == other.f.__code__ and 361 self.closure == other.closure) 362 363 def __hash__(self): 364 return hash((self.f.__code__, self.closure)) 365 366 def __call__(self, *args, **kwargs): 367 return self.f(*args, **kwargs) 368 369 def __repr__(self): 370 return f'<hashable {self.f.__name__} with closure={self.closure}>' 371 372def as_hashable_function(closure): 373 return lambda f: HashableFunction(f, closure) 374