You should be able to use batched_mul
for this, but someone needs to merge #191 first:
julia> using NNlib # v0.7.4
julia> batched_mul(rand(2,3,5), rand(3,4,1))
ERROR: DimensionMismatch("batch size mismatch")
Edit – in reply to a deleted comment.
And xref a recent thread. However, whether this is the bottleneck here or not I’ve no idea.