Using Flux to optimize a function of the Singular Values

Update – using latest Flux.

I’d like to optimize a function of the singular values of a matrix. The Flux training loop barfs.

Here’s my code

using Flux, ForwardDiff, LinearAlgebra, GenericLinearAlgebra
m = 20
X = randn(m,m)
model = Dense(m,m)
loss(X) = sum(GenericLinearAlgebra.svdvals(model(X))) ## nuclear norm
data = [X] 
opt = ADAM()
for i = 1:10
      Flux.train!(loss, params(model), X, opt) ## also tried zip(X)
end

This gives a

ERROR: MethodError: no method matching (::Dense{typeof(identity),Array{Float32,2},Array{Float32,1}})(::Float64)
Closest candidates are:
Any(::AbstractArray{T,N} where N) where {T<:Union{Float32, Float64}, W<:(AbstractArray{T,N} where N)} at C:\Users\rajnr.julia\packages\Flux\Fj3bt\src\layers\basic.jl:133
Any(::AbstractArray{#s107,N} where N where #s107<:AbstractFloat) where {T<:Union{Float32, Float64}, W<:(AbstractArray{T,N} where N)} at C:\Users\rajnr.julia\packages\Flux\Fj3bt\src\layers\basic.jl:136
Any(::AbstractArray) at C:\Users\rajnr.julia\packages\Flux\Fj3bt\src\layers\basic.jl:121
Stacktrace:
[1] macro expansion at C:\Users\rajnr.julia\packages\Zygote\YeCEW\src\compiler\interface2.jl:0 [inlined]
[2] _pullback(::Zygote.Context, ::Dense{typeof(identity),Array{Float32,2},Array{Float32,1}}, ::Float64) at C:\Users\rajnr.julia\packages\Zygote\YeCEW\src\compiler\interface2.jl:7
[3] loss at .\REPL[65]:1 [inlined]
[4] _pullback(::Zygote.Context, ::typeof(loss), ::Float64) at C:\Users\rajnr.julia\packages\Zygote\YeCEW\src\compiler\interface2.jl:0
[5] adjoint at C:\Users\rajnr.julia\packages\Zygote\YeCEW\src\lib\lib.jl:179 [inlined]
[6] _pullback at C:\Users\rajnr.julia\packages\ZygoteRules\6nssF\src\adjoint.jl:47 [inlined]
[7] #17 at C:\Users\rajnr.julia\packages\Flux\Fj3bt\src\optimise\train.jl:89 [inlined]
[8] _pullback(::Zygote.Context, ::Flux.Optimise.var"#17#25"{typeof(loss),Float64}) at C:\Users\rajnr.julia\packages\Zygote\YeCEW\src\compiler\interface2.jl:0
[9] pullback(::Function, ::Zygote.Params) at C:\Users\rajnr.julia\packages\Zygote\YeCEW\src\compiler\interface.jl:174
[10] gradient(::Function, ::Zygote.Params) at C:\Users\rajnr.julia\packages\Zygote\YeCEW\src\compiler\interface.jl:54
[11] macro expansion at C:\Users\rajnr.julia\packages\Flux\Fj3bt\src\optimise\train.jl:88 [inlined]
[12] macro expansion at C:\Users\rajnr.julia\packages\Juno\f8hj2\src\progress.jl:134 [inlined]
[13] train!(::typeof(loss), ::Zygote.Params, ::Array{Float64,2}, ::ADAM; cb::Flux.Optimise.var"#18#26") at C:\Users\rajnr.julia\packages\Flux\Fj3bt\src\optimise\train.jl:81
[14] train!(::Function, ::Zygote.Params, ::Array{Float64,2}, ::ADAM) at C:\Users\rajnr.julia\packages\Flux\Fj3bt\src\optimise\train.jl:79
[15] top-level scope at .\REPL[66]:2

The AD is able to compute gradients though. For example

h(W) = ForwardDiff.gradient(W -> loss(W*X),W)

h(randn(20,20))

will work. So will

h(Y) = ForwardDiff.gradient(Y -> loss(Y),Y)
h(randn(20,20))

It returns the right answer - it’s the gradient of the nuclear norm with respect to the matrix

How can I close the loop so Flux can optimize to above loss function?

Thanks,

Raj

Thanks @baggepinnen for the tip to update to latest Flux. It fixed the TrackedArrays problem.

Welcome to this forum! See this post


to make the code more readable :slight_smile:

2 Likes

It appears that you are using a very old version of Flux, later versions do not use TrackedArrays anymore. Maybe you could try to update Flux to the latest version?

1 Like

Here’s how you can get the gradient working:

using Flux, Zygote, ForwardDiff
using GenericLinearAlgebra: svdvals

svdvals2(X) = Zygote.forwarddiff(svdvals, X)

m = 20
X = randn(m,m)
model = Dense(m,m)
loss(X) = sum(svdvals2(X)) ## nuclear norm

loss(X)

gradient(loss, X)

This plugs in forwarddiff to get the svdvals gradient. It’d be nice to have direct support for this, so I opened ChainRules issues for SVD and svdvals.

2 Likes

@MikeInnes that works great.

The Zygote.forwarddiff(svdvals,X) doesn’t do what one initially thinks it does (it seems to evaluate svdvals(X) instead of diff(svdvals(X)) ). Perhaps needs a better name – Zygote.evalf(svdvals,X) or Zygote.map(svdvals,X) ? Just a thought :slight_smile:

forwarddiff(f, x) is a little weird because it is the same as f(x) for the forward pass – it only affects how gradients are calculated, ie using forward rather than reverse mode.

I’d love to have a name that gets that idea across better, although it might be fundamentally unintuitive until you’re exposed to the idea of functions that (only) affect the backwards pass.

1 Like

aah I see… that is an idea i haven’t been exposed to but seems like a great name could also expose the user to that idea (if they wanted to know more)

how about map4autodiff, map4forwardpass, map4backpass ?