Home
last modified time | relevance | path

Searched refs:full_lower (Results 1 – 11 of 11) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/
H A Dcustom_derivatives.py271 args = map(core.full_lower, args)
281 return _apply_todos(env_trace_todo, map(core.full_lower, outs))
292 outs = map(core.full_lower, todos_list.pop()(outs))
570 args = map(core.full_lower, args)
581 return _apply_todos(env_trace_todo, map(core.full_lower, outs))
682 args = map(core.full_lower, args)
697 return map(core.full_lower, outs)
705 args = map(core.full_lower, args)
713 return map(core.full_lower, outs)
H A Dcore.py283 return map(full_lower, out) if self.multiple_results else full_lower(out)
814 def full_lower(val): function
816 return val.full_lower()
1212 outs = map(full_lower, todos_list.pop()(outs))
1267 return map(full_lower, apply_todos(env_trace_todo(), outs))
1727 return map(full_lower, out_tracer)
1729 return full_lower(out_tracer)
1760 return apply_todos(env_trace_todo(), map(full_lower, outs))
/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dmasking.py94 in_tracers = [MaskTracer(trace, x, s).full_lower()
451 def full_lower(self): member in MaskTracer
453 return core.full_lower(self.val)
H A Dad.py84 aux_primals = [core.full_lower(x.primal)
346 primals_in = map(core.full_lower, primals_in)
362 res_and_primals_out = fwd.call_wrapped(*map(core.full_lower, primals_in))
402 def full_lower(self): member in JVPTracer
404 return core.full_lower(self.primal)
H A Dbatching.py121 def full_lower(self): member in BatchTracer
123 return core.full_lower(self.val)
H A Dpartial_eval.py448 def full_lower(self): member in JaxprTracer
451 return core.full_lower(known)
521 out_tracers = map(trace.full_raise, map(core.full_lower, ans))
907 def full_lower(self): member in DynamicJaxprTracer
/dports/math/py-jax/jax-0.2.9/jax/experimental/
H A Dcallback.py132 def full_lower(self): member in CallbackTracer
H A Ddoubledouble.py50 def full_lower(self): member in DoublingTracer
52 return core.full_lower(self.head)
H A Djet.py101 def full_lower(self): member in JetTracer
103 return core.full_lower(self.primal)
H A Dloops.py460 out_tracers = safe_map(trace.full_raise, safe_map(core.full_lower, out_tracers))
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/
H A Djax2tf.py656 def full_lower(self): member in TensorFlowTracer