Home
last modified time | relevance | path

Searched defs:call_jaxpr (Results 1 – 5 of 5) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dad.py543 def call_transpose(primitive, params, call_jaxpr, args, ct, _): argument
557 def remat_transpose(params, call_jaxpr, primals_in, cotangents_in, cotangent_in_avals): argument
593 def map_transpose(primitive, params, call_jaxpr, args, ct, _): argument
H A Dxla.py897 call_jaxpr, donated_invars, device=None): argument
1401 name_stack, backend, name, call_jaxpr, argument
1449 name="core_call", backend, call_jaxpr): argument
1460 call_jaxpr): argument
H A Dsharded_jit.py201 name, call_jaxpr, local_in_parts, argument
H A Dpxla.py1221 call_jaxpr, *, backend=None, in_axes, out_axes, argument
/dports/math/py-jax/jax-0.2.9/jax/experimental/
H A Dmaps.py452 call_jaxpr, name, argument
559 call_jaxpr, name, argument