I’m having some trouble getting Zygote over ForwardDiff.derivative
to work.
I’m going to refer to this prior post. The following code used to fail with ERROR: setindex! not defined for ForwardDiff.Partials{1,Float64}
. The fix was to define some additional adjoints, as @ChrisRackauckas showed in DiffEqFlux.jl.
using Flux, ForwardDiff
f = Chain(x -> fill(x, 3), Dense(3, 3, softplus))
df(x) = ForwardDiff.derivative(f, x)
x = rand()
f(x) #Works
df(x) #Works
gs = gradient(() -> sum(df(x)), params(f)) #Fails
However, the code above now runs without the additional adjoints, but the gradients returned are nothing
.
One of my codes used a similar Zygote over ForwardDiff.derivative
idea, but it no longer trains as all of the gradients are nothing
. Something seems to have changed, but I don’t know where to start. Unfortunately, I don’t have the old Project or Manifest files, so I don’t know what versions I was using.
1 Like
The example above fails in the way I describe. gs
should have the gradients wrt to params(f)
, but has nothing
instead.
My example is more along the lines of this, computing the dot product between the gradient and a vector, which is equivalent to the directional derivative.
using Flux
using ForwardDiff
net = Chain(Dense(2, 128, relu), Dense(128, 128, relu), Dense(128, 1))
p, re = Flux.destructure(net)
x = randn(Float32, 2, 128)
dx = randn(Float32, 2, 128)
grads = Flux.gradient(p -> sum(ForwardDiff.derivative(h -> re(p)(x + h*dx), 0.0f0)), p)
This used to fail without defining some extra adjoints as in DiffEqFlux. It now just gives nothing
If you add the DiffEqFlux adjoints does it work?
No, that doesn’t make a difference.
Interesting. @mcabbott would you know something about what might’ve changed?
Yes this won’t work, sadly. The warning from Zygote.forwarddiff is:
Note that the function `f` will *drop gradients* for any closed-over values.
and that’s what’s being used here. That is, it’s forward-over-forward, and takes derivatives only with respect to the explicit parameter, not to anything closed over (since ForwardDiff is unaware of those).
Making it give errors when f
closes over anything would be better. Making it actually work… I’m not sure, might be possible? Does DiffEqFlux.jl have (pirate?) code which handles this?
Yes.
ZygoteRules.@adjoint function ForwardDiff.Dual{T}(x, ẋ::Tuple) where T
@assert length(ẋ) == 1
ForwardDiff.Dual{T}(x, ẋ), ḋ -> (ḋ.partials[1], (ḋ.value,))
end
ZygoteRules.@adjoint ZygoteRules.literal_getproperty(d::ForwardDiff.Dual{T}, ::Val{:partials}) where T =
d.partials, ṗ -> (ForwardDiff.Dual{T}(ṗ[1], 0),)
ZygoteRules.@adjoint ZygoteRules.literal_getproperty(d::ForwardDiff.Dual{T}, ::Val{:value}) where T =
d.value, ẋ -> (ForwardDiff.Dual{T}(0, ẋ),)
All of our pirate code is: https://github.com/SciML/DiffEqFlux.jl/blob/v1.44.0/src/DiffEqFlux.jl#L60-L74 and we should upstream some of it.
1 Like
Is there any possibility of a workaround? This used to work ~1 year ago.
The only other approach I have been able to make work is ReverseDiff over Zygote, but for some reason this is super slow (I’ll create another thread about this).
1 Like
Has it been any update or progress in this line? I recently encountered a similar problem that I posted in Nested and different AD methods altogether: How to add AD calculations inside my loss function when using neural differential equations? that I am trying to make work. Thanks!
1 Like