1# Copyright 2018 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15""" 16Utilities for defining functions composed with transformations. 17 18For example, 19 20 from jax import linear_util as lu 21 22 wf = lu.wrap_init(f) # Produce a WrappedFun for applying transformations on `f` 23 24A `WrappedFun` object represents a function `f`, together with a sequence of 25nested transformations that are to be applied to the positional and keyword 26arguments at call time and function return values at return time. 27A transformation can take some static positional arguments that are given 28at the wrapping time, and may also return some auxiliary output: 29 30 wf, aux_out_thunk = trans1(wf, static_arg) 31 32We can call the transformed function. First, the transformation is applied 33to the dynamic args and keyword args to produce new dynamic and keyword args. 34Then the underlying function is called and the transformation is applied to 35the results. 36If there are multiple transformations, they form a stack. The arguments are 37transformed first with the last applied transformation; the results are 38transformed first with the first applied transformation. 39 40 res = wf.call_wrapped(dynamic_args, kwargs) 41 # Now `aux_out_thunk()` is the auxiliary output. 42 43A transformation is written as a generator function that takes zero or more 44static positional arguments (given when the transformation is instantiated), 45along with positional and keyword arguments to be transformed. 46The generator will yield twice: 47 48 @lu.transformation_with_aux 49 def trans1(static_arg, *dynamic_args, **kwargs): 50 ... 51 # First yield: pair of transformed (args, kwargs). Get back the results. 52 results = yield (new_dynamic_args, new_kwargs) 53 ... 54 # Second yield: pair of (transformed results, and auxiliary output) 55 yield new_results, auxiliary_output 56 57 58`WrappedFun` objects explicitly represent the set of transformations so that 59they can be used as dictionary keys for memoization. `WrappedFun` objects 60compare as equal only if they compute the same function. The static and the 61dynamic positional arguments for the generators, and also the auxiliary output 62data must be immutable, because it will be stored in function memoization tables. 63""" 64 65import threading 66from functools import partial 67from typing import Any, Tuple, Callable 68import weakref 69 70from . import core 71from ._src.util import curry 72from .tree_util import tree_map 73 74from ._src import traceback_util 75 76from .config import FLAGS 77 78traceback_util.register_exclusion(__file__) 79 80 81class StoreException(Exception): pass 82 83 84class EmptyStoreValue(object): pass 85_EMPTY_STORE_VALUE = EmptyStoreValue() 86 87class Store(object): 88 """Storage for a value, with checks for overwriting or reading empty store.""" 89 __slots__ = ("_val",) 90 91 def __init__(self): 92 self._val = _EMPTY_STORE_VALUE 93 94 def store(self, val): 95 if self._val is not _EMPTY_STORE_VALUE: 96 raise StoreException("Store occupied") 97 self._val = val 98 99 def reset(self): 100 # This should only be called in exceptional circumstances (e.g. debugging). 101 self._val = _EMPTY_STORE_VALUE 102 103 @property 104 def val(self): 105 if not self: 106 raise StoreException("Store empty") 107 return self._val 108 109 def __nonzero__(self): 110 return self._val is not _EMPTY_STORE_VALUE 111 112 __bool__ = __nonzero__ 113 114 115class WrappedFun(object): 116 """Represents a function `f` to which `transforms` are to be applied. 117 118 Args: 119 f: the function to be transformed. 120 transforms: a list of `(gen, gen_static_args)` tuples representing 121 transformations to apply to `f.` Here `gen` is a generator function 122 and `gen_static_args` is a tuple of static arguments for the generator. See 123 description at the start of this module for the expected behavior of the 124 generator. 125 stores: a list of out_store for the auxiliary output of the `transforms`. 126 params: extra parameters to pass as keyword arguments to `f`, along with the 127 transformed keyword arguments. 128 """ 129 __slots__ = ("f", "transforms", "stores", "params") 130 131 def __init__(self, f, transforms, stores, params): 132 self.f = f 133 self.transforms = transforms 134 self.stores = stores 135 self.params = params 136 137 @property 138 def __name__(self): 139 return getattr(self.f, '__name__', '<unnamed wrapped function>') 140 141 def wrap(self, gen, gen_static_args, out_store) -> 'WrappedFun': 142 """Add another transform and its store.""" 143 return WrappedFun(self.f, ((gen, gen_static_args),) + self.transforms, 144 (out_store,) + self.stores, self.params) 145 146 def populate_stores(self, stores): 147 """Copy the values from the `stores` into `self.stores`.""" 148 for self_store, other_store in zip(self.stores, stores): 149 if self_store is not None: 150 self_store.store(other_store.val) 151 152 def call_wrapped(self, *args, **kwargs): 153 """Calls the underlying function, applying the transforms. 154 155 The positional `args` and keyword `kwargs` are passed to the first 156 transformation generator. 157 """ 158 stack = [] 159 for (gen, gen_static_args), out_store in zip(self.transforms, self.stores): 160 gen = gen(*(gen_static_args + tuple(args)), **kwargs) 161 args, kwargs = next(gen) 162 stack.append((gen, out_store)) 163 gen = gen_static_args = out_store = None 164 165 try: 166 ans = self.f(*args, **dict(self.params, **kwargs)) 167 except: 168 # Some transformations yield from inside context managers, so we have to 169 # interrupt them before reraising the exception. Otherwise they will only 170 # get garbage-collected at some later time, running their cleanup tasks only 171 # after this exception is handled, which can corrupt the global state. 172 while stack: 173 stack.pop()[0].close() 174 raise 175 176 args = kwargs = None 177 while stack: 178 gen, out_store = stack.pop() 179 ans = gen.send(ans) 180 if out_store is not None: 181 ans, side = ans 182 out_store.store(side) 183 184 return ans 185 186 def __repr__(self): 187 def transform_to_str(x): 188 i, (gen, args) = x 189 return "{} : {} {}".format(i, fun_name(gen), fun_name(args)) 190 transformation_stack = map(transform_to_str, enumerate(self.transforms)) 191 return "Wrapped function:\n" + '\n'.join(transformation_stack) + '\nCore: ' + fun_name(self.f) + '\n' 192 193 def __hash__(self): 194 return hash((self.f, self.transforms, self.params)) 195 196 def __eq__(self, other): 197 return (self.f == other.f and self.transforms == other.transforms and 198 self.params == other.params) 199 200@curry 201def transformation(gen, fun: WrappedFun, *gen_static_args) -> WrappedFun: 202 """Adds one more transformation to a WrappedFun. 203 Args: 204 gen: the transformation generator function 205 fun: a WrappedFun on which to apply the transformation 206 gen_static_args: static args for the generator function 207 """ 208 return fun.wrap(gen, gen_static_args, None) 209 210@curry 211def transformation_with_aux(gen, fun: WrappedFun, *gen_static_args) -> Tuple[WrappedFun, Any]: 212 """Adds one more transformation with auxiliary output to a WrappedFun.""" 213 out_store = Store() 214 out_thunk = lambda: out_store.val 215 return fun.wrap(gen, gen_static_args, out_store), out_thunk 216 217def fun_name(f): 218 try: 219 return f.__name__ 220 except: 221 return str(f) 222 223def wrap_init(f, params={}) -> WrappedFun: 224 """Wraps function `f` as a `WrappedFun`, suitable for transformation.""" 225 return WrappedFun(f, (), (), tuple(sorted(params.items()))) 226 227 228class _CacheLocalContext(threading.local): 229 230 def __init__(self): 231 super(_CacheLocalContext, self).__init__() 232 self.most_recent_entry = None 233 234 235def cache(call: Callable): 236 """Memoization decorator for functions taking a WrappedFun as first argument. 237 238 Args: 239 call: a Python callable that takes a WrappedFun as its first argument. The 240 underlying transforms and params on the WrappedFun are used as part of the 241 memoization cache key. 242 243 Returns: 244 A memoized version of ``call``. 245 """ 246 fun_caches: weakref.WeakKeyDictionary = weakref.WeakKeyDictionary() 247 thread_local: threading.local = _CacheLocalContext() 248 249 def memoized_fun(fun: WrappedFun, *args): 250 cache = fun_caches.setdefault(fun.f, {}) 251 if core.debug_state.check_leaks: 252 key = (_copy_main_traces(fun.transforms), fun.params, args, bool(FLAGS.jax_enable_x64)) 253 else: 254 key = (fun.transforms, fun.params, args, bool(FLAGS.jax_enable_x64)) 255 result = cache.get(key, None) 256 if result is not None: 257 ans, stores = result 258 fun.populate_stores(stores) 259 else: 260 ans = call(fun, *args) 261 cache[key] = (ans, fun.stores) 262 263 thread_local.most_recent_entry = weakref.ref(ans) 264 return ans 265 266 def _most_recent_entry(): 267 most_recent_entry = thread_local.most_recent_entry 268 if most_recent_entry is not None: 269 result = most_recent_entry() 270 thread_local.most_recent_entry = None 271 return result 272 273 memoized_fun.most_recent_entry = _most_recent_entry # type: ignore 274 memoized_fun.cache_clear = fun_caches.clear # type: ignore 275 276 return memoized_fun 277 278@partial(partial, tree_map) 279def _copy_main_traces(x): 280 if isinstance(x, core.MainTrace): 281 return core.MainTrace(x.level, x.trace_type, **x.payload) 282 else: 283 return x 284 285 286@transformation 287def hashable_partial(x, *args): 288 ans = yield (x,) + args, {} 289 yield ans 290 291 292def merge_linear_aux(aux1, aux2): 293 try: 294 out1 = aux1() 295 except StoreException: 296 # store 1 was not occupied, so store 2 better be 297 try: 298 out2 = aux2() 299 except StoreException: 300 raise StoreException("neither store occupied") from None 301 else: 302 return False, out2 303 else: 304 # store 1 was occupied, so let's check store 2 is not occupied 305 try: 306 out2 = aux2() 307 except StoreException: 308 return True, out1 309 else: 310 raise StoreException("both stores occupied") 311