Calculating specific components of gradients

Occasionally, I work with scalar-valued functions that take multiple vectors as arguments. A simple example is:

f(\mathbf{x}, \mathbf{y}) = e^{-|\mathbf{x}-\mathbf{y}|^2}

In Julia, this can be coded as:

f(x, y) = exp(-norm(x-y)^2)

I would like to automatically differentiate f with respect to a component of an argument vector, for example

\frac{\partial f}{\partial x_i} = -2(x_i - y_i) e^{-|\mathbf{x}-\mathbf{y}|^2}

where x_i and y_i are the i-th component of \mathbf{x} and \mathbf{y}, respectively.

I can calculate the gradients at points \mathbf{x} and \mathbf{y} with Zygote.jl:

julia> gradient(f, [1., 2.], [0., 2.])
([-0.7357588823428847, -0.0], [0.7357588823428847, 0.0])

However, this calculates derivatives with respect to all components of all input variables, where I only need a single component. Can this unneeded computation be avoided? As usual, runtime performance is of utmost importance :smiley:



Seconding this question. I use ForwardDiff and not Zygote, but I have this really gnarly code pattern that probably would work similarly in Zygote that looks like

function coord_deriv(f,p,j)
  ForwardDiff.derivative(z->f(vcat(p[1:(j-1)],z,p[(j+1):end])), p[j])

# A vector of functions that give individual gradient entries:
dfuns =  [p->coord_deriv(myfn,p,j) for j in 1:length(params)]

But this is obviously heinous. I’d be very interested to hear from somebody who has a more elegant solution.


That’s in the eye of the beholder — I think it is the right solution though.

1 Like

That’s fair. I just have always imagined that the vcat would make some unfortunately allocations. I’ve always wondered if using generated functions and forcing p::SVector{N,D} would help the compiler to get rid of them. But admittedly after tinkering for a minute, it seems like the extra allocations really are minimal when compared to the baseline of doing something manually. Which is pretty impressive.

You are right, if you are only differentiating wrt one variable, forward mode AD will likely be more performant. Unfortunately, ForwardDiff.Dual is always scalar, so although you can work with a vector of Duals, in cases like these, performance won’t be ideal.

I thought this would be a good example to show off Zygote’s forward diff capabilities, it’s just a bit unfortunate that it’s not hooked up to ChainRules yet, so the number of array rules is currently a bit meager. Not to worry though – as a workaround we will just manually steel the rules we need from ChainRules and define the array rule for - ourselves. We can also use Flux.OneHotVector for a more efficient representation of a unit vector, although I don’t expect it to make a huge difference here.

Putting all of this together, I got the following:

julia> using Zygote, ChainRules, LinearAlgebra, Flux

julia> for f in [:norm] # steel some frules (just `norm` here)
           @eval Zygote.Forward._pushforward(dargs, ::typeof($f), x...) = frule(dargs, $f, x...)

julia> Zygote.Forward.@tangent A::AbstractArray - B::AbstractArray = A-B, (Ȧ, Ḃ) -> Ȧ .- Ḃ

julia> f(x, y) = exp(-norm(x-y)^2)
f (generic function with 1 method)

julia> pf = Zygote.pushforward(f, [1., 2.], [0., 2.])
#8 (generic function with 1 method)

julia> pf(Flux.OneHotVector(1, 2), Zero()) # ∂/∂x₁

julia> pf(Zero(), Flux.OneHotVector(2, 2)) # ∂/∂y₂