Searched refs:in_tracers (Results 1 – 9 of 9) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | partial_eval.py | 318 in_tracers = (*const_tracers, *env_tracers, *tracers) 518 in_tracers = map(trace.new_arg, pvals) 519 ans = yield in_tracers, {} 525 del trace, in_tracers, out_tracers 576 invars = [getvar(t) for t in in_tracers] 582 in_tracers: Sequence[JaxprTracer], 605 invars = map(getvar, in_tracers) 626 assert in_tracers, "Lambda binding with no args" 954 def to_jaxpr(self, in_tracers, out_tracers): argument 1195 in_tracers = map(trace.new_arg, in_avals) [all …]
|
H A D | batching.py | 55 in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val 57 outs = yield in_tracers, {} 431 in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val 433 outs = yield in_tracers, {} 450 in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val 452 outs = yield in_tracers, {}
|
H A D | masking.py | 94 in_tracers = [MaskTracer(trace, x, s).full_lower() 97 outs = yield in_tracers, {}
|
H A D | ad.py | 68 in_tracers = [JVPTracer(trace, x, t) if type(t) is not Zero else x 70 ans = yield in_tracers, {}
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/ |
H A D | loops.py | 400 …in_tracers = tuple(itertools.chain(*[self.carried_state_vars[ms] for ms in self.carried_state_name… 402 in_tracers += (self._index_var,) 426 in_tracers=in_tracers, 457 def trace_to_jaxpr_finalize(in_tracers, out_tracers, trace, instantiate=True): argument 463 jaxpr, consts, env = pe.tracers_to_jaxpr(in_tracers, out_tracers)
|
H A D | callback.py | 98 in_tracers = [CallbackTracer(trace, val) for val in in_vals] 99 outs = yield in_tracers, params
|
H A D | doubledouble.py | 92 in_tracers = [DoublingTracer(trace, h, t) if t is not None else h 94 ans = yield in_tracers, {} 114 in_tracers = [DoublingTracer(trace, head, tail) for head, tail in args] 115 outputs = yield in_tracers, {}
|
H A D | jet.py | 74 in_tracers = map(partial(JetTracer, trace), primals, series) 75 ans = yield in_tracers, {}
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/ |
H A D | jax2tf.py | 334 in_tracers = tuple(TensorFlowTracer(trace, val, aval) 337 outs = yield in_tracers, {} # type: Sequence[Union[TfVal, core.Unit]]
|