Home
last modified time | relevance | path

Searched refs:xla_call_p (Results 1 – 5 of 5) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dxla.py842 check_special(xla_call_p, out_bufs)
850 check_special(xla_call_p, out_bufs)
862 xla_call_p = core.CallPrimitive('xla_call') variable
863 xla_call = xla_call_p.bind
864 xla_call_p.def_impl(_xla_call_impl)
878 pe.call_param_updaters[xla_call_p] = _xla_call_partial_eval_update_params
885 ad.call_param_updaters[xla_call_p] = _xla_call_jvp_update_params
892 ad.call_transpose_param_updaters[xla_call_p] = _xla_call_transpose_update_params
905 ad.primitive_transposes[xla_call_p] = partial(ad.call_transpose, xla_call_p)
917 call_translations[xla_call_p] = _xla_call_translation_rule
[all …]
H A Dpxla.py1213 pe.call_param_updaters[xla_pmap_p] = pe.call_param_updaters[xla.xla_call_p]
1214 ad.call_param_updaters[xla_pmap_p] = ad.call_param_updaters[xla.xla_call_p]
1216 ad.call_transpose_param_updaters[xla.xla_call_p]
/dports/math/py-jax/jax-0.2.9/jax/experimental/
H A Dhost_callback.py1218 elif eqn.primitive is xla.xla_call_p:
1328 pred1_and_token1, xla.xla_call_p,
1371 xla.xla_call_p,
1379 [new_body_pred2, new_body_token3, new_body_itoken3], xla.xla_call_p,
H A Djet.py182 call_param_updaters[xla.xla_call_p] = _xla_call_param_updater
/dports/math/py-jax/jax-0.2.9/jax/
H A Dapi.py384 xla.check_special(xla.xla_call_p, [