Home
last modified time | relevance | path

Searched refs:primitive_subcomputation (Results 1 – 4 of 4) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/_src/lax/
H A Dparallel.py554 computation = xla.primitive_subcomputation(prim, scalar, scalar)
574 computation = xla.primitive_subcomputation(prim, scalar, scalar)
H A Dlax.py4985 xla.primitive_subcomputation(add_p, scalar, scalar),
5017 xla.primitive_subcomputation(mul_p, scalar, scalar), axes)
5069 xla.primitive_subcomputation(prim, scalar, scalar), axes)
5180 xla.primitive_subcomputation(prim, scalar, scalar), axes)
5261 xla.primitive_subcomputation(add_p, scalar, scalar), window_dimensions,
5315 xla.primitive_subcomputation(prim, scalar, scalar), window_dimensions,
5428 select = xla.primitive_subcomputation(select_prim, scalar, scalar)
5429 scatter = xla.primitive_subcomputation(add_p, scalar, scalar)
H A Dcontrol_flow.py323 or_ = xla.primitive_subcomputation(lax.or_p, scalar, scalar)
/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dxla.py345 def primitive_subcomputation(prim, *avals, **params): function