Shape of input arrays for an LSTM in Flux.jl? julia 1.0

I am currently trying to implement an LSTM model for a regression exercise using the flux.jl library. Although it is rather straightforward to build-up the model, I have rather some issues in understanding the right shape of the input arrays to train the model as an error related to the dimensions of the arrays is thrown, so I am not quite sure whether I feed the right input shape to the model, etc… Does somebody have any clues why this is happening?

Here’s the code/error to reproduce the issue for some random data (20 samples, 6 input variables, 1 target variable, sequence length is 100):

using Flux

#Create training and validation sets
x_train, y_train= [rand(6, 100) for i in 1:20], [rand(1, 100) for i in 1:20]
x_valid, y_valid = [rand(6, 100) for i in 1:20], [rand(1, 100) for i in 1:20]

#Define loss function
function mseLoss(x, y)
loss = Flux.mse(model(x), y)
Flux.reset!(model)
return loss
end

#Create initial model
model= Chain(
LSTM(6, 20),
LSTM(20, 20),
LSTM(20, 20),
Dense(20, 1))

#Train model
evalcb = () → @show mseLoss(x_valid, y_valid)
Flux.train!(mseLoss, params(model), zip(x_train, y_train), Flux.ADAM(0.01), cb = Flux.throttle(evalcb, 30))
ERROR: DimensionMismatch(“matrix A has dimensions (80,6), vector B has length 20”)

However, if I remove the call back function from the training routine, there is no error being thrown.

Flux.train!(mseLoss, params(model), zip(x_train, y_train), Flux.ADAM(0.01))

The issue is that your evalcb() function is trying to apply your loss function (mseLoss()) on the entire validation set, instead of a single sample.

For example, if you manually try to call mseLoss() on your validation set, it won’t work:

julia> mseLoss(x_valid, y_valid)
ERROR: DimensionMismatch("matrix A has dimensions (80,6), vector B has length 20")
Stacktrace:

But if you use the dot operator it’s all fine:

julia> mseLoss.(x_valid, y_valid)
20-element Array{Tracker.TrackedReal{Float64},1}:
 0.09329284350031428
 0.08800717849477932
 0.09536389093326372
 0.07287500246038604
 0.07941595462823466
 0.08271627911981938
 0.06494661422789759
 0.090946393431492  
 0.08429430207008291
 0.07520673733085231
 0.08991733401842267
 0.08013483351191553
 0.08518935831116269
 0.08661341277531345
 0.09158514810736933
 0.07756680475821241
 0.10052480072450713
 0.07839699178557202
 0.08355401235047255
 0.08109280394475628

So the easy way to get your callback to work is by using the dot operator:

evalcb = () -> @show mseLoss.(x_valid, y_valid)
Flux.train!(mseLoss, params(model), zip(x_train, y_train), Flux.ADAM(0.01), cb = Flux.throttle(evalcb, 30))

Should work like a charm.

Good luck!

3 Likes