1# Copyright 2020 Google LLC 2# 3# Licensed under the Apache License, Version 2.0 (the "License"); 4# you may not use this file except in compliance with the License. 5# You may obtain a copy of the License at 6# 7# https://www.apache.org/licenses/LICENSE-2.0 8# 9# Unless required by applicable law or agreed to in writing, software 10# distributed under the License is distributed on an "AS IS" BASIS, 11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12# See the License for the specific language governing permissions and 13# limitations under the License. 14 15from typing import Callable 16 17from functools import partial 18 19import numpy as np 20 21import jax 22import jax.numpy as jnp 23from jax import core 24from jax._src.util import unzip2 25from jax import ad_util 26from jax.tree_util import (register_pytree_node, tree_structure, 27 treedef_is_leaf, tree_flatten, tree_unflatten) 28import jax.linear_util as lu 29from jax.interpreters import xla 30from jax.custom_derivatives import custom_jvp_call_jaxpr_p 31from jax._src.lax import lax 32from jax._src.lax import control_flow as lax_control_flow 33from jax._src.lax import fft as lax_fft 34 35def jet(fun, primals, series): 36 try: 37 order, = set(map(len, series)) 38 except ValueError: 39 msg = "jet terms have inconsistent lengths for different arguments" 40 raise ValueError(msg) from None 41 42 # TODO(mattjj): consider supporting pytree inputs 43 for i, (x, terms) in enumerate(zip(primals, series)): 44 treedef = tree_structure(x) 45 if not treedef_is_leaf(treedef): 46 raise ValueError("primal value at position {} is not an array".format(i)) 47 for j, t in enumerate(terms): 48 treedef = tree_structure(t) 49 if not treedef_is_leaf(treedef): 50 raise ValueError("term {} for argument {} is not an array".format(j, i)) 51 52 @lu.transformation_with_aux 53 def flatten_fun_output(*args): 54 ans = yield args, {} 55 yield tree_flatten(ans) 56 57 f, out_tree = flatten_fun_output(lu.wrap_init(fun)) 58 out_primals, out_terms = jet_fun(jet_subtrace(f), order).call_wrapped(primals, series) 59 return tree_unflatten(out_tree(), out_primals), tree_unflatten(out_tree(), out_terms) 60 61@lu.transformation 62def jet_fun(order, primals, series): 63 with core.new_main(JetTrace) as main: 64 main.order = order 65 out_primals, out_terms = yield (main, primals, series), {} 66 del main 67 out_terms = [[np.zeros_like(p)] * order if s is zero_series else s 68 for p, s in zip(out_primals, out_terms)] 69 yield out_primals, out_terms 70 71@lu.transformation 72def jet_subtrace(main, primals, series): 73 trace = JetTrace(main, core.cur_sublevel()) 74 in_tracers = map(partial(JetTracer, trace), primals, series) 75 ans = yield in_tracers, {} 76 out_tracers = map(trace.full_raise, ans) 77 out_primals, out_terms = unzip2((t.primal, t.terms) for t in out_tracers) 78 yield out_primals, out_terms 79 80@lu.transformation_with_aux 81def traceable(in_tree_def, *primals_and_series): 82 primals_in, series_in = tree_unflatten(in_tree_def, primals_and_series) 83 primals_out, series_out = yield (primals_in, series_in), {} 84 out_flat, out_tree_def = tree_flatten((primals_out, series_out)) 85 yield out_flat, out_tree_def 86 87 88class JetTracer(core.Tracer): 89 __slots__ = ["primal", "terms"] 90 91 def __init__(self, trace, primal, terms): 92 assert type(terms) in (ZeroSeries, list, tuple) 93 self._trace = trace 94 self.primal = primal 95 self.terms = terms 96 97 @property 98 def aval(self): 99 return core.get_aval(self.primal) 100 101 def full_lower(self): 102 if self.terms is zero_series or all(t is zero_term for t in self.terms): 103 return core.full_lower(self.primal) 104 else: 105 return self 106 107class JetTrace(core.Trace): 108 109 def pure(self, val): 110 return JetTracer(self, val, zero_series) 111 112 def lift(self, val): 113 return JetTracer(self, val, zero_series) 114 115 def sublift(self, val): 116 return JetTracer(self, val.primal, val.terms) 117 118 def process_primitive(self, primitive, tracers, params): 119 order = self.main.order # pytype: disable=attribute-error 120 primals_in, series_in = unzip2((t.primal, t.terms) for t in tracers) 121 series_in = [[zero_term] * order if s is zero_series else s 122 for s in series_in] 123 # TODO(mattjj): avoid always instantiating zeros 124 series_in = [[np.zeros(np.shape(x), dtype=np.result_type(x)) 125 if t is zero_term else t for t in series] 126 for x, series in zip(primals_in, series_in)] 127 rule = jet_rules[primitive] 128 primal_out, terms_out = rule(primals_in, series_in, **params) 129 if not primitive.multiple_results: 130 return JetTracer(self, primal_out, terms_out) 131 else: 132 return [JetTracer(self, p, ts) for p, ts in zip(primal_out, terms_out)] 133 134 def process_call(self, call_primitive, f, tracers, params): 135 primals_in, series_in = unzip2((t.primal, t.terms) for t in tracers) 136 primals_and_series, in_tree_def = tree_flatten((primals_in, series_in)) 137 f_jet, out_tree_def = traceable(jet_subtrace(f, self.main), in_tree_def) 138 update_params = call_param_updaters.get(call_primitive) 139 new_params = (update_params(params, len(primals_and_series)) 140 if update_params else params) 141 result = call_primitive.bind(f_jet, *primals_and_series, **new_params) 142 primals_out, series_out = tree_unflatten(out_tree_def(), result) 143 return [JetTracer(self, p, ts) for p, ts in zip(primals_out, series_out)] 144 145 def post_process_call(self, call_primitive, out_tracers, params): 146 primals, series = unzip2((t.primal, t.terms) for t in out_tracers) 147 out, treedef = tree_flatten((primals, series)) 148 del primals, series 149 main = self.main 150 def todo(x): 151 primals, series = tree_unflatten(treedef, x) 152 trace = JetTrace(main, core.cur_sublevel()) 153 return map(partial(JetTracer, trace), primals, series) 154 return out, todo 155 156 def process_custom_jvp_call(self, primitive, fun, jvp, tracers): 157 # TODO(mattjj): don't just ignore custom jvp rules? 158 del primitive, jvp # Unused. 159 return fun.call_wrapped(*tracers) 160 161 def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees): 162 del primitive, fwd, bwd, out_trees # Unused. 163 return fun.call_wrapped(*tracers) 164 165 166class ZeroTerm(object): pass 167zero_term = ZeroTerm() 168register_pytree_node(ZeroTerm, lambda z: ((), None), lambda _, xs: zero_term) 169 170class ZeroSeries(object): pass 171zero_series = ZeroSeries() 172register_pytree_node(ZeroSeries, lambda z: ((), None), lambda _, xs: zero_series) 173 174 175call_param_updaters = {} 176 177def _xla_call_param_updater(params, num_inputs): 178 donated_invars = params['donated_invars'] 179 if any(donated_invars): 180 raise NotImplementedError("donated_invars not supported with jet") 181 return dict(params, donated_invars=(False,) * num_inputs) 182call_param_updaters[xla.xla_call_p] = _xla_call_param_updater 183 184 185### rule definitions 186 187jet_rules = {} 188 189def defzero(prim): 190 jet_rules[prim] = partial(zero_prop, prim) 191 192def zero_prop(prim, primals_in, series_in, **params): 193 primal_out = prim.bind(*primals_in, **params) 194 return primal_out, zero_series 195 196defzero(lax.le_p) 197defzero(lax.lt_p) 198defzero(lax.gt_p) 199defzero(lax.ge_p) 200defzero(lax.eq_p) 201defzero(lax.ne_p) 202defzero(lax.not_p) 203defzero(lax.and_p) 204defzero(lax.or_p) 205defzero(lax.xor_p) 206defzero(lax.floor_p) 207defzero(lax.ceil_p) 208defzero(lax.round_p) 209defzero(lax.sign_p) 210defzero(ad_util.stop_gradient_p) 211defzero(lax.is_finite_p) 212defzero(lax.shift_left_p) 213defzero(lax.shift_right_arithmetic_p) 214defzero(lax.shift_right_logical_p) 215defzero(lax.bitcast_convert_type_p) 216 217 218def deflinear(prim): 219 jet_rules[prim] = partial(linear_prop, prim) 220 221def linear_prop(prim, primals_in, series_in, **params): 222 primal_out = prim.bind(*primals_in, **params) 223 series_out = [prim.bind(*terms_in, **params) for terms_in in zip(*series_in)] 224 return primal_out, series_out 225 226deflinear(lax.neg_p) 227deflinear(lax.real_p) 228deflinear(lax.complex_p) 229deflinear(lax.conj_p) 230deflinear(lax.imag_p) 231deflinear(lax.add_p) 232deflinear(ad_util.add_jaxvals_p) 233deflinear(lax.sub_p) 234deflinear(lax.convert_element_type_p) 235deflinear(lax.broadcast_in_dim_p) 236deflinear(lax.concatenate_p) 237deflinear(lax.pad_p) 238deflinear(lax.reshape_p) 239deflinear(lax.rev_p) 240deflinear(lax.transpose_p) 241deflinear(lax.slice_p) 242deflinear(lax.reduce_sum_p) 243deflinear(lax.reduce_window_sum_p) 244deflinear(lax_fft.fft_p) 245deflinear(xla.device_put_p) 246 247def _cumulative_jet_rule(primals_in, series_in, *, axis: int, reverse: bool, 248 combine_fn: Callable): 249 # Irrespective of backend, we always use the parallel prefix scan 250 # implementation when differentiating because reduce_window is not 251 # arbitrarily differentiable. 252 return jet(partial(lax_control_flow.associative_scan, combine_fn, axis=axis, 253 reverse=reverse), 254 primals_in, series_in) 255 256deflinear(lax_control_flow.cumsum_p) 257jet_rules[lax_control_flow.cumprod_p] = partial(_cumulative_jet_rule, 258 combine_fn=lax.mul) 259jet_rules[lax_control_flow.cummax_p] = partial(_cumulative_jet_rule, 260 combine_fn=lax.max) 261jet_rules[lax_control_flow.cummin_p] = partial(_cumulative_jet_rule, 262 combine_fn=lax.min) 263 264 265def def_deriv(prim, deriv): 266 """ 267 Define the jet rule for a primitive in terms of its first derivative. 268 """ 269 jet_rules[prim] = partial(deriv_prop, prim, deriv) 270 271 272def deriv_prop(prim, deriv, primals_in, series_in): 273 x, = primals_in 274 series, = series_in 275 primal_out = prim.bind(x) 276 c0, cs = jet(deriv, primals_in, series_in) 277 c = [c0] + cs 278 u = [x] + series 279 v = [primal_out] + [None] * len(series) 280 for k in range(1, len(v)): 281 v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1)) 282 primal_out, *series_out = v 283 return primal_out, series_out 284 285 286def_deriv(lax.erf_p, lambda x: lax.mul(lax._const(x, 2. / np.sqrt(np.pi)), lax.exp(lax.neg(lax.square(x))))) 287 288 289def def_comp(prim, comp): 290 """ 291 Define the jet rule for a primitive in terms of a composition of simpler primitives. 292 """ 293 jet_rules[prim] = partial(jet, comp) 294 295 296def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1) 297def_comp(lax.log1p_p, lambda x: lax.log(1 + x)) 298def_comp(lax.sqrt_p, lambda x: x ** 0.5) 299def_comp(lax.rsqrt_p, lambda x: x ** -0.5) 300def_comp(lax.asinh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1))) 301def_comp(lax.acosh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1))) 302def_comp(lax.atanh_p, lambda x: 0.5 * lax.log(lax.div(1 + x, 1 - x))) 303def_comp(lax.erfc_p, lambda x: 1 - lax.erf(x)) 304def_comp(lax.rem_p, lambda x, y: x - y * lax.floor(x / y)) 305def_comp(lax.clamp_p, lambda a, x, b: lax.min(lax.max(a, x), b)) 306 307 308def _erf_inv_rule(primals_in, series_in): 309 x, = primals_in 310 series, = series_in 311 312 u = [x] + series 313 primal_out = lax.erf_inv(x) 314 v = [primal_out] + [None] * len(series) 315 316 # derivative on co-domain for caching purposes 317 deriv_const = np.sqrt(np.pi) / 2. 318 deriv_y = lambda y: lax.mul(deriv_const, lax.exp(lax.square(y))) 319 320 # manually propagate through deriv_y since we don't have lazy evaluation of sensitivities 321 322 c = [deriv_y(primal_out)] + [None] * (len(series) - 1) 323 tmp_sq = [lax.square(v[0])] + [None] * (len(series) - 1) 324 tmp_exp = [lax.exp(tmp_sq[0])] + [None] * (len(series) - 1) 325 for k in range(1, len(series)): 326 # we know c[:k], we compute c[k] 327 328 # propagate c to get v 329 v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1)) 330 331 # propagate v to get next c 332 333 # square 334 tmp_sq[k] = fact(k) * sum(_scale2(k, j) * v[k-j] * v[j] for j in range(k + 1)) 335 336 # exp 337 tmp_exp[k] = fact(k-1) * sum(_scale(k, j) * tmp_exp[k-j] * tmp_sq[j] for j in range(1, k + 1)) 338 339 # const 340 c[k] = deriv_const * tmp_exp[k] 341 342 # we can't, and don't need, to compute c[k+1], just need to get the last v[k] 343 k = len(series) 344 v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1)) 345 346 primal_out, *series_out = v 347 return primal_out, series_out 348jet_rules[lax.erf_inv_p] = _erf_inv_rule 349 350### More complicated rules 351 352def fact(n): 353 return lax.exp(lax.lgamma(n+1.)) 354 355def _scale(k, j): 356 return 1. / (fact(k - j) * fact(j - 1)) 357 358def _scale2(k, j): 359 return 1. / (fact(k - j) * fact(j)) 360 361def _exp_taylor(primals_in, series_in): 362 x, = primals_in 363 series, = series_in 364 u = [x] + series 365 v = [lax.exp(x)] + [None] * len(series) 366 for k in range(1,len(v)): 367 v[k] = fact(k-1) * sum([_scale(k, j)* v[k-j] * u[j] for j in range(1, k+1)]) 368 primal_out, *series_out = v 369 return primal_out, series_out 370jet_rules[lax.exp_p] = _exp_taylor 371 372def _pow_taylor(primals_in, series_in): 373 u_, r_ = primals_in 374 375 x, series = jet(lambda x, y: lax.mul(y, lax.log(x)), primals_in, series_in) 376 377 u = [x] + series 378 v = [u_ ** r_] + [None] * len(series) 379 for k in range(1, len(v)): 380 v[k] = fact(k-1) * sum([_scale(k, j)* v[k-j] * u[j] for j in range(1, k+1)]) 381 primal_out, *series_out = v 382 383 return primal_out, series_out 384jet_rules[lax.pow_p] = _pow_taylor 385 386def _integer_pow_taylor(primals_in, series_in, *, y): 387 if y == 0: 388 return jet(jnp.ones_like, primals_in, series_in) 389 elif y == 1: 390 return jet(lambda x: x, primals_in, series_in) 391 elif y == 2: 392 return jet(lambda x: x * x, primals_in, series_in) 393 x, = primals_in 394 series, = series_in 395 u = [x] + series 396 v = [lax.integer_pow(x, y)] + [None] * len(series) 397 for k in range(1, len(v)): 398 vu = sum(_scale(k, j) * v[k-j] * u[j] for j in range(1, k + 1)) 399 uv = sum(_scale(k, j) * u[k-j] * v[j] for j in range(1, k)) 400 v[k] = jnp.where(x == 0, 0, fact(k-1) * (y * vu - uv) / x) 401 primal_out, *series_out = v 402 403 return primal_out, series_out 404jet_rules[lax.integer_pow_p] = _integer_pow_taylor 405 406 407def _expit_taylor(primals_in, series_in): 408 x, = primals_in 409 series, = series_in 410 u = [x] + series 411 v = [jax.scipy.special.expit(x)] + [None] * len(series) 412 e = [v[0] * (1 - v[0])] + [None] * len(series) # terms for sigmoid' = sigmoid * (1 - sigmoid) 413 for k in range(1, len(v)): 414 v[k] = fact(k-1) * sum([_scale(k, j) * e[k-j] * u[j] for j in range(1, k+1)]) 415 e[k] = (1 - v[0]) * v[k] - fact(k) * sum([_scale2(k, j)* v[j] * v[k-j] for j in range(1, k+1)]) 416 417 primal_out, *series_out = v 418 return primal_out, series_out 419 420def _tanh_taylor(primals_in, series_in): 421 x, = primals_in 422 series, = series_in 423 u = [2*x] + [2 * series_ for series_ in series] 424 primals_in, *series_in = u 425 primal_out, series_out = _expit_taylor((primals_in, ), (series_in, )) 426 series_out = [2 * series_ for series_ in series_out] 427 return 2 * primal_out - 1, series_out 428jet_rules[lax.tanh_p] = _tanh_taylor 429 430def _log_taylor(primals_in, series_in): 431 x, = primals_in 432 series, = series_in 433 u = [x] + series 434 v = [lax.log(x)] + [None] * len(series) 435 for k in range(1, len(v)): 436 conv = sum([_scale(k, j) * v[j] * u[k-j] for j in range(1, k)]) 437 v[k] = (u[k] - fact(k - 1) * conv) / u[0] 438 primal_out, *series_out = v 439 return primal_out, series_out 440jet_rules[lax.log_p] = _log_taylor 441 442def _atan2_taylor(primals_in, series_in): 443 x, y = primals_in 444 primal_out = lax.atan2(x, y) 445 446 x, series = jet(lax.div, primals_in, series_in) 447 c0, cs = jet(lambda x: lax.div(1, 1 + lax.square(x)), (x, ), (series, )) 448 c = [c0] + cs 449 u = [x] + series 450 v = [primal_out] + [None] * len(series) 451 for k in range(1, len(v)): 452 v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1)) 453 primal_out, *series_out = v 454 return primal_out, series_out 455jet_rules[lax.atan2_p] = _atan2_taylor 456 457def _div_taylor_rule(primals_in, series_in): 458 x, y = primals_in 459 x_terms, y_terms = series_in 460 u = [x] + x_terms 461 w = [y] + y_terms 462 v = [None] * len(u) 463 def scale(k, j): return 1. / (fact(k - j) * fact(j)) 464 for k in range(0, len(v)): 465 conv = sum([scale(k, j) * v[j] * w[k-j] for j in range(0, k)]) 466 v[k] = (u[k] - fact(k) * conv) / w[0] 467 primal_out, *series_out = v 468 return primal_out, series_out 469jet_rules[lax.div_p] = _div_taylor_rule 470 471def _sinusoidal_rule(sign, prims, primals_in, series_in): 472 x, = primals_in 473 series, = series_in 474 u = [x] + series 475 s, c = prims 476 s = [s(x)] + [None] * len(series) 477 c = [c(x)] + [None] * len(series) 478 for k in range(1, len(s)): 479 s[k] = fact(k-1) * sum(_scale(k, j) * u[j] * c[k-j] for j in range(1, k + 1)) 480 c[k] = fact(k-1) * sum(_scale(k, j) * u[j] * s[k-j] for j in range(1, k + 1)) * sign 481 return (s[0], s[1:]), (c[0], c[1:]) 482 483def _get_ind(f, ind): 484 return lambda *args: f(*args)[ind] 485 486jet_rules[lax.sin_p] = _get_ind(partial(_sinusoidal_rule, -1, (lax.sin, lax.cos)), 0) 487jet_rules[lax.cos_p] = _get_ind(partial(_sinusoidal_rule, -1, (lax.sin, lax.cos)), 1) 488jet_rules[lax.sinh_p] = _get_ind(partial(_sinusoidal_rule, 1, (lax.sinh, lax.cosh)), 0) 489jet_rules[lax.cosh_p] = _get_ind(partial(_sinusoidal_rule, 1, (lax.sinh, lax.cosh)), 1) 490 491def _bilinear_taylor_rule(prim, primals_in, series_in, **params): 492 x, y = primals_in 493 x_terms, y_terms = series_in 494 u = [x] + x_terms 495 w = [y] + y_terms 496 v = [None] * len(u) 497 op = partial(prim.bind, **params) 498 def scale(k, j): return 1. / (fact(k - j) * fact(j)) 499 for k in range(0, len(v)): 500 v[k] = fact(k) * sum([scale(k, j) * op(u[j], w[k-j]) for j in range(0, k+1)]) 501 primal_out, *series_out = v 502 return primal_out, series_out 503jet_rules[lax.dot_general_p] = partial(_bilinear_taylor_rule, lax.dot_general_p) 504jet_rules[lax.mul_p] = partial(_bilinear_taylor_rule, lax.mul_p) 505jet_rules[lax.conv_general_dilated_p] = partial(_bilinear_taylor_rule, lax.conv_general_dilated_p) 506 507def _gather_taylor_rule(primals_in, series_in, **params): 508 operand, start_indices = primals_in 509 gs, _ = series_in 510 primal_out = lax.gather_p.bind(operand, start_indices, **params) 511 series_out = [lax.gather_p.bind(g, start_indices, **params) for g in gs] 512 return primal_out, series_out 513jet_rules[lax.gather_p] = _gather_taylor_rule 514 515def _gen_reduce_choose_taylor_rule(chooser_fun): 516 def chooser_taylor_rule(primals_in, series_in, **params): 517 operand, = primals_in 518 gs, = series_in 519 primal_out = chooser_fun(operand, **params) 520 axes = params.pop("axes", None) 521 primal_dtype = gs[0].dtype 522 shape = [1 if i in axes else d for i, d in enumerate(operand.shape)] 523 location_indicators = lax.convert_element_type( 524 lax._eq_meet(operand, lax.reshape(primal_out, shape)), primal_dtype) 525 counts = lax._reduce_sum(location_indicators, axes) 526 def _reduce_chooser_taylor_rule(g): 527 return lax.div(lax._reduce_sum(lax.mul(g, location_indicators), axes), counts) 528 series_out = [_reduce_chooser_taylor_rule(g) for g in gs] 529 return primal_out, series_out 530 return chooser_taylor_rule 531jet_rules[lax.reduce_max_p] = _gen_reduce_choose_taylor_rule(lax._reduce_max) 532jet_rules[lax.reduce_min_p] = _gen_reduce_choose_taylor_rule(lax._reduce_min) 533 534def _abs_taylor_rule(x, series_in, **params): 535 x, = x 536 zero = lax.full_like(x, 0, shape=()) 537 primal_out = lax.abs_p.bind(x, **params) 538 negs = lax.select(lax.lt(x, zero), lax.full_like(x, -1), lax.full_like(x, 1.0)) 539 fix_sign = lambda y: negs * y 540 series_out = [fix_sign(*terms_in, **params) for terms_in in zip(*series_in)] 541 return primal_out, series_out 542jet_rules[lax.abs_p] = _abs_taylor_rule 543 544def _select_taylor_rule(primal_in, series_in, **params): 545 b, x, y = primal_in 546 primal_out = lax.select_p.bind(b, x, y, **params) 547 sel = lambda _, x, y: lax.select(b, x, y) 548 series_out = [sel(*terms_in, **params) for terms_in in zip(*series_in)] 549 return primal_out, series_out 550jet_rules[lax.select_p] = _select_taylor_rule 551 552 553def _lax_max_taylor_rule(primal_in, series_in): 554 x, y = primal_in 555 556 xgy = x > y # greater than mask 557 xey = x == y # equal to mask 558 primal_out = lax.select(xgy, x, y) 559 560 def select_max_and_avg_eq(x_i, y_i): 561 """Select x where x>y or average when x==y""" 562 max_i = lax.select(xgy, x_i, y_i) 563 max_i = lax.select(xey, (x_i + y_i)/2, max_i) 564 return max_i 565 566 series_out = [select_max_and_avg_eq(*terms_in) for terms_in in zip(*series_in)] 567 return primal_out, series_out 568jet_rules[lax.max_p] = _lax_max_taylor_rule 569 570def _lax_min_taylor_rule(primal_in, series_in): 571 x, y = primal_in 572 xgy = x < y # less than mask 573 xey = x == y # equal to mask 574 primal_out = lax.select(xgy, x, y) 575 576 def select_min_and_avg_eq(x_i, y_i): 577 """Select x where x>y or average when x==y""" 578 min_i = lax.select(xgy, x_i, y_i) 579 min_i = lax.select(xey, (x_i + y_i)/2, min_i) 580 return min_i 581 582 series_out = [select_min_and_avg_eq(*terms_in) for terms_in in zip(*series_in)] 583 return primal_out, series_out 584jet_rules[lax.min_p] = _lax_min_taylor_rule 585 586def _custom_jvp_call_jaxpr_rule(primals_in, series_in, *, fun_jaxpr, 587 jvp_jaxpr_thunk): 588 # TODO(mattjj): do something better than ignoring custom jvp rules for jet? 589 del jvp_jaxpr_thunk 590 return jet(core.jaxpr_as_fun(fun_jaxpr), primals_in, series_in) 591jet_rules[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_rule 592 593 594deflinear(lax.tie_in_p) 595