Factor of two in Zygote complex gradient

In the real case, I’d agree, but for a complex vector, I don’t think that’s still true. In fact, it’s the same for a complex scalar: For J(z) = 1 - |z|² = 1 - z z̄, the Wirtinger derivative ∂J/∂z̄ would be defined as (∂J/∂Re[z] + 𝕚 ∂J/∂Im[z])/2, see e.g. Complex Derivatives, Wirtinger View and the Chain Rule | Ekin Akyürek. So, with z = x + 𝕚y that would be -(2x + 𝕚 2y) / 2 = -z, which also matches the intuition that z and can simply be treated as independent variables. However, Zygote gives

J(z) = 1 - abs(z)^2
Zygote.gradient(z -> J(z), 1+1im)
(-2.0 - 2.0im,)

I then found that the Zygote manual uses exactly this example in the chapter on Complex Differentiation, and that I just didn’t read that closely enough. It also gives the solution of manually defining a Wirtinger derivative. This indeed also works for the vector case. In my example script, I can add

# Wirtinger derivative
function wirtinger(func, Ψ)
    y, back = Zygote.pullback(func, Ψ)
    du = back(1.0)[1]
    dv = back(1.0im)[1]
    return (du' + im*dv')/2, (du + im*dv)/2
end

∇3 = wirtinger(Ψ -> J(Ψ, Ψtgt), Ψ)[2]
norm(abs.(∇3 - ∇2))    # zero, as it should be

which does the right thing.

I’m still a little confused about what just Zygote.gradient does (I might have to re-read that chapter a few more times). In the scalar case, I’m pretty sure it does

∂f(z)/∂z -> ∂f(x, y)/∂x + 𝕚 ∂f(x, y)/∂y

which seems somewhat trivially connected to the correct Wirtinger derivative (up to my missing factor of 2).

On the other hand, it would seem like Zygote.gradient does a single backward-propagation, whereas wirtinger does two. So probably, I can’t use the simple Zygote.gradient in general.

1 Like