Either Zygote’s gradients for a scalar function with respect to a complex vector are off by a factor of two, I’ve been doing matrix algebra wrong for the last 10 years, or I’m just being stupid here. Consider the following example:
using LinearAlgebra
using Zygote
# Two random normalized complex state vectors |Ψ⟩, |Ψtgt⟩
N = 10
Ψ = rand(N) .* exp.(2π .* 1im .* rand(N)); Ψ ./= norm(Ψ)
Ψtgt = rand(N) .* exp.(2π .* 1im .* rand(N)); Ψtgt ./= norm(Ψtgt)
# Functional J = 1 - |⟨Ψtgt|Ψ⟩|² = 1 - ⟨Ψtgt|Ψ⟩⟨Ψ|Ψtgt⟩
J(Ψ, Ψtgt) = 1 - abs(Ψtgt ⋅ Ψ)^2
# Zygote gradient
∇1 = Zygote.gradient(ϕ -> J(ϕ, Ψtgt), Ψ)[1]
# Manual gradient: ∂/∂⟨Ψ| (1 - ⟨Ψtgt|Ψ⟩⟨Ψ|Ψtgt⟩) = - ⟨Ψtgt|Ψ⟩ |Ψtgt⟩
grad(Ψ, Ψtgt) = -1 .* (Ψtgt ⋅ Ψ) .* Ψtgt
∇2 = grad(Ψ, Ψtgt)
norm(abs.(∇2 - ∇1)) # should be zero, but is not
norm(abs.(∇2 - ∇1/2)) # zero
I’m using the notation ⟨a|b⟩ ≡ a ⋅ b
for the inner product here (borrowed from quantum mechanics, |b⟩
is a (column) vector here, and ⟨b|
would be the conjugate-transpose (row) vector, the “co-state”).
I was always taught (and have taught others) that if you have a functional J(|Ψ⟩)
you can do the derivatives by simply treating states and co-states as independent variables. I actually never bothered to check entrywise that this derivative w.r.t. the co-state matches the Wirtinger derivative, but that was always the assumption.
So, am I doing something wrong here, or is there something wrong with Zygote?