Home
last modified time | relevance | path

Searched refs:lower_fun (Results 1 – 7 of 7) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/_src/lax/
H A Dlinalg.py890 perm = xla.lower_fun(lambda x: lu_pivots_to_permutation(x, m),
900 return xla.lower_fun(_lu_python, multiple_results=True)(c, operand)
907 xla.translations[lu_p] = xla.lower_fun(_lu_python, multiple_results=True)
1099 r = xla.lower_fun(jnp.triu, multiple_results=False)(c, r)
1132 return xla.lower_fun(_empty_svd, multiple_results=True)(
1229 return xla.lower_fun(_empty_svd, multiple_results=True)(
H A Dparallel.py707 lowering = xla.lower_fun(_all_to_all_via_all_gather, multiple_results=False, parallel=True)
724 …x = xla.lower_fun(partial(lax.expand_dims, dimensions=(concat_axis,)), multiple_results=False)(c, …
726 … x = xla.lower_fun(partial(lax.squeeze, dimensions=(split_axis,)), multiple_results=False)(c, x)
887 lowering = xla.lower_fun(_all_gather_via_psum, multiple_results=False, parallel=True)
H A Dcontrol_flow.py2564 xla.backend_specific_translations['tpu'][cumsum_p] = xla.lower_fun(
2574 xla.backend_specific_translations['tpu'][reducer_p] = xla.lower_fun(
2585 xla.translations[cumsum_p] = xla.lower_fun(
2587 xla.translations[cumprod_p] = xla.lower_fun(
2589 xla.translations[cummin_p] = xla.lower_fun(
2591 xla.translations[cummax_p] = xla.lower_fun(
H A Dlax.py2184 rounding_fun = xla.lower_fun(_round_to_nearest_even, multiple_results=False)
2221 @partial(xla.lower_fun, multiple_results=False)
2231 @partial(xla.lower_fun, multiple_results=False)
2244 @partial(xla.lower_fun, multiple_results=False)
2266 @partial(xla.lower_fun, multiple_results=False)
5158 xla.backend_specific_translations['gpu'][argmin_p] = xla.lower_fun(
5166 xla.backend_specific_translations['gpu'][argmax_p] = xla.lower_fun(
5750 result = xla.lower_fun(partial(_sort_lt_comparator, num_keys=num_keys),
/dports/math/py-jax/jax-0.2.9/jax/_src/
H A Drandom.py210 xla.translations_with_avals[threefry2x32_p] = xla.lower_fun(
213 xla.backend_specific_translations['cpu'][threefry2x32_p] = xla.lower_fun(
1041 xla.translations_with_avals[random_gamma_p] = xla.lower_fun(
1044 xla.backend_specific_translations['cpu'][random_gamma_p] = xla.lower_fun(
/dports/math/py-jax/jax-0.2.9/jax/
H A Dapi.py2553 return xla.lower_fun(fun_impl, multiple_results=True)(c, *xla_args, **params)
/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dxla.py940 def lower_fun(fun, multiple_results, parallel=False, with_avals=False): function