Least squares with SciML NonlinearSolve.jl - how to?

OK, now I got it.

using NonlinearSolve, Plots

function curve_fit(model, u0, xdata, ydata, p)
    data = (xdata, ydata, p)

    function lossfn!(du, u, data)
        (xs, ys, p) = data   
        du .= model.(xs, Ref(u), Ref(p)) .- ys
        return nothing
    end

    prob = NonlinearLeastSquaresProblem(
        NonlinearFunction(lossfn!, resid_prototype = similar(ydata)), u0, data)
    sol = solve(prob)
    u = sol.u
    fit = model.(xdata, Ref(u), Ref(p))
    return (;sol, fit)
end

define a model:

function expmodel(x, u, t₀=0)
    y0 = u[1]
    a = u[2]
    τ = u[3]
    return y0 + a * exp(-(x-t₀)/τ)
end

synthesize data:

y0, a, τ, t₀ = 0.0, 1.0, 2.0, 0
xs = 0:0.1:3;
ys0 = expmodel.(xs, Ref([y0, a, τ]), Ref(t₀));
ys = @. ys0 + 0.02*randn()

Solve, plot:

(;sol, fit) = curve_fit(expmodel, [y0, a, τ ], xs, ys, t₀)

plot(xs, [ys, fit])


and enjoy :slight_smile:

2 Likes