Home
last modified time | relevance | path

Searched refs:call_primitive (Results 1 – 11 of 11) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dbatching.py157 def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params): argument
158 assert call_primitive.multiple_results
162 return call_primitive.bind(f, *vals, **params)
165 vals_out = call_primitive.bind(f, *vals, **params)
168 def post_process_call(self, call_primitive, out_tracers, params): argument
216 def post_process_map(self, call_primitive, out_tracers, params): argument
225 if call_primitive.map_primitive:
H A Dmasking.py485 def process_call(self, call_primitive, f, tracers, params): argument
486 assert call_primitive.multiple_results
490 return call_primitive.bind(f, *vals, **params)
501 vals_out = call_primitive.bind(f, *(logical_env_vals + vals), **params)
504 def post_process_call(self, call_primitive, out_tracers, params): argument
H A Dad.py221 if eqn.primitive.call_primitive or eqn.primitive.map_primitive:
291 def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params): argument
292 assert call_primitive.multiple_results
298 if isinstance(call_primitive, core.MapPrimitive):
316 update_params = call_param_updaters.get(call_primitive)
318 result = call_primitive.bind(f_jvp, *primals, *nonzero_tangents, **new_params)
322 def post_process_call(self, call_primitive, out_tracers, params): argument
332 if call_primitive.map_primitive:
H A Dpartial_eval.py561 if primitive.call_primitive or primitive.map_primitive:
1068 def process_call(self, call_primitive, f, tracers, params): argument
1079 update_params = call_param_updaters.get(call_primitive)
1082 eqn = new_jaxpr_eqn([*constvars, *invars], outvars, call_primitive,
1087 def post_process_call(self, call_primitive, out_tracers, params): argument
H A Dinvertible_ad.py194 assert not eqn.primitive.call_primitive
/dports/math/py-jax/jax-0.2.9/jax/experimental/
H A Dcallback.py152 def process_call(self, call_primitive, f: lu.WrappedFun, tracers, params): argument
157 vals_out = call_primitive.bind(f, *vals_in, **params)
H A Ddoubledouble.py75 def process_call(self, call_primitive, f, tracers, params): argument
76 assert call_primitive.multiple_results
84 result = call_primitive.bind(f_double, *heads, *nonzero_tails, **new_params)
H A Djet.py134 def process_call(self, call_primitive, f, tracers, params): argument
138 update_params = call_param_updaters.get(call_primitive)
141 result = call_primitive.bind(f_jet, *primals_and_series, **new_params)
145 def post_process_call(self, call_primitive, out_tracers, params): argument
H A Dmaps.py368 def _process_xmap_default(self, call_primitive, f, tracers, params): argument
/dports/math/py-jax/jax-0.2.9/jax/
H A Dcore.py267 call_primitive = False # set for call primitives processed in final style variable in Primitive
318 if not (primitive.call_primitive or primitive.map_primitive):
428 def process_call(self, call_primitive, f, tracers, params): argument
433 def process_map(self, call_primitive, f, tracers, params): argument
1272 call_primitive = True variable in CallPrimitive
1459 if prim.call_primitive:
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/
H A Djax2tf.py740 def process_call(self, call_primitive: core.Primitive, f: lu.WrappedFun,
742 assert call_primitive.multiple_results
745 if call_primitive == core.named_call_p:
749 elif call_primitive == sharded_jit.sharded_call_p:
755 def post_process_call(self, call_primitive: core.Primitive,