Home
last modified time | relevance | path

Searched refs:parts_proto (Results 1 – 1 of 1) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dxla.py827 def _xla_param(builder, param_num, xla_shape, replicated, partitions, parts_proto): argument
830 with_sharding = xb.with_sharding_proto if parts_proto else xb.with_sharding