Searched refs:xla_call_p (Results 1 – 5 of 5) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | xla.py | 842 check_special(xla_call_p, out_bufs) 850 check_special(xla_call_p, out_bufs) 862 xla_call_p = core.CallPrimitive('xla_call') variable 863 xla_call = xla_call_p.bind 864 xla_call_p.def_impl(_xla_call_impl) 878 pe.call_param_updaters[xla_call_p] = _xla_call_partial_eval_update_params 885 ad.call_param_updaters[xla_call_p] = _xla_call_jvp_update_params 892 ad.call_transpose_param_updaters[xla_call_p] = _xla_call_transpose_update_params 905 ad.primitive_transposes[xla_call_p] = partial(ad.call_transpose, xla_call_p) 917 call_translations[xla_call_p] = _xla_call_translation_rule [all …]
|
H A D | pxla.py | 1213 pe.call_param_updaters[xla_pmap_p] = pe.call_param_updaters[xla.xla_call_p] 1214 ad.call_param_updaters[xla_pmap_p] = ad.call_param_updaters[xla.xla_call_p] 1216 ad.call_transpose_param_updaters[xla.xla_call_p]
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/ |
H A D | host_callback.py | 1218 elif eqn.primitive is xla.xla_call_p: 1328 pred1_and_token1, xla.xla_call_p, 1371 xla.xla_call_p, 1379 [new_body_pred2, new_body_token3, new_body_itoken3], xla.xla_call_p,
|
H A D | jet.py | 182 call_param_updaters[xla.xla_call_p] = _xla_call_param_updater
|
/dports/math/py-jax/jax-0.2.9/jax/ |
H A D | api.py | 384 xla.check_special(xla.xla_call_p, [
|