I’m trying to train some deep learning graph models and I’m looking for an efficient way of performing a dense-sparse matrix product on GPU that’s Zygote-differentiable with respect to entries of the dense matrix. I was surprised that x * A
doesn’t work, and some experimentations with various transposes didn’t clarify exactly how I should be performing things.
Minimal working example below. Thanks in advance.
using CUDA, Zygote, Random, SparseArrays
CUDA.allowscalar(false)
n, m = 100, 13
# get a symmetric A
A = sprand(n, n, 0.1)
A = A + A'
A = CUDA.CUSPARSE.CuSparseMatrixCSC(A)
# sparse-dense works fine - just not what I'm after
x = CuMatrix(rand(n, m))
gradient(x -> sum(A * x), x)
xt = copy(x') # make an [m x n] matrix
gradient(x -> sum(x * A), xt) # indexerror - first stacktrace
gradient(x -> sum(A * x'), xt) # works fine
gradient(x -> sum((A * x')'), xt) # indexerror - second stacktrace
gradient(x -> sum(copy((A * x')')), xt) # works
Stacktrace for gradient(x -> sum(x * A), xt)
ERROR: scalar getindex is disallowed
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] assertscalar(op::String)
@ GPUArrays ~/.julia/packages/GPUArrays/0ShDd/src/host/indexing.jl:62
[3] getindex(::CuArray{Float64, 2}, ::Int64, ::Int64)
@ GPUArrays ~/.julia/packages/GPUArrays/0ShDd/src/host/indexing.jl:104
[4] _generic_matmatmul!(C::Matrix{Float64}, tA::Char, tB::Char, A::CuArray{Float64, 2}, B::CUDA.CUSPARSE.CuSparseMatrixCSC{Float64}, _add::LinearAlgebra.MulAddMul{true, true, Bool, Bool})
@ LinearAlgebra /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:835
[5] generic_matmatmul!(C::Matrix{Float64}, tA::Char, tB::Char, A::CuArray{Float64, 2}, B::CUDA.CUSPARSE.CuSparseMatrixCSC{Float64}, _add::LinearAlgebra.MulAddMul{true, true, Bool, Bool})
@ LinearAlgebra /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:802
[6] mul!
@ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:302 [inlined]
[7] mul!
@ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:275 [inlined]
[8] *
@ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:153 [inlined]
[9] rrule
@ ~/.julia/packages/ChainRules/vz8Io/src/rulesets/Base/arraymath.jl:40 [inlined]
[10] chain_rrule
@ ~/.julia/packages/Zygote/6HN9x/src/compiler/chainrules.jl:89 [inlined]
[11] macro expansion
@ ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0 [inlined]
[12] _pullback(::Zygote.Context, ::typeof(*), ::CuArray{Float64, 2}, ::CUDA.CUSPARSE.CuSparseMatrixCSC{Float64})
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:9
[13] _pullback
@ ./REPL[15]:1 [inlined]
[14] _pullback(ctx::Zygote.Context, f::var"#5#6", args::CuArray{Float64, 2})
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
[15] _pullback(f::Function, args::CuArray{Float64, 2})
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:34
[16] pullback(f::Function, args::CuArray{Float64, 2})
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:40
[17] gradient(f::Function, args::CuArray{Float64, 2})
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:58
[18] top-level scope
@ REPL[15]:1
[19] top-level scope
@ ~/.julia/packages/CUDA/k52QH/src/initialization.jl:81
Stacktrace for gradient(x -> sum((A * x')'), xt)
ERROR: scalar getindex is disallowed
Stacktrace:
[1] error(s::String)
@ Base ./error.jl:33
[2] assertscalar(op::String)
@ GPUArrays ~/.julia/packages/GPUArrays/0ShDd/src/host/indexing.jl:62
[3] getindex(::CuArray{Float64, 2}, ::Int64, ::Int64)
@ GPUArrays ~/.julia/packages/GPUArrays/0ShDd/src/host/indexing.jl:104
[4] _generic_matmatmul!(C::CuArray{Float64, 2}, tA::Char, tB::Char, A::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, B::CuArray{Float64, 2}, _add::LinearAlgebra.MulAddMul{true, true, Bool, Bool})
@ LinearAlgebra /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:835
[5] generic_matmatmul!(C::CuArray{Float64, 2}, tA::Char, tB::Char, A::FillArrays.Fill{Float64, 2, Tuple{Base.OneTo{Int64}, Base.OneTo{Int64}}}, B::CuArray{Float64, 2}, _add::LinearAlgebra.MulAddMul{true, true, Bool, Bool})
@ LinearAlgebra /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:802
[6] mul!
@ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:302 [inlined]
[7] mul!
@ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:275 [inlined]
[8] *
@ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:153 [inlined]
[9] #1523
@ ~/.julia/packages/ChainRules/vz8Io/src/rulesets/Base/arraymath.jl:31 [inlined]
[10] Thunk
@ ~/.julia/packages/ChainRulesCore/1qau5/src/differentials/thunks.jl:98 [inlined]
[11] unthunk
@ ~/.julia/packages/ChainRulesCore/1qau5/src/differentials/thunks.jl:99 [inlined]
[12] unthunk
@ ~/.julia/packages/ChainRulesCore/1qau5/src/differentials/thunks.jl:120 [inlined]
[13] wrap_chainrules_output
@ ~/.julia/packages/Zygote/6HN9x/src/compiler/chainrules.jl:41 [inlined]
[14] map
@ ./tuple.jl:215 [inlined]
[15] wrap_chainrules_output
@ ~/.julia/packages/Zygote/6HN9x/src/compiler/chainrules.jl:42 [inlined]
[16] ZBack
@ ~/.julia/packages/Zygote/6HN9x/src/compiler/chainrules.jl:77 [inlined]
[17] Pullback
@ ./REPL[16]:1 [inlined]
[18] (::typeof(∂(#7)))(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
[19] (::Zygote.var"#41#42"{typeof(∂(#7))})(Δ::Float64)
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:41
[20] gradient(f::Function, args::CuArray{Float64, 2})
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:59
[21] top-level scope
@ REPL[16]:1
[22] top-level scope
@ ~/.julia/packages/CUDA/k52QH/src/initialization.jl:81