Home
last modified time | relevance | path

Searched refs:in_tracers (Results 1 – 9 of 9) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dpartial_eval.py318 in_tracers = (*const_tracers, *env_tracers, *tracers)
518 in_tracers = map(trace.new_arg, pvals)
519 ans = yield in_tracers, {}
525 del trace, in_tracers, out_tracers
576 invars = [getvar(t) for t in in_tracers]
582 in_tracers: Sequence[JaxprTracer],
605 invars = map(getvar, in_tracers)
626 assert in_tracers, "Lambda binding with no args"
954 def to_jaxpr(self, in_tracers, out_tracers): argument
1195 in_tracers = map(trace.new_arg, in_avals)
[all …]
H A Dbatching.py55 in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val
57 outs = yield in_tracers, {}
431 in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val
433 outs = yield in_tracers, {}
450 in_tracers = [BatchTracer(trace, val, dim) if dim is not None else val
452 outs = yield in_tracers, {}
H A Dmasking.py94 in_tracers = [MaskTracer(trace, x, s).full_lower()
97 outs = yield in_tracers, {}
H A Dad.py68 in_tracers = [JVPTracer(trace, x, t) if type(t) is not Zero else x
70 ans = yield in_tracers, {}
/dports/math/py-jax/jax-0.2.9/jax/experimental/
H A Dloops.py400in_tracers = tuple(itertools.chain(*[self.carried_state_vars[ms] for ms in self.carried_state_name…
402 in_tracers += (self._index_var,)
426 in_tracers=in_tracers,
457 def trace_to_jaxpr_finalize(in_tracers, out_tracers, trace, instantiate=True): argument
463 jaxpr, consts, env = pe.tracers_to_jaxpr(in_tracers, out_tracers)
H A Dcallback.py98 in_tracers = [CallbackTracer(trace, val) for val in in_vals]
99 outs = yield in_tracers, params
H A Ddoubledouble.py92 in_tracers = [DoublingTracer(trace, h, t) if t is not None else h
94 ans = yield in_tracers, {}
114 in_tracers = [DoublingTracer(trace, head, tail) for head, tail in args]
115 outputs = yield in_tracers, {}
H A Djet.py74 in_tracers = map(partial(JetTracer, trace), primals, series)
75 ans = yield in_tracers, {}
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/
H A Djax2tf.py334 in_tracers = tuple(TensorFlowTracer(trace, val, aval)
337 outs = yield in_tracers, {} # type: Sequence[Union[TfVal, core.Unit]]