CuArrays using @views+mul!+transpose+slicing

Hi, had a problem when using CuArrays for a specific operation. It is optimized for CPU and it would be nice if it worked with the same code on GPU so there was no need to dispatch specifically on it.

a = cu(randn(10, 10))
b = cu(randn(10, 10))
c = similar(a)

@views mul!(c[:, 2:end], a, transpose(b[1:end-1, :])) # Would like this but it errors
@views mul!(c[:, 2:end], a, transpose(b)[:, 1:end-1]) # Works fine
mul!(c[:, 2:end], a, transpose(b[1:end-1, :])) # Works fine

The error from the GPU version is

ERROR: MethodError: no method matching generic_matmatmul!(::CuArray{Float32, 2}, ::CuArray{Float32, 2}, ::Transpose{Float32, SubArray{Float32, 2, CuArray{Float32, 2}, Tuple{UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}}, ::Bool, ::Bool)
Closest candidates are:
  generic_matmatmul!(::Union{GPUArrays.AbstractGPUArray{R, N}, Base.LogicalIndex{R, var"#s5"} where var"#s5"<:GPUArrays.AbstractGPUArray, Base.ReinterpretArray{R, N, var"#s1", var"#s2", IsReshaped} where {var"#s13"<:GPUArrays.AbstractGPUArray, var"#s1", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s13", I, L} where {var"#s3", var"#s4", I, L}, var"#s13"}, IsReshaped}, Base.ReshapedArray{R, N, var"#s4", MI} where {var"#s14"<:GPUArrays.AbstractGPUArray, var"#s4"<:Union{Base.ReinterpretArray{var"#s1", var"#s5", var"#s11", var"#s2", IsReshaped} where {var"#s1", var"#s5", var"#s11", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s14", I, L} where {var"#s3", var"#s4", I, L}, var"#s14"}, IsReshaped}, SubArray{var"#s3", var"#s2", var"#s14", I, L} where {var"#s3", var"#s2", I, L}, var"#s14"}, MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, N} where N}}, SubArray{R, N, var"#s5", I, L} where {var"#s15"<:GPUArrays.AbstractGPUArray, var"#s5"<:Union{Base.ReinterpretArray{var"#s2", var"#s1", var"#s11", var"#s21", IsReshaped} where {var"#s2", var"#s1", var"#s11", var"#s21"<:Union{SubArray{var"#s3", var"#s4", var"#s15", I, L} where {var"#s3", var"#s4", I, L}, var"#s15"}, IsReshaped}, Base.ReshapedArray{var"#s4", var"#s3", var"#s41", MI} where {var"#s4", var"#s3", var"#s41"<:Union{Base.ReinterpretArray{var"#s1", var"#s5", var"#s11", var"#s2", IsReshaped} where {var"#s1", var"#s5", var"#s11", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s15", I, L} where {var"#s3", var"#s4", I, L}, var"#s15"}, IsReshaped}, SubArray{var"#s3", var"#s2", var"#s15", I, L} where {var"#s3", var"#s2", I, L}, var"#s15"}, MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, N} where N}}, var"#s15"}, I, L}, Adjoint{R, var"#s1"} where var"#s1"<:GPUArrays.AbstractGPUArray{R, N}, Diagonal{R, var"#s11"} where var"#s11"<:GPUArrays.AbstractGPUArray{R, N}, LowerTriangular{R, var"#s7"} where var"#s7"<:GPUArrays.AbstractGPUArray{R, N}, Transpose{R, var"#s6"} where var"#s6"<:GPUArrays.AbstractGPUArray{R, N}, Tridiagonal{R, var"#s12"} where var"#s12"<:GPUArrays.AbstractGPUArray{R, N}, UnitLowerTriangular{R, var"#s8"} where var"#s8"<:GPUArrays.AbstractGPUArray{R, N}, UnitUpperTriangular{R, var"#s10"} where var"#s10"<:GPUArrays.AbstractGPUArray{R, N}, UpperTriangular{R, var"#s9"} where var"#s9"<:GPUArrays.AbstractGPUArray{R, N}, PermutedDimsArray{R, N, var"#s4", var"#s3", var"#s2"} where {var"#s4", var"#s3", var"#s2"<:GPUArrays.AbstractGPUArray}} where N, ::Union{GPUArrays.AbstractGPUArray{T, N}, Base.LogicalIndex{T, var"#s5"} where var"#s5"<:GPUArrays.AbstractGPUArray, Base.ReinterpretArray{T, N, var"#s1", var"#s2", IsReshaped} where {var"#s13"<:GPUArrays.AbstractGPUArray, var"#s1", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s13", I, L} where {var"#s3", var"#s4", I, L}, var"#s13"}, IsReshaped}, Base.ReshapedArray{T, N, var"#s4", MI} where {var"#s14"<:GPUArrays.AbstractGPUArray, var"#s4"<:Union{Base.ReinterpretArray{var"#s1", var"#s5", var"#s11", var"#s2", IsReshaped} where {var"#s1", var"#s5", var"#s11", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s14", I, L} where {var"#s3", var"#s4", I, L}, var"#s14"}, IsReshaped}, SubArray{var"#s3", var"#s2", var"#s14", I, L} where {var"#s3", var"#s2", I, L}, var"#s14"}, MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, N} where N}}, SubArray{T, N, var"#s5", I, L} where {var"#s15"<:GPUArrays.AbstractGPUArray, var"#s5"<:Union{Base.ReinterpretArray{var"#s2", var"#s1", var"#s11", var"#s21", IsReshaped} where {var"#s2", var"#s1", var"#s11", var"#s21"<:Union{SubArray{var"#s3", var"#s4", var"#s15", I, L} where {var"#s3", var"#s4", I, L}, var"#s15"}, IsReshaped}, Base.ReshapedArray{var"#s4", var"#s3", var"#s41", MI} where {var"#s4", var"#s3", var"#s41"<:Union{Base.ReinterpretArray{var"#s1", var"#s5", var"#s11", var"#s2", IsReshaped} where {var"#s1", var"#s5", var"#s11", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s15", I, L} where {var"#s3", var"#s4", I, L}, var"#s15"}, IsReshaped}, SubArray{var"#s3", var"#s2", var"#s15", I, L} where {var"#s3", var"#s2", I, L}, var"#s15"}, MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, N} where N}}, var"#s15"}, I, L}, Adjoint{T, var"#s1"} where var"#s1"<:GPUArrays.AbstractGPUArray{T, N}, Diagonal{T, var"#s11"} where var"#s11"<:GPUArrays.AbstractGPUArray{T, N}, LowerTriangular{T, var"#s7"} where var"#s7"<:GPUArrays.AbstractGPUArray{T, N}, Transpose{T, var"#s6"} where var"#s6"<:GPUArrays.AbstractGPUArray{T, N}, Tridiagonal{T, var"#s12"} where var"#s12"<:GPUArrays.AbstractGPUArray{T, N}, UnitLowerTriangular{T, var"#s8"} where var"#s8"<:GPUArrays.AbstractGPUArray{T, N}, UnitUpperTriangular{T, var"#s10"} where var"#s10"<:GPUArrays.AbstractGPUArray{T, N}, UpperTriangular{T, var"#s9"} where var"#s9"<:GPUArrays.AbstractGPUArray{T, N}, PermutedDimsArray{T, N, var"#s4", var"#s3", var"#s2"} where {var"#s4", var"#s3", var"#s2"<:GPUArrays.AbstractGPUArray}} where N, ::Union{GPUArrays.AbstractGPUArray{S, N}, Base.LogicalIndex{S, var"#s5"} where var"#s5"<:GPUArrays.AbstractGPUArray, Base.ReinterpretArray{S, N, var"#s1", var"#s2", IsReshaped} where {var"#s13"<:GPUArrays.AbstractGPUArray, var"#s1", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s13", I, L} where {var"#s3", var"#s4", I, L}, var"#s13"}, IsReshaped}, Base.ReshapedArray{S, N, var"#s4", MI} where {var"#s14"<:GPUArrays.AbstractGPUArray, var"#s4"<:Union{Base.ReinterpretArray{var"#s1", var"#s5", var"#s11", var"#s2", IsReshaped} where {var"#s1", var"#s5", var"#s11", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s14", I, L} where {var"#s3", var"#s4", I, L}, var"#s14"}, IsReshaped}, SubArray{var"#s3", var"#s2", var"#s14", I, L} where {var"#s3", var"#s2", I, L}, var"#s14"}, MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, N} where N}}, SubArray{S, N, var"#s5", I, L} where {var"#s15"<:GPUArrays.AbstractGPUArray, var"#s5"<:Union{Base.ReinterpretArray{var"#s2", var"#s1", var"#s11", var"#s21", IsReshaped} where {var"#s2", var"#s1", var"#s11", var"#s21"<:Union{SubArray{var"#s3", var"#s4", var"#s15", I, L} where {var"#s3", var"#s4", I, L}, var"#s15"}, IsReshaped}, Base.ReshapedArray{var"#s4", var"#s3", var"#s41", MI} where {var"#s4", var"#s3", var"#s41"<:Union{Base.ReinterpretArray{var"#s1", var"#s5", var"#s11", var"#s2", IsReshaped} where {var"#s1", var"#s5", var"#s11", var"#s2"<:Union{SubArray{var"#s3", var"#s4", var"#s15", I, L} where {var"#s3", var"#s4", I, L}, var"#s15"}, IsReshaped}, SubArray{var"#s3", var"#s2", var"#s15", I, L} where {var"#s3", var"#s2", I, L}, var"#s15"}, MI<:Tuple{Vararg{Base.MultiplicativeInverses.SignedMultiplicativeInverse{Int64}, N} where N}}, var"#s15"}, I, L}, Adjoint{S, var"#s1"} where var"#s1"<:GPUArrays.AbstractGPUArray{S, N}, Diagonal{S, var"#s11"} where var"#s11"<:GPUArrays.AbstractGPUArray{S, N}, LowerTriangular{S, var"#s7"} where var"#s7"<:GPUArrays.AbstractGPUArray{S, N}, Transpose{S, var"#s6"} where var"#s6"<:GPUArrays.AbstractGPUArray{S, N}, Tridiagonal{S, var"#s12"} where var"#s12"<:GPUArrays.AbstractGPUArray{S, N}, UnitLowerTriangular{S, var"#s8"} where var"#s8"<:GPUArrays.AbstractGPUArray{S, N}, UnitUpperTriangular{S, var"#s10"} where var"#s10"<:GPUArrays.AbstractGPUArray{S, N}, UpperTriangular{S, var"#s9"} where var"#s9"<:GPUArrays.AbstractGPUArray{S, N}, PermutedDimsArray{S, N, var"#s4", var"#s3", var"#s2"} where {var"#s4", var"#s3", var"#s2"<:GPUArrays.AbstractGPUArray}} where N, ::Number, ::Number) where {T, S, R} at /home/ubuntu/.julia/packages/GPUArrays/gjXOn/src/host/linalg.jl:102
Stacktrace:
 [1] gemm_dispatch!(C::CuArray{Float32, 2}, A::CuArray{Float32, 2}, B::Transpose{Float32, SubArray{Float32, 2, CuArray{Float32, 2}, Tuple{UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}}, alpha::Bool, beta::Bool)
   @ CUDA.CUBLAS ~/.julia/packages/CUDA/Px7QU/lib/cublas/linalg.jl:226
 [2] mul!
   @ ~/.julia/packages/CUDA/Px7QU/lib/cublas/linalg.jl:238 [inlined]
 [3] mul!(C::CuArray{Float32, 2}, A::CuArray{Float32, 2}, B::Transpose{Float32, SubArray{Float32, 2, CuArray{Float32, 2}, Tuple{UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}})
   @ LinearAlgebra /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.6/LinearAlgebra/src/matmul.jl:275
 [4] top-level scope
   @ REPL[74]:1

Just doing a time check on the CPU for corresponding versions we get

a = cu(randn(10, 10))
b = cu(randn(10, 10))
c = similar(a)
@btime @views mul!($(c)[:, 2:end], $(a), transpose($(b)[1:end-1, :]))
  # 424.925 ns (0 allocations: 0 bytes)
@btime mul!($(c)[:, 2:end], $(a), transpose($(b)[1:end-1, :]))
  # 740.826 ns (2 allocations: 1.59 KiB)
@btime @views mul!($(c)[:, 2:end], $(a), transpose($(b))[:, 1:end-1])
  # 890.596 ns (4 allocations: 224 bytes)

so it would be nice if we could have the same version work well for the GPU.

This should fix it: https://github.com/JuliaGPU/GPUArrays.jl/pull/352

Do note that you’re dispatching to the fallback implementation here, which is slow. It’s better to try and dispatch to one of the CUBLAS methods. The reason it doesn’t here, is that your type is too complex:

julia> typeof(@views(transpose(b[1:end-1, :])))
Transpose{Float32, SubArray{Float32, 2, CuArray{Float32, 2}, Tuple{UnitRange{Int64}, Base.Slice{Base.OneTo{Int64}}}, false}}

julia> @views(transpose(b[1:end-1, :])) isa StridedCuVecOrMat
false

Also:

@views mul!(c[:, 2:end], a, transpose(b)[:, 1:end-1]) # Works fine

That one is even worse, as it falls back to the Base implementation. Always run with CUDA.allowscalar(false).

Sorry, not sure I’m completely following here. You are saying that with the PR you linked it will work, but it will not be very optimized for GPU?

So this is part of a function call where a and b are parameters to the function (and currently the shape has to be in that way so that b needs transposing) and c is created to store the result. I would like this function to be agnostic to the arrays used (don’t want to have CUDA as a dependency, but it would be nice if it was possible to send in CuArrays) and be fast for both CPU and GPU if possible. The current call is what is optimized for the CPU, and if possible it would be nice if the same call was handled at least okay for CuArrays.

I have no real clue about GPU programing, and that is why CuArrays are so cool, a lot of stuff just works and it works fast and well. I don’t really know how I would check when a CUBLAS method is used or not, and how I would try to target that?

With that PR, you’ll end up executing the generic GEMM method from GPUArrays, which is OK but slow. The second operation, which you mentioned ‘works fine’ actually doesn’t and triggers scalar iteration, which is extremely slow and should be avoided. Optionally, if you want a really fast GPU execution, you need to make sure your arrays are recognized as strided GPU arrays so that we can dispatch to the CUBLAS library. That involves making sure the memory is contiguous, and that you’re not using too many array wrappers (because of how Julia’s array hierarchy is currently designed, it’s hard to recognize GPU arrays when they are wrapped a bunch).

1 Like

The “Works fine” was more of a “This case does not error” than it actually doing something good. I was more curious as to why the method I wanted to use did not work at all, but some others that seemed very similar “worked fine”.

But thanks for the quick response, will have a look at how I can solve this.