Unrecognized gradient using Zygote for AD with Universal Differential Equations

We’ve had this issue for over 2 weeks now (see the date of the previous posts in this discussion). When was this new Zygote issue introduced? I’m surprised it was not brought up during the previous lengthy discussion.

That was yesterday. Okay, I’ll mark this to take a look after Flux is all fixed up, though with the JuliaCon rush I am a bit behind :sweat_smile:

1 Like

Cheers Chris. I’ll keep investigating this. Meanwhile I hope @mcabbott or @darsnack can reproduce it with the new MWE and give some extra hints on what might be wrong :grinning_face_with_smiling_eyes: :pray:

So after having investigated this issue a little bit more, I’m suspecting that Zygote is having issues computing the gradients of a function involving variables with very small Float64 values (e.g. 2e-16). The returned back() function from the pullback produces NaN results when applied to the NN parameters.

Is there any way to make sure Zygote can compute gradients for operations involving very small floats? I have an alternative example using a heat equation without the super small float values and it works perfectly fine. Thanks again!

Is there a case you can isolate?

Alright, so after a very long time investigating this bug we have finally found the source. The problem appears to come from Zygote producing a NaN gradient for the sqrt function. More precisely, we were doing:

∇S = sqrt.(avg_y(dSdx).^2 .+ avg_x(dSdy).^2)  # this does not work
D = Γ .* avg(H).^(n + 2) .* ∇S.^(n - 1) 

So we isolated the issue, which is coming from sqrt. When we changed it to:

∇S² = avg_y(dSdx).^2 .+ avg_x(dSdy).^2    # this does work
D = Γ .* avg(H).^(n + 2) .* ∇S².^((n - 1)/2)

everything works perfectly.

Is this normal? We’re extremely surprised that Zygote cannot provide a gradient for such a simple function.

that looks worth isolating to a Zygote issue. I think it’s from the broadcast implementation.

1 Like

Here is a minimal example where Zygote fails when trying to differentiate sqrt().

using Zygote 
using Flux

A₀ = [[1,0] [0,3]]
A₁ = [[0,0] [0,0]]

function loss(θ)
    A = A₀.^θ
    A = sqrt.(A)
    return sqrt(Flux.Losses.mse(A, A₀; agg=sum))

θ = 4.0
loss_θ, back_θ = Zygote.pullback(loss, θ) 

For this last case, the value of back_θ(1.0) is NaN. However, if we avoid the use of sqrt() by defining the lost function as

function loss(θ)
    A = A₀.^(θ/2)
    return sqrt(Flux.Losses.mse(A, A₀; agg=sum))

then the gradient gives the correct result.

1 Like

If you boil this down a little, your two loss functions are doing something like this (you could delete the sqrt(3^x) term here too):

julia> using Zygote

julia> withgradient(x -> sqrt(0^x) + sqrt(3^x), 4)
(val = 9.0, grad = (NaN,))

julia> withgradient(x -> 0^(x/2) + 3^(x/2), 4)
(val = 9.0, grad = (4.943755299006494,))

julia> let x = 4.001
         sqrt(0^x) + sqrt(3^x)
9.004945113365242  # supports the 2nd answer

The reason you get NaN is that the slope of sqrt at zero is infinite. That infinity multiplies the slope of 0^x at 4, which is zero. Whereas with the 0^(x/2) version, the slope is simply zero.

For an AD system to do better, I suppose it would need to keep track of how big an infinity the gradient of sqrt is… has anyone made such a thing?

julia> ForwardDiff.derivative(x -> sqrt(0^x) + sqrt(3^x), 4)

julia> ForwardDiff.derivative(x -> 0^(x/2) + 3^(x/2), 4)

julia> gradient(sqrt, 0)  # Zygote

julia> ForwardDiff.derivative(sqrt, 0)  # often used in Zygote's broadcasting

julia> gradient(x -> 0^x, 4)

julia> ForwardDiff.derivative(x -> 0^x, 4)
1 Like

Thanks for shedding some light into this @mcabbott. For such a flexible library as Zygote, meant to provide AD for Julia source code, I think it would be important to provide gradients for sqrt. It’s a pretty common function, so this bug will probably be encountered by a large number of users.

I have opened an issue following @facusapienza 's MWE.

Some even simpler examples might be:

julia> f(x) = inv(inv(x));
julia> g(x) = cbrt(x)^3;

julia> all(f(x)==x for x in -10:10)
julia> all(g(x)≈x for x in -10:10)

julia> gradient(f, 0)
julia> gradient(g, 0)

If the chain of functions (or rather, the chain of their derivatives) contains singularities, at intermediate steps, then the final gradient will tend to be NaN. Even if it’s obvious to a human that things ought to cancel.

The individual components all seem correct here. I think these examples could be fixed by returning incorrect gradients near singularities, e.g. replacing the gradient at x==0 with one slightly off of it, like gradient(cbrt, 0 + eps()). But this may have horrible consequences elsewhere, I’m not sure.