I need to compute a gradient of an interpolation w.r.t. to the knot y-positions. None of the standard packages seemed to work with Zygote so I coded this very simple MWE:
using Zygote, ForwardDiff, BenchmarkTools
function LinearInterpolation(xdat::AbstractVector, ydat::AbstractVector{T}, extrapolation_value::T = T(NaN)) where {T}
m = diff(ydat) ./ diff(xdat)
x_lower, x_upper = first(xdat), last(xdat)
function (x)
if x_lower <= x <= x_upper
# sets i such that x is between xdat[i] and xdat[i+1]
i = Zygote.@ignore(max(1, searchsortedfirst(xdat, x) - 1))
@inbounds(ydat[i] + m[i]*(x-xdat[i]))
else
return extrapolation_value
end
end
end
However, while the Zygote gradient works, its incredibly slow, its about ~1000X slower than the evaluation itself, and ~100X slower than with ForwardDiff for a typical use-case for me:
xdat = collect(range(0,1,length=100))
ydat = rand(100)
x = rand(128,128)
@btime sum(LinearInterpolation($xdat, $ydat).($x))
# ~500μs
@btime Zygote.gradient(ydat -> sum(LinearInterpolation($xdat, ydat).($x)), $ydat)
# ~500ms
@btime ForwardDiff.gradient(ydat -> sum(LinearInterpolation($xdat, ydat).($x)), $ydat)
# ~5ms
Profiling reveals lots of time spent in Zygote internals which I’m not familiar with (eg _generate_pullback_via_decomposition
, dynamic dispatch, and stuff where stack frames dissapear, so I guess are outside of Julia?) so that didn’t get me far.
I’m sure I’m just hitting a case where Zygote is known to be bad, but I’m wondering if anyone could still give some hints how I might rewrite this to make it less disastorously slow?
Alternatively, I guess the nuclear option is code the rrule for the entire function by hand. However, what’s the rrule of a function which returns a closure? I suppose I could also use ForwardDiff in the rrule, any examples of how to do that? Thanks.