We’ve had this issue for over 2 weeks now (see the date of the previous posts in this discussion). When was this new Zygote issue introduced? I’m surprised it was not brought up during the previous lengthy discussion.
That was yesterday. Okay, I’ll mark this to take a look after Flux is all fixed up, though with the JuliaCon rush I am a bit behind
Cheers Chris. I’ll keep investigating this. Meanwhile I hope @mcabbott or @darsnack can reproduce it with the new MWE and give some extra hints on what might be wrong
So after having investigated this issue a little bit more, I’m suspecting that Zygote is having issues computing the gradients of a function involving variables with very small Float64 values (e.g. 2e-16). The returned back()
function from the pullback produces NaN
results when applied to the NN parameters.
Is there any way to make sure Zygote can compute gradients for operations involving very small floats? I have an alternative example using a heat equation without the super small float values and it works perfectly fine. Thanks again!
Is there a case you can isolate?
Alright, so after a very long time investigating this bug we have finally found the source. The problem appears to come from Zygote producing a NaN
gradient for the sqrt
function. More precisely, we were doing:
∇S = sqrt.(avg_y(dSdx).^2 .+ avg_x(dSdy).^2) # this does not work
D = Γ .* avg(H).^(n + 2) .* ∇S.^(n - 1)
So we isolated the issue, which is coming from sqrt
. When we changed it to:
∇S² = avg_y(dSdx).^2 .+ avg_x(dSdy).^2 # this does work
D = Γ .* avg(H).^(n + 2) .* ∇S².^((n - 1)/2)
everything works perfectly.
Is this normal? We’re extremely surprised that Zygote cannot provide a gradient for such a simple function.
that looks worth isolating to a Zygote issue. I think it’s from the broadcast implementation.
Here is a minimal example where Zygote fails when trying to differentiate sqrt()
.
using Zygote
using Flux
A₀ = [[1,0] [0,3]]
A₁ = [[0,0] [0,0]]
function loss(θ)
A = A₀.^θ
A = sqrt.(A)
return sqrt(Flux.Losses.mse(A, A₀; agg=sum))
end
θ = 4.0
loss_θ, back_θ = Zygote.pullback(loss, θ)
For this last case, the value of back_θ(1.0)
is NaN
. However, if we avoid the use of sqrt()
by defining the lost function as
function loss(θ)
A = A₀.^(θ/2)
return sqrt(Flux.Losses.mse(A, A₀; agg=sum))
end
then the gradient gives the correct result.
If you boil this down a little, your two loss functions are doing something like this (you could delete the sqrt(3^x)
term here too):
julia> using Zygote
julia> withgradient(x -> sqrt(0^x) + sqrt(3^x), 4)
(val = 9.0, grad = (NaN,))
julia> withgradient(x -> 0^(x/2) + 3^(x/2), 4)
(val = 9.0, grad = (4.943755299006494,))
julia> let x = 4.001
sqrt(0^x) + sqrt(3^x)
end
9.004945113365242 # supports the 2nd answer
The reason you get NaN
is that the slope of sqrt at zero is infinite. That infinity multiplies the slope of 0^x
at 4, which is zero. Whereas with the 0^(x/2)
version, the slope is simply zero.
For an AD system to do better, I suppose it would need to keep track of how big an infinity the gradient of sqrt
is… has anyone made such a thing?
julia> ForwardDiff.derivative(x -> sqrt(0^x) + sqrt(3^x), 4)
NaN
julia> ForwardDiff.derivative(x -> 0^(x/2) + 3^(x/2), 4)
4.943755299006494
julia> gradient(sqrt, 0) # Zygote
(Inf,)
julia> ForwardDiff.derivative(sqrt, 0) # often used in Zygote's broadcasting
Inf
julia> gradient(x -> 0^x, 4)
(0.0,)
julia> ForwardDiff.derivative(x -> 0^x, 4)
0
Thanks for shedding some light into this @mcabbott. For such a flexible library as Zygote, meant to provide AD for Julia source code, I think it would be important to provide gradients for sqrt
. It’s a pretty common function, so this bug will probably be encountered by a large number of users.
I have opened an issue following @facusapienza 's MWE.
Some even simpler examples might be:
julia> f(x) = inv(inv(x));
julia> g(x) = cbrt(x)^3;
julia> all(f(x)==x for x in -10:10)
true
julia> all(g(x)≈x for x in -10:10)
true
julia> gradient(f, 0)
(NaN,)
julia> gradient(g, 0)
(NaN,)
If the chain of functions (or rather, the chain of their derivatives) contains singularities, at intermediate steps, then the final gradient will tend to be NaN. Even if it’s obvious to a human that things ought to cancel.
The individual components all seem correct here. I think these examples could be fixed by returning incorrect gradients near singularities, e.g. replacing the gradient at x==0
with one slightly off of it, like gradient(cbrt, 0 + eps())
. But this may have horrible consequences elsewhere, I’m not sure.