Searched refs:lower_fun (Results 1 – 7 of 7) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/ |
H A D | linalg.py | 890 perm = xla.lower_fun(lambda x: lu_pivots_to_permutation(x, m), 900 return xla.lower_fun(_lu_python, multiple_results=True)(c, operand) 907 xla.translations[lu_p] = xla.lower_fun(_lu_python, multiple_results=True) 1099 r = xla.lower_fun(jnp.triu, multiple_results=False)(c, r) 1132 return xla.lower_fun(_empty_svd, multiple_results=True)( 1229 return xla.lower_fun(_empty_svd, multiple_results=True)(
|
H A D | parallel.py | 707 lowering = xla.lower_fun(_all_to_all_via_all_gather, multiple_results=False, parallel=True) 724 …x = xla.lower_fun(partial(lax.expand_dims, dimensions=(concat_axis,)), multiple_results=False)(c, … 726 … x = xla.lower_fun(partial(lax.squeeze, dimensions=(split_axis,)), multiple_results=False)(c, x) 887 lowering = xla.lower_fun(_all_gather_via_psum, multiple_results=False, parallel=True)
|
H A D | control_flow.py | 2564 xla.backend_specific_translations['tpu'][cumsum_p] = xla.lower_fun( 2574 xla.backend_specific_translations['tpu'][reducer_p] = xla.lower_fun( 2585 xla.translations[cumsum_p] = xla.lower_fun( 2587 xla.translations[cumprod_p] = xla.lower_fun( 2589 xla.translations[cummin_p] = xla.lower_fun( 2591 xla.translations[cummax_p] = xla.lower_fun(
|
H A D | lax.py | 2184 rounding_fun = xla.lower_fun(_round_to_nearest_even, multiple_results=False) 2221 @partial(xla.lower_fun, multiple_results=False) 2231 @partial(xla.lower_fun, multiple_results=False) 2244 @partial(xla.lower_fun, multiple_results=False) 2266 @partial(xla.lower_fun, multiple_results=False) 5158 xla.backend_specific_translations['gpu'][argmin_p] = xla.lower_fun( 5166 xla.backend_specific_translations['gpu'][argmax_p] = xla.lower_fun( 5750 result = xla.lower_fun(partial(_sort_lt_comparator, num_keys=num_keys),
|
/dports/math/py-jax/jax-0.2.9/jax/_src/ |
H A D | random.py | 210 xla.translations_with_avals[threefry2x32_p] = xla.lower_fun( 213 xla.backend_specific_translations['cpu'][threefry2x32_p] = xla.lower_fun( 1041 xla.translations_with_avals[random_gamma_p] = xla.lower_fun( 1044 xla.backend_specific_translations['cpu'][random_gamma_p] = xla.lower_fun(
|
/dports/math/py-jax/jax-0.2.9/jax/ |
H A D | api.py | 2553 return xla.lower_fun(fun_impl, multiple_results=True)(c, *xla_args, **params)
|
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | xla.py | 940 def lower_fun(fun, multiple_results, parallel=False, with_avals=False): function
|