Implementing the Learn rate scheduling in the NeuralPDE julia package

Hi eveyone, I have a query regarding the implementation of the learn rate scheduling in the NeuralPDE frmework during training. I want to progressively reduce the learning rate of the ADAM optimizer using a decay factor. How do I implement this in the NeuralPDE framework, as I did not find an effective way to do it, neither is mentioned anywhere in the documentation…Can anyone please help?

Just take the resulting OptimizationProblem and use it in a loop with different Adam calls like you would in any other machine learning example. What did you try?

But everytime I have to remake the problem right? I think this is computationally intensive as the remake function wasn’t working for me…It was taking a lot of time to remake the problem…I’m using Flux instead of Lux due to the ease of saving the network parameters and continue training from the trained parameters. Currently, what I do is, I do multi-stage training say like 500 iterations with one learning rate and the optimization completes and I start the next round with trained parameters and new learning rate…ADAM is fast but in the last round to fine tune the params, when I switch to LBFGS, the optimization becomes dead slow maybe due to the stiffness in the loss landscape or I’m not able to figure out why this is happening for my case…any help is highly appreciable…
Thanks in advance…

remake is pretty quick as it’s just a pointer change. It should be around 10ns. Can you show what you’re doing?

Is it that the remake just works for Lux networks and not the flux? Here is what I’m exactly doing:-

Create the symbolic problem constructed from the pdeSystem and the discretization strategy to be utilized

prob = NeuralPDE.discretize(pdeSystem,discretization)
res = Optimization.solve(prob, opt, callback = callback, maxiters = numIters-1)
prob = remake(prob, u0 = res.u)
res = Optimization.solve(prob, opt, callback = callback, maxiters = numIters-1)

But here, the networks parameters (res.u) is just a flattened vector as I’m using Flux…

That should work fine with Flux as well. Do you have a quick MWE to look at?