Using Flux to optimize a function of the Singular Values

Here’s how you can get the gradient working:

using Flux, Zygote, ForwardDiff
using GenericLinearAlgebra: svdvals

svdvals2(X) = Zygote.forwarddiff(svdvals, X)

m = 20
X = randn(m,m)
model = Dense(m,m)
loss(X) = sum(svdvals2(X)) ## nuclear norm

loss(X)

gradient(loss, X)

This plugs in forwarddiff to get the svdvals gradient. It’d be nice to have direct support for this, so I opened ChainRules issues for SVD and svdvals.

2 Likes