1from autograd.extend import primitive, defvjp, vspace
2from autograd.builtins import tuple
3from autograd import make_vjp
4
5@primitive
6def fixed_point(f, a, x0, distance, tol):
7    _f = f(a)
8    x, x_prev = _f(x0), x0
9    while distance(x, x_prev) > tol:
10        x, x_prev = _f(x), x
11    return x
12
13def fixed_point_vjp(ans, f, a, x0, distance, tol):
14    def rev_iter(params):
15        a, x_star, x_star_bar = params
16        vjp_x, _ = make_vjp(f(a))(x_star)
17        vs = vspace(x_star)
18        return lambda g: vs.add(vjp_x(g), x_star_bar)
19    vjp_a, _ = make_vjp(lambda x, y: f(x)(y))(a, ans)
20    return lambda g: vjp_a(fixed_point(rev_iter, tuple((a, ans, g)),
21                           vspace(x0).zeros(), distance, tol))
22
23defvjp(fixed_point, None, fixed_point_vjp, None)
24