1# Copyright 2018 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15
16import functools
17import itertools as it
18from typing import Any, Callable, Dict, Set, List
19
20from . import partial_eval as pe
21from ..config import config
22from .. import core
23from ..dtypes import dtype, float0
24from ..core import (Trace, Tracer, get_aval, call_p, Primitive, Literal,
25                    raise_to_shaped)
26from ..ad_util import (add_jaxvals, add_jaxvals_p, zeros_like_jaxval, zeros_like_aval,
27                       zeros_like_p, Zero)
28from .._src.util import (unzip2, safe_map, safe_zip, partial, split_list,
29                         wrap_name, as_hashable_function)
30from ..tree_util import register_pytree_node
31from .. import linear_util as lu
32from ..api_util import flatten_fun, flatten_fun_nokwargs
33from ..tree_util import tree_flatten, tree_unflatten, Partial
34from jax._src import source_info_util
35
36zip = safe_zip
37map = safe_map
38def identity(x): return x
39
40
41def jvp(fun: lu.WrappedFun, has_aux=False, instantiate=True) -> Any:
42  if not has_aux:
43    return jvpfun(jvp_subtrace(fun), instantiate)
44  else:
45    fun, aux = jvp_subtrace_aux(fun)
46    return jvpfun(fun, instantiate), aux
47
48
49@lu.transformation
50def jvpfun(instantiate, primals, tangents):
51  tangents = [Zero.from_value(t) if not isinstance(t, Zero)
52              and dtype(t) is float0 else t for t in tangents]
53  with core.new_main(JVPTrace) as main:
54    out_primals, out_tangents = yield (main, primals, tangents), {}
55    del main
56  if type(instantiate) is bool:
57    instantiate = [instantiate] * len(out_tangents)
58  out_tangents = [instantiate_zeros(t) if inst else t for t, inst
59                  in zip(out_tangents, instantiate)]
60  yield out_primals, out_tangents
61
62@lu.transformation
63def jvp_subtrace(main, primals, tangents):
64  trace = JVPTrace(main, core.cur_sublevel())
65  for x in list(primals) + list(tangents):
66    if isinstance(x, Tracer):
67      assert x._trace.level < trace.level
68  in_tracers = [JVPTracer(trace, x, t) if type(t) is not Zero else x
69                for x, t in zip(primals, tangents)]
70  ans = yield in_tracers, {}
71  out_tracers = map(trace.full_raise, ans)
72  yield unzip2([(out_tracer.primal, out_tracer.tangent)
73                for out_tracer in out_tracers])
74
75@lu.transformation_with_aux
76def jvp_subtrace_aux(main, primals, tangents):
77  trace = JVPTrace(main, core.cur_sublevel())
78  for x in list(primals) + list(tangents):
79    if isinstance(x, Tracer):
80      assert x._trace.level < trace.level
81  ans, aux = yield map(partial(JVPTracer, trace), primals, tangents), {}
82  ans_tracers = map(trace.full_raise, ans)
83  out_primals, out_tangents = unzip2((t.primal, t.tangent) for t in ans_tracers)
84  aux_primals = [core.full_lower(x.primal)
85                 if isinstance(x, JVPTracer) and x._trace.level == trace.level
86                 else x for x in aux]
87  yield (out_primals, out_tangents), aux_primals
88
89def linearize(traceable, *primals, **kwargs):
90  has_aux = kwargs.pop('has_aux', False)
91  if not has_aux:
92    jvpfun = jvp(traceable)
93  else:
94    jvpfun, aux = jvp(traceable, has_aux=True)
95
96  in_pvals = (tuple(pe.PartialVal.known(p) for p in primals)
97              + tuple(pe.PartialVal.unknown(get_aval(p).at_least_vspace())
98                    for p in primals))
99  _, in_tree = tree_flatten(((primals, primals), {}))
100  jvpfun_flat, out_tree = flatten_fun(jvpfun, in_tree)
101  jaxpr, out_pvals, consts = pe.trace_to_jaxpr(jvpfun_flat, in_pvals)
102  out_primals_pvals, out_tangents_pvals = tree_unflatten(out_tree(), out_pvals)
103  assert all(out_primal_pval.is_known() for out_primal_pval in out_primals_pvals)
104  _, out_primals_consts = unzip2(out_primals_pvals)
105  jaxpr.invars = jaxpr.invars[len(primals):]
106  jaxpr.outvars = jaxpr.outvars[len(out_primals_pvals):]
107  if not has_aux:
108    return out_primals_consts, out_tangents_pvals, jaxpr, consts
109  else:
110    return out_primals_consts, out_tangents_pvals, jaxpr, consts, aux()
111
112def vjp(traceable, primals, has_aux=False):
113  if not has_aux:
114    out_primals, pvals, jaxpr, consts = linearize(traceable, *primals)
115  else:
116    out_primals, pvals, jaxpr, consts, aux = linearize(traceable, *primals, has_aux=True)
117
118  def unbound_vjp(pvals, jaxpr, consts, *cts):
119    cts = tuple(map(ignore_consts, cts, pvals))
120    dummy_args = [UndefinedPrimal(v.aval) for v in jaxpr.invars]
121    arg_cts = backward_pass(jaxpr, consts, dummy_args, cts)
122    return map(instantiate_zeros, arg_cts)
123
124  # Ensure that vjp_ is a PyTree so that we can pass it from the forward to the backward
125  # pass in a custom VJP.
126  vjp_ =  Partial(partial(unbound_vjp, pvals, jaxpr), consts)
127  if not has_aux:
128    return out_primals, vjp_
129  else:
130    return out_primals, vjp_, aux
131
132def ignore_consts(ct, pval):
133  aval, const = pval
134  if isinstance(aval, core.AbstractValue):
135    return ct
136  elif aval is None:
137    return core.unit
138  else:
139    raise TypeError(aval)
140
141def unpair_pval(pval):
142  aval, const = pval
143  const_1, const_2 = const
144  if aval is None:
145    return (None, const_1), (None, const_2)
146  else:
147    aval_1, aval_2 = aval
148    return (aval_1, const_1), (aval_2, const_2)
149
150def replace_float0s(primal, tangent):
151  if dtype(tangent) is float0:
152    return core.zeros_like_float0(tangent, dtype(primal))
153  else:
154    return tangent
155
156def recast_to_float0(primal, tangent):
157  if core.primal_dtype_to_tangent_dtype(dtype(primal)) == float0:
158    return Zero(get_aval(primal).at_least_vspace())
159  else:
160    return tangent
161
162# NOTE: The FIXMEs below are caused by primal/tangent mixups (type errors if you will)
163def backward_pass(jaxpr: core.Jaxpr, consts, primals_in, cotangents_in):
164  if all(type(ct) is Zero for ct in cotangents_in):
165    return map(lambda v: Zero(v.aval), jaxpr.invars)
166
167  def write_cotangent(prim, v, ct):
168    # assert v not in primal_env
169    assert ct is not Zero, (prim, v.aval)  # check for an old harmless type error
170    if ct is None or type(v) is Literal:
171      return
172    if type(ct) is Zero:
173      # FIXME: This triggers a lot of failures!
174      # assert v.aval == ct.aval, (prim, v.aval, ct.aval)
175      return
176    ct_env[v] = add_tangents(ct_env[v], ct) if v in ct_env else ct
177    if not core.skip_checks:
178      ct_aval = core.get_aval(ct_env[v])
179      joined_aval = core.lattice_join(v.aval, ct_aval).strip_weak_type()
180      assert v.aval.strip_weak_type() == joined_aval, (prim, v.aval, ct_aval)
181
182  def read_cotangent(v):
183    return ct_env.get(v, Zero(v.aval))
184
185  def read_primal(v):
186    if type(v) is Literal:
187      return v.val
188    else:
189      return primal_env.get(v, UndefinedPrimal(v.aval))
190
191  def write_primal(v, val):
192    if not is_undefined_primal(val):
193      primal_env[v] = val
194
195  primal_env: Dict[Any, Any] = {}
196  write_primal(core.unitvar, core.unit)
197  map(write_primal, jaxpr.constvars, consts)
198  # FIXME: invars can contain both primal and tangent values, and this line
199  #        forces primal_in to contain UndefinedPrimals for tangent values!
200  map(write_primal, jaxpr.invars, primals_in)
201
202  # Find the last use of each cotangent so that they can be removed
203  # as soon as possible.
204  drop_cts: List[Set[Any]] = []
205  seen_vars: Set[Any] = set(jaxpr.invars)
206  for eqn in jaxpr.eqns:
207    read_set = set(eqn.outvars)  # NOTE: eqn is not transposed yet!
208    drop_cts.append(read_set - seen_vars)
209    seen_vars |= read_set
210
211  ct_env: Dict[Any, Any] = {}
212  map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in)
213  for eqn, to_drop in zip(jaxpr.eqns[::-1], drop_cts[::-1]):
214    # FIXME: Some invars correspond to tangents
215    invals = map(read_primal, eqn.invars)
216    if eqn.primitive.multiple_results:
217      cts_in = map(read_cotangent, eqn.outvars)
218    else:
219      cts_in, = map(read_cotangent, eqn.outvars)
220    with source_info_util.user_context(eqn.source_info):
221      if eqn.primitive.call_primitive or eqn.primitive.map_primitive:
222        cts_in_avals = [v.aval for v in eqn.outvars]
223        call_jaxpr, params = core.extract_call_jaxpr(eqn.primitive, eqn.params)
224        cts_out = get_primitive_transpose(eqn.primitive)(
225            params, call_jaxpr, invals, cts_in, cts_in_avals)
226      else:
227        cts_out = get_primitive_transpose(eqn.primitive)(cts_in, *invals,
228                                                         **eqn.params)
229    cts_out = [Zero(v.aval) for v in eqn.invars] if cts_out is Zero else cts_out
230    # FIXME: Some invars correspond to primals!
231    map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out)
232    for var in to_drop:
233      ct_env.pop(var, None)  # NB: Constant cotangents might be missing
234
235  cotangents_out = map(read_cotangent, jaxpr.invars)
236  return cotangents_out
237
238class UndefinedPrimal:
239  __slots__ = ['aval']
240  def __init__(self, aval):
241    self.aval = aval
242  def __repr__(self):
243    return 'UndefinedPrimal({})'.format(self.aval)
244
245def is_undefined_primal(x):
246  return type(x) is UndefinedPrimal
247
248register_pytree_node(UndefinedPrimal,
249                     lambda z: ((), z.aval),
250                     lambda aval, _: UndefinedPrimal(aval))
251
252def get_primitive_transpose(p):
253  try:
254    return primitive_transposes[p]
255  except KeyError as err:
256    raise NotImplementedError(
257        "Transpose rule (for reverse-mode differentiation) for '{}' "
258        "not implemented".format(p)) from err
259
260@lu.transformation_with_aux
261def nonzero_tangent_outputs(*args, **kwargs):
262  results = (_, tangents_out) = yield args, kwargs
263  yield results, [type(r) is not Zero for r in tangents_out]
264
265
266class JVPTrace(Trace):
267
268  def pure(self, val):
269    tangent_zero = Zero(get_aval(val).at_least_vspace())
270    return JVPTracer(self, val, tangent_zero)
271
272  def lift(self, val):
273    tangent_zero = Zero(get_aval(val).at_least_vspace())
274    return JVPTracer(self, val, tangent_zero)
275
276  def sublift(self, val):
277    return JVPTracer(self, val.primal, val.tangent)
278
279  def process_primitive(self, primitive, tracers, params):
280    primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
281    jvp = primitive_jvps.get(primitive)
282    if not jvp:
283      msg = f"Differentiation rule for '{primitive}' not implemented"
284      raise NotImplementedError(msg)
285    primal_out, tangent_out = jvp(primals_in, tangents_in, **params)
286    if primitive.multiple_results:
287      return [JVPTracer(self, x, t) for x, t in zip(primal_out, tangent_out)]
288    else:
289      return JVPTracer(self, primal_out, tangent_out)
290
291  def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params):
292    assert call_primitive.multiple_results
293    primals, tangents = unzip2((t.primal, t.tangent) for t in tracers)
294    nonzero_tangents, tangent_tree_def = tree_flatten(tangents)
295    nz_tangents = [type(t) is not Zero for t in tangents]
296    params = dict(params, name=wrap_name(params['name'], 'jvp'))
297    f_jvp = jvp_subtrace(f, self.main)
298    if isinstance(call_primitive, core.MapPrimitive):
299      in_axes = params['in_axes']
300      tangent_in_axes = [ax for ax, nz in zip(in_axes, nz_tangents) if nz]
301      out_axes_thunk = params['out_axes_thunk']
302      f_jvp, nz_tangents_out = nonzero_tangent_outputs(f_jvp)
303      # The new thunk depends deterministically on the old thunk and the wrapped function.
304      # Any caching already has to include the wrapped function as part of the key, so we
305      # only use the previous thunk for equality checks.
306      # NOTE: This assumes that the output tangents being zero is a deterministic
307      #       function of which input tangents were zero.
308      @as_hashable_function(closure=(tuple(nz_tangents), out_axes_thunk))
309      def new_out_axes_thunk():
310        out_axes = out_axes_thunk()
311        return (*out_axes, *(ax for ax, nz in zip(out_axes, nz_tangents_out()) if nz))
312      params = dict(params,
313                    in_axes=(*in_axes, *tangent_in_axes),
314                    out_axes_thunk=new_out_axes_thunk)
315    f_jvp, out_tree_def = traceable(f_jvp, len(primals), tangent_tree_def)
316    update_params = call_param_updaters.get(call_primitive)
317    new_params = update_params(params, nz_tangents) if update_params else params
318    result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
319    primal_out, tangent_out = tree_unflatten(out_tree_def(), result)
320    return [JVPTracer(self, p, t) for p, t in zip(primal_out, tangent_out)]
321
322  def post_process_call(self, call_primitive, out_tracers, params):
323    primals, tangents = unzip2((t.primal, t.tangent) for t in out_tracers)
324    out, treedef = tree_flatten((primals, tangents))
325    tangents_nz = [type(t) is not Zero for t in tangents]
326    del primals, tangents
327    main = self.main
328    def todo(x):
329      primals, tangents = tree_unflatten(treedef, x)
330      trace = JVPTrace(main, core.cur_sublevel())
331      return map(partial(JVPTracer, trace), primals, tangents)
332    if call_primitive.map_primitive:
333      def out_axes_transform(out_axes):
334        return (*out_axes, *(ax for ax, nz in zip(out_axes, tangents_nz) if nz))
335      todo = (todo, out_axes_transform)
336    return out, todo
337
338  # The only difference between process_map and process_call is that
339  # the `in_axes` and `out_axes_thunk` params must be updated;
340  # that's handled in process_call.
341  process_map = process_call
342  post_process_map = post_process_call
343
344  def process_custom_jvp_call(self, _, __, f_jvp, tracers):
345    primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
346    primals_in = map(core.full_lower, primals_in)
347    tangents_in = map(instantiate_zeros, tangents_in)
348    # Cast float0 to zeros with the primal dtype because custom jvp rules don't
349    # currently handle float0s
350    tangents_in = map(replace_float0s, primals_in, tangents_in)
351    outs = f_jvp.call_wrapped(*it.chain(primals_in, tangents_in))
352    primals_out, tangents_out = split_list(outs, [len(outs) // 2])
353    tangents_out = map(recast_to_float0, primals_out, tangents_out)
354    return map(partial(JVPTracer, self), primals_out, tangents_out)
355
356  def post_process_custom_jvp_call(self, out_tracers, params):
357    raise CustomJVPException()
358
359  def process_custom_vjp_call(self, _, __, fwd, bwd, tracers, *, out_trees):
360    primals_in, tangents_in = unzip2((t.primal, t.tangent) for t in tracers)
361    tangents_in = map(instantiate_zeros, tangents_in)
362    res_and_primals_out = fwd.call_wrapped(*map(core.full_lower, primals_in))
363    out_tree, res_tree = out_trees()
364    res, primals_out = split_list(res_and_primals_out, [res_tree.num_leaves])
365    avals_out = [raise_to_shaped(core.get_aval(x)) for x in primals_out]
366    tangents_out = custom_lin_p.bind(
367        *res, *tangents_in, num_res=res_tree.num_leaves, bwd=bwd,
368        avals_out=avals_out)
369    tangents_out = map(recast_to_float0, primals_out, tangents_out)
370    return map(partial(JVPTracer, self), primals_out, tangents_out)
371
372  def post_process_custom_vjp_call(self, out_tracers, params):
373    raise CustomVJPException()
374
375  def join(self, xt, yt):
376    xz, yz = type(xt) is Zero, type(yt) is Zero
377    if xz == yz:
378      return xt, yt
379    elif yz and not xz:
380      return xt, zeros_like_jaxval(xt)
381    elif xz and not yz:
382      return zeros_like_jaxval(yt), yt
383    else:
384      raise TypeError((xt, yt))
385
386
387class JVPTracer(Tracer):
388  __slots__ = ['primal', 'tangent']
389
390  def __init__(self, trace, primal, tangent):
391    if not core.skip_checks:
392      _primal_tangent_shapes_match(primal, tangent)
393    self._trace = trace
394    self.primal = primal
395    self.tangent = tangent
396
397  @property
398  def aval(self):
399    # TODO(dougalm): add epsilon ball
400    return get_aval(self.primal)
401
402  def full_lower(self):
403    if type(self.tangent) is Zero:
404      return core.full_lower(self.primal)
405    else:
406      return self
407
408def _primal_tangent_shapes_match(primal, tangent):
409  if type(tangent) is not Zero:
410    primal_aval = raise_to_shaped(get_aval(primal), weak_type=False)
411    tangent_aval = raise_to_shaped(get_aval(tangent), weak_type=False)
412    assert primal_aval.shape == tangent_aval.shape, (primal_aval.shape, tangent_aval.shape)
413    expected_tangent_dtype = core.primal_dtype_to_tangent_dtype(primal_aval.dtype)
414    assert expected_tangent_dtype == tangent_aval.dtype, (expected_tangent_dtype, tangent_aval.dtype)
415
416call_param_updaters: Dict[core.Primitive, Callable] = {}
417call_transpose_param_updaters: Dict[core.Primitive, Callable] = {}
418
419
420# -------------------- Primitives --------------------
421
422primitive_jvps : Dict[core.Primitive, Callable] = {}
423
424primitive_transposes: Dict[core.Primitive, Callable] = {}
425
426
427def deflinear(primitive, transpose_rule):
428  primitive_jvps[primitive] = partial(linear_jvp, primitive)
429  primitive_transposes[primitive] = partial(linear_transpose, transpose_rule)
430
431def linear_jvp(primitive, primals, tangents, **params):
432  val_out = primitive.bind(*primals, **params)
433  if all(type(tangent) is Zero for tangent in tangents):
434    return val_out, Zero.from_value(val_out)
435  else:
436    tangents = map(instantiate_zeros, tangents)
437    return val_out, primitive.bind(*tangents, **params)
438
439def linear_transpose(transpose_rule, cotangent, *args, **kwargs):
440  return Zero if type(cotangent) is Zero else transpose_rule(cotangent, **kwargs)
441
442
443def deflinear2(primitive, transpose_rule):
444  primitive_jvps[primitive] = partial(linear_jvp, primitive)
445  primitive_transposes[primitive] = partial(linear_transpose2, transpose_rule)
446
447def linear_transpose2(transpose_rule, cotangent, *args, **kwargs):
448  return Zero if type(cotangent) is Zero else transpose_rule(cotangent, *args, **kwargs)
449
450
451def defjvp(primitive, *jvprules):
452  assert isinstance(primitive, Primitive)
453  assert not primitive.multiple_results
454  primitive_jvps[primitive] = partial(standard_jvp, jvprules, primitive)
455
456
457def standard_jvp(jvprules, primitive, primals, tangents, **params):
458  val_out = primitive.bind(*primals, **params)
459  tangents_out = [rule(t, *primals, **params) for rule, t in zip(jvprules, tangents)
460                  if rule is not None and type(t) is not Zero]
461  return val_out, functools.reduce(add_tangents, tangents_out, Zero.from_value(val_out))
462
463def defjvp2(primitive, *jvprules):
464  assert isinstance(primitive, Primitive)
465  assert not primitive.multiple_results
466  primitive_jvps[primitive] = partial(standard_jvp2, jvprules, primitive)
467
468def standard_jvp2(jvprules, primitive, primals, tangents, **params):
469  val_out = primitive.bind(*primals, **params)
470  tangents_out = (rule(t, val_out, *primals, **params) for rule, t in zip(jvprules, tangents)
471                  if rule is not None and type(t) is not Zero)
472  tangents_out = list(tangents_out)
473  return val_out, functools.reduce(add_tangents, tangents_out, Zero.from_value(val_out))
474
475def add_tangents(x, y):
476  if type(x) is Zero:
477    return y
478  elif type(y) is Zero:
479    return x
480  else:
481    return add_jaxvals(x, y)
482
483
484def defbilinear_broadcasting(bcast, prim, lhs_rule, rhs_rule):
485  assert isinstance(prim, Primitive)
486  lhs_jvp = lambda g, x, y, **kwargs: prim.bind(bcast(g, y), y, **kwargs)
487  rhs_jvp = lambda g, x, y, **kwargs: prim.bind(x, bcast(g, x), **kwargs)
488  defjvp(prim, lhs_jvp, rhs_jvp)
489  primitive_transposes[prim] = partial(bilinear_transpose, lhs_rule, rhs_rule)
490defbilinear: Callable = partial(defbilinear_broadcasting, lambda g, x: g)
491
492def bilinear_transpose(lhs_rule, rhs_rule, cotangent, x, y, **kwargs):
493  assert is_undefined_primal(x) ^ is_undefined_primal(y)
494  if type(cotangent) is Zero:
495    return Zero
496  if is_undefined_primal(x):
497    out = lhs_rule(cotangent, y, **kwargs)
498    return Zero if out is Zero else (out, None)
499  else:
500    out = rhs_rule(cotangent, x, **kwargs)
501    return Zero if out is Zero else (None, out)
502
503
504def defjvp_zero(primitive):
505  assert isinstance(primitive, Primitive)
506  primitive_jvps[primitive] = partial(zero_jvp, primitive)
507
508def zero_jvp(primitive, primals, tangents, **params):
509  r = primitive.bind(*primals, **params)
510  return r, Zero.from_value(r)
511
512
513deflinear2(zeros_like_p, lambda t, _: [Zero.from_value(t)])
514deflinear2(add_jaxvals_p, lambda t, *args: (t, t))
515
516def instantiate_zeros(tangent):
517  if type(tangent) is Zero:
518    if isinstance(tangent.aval, Tracer):
519      return tangent.aval
520    return zeros_like_aval(tangent.aval)
521  else:
522    return tangent
523
524# This function seems similar to instantiate_zeros, but it is sometimes used
525# to instantiate zero abstract units with a different aval
526def instantiate_zeros_aval(aval, tangent):
527  if type(tangent) is Zero:
528    assert type(tangent.aval) is core.AbstractUnit or tangent.aval == aval
529    return zeros_like_aval(aval)
530  else:
531    return tangent
532
533@lu.transformation_with_aux
534def traceable(num_primals, in_tree_def, *primals_and_tangents):
535  new_primals  = primals_and_tangents[:num_primals]
536  new_tangents = primals_and_tangents[num_primals:]
537  new_tangents = tree_unflatten(in_tree_def, new_tangents)
538  primal_out, tangent_out = yield (new_primals, new_tangents), {}
539  out_flat, tree_def = tree_flatten((primal_out, tangent_out))
540  yield out_flat, tree_def
541
542
543def call_transpose(primitive, params, call_jaxpr, args, ct, _):
544  all_args, in_tree_def = tree_flatten(((), args, ct))  # empty consts
545  fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr)
546  fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
547  new_params = dict(params, name=wrap_name(params['name'], 'transpose'))
548  update_params = call_transpose_param_updaters.get(primitive)
549  if update_params:
550    new_params = update_params(new_params, map(is_undefined_primal, args),
551                               [type(x) is not Zero for x in ct])
552  out_flat = primitive.bind(fun, *all_args, **new_params)
553  return tree_unflatten(out_tree(), out_flat)
554primitive_transposes[core.call_p] = partial(call_transpose, call_p)
555
556
557def remat_transpose(params, call_jaxpr, primals_in, cotangents_in, cotangent_in_avals):
558  # backward_pass can only transpose linear computations, but the call_jaxpr embedded in
559  # remat contains primal (non-linear) equations too. Hence, we have to eliminate those
560  # (in this case via partial_eval) before we call into backward_pass again.
561  typed_call_jaxpr = core.ClosedJaxpr(call_jaxpr, [])
562  unknowns = map(is_undefined_primal, primals_in)
563  if config.omnistaging_enabled:
564    primal_jaxpr, tangent_jaxpr, out_unknowns = \
565      pe.partial_eval_jaxpr(typed_call_jaxpr, unknowns=unknowns, instantiate=True)  # type: ignore
566  else:
567    primal_jaxpr, tangent_jaxpr, out_unknowns = \
568      pe.partial_eval_jaxpr(typed_call_jaxpr, unknowns=unknowns, instantiate=True,
569                            trace_type=None)  # type: ignore
570
571  def do_transpose(primals_in, cotangents_in):
572    # NOTE: This is passing in undefined primals in place of tangent arguments, but it
573    #       should all work out, because we're only computing the primal part here.
574    residuals = core.jaxpr_as_fun(primal_jaxpr)(*primals_in)[len(cotangents_in):]
575    # Now that we have a purely linear jaxpr, we can transpose it
576    cotangents_out = backward_pass(tangent_jaxpr.jaxpr, (), primals_in + residuals, cotangents_in)
577    # backward_pass will return cotangents computed for all invars, but some of them
578    # are residuals appended by partial eval, so we need to skip those before we return.
579    return cotangents_out[:len(primals_in)]
580
581  flat_args, in_tree_def = tree_flatten((primals_in, cotangents_in))
582  flat_do_transpose, out_tree = flatten_fun_nokwargs(lu.wrap_init(do_transpose), in_tree_def)
583  flat_cotangents_out = pe.remat_call_p.bind(flat_do_transpose, *flat_args, **params)
584  return tree_unflatten(out_tree(), flat_cotangents_out)
585primitive_transposes[pe.remat_call_p] = remat_transpose
586
587@lu.transformation_with_aux
588def nonzero_outputs(*args, **kwargs):
589  results = yield args, kwargs
590  yield results, [type(r) is not Zero for r in results]
591
592
593def map_transpose(primitive, params, call_jaxpr, args, ct, _):
594  all_args, in_tree_def = tree_flatten(((), args, ct))  # empty consts
595  fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr)
596  fun, nz_arg_cts = nonzero_outputs(fun)
597  fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
598  # Preserve axis for primal arguments, skip tangents (represented as undefined primals).
599  in_axes, out_axes = params['in_axes'], params['out_axes']
600  new_in_axes = (*[axis for axis, x in zip(in_axes, args)
601                   if not is_undefined_primal(x)],
602                 *[axis for axis, x in zip(out_axes, ct)
603                   if type(x) is not Zero])
604  # The interim strategy we use below (until avals-with-names) only works
605  # when all outputs are mapped.
606  assert all(out_axis is not None for out_axis in out_axes), out_axes
607  # NOTE: This assumes that the output cotangents being zero is a deterministic
608  #       function of which input cotangents were zero.
609  @as_hashable_function(closure=(in_axes, tuple(type(c) is Zero for c in ct)))
610  def out_axes_thunk():
611    return tuple(axis or 0 for axis, nz in zip(in_axes, nz_arg_cts()) if nz)
612  new_params = dict(params, name=wrap_name(params['name'], 'transpose'),
613                    in_axes=new_in_axes, out_axes_thunk=out_axes_thunk)
614  del new_params['out_axes']
615  update_params = call_transpose_param_updaters.get(primitive)
616  if update_params:
617    new_params = update_params(new_params, map(is_undefined_primal, args),
618                               [type(x) is not Zero for x in ct])
619  out_flat = primitive.bind(fun, *all_args, **new_params)
620  arg_cts = tree_unflatten(out_tree(), out_flat)
621
622  # The freevars are being fanned out (not mapped). During transpose the
623  # dual of fan-out is fan-in-sum. We apply it to the unmapped invars.
624  assert len(in_axes) == len(arg_cts)
625  def unmap_zero(zero, in_axis):
626    return (zero if in_axis is None else
627            Zero(core.unmapped_aval(params['axis_size'], in_axis, zero.aval)))
628  arg_cts = (unmap_zero(arg_ct, in_axis) if type(arg_ct) is Zero else
629             arg_ct if in_axis is not None else
630             arg_ct.sum(0)
631             for arg_ct, in_axis in zip(arg_cts, in_axes))
632  return tuple(arg_cts)
633
634
635def jvp_jaxpr(jaxpr, nonzeros, instantiate):
636  assert len(jaxpr.in_avals) == len(nonzeros)
637  f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
638  f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate), nonzeros)
639  tangent_avals = [aval for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz]
640  avals_in = list(it.chain(jaxpr.in_avals, tangent_avals))
641  jaxpr_out, avals_out, literals_out = pe.trace_to_jaxpr_dynamic(f_jvp, avals_in)
642  return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros()
643
644@lu.transformation_with_aux
645def f_jvp_traceable(nonzeros, *primals_and_nztangents):
646  num_primals = len(nonzeros)
647  primals = list(primals_and_nztangents[:num_primals])
648  nonzero_tangents = iter(primals_and_nztangents[num_primals:])
649  tangents = [next(nonzero_tangents) if nz else Zero.from_value(p)
650              for p, nz in zip(primals, nonzeros)]
651  primals_out, tangents_out = yield (primals, tangents), {}
652  out_nonzeros = [type(t) is not Zero for t in tangents_out]
653  nonzero_tangents_out = [t for t in tangents_out if type(t) is not Zero]
654  yield list(primals_out) + nonzero_tangents_out, out_nonzeros
655
656def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_out, tangents_out):
657  new_invars = _perm(primals_in, tangents_in, jaxpr.jaxpr.invars)
658  new_outvars = _perm(primals_out, tangents_out, jaxpr.jaxpr.outvars)
659  new_jaxpr = core.Jaxpr(jaxpr.jaxpr.constvars,
660                         new_invars, new_outvars, jaxpr.jaxpr.eqns)
661  return core.ClosedJaxpr(new_jaxpr, jaxpr.consts)
662
663def _perm(primal_counts, tangent_counts, lst):
664  n = sum(primal_counts)
665  primals, tangents = lst[:n], lst[n:]
666  primal_groups = split_list(primals, primal_counts[:-1])
667  tangent_groups = split_list(tangents, tangent_counts[:-1])
668  return _interleave(primal_groups, tangent_groups)
669
670def _interleave(xs, ys):
671  assert len(xs) == len(ys)
672  return [e for pair in zip(xs, ys) for l in pair for e in l]
673
674
675custom_lin_p = core.Primitive('custom_lin')
676custom_lin_p.def_abstract_eval(lambda *_, avals_out, **__: avals_out)
677custom_lin_p.multiple_results = True
678
679def _raise_custom_vjp_error_on_jvp(*_, **__):
680  raise TypeError("can't apply forward-mode autodiff (jvp) to a custom_vjp "
681                  "function.")
682custom_lin_p.def_impl(_raise_custom_vjp_error_on_jvp)
683
684def _custom_lin_transpose(cts_out, *invals, num_res, bwd, avals_out):
685  res, _ = split_list(invals, [num_res])
686  cts_out = map(instantiate_zeros_aval, avals_out, cts_out)
687  cts_in = bwd.call_wrapped(*res, *cts_out)
688  return [None] * num_res + list(cts_in)
689primitive_transposes[custom_lin_p] = _custom_lin_transpose
690
691
692# TODO(mattjj): delete everything below here (deprecated custom_transforms)
693
694def defvjp_all(prim, custom_vjp):
695  # see https://github.com/google/jax/pull/636
696  name = prim.name
697
698  def fun_jvp(xs, ts, **params):
699    ts = map(instantiate_zeros, ts)
700    primals_and_tangents = fun_jvp_p.bind(*it.chain(xs, ts), **params)
701    primals, tangents = split_list(primals_and_tangents, [len(primals_and_tangents) // 2])
702    if prim.multiple_results:
703      return primals, tangents
704    else:
705      primal, = primals
706      tangent, = tangents
707      return primal, tangent
708  primitive_jvps[prim] = fun_jvp
709
710  fun_jvp_p = core.Primitive('{name}_jvp'.format(name=name))
711  fun_jvp_p.multiple_results = True
712  def fun_jvp_partial_eval(trace, *tracers, **params):
713    primals, tangents = split_list(tracers, [len(tracers) // 2])
714    primals_out, vjp_py = custom_vjp(*primals, **params)
715    if not prim.multiple_results:
716      primals_out = [primals_out]
717    out_avals = [raise_to_shaped(get_aval(x)) for x in primals_out]
718    ct_pvals = [pe.PartialVal.unknown(aval) for aval in out_avals]
719    if config.omnistaging_enabled:
720      jaxpr, _, res = pe.trace_to_jaxpr(lu.wrap_init(vjp_py), ct_pvals, instantiate=True)
721    else:
722      with core.initial_style_staging():  # type: ignore
723        jaxpr, _, res = pe.trace_to_jaxpr(lu.wrap_init(vjp_py), ct_pvals,
724                                          instantiate=True)
725    tangents_out = fun_lin_p.bind(*it.chain(res, tangents), trans_jaxpr=jaxpr,
726                                  num_res=len(res), out_avals=out_avals)
727    return primals_out + tangents_out
728  pe.custom_partial_eval_rules[fun_jvp_p] = fun_jvp_partial_eval
729
730  fun_lin_p = core.Primitive('{name}_lin'.format(name=name))
731  fun_lin_p.multiple_results = True
732  fun_lin_p.def_abstract_eval(lambda *_, **kwargs: kwargs['out_avals'])
733  def fun_lin_transpose(cts, *args, **kwargs):
734    num_res, trans_jaxpr = kwargs['num_res'], kwargs['trans_jaxpr']
735    res, _ = split_list(args, [num_res])
736    cts = map(instantiate_zeros_aval, kwargs['out_avals'], cts)
737    outs = core.eval_jaxpr(trans_jaxpr, res, *cts)
738    return [None] * num_res + outs
739  primitive_transposes[fun_lin_p] = fun_lin_transpose
740
741def defvjp(prim, *vjps):
742  def vjpmaker(*primals):
743    ans = prim.bind(*primals)
744    vjpfun = lambda ct: [vjp(ct, *primals) if vjp else zeros_like_jaxval(x)
745                         for x, vjp in zip(primals, vjps)]
746    return ans, vjpfun
747  defvjp_all(prim, vjpmaker)
748
749def defvjp2(prim, *vjps):
750  def vjpmaker(*primals):
751    ans = prim.bind(*primals)
752    vjpfun = lambda ct: [vjp(ct, ans, *primals) if vjp else zeros_like_jaxval(x)
753                         for x, vjp in zip(primals, vjps)]
754    return ans, vjpfun
755  defvjp_all(prim, vjpmaker)
756
757
758class CustomJVPException(Exception):
759  def __init__(self):
760    # TODO(mattjj): track source provenance on AD tracers, improve error
761    msg = ("Detected differentiation of a custom_jvp function with respect to "
762           "a closed-over value. That isn't supported because the custom JVP "
763           "rule only specifies how to differentiate the custom_jvp function "
764           "with respect to explicit input parameters. Try passing the "
765           "closed-over value into the custom_jvp function as an argument, and "
766           "adapting the custom_jvp rule.")
767    super().__init__(msg)
768
769class CustomVJPException(Exception):
770  def __init__(self):
771    # TODO(mattjj): track source provenance on AD tracers, improve error
772    msg = ("Detected differentiation of a custom_vjp function with respect to "
773           "a closed-over value. That isn't supported because the custom VJP "
774           "rule only specifies how to differentiate the custom_vjp function "
775           "with respect to explicit input parameters. Try passing the "
776           "closed-over value into the custom_vjp function as an argument, and "
777           "adapting the custom_vjp fwd and bwd rules.")
778    super().__init__(msg)
779
780@config.register_omnistaging_disabler
781def omnistaging_disabler() -> None:
782  global jvp_jaxpr
783
784  def jvp_jaxpr(jaxpr, nonzeros, instantiate):
785    assert len(jaxpr.in_avals) == len(nonzeros)
786    f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
787    f_jvp, out_nonzeros = f_jvp_traceable(jvp(f, instantiate=instantiate), nonzeros)
788    tangent_avals = [aval for aval, nz in zip(jaxpr.in_avals, nonzeros) if nz]
789    avals_in = list(it.chain(jaxpr.in_avals, tangent_avals))
790    pvals = [pe.PartialVal.unknown(aval) for aval in avals_in]
791    jaxpr_out, _, consts = pe.trace_to_jaxpr(f_jvp, pvals, instantiate=True)
792    return core.ClosedJaxpr(jaxpr_out, consts), out_nonzeros()
793