You can reshape this (the original question) to just 3 calls of the Dense:
using Flux
den = Flux.Dense(30 => 4)
x = [i + j / 10 for i in 1:10, j in 1:20, b in 1:5]; # 10×20×5
selectdim(x, 2, 1:3); # 10×3×5
x2 = Flux.flatten(selectdim(x, 2, 1:3)) # 30×5
den(x2) # 4×5
# This is what I think your function is, but please provide code which runs:
myguess(den::Dense, sizeOut::Tuple, x; dim=2, plus=2) = mapreduce(
i -> reshape(den(Flux.flatten(selectdim(x, dim, i:i+plus))), sizeOut),
(x1, x2) -> cat(x1, x2; dims=dim),
1:(size(x, dim)-plus)
)
y1 = myguess(den, (4, 1, 5), x); size(y1)
# Original, overlapping windows:
# out[c, i, z] := sum(a,j) w[c, (a, j)], x[a, i+j, z] (j in 1:3)
# Easier problem, non-overlapping:
# out[c, i, z] := sum(a,j) w[c, (a, j)], x[a, j + 3i, z] (j in 1:3)
# reshape:
# out[c, i, z] := sum(aj) w[c, aj], x[aj, 3i, z] (j in 1:3)
# then do 3i, 3i+1, 3i+2 separately:
myfun(den, x) = hcat(
den(reshape(x[:, 1:18, :], :, 6, 5)),
den(reshape(x[:, 2:19, :], :, 6, 5)),
den(reshape(x[:, 3:20, :], :, 6, 5)),
);
y3 = myfun(den, x);
y3 ≈ hcat(y1[:, 1:3:end, :], y1[:, 2:3:end, :], y1[:, 3:3:end, :]) # true
# about 10x faster, gradient 15x
I didn’t try on GPU but it ought to work. And I see now more code, after I started… something like this can surely be generalised. plus will control the number of terms you need.
Mine has the 2nd index in a different order, if that matters of course you could fix it… by indexing afterwards, or perhaps by writing some kind of “inner not outer” variant of cat. The indexing x[:, 1:18, :] etc. could also be done better.