Home
last modified time | relevance | path

Searched refs:xla_args (Results 1 – 4 of 4) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dxla.py322 ans = rule(c, *xla_args, **params)
325 ans = rule(c, *xla_args, **params)
328 ans = rule(c, avals, xla_args, params)
332 *xla_args, **params)
738 for arg_index, arg in enumerate(xla_args):
807 return xla_args, donated_invars
825 return xla_args, donated_invars
942 def f(c, *xla_args, **params): argument
946 def f_with_avals(c, avals, xla_args, params): argument
958 *xla_args)
[all …]
H A Dsharded_jit.py162 xla_args = _xla_sharded_args(c, global_abstract_args, in_parts)
166 extend_name_stack(wrap_name(name, "sharded_jit")), *xla_args)
236 xla_args = []
240 xla_args.append(param)
241 return xla_args
H A Dpxla.py836 xla_args, donated_invars = xla._xla_callable_args(c, global_sharded_avals, tuple_args,
842 extend_name_stack(wrap_name(name, 'pmap')), *xla_args)
850 donated_invars = xla.set_up_aliases(c, xla_args, out_tuple, donated_invars, tuple_args)
1475 xla_args, donated_invars = xla._xla_callable_args(
1484 extend_name_stack(wrap_name(transformed_name, 'xmap')), *xla_args)
1493 donated_invars = xla.set_up_aliases(c, xla_args, out_tuple, donated_invars, tuple_args)
1636 xla_args, _ = xla._xla_callable_args(c, chunked_avals, tuple_args)
1639 'soft_pmap', *xla_args)
/dports/math/py-jax/jax-0.2.9/jax/
H A Dapi.py661 xla_args, donated_invars = xla._xla_callable_args(
665 extend_name_stack(wrap_name(fun_name, "xla_computation")), *xla_args)
673 donated_invars = xla.set_up_aliases(c, xla_args, out_tuple, donated_invars,
676 shapes = [str(c.GetShape(a)) for a, d in zip(xla_args, donated_invars) if d]
2552 def fun_translation(c, *xla_args, **params): argument
2553 return xla.lower_fun(fun_impl, multiple_results=True)(c, *xla_args, **params)