How to pass additional parameters to curve_fit in Julia?

Hi,
I have a dataset that I would like to fit to my function pbase(x,pars) with Julia’s curve_fit, whereby pars is supposed to contain a non-fitting parameter that controls which functional model is chosen. The function looks like this:

function pbase(d,pars)
    typ=pars[1]
    pimb=fill(-5.81,length(d))
    @. x=(length(pars) == 3) ? pars[2]*(log10(d)-log10(pars[3])) : pars[2]*log10(d)
    if typ == "tanh"
        @. pimb*=tanh(x)
    elseif typ == "arctan"
        @. pimb*=atan(x)
    end
    pimb
end

i.e., par[1] contains the selector that is not to be changed during fitting and par[2:3] contain the actual fitting parameters. Is this possible at all?

I don’t think curve_fit is a default Julia function. Without knowing which optimization package you are using, I can’t really give you a specific answer, but you can probably wrap your function in another one as follows:

function helper(d,pars)
    typ = your_selector_value_here #hardcode your selector value here
    new_pars = vcat(typ,pars) #merge selector and mutable params
    pbase(d,new_pars) #call your original function with selector and mutable params
end

Anonymous functions and closures / lexical scoping are your friends here:

p1 = fixed_parameter
curve_fit((x, p) -> pbase(x, [p1; p]), xdata, ydata, p0)

(Assuming you are using the curve_fit function from LsqFit.jl.)

In my experience, “how do I pass additional parameters” is by far the most common question for every numerical function that takes a function as a parameter (e.g. numerical integration, root finding, minimization, differential equations), also called higher-order functions. I think it’s a basic gap in computer-science education for people doing computational science.

2 Likes

Thanks for the replies. Steven’s suggestion basically settled it for me. Yes, I am using LsqFit.jl.
I’ll add for completeness and clarity (I hope) that p1 has to be a vector (at least if there is more than one fixed parameter) and that the compiler took issue with the line

@. x=(length(pars) == 3) ? pars[2]*(log10(d)-log10(pars[3])) : pars[2]*log10(d)

from my function. Changing that into an if clause resolved that problem.