What is the Flux.flatten inverse operation?

My input data is multivariate time series, so the input array is width x length x batch. It appears a Dense layer requires the data to be in a single dimension, so I thought this is what Flux.flatten is for. On the output side, I want to get it back out in the same dimensions as the input. It looks like Flux.unsqueeze might be intended to be that inverse operation, but it doesn’t result in the right shape.

m3 = cat([1 2 3; 1 2 3], [4 5 6; 4 5 6]; dims=3) # 2x3x2 array
Flux.unsqueeze(Flux.flatten(m3); dims=1)    # 1x6x2
Flux.unsqueeze(Flux.flatten(m3); dims=2)    # 6x1x2
Flux.unsqueeze(Flux.flatten(m3); dims=3)    # 6x2x1

It appears that I could use reshape:

reshape(reshape(m3, (2*3, 2)), (2,3,2))

but it seemed odd to me there wouldn’t be an inverse operation.

If I understand your question correctly, there isn’t an inverse operation because flattening destroys some shape and dimension information. So unless you keep that around for later, there’s no real way to “unflatten”. And if you do keep it around, reshape is the most concise way to restore the original shape :slight_smile: .

Not quite. Per the docs, we allow for an arbitrary number of batch dimensions after the first (input) dimension. If this sounds familiar, that’s because it’s exactly what PyTorch does too.

I’m not sure whether you consider length to be an input dimension for your problem (i.e. # of features == width x length), but if not then you could pass your input directly to Dense and it will preserve all 3 dimensions.

1 Like

I thought that might be the case, but wanted to ask in case I was missing something.

Right, I saw that but that would confuse me to use batch dimensions for input dimensions. I’ll stick with reshape.

Thanks!

1 Like