Change of return syntax for loss_function in Optimization.jl

I’m using Optimization.jl to optimize conditions of ODE simulation to a given objective.
Before the latest update (4.0.0), one could use this return syntax

function loss_adjoint(fullp, batch, time_batch)
    pred = predict_adjoint(fullp, time_batch)
    sum(abs2, batch .- pred), pred
end

allowing to retrieve some object, in this case pred at each solver iteration.
In the latest version, it seems this syntax is no longer available

function loss_adjoint(fullp, data)
    batch, time_batch = data
    pred = predict_adjoint(fullp, time_batch)
    sum(abs2, batch .- pred)
end

What would be the proper way to retrieve pred with the newest version ?

Just enclose it or use a global.

What do you mean by enclose it ?

Capture it via a closure or use a callable struct.

I am sorry I do not see what closure to implement in this case. Could you point me to a relevant example in the docs. In that sense, the new Data Iterators and Minibatching · Optimization.jl example is less informative than the former one.

Ok, so I read Common Solver Options (Solve Keyword Arguments) · Optimization.jl and in that case, an additional predict is used twice to retrieve pred in both the loss evaluation and in the callback, using the current OptimizationState.

function predict(u)
    Array(solve(prob, Tsit5(), p = u))
end

function loss(u, p)
    pred = predict(u)
    sum(abs2, batch .- pred), pred
end

callback = function (state, l; doplot = false) #callback function to observe training
    display(l)
    # plot current prediction against data
    if doplot
        pred = predict(state.u)
        pl = scatter(t, ode_data[1, :], label = "data")
        scatter!(pl, t, pred[1, :], label = "prediction")
        display(plot(pl))
    end
    return false
end

This works, but is this what you had in mind ? Calling predict twice seems like a bad design.

pred = Ref{Any}()

function predict(u)
    Array(solve(prob, Tsit5(), p = u))
end

function loss(u, p)
    pred[] = predict(u)
    sum(abs2, batch .- pred[])
end

callback = function (state, l; doplot = false) #callback function to observe training
    display(l)
    # plot current prediction against data
    if doplot
        pl = scatter(t, ode_data[1, :], label = "data")
        scatter!(pl, t, pred[][1, :], label = "prediction")
        display(plot(pl))
    end
    return false
end

Etc. you can optimize different ways of doing it from there.

Thanks @ChrisRackauckas ! Just a side question for which I have no use right now, but for the sake of the discussion, I’m guessing this approach is not compatible with Multistart optimization with EnsembleProblem · Optimization.jl or is it ?

1 Like

You can definitely do this kind of thing with multistart, but if you batch multithread then you need to take that into account in the callbacks