Searched refs:jaxpr_subcomp (Results 1 – 8 of 8) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | xla.py | 405 def jaxpr_subcomp(c, jaxpr, backend, axis_env, consts, name_stack, *args): function 708 out_nodes = jaxpr_subcomp( 901 out_nodes = jaxpr_subcomp(subc, call_jaxpr, backend, axis_env, (), 957 outs = jaxpr_subcomp(c, jaxpr, None, axis_env, _xla_consts(c, consts), '', 964 outs = jaxpr_subcomp(c, jaxpr, None, axis_env, xla_consts, '', *xla_args) 984 outs = jaxpr_subcomp(c, jaxpr, backend, axis_env, _xla_consts(c, consts), 991 outs = jaxpr_subcomp(c, jaxpr, backend, axis_env, xla_consts, name_stack, 1419 out_nodes = jaxpr_subcomp(remat_subc, call_jaxpr, backend, axis_env, (), 1452 out_nodes = jaxpr_subcomp(subc, call_jaxpr, backend, axis_env, (),
|
H A D | sharded_jit.py | 164 out_nodes = xla.jaxpr_subcomp( 217 out_nodes = xla.jaxpr_subcomp(
|
H A D | pxla.py | 841 out_nodes = xla.jaxpr_subcomp(c, jaxpr, backend_name, axis_env, xla_consts, 1238 sharded_outs = xla.jaxpr_subcomp( 1482 out_nodes = xla.jaxpr_subcomp( 1638 out_nodes = xla.jaxpr_subcomp(c, jaxpr, None, axis_env, xla_consts,
|
/dports/math/py-jax/jax-0.2.9/jax/ |
H A D | test_util.py | 334 jaxpr_subcomp = xla.jaxpr_subcomp 339 return jaxpr_subcomp(*args, **kwargs) 341 xla.jaxpr_subcomp = jaxpr_subcomp_and_count 345 xla.jaxpr_subcomp = jaxpr_subcomp
|
H A D | api.py | 663 out_nodes = xla.jaxpr_subcomp(
|
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/ |
H A D | control_flow.py | 318 pred, = xla.jaxpr_subcomp(cond_c, cond_jaxpr.jaxpr, backend, axis_env, 331 new_z = xla.jaxpr_subcomp(body_c, body_jaxpr.jaxpr, backend, axis_env, 335 body_pred, = xla.jaxpr_subcomp(body_c, cond_jaxpr.jaxpr, backend, axis_env, 739 outs = xla.jaxpr_subcomp(c, jaxpr.jaxpr, backend, axis_env,
|
H A D | lax.py | 4944 out_nodes = xla.jaxpr_subcomp(subc, jaxpr, None, axis_env, consts, '', *args)
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/ |
H A D | maps.py | 489 tiled_outs = xla.jaxpr_subcomp(
|