ChainRules rrule for a fairly simple function

I have some code that roughly follows this pattern where a function g iteratively calls a function f, stores the results in an array where the input to the next call to f involves the output of the last call to f, and returns the resulting array:

import Zygote
f(x, y) = sin.(x + y)
function g(x0, y, n)
	xs = Array{eltype(x0)}(undef, length(x0), n)
	xs[:, 1] = x0
	for i = 2:n
		xs[:, i] = f(xs[:, i - 1], y)
	end
	return xs
end
function g_zygote(x0, y, n)
	xs = Zygote.Buffer(x0, length(x0), n)
	xs[:, 1] = x0
	for i = 2:n
		xs[:, i] = f(xs[:, i - 1], y)
	end
	return copy(xs)
end
x0 = zeros(2)
y = zeros(2)
h(y) = sum(g(x0, y, 3))
h_zygote(y) = sum(g_zygote(x0, y, 3))
zgrad = Zygote.gradient(h_zygote, y)[1]#looks good

I can make this code differentiable with Zygote using Zygote.Buffer (and g_zygote, as above), but I would like to define an rrule method for g to make this work with ChainRules. Note that in my case, the function f is much more complicated than the one here, but I have successfully defined the rrule for that. Does anyone know how to define an efficient rrule for g?

Any help would be greatly appreciated!

1 Like

There you go:

using ChainRulesCore

function ChainRulesCore.rrule(::typeof(g), x0, y, n)
    xs = Array{eltype(x0)}(undef, length(x0), n)
    xs[:, 1] = xi = x0

    # Two versions of evaluating the primal
    if false
        # Explicit version. Easier to read, but `eltype(pullbacks) == Any`
        xs = Array{eltype(x0)}(undef, length(x0), n)
        pullbacks = Vector{Any}(undef, n)
        xs[:, 1] = x0
        for i = 2:n
        	xs[:, i], pullbacks[i] = pullback(f, xs[:,i-1], y)
        end
    else
        # A bit of a hack, but makes sure `eltype(pullbacks)` is as narrow as possible
        pullbacks = [((xi,_) = pullback(f, xi, y))[2] for i = 2:n]
    end

    function g_pullback(dxs)
        dxi,dy = pullbacks[n-1](dxs[:,n]) .+ (dxs[:,n-1], Zero())
        for i = reverse(1:n-2)
            dxi,dy = pullbacks[i](dxi) .+ (dxs[:,i], dy)
        end
        return NO_FIELDS, dxi, dy, DoesNotExist()
    end

	return xs, g_pullback
end

Test:

x0 = rand(2)
y = rand(2)
n = 3

h_rule(y) = sum(g(x0, y, n))
h_auto(y) = sum(g_zygote(x0, y, n))

@show only(gradient(h_rule,y))
@show only(gradient(h_auto,y))

#=
only(gradient(h_rule, y)) = [1.8040253084967446, 0.42600995398401]
only(gradient(h_auto, y)) = [1.8040253084967446, 0.42600995398401]
=#
2 Likes

Perfect – thanks!