Hey there,

I am trying to compile a training session of a simple time series prediction example using a LSTM cell within a Lux chain.

My model is

```
model = Lux.Chain(Lux.Dense(2 => 2), Lux.Recurrence(Lux.LSTMCell(2 => 2)), Lux.Dense(2 => 1))
rng = Random.default_rng()
ps, st = Lux.setup(rng, model)
params = ComponentArray(ps)
```

For training, I have produced some dummy data of a dynamical system (single output, two inputs) and the model should be trained to estimate the according timeseries using this loss

```
function loss_LSTM(p)
return sum(abs2, y_target - vcat(map(x -> model(reshape([(i(x)*n(x)).^2/in_sqr_max; (i(x)).^2/i_sqr_max], 2, 1, 1), p, st)[1], tsteps_NODE)...))
end;
```

The prediction routine as well as the loss function work well and I get prediction / loss results within ~1 ms according to BenchmarkTools.jl for my given data set.

The problem starts when I try start optimizing the model, i.e., the ADAM call using the following code snippet takes forever. I have cancled the code execution after 30 minutes:

```
adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, p) -> loss_LSTM(x), adtype)
optprob = Optimization.OptimizationProblem(optf, params)
res = Optimization.solve(optprob, Adam(0.05); maxiters=50, maxtime = 120.0, callback=callback_online)
```

I do not receive any intermediate callback and I guess that the code is somehow in compile mode while not reaching the actual optimization stage. I have used a very similar toolchain for a NODE training using a simple forward MLP within the ODE model formulation which worked perfectly and returned nice results within seconds.

Moreoever, I have the identical problem (code execution takes forever) if I want to calculate the gradient of the loss:

```
gradient(p -> loss_LSTM(p), params)
```

So, I am quite puzzled because the execution of the loss function works very well but the differentiation / optimization seems stuck. Any hint on how to get this optimization going is highly appreciated!