While automatic differentiation with Zygote works fine for matrices with element in Float64 it fails for elements in ComplexF64:
‘’’
using Zygote
function svdtest(A)
U,S,V = svd(A)
a = S[1]
return a
end
A = rand(ComplexF64, 10,10)
svdtest(A)
gradient(A → svdtest(A), A)‘’’
Is there a way around this?
Complex SVD seems to be supported (because svdtest(A) works), but it maybe is differently implemented than real SVD (because your MWE works for Float64 indeed). And I would have expected to find a reference to cgesvd.
Edit: ah, I think I understand: that is maybe due to different storage formats of Julia and BLAS?
Edit: sorry, my misunderstanding, if I replace svd with a direct call to LAPACK
# U,S,V = svd(A)
U, S, V = LAPACK.gesvd!('A', 'A', A)
I run into the same error. Looks like the SVD for reals is somehow smarter?
Edit: Zygote has special rules for SVD and they are only implemented for reals IIUC.
Isn’t that simply the error you get if no ad-rules are defined (e.g. via ChainRules.jl) for foreign (i.e. non-julia) functions?
I didn’t check, but possibly the rules are defined for Reals only?
using LinearAlgebra, Zygote, ChainRules
function ChainRules.rrule(::typeof(svd), X::AbstractMatrix{<:Complex})
F = svd(X)
svd_pullback(ȳ) = ChainRules._svd_pullback(ȳ, F)
return F, svd_pullback
end
function svdtest(A)
U,S,V = svd(A)
a = S[1]
return a
end
A = rand(ComplexF64, 10,10)
svdtest(A)
gradient(A -> svdtest(A), A)
@Eriklw would you care to check and file an issue against ChainRules.jl?
Sorry, I naively didn’t check if this was a recent research question.
Edit: nevertheless I did a quick gradient check
using FiniteDifferences, LinearAlgebra, Zygote, ChainRules
function ChainRules.rrule(::typeof(svd), X::AbstractMatrix{<:Complex})
F = svd(X)
svd_pullback(ȳ) = ChainRules._svd_pullback(ȳ, F)
return F, svd_pullback
end
function svdtest(A)
U,S,V = svd(A)
a = S[1]
return a
end
A = rand(ComplexF64, 10,10)
svdtest(A)
a1 = gradient(A -> svdtest(A), A)
a2 = grad(central_fdm(5, 1), A -> svdtest(A), A)
norm.(a2 .- a1)
It says that the part of the complex SVD back propagation formula that is unique to the complex case is zero if your loss function only depends on S (in fact it has to depend on both U and V to test the complex back propagation formula). Perhaps try a loss function that depends on U and V like the one they suggest in that paper.