Gradient of sum of singular values of a matrix with CUDA.jl

Any idea on why it doesn’t work with CUDA

julia> using Flux, LinearAlgebra

julia> x = randn(10, 10);

julia> xgpu = gpu(x);

julia> gradient(x -> sum(svd(x).S), xgpu)
ERROR: MethodError: no method matching svd_rev(::CUDA.CUSOLVER.CuSVD{Float32,Float32,CUDA.CuArray{Float32,2}}, ::ChainRulesCore.Zero, ::CUDA.CuArray{Float32,1}, ::ChainRulesCore.Zero)
Closest candidates are:
  svd_rev(::SVD, ::Any, ::Any, ::Any) at /home/arl/.julia/packages/ChainRules/fxzix/src/rulesets/LinearAlgebra/factorization.jl:235

but is completely fine with CPU?

julia> gradient(x -> sum(svd(x).S), x)
([0.22422035794513845 -0.22840843902956595 … -0.11853353221499852 -0.2644796980043797; 0.3925933912813142 -0.17862913793640262 … -0.07040796485757393 -0.2394096734163877; … ; 0.1322218981131396 0.04551545691787621 … 0.046312344885702884 0.27236765206728764; -0.014331187321234687 -0.08331098168443869 … 0.28966886173762263 -0.2966424859947306],)

Julia and package versions:

julia> versioninfo()
Julia Version 1.5.3
Commit 788b2c77c1 (2020-11-09 13:37 UTC)
Platform Info:
  OS: Linux (x86_64-pc-linux-gnu)
  CPU: Intel(R) Xeon(R) CPU E5-1630 v3 @ 3.70GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-9.0.1 (ORCJIT, haswell)

(jl_zGZp9B) pkg> st
Status `/tmp/jl_zGZp9B/Project.toml`
  [587475ba] Flux v0.11.6

ChainRules calls svd_rev, which it defines only on Base.SVD while CUDA.jl has its own SVD type.

Thanks! It works once I have defined svd_rev on CUDA.CUSOLVER.CuSVD by modifying the original svd_rev.

using CUDA, ChainRules, Flux, LinearAlgebra, Test
CUDA.allowscalar(false)

