Home
last modified time | relevance | path

Searched refs:call_transpose (Results 1 – 2 of 2) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dad.py543 def call_transpose(primitive, params, call_jaxpr, args, ct, _): function
554 primitive_transposes[core.call_p] = partial(call_transpose, call_p)
H A Dxla.py905 ad.primitive_transposes[xla_call_p] = partial(ad.call_transpose, xla_call_p)
1444 ad.primitive_transposes[core.named_call_p] = partial(ad.call_transpose,