1# Copyright 2020 Google LLC
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7#     https://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14
15from typing import Callable
16
17from functools import partial
18
19import numpy as np
20
21import jax
22import jax.numpy as jnp
23from jax import core
24from jax._src.util import unzip2
25from jax import ad_util
26from jax.tree_util import (register_pytree_node, tree_structure,
27                           treedef_is_leaf, tree_flatten, tree_unflatten)
28import jax.linear_util as lu
29from jax.interpreters import xla
30from jax.custom_derivatives import custom_jvp_call_jaxpr_p
31from jax._src.lax import lax
32from jax._src.lax import control_flow as lax_control_flow
33from jax._src.lax import fft as lax_fft
34
35def jet(fun, primals, series):
36  try:
37    order, = set(map(len, series))
38  except ValueError:
39    msg = "jet terms have inconsistent lengths for different arguments"
40    raise ValueError(msg) from None
41
42  # TODO(mattjj): consider supporting pytree inputs
43  for i, (x, terms) in enumerate(zip(primals, series)):
44    treedef = tree_structure(x)
45    if not treedef_is_leaf(treedef):
46      raise ValueError("primal value at position {} is not an array".format(i))
47    for j, t in enumerate(terms):
48      treedef = tree_structure(t)
49      if not treedef_is_leaf(treedef):
50        raise ValueError("term {} for argument {} is not an array".format(j, i))
51
52  @lu.transformation_with_aux
53  def flatten_fun_output(*args):
54    ans = yield args, {}
55    yield tree_flatten(ans)
56
57  f, out_tree = flatten_fun_output(lu.wrap_init(fun))
58  out_primals, out_terms = jet_fun(jet_subtrace(f), order).call_wrapped(primals, series)
59  return tree_unflatten(out_tree(), out_primals), tree_unflatten(out_tree(), out_terms)
60
61@lu.transformation
62def jet_fun(order, primals, series):
63  with core.new_main(JetTrace) as main:
64    main.order = order
65    out_primals, out_terms = yield (main, primals, series), {}
66    del main
67  out_terms = [[np.zeros_like(p)] * order if s is zero_series else s
68               for p, s in zip(out_primals, out_terms)]
69  yield out_primals, out_terms
70
71@lu.transformation
72def jet_subtrace(main, primals, series):
73  trace = JetTrace(main, core.cur_sublevel())
74  in_tracers = map(partial(JetTracer, trace), primals, series)
75  ans = yield in_tracers, {}
76  out_tracers = map(trace.full_raise, ans)
77  out_primals, out_terms = unzip2((t.primal, t.terms) for t in out_tracers)
78  yield out_primals, out_terms
79
80@lu.transformation_with_aux
81def traceable(in_tree_def, *primals_and_series):
82  primals_in, series_in = tree_unflatten(in_tree_def, primals_and_series)
83  primals_out, series_out = yield (primals_in, series_in), {}
84  out_flat, out_tree_def = tree_flatten((primals_out, series_out))
85  yield out_flat, out_tree_def
86
87
88class JetTracer(core.Tracer):
89  __slots__ = ["primal", "terms"]
90
91  def __init__(self, trace, primal, terms):
92    assert type(terms) in (ZeroSeries, list, tuple)
93    self._trace = trace
94    self.primal = primal
95    self.terms = terms
96
97  @property
98  def aval(self):
99    return core.get_aval(self.primal)
100
101  def full_lower(self):
102    if self.terms is zero_series or all(t is zero_term for t in self.terms):
103      return core.full_lower(self.primal)
104    else:
105      return self
106
107class JetTrace(core.Trace):
108
109  def pure(self, val):
110    return JetTracer(self, val, zero_series)
111
112  def lift(self, val):
113    return JetTracer(self, val, zero_series)
114
115  def sublift(self, val):
116    return JetTracer(self, val.primal, val.terms)
117
118  def process_primitive(self, primitive, tracers, params):
119    order = self.main.order              # pytype: disable=attribute-error
120    primals_in, series_in = unzip2((t.primal, t.terms) for t in tracers)
121    series_in = [[zero_term] * order if s is zero_series else s
122                 for s in series_in]
123    # TODO(mattjj): avoid always instantiating zeros
124    series_in = [[np.zeros(np.shape(x), dtype=np.result_type(x))
125                  if t is zero_term else t for t in series]
126                 for x, series in zip(primals_in, series_in)]
127    rule = jet_rules[primitive]
128    primal_out, terms_out = rule(primals_in, series_in, **params)
129    if not primitive.multiple_results:
130      return JetTracer(self, primal_out, terms_out)
131    else:
132      return [JetTracer(self, p, ts) for p, ts in zip(primal_out, terms_out)]
133
134  def process_call(self, call_primitive, f, tracers, params):
135    primals_in, series_in = unzip2((t.primal, t.terms) for t in tracers)
136    primals_and_series, in_tree_def = tree_flatten((primals_in, series_in))
137    f_jet, out_tree_def = traceable(jet_subtrace(f, self.main), in_tree_def)
138    update_params = call_param_updaters.get(call_primitive)
139    new_params = (update_params(params, len(primals_and_series))
140                  if update_params else params)
141    result = call_primitive.bind(f_jet, *primals_and_series, **new_params)
142    primals_out, series_out = tree_unflatten(out_tree_def(), result)
143    return [JetTracer(self, p, ts) for p, ts in zip(primals_out, series_out)]
144
145  def post_process_call(self, call_primitive, out_tracers, params):
146    primals, series = unzip2((t.primal, t.terms) for t in out_tracers)
147    out, treedef = tree_flatten((primals, series))
148    del primals, series
149    main = self.main
150    def todo(x):
151      primals, series = tree_unflatten(treedef, x)
152      trace = JetTrace(main, core.cur_sublevel())
153      return map(partial(JetTracer, trace), primals, series)
154    return out, todo
155
156  def process_custom_jvp_call(self, primitive, fun, jvp, tracers):
157    # TODO(mattjj): don't just ignore custom jvp rules?
158    del primitive, jvp  # Unused.
159    return fun.call_wrapped(*tracers)
160
161  def process_custom_vjp_call(self, primitive, fun, fwd, bwd, tracers, out_trees):
162    del primitive, fwd, bwd, out_trees  # Unused.
163    return fun.call_wrapped(*tracers)
164
165
166class ZeroTerm(object): pass
167zero_term = ZeroTerm()
168register_pytree_node(ZeroTerm, lambda z: ((), None), lambda _, xs: zero_term)
169
170class ZeroSeries(object): pass
171zero_series = ZeroSeries()
172register_pytree_node(ZeroSeries, lambda z: ((), None), lambda _, xs: zero_series)
173
174
175call_param_updaters = {}
176
177def _xla_call_param_updater(params, num_inputs):
178  donated_invars = params['donated_invars']
179  if any(donated_invars):
180    raise NotImplementedError("donated_invars not supported with jet")
181  return dict(params, donated_invars=(False,) * num_inputs)
182call_param_updaters[xla.xla_call_p] = _xla_call_param_updater
183
184
185### rule definitions
186
187jet_rules = {}
188
189def defzero(prim):
190  jet_rules[prim] = partial(zero_prop, prim)
191
192def zero_prop(prim, primals_in, series_in, **params):
193  primal_out = prim.bind(*primals_in, **params)
194  return primal_out, zero_series
195
196defzero(lax.le_p)
197defzero(lax.lt_p)
198defzero(lax.gt_p)
199defzero(lax.ge_p)
200defzero(lax.eq_p)
201defzero(lax.ne_p)
202defzero(lax.not_p)
203defzero(lax.and_p)
204defzero(lax.or_p)
205defzero(lax.xor_p)
206defzero(lax.floor_p)
207defzero(lax.ceil_p)
208defzero(lax.round_p)
209defzero(lax.sign_p)
210defzero(ad_util.stop_gradient_p)
211defzero(lax.is_finite_p)
212defzero(lax.shift_left_p)
213defzero(lax.shift_right_arithmetic_p)
214defzero(lax.shift_right_logical_p)
215defzero(lax.bitcast_convert_type_p)
216
217
218def deflinear(prim):
219  jet_rules[prim] = partial(linear_prop, prim)
220
221def linear_prop(prim, primals_in, series_in, **params):
222  primal_out = prim.bind(*primals_in, **params)
223  series_out = [prim.bind(*terms_in, **params) for terms_in in zip(*series_in)]
224  return primal_out, series_out
225
226deflinear(lax.neg_p)
227deflinear(lax.real_p)
228deflinear(lax.complex_p)
229deflinear(lax.conj_p)
230deflinear(lax.imag_p)
231deflinear(lax.add_p)
232deflinear(ad_util.add_jaxvals_p)
233deflinear(lax.sub_p)
234deflinear(lax.convert_element_type_p)
235deflinear(lax.broadcast_in_dim_p)
236deflinear(lax.concatenate_p)
237deflinear(lax.pad_p)
238deflinear(lax.reshape_p)
239deflinear(lax.rev_p)
240deflinear(lax.transpose_p)
241deflinear(lax.slice_p)
242deflinear(lax.reduce_sum_p)
243deflinear(lax.reduce_window_sum_p)
244deflinear(lax_fft.fft_p)
245deflinear(xla.device_put_p)
246
247def _cumulative_jet_rule(primals_in, series_in, *, axis: int, reverse: bool,
248                         combine_fn: Callable):
249  # Irrespective of backend, we always use the parallel prefix scan
250  # implementation when differentiating because reduce_window is not
251  # arbitrarily differentiable.
252  return jet(partial(lax_control_flow.associative_scan, combine_fn, axis=axis,
253                     reverse=reverse),
254             primals_in, series_in)
255
256deflinear(lax_control_flow.cumsum_p)
257jet_rules[lax_control_flow.cumprod_p] = partial(_cumulative_jet_rule,
258                                                combine_fn=lax.mul)
259jet_rules[lax_control_flow.cummax_p] = partial(_cumulative_jet_rule,
260                                               combine_fn=lax.max)
261jet_rules[lax_control_flow.cummin_p] = partial(_cumulative_jet_rule,
262                                               combine_fn=lax.min)
263
264
265def def_deriv(prim, deriv):
266  """
267  Define the jet rule for a primitive in terms of its first derivative.
268  """
269  jet_rules[prim] = partial(deriv_prop, prim, deriv)
270
271
272def deriv_prop(prim, deriv, primals_in, series_in):
273  x, = primals_in
274  series, = series_in
275  primal_out = prim.bind(x)
276  c0, cs = jet(deriv, primals_in, series_in)
277  c = [c0] + cs
278  u = [x] + series
279  v = [primal_out] + [None] * len(series)
280  for k in range(1, len(v)):
281    v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1))
282  primal_out, *series_out = v
283  return primal_out, series_out
284
285
286def_deriv(lax.erf_p, lambda x: lax.mul(lax._const(x, 2. / np.sqrt(np.pi)), lax.exp(lax.neg(lax.square(x)))))
287
288
289def def_comp(prim, comp):
290  """
291  Define the jet rule for a primitive in terms of a composition of simpler primitives.
292  """
293  jet_rules[prim] = partial(jet, comp)
294
295
296def_comp(lax.expm1_p, lambda x: lax.exp(x) - 1)
297def_comp(lax.log1p_p, lambda x: lax.log(1 + x))
298def_comp(lax.sqrt_p, lambda x: x ** 0.5)
299def_comp(lax.rsqrt_p, lambda x: x ** -0.5)
300def_comp(lax.asinh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) + 1)))
301def_comp(lax.acosh_p, lambda x: lax.log(x + lax.sqrt(lax.square(x) - 1)))
302def_comp(lax.atanh_p, lambda x: 0.5 * lax.log(lax.div(1 + x, 1 - x)))
303def_comp(lax.erfc_p, lambda x: 1 - lax.erf(x))
304def_comp(lax.rem_p, lambda x, y: x - y * lax.floor(x / y))
305def_comp(lax.clamp_p, lambda a, x, b: lax.min(lax.max(a, x), b))
306
307
308def _erf_inv_rule(primals_in, series_in):
309  x, = primals_in
310  series, = series_in
311
312  u = [x] + series
313  primal_out = lax.erf_inv(x)
314  v = [primal_out] + [None] * len(series)
315
316  # derivative on co-domain for caching purposes
317  deriv_const = np.sqrt(np.pi) / 2.
318  deriv_y = lambda y: lax.mul(deriv_const, lax.exp(lax.square(y)))
319
320  # manually propagate through deriv_y since we don't have lazy evaluation of sensitivities
321
322  c = [deriv_y(primal_out)] + [None] * (len(series) - 1)
323  tmp_sq = [lax.square(v[0])] + [None] * (len(series) - 1)
324  tmp_exp = [lax.exp(tmp_sq[0])] + [None] * (len(series) - 1)
325  for k in range(1, len(series)):
326    # we know c[:k], we compute c[k]
327
328    # propagate c to get v
329    v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1))
330
331    # propagate v to get next c
332
333    # square
334    tmp_sq[k] = fact(k) * sum(_scale2(k, j) * v[k-j] * v[j] for j in range(k + 1))
335
336    # exp
337    tmp_exp[k] = fact(k-1) * sum(_scale(k, j) * tmp_exp[k-j] * tmp_sq[j] for j in range(1, k + 1))
338
339    # const
340    c[k] = deriv_const * tmp_exp[k]
341
342  # we can't, and don't need, to compute c[k+1], just need to get the last v[k]
343  k = len(series)
344  v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1))
345
346  primal_out, *series_out = v
347  return primal_out, series_out
348jet_rules[lax.erf_inv_p] = _erf_inv_rule
349
350### More complicated rules
351
352def fact(n):
353  return lax.exp(lax.lgamma(n+1.))
354
355def _scale(k, j):
356  return 1. / (fact(k - j) * fact(j - 1))
357
358def _scale2(k, j):
359  return 1. / (fact(k - j) * fact(j))
360
361def _exp_taylor(primals_in, series_in):
362  x, = primals_in
363  series, = series_in
364  u = [x] + series
365  v = [lax.exp(x)] + [None] * len(series)
366  for k in range(1,len(v)):
367    v[k] = fact(k-1) * sum([_scale(k, j)* v[k-j] * u[j] for j in range(1, k+1)])
368  primal_out, *series_out = v
369  return primal_out, series_out
370jet_rules[lax.exp_p] = _exp_taylor
371
372def _pow_taylor(primals_in, series_in):
373  u_, r_ = primals_in
374
375  x, series = jet(lambda x, y: lax.mul(y, lax.log(x)), primals_in, series_in)
376
377  u = [x] + series
378  v = [u_ ** r_] + [None] * len(series)
379  for k in range(1, len(v)):
380    v[k] = fact(k-1) * sum([_scale(k, j)* v[k-j] * u[j] for j in range(1, k+1)])
381  primal_out, *series_out = v
382
383  return primal_out, series_out
384jet_rules[lax.pow_p] = _pow_taylor
385
386def _integer_pow_taylor(primals_in, series_in, *, y):
387  if y == 0:
388    return jet(jnp.ones_like, primals_in, series_in)
389  elif y == 1:
390    return jet(lambda x: x, primals_in, series_in)
391  elif y == 2:
392    return jet(lambda x: x * x, primals_in, series_in)
393  x, = primals_in
394  series, = series_in
395  u = [x] + series
396  v = [lax.integer_pow(x, y)] + [None] * len(series)
397  for k in range(1, len(v)):
398    vu = sum(_scale(k, j) * v[k-j] * u[j] for j in range(1, k + 1))
399    uv = sum(_scale(k, j) * u[k-j] * v[j] for j in range(1, k))
400    v[k] = jnp.where(x == 0, 0, fact(k-1) * (y * vu - uv) / x)
401  primal_out, *series_out = v
402
403  return primal_out, series_out
404jet_rules[lax.integer_pow_p] = _integer_pow_taylor
405
406
407def _expit_taylor(primals_in, series_in):
408  x, = primals_in
409  series, = series_in
410  u = [x] + series
411  v = [jax.scipy.special.expit(x)] + [None] * len(series)
412  e = [v[0] * (1 - v[0])] + [None] * len(series)  # terms for sigmoid' = sigmoid * (1 - sigmoid)
413  for k in range(1, len(v)):
414    v[k] = fact(k-1) * sum([_scale(k, j) * e[k-j] * u[j] for j in range(1, k+1)])
415    e[k] = (1 - v[0]) * v[k] - fact(k) * sum([_scale2(k, j)* v[j] * v[k-j] for j in range(1, k+1)])
416
417  primal_out, *series_out = v
418  return primal_out, series_out
419
420def _tanh_taylor(primals_in, series_in):
421  x, = primals_in
422  series, = series_in
423  u = [2*x] + [2 * series_ for series_ in series]
424  primals_in, *series_in = u
425  primal_out, series_out = _expit_taylor((primals_in, ), (series_in, ))
426  series_out = [2 * series_ for series_ in series_out]
427  return 2 * primal_out - 1, series_out
428jet_rules[lax.tanh_p] = _tanh_taylor
429
430def _log_taylor(primals_in, series_in):
431  x, = primals_in
432  series, = series_in
433  u = [x] + series
434  v = [lax.log(x)] + [None] * len(series)
435  for k in range(1, len(v)):
436    conv = sum([_scale(k, j) * v[j] * u[k-j] for j in range(1, k)])
437    v[k] = (u[k] - fact(k - 1) * conv) / u[0]
438  primal_out, *series_out = v
439  return primal_out, series_out
440jet_rules[lax.log_p] = _log_taylor
441
442def _atan2_taylor(primals_in, series_in):
443  x, y = primals_in
444  primal_out = lax.atan2(x, y)
445
446  x, series = jet(lax.div, primals_in, series_in)
447  c0, cs = jet(lambda x: lax.div(1, 1 + lax.square(x)), (x, ), (series, ))
448  c = [c0] + cs
449  u = [x] + series
450  v = [primal_out] + [None] * len(series)
451  for k in range(1, len(v)):
452    v[k] = fact(k-1) * sum(_scale(k, j) * c[k-j] * u[j] for j in range(1, k + 1))
453  primal_out, *series_out = v
454  return primal_out, series_out
455jet_rules[lax.atan2_p] = _atan2_taylor
456
457def _div_taylor_rule(primals_in, series_in):
458  x, y = primals_in
459  x_terms, y_terms = series_in
460  u = [x] + x_terms
461  w = [y] + y_terms
462  v = [None] * len(u)
463  def scale(k, j): return 1. / (fact(k - j) * fact(j))
464  for k in range(0, len(v)):
465    conv = sum([scale(k, j) * v[j] * w[k-j] for j in range(0, k)])
466    v[k] = (u[k] - fact(k) * conv) / w[0]
467  primal_out, *series_out = v
468  return primal_out, series_out
469jet_rules[lax.div_p] = _div_taylor_rule
470
471def _sinusoidal_rule(sign, prims, primals_in, series_in):
472  x, = primals_in
473  series, = series_in
474  u = [x] + series
475  s, c = prims
476  s = [s(x)] + [None] * len(series)
477  c = [c(x)] + [None] * len(series)
478  for k in range(1, len(s)):
479    s[k] = fact(k-1) * sum(_scale(k, j) * u[j] * c[k-j] for j in range(1, k + 1))
480    c[k] = fact(k-1) * sum(_scale(k, j) * u[j] * s[k-j] for j in range(1, k + 1)) * sign
481  return (s[0], s[1:]), (c[0], c[1:])
482
483def _get_ind(f, ind):
484  return lambda *args: f(*args)[ind]
485
486jet_rules[lax.sin_p] = _get_ind(partial(_sinusoidal_rule, -1, (lax.sin, lax.cos)), 0)
487jet_rules[lax.cos_p] = _get_ind(partial(_sinusoidal_rule, -1, (lax.sin, lax.cos)), 1)
488jet_rules[lax.sinh_p] = _get_ind(partial(_sinusoidal_rule, 1, (lax.sinh, lax.cosh)), 0)
489jet_rules[lax.cosh_p] = _get_ind(partial(_sinusoidal_rule, 1, (lax.sinh, lax.cosh)), 1)
490
491def _bilinear_taylor_rule(prim, primals_in, series_in, **params):
492  x, y = primals_in
493  x_terms, y_terms = series_in
494  u = [x] + x_terms
495  w = [y] + y_terms
496  v = [None] * len(u)
497  op = partial(prim.bind, **params)
498  def scale(k, j): return 1. / (fact(k - j) * fact(j))
499  for k in range(0, len(v)):
500    v[k] = fact(k) * sum([scale(k, j) * op(u[j], w[k-j]) for j in range(0, k+1)])
501  primal_out, *series_out = v
502  return primal_out, series_out
503jet_rules[lax.dot_general_p] = partial(_bilinear_taylor_rule, lax.dot_general_p)
504jet_rules[lax.mul_p] = partial(_bilinear_taylor_rule, lax.mul_p)
505jet_rules[lax.conv_general_dilated_p] = partial(_bilinear_taylor_rule, lax.conv_general_dilated_p)
506
507def _gather_taylor_rule(primals_in, series_in, **params):
508  operand, start_indices = primals_in
509  gs, _ = series_in
510  primal_out = lax.gather_p.bind(operand, start_indices, **params)
511  series_out = [lax.gather_p.bind(g, start_indices, **params) for g in gs]
512  return primal_out, series_out
513jet_rules[lax.gather_p] = _gather_taylor_rule
514
515def _gen_reduce_choose_taylor_rule(chooser_fun):
516  def chooser_taylor_rule(primals_in, series_in, **params):
517    operand, = primals_in
518    gs, = series_in
519    primal_out = chooser_fun(operand, **params)
520    axes = params.pop("axes", None)
521    primal_dtype = gs[0].dtype
522    shape = [1 if i in axes else d for i, d in enumerate(operand.shape)]
523    location_indicators = lax.convert_element_type(
524          lax._eq_meet(operand, lax.reshape(primal_out, shape)), primal_dtype)
525    counts = lax._reduce_sum(location_indicators, axes)
526    def _reduce_chooser_taylor_rule(g):
527      return lax.div(lax._reduce_sum(lax.mul(g, location_indicators), axes), counts)
528    series_out = [_reduce_chooser_taylor_rule(g) for g in gs]
529    return primal_out, series_out
530  return chooser_taylor_rule
531jet_rules[lax.reduce_max_p] = _gen_reduce_choose_taylor_rule(lax._reduce_max)
532jet_rules[lax.reduce_min_p] = _gen_reduce_choose_taylor_rule(lax._reduce_min)
533
534def _abs_taylor_rule(x, series_in, **params):
535  x, = x
536  zero = lax.full_like(x, 0, shape=())
537  primal_out = lax.abs_p.bind(x, **params)
538  negs = lax.select(lax.lt(x, zero), lax.full_like(x, -1), lax.full_like(x, 1.0))
539  fix_sign = lambda y: negs * y
540  series_out = [fix_sign(*terms_in, **params) for terms_in in zip(*series_in)]
541  return primal_out, series_out
542jet_rules[lax.abs_p] = _abs_taylor_rule
543
544def _select_taylor_rule(primal_in, series_in, **params):
545  b, x, y = primal_in
546  primal_out = lax.select_p.bind(b, x, y, **params)
547  sel = lambda _, x, y: lax.select(b, x, y)
548  series_out = [sel(*terms_in, **params) for terms_in in zip(*series_in)]
549  return primal_out, series_out
550jet_rules[lax.select_p] = _select_taylor_rule
551
552
553def _lax_max_taylor_rule(primal_in, series_in):
554    x, y = primal_in
555
556    xgy = x > y   # greater than mask
557    xey = x == y  # equal to mask
558    primal_out = lax.select(xgy, x, y)
559
560    def select_max_and_avg_eq(x_i, y_i):
561        """Select x where x>y or average when x==y"""
562        max_i = lax.select(xgy, x_i, y_i)
563        max_i = lax.select(xey, (x_i + y_i)/2, max_i)
564        return max_i
565
566    series_out = [select_max_and_avg_eq(*terms_in) for terms_in in zip(*series_in)]
567    return primal_out, series_out
568jet_rules[lax.max_p] = _lax_max_taylor_rule
569
570def _lax_min_taylor_rule(primal_in, series_in):
571    x, y = primal_in
572    xgy = x < y   # less than mask
573    xey = x == y  # equal to mask
574    primal_out = lax.select(xgy, x, y)
575
576    def select_min_and_avg_eq(x_i, y_i):
577        """Select x where x>y or average when x==y"""
578        min_i = lax.select(xgy, x_i, y_i)
579        min_i = lax.select(xey, (x_i + y_i)/2, min_i)
580        return min_i
581
582    series_out = [select_min_and_avg_eq(*terms_in) for terms_in in zip(*series_in)]
583    return primal_out, series_out
584jet_rules[lax.min_p] = _lax_min_taylor_rule
585
586def _custom_jvp_call_jaxpr_rule(primals_in, series_in, *, fun_jaxpr,
587                                jvp_jaxpr_thunk):
588  # TODO(mattjj): do something better than ignoring custom jvp rules for jet?
589  del jvp_jaxpr_thunk
590  return jet(core.jaxpr_as_fun(fun_jaxpr), primals_in, series_in)
591jet_rules[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_rule
592
593
594deflinear(lax.tie_in_p)
595