Improving performance for training Universal Differential Equation

I’m trying to understand the performance bottlenecks in training universal differential equations. Following this guide, benchmarking the training shows that optimizing the parameters of the Lux network requires a large number of allocations:

@btime res1 = Optimization.solve(optprob, ADAM(), maxiters = 100) 
417.338 ms (2205218 allocations: 201.85 MiB)

I’m going to run a larger number of similar optimization and would like to make reduce time to solution for the optimization problem. My problems are low dimensional and I would need to optimize for a large number of given datasets.
Does anyone know how to minimize the number of allocations?

I’m thinking about either using StaticArrays, like described here , but am unsure whether they work with Lux. Another thing I’d like to look into is using SimpleChains but I haven’t figured out yet how to make their interface compatible with Optimization.

Does anyone have experience with optimizing for this situation?

It depends on your model. What does your code look like? Do you have an MWE to play with?

I’m using the missing physics guide as an MWE. Basically I’m wondering if there is a way to keep the number of allocations constant when varying the number of iterations for the optmizer. I’m seeing that they increase linearly with number of steps:

@btime Optimization.solve(optprob, ADAM(), maxiters=1);
  3.591 ms (21770 allocations: 2.00 MiB)
@btime Optimization.solve(optprob, ADAM(), maxiters=2);
 7.202 ms (43489 allocations: 4.01 MiB)
@btime Optimization.solve(optprob, ADAM(), maxiters=3);
 11.022 ms (65808 allocations: 6.04 MiB)

You’d have to do things like cache the ODE solver, which would be difficult with Zygote and require setting it all up with Enzyme.