Home
last modified time | relevance | path

Searched refs:xla_primitive_callable (Results 1 – 2 of 2) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/
H A Dtest_util.py310 xla.xla_primitive_callable.cache_clear()
879 cache_misses = xla.xla_primitive_callable.cache_info().misses
882 cache_misses, xla.xla_primitive_callable.cache_info().misses,
/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dxla.py238 compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
251 def xla_primitive_callable(prim, *arg_specs: Tuple[core.AbstractValue, function