/dports/math/py-jax/jax-0.2.9/jax/_src/lax/ |
H A D | control_flow.py | 95 def pad_jaxpr_constvars(i, jaxpr): argument 735 def make_computation(name, jaxpr, op_shape): argument 958 def augment_jaxpr(jaxpr, res_indices): argument 979 def augment_jaxpr(jaxpr, res_indices): argument 997 def _transpose_cond_jaxpr(jaxpr, num_res): argument 1533 jaxpr, linear, unroll): argument 1709 def _transpose_scan_jaxpr(num_res1, num_c, num_res2, jaxpr): argument 1801 def _masked_scan_jaxpr(jaxpr, num_consts, num_carry): argument 1819 jaxpr, linear, unroll): argument 2650 def pad_jaxpr_constvars(i, jaxpr): argument [all …]
|
H A D | lax.py | 4880 def _reduce_shape_rule(*args, computation, jaxpr, consts, dimensions): argument 4891 def _reduce_dtype_rule(*args, computation, jaxpr, consts, dimensions): argument 4901 def _reduce_translation_rule(c, *values, computation, jaxpr, argument 4913 def _reduce_batch_rule(batched_args, batch_dims, *, computation, jaxpr, argument 4936 def _reduction_computation(c, jaxpr, consts, init_values, singleton=True): argument 5195 def _reduce_window_shape_rule(operand, init_value, *, jaxpr, consts, argument 5210 def _reduce_window_translation_rule(c, operand, init_value, *, jaxpr, consts, argument 5219 batched_args, batch_dims, *, jaxpr, consts, window_dimensions, argument
|
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | pxla.py | 947 def check_multihost_collective_allowlist(jaxpr): argument 957 def _find_partitions(jaxpr) -> Tuple[ argument 982 def reconcile_num_partitions(jaxpr, outer_num_parts: Optional[int]): argument 998 def _inner_partitions(jaxpr, expected_num_parts: Optional[int]): argument 1569 def _sanitize_mesh_jaxpr(jaxpr): argument 1665 def _soft_pmap_jaxpr(jaxpr, consts, in_axes, axis_name, axis_size, chunk_size): argument 1674 def _soft_pmap_interp(chunk_size, jaxpr, consts, mapped_invars, *args): argument
|
H A D | xla.py | 389 def jaxpr_literals(jaxpr): argument 405 def jaxpr_subcomp(c, jaxpr, backend, axis_env, consts, name_stack, *args): argument 557 def jaxpr_has_pmap(jaxpr): argument 568 def jaxpr_collectives(jaxpr): argument 853 def _execute_trivial(jaxpr, device: Optional[Device], consts, avals, handlers, *args): argument
|
H A D | ad.py | 118 def unbound_vjp(pvals, jaxpr, consts, *cts): argument 635 def jvp_jaxpr(jaxpr, nonzeros, instantiate): argument 784 def jvp_jaxpr(jaxpr, nonzeros, instantiate): argument
|
H A D | batching.py | 477 def batch_jaxpr(jaxpr, axis_size, in_batched, instantiate, axis_name): argument
|
H A D | partial_eval.py | 978 def _inline_literals(jaxpr, constvals): argument
|
/dports/math/py-jax/jax-0.2.9/jax/ |
H A D | jaxpr_util.py | 112 def hist(jaxpr, reads): argument
|
H A D | custom_derivatives.py | 831 def __init__(self, jaxpr, in_tree, out_tree, consts): argument
|
H A D | test_util.py | 730 def iter_eqns(jaxpr): argument
|
H A D | api.py | 1805 def _lift_linearized(jaxpr, primal_avals, consts, io_tree, out_pvals, *py_args): argument
|
H A D | core.py | 197 def _jaxpr_vars(jaxpr): argument
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/ |
H A D | maps.py | 625 def _jaxpr_resources(jaxpr, resource_env) -> Set[ResourceAxisName]: argument 639 def subst_axis_names(jaxpr, axis_subst: Dict[AxisName, Tuple[AxisName]]): argument
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/ |
H A D | jax2tf.py | 1564 def _reduce_window(operand, init_value, *, jaxpr, consts, window_dimensions, argument
|