Searched refs:aval_to_result_handler (Results 1 – 3 of 3) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | sharded_jit.py | 72 return pxla.aval_to_result_handler(spec, indices, aval) 503 return pxla.aval_to_result_handler(spec, indices, pv)
|
H A D | xla.py | 108 def aval_to_result_handler(device: Optional[Device], aval: core.AbstractValue) -> Callable: function 265 handle_result = aval_to_result_handler(device, aval_out) 267 handlers = map(partial(aval_to_result_handler, device), aval_out) 670 result_handlers = map(partial(aval_to_result_handler, device), out_avals) 1390 return aval_to_result_handler(device, a)(*device_put(x, device)) 1477 return aval_to_result_handler(device, pv)
|
H A D | pxla.py | 440 def aval_to_result_handler(sharding_spec: Optional[ShardingSpec], function 1104 handlers = [aval_to_result_handler(spec, idcs, aval) 1829 return aval_to_result_handler(sharding_spec, indices, unsharded_aval)
|