Learning rate decay in callback function

I was wondering if there is a way to access the learning rate through a callback function when using Lux.jl and Optimization.jl packages. For instance in the line below:

Optimization.solve(optprob, ADAM(1e-3), callback = callback, progress = true, maxiters = sw)

I have set the learning rate to 1e-3, and I would like to access it through the callback function to use learning rate decay. I have found the following relevant resources:


However both resources are intended for training their model using a loop, while the training for my code happens in one line, which is the line of code I pasted above. Therefore I would need to access the learning rate through a callback function. Do you have any suggestions?

1 Like

It should be possible but needs a change in the OptimizationOptimisers wrapper, we currently don’t pass the state as an argument to the callback function but it would be a small change and will then let you do it how it is described in the Flux docs. Can you open an issue in Optimization.jl and I’ll create a PR

In the documentations of Flux.jl regarding learning rate decay, they create a Flux.setup which is then used to update throughout the training as described in the link below.


The optimization part of my code is written as the following:

adtype = Optimization.AutoZygote()

optf = Optimization.OptimizationFunction((x, p) -> loss(x,u0), adtype)

optprob = Optimization.OptimizationProblem(optf, ComponentArray(p))

opt = Adam(1e-2)

res1 = Optimization.solve(optprob, opt, callback = callback)

Meaning that I do not have any Flux.setup. Additionally, I am using Lux.jl.

What is equivalent to opt_state in the code snippet I have provided above to be able to perform learning rate decay through Lux.adjust!(…)?

This works now Optimization.jl/lib/OptimizationOptimisers/test/runtests.jl at master · SciML/Optimization.jl · GitHub

1 Like