Home
last modified time | relevance | path

Searched defs:out_axes_thunk (Results 1 – 3 of 3) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/experimental/
H A Dmaps.py267 def xmap_impl(fun: lu.WrappedFun, *args, name, in_axes, out_axes_thunk, axis_sizes, argument
276 in_axes, out_axes_thunk, axis_sizes, argument
/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dpxla.py640 global_axis_size, devices, name, in_axes, out_axes_thunk, argument
1603 def soft_pmap_impl(fun: lu.WrappedFun, *args, axis_name, axis_size, in_axes, out_axes_thunk): argument
1610 def _soft_pmap_callable(fun, axis_name, axis_size, in_axes, out_axes_thunk, *avals): argument
H A Dad.py610 def out_axes_thunk(): function