/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | partial_eval.py | 32 unit, unitvar, abstract_unit, ClosedJaxpr, new_jaxpr_eqn, 668 def partial_eval_jaxpr(jaxpr: ClosedJaxpr, unknowns: Sequence[bool], argument 670 ) -> Tuple[ClosedJaxpr, ClosedJaxpr, Sequence[bool]]: 744 return ClosedJaxpr(jaxpr_1, consts_1), ClosedJaxpr(jaxpr_2, ()), uk_out 784 closed_jaxpr = core.ClosedJaxpr(jaxpr, ()) 842 def _dce_jaxpr(closed_jaxpr: ClosedJaxpr, outputs: Sequence[bool], drop_outputs=False) -> ClosedJax… argument 844 return core.ClosedJaxpr(new_jaxpr, closed_jaxpr.consts) 884 def move_binders_to_front(closed_jaxpr: ClosedJaxpr, to_move: Sequence[bool]) -> ClosedJaxpr: argument 891 new_closed_jaxpr = core.ClosedJaxpr(new_jaxpr, closed_jaxpr.consts) 1304 ) -> Tuple[ClosedJaxpr, ClosedJaxpr, Sequence[bool]]: [all …]
|
H A D | invertible_ad.py | 36 return core.ClosedJaxpr(jaxpr, consts) 226 ivjp_jaxpr = core.ClosedJaxpr(ivjp_jaxpr, []) 326 return core.ClosedJaxpr(jaxpr, consts)
|
H A D | ad.py | 561 typed_call_jaxpr = core.ClosedJaxpr(call_jaxpr, []) 642 return core.ClosedJaxpr(jaxpr_out, literals_out), out_nonzeros() 656 def rearrange_binders(jaxpr: core.ClosedJaxpr, primals_in, tangents_in, primals_out, tangents_out): argument 661 return core.ClosedJaxpr(new_jaxpr, jaxpr.consts) 792 return core.ClosedJaxpr(jaxpr_out, consts), out_nonzeros()
|
H A D | batching.py | 424 return core.ClosedJaxpr(jaxpr_out, consts), out_batched() 486 return core.ClosedJaxpr(jaxpr_out, consts_out), out_batched()
|
H A D | xla.py | 211 elif type(param) is core.ClosedJaxpr:
|
/dports/math/py-jax/jax-0.2.9/jax/ |
H A D | custom_derivatives.py | 298 def _custom_jvp_call_jaxpr_impl(*args, fun_jaxpr: core.ClosedJaxpr, **params): argument 302 def _custom_jvp_call_jaxpr_abstract_eval(*args, fun_jaxpr: core.ClosedJaxpr, **params): argument 313 primals, tangents, *, fun_jaxpr: core.ClosedJaxpr, argument 332 args, in_dims, axis_name, *, fun_jaxpr: core.ClosedJaxpr, argument 347 jvp_jaxpr = core.ClosedJaxpr(*jvp_jaxpr_thunk()) # consts can be tracers 604 primals, tangents, *, fun_jaxpr: core.ClosedJaxpr, argument 627 args, in_dims, axis_name, *, fun_jaxpr: core.ClosedJaxpr, argument 643 fwd_jaxpr = core.ClosedJaxpr(*fwd_jaxpr_thunk()) # consts can be tracers 723 closed_fun_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(fun_jaxpr), ()) 732 closed_fun_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(fun_jaxpr), ())
|
H A D | core.py | 110 elif isinstance(v, ClosedJaxpr): 123 class ClosedJaxpr: class 145 return ClosedJaxpr(f(self.jaxpr), self.consts) 151 def jaxpr_as_fun(closed_jaxpr: ClosedJaxpr, *args): argument 331 if type(param) in (Jaxpr, ClosedJaxpr)) 1549 not isinstance(v, (Jaxpr, ClosedJaxpr)))} 1602 jaxprs = [j.jaxpr if isinstance(j, ClosedJaxpr) else j for j in jaxprs] 1606 if type(v) is tuple and all(isinstance(j, (Jaxpr, ClosedJaxpr)) for j in v):
|
H A D | api.py | 2005 ) -> Callable[..., core.ClosedJaxpr]: 2091 closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/ |
H A D | host_callback.py | 1085 cjaxpr: core.ClosedJaxpr, has_input_token: bool, argument 1086 has_output_token: bool) -> core.ClosedJaxpr: 1089 return core.ClosedJaxpr(new_jaxpr, cjaxpr.consts) 1339 new_cond_jaxpr = core.ClosedJaxpr( 1386 new_body_jaxpr = core.ClosedJaxpr(
|
H A D | maps.py | 290 f = lu.wrap_init(core.jaxpr_as_fun(core.ClosedJaxpr(jaxpr, consts))) 310 return core.jaxpr_as_fun(core.ClosedJaxpr(final_jaxpr, final_consts)) 470 f = lu.wrap_init(core.jaxpr_as_fun(core.ClosedJaxpr(resource_call_jaxpr, ())))
|
H A D | loops.py | 465 closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/ |
H A D | jax2tf.py | 344 def _interpret_jaxpr(jaxpr: core.ClosedJaxpr, *args: TfVal) -> Sequence[TfVal]: argument 1586 closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) 1871 closed_jaxpr = core.ClosedJaxpr(update_jaxpr, update_consts) 1898 branches: Sequence[core.ClosedJaxpr], argument 1909 def _while(*args: TfVal, cond_nconsts: int, cond_jaxpr: core.ClosedJaxpr, argument 1910 body_nconsts: int, body_jaxpr: core.ClosedJaxpr) -> Sequence[TfVal]: argument 1928 cond_nconsts: int, cond_jaxpr: core.ClosedJaxpr, argument 1929 body_nconsts: int, body_jaxpr: core.ClosedJaxpr argument 2167 fun_jaxpr: core.ClosedJaxpr, argument 2177 fun_jaxpr: core.ClosedJaxpr, argument
|
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/ |
H A D | control_flow.py | 73 closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) 104 closed_jaxprs = [core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) 459 cond_jaxpr: pe.ClosedJaxpr, body_nconsts: int, argument 460 body_jaxpr: pe.ClosedJaxpr) -> Sequence[pe.Tracer]: argument 988 jaxpr_aug = core.ClosedJaxpr(jaxpr_aug, jaxpr.consts) 1046 all(type(x) is core.ClosedJaxpr for x in branches)) 1579 jaxpr_1_opt = pe.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr_1_opt), ()) 1737 return core.ClosedJaxpr(jaxpr, consts) 1826 tc(jaxpr, 'jaxpr', 'ClosedJaxpr', type(jaxpr) is core.ClosedJaxpr) 2626 closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) [all …]
|