Factor of two in Zygote complex gradient

Either Zygote’s gradients for a scalar function with respect to a complex vector are off by a factor of two, I’ve been doing matrix algebra wrong for the last 10 years, or I’m just being stupid here. Consider the following example:

using LinearAlgebra
using Zygote

# Two random normalized complex state vectors |Ψ⟩, |Ψtgt⟩
N = 10
Ψ = rand(N) .* exp.(2π .* 1im .* rand(N)); Ψ ./= norm(Ψ)
Ψtgt = rand(N) .* exp.(2π .* 1im .* rand(N)); Ψtgt ./= norm(Ψtgt)

# Functional J = 1 - |⟨Ψtgt|Ψ⟩|² = 1 - ⟨Ψtgt|Ψ⟩⟨Ψ|Ψtgt⟩
J(Ψ, Ψtgt) = 1 - abs(Ψtgt ⋅ Ψ)^2

# Zygote gradient
∇1 = Zygote.gradient(ϕ -> J(ϕ, Ψtgt), Ψ)[1]

# Manual gradient: ∂/∂⟨Ψ| (1 - ⟨Ψtgt|Ψ⟩⟨Ψ|Ψtgt⟩) = - ⟨Ψtgt|Ψ⟩ |Ψtgt⟩
grad(Ψ, Ψtgt) =  -1 .* (Ψtgt ⋅ Ψ) .* Ψtgt
∇2 = grad(Ψ, Ψtgt)

norm(abs.(∇2 - ∇1))    # should be zero, but is not
norm(abs.(∇2 - ∇1/2))  # zero

I’m using the notation ⟨a|b⟩ ≡ a ⋅ b for the inner product here (borrowed from quantum mechanics, |b⟩ is a (column) vector here, and ⟨b| would be the conjugate-transpose (row) vector, the “co-state”).

I was always taught (and have taught others) that if you have a functional J(|Ψ⟩) you can do the derivatives by simply treating states and co-states as independent variables. I actually never bothered to check entrywise that this derivative w.r.t. the co-state matches the Wirtinger derivative, but that was always the assumption.

So, am I doing something wrong here, or is there something wrong with Zygote?

I think your manual gradient is missing a 2. The variable your differentiating shows up twice so you need the product rule (or alternatively when looking at the absolute value squared you need to chain rule through the squared function which gives a 2).

http://www.matrixcalculus.org/ seems to agree (although they assume everything is real):


(Sorry for the screenshot, I’m on my phone).

In the real case, I’d agree, but for a complex vector, I don’t think that’s still true. In fact, it’s the same for a complex scalar: For J(z) = 1 - |z|² = 1 - z z̄, the Wirtinger derivative ∂J/∂z̄ would be defined as (∂J/∂Re[z] + 𝕚 ∂J/∂Im[z])/2, see e.g. Complex Derivatives, Wirtinger View and the Chain Rule | Ekin Akyürek. So, with z = x + 𝕚y that would be -(2x + 𝕚 2y) / 2 = -z, which also matches the intuition that z and can simply be treated as independent variables. However, Zygote gives

J(z) = 1 - abs(z)^2
Zygote.gradient(z -> J(z), 1+1im)
(-2.0 - 2.0im,)

I then found that the Zygote manual uses exactly this example in the chapter on Complex Differentiation, and that I just didn’t read that closely enough. It also gives the solution of manually defining a Wirtinger derivative. This indeed also works for the vector case. In my example script, I can add

# Wirtinger derivative
function wirtinger(func, Ψ)
    y, back = Zygote.pullback(func, Ψ)
    du = back(1.0)[1]
    dv = back(1.0im)[1]
    return (du' + im*dv')/2, (du + im*dv)/2
end

∇3 = wirtinger(Ψ -> J(Ψ, Ψtgt), Ψ)[2]
norm(abs.(∇3 - ∇2))    # zero, as it should be

which does the right thing.

I’m still a little confused about what just Zygote.gradient does (I might have to re-read that chapter a few more times). In the scalar case, I’m pretty sure it does

∂f(z)/∂z -> ∂f(x, y)/∂x + 𝕚 ∂f(x, y)/∂y

which seems somewhat trivially connected to the correct Wirtinger derivative (up to my missing factor of 2).

On the other hand, it would seem like Zygote.gradient does a single backward-propagation, whereas wirtinger does two. So probably, I can’t use the simple Zygote.gradient in general.

1 Like

Ah I see, I did mentally skip over the Wirtinger part. To me it’s less natural to define two derivatives like that, but googling it, it seems pretty common and probably has its uses.

See also Taking Complex Autodiff Seriously in ChainRules - #62 by Mason and Complex numbers · ChainRules (since Zygote is built on top of ChainRules).

3 Likes

Ok, I think I got the connection between Zygote.gradient and the Wirtinger derivative now. It seems Iike I can in fact use a single call to Zygote.gradient, because my function J is \in \mathbb{R} (whereas the above wirtinger function also handles the more general case of J \in \mathbb{C}). Specifically:

  1. Zygote.gradient(func, Ψ) is equivalent to y, back = Zygote.pullback(func, Ψ); back(1.0)[1], according to the Zygote documentation.

  2. Zygote.pullback falls back to ChainRules.rrule (or, at least, is compatible with ChainRules.rrule).

  3. Given a function J: \mathbb{C}^N \rightarrow \mathbb{R}, and z⃗, back = Zygote.pullback(J, z⃗₀), the pullback for the imaginary unity is zero, that is, dv = back(1im)[1] = 0 in the wirtinger function:

    • According to the ChainRules documentation, for a function \mathbb{C} \rightarrow \mathbb{C} defined as f(x+iy) = u(x,y) + iv(x, y), the rrule returns \Delta u \, \tfrac{\partial u}{\partial x} + \Delta v \, \tfrac{\partial v}{\partial x} + i \, \Bigl(\Delta u \, \tfrac{\partial u }{\partial y} + \Delta v \, \tfrac{\partial v}{\partial y} \Bigr), where (\Delta u, \Delta v) is the adjoint that we feed into rrule. In our case, \Delta u = 0, \Delta v = 1 (because we’re feeding in 1im), and u=J, v=0 (because J \in \mathbb{R}). Thus the entire pullback is zero.
    • The pullback for a vector function is the sum of the scalar pullbacks for the individual components, so the argument that back(1im)[1]=0 still holds.

Thus, in the wirtinger function, du = Zygote.gradient(func, Ψ) and dv = 0, and the two outputs are simply du'/2 and du/2. That is, the gradient up to a factor of two, as I was observing.