1# Copyright 2018 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15
16import functools
17import itertools as it
18import operator
19import types
20from typing import Any, Callable
21
22import numpy as np
23
24import jax
25from jax.config import FLAGS
26
27partial = functools.partial
28
29
30def safe_zip(*args):
31  n = len(args[0])
32  for arg in args[1:]:
33    assert len(arg) == n, 'length mismatch: {}'.format(list(map(len, args)))
34  return list(zip(*args))
35
36def safe_map(f, *args):
37  args = list(map(list, args))
38  n = len(args[0])
39  for arg in args[1:]:
40    assert len(arg) == n, 'length mismatch: {}'.format(list(map(len, args)))
41  return list(map(f, *args))
42
43def unzip2(xys):
44  xs = []
45  ys = []
46  for x, y in xys:
47    xs.append(x)
48    ys.append(y)
49  return tuple(xs), tuple(ys)
50
51def unzip3(xyzs):
52  xs = []
53  ys = []
54  zs = []
55  for x, y, z in xyzs:
56    xs.append(x)
57    ys.append(y)
58    zs.append(z)
59  return tuple(xs), tuple(ys), tuple(zs)
60
61def unzip4(wxyzs):
62  ws = []
63  xs = []
64  ys = []
65  zs = []
66  for w, x, y, z in wxyzs:
67    ws.append(w)
68    xs.append(x)
69    ys.append(y)
70    zs.append(z)
71  return tuple(ws), tuple(xs), tuple(ys), tuple(zs)
72
73def subvals(lst, replace):
74  lst = list(lst)
75  for i, v in replace:
76    lst[i] = v
77  return tuple(lst)
78
79def split_list(args, ns):
80  assert type(ns) is list
81  args = list(args)
82  lists = []
83  for n in ns:
84    lists.append(args[:n])
85    args = args[n:]
86  lists.append(args)
87  return lists
88
89def split_dict(dct, names):
90  dct = dict(dct)
91  lst = [dct.pop(name) for name in names]
92  assert not dct
93  return lst
94
95def concatenate(xs):
96  return list(it.chain.from_iterable(xs))
97
98class partialmethod(functools.partial):
99  def __get__(self, instance, owner):
100    if instance is None:
101      return self
102    else:
103      return partial(self.func, instance,
104                     *(self.args or ()), **(self.keywords or {}))
105
106def curry(f):
107  """Curries arguments of f, returning a function on any remaining arguments.
108
109  For example:
110  >>> f = lambda x, y, z, w: x * y + z * w
111  >>> f(2,3,4,5)
112  26
113  >>> curry(f)(2)(3, 4, 5)
114  26
115  >>> curry(f)(2, 3)(4, 5)
116  26
117  >>> curry(f)(2, 3, 4, 5)()
118  26
119  """
120  return partial(partial, f)
121
122def toposort(end_nodes):
123  if not end_nodes: return []
124  end_nodes = _remove_duplicates(end_nodes)
125
126  child_counts = {}
127  stack = list(end_nodes)
128  while stack:
129    node = stack.pop()
130    if id(node) in child_counts:
131      child_counts[id(node)] += 1
132    else:
133      child_counts[id(node)] = 1
134      stack.extend(node.parents)
135  for node in end_nodes:
136    child_counts[id(node)] -= 1
137
138  sorted_nodes = []
139  childless_nodes = [node for node in end_nodes if child_counts[id(node)] == 0]
140  assert childless_nodes
141  while childless_nodes:
142    node = childless_nodes.pop()
143    sorted_nodes.append(node)
144    for parent in node.parents:
145      if child_counts[id(parent)] == 1:
146        childless_nodes.append(parent)
147      else:
148        child_counts[id(parent)] -= 1
149
150  check_toposort(sorted_nodes[::-1])
151  return sorted_nodes[::-1]
152
153def check_toposort(nodes):
154  visited = set()
155  for node in nodes:
156    assert all(id(parent) in visited for parent in node.parents)
157    visited.add(id(node))
158
159def _remove_duplicates(node_list):
160  seen = set()
161  out = []
162  for n in node_list:
163    if id(n) not in seen:
164      seen.add(id(n))
165      out.append(n)
166  return out
167
168def split_merge(predicate, xs):
169  sides = list(map(predicate, xs))
170  lhs = [x for x, s in zip(xs, sides) if s]
171  rhs = [x for x, s in zip(xs, sides) if not s]
172  def merge(new_lhs, new_rhs):
173    out = []
174    for s in sides:
175      if s:
176        out.append(new_lhs[0])
177        new_lhs = new_lhs[1:]
178      else:
179        out.append(new_rhs[0])
180        new_rhs = new_rhs[1:]
181    assert not new_rhs
182    assert not new_lhs
183    return out
184
185  return lhs, rhs, merge
186
187def cache(max_size=4096):
188  def wrap(f):
189    @functools.lru_cache(max_size)
190    def cached(_, *args, **kwargs):
191      return f(*args, **kwargs)
192
193    @functools.wraps(f)
194    def wrapper(*args, **kwargs):
195      if jax.core.debug_state.check_leaks:
196        return f(*args, **kwargs)
197      else:
198        return cached(bool(FLAGS.jax_enable_x64), *args, **kwargs)
199
200    wrapper.cache_clear = cached.cache_clear
201    wrapper.cache_info = cached.cache_info
202    return wrapper
203  return wrap
204
205def memoize(f):
206  @functools.lru_cache(None)
207  def memoized(_, *args, **kwargs):
208    return f(*args, **kwargs)
209
210  @functools.wraps(f)
211  def wrapper(*args, **kwargs):
212    return memoized(bool(FLAGS.jax_enable_x64), *args, **kwargs)
213
214  wrapper.cache_clear = memoized.cache_clear
215  wrapper.cache_info = memoized.cache_info
216  return wrapper
217
218def prod(xs):
219  out = 1
220  for x in xs:
221    out *= x
222  return out
223
224class WrapHashably(object):
225  __slots__ = ["val"]
226
227  def __init__(self, val):
228    self.val = val
229
230  def __hash__(self):
231    return id(self.val)
232
233  def __eq__(self, other):
234    return self.val is other.val
235
236class Hashable(object):
237  __slots__ = ["val"]
238
239  def __init__(self, val):
240    self.val = val
241
242  def __hash__(self):
243    return hash(self.val)
244
245  def __eq__(self, other):
246    return self.val == other.val
247
248def get_module_functions(module):
249  """Finds functions in module.
250  Args:
251    module: A Python module.
252  Returns:
253    module_fns: A dict of names mapped to functions, builtins or ufuncs in `module`.
254  """
255  module_fns = {}
256  for key in dir(module):
257    # Omitting module level __getattr__, __dir__ which was added in Python 3.7
258    # https://www.python.org/dev/peps/pep-0562/
259    if key in ('__getattr__', '__dir__'):
260      continue
261    attr = getattr(module, key)
262    if isinstance(
263        attr, (types.BuiltinFunctionType, types.FunctionType, np.ufunc)):
264      module_fns[key] = attr
265  return module_fns
266
267def wrap_name(name, transform_name):
268  return transform_name + '(' + name + ')'
269
270def extend_name_stack(stack, name=''):
271  return stack + name + '/'
272
273def canonicalize_axis(axis, num_dims) -> int:
274  """Canonicalize an axis in [-num_dims, num_dims) to [0, num_dims)."""
275  axis = operator.index(axis)
276  if not -num_dims <= axis < num_dims:
277    raise ValueError(
278        "axis {} is out of bounds for array of dimension {}".format(
279            axis, num_dims))
280  if axis < 0:
281    axis = axis + num_dims
282  return axis
283
284def moveaxis(x, src, dst):
285  if src == dst:
286    return x
287  src = canonicalize_axis(src, x.ndim)
288  dst = canonicalize_axis(dst, x.ndim)
289  perm = [i for i in range(np.ndim(x)) if i != src]
290  perm.insert(dst, src)
291  return x.transpose(perm)
292
293def ceil_of_ratio(x, y):
294  return -(-x // y)
295
296@curry
297def wraps(wrapped, fun, namestr="{fun}", docstr="{doc}", **kwargs):
298  try:
299    fun.__name__ = namestr.format(fun=get_name(wrapped))
300    fun.__module__ = get_module(wrapped)
301    fun.__doc__ = docstr.format(fun=get_name(wrapped), doc=get_doc(wrapped), **kwargs)
302    fun.__wrapped__ = wrapped
303  finally:
304    return fun
305
306def get_name(fun): return getattr(fun, "__name__", "<unnamed function>")
307def get_module(fun): return getattr(fun, "__module__", "<unknown module>")
308def get_doc(fun): return getattr(fun, "__doc__", "")
309
310# NOTE: Ideally we would annotate both the argument and return type as NoReturn
311#       but it seems like pytype doesn't support that...
312def assert_unreachable(x):
313  raise AssertionError(f"Unhandled case: {type(x).__name__}")
314
315def tuple_insert(t, idx, val):
316  assert 0 <= idx <= len(t), (idx, len(t))
317  return t[:idx] + (val,) + t[idx:]
318
319def tuple_delete(t, idx):
320  assert 0 <= idx < len(t), (idx, len(t))
321  return t[:idx] + t[idx + 1:]
322
323# TODO(mattjj): replace with dataclass when Python 2 support is removed
324def taggedtuple(name, fields) -> Callable[..., Any]:
325  """Lightweight version of namedtuple where equality depends on the type."""
326  def __new__(cls, *xs):
327    return tuple.__new__(cls, (cls,) + xs)
328  def __repr__(self):
329    return '{}{}'.format(name, tuple.__str__(self[1:]))
330  class_namespace = {'__new__' : __new__, '__repr__': __repr__}
331  for i, f in enumerate(fields):
332    class_namespace[f] = property(operator.itemgetter(i+1))  # type: ignore
333  return type(name, (tuple,), class_namespace)
334
335class HashableFunction:
336  """Decouples function equality and hash from its identity.
337
338  Local lambdas and functiond defs are reallocated on each function call, making
339  the functions created on different calls compare as unequal. This breaks our
340  caching logic, which should really only care about comparing the semantics and
341  not actual identity.
342
343  This class makes it possible to compare different functions based on their
344  semantics. The parts that are taken into account are: the bytecode of
345  the wrapped function (which is cached by the CPython interpreter and is stable
346  across the invocations of the surrounding function), and `closure` which should
347  contain all values in scope that affect the function semantics. In particular
348  `closure` should contain all elements of the function closure, or it should be
349  possible to derive the relevant elements of the true function closure based
350  solely on the contents of the `closure` argument (e.g. in case some closed-over
351  values are not hashable, but are entirely determined by hashable locals).
352  """
353
354  def __init__(self, f, closure):
355    self.f = f
356    self.closure = closure
357
358  def __eq__(self, other):
359    return (type(other) is HashableFunction and
360            self.f.__code__ == other.f.__code__ and
361            self.closure == other.closure)
362
363  def __hash__(self):
364    return hash((self.f.__code__, self.closure))
365
366  def __call__(self, *args, **kwargs):
367    return self.f(*args, **kwargs)
368
369  def __repr__(self):
370    return f'<hashable {self.f.__name__} with closure={self.closure}>'
371
372def as_hashable_function(closure):
373  return lambda f: HashableFunction(f, closure)
374