Home
last modified time | relevance | path

Searched refs:in_tree_def (Results 1 – 3 of 3) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dad.py534 def traceable(num_primals, in_tree_def, *primals_and_tangents): argument
537 new_tangents = tree_unflatten(in_tree_def, new_tangents)
544 all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
546 fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
581 flat_args, in_tree_def = tree_flatten((primals_in, cotangents_in))
582 flat_do_transpose, out_tree = flatten_fun_nokwargs(lu.wrap_init(do_transpose), in_tree_def)
594 all_args, in_tree_def = tree_flatten(((), args, ct)) # empty consts
597 fun, out_tree = flatten_fun_nokwargs(fun, in_tree_def)
/dports/math/py-jax/jax-0.2.9/jax/experimental/
H A Ddoubledouble.py78 nonzero_tails, in_tree_def = tree_flatten(tails)
80 len(heads), in_tree_def)
101 def screen_nones(num_heads, in_tree_def, *heads_and_tails): argument
104 new_tails = tree_unflatten(in_tree_def, new_tails)
H A Djet.py81 def traceable(in_tree_def, *primals_and_series): argument
82 primals_in, series_in = tree_unflatten(in_tree_def, primals_and_series)
136 primals_and_series, in_tree_def = tree_flatten((primals_in, series_in))
137 f_jet, out_tree_def = traceable(jet_subtrace(f, self.main), in_tree_def)