Hi, had a problem when using CuArrays for a specific operation. It is optimized for CPU and it would be nice if it worked with the same code on GPU so there was no need to dispatch specifically on it.
a = cu(randn(10, 10))
b = cu(randn(10, 10))
c = similar(a)
@views mul!(c[:, 2:end], a, transpose(b[1:end-1, :])) # Would like this but it errors
@views mul!(c[:, 2:end], a, transpose(b)[:, 1:end-1]) # Works fine
mul!(c[:, 2:end], a, transpose(b[1:end-1, :])) # Works fine
The error from the GPU version is
ERROR: MethodError: no method matching generic_matmatmul!(::CuArray{Float32, 2}, ::CuArray{Float32, 2}, ::Transpose{Float32, SubArray{Float32, 2, CuArray{Float32, 2}, Tuple{UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}}, ::Bool, ::Bool)
Closest candidates are:
generic_matmatmul!(::Union{GPUArrays.AbstractGPUArray{R, N}, Base.LogicalIndex{R, var"#s5"} where var"#s5"<:GPUArrays.AbstractGPUArray, Base.ReinterpretArray{R, N, var"#s1", var"#s2", IsReshaped} where {var"#s13"<:GPUArrays.AbstractGPUArray, var"#s1", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s13", I, L} where {var"#s3", var"#s4", I, L}, var"#s13"}, IsReshaped}, Base.ReshapedArray{R, N, var"#s4", MI} where {var"#s14"<:GPUArrays.AbstractGPUArray, var"#s4"<:Union{Base.ReinterpretArray{var"#s1", var"#s5", var"#s11", var"#s2", IsReshaped} where {var"#s1", var"#s5", var"#s11", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s14", I, L} where {var"#s3", var"#s4", I, L}, var"#s14"}, IsReshaped}, SubArray{var"#s3", var"#s2", var"#s14", I, L} where {var"#s3", var"#s2", I, L}, var"#s14"}, MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, N} where N}}, SubArray{R, N, var"#s5", I, L} where {var"#s15"<:GPUArrays.AbstractGPUArray, var"#s5"<:Union{Base.ReinterpretArray{var"#s2", var"#s1", var"#s11", var"#s21", IsReshaped} where {var"#s2", var"#s1", var"#s11", var"#s21"<:Union{SubArray{var"#s3", var"#s4", var"#s15", I, L} where {var"#s3", var"#s4", I, L}, var"#s15"}, IsReshaped}, Base.ReshapedArray{var"#s4", var"#s3", var"#s41", MI} where {var"#s4", var"#s3", var"#s41"<:Union{Base.ReinterpretArray{var"#s1", var"#s5", var"#s11", var"#s2", IsReshaped} where {var"#s1", var"#s5", var"#s11", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s15", I, L} where {var"#s3", var"#s4", I, L}, var"#s15"}, IsReshaped}, SubArray{var"#s3", var"#s2", var"#s15", I, L} where {var"#s3", var"#s2", I, L}, var"#s15"}, MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, N} where N}}, var"#s15"}, I, L}, Adjoint{R, var"#s1"} where var"#s1"<:GPUArrays.AbstractGPUArray{R, N}, Diagonal{R, var"#s11"} where var"#s11"<:GPUArrays.AbstractGPUArray{R, N}, LowerTriangular{R, var"#s7"} where var"#s7"<:GPUArrays.AbstractGPUArray{R, N}, Transpose{R, var"#s6"} where var"#s6"<:GPUArrays.AbstractGPUArray{R, N}, Tridiagonal{R, var"#s12"} where var"#s12"<:GPUArrays.AbstractGPUArray{R, N}, UnitLowerTriangular{R, var"#s8"} where var"#s8"<:GPUArrays.AbstractGPUArray{R, N}, UnitUpperTriangular{R, var"#s10"} where var"#s10"<:GPUArrays.AbstractGPUArray{R, N}, UpperTriangular{R, var"#s9"} where var"#s9"<:GPUArrays.AbstractGPUArray{R, N}, PermutedDimsArray{R, N, var"#s4", var"#s3", var"#s2"} where {var"#s4", var"#s3", var"#s2"<:GPUArrays.AbstractGPUArray}} where N, ::Union{GPUArrays.AbstractGPUArray{T, N}, Base.LogicalIndex{T, var"#s5"} where var"#s5"<:GPUArrays.AbstractGPUArray, Base.ReinterpretArray{T, N, var"#s1", var"#s2", IsReshaped} where {var"#s13"<:GPUArrays.AbstractGPUArray, var"#s1", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s13", I, L} where {var"#s3", var"#s4", I, L}, var"#s13"}, IsReshaped}, Base.ReshapedArray{T, N, var"#s4", MI} where {var"#s14"<:GPUArrays.AbstractGPUArray, var"#s4"<:Union{Base.ReinterpretArray{var"#s1", var"#s5", var"#s11", var"#s2", IsReshaped} where {var"#s1", var"#s5", var"#s11", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s14", I, L} where {var"#s3", var"#s4", I, L}, var"#s14"}, IsReshaped}, SubArray{var"#s3", var"#s2", var"#s14", I, L} where {var"#s3", var"#s2", I, L}, var"#s14"}, MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, N} where N}}, SubArray{T, N, var"#s5", I, L} where {var"#s15"<:GPUArrays.AbstractGPUArray, var"#s5"<:Union{Base.ReinterpretArray{var"#s2", var"#s1", var"#s11", var"#s21", IsReshaped} where {var"#s2", var"#s1", var"#s11", var"#s21"<:Union{SubArray{var"#s3", var"#s4", var"#s15", I, L} where {var"#s3", var"#s4", I, L}, var"#s15"}, IsReshaped}, Base.ReshapedArray{var"#s4", var"#s3", var"#s41", MI} where {var"#s4", var"#s3", var"#s41"<:Union{Base.ReinterpretArray{var"#s1", var"#s5", var"#s11", var"#s2", IsReshaped} where {var"#s1", var"#s5", var"#s11", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s15", I, L} where {var"#s3", var"#s4", I, L}, var"#s15"}, IsReshaped}, SubArray{var"#s3", var"#s2", var"#s15", I, L} where {var"#s3", var"#s2", I, L}, var"#s15"}, MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, N} where N}}, var"#s15"}, I, L}, Adjoint{T, var"#s1"} where var"#s1"<:GPUArrays.AbstractGPUArray{T, N}, Diagonal{T, var"#s11"} where var"#s11"<:GPUArrays.AbstractGPUArray{T, N}, LowerTriangular{T, var"#s7"} where var"#s7"<:GPUArrays.AbstractGPUArray{T, N}, Transpose{T, var"#s6"} where var"#s6"<:GPUArrays.AbstractGPUArray{T, N}, Tridiagonal{T, var"#s12"} where var"#s12"<:GPUArrays.AbstractGPUArray{T, N}, UnitLowerTriangular{T, var"#s8"} where var"#s8"<:GPUArrays.AbstractGPUArray{T, N}, UnitUpperTriangular{T, var"#s10"} where var"#s10"<:GPUArrays.AbstractGPUArray{T, N}, UpperTriangular{T, var"#s9"} where var"#s9"<:GPUArrays.AbstractGPUArray{T, N}, PermutedDimsArray{T, N, var"#s4", var"#s3", var"#s2"} where {var"#s4", var"#s3", var"#s2"<:GPUArrays.AbstractGPUArray}} where N, ::Union{GPUArrays.AbstractGPUArray{S, N}, Base.LogicalIndex{S, var"#s5"} where var"#s5"<:GPUArrays.AbstractGPUArray, Base.ReinterpretArray{S, N, var"#s1", var"#s2", IsReshaped} where {var"#s13"<:GPUArrays.AbstractGPUArray, var"#s1", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s13", I, L} where {var"#s3", var"#s4", I, L}, var"#s13"}, IsReshaped}, Base.ReshapedArray{S, N, var"#s4", MI} where {var"#s14"<:GPUArrays.AbstractGPUArray, var"#s4"<:Union{Base.ReinterpretArray{var"#s1", var"#s5", var"#s11", var"#s2", IsReshaped} where {var"#s1", var"#s5", var"#s11", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s14", I, L} where {var"#s3", var"#s4", I, L}, var"#s14"}, IsReshaped}, SubArray{var"#s3", var"#s2", var"#s14", I, L} where {var"#s3", var"#s2", I, L}, var"#s14"}, MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, N} where N}}, SubArray{S, N, var"#s5", I, L} where {var"#s15"<:GPUArrays.AbstractGPUArray, var"#s5"<:Union{Base.ReinterpretArray{var"#s2", var"#s1", var"#s11", var"#s21", IsReshaped} where {var"#s2", var"#s1", var"#s11", var"#s21"<:Union{SubArray{var"#s3", var"#s4", var"#s15", I, L} where {var"#s3", var"#s4", I, L}, var"#s15"}, IsReshaped}, Base.ReshapedArray{var"#s4", var"#s3", var"#s41", MI} where {var"#s4", var"#s3", var"#s41"<:Union{Base.ReinterpretArray{var"#s1", var"#s5", var"#s11", var"#s2", IsReshaped} where {var"#s1", var"#s5", var"#s11", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s15", I, L} where {var"#s3", var"#s4", I, L}, var"#s15"}, IsReshaped}, SubArray{var"#s3", var"#s2", var"#s15", I, L} where {var"#s3", var"#s2", I, L}, var"#s15"}, MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, N} where N}}, var"#s15"}, I, L}, Adjoint{S, var"#s1"} where var"#s1"<:GPUArrays.AbstractGPUArray{S, N}, Diagonal{S, var"#s11"} where var"#s11"<:GPUArrays.AbstractGPUArray{S, N}, LowerTriangular{S, var"#s7"} where var"#s7"<:GPUArrays.AbstractGPUArray{S, N}, Transpose{S, var"#s6"} where var"#s6"<:GPUArrays.AbstractGPUArray{S, N}, Tridiagonal{S, var"#s12"} where var"#s12"<:GPUArrays.AbstractGPUArray{S, N}, UnitLowerTriangular{S, var"#s8"} where var"#s8"<:GPUArrays.AbstractGPUArray{S, N}, UnitUpperTriangular{S, var"#s10"} where var"#s10"<:GPUArrays.AbstractGPUArray{S, N}, UpperTriangular{S, var"#s9"} where var"#s9"<:GPUArrays.AbstractGPUArray{S, N}, PermutedDimsArray{S, N, var"#s4", var"#s3", var"#s2"} where {var"#s4", var"#s3", var"#s2"<:GPUArrays.AbstractGPUArray}} where N, ::Number, ::Number) where {T, S, R} at /home/ubuntu/.julia/packages/GPUArrays/gjXOn/src/host/linalg.jl:102
Stacktrace:
[1] gemm_dispatch!(C::CuArray{Float32, 2}, A::CuArray{Float32, 2}, B::Transpose{Float32, SubArray{Float32, 2, CuArray{Float32, 2}, Tuple{UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}}, alpha::Bool, beta::Bool)
@ CUDA.CUBLAS ~/.julia/packages/CUDA/Px7QU/lib/cublas/linalg.jl:226
[2] mul!
@ ~/.julia/packages/CUDA/Px7QU/lib/cublas/linalg.jl:238 [inlined]
[3] mul!(C::CuArray{Float32, 2}, A::CuArray{Float32, 2}, B::Transpose{Float32, SubArray{Float32, 2, CuArray{Float32, 2}, Tuple{UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}})
@ LinearAlgebra /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:275
[4] top-level scope
@ REPL[74]:1
Just doing a time check on the CPU for corresponding versions we get
a = cu(randn(10, 10))
b = cu(randn(10, 10))
c = similar(a)
@btime @views mul!($(c)[:, 2:end], $(a), transpose($(b)[1:end-1, :]))
# 424.925 ns (0 allocations: 0 bytes)
@btime mul!($(c)[:, 2:end], $(a), transpose($(b)[1:end-1, :]))
# 740.826 ns (2 allocations: 1.59 KiB)
@btime @views mul!($(c)[:, 2:end], $(a), transpose($(b))[:, 1:end-1])
# 890.596 ns (4 allocations: 224 bytes)
so it would be nice if we could have the same version work well for the GPU.