Automatic differentiation of complex matrix fails

While automatic differentiation with Zygote works fine for matrices with element in Float64 it fails for elements in ComplexF64:
‘’’
using Zygote
function svdtest(A)
U,S,V = svd(A)
a = S[1]
return a
end

A = rand(ComplexF64, 10,10)
svdtest(A)
gradient(A → svdtest(A), A)‘’’
Is there a way around this?

Here is the error message I see:

ERROR: LoadError: Can't differentiate foreigncall expression

Interesting: it looks like LAPACK’s cgesvd isn’t referenced in svd.jl.

I see the same error message. Do you know how to interpret this? Is an SVD for a complex matrix simply not supported?

Complex SVD seems to be supported (because svdtest(A) works), but it maybe is differently implemented than real SVD (because your MWE works for Float64 indeed). And I would have expected to find a reference to cgesvd.

Edit: ah, I think I understand: that is maybe due to different storage formats of Julia and BLAS?
Edit: sorry, my misunderstanding, if I replace svd with a direct call to LAPACK

    # U,S,V = svd(A)
    U, S, V = LAPACK.gesvd!('A', 'A', A)

I run into the same error. Looks like the SVD for reals is somehow smarter?

Edit: Zygote has special rules for SVD and they are only implemented for reals IIUC.

Isn’t that simply the error you get if no ad-rules are defined (e.g. via ChainRules.jl) for foreign (i.e. non-julia) functions?
I didn’t check, but possibly the rules are defined for Reals only?

1 Like

A simple hack seems to work around the problem:

using LinearAlgebra, Zygote, ChainRules

function ChainRules.rrule(::typeof(svd), X::AbstractMatrix{<:Complex})
    F = svd(X)
    svd_pullback(ȳ) = ChainRules._svd_pullback(ȳ, F)
    return F, svd_pullback
end

function svdtest(A)
    U,S,V = svd(A)
    a = S[1]
    return a
end

A = rand(ComplexF64, 10,10)
svdtest(A)
gradient(A -> svdtest(A), A)

@Eriklw would you care to check and file an issue against ChainRules.jl?

2 Likes

Indeed, it is currently only defined for Real matrices in ChainRules right now:

This package extends it to complex:

Note that it is more subtle than the rule defined by @goerch, see the references:

https://giggleliu.github.io/2019/04/02/einsumbp.html

Probably it would be good for someone to make a proper PR of the complex AD rule from BackwardsLinalg.jl to get it into ChainRules.

4 Likes

Sorry, I naively didn’t check if this was a recent research question.

Edit: nevertheless I did a quick gradient check

using FiniteDifferences, LinearAlgebra, Zygote, ChainRules

function ChainRules.rrule(::typeof(svd), X::AbstractMatrix{<:Complex})
    F = svd(X)
    svd_pullback(ȳ) = ChainRules._svd_pullback(ȳ, F)
    return F, svd_pullback
end

function svdtest(A)
    U,S,V = svd(A)
    a = S[1]
    return a
end

A = rand(ComplexF64, 10,10)
svdtest(A)
a1 = gradient(A -> svdtest(A), A)
a2 = grad(central_fdm(5, 1), A -> svdtest(A), A)

norm.(a2 .- a1)

yielding

(4.390263449925908e-12,)

Coincidence?

Does the BackwardsLinalg.jl package work for you?

Didn’t get it to work with current Zygote, older versions error`d out.

Please refer to the end of: [1909.02659] Automatic Differentiation for Complex Valued SVD

It says that the part of the complex SVD back propagation formula that is unique to the complex case is zero if your loss function only depends on S (in fact it has to depend on both U and V to test the complex back propagation formula). Perhaps try a loss function that depends on U and V like the one they suggest in that paper.

1 Like

I haven’t tested it in a long time, so I wouldn’t be surprised. Ideally the rule should be ported to ChainRules, but that shouldn’t be difficult to write in terms of svd_back in https://github.com/GiggleLiu/BackwardsLinalg.jl/blob/master/src/svd.jl.