Background
I am trying to implement a custom layer for a Flux.jl network (please point out if this exists already somewhere). What the layer should do is apply a 1x1x1 convolution to the concatenation of a bunch of tensors. To simplify the example slightly, lets work with 1d data.
xs -> Conv((1,), in => out)(cat(xs..., dims=2))
To save memory the concatenation should not be materialized.
My code
using CUDA
function linmerge1d(A, xs)
    setup = prepare_linmerge1d(A,xs)
    @cuda threads=256 linmerge1d_kernel!(setup.out, setup.A, setup.xs, setup.offsets)
    synchronize()
    setup.out
end
function linmerge1d_kernel!(out, A, xs, offsets)
    dim_space = 1
    dim_channel = 2
    dim_batch = 3
    index = threadIdx().x
    stride = blockDim().x
    for ispace in index:stride:size(out,dim_space)
        for (x, offset) in zip(xs, offsets)
            for ibatch in axes(x,dim_batch)
                for icin in axes(x, dim_channel)
                    for icout in axes(out, dim_channel)
                        @inbounds out[ispace, icout, ibatch] += A[icout, offset+icin] * x[ispace, icin, ibatch]
                    end
                end
            end
        end
    end
    return nothing
end
prepare_linmerge1d(A, xs) = prepare_linmerge1d(A, Tuple(xs))
@noinline function prepare_linmerge1d(A, xs::Tuple)
    dim_space = 1
    dim_channel = 2
    dim_batch = 3
    
    nb = size(first(xs), dim_batch)
    nspace = size(first(xs), dim_space)
    A = CuArray(A)
    xs = map(xs) do x
        @assert ndims(x) === 3
        @assert size(x,1) === nspace
        @assert size(x,3) === nb
        CuArray(x)
    end
    lens = map(x -> size(x, dim_channel), xs)
    stops = cumsum(lens)
    ncout = size(A, 1)
    @assert size(A,2) == last(stops)
    out = CUDA.zeros(nspace, ncout, nbatch)
    
    offsets = map(lens, stops) do len, stop
        stop - len
    end
    return (out=out, A=A,xs=xs, offsets=offsets)
end
nspace = 10^6
nbatch = 4
ncout = 4
xs = (
    randn(nspace, 4, nbatch),
    randn(nspace, 4, nbatch),
    randn(nspace, 4, nbatch),
    randn(nspace, 4, nbatch),
    randn(nspace, 4, nbatch),
)
A = randn(ncout, sum(x -> size(x,2), xs))
xs = map(CuArray, xs)
A = CuArray(A)
#prepare_linmerge1d(A, xs)
using BenchmarkTools
@btime linmerge1d(A, xs) #   315.913 ms (31 allocations: 1.09 KiB)
If I am not mistaken, this function is only 1e6*4*5*4*4 = 3.2e8 fma instructions. So the performance here is worse, than what I would expect from a single cpu thread. Any ideas?