Searched refs:primitive_subcomputation (Results 1 – 4 of 4) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/ |
H A D | parallel.py | 554 computation = xla.primitive_subcomputation(prim, scalar, scalar) 574 computation = xla.primitive_subcomputation(prim, scalar, scalar)
|
H A D | lax.py | 4985 xla.primitive_subcomputation(add_p, scalar, scalar), 5017 xla.primitive_subcomputation(mul_p, scalar, scalar), axes) 5069 xla.primitive_subcomputation(prim, scalar, scalar), axes) 5180 xla.primitive_subcomputation(prim, scalar, scalar), axes) 5261 xla.primitive_subcomputation(add_p, scalar, scalar), window_dimensions, 5315 xla.primitive_subcomputation(prim, scalar, scalar), window_dimensions, 5428 select = xla.primitive_subcomputation(select_prim, scalar, scalar) 5429 scatter = xla.primitive_subcomputation(add_p, scalar, scalar)
|
H A D | control_flow.py | 323 or_ = xla.primitive_subcomputation(lax.or_p, scalar, scalar)
|
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | xla.py | 345 def primitive_subcomputation(prim, *avals, **params): function
|