I’m using Optimization.jl to optimize conditions of ODE simulation to a given objective.
Before the latest update (4.0.0), one could use this return syntax
function loss_adjoint(fullp, batch, time_batch)
pred = predict_adjoint(fullp, time_batch)
sum(abs2, batch .- pred), pred
end
allowing to retrieve some object, in this case pred
at each solver iteration.
In the latest version, it seems this syntax is no longer available
function loss_adjoint(fullp, data)
batch, time_batch = data
pred = predict_adjoint(fullp, time_batch)
sum(abs2, batch .- pred)
end
What would be the proper way to retrieve pred
with the newest version ?
Just enclose it or use a global.
What do you mean by enclose it ?
Capture it via a closure or use a callable struct.
I am sorry I do not see what closure to implement in this case. Could you point me to a relevant example in the docs. In that sense, the new Data Iterators and Minibatching · Optimization.jl example is less informative than the former one.
Ok, so I read Common Solver Options (Solve Keyword Arguments) · Optimization.jl and in that case, an additional predict
is used twice to retrieve pred
in both the loss evaluation and in the callback, using the current OptimizationState
.
function predict(u)
Array(solve(prob, Tsit5(), p = u))
end
function loss(u, p)
pred = predict(u)
sum(abs2, batch .- pred), pred
end
callback = function (state, l; doplot = false) #callback function to observe training
display(l)
# plot current prediction against data
if doplot
pred = predict(state.u)
pl = scatter(t, ode_data[1, :], label = "data")
scatter!(pl, t, pred[1, :], label = "prediction")
display(plot(pl))
end
return false
end
This works, but is this what you had in mind ? Calling predict
twice seems like a bad design.
pred = Ref{Any}()
function predict(u)
Array(solve(prob, Tsit5(), p = u))
end
function loss(u, p)
pred[] = predict(u)
sum(abs2, batch .- pred[])
end
callback = function (state, l; doplot = false) #callback function to observe training
display(l)
# plot current prediction against data
if doplot
pl = scatter(t, ode_data[1, :], label = "data")
scatter!(pl, t, pred[][1, :], label = "prediction")
display(plot(pl))
end
return false
end
Etc. you can optimize different ways of doing it from there.
Thanks @ChrisRackauckas ! Just a side question for which I have no use right now, but for the sake of the discussion, I’m guessing this approach is not compatible with Multistart optimization with EnsembleProblem · Optimization.jl or is it ?
1 Like
You can definitely do this kind of thing with multistart, but if you batch multithread then you need to take that into account in the callbacks