Home
last modified time | relevance | path

Searched refs:unsafe_map (Results 1 – 4 of 4) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/experimental/
H A Dmaps.py40 map, unsafe_map = safe_map, map variable
327 map_in_axes = tuple(unsafe_map(lambda spec: spec.get(naxis, None), in_axes))
328 map_out_axes = tuple(unsafe_map(lambda spec: spec.get(naxis, None), out_axes))
341 return (tuple(unsafe_map(to_mesh, in_axes)),
342 tuple(unsafe_map(to_mesh, out_axes)))
/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dxla.py45 map, unsafe_map = safe_map, map variable
221 if any(unsafe_map(_param_uses_outfeed, param)):
238 compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params)
510 mesh_axes = tuple(unsafe_map(partial(axis_read, axis_env), name))
530 return tuple(unsafe_map(tuple, groups.T))
538 return max(unsafe_map(eqn_replicas, jaxpr.eqns), default=1)
581 *unsafe_map(arg_spec, args))
H A Dpxla.py70 unsafe_map, map = map, safe_map # type: ignore variable
642 abstract_args = unsafe_map(xla.abstractify, args)
1604 abstract_args = unsafe_map(xla.abstractify, args)
/dports/math/py-jax/jax-0.2.9/jax/
H A Djaxpr_util.py24 map, unsafe_map = util.safe_map, map variable