Searched refs:with_sharding_proto (Results 1 – 3 of 3) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/lib/ |
H A D | xla_bridge.py | 381 return with_sharding_proto(builder, sharding_proto, xops.CustomCall, 384 def with_sharding_proto(builder, sharding_proto, op_fn, *args, **kwargs): function 398 return with_sharding_proto(builder, _sharding_to_proto(sharding), op_fn, *args, **kwargs)
|
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | xla.py | 830 with_sharding = xb.with_sharding_proto if parts_proto else xb.with_sharding
|
H A D | pxla.py | 1488 out_tuple = xb.with_sharding_proto(c, out_partitions_t, xops.Tuple, c, out_nodes)
|