I was wondering if there is a way to access the learning rate through a callback function when using Lux.jl and Optimization.jl packages. For instance in the line below:
I have set the learning rate to 1e-3, and I would like to access it through the callback function to use learning rate decay. I have found the following relevant resources:
However both resources are intended for training their model using a loop, while the training for my code happens in one line, which is the line of code I pasted above. Therefore I would need to access the learning rate through a callback function. Do you have any suggestions?
It should be possible but needs a change in the OptimizationOptimisers wrapper, we currently don’t pass the state as an argument to the callback function but it would be a small change and will then let you do it how it is described in the Flux docs. Can you open an issue in Optimization.jl and I’ll create a PR
In the documentations of Flux.jl regarding learning rate decay, they create a Flux.setup which is then used to update throughout the training as described in the link below.