Home
last modified time | relevance | path

Searched refs:jaxpr (Results 1 – 18 of 18) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/
H A Djaxpr_util.py28 def all_eqns(jaxpr: core.Jaxpr):
29 for eqn in jaxpr.eqns:
30 yield (jaxpr, eqn)
42 d = collect_eqns(jaxpr, key)
64 return histogram(jaxpr, key)
86 for v in jaxpr.constvars:
88 for v in jaxpr.invars:
91 for eqn in jaxpr.eqns:
97 for a in jaxpr.outvars:
102 return [(jaxpr, res), *subs] if subs else (jaxpr, res)
[all …]
H A Dcore.py111 yield v.jaxpr
119 for eqn in jaxpr.eqns:
124 jaxpr: Jaxpr
129 self.jaxpr = jaxpr
197 def _jaxpr_vars(jaxpr): argument
199 jaxpr.invars, jaxpr.constvars,
348 for eqn in jaxpr.eqns:
1417 _check_jaxpr(jaxpr, [v.aval for v in jaxpr.invars])
1448 map(write, jaxpr.constvars, [v.aval for v in jaxpr.constvars])
1473 map(read, jaxpr.outvars)
[all …]
H A Dcustom_derivatives.py55 jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals)
56 return jaxpr, consts
679 return jaxpr, consts
817 jaxpr, in_tree, out_tree, consts = res
820 cts_out = core.eval_jaxpr(jaxpr, consts, *cts_flat)
831 def __init__(self, jaxpr, in_tree, out_tree, consts): argument
832 self.jaxpr = jaxpr
842 jaxpr, in_tree, out_tree = aux
843 return cls(jaxpr, in_tree, out_tree, consts)
919 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
[all …]
H A Dapi.py648 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
651 jaxpr = xla.apply_outfeed_rewriter(jaxpr)
652 axis_env_ = make_axis_env(xla.jaxpr_replicas(jaxpr))
664 c, jaxpr, backend, axis_env_, xla_consts,
1801 lifted_jvp = partial(_lift_linearized, jaxpr, primal_avals, consts,
1815 tangents_out = eval_jaxpr(jaxpr, consts, *tangents)
1982 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(flat_fun, in_pvals,
1995 in_cotangents = ad.backward_pass(jaxpr, consts, dummies, out_cotangents)
2088 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(
2091 closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
[all …]
H A Dtest_util.py730 def iter_eqns(jaxpr): argument
732 for eqn in jaxpr.eqns:
734 for subjaxpr in core.subjaxprs(jaxpr):
738 jaxpr = api.make_jaxpr(fun)(*args)
739 precisions = [eqn.params['precision'] for eqn in iter_eqns(jaxpr.jaxpr)
/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dpartial_eval.py205 jaxpr = _drop_invars(jaxpr, in_knowns)
206 jaxpr = _dce_open_jaxpr(jaxpr, out_unknowns, drop_outputs=True)
651 invars=jaxpr.constvars + jaxpr.invars,
652 outvars=jaxpr.outvars, eqns=jaxpr.eqns)
660 invars=invars, outvars=jaxpr.outvars, eqns=jaxpr.eqns)
779 jaxpr = convert_constvars_jaxpr(jaxpr)
864 return core.Jaxpr(jaxpr.constvars, jaxpr.invars,
870 jaxpr.outvars, jaxpr.eqns)
959 jaxpr, constvals = _inline_literals(jaxpr, constvals)
994 used = {v for eqn in jaxpr.eqns for v in eqn.invars} | set(jaxpr.outvars)
[all …]
H A Dinvertible_ad.py36 return core.ClosedJaxpr(jaxpr, consts)
74 jaxpr, in_tree = aux.val
183 map(write_primal, jaxpr.invars, primals_in)
184 map(write_primal, jaxpr.outvars, primals_out)
185 map(write_primal, jaxpr.constvars, consts)
186 map(write_cotangent, jaxpr.outvars, cotangents_in)
187 for eqn in jaxpr.eqns[::-1]:
246 num_outputs = len(jaxpr_unknown.jaxpr.outvars)
247 jaxpr_known.jaxpr.outvars = jaxpr_known.jaxpr.outvars[:num_outputs]
264 return map(read_cotangent, jaxpr.invars)
[all …]
H A Dad.py105 jaxpr.invars = jaxpr.invars[len(primals):]
106 jaxpr.outvars = jaxpr.outvars[len(out_primals_pvals):]
118 def unbound_vjp(pvals, jaxpr, consts, *cts): argument
197 map(write_primal, jaxpr.constvars, consts)
200 map(write_primal, jaxpr.invars, primals_in)
205 seen_vars: Set[Any] = set(jaxpr.invars)
206 for eqn in jaxpr.eqns:
657 new_invars = _perm(primals_in, tangents_in, jaxpr.jaxpr.invars)
658 new_outvars = _perm(primals_out, tangents_out, jaxpr.jaxpr.outvars)
659 new_jaxpr = core.Jaxpr(jaxpr.jaxpr.constvars,
[all …]
H A Dxla.py199 return jaxpr
389 def jaxpr_literals(jaxpr): argument
391 for eqn in jaxpr.eqns:
431 for eqn in jaxpr.eqns:
557 def jaxpr_has_pmap(jaxpr): argument
559 for eqn in jaxpr.eqns:
568 def jaxpr_collectives(jaxpr): argument
570 for eqn in jaxpr.eqns:
664 jaxpr = apply_outfeed_rewriter(jaxpr)
666 nreps = jaxpr_replicas(jaxpr)
[all …]
H A Dpxla.py719 jaxpr = xla.apply_outfeed_rewriter(jaxpr)
733 jaxpr.invars = jaxpr.invars[1:] # ignore dummy
734 jaxpr = xla.apply_outfeed_rewriter(jaxpr)
967 for eqn in jaxpr.eqns:
969 if len(jaxpr.eqns) > 1:
1003 for eqn in jaxpr.eqns:
1450 jaxpr = xla.apply_outfeed_rewriter(jaxpr)
1570 for eqn in jaxpr.eqns:
1617 jaxpr = xla.apply_outfeed_rewriter(jaxpr)
1625 jaxpr, _, consts = _soft_pmap_jaxpr(jaxpr, consts, in_axes,
[all …]
H A Dsharded_jit.py98 jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final(
102 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, in_pvals, # type: ignore
106 if not jaxpr.eqns and all(outvar.aval is core.abstract_unit
107 for outvar in jaxpr.outvars):
116 nparts = pxla.reconcile_num_partitions(jaxpr, nparts)
165 c, jaxpr, None, axis_env, xla_consts,
H A Dbatching.py477 def batch_jaxpr(jaxpr, axis_size, in_batched, instantiate, axis_name): argument
478 f = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
482 for aval, b in zip(jaxpr.in_avals, in_batched)]
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/
H A Dcontrol_flow.py100 outvars=jaxpr.outvars, eqns=jaxpr.eqns)
739 outs = xla.jaxpr_subcomp(c, jaxpr.jaxpr, backend, axis_env,
981 res_vars = jaxpr.jaxpr.invars[:num_res]
982 non_res_vars = jaxpr.jaxpr.invars[num_res:]
986 jaxpr_aug = core.Jaxpr(jaxpr.jaxpr.constvars, aug_invars,
987 jaxpr.jaxpr.outvars, jaxpr.jaxpr.eqns)
1006 jaxpr.jaxpr, jaxpr.consts, primals, cts_out)
1099 core.check_jaxpr(jaxpr.jaxpr)
1273 reverse=reverse, length=length, jaxpr=jaxpr,
1722 cbar_abar = ad.backward_pass(jaxpr.jaxpr, jaxpr.consts, primals, b_bar)
[all …]
H A Dlax.py934 operand, scatter_indices, updates, update_jaxpr=jaxpr,
969 operand, scatter_indices, updates, update_jaxpr=jaxpr,
1004 operand, scatter_indices, updates, update_jaxpr=jaxpr,
1141 jaxpr, consts, out_tree = _variadic_reduction_jaxpr(
1144 jaxpr=jaxpr, consts=consts, dimensions=tuple(dimensions))
1152 return jaxpr, consts
1163 return jaxpr, consts, out_tree()
1247 operand, init_value, jaxpr=jaxpr, consts=consts,
1291 operand, init_value, jaxpr=jaxpr, consts=consts,
4931 jaxpr=jaxpr), new_operand_bdims
[all …]
/dports/math/py-jax/jax-0.2.9/jax/experimental/
H A Dhost_callback.py1098 return jaxpr
1100 mk_new_var = core.gensym([jaxpr])
1108 invars = jaxpr.invars
1117 for eqn in jaxpr.eqns:
1155 if xla.jaxpr_uses_outfeed(cond_jaxpr.jaxpr):
1181 for jaxpr in branches),
1194 new_jaxpr_invars = new_jaxpr.jaxpr.invars
1198 new_jaxpr.jaxpr.invars = new_jaxpr_invars
1200 new_jaxpr_outvars = new_jaxpr.jaxpr.outvars
1204 new_jaxpr.jaxpr.outvars = new_jaxpr_outvars
[all …]
H A Dmaps.py286 jaxpr, _, consts = pe.trace_to_jaxpr_final(fun, mapped_in_avals)
288 jaxpr = subst_axis_names(jaxpr, plan.axis_subst)
290 f = lu.wrap_init(core.jaxpr_as_fun(core.ClosedJaxpr(jaxpr, consts)))
294 used_resources = _jaxpr_resources(jaxpr, resource_env) | set(it.chain(*axis_resources.values()))
384 jaxpr, mapped_out_avals, consts = trace_to_subjaxpr_dynamic(
396 call_jaxpr=convert_constvars_jaxpr(jaxpr))
625 def _jaxpr_resources(jaxpr, resource_env) -> Set[ResourceAxisName]: argument
627 for eqn in jaxpr.eqns:
639 def subst_axis_names(jaxpr, axis_subst: Dict[AxisName, Tuple[AxisName]]): argument
640 eqns = [subst_eqn_axis_names(eqn, axis_subst) for eqn in jaxpr.eqns]
[all …]
H A Dloops.py463 jaxpr, consts, env = pe.tracers_to_jaxpr(in_tracers, out_tracers)
465 closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ())
517 jaxpr=body_closed_jaxpr,
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/
H A Djax2tf.py344 def _interpret_jaxpr(jaxpr: core.ClosedJaxpr, *args: TfVal) -> Sequence[TfVal]:
349 fun: lu.WrappedFun = lu.wrap_init(core.jaxpr_as_fun(jaxpr))
350 out_with_avals = _interpret_fun(fun, args, jaxpr.in_avals)
1564 def _reduce_window(operand, init_value, *, jaxpr, consts, window_dimensions, argument
1586 closed_jaxpr = core.ClosedJaxpr(jaxpr, consts)
1902 branches_tf = [functools.partial(_interpret_jaxpr, jaxpr, *operands)
1903 for jaxpr in branches]