Home
last modified time | relevance | path

Searched defs:jaxpr (Results 1 – 14 of 14) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/_src/lax/
H A Dcontrol_flow.py95 def pad_jaxpr_constvars(i, jaxpr): argument
735 def make_computation(name, jaxpr, op_shape): argument
958 def augment_jaxpr(jaxpr, res_indices): argument
979 def augment_jaxpr(jaxpr, res_indices): argument
997 def _transpose_cond_jaxpr(jaxpr, num_res): argument
1533 jaxpr, linear, unroll): argument
1709 def _transpose_scan_jaxpr(num_res1, num_c, num_res2, jaxpr): argument
1801 def _masked_scan_jaxpr(jaxpr, num_consts, num_carry): argument
1819 jaxpr, linear, unroll): argument
2650 def pad_jaxpr_constvars(i, jaxpr): argument
[all …]
H A Dlax.py4880 def _reduce_shape_rule(*args, computation, jaxpr, consts, dimensions): argument
4891 def _reduce_dtype_rule(*args, computation, jaxpr, consts, dimensions): argument
4901 def _reduce_translation_rule(c, *values, computation, jaxpr, argument
4913 def _reduce_batch_rule(batched_args, batch_dims, *, computation, jaxpr, argument
4936 def _reduction_computation(c, jaxpr, consts, init_values, singleton=True): argument
5195 def _reduce_window_shape_rule(operand, init_value, *, jaxpr, consts, argument
5210 def _reduce_window_translation_rule(c, operand, init_value, *, jaxpr, consts, argument
5219 batched_args, batch_dims, *, jaxpr, consts, window_dimensions, argument
/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dpxla.py947 def check_multihost_collective_allowlist(jaxpr): argument
957 def _find_partitions(jaxpr) -> Tuple[ argument
982 def reconcile_num_partitions(jaxpr, outer_num_parts: Optional[int]): argument
998 def _inner_partitions(jaxpr, expected_num_parts: Optional[int]): argument
1569 def _sanitize_mesh_jaxpr(jaxpr): argument
1665 def _soft_pmap_jaxpr(jaxpr, consts, in_axes, axis_name, axis_size, chunk_size): argument
1674 def _soft_pmap_interp(chunk_size, jaxpr, consts, mapped_invars, *args): argument
H A Dxla.py389 def jaxpr_literals(jaxpr): argument
405 def jaxpr_subcomp(c, jaxpr, backend, axis_env, consts, name_stack, *args): argument
557 def jaxpr_has_pmap(jaxpr): argument
568 def jaxpr_collectives(jaxpr): argument
853 def _execute_trivial(jaxpr, device: Optional[Device], consts, avals, handlers, *args): argument
H A Dad.py118 def unbound_vjp(pvals, jaxpr, consts, *cts): argument
635 def jvp_jaxpr(jaxpr, nonzeros, instantiate): argument
784 def jvp_jaxpr(jaxpr, nonzeros, instantiate): argument
H A Dbatching.py477 def batch_jaxpr(jaxpr, axis_size, in_batched, instantiate, axis_name): argument
H A Dpartial_eval.py978 def _inline_literals(jaxpr, constvals): argument
/dports/math/py-jax/jax-0.2.9/jax/
H A Djaxpr_util.py112 def hist(jaxpr, reads): argument
H A Dcustom_derivatives.py831 def __init__(self, jaxpr, in_tree, out_tree, consts): argument
H A Dtest_util.py730 def iter_eqns(jaxpr): argument
H A Dapi.py1805 def _lift_linearized(jaxpr, primal_avals, consts, io_tree, out_pvals, *py_args): argument
H A Dcore.py197 def _jaxpr_vars(jaxpr): argument
/dports/math/py-jax/jax-0.2.9/jax/experimental/
H A Dmaps.py625 def _jaxpr_resources(jaxpr, resource_env) -> Set[ResourceAxisName]: argument
639 def subst_axis_names(jaxpr, axis_subst: Dict[AxisName, Tuple[AxisName]]): argument
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/
H A Djax2tf.py1564 def _reduce_window(operand, init_value, *, jaxpr, consts, window_dimensions, argument