Home
last modified time | relevance | path

Searched refs:nz_tangents (Results 1 – 2 of 2) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dad.py295 nz_tangents = [type(t) is not Zero for t in tangents]
300 tangent_in_axes = [ax for ax, nz in zip(in_axes, nz_tangents) if nz]
308 @as_hashable_function(closure=(tuple(nz_tangents), out_axes_thunk))
317 new_params = update_params(params, nz_tangents) if update_params else params
H A Dxla.py880 def _xla_call_jvp_update_params(params, nz_tangents): argument
882 donated_tangents = [d for d, nz in zip(donated_invars, nz_tangents) if nz]