Batch trainning with Lux and multiple optimizers

Hello to all,

Following the description of batching in Flux in here.

I have created a script with Lux:

using Lux, Optimization, OptimizationOptimisers, OptimizationOptimJL, OrdinaryDiffEq, SciMLSensitivity, ComponentArrays

using StableRNGs
import MLUtils: DataLoader

function newtons_cooling(du, u, p, t)
    temp = u[1]
    k, temp_m = p
    du[1] = dT = -k * (temp - temp_m)
end

function true_sol(du, u, p, t)
    true_p = [log(2) / 8.0, 100.0]
    newtons_cooling(du, u, true_p, t)
end

rng = StableRNG(1111)

ann = Lux.Chain(Lux.Dense(1, 8, tanh), Lux.Dense(8, 1, tanh))
pp, st = Lux.setup(rng, ann)

function dudt_(u, p, t)
    ann(u,p,st)[1] .* u
end

callback = function (p, l) #callback function to observe training
    display(l)
    return false
end

u0 = [200.0]
datasize = 30
tspan = (0.0f0, 1.5f0)

t = range(tspan[1], tspan[2], length = datasize)
true_prob = ODEProblem(true_sol, u0, tspan)
ode_data = Array(solve(true_prob, Tsit5(), saveat = t))

prob = ODEProblem{false}(dudt_, u0, tspan, pp)

function predict_adjoint(fullp, time_batch)
    Array(solve(prob, Tsit5(), p = fullp, saveat = time_batch))
end

function loss_adjoint(fullp, batch, time_batch)
    pred = predict_adjoint(fullp, time_batch)
    sum(abs2, batch .- pred)
end

k = 10
# Pass the data for the batches as separate vectors wrapped in a tuple
train_loader = DataLoader((ode_data, t), batchsize = k)

numEpochs = 300

optfun = OptimizationFunction((θ, p, batch, time_batch) -> loss_adjoint(θ, batch,
        time_batch),
    Optimization.AutoZygote())
optprob = OptimizationProblem(optfun, ComponentArray{Float64}(pp))
using IterTools: ncycle

res1 = Optimization.solve(optprob, Optimisers.ADAM(0.05), ncycle(train_loader, numEpochs),
    callback = callback)

optprob2 = Optimization.OptimizationProblem(optfun, res1.u)

numEpochsLBFGS=100
res2 = Optimization.solve(optprob, Optim.LBFGS(), ncycle(train_loader, numEpochsLBFGS),
callback = callback)

The script will perform the 300 ADAM iterations but only 2 BFGS iterations. Why?
How can I force it to do the 100 iterations?

Best Regards

BFGS stops when it detects a local minima.

Pass allow_f_increases = true, i.e. BFGS(allow_f_increases = true).

1 Like

Pass allow_f_increases = true, i.e. BFGS(allow_f_increases = true).

Doing

res2 = Optimization.solve(optprob, Optim.LBFGS(allow_f_increases = true), ncycle(train_loader, numEpochsLBFGS),
callback = callback)

Given the following error:

ERROR: MethodError: no method matching LBFGS(; allow_f_increases::Bool)

Closest candidates are:
  LBFGS(; m, alphaguess, linesearch, P, precondprep, manifold, scaleinvH0) got unsupported keyword argument "allow_f_increases"
   @ Optim ~/.julia/packages/Optim/dBGGV/src/multivariate/solvers/first_order/l_bfgs.jl:122
  LBFGS(::Int64, ::IL, ::L, ::T, ::Tprep, ::Manifold, ::Bool) where {T, IL, L, Tprep} got unsupported keyword argument "allow_f_increases"
   @ Optim ~/.julia/packages/Optim/dBGGV/src/multivariate/solvers/first_order/l_bfgs.jl:81

Doing:

res2 = Optimization.solve(optprob, Optim.LBFGS(), ncycle(train_loader, numEpochsLBFGS),
callback = callback, allow_f_increases = true)

Runs, but only does 2 iterations again.

Sorry, that was my bad. It’s an option to solve. See:

  • allow_f_increases: Allow steps that increase the objective value. Defaults to false. Note that, when setting this to true, the last iterate will be returned as the minimizer even if the objective increased.
1 Like

If I define it as a solve parameter, I only get 2 iterations which is not the intended behavior. Is there another switch that needs to be defined in order to force 100 iterations?

@Vaibhavdixit02 @SebastianM-C is this the weird inner vs outer iterations thing?

Don’t think so since that was with box constraints. I’ll run the code and check what’s the issue here

What does the res2.original say? It may give you a clue on the reason for stopping early.

It says:

 * Status: success (objective increased between iterations)

 * Candidate solution
    Final objective value:     4.154752e+03

 * Found with
    Algorithm:     L-BFGS

 * Convergence measures
    |x - x'|               = 0.00e+00 ≤ 0.0e+00
    |x - x'|/|x'|          = 0.00e+00 ≤ 0.0e+00
    |f(x) - f(x')|         = 3.62e+03 ≰ 0.0e+00
    |f(x) - f(x')|/|f(x')| = 8.71e-01 ≰ 0.0e+00
    |g(x)|                 = 6.64e+04 ≰ 1.0e-08

 * Work counters
    Seconds run:   1  (vs limit Inf)
    Iterations:    1
    f(x) calls:    1078
    ∇f(x) calls:   1078
1 Like

Looks like the optimizer declared convergence. You have multiple function calls but it still considers it one iteration. If I understand correctly, the definition of the iteration here depends on the details of the line search used.

1 Like

How can it be force to do 100 iterations?

Any suggestions? How can LBFGS be forced to perform N iterations in a batch?

I’m not sure you can force the optimizer to do more work after it finished / declared convergence. Why would you need that?

Cost function is too high… It should reduce more. Also it sometimes does not perform a full sweep over a batch. (it does not use LBFGS in all the data)

You can try different linesearch methods. The default for BFGS type methods is very conservative (for solid theoretical reasons) but in practice much simpler methods such as quadratic backtracking or even fixed size steps can often perform better. I don’t recall the usage off the top of my head - look at LineSearches.jl

1 Like