Dear Michael,
Thank you for giving me some advice for getting better solution.
- In your original ODE that you solved, does it have any time-dependent function (i.e, a function that depends on time, such as an event?) If that’s the case, I’m in the same boat as you – the topic is here.
In my case, there is no time-dependent function as shown in your case.
- Try a lower learning rate for ADAM once you get to the point you’re at now. I’ve had quite a bit of success by doing the initial learning pass with the default learning rate, and after that initial convergence I can often eek out a substantial amount of improvement by dropping the learning rate by a factor or 10 or so. This will take a bit of experimentation.
I thought the function for dropping learning rate have been included by “ADAM(0.001, (0.9, 0.999))” in my code. I will try to modify the code for varying learning rate according to loss value or iteration numbers.
- Use a different activation function. I’ve found that
swish
and σ work better thanrelu
andtanh
in my personal ODEs that I’ve tried to solve.
I will try to use the swish as activation function.
- Try a different model architecture. From personal experience, bigger isn’t always better. Your model is pretty large for an ODE (at least from what I’ve seen). How complicated is the original system model? That should help inform how expressive your model needs to be. Also consider playing around with different layer sizes. 32 → 16 → 8 → 2, etc. This has achieved good results in certain cases for me, where a model with more parameters struggled somewhat to train.
I have already reduced model size, since the original ODE solution contains time-series of mass fraction for many chemical species. I used only mass fraction of hydrogen as solution for learning in the neural ODE system. But I will try to reduce number of data in the mass fraction of hydrogen, and run the neural ODE code.
Does the different layer size mean layer structure described as following?
Chain(
#x -> -tanh.(x*10),
Dense(2,32,swish),
Dense(32,16,swish),
Dense(16,8,swish),
Dense(100,2)
)