1# Copyright 2018 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15"""
16Utilities for defining functions composed with transformations.
17
18For example,
19
20   from jax import linear_util as lu
21
22   wf = lu.wrap_init(f)  # Produce a WrappedFun for applying transformations on `f`
23
24A `WrappedFun` object represents a function `f`, together with a sequence of
25nested transformations that are to be applied to the positional and keyword
26arguments at call time and function return values at return time.
27A transformation can take some static positional arguments that are given
28at the wrapping time, and may also return some auxiliary output:
29
30    wf, aux_out_thunk = trans1(wf, static_arg)
31
32We can call the transformed function. First, the transformation is applied
33to the dynamic args and keyword args to produce new dynamic and keyword args.
34Then the underlying function is called and the transformation is applied to
35the results.
36If there are multiple transformations, they form a stack. The arguments are
37transformed first with the last applied transformation; the results are
38transformed first with the first applied transformation.
39
40    res = wf.call_wrapped(dynamic_args, kwargs)
41    # Now `aux_out_thunk()` is the auxiliary output.
42
43A transformation is written as a generator function that takes zero or more
44static positional arguments (given when the transformation is instantiated),
45along with positional and keyword arguments to be transformed.
46The generator will yield twice:
47
48    @lu.transformation_with_aux
49    def trans1(static_arg, *dynamic_args, **kwargs):
50      ...
51      # First yield: pair of transformed (args, kwargs). Get back the results.
52      results = yield (new_dynamic_args, new_kwargs)
53      ...
54      # Second yield: pair of (transformed results, and auxiliary output)
55      yield new_results, auxiliary_output
56
57
58`WrappedFun` objects explicitly represent the set of transformations so that
59they can be used as dictionary keys for memoization. `WrappedFun` objects
60compare as equal only if they compute the same function. The static and the
61dynamic positional arguments for the generators, and also the auxiliary output
62data must be immutable, because it will be stored in function memoization tables.
63"""
64
65import threading
66from functools import partial
67from typing import Any, Tuple, Callable
68import weakref
69
70from . import core
71from ._src.util import curry
72from .tree_util import tree_map
73
74from ._src import traceback_util
75
76from .config import FLAGS
77
78traceback_util.register_exclusion(__file__)
79
80
81class StoreException(Exception): pass
82
83
84class EmptyStoreValue(object): pass
85_EMPTY_STORE_VALUE = EmptyStoreValue()
86
87class Store(object):
88  """Storage for a value, with checks for overwriting or reading empty store."""
89  __slots__ = ("_val",)
90
91  def __init__(self):
92    self._val = _EMPTY_STORE_VALUE
93
94  def store(self, val):
95    if self._val is not _EMPTY_STORE_VALUE:
96      raise StoreException("Store occupied")
97    self._val = val
98
99  def reset(self):
100    # This should only be called in exceptional circumstances (e.g. debugging).
101    self._val = _EMPTY_STORE_VALUE
102
103  @property
104  def val(self):
105    if not self:
106      raise StoreException("Store empty")
107    return self._val
108
109  def __nonzero__(self):
110    return self._val is not _EMPTY_STORE_VALUE
111
112  __bool__ = __nonzero__
113
114
115class WrappedFun(object):
116  """Represents a function `f` to which `transforms` are to be applied.
117
118  Args:
119    f: the function to be transformed.
120    transforms: a list of `(gen, gen_static_args)` tuples representing
121      transformations to apply to `f.` Here `gen` is a generator function
122      and `gen_static_args` is a tuple of static arguments for the generator. See
123      description at the start of this module for the expected behavior of the
124      generator.
125    stores: a list of out_store for the auxiliary output of the `transforms`.
126    params: extra parameters to pass as keyword arguments to `f`, along with the
127      transformed keyword arguments.
128  """
129  __slots__ = ("f", "transforms", "stores", "params")
130
131  def __init__(self, f, transforms, stores, params):
132    self.f = f
133    self.transforms = transforms
134    self.stores = stores
135    self.params = params
136
137  @property
138  def __name__(self):
139    return getattr(self.f, '__name__', '<unnamed wrapped function>')
140
141  def wrap(self, gen, gen_static_args, out_store) -> 'WrappedFun':
142    """Add another transform and its store."""
143    return WrappedFun(self.f, ((gen, gen_static_args),) + self.transforms,
144                      (out_store,) + self.stores, self.params)
145
146  def populate_stores(self, stores):
147    """Copy the values from the `stores` into `self.stores`."""
148    for self_store, other_store in zip(self.stores, stores):
149      if self_store is not None:
150        self_store.store(other_store.val)
151
152  def call_wrapped(self, *args, **kwargs):
153    """Calls the underlying function, applying the transforms.
154
155    The positional `args` and keyword `kwargs` are passed to the first
156    transformation generator.
157    """
158    stack = []
159    for (gen, gen_static_args), out_store in zip(self.transforms, self.stores):
160      gen = gen(*(gen_static_args + tuple(args)), **kwargs)
161      args, kwargs = next(gen)
162      stack.append((gen, out_store))
163    gen = gen_static_args = out_store = None
164
165    try:
166      ans = self.f(*args, **dict(self.params, **kwargs))
167    except:
168      # Some transformations yield from inside context managers, so we have to
169      # interrupt them before reraising the exception. Otherwise they will only
170      # get garbage-collected at some later time, running their cleanup tasks only
171      # after this exception is handled, which can corrupt the global state.
172      while stack:
173        stack.pop()[0].close()
174      raise
175
176    args = kwargs = None
177    while stack:
178      gen, out_store = stack.pop()
179      ans = gen.send(ans)
180      if out_store is not None:
181        ans, side = ans
182        out_store.store(side)
183
184    return ans
185
186  def __repr__(self):
187    def transform_to_str(x):
188      i, (gen, args) = x
189      return "{}   : {}   {}".format(i, fun_name(gen), fun_name(args))
190    transformation_stack = map(transform_to_str, enumerate(self.transforms))
191    return "Wrapped function:\n" + '\n'.join(transformation_stack) + '\nCore: ' + fun_name(self.f) + '\n'
192
193  def __hash__(self):
194    return hash((self.f, self.transforms, self.params))
195
196  def __eq__(self, other):
197    return (self.f == other.f and self.transforms == other.transforms and
198            self.params == other.params)
199
200@curry
201def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun:
202  """Adds one more transformation to a WrappedFun.
203  Args:
204    gen: the transformation generator function
205    fun: a WrappedFun on which to apply the transformation
206    gen_static_args: static args for the generator function
207  """
208  return fun.wrap(gen, gen_static_args, None)
209
210@curry
211def transformation_with_aux(gen, fun: WrappedFun, *gen_static_args) -> Tuple[WrappedFun, Any]:
212  """Adds one more transformation with auxiliary output to a WrappedFun."""
213  out_store = Store()
214  out_thunk = lambda: out_store.val
215  return fun.wrap(gen, gen_static_args, out_store), out_thunk
216
217def fun_name(f):
218  try:
219    return f.__name__
220  except:
221    return str(f)
222
223def wrap_init(f, params={}) -> WrappedFun:
224  """Wraps function `f` as a `WrappedFun`, suitable for transformation."""
225  return WrappedFun(f, (), (), tuple(sorted(params.items())))
226
227
228class _CacheLocalContext(threading.local):
229
230  def __init__(self):
231    super(_CacheLocalContext, self).__init__()
232    self.most_recent_entry = None
233
234
235def cache(call: Callable):
236  """Memoization decorator for functions taking a WrappedFun as first argument.
237
238  Args:
239    call: a Python callable that takes a WrappedFun as its first argument. The
240      underlying transforms and params on the WrappedFun are used as part of the
241      memoization cache key.
242
243  Returns:
244     A memoized version of ``call``.
245  """
246  fun_caches: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary()
247  thread_local: threading.local = _CacheLocalContext()
248
249  def memoized_fun(fun: WrappedFun, *args):
250    cache = fun_caches.setdefault(fun.f, {})
251    if core.debug_state.check_leaks:
252      key = (_copy_main_traces(fun.transforms), fun.params, args, bool(FLAGS.jax_enable_x64))
253    else:
254      key = (fun.transforms, fun.params, args, bool(FLAGS.jax_enable_x64))
255    result = cache.get(key, None)
256    if result is not None:
257      ans, stores = result
258      fun.populate_stores(stores)
259    else:
260      ans = call(fun, *args)
261      cache[key] = (ans, fun.stores)
262
263    thread_local.most_recent_entry = weakref.ref(ans)
264    return ans
265
266  def _most_recent_entry():
267    most_recent_entry = thread_local.most_recent_entry
268    if most_recent_entry is not None:
269      result = most_recent_entry()
270      thread_local.most_recent_entry = None
271      return result
272
273  memoized_fun.most_recent_entry = _most_recent_entry  # type: ignore
274  memoized_fun.cache_clear = fun_caches.clear  # type: ignore
275
276  return memoized_fun
277
278@partial(partial, tree_map)
279def _copy_main_traces(x):
280  if isinstance(x, core.MainTrace):
281    return core.MainTrace(x.level, x.trace_type, **x.payload)
282  else:
283    return x
284
285
286@transformation
287def hashable_partial(x, *args):
288  ans = yield (x,) + args, {}
289  yield ans
290
291
292def merge_linear_aux(aux1, aux2):
293  try:
294    out1 = aux1()
295  except StoreException:
296    # store 1 was not occupied, so store 2 better be
297    try:
298      out2 = aux2()
299    except StoreException:
300      raise StoreException("neither store occupied") from None
301    else:
302      return False, out2
303  else:
304    # store 1 was occupied, so let's check store 2 is not occupied
305    try:
306      out2 = aux2()
307    except StoreException:
308      return True, out1
309    else:
310      raise StoreException("both stores occupied")
311