Hi,
I have a neural network whose loss functions requires to calculate orthogonal projections onto the output of the network. The code runs fine on the CPU, albeit slow. Porting it to the GPU is difficult and I’m stuck. Maybe someone here has a suggestion of how this function can be implemented for the GPU?
For the loss function I have two options:
1.) Using P_A = A * (transpose(A) * A)^{-1} * tr(A)
using Linear Algebra
using CUDA
using Zygote
N=16
V = rand(Float32, (2*N, N))
gradient(x -> sum(pinv(x)), V)
works fine. But the GPU version
Vg = V |> CuArray
gradient(x -> sum(pinv(x)), V)
hits this issue Support for LinearAlgebra.pinv · Issue #883 · JuliaGPU/CUDA.jl · GitHub
It’s apparently fixed in the nightly, but it doesn’t pass the tests on my system and the build is unusable.
2.) Using P_A * y = sum(q_i’ * y q_i, over i) where q_i are the columns of the QR-Factorization
Here I can’t even get the loss function to work on the CPU. A minimal working example would be
using LinearAlgebra
using Zygote
function f2(V, y) where T
# QR-decomposition
Q, _ = qr(V)
# project
f = ((Q' * y)' * Q')
return sum(f)
end
N = 16
y = zeros(Float32, 2 * N)
y[1] = 1.0
V = rand(Float32, (2*N, N))
gradient(x -> f2(x, y), V)
But this gives me an error, coming from the QR factorization:
ERROR: LoadError: Mutating arrays is not supported
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] (::Zygote.var"#399#400")(#unused#::Nothing)
@ Zygote ~/.julia/packages/Zygote/fjuG8/src/lib/array.jl:58
[3] (::Zygote.var"#2253#back#401"{Zygote.var"#399#400"})(Δ::Nothing)
@ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
[4] Pullback
@ /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/dense.jl:138 [inlined]
[5] (::typeof(∂(triu!)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/fjuG8/src/compiler/interface2.jl:0
[6] Pullback
@ /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/generic.jl:435 [inlined]
[7] (::typeof(∂(triu!)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/fjuG8/src/compiler/interface2.jl:0
[8] Pullback
@ /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/qr.jl:435 [inlined]
[9] (::typeof(∂(getproperty)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/fjuG8/src/compiler/interface2.jl:0
[10] Pullback
@ ~/.julia/packages/ZygoteRules/OjfTt/src/ZygoteRules.jl:11 [inlined]
[11] Pullback
@ /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/qr.jl:127 [inlined]
[12] (::typeof(∂(iterate)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/fjuG8/src/compiler/interface2.jl:0
[13] Pullback
@ ./tuple.jl:94 [inlined]
[14] (::typeof(∂(indexed_iterate)))(Δ::Nothing)
@ Zygote ~/.julia/packages/Zygote/fjuG8/src/compiler/interface2.jl:0
[15] Pullback
@ ~/.julia/packages/Zygote/fjuG8/src/tools/builtins.jl:17 [inlined]
[16] Pullback
@ ~/source/repos/picfun/test_QR_slice_backprop.jl:21 [inlined]
[17] (::typeof(∂(f2)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/fjuG8/src/compiler/interface2.jl:0
[18] Pullback
@ ~/source/repos/picfun/test_QR_slice_backprop.jl:56 [inlined]
[19] (::Zygote.var"#41#42"{typeof(∂(#1))})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/fjuG8/src/compiler/interface.jl:41
[20] gradient(f::Function, args::Matrix{Float32})
@ Zygote ~/.julia/packages/Zygote/fjuG8/src/compiler/interface.jl:59
[21] top-level scope
Is there any way to implement such a loss in Julia that runs on the GPU?