Searched defs:out_axes_thunk (Results 1 – 3 of 3) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/experimental/ |
H A D | maps.py | 267 def xmap_impl(fun: lu.WrappedFun, *args, name, in_axes, out_axes_thunk, axis_sizes, argument 276 in_axes, out_axes_thunk, axis_sizes, argument
|
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | pxla.py | 640 global_axis_size, devices, name, in_axes, out_axes_thunk, argument 1603 def soft_pmap_impl(fun: lu.WrappedFun, *args, axis_name, axis_size, in_axes, out_axes_thunk): argument 1610 def _soft_pmap_callable(fun, axis_name, axis_size, in_axes, out_axes_thunk, *avals): argument
|
H A D | ad.py | 610 def out_axes_thunk(): function
|