Unrecognized gradient using Zygote for AD with Universal Differential Equations

If you boil this down a little, your two loss functions are doing something like this (you could delete the sqrt(3^x) term here too):

julia> using Zygote

julia> withgradient(x -> sqrt(0^x) + sqrt(3^x), 4)
(val = 9.0, grad = (NaN,))

julia> withgradient(x -> 0^(x/2) + 3^(x/2), 4)
(val = 9.0, grad = (4.943755299006494,))

julia> let x = 4.001
         sqrt(0^x) + sqrt(3^x)
       end
9.004945113365242  # supports the 2nd answer

The reason you get NaN is that the slope of sqrt at zero is infinite. That infinity multiplies the slope of 0^x at 4, which is zero. Whereas with the 0^(x/2) version, the slope is simply zero.

For an AD system to do better, I suppose it would need to keep track of how big an infinity the gradient of sqrt is… has anyone made such a thing?

julia> ForwardDiff.derivative(x -> sqrt(0^x) + sqrt(3^x), 4)
NaN

julia> ForwardDiff.derivative(x -> 0^(x/2) + 3^(x/2), 4)
4.943755299006494

julia> gradient(sqrt, 0)  # Zygote
(Inf,)

julia> ForwardDiff.derivative(sqrt, 0)  # often used in Zygote's broadcasting
Inf

julia> gradient(x -> 0^x, 4)
(0.0,)

julia> ForwardDiff.derivative(x -> 0^x, 4)
0
1 Like