Implementation of Spectral Normalization for Machine Learning

Greetings.

I have migrated most of my Machine Learning work from working in Pytorch for Flux and have not felt the need to look back(especially as I move towards SciML), however there was a small hiccup I’d like t inquire about:

Spectral normalization (SN) is a technique employed in Machine Learning, specifically Generative Adversarial Networks, to stabilize learning of the discriminator. As an exercise, I tried to generate my own Dense layer function with spectral normalization. However, I still find myself unable to train the network because my implementation of the spectral norm seems to not be differentiable using the Flux rendition of the gradient function, but shows no problem with the ForwardDiff gradient function.

Below is a simple code I wrote to illustrate this fact:

using Flux: gradient
using GenericLinearAlgebra: svdvals

function snorm(X)
    return svdvals(X)[1]
end

X = rand(3,5)#Generating a random matrix to analyze.

println(snorm(X))#This part executes without issue.

println(gradient(snorm, X))#This part gives a "Can't differentiate foreigncall expression" error.

The stacktrace is as follows:

ERROR: Can't differentiate foreigncall expression
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] Pullback
    @ /usr/share/julia/stdlib/v1.7/LinearAlgebra/src/lapack.jl:1667 [inlined]
  [3] (::typeof(∂(gesdd!)))(Δ::Tuple{Nothing, Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
  [4] Pullback
    @ /usr/share/julia/stdlib/v1.7/LinearAlgebra/src/svd.jl:211 [inlined]
  [5] (::typeof(∂(svdvals!)))(Δ::Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
  [6] Pullback
    @ /usr/share/julia/stdlib/v1.7/LinearAlgebra/src/svd.jl:238 [inlined]
  [7] (::typeof(∂(svdvals)))(Δ::Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
  [8] Pullback
    @ ./REPL[10]:2 [inlined]
  [9] (::Zygote.var"#57#58"{typeof(∂(snorm))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface.jl:41
 [10] gradient(f::Function, args::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface.jl:76
 [11] top-level scope
    @ REPL[13]:1
 [12] top-level scope
    @ ~/.julia/packages/CUDA/bki2w/src/initialization.jl:52

However, if the exact same steps are taken using the gradient from ForwardDiff instead, there is no error reported.

using ForwardDiff: gradient
using GenericLinearAlgebra: svdvals

function snorm(X)
    return svdvals(X)[1]
end

X = rand(3,5)#Generating a random matrix to analyze.

println(snorm(X))#This part executes without issue.

println(gradient(snorm, X))#This part does not give a "Can't differentiate foreigncall expression" error anymore.

My research machine is an Arch Linux using the Linux-bin package downloaded from the AUR, and otherwise runs flawlessly. The error is reproduced in my personal machine, which is a Mac OS system. Both Julia distributions are up-to-date, and the packages were also updated yesterday.

Any help, be it correcting some error I might have made or suggesting an alternative implementation that would work, is appreciated. This is my first time writing anything here, so please do point out if I did anything out of line.

The norm you’re using is called opnorm in Julia, perhaps this function is differentiable?

Thank you for the idea. I didn’t think to look up whether it was implemented as opnorm since the spectral norm is a specific type of operation norm.(Edit: Sorry, realized it works the same as the regular norm in that you inform the order. However, default seems to be 2, so everything here stands.)
Unfortunately, the results seem to be the same in the end.

ERROR: Can't differentiate foreigncall expression
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] Pullback
    @ /usr/share/julia/stdlib/v1.7/LinearAlgebra/src/lapack.jl:1667 [inlined]
  [3] (::typeof(∂(gesdd!)))(Δ::Tuple{Nothing, Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}}, Nothing})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
  [4] Pullback
    @ /usr/share/julia/stdlib/v1.7/LinearAlgebra/src/svd.jl:211 [inlined]
  [5] (::typeof(∂(svdvals!)))(Δ::Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
  [6] Pullback
    @ /usr/share/julia/stdlib/v1.7/LinearAlgebra/src/svd.jl:238 [inlined]
  [7] (::typeof(∂(svdvals)))(Δ::Zygote.OneElement{Float64, 1, Tuple{Int64}, Tuple{Base.OneTo{Int64}}})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
  [8] Pullback
    @ /usr/share/julia/stdlib/v1.7/LinearAlgebra/src/generic.jl:708 [inlined]
  [9] (::typeof(∂(opnorm2)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
 [10] Pullback
    @ /usr/share/julia/stdlib/v1.7/LinearAlgebra/src/generic.jl:768 [inlined]
 [11] (::typeof(∂(opnorm)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface2.jl:0
 [12] Pullback
    @ /usr/share/julia/stdlib/v1.7/LinearAlgebra/src/generic.jl:767 [inlined]
 [13] (::Zygote.var"#57#58"{typeof(∂(opnorm))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface.jl:41
 [14] gradient(f::Function, args::Matrix{Float64})
    @ Zygote ~/.julia/packages/Zygote/FPUm3/src/compiler/interface.jl:76
 [15] dopnorm(X::Matrix{Float64})
    @ Main ./REPL[4]:1
 [16] top-level scope
    @ REPL[5]:1
 [17] top-level scope
    @ ~/.julia/packages/CUDA/bki2w/src/initialization.jl:52

The problem is that there is currently no rrule defined for svdvals.
A workaround would be to define:

function snorm(X)
    return svd(X).S[1]
end

Which does a full svd (for which there is an rrule defined)

1 Like

Interesting. This reminds me of this thread where I learned that at least differentiation of real svd seems supported by Zygote?

Thank you very much! It worked perfectly now!

1 Like

The rule for SVD is here. It would probably be very simple to add one for svdvals, perhaps calling the same function to do the work svd_rev(nt, NoTangent(), S̄, NoTangent()) or perhaps just keeping the few lines which are nonzero there.

1 Like

Thank you for the explanation and link.
Using the same would probably be the way to go, since to my knowledge calculating the gradient of the spectral norm (where it exists) requires knowing the U and V matrices obtained from performing the standard svd.