Fixing type stability involving functions

I have the following codes

using Parameters

function spatiotemporal(x, t, a_coefs::AbstractMatrix, b_coefs::AbstractMatrix)

    ns = size(a_coefs,1) - 1
    nt = Int((size(a_coefs,2) - 1)/2)

    c = zero(eltype(a_coefs))

    for m = 0:ns, r = -nt:nt
        a, b = a_coefs[m+1,r+nt+1], b_coefs[m+1,r+nt+1]
        c += a*cos(2*pi*(m*x + r*t)) + b*sin(2*pi*(m*x + r*t))
    end

    return c
end

@with_kw struct Test3{T1,T2}
    V::T1
    x::T2 = range(0, 1, length=129)[1:end-1]
    ns::Int64 = 10
    nt::Int64 = 10
end

function dhdt_1!(dh, h, par, t)
    if isa(par.V, Function)
        V = par.V(t)
    else
        V = par.V
    end
    
    dh .= V .+ 1.0
    return nothing
end

function costFun!(C, par)
    ns, nt = par.ns, par.nt
    n_total = (ns+1)*(2*nt+1)
    a_coefs = reshape(C[1:n_total], ns+1, 2*nt+1)
    b_coefs = reshape(C[n_total+1:end], ns+1, 2*nt+1)
    
    par2 = @set par.V = t -> spatiotemporal.(par1.x, t, (a_coefs,), (b_coefs,))
    
    # solves ODE using dhdt_1!(dh, h, par2, t)
end

Basically, I’m doing an optimization over the cost function costFun! where in each iteration, the unknown variable C which contains the Fourier coefficients is used to construct a function stored in par2.V via the general function spatiotemporal. The struct par2 which now contains the newly created function with the updated Fourier coefficients is used to solve a stiff ODE, where the results will be used to compute the cost function.

However, it looks like there are type stability issues with par.V. With the following:

a_coefs = [1. 2. 3.; 3. 4. 5.]
b_coefs = [1. 2. 3.; 4. 5. 6.]

h0 = ones(128)
dh = similar(dh)
par1 = Test3(V = t->1.0)
par2 = @set par1.V = t -> spatiotemporal.(par1.x, t, (a_coefs,), (b_coefs,))

Using par1 is good after checking with @code_warntype dhdt_1!(dh, h0, par1, 0.0). However, for par2 where par2.V is a function constructed using a_coefs and b_coefs, @code_warntype dhdt_1!(dh, h0, par2, 0.0) gives

MethodInstance for dhdt_1!(::Vector{Float64}, ::Vector{Float64}, ::Test3{var"#61#62", StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}}, ::Int64)
  from dhdt_1!(dh, h, par, t) in Main at In[102]:10
Arguments
  #self#::Core.Const(dhdt_1!)
  dh::Vector{Float64}
  h::Vector{Float64}
  par::Test3{var"#61#62", StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}}
  t::Int64
Locals
  V::Any
Body::Nothing
1 ─      Core.NewvarNode(:(V))
│   %2 = Base.getproperty(par, :V)::Core.Const(var"#61#62"())
│   %3 = (%2 isa Main.Function)::Core.Const(true)
│        Core.typeassert(%3, Core.Bool)
│   %5 = Base.getproperty(par, :V)::Core.Const(var"#61#62"())
│        (V = (%5)(t))
└──      goto #3
2 ─      Core.Const(:(V = Base.getproperty(par, :V)))
3 ┄ %9 = Base.broadcasted(Main.:+, V, 1.0)::Any
│        Base.materialize!(dh, %9)
└──      return Main.nothing

V is type Any, so it looks like it can’t infer the type of a_coefs and b_coefs. In my actual problem, a_coefs and b_coefs will come from C passed to the cost function.

So are there any ways to fix this type stability issues?

If the whole a_coefs... par2 = ... block is wrapped into a local scope, I have no type instability warning on my machine.
The issue is, I think, that the closure in par2 uses global variables.