Type stability when having function as struct field

I have defined a struct called ParDemo1 to store some parameters (and do some basic calculations for default values not shown here). It is passed to the function dhdt_demo!, which is the RHS of a system of ODE to be passed to DifferentialEquations.jl:

function ℿdemo(h, par, t)
    return par.A./h.^5 .+ par.W./(par.D .- h).^2
end


@with_kw struct ParDemo1{T1<:Union{Float64,Function}}
    D::Float64 = 5.0
    A::Float64 = 0.0
    W::Float64 = 1.0
    V::T1 = 1.0
    ℿ::Function = ℿdemo
end

function dhdt_demo!(dh, h, par, t)
    isa(par.V, Function) ? V = par.V(t) : V = par.V
    ℿ = par.ℿ(h, par, t)
    @. dh = h^3 * ℿ * V
    return nothing
end

ParDemo1.ℿ is a function with signature ParDemo1.ℿ(h,par,t) and will change depending on the exact problem I’m solving.

Now I have the following codes:

Vfun = t->1+sin(t)
par_demo = ParDemo1(V=Vfun)

h0 = ones(256)
dh = similar(dh)

dhdt_demo!(dh, h0, par_demo, 0)

However, there seems to be some type instability problem:

@btime dhdt_demo!(dh, h0, par_demo, 0)   # 1.910 μs (8 allocations: 2.33 KiB)

@code_warntype dhdt_demo!(dh, h0, par_demo, 0) gives the following output:

MethodInstance for dhdt_demo!(::Vector{Float64}, ::Vector{Float64}, ::ParDemo1{var"#11#12"}, ::Int64)
  from dhdt_demo!(dh, h, par, t) in Main at In[59]:1
Arguments
  #self#::Core.Const(dhdt_demo!)
  dh::Vector{Float64}
  h::Vector{Float64}
  par::ParDemo1{var"#11#12"}
  t::Int64
Locals
  ℿ::Any
  V::Float64
Body::Nothing
1 ─       Core.NewvarNode(:(ℿ))
│         Core.NewvarNode(:(V))
│   %3  = Base.getproperty(par, :V)::Core.Const(var"#11#12"())
│   %4  = (%3 isa Main.Function)::Core.Const(true)
│         Core.typeassert(%4, Core.Bool)
│   %6  = Base.getproperty(par, :V)::Core.Const(var"#11#12"())
│         (V = (%6)(t))
└──       goto #3
2 ─       Core.Const(:(V = Base.getproperty(par, :V)))
3 ┄ %10 = Base.getproperty(par, :ℿ)::Function
│         (ℿ = (%10)(h, par, t))
│   %12 = Core.apply_type(Base.Val, 3)::Core.Const(Val{3})
│   %13 = (%12)()::Core.Const(Val{3}())
│   %14 = Base.broadcasted(Base.literal_pow, Main.:^, h, %13)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(Base.literal_pow), Tuple{Base.RefValue{typeof(^)}, Vector{Float64}, Base.RefValue{Val{3}}}}
│   %15 = ℿ::Any
│   %16 = Base.broadcasted(Main.:*, %14, %15, V)::Any
│         Base.materialize!(dh, %16)
└──       return Main.nothing

Looks like it has trouble inferring the type of the output of ℿdemo (but it is ok with par.V).

If I pass a named tuple instead of my self-defined struct, it seems to be ok:

par_demo2 = (D=5.0, A=0.0, W=1.0, V=Vfun, ℿ=ℿdemo)
@btime dhdt_demo!(dh, h0, par_demo2, 0)   #  1.390 μs (1 allocation: 2.12 KiB)  (any way to completely get rid of allocation?)`

And using code_warntype doesn’t show any ::Any.

Are there any ways to modify my ParDemo1 or other parts of the code so that it becomes type stable? It will be a lot more convenient for my problem if I can use a self-defined struct instead of a named tuple.

1 Like
@with_kw struct ParDemo1{T1<:Union{Float64,Function}, F}
    D::Float64 = 5.0
    A::Float64 = 0.0
    W::Float64 = 1.0
    V::T1 = 1.0
    ℿ::F = ℿdemo
end
2 Likes

It works, thanks. But do you mind briefly explaining the difference (or suggest some relevant webpages/resources on this)?

isa(ℿdemo, Function) returns true, so why does having a parametric type F here makes a difference? And why par.V doesn’t give me this type stability issue even if I specify par.V as T1<:union{Float64,Function} but not just a general T1?

Function is an abstract type. A field of this declared type will always be unstable. But if you declare a field of type F where F<:Function (or just F, equivalent to F where F<:Any), it will specialize on the specific function type F and be stable.

T1<:Union{Float64,Function} is a union of a concrete type Float64 and an abstract Function, so it is not a concrete union. However, it is a type parameter in ParDemo1, rather than an actual field type, so it gets specialized when you actually construct a ParDemo1 struct. When passed the value V=1.0, it creates a ParDemo1{Float64} If you pass function V=foo instead, you create a ParDemo1{typeof(foo)}. In either case, T1 now takes a concrete value (instead of a non-concrete union).

However, if you manually create
par = ParDemo1{Union{Float64,Function}}(1.0,2.0,3.0,sin,cos),
or even
par = ParDemo1{Function}(1.0,2.0,3.0,sin,cos),
you’ll get type instability on par.V because you have explictly forced this instance to have a non-concrete V field.

The fix offered in the post above was to make the type of parametric, rather than an abstract Function like in the original post.

2 Likes

The problem with your initial code is that Function is an abstract type. Type inference fails because the field could be changed at any point to any function. The parameter F allows you to specify a specific concrete type, which must be fixed when specified.

Here is a simple example showing that a function is a subtype of Function.

f(x) = x

typeof(f) <: Function

thinking out loud:

shouldn’t it be possible to have Function{Any, out_type} so that the resultant callable is type stable?

I feel like this is conceptually identical to the function barrier trick

You can certainly annotate the output of the function call if you know it
y = function_of_unknown_type(x)::Float64.
This will still require dynamic dispatch to call the function, but afterwards it will check that it is Float64 and know that it is Float64 thereafter. If the returned value is not Float64, it will throw an error complaining of the fact.

But what you’re asking about sounds more like what you have in C/C++, where a function pointer still has types associated with it. For that, look into FunctionWrappers.jl. But that package is not well documented or widely used, so you’ll likely need to inspect the source code to figure out how to use it. In general, I’ll discourage its use - but it’s there if you want it.