Home
last modified time | relevance | path

Searched refs:donated_invars (Results 1 – 7 of 7) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dxla.py707 …xla_args, donated_invars = _xla_callable_args(c, abstract_args, tuple_args, donated_invars=donated…
714 donated_invars = set_up_aliases(c, xla_args, out_tuple, donated_invars, tuple_args)
715 if any(donated_invars):
786 donated_invars=None): argument
803 if donated_invars is not None:
804 donated_invars = [d
807 return xla_args, donated_invars
825 return xla_args, donated_invars
869 if not in_unknowns and donated_invars:
874 donated_invars = [d for d, uk in zip(donated_invars, in_unknowns) if uk]
[all …]
H A Dpxla.py641 donated_invars, global_arg_shapes): argument
646 donated_invars, global_arg_shapes,
660 donated_invars: Iterable[bool],
839 donated_invars=donated_invars)
850 donated_invars = xla.set_up_aliases(c, xla_args, out_tuple, donated_invars, tuple_args)
1222 donated_invars, global_arg_shapes): argument
1223 del donated_invars # Unused.
1455 donated_invars = (False,) * len(in_jaxpr_avals) # TODO(apaszke): support donation
1475 xla_args, donated_invars = xla._xla_callable_args(
1480 donated_invars=donated_invars)
[all …]
H A Dmasking.py499 params = dict(params, donated_invars=((False,) * len(logical_env_vals) +
/dports/math/py-jax/jax-0.2.9/jax/
H A Dapi.py220 donated_invars = (False,) * len(args_flat)
230 donated_invars=donated_invars)
284 donated_invars = (False,) * len(args_flat)
295 donated_invars=donated_invars)
632 donated_invars = (False,) * len(args_flat)
662 c, avals, should_tuple, partitions=in_parts_flat, donated_invars=donated_invars)
672 if any(donated_invars):
673 donated_invars = xla.set_up_aliases(c, xla_args, out_tuple, donated_invars,
675 if any(donated_invars):
1539 donated_invars = (False,) * len(args)
[all …]
/dports/math/py-jax/jax-0.2.9/jax/experimental/
H A Dhost_callback.py1227 donated_invars=eqn.params["donated_invars"] + (False, False)
1240 donated_invars=eqn.params["donated_invars"] + (False, False),
1332 donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals)),
1375 donated_invars=(False,) * len(transformed_body_jaxpr.in_avals)),
1383 donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals)),
H A Djet.py178 donated_invars = params['donated_invars']
179 if any(donated_invars):
181 return dict(params, donated_invars=(False,) * num_inputs)
H A Ddoubledouble.py83 donated_invars=(False,) * (len(heads) + len(nonzero_tails)))