1# Copyright 2020 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15import warnings 16from functools import partial 17from typing import Dict, Any, Callable 18 19import jax 20from jax import core 21from jax import linear_util as lu 22from . import ad 23from . import partial_eval as pe 24from ..core import raise_to_shaped, get_aval, Literal, Jaxpr 25from ..api_util import flatten_fun_nokwargs 26from ..tree_util import tree_flatten, tree_unflatten, register_pytree_node 27from .._src.util import safe_map, safe_zip, split_list 28from .. import custom_derivatives 29from ..config import config 30 31map = safe_map 32zip = safe_zip 33 34def _initial_style_jaxpr(fun, in_avals): 35 jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals) 36 return core.ClosedJaxpr(jaxpr, consts) 37 38################################################################################ 39# Reverse call primitive 40################################################################################ 41 42class DontFlatten: 43 def __init__(self, val): 44 self.val = val 45 46register_pytree_node(DontFlatten, 47 lambda x: ((), x.val), 48 lambda val, _: DontFlatten(val)) 49 50def get_concrete_array(aval): 51 assert isinstance(aval, core.ConcreteArray), aval 52 return aval.val 53 54def invertible(fun): 55 # TODO: Avoid materializing zeros! 56 ifun = custom_derivatives.custom_vjp(fun) 57 58 def fwd(*args): 59 flat_args, in_tree = tree_flatten(args) 60 in_pvals = tuple(pe.PartialVal.unknown(raise_to_shaped(get_aval(arg))) for arg in flat_args) 61 fun_flat, out_tree = flatten_fun_nokwargs(lu.wrap_init(fun), in_tree) 62 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun_flat, in_pvals) 63 # TODO: Don't warn if consts contain JVP tracers? 64 if consts: 65 warnings.warn("Values that an @invertible function closes over will not have their " + 66 "gradients computed correctly (their uses inside this function will be ignored)!") 67 # TODO: This requires the body to be jittable, but this shouldn't be necessary. 68 # Is there a way to trace a jaxpr while running it? 69 flat_outs = core.eval_jaxpr(jaxpr, consts, *flat_args) 70 return tree_unflatten(out_tree(), flat_outs), (flat_args, flat_outs, consts, DontFlatten((jaxpr, in_tree))) 71 72 def bwd(res, cts): 73 flat_args, flat_outs, consts, aux = res 74 jaxpr, in_tree = aux.val 75 flat_cts, _ = tree_flatten(cts) 76 return tree_unflatten(in_tree, inv_backward_pass(jaxpr, consts, flat_args, flat_outs, flat_cts)) 77 78 ifun.defvjp(fwd, bwd) 79 80 return ifun 81 82################################################################################ 83# Custom inverse 84################################################################################ 85 86class custom_ivjp: 87 def __init__(self, fun): 88 self.fun = fun 89 self.ivjp = None 90 91 def defivjp(self, ivjp): 92 # ivjp(inputs, outputs, output_cotangents) -> (inputs, input_cotangents) 93 self.ivjp = ivjp 94 95 def __call__(self, *args, **kwargs): 96 if self.ivjp is None: 97 msg = "No IVJP defined for custom_vjp function {}. Did you forget to use defivjp?" 98 raise AttributeError(msg.format(self.__name__)) 99 args = custom_derivatives._resolve_kwargs(self.fun, args, kwargs) 100 # TODO: Support nondiff_argnums 101 fun, ivjp = lu.wrap_init(self.fun), lu.wrap_init(self.ivjp) 102 args_flat, in_tree = tree_flatten(args) 103 flat_fun, out_tree = flatten_fun_nokwargs(fun, in_tree) 104 flat_ivjp = _flatten_ivjp(ivjp, in_tree, out_tree) 105 out_flat = _custom_ivjp(flat_fun, flat_ivjp, args_flat) 106 return tree_unflatten(out_tree(), out_flat) 107 108def zip_with(fun, *args): 109 return map(lambda p: fun(*p), zip(*args)) 110 111@lu.transformation 112def _flatten_ivjp(in_tree, out_tree, *args): 113 out_tree = out_tree() 114 num_inputs, num_outputs = in_tree.num_leaves, out_tree.num_leaves 115 assert len(args) == num_inputs + 2 * num_outputs 116 arg_leaves = split_list(args, [num_inputs, num_outputs]) 117 py_args = zip_with(tree_unflatten, [in_tree, out_tree, out_tree], arg_leaves) 118 pair_out = yield py_args, {} 119 if not isinstance(pair_out, (list, tuple)) or len(pair_out) != 2: 120 raise TypeError("Expected a two element pair as output of custom ivjp") 121 yield tree_flatten(pair_out)[0] 122 123def _custom_ivjp(fun, ivjp, args): 124 in_avals = [raise_to_shaped(get_aval(x)) for x in args] 125 fun_jaxpr = _initial_style_jaxpr(fun, in_avals) 126 try: 127 ivjp_jaxpr = _initial_style_jaxpr(ivjp, in_avals + fun_jaxpr.out_avals * 2) 128 except RecursionError: 129 raise ValueError("Calls to {} from its custom ivjp aren't supported yet".format(fun.__name__)) 130 return custom_ivjp_p.bind(*args, fun_jaxpr=fun_jaxpr, 131 ivjp_jaxpr=ivjp_jaxpr) 132 133def _custom_ivjp_impl(*args, fun_jaxpr, **_): 134 return core.jaxpr_as_fun(fun_jaxpr)(*args) 135 136custom_ivjp_p = core.Primitive('custom_ivjp') 137custom_ivjp_p.multiple_results = True 138custom_ivjp_p.def_impl(_custom_ivjp_impl) 139custom_ivjp_p.def_abstract_eval(lambda *_, fun_jaxpr, **__: fun_jaxpr.out_avals) 140 141def _custom_ivjp_jvp(primals, tangents, *, fun_jaxpr, ivjp_jaxpr): 142 primals_out = custom_ivjp_p.bind(*primals, fun_jaxpr=fun_jaxpr, 143 ivjp_jaxpr=ivjp_jaxpr) 144 fun = core.jaxpr_as_fun(fun_jaxpr) 145 # FIXME: This might compute the primals multiple times, but we only need to do 146 # this trick while linearizing. It should be possible to do it through 147 # a custom partial eval rule. 148 _, tangents_out = ad.jvp(lu.wrap_init(fun)).call_wrapped(primals, tangents) 149 return primals_out, tangents_out 150ad.primitive_jvps[custom_ivjp_p] = _custom_ivjp_jvp 151 152################################################################################ 153# Backward pass implementation 154################################################################################ 155 156def inv_backward_pass(jaxpr: core.Jaxpr, consts, primals_in, primals_out, cotangents_in): 157 if all(type(ct) is ad.Zero for ct in cotangents_in): 158 return map(lambda v: ad.Zero(v.aval), jaxpr.invars) 159 160 def write_cotangent(v, ct): 161 # assert v not in primal_env 162 if ct is not None and type(v) is not Literal: 163 ct_env[v] = ad.add_tangents(ct_env[v], ct) if v in ct_env else ct 164 165 def read_cotangent(v): 166 return ct_env.get(v, ad.Zero(v.aval)) 167 168 def read_primal(v): 169 if type(v) is Literal: 170 return v.val 171 else: 172 return primal_env.get(v, ad.UndefinedPrimal(v.aval)) 173 174 def write_primal(v, val): 175 if type(v) is Literal: 176 return 177 primal_env.setdefault(v, val) 178 179 # Invert while computing cotangents 180 ct_env: Dict[Any, Any] = {} 181 primal_env: Dict[Any, Any] = {} 182 write_primal(core.unitvar, core.unit) 183 map(write_primal, jaxpr.invars, primals_in) 184 map(write_primal, jaxpr.outvars, primals_out) 185 map(write_primal, jaxpr.constvars, consts) 186 map(write_cotangent, jaxpr.outvars, cotangents_in) 187 for eqn in jaxpr.eqns[::-1]: 188 primals_in = map(read_primal, eqn.invars) 189 primals_out = map(read_primal, eqn.outvars) 190 cts_in = map(read_cotangent, eqn.outvars) 191 should_invert = any(type(primal) is not ad.UndefinedPrimal 192 for primal in primals_out) 193 should_vjp = any(type(ct) is not ad.Zero for ct in cts_in) 194 assert not eqn.primitive.call_primitive 195 196 # Skip primals equations that are only jvp coefficients and don't affect 197 # primal outputs. 198 if not should_invert and not should_vjp: 199 continue 200 201 def abstract(value): 202 return raise_to_shaped(value.aval if ad.is_undefined_primal(value) else get_aval(value)) 203 204 # Get the ivjp_jaxpr 205 if eqn.primitive is custom_ivjp_p: 206 ivjp_jaxpr = eqn.params['ivjp_jaxpr'] 207 else: 208 if eqn.primitive in primitive_ivjps: 209 complete_ivjp = lu.wrap_init(primitive_ivjps[eqn.primitive]) 210 else: 211 complete_ivjp = lu.wrap_init(partial(synthesize_ivjp, eqn, map(ad.is_undefined_primal, primals_in))) 212 _, in_tree = tree_flatten( 213 tuple(map(abstract, x) for x in (primals_in, primals_out, primals_out))) 214 complete_ivjp_flat, _ = flatten_fun_nokwargs(complete_ivjp, in_tree) 215 216 in_avals = map(abstract, primals_in + primals_out + primals_out) 217 if config.omnistaging_enabled: 218 # TODO: Actually we do know some of the inputs, because they might be literals! 219 ivjp_jaxpr, out_pvals, _ = pe.trace_to_jaxpr( 220 complete_ivjp_flat, map(pe.PartialVal.unknown, in_avals), instantiate=True) 221 else: 222 ivjp_jaxpr, out_pvals, _ = pe.trace_to_jaxpr( # type: ignore 223 complete_ivjp_flat, map(pe.PartialVal.unknown, in_avals), 224 instantiate=True, stage_out=False) # type: ignore 225 assert not ivjp_jaxpr.constvars # That might happen some time, but don't bother until then 226 ivjp_jaxpr = core.ClosedJaxpr(ivjp_jaxpr, []) 227 228 # Once we know what the ivjp can do exactly, we have to isolate the part we are 229 # actually able to compute with the values we have at hand. 230 num_inputs = len(eqn.invars) 231 unknowns = (map(ad.is_undefined_primal, primals_in) + 232 map(ad.is_undefined_primal, primals_out) + 233 [False] * len(cts_in)) 234 if config.omnistaging_enabled: 235 jaxpr_known, jaxpr_unknown, out_unknowns = pe.partial_eval_jaxpr( # type: ignore 236 ivjp_jaxpr, unknowns, instantiate=False) # type:ignore 237 else: 238 jaxpr_known, jaxpr_unknown, out_unknowns = pe.partial_eval_jaxpr( # type: ignore 239 ivjp_jaxpr, unknowns, instantiate=False, trace_type=None) # type: ignore 240 unknown_rec_primals_in, unknown_cotangents = split_list(out_unknowns, [num_inputs]) 241 # Make sure we're able to compute all cotangents. We don't really care if we 242 # can reconstruct or primals or not, although failure to do so might result in 243 # failing to compute cotangents later. 244 assert not any(unknown_cotangents) 245 # Remove residual outputs -- we won't be computing the unknown jaxpr anyway. 246 num_outputs = len(jaxpr_unknown.jaxpr.outvars) 247 jaxpr_known.jaxpr.outvars = jaxpr_known.jaxpr.outvars[:num_outputs] 248 # TODO: We could drop the outputs that correspond to primals that we already know. 249 # This only matters in eager mode, so leaving it out for now... 250 ivjp = core.jaxpr_as_fun(jaxpr_known) 251 rec_primals_in, cts_out = split_list(ivjp(*primals_in, *primals_out, *cts_in), 252 [num_inputs]) 253 # Unknown rec_primals_in are core.units, so we have to replace them 254 # with UnknownPrimals because that's what write_primal will ignore. 255 rec_primals_in = [prev if unknown else rec 256 for prev, rec, unknown 257 in zip(primals_in, rec_primals_in, unknown_rec_primals_in)] 258 map(write_primal, eqn.invars, rec_primals_in) 259 map(write_cotangent, [v for v in eqn.invars if type(v) is not Literal], cts_out) 260 261 # NOTE: We keep the cotangents associated with primal variables, while the contract of a 262 # transpose is to return them in positions associated with tangent variables, which 263 # is what causes this whole confusion. 264 return map(read_cotangent, jaxpr.invars) 265 266primitive_ivjps: Dict[core.Primitive, Callable] = {} 267 268def synthesize_ivjp(eqn, unknown_primals, primals_in, primals_out, cts_in): 269 # Invert eqn 270 if not eqn.primitive.multiple_results: 271 primals_out, = primals_out 272 rec_primals_in = get_primitive_inverse(eqn.primitive)(primals_out, *primals_in) 273 if len(eqn.invars) == 1: 274 rec_primals_in = (rec_primals_in,) 275 276 # Use the reconstructed primals if some primals_in were unknown, because we 277 # might have reconstructed some of them. 278 primals_in = map(lambda p, rp, unknown: rp if unknown else p, 279 primals_in, rec_primals_in, unknown_primals) 280 281 # Compute the VJP of eqn 282 variable_invars = [v for v in eqn.invars if type(v) is not Literal] 283 variable_primals_in = [p for p, v in zip(primals_in, eqn.invars) if type(v) is not Literal] 284 eqn_jaxpr = Jaxpr([], variable_invars, eqn.outvars, [eqn]) 285 eqn_callable = lambda args: core.eval_jaxpr(eqn_jaxpr, (), *args) 286 _, eqn_vjp = jax.vjp(eqn_callable, variable_primals_in) 287 # TODO: Instantiate zeros or (better) figure out how to avoid it! 288 cts_out, = eqn_vjp(cts_in) 289 290 return rec_primals_in, cts_out 291 292def split(l, parts): 293 assert len(l) % parts == 0 294 chunk = len(l) // parts 295 return [l[i:i + chunk] for i in range(0, len(l), chunk)] 296 297################################################################################ 298# Primitive inverses 299################################################################################ 300 301primitive_inverses: Dict[core.Primitive, Callable] = {} 302 303def get_primitive_inverse(p): 304 try: 305 return primitive_inverses[p] 306 except KeyError: 307 pass 308 raise NotImplementedError( 309 "Inverse rule for '{}' not implemented".format(p)) 310 311 312def definverse(primitive, inverse_rule): 313 primitive_inverses[primitive] = inverse_rule 314 return inverse_rule 315 316 317@config.register_omnistaging_disabler 318def omnistaging_disabler() -> None: 319 global _initial_style_jaxpr, custom_jvp_call 320 321 def _initial_style_jaxpr(fun, in_avals): 322 in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals] 323 jaxpr, _, consts = pe.trace_to_jaxpr(fun, in_pvals, instantiate=True, 324 bottom=True, stage_out=False) # type: ignore 325 assert not any(isinstance(c, core.Tracer) for c in consts) 326 return core.ClosedJaxpr(jaxpr, consts) 327