Home
last modified time | relevance | path

Searched refs:ClosedJaxpr (Results 1 – 13 of 13) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dpartial_eval.py32 unit, unitvar, abstract_unit, ClosedJaxpr, new_jaxpr_eqn,
668 def partial_eval_jaxpr(jaxpr: ClosedJaxpr, unknowns: Sequence[bool], argument
670 ) -> Tuple[ClosedJaxpr, ClosedJaxpr, Sequence[bool]]:
744 return ClosedJaxpr(jaxpr_1, consts_1), ClosedJaxpr(jaxpr_2, ()), uk_out
784 closed_jaxpr = core.ClosedJaxpr(jaxpr, ())
842 def _dce_jaxpr(closed_jaxpr: ClosedJaxpr, outputs: Sequence[bool], drop_outputs=False) -> ClosedJax… argument
844 return core.ClosedJaxpr(new_jaxpr, closed_jaxpr.consts)
884 def move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool]) -> ClosedJaxpr: argument
891 new_closed_jaxpr = core.ClosedJaxpr(new_jaxpr, closed_jaxpr.consts)
1304 ) -> Tuple[ClosedJaxpr, ClosedJaxpr, Sequence[bool]]:
[all …]
H A Dinvertible_ad.py36 return core.ClosedJaxpr(jaxpr, consts)
226 ivjp_jaxpr = core.ClosedJaxpr(ivjp_jaxpr, [])
326 return core.ClosedJaxpr(jaxpr, consts)
H A Dad.py561 typed_call_jaxpr = core.ClosedJaxpr(call_jaxpr, [])
642 return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros()
656 def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_out, tangents_out): argument
661 return core.ClosedJaxpr(new_jaxpr, jaxpr.consts)
792 return core.ClosedJaxpr(jaxpr_out, consts), out_nonzeros()
H A Dbatching.py424 return core.ClosedJaxpr(jaxpr_out, consts), out_batched()
486 return core.ClosedJaxpr(jaxpr_out, consts_out), out_batched()
H A Dxla.py211 elif type(param) is core.ClosedJaxpr:
/dports/math/py-jax/jax-0.2.9/jax/
H A Dcustom_derivatives.py298 def _custom_jvp_call_jaxpr_impl(*args, fun_jaxpr: core.ClosedJaxpr, **params): argument
302 def _custom_jvp_call_jaxpr_abstract_eval(*args, fun_jaxpr: core.ClosedJaxpr, **params): argument
313 primals, tangents, *, fun_jaxpr: core.ClosedJaxpr, argument
332 args, in_dims, axis_name, *, fun_jaxpr: core.ClosedJaxpr, argument
347 jvp_jaxpr = core.ClosedJaxpr(*jvp_jaxpr_thunk()) # consts can be tracers
604 primals, tangents, *, fun_jaxpr: core.ClosedJaxpr, argument
627 args, in_dims, axis_name, *, fun_jaxpr: core.ClosedJaxpr, argument
643 fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk()) # consts can be tracers
723 closed_fun_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(fun_jaxpr), ())
732 closed_fun_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(fun_jaxpr), ())
H A Dcore.py110 elif isinstance(v, ClosedJaxpr):
123 class ClosedJaxpr: class
145 return ClosedJaxpr(f(self.jaxpr), self.consts)
151 def jaxpr_as_fun(closed_jaxpr: ClosedJaxpr, *args): argument
331 if type(param) in (Jaxpr, ClosedJaxpr))
1549 not isinstance(v, (Jaxpr, ClosedJaxpr)))}
1602 jaxprs = [j.jaxpr if isinstance(j, ClosedJaxpr) else j for j in jaxprs]
1606 if type(v) is tuple and all(isinstance(j, (Jaxpr, ClosedJaxpr)) for j in v):
H A Dapi.py2005 ) -> Callable[..., core.ClosedJaxpr]:
2091 closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
/dports/math/py-jax/jax-0.2.9/jax/experimental/
H A Dhost_callback.py1085 cjaxpr: core.ClosedJaxpr, has_input_token: bool, argument
1086 has_output_token: bool) -> core.ClosedJaxpr:
1089 return core.ClosedJaxpr(new_jaxpr, cjaxpr.consts)
1339 new_cond_jaxpr = core.ClosedJaxpr(
1386 new_body_jaxpr = core.ClosedJaxpr(
H A Dmaps.py290 f = lu.wrap_init(core.jaxpr_as_fun(core.ClosedJaxpr(jaxpr, consts)))
310 return core.jaxpr_as_fun(core.ClosedJaxpr(final_jaxpr, final_consts))
470 f = lu.wrap_init(core.jaxpr_as_fun(core.ClosedJaxpr(resource_call_jaxpr, ())))
H A Dloops.py465 closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/
H A Djax2tf.py344 def _interpret_jaxpr(jaxpr: core.ClosedJaxpr, *args: TfVal) -> Sequence[TfVal]: argument
1586 closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
1871 closed_jaxpr = core.ClosedJaxpr(update_jaxpr, update_consts)
1898 branches: Sequence[core.ClosedJaxpr], argument
1909 def _while(*args: TfVal, cond_nconsts: int, cond_jaxpr: core.ClosedJaxpr, argument
1910 body_nconsts: int, body_jaxpr: core.ClosedJaxpr) -> Sequence[TfVal]: argument
1928 cond_nconsts: int, cond_jaxpr: core.ClosedJaxpr, argument
1929 body_nconsts: int, body_jaxpr: core.ClosedJaxpr argument
2167 fun_jaxpr: core.ClosedJaxpr, argument
2177 fun_jaxpr: core.ClosedJaxpr, argument
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/
H A Dcontrol_flow.py73 closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
104 closed_jaxprs = [core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
459 cond_jaxpr: pe.ClosedJaxpr, body_nconsts: int, argument
460 body_jaxpr: pe.ClosedJaxpr) -> Sequence[pe.Tracer]: argument
988 jaxpr_aug = core.ClosedJaxpr(jaxpr_aug, jaxpr.consts)
1046 all(type(x) is core.ClosedJaxpr for x in branches))
1579 jaxpr_1_opt = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr_1_opt), ())
1737 return core.ClosedJaxpr(jaxpr, consts)
1826 tc(jaxpr, 'jaxpr', 'ClosedJaxpr', type(jaxpr) is core.ClosedJaxpr)
2626 closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
[all …]