Flux model with `SubArray` parameters don't work on GPU

Hi everyone,

I’m trying to compute the gradient for a “submodel” of a Flux model. That is, I’m defining layers with SubArrays of the original weights as parameters.

julia> using Flux

julia> X = rand(4, 1000) |> gpu;

julia> Y = rand(2, 1000) |> gpu;

julia> model = Chain(
           Dense(4, 4),
           Dense(4, 2)
       ) |> gpu
Chain(Dense(4, 4), Dense(4, 2))

julia> submodel = Chain(
           Dense(view(model[1].weight, 1:2, :), view(model[1].bias, 1:2), model[1].σ),
           Dense(view(model[2].weight, :, 1:2), view(model[2].bias, :), model[2].σ)
       )
Chain(Dense(4, 2), Dense(2, 2))

julia> subgrad = gradient(params(submodel)) do
           Flux.mse(submodel(X), Y)
       end

This code works on CPU, but keeping the gpu calls gives this scarry error message:

Click to expand the error message
ERROR: LoadError: MethodError: no method matching generic_matmatmul!(::CUDA.CuArray{Float32, 2}, ::LinearAlgebra.Adjoint{Float32, SubArray{Float32, 2, CUDA.CuArray{Float32, 2}, Tuple{UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}}, ::CUDA.CuArray{Float32, 2}, ::Bool, ::Bool)
Closest candidates are:
  generic_matmatmul!(::Union{GPUArrays.AbstractGPUArray{R, N}, Base.LogicalIndex{R, var"#s5"} where var"#s5"<:GPUArrays.AbstractGPUArray, Base.ReinterpretArray{R, N, var"#s1", var"#s2", IsReshaped} where {var"#s13"<:GPUArrays.AbstractGPUArray, var"#s1", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s13", I, L} where {var"#s3", var"#s4", I, L}, var"#s13"}, IsReshaped}, Base.ReshapedArray{R, N, var"#s4", MI} where {var"#s14"<:GPUArrays.AbstractGPUArray, var"#s4"<:Union{Base.ReinterpretArray{var"#s1", var"#s5", var"#s11", var"#s2", IsReshaped} where {var"#s1", var"#s5", var"#s11", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s14", I, L} where {var"#s3", var"#s4", I, L}, var"#s14"}, IsReshaped}, SubArray{var"#s3", var"#s2", var"#s14", I, L} where {var"#s3", var"#s2", I, L}, var"#s14"}, MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, N} where N}}, SubArray{R, N, var"#s5", I, L} where {var"#s15"<:GPUArrays.AbstractGPUArray, var"#s5"<:Union{Base.ReinterpretArray{var"#s2", var"#s1", var"#s11", var"#s21", IsReshaped} where {var"#s2", var"#s1", var"#s11", var"#s21"<:Union{SubArray{var"#s3", var"#s4", var"#s15", I, L} where {var"#s3", var"#s4", I, L}, var"#s15"}, IsReshaped}, Base.ReshapedArray{var"#s4", var"#s3", var"#s41", MI} where {var"#s4", var"#s3", var"#s41"<:Union{Base.ReinterpretArray{var"#s1", var"#s5", var"#s11", var"#s2", IsReshaped} where {var"#s1", var"#s5", var"#s11", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s15", I, L} where {var"#s3", var"#s4", I, L}, var"#s15"}, IsReshaped}, SubArray{var"#s3", var"#s2", var"#s15", I, L} where {var"#s3", var"#s2", I, L}, var"#s15"}, MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, N} where N}}, var"#s15"}, I, L}, LinearAlgebra.Adjoint{R, var"#s1"} where var"#s1"<:GPUArrays.AbstractGPUArray{R, N}, LinearAlgebra.Diagonal{R, var"#s11"} where var"#s11"<:GPUArrays.AbstractGPUArray{R, N}, LinearAlgebra.LowerTriangular{R, var"#s7"} where var"#s7"<:GPUArrays.AbstractGPUArray{R, N}, LinearAlgebra.Transpose{R, var"#s6"} where var"#s6"<:GPUArrays.AbstractGPUArray{R, N}, LinearAlgebra.Tridiagonal{R, var"#s12"} where var"#s12"<:GPUArrays.AbstractGPUArray{R, N}, LinearAlgebra.UnitLowerTriangular{R, var"#s8"} where var"#s8"<:GPUArrays.AbstractGPUArray{R, N}, LinearAlgebra.UnitUpperTriangular{R, var"#s10"} where var"#s10"<:GPUArrays.AbstractGPUArray{R, N}, LinearAlgebra.UpperTriangular{R, var"#s9"} where var"#s9"<:GPUArrays.AbstractGPUArray{R, N}, PermutedDimsArray{R, N, var"#s4", var"#s3", var"#s2"} where {var"#s4", var"#s3", var"#s2"<:GPUArrays.AbstractGPUArray}} where N, ::Union{GPUArrays.AbstractGPUArray{T, N}, Base.LogicalIndex{T, var"#s5"} where var"#s5"<:GPUArrays.AbstractGPUArray, Base.ReinterpretArray{T, N, var"#s1", var"#s2", IsReshaped} where {var"#s13"<:GPUArrays.AbstractGPUArray, var"#s1", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s13", I, L} where {var"#s3", var"#s4", I, L}, var"#s13"}, IsReshaped}, Base.ReshapedArray{T, N, var"#s4", MI} where {var"#s14"<:GPUArrays.AbstractGPUArray, var"#s4"<:Union{Base.ReinterpretArray{var"#s1", var"#s5", var"#s11", var"#s2", IsReshaped} where {var"#s1", var"#s5", var"#s11", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s14", I, L} where {var"#s3", var"#s4", I, L}, var"#s14"}, IsReshaped}, SubArray{var"#s3", var"#s2", var"#s14", I, L} where {var"#s3", var"#s2", I, L}, var"#s14"}, MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, N} where N}}, SubArray{T, N, var"#s5", I, L} where {var"#s15"<:GPUArrays.AbstractGPUArray, var"#s5"<:Union{Base.ReinterpretArray{var"#s2", var"#s1", var"#s11", var"#s21", IsReshaped} where {var"#s2", var"#s1", var"#s11", var"#s21"<:Union{SubArray{var"#s3", var"#s4", var"#s15", I, L} where {var"#s3", var"#s4", I, L}, var"#s15"}, IsReshaped}, Base.ReshapedArray{var"#s4", var"#s3", var"#s41", MI} where {var"#s4", var"#s3", var"#s41"<:Union{Base.ReinterpretArray{var"#s1", var"#s5", var"#s11", var"#s2", IsReshaped} where {var"#s1", var"#s5", var"#s11", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s15", I, L} where {var"#s3", var"#s4", I, L}, var"#s15"}, IsReshaped}, SubArray{var"#s3", var"#s2", var"#s15", I, L} where {var"#s3", var"#s2", I, L}, var"#s15"}, MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, N} where N}}, var"#s15"}, I, L}, LinearAlgebra.Adjoint{T, var"#s1"} where var"#s1"<:GPUArrays.AbstractGPUArray{T, N}, LinearAlgebra.Diagonal{T, var"#s11"} where var"#s11"<:GPUArrays.AbstractGPUArray{T, N}, LinearAlgebra.LowerTriangular{T, var"#s7"} where var"#s7"<:GPUArrays.AbstractGPUArray{T, N}, LinearAlgebra.Transpose{T, var"#s6"} where var"#s6"<:GPUArrays.AbstractGPUArray{T, N}, LinearAlgebra.Tridiagonal{T, var"#s12"} where var"#s12"<:GPUArrays.AbstractGPUArray{T, N}, LinearAlgebra.UnitLowerTriangular{T, var"#s8"} where var"#s8"<:GPUArrays.AbstractGPUArray{T, N}, LinearAlgebra.UnitUpperTriangular{T, var"#s10"} where var"#s10"<:GPUArrays.AbstractGPUArray{T, N}, LinearAlgebra.UpperTriangular{T, var"#s9"} where var"#s9"<:GPUArrays.AbstractGPUArray{T, N}, PermutedDimsArray{T, N, var"#s4", var"#s3", var"#s2"} where {var"#s4", var"#s3", var"#s2"<:GPUArrays.AbstractGPUArray}} where N, ::Union{GPUArrays.AbstractGPUArray{S, N}, Base.LogicalIndex{S, var"#s5"} where var"#s5"<:GPUArrays.AbstractGPUArray, Base.ReinterpretArray{S, N, var"#s1", var"#s2", IsReshaped} where {var"#s13"<:GPUArrays.AbstractGPUArray, var"#s1", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s13", I, L} where {var"#s3", var"#s4", I, L}, var"#s13"}, IsReshaped}, Base.ReshapedArray{S, N, var"#s4", MI} where {var"#s14"<:GPUArrays.AbstractGPUArray, var"#s4"<:Union{Base.ReinterpretArray{var"#s1", var"#s5", var"#s11", var"#s2", IsReshaped} where {var"#s1", var"#s5", var"#s11", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s14", I, L} where {var"#s3", var"#s4", I, L}, var"#s14"}, IsReshaped}, SubArray{var"#s3", var"#s2", var"#s14", I, L} where {var"#s3", var"#s2", I, L}, var"#s14"}, MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, N} where N}}, SubArray{S, N, var"#s5", I, L} where {var"#s15"<:GPUArrays.AbstractGPUArray, var"#s5"<:Union{Base.ReinterpretArray{var"#s2", var"#s1", var"#s11", var"#s21", IsReshaped} where {var"#s2", var"#s1", var"#s11", var"#s21"<:Union{SubArray{var"#s3", var"#s4", var"#s15", I, L} where {var"#s3", var"#s4", I, L}, var"#s15"}, IsReshaped}, Base.ReshapedArray{var"#s4", var"#s3", var"#s41", MI} where {var"#s4", var"#s3", var"#s41"<:Union{Base.ReinterpretArray{var"#s1", var"#s5", var"#s11", var"#s2", IsReshaped} where {var"#s1", var"#s5", var"#s11", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s15", I, L} where {var"#s3", var"#s4", I, L}, var"#s15"}, IsReshaped}, SubArray{var"#s3", var"#s2", var"#s15", I, L} where {var"#s3", var"#s2", I, L}, var"#s15"}, MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, N} where N}}, var"#s15"}, I, L}, LinearAlgebra.Adjoint{S, var"#s1"} where var"#s1"<:GPUArrays.AbstractGPUArray{S, N}, LinearAlgebra.Diagonal{S, var"#s11"} where var"#s11"<:GPUArrays.AbstractGPUArray{S, N}, LinearAlgebra.LowerTriangular{S, var"#s7"} where var"#s7"<:GPUArrays.AbstractGPUArray{S, N}, LinearAlgebra.Transpose{S, var"#s6"} where var"#s6"<:GPUArrays.AbstractGPUArray{S, N}, LinearAlgebra.Tridiagonal{S, var"#s12"} where var"#s12"<:GPUArrays.AbstractGPUArray{S, N}, LinearAlgebra.UnitLowerTriangular{S, var"#s8"} where var"#s8"<:GPUArrays.AbstractGPUArray{S, N}, LinearAlgebra.UnitUpperTriangular{S, var"#s10"} where var"#s10"<:GPUArrays.AbstractGPUArray{S, N}, LinearAlgebra.UpperTriangular{S, var"#s9"} where var"#s9"<:GPUArrays.AbstractGPUArray{S, N}, PermutedDimsArray{S, N, var"#s4", var"#s3", var"#s2"} where {var"#s4", var"#s3", var"#s2"<:GPUArrays.AbstractGPUArray}} where N, ::Number, ::Number) where {T, S, R} at /user/acarvalh/home/.julia/packages/GPUArrays/0ShDd/src/host/linalg.jl:102
Stacktrace:
  [1] gemm_dispatch!(C::CUDA.CuArray{Float32, 2}, A::LinearAlgebra.Adjoint{Float32, SubArray{Float32, 2, CUDA.CuArray{Float32, 2}, Tuple{UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}}, B::CUDA.CuArray{Float32, 2}, alpha::Bool, beta::Bool)
    @ CUDA.CUBLAS ~/.julia/packages/CUDA/k52QH/lib/cublas/linalg.jl:226
  [2] mul!
    @ ~/.julia/packages/CUDA/k52QH/lib/cublas/linalg.jl:243 [inlined]
  [3] mul!
    @ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:275 [inlined]
  [4] *
    @ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:153 [inlined]
  [5] #1525
    @ ~/.julia/packages/ChainRules/vz8Io/src/rulesets/Base/arraymath.jl:35 [inlined]
  [6] Thunk
    @ ~/.julia/packages/ChainRulesCore/oS5wQ/src/differentials/thunks.jl:98 [inlined]
  [7] unthunk
    @ ~/.julia/packages/ChainRulesCore/oS5wQ/src/differentials/thunks.jl:99 [inlined]
  [8] unthunk
    @ ~/.julia/packages/ChainRulesCore/oS5wQ/src/differentials/thunks.jl:120 [inlined]
  [9] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/6HN9x/src/compiler/chainrules.jl:41 [inlined]
 [10] map
    @ ./tuple.jl:215 [inlined]
 [11] wrap_chainrules_output
    @ ~/.julia/packages/Zygote/6HN9x/src/compiler/chainrules.jl:42 [inlined]
 [12] ZBack
    @ ~/.julia/packages/Zygote/6HN9x/src/compiler/chainrules.jl:77 [inlined]
 [13] Pullback
    @ ~/.julia/packages/Flux/6o4DQ/src/layers/basic.jl:147 [inlined]
 [14] (::typeof(∂(λ)))(Δ::CUDA.CuArray{Float32, 2})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [15] Pullback
    @ ~/.julia/packages/Flux/6o4DQ/src/layers/basic.jl:36 [inlined]
 [16] (::typeof(∂(applychain)))(Δ::CUDA.CuArray{Float32, 2})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [17] Pullback
    @ ~/.julia/packages/Flux/6o4DQ/src/layers/basic.jl:38 [inlined]
 [18] (::typeof(∂(λ)))(Δ::CUDA.CuArray{Float32, 2})
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [19] Pullback
    @ OMITTED [inlined]
 [20] (::typeof(∂(#1)))(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
 [21] (::Zygote.var"#69#70"{Zygote.Params, typeof(∂(#1)), Zygote.Context})(Δ::Float32)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:252
 [22] gradient(f::Function, args::Zygote.Params)
    @ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:59
 [23] top-level scope

I guess it means that the SubArrays are causing type confusions for generic_matmul!().

julia> typeof(params(model)[1])
CUDA.CuArray{Float32, 2}

julia> typeof(params(submodel)[1])
SubArray{Float32, 2, CUDA.CuArray{Float32, 2}, Tuple{UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}

So, I’m trying to find a way around this issue that doesn’t involve converting the SubArrays into CuArrays since that would imply too much memory usage in my use case. Do you have any clues?

3 Likes

Did you find any workaround for this?

1 Like

Views may have poor data locality and thus not play well with many of the CUDA routines Flux uses under the hood. If the main concern is a lack of VRAM, CUDA 3.4 now has support for unified arrays which move data between CPU and GPU as needed. It won’t be as fast as using device allocations exclusively, but that hardly matters if the choice is between running slightly slower and not running at all.

Another, possibly complementary but more primitive approach is to move model’s weights off the GPU after submodel is constructed. My understanding of the OP is that both are not used simultaneously.

2 Likes

It works now. I just tried it.

Many of the packages involved have been updated since I created the topic (4 months ago). Most notably, CUDA.jl went from version 3.1.0 to 3.3.4.
I haven’t touched the project in the meanwhile, so I don’t know which update solved the issue.

Thank you all for the help.

1 Like

The package updates fixed the problem.

Thank you for the answer, though. I didn’t know about unified arrays.

1 Like