Lux.jl LSTM timeseries prediction: AD / optimization does not start

Hey there,

I am trying to compile a training session of a simple time series prediction example using a LSTM cell within a Lux chain.

My model is

model = Lux.Chain(Lux.Dense(2 => 2), Lux.Recurrence(Lux.LSTMCell(2 => 2)), Lux.Dense(2 => 1))
rng = Random.default_rng()
ps, st = Lux.setup(rng, model)
params = ComponentArray(ps) 

For training, I have produced some dummy data of a dynamical system (single output, two inputs) and the model should be trained to estimate the according timeseries using this loss

function loss_LSTM(p)
    return sum(abs2, y_target - vcat(map(x -> model(reshape([(i(x)*n(x)).^2/in_sqr_max; (i(x)).^2/i_sqr_max], 2, 1, 1), p, st)[1], tsteps_NODE)...))
end;

The prediction routine as well as the loss function work well and I get prediction / loss results within ~1 ms according to BenchmarkTools.jl for my given data set.

The problem starts when I try start optimizing the model, i.e., the ADAM call using the following code snippet takes forever. I have cancled the code execution after 30 minutes:

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_LSTM(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(optprob, Adam(0.05);   maxiters=50, maxtime = 120.0, callback=callback_online)

I do not receive any intermediate callback and I guess that the code is somehow in compile mode while not reaching the actual optimization stage. I have used a very similar toolchain for a NODE training using a simple forward MLP within the ODE model formulation which worked perfectly and returned nice results within seconds.

Moreoever, I have the identical problem (code execution takes forever) if I want to calculate the gradient of the loss:

gradient(p -> loss_LSTM(p), params)

So, I am quite puzzled because the execution of the loss function works very well but the differentiation / optimization seems stuck. Any hint on how to get this optimization going is highly appreciated!

I guess I found the issue: the map()call within the loss function

function loss_LSTM(p)
    return sum(abs2, y_target - vcat(map(x -> model(reshape([(i(x)*n(x)).^2/in_sqr_max; (i(x)).^2/i_sqr_max], 2, 1, 1), p, st)[1], tsteps_NODE)...))
end;

seemed to conflict with the AD calculation. Providing the input data as a simple pre-calculated array solved the issue.

1 Like

Hi,
Could you please elaborate this one bit?
" Providing the input data as a simple pre-calculated array solved the issue."
Did you mean calculate and reshape explicitly and then pass to the network?