1from autograd.extend import primitive, defvjp, vspace 2from autograd.builtins import tuple 3from autograd import make_vjp 4 5@primitive 6def fixed_point(f, a, x0, distance, tol): 7 _f = f(a) 8 x, x_prev = _f(x0), x0 9 while distance(x, x_prev) > tol: 10 x, x_prev = _f(x), x 11 return x 12 13def fixed_point_vjp(ans, f, a, x0, distance, tol): 14 def rev_iter(params): 15 a, x_star, x_star_bar = params 16 vjp_x, _ = make_vjp(f(a))(x_star) 17 vs = vspace(x_star) 18 return lambda g: vs.add(vjp_x(g), x_star_bar) 19 vjp_a, _ = make_vjp(lambda x, y: f(x)(y))(a, ans) 20 return lambda g: vjp_a(fixed_point(rev_iter, tuple((a, ans, g)), 21 vspace(x0).zeros(), distance, tol)) 22 23defvjp(fixed_point, None, fixed_point_vjp, None) 24