Hi everyone,
I’m trying to compute the gradient for a “submodel” of a Flux model. That is, I’m defining layers with SubArray
s of the original weights as parameters.
julia> using Flux
julia> X = rand(4, 1000) |> gpu;
julia> Y = rand(2, 1000) |> gpu;
julia> model = Chain(
Dense(4, 4),
Dense(4, 2)
) |> gpu
Chain(Dense(4, 4), Dense(4, 2))
julia> submodel = Chain(
Dense(view(model[1].weight, 1:2, :), view(model[1].bias, 1:2), model[1].σ),
Dense(view(model[2].weight, :, 1:2), view(model[2].bias, :), model[2].σ)
)
Chain(Dense(4, 2), Dense(2, 2))
julia> subgrad = gradient(params(submodel)) do
Flux.mse(submodel(X), Y)
end
This code works on CPU, but keeping the gpu
calls gives this scarry error message:
Click to expand the error message
ERROR: LoadError: MethodError: no method matching generic_matmatmul!(::CUDA.CuArray{Float32, 2}, ::LinearAlgebra.Adjoint{Float32, SubArray{Float32, 2, CUDA.CuArray{Float32, 2}, Tuple{UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}}, ::CUDA.CuArray{Float32, 2}, ::Bool, ::Bool)
Closest candidates are:
generic_matmatmul!(::Union{GPUArrays.AbstractGPUArray{R, N}, Base.LogicalIndex{R, var"#s5"} where var"#s5"<:GPUArrays.AbstractGPUArray, Base.ReinterpretArray{R, N, var"#s1", var"#s2", IsReshaped} where {var"#s13"<:GPUArrays.AbstractGPUArray, var"#s1", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s13", I, L} where {var"#s3", var"#s4", I, L}, var"#s13"}, IsReshaped}, Base.ReshapedArray{R, N, var"#s4", MI} where {var"#s14"<:GPUArrays.AbstractGPUArray, var"#s4"<:Union{Base.ReinterpretArray{var"#s1", var"#s5", var"#s11", var"#s2", IsReshaped} where {var"#s1", var"#s5", var"#s11", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s14", I, L} where {var"#s3", var"#s4", I, L}, var"#s14"}, IsReshaped}, SubArray{var"#s3", var"#s2", var"#s14", I, L} where {var"#s3", var"#s2", I, L}, var"#s14"}, MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, N} where N}}, SubArray{R, N, var"#s5", I, L} where {var"#s15"<:GPUArrays.AbstractGPUArray, var"#s5"<:Union{Base.ReinterpretArray{var"#s2", var"#s1", var"#s11", var"#s21", IsReshaped} where {var"#s2", var"#s1", var"#s11", var"#s21"<:Union{SubArray{var"#s3", var"#s4", var"#s15", I, L} where {var"#s3", var"#s4", I, L}, var"#s15"}, IsReshaped}, Base.ReshapedArray{var"#s4", var"#s3", var"#s41", MI} where {var"#s4", var"#s3", var"#s41"<:Union{Base.ReinterpretArray{var"#s1", var"#s5", var"#s11", var"#s2", IsReshaped} where {var"#s1", var"#s5", var"#s11", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s15", I, L} where {var"#s3", var"#s4", I, L}, var"#s15"}, IsReshaped}, SubArray{var"#s3", var"#s2", var"#s15", I, L} where {var"#s3", var"#s2", I, L}, var"#s15"}, MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, N} where N}}, var"#s15"}, I, L}, LinearAlgebra.Adjoint{R, var"#s1"} where var"#s1"<:GPUArrays.AbstractGPUArray{R, N}, LinearAlgebra.Diagonal{R, var"#s11"} where var"#s11"<:GPUArrays.AbstractGPUArray{R, N}, LinearAlgebra.LowerTriangular{R, var"#s7"} where var"#s7"<:GPUArrays.AbstractGPUArray{R, N}, LinearAlgebra.Transpose{R, var"#s6"} where var"#s6"<:GPUArrays.AbstractGPUArray{R, N}, LinearAlgebra.Tridiagonal{R, var"#s12"} where var"#s12"<:GPUArrays.AbstractGPUArray{R, N}, LinearAlgebra.UnitLowerTriangular{R, var"#s8"} where var"#s8"<:GPUArrays.AbstractGPUArray{R, N}, LinearAlgebra.UnitUpperTriangular{R, var"#s10"} where var"#s10"<:GPUArrays.AbstractGPUArray{R, N}, LinearAlgebra.UpperTriangular{R, var"#s9"} where var"#s9"<:GPUArrays.AbstractGPUArray{R, N}, PermutedDimsArray{R, N, var"#s4", var"#s3", var"#s2"} where {var"#s4", var"#s3", var"#s2"<:GPUArrays.AbstractGPUArray}} where N, ::Union{GPUArrays.AbstractGPUArray{T, N}, Base.LogicalIndex{T, var"#s5"} where var"#s5"<:GPUArrays.AbstractGPUArray, Base.ReinterpretArray{T, N, var"#s1", var"#s2", IsReshaped} where {var"#s13"<:GPUArrays.AbstractGPUArray, var"#s1", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s13", I, L} where {var"#s3", var"#s4", I, L}, var"#s13"}, IsReshaped}, Base.ReshapedArray{T, N, var"#s4", MI} where {var"#s14"<:GPUArrays.AbstractGPUArray, var"#s4"<:Union{Base.ReinterpretArray{var"#s1", var"#s5", var"#s11", var"#s2", IsReshaped} where {var"#s1", var"#s5", var"#s11", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s14", I, L} where {var"#s3", var"#s4", I, L}, var"#s14"}, IsReshaped}, SubArray{var"#s3", var"#s2", var"#s14", I, L} where {var"#s3", var"#s2", I, L}, var"#s14"}, MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, N} where N}}, SubArray{T, N, var"#s5", I, L} where {var"#s15"<:GPUArrays.AbstractGPUArray, var"#s5"<:Union{Base.ReinterpretArray{var"#s2", var"#s1", var"#s11", var"#s21", IsReshaped} where {var"#s2", var"#s1", var"#s11", var"#s21"<:Union{SubArray{var"#s3", var"#s4", var"#s15", I, L} where {var"#s3", var"#s4", I, L}, var"#s15"}, IsReshaped}, Base.ReshapedArray{var"#s4", var"#s3", var"#s41", MI} where {var"#s4", var"#s3", var"#s41"<:Union{Base.ReinterpretArray{var"#s1", var"#s5", var"#s11", var"#s2", IsReshaped} where {var"#s1", var"#s5", var"#s11", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s15", I, L} where {var"#s3", var"#s4", I, L}, var"#s15"}, IsReshaped}, SubArray{var"#s3", var"#s2", var"#s15", I, L} where {var"#s3", var"#s2", I, L}, var"#s15"}, MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, N} where N}}, var"#s15"}, I, L}, LinearAlgebra.Adjoint{T, var"#s1"} where var"#s1"<:GPUArrays.AbstractGPUArray{T, N}, LinearAlgebra.Diagonal{T, var"#s11"} where var"#s11"<:GPUArrays.AbstractGPUArray{T, N}, LinearAlgebra.LowerTriangular{T, var"#s7"} where var"#s7"<:GPUArrays.AbstractGPUArray{T, N}, LinearAlgebra.Transpose{T, var"#s6"} where var"#s6"<:GPUArrays.AbstractGPUArray{T, N}, LinearAlgebra.Tridiagonal{T, var"#s12"} where var"#s12"<:GPUArrays.AbstractGPUArray{T, N}, LinearAlgebra.UnitLowerTriangular{T, var"#s8"} where var"#s8"<:GPUArrays.AbstractGPUArray{T, N}, LinearAlgebra.UnitUpperTriangular{T, var"#s10"} where var"#s10"<:GPUArrays.AbstractGPUArray{T, N}, LinearAlgebra.UpperTriangular{T, var"#s9"} where var"#s9"<:GPUArrays.AbstractGPUArray{T, N}, PermutedDimsArray{T, N, var"#s4", var"#s3", var"#s2"} where {var"#s4", var"#s3", var"#s2"<:GPUArrays.AbstractGPUArray}} where N, ::Union{GPUArrays.AbstractGPUArray{S, N}, Base.LogicalIndex{S, var"#s5"} where var"#s5"<:GPUArrays.AbstractGPUArray, Base.ReinterpretArray{S, N, var"#s1", var"#s2", IsReshaped} where {var"#s13"<:GPUArrays.AbstractGPUArray, var"#s1", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s13", I, L} where {var"#s3", var"#s4", I, L}, var"#s13"}, IsReshaped}, Base.ReshapedArray{S, N, var"#s4", MI} where {var"#s14"<:GPUArrays.AbstractGPUArray, var"#s4"<:Union{Base.ReinterpretArray{var"#s1", var"#s5", var"#s11", var"#s2", IsReshaped} where {var"#s1", var"#s5", var"#s11", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s14", I, L} where {var"#s3", var"#s4", I, L}, var"#s14"}, IsReshaped}, SubArray{var"#s3", var"#s2", var"#s14", I, L} where {var"#s3", var"#s2", I, L}, var"#s14"}, MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, N} where N}}, SubArray{S, N, var"#s5", I, L} where {var"#s15"<:GPUArrays.AbstractGPUArray, var"#s5"<:Union{Base.ReinterpretArray{var"#s2", var"#s1", var"#s11", var"#s21", IsReshaped} where {var"#s2", var"#s1", var"#s11", var"#s21"<:Union{SubArray{var"#s3", var"#s4", var"#s15", I, L} where {var"#s3", var"#s4", I, L}, var"#s15"}, IsReshaped}, Base.ReshapedArray{var"#s4", var"#s3", var"#s41", MI} where {var"#s4", var"#s3", var"#s41"<:Union{Base.ReinterpretArray{var"#s1", var"#s5", var"#s11", var"#s2", IsReshaped} where {var"#s1", var"#s5", var"#s11", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s15", I, L} where {var"#s3", var"#s4", I, L}, var"#s15"}, IsReshaped}, SubArray{var"#s3", var"#s2", var"#s15", I, L} where {var"#s3", var"#s2", I, L}, var"#s15"}, MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, N} where N}}, var"#s15"}, I, L}, LinearAlgebra.Adjoint{S, var"#s1"} where var"#s1"<:GPUArrays.AbstractGPUArray{S, N}, LinearAlgebra.Diagonal{S, var"#s11"} where var"#s11"<:GPUArrays.AbstractGPUArray{S, N}, LinearAlgebra.LowerTriangular{S, var"#s7"} where var"#s7"<:GPUArrays.AbstractGPUArray{S, N}, LinearAlgebra.Transpose{S, var"#s6"} where var"#s6"<:GPUArrays.AbstractGPUArray{S, N}, LinearAlgebra.Tridiagonal{S, var"#s12"} where var"#s12"<:GPUArrays.AbstractGPUArray{S, N}, LinearAlgebra.UnitLowerTriangular{S, var"#s8"} where var"#s8"<:GPUArrays.AbstractGPUArray{S, N}, LinearAlgebra.UnitUpperTriangular{S, var"#s10"} where var"#s10"<:GPUArrays.AbstractGPUArray{S, N}, LinearAlgebra.UpperTriangular{S, var"#s9"} where var"#s9"<:GPUArrays.AbstractGPUArray{S, N}, PermutedDimsArray{S, N, var"#s4", var"#s3", var"#s2"} where {var"#s4", var"#s3", var"#s2"<:GPUArrays.AbstractGPUArray}} where N, ::Number, ::Number) where {T, S, R} at /user/acarvalh/home/.julia/packages/GPUArrays/0ShDd/src/host/linalg.jl:102
Stacktrace:
[1] gemm_dispatch!(C::CUDA.CuArray{Float32, 2}, A::LinearAlgebra.Adjoint{Float32, SubArray{Float32, 2, CUDA.CuArray{Float32, 2}, Tuple{UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}}, B::CUDA.CuArray{Float32, 2}, alpha::Bool, beta::Bool)
@ CUDA.CUBLAS ~/.julia/packages/CUDA/k52QH/lib/cublas/linalg.jl:226
[2] mul!
@ ~/.julia/packages/CUDA/k52QH/lib/cublas/linalg.jl:243 [inlined]
[3] mul!
@ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:275 [inlined]
[4] *
@ /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:153 [inlined]
[5] #1525
@ ~/.julia/packages/ChainRules/vz8Io/src/rulesets/Base/arraymath.jl:35 [inlined]
[6] Thunk
@ ~/.julia/packages/ChainRulesCore/oS5wQ/src/differentials/thunks.jl:98 [inlined]
[7] unthunk
@ ~/.julia/packages/ChainRulesCore/oS5wQ/src/differentials/thunks.jl:99 [inlined]
[8] unthunk
@ ~/.julia/packages/ChainRulesCore/oS5wQ/src/differentials/thunks.jl:120 [inlined]
[9] wrap_chainrules_output
@ ~/.julia/packages/Zygote/6HN9x/src/compiler/chainrules.jl:41 [inlined]
[10] map
@ ./tuple.jl:215 [inlined]
[11] wrap_chainrules_output
@ ~/.julia/packages/Zygote/6HN9x/src/compiler/chainrules.jl:42 [inlined]
[12] ZBack
@ ~/.julia/packages/Zygote/6HN9x/src/compiler/chainrules.jl:77 [inlined]
[13] Pullback
@ ~/.julia/packages/Flux/6o4DQ/src/layers/basic.jl:147 [inlined]
[14] (::typeof(∂(λ)))(Δ::CUDA.CuArray{Float32, 2})
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
[15] Pullback
@ ~/.julia/packages/Flux/6o4DQ/src/layers/basic.jl:36 [inlined]
[16] (::typeof(∂(applychain)))(Δ::CUDA.CuArray{Float32, 2})
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
[17] Pullback
@ ~/.julia/packages/Flux/6o4DQ/src/layers/basic.jl:38 [inlined]
[18] (::typeof(∂(λ)))(Δ::CUDA.CuArray{Float32, 2})
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
[19] Pullback
@ OMITTED [inlined]
[20] (::typeof(∂(#1)))(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface2.jl:0
[21] (::Zygote.var"#69#70"{Zygote.Params, typeof(∂(#1)), Zygote.Context})(Δ::Float32)
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:252
[22] gradient(f::Function, args::Zygote.Params)
@ Zygote ~/.julia/packages/Zygote/6HN9x/src/compiler/interface.jl:59
[23] top-level scope
I guess it means that the SubArray
s are causing type confusions for generic_matmul!()
.
julia> typeof(params(model)[1])
CUDA.CuArray{Float32, 2}
julia> typeof(params(submodel)[1])
SubArray{Float32, 2, CUDA.CuArray{Float32, 2}, Tuple{UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}
So, I’m trying to find a way around this issue that doesn’t involve converting the SubArray
s into CuArray
s since that would imply too much memory usage in my use case. Do you have any clues?