There you go:
using ChainRulesCore
function ChainRulesCore.rrule(::typeof(g), x0, y, n)
xs = Array{eltype(x0)}(undef, length(x0), n)
xs[:, 1] = xi = x0
# Two versions of evaluating the primal
if false
# Explicit version. Easier to read, but `eltype(pullbacks) == Any`
xs = Array{eltype(x0)}(undef, length(x0), n)
pullbacks = Vector{Any}(undef, n)
xs[:, 1] = x0
for i = 2:n
xs[:, i], pullbacks[i] = pullback(f, xs[:,i-1], y)
end
else
# A bit of a hack, but makes sure `eltype(pullbacks)` is as narrow as possible
pullbacks = [((xi,_) = pullback(f, xi, y))[2] for i = 2:n]
end
function g_pullback(dxs)
dxi,dy = pullbacks[n-1](dxs[:,n]) .+ (dxs[:,n-1], Zero())
for i = reverse(1:n-2)
dxi,dy = pullbacks[i](dxi) .+ (dxs[:,i], dy)
end
return NO_FIELDS, dxi, dy, DoesNotExist()
end
return xs, g_pullback
end
Test:
x0 = rand(2)
y = rand(2)
n = 3
h_rule(y) = sum(g(x0, y, n))
h_auto(y) = sum(g_zygote(x0, y, n))
@show only(gradient(h_rule,y))
@show only(gradient(h_auto,y))
#=
only(gradient(h_rule, y)) = [1.8040253084967446, 0.42600995398401]
only(gradient(h_auto, y)) = [1.8040253084967446, 0.42600995398401]
=#