/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | partial_eval.py | 157 return out_tracers 280 for t in out_tracers: 282 return out_tracers 337 return out_tracers 377 return out_tracers 522 out_tracers = map(partial(instantiate_const_at, trace), instantiate, out_tracers) 1066 return out_tracers if primitive.multiple_results else out_tracers.pop() 1085 return out_tracers 1118 return out_tracers 1139 return out_tracers [all …]
|
H A D | batching.py | 58 out_tracers = map(trace.full_raise, outs) 59 out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers) 168 def post_process_call(self, call_primitive, out_tracers, params): argument 169 vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers) 216 def post_process_map(self, call_primitive, out_tracers, params): argument 217 vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers) 243 def post_process_custom_jvp_call(self, out_tracers, params): argument 244 vals, dims = unzip2((t.val, t.batch_dim) for t in out_tracers) 434 out_tracers = map(trace.full_raise, outs) 435 out_vals, out_dims = unzip2((t.val, t.batch_dim) for t in out_tracers) [all …]
|
H A D | masking.py | 98 out_tracers = map(trace.full_raise, outs) 99 out_vals, out_shapes = unzip2((t.val, t.polymorphic_shape) for t in out_tracers) 504 def post_process_call(self, call_primitive, out_tracers, params): argument 505 vals, shapes = unzip2((t.val, t.polymorphic_shape) for t in out_tracers)
|
H A D | ad.py | 71 out_tracers = map(trace.full_raise, ans) 73 for out_tracer in out_tracers]) 322 def post_process_call(self, call_primitive, out_tracers, params): argument 323 primals, tangents = unzip2((t.primal, t.tangent) for t in out_tracers) 356 def post_process_custom_jvp_call(self, out_tracers, params): argument 372 def post_process_custom_vjp_call(self, out_tracers, params): argument
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/ |
H A D | loops.py | 427 out_tracers=body_out_tracers, 457 def trace_to_jaxpr_finalize(in_tracers, out_tracers, trace, instantiate=True): argument 459 instantiate = [instantiate] * len(out_tracers) 460 out_tracers = safe_map(trace.full_raise, safe_map(core.full_lower, out_tracers)) 461 out_tracers = safe_map(partial(pe.instantiate_const_at, trace), 462 instantiate, out_tracers) 463 jaxpr, consts, env = pe.tracers_to_jaxpr(in_tracers, out_tracers)
|
H A D | callback.py | 100 out_tracers = map(trace.full_raise, outs) 101 out_vals = [t.val for t in out_tracers]
|
H A D | doubledouble.py | 95 out_tracers = map(trace.full_raise, ans) 97 for out_tracer in out_tracers]) 117 out_tracers = map(trace.full_raise, outputs) 118 result = [(x.head, x.tail) for x in out_tracers]
|
H A D | maps.py | 363 def post_process(self, trace, out_tracers, params): argument 390 out_tracers = [DynamicJaxprTracer(self, a, source_info) for a in out_avals] 393 outvars = map(self.makevar, out_tracers) 404 return out_tracers
|
H A D | jet.py | 76 out_tracers = map(trace.full_raise, ans) 77 out_primals, out_terms = unzip2((t.primal, t.terms) for t in out_tracers) 145 def post_process_call(self, call_primitive, out_tracers, params): argument 146 primals, series = unzip2((t.primal, t.terms) for t in out_tracers)
|
/dports/math/py-jax/jax-0.2.9/jax/ |
H A D | custom_derivatives.py | 286 def post_process(self, trace, out_tracers, params): argument 287 return trace.post_process_custom_jvp_call(out_tracers, params) 587 def post_process(self, trace, out_tracers, params): argument 588 return trace.post_process_custom_vjp_call(out_tracers, params) 700 def jvp_post_process(self, trace, out_tracers, params): argument 716 def vjp_post_process(self, trace, out_tracers, params): argument
|
H A D | core.py | 1280 def post_process(self, trace, out_tracers, params): argument 1281 return trace.post_process_call(self, out_tracers, params) 1307 def post_process(self, trace, out_tracers, params): argument 1308 return trace.post_process_map(self, out_tracers, params)
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/ |
H A D | jax2tf.py | 338 out_tracers: Iterable[TensorFlowTracer] = map(trace.full_raise, outs) # type: ignore 340 tuple((t.val, t.aval) for t in out_tracers)) 756 out_tracers: Sequence[TensorFlowTracer], params): 760 vals = tuple(t.val for t in out_tracers) 765 for v, out_tracer in util.safe_zip(vals, out_tracers)] 771 def post_process_map(self, map_primitive, out_tracers, params): argument 781 def post_process_custom_jvp_call(self, out_tracers, params): argument 791 def post_process_custom_vjp_call(self, out_tracers, params): argument
|
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/ |
H A D | control_flow.py | 527 out_tracers: Sequence[pe.Tracer] = [ 532 return out_tracers 905 out_tracers = [pe.JaxprTracer(trace, pe.PartialVal((pv, const)), None) 911 [index_tracer] + res_tracers + ops_tracers, out_tracers, cond_p, params, 913 for t in out_tracers: t.recipe = eqn 914 return out_tracers 1638 out_tracers = [pe.JaxprTracer(trace, pe.PartialVal((pv, const)), None) 1644 out_tracers, scan_p, 1650 for t in out_tracers: t.recipe = eqn 1651 return out_tracers
|