Searched defs:call_jaxpr (Results 1 – 5 of 5) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | ad.py | 543 def call_transpose(primitive, params, call_jaxpr, args, ct, _): argument 557 def remat_transpose(params, call_jaxpr, primals_in, cotangents_in, cotangent_in_avals): argument 593 def map_transpose(primitive, params, call_jaxpr, args, ct, _): argument
|
H A D | xla.py | 897 call_jaxpr, donated_invars, device=None): argument 1401 name_stack, backend, name, call_jaxpr, argument 1449 name="core_call", backend, call_jaxpr): argument 1460 call_jaxpr): argument
|
H A D | sharded_jit.py | 201 name, call_jaxpr, local_in_parts, argument
|
H A D | pxla.py | 1221 call_jaxpr, *, backend=None, in_axes, out_axes, argument
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/ |
H A D | maps.py | 452 call_jaxpr, name, argument 559 call_jaxpr, name, argument
|