Question about Flux.jl data inputs


I’m playing around with Flux.jl, and I came across an error that I have some questions about.

With a Dense layer, I can provide a matrix input in which each column represents a single input value. Additionally, the number of columns can change in subsequent calls. For example,

m = Dense(3, 2)
m(rand(3, 10))
m(rand(3, 2))

is perfectly fine.

Similarly, with RNN and LSTM (and presumably with other recurrent networks), I can provide a matrix input in which each column represents a single input value. However, the number of columns must remain the same in subsequent calls. (However, vectors seem to be treated differently.) For example,

m = RNN(3, 2)
m(rand(3, 10))
m(rand(3, 2))                    # ERROR: DimensionMismatch

fails with a DimensionMismatch error.

The error originally occurred because I was using training and test sets in matrix form with different numbers of columns. It seems that the issue can be sidestepped with something like the following

columns(A::Matrix) = (@view(A[:, j]) for j = 1 : size(A, 2))
m.(columns(rand(3, 10))
m.(columns(rand(3, 2))

if the data is in matrix form.

It is unclear to me why the error occurs with recurrent networks. Should I not be using matrices as I described? Should there not be an error with the recurrent networks? Is the code working with the Dense network as an artifact of the implementation? Is there a recommended shape for the data? For instance, other deep-learning frameworks seem to insist that data be coerced into a tensor (e.g., with specific dimensions corresponding to things such as batches and timesteps).

I apologize if I’ve missed something simple somewhere, but I’ve been getting some cryptic errors, and this seems to be the root cause of one of them.


I think the problem is that you are trying to put various batch sizes into the RNN layer one after the other (ie, batches of size 10, 1 and 2), but the Recur object returned from RNN stores the hidden state (including number of batches used) so the batch size must match all previous batches.

The reason it “works” for a batch size of 1 (ie, vector input) is probably an unhappy accident of the implementation using broadcast.

If you want to try your model with a different batch size, you’ll have to reset the hidden state with Flux.reset!(m).