Errors with Flux RNN set

Hi everyone, I’m currently working on a final project that involves developing an RNN to represent a parameterized dynamic model. I have been able to do this fairly easily with a feedforward ANN in Flux but I have had major issues when trying to develop a recurrent model.

Before developing the actual model, I have been playing around with a simple model to try to get something to work before I move to actually start the project. My data has 5 features, 1 target, and a sequence length of 20. I generate the data myself so I have full control over the number of batches.

My current setup is as follows:

n_feat = 5;
n_batches = 10; % 10 batches arbitrarily chosen for test model
seq_len = 20; 
n_targ =1;

m = Chain(RNN(n_feat, n_hidden), Dense(n_hidden, n_targ))  

function loss(x, y)
     Flux.reset!(m)
     sum(mse(m(xi), yi) for (xi, yi) in zip(x, y))
end

x = [rand(Float32, n_feat, n_batches) for i = 1:seq_len] % random data used to build test model
y = [rand(Float32, n_targ, n_batches) for i = 1:seq_len]
data = zip(x,y)

Flux.reset!(m)
ps = params(m)
opt= ADAM(1e-3)
Flux.train!(loss, ps, data, opt) 

When I run this I get the following error:

MethodError: no method matching (::Flux.RNNCell{typeof(tanh), Matrix{Float32}, Vector{Float32}, Matrix{Float32}})(::Matrix{Float32}, ::Vector{Matrix{Float32}}) 

Closest candidates are:   (::Flux.RNNCell{F, A, V, var"#s263"} where var"#s263"<:AbstractMatrix{T})(::Any, ::Union{AbstractMatrix{T}, AbstractVector{T}, Flux.OneHotArray}) where {F, A, V, T} 

I am not sure how to proceed from here. My intuition tells me I probably am somehow preparing the data wrong such that the type is not what train! expects. Any guidance would be greatly appreciated!

The comment character should be a # not a %.

You didn’t define n_hidden, I added n_hidden=10 to test.

Your loss function expects to be passed the entire dataset at once, you can call loss(x, y) and get a value. Flux.train iterates over the value provided for data and calls either loss(d) (or loss(d...) if d is a tuple) with each element. With the data shapes in your example, that results in passing scalars to your model since you have now zipped twice (once creating data and once inside loss). I suspect what you meant to do is:

function loss(x, y)
     Flux.reset!(m)
     Flux.mse(m(x), y)
end

Which works as expected in the train! loop.

You can test this with

loss(first(zip(x, y))...)