Home
last modified time | relevance | path

Searched refs:primitive_transposes (Results 1 – 8 of 8) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dad.py254 return primitive_transposes[p]
424 primitive_transposes: Dict[core.Primitive, Callable] = {} variable
429 primitive_transposes[primitive] = partial(linear_transpose, transpose_rule)
445 primitive_transposes[primitive] = partial(linear_transpose2, transpose_rule)
489 primitive_transposes[prim] = partial(bilinear_transpose, lhs_rule, rhs_rule)
554 primitive_transposes[core.call_p] = partial(call_transpose, call_p)
585 primitive_transposes[pe.remat_call_p] = remat_transpose
689 primitive_transposes[custom_lin_p] = _custom_lin_transpose
739 primitive_transposes[fun_lin_p] = fun_lin_transpose
H A Dxla.py905 ad.primitive_transposes[xla_call_p] = partial(ad.call_transpose, xla_call_p)
1444 ad.primitive_transposes[core.named_call_p] = partial(ad.call_transpose,
H A Dpxla.py1247 ad.primitive_transposes[xla_pmap_p] = partial(ad.map_transpose, xla_pmap_p)
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/
H A Dlax.py2401 ad.primitive_transposes[conj_p] = _conj_transpose_rule
2493 ad.primitive_transposes[add_p] = _add_transpose
2516 ad.primitive_transposes[sub_p] = _sub_transpose
2534 ad.primitive_transposes[div_p] = _div_transpose_rule
3430 ad.primitive_transposes[concatenate_p] = _concatenate_transpose_rule
3819 ad.primitive_transposes[select_p] = _select_transpose_rule
4096 ad.primitive_transposes[dynamic_update_slice_p] = \
4343 ad.primitive_transposes[gather_p] = _gather_transpose_rule
4653 ad.primitive_transposes[scatter_add_p] = _scatter_add_transpose_rule
5481 ad.primitive_transposes[select_and_scatter_add_p] = \
[all …]
H A Dcontrol_flow.py546 ad.primitive_transposes[while_p] = _while_transpose_error
1108 ad.primitive_transposes[cond_p] = _cond_transpose
1876 ad.primitive_transposes[scan_p] = _scan_transpose
2359 ad.primitive_transposes[linear_solve_p] = _linear_solve_transpose_rule
H A Dlinalg.py662 ad.primitive_transposes[triangular_solve_p] = triangular_solve_transpose_rule
/dports/math/py-jax/jax-0.2.9/jax/experimental/
H A Dhost_callback.py679 ad.primitive_transposes[id_tap_dep_p] = _id_tap_dep_transpose_rule
1042 ad.primitive_transposes[outside_call_p] = _outside_call_transpose_rule
/dports/math/py-jax/jax-0.2.9/jax/
H A Dcustom_derivatives.py374 ad.primitive_transposes[custom_jvp_call_jaxpr_p] = _custom_jvp_call_jaxpr_transpose