Searched refs:PartialVal (Results 1 – 16 of 16) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | partial_eval.py | 41 class PartialVal(tuple): class 63 return PartialVal((None, const)) 67 return PartialVal((aval, core.unit)) 199 PartialVal.unknown(pval[0]) 388 py_args = map(PartialVal, zip(pvs, consts)) 425 assert isinstance(pval, PartialVal) 515 pvals: Sequence[PartialVal]): argument 710 pvals = [PartialVal.unknown(aval) if uk else PartialVal.known(val) 881 return PartialVal.known(const2) 1309 pvals = [PartialVal.unknown(aval) if uk else PartialVal.known(val) [all …]
|
H A D | invertible_ad.py | 60 in_pvals = tuple(pe.PartialVal.unknown(raise_to_shaped(get_aval(arg))) for arg in flat_args) 220 complete_ivjp_flat, map(pe.PartialVal.unknown, in_avals), instantiate=True) 223 complete_ivjp_flat, map(pe.PartialVal.unknown, in_avals), 322 in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
|
H A D | ad.py | 96 in_pvals = (tuple(pe.PartialVal.known(p) for p in primals) 97 + tuple(pe.PartialVal.unknown(get_aval(p).at_least_vspace()) 718 ct_pvals = [pe.PartialVal.unknown(aval) for aval in out_avals] 790 pvals = [pe.PartialVal.unknown(aval) for aval in avals_in]
|
H A D | xla.py | 660 pvals: Sequence[pe.PartialVal] = [pe.PartialVal.unknown(aval) for aval in abstract_args] 960 pvals = [pe.PartialVal.unknown(a) for a in avals] 987 pvals = [pe.PartialVal.unknown(a) for a in avals]
|
H A D | sharded_jit.py | 101 in_pvals = [pe.PartialVal.unknown(aval) for aval in global_abstract_args]
|
H A D | batching.py | 483 in_pvals = [pe.PartialVal.unknown(aval) for aval in avals_in]
|
H A D | pxla.py | 728 pvals = [pe.PartialVal.unknown(aval) for aval in sharded_avals] 730 pval = pe.PartialVal.unknown(core.abstract_unit) # dummy value for axis env
|
/dports/math/py-flax/flax-0.3.3/flax/core/ |
H A D | axes_scan.py | 125 lambda x: pe.PartialVal.unknown(jax.ShapedArray(jnp.shape(x), jnp.result_type(x))), 128 lambda x: pe.PartialVal.unknown(jax.ShapedArray(jnp.shape(x)[1:], jnp.result_type(x))),
|
/dports/math/py-jax/jax-0.2.9/jax/experimental/ |
H A D | loops.py | 381 flat_init_pvals = safe_map(pe.PartialVal.unknown, flat_init_avals) 391 index_var_pval = pe.PartialVal.unknown(index_var_aval)
|
H A D | host_callback.py | 1002 [pe.JaxprTracer(trace, pe.PartialVal.known(primal_known),
|
/dports/math/py-flax/flax-0.3.3/flax/ |
H A D | jax_utils.py | 114 in_pvals = [pe.PartialVal.unknown(jax.ShapedArray(x.shape, x.dtype))
|
/dports/math/py-jax/jax-0.2.9/jax/ |
H A D | custom_derivatives.py | 675 in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals] 812 ans_pvals = [pe.PartialVal.unknown(a) for a in ans_avals] 916 in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
|
H A D | api.py | 647 pvals = [pe.PartialVal.unknown(aval) for aval in avals] 1981 in_pvals = map(pe.PartialVal.unknown, in_avals) 2087 in_pvals = [pe.PartialVal.unknown(a) for a in in_avals] 2515 in_pvals = [pe.PartialVal.unknown(raise_to_shaped(core.get_aval(x)))
|
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/ |
H A D | control_flow.py | 529 else pe.JaxprTracer(trace, pe.PartialVal.known(known), out_unknown.recipe) 905 out_tracers = [pe.JaxprTracer(trace, pe.PartialVal((pv, const)), None) 1572 invariant_pvals = [pe.PartialVal.known(core.unit if uk else t.pval[1]) 1574 other_pvals = [pe.PartialVal.unknown(a) for a in jaxpr_1.in_avals[num_consts:]] 1638 out_tracers = [pe.JaxprTracer(trace, pe.PartialVal((pv, const)), None) 1734 pvals = [pe.PartialVal.unknown(aval) for aval in in_avals] 2615 in_pvals = [pe.PartialVal.unknown(aval) for aval in in_avals]
|
H A D | parallel.py | 1075 out_tracer = pe.JaxprTracer(trace, pe.PartialVal.unknown(out_aval), None)
|
H A D | lax.py | 1149 pval = pe.PartialVal.unknown(aval) 1158 pvals = safe_map(pe.PartialVal.unknown, flat_in_avals)
|