Help optimizing 1000x Zygote overhead for linear interpolation

I need to compute a gradient of an interpolation w.r.t. to the knot y-positions. None of the standard packages seemed to work with Zygote so I coded this very simple MWE:

using Zygote, ForwardDiff, BenchmarkTools

function LinearInterpolation(xdat::AbstractVector, ydat::AbstractVector{T}, extrapolation_value::T = T(NaN)) where {T}
    m = diff(ydat) ./ diff(xdat)
    x_lower, x_upper = first(xdat), last(xdat)
    function (x)
        if x_lower <= x <= x_upper
            # sets i such that x is between xdat[i] and xdat[i+1]
            i = Zygote.@ignore(max(1, searchsortedfirst(xdat, x) - 1))
            @inbounds(ydat[i] + m[i]*(x-xdat[i]))
        else
            return extrapolation_value
        end
    end
end

However, while the Zygote gradient works, its incredibly slow, its about ~1000X slower than the evaluation itself, and ~100X slower than with ForwardDiff for a typical use-case for me:

xdat = collect(range(0,1,length=100))
ydat = rand(100)
x = rand(128,128)

@btime sum(LinearInterpolation($xdat, $ydat).($x))
# ~500μs

@btime Zygote.gradient(ydat -> sum(LinearInterpolation($xdat, ydat).($x)), $ydat)
# ~500ms

@btime ForwardDiff.gradient(ydat -> sum(LinearInterpolation($xdat, ydat).($x)), $ydat)
# ~5ms

Profiling reveals lots of time spent in Zygote internals which I’m not familiar with (eg _generate_pullback_via_decomposition, dynamic dispatch, and stuff where stack frames dissapear, so I guess are outside of Julia?) so that didn’t get me far.

I’m sure I’m just hitting a case where Zygote is known to be bad, but I’m wondering if anyone could still give some hints how I might rewrite this to make it less disastorously slow?

Alternatively, I guess the nuclear option is code the rrule for the entire function by hand. However, what’s the rrule of a function which returns a closure? I suppose I could also use ForwardDiff in the rrule, any examples of how to do that? Thanks.

I recently learned that you can mix reverse and forward differentiation. For Zygote this is the magic incantation.

I’m not sure about this but have you tried simply removing the closure to see how this affects the running time?

Just checked, doesn’t seem to.

Didn’t know about that one, indeed that makes it super easy to quickly switch to ForwardDiff for the relevant part. That’s gets me on the order of the ForwardDiff result I had above, which may be OK for now. Thanks! Here’s what I have now for reference (combined into one function):

function LinearInterpolation(xdat::AbstractVector{TX}, ydat::AbstractVector{TY}, x::AbstractArray{TX}, extrapolation_value::TY = TY(NaN)) where {TX,TY}
    x_lower, x_upper = first(xdat), last(xdat)
    Zygote.forwarddiff(ydat) do ydat
        m = diff(ydat) ./ diff(xdat)
        map(x) do x
            if x_lower <= x <= x_upper
                # sets i such that x is between xdat[i] and xdat[i+1]
                i = max(1, searchsortedfirst(xdat, x) - 1)
                @inbounds(ydat[i] + m[i]*(x-xdat[i]))
            else
                extrapolation_value
            end
        end
    end
end
1 Like