We’ve had this issue for over 2 weeks now (see the date of the previous posts in this discussion). When was this new Zygote issue introduced? I’m surprised it was not brought up during the previous lengthy discussion.
That was yesterday. Okay, I’ll mark this to take a look after Flux is all fixed up, though with the JuliaCon rush I am a bit behind
So after having investigated this issue a little bit more, I’m suspecting that Zygote is having issues computing the gradients of a function involving variables with very small Float64 values (e.g. 2e-16). The returned
back() function from the pullback produces
NaN results when applied to the NN parameters.
Is there any way to make sure Zygote can compute gradients for operations involving very small floats? I have an alternative example using a heat equation without the super small float values and it works perfectly fine. Thanks again!
Is there a case you can isolate?
Alright, so after a very long time investigating this bug we have finally found the source. The problem appears to come from Zygote producing a
NaN gradient for the
sqrt function. More precisely, we were doing:
∇S = sqrt.(avg_y(dSdx).^2 .+ avg_x(dSdy).^2) # this does not work D = Γ .* avg(H).^(n + 2) .* ∇S.^(n - 1)
So we isolated the issue, which is coming from
sqrt. When we changed it to:
∇S² = avg_y(dSdx).^2 .+ avg_x(dSdy).^2 # this does work D = Γ .* avg(H).^(n + 2) .* ∇S².^((n - 1)/2)
everything works perfectly.
Is this normal? We’re extremely surprised that Zygote cannot provide a gradient for such a simple function.
that looks worth isolating to a Zygote issue. I think it’s from the broadcast implementation.
Here is a minimal example where Zygote fails when trying to differentiate
using Zygote using Flux A₀ = [[1,0] [0,3]] A₁ = [[0,0] [0,0]] function loss(θ) A = A₀.^θ A = sqrt.(A) return sqrt(Flux.Losses.mse(A, A₀; agg=sum)) end θ = 4.0 loss_θ, back_θ = Zygote.pullback(loss, θ)
For this last case, the value of
NaN. However, if we avoid the use of
sqrt() by defining the lost function as
function loss(θ) A = A₀.^(θ/2) return sqrt(Flux.Losses.mse(A, A₀; agg=sum)) end
then the gradient gives the correct result.
If you boil this down a little, your two loss functions are doing something like this (you could delete the
sqrt(3^x) term here too):
julia> using Zygote julia> withgradient(x -> sqrt(0^x) + sqrt(3^x), 4) (val = 9.0, grad = (NaN,)) julia> withgradient(x -> 0^(x/2) + 3^(x/2), 4) (val = 9.0, grad = (4.943755299006494,)) julia> let x = 4.001 sqrt(0^x) + sqrt(3^x) end 9.004945113365242 # supports the 2nd answer
The reason you get
NaN is that the slope of sqrt at zero is infinite. That infinity multiplies the slope of
0^x at 4, which is zero. Whereas with the
0^(x/2) version, the slope is simply zero.
For an AD system to do better, I suppose it would need to keep track of how big an infinity the gradient of
sqrt is… has anyone made such a thing?
julia> ForwardDiff.derivative(x -> sqrt(0^x) + sqrt(3^x), 4) NaN julia> ForwardDiff.derivative(x -> 0^(x/2) + 3^(x/2), 4) 4.943755299006494 julia> gradient(sqrt, 0) # Zygote (Inf,) julia> ForwardDiff.derivative(sqrt, 0) # often used in Zygote's broadcasting Inf julia> gradient(x -> 0^x, 4) (0.0,) julia> ForwardDiff.derivative(x -> 0^x, 4) 0
Thanks for shedding some light into this @mcabbott. For such a flexible library as Zygote, meant to provide AD for Julia source code, I think it would be important to provide gradients for
sqrt. It’s a pretty common function, so this bug will probably be encountered by a large number of users.
Some even simpler examples might be:
julia> f(x) = inv(inv(x)); julia> g(x) = cbrt(x)^3; julia> all(f(x)==x for x in -10:10) true julia> all(g(x)≈x for x in -10:10) true julia> gradient(f, 0) (NaN,) julia> gradient(g, 0) (NaN,)
If the chain of functions (or rather, the chain of their derivatives) contains singularities, at intermediate steps, then the final gradient will tend to be NaN. Even if it’s obvious to a human that things ought to cancel.
The individual components all seem correct here. I think these examples could be fixed by returning incorrect gradients near singularities, e.g. replacing the gradient at
x==0 with one slightly off of it, like
gradient(cbrt, 0 + eps()). But this may have horrible consequences elsewhere, I’m not sure.