Searched refs:full_lower (Results 1 – 11 of 11) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/ |
H A D | custom_derivatives.py | 271 args = map(core.full_lower, args) 281 return _apply_todos(env_trace_todo, map(core.full_lower, outs)) 292 outs = map(core.full_lower, todos_list.pop()(outs)) 570 args = map(core.full_lower, args) 581 return _apply_todos(env_trace_todo, map(core.full_lower, outs)) 682 args = map(core.full_lower, args) 697 return map(core.full_lower, outs) 705 args = map(core.full_lower, args) 713 return map(core.full_lower, outs)
|
H A D | core.py | 283 return map(full_lower, out) if self.multiple_results else full_lower(out) 814 def full_lower(val): function 816 return val.full_lower() 1212 outs = map(full_lower, todos_list.pop()(outs)) 1267 return map(full_lower, apply_todos(env_trace_todo(), outs)) 1727 return map(full_lower, out_tracer) 1729 return full_lower(out_tracer) 1760 return apply_todos(env_trace_todo(), map(full_lower, outs))
|
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | masking.py | 94 in_tracers = [MaskTracer(trace, x, s).full_lower() 451 def full_lower(self): member in MaskTracer 453 return core.full_lower(self.val)
|
H A D | ad.py | 84 aux_primals = [core.full_lower(x.primal) 346 primals_in = map(core.full_lower, primals_in) 362 res_and_primals_out = fwd.call_wrapped(*map(core.full_lower, primals_in)) 402 def full_lower(self): member in JVPTracer 404 return core.full_lower(self.primal)
|
H A D | batching.py | 121 def full_lower(self): member in BatchTracer 123 return core.full_lower(self.val)
|
H A D | partial_eval.py | 448 def full_lower(self): member in JaxprTracer 451 return core.full_lower(known) 521 out_tracers = map(trace.full_raise, map(core.full_lower, ans)) 907 def full_lower(self): member in DynamicJaxprTracer
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/ |
H A D | callback.py | 132 def full_lower(self): member in CallbackTracer
|
H A D | doubledouble.py | 50 def full_lower(self): member in DoublingTracer 52 return core.full_lower(self.head)
|
H A D | jet.py | 101 def full_lower(self): member in JetTracer 103 return core.full_lower(self.primal)
|
H A D | loops.py | 460 out_tracers = safe_map(trace.full_raise, safe_map(core.full_lower, out_tracers))
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/ |
H A D | jax2tf.py | 656 def full_lower(self): member in TensorFlowTracer
|