Better way to name or inherit parameters with NonlinearSolve.jl?

I am using NonlinearSolve.jl for simple interval root finding. It is working well, but the way I am handling parameters is sticking out to me like a code smell.

  1. Is there a better way to pack and unpack the parameter vector p? Doing it by position seems error prone. Can I pass a NamedTuple somehow?
  2. Is there a way to just have f inherit the parameters from the for loop rather than passing them explicitly?
  3. Does it matter if the f function definition is inside or outside the main function body?

I’m still new to all of this, so any guidance is appreciated.

Full implementation below:
(Within f and the for loop, all variables are scalar Float64.)

function find_proportional_limit(table::DataFrame, searchrange::Tuple=(1.0, 1e6))

    # Define the non-linear function to be solved.
    function f(u, p)
        # Solution Variable
        σ_t = u

        # Parameters
        σ_ys = p[1]
        σ_uts = p[2]
        K = p[3]
        m_1 = p[4]
        m_2 = p[5]
        A_1 = p[6]
        A_2 = p[7]
        ϵ_p = p[8]

        # Equations
        H = KM620.H(σ_t, σ_ys, σ_uts, K)
        ϵ_1 = KM620.ϵ_1(σ_t, A_1, m_1)
        ϵ_2 = KM620.ϵ_2(σ_t, A_2, m_2)
        γ_1 = KM620.γ_1(ϵ_1, H)
        γ_2 = KM620.γ_2(ϵ_2, H)

        return γ_1 + γ_2 - ϵ_p  # = 0 (Eq. KM-620.2)
    end

    # Find the root of the f function for each temperature in the data frame.
    σ_p = Float64[]
    for row in eachrow(table)
        p = [
            row.σ_ys,
            row.σ_uts,
            row.K,
            row.m_1,
            row.m_2,
            row.A_1,
            row.A_2,
            row.ϵ_p,
        ]
        problem = IntervalNonlinearProblem(f, searchrange, p)
        solution = solve(problem)
        push!(σ_p, solution.u)
    end
    return σ_p
end

Regarding 1: I think NamedTuple works well. But if not, then you could try ComponentVector from ComponentArrays.jl.

Regarding 2: It is better not to inherit the parameter values. You can use the following syntax to unpack parameters like this in one line (;σ_ys, σ_uts) = p, that way it is only one or two lines extra. I would recommend that.

However, if you really don’t want to pass the values via the parameters, you could capture the parameters by defining f directly inside the for loop. But that will also trigger re-compilations. That’s why it is better to define f as a nice function which only depends on the inputs u and p.

Regarding 3: Here, it doesn’t matter, as f doesn’t refer to anything from the outer scopes. However, if you decide to capture some variables from the outside, than it matters a lot. Since inside the for-loop it might be efficiently captures, but outside it would use the global scope which is slow.

(I don’t know if it exists, but maybe you can use remake like for ODEProblems to avoid creating NonlinearProblems inside a for-loop.)

1 Like

You should be able to pass in any form of p AFAIR. p doesn’t really matter from the root-finding context, it is passed unmodified to the function. The only consideration is if you want to differentiate wrt these parameters, then it needs to be:

  1. NamedTuple / AbstractArray – For Zygote
  2. AbstractArray – For ForwardDiff

With SciMLStructures we will be able to consolidate the interfaces I think.

1 Like

Much better!

I must have read this and thought nothing else would work. The documentation examples index u by position, so I think I assumed I would have to do the same with p.

In my case, the solver seems to work the same passing directly my DataFrameRow row. remake also seems to work well even though the docstring for remake does not list IntervalNonlinearProblem as an option.

function find_proportional_limit(table::DataFrame, searchrange::Tuple=(1.0, 1e6))

    # Define the non-linear function to be solved.
    function f(u, p)
        # Inputs
        σ_t = u  # Solution Variable
        (; σ_ys, σ_uts, K, m_1, m_2, A_1, A_2, ϵ_p) = p  # Parameters

        # Equations
        H = KM620.H(σ_t, σ_ys, σ_uts, K)
        ϵ_1 = KM620.ϵ_1(σ_t, A_1, m_1)
        ϵ_2 = KM620.ϵ_2(σ_t, A_2, m_2)
        γ_1 = KM620.γ_1(ϵ_1, H)
        γ_2 = KM620.γ_2(ϵ_2, H)

        return γ_1 + γ_2 - ϵ_p  # = 0 (Eq. KM-620.2)
    end

    # Find the root of the f function for each temperature in the data frame.
    σ_p = Float64[]
    problem = IntervalNonlinearProblem(f, searchrange)
    for row in eachrow(table)
        problem = remake(problem, p=row)
        solution = solve(problem)
        push!(σ_p, solution.u)
    end

    return σ_p
end
1 Like