Home
last modified time | relevance | path

Searched refs:psum_p (Results 1 – 3 of 3) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/_src/lax/
H A Dparallel.py593 psum_p = core.Primitive('psum') variable
594 psum_p.multiple_results = True
596 pxla.soft_pmap_rules[psum_p] = \
599 ad.deflinear2(psum_p, _psum_transpose_rule)
601 batching.primitive_batchers[psum_p] = partial(_collective_batcher, psum_p)
602 batching.collective_rules[psum_p] = \
604 psum_p,
610 @psum_p.def_custom_bind
1035 out_tup = xla.parallel_translations[psum_p](
1063 psum_p.bind = partial(core.Primitive.bind, psum_p) # type: ignore
[all …]
/dports/math/py-jax/jax-0.2.9/jax/lax/
H A D__init__.py343 psum_p,
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/
H A Djax2tf.py832 lax_parallel.pmin_p, lax_parallel.ppermute_p, lax_parallel.psum_p,