/dports/math/py-jax/jax-0.2.9/jax/ |
H A D | jaxpr_util.py | 28 def all_eqns(jaxpr: core.Jaxpr): 29 for eqn in jaxpr.eqns: 30 yield (jaxpr, eqn) 42 d = collect_eqns(jaxpr, key) 64 return histogram(jaxpr, key) 86 for v in jaxpr.constvars: 88 for v in jaxpr.invars: 91 for eqn in jaxpr.eqns: 97 for a in jaxpr.outvars: 102 return [(jaxpr, res), *subs] if subs else (jaxpr, res) [all …]
|
H A D | core.py | 111 yield v.jaxpr 119 for eqn in jaxpr.eqns: 124 jaxpr: Jaxpr 129 self.jaxpr = jaxpr 197 def _jaxpr_vars(jaxpr): argument 199 jaxpr.invars, jaxpr.constvars, 348 for eqn in jaxpr.eqns: 1417 _check_jaxpr(jaxpr, [v.aval for v in jaxpr.invars]) 1448 map(write, jaxpr.constvars, [v.aval for v in jaxpr.constvars]) 1473 map(read, jaxpr.outvars) [all …]
|
H A D | custom_derivatives.py | 55 jaxpr, _, consts = pe.trace_to_jaxpr_dynamic(fun, in_avals) 56 return jaxpr, consts 679 return jaxpr, consts 817 jaxpr, in_tree, out_tree, consts = res 820 cts_out = core.eval_jaxpr(jaxpr, consts, *cts_flat) 831 def __init__(self, jaxpr, in_tree, out_tree, consts): argument 832 self.jaxpr = jaxpr 842 jaxpr, in_tree, out_tree = aux 843 return cls(jaxpr, in_tree, out_tree, consts) 919 jaxpr, out_pvals, consts = pe.trace_to_jaxpr( [all …]
|
H A D | api.py | 648 jaxpr, out_pvals, consts = pe.trace_to_jaxpr( 651 jaxpr = xla.apply_outfeed_rewriter(jaxpr) 652 axis_env_ = make_axis_env(xla.jaxpr_replicas(jaxpr)) 664 c, jaxpr, backend, axis_env_, xla_consts, 1801 lifted_jvp = partial(_lift_linearized, jaxpr, primal_avals, consts, 1815 tangents_out = eval_jaxpr(jaxpr, consts, *tangents) 1982 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(flat_fun, in_pvals, 1995 in_cotangents = ad.backward_pass(jaxpr, consts, dummies, out_cotangents) 2088 jaxpr, out_pvals, consts = pe.trace_to_jaxpr( 2091 closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) [all …]
|
H A D | test_util.py | 730 def iter_eqns(jaxpr): argument 732 for eqn in jaxpr.eqns: 734 for subjaxpr in core.subjaxprs(jaxpr): 738 jaxpr = api.make_jaxpr(fun)(*args) 739 precisions = [eqn.params['precision'] for eqn in iter_eqns(jaxpr.jaxpr)
|
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | partial_eval.py | 205 jaxpr = _drop_invars(jaxpr, in_knowns) 206 jaxpr = _dce_open_jaxpr(jaxpr, out_unknowns, drop_outputs=True) 651 invars=jaxpr.constvars + jaxpr.invars, 652 outvars=jaxpr.outvars, eqns=jaxpr.eqns) 660 invars=invars, outvars=jaxpr.outvars, eqns=jaxpr.eqns) 779 jaxpr = convert_constvars_jaxpr(jaxpr) 864 return core.Jaxpr(jaxpr.constvars, jaxpr.invars, 870 jaxpr.outvars, jaxpr.eqns) 959 jaxpr, constvals = _inline_literals(jaxpr, constvals) 994 used = {v for eqn in jaxpr.eqns for v in eqn.invars} | set(jaxpr.outvars) [all …]
|
H A D | invertible_ad.py | 36 return core.ClosedJaxpr(jaxpr, consts) 74 jaxpr, in_tree = aux.val 183 map(write_primal, jaxpr.invars, primals_in) 184 map(write_primal, jaxpr.outvars, primals_out) 185 map(write_primal, jaxpr.constvars, consts) 186 map(write_cotangent, jaxpr.outvars, cotangents_in) 187 for eqn in jaxpr.eqns[::-1]: 246 num_outputs = len(jaxpr_unknown.jaxpr.outvars) 247 jaxpr_known.jaxpr.outvars = jaxpr_known.jaxpr.outvars[:num_outputs] 264 return map(read_cotangent, jaxpr.invars) [all …]
|
H A D | ad.py | 105 jaxpr.invars = jaxpr.invars[len(primals):] 106 jaxpr.outvars = jaxpr.outvars[len(out_primals_pvals):] 118 def unbound_vjp(pvals, jaxpr, consts, *cts): argument 197 map(write_primal, jaxpr.constvars, consts) 200 map(write_primal, jaxpr.invars, primals_in) 205 seen_vars: Set[Any] = set(jaxpr.invars) 206 for eqn in jaxpr.eqns: 657 new_invars = _perm(primals_in, tangents_in, jaxpr.jaxpr.invars) 658 new_outvars = _perm(primals_out, tangents_out, jaxpr.jaxpr.outvars) 659 new_jaxpr = core.Jaxpr(jaxpr.jaxpr.constvars, [all …]
|
H A D | xla.py | 199 return jaxpr 389 def jaxpr_literals(jaxpr): argument 391 for eqn in jaxpr.eqns: 431 for eqn in jaxpr.eqns: 557 def jaxpr_has_pmap(jaxpr): argument 559 for eqn in jaxpr.eqns: 568 def jaxpr_collectives(jaxpr): argument 570 for eqn in jaxpr.eqns: 664 jaxpr = apply_outfeed_rewriter(jaxpr) 666 nreps = jaxpr_replicas(jaxpr) [all …]
|
H A D | pxla.py | 719 jaxpr = xla.apply_outfeed_rewriter(jaxpr) 733 jaxpr.invars = jaxpr.invars[1:] # ignore dummy 734 jaxpr = xla.apply_outfeed_rewriter(jaxpr) 967 for eqn in jaxpr.eqns: 969 if len(jaxpr.eqns) > 1: 1003 for eqn in jaxpr.eqns: 1450 jaxpr = xla.apply_outfeed_rewriter(jaxpr) 1570 for eqn in jaxpr.eqns: 1617 jaxpr = xla.apply_outfeed_rewriter(jaxpr) 1625 jaxpr, _, consts = _soft_pmap_jaxpr(jaxpr, consts, in_axes, [all …]
|
H A D | sharded_jit.py | 98 jaxpr, global_out_avals, consts = pe.trace_to_jaxpr_final( 102 jaxpr, out_pvals, consts = pe.trace_to_jaxpr(fun, in_pvals, # type: ignore 106 if not jaxpr.eqns and all(outvar.aval is core.abstract_unit 107 for outvar in jaxpr.outvars): 116 nparts = pxla.reconcile_num_partitions(jaxpr, nparts) 165 c, jaxpr, None, axis_env, xla_consts,
|
H A D | batching.py | 477 def batch_jaxpr(jaxpr, axis_size, in_batched, instantiate, axis_name): argument 478 f = lu.wrap_init(core.jaxpr_as_fun(jaxpr)) 482 for aval, b in zip(jaxpr.in_avals, in_batched)]
|
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/ |
H A D | control_flow.py | 100 outvars=jaxpr.outvars, eqns=jaxpr.eqns) 739 outs = xla.jaxpr_subcomp(c, jaxpr.jaxpr, backend, axis_env, 981 res_vars = jaxpr.jaxpr.invars[:num_res] 982 non_res_vars = jaxpr.jaxpr.invars[num_res:] 986 jaxpr_aug = core.Jaxpr(jaxpr.jaxpr.constvars, aug_invars, 987 jaxpr.jaxpr.outvars, jaxpr.jaxpr.eqns) 1006 jaxpr.jaxpr, jaxpr.consts, primals, cts_out) 1099 core.check_jaxpr(jaxpr.jaxpr) 1273 reverse=reverse, length=length, jaxpr=jaxpr, 1722 cbar_abar = ad.backward_pass(jaxpr.jaxpr, jaxpr.consts, primals, b_bar) [all …]
|
H A D | lax.py | 934 operand, scatter_indices, updates, update_jaxpr=jaxpr, 969 operand, scatter_indices, updates, update_jaxpr=jaxpr, 1004 operand, scatter_indices, updates, update_jaxpr=jaxpr, 1141 jaxpr, consts, out_tree = _variadic_reduction_jaxpr( 1144 jaxpr=jaxpr, consts=consts, dimensions=tuple(dimensions)) 1152 return jaxpr, consts 1163 return jaxpr, consts, out_tree() 1247 operand, init_value, jaxpr=jaxpr, consts=consts, 1291 operand, init_value, jaxpr=jaxpr, consts=consts, 4931 jaxpr=jaxpr), new_operand_bdims [all …]
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/ |
H A D | host_callback.py | 1098 return jaxpr 1100 mk_new_var = core.gensym([jaxpr]) 1108 invars = jaxpr.invars 1117 for eqn in jaxpr.eqns: 1155 if xla.jaxpr_uses_outfeed(cond_jaxpr.jaxpr): 1181 for jaxpr in branches), 1194 new_jaxpr_invars = new_jaxpr.jaxpr.invars 1198 new_jaxpr.jaxpr.invars = new_jaxpr_invars 1200 new_jaxpr_outvars = new_jaxpr.jaxpr.outvars 1204 new_jaxpr.jaxpr.outvars = new_jaxpr_outvars [all …]
|
H A D | maps.py | 286 jaxpr, _, consts = pe.trace_to_jaxpr_final(fun, mapped_in_avals) 288 jaxpr = subst_axis_names(jaxpr, plan.axis_subst) 290 f = lu.wrap_init(core.jaxpr_as_fun(core.ClosedJaxpr(jaxpr, consts))) 294 used_resources = _jaxpr_resources(jaxpr, resource_env) | set(it.chain(*axis_resources.values())) 384 jaxpr, mapped_out_avals, consts = trace_to_subjaxpr_dynamic( 396 call_jaxpr=convert_constvars_jaxpr(jaxpr)) 625 def _jaxpr_resources(jaxpr, resource_env) -> Set[ResourceAxisName]: argument 627 for eqn in jaxpr.eqns: 639 def subst_axis_names(jaxpr, axis_subst: Dict[AxisName, Tuple[AxisName]]): argument 640 eqns = [subst_eqn_axis_names(eqn, axis_subst) for eqn in jaxpr.eqns] [all …]
|
H A D | loops.py | 463 jaxpr, consts, env = pe.tracers_to_jaxpr(in_tracers, out_tracers) 465 closed_jaxpr = core.ClosedJaxpr(pe.convert_constvars_jaxpr(jaxpr), ()) 517 jaxpr=body_closed_jaxpr,
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/jax2tf/ |
H A D | jax2tf.py | 344 def _interpret_jaxpr(jaxpr: core.ClosedJaxpr, *args: TfVal) -> Sequence[TfVal]: 349 fun: lu.WrappedFun = lu.wrap_init(core.jaxpr_as_fun(jaxpr)) 350 out_with_avals = _interpret_fun(fun, args, jaxpr.in_avals) 1564 def _reduce_window(operand, init_value, *, jaxpr, consts, window_dimensions, argument 1586 closed_jaxpr = core.ClosedJaxpr(jaxpr, consts) 1902 branches_tf = [functools.partial(_interpret_jaxpr, jaxpr, *operands) 1903 for jaxpr in branches]
|