/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | batching.py | 157 def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params): argument 158 assert call_primitive.multiple_results 162 return call_primitive.bind(f, *vals, **params) 165 vals_out = call_primitive.bind(f, *vals, **params) 168 def post_process_call(self, call_primitive, out_tracers, params): argument 216 def post_process_map(self, call_primitive, out_tracers, params): argument 225 if call_primitive.map_primitive:
|
H A D | masking.py | 485 def process_call(self, call_primitive, f, tracers, params): argument 486 assert call_primitive.multiple_results 490 return call_primitive.bind(f, *vals, **params) 501 vals_out = call_primitive.bind(f, *(logical_env_vals + vals), **params) 504 def post_process_call(self, call_primitive, out_tracers, params): argument
|
H A D | ad.py | 221 if eqn.primitive.call_primitive or eqn.primitive.map_primitive: 291 def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params): argument 292 assert call_primitive.multiple_results 298 if isinstance(call_primitive, core.MapPrimitive): 316 update_params = call_param_updaters.get(call_primitive) 318 result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params) 322 def post_process_call(self, call_primitive, out_tracers, params): argument 332 if call_primitive.map_primitive:
|
H A D | partial_eval.py | 561 if primitive.call_primitive or primitive.map_primitive: 1068 def process_call(self, call_primitive, f, tracers, params): argument 1079 update_params = call_param_updaters.get(call_primitive) 1082 eqn = new_jaxpr_eqn([*constvars, *invars], outvars, call_primitive, 1087 def post_process_call(self, call_primitive, out_tracers, params): argument
|
H A D | invertible_ad.py | 194 assert not eqn.primitive.call_primitive
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/ |
H A D | callback.py | 152 def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params): argument 157 vals_out = call_primitive.bind(f, *vals_in, **params)
|
H A D | doubledouble.py | 75 def process_call(self, call_primitive, f, tracers, params): argument 76 assert call_primitive.multiple_results 84 result = call_primitive.bind(f_double, *heads, *nonzero_tails, **new_params)
|
H A D | jet.py | 134 def process_call(self, call_primitive, f, tracers, params): argument 138 update_params = call_param_updaters.get(call_primitive) 141 result = call_primitive.bind(f_jet, *primals_and_series, **new_params) 145 def post_process_call(self, call_primitive, out_tracers, params): argument
|
H A D | maps.py | 368 def _process_xmap_default(self, call_primitive, f, tracers, params): argument
|
/dports/math/py-jax/jax-0.2.9/jax/ |
H A D | core.py | 267 call_primitive = False # set for call primitives processed in final style variable in Primitive 318 if not (primitive.call_primitive or primitive.map_primitive): 428 def process_call(self, call_primitive, f, tracers, params): argument 433 def process_map(self, call_primitive, f, tracers, params): argument 1272 call_primitive = True variable in CallPrimitive 1459 if prim.call_primitive:
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/ |
H A D | jax2tf.py | 740 def process_call(self, call_primitive: core.Primitive, f: lu.WrappedFun, 742 assert call_primitive.multiple_results 745 if call_primitive == core.named_call_p: 749 elif call_primitive == sharded_jit.sharded_call_p: 755 def post_process_call(self, call_primitive: core.Primitive,
|