However, this calculates derivatives with respect to all components of all input variables, where I only need a single component. Can this unneeded computation be avoided? As usual, runtime performance is of utmost importance
Seconding this question. I use ForwardDiff and not Zygote, but I have this really gnarly code pattern that probably would work similarly in Zygote that looks like
function coord_deriv(f,p,j)
ForwardDiff.derivative(z->f(vcat(p[1:(j-1)],z,p[(j+1):end])), p[j])
end
# A vector of functions that give individual gradient entries:
dfuns = [p->coord_deriv(myfn,p,j) for j in 1:length(params)]
But this is obviously heinous. I’d be very interested to hear from somebody who has a more elegant solution.
That’s fair. I just have always imagined that the vcat would make some unfortunately allocations. I’ve always wondered if using generated functions and forcing p::SVector{N,D} would help the compiler to get rid of them. But admittedly after tinkering for a minute, it seems like the extra allocations really are minimal when compared to the baseline of doing something manually. Which is pretty impressive.
You are right, if you are only differentiating wrt one variable, forward mode AD will likely be more performant. Unfortunately, ForwardDiff.Dual is always scalar, so although you can work with a vector of Duals, in cases like these, performance won’t be ideal.
I thought this would be a good example to show off Zygote’s forward diff capabilities, it’s just a bit unfortunate that it’s not hooked up to ChainRules yet, so the number of array rules is currently a bit meager. Not to worry though – as a workaround we will just manually steel the rules we need from ChainRules and define the array rule for - ourselves. We can also use Flux.OneHotVector for a more efficient representation of a unit vector, although I don’t expect it to make a huge difference here.
Putting all of this together, I got the following:
julia> using Zygote, ChainRules, LinearAlgebra, Flux
julia> for f in [:norm] # steel some frules (just `norm` here)
@eval Zygote.Forward._pushforward(dargs, ::typeof($f), x...) = frule(dargs, $f, x...)
end
julia> Zygote.Forward.@tangent A::AbstractArray - B::AbstractArray = A-B, (Ȧ, Ḃ) -> Ȧ .- Ḃ
julia> f(x, y) = exp(-norm(x-y)^2)
f (generic function with 1 method)
julia> pf = Zygote.pushforward(f, [1., 2.], [0., 2.])
#8 (generic function with 1 method)
julia> pf(Flux.OneHotVector(1, 2), Zero()) # ∂/∂x₁
-0.7357588823428847
julia> pf(Zero(), Flux.OneHotVector(2, 2)) # ∂/∂y₂
-0.0