1# Copyright 2018 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15 16import functools 17import itertools as it 18from typing import Any, Callable, Dict, Set, List 19 20from . import partial_eval as pe 21from ..config import config 22from .. import core 23from ..dtypes import dtype, float0 24from ..core import (Trace, Tracer, get_aval, call_p, Primitive, Literal, 25 raise_to_shaped) 26from ..ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like_aval, 27 zeros_like_p, Zero) 28from .._src.util import (unzip2, safe_map, safe_zip, partial, split_list, 29 wrap_name, as_hashable_function) 30from ..tree_util import register_pytree_node 31from .. import linear_util as lu 32from ..api_util import flatten_fun, flatten_fun_nokwargs 33from ..tree_util import tree_flatten, tree_unflatten, Partial 34from jax._src import source_info_util 35 36zip = safe_zip 37map = safe_map 38def identity(x): return x 39 40 41def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True) -> Any: 42 if not has_aux: 43 return jvpfun(jvp_subtrace(fun), instantiate) 44 else: 45 fun, aux = jvp_subtrace_aux(fun) 46 return jvpfun(fun, instantiate), aux 47 48 49@lu.transformation 50def jvpfun(instantiate, primals, tangents): 51 tangents = [Zero.from_value(t) if not isinstance(t, Zero) 52 and dtype(t) is float0 else t for t in tangents] 53 with core.new_main(JVPTrace) as main: 54 out_primals, out_tangents = yield (main, primals, tangents), {} 55 del main 56 if type(instantiate) is bool: 57 instantiate = [instantiate] * len(out_tangents) 58 out_tangents = [instantiate_zeros(t) if inst else t for t, inst 59 in zip(out_tangents, instantiate)] 60 yield out_primals, out_tangents 61 62@lu.transformation 63def jvp_subtrace(main, primals, tangents): 64 trace = JVPTrace(main, core.cur_sublevel()) 65 for x in list(primals) + list(tangents): 66 if isinstance(x, Tracer): 67 assert x._trace.level < trace.level 68 in_tracers = [JVPTracer(trace, x, t) if type(t) is not Zero else x 69 for x, t in zip(primals, tangents)] 70 ans = yield in_tracers, {} 71 out_tracers = map(trace.full_raise, ans) 72 yield unzip2([(out_tracer.primal, out_tracer.tangent) 73 for out_tracer in out_tracers]) 74 75@lu.transformation_with_aux 76def jvp_subtrace_aux(main, primals, tangents): 77 trace = JVPTrace(main, core.cur_sublevel()) 78 for x in list(primals) + list(tangents): 79 if isinstance(x, Tracer): 80 assert x._trace.level < trace.level 81 ans, aux = yield map(partial(JVPTracer, trace), primals, tangents), {} 82 ans_tracers = map(trace.full_raise, ans) 83 out_primals, out_tangents = unzip2((t.primal, t.tangent) for t in ans_tracers) 84 aux_primals = [core.full_lower(x.primal) 85 if isinstance(x, JVPTracer) and x._trace.level == trace.level 86 else x for x in aux] 87 yield (out_primals, out_tangents), aux_primals 88 89def linearize(traceable, *primals, **kwargs): 90 has_aux = kwargs.pop('has_aux', False) 91 if not has_aux: 92 jvpfun = jvp(traceable) 93 else: 94 jvpfun, aux = jvp(traceable, has_aux=True) 95 96 in_pvals = (tuple(pe.PartialVal.known(p) for p in primals) 97 + tuple(pe.PartialVal.unknown(get_aval(p).at_least_vspace()) 98 for p in primals)) 99 _, in_tree = tree_flatten(((primals, primals), {})) 100 jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree) 101 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals) 102 out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals) 103 assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals) 104 _, out_primals_consts = unzip2(out_primals_pvals) 105 jaxpr.invars = jaxpr.invars[len(primals):] 106 jaxpr.outvars = jaxpr.outvars[len(out_primals_pvals):] 107 if not has_aux: 108 return out_primals_consts, out_tangents_pvals, jaxpr, consts 109 else: 110 return out_primals_consts, out_tangents_pvals, jaxpr, consts, aux() 111 112def vjp(traceable, primals, has_aux=False): 113 if not has_aux: 114 out_primals, pvals, jaxpr, consts = linearize(traceable, *primals) 115 else: 116 out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True) 117 118 def unbound_vjp(pvals, jaxpr, consts, *cts): 119 cts = tuple(map(ignore_consts, cts, pvals)) 120 dummy_args = [UndefinedPrimal(v.aval) for v in jaxpr.invars] 121 arg_cts = backward_pass(jaxpr, consts, dummy_args, cts) 122 return map(instantiate_zeros, arg_cts) 123 124 # Ensure that vjp_ is a PyTree so that we can pass it from the forward to the backward 125 # pass in a custom VJP. 126 vjp_ = Partial(partial(unbound_vjp, pvals, jaxpr), consts) 127 if not has_aux: 128 return out_primals, vjp_ 129 else: 130 return out_primals, vjp_, aux 131 132def ignore_consts(ct, pval): 133 aval, const = pval 134 if isinstance(aval, core.AbstractValue): 135 return ct 136 elif aval is None: 137 return core.unit 138 else: 139 raise TypeError(aval) 140 141def unpair_pval(pval): 142 aval, const = pval 143 const_1, const_2 = const 144 if aval is None: 145 return (None, const_1), (None, const_2) 146 else: 147 aval_1, aval_2 = aval 148 return (aval_1, const_1), (aval_2, const_2) 149 150def replace_float0s(primal, tangent): 151 if dtype(tangent) is float0: 152 return core.zeros_like_float0(tangent, dtype(primal)) 153 else: 154 return tangent 155 156def recast_to_float0(primal, tangent): 157 if core.primal_dtype_to_tangent_dtype(dtype(primal)) == float0: 158 return Zero(get_aval(primal).at_least_vspace()) 159 else: 160 return tangent 161 162# NOTE: The FIXMEs below are caused by primal/tangent mixups (type errors if you will) 163def backward_pass(jaxpr: core.Jaxpr, consts, primals_in, cotangents_in): 164 if all(type(ct) is Zero for ct in cotangents_in): 165 return map(lambda v: Zero(v.aval), jaxpr.invars) 166 167 def write_cotangent(prim, v, ct): 168 # assert v not in primal_env 169 assert ct is not Zero, (prim, v.aval) # check for an old harmless type error 170 if ct is None or type(v) is Literal: 171 return 172 if type(ct) is Zero: 173 # FIXME: This triggers a lot of failures! 174 # assert v.aval == ct.aval, (prim, v.aval, ct.aval) 175 return 176 ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct 177 if not core.skip_checks: 178 ct_aval = core.get_aval(ct_env[v]) 179 joined_aval = core.lattice_join(v.aval, ct_aval).strip_weak_type() 180 assert v.aval.strip_weak_type() == joined_aval, (prim, v.aval, ct_aval) 181 182 def read_cotangent(v): 183 return ct_env.get(v, Zero(v.aval)) 184 185 def read_primal(v): 186 if type(v) is Literal: 187 return v.val 188 else: 189 return primal_env.get(v, UndefinedPrimal(v.aval)) 190 191 def write_primal(v, val): 192 if not is_undefined_primal(val): 193 primal_env[v] = val 194 195 primal_env: Dict[Any, Any] = {} 196 write_primal(core.unitvar, core.unit) 197 map(write_primal, jaxpr.constvars, consts) 198 # FIXME: invars can contain both primal and tangent values, and this line 199 # forces primal_in to contain UndefinedPrimals for tangent values! 200 map(write_primal, jaxpr.invars, primals_in) 201 202 # Find the last use of each cotangent so that they can be removed 203 # as soon as possible. 204 drop_cts: List[Set[Any]] = [] 205 seen_vars: Set[Any] = set(jaxpr.invars) 206 for eqn in jaxpr.eqns: 207 read_set = set(eqn.outvars) # NOTE: eqn is not transposed yet! 208 drop_cts.append(read_set - seen_vars) 209 seen_vars |= read_set 210 211 ct_env: Dict[Any, Any] = {} 212 map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in) 213 for eqn, to_drop in zip(jaxpr.eqns[::-1], drop_cts[::-1]): 214 # FIXME: Some invars correspond to tangents 215 invals = map(read_primal, eqn.invars) 216 if eqn.primitive.multiple_results: 217 cts_in = map(read_cotangent, eqn.outvars) 218 else: 219 cts_in, = map(read_cotangent, eqn.outvars) 220 with source_info_util.user_context(eqn.source_info): 221 if eqn.primitive.call_primitive or eqn.primitive.map_primitive: 222 cts_in_avals = [v.aval for v in eqn.outvars] 223 call_jaxpr, params = core.extract_call_jaxpr(eqn.primitive, eqn.params) 224 cts_out = get_primitive_transpose(eqn.primitive)( 225 params, call_jaxpr, invals, cts_in, cts_in_avals) 226 else: 227 cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals, 228 **eqn.params) 229 cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out 230 # FIXME: Some invars correspond to primals! 231 map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out) 232 for var in to_drop: 233 ct_env.pop(var, None) # NB: Constant cotangents might be missing 234 235 cotangents_out = map(read_cotangent, jaxpr.invars) 236 return cotangents_out 237 238class UndefinedPrimal: 239 __slots__ = ['aval'] 240 def __init__(self, aval): 241 self.aval = aval 242 def __repr__(self): 243 return 'UndefinedPrimal({})'.format(self.aval) 244 245def is_undefined_primal(x): 246 return type(x) is UndefinedPrimal 247 248register_pytree_node(UndefinedPrimal, 249 lambda z: ((), z.aval), 250 lambda aval, _: UndefinedPrimal(aval)) 251 252def get_primitive_transpose(p): 253 try: 254 return primitive_transposes[p] 255 except KeyError as err: 256 raise NotImplementedError( 257 "Transpose rule (for reverse-mode differentiation) for '{}' " 258 "not implemented".format(p)) from err 259 260@lu.transformation_with_aux 261def nonzero_tangent_outputs(*args, **kwargs): 262 results = (_, tangents_out) = yield args, kwargs 263 yield results, [type(r) is not Zero for r in tangents_out] 264 265 266class JVPTrace(Trace): 267 268 def pure(self, val): 269 tangent_zero = Zero(get_aval(val).at_least_vspace()) 270 return JVPTracer(self, val, tangent_zero) 271 272 def lift(self, val): 273 tangent_zero = Zero(get_aval(val).at_least_vspace()) 274 return JVPTracer(self, val, tangent_zero) 275 276 def sublift(self, val): 277 return JVPTracer(self, val.primal, val.tangent) 278 279 def process_primitive(self, primitive, tracers, params): 280 primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers) 281 jvp = primitive_jvps.get(primitive) 282 if not jvp: 283 msg = f"Differentiation rule for '{primitive}' not implemented" 284 raise NotImplementedError(msg) 285 primal_out, tangent_out = jvp(primals_in, tangents_in, **params) 286 if primitive.multiple_results: 287 return [JVPTracer(self, x, t) for x, t in zip(primal_out, tangent_out)] 288 else: 289 return JVPTracer(self, primal_out, tangent_out) 290 291 def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params): 292 assert call_primitive.multiple_results 293 primals, tangents = unzip2((t.primal, t.tangent) for t in tracers) 294 nonzero_tangents, tangent_tree_def = tree_flatten(tangents) 295 nz_tangents = [type(t) is not Zero for t in tangents] 296 params = dict(params, name=wrap_name(params['name'], 'jvp')) 297 f_jvp = jvp_subtrace(f, self.main) 298 if isinstance(call_primitive, core.MapPrimitive): 299 in_axes = params['in_axes'] 300 tangent_in_axes = [ax for ax, nz in zip(in_axes, nz_tangents) if nz] 301 out_axes_thunk = params['out_axes_thunk'] 302 f_jvp, nz_tangents_out = nonzero_tangent_outputs(f_jvp) 303 # The new thunk depends deterministically on the old thunk and the wrapped function. 304 # Any caching already has to include the wrapped function as part of the key, so we 305 # only use the previous thunk for equality checks. 306 # NOTE: This assumes that the output tangents being zero is a deterministic 307 # function of which input tangents were zero. 308 @as_hashable_function(closure=(tuple(nz_tangents), out_axes_thunk)) 309 def new_out_axes_thunk(): 310 out_axes = out_axes_thunk() 311 return (*out_axes, *(ax for ax, nz in zip(out_axes, nz_tangents_out()) if nz)) 312 params = dict(params, 313 in_axes=(*in_axes, *tangent_in_axes), 314 out_axes_thunk=new_out_axes_thunk) 315 f_jvp, out_tree_def = traceable(f_jvp, len(primals), tangent_tree_def) 316 update_params = call_param_updaters.get(call_primitive) 317 new_params = update_params(params, nz_tangents) if update_params else params 318 result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params) 319 primal_out, tangent_out = tree_unflatten(out_tree_def(), result) 320 return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)] 321 322 def post_process_call(self, call_primitive, out_tracers, params): 323 primals, tangents = unzip2((t.primal, t.tangent) for t in out_tracers) 324 out, treedef = tree_flatten((primals, tangents)) 325 tangents_nz = [type(t) is not Zero for t in tangents] 326 del primals, tangents 327 main = self.main 328 def todo(x): 329 primals, tangents = tree_unflatten(treedef, x) 330 trace = JVPTrace(main, core.cur_sublevel()) 331 return map(partial(JVPTracer, trace), primals, tangents) 332 if call_primitive.map_primitive: 333 def out_axes_transform(out_axes): 334 return (*out_axes, *(ax for ax, nz in zip(out_axes, tangents_nz) if nz)) 335 todo = (todo, out_axes_transform) 336 return out, todo 337 338 # The only difference between process_map and process_call is that 339 # the `in_axes` and `out_axes_thunk` params must be updated; 340 # that's handled in process_call. 341 process_map = process_call 342 post_process_map = post_process_call 343 344 def process_custom_jvp_call(self, _, __, f_jvp, tracers): 345 primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers) 346 primals_in = map(core.full_lower, primals_in) 347 tangents_in = map(instantiate_zeros, tangents_in) 348 # Cast float0 to zeros with the primal dtype because custom jvp rules don't 349 # currently handle float0s 350 tangents_in = map(replace_float0s, primals_in, tangents_in) 351 outs = f_jvp.call_wrapped(*it.chain(primals_in, tangents_in)) 352 primals_out, tangents_out = split_list(outs, [len(outs) // 2]) 353 tangents_out = map(recast_to_float0, primals_out, tangents_out) 354 return map(partial(JVPTracer, self), primals_out, tangents_out) 355 356 def post_process_custom_jvp_call(self, out_tracers, params): 357 raise CustomJVPException() 358 359 def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, *, out_trees): 360 primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers) 361 tangents_in = map(instantiate_zeros, tangents_in) 362 res_and_primals_out = fwd.call_wrapped(*map(core.full_lower, primals_in)) 363 out_tree, res_tree = out_trees() 364 res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves]) 365 avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out] 366 tangents_out = custom_lin_p.bind( 367 *res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd, 368 avals_out=avals_out) 369 tangents_out = map(recast_to_float0, primals_out, tangents_out) 370 return map(partial(JVPTracer, self), primals_out, tangents_out) 371 372 def post_process_custom_vjp_call(self, out_tracers, params): 373 raise CustomVJPException() 374 375 def join(self, xt, yt): 376 xz, yz = type(xt) is Zero, type(yt) is Zero 377 if xz == yz: 378 return xt, yt 379 elif yz and not xz: 380 return xt, zeros_like_jaxval(xt) 381 elif xz and not yz: 382 return zeros_like_jaxval(yt), yt 383 else: 384 raise TypeError((xt, yt)) 385 386 387class JVPTracer(Tracer): 388 __slots__ = ['primal', 'tangent'] 389 390 def __init__(self, trace, primal, tangent): 391 if not core.skip_checks: 392 _primal_tangent_shapes_match(primal, tangent) 393 self._trace = trace 394 self.primal = primal 395 self.tangent = tangent 396 397 @property 398 def aval(self): 399 # TODO(dougalm): add epsilon ball 400 return get_aval(self.primal) 401 402 def full_lower(self): 403 if type(self.tangent) is Zero: 404 return core.full_lower(self.primal) 405 else: 406 return self 407 408def _primal_tangent_shapes_match(primal, tangent): 409 if type(tangent) is not Zero: 410 primal_aval = raise_to_shaped(get_aval(primal), weak_type=False) 411 tangent_aval = raise_to_shaped(get_aval(tangent), weak_type=False) 412 assert primal_aval.shape == tangent_aval.shape, (primal_aval.shape, tangent_aval.shape) 413 expected_tangent_dtype = core.primal_dtype_to_tangent_dtype(primal_aval.dtype) 414 assert expected_tangent_dtype == tangent_aval.dtype, (expected_tangent_dtype, tangent_aval.dtype) 415 416call_param_updaters: Dict[core.Primitive, Callable] = {} 417call_transpose_param_updaters: Dict[core.Primitive, Callable] = {} 418 419 420# -------------------- Primitives -------------------- 421 422primitive_jvps : Dict[core.Primitive, Callable] = {} 423 424primitive_transposes: Dict[core.Primitive, Callable] = {} 425 426 427def deflinear(primitive, transpose_rule): 428 primitive_jvps[primitive] = partial(linear_jvp, primitive) 429 primitive_transposes[primitive] = partial(linear_transpose, transpose_rule) 430 431def linear_jvp(primitive, primals, tangents, **params): 432 val_out = primitive.bind(*primals, **params) 433 if all(type(tangent) is Zero for tangent in tangents): 434 return val_out, Zero.from_value(val_out) 435 else: 436 tangents = map(instantiate_zeros, tangents) 437 return val_out, primitive.bind(*tangents, **params) 438 439def linear_transpose(transpose_rule, cotangent, *args, **kwargs): 440 return Zero if type(cotangent) is Zero else transpose_rule(cotangent, **kwargs) 441 442 443def deflinear2(primitive, transpose_rule): 444 primitive_jvps[primitive] = partial(linear_jvp, primitive) 445 primitive_transposes[primitive] = partial(linear_transpose2, transpose_rule) 446 447def linear_transpose2(transpose_rule, cotangent, *args, **kwargs): 448 return Zero if type(cotangent) is Zero else transpose_rule(cotangent, *args, **kwargs) 449 450 451def defjvp(primitive, *jvprules): 452 assert isinstance(primitive, Primitive) 453 assert not primitive.multiple_results 454 primitive_jvps[primitive] = partial(standard_jvp, jvprules, primitive) 455 456 457def standard_jvp(jvprules, primitive, primals, tangents, **params): 458 val_out = primitive.bind(*primals, **params) 459 tangents_out = [rule(t, *primals, **params) for rule, t in zip(jvprules, tangents) 460 if rule is not None and type(t) is not Zero] 461 return val_out, functools.reduce(add_tangents, tangents_out, Zero.from_value(val_out)) 462 463def defjvp2(primitive, *jvprules): 464 assert isinstance(primitive, Primitive) 465 assert not primitive.multiple_results 466 primitive_jvps[primitive] = partial(standard_jvp2, jvprules, primitive) 467 468def standard_jvp2(jvprules, primitive, primals, tangents, **params): 469 val_out = primitive.bind(*primals, **params) 470 tangents_out = (rule(t, val_out, *primals, **params) for rule, t in zip(jvprules, tangents) 471 if rule is not None and type(t) is not Zero) 472 tangents_out = list(tangents_out) 473 return val_out, functools.reduce(add_tangents, tangents_out, Zero.from_value(val_out)) 474 475def add_tangents(x, y): 476 if type(x) is Zero: 477 return y 478 elif type(y) is Zero: 479 return x 480 else: 481 return add_jaxvals(x, y) 482 483 484def defbilinear_broadcasting(bcast, prim, lhs_rule, rhs_rule): 485 assert isinstance(prim, Primitive) 486 lhs_jvp = lambda g, x, y, **kwargs: prim.bind(bcast(g, y), y, **kwargs) 487 rhs_jvp = lambda g, x, y, **kwargs: prim.bind(x, bcast(g, x), **kwargs) 488 defjvp(prim, lhs_jvp, rhs_jvp) 489 primitive_transposes[prim] = partial(bilinear_transpose, lhs_rule, rhs_rule) 490defbilinear: Callable = partial(defbilinear_broadcasting, lambda g, x: g) 491 492def bilinear_transpose(lhs_rule, rhs_rule, cotangent, x, y, **kwargs): 493 assert is_undefined_primal(x) ^ is_undefined_primal(y) 494 if type(cotangent) is Zero: 495 return Zero 496 if is_undefined_primal(x): 497 out = lhs_rule(cotangent, y, **kwargs) 498 return Zero if out is Zero else (out, None) 499 else: 500 out = rhs_rule(cotangent, x, **kwargs) 501 return Zero if out is Zero else (None, out) 502 503 504def defjvp_zero(primitive): 505 assert isinstance(primitive, Primitive) 506 primitive_jvps[primitive] = partial(zero_jvp, primitive) 507 508def zero_jvp(primitive, primals, tangents, **params): 509 r = primitive.bind(*primals, **params) 510 return r, Zero.from_value(r) 511 512 513deflinear2(zeros_like_p, lambda t, _: [Zero.from_value(t)]) 514deflinear2(add_jaxvals_p, lambda t, *args: (t, t)) 515 516def instantiate_zeros(tangent): 517 if type(tangent) is Zero: 518 if isinstance(tangent.aval, Tracer): 519 return tangent.aval 520 return zeros_like_aval(tangent.aval) 521 else: 522 return tangent 523 524# This function seems similar to instantiate_zeros, but it is sometimes used 525# to instantiate zero abstract units with a different aval 526def instantiate_zeros_aval(aval, tangent): 527 if type(tangent) is Zero: 528 assert type(tangent.aval) is core.AbstractUnit or tangent.aval == aval 529 return zeros_like_aval(aval) 530 else: 531 return tangent 532 533@lu.transformation_with_aux 534def traceable(num_primals, in_tree_def, *primals_and_tangents): 535 new_primals = primals_and_tangents[:num_primals] 536 new_tangents = primals_and_tangents[num_primals:] 537 new_tangents = tree_unflatten(in_tree_def, new_tangents) 538 primal_out, tangent_out = yield (new_primals, new_tangents), {} 539 out_flat, tree_def = tree_flatten((primal_out, tangent_out)) 540 yield out_flat, tree_def 541 542 543def call_transpose(primitive, params, call_jaxpr, args, ct, _): 544 all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts 545 fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr) 546 fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def) 547 new_params = dict(params, name=wrap_name(params['name'], 'transpose')) 548 update_params = call_transpose_param_updaters.get(primitive) 549 if update_params: 550 new_params = update_params(new_params, map(is_undefined_primal, args), 551 [type(x) is not Zero for x in ct]) 552 out_flat = primitive.bind(fun, *all_args, **new_params) 553 return tree_unflatten(out_tree(), out_flat) 554primitive_transposes[core.call_p] = partial(call_transpose, call_p) 555 556 557def remat_transpose(params, call_jaxpr, primals_in, cotangents_in, cotangent_in_avals): 558 # backward_pass can only transpose linear computations, but the call_jaxpr embedded in 559 # remat contains primal (non-linear) equations too. Hence, we have to eliminate those 560 # (in this case via partial_eval) before we call into backward_pass again. 561 typed_call_jaxpr = core.ClosedJaxpr(call_jaxpr, []) 562 unknowns = map(is_undefined_primal, primals_in) 563 if config.omnistaging_enabled: 564 primal_jaxpr, tangent_jaxpr, out_unknowns = \ 565 pe.partial_eval_jaxpr(typed_call_jaxpr, unknowns=unknowns, instantiate=True) # type: ignore 566 else: 567 primal_jaxpr, tangent_jaxpr, out_unknowns = \ 568 pe.partial_eval_jaxpr(typed_call_jaxpr, unknowns=unknowns, instantiate=True, 569 trace_type=None) # type: ignore 570 571 def do_transpose(primals_in, cotangents_in): 572 # NOTE: This is passing in undefined primals in place of tangent arguments, but it 573 # should all work out, because we're only computing the primal part here. 574 residuals = core.jaxpr_as_fun(primal_jaxpr)(*primals_in)[len(cotangents_in):] 575 # Now that we have a purely linear jaxpr, we can transpose it 576 cotangents_out = backward_pass(tangent_jaxpr.jaxpr, (), primals_in + residuals, cotangents_in) 577 # backward_pass will return cotangents computed for all invars, but some of them 578 # are residuals appended by partial eval, so we need to skip those before we return. 579 return cotangents_out[:len(primals_in)] 580 581 flat_args, in_tree_def = tree_flatten((primals_in, cotangents_in)) 582 flat_do_transpose, out_tree = flatten_fun_nokwargs(lu.wrap_init(do_transpose), in_tree_def) 583 flat_cotangents_out = pe.remat_call_p.bind(flat_do_transpose, *flat_args, **params) 584 return tree_unflatten(out_tree(), flat_cotangents_out) 585primitive_transposes[pe.remat_call_p] = remat_transpose 586 587@lu.transformation_with_aux 588def nonzero_outputs(*args, **kwargs): 589 results = yield args, kwargs 590 yield results, [type(r) is not Zero for r in results] 591 592 593def map_transpose(primitive, params, call_jaxpr, args, ct, _): 594 all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts 595 fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr) 596 fun, nz_arg_cts = nonzero_outputs(fun) 597 fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def) 598 # Preserve axis for primal arguments, skip tangents (represented as undefined primals). 599 in_axes, out_axes = params['in_axes'], params['out_axes'] 600 new_in_axes = (*[axis for axis, x in zip(in_axes, args) 601 if not is_undefined_primal(x)], 602 *[axis for axis, x in zip(out_axes, ct) 603 if type(x) is not Zero]) 604 # The interim strategy we use below (until avals-with-names) only works 605 # when all outputs are mapped. 606 assert all(out_axis is not None for out_axis in out_axes), out_axes 607 # NOTE: This assumes that the output cotangents being zero is a deterministic 608 # function of which input cotangents were zero. 609 @as_hashable_function(closure=(in_axes, tuple(type(c) is Zero for c in ct))) 610 def out_axes_thunk(): 611 return tuple(axis or 0 for axis, nz in zip(in_axes, nz_arg_cts()) if nz) 612 new_params = dict(params, name=wrap_name(params['name'], 'transpose'), 613 in_axes=new_in_axes, out_axes_thunk=out_axes_thunk) 614 del new_params['out_axes'] 615 update_params = call_transpose_param_updaters.get(primitive) 616 if update_params: 617 new_params = update_params(new_params, map(is_undefined_primal, args), 618 [type(x) is not Zero for x in ct]) 619 out_flat = primitive.bind(fun, *all_args, **new_params) 620 arg_cts = tree_unflatten(out_tree(), out_flat) 621 622 # The freevars are being fanned out (not mapped). During transpose the 623 # dual of fan-out is fan-in-sum. We apply it to the unmapped invars. 624 assert len(in_axes) == len(arg_cts) 625 def unmap_zero(zero, in_axis): 626 return (zero if in_axis is None else 627 Zero(core.unmapped_aval(params['axis_size'], in_axis, zero.aval))) 628 arg_cts = (unmap_zero(arg_ct, in_axis) if type(arg_ct) is Zero else 629 arg_ct if in_axis is not None else 630 arg_ct.sum(0) 631 for arg_ct, in_axis in zip(arg_cts, in_axes)) 632 return tuple(arg_cts) 633 634 635def jvp_jaxpr(jaxpr, nonzeros, instantiate): 636 assert len(jaxpr.in_avals) == len(nonzeros) 637 f = lu.wrap_init(core.jaxpr_as_fun(jaxpr)) 638 f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate), nonzeros) 639 tangent_avals = [aval for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz] 640 avals_in = list(it.chain(jaxpr.in_avals, tangent_avals)) 641 jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in) 642 return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros() 643 644@lu.transformation_with_aux 645def f_jvp_traceable(nonzeros, *primals_and_nztangents): 646 num_primals = len(nonzeros) 647 primals = list(primals_and_nztangents[:num_primals]) 648 nonzero_tangents = iter(primals_and_nztangents[num_primals:]) 649 tangents = [next(nonzero_tangents) if nz else Zero.from_value(p) 650 for p, nz in zip(primals, nonzeros)] 651 primals_out, tangents_out = yield (primals, tangents), {} 652 out_nonzeros = [type(t) is not Zero for t in tangents_out] 653 nonzero_tangents_out = [t for t in tangents_out if type(t) is not Zero] 654 yield list(primals_out) + nonzero_tangents_out, out_nonzeros 655 656def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_out, tangents_out): 657 new_invars = _perm(primals_in, tangents_in, jaxpr.jaxpr.invars) 658 new_outvars = _perm(primals_out, tangents_out, jaxpr.jaxpr.outvars) 659 new_jaxpr = core.Jaxpr(jaxpr.jaxpr.constvars, 660 new_invars, new_outvars, jaxpr.jaxpr.eqns) 661 return core.ClosedJaxpr(new_jaxpr, jaxpr.consts) 662 663def _perm(primal_counts, tangent_counts, lst): 664 n = sum(primal_counts) 665 primals, tangents = lst[:n], lst[n:] 666 primal_groups = split_list(primals, primal_counts[:-1]) 667 tangent_groups = split_list(tangents, tangent_counts[:-1]) 668 return _interleave(primal_groups, tangent_groups) 669 670def _interleave(xs, ys): 671 assert len(xs) == len(ys) 672 return [e for pair in zip(xs, ys) for l in pair for e in l] 673 674 675custom_lin_p = core.Primitive('custom_lin') 676custom_lin_p.def_abstract_eval(lambda *_, avals_out, **__: avals_out) 677custom_lin_p.multiple_results = True 678 679def _raise_custom_vjp_error_on_jvp(*_, **__): 680 raise TypeError("can't apply forward-mode autodiff (jvp) to a custom_vjp " 681 "function.") 682custom_lin_p.def_impl(_raise_custom_vjp_error_on_jvp) 683 684def _custom_lin_transpose(cts_out, *invals, num_res, bwd, avals_out): 685 res, _ = split_list(invals, [num_res]) 686 cts_out = map(instantiate_zeros_aval, avals_out, cts_out) 687 cts_in = bwd.call_wrapped(*res, *cts_out) 688 return [None] * num_res + list(cts_in) 689primitive_transposes[custom_lin_p] = _custom_lin_transpose 690 691 692# TODO(mattjj): delete everything below here (deprecated custom_transforms) 693 694def defvjp_all(prim, custom_vjp): 695 # see https://github.com/google/jax/pull/636 696 name = prim.name 697 698 def fun_jvp(xs, ts, **params): 699 ts = map(instantiate_zeros, ts) 700 primals_and_tangents = fun_jvp_p.bind(*it.chain(xs, ts), **params) 701 primals, tangents = split_list(primals_and_tangents, [len(primals_and_tangents) // 2]) 702 if prim.multiple_results: 703 return primals, tangents 704 else: 705 primal, = primals 706 tangent, = tangents 707 return primal, tangent 708 primitive_jvps[prim] = fun_jvp 709 710 fun_jvp_p = core.Primitive('{name}_jvp'.format(name=name)) 711 fun_jvp_p.multiple_results = True 712 def fun_jvp_partial_eval(trace, *tracers, **params): 713 primals, tangents = split_list(tracers, [len(tracers) // 2]) 714 primals_out, vjp_py = custom_vjp(*primals, **params) 715 if not prim.multiple_results: 716 primals_out = [primals_out] 717 out_avals = [raise_to_shaped(get_aval(x)) for x in primals_out] 718 ct_pvals = [pe.PartialVal.unknown(aval) for aval in out_avals] 719 if config.omnistaging_enabled: 720 jaxpr, _, res = pe.trace_to_jaxpr(lu.wrap_init(vjp_py), ct_pvals, instantiate=True) 721 else: 722 with core.initial_style_staging(): # type: ignore 723 jaxpr, _, res = pe.trace_to_jaxpr(lu.wrap_init(vjp_py), ct_pvals, 724 instantiate=True) 725 tangents_out = fun_lin_p.bind(*it.chain(res, tangents), trans_jaxpr=jaxpr, 726 num_res=len(res), out_avals=out_avals) 727 return primals_out + tangents_out 728 pe.custom_partial_eval_rules[fun_jvp_p] = fun_jvp_partial_eval 729 730 fun_lin_p = core.Primitive('{name}_lin'.format(name=name)) 731 fun_lin_p.multiple_results = True 732 fun_lin_p.def_abstract_eval(lambda *_, **kwargs: kwargs['out_avals']) 733 def fun_lin_transpose(cts, *args, **kwargs): 734 num_res, trans_jaxpr = kwargs['num_res'], kwargs['trans_jaxpr'] 735 res, _ = split_list(args, [num_res]) 736 cts = map(instantiate_zeros_aval, kwargs['out_avals'], cts) 737 outs = core.eval_jaxpr(trans_jaxpr, res, *cts) 738 return [None] * num_res + outs 739 primitive_transposes[fun_lin_p] = fun_lin_transpose 740 741def defvjp(prim, *vjps): 742 def vjpmaker(*primals): 743 ans = prim.bind(*primals) 744 vjpfun = lambda ct: [vjp(ct, *primals) if vjp else zeros_like_jaxval(x) 745 for x, vjp in zip(primals, vjps)] 746 return ans, vjpfun 747 defvjp_all(prim, vjpmaker) 748 749def defvjp2(prim, *vjps): 750 def vjpmaker(*primals): 751 ans = prim.bind(*primals) 752 vjpfun = lambda ct: [vjp(ct, ans, *primals) if vjp else zeros_like_jaxval(x) 753 for x, vjp in zip(primals, vjps)] 754 return ans, vjpfun 755 defvjp_all(prim, vjpmaker) 756 757 758class CustomJVPException(Exception): 759 def __init__(self): 760 # TODO(mattjj): track source provenance on AD tracers, improve error 761 msg = ("Detected differentiation of a custom_jvp function with respect to " 762 "a closed-over value. That isn't supported because the custom JVP " 763 "rule only specifies how to differentiate the custom_jvp function " 764 "with respect to explicit input parameters. Try passing the " 765 "closed-over value into the custom_jvp function as an argument, and " 766 "adapting the custom_jvp rule.") 767 super().__init__(msg) 768 769class CustomVJPException(Exception): 770 def __init__(self): 771 # TODO(mattjj): track source provenance on AD tracers, improve error 772 msg = ("Detected differentiation of a custom_vjp function with respect to " 773 "a closed-over value. That isn't supported because the custom VJP " 774 "rule only specifies how to differentiate the custom_vjp function " 775 "with respect to explicit input parameters. Try passing the " 776 "closed-over value into the custom_vjp function as an argument, and " 777 "adapting the custom_vjp fwd and bwd rules.") 778 super().__init__(msg) 779 780@config.register_omnistaging_disabler 781def omnistaging_disabler() -> None: 782 global jvp_jaxpr 783 784 def jvp_jaxpr(jaxpr, nonzeros, instantiate): 785 assert len(jaxpr.in_avals) == len(nonzeros) 786 f = lu.wrap_init(core.jaxpr_as_fun(jaxpr)) 787 f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate), nonzeros) 788 tangent_avals = [aval for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz] 789 avals_in = list(it.chain(jaxpr.in_avals, tangent_avals)) 790 pvals = [pe.PartialVal.unknown(aval) for aval in avals_in] 791 jaxpr_out, _, consts = pe.trace_to_jaxpr(f_jvp, pvals, instantiate=True) 792 return core.ClosedJaxpr(jaxpr_out, consts), out_nonzeros() 793