Background
I am trying to implement a custom layer for a Flux.jl
network (please point out if this exists already somewhere). What the layer should do is apply a 1x1x1 convolution to the concatenation of a bunch of tensors. To simplify the example slightly, lets work with 1d data.
xs -> Conv((1,), in => out)(cat(xs..., dims=2))
To save memory the concatenation should not be materialized.
My code
using CUDA
function linmerge1d(A, xs)
setup = prepare_linmerge1d(A,xs)
@cuda threads=256 linmerge1d_kernel!(setup.out, setup.A, setup.xs, setup.offsets)
synchronize()
setup.out
end
function linmerge1d_kernel!(out, A, xs, offsets)
dim_space = 1
dim_channel = 2
dim_batch = 3
index = threadIdx().x
stride = blockDim().x
for ispace in index:stride:size(out,dim_space)
for (x, offset) in zip(xs, offsets)
for ibatch in axes(x,dim_batch)
for icin in axes(x, dim_channel)
for icout in axes(out, dim_channel)
@inbounds out[ispace, icout, ibatch] += A[icout, offset+icin] * x[ispace, icin, ibatch]
end
end
end
end
end
return nothing
end
prepare_linmerge1d(A, xs) = prepare_linmerge1d(A, Tuple(xs))
@noinline function prepare_linmerge1d(A, xs::Tuple)
dim_space = 1
dim_channel = 2
dim_batch = 3
nb = size(first(xs), dim_batch)
nspace = size(first(xs), dim_space)
A = CuArray(A)
xs = map(xs) do x
@assert ndims(x) === 3
@assert size(x,1) === nspace
@assert size(x,3) === nb
CuArray(x)
end
lens = map(x -> size(x, dim_channel), xs)
stops = cumsum(lens)
ncout = size(A, 1)
@assert size(A,2) == last(stops)
out = CUDA.zeros(nspace, ncout, nbatch)
offsets = map(lens, stops) do len, stop
stop - len
end
return (out=out, A=A,xs=xs, offsets=offsets)
end
nspace = 10^6
nbatch = 4
ncout = 4
xs = (
randn(nspace, 4, nbatch),
randn(nspace, 4, nbatch),
randn(nspace, 4, nbatch),
randn(nspace, 4, nbatch),
randn(nspace, 4, nbatch),
)
A = randn(ncout, sum(x -> size(x,2), xs))
xs = map(CuArray, xs)
A = CuArray(A)
#prepare_linmerge1d(A, xs)
using BenchmarkTools
@btime linmerge1d(A, xs) # 315.913 ms (31 allocations: 1.09 KiB)
If I am not mistaken, this function is only 1e6*4*5*4*4 = 3.2e8
fma instructions. So the performance here is worse, than what I would expect from a single cpu thread. Any ideas?