I am trying to implement a custom layer for a
Flux.jl network (please point out if this exists already somewhere). What the layer should do is apply a 1x1x1 convolution to the concatenation of a bunch of tensors. To simplify the example slightly, lets work with 1d data.
xs -> Conv((1,), in => out)(cat(xs..., dims=2))
To save memory the concatenation should not be materialized.
using CUDA function linmerge1d(A, xs) setup = prepare_linmerge1d(A,xs) @cuda threads=256 linmerge1d_kernel!(setup.out, setup.A, setup.xs, setup.offsets) synchronize() setup.out end function linmerge1d_kernel!(out, A, xs, offsets) dim_space = 1 dim_channel = 2 dim_batch = 3 index = threadIdx().x stride = blockDim().x for ispace in index:stride:size(out,dim_space) for (x, offset) in zip(xs, offsets) for ibatch in axes(x,dim_batch) for icin in axes(x, dim_channel) for icout in axes(out, dim_channel) @inbounds out[ispace, icout, ibatch] += A[icout, offset+icin] * x[ispace, icin, ibatch] end end end end end return nothing end prepare_linmerge1d(A, xs) = prepare_linmerge1d(A, Tuple(xs)) @noinline function prepare_linmerge1d(A, xs::Tuple) dim_space = 1 dim_channel = 2 dim_batch = 3 nb = size(first(xs), dim_batch) nspace = size(first(xs), dim_space) A = CuArray(A) xs = map(xs) do x @assert ndims(x) === 3 @assert size(x,1) === nspace @assert size(x,3) === nb CuArray(x) end lens = map(x -> size(x, dim_channel), xs) stops = cumsum(lens) ncout = size(A, 1) @assert size(A,2) == last(stops) out = CUDA.zeros(nspace, ncout, nbatch) offsets = map(lens, stops) do len, stop stop - len end return (out=out, A=A,xs=xs, offsets=offsets) end nspace = 10^6 nbatch = 4 ncout = 4 xs = ( randn(nspace, 4, nbatch), randn(nspace, 4, nbatch), randn(nspace, 4, nbatch), randn(nspace, 4, nbatch), randn(nspace, 4, nbatch), ) A = randn(ncout, sum(x -> size(x,2), xs)) xs = map(CuArray, xs) A = CuArray(A) #prepare_linmerge1d(A, xs) using BenchmarkTools @btime linmerge1d(A, xs) # 315.913 ms (31 allocations: 1.09 KiB)
If I am not mistaken, this function is only
1e6*4*5*4*4 = 3.2e8 fma instructions. So the performance here is worse, than what I would expect from a single cpu thread. Any ideas?