Searched refs:xla_args (Results 1 – 4 of 4) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | xla.py | 322 ans = rule(c, *xla_args, **params) 325 ans = rule(c, *xla_args, **params) 328 ans = rule(c, avals, xla_args, params) 332 *xla_args, **params) 738 for arg_index, arg in enumerate(xla_args): 807 return xla_args, donated_invars 825 return xla_args, donated_invars 942 def f(c, *xla_args, **params): argument 946 def f_with_avals(c, avals, xla_args, params): argument 958 *xla_args) [all …]
|
H A D | sharded_jit.py | 162 xla_args = _xla_sharded_args(c, global_abstract_args, in_parts) 166 extend_name_stack(wrap_name(name, "sharded_jit")), *xla_args) 236 xla_args = [] 240 xla_args.append(param) 241 return xla_args
|
H A D | pxla.py | 836 xla_args, donated_invars = xla._xla_callable_args(c, global_sharded_avals, tuple_args, 842 extend_name_stack(wrap_name(name, 'pmap')), *xla_args) 850 donated_invars = xla.set_up_aliases(c, xla_args, out_tuple, donated_invars, tuple_args) 1475 xla_args, donated_invars = xla._xla_callable_args( 1484 extend_name_stack(wrap_name(transformed_name, 'xmap')), *xla_args) 1493 donated_invars = xla.set_up_aliases(c, xla_args, out_tuple, donated_invars, tuple_args) 1636 xla_args, _ = xla._xla_callable_args(c, chunked_avals, tuple_args) 1639 'soft_pmap', *xla_args)
|
/dports/math/py-jax/jax-0.2.9/jax/ |
H A D | api.py | 661 xla_args, donated_invars = xla._xla_callable_args( 665 extend_name_stack(wrap_name(fun_name, "xla_computation")), *xla_args) 673 donated_invars = xla.set_up_aliases(c, xla_args, out_tuple, donated_invars, 676 shapes = [str(c.GetShape(a)) for a, d in zip(xla_args, donated_invars) if d] 2552 def fun_translation(c, *xla_args, **params): argument 2553 return xla.lower_fun(fun_impl, multiple_results=True)(c, *xla_args, **params)
|