Type instability inside function with function as input argument

The following is part of the code for some FEM calculations:

function assemble_Ab(x, a_fun, c_fun, f_fun, order::Integer=1)
    
    N = length(x)
    
    A = zeros(N,N)
    b = zeros(N)
    
    if order == 1
        n_element = N - 1
        n_quad = 2
        
        weights = [1.0, 1.0]
        abscissas = [-1/sqrt(3), 1/sqrt(3)]
        
        for j in 1:n_element
            
            xl = x[j]
            xr = x[j+1]
            
            for q in 1:n_quad
                
                xq = (xl*(1-abscissas[q]) + xr*(1+abscissas[q]))/2
                wq = weights[q]*(xr-xl)/2
                
                aq = a_fun(xq)
                cq = c_fun(xq)
                fq = f_fun(xq)
                
                cross_term = aq*cq*fq   # do sth. with aq, cq, fq for demo purpose
                # more calculations involving aq, cq, fq for filling A and b

            end
        end
    end
    
    return A, b
end


cfun1(x) = 2 + sin(2*pi*x/L)
afun1(x) = 2*x
ffun1(x) = 2 + 0.5*sin(2*pi*x/L)

x = range(0, 1, length=101);

A1, b1 = assemble_A(x, afun1, cfun1, ffun1, 1);

When I run @code_warntype assemble_A(x, afun1, cfun1, ffun1, 1), I get the following:

MethodInstance for assemble_Ab(::StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}, ::typeof(afun2), ::typeof(cfun1), ::typeof(ffun1), ::Int64)
  from assemble_Ab(x, a_fun, c_fun, f_fun, order::Integer) in Main at In[90]:1
Arguments
  #self#::Core.Const(assemble_Ab)
  x::StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}
  a_fun::Core.Const(afun2)
  c_fun::Core.Const(cfun1)
  f_fun::Core.Const(ffun1)
  order::Int64
Locals
  @_7::Union{Nothing, Tuple{Int64, Int64}}
  abscissas::Vector{Float64}
  weights::Vector{Float64}
  n_quad::Int64
  n_element::Int64
  b::Vector{Float64}
  A::Matrix{Float64}
  N::Int64
  @_15::Union{Nothing, Tuple{Int64, Int64}}
  j::Int64
  xr::Float64
  xl::Float64
  q::Int64
  cross_term::Any     # problematic
  fq::Any     # problematic
  cq::Any     # problematic
  aq::Float64
  wq::Float64
  xq::Float64
Body::Tuple{Matrix{Float64}, Vector{Float64}}
1 ─       Core.NewvarNode(:(@_7))
│         Core.NewvarNode(:(abscissas))
│         Core.NewvarNode(:(weights))
│         Core.NewvarNode(:(n_quad))
│         Core.NewvarNode(:(n_element))
│         (N = Main.length(x))
│         (A = Main.zeros(N, N))
│         (b = Main.zeros(N))
│   %9  = (order == 1)::Bool
└──       goto #8 if not %9
2 ─       (n_element = N - 1)
│         (n_quad = 2)
│         (weights = Base.vect(1.0, 1.0))
│   %14 = Main.sqrt(3)::Float64
│   %15 = (-1 / %14)::Float64
│   %16 = Main.sqrt(3)::Float64
│   %17 = (1 / %16)::Float64
│         (abscissas = Base.vect(%15, %17))
│   %19 = (1:n_element)::Core.PartialStruct(UnitRange{Int64}, Any[Core.Const(1), Int64])
│         (@_7 = Base.iterate(%19))
│   %21 = (@_7 === nothing)::Bool
│   %22 = Base.not_int(%21)::Bool
└──       goto #8 if not %22
3 ┄ %24 = @_7::Tuple{Int64, Int64}
│         (j = Core.getfield(%24, 1))
│   %26 = Core.getfield(%24, 2)::Int64
│         (xl = Base.getindex(x, j))
│   %28 = (j + 1)::Int64
│         (xr = Base.getindex(x, %28))
│   %30 = (1:n_quad::Core.Const(2))::Core.Const(1:2)
│         (@_15 = Base.iterate(%30))
│   %32 = (@_15::Core.Const((1, 1)) === nothing)::Core.Const(false)
│   %33 = Base.not_int(%32)::Core.Const(true)
└──       goto #6 if not %33
4 ┄ %35 = @_15::Tuple{Int64, Int64}
│         (q = Core.getfield(%35, 1))
│   %37 = Core.getfield(%35, 2)::Int64
│   %38 = xl::Float64
│   %39 = Base.getindex(abscissas, q)::Float64
│   %40 = (1 - %39)::Float64
│   %41 = (%38 * %40)::Float64
│   %42 = xr::Float64
│   %43 = Base.getindex(abscissas, q)::Float64
│   %44 = (1 + %43)::Float64
│   %45 = (%42 * %44)::Float64
│   %46 = (%41 + %45)::Float64
│         (xq = %46 / 2)
│   %48 = Base.getindex(weights, q)::Float64
│   %49 = (xr - xl)::Float64
│   %50 = (%48 * %49)::Float64
│         (wq = %50 / 2)
│         (aq = (a_fun)(xq))
│         (cq = (c_fun)(xq))
│         (fq = (f_fun)(xq))
│         (cross_term = aq * cq * fq)
│         (@_15 = Base.iterate(%30, %37))
│   %57 = (@_15 === nothing)::Bool
│   %58 = Base.not_int(%57)::Bool
└──       goto #6 if not %58
5 ─       goto #4
6 ┄       (@_7 = Base.iterate(%19, %26))
│   %62 = (@_7 === nothing)::Bool
│   %63 = Base.not_int(%62)::Bool
└──       goto #8 if not %63
7 ─       goto #3
8 ┄ %66 = Core.tuple(A, b)::Tuple{Matrix{Float64}, Vector{Float64}}
└──       return %66

Looks like it can’t infer the type of fq = f_fun(xq) and cq = c_fun(xq) and they (and any subsequent calculations involving fq and cq become type Any. Any problem with the main function assemble_Ab, or does it have somethin to do with the particular choice of ffun1 and cfun1 (since they have the sin and look more complication while `afun1(x)=2*x certainly looks more simple)?

function assemble_Ab(x, a_fun::A, c_fun::C, f_fun::F, order::Integer=1) where {A, C, F}

I’m guessing you’re hinting at the compiler not automatically specializing for ::Type, ::Function, or ::Vararg arguments, but the docs also say specialization happens when the argument is used beyond being an argument of callee methods. I would think calling a_fun, f_fun, and c_fun counts. What stands out to me is that aq is type-stable despite a_fun being used in the same way as the others, so looking at the input methods:

The clear difference is that cfun1 and ffun1 take an L as a global variable rather than an argument. L’s assignment wasn’t posted, but if it’s not const it’ll cause type instability in any method that accesses it.

Oh yeah… You can’t have non-const global variable I guess that’s a possibility I overlooked