How to code a type-flexible model that is type stable?

Hi. I want to write a model where I can swap between different number formats (Float32, Float64…), and at the same time achieve type stability. I have been looking at ShallowWaters.jl for hints how to do this best. I would like to store my model variables in a struct, including the type for the number format. Here is a minimal example with different tests for what I want to do:

using Parameters
using BenchmarkTools

"""Struct containing model variables"""
@with_kw mutable struct States{T<:AbstractFloat}

    ftype=T

    state1::T=123
    state2::T=234

end

"""Option 1: This function uses flexible number format but is type instable"""
function model_flexible_type_instable(indata)

    val1 = indata.state1 * indata.state2 * indata.ftype(1.2)
    return val1 * rand(indata.ftype,1)[1]

end

"""Option 2: This function uses hard-coded number format (not what I want) and is type stable"""
function model_fixed_type_stable(indata)

    val1 = indata.state1 * indata.state2 * Float32(1.2)
    return val1 * rand(Float32,1)[1]

end

"""Option 3: This function uses the number format as input (a bit inconvinient maybe) and is type stable"""
function model_type_as_input_type_stable(::Type{T}, indata) where {T<:Real}

    val1 = indata.state1 * indata.state2 * T(1.2)
    return val1 * rand(T,1)[1]

end

states = States{Float32}()

@btime model_flexible_type_instable(states)
@code_warntype model_flexible_type_instable(states)

@btime model_fixed_type_stable(states)
@code_warntype model_fixed_type_stable(states)

@btime model_type_as_input_type_stable(states.ftype, states)
@code_warntype model_type_as_input_type_stable(states.ftype, states)

Option 3 is the only one that works well while Option 1 does not and Option 2 has the number format hard-coded. Is there a way to get Option 1 to work? Or is there a better alternative than Option 3? Thanks for any input before I start using this for a real code.

You probably want something like this (untested):

Base.@kwdef mutable struct States{T<:AbstractFloat}
    state1::T=123
    state2::T=234
end

# option 4
function model_type_as_input_type_stable(indata::States{T}) where {T<:Real}
    val1 = indata.state1 * indata.state2 * T(1.2)
    return val1 * rand(T,1)[1]
end

At first glance, the problem with option 1 is that you get the type as a value (a field of your struct), which prevents type inference from working properly. All the compiler knows is that indata.ftype isa DataType, but that’s not enough to handle indata.ftype(1.2) efficiently. Does that make sense?

1 Like

That seems to do the trick, and is also faster than option 3. I think this will work with my real code. Many thanks for the help. Awesome.

1 Like