Searched refs:donated_invars (Results 1 – 7 of 7) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | xla.py | 707 …xla_args, donated_invars = _xla_callable_args(c, abstract_args, tuple_args, donated_invars=donated… 714 donated_invars = set_up_aliases(c, xla_args, out_tuple, donated_invars, tuple_args) 715 if any(donated_invars): 786 donated_invars=None): argument 803 if donated_invars is not None: 804 donated_invars = [d 807 return xla_args, donated_invars 825 return xla_args, donated_invars 869 if not in_unknowns and donated_invars: 874 donated_invars = [d for d, uk in zip(donated_invars, in_unknowns) if uk] [all …]
|
H A D | pxla.py | 641 donated_invars, global_arg_shapes): argument 646 donated_invars, global_arg_shapes, 660 donated_invars: Iterable[bool], 839 donated_invars=donated_invars) 850 donated_invars = xla.set_up_aliases(c, xla_args, out_tuple, donated_invars, tuple_args) 1222 donated_invars, global_arg_shapes): argument 1223 del donated_invars # Unused. 1455 donated_invars = (False,) * len(in_jaxpr_avals) # TODO(apaszke): support donation 1475 xla_args, donated_invars = xla._xla_callable_args( 1480 donated_invars=donated_invars) [all …]
|
H A D | masking.py | 499 params = dict(params, donated_invars=((False,) * len(logical_env_vals) +
|
/dports/math/py-jax/jax-0.2.9/jax/ |
H A D | api.py | 220 donated_invars = (False,) * len(args_flat) 230 donated_invars=donated_invars) 284 donated_invars = (False,) * len(args_flat) 295 donated_invars=donated_invars) 632 donated_invars = (False,) * len(args_flat) 662 c, avals, should_tuple, partitions=in_parts_flat, donated_invars=donated_invars) 672 if any(donated_invars): 673 donated_invars = xla.set_up_aliases(c, xla_args, out_tuple, donated_invars, 675 if any(donated_invars): 1539 donated_invars = (False,) * len(args) [all …]
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/ |
H A D | host_callback.py | 1227 donated_invars=eqn.params["donated_invars"] + (False, False) 1240 donated_invars=eqn.params["donated_invars"] + (False, False), 1332 donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals)), 1375 donated_invars=(False,) * len(transformed_body_jaxpr.in_avals)), 1383 donated_invars=(False,) * len(transformed_cond_jaxpr.in_avals)),
|
H A D | jet.py | 178 donated_invars = params['donated_invars'] 179 if any(donated_invars): 181 return dict(params, donated_invars=(False,) * num_inputs)
|
H A D | doubledouble.py | 83 donated_invars=(False,) * (len(heads) + len(nonzero_tails)))
|