Backpropagating through orthogonal projection operator

I have a neural network whose loss functions requires to calculate orthogonal projections onto the output of the network. The code runs fine on the CPU, albeit slow. Porting it to the GPU is difficult and I’m stuck. Maybe someone here has a suggestion of how this function can be implemented for the GPU?

For the loss function I have two options:
1.) Using P_A = A * (transpose(A) * A)^{-1} * tr(A)

using Linear Algebra
using CUDA
using Zygote

V = rand(Float32, (2*N, N))

gradient(x -> sum(pinv(x)), V)

works fine. But the GPU version

Vg = V |> CuArray
gradient(x -> sum(pinv(x)), V)

hits this issue Support for LinearAlgebra.pinv · Issue #883 · JuliaGPU/CUDA.jl · GitHub
It’s apparently fixed in the nightly, but it doesn’t pass the tests on my system and the build is unusable.

2.) Using P_A * y = sum(q_i’ * y q_i, over i) where q_i are the columns of the QR-Factorization

Here I can’t even get the loss function to work on the CPU. A minimal working example would be

using LinearAlgebra
using Zygote

function f2(V, y) where T
    # QR-decomposition
    Q, _ = qr(V)
    # project
    f = ((Q' * y)' * Q')
    return sum(f)

N = 16
y = zeros(Float32, 2 * N)
y[1] = 1.0
V = rand(Float32, (2*N, N)) 
gradient(x -> f2(x, y), V)

But this gives me an error, coming from the QR factorization:

ERROR: LoadError: Mutating arrays is not supported
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] (::Zygote.var"#399#400")(#unused#::Nothing)
    @ Zygote ~/.julia/packages/Zygote/fjuG8/src/lib/array.jl:58
  [3] (::Zygote.var"#2253#back#401"{Zygote.var"#399#400"})(Δ::Nothing)
    @ Zygote ~/.julia/packages/ZygoteRules/OjfTt/src/adjoint.jl:59
  [4] Pullback
    @ /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/dense.jl:138 [inlined]
  [5] (::typeof(∂(triu!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/fjuG8/src/compiler/interface2.jl:0
  [6] Pullback
    @ /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/generic.jl:435 [inlined]
  [7] (::typeof(∂(triu!)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/fjuG8/src/compiler/interface2.jl:0
  [8] Pullback
    @ /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/qr.jl:435 [inlined]
  [9] (::typeof(∂(getproperty)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/fjuG8/src/compiler/interface2.jl:0
 [10] Pullback
    @ ~/.julia/packages/ZygoteRules/OjfTt/src/ZygoteRules.jl:11 [inlined]
 [11] Pullback
    @ /Users/julia/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/qr.jl:127 [inlined]
 [12] (::typeof(∂(iterate)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/fjuG8/src/compiler/interface2.jl:0
 [13] Pullback
    @ ./tuple.jl:94 [inlined]
 [14] (::typeof(∂(indexed_iterate)))(Δ::Nothing)
    @ Zygote ~/.julia/packages/Zygote/fjuG8/src/compiler/interface2.jl:0
 [15] Pullback
    @ ~/.julia/packages/Zygote/fjuG8/src/tools/builtins.jl:17 [inlined]
 [16] Pullback
    @ ~/source/repos/picfun/test_QR_slice_backprop.jl:21 [inlined]
 [17] (::typeof(∂(f2)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/fjuG8/src/compiler/interface2.jl:0
 [18] Pullback
    @ ~/source/repos/picfun/test_QR_slice_backprop.jl:56 [inlined]
 [19] (::Zygote.var"#41#42"{typeof(∂(#1))})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/fjuG8/src/compiler/interface.jl:41
 [20] gradient(f::Function, args::Matrix{Float32})
    @ Zygote ~/.julia/packages/Zygote/fjuG8/src/compiler/interface.jl:59
 [21] top-level scope

Is there any way to implement such a loss in Julia that runs on the GPU?

The pinv issue looks to be fixed by Make istriu, istril, and pinv more gpu friendly by pabloferz · Pull Request #39756 · JuliaLang/julia · GitHub (linked from CUDA.jl issue), I just tested it locally. For QR support, the PRs to follow are Implement QR pullback by Kolaru · Pull Request #306 · JuliaDiff/ChainRules.jl · GitHub and Support for qr decomposition pullback by rkube · Pull Request #469 · JuliaDiff/ChainRules.jl · GitHub.