Issue with Zygote over ForwardDiff.derivative

I’m having some trouble getting Zygote over ForwardDiff.derivative to work.

I’m going to refer to this prior post. The following code used to fail with ERROR: setindex! not defined for ForwardDiff.Partials{1,Float64}. The fix was to define some additional adjoints, as @ChrisRackauckas showed in DiffEqFlux.jl.

using Flux, ForwardDiff

f = Chain(x -> fill(x, 3), Dense(3, 3, softplus))
df(x) = ForwardDiff.derivative(f, x)

x = rand()
f(x) #Works
df(x) #Works
gs = gradient(() -> sum(df(x)), params(f)) #Fails

However, the code above now runs without the additional adjoints, but the gradients returned are nothing.

One of my codes used a similar Zygote over ForwardDiff.derivative idea, but it no longer trains as all of the gradients are nothing. Something seems to have changed, but I don’t know where to start. Unfortunately, I don’t have the old Project or Manifest files, so I don’t know what versions I was using.

what’s your full MWE?

The example above fails in the way I describe. gs should have the gradients wrt to params(f), but has nothing instead.

My example is more along the lines of this, computing the dot product between the gradient and a vector, which is equivalent to the directional derivative.

using Flux
using ForwardDiff

net = Chain(Dense(2, 128, relu), Dense(128, 128, relu), Dense(128, 1))
p, re = Flux.destructure(net)

x = randn(Float32, 2, 128)
dx = randn(Float32, 2, 128)

grads = Flux.gradient(p -> sum(ForwardDiff.derivative(h -> re(p)(x + h*dx), 0.0f0)), p)

This used to fail without defining some extra adjoints as in DiffEqFlux. It now just gives nothing

If you add the DiffEqFlux adjoints does it work?

No, that doesn’t make a difference.

Interesting. @mcabbott would you know something about what might’ve changed?

Yes this won’t work, sadly. The warning from Zygote.forwarddiff is:

Note that the function `f` will *drop gradients* for any closed-over values.

and that’s what’s being used here. That is, it’s forward-over-forward, and takes derivatives only with respect to the explicit parameter, not to anything closed over (since ForwardDiff is unaware of those).

Making it give errors when f closes over anything would be better. Making it actually work… I’m not sure, might be possible? Does DiffEqFlux.jl have (pirate?) code which handles this?

Yes.

ZygoteRules.@adjoint function ForwardDiff.Dual{T}(x, ẋ::Tuple) where T
  @assert length(ẋ) == 1
  ForwardDiff.Dual{T}(x, ẋ), ḋ -> (ḋ.partials[1], (ḋ.value,))
end

ZygoteRules.@adjoint ZygoteRules.literal_getproperty(d::ForwardDiff.Dual{T}, ::Val{:partials}) where T =
  d.partials, ṗ -> (ForwardDiff.Dual{T}(ṗ[1], 0),)

ZygoteRules.@adjoint ZygoteRules.literal_getproperty(d::ForwardDiff.Dual{T}, ::Val{:value}) where T =
  d.value, ẋ -> (ForwardDiff.Dual{T}(0, ẋ),)

All of our pirate code is: DiffEqFlux.jl/DiffEqFlux.jl at v1.44.0 · SciML/DiffEqFlux.jl · GitHub and we should upstream some of it.

1 Like

Is there any possibility of a workaround? This used to work ~1 year ago.

The only other approach I have been able to make work is ReverseDiff over Zygote, but for some reason this is super slow (I’ll create another thread about this).