Home
last modified time | relevance | path

Searched refs:linear_transpose (Results 1 – 5 of 5) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/_src/lax/
H A Dfft.py20 from jax.api import jit, linear_transpose, ShapeDtypeStruct
94 transpose = linear_transpose(
H A Dcontrol_flow.py2114 transpose_fun = jax.linear_transpose(linear_fun, primals)
/dports/math/py-jax/jax-0.2.9/jax/
H A D__init__.py61 linear_transpose,
H A Dapi.py1934 def linear_transpose(fun: Callable, *primals) -> Callable: function
/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dad.py429 primitive_transposes[primitive] = partial(linear_transpose, transpose_rule)
439 def linear_transpose(transpose_rule, cotangent, *args, **kwargs): function