In the real case, I’d agree, but for a complex vector, I don’t think that’s still true. In fact, it’s the same for a complex scalar: For J(z) = 1 - |z|² = 1 - z z̄
, the Wirtinger derivative ∂J/∂z̄
would be defined as (∂J/∂Re[z] + 𝕚 ∂J/∂Im[z])/2
, see e.g. Complex Derivatives, Wirtinger View and the Chain Rule | Ekin Akyürek. So, with z = x + 𝕚y
that would be -(2x + 𝕚 2y) / 2 = -z
, which also matches the intuition that z
and z̄
can simply be treated as independent variables. However, Zygote gives
J(z) = 1 - abs(z)^2
Zygote.gradient(z -> J(z), 1+1im)
(-2.0 - 2.0im,)
I then found that the Zygote manual uses exactly this example in the chapter on Complex Differentiation, and that I just didn’t read that closely enough. It also gives the solution of manually defining a Wirtinger derivative. This indeed also works for the vector case. In my example script, I can add
# Wirtinger derivative
function wirtinger(func, Ψ)
y, back = Zygote.pullback(func, Ψ)
du = back(1.0)[1]
dv = back(1.0im)[1]
return (du' + im*dv')/2, (du + im*dv)/2
end
∇3 = wirtinger(Ψ -> J(Ψ, Ψtgt), Ψ)[2]
norm(abs.(∇3 - ∇2)) # zero, as it should be
which does the right thing.
I’m still a little confused about what just Zygote.gradient
does (I might have to re-read that chapter a few more times). In the scalar case, I’m pretty sure it does
∂f(z)/∂z -> ∂f(x, y)/∂x + 𝕚 ∂f(x, y)/∂y
which seems somewhat trivially connected to the correct Wirtinger derivative (up to my missing factor of 2).
On the other hand, it would seem like Zygote.gradient
does a single backward-propagation, whereas wirtinger
does two. So probably, I can’t use the simple Zygote.gradient
in general.