Efficient GPU dense-sparse matmul differentiable with Zygote

I’m trying to train some deep learning graph models and I’m looking for an efficient way of performing a dense-sparse matrix product on GPU that’s Zygote-differentiable with respect to entries of the dense matrix. I was surprised that x * A doesn’t work, and some experimentations with various transposes didn’t clarify exactly how I should be performing things.

Minimal working example below. Thanks in advance.

using CUDA, Zygote, Random, SparseArrays
CUDA.allowscalar(false)

n, m = 100, 13

# get a symmetric A
A = sprand(n, n, 0.1)
A = A + A'
A = CUDA.CUSPARSE.CuSparseMatrixCSC(A)

# sparse-dense works fine - just not what I'm after
x = CuMatrix(rand(n, m))
gradient(x -> sum(A * x), x)

xt = copy(x')  # make an [m x n] matrix
gradient(x -> sum(x * A), xt)              # indexerror - first stacktrace
gradient(x -> sum(A * x'), xt)             # works fine
gradient(x -> sum((A * x')'), xt)          # indexerror - second stacktrace
gradient(x -> sum(copy((A * x')')), xt)    # works

Stacktrace for gradient(x -> sum(x * A), xt)

ERROR: scalar getindex is disallowed
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] assertscalar(op::String)
    @ GPUArrays ~/.julia/packages/GPUArrays/0ShDd/src/host/indexing.jl:62
  [3] getindex(::CuArray{Float64, 2}, ::Int64, ::Int64)
    @ GPUArrays ~/.julia/packages/GPUArrays/0ShDd/src/host/indexing.jl:104
  [4] _generic_matmatmul!(C::Matrix{Float64}, tA::Char, tB::Char, A::CuArray{Float64, 2}, B::CUDA.CUSPARSE.CuSparseMatrixCSC{Float64}, _add::LinearAlgebra.MulAddMul{true, true, Bool, Bool})
    @ LinearAlgebra /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:835
  [5] generic_matmatmul!(C::Matrix{Float64}, tA::Char, tB::Char, A::CuArray{Float64, 2}, B::CUDA.CUSPARSE.CuSparseMatrixCSC{Float64}, _add::LinearAlgebra.MulAddMul{true, true, Bool, Bool})
    @ LinearAlgebra /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:802
  [6] mul!
    @ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:302 [inlined]
  [7] mul!
    @ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:275 [inlined]
  [8] *
    @ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:153 [inlined]
  [9] rrule
    @ ~/.julia/packages/ChainRules/vz8Io/src/rulesets/Base/arraymath.jl:40 [inlined]
 [10] chain_rrule
    @ ~/.julia/packages/Zygote/6HN9x/src/compiler/chainrules.jl:89 [inlined]
 [11] macro expansion
    @ ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0 [inlined]
 [12] _pullback(::Zygote.Context, ::typeof(*), ::CuArray{Float64, 2}, ::CUDA.CUSPARSE.CuSparseMatrixCSC{Float64})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:9
 [13] _pullback
    @ ./REPL[15]:1 [inlined]
 [14] _pullback(ctx::Zygote.Context, f::var"#5#6", args::CuArray{Float64, 2})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [15] _pullback(f::Function, args::CuArray{Float64, 2})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:34
 [16] pullback(f::Function, args::CuArray{Float64, 2})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:40
 [17] gradient(f::Function, args::CuArray{Float64, 2})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:58
 [18] top-level scope
    @ REPL[15]:1
 [19] top-level scope
    @ ~/.julia/packages/CUDA/k52QH/src/initialization.jl:81

Stacktrace for gradient(x -> sum((A * x')'), xt)

ERROR: scalar getindex is disallowed
Stacktrace:
  [1] error(s::String)
    @ Base ./error.jl:33
  [2] assertscalar(op::String)
    @ GPUArrays ~/.julia/packages/GPUArrays/0ShDd/src/host/indexing.jl:62
  [3] getindex(::CuArray{Float64, 2}, ::Int64, ::Int64)
    @ GPUArrays ~/.julia/packages/GPUArrays/0ShDd/src/host/indexing.jl:104
  [4] _generic_matmatmul!(C::CuArray{Float64, 2}, tA::Char, tB::Char, A::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, B::CuArray{Float64, 2}, _add::LinearAlgebra.MulAddMul{true, true, Bool, Bool})
    @ LinearAlgebra /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:835
  [5] generic_matmatmul!(C::CuArray{Float64, 2}, tA::Char, tB::Char, A::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, B::CuArray{Float64, 2}, _add::LinearAlgebra.MulAddMul{true, true, Bool, Bool})
    @ LinearAlgebra /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:802
  [6] mul!
    @ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:302 [inlined]
  [7] mul!
    @ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:275 [inlined]
  [8] *
    @ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:153 [inlined]
  [9] #1523
    @ ~/.julia/packages/ChainRules/vz8Io/src/rulesets/Base/arraymath.jl:31 [inlined]
 [10] Thunk
    @ ~/.julia/packages/ChainRulesCore/1qau5/src/differentials/thunks.jl:98 [inlined]
 [11] unthunk
    @ ~/.julia/packages/ChainRulesCore/1qau5/src/differentials/thunks.jl:99 [inlined]
 [12] unthunk
    @ ~/.julia/packages/ChainRulesCore/1qau5/src/differentials/thunks.jl:120 [inlined]
 [13] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/6HN9x/src/compiler/chainrules.jl:41 [inlined]
 [14] map
    @ ./tuple.jl:215 [inlined]
 [15] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/6HN9x/src/compiler/chainrules.jl:42 [inlined]
 [16] ZBack
    @ ~/.julia/packages/Zygote/6HN9x/src/compiler/chainrules.jl:77 [inlined]
 [17] Pullback
    @ ./REPL[16]:1 [inlined]
 [18] (::typeof(∂(#7)))(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [19] (::Zygote.var"#41#42"{typeof(∂(#7))})(Δ::Float64)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:41
 [20] gradient(f::Function, args::CuArray{Float64, 2})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:59
 [21] top-level scope
    @ REPL[16]:1
 [22] top-level scope
    @ ~/.julia/packages/CUDA/k52QH/src/initialization.jl:81

AFAIK CUSPARSE only does sparse x dense, and we don’t have a native implementation of the sparse array types so can’t fall back to a GPU matrix-matrix multiplication.

Turns out I the issue is related to an optimisation of the gradients of sum. If I change sum((A * x')') to sum((A * x')' * CuMatrix(rand(n, m)) it works fine.

I’m unsure if that optimisation is helpful on GPU or if it should be considered a bug, but it’s narrow enough that it only really affected my minimal working example (the motivating problem I had featured a different indexerror that I’ve since tracked down).