Searched refs:unsafe_map (Results 1 – 4 of 4) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/experimental/ |
H A D | maps.py | 40 map, unsafe_map = safe_map, map variable 327 map_in_axes = tuple(unsafe_map(lambda spec: spec.get(naxis, None), in_axes)) 328 map_out_axes = tuple(unsafe_map(lambda spec: spec.get(naxis, None), out_axes)) 341 return (tuple(unsafe_map(to_mesh, in_axes)), 342 tuple(unsafe_map(to_mesh, out_axes)))
|
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | xla.py | 45 map, unsafe_map = safe_map, map variable 221 if any(unsafe_map(_param_uses_outfeed, param)): 238 compiled_fun = xla_primitive_callable(prim, *unsafe_map(arg_spec, args), **params) 510 mesh_axes = tuple(unsafe_map(partial(axis_read, axis_env), name)) 530 return tuple(unsafe_map(tuple, groups.T)) 538 return max(unsafe_map(eqn_replicas, jaxpr.eqns), default=1) 581 *unsafe_map(arg_spec, args))
|
H A D | pxla.py | 70 unsafe_map, map = map, safe_map # type: ignore variable 642 abstract_args = unsafe_map(xla.abstractify, args) 1604 abstract_args = unsafe_map(xla.abstractify, args)
|
/dports/math/py-jax/jax-0.2.9/jax/ |
H A D | jaxpr_util.py | 24 map, unsafe_map = util.safe_map, map variable
|