This looks like Use with multiple wrappers · Issue #21 · JuliaGPU/Adapt.jl · GitHub. tl;dr is that one wrapper like transpose(::CuArray)
is usually fine, but two or more result in dispatch not sending it to the CUDA routines, so that instead you get the slow generic_matmatmul!
which works by indexing.
I think this is the multiplication in ordinary Dense. I don’t see this reshape in the trace: https://github.com/FluxML/Flux.jl/blob/15c85908fdc8766c71141c5cee24726b12b0583b/src/layers/basic.jl#L175-L176 Do you know what is creating a reshaped PermutedDimsArray?
Oh now it’s changed, with this:
What types are X and m.f(X)
? batched_mul
may reshape things but always from 2 to 3 dimensions, not 3 to 2.