Is there a differentiable implementation of matrix square root?

Hello!

I’m trying to make the following code work:

using ForwardDiff, LinearAlgebra

# Fidelity between two density matrices
function fidelity(ρ::AbstractMatrix, σ::AbstractMatrix)
    sqrt_ρ = sqrt(ρ)
    abs2(tr(sqrt(sqrt_ρ * σ * sqrt_ρ)))
end

# Matrix representation of a Bloch vector
function matrix_representation(r)
    [(1+r[1]) (r[2]-r[3]im); (r[2]+r[3]im) (1-r[1])] ./ 2
end

# Fidelity between a density matrix and a Bloch vector
function fidelity(ρ::AbstractMatrix, r::AbstractVector)
    fidelity(ρ, matrix_representation(r))
end

# Gradient of the fidelity
function ∇fidelity(ρ::AbstractMatrix, r::AbstractVector)
    ForwardDiff.gradient(r -> fidelity(ρ, r), r)
end

r = [0, 0, 0]
ρ = matrix_representation(r)
∇fidelity(ρ, r)

In it, I attempt to calculate the gradient of the fidelity between two quantum states, differentiating with respect to the Bloch vector of one of them. Unfortunately, I get the error

ERROR: MethodError: no method matching eigen!(::Hermitian{Complex{ForwardDiff.Dual{…}}, Matrix{Complex{…}}}; sortby::Nothing)

which I believe means that the sqrt(::AbstractMatrix) method is not differentiable by ForwardDiff.

Could someone point me to a differentiable implementation of such a method, or propose a workaround?

Obs: This 2\times2 case is only an example, and I actually need a method that works on arbitrary dimension.

Thanks!

You can take a look at Enzyme.jl or DifferentiableFactorizations.jl, although I’m unsure how either behaves with complex input

1 Like

Since you only need a 2x2 matrix square root, you can use the analtyical formula, which should be differentiable by ForwardDiff etc. This will be much more efficient anyway than forming a generic (albeit Hermitian) matrix and taking the square root (via eigenvalues).

Actually, this 2x2 formula is already included in StaticArrays.jl, so you can just use an SMatrix — which you should probably be using anyway for such small fixed-size matrices — and it should work (and be much faster).

Changing your code to:

using StaticArrays
function matrix_representation(r)
    @SMatrix[(1+r[1]) (r[2]-r[3]im); (r[2]+r[3]im) (1-r[1])] ./ 2
end

gives

julia> ∇fidelity(ρ, [0,0,0])
3-element Vector{Float64}:
 0.0
 0.0
 0.0

julia> ∇fidelity(ρ, [0.1,0.2,0.3])
3-element Vector{Float64}:
 -0.053916386601719206
 -0.10783277320343838
 -0.1617491598051576

which matches a finite-difference check:

julia> r = [0.1,0.2,0.3]; dr = randn(3) * 1e-8;

julia> isapprox(fidelity(ρ,r+dr) - fidelity(ρ,r), ∇fidelity(ρ,r) ⋅ dr, rtol=1e-5)
true

:cherry_blossom::cherry_blossom::cherry_blossom::rabbit::egg::basket::rabbit::egg::basket::rabbit::egg::basket::rabbit::egg::basket::cherry_blossom::cherry_blossom::cherry_blossom:

1 Like

When I wrote the post, I realized that I didn’t make it explicit that the 2\times2 case was only an example, and that I actually need it to work in arbitrary dimensions. I thought I had it edited to include an observation stating that, but I must have forgotten to save it. Anyway, thank you for the response, I didn’t know that StaticArrays had those optimized methods! Happy Easter!! :rabbit: :rabbit: :rabbit: :egg: :egg: :egg:

If you are in high dimensions (with correspondingly lots of parameters), then you probably don’t want forward-mode AD (ala ForwardDiff.jl), as in that case the cost of the gradient scales with the function cost times the number of parameters. Instead, you want reverse-mode AD (ala Zygote.jl, ReverseDiff.jl, or Enzyme.jl), as in that case the cost of the gradient scales with the function cost, independent of the number of parameters.

ChainRules.jl (used by Zygote.jl) already has a rule for differentiating the sqrt of a Hermitian matrix. (Don’t forget to wrap your matrix in Hermitian.)

1 Like

See here for a discussion of Complex numbers in Enzyme: FAQ · Enzyme.jl