1# Copyright 2020 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15import warnings
16from functools import partial
17from typing import Dict, Any, Callable
18
19import jax
20from jax import core
21from jax import linear_util as lu
22from . import ad
23from . import partial_eval as pe
24from ..core import raise_to_shaped, get_aval, Literal, Jaxpr
25from ..api_util import flatten_fun_nokwargs
26from ..tree_util import tree_flatten, tree_unflatten, register_pytree_node
27from .._src.util import safe_map, safe_zip, split_list
28from .. import custom_derivatives
29from ..config import config
30
31map = safe_map
32zip = safe_zip
33
34def _initial_style_jaxpr(fun, in_avals):
35  jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals)
36  return core.ClosedJaxpr(jaxpr, consts)
37
38################################################################################
39# Reverse call primitive
40################################################################################
41
42class DontFlatten:
43  def __init__(self, val):
44    self.val = val
45
46register_pytree_node(DontFlatten,
47                     lambda x: ((), x.val),
48                     lambda val, _: DontFlatten(val))
49
50def get_concrete_array(aval):
51  assert isinstance(aval, core.ConcreteArray), aval
52  return aval.val
53
54def invertible(fun):
55  # TODO: Avoid materializing zeros!
56  ifun = custom_derivatives.custom_vjp(fun)
57
58  def fwd(*args):
59    flat_args, in_tree = tree_flatten(args)
60    in_pvals = tuple(pe.PartialVal.unknown(raise_to_shaped(get_aval(arg))) for arg in flat_args)
61    fun_flat, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree)
62    jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun_flat, in_pvals)
63    # TODO: Don't warn if consts contain JVP tracers?
64    if consts:
65      warnings.warn("Values that an @invertible function closes over will not have their " +
66                    "gradients computed correctly (their uses inside this function will be ignored)!")
67    # TODO: This requires the body to be jittable, but this shouldn't be necessary.
68    #       Is there a way to trace a jaxpr while running it?
69    flat_outs = core.eval_jaxpr(jaxpr, consts, *flat_args)
70    return tree_unflatten(out_tree(), flat_outs), (flat_args, flat_outs, consts, DontFlatten((jaxpr, in_tree)))
71
72  def bwd(res, cts):
73    flat_args, flat_outs, consts, aux = res
74    jaxpr, in_tree = aux.val
75    flat_cts, _ = tree_flatten(cts)
76    return tree_unflatten(in_tree, inv_backward_pass(jaxpr, consts, flat_args, flat_outs, flat_cts))
77
78  ifun.defvjp(fwd, bwd)
79
80  return ifun
81
82################################################################################
83# Custom inverse
84################################################################################
85
86class custom_ivjp:
87  def __init__(self, fun):
88    self.fun = fun
89    self.ivjp = None
90
91  def defivjp(self, ivjp):
92    # ivjp(inputs, outputs, output_cotangents) -> (inputs, input_cotangents)
93    self.ivjp = ivjp
94
95  def __call__(self, *args, **kwargs):
96    if self.ivjp is None:
97      msg = "No IVJP defined for custom_vjp function {}. Did you forget to use defivjp?"
98      raise AttributeError(msg.format(self.__name__))
99    args = custom_derivatives._resolve_kwargs(self.fun, args, kwargs)
100    # TODO: Support nondiff_argnums
101    fun, ivjp = lu.wrap_init(self.fun), lu.wrap_init(self.ivjp)
102    args_flat, in_tree = tree_flatten(args)
103    flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree)
104    flat_ivjp = _flatten_ivjp(ivjp, in_tree, out_tree)
105    out_flat = _custom_ivjp(flat_fun, flat_ivjp, args_flat)
106    return tree_unflatten(out_tree(), out_flat)
107
108def zip_with(fun, *args):
109  return map(lambda p: fun(*p), zip(*args))
110
111@lu.transformation
112def _flatten_ivjp(in_tree, out_tree, *args):
113  out_tree = out_tree()
114  num_inputs, num_outputs = in_tree.num_leaves, out_tree.num_leaves
115  assert len(args) == num_inputs + 2 * num_outputs
116  arg_leaves = split_list(args, [num_inputs, num_outputs])
117  py_args = zip_with(tree_unflatten, [in_tree, out_tree, out_tree], arg_leaves)
118  pair_out = yield py_args, {}
119  if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2:
120    raise TypeError("Expected a two element pair as output of custom ivjp")
121  yield tree_flatten(pair_out)[0]
122
123def _custom_ivjp(fun, ivjp, args):
124  in_avals = [raise_to_shaped(get_aval(x)) for x in args]
125  fun_jaxpr = _initial_style_jaxpr(fun, in_avals)
126  try:
127    ivjp_jaxpr = _initial_style_jaxpr(ivjp, in_avals + fun_jaxpr.out_avals * 2)
128  except RecursionError:
129    raise ValueError("Calls to {} from its custom ivjp aren't supported yet".format(fun.__name__))
130  return custom_ivjp_p.bind(*args, fun_jaxpr=fun_jaxpr,
131                                   ivjp_jaxpr=ivjp_jaxpr)
132
133def _custom_ivjp_impl(*args, fun_jaxpr, **_):
134  return core.jaxpr_as_fun(fun_jaxpr)(*args)
135
136custom_ivjp_p = core.Primitive('custom_ivjp')
137custom_ivjp_p.multiple_results = True
138custom_ivjp_p.def_impl(_custom_ivjp_impl)
139custom_ivjp_p.def_abstract_eval(lambda *_, fun_jaxpr, **__: fun_jaxpr.out_avals)
140
141def _custom_ivjp_jvp(primals, tangents, *, fun_jaxpr, ivjp_jaxpr):
142  primals_out = custom_ivjp_p.bind(*primals, fun_jaxpr=fun_jaxpr,
143                                             ivjp_jaxpr=ivjp_jaxpr)
144  fun = core.jaxpr_as_fun(fun_jaxpr)
145  # FIXME: This might compute the primals multiple times, but we only need to do
146  #        this trick while linearizing. It should be possible to do it through
147  #        a custom partial eval rule.
148  _, tangents_out = ad.jvp(lu.wrap_init(fun)).call_wrapped(primals, tangents)
149  return primals_out, tangents_out
150ad.primitive_jvps[custom_ivjp_p] = _custom_ivjp_jvp
151
152################################################################################
153# Backward pass implementation
154################################################################################
155
156def inv_backward_pass(jaxpr: core.Jaxpr, consts, primals_in, primals_out, cotangents_in):
157  if all(type(ct) is ad.Zero for ct in cotangents_in):
158    return map(lambda v: ad.Zero(v.aval), jaxpr.invars)
159
160  def write_cotangent(v, ct):
161    # assert v not in primal_env
162    if ct is not None and type(v) is not Literal:
163      ct_env[v] = ad.add_tangents(ct_env[v], ct) if v in ct_env else ct
164
165  def read_cotangent(v):
166    return ct_env.get(v, ad.Zero(v.aval))
167
168  def read_primal(v):
169    if type(v) is Literal:
170      return v.val
171    else:
172      return primal_env.get(v, ad.UndefinedPrimal(v.aval))
173
174  def write_primal(v, val):
175    if type(v) is Literal:
176      return
177    primal_env.setdefault(v, val)
178
179  # Invert while computing cotangents
180  ct_env: Dict[Any, Any] = {}
181  primal_env: Dict[Any, Any] = {}
182  write_primal(core.unitvar, core.unit)
183  map(write_primal, jaxpr.invars, primals_in)
184  map(write_primal, jaxpr.outvars, primals_out)
185  map(write_primal, jaxpr.constvars, consts)
186  map(write_cotangent, jaxpr.outvars, cotangents_in)
187  for eqn in jaxpr.eqns[::-1]:
188    primals_in = map(read_primal, eqn.invars)
189    primals_out = map(read_primal, eqn.outvars)
190    cts_in = map(read_cotangent, eqn.outvars)
191    should_invert = any(type(primal) is not ad.UndefinedPrimal
192                        for primal in primals_out)
193    should_vjp = any(type(ct) is not ad.Zero for ct in cts_in)
194    assert not eqn.primitive.call_primitive
195
196    # Skip primals equations that are only jvp coefficients and don't affect
197    # primal outputs.
198    if not should_invert and not should_vjp:
199      continue
200
201    def abstract(value):
202      return raise_to_shaped(value.aval if ad.is_undefined_primal(value) else get_aval(value))
203
204    # Get the ivjp_jaxpr
205    if eqn.primitive is custom_ivjp_p:
206      ivjp_jaxpr = eqn.params['ivjp_jaxpr']
207    else:
208      if eqn.primitive in primitive_ivjps:
209        complete_ivjp = lu.wrap_init(primitive_ivjps[eqn.primitive])
210      else:
211        complete_ivjp = lu.wrap_init(partial(synthesize_ivjp, eqn, map(ad.is_undefined_primal, primals_in)))
212      _, in_tree = tree_flatten(
213          tuple(map(abstract, x) for x in (primals_in, primals_out, primals_out)))
214      complete_ivjp_flat, _ = flatten_fun_nokwargs(complete_ivjp, in_tree)
215
216      in_avals = map(abstract, primals_in + primals_out + primals_out)
217      if config.omnistaging_enabled:
218        # TODO: Actually we do know some of the inputs, because they might be literals!
219        ivjp_jaxpr, out_pvals, _ = pe.trace_to_jaxpr(
220            complete_ivjp_flat, map(pe.PartialVal.unknown, in_avals), instantiate=True)
221      else:
222        ivjp_jaxpr, out_pvals, _ = pe.trace_to_jaxpr(  # type: ignore
223          complete_ivjp_flat, map(pe.PartialVal.unknown, in_avals),
224          instantiate=True, stage_out=False)  # type: ignore
225      assert not ivjp_jaxpr.constvars  # That might happen some time, but don't bother until then
226      ivjp_jaxpr = core.ClosedJaxpr(ivjp_jaxpr, [])
227
228    # Once we know what the ivjp can do exactly, we have to isolate the part we are
229    # actually able to compute with the values we have at hand.
230    num_inputs = len(eqn.invars)
231    unknowns = (map(ad.is_undefined_primal, primals_in) +
232                map(ad.is_undefined_primal, primals_out) +
233                [False] * len(cts_in))
234    if config.omnistaging_enabled:
235      jaxpr_known, jaxpr_unknown, out_unknowns = pe.partial_eval_jaxpr(  # type: ignore
236          ivjp_jaxpr, unknowns, instantiate=False)  # type:ignore
237    else:
238      jaxpr_known, jaxpr_unknown, out_unknowns = pe.partial_eval_jaxpr(  # type: ignore
239          ivjp_jaxpr, unknowns, instantiate=False, trace_type=None)  # type: ignore
240    unknown_rec_primals_in, unknown_cotangents = split_list(out_unknowns, [num_inputs])
241    # Make sure we're able to compute all cotangents. We don't really care if we
242    # can reconstruct or primals or not, although failure to do so might result in
243    # failing to compute cotangents later.
244    assert not any(unknown_cotangents)
245    # Remove residual outputs -- we won't be computing the unknown jaxpr anyway.
246    num_outputs = len(jaxpr_unknown.jaxpr.outvars)
247    jaxpr_known.jaxpr.outvars = jaxpr_known.jaxpr.outvars[:num_outputs]
248    # TODO: We could drop the outputs that correspond to primals that we already know.
249    #       This only matters in eager mode, so leaving it out for now...
250    ivjp = core.jaxpr_as_fun(jaxpr_known)
251    rec_primals_in, cts_out = split_list(ivjp(*primals_in, *primals_out, *cts_in),
252                                         [num_inputs])
253    # Unknown rec_primals_in are core.units, so we have to replace them
254    # with UnknownPrimals because that's what write_primal will ignore.
255    rec_primals_in = [prev if unknown else rec
256                      for prev, rec, unknown
257                      in zip(primals_in, rec_primals_in, unknown_rec_primals_in)]
258    map(write_primal, eqn.invars, rec_primals_in)
259    map(write_cotangent, [v for v in eqn.invars if type(v) is not Literal], cts_out)
260
261  # NOTE: We keep the cotangents associated with primal variables, while the contract of a
262  #       transpose is to return them in positions associated with tangent variables, which
263  #       is what causes this whole confusion.
264  return map(read_cotangent, jaxpr.invars)
265
266primitive_ivjps: Dict[core.Primitive, Callable] = {}
267
268def synthesize_ivjp(eqn, unknown_primals, primals_in, primals_out, cts_in):
269  # Invert eqn
270  if not eqn.primitive.multiple_results:
271    primals_out, = primals_out
272  rec_primals_in = get_primitive_inverse(eqn.primitive)(primals_out, *primals_in)
273  if len(eqn.invars) == 1:
274    rec_primals_in = (rec_primals_in,)
275
276  # Use the reconstructed primals if some primals_in were unknown, because we
277  # might have reconstructed some of them.
278  primals_in = map(lambda p, rp, unknown: rp if unknown else p,
279                   primals_in, rec_primals_in, unknown_primals)
280
281  # Compute the VJP of eqn
282  variable_invars = [v for v in eqn.invars if type(v) is not Literal]
283  variable_primals_in = [p for p, v in zip(primals_in, eqn.invars) if type(v) is not Literal]
284  eqn_jaxpr = Jaxpr([], variable_invars, eqn.outvars, [eqn])
285  eqn_callable = lambda args: core.eval_jaxpr(eqn_jaxpr, (), *args)
286  _, eqn_vjp = jax.vjp(eqn_callable, variable_primals_in)
287  # TODO: Instantiate zeros or (better) figure out how to avoid it!
288  cts_out, = eqn_vjp(cts_in)
289
290  return rec_primals_in, cts_out
291
292def split(l, parts):
293  assert len(l) % parts == 0
294  chunk = len(l) // parts
295  return [l[i:i + chunk] for i in range(0, len(l), chunk)]
296
297################################################################################
298# Primitive inverses
299################################################################################
300
301primitive_inverses: Dict[core.Primitive, Callable] = {}
302
303def get_primitive_inverse(p):
304  try:
305    return primitive_inverses[p]
306  except KeyError:
307    pass
308  raise NotImplementedError(
309    "Inverse rule for '{}' not implemented".format(p))
310
311
312def definverse(primitive, inverse_rule):
313  primitive_inverses[primitive] = inverse_rule
314  return inverse_rule
315
316
317@config.register_omnistaging_disabler
318def omnistaging_disabler() -> None:
319  global _initial_style_jaxpr, custom_jvp_call
320
321  def _initial_style_jaxpr(fun, in_avals):
322    in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
323    jaxpr, _, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True,
324                                         bottom=True, stage_out=False)  # type: ignore
325    assert not any(isinstance(c, core.Tracer) for c in consts)
326    return core.ClosedJaxpr(jaxpr, consts)
327