function ChainRules.svd_rev(USV::CUDA.CUSOLVER.CuSVD{T}, Ū, s̄, V̄) where T
    U = USV.U
    s = USV.S
    V = USV.V
    Vt = USV.Vt

    m = size(U,1)
    k = length(s)
    #T = eltype(s)
    #F = T[i == j ? 1 : inv(@inbounds s[j]^2 - s[i]^2) for i = 1:k, j = 1:k]
    s2 = s.^2 
    F = inv.((s2 .- reshape(s2, 1, :))' .+ CuArray(Diagonal(ones(T, k))))

    # We do a lot of matrix operations here, so we'll try to be memory-friendly and do
    # as many of the computations in-place as possible. Benchmarking shows that the in-
    # place functions here are significantly faster than their out-of-place, naively
    # implemented counterparts, and allocate no additional memory.
    Ut = U'
    FUᵀŪ = ChainRules._mulsubtrans!!(Ut*Ū, F)  # F .* (UᵀŪ - ŪᵀU)
    FVᵀV̄ = ChainRules._mulsubtrans!!(Vt*V̄, F)  # F .* (VᵀV̄ - V̄ᵀV)
    ImUUᵀ = CuArray(Diagonal(ones(T, m))) .- U*Ut #ChainRules._eyesubx!(U*Ut)  # I - UUᵀ
    ImVVᵀ = CuArray(Diagonal(ones(T, m))) .- V*Vt #ChainRules._eyesubx!(V*Vt)  # I - VVᵀ

    S = Diagonal(s)
    S̄ = s̄ isa AbstractZero ? s̄ : Diagonal(s̄)

    # TODO: consider using MuladdMacro here
    Ā = add!!(U * FUᵀŪ * S, ImUUᵀ * (Ū / S)) * Vt
    Ā = add!!(Ā, U * S̄ * Vt)
    Ā = add!!(Ā, U * add!!(S * FVᵀV̄ * Vt, (S \ V̄') * ImVVᵀ))

    return Ā
end

x = randn(Float32, 10, 10)
xgpu = gpu(x)

g1 = gradient(x -> sum(svd(x).S), x)
g2 = gradient(x -> sum(svd(x).S), xgpu)

@test g1[1] ≈ Array(g2[1]) # Test Passed

Thanks, I just came across the same problem.
I had to change AbstractZero to ChainRules. AbstractZero.
and add!! to ChainRulesCore.add!! to get it working.

@ymtoo Reading your code I’m confused how calling ChainRules._mulsubtrans!! is compatible with CUDA.allowscalar(false). The routine explicitly loops over indices, which should trigger an error.

Also, when you construct S = Diagonal(s), which is a CPU array, are the calls Ubar / S and S \ Vbar' resolved to CPU or CUDA routines?

@rkube You’re right. The svd_rev only works on getting the gradient of singular values of a matrix. Since Ubar=Zero() and Vbar=Zero(), the method can be further simplified to

function ChainRules.svd_rev(USV::CUDA.CUSOLVER.CuSVD{T}, Ū::Zero, s̄::CuArray{T,1}, V̄::Zero) where T
  U = USV.U
  Vt = USV.Vt
  S̄ = s̄ isa AbstractZero ? s̄ : Diagonal(s̄)
  Ā = U * S̄ * Vt
  return Ā 
end

Hence, the following doesn’t work.

julia> g3 = gradient(x -> sum(svd(x).U), xgpu)
ERROR: LoadError: MethodError: no method matching svd_rev(::CUDA.CUSOLVER.CuSVD{Float32, Float32, CuArray{Float32, 2}}, ::CuArray{Float32, 2}, ::Zero, ::Zero)
Closest candidates are:
  svd_rev(::SVD, ::Any, ::Any, ::Any) at /home/ymtoo/.julia/packages/ChainRules/h2PiT/src/rulesets/LinearAlgebra/factorization.jl:234
  svd_rev(::CUDA.CUSOLVER.CuSVD{T, Tr, A} where {Tr, A<:AbstractMatrix{T}}, ::Zero, ::CuArray{T, 1}, ::Zero) where T at /home/ymtoo/Projects/tmp/test-svd.jl:5
Stacktrace:
 [1] (::ChainRules.var"#svd_pullback#1901"{CUDA.CUSOLVER.CuSVD{Float32, Float32, CuArray{Float32, 2}}})(Ȳ::Composite{Any, NamedTuple{(:U, :S, :V), Tuple{CuArray{Float32, 2}, Zero, Zero}}})
   @ ChainRules ~/.julia/packages/ChainRules/h2PiT/src/rulesets/LinearAlgebra/factorization.jl:210
 [2] (::Zygote.ZBack{ChainRules.var"#svd_pullback#1901"{CUDA.CUSOLVER.CuSVD{Float32, Float32, CuArray{Float32, 2}}}})(dy::NamedTuple{(:U, :S, :V), Tuple{CuArray{Float32, 2}, Nothing, Nothing}})
   @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/chainrules.jl:77
 [3] Pullback
   @ ~/Projects/tmp/test-svd.jl:71 [inlined]
 [4] (::typeof(∂(#15)))(Δ::Float32)
   @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [5] (::Zygote.var"#41#42"{typeof(∂(#15))})(Δ::Float32)
   @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:41
 [6] gradient(f::Function, args::CuArray{Float32, 2})
   @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:59
 [7] top-level scope
   @ ~/Projects/tmp/test-svd.jl:71
in expression starting at /home/ymtoo/Projects/tmp/test-svd.jl:71

Thank you for the clarification.

So with Ū::Zero and V̄::Zero the calls to ChainRules._mulsubtrans!! in your previous calls are dispatched to _mulsubtrans!!(X::AbstractZero, F::AbstractZero) = X and not to
function _mulsubtrans!!(X::AbstractMatrix{<:Real}, F::AbstractMatrix{<:Real}). The later one employs array indexing and would trigger the error flag set by CUDA.allowscalar(false).

Hi, I have just run across your adjoint and it really helped me! Just a quick note for anyone comming after me: this works for svd of square matrices only. To make it work for non-square ones, we need to replace the m in computation of ImVV^T by size(V,1).