I’ve realized that I had forgotten to change the output of the neural network to 1 (my state variable, wr) in:
chain = Lux.Chain(
Lux.Dense(1, n, Lux.σ),
Lux.Dense(n, n, Lux.σ),
Lux.Dense(n, n, Lux.σ),
Lux.Dense(n, 1)
)
The code runs correctly if that is changed, along with the modification in datasets
mentioned in my previous comment and the input interpolation mentioned by Chris Rackauckas. There’s probably some issue with this implementation as the results are not the expected ones, but that’s another story. I will open another topic to continue the discussion and add the link below.