Home
last modified time | relevance | path

Searched refs:jaxpr_subcomp (Results 1 – 8 of 8) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dxla.py405 def jaxpr_subcomp(c, jaxpr, backend, axis_env, consts, name_stack, *args): function
708 out_nodes = jaxpr_subcomp(
901 out_nodes = jaxpr_subcomp(subc, call_jaxpr, backend, axis_env, (),
957 outs = jaxpr_subcomp(c, jaxpr, None, axis_env, _xla_consts(c, consts), '',
964 outs = jaxpr_subcomp(c, jaxpr, None, axis_env, xla_consts, '', *xla_args)
984 outs = jaxpr_subcomp(c, jaxpr, backend, axis_env, _xla_consts(c, consts),
991 outs = jaxpr_subcomp(c, jaxpr, backend, axis_env, xla_consts, name_stack,
1419 out_nodes = jaxpr_subcomp(remat_subc, call_jaxpr, backend, axis_env, (),
1452 out_nodes = jaxpr_subcomp(subc, call_jaxpr, backend, axis_env, (),
H A Dsharded_jit.py164 out_nodes = xla.jaxpr_subcomp(
217 out_nodes = xla.jaxpr_subcomp(
H A Dpxla.py841 out_nodes = xla.jaxpr_subcomp(c, jaxpr, backend_name, axis_env, xla_consts,
1238 sharded_outs = xla.jaxpr_subcomp(
1482 out_nodes = xla.jaxpr_subcomp(
1638 out_nodes = xla.jaxpr_subcomp(c, jaxpr, None, axis_env, xla_consts,
/dports/math/py-jax/jax-0.2.9/jax/
H A Dtest_util.py334 jaxpr_subcomp = xla.jaxpr_subcomp
339 return jaxpr_subcomp(*args, **kwargs)
341 xla.jaxpr_subcomp = jaxpr_subcomp_and_count
345 xla.jaxpr_subcomp = jaxpr_subcomp
H A Dapi.py663 out_nodes = xla.jaxpr_subcomp(
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/
H A Dcontrol_flow.py318 pred, = xla.jaxpr_subcomp(cond_c, cond_jaxpr.jaxpr, backend, axis_env,
331 new_z = xla.jaxpr_subcomp(body_c, body_jaxpr.jaxpr, backend, axis_env,
335 body_pred, = xla.jaxpr_subcomp(body_c, cond_jaxpr.jaxpr, backend, axis_env,
739 outs = xla.jaxpr_subcomp(c, jaxpr.jaxpr, backend, axis_env,
H A Dlax.py4944 out_nodes = xla.jaxpr_subcomp(subc, jaxpr, None, axis_env, consts, '', *args)
/dports/math/py-jax/jax-0.2.9/jax/experimental/
H A Dmaps.py489 tiled_outs = xla.jaxpr_subcomp(