Using Flux to optimize a function of the Singular Values

Update – using latest Flux.

I’d like to optimize a function of the singular values of a matrix. The Flux training loop barfs.

Here’s my code

using Flux, ForwardDiff, LinearAlgebra, GenericLinearAlgebra
m = 20
X = randn(m,m)
model = Dense(m,m)
loss(X) = sum(GenericLinearAlgebra.svdvals(model(X))) ## nuclear norm
data = [X] 
opt = ADAM()
for i = 1:10
      Flux.train!(loss, params(model), X, opt) ## also tried zip(X)

This gives a

ERROR: MethodError: no method matching (::Dense{typeof(identity),Array{Float32,2},Array{Float32,1}})(::Float64)
Closest candidates are:
Any(::AbstractArray{T,N} where N) where {T<:Union{Float32, Float64}, W<:(AbstractArray{T,N} where N)} at C:\Users\rajnr.julia\packages\Flux\Fj3bt\src\layers\basic.jl:133
Any(::AbstractArray{#s107,N} where N where #s107<:AbstractFloat) where {T<:Union{Float32, Float64}, W<:(AbstractArray{T,N} where N)} at C:\Users\rajnr.julia\packages\Flux\Fj3bt\src\layers\basic.jl:136
Any(::AbstractArray) at C:\Users\rajnr.julia\packages\Flux\Fj3bt\src\layers\basic.jl:121
[1] macro expansion at C:\Users\rajnr.julia\packages\Zygote\YeCEW\src\compiler\interface2.jl:0 [inlined]
[2] _pullback(::Zygote.Context, ::Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}, ::Float64) at C:\Users\rajnr.julia\packages\Zygote\YeCEW\src\compiler\interface2.jl:7
[3] loss at .\REPL[65]:1 [inlined]
[4] _pullback(::Zygote.Context, ::typeof(loss), ::Float64) at C:\Users\rajnr.julia\packages\Zygote\YeCEW\src\compiler\interface2.jl:0
[5] adjoint at C:\Users\rajnr.julia\packages\Zygote\YeCEW\src\lib\lib.jl:179 [inlined]
[6] _pullback at C:\Users\rajnr.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:47 [inlined]
[7] #17 at C:\Users\rajnr.julia\packages\Flux\Fj3bt\src\optimise\train.jl:89 [inlined]
[8] _pullback(::Zygote.Context, ::Flux.Optimise.var"#17#25"{typeof(loss),Float64}) at C:\Users\rajnr.julia\packages\Zygote\YeCEW\src\compiler\interface2.jl:0
[9] pullback(::Function, ::Zygote.Params) at C:\Users\rajnr.julia\packages\Zygote\YeCEW\src\compiler\interface.jl:174
[10] gradient(::Function, ::Zygote.Params) at C:\Users\rajnr.julia\packages\Zygote\YeCEW\src\compiler\interface.jl:54
[11] macro expansion at C:\Users\rajnr.julia\packages\Flux\Fj3bt\src\optimise\train.jl:88 [inlined]
[12] macro expansion at C:\Users\rajnr.julia\packages\Juno\f8hj2\src\progress.jl:134 [inlined]
[13] train!(::typeof(loss), ::Zygote.Params, ::Array{Float64,2}, ::ADAM; cb::Flux.Optimise.var"#18#26") at C:\Users\rajnr.julia\packages\Flux\Fj3bt\src\optimise\train.jl:81
[14] train!(::Function, ::Zygote.Params, ::Array{Float64,2}, ::ADAM) at C:\Users\rajnr.julia\packages\Flux\Fj3bt\src\optimise\train.jl:79
[15] top-level scope at .\REPL[66]:2

The AD is able to compute gradients though. For example

h(W) = ForwardDiff.gradient(W -> loss(W*X),W)


will work. So will

h(Y) = ForwardDiff.gradient(Y -> loss(Y),Y)

It returns the right answer - it’s the gradient of the nuclear norm with respect to the matrix

How can I close the loop so Flux can optimize to above loss function?



Thanks @baggepinnen for the tip to update to latest Flux. It fixed the TrackedArrays problem.

It appears that you are using a very old version of Flux, later versions do not use TrackedArrays anymore. Maybe you could try to update Flux to the latest version?

Here’s how you can get the gradient working:

using Flux, Zygote, ForwardDiff
using GenericLinearAlgebra: svdvals

svdvals2(X) = Zygote.forwarddiff(svdvals, X)

m = 20
X = randn(m,m)
model = Dense(m,m)
loss(X) = sum(svdvals2(X)) ## nuclear norm


gradient(loss, X)

This plugs in forwarddiff to get the svdvals gradient. It’d be nice to have direct support for this, so I opened ChainRules issues for SVD and svdvals.


@MikeInnes that works great.

The Zygote.forwarddiff(svdvals,X) doesn’t do what one initially thinks it does (it seems to evaluate svdvals(X) instead of diff(svdvals(X)) ). Perhaps needs a better name – Zygote.evalf(svdvals,X) or,X) ? Just a thought :slight_smile:

forwarddiff(f, x) is a little weird because it is the same as f(x) for the forward pass – it only affects how gradients are calculated, ie using forward rather than reverse mode.

I’d love to have a name that gets that idea across better, although it might be fundamentally unintuitive until you’re exposed to the idea of functions that (only) affect the backwards pass.

aah I see… that is an idea i haven’t been exposed to but seems like a great name could also expose the user to that idea (if they wanted to know more)

how about map4autodiff, map4forwardpass, map4backpass ?