Hello to all,
Following the description of batching in Flux in here.
I have created a script with Lux
:
using Lux, Optimization, OptimizationOptimisers, OptimizationOptimJL, OrdinaryDiffEq, SciMLSensitivity, ComponentArrays
using StableRNGs
import MLUtils: DataLoader
function newtons_cooling(du, u, p, t)
temp = u[1]
k, temp_m = p
du[1] = dT = -k * (temp - temp_m)
end
function true_sol(du, u, p, t)
true_p = [log(2) / 8.0, 100.0]
newtons_cooling(du, u, true_p, t)
end
rng = StableRNG(1111)
ann = Lux.Chain(Lux.Dense(1, 8, tanh), Lux.Dense(8, 1, tanh))
pp, st = Lux.setup(rng, ann)
function dudt_(u, p, t)
ann(u,p,st)[1] .* u
end
callback = function (p, l) #callback function to observe training
display(l)
return false
end
u0 = [200.0]
datasize = 30
tspan = (0.0f0, 1.5f0)
t = range(tspan[1], tspan[2], length = datasize)
true_prob = ODEProblem(true_sol, u0, tspan)
ode_data = Array(solve(true_prob, Tsit5(), saveat = t))
prob = ODEProblem{false}(dudt_, u0, tspan, pp)
function predict_adjoint(fullp, time_batch)
Array(solve(prob, Tsit5(), p = fullp, saveat = time_batch))
end
function loss_adjoint(fullp, batch, time_batch)
pred = predict_adjoint(fullp, time_batch)
sum(abs2, batch .- pred)
end
k = 10
# Pass the data for the batches as separate vectors wrapped in a tuple
train_loader = DataLoader((ode_data, t), batchsize = k)
numEpochs = 300
optfun = OptimizationFunction((θ, p, batch, time_batch) -> loss_adjoint(θ, batch,
time_batch),
Optimization.AutoZygote())
optprob = OptimizationProblem(optfun, ComponentArray{Float64}(pp))
using IterTools: ncycle
res1 = Optimization.solve(optprob, Optimisers.ADAM(0.05), ncycle(train_loader, numEpochs),
callback = callback)
optprob2 = Optimization.OptimizationProblem(optfun, res1.u)
numEpochsLBFGS=100
res2 = Optimization.solve(optprob, Optim.LBFGS(), ncycle(train_loader, numEpochsLBFGS),
callback = callback)
The script will perform the 300 ADAM
iterations but only 2 BFGS
iterations. Why?
How can I force it to do the 100 iterations?
Best Regards