Backpropagating through orthogonal projection operator

Hi,
I have a neural network whose loss functions requires to calculate orthogonal projections onto the output of the network. The code runs fine on the CPU, albeit slow. Porting it to the GPU is difficult and I’m stuck. Maybe someone here has a suggestion of how this function can be implemented for the GPU?

For the loss function I have two options:
1.) Using P_A = A * (transpose(A) * A)^{-1} * tr(A)

using Linear Algebra
using CUDA
using Zygote

N=16
V = rand(Float32, (2*N, N))

gradient(x -> sum(pinv(x)), V)

works fine. But the GPU version

Vg = V |> CuArray
gradient(x -> sum(pinv(x)), V)

hits this issue Support for LinearAlgebra.pinv · Issue #883 · JuliaGPU/CUDA.jl · GitHub
It’s apparently fixed in the nightly, but it doesn’t pass the tests on my system and the build is unusable.

2.) Using P_A * y = sum(q_i’ * y q_i, over i) where q_i are the columns of the QR-Factorization

Here I can’t even get the loss function to work on the CPU. A minimal working example would be


using LinearAlgebra
using Zygote


function f2(V, y) where T
    # QR-decomposition
    Q, _ = qr(V)
    # project
    f = ((Q' * y)' * Q')
    return sum(f)
end

N = 16
y = zeros(Float32, 2 * N)
y[1] = 1.0
V = rand(Float32, (2*N, N)) 
gradient(x -> f2(x, y), V)

But this gives me an error, coming from the QR factorization:

ERROR: LoadError: Mutating arrays is not supported
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.var"#399#400")(#unused#::Nothing)
    @ Zygote ~/.julia/packages/Zygote/fjuG8/src/lib/array.jl:58
  [3] (::Zygote.var"#2253#back#401"{Zygote.var"#399#400"})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [4] Pullback
    @ /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/dense.jl:138 [inlined]
  [5] (::typeof(∂(triu!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/fjuG8/src/compiler/interface2.jl:0
  [6] Pullback
    @ /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/generic.jl:435 [inlined]
  [7] (::typeof(∂(triu!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/fjuG8/src/compiler/interface2.jl:0
  [8] Pullback
    @ /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/qr.jl:435 [inlined]
  [9] (::typeof(∂(getproperty)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/fjuG8/src/compiler/interface2.jl:0
 [10] Pullback
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/ZygoteRules.jl:11 [inlined]
 [11] Pullback
    @ /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/qr.jl:127 [inlined]
 [12] (::typeof(∂(iterate)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/fjuG8/src/compiler/interface2.jl:0
 [13] Pullback
    @ ./tuple.jl:94 [inlined]
 [14] (::typeof(∂(indexed_iterate)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/fjuG8/src/compiler/interface2.jl:0
 [15] Pullback
    @ ~/.julia/packages/Zygote/fjuG8/src/tools/builtins.jl:17 [inlined]
 [16] Pullback
    @ ~/source/repos/picfun/test_QR_slice_backprop.jl:21 [inlined]
 [17] (::typeof(∂(f2)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/fjuG8/src/compiler/interface2.jl:0
 [18] Pullback
    @ ~/source/repos/picfun/test_QR_slice_backprop.jl:56 [inlined]
 [19] (::Zygote.var"#41#42"{typeof(∂(#1))})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/fjuG8/src/compiler/interface.jl:41
 [20] gradient(f::Function, args::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/fjuG8/src/compiler/interface.jl:59
 [21] top-level scope

Is there any way to implement such a loss in Julia that runs on the GPU?

The pinv issue looks to be fixed by https://github.com/JuliaLang/julia/pull/39756 (linked from CUDA.jl issue), I just tested it locally. For QR support, the PRs to follow are https://github.com/JuliaDiff/ChainRules.jl/pull/306 and https://github.com/JuliaDiff/ChainRules.jl/pull/469.

The loss function is numerically more stable when calculated with QR instead of pinv. I implemented the loss function with the code used for PR469. for JuliaDiff/ChainRules.jl