Flux model with `SubArray` parameters don't work on GPU

Hi everyone,

I’m trying to compute the gradient for a “submodel” of a Flux model. That is, I’m defining layers with SubArrays of the original weights as parameters.

julia> using Flux

julia> X = rand(4, 1000) |> gpu;

julia> Y = rand(2, 1000) |> gpu;

julia> model = Chain(
           Dense(4, 4),
           Dense(4, 2)
       ) |> gpu
Chain(Dense(4, 4), Dense(4, 2))

julia> submodel = Chain(
           Dense(view(model[1].weight, 1:2, :), view(model[1].bias, 1:2), model[1].σ),
           Dense(view(model[2].weight, :, 1:2), view(model[2].bias, :), model[2].σ)
       )
Chain(Dense(4, 2), Dense(2, 2))

julia> subgrad = gradient(params(submodel)) do
           Flux.mse(submodel(X), Y)
       end

This code works on CPU, but keeping the gpu calls gives this scarry error message:

Click to expand the error message
ERROR: LoadError: MethodError: no method matching generic_matmatmul!(::CUDA.CuArray{Float32, 2}, ::LinearAlgebra.Adjoint{Float32, SubArray{Float32, 2, CUDA.CuArray{Float32, 2}, Tuple{UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}}, ::CUDA.CuArray{Float32, 2}, ::Bool, ::Bool)
Closest candidates are:
  generic_matmatmul!(::Union{GPUArrays.AbstractGPUArray{R, N}, Base.LogicalIndex{R, var"#s5"} where var"#s5"<:GPUArrays.AbstractGPUArray, Base.ReinterpretArray{R, N, var"#s1", var"#s2", IsReshaped} where {var"#s13"<:GPUArrays.AbstractGPUArray, var"#s1", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s13", I, L} where {var"#s3", var"#s4", I, L}, var"#s13"}, IsReshaped}, Base.ReshapedArray{R, N, var"#s4", MI} where {var"#s14"<:GPUArrays.AbstractGPUArray, var"#s4"<:Union{Base.ReinterpretArray{var"#s1", var"#s5", var"#s11", var"#s2", IsReshaped} where {var"#s1", var"#s5", var"#s11", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s14", I, L} where {var"#s3", var"#s4", I, L}, var"#s14"}, IsReshaped}, SubArray{var"#s3", var"#s2", var"#s14", I, L} where {var"#s3", var"#s2", I, L}, var"#s14"}, MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, N} where N}}, SubArray{R, N, var"#s5", I, L} where {var"#s15"<:GPUArrays.AbstractGPUArray, var"#s5"<:Union{Base.ReinterpretArray{var"#s2", var"#s1", var"#s11", var"#s21", IsReshaped} where {var"#s2", var"#s1", var"#s11", var"#s21"<:Union{SubArray{var"#s3", var"#s4", var"#s15", I, L} where {var"#s3", var"#s4", I, L}, var"#s15"}, IsReshaped}, Base.ReshapedArray{var"#s4", var"#s3", var"#s41", MI} where {var"#s4", var"#s3", var"#s41"<:Union{Base.ReinterpretArray{var"#s1", var"#s5", var"#s11", var"#s2", IsReshaped} where {var"#s1", var"#s5", var"#s11", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s15", I, L} where {var"#s3", var"#s4", I, L}, var"#s15"}, IsReshaped}, SubArray{var"#s3", var"#s2", var"#s15", I, L} where {var"#s3", var"#s2", I, L}, var"#s15"}, MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, N} where N}}, var"#s15"}, I, L}, LinearAlgebra.Adjoint{R, var"#s1"} where var"#s1"<:GPUArrays.AbstractGPUArray{R, N}, LinearAlgebra.Diagonal{R, var"#s11"} where var"#s11"<:GPUArrays.AbstractGPUArray{R, N}, LinearAlgebra.LowerTriangular{R, var"#s7"} where var"#s7"<:GPUArrays.AbstractGPUArray{R, N}, LinearAlgebra.Transpose{R, var"#s6"} where var"#s6"<:GPUArrays.AbstractGPUArray{R, N}, LinearAlgebra.Tridiagonal{R, var"#s12"} where var"#s12"<:GPUArrays.AbstractGPUArray{R, N}, LinearAlgebra.UnitLowerTriangular{R, var"#s8"} where var"#s8"<:GPUArrays.AbstractGPUArray{R, N}, LinearAlgebra.UnitUpperTriangular{R, var"#s10"} where var"#s10"<:GPUArrays.AbstractGPUArray{R, N}, LinearAlgebra.UpperTriangular{R, var"#s9"} where var"#s9"<:GPUArrays.AbstractGPUArray{R, N}, PermutedDimsArray{R, N, var"#s4", var"#s3", var"#s2"} where {var"#s4", var"#s3", var"#s2"<:GPUArrays.AbstractGPUArray}} where N, ::Union{GPUArrays.AbstractGPUArray{T, N}, Base.LogicalIndex{T, var"#s5"} where var"#s5"<:GPUArrays.AbstractGPUArray, Base.ReinterpretArray{T, N, var"#s1", var"#s2", IsReshaped} where {var"#s13"<:GPUArrays.AbstractGPUArray, var"#s1", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s13", I, L} where {var"#s3", var"#s4", I, L}, var"#s13"}, IsReshaped}, Base.ReshapedArray{T, N, var"#s4", MI} where {var"#s14"<:GPUArrays.AbstractGPUArray, var"#s4"<:Union{Base.ReinterpretArray{var"#s1", var"#s5", var"#s11", var"#s2", IsReshaped} where {var"#s1", var"#s5", var"#s11", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s14", I, L} where {var"#s3", var"#s4", I, L}, var"#s14"}, IsReshaped}, SubArray{var"#s3", var"#s2", var"#s14", I, L} where {var"#s3", var"#s2", I, L}, var"#s14"}, MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, N} where N}}, SubArray{T, N, var"#s5", I, L} where {var"#s15"<:GPUArrays.AbstractGPUArray, var"#s5"<:Union{Base.ReinterpretArray{var"#s2", var"#s1", var"#s11", var"#s21", IsReshaped} where {var"#s2", var"#s1", var"#s11", var"#s21"<:Union{SubArray{var"#s3", var"#s4", var"#s15", I, L} where {var"#s3", var"#s4", I, L}, var"#s15"}, IsReshaped}, Base.ReshapedArray{var"#s4", var"#s3", var"#s41", MI} where {var"#s4", var"#s3", var"#s41"<:Union{Base.ReinterpretArray{var"#s1", var"#s5", var"#s11", var"#s2", IsReshaped} where {var"#s1", var"#s5", var"#s11", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s15", I, L} where {var"#s3", var"#s4", I, L}, var"#s15"}, IsReshaped}, SubArray{var"#s3", var"#s2", var"#s15", I, L} where {var"#s3", var"#s2", I, L}, var"#s15"}, MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, N} where N}}, var"#s15"}, I, L}, LinearAlgebra.Adjoint{T, var"#s1"} where var"#s1"<:GPUArrays.AbstractGPUArray{T, N}, LinearAlgebra.Diagonal{T, var"#s11"} where var"#s11"<:GPUArrays.AbstractGPUArray{T, N}, LinearAlgebra.LowerTriangular{T, var"#s7"} where var"#s7"<:GPUArrays.AbstractGPUArray{T, N}, LinearAlgebra.Transpose{T, var"#s6"} where var"#s6"<:GPUArrays.AbstractGPUArray{T, N}, LinearAlgebra.Tridiagonal{T, var"#s12"} where var"#s12"<:GPUArrays.AbstractGPUArray{T, N}, LinearAlgebra.UnitLowerTriangular{T, var"#s8"} where var"#s8"<:GPUArrays.AbstractGPUArray{T, N}, LinearAlgebra.UnitUpperTriangular{T, var"#s10"} where var"#s10"<:GPUArrays.AbstractGPUArray{T, N}, LinearAlgebra.UpperTriangular{T, var"#s9"} where var"#s9"<:GPUArrays.AbstractGPUArray{T, N}, PermutedDimsArray{T, N, var"#s4", var"#s3", var"#s2"} where {var"#s4", var"#s3", var"#s2"<:GPUArrays.AbstractGPUArray}} where N, ::Union{GPUArrays.AbstractGPUArray{S, N}, Base.LogicalIndex{S, var"#s5"} where var"#s5"<:GPUArrays.AbstractGPUArray, Base.ReinterpretArray{S, N, var"#s1", var"#s2", IsReshaped} where {var"#s13"<:GPUArrays.AbstractGPUArray, var"#s1", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s13", I, L} where {var"#s3", var"#s4", I, L}, var"#s13"}, IsReshaped}, Base.ReshapedArray{S, N, var"#s4", MI} where {var"#s14"<:GPUArrays.AbstractGPUArray, var"#s4"<:Union{Base.ReinterpretArray{var"#s1", var"#s5", var"#s11", var"#s2", IsReshaped} where {var"#s1", var"#s5", var"#s11", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s14", I, L} where {var"#s3", var"#s4", I, L}, var"#s14"}, IsReshaped}, SubArray{var"#s3", var"#s2", var"#s14", I, L} where {var"#s3", var"#s2", I, L}, var"#s14"}, MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, N} where N}}, SubArray{S, N, var"#s5", I, L} where {var"#s15"<:GPUArrays.AbstractGPUArray, var"#s5"<:Union{Base.ReinterpretArray{var"#s2", var"#s1", var"#s11", var"#s21", IsReshaped} where {var"#s2", var"#s1", var"#s11", var"#s21"<:Union{SubArray{var"#s3", var"#s4", var"#s15", I, L} where {var"#s3", var"#s4", I, L}, var"#s15"}, IsReshaped}, Base.ReshapedArray{var"#s4", var"#s3", var"#s41", MI} where {var"#s4", var"#s3", var"#s41"<:Union{Base.ReinterpretArray{var"#s1", var"#s5", var"#s11", var"#s2", IsReshaped} where {var"#s1", var"#s5", var"#s11", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s15", I, L} where {var"#s3", var"#s4", I, L}, var"#s15"}, IsReshaped}, SubArray{var"#s3", var"#s2", var"#s15", I, L} where {var"#s3", var"#s2", I, L}, var"#s15"}, MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, N} where N}}, var"#s15"}, I, L}, LinearAlgebra.Adjoint{S, var"#s1"} where var"#s1"<:GPUArrays.AbstractGPUArray{S, N}, LinearAlgebra.Diagonal{S, var"#s11"} where var"#s11"<:GPUArrays.AbstractGPUArray{S, N}, LinearAlgebra.LowerTriangular{S, var"#s7"} where var"#s7"<:GPUArrays.AbstractGPUArray{S, N}, LinearAlgebra.Transpose{S, var"#s6"} where var"#s6"<:GPUArrays.AbstractGPUArray{S, N}, LinearAlgebra.Tridiagonal{S, var"#s12"} where var"#s12"<:GPUArrays.AbstractGPUArray{S, N}, LinearAlgebra.UnitLowerTriangular{S, var"#s8"} where var"#s8"<:GPUArrays.AbstractGPUArray{S, N}, LinearAlgebra.UnitUpperTriangular{S, var"#s10"} where var"#s10"<:GPUArrays.AbstractGPUArray{S, N}, LinearAlgebra.UpperTriangular{S, var"#s9"} where var"#s9"<:GPUArrays.AbstractGPUArray{S, N}, PermutedDimsArray{S, N, var"#s4", var"#s3", var"#s2"} where {var"#s4", var"#s3", var"#s2"<:GPUArrays.AbstractGPUArray}} where N, ::Number, ::Number) where {T, S, R} at /user/acarvalh/home/.julia/packages/GPUArrays/0ShDd/src/host/linalg.jl:102
Stacktrace:
  [1] gemm_dispatch!(C::CUDA.CuArray{Float32, 2}, A::LinearAlgebra.Adjoint{Float32, SubArray{Float32, 2, CUDA.CuArray{Float32, 2}, Tuple{UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}}, B::CUDA.CuArray{Float32, 2}, alpha::Bool, beta::Bool)
    @ CUDA.CUBLAS ~/.julia/packages/CUDA/k52QH/lib/cublas/linalg.jl:226
  [2] mul!
    @ ~/.julia/packages/CUDA/k52QH/lib/cublas/linalg.jl:243 [inlined]
  [3] mul!
    @ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:275 [inlined]
  [4] *
    @ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:153 [inlined]
  [5] #1525
    @ ~/.julia/packages/ChainRules/vz8Io/src/rulesets/Base/arraymath.jl:35 [inlined]
  [6] Thunk
    @ ~/.julia/packages/ChainRulesCore/oS5wQ/src/differentials/thunks.jl:98 [inlined]
  [7] unthunk
    @ ~/.julia/packages/ChainRulesCore/oS5wQ/src/differentials/thunks.jl:99 [inlined]
  [8] unthunk
    @ ~/.julia/packages/ChainRulesCore/oS5wQ/src/differentials/thunks.jl:120 [inlined]
  [9] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/6HN9x/src/compiler/chainrules.jl:41 [inlined]
 [10] map
    @ ./tuple.jl:215 [inlined]
 [11] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/6HN9x/src/compiler/chainrules.jl:42 [inlined]
 [12] ZBack
    @ ~/.julia/packages/Zygote/6HN9x/src/compiler/chainrules.jl:77 [inlined]
 [13] Pullback
    @ ~/.julia/packages/Flux/6o4DQ/src/layers/basic.jl:147 [inlined]
 [14] (::typeof(∂(λ)))(Δ::CUDA.CuArray{Float32, 2})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [15] Pullback
    @ ~/.julia/packages/Flux/6o4DQ/src/layers/basic.jl:36 [inlined]
 [16] (::typeof(∂(applychain)))(Δ::CUDA.CuArray{Float32, 2})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [17] Pullback
    @ ~/.julia/packages/Flux/6o4DQ/src/layers/basic.jl:38 [inlined]
 [18] (::typeof(∂(λ)))(Δ::CUDA.CuArray{Float32, 2})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [19] Pullback
    @ OMITTED [inlined]
 [20] (::typeof(∂(#1)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [21] (::Zygote.var"#69#70"{Zygote.Params, typeof(∂(#1)), Zygote.Context})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:252
 [22] gradient(f::Function, args::Zygote.Params)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:59
 [23] top-level scope

I guess it means that the SubArrays are causing type confusions for generic_matmul!().

julia> typeof(params(model)[1])
CUDA.CuArray{Float32, 2}

julia> typeof(params(submodel)[1])
SubArray{Float32, 2, CUDA.CuArray{Float32, 2}, Tuple{UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}

So, I’m trying to find a way around this issue that doesn’t involve converting the SubArrays into CuArrays since that would imply too much memory usage in my use case. Do you have any clues?

2 Likes