Searched refs:in_tree_def (Results 1 – 3 of 3) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | ad.py | 534 def traceable(num_primals, in_tree_def, *primals_and_tangents): argument 537 new_tangents = tree_unflatten(in_tree_def, new_tangents) 544 all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts 546 fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def) 581 flat_args, in_tree_def = tree_flatten((primals_in, cotangents_in)) 582 flat_do_transpose, out_tree = flatten_fun_nokwargs(lu.wrap_init(do_transpose), in_tree_def) 594 all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts 597 fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/ |
H A D | doubledouble.py | 78 nonzero_tails, in_tree_def = tree_flatten(tails) 80 len(heads), in_tree_def) 101 def screen_nones(num_heads, in_tree_def, *heads_and_tails): argument 104 new_tails = tree_unflatten(in_tree_def, new_tails)
|
H A D | jet.py | 81 def traceable(in_tree_def, *primals_and_series): argument 82 primals_in, series_in = tree_unflatten(in_tree_def, primals_and_series) 136 primals_and_series, in_tree_def = tree_flatten((primals_in, series_in)) 137 f_jet, out_tree_def = traceable(jet_subtrace(f, self.main), in_tree_def)
|