Home
last modified time | relevance | path

Searched refs:out_tracers (Results 1 – 13 of 13) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dpartial_eval.py157 return out_tracers
280 for t in out_tracers:
282 return out_tracers
337 return out_tracers
377 return out_tracers
522 out_tracers = map(partial(instantiate_const_at, trace), instantiate, out_tracers)
1066 return out_tracers if primitive.multiple_results else out_tracers.pop()
1085 return out_tracers
1118 return out_tracers
1139 return out_tracers
[all …]
H A Dbatching.py58 out_tracers = map(trace.full_raise, outs)
59 out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
168 def post_process_call(self, call_primitive, out_tracers, params): argument
169 vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
216 def post_process_map(self, call_primitive, out_tracers, params): argument
217 vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
243 def post_process_custom_jvp_call(self, out_tracers, params): argument
244 vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
434 out_tracers = map(trace.full_raise, outs)
435 out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers)
[all …]
H A Dmasking.py98 out_tracers = map(trace.full_raise, outs)
99 out_vals, out_shapes = unzip2((t.val, t.polymorphic_shape) for t in out_tracers)
504 def post_process_call(self, call_primitive, out_tracers, params): argument
505 vals, shapes = unzip2((t.val, t.polymorphic_shape) for t in out_tracers)
H A Dad.py71 out_tracers = map(trace.full_raise, ans)
73 for out_tracer in out_tracers])
322 def post_process_call(self, call_primitive, out_tracers, params): argument
323 primals, tangents = unzip2((t.primal, t.tangent) for t in out_tracers)
356 def post_process_custom_jvp_call(self, out_tracers, params): argument
372 def post_process_custom_vjp_call(self, out_tracers, params): argument
/dports/math/py-jax/jax-0.2.9/jax/experimental/
H A Dloops.py427 out_tracers=body_out_tracers,
457 def trace_to_jaxpr_finalize(in_tracers, out_tracers, trace, instantiate=True): argument
459 instantiate = [instantiate] * len(out_tracers)
460 out_tracers = safe_map(trace.full_raise, safe_map(core.full_lower, out_tracers))
461 out_tracers = safe_map(partial(pe.instantiate_const_at, trace),
462 instantiate, out_tracers)
463 jaxpr, consts, env = pe.tracers_to_jaxpr(in_tracers, out_tracers)
H A Dcallback.py100 out_tracers = map(trace.full_raise, outs)
101 out_vals = [t.val for t in out_tracers]
H A Ddoubledouble.py95 out_tracers = map(trace.full_raise, ans)
97 for out_tracer in out_tracers])
117 out_tracers = map(trace.full_raise, outputs)
118 result = [(x.head, x.tail) for x in out_tracers]
H A Dmaps.py363 def post_process(self, trace, out_tracers, params): argument
390 out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals]
393 outvars = map(self.makevar, out_tracers)
404 return out_tracers
H A Djet.py76 out_tracers = map(trace.full_raise, ans)
77 out_primals, out_terms = unzip2((t.primal, t.terms) for t in out_tracers)
145 def post_process_call(self, call_primitive, out_tracers, params): argument
146 primals, series = unzip2((t.primal, t.terms) for t in out_tracers)
/dports/math/py-jax/jax-0.2.9/jax/
H A Dcustom_derivatives.py286 def post_process(self, trace, out_tracers, params): argument
287 return trace.post_process_custom_jvp_call(out_tracers, params)
587 def post_process(self, trace, out_tracers, params): argument
588 return trace.post_process_custom_vjp_call(out_tracers, params)
700 def jvp_post_process(self, trace, out_tracers, params): argument
716 def vjp_post_process(self, trace, out_tracers, params): argument
H A Dcore.py1280 def post_process(self, trace, out_tracers, params): argument
1281 return trace.post_process_call(self, out_tracers, params)
1307 def post_process(self, trace, out_tracers, params): argument
1308 return trace.post_process_map(self, out_tracers, params)
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/
H A Djax2tf.py338 out_tracers: Iterable[TensorFlowTracer] = map(trace.full_raise, outs) # type: ignore
340 tuple((t.val, t.aval) for t in out_tracers))
756 out_tracers: Sequence[TensorFlowTracer], params):
760 vals = tuple(t.val for t in out_tracers)
765 for v, out_tracer in util.safe_zip(vals, out_tracers)]
771 def post_process_map(self, map_primitive, out_tracers, params): argument
781 def post_process_custom_jvp_call(self, out_tracers, params): argument
791 def post_process_custom_vjp_call(self, out_tracers, params): argument
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/
H A Dcontrol_flow.py527 out_tracers: Sequence[pe.Tracer] = [
532 return out_tracers
905 out_tracers = [pe.JaxprTracer(trace, pe.PartialVal((pv, const)), None)
911 [index_tracer] + res_tracers + ops_tracers, out_tracers, cond_p, params,
913 for t in out_tracers: t.recipe = eqn
914 return out_tracers
1638 out_tracers = [pe.JaxprTracer(trace, pe.PartialVal((pv, const)), None)
1644 out_tracers, scan_p,
1650 for t in out_tracers: t.recipe = eqn
1651 return out_tracers