Searched refs:write_cotangent (Results 1 – 2 of 2) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | invertible_ad.py | 160 def write_cotangent(v, ct): function 186 map(write_cotangent, jaxpr.outvars, cotangents_in) 259 map(write_cotangent, [v for v in eqn.invars if type(v) is not Literal], cts_out)
|
H A D | ad.py | 167 def write_cotangent(prim, v, ct): function 212 map(partial(write_cotangent, 'outvars'), jaxpr.outvars, cotangents_in) 231 map(partial(write_cotangent, eqn.primitive), eqn.invars, cts_out)
|