Home
last modified time | relevance | path

Searched refs:PartialVal (Results 1 – 16 of 16) sorted by relevance

/dports/math/py-jax/jax-0.2.9/jax/interpreters/
H A Dpartial_eval.py41 class PartialVal(tuple): class
63 return PartialVal((None, const))
67 return PartialVal((aval, core.unit))
199 PartialVal.unknown(pval[0])
388 py_args = map(PartialVal, zip(pvs, consts))
425 assert isinstance(pval, PartialVal)
515 pvals: Sequence[PartialVal]): argument
710 pvals = [PartialVal.unknown(aval) if uk else PartialVal.known(val)
881 return PartialVal.known(const2)
1309 pvals = [PartialVal.unknown(aval) if uk else PartialVal.known(val)
[all …]
H A Dinvertible_ad.py60 in_pvals = tuple(pe.PartialVal.unknown(raise_to_shaped(get_aval(arg))) for arg in flat_args)
220 complete_ivjp_flat, map(pe.PartialVal.unknown, in_avals), instantiate=True)
223 complete_ivjp_flat, map(pe.PartialVal.unknown, in_avals),
322 in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
H A Dad.py96 in_pvals = (tuple(pe.PartialVal.known(p) for p in primals)
97 + tuple(pe.PartialVal.unknown(get_aval(p).at_least_vspace())
718 ct_pvals = [pe.PartialVal.unknown(aval) for aval in out_avals]
790 pvals = [pe.PartialVal.unknown(aval) for aval in avals_in]
H A Dxla.py660 pvals: Sequence[pe.PartialVal] = [pe.PartialVal.unknown(aval) for aval in abstract_args]
960 pvals = [pe.PartialVal.unknown(a) for a in avals]
987 pvals = [pe.PartialVal.unknown(a) for a in avals]
H A Dsharded_jit.py101 in_pvals = [pe.PartialVal.unknown(aval) for aval in global_abstract_args]
H A Dbatching.py483 in_pvals = [pe.PartialVal.unknown(aval) for aval in avals_in]
H A Dpxla.py728 pvals = [pe.PartialVal.unknown(aval) for aval in sharded_avals]
730 pval = pe.PartialVal.unknown(core.abstract_unit) # dummy value for axis env
/dports/math/py-flax/flax-0.3.3/flax/core/
H A Daxes_scan.py125 lambda x: pe.PartialVal.unknown(jax.ShapedArray(jnp.shape(x), jnp.result_type(x))),
128 lambda x: pe.PartialVal.unknown(jax.ShapedArray(jnp.shape(x)[1:], jnp.result_type(x))),
/dports/math/py-jax/jax-0.2.9/jax/experimental/
H A Dloops.py381 flat_init_pvals = safe_map(pe.PartialVal.unknown, flat_init_avals)
391 index_var_pval = pe.PartialVal.unknown(index_var_aval)
H A Dhost_callback.py1002 [pe.JaxprTracer(trace, pe.PartialVal.known(primal_known),
/dports/math/py-flax/flax-0.3.3/flax/
H A Djax_utils.py114 in_pvals = [pe.PartialVal.unknown(jax.ShapedArray(x.shape, x.dtype))
/dports/math/py-jax/jax-0.2.9/jax/
H A Dcustom_derivatives.py675 in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
812 ans_pvals = [pe.PartialVal.unknown(a) for a in ans_avals]
916 in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
H A Dapi.py647 pvals = [pe.PartialVal.unknown(aval) for aval in avals]
1981 in_pvals = map(pe.PartialVal.unknown, in_avals)
2087 in_pvals = [pe.PartialVal.unknown(a) for a in in_avals]
2515 in_pvals = [pe.PartialVal.unknown(raise_to_shaped(core.get_aval(x)))
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/
H A Dcontrol_flow.py529 else pe.JaxprTracer(trace, pe.PartialVal.known(known), out_unknown.recipe)
905 out_tracers = [pe.JaxprTracer(trace, pe.PartialVal((pv, const)), None)
1572 invariant_pvals = [pe.PartialVal.known(core.unit if uk else t.pval[1])
1574 other_pvals = [pe.PartialVal.unknown(a) for a in jaxpr_1.in_avals[num_consts:]]
1638 out_tracers = [pe.JaxprTracer(trace, pe.PartialVal((pv, const)), None)
1734 pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
2615 in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
H A Dparallel.py1075 out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None)
H A Dlax.py1149 pval = pe.PartialVal.unknown(aval)
1158 pvals = safe_map(pe.PartialVal.unknown, flat_in_avals)