# ChainRules rrule for a fairly simple function

I have some code that roughly follows this pattern where a function `g` iteratively calls a function `f`, stores the results in an array where the input to the next call to `f` involves the output of the last call to `f`, and returns the resulting array:

``````import Zygote
f(x, y) = sin.(x + y)
function g(x0, y, n)
xs = Array{eltype(x0)}(undef, length(x0), n)
xs[:, 1] = x0
for i = 2:n
xs[:, i] = f(xs[:, i - 1], y)
end
return xs
end
function g_zygote(x0, y, n)
xs = Zygote.Buffer(x0, length(x0), n)
xs[:, 1] = x0
for i = 2:n
xs[:, i] = f(xs[:, i - 1], y)
end
return copy(xs)
end
x0 = zeros(2)
y = zeros(2)
h(y) = sum(g(x0, y, 3))
h_zygote(y) = sum(g_zygote(x0, y, 3))
``````

I can make this code differentiable with Zygote using `Zygote.Buffer` (and `g_zygote`, as above), but I would like to define an `rrule` method for `g` to make this work with ChainRules. Note that in my case, the function `f` is much more complicated than the one here, but I have successfully defined the `rrule` for that. Does anyone know how to define an efficient `rrule` for `g`?

Any help would be greatly appreciated!

1 Like

There you go:

``````using ChainRulesCore

function ChainRulesCore.rrule(::typeof(g), x0, y, n)
xs = Array{eltype(x0)}(undef, length(x0), n)
xs[:, 1] = xi = x0

# Two versions of evaluating the primal
if false
# Explicit version. Easier to read, but `eltype(pullbacks) == Any`
xs = Array{eltype(x0)}(undef, length(x0), n)
pullbacks = Vector{Any}(undef, n)
xs[:, 1] = x0
for i = 2:n
xs[:, i], pullbacks[i] = pullback(f, xs[:,i-1], y)
end
else
# A bit of a hack, but makes sure `eltype(pullbacks)` is as narrow as possible
pullbacks = [((xi,_) = pullback(f, xi, y)) for i = 2:n]
end

function g_pullback(dxs)
dxi,dy = pullbacks[n-1](dxs[:,n]) .+ (dxs[:,n-1], Zero())
for i = reverse(1:n-2)
dxi,dy = pullbacks[i](dxi) .+ (dxs[:,i], dy)
end
return NO_FIELDS, dxi, dy, DoesNotExist()
end

return xs, g_pullback
end
``````

Test:

``````x0 = rand(2)
y = rand(2)
n = 3

h_rule(y) = sum(g(x0, y, n))
h_auto(y) = sum(g_zygote(x0, y, n))