1# Copyright 2018 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import itertools as it 16from collections import namedtuple 17import contextlib 18import functools 19from typing import (Any, Callable, Dict, NamedTuple, Optional, Sequence, Tuple, 20 List, Union, cast, Type, no_type_check) 21from weakref import ref 22 23import numpy as np 24 25from .. import core 26from .. import dtypes 27from .. import linear_util as lu 28from ..ad_util import Zero 29from .._src.util import (unzip2, safe_zip, safe_map, toposort, partial, 30 split_list, cache, as_hashable_function) 31from ..core import (Trace, Tracer, Jaxpr, Literal, get_aval, AbstractValue, 32 unit, unitvar, abstract_unit, ClosedJaxpr, new_jaxpr_eqn, 33 dropvar, ConcreteArray, raise_to_shaped) 34from jax._src import source_info_util 35from ..config import config 36 37map = safe_map 38zip = safe_zip 39def identity(x): return x 40 41class PartialVal(tuple): 42 """Partial value: either a known value or an unknown (abstract) value. 43 44 Represented as a pair `(aval_opt, const)` of one of two kinds: 45 * `(None, <Constant>)` indicates a known value, either a Python regular 46 value, or a Tracer. 47 * `(<AbstractValue>, *)` indicates an unknown value characterized by an 48 abstract value. 49 """ 50 def __new__(cls, xs: Tuple[Optional[AbstractValue], core.Value]): 51 pv, const = xs 52 if not core.skip_checks: 53 # type checks 54 assert isinstance(pv, (AbstractValue, type(None))), xs 55 assert isinstance(const, core.Tracer) or type(const) is Zero or core.valid_jaxtype(const), xs 56 # invariant checks 57 if isinstance(pv, AbstractValue): 58 assert get_aval(const) == core.abstract_unit, xs 59 return tuple.__new__(cls, xs) 60 61 @classmethod 62 def known(cls, const: core.Value) -> 'PartialVal': 63 return PartialVal((None, const)) 64 65 @classmethod 66 def unknown(cls, aval: AbstractValue) -> 'PartialVal': 67 return PartialVal((aval, core.unit)) 68 69 def is_known(self) -> bool: 70 return self[0] is None 71 72 def get_known(self) -> Optional[core.Value]: 73 """Get the known value, if known, else None.""" 74 return self[1] if self[0] is None else None 75 76 def get_aval(self) -> AbstractValue: 77 """Get AbstractValue directly (if unknown) or from the constant (known).""" 78 known = self.get_known() 79 if known is not None: 80 return get_aval(known) 81 else: 82 return self[0] 83 84 def merge_with_known(self, val: core.Value) -> core.Value: 85 """Either the stored known value, or the given 'val'.""" 86 known = self.get_known() 87 return known if known is not None else val 88 89 90class JaxprTrace(Trace): 91 def pure(self, val) -> 'JaxprTracer': 92 return self.new_const(val) 93 94 def lift(self, val) -> 'JaxprTracer': 95 return self.new_const(val) 96 97 def sublift(self, val) -> 'JaxprTracer': 98 return JaxprTracer(self, val.pval, FreeVar(val)) 99 100 def new_const(self, val) -> 'JaxprTracer': 101 if isinstance(val, Tracer) and val._trace.level == self.level: 102 raise Exception 103 return JaxprTracer(self, PartialVal.known(val), unit) 104 105 def new_instantiated_literal(self, val) -> 'JaxprTracer': 106 return JaxprTracer(self, PartialVal.unknown(get_aval(val)), Literal(val)) 107 108 def new_instantiated_const(self, val) -> 'JaxprTracer': 109 return JaxprTracer(self, PartialVal.unknown(get_aval(val)), ConstVar(val)) 110 111 def new_arg(self, pval: PartialVal) -> 'JaxprTracer': 112 const = pval.get_known() 113 if const is None: 114 return JaxprTracer(self, pval, LambdaBinding()) 115 else: 116 return self.new_const(const) 117 118 def instantiate_const(self, tracer) -> Tracer: 119 const = tracer.pval.get_known() 120 if const is None: 121 return tracer 122 else: 123 if type(const) in core.literalable_types and np.shape(const) == (): 124 return self.new_instantiated_literal(const) 125 else: 126 return self.new_instantiated_const(const) 127 128 def instantiate_const_abstracted(self, tracer) -> 'JaxprTracer': 129 const = tracer.pval.get_known() 130 if const is None: 131 return tracer 132 else: 133 aval = raise_to_shaped(get_aval(const), np.isscalar(const)) 134 return JaxprTracer(self, PartialVal.unknown(aval), ConstVar(const)) 135 136 def process_primitive(self, primitive, tracers, params): 137 if primitive in custom_partial_eval_rules: 138 return custom_partial_eval_rules[primitive](self, *tracers, **params) 139 else: 140 return self.default_process_primitive(primitive, tracers, params) 141 142 def default_process_primitive(self, primitive, tracers, params): 143 """By default, if all the input tracers are known, then execute the primitive 144 and all the ouputs are known. Otherwise, all the outputs are unknown.""" 145 consts = [t.pval.get_known() for t in tracers] 146 if all(c is not None for c in consts): 147 return primitive.bind(*consts, **params) 148 tracers = map(self.instantiate_const, tracers) 149 avals = [t.aval for t in tracers] 150 out_aval = primitive.abstract_eval(*avals, **params) 151 source = source_info_util.current() 152 if primitive.multiple_results: 153 out_tracers = [JaxprTracer(self, PartialVal.unknown(aval), None) 154 for aval in out_aval] 155 eqn = new_eqn_recipe(tracers, out_tracers, primitive, params, source) 156 for t in out_tracers: t.recipe = eqn 157 return out_tracers 158 else: 159 out_tracer = JaxprTracer(self, PartialVal.unknown(out_aval), None) 160 out_tracer.recipe = new_eqn_recipe(tracers, [out_tracer], primitive, 161 params, source) 162 return out_tracer 163 164 # We use process_call to handle both call and map primitives. 165 def process_call(self, primitive, f: lu.WrappedFun, tracers, params): 166 if not config.omnistaging_enabled: 167 if (self.main.trace_type is StagingJaxprTrace # type: ignore 168 and primitive in staged_out_calls): # type: ignore 169 tracers = map(self.instantiate_const_abstracted, tracers) 170 171 if primitive in call_partial_eval_rules: 172 return call_partial_eval_rules[primitive](self, primitive, f, tracers, params) 173 174 in_pvals = [t.pval for t in tracers] 175 if primitive.map_primitive: 176 mapped_aval = partial(core.mapped_aval, params['axis_size']) 177 in_pvals = [pval if pval.is_known() or in_axis is None 178 else PartialVal.unknown(mapped_aval(in_axis, pval[0])) 179 for pval, in_axis in zip(in_pvals, params['in_axes'])] 180 181 def app(f, *args): 182 f, num_outputs = count_outputs(f) 183 out_axes_thunk = params['out_axes_thunk'] 184 @as_hashable_function(closure=out_axes_thunk) 185 def new_out_axes_thunk(): 186 out_axes = out_axes_thunk() 187 return out_axes + (0,) * (num_outputs() - len(out_axes)) 188 pe_params = dict(params, out_axes_thunk=new_out_axes_thunk) 189 return primitive.bind(f, *args, **pe_params) 190 else: 191 app = partial(primitive.bind, **params) 192 jaxpr, out_pvals, consts, env_tracers = self.partial_eval( 193 f, in_pvals, app, instantiate=False) 194 if primitive.map_primitive: 195 unmapped_aval = partial(core.unmapped_aval, params['axis_size']) 196 out_axes = params['out_axes_thunk']() 197 out_pvals = [pval if pval.is_known() else 198 PartialVal.unknown(unmapped_aval(out_axis, pval[0])) if out_axis is not None else 199 PartialVal.unknown(pval[0]) 200 for pval, out_axis in zip(out_pvals, out_axes)] 201 202 # Skip known invars and outvars, and lift constants as regular invars 203 in_knowns = tuple(t.pval.is_known() for t in it.chain(env_tracers, tracers)) 204 out_unknowns = tuple(not pval.is_known() for pval in out_pvals) 205 jaxpr = _drop_invars(jaxpr, in_knowns) 206 jaxpr = _dce_open_jaxpr(jaxpr, out_unknowns, drop_outputs=True) 207 208 # Known tracers get propagated as if they were constants 209 known_tracers_out = [self.new_const(pval.get_known()) for pval in out_pvals 210 if pval.is_known()] 211 212 # Unknown tracers need to have the jaxpr set up as their recipe 213 unknown_tracers_out = [JaxprTracer(self, pval, None) for pval in out_pvals 214 if not pval.is_known()] 215 unknown_tracers_in = [t for t in tracers if not t.pval.is_known()] 216 const_tracers = map(self.new_instantiated_const, consts) 217 in_tracers = (*const_tracers, *env_tracers, *unknown_tracers_in) 218 219 # Set up new params 220 new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr)) 221 if primitive.map_primitive: 222 in_axes = params['in_axes'] 223 # NOTE: const_tracers are added as map outputs, and we always map them 224 # along axis 0 (see `new_out_axes_thunk` above). 225 new_in_axes = ((0,) * len(const_tracers) + 226 (None,) * len(env_tracers) + 227 tuple(axis for axis, t in zip(in_axes, tracers) 228 if not t.pval.is_known())) 229 new_out_axes = tuple(axis for axis, pval in zip(out_axes, out_pvals) 230 if not pval.is_known()) 231 new_params = dict(new_params, in_axes=new_in_axes, out_axes=new_out_axes) 232 del new_params['out_axes_thunk'] 233 update_params = call_param_updaters.get(primitive) 234 if update_params: 235 new_params = update_params(new_params, [not t.pval.is_known() for t in tracers]) 236 237 eqn = new_eqn_recipe(in_tracers, unknown_tracers_out, primitive, new_params, 238 source_info_util.current()) 239 for t in unknown_tracers_out: t.recipe = eqn 240 return _zip_knowns(known_tracers_out, unknown_tracers_out, out_unknowns) 241 242 process_map = process_call 243 244 # We use post_process_call to handle both call and map primitives. 245 def post_process_call(self, primitive, out_tracers, params): 246 jaxpr, consts, env = tracers_to_jaxpr([], out_tracers) 247 out_pvs, out_pv_consts = unzip2(t.pval for t in out_tracers) 248 out = out_pv_consts + consts 249 nconsts = len(consts) 250 del consts, out_pv_consts 251 main = self.main 252 253 if primitive.map_primitive: 254 out_axes = params['out_axes_thunk']() 255 sz = params['axis_size'] 256 out_pvs = [None if pv is None else core.unmapped_aval(sz, ax, pv) 257 for pv, ax in zip(out_pvs, out_axes)] 258 259 def todo(x): 260 n = len(jaxpr.outvars) 261 out_pv_consts, consts = x[:n], x[n:] 262 trace = JaxprTrace(main, core.cur_sublevel()) 263 const_tracers = map(trace.new_instantiated_const, consts) 264 out_tracers = [JaxprTracer(trace, PartialVal((out_pv, out_pv_const)), None) 265 for out_pv, out_pv_const in zip(out_pvs, out_pv_consts)] 266 in_tracers = (*const_tracers, *map(trace.full_raise, env)) 267 268 new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr)) 269 if primitive.map_primitive: 270 # NOTE: We've assigned axis 0 to const tracers below, in out_axes_transform. 271 new_in_axes = (0,) * len(const_tracers) + (None,) * len(env) 272 new_params = dict(new_params, in_axes=new_in_axes, out_axes=out_axes) 273 del new_params['out_axes_thunk'] 274 update_params = call_param_updaters.get(primitive) 275 if update_params: 276 new_params = update_params(new_params, []) 277 278 eqn = new_eqn_recipe(in_tracers, out_tracers, primitive, new_params, 279 source_info_util.current()) 280 for t in out_tracers: 281 t.recipe = eqn 282 return out_tracers 283 284 if primitive.map_primitive: 285 def out_axes_transform(out_axes): 286 return out_axes + (0,) * nconsts 287 todo = (todo, out_axes_transform) 288 289 return out, todo 290 291 post_process_map = post_process_call 292 293 def partial_eval(self, f: lu.WrappedFun, pvals: Sequence[PartialVal], 294 app: Callable[[lu.WrappedFun, Tuple[core.Value, ...]], Tuple[core.Value]], 295 instantiate: bool): 296 """Partially evaluate f on a sequence of PartialVals.""" 297 in_avals, in_consts = unzip2(pvals) 298 f = trace_to_subjaxpr(f, self.main, instantiate) 299 f, aux = partial_eval_wrapper(f, tuple(in_avals)) 300 out_flat, (out_avals, jaxpr, env) = app(f, *in_consts), aux() 301 out_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)]) 302 out_pvs = map(PartialVal, zip(out_avals, out_consts)) 303 env_tracers = map(self.full_raise, env) 304 return jaxpr, out_pvs, consts, env_tracers 305 306 def process_custom_jvp_call(self, prim, fun, jvp, tracers): 307 tracers = map(self.instantiate_const_abstracted, tracers) 308 in_avals, in_consts = unzip2(t.pval for t in tracers) # in_consts are units 309 fun = trace_to_subjaxpr(fun, self.main, True) 310 fun, aux = partial_eval_wrapper(fun, tuple(in_avals)) 311 out_flat = prim.bind(fun, jvp, *in_consts) 312 out_avals, jaxpr, env = aux() 313 out_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)]) 314 out_pvals = map(PartialVal, zip(out_avals, out_consts)) # out_consts are units 315 env_tracers = map(self.full_raise, env) 316 out_tracers = [JaxprTracer(self, pval, None) for pval in out_pvals] 317 const_tracers = map(self.new_instantiated_const, consts) 318 in_tracers = (*const_tracers, *env_tracers, *tracers) 319 closed_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(jaxpr), ()) 320 321 @_memoize 322 def jvp_jaxpr_thunk(): 323 jvp_ = trace_to_subjaxpr(jvp, self.main, True) 324 jvp_, aux = partial_eval_wrapper(jvp_, tuple(in_avals) * 2) 325 out_flat = jvp_.call_wrapped(*(in_consts * 2)) # in_consts are units 326 out_avals, jaxpr, env = aux() 327 _, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)]) 328 converted_jaxpr = convert_envvars_to_constvars(jaxpr, len(env)) 329 return converted_jaxpr, (*consts, *env) 330 331 eqn = new_eqn_recipe(in_tracers, out_tracers, prim.initial_style, 332 dict(fun_jaxpr=closed_jaxpr, 333 jvp_jaxpr_thunk=jvp_jaxpr_thunk, 334 num_consts=len(consts) + len(env)), 335 source_info_util.current()) 336 for t in out_tracers: t.recipe = eqn 337 return out_tracers 338 339 def post_process_custom_jvp_call(self, out_tracers, params): 340 # This path should only be reachable if we expose a partial eval API 341 # unrelated to autodiff, since we raise an error when differentiation with 342 # respect to values over which a custom_jvp function closes is detected. 343 raise NotImplementedError # TODO(mattjj) 344 345 def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees): 346 tracers = map(self.instantiate_const_abstracted, tracers) 347 in_avals, in_consts = unzip2(t.pval for t in tracers) # in_consts are units 348 fun = trace_to_subjaxpr(fun, self.main, True) 349 fun, aux = partial_eval_wrapper(fun, tuple(in_avals)) 350 out_flat = prim.bind(fun, fwd, bwd, *in_consts, out_trees=out_trees) 351 out_avals, jaxpr, env = aux() 352 out_consts, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)]) 353 out_pvals = map(PartialVal, zip(out_avals, out_consts)) # out_consts are units 354 env_tracers = map(self.full_raise, env) 355 out_tracers = [JaxprTracer(self, pval, None) for pval in out_pvals] 356 const_tracers = map(self.new_instantiated_const, consts) 357 in_tracers = (*const_tracers, *env_tracers, *tracers) 358 closed_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(jaxpr), ()) 359 360 @_memoize 361 def fwd_jaxpr_thunk(): 362 fwd_ = trace_to_subjaxpr(fwd, self.main, True) 363 fwd_, aux = partial_eval_wrapper(fwd_, tuple(in_avals)) 364 out_flat = fwd_.call_wrapped(*in_consts) # in_consts are units 365 out_avals, jaxpr, env = aux() 366 _, consts = split_list(out_flat, [len(out_flat)-len(jaxpr.constvars)]) 367 converted_jaxpr = convert_envvars_to_constvars(jaxpr, len(env)) 368 return converted_jaxpr, (*consts, *env) 369 370 eqn = new_eqn_recipe(in_tracers, out_tracers, prim.initial_style, 371 dict(fun_jaxpr=closed_jaxpr, 372 fwd_jaxpr_thunk=fwd_jaxpr_thunk, 373 num_consts=len(consts) + len(env), 374 bwd=bwd, out_trees=out_trees), 375 source_info_util.current()) 376 for t in out_tracers: t.recipe = eqn 377 return out_tracers 378 379 def post_process_custom_vjp_call(self, out_tracers, params): 380 # This path should only be reachable if we expose a partial eval API 381 # unrelated to autodiff, since we raise an error when differentiation with 382 # respect to values over which a custom_vjp function closes is detected. 383 raise NotImplementedError # TODO(mattjj) 384 385 386@lu.transformation_with_aux 387def partial_eval_wrapper(pvs: Sequence[Optional[AbstractValue]], *consts): 388 py_args = map(PartialVal, zip(pvs, consts)) 389 jaxpr, (out_pvals, consts, env) = yield (py_args,), {} 390 out_pvs, out_consts = unzip2(out_pvals) 391 out = tuple(out_consts) + tuple(consts) 392 yield out, (out_pvs, jaxpr, env) 393 394@lu.transformation_with_aux 395def count_outputs(*args, **kwargs): 396 ans = yield args, kwargs 397 yield ans, len(ans) 398 399custom_partial_eval_rules: Dict[core.Primitive, Callable] = {} 400call_partial_eval_rules: Dict[core.Primitive, Callable] = {} 401call_param_updaters: Dict[core.Primitive, Callable] = {} 402 403 404def abstract_eval_fun(fun, *avals, **params): 405 if config.omnistaging_enabled: 406 _, avals_out, _ = trace_to_jaxpr_dynamic(lu.wrap_init(fun, params), avals) 407 else: 408 pvals_in = [PartialVal.unknown(a) for a in avals] 409 _, pvals_out, _ = trace_to_jaxpr(lu.wrap_init(fun, params), pvals_in, 410 instantiate=True, stage_out=True) # type: ignore 411 avals_out, _ = unzip2(pvals_out) 412 for aval_out in avals_out: 413 assert isinstance(aval_out, AbstractValue) # instantiate=True 414 return avals_out 415 416 417JaxprTracerRecipe = Union['JaxprEqnRecipe', 'LambdaBinding', 'FreeVar', 418 'ConstVar', Literal, core.Unit] 419 420class JaxprTracer(Tracer): 421 __slots__ = ['pval', 'recipe'] 422 423 def __init__(self, trace: JaxprTrace, pval: PartialVal, 424 recipe: Optional[JaxprTracerRecipe]): 425 assert isinstance(pval, PartialVal) 426 pv, const = pval 427 if isinstance(const, Tracer) and const._trace.level >= trace.level: 428 raise core.escaped_tracer_error( 429 const, "Tracer from a higher level: {} in trace {}".format(const, trace)) 430 self._trace = trace 431 self.pval = pval 432 self.recipe = recipe 433 434 def __repr__(self): 435 return 'Traced<{}:{}>'.format(self.aval, self._trace) 436 437 @property 438 def aval(self) -> AbstractValue: 439 return self.pval.get_aval() 440 441 @property 442 def parents(self) -> Sequence['JaxprTracer']: 443 if isinstance(self.recipe, JaxprEqnRecipe): 444 return self.recipe.invars 445 else: 446 return [] 447 448 def full_lower(self): 449 known = self.pval.get_known() 450 if known is not None: 451 return core.full_lower(known) 452 else: 453 return self 454 455 def is_known(self): 456 return self.pval.is_known() 457 458# TODO(necula): this could return a ClosedJaxpr with out_pvals 459def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal], 460 instantiate: Union[bool, Sequence[bool]] = False, 461 ) -> Tuple[Jaxpr, Tuple[PartialVal, ...], Tuple[core.Value, ...]]: 462 """Traces a function into a Jaxpr, given PartialVals for inputs. 463 464 Returns (`jaxpr`, `out_pvals`, `consts`). The `jaxpr` contains only the 465 computation that depends on unknown inputs. The `out_pvals` are the PartialVal 466 for the outputs. The intermediate values that depend only on known inputs and 467 are needed to compute the output of `jaxpr` are in `consts` and are passed in 468 as the constvars of the `jaxpr`. The handling of the known outputs depends on 469 `instantiate`. 470 471 For example, given `fun` defined as follows:: 472 473 def fun(ki, ui): # ki will be a known input in this example 474 ka = ki + 2 475 kb = ka + 3 476 return (kb, ui + ka) 477 478 with `ki` the known PartialVal `1.`, and `ui` an unknown PartialVal. The only 479 computation that depends on unknown inputs is `ui + ka` and will be the only 480 computation in the body of the `jaxpr`. This computation depends on the known 481 intermediate value `ka`, which will be computed statically. Currently, such 482 constants are either embedded in the Jaxpr if they are scalars, or passed as a 483 constvar to `jaxpr`, and then the value of the actual constant will be in 484 `consts`: 485 486 When `instantiate=False` we get:: 487 488 jaxpr = 489 { lambda ka ; ki ui. 490 let c = add ui ka 491 in (*, c) } # known outputs are `*` 492 out_pvals = [PartialVal.known(6), PartialVal.unknown(ShapedArray)] 493 consts = [3] # the constant for `ka` 494 495 When `instantiate=True` we get:: 496 497 jaxpr = 498 { lambda ka kb ; ki ui. 499 let c = add ui ka 500 in (kb, c) } # known output are explicit 501 out_pvals = [PartialVal.unknown(ConcreteArray(6)), PartialVal.unknown(ShapedArray)] 502 consts = [3, 6] # values for `ka` and `kb` constvars 503 """ 504 with core.new_main(JaxprTrace) as main: 505 fun = trace_to_subjaxpr(fun, main, instantiate) 506 jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) 507 assert not env 508 del main, fun, env 509 510 return jaxpr, out_pvals, consts 511 512 513@lu.transformation 514def trace_to_subjaxpr(main: core.MainTrace, instantiate: Union[bool, Sequence[bool]], 515 pvals: Sequence[PartialVal]): 516 assert all([isinstance(pv, PartialVal) for pv in pvals]), pvals 517 trace = JaxprTrace(main, core.cur_sublevel()) 518 in_tracers = map(trace.new_arg, pvals) 519 ans = yield in_tracers, {} 520 instantiate = [instantiate] * len(ans) if isinstance(instantiate, bool) else instantiate 521 out_tracers = map(trace.full_raise, map(core.full_lower, ans)) 522 out_tracers = map(partial(instantiate_const_at, trace), instantiate, out_tracers) 523 jaxpr, consts, env = tracers_to_jaxpr(in_tracers, out_tracers) 524 out_pvals = [t.pval for t in out_tracers] 525 del trace, in_tracers, out_tracers 526 yield jaxpr, (out_pvals, consts, env) 527 528def instantiate_const_at(trace: JaxprTrace, instantiate: bool, tracer): 529 if instantiate: 530 return trace.instantiate_const(trace.full_raise(tracer)) 531 else: 532 return tracer 533 534 535FreeVar = namedtuple('FreeVar', ['val']) 536ConstVar = namedtuple('ConstVar', ['val']) 537LambdaBinding = namedtuple('LambdaBinding', []) 538class JaxprEqnRecipe(NamedTuple): 539 eqn_id: object 540 invars: Sequence[JaxprTracer] 541 outvars: 'Sequence[ref[JaxprTracer]]' 542 primitive: core.Primitive 543 params: Dict[str, Any] 544 source_info: Optional[source_info_util.Traceback] 545 546def new_eqn_recipe(invars: Sequence[JaxprTracer], 547 outvars: Sequence[JaxprTracer], 548 primitive: core.Primitive, 549 params: Dict[str, Any], 550 source_info: Optional[source_info_util.Traceback] 551 ) -> JaxprEqnRecipe: 552 """Constructs a new JaxEqnRecipe. 553 554 Params: 555 invars: the tracers for the primitive inputs. 556 outvars: the tracers for the primitive outputs. 557 primitive: the primitive. 558 params: the primitive params 559 """ 560 # TODO(necula): move these checks to core.check_jaxpr, and call in more places 561 if primitive.call_primitive or primitive.map_primitive: 562 assert "call_jaxpr" in params 563 if primitive.map_primitive: 564 assert ("in_axes" in params and 565 len(params["in_axes"]) == len(params["call_jaxpr"].invars)) 566 assert ("donated_invars" in params and 567 len(params["donated_invars"]) == len(params["call_jaxpr"].invars)) 568 return JaxprEqnRecipe(object(), tuple(invars), map(ref, outvars), primitive, 569 params, source_info) 570 571 572def recipe_to_eqn(getvar: Callable[[JaxprTracer], core.Atom], 573 recipe: JaxprEqnRecipe) -> core.JaxprEqn: 574 _, in_tracers, out_tracer_refs, primitive, params, source_info = recipe 575 out_tracers = [t_ref() for t_ref in out_tracer_refs] 576 invars = [getvar(t) for t in in_tracers] 577 outvars = [core.dropvar if t is None else cast(core.Var, getvar(t)) 578 for t in out_tracers] 579 return new_jaxpr_eqn(invars, outvars, primitive, params, source_info) 580 581def tracers_to_jaxpr( 582 in_tracers: Sequence[JaxprTracer], 583 out_tracers: Sequence[JaxprTracer] 584 ) -> Tuple[Jaxpr, Tuple[Any, ...], Tuple[Any, ...]]: 585 """Constructs Jaxpr given tracers for inputs and outputs. 586 587 Params: 588 in_tracers: the tracers that were created for the function inputs 589 out_tracers: the tracers that were output by the function. 590 591 Returns: a triple of a `Jaxpr`, a list of constant values corresponding to 592 the `constvars` in the returned Jaxps, and a list of environment values. 593 The vars for the environment values have been prepended to the Jaxpr's 594 `invars`. 595 """ 596 newvar = core.gensym() 597 t_to_var: Dict[int, core.Atom] = {} 598 def getvar(t: JaxprTracer) -> core.Atom: 599 var = t_to_var.get(id(t)) 600 if var is None: 601 aval = t.pval.get_aval() if not t.pval.is_known() else abstract_unit 602 var = t_to_var[id(t)] = newvar(aval) 603 return var 604 sorted_tracers = toposort(out_tracers) 605 invars = map(getvar, in_tracers) 606 eqns: List[core.JaxprEqn] = [] 607 env: Dict[core.Var, Any] = {} 608 consts: Dict[core.Var, Any] = {} 609 const_to_var: Dict[int, core.Var] = {} 610 def getconstvar(c): 611 var = const_to_var.get(id(c)) 612 if var is None: 613 var = const_to_var[id(c)] = newvar(get_aval(c)) 614 return var 615 processed_eqn_ids = set() 616 for t in sorted_tracers: 617 recipe = t.recipe 618 if isinstance(recipe, JaxprEqnRecipe): 619 if recipe.eqn_id not in processed_eqn_ids: 620 eqns.append(recipe_to_eqn(getvar, recipe)) 621 processed_eqn_ids.add(recipe.eqn_id) 622 elif isinstance(recipe, LambdaBinding): 623 if not any(t is in_tracer for in_tracer in in_tracers): 624 raise core.escaped_tracer_error( 625 t, "Tracer not among input tracers {}".format(t)) 626 assert in_tracers, "Lambda binding with no args" 627 elif isinstance(recipe, FreeVar): 628 env[cast(core.Var, getvar(t))] = recipe.val 629 elif isinstance(recipe, ConstVar): 630 v = t_to_var[id(t)] = getconstvar(recipe.val) 631 consts[v] = recipe.val 632 elif isinstance(recipe, Literal): 633 t_to_var[id(t)] = recipe 634 elif recipe is unit: 635 t_to_var[id(t)] = unitvar 636 else: 637 raise TypeError(recipe) 638 639 env_vars, env_vals = unzip2(env.items()) 640 const_vars, const_vals = unzip2(consts.items()) 641 # The env_vars are pre-pended to the invars 642 jaxpr = Jaxpr(const_vars, [*env_vars, *invars], map(getvar, out_tracers), eqns) 643 core.skip_checks or core.check_jaxpr(jaxpr) 644 return jaxpr, const_vals, env_vals 645 646@cache() 647def convert_constvars_jaxpr(jaxpr: Jaxpr): 648 """Moves the constvars to the start of invars.""" 649 core.skip_checks or core.check_jaxpr(jaxpr) 650 lifted_jaxpr = Jaxpr(constvars=(), 651 invars=jaxpr.constvars + jaxpr.invars, 652 outvars=jaxpr.outvars, eqns=jaxpr.eqns) 653 core.skip_checks or core.check_jaxpr(lifted_jaxpr) 654 return lifted_jaxpr 655 656def convert_envvars_to_constvars(jaxpr: Jaxpr, num_env_vars: int): 657 core.skip_checks or core.check_jaxpr(jaxpr) 658 env_vars, invars = split_list(jaxpr.invars, [num_env_vars]) 659 converted_jaxpr = Jaxpr(constvars=jaxpr.constvars + env_vars, 660 invars=invars, outvars=jaxpr.outvars, eqns=jaxpr.eqns) 661 core.skip_checks or core.check_jaxpr(converted_jaxpr) 662 return converted_jaxpr 663 664 665def _split_aval(unknown: bool, aval: AbstractValue) -> Tuple[AbstractValue, AbstractValue]: 666 return (abstract_unit, aval) if unknown else (aval, abstract_unit) 667 668def partial_eval_jaxpr(jaxpr: ClosedJaxpr, unknowns: Sequence[bool], 669 instantiate: Union[bool, Sequence[bool]], 670 ) -> Tuple[ClosedJaxpr, ClosedJaxpr, Sequence[bool]]: 671 """Specializes a Jaxpr given an indication of which inputs are known. 672 673 Returns: (jaxpr_known, jaxpr_unknown, out_unknowns). 674 675 `out_unknowns` specifies which outputs are unknown (depend on some unknown inputs). 676 `jaxpr_known` takes the same inputs as `jaxpr`, ignores the unknown inputs, 677 and performs *all* the computation in `jaxpr` that depends only on the known inputs. 678 Outputs correspond to those of `jaxpr`, with the unknown ones replaced with `*`, 679 appended with the known residuals (the intermediate computations in `jaxpr` 680 that depend only on known inputs and that are needed to compute the unknown outputs). 681 682 `jaxpr_unknown` takes the same inputs as `jaxpr` along with the known residuals 683 computed by `jaxpr_known` and returns the same outputs as `jaxpr` with the known 684 outputs replaced by `*`. 685 686 Roughly, `jaxpr(ki, ui)` is decomposed assuming `ki` and `ui` are the known and respectively 687 unknown inputs into: 688 689 jaxpr(ki, ui) = let kout, _, kresidual = jaxpr_known(kin, *) 690 let _, uout = jaxpr_unknown(ki, ui, kresidual) 691 in (kout, uout) 692 693 For example, if `jaxpr` is lambda ki, ui: let ka = ki + 2 694 in (ki + 3, ui + ka)" 695 then 696 `jaxpr_known` = lambda ki, ui: let ka = ki + 2 697 in (ki + 3, *, ka) 698 'jaxpr_unknown` = lambda ki, ui, ka: (*, ui + ka) 699 700 Note that if instantiate is True for a given output, then jaxpr_known always returns a 701 unit in its place. So when instantiate is True, the expectation is the one doesn't 702 run `jaxpr_known` for any of its outputs, but only to generate residuals that will allow 703 to obtain the full outputs once `jaxpr_unknown` is ran. Outputs known ahead of time will 704 simply get passed as residual constants and returned immediately. 705 """ 706 f = lu.wrap_init(core.jaxpr_as_fun(jaxpr)) 707 708 cell = [] 709 def fun(*vals): 710 pvals = [PartialVal.unknown(aval) if uk else PartialVal.known(val) 711 for aval, val, uk in zip(jaxpr.in_avals, vals, unknowns)] 712 jaxpr_2, out_pvals_2, consts_2 = trace_to_jaxpr(f, pvals, instantiate=instantiate) 713 out_pvs_2, out_consts_2 = unzip2(out_pvals_2) 714 cell.append((out_pvs_2, jaxpr_2, len(consts_2))) 715 return out_consts_2 + consts_2 716 717 # For jaxpr_known we pass core.unit for the unknown inputs, and known PartialVal for the 718 # known inputs. 719 in_avals = [abstract_unit if uk else a for a, uk in zip(jaxpr.in_avals, unknowns)] 720 jaxpr_1, out_avals, consts_1 = trace_to_jaxpr_dynamic(lu.wrap_init(fun), in_avals) 721 (out_pvs_2, jaxpr_2, num_res), = cell 722 assert len(jaxpr_2.constvars) == num_res 723 724 # jaxpr :: a -> b 725 # jaxpr_1 :: a1 -> [b1, res] 726 # jaxpr_2 :: res | a2 -> b2 727 # jaxpr_2 :: [a2, res] -> b2 728 jaxpr_2 = convert_constvars_jaxpr(jaxpr_2) 729 jaxpr_2.invars = jaxpr_2.invars[num_res:] + jaxpr_2.invars[:num_res] 730 for var, unknown in zip(jaxpr_2.invars[:len(unknowns)], unknowns): 731 if not unknown: 732 var.aval = abstract_unit 733 734 uk_out = [pv is not None for pv in out_pvs_2] 735 736 in_avals_1, in_avals_2 = unzip2(map(_split_aval, unknowns, jaxpr.in_avals)) 737 out_avals_1, out_avals_2 = unzip2(map(_split_aval, uk_out, jaxpr.out_avals)) 738 # out_avals_1 and in_avals_2 need the residuals added 739 res_avals = out_avals[len(jaxpr.out_avals):] 740 assert len(res_avals) == num_res 741 out_avals_1 = [*out_avals_1, *res_avals] 742 in_avals_2 = [*in_avals_2, *res_avals] 743 744 return ClosedJaxpr(jaxpr_1, consts_1), ClosedJaxpr(jaxpr_2, ()), uk_out 745 746 747remat_call_p = core.CallPrimitive('remat_call') 748remat_call = remat_call_p.bind 749remat_call_p.def_impl(core.call_impl) 750 751def _remat_partial_eval(trace, _, f, tracers, params): 752 concrete = params['concrete'] 753 754 # Unlike JaxprTrace.process_call, we want to form a jaxpr for the entirety of 755 # the function being called, not just for the unknown parts. To do that, we 756 # instantiate all the input tracers as constants in the jaxpr being formed. 757 # Those tracers might have concrete avals, and doing abstract interpretation 758 # on concrete avals engenders a tradeoff: it allows data-dependent Python 759 # control flow to work, but it can in some cases lead to redundant FLOPs (done 760 # both in the `bind` call below and the `core.jaxpr_as_fun` call). We use the 761 # `concrete` parameter to switch this behavior, and if `concrete` is False 762 # then we raise the avals to the Shaped level. 763 if concrete: 764 instantiated_tracers = map(trace.instantiate_const, tracers) 765 else: 766 instantiated_tracers = map(trace.instantiate_const_abstracted, tracers) 767 768 # Using the instantiated tracers, run call_bind like JaxprTrace.process_call. 769 in_pvals = [t.pval for t in instantiated_tracers] 770 if config.omnistaging_enabled: 771 jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval( 772 f, in_pvals, partial(remat_call_p.bind, **params), instantiate=False) 773 else: 774 with core.initial_style_staging(): # type: ignore 775 jaxpr, eval_out_pvals, consts, env_tracers = trace.partial_eval( 776 f, in_pvals, partial(remat_call_p.bind, **params), instantiate=False) 777 778 # Convert consts to inputs, since they may contain Tracer instances. 779 jaxpr = convert_constvars_jaxpr(jaxpr) 780 const_tracers = map(trace.new_instantiated_const, consts) 781 782 # Since we traced with everything marked as unknown, but we need to know which 783 # outputs are known/unknown, we use partial_eval_jaxpr to get out_unknowns. 784 closed_jaxpr = core.ClosedJaxpr(jaxpr, ()) 785 in_unknowns = ([False] * len(consts) + 786 [not t.is_known() for t in it.chain(env_tracers, tracers)]) 787 if config.omnistaging_enabled: 788 jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr( 789 closed_jaxpr, in_unknowns, instantiate=False) # type: ignore 790 else: 791 jaxpr_known, jaxpr_unknown, out_unknowns = partial_eval_jaxpr( 792 closed_jaxpr, in_unknowns, instantiate=False, trace_type=trace.main.trace_type) # type: ignore 793 out_knowns = [not b for b in out_unknowns] 794 out_known_pvals, out_unknown_pvals = _partition_knowns(eval_out_pvals, out_unknowns) 795 796 # Next, we need values for the outputs that should be known. Since consts 797 # weren't passed through Python for evaluation, we need to evaluate jaxpr_known, 798 # minus the residual outputs that we don't need. When `concrete=True`, as an 799 # optimization we can avoid redoing *some* redundant FLOPs, namely those that 800 # produced concrete avals at the output, simply by using those as computed 801 # values. For the use case of inverse-mode ad in op-by-op ("eager mode") 802 # evaluation, all the primal outputs should be concrete (thus not recomputed). 803 to_compute = [type(pval[0]) is not ConcreteArray 804 for uk, pval in zip(out_unknowns, eval_out_pvals) if not uk] 805 num_outputs = len(jaxpr_unknown.out_avals) 806 num_res = len(jaxpr_known.out_avals) - num_outputs 807 jaxpr_known_nores = _dce_jaxpr(jaxpr_known, out_knowns + [False] * num_res, drop_outputs=True) 808 jaxpr_known_comp = _dce_jaxpr(jaxpr_known_nores, to_compute) 809 _, in_consts = unzip2(t.pval for t in it.chain(env_tracers, tracers)) 810 reconstructed_consts = core.jaxpr_as_fun(jaxpr_known_comp)(*consts, *in_consts) 811 out_known_pvals = map(_reconstruct_pval, out_known_pvals, reconstructed_consts) 812 813 # Known outputs should keep propagating as constants 814 assert all(pv.is_known() for pv in out_known_pvals) 815 known_output_tracers = [trace.new_const(pval.get_known()) 816 for pval in out_known_pvals] 817 # Unknown outputs get wrapped in tracers with the appropriate recipe 818 unknown_output_tracers = [JaxprTracer(trace, out_pval, None) 819 for out_pval in out_unknown_pvals] 820 821 # dce jaxpr outputs 822 new_jaxpr = _dce_jaxpr(closed_jaxpr, out_unknowns, drop_outputs=True).jaxpr 823 new_params = dict(params, call_jaxpr=new_jaxpr) 824 825 # set up eqn for unknown outputs 826 in_tracers = (*const_tracers, *env_tracers, *instantiated_tracers) 827 eqn = new_eqn_recipe(in_tracers, unknown_output_tracers, remat_call_p, new_params, 828 source_info_util.current()) 829 for t in unknown_output_tracers: t.recipe = eqn 830 return _zip_knowns(known_output_tracers, unknown_output_tracers, out_unknowns) 831call_partial_eval_rules[remat_call_p] = _remat_partial_eval 832 833def _partition_knowns(pvals, unknowns: Sequence[bool]): 834 return ([e for e, unknown in zip(pvals, unknowns) if not unknown], 835 [e for e, unknown in zip(pvals, unknowns) if unknown]) 836 837def _zip_knowns(known_list, unknown_list, which_unknown: Sequence[bool]): 838 known_iter, unknown_iter = iter(known_list), iter(unknown_list) 839 return [next(unknown_iter) if uk else next(known_iter) for uk in which_unknown] 840 841 842def _dce_jaxpr(closed_jaxpr: ClosedJaxpr, outputs: Sequence[bool], drop_outputs=False) -> ClosedJaxpr: 843 new_jaxpr = _dce_open_jaxpr(closed_jaxpr.jaxpr, tuple(outputs), drop_outputs) 844 return core.ClosedJaxpr(new_jaxpr, closed_jaxpr.consts) 845 846@cache() 847def _dce_open_jaxpr(jaxpr: Jaxpr, outputs: Tuple[bool, ...], drop_outputs=False) -> Jaxpr: 848 # This dead-code elimination is pretty rudimentary, and in particular doesn't 849 # nontrivially DCE through scan, call, or other higher-order primitives. 850 # TODO(mattjj): better DCE 851 if drop_outputs: 852 new_outvars = [var for var, output in zip(jaxpr.outvars, outputs) if output] 853 else: 854 new_outvars = [var if output else unitvar 855 for var, output in zip(jaxpr.outvars, outputs)] 856 857 needed_vars = {v for v in new_outvars if type(v) is not Literal} 858 new_eqns = [] 859 for eqn in jaxpr.eqns[::-1]: 860 if set(eqn.outvars) & needed_vars: 861 new_eqns.append(eqn) 862 needed_vars.update(v for v in eqn.invars if type(v) is not Literal) 863 new_eqns = new_eqns[::-1] 864 return core.Jaxpr(jaxpr.constvars, jaxpr.invars, 865 new_outvars, new_eqns) 866 867@cache() 868def _drop_invars(jaxpr: Jaxpr, drop: Tuple[bool, ...]): 869 return core.Jaxpr(jaxpr.constvars, [v for v, d in zip(jaxpr.invars, drop) if not d], 870 jaxpr.outvars, jaxpr.eqns) 871 872 873def _reconstruct_pval(pval1: PartialVal, const2: core.Value): 874 pv1, _ = pval1 875 if pval1.is_known(): 876 return pval1 877 else: 878 if type(pv1) is ConcreteArray: 879 return PartialVal.known(pv1.val) # pytype: disable=attribute-error 880 else: 881 return PartialVal.known(const2) 882 883 884def move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool]) -> ClosedJaxpr: 885 """Reorder the `invars` to move to front the ones for which `to_move` is True.""" 886 assert not closed_jaxpr.jaxpr.constvars 887 assert len(closed_jaxpr.in_avals) == len(to_move) 888 new_invars = _move_to_front(closed_jaxpr.jaxpr.invars, to_move) 889 new_jaxpr = core.Jaxpr((), new_invars, closed_jaxpr.jaxpr.outvars, 890 closed_jaxpr.jaxpr.eqns) 891 new_closed_jaxpr = core.ClosedJaxpr(new_jaxpr, closed_jaxpr.consts) 892 return new_closed_jaxpr 893 894def _move_to_front(lst: Sequence, to_move: Sequence[bool]) -> Sequence: 895 return ([elt for elt, move in zip(lst, to_move) if move] + 896 [elt for elt, move in zip(lst, to_move) if not move]) 897 898 899class DynamicJaxprTracer(core.Tracer): 900 __slots__ = ['aval'] 901 902 def __init__(self, trace, aval, line_info=None): 903 self._trace = trace 904 self._line_info = line_info 905 self.aval = aval 906 907 def full_lower(self): 908 return self 909 910 def _contents(self): 911 return () 912 913 def _origin_msg(self): 914 invar_pos, progenitor_eqns = self._trace.frame.find_progenitors(self) 915 if invar_pos: 916 origin = (f"While tracing the function {self._trace.main.source_info}, " 917 "this concrete value was not available in Python because it " 918 "depends on the value of the arguments to " 919 f"{self._trace.main.source_info} at flattened positions {invar_pos}, " 920 "and the computation of these values is being staged out " 921 "(that is, delayed rather than executed eagerly).\n\n" 922 "You can use transformation parameters such as `static_argnums` " 923 "for `jit` to avoid tracing particular arguments of transformed " 924 "functions, though at the cost of more recompiles.") 925 elif progenitor_eqns: 926 msts = [f" operation {core.pp_eqn(eqn, print_shapes=True)}\n" 927 f" from line {source_info_util.summarize(eqn.source_info)}" 928 for eqn in progenitor_eqns] 929 origin = (f"While tracing the function {self._trace.main.source_info}, " 930 "this value became a tracer due to JAX operations on these lines:" 931 "\n\n" + "\n\n".join(msts)) 932 else: 933 origin = ("The error occured while tracing the function " 934 f"{self._trace.main.source_info}.") 935 return origin 936 937 def _assert_live(self) -> None: 938 if not self._trace.main.jaxpr_stack: # type: ignore 939 raise core.escaped_tracer_error(self, None) 940 941class JaxprStackFrame: 942 __slots__ = ['newvar', 'tracer_to_var', 'constid_to_var', 'constvar_to_val', 943 'tracers', 'eqns', 'invars'] 944 945 def __init__(self): 946 self.newvar = core.gensym() 947 self.tracer_to_var = {} 948 self.constid_to_var = {} 949 self.constvar_to_val = {} 950 self.tracers = [] # circ refs, frame->tracer->trace->main->frame, 951 self.eqns = [] # cleared when we pop frame from main 952 self.invars = [] 953 954 def to_jaxpr(self, in_tracers, out_tracers): 955 invars = [self.tracer_to_var[id(t)] for t in in_tracers] 956 outvars = [self.tracer_to_var[id(t)] for t in out_tracers] 957 constvars, constvals = unzip2(self.constvar_to_val.items()) 958 jaxpr = Jaxpr(constvars, invars, outvars, self.eqns) 959 jaxpr, constvals = _inline_literals(jaxpr, constvals) 960 out_avals = [t.aval for t in out_tracers] 961 return jaxpr, out_avals, constvals 962 963 def find_progenitors(self, tracer): 964 var = self.tracer_to_var.get(id(tracer)) 965 if not var: 966 return None, None 967 active_vars = {var} 968 for eqn in self.eqns[::-1]: 969 produced = set(eqn.outvars) & active_vars 970 if produced: 971 active_vars.difference_update(produced) 972 active_vars.update(eqn.invars) 973 invar_positions = [i for i, v in enumerate(self.invars) if v in active_vars] 974 constvars = active_vars & set(self.constvar_to_val) 975 const_eqns = [eqn for eqn in self.eqns if set(eqn.invars) & constvars] 976 return invar_positions, const_eqns 977 978def _inline_literals(jaxpr, constvals): 979 consts = dict(zip(jaxpr.constvars, constvals)) 980 newvar = core.gensym() 981 class var(dict): 982 def __missing__(self, v): 983 new_v = self[v] = newvar(v.aval) 984 return new_v 985 var = var() 986 987 def lit(var: core.Var) -> Optional[Any]: 988 val = consts.get(var) 989 if type(val) in core.literalable_types and not np.shape(val): 990 return Literal(val) 991 else: 992 return None 993 994 used = {v for eqn in jaxpr.eqns for v in eqn.invars} | set(jaxpr.outvars) 995 new_constvars = [var[v] for v in jaxpr.constvars if not lit(v)] 996 new_constvals = [c for v, c in zip(jaxpr.constvars, constvals) if not lit(v)] 997 new_invars = [var[v] for v in jaxpr.invars] 998 new_eqns = [new_jaxpr_eqn([lit(v) or var[v] for v in eqn.invars], 999 [var[v] if v in used else dropvar for v in eqn.outvars], 1000 eqn.primitive, eqn.params, eqn.source_info) 1001 for eqn in jaxpr.eqns] 1002 new_outvars = [lit(v) or var[v] for v in jaxpr.outvars] 1003 new_jaxpr = Jaxpr(new_constvars, new_invars, new_outvars, new_eqns) 1004 return new_jaxpr, new_constvals 1005 1006class DynamicJaxprTrace(core.Trace): 1007 __slots__ = [] # type: ignore 1008 1009 @property 1010 def frame(self): 1011 return self.main.jaxpr_stack[-1] # pytype: disable=attribute-error 1012 1013 def new_arg(self, aval): 1014 tracer = DynamicJaxprTracer(self, aval, source_info_util.current()) 1015 self.frame.tracers.append(tracer) 1016 self.frame.tracer_to_var[id(tracer)] = var = self.frame.newvar(aval) 1017 self.frame.invars.append(var) 1018 return tracer 1019 1020 def new_const(self, val): 1021 aval = raise_to_shaped(get_aval(val), weak_type=dtypes.is_weakly_typed(val)) 1022 tracer = DynamicJaxprTracer(self, aval, source_info_util.current()) 1023 self.frame.tracers.append(tracer) 1024 var = self.frame.tracer_to_var[id(tracer)] = self.getconstvar(val) 1025 self.frame.constvar_to_val[var] = val 1026 return tracer 1027 1028 pure = lift = sublift = new_const 1029 1030 def getvar(self, tracer): 1031 var = self.frame.tracer_to_var.get(id(tracer)) 1032 if var is None: 1033 raise core.escaped_tracer_error(tracer) 1034 return var 1035 1036 def makevar(self, tracer): 1037 var = self.frame.tracer_to_var.get(id(tracer)) 1038 assert var is None, "a jaxpr variable must be created only once per tracer" 1039 self.frame.tracers.append(tracer) 1040 var = self.frame.tracer_to_var[id(tracer)] = self.frame.newvar(tracer.aval) 1041 return var 1042 1043 def getconstvar(self, c): 1044 var = self.frame.constid_to_var.get(id(c)) 1045 if var is None: 1046 var = self.frame.constid_to_var[id(c)] = self.frame.newvar(get_aval(c)) 1047 return var 1048 1049 def instantiate_const(self, val): 1050 if (isinstance(val, Tracer) and val._trace.main is self.main 1051 and val._trace.sublevel == self.sublevel): 1052 return val 1053 else: 1054 return self.new_const(val) 1055 1056 def process_primitive(self, primitive, tracers, params): 1057 avals = [t.aval for t in tracers] 1058 out_avals = primitive.abstract_eval(*avals, **params) 1059 out_avals = [out_avals] if not primitive.multiple_results else out_avals 1060 source_info = source_info_util.current() 1061 out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] 1062 invars = map(self.getvar, tracers) 1063 outvars = map(self.makevar, out_tracers) 1064 eqn = new_jaxpr_eqn(invars, outvars, primitive, params, source_info) 1065 self.frame.eqns.append(eqn) 1066 return out_tracers if primitive.multiple_results else out_tracers.pop() 1067 1068 def process_call(self, call_primitive, f, tracers, params): 1069 in_avals = [t.aval for t in tracers] 1070 jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(f, self.main, in_avals) 1071 if not jaxpr.eqns: 1072 return core.eval_jaxpr(jaxpr, consts, *tracers) 1073 source_info = source_info_util.current() 1074 out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] 1075 invars = map(self.getvar, tracers) 1076 constvars = map(self.getvar, map(self.instantiate_const, consts)) 1077 outvars = map(self.makevar, out_tracers) 1078 new_params = dict(params, call_jaxpr=convert_constvars_jaxpr(jaxpr)) 1079 update_params = call_param_updaters.get(call_primitive) 1080 if update_params: 1081 new_params = update_params(new_params, [True] * len(tracers)) 1082 eqn = new_jaxpr_eqn([*constvars, *invars], outvars, call_primitive, 1083 new_params, source_info) 1084 self.frame.eqns.append(eqn) 1085 return out_tracers 1086 1087 def post_process_call(self, call_primitive, out_tracers, params): 1088 assert False # unreachable 1089 1090 def process_map(self, map_primitive, f, tracers, params): 1091 in_avals = [t.aval for t in tracers] 1092 axis_name, axis_size = params['axis_name'], params['axis_size'] 1093 reduced_in_avals = [core.mapped_aval(axis_size, in_axis, a) 1094 if in_axis is not None else a 1095 for a, in_axis in zip(in_avals, params['in_axes'])] 1096 with core.extend_axis_env(axis_name, axis_size, None): # type: ignore 1097 jaxpr, reduced_out_avals, consts = trace_to_subjaxpr_dynamic( 1098 f, self.main, reduced_in_avals) 1099 out_axes = params['out_axes_thunk']() 1100 out_avals = [core.unmapped_aval(params['axis_size'], out_axis, a) 1101 if out_axis is not None else a 1102 for a, out_axis in zip(reduced_out_avals, out_axes)] 1103 source_info = source_info_util.current() 1104 out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] 1105 invars = map(self.getvar, tracers) 1106 constvars = map(self.getvar, map(self.instantiate_const, consts)) 1107 outvars = map(self.makevar, out_tracers) 1108 new_in_axes = (None,) * len(consts) + params['in_axes'] 1109 new_params = dict(params, in_axes=new_in_axes, out_axes=out_axes, 1110 call_jaxpr=convert_constvars_jaxpr(jaxpr)) 1111 del new_params['out_axes_thunk'] 1112 update_params = call_param_updaters.get(map_primitive) 1113 if update_params: 1114 new_params = update_params(new_params, [True] * len(tracers)) 1115 eqn = new_jaxpr_eqn([*constvars, *invars], outvars, map_primitive, 1116 new_params, source_info) 1117 self.frame.eqns.append(eqn) 1118 return out_tracers 1119 1120 def post_process_map(self, map_primitive, out_tracers, params): 1121 assert False # unreachable 1122 1123 def process_custom_jvp_call(self, prim, fun, jvp, tracers): 1124 in_avals = [t.aval for t in tracers] 1125 fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals) 1126 closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ()) 1127 jvp_jaxpr_thunk = _memoize( 1128 lambda: trace_to_subjaxpr_dynamic(jvp, self.main, 2 * in_avals)[::2]) 1129 out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals] 1130 invars = map(self.getvar, tracers) 1131 constvars = map(self.getvar, map(self.instantiate_const, consts)) 1132 outvars = map(self.makevar, out_tracers) 1133 eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim.initial_style, 1134 dict(fun_jaxpr=closed_fun_jaxpr, 1135 jvp_jaxpr_thunk=jvp_jaxpr_thunk, 1136 num_consts=len(consts)), 1137 source_info_util.current()) 1138 self.frame.eqns.append(eqn) 1139 return out_tracers 1140 1141 def post_process_custom_jvp_call(self, out_tracers, params): 1142 assert False # unreachable 1143 1144 def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees): 1145 in_avals = [t.aval for t in tracers] 1146 fun_jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, self.main, in_avals) 1147 closed_fun_jaxpr = core.ClosedJaxpr(convert_constvars_jaxpr(fun_jaxpr), ()) 1148 fwd_jaxpr_thunk = _memoize( 1149 lambda: trace_to_subjaxpr_dynamic(fwd, self.main, in_avals)[::2]) 1150 out_tracers = [DynamicJaxprTracer(self, a) for a in out_avals] 1151 invars = map(self.getvar, tracers) 1152 constvars = map(self.getvar, map(self.instantiate_const, consts)) 1153 outvars = map(self.makevar, out_tracers) 1154 eqn = new_jaxpr_eqn([*constvars, *invars], outvars, prim.initial_style, 1155 dict(fun_jaxpr=closed_fun_jaxpr, 1156 fwd_jaxpr_thunk=fwd_jaxpr_thunk, 1157 num_consts=len(consts), 1158 bwd=bwd, out_trees=out_trees), 1159 source_info_util.current()) 1160 self.frame.eqns.append(eqn) 1161 return out_tracers 1162 1163 def post_process_custom_vjp_call(self, out_tracers, params): 1164 assert False # unreachable 1165 1166def _memoize(thunk): 1167 cell = [] 1168 saved_state = core.thread_local_state.trace_state.copy() 1169 def memoized(): 1170 if not cell: 1171 prev_state = core.thread_local_state.trace_state 1172 core.thread_local_state.trace_state = saved_state 1173 try: 1174 cell.append(thunk()) 1175 finally: 1176 core.thread_local_state.trace_state = prev_state 1177 return cell[0] 1178 return memoized 1179 1180 1181def trace_to_jaxpr_dynamic(fun: lu.WrappedFun, in_avals: Sequence[AbstractValue]): 1182 assert config.omnistaging_enabled 1183 with core.new_main(DynamicJaxprTrace, dynamic=True) as main: # type: ignore 1184 main.source_info = fun_sourceinfo(fun.f) # type: ignore 1185 main.jaxpr_stack = () # type: ignore 1186 jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals) 1187 del main, fun 1188 return jaxpr, out_avals, consts 1189 1190def trace_to_subjaxpr_dynamic(fun: lu.WrappedFun, main: core.MainTrace, 1191 in_avals: Sequence[AbstractValue]): 1192 frame = JaxprStackFrame() 1193 with extend_jaxpr_stack(main, frame): 1194 trace = DynamicJaxprTrace(main, core.cur_sublevel()) 1195 in_tracers = map(trace.new_arg, in_avals) 1196 ans = fun.call_wrapped(*in_tracers) 1197 out_tracers = map(trace.full_raise, ans) 1198 jaxpr, out_avals, consts = frame.to_jaxpr(in_tracers, out_tracers) 1199 del fun, main, trace, frame, in_tracers, out_tracers, ans 1200 return jaxpr, out_avals, consts 1201 1202@contextlib.contextmanager 1203def extend_jaxpr_stack(main, frame): 1204 main.jaxpr_stack = main.jaxpr_stack + (frame,) 1205 try: 1206 yield 1207 finally: 1208 assert frame is main.jaxpr_stack[-1] 1209 main.jaxpr_stack = main.jaxpr_stack[:-1] 1210 1211def trace_to_jaxpr_final(fun: lu.WrappedFun, in_avals: Sequence[AbstractValue]): 1212 assert config.omnistaging_enabled 1213 with core.new_base_main(DynamicJaxprTrace) as main: # type: ignore 1214 main.source_info = fun_sourceinfo(fun.f) # type: ignore 1215 main.jaxpr_stack = () # type: ignore 1216 jaxpr, out_avals, consts = trace_to_subjaxpr_dynamic(fun, main, in_avals) 1217 del fun, main 1218 return jaxpr, out_avals, consts 1219 1220def partial_eval_to_jaxpr_dynamic(fun: lu.WrappedFun, in_pvals: Sequence[PartialVal]): 1221 # This function provides a partial evaluation behavior used by Flax. We can't 1222 # use trace_to_jaxpr directly because of an interaction with the curent 1223 # custom_derivatives.py, which we work around by adding the EvalTrace. 1224 # TODO(mattjj): alias to trace_to_jaxpr after revising custom_derivatives.py 1225 assert config.omnistaging_enabled 1226 with core.new_main(core.EvalTrace, dynamic=True) as _: # type: ignore 1227 return trace_to_jaxpr(fun, in_pvals) 1228 1229def fun_sourceinfo(fun): 1230 if isinstance(fun, functools.partial): 1231 fun = fun.func 1232 try: 1233 filename = fun.__code__.co_filename 1234 lineno = fun.__code__.co_firstlineno 1235 return f"{fun.__name__} at {filename}:{lineno}" 1236 except AttributeError: 1237 return "<unknown>" 1238 1239 1240@config.register_omnistaging_disabler 1241@no_type_check 1242def omnistaging_disabler() -> None: 1243 global trace_to_jaxpr, partial_eval_jaxpr, staged_out_calls, StagingJaxprTrace 1244 1245 def trace_to_jaxpr(fun: lu.WrappedFun, pvals: Sequence[PartialVal], 1246 instantiate: Union[bool, Sequence[bool]] = False, 1247 stage_out=False, bottom=False, 1248 trace_type: Optional[Type[Trace]] = None, 1249 ) -> Tuple[Jaxpr, Tuple[PartialVal, ...], Tuple[core.Value, ...]]: 1250 """Traces a function into a Jaxpr, given PartialVals for inputs. 1251 1252 Returns (`jaxpr`, `out_pvals`, `consts`). The `jaxpr` contains only the 1253 computation that depends on unknown inputs. The `out_pvals` are the PartialVal 1254 for the outputs. The intermediate values that depend only on known inputs and 1255 are needed to compute the output of `jaxpr` are in `consts` and are passed in 1256 as the constvars of the `jaxpr`. The handling of the known outputs depends on 1257 `instantiate`. 1258 1259 For example, given `fun` defined as follows:: 1260 1261 def fun(ki, ui): # ki will be a known input in this example 1262 ka = ki + 2 1263 kb = ka + 3 1264 return (kb, ui + ka) 1265 1266 with `ki` the known PartialVal `1.`, and `ui` an unknown PartialVal. The only 1267 computation that depends on unknown inputs is `ui + ka` and will be the only 1268 computation in the body of the `jaxpr`. This computation depends on the known 1269 intermediate value `ka`, which will be computed statically. Currently, such 1270 constants are either embedded in the Jaxpr if they are scalars, or passed as a 1271 constvar to `jaxpr`, and then the value of the actual constant will be in 1272 `consts`: 1273 1274 When `instantiate=False` we get:: 1275 1276 jaxpr = 1277 { lambda ka ; ki ui. 1278 let c = add ui ka 1279 in (*, c) } # known outputs are `*` 1280 out_pvals = [PartialVal.known(6), PartialVal.unknown(ShapedArray)] 1281 consts = [3] # the constant for `ka` 1282 1283 When `instantiate=True` we get:: 1284 1285 jaxpr = 1286 { lambda ka kb ; ki ui. 1287 let c = add ui ka 1288 in (kb, c) } # known output are explicit 1289 out_pvals = [PartialVal.unknown(ConcreteArray(6)), PartialVal.unknown(ShapedArray)] 1290 consts = [3, 6] # values for `ka` and `kb` constvars 1291 """ 1292 trace_type = trace_type or (StagingJaxprTrace if stage_out else JaxprTrace) 1293 with core.new_main(trace_type, bottom=bottom) as main: 1294 fun = trace_to_subjaxpr(fun, main, instantiate) 1295 jaxpr, (out_pvals, consts, env) = fun.call_wrapped(pvals) 1296 assert not env 1297 del main 1298 1299 return jaxpr, out_pvals, consts 1300 1301 def partial_eval_jaxpr(jaxpr: ClosedJaxpr, unknowns: Sequence[bool], 1302 instantiate: Union[bool, Sequence[bool]], 1303 trace_type: Optional[Type[core.Trace]] 1304 ) -> Tuple[ClosedJaxpr, ClosedJaxpr, Sequence[bool]]: 1305 f = lu.wrap_init(core.jaxpr_as_fun(jaxpr)) 1306 1307 cell = [] 1308 def fun(*vals): 1309 pvals = [PartialVal.unknown(aval) if uk else PartialVal.known(val) 1310 for aval, val, uk in zip(jaxpr.in_avals, vals, unknowns)] 1311 jaxpr_2, out_pvals_2, consts_2 = trace_to_jaxpr(f, pvals, instantiate=instantiate, 1312 trace_type=trace_type) 1313 out_pvs_2, out_consts_2 = unzip2(out_pvals_2) 1314 cell.append((out_pvs_2, jaxpr_2, len(consts_2))) 1315 return out_consts_2 + consts_2 1316 1317 # The abstract_unit here doesn't really matter, because trace_to_jaxpr completely ignores 1318 # the avals, and it will never actually reach any primitives, because the `fun` above will 1319 # execute the jaxpr with the right avals (it reconstructs `pvals` inside). 1320 pvals = [PartialVal.unknown(abstract_unit) if uk else PartialVal.unknown(aval) 1321 for aval, uk in zip(jaxpr.in_avals, unknowns)] 1322 jaxpr_1, out_pvals, consts_1 = trace_to_jaxpr(lu.wrap_init(fun), pvals, instantiate=True) 1323 (out_pvs_2, jaxpr_2, num_res), = cell 1324 assert len(jaxpr_2.constvars) == num_res 1325 1326 # jaxpr :: a -> b 1327 # jaxpr_1 :: a1 -> [b1, res] 1328 # jaxpr_2 :: res | a2 -> b2 1329 # jaxpr_2 :: [a2, res] -> b2 1330 jaxpr_2 = convert_constvars_jaxpr(jaxpr_2) 1331 jaxpr_2.invars = jaxpr_2.invars[num_res:] + jaxpr_2.invars[:num_res] 1332 for var, unknown in zip(jaxpr_2.invars[:len(unknowns)], unknowns): 1333 if not unknown: 1334 var.aval = abstract_unit 1335 1336 uk_out = [pv is not None for pv in out_pvs_2] 1337 1338 return ClosedJaxpr(jaxpr_1, consts_1), ClosedJaxpr(jaxpr_2, ()), uk_out 1339 1340 def process_custom_jvp_call(self, prim, fun, jvp, tracers): 1341 # See comment at top of `JaxprTrace`. This method should be reachable 1342 # only when we stage out, and in that case we drop the custom differentiation 1343 # rules, because we do not need them. 1344 if not config.omnistaging_enabled: 1345 assert self.main.trace_type is StagingJaxprTrace 1346 return fun.call_wrapped(*tracers) 1347 JaxprTrace.process_custom_jvp_call = process_custom_jvp_call 1348 1349 def process_custom_vjp_call(self, prim, fun, fwd, bwd, tracers, out_trees): 1350 # See comment in the above process_custom_jvp_call method. 1351 if not config.omnistaging_enabled: 1352 assert self.main.trace_type is StagingJaxprTrace 1353 return fun.call_wrapped(*tracers) 1354 JaxprTrace.process_custom_vjp_call = process_custom_vjp_call 1355 1356 staged_out_calls = set() 1357 1358 class StagingJaxprTrace(JaxprTrace): pass 1359