Home
last modified time | relevance | path

Searched refs:with_sharding_proto (Results 1 – 3 of 3) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/lib/
H A Dxla_bridge.py381 return with_sharding_proto(builder, sharding_proto, xops.CustomCall,
384 def with_sharding_proto(builder, sharding_proto, op_fn, *args, **kwargs): function
398 return with_sharding_proto(builder, _sharding_to_proto(sharding), op_fn, *args, **kwargs)
/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dxla.py830 with_sharding = xb.with_sharding_proto if parts_proto else xb.with_sharding
H A Dpxla.py1488 out_tuple = xb.with_sharding_proto(c, out_partitions_t, xops.Tuple, c, out_nodes)