Searched refs:backward_pass (Results 1 – 8 of 8) sorted by relevance
/dports/math/py-jax/jax-0.2.9/jax/interpreters/ |
H A D | ad.py | 121 arg_cts = backward_pass(jaxpr, consts, dummy_args, cts) 163 def backward_pass(jaxpr: core.Jaxpr, consts, primals_in, cotangents_in): function 545 fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr) 576 cotangents_out = backward_pass(tangent_jaxpr.jaxpr, (), primals_in + residuals, cotangents_in) 595 fun = lu.hashable_partial(lu.wrap_init(backward_pass), call_jaxpr)
|
/dports/math/py-autograd/autograd-1.3/autograd/ |
H A D | core.py | 14 def vjp(g): return backward_pass(g, end_node) 17 def backward_pass(g, end_node): function
|
/dports/math/py-ssm/ssm-0.0.1/ssm/ |
H A D | primitives.py | 15 from ssm.messages import forward_pass, backward_pass, \ 76 backward_pass(log_Ps, ll, betas)
|
H A D | messages.pyx | 72 cpdef backward_pass(double[:,:,::1] log_Ps, function
|
/dports/devel/py-PeachPy/PeachPy-01d1515/peachpy/x86_64/ |
H A D | function.py | 558 def backward_pass(self, processing_function, instructions, input_state): member in Function._analize.BasicBlock 563 input_block.backward_pass(processing_function, instructions, output_state) 603 self.backward_pass(propogate_sse_backward, instructions, avx_state) 616 self.backward_pass(propogate_avx_backward, instructions, avx_state)
|
/dports/math/py-jax/jax-0.2.9/jax/ |
H A D | custom_derivatives.py | 373 return ad.backward_pass(fun_jaxpr.jaxpr, fun_jaxpr.consts, args, cts)
|
H A D | api.py | 1995 in_cotangents = ad.backward_pass(jaxpr, consts, dummies, out_cotangents)
|
/dports/math/py-jax/jax-0.2.9/jax/_src/lax/ |
H A D | control_flow.py | 1005 cts_in = ad.backward_pass( 1722 cbar_abar = ad.backward_pass(jaxpr.jaxpr, jaxpr.consts, primals, b_bar)
|