Home
last modified time | relevance | path

Searched refs:backend_specific_translations (Results 1 – 7 of 7) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/_src/lax/
H A Dlinalg.py534 xla.backend_specific_translations['cpu'][eigh_p] = partial(
538 xla.backend_specific_translations['gpu'][eigh_p] = partial(
542 xla.backend_specific_translations['gpu'][eigh_p] = partial(
911 xla.backend_specific_translations['cpu'][lu_p] = partial(
915 xla.backend_specific_translations['gpu'][lu_p] = partial(
919 xla.backend_specific_translations['gpu'][lu_p] = partial(
1110 xla.backend_specific_translations['cpu'][qr_p] = partial(
1114 xla.backend_specific_translations['gpu'][qr_p] = partial(
1118 xla.backend_specific_translations['gpu'][qr_p] = partial(
1269 xla.backend_specific_translations['cpu'][svd_p] = partial(
[all …]
H A Dfft.py141 xla.backend_specific_translations['cpu'][fft_p] = pocketfft.pocketfft
H A Dlax.py2994 xla.backend_specific_translations['cpu'][conv_general_dilated_p] = partial(
2996 xla.backend_specific_translations['gpu'][conv_general_dilated_p] = partial(
4657 xla.backend_specific_translations['gpu'][scatter_add_p] = partial(
5158 xla.backend_specific_translations['gpu'][argmin_p] = xla.lower_fun(
5166 xla.backend_specific_translations['gpu'][argmax_p] = xla.lower_fun(
5677 xla.backend_specific_translations['tpu'][select_and_gather_add_p] = partial(
H A Dcontrol_flow.py2564 xla.backend_specific_translations['tpu'][cumsum_p] = xla.lower_fun(
2574 xla.backend_specific_translations['tpu'][reducer_p] = xla.lower_fun(
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/tests/
H A Dprimitives_test.py119 | set(xla.backend_specific_translations["cpu"])
120 | set(xla.backend_specific_translations["gpu"])
121 | set(xla.backend_specific_translations["tpu"])
/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dxla.py320 if prim in backend_specific_translations[platform]:
321 rule = backend_specific_translations[platform][prim]
441 if eqn.primitive in backend_specific_translations[platform]:
442 rule = backend_specific_translations[platform][eqn.primitive]
915 backend_specific_translations: Dict[str, Dict[core.Primitive, Callable]] = defaultdict(dict) variable
/dports/math/py-jax/jax-0.2.9/jax/_src/
H A Drandom.py213 xla.backend_specific_translations['cpu'][threefry2x32_p] = xla.lower_fun(
217 xla.backend_specific_translations['gpu'][threefry2x32_p] = \
1044 xla.backend_specific_translations['cpu'][random_gamma_p] = xla.lower_fun(