Workaround for subarray/selectdim on gpu or how does Flux.Conv do it?

You can reshape this (the original question) to just 3 calls of the Dense:

using Flux
den = Flux.Dense(30 => 4)
x = [i + j / 10 for i in 1:10, j in 1:20, b in 1:5]; # 10×20×5
selectdim(x, 2, 1:3);  # 10×3×5
x2 = Flux.flatten(selectdim(x, 2, 1:3))  # 30×5
den(x2)  # 4×5

# This is what I think your function is, but please provide code which runs:
myguess(den::Dense, sizeOut::Tuple, x; dim=2, plus=2) = mapreduce(
    i -> reshape(den(Flux.flatten(selectdim(x, dim, i:i+plus))), sizeOut),
    (x1, x2) -> cat(x1, x2; dims=dim),
    1:(size(x, dim)-plus)
)
y1 = myguess(den, (4, 1, 5), x); size(y1)

# Original, overlapping windows:
# out[c, i, z] := sum(a,j) w[c, (a, j)], x[a, i+j, z]  (j in 1:3)

# Easier problem, non-overlapping:
# out[c, i, z] := sum(a,j) w[c, (a, j)], x[a, j + 3i, z]  (j in 1:3)
# reshape:
# out[c, i, z] := sum(aj) w[c, aj], x[aj, 3i, z]  (j in 1:3)
# then do 3i, 3i+1, 3i+2 separately:

myfun(den, x) = hcat(
    den(reshape(x[:, 1:18, :], :, 6, 5)),
    den(reshape(x[:, 2:19, :], :, 6, 5)),
    den(reshape(x[:, 3:20, :], :, 6, 5)),
);
y3 = myfun(den, x);
y3 ≈ hcat(y1[:, 1:3:end, :], y1[:, 2:3:end, :], y1[:, 3:3:end, :])  # true

# about 10x faster, gradient 15x

I didn’t try on GPU but it ought to work. And I see now more code, after I started… something like this can surely be generalised. plus will control the number of terms you need.

Mine has the 2nd index in a different order, if that matters of course you could fix it… by indexing afterwards, or perhaps by writing some kind of “inner not outer” variant of cat. The indexing x[:, 1:18, :] etc. could also be done better.

1 Like