[ANN] LightSumTypes.jl v4

Hi all,

I’d like to share with you that version 3 of DynamicSumTypes.jl is now out! The package allows to enclose multiple types into one with the @sumtype macro and it is almost like working with a Union of types, but with no dynamic dispatch, which can improve by a lot the performance of some programs requiring multiple types. Since it doesn’t require anything too special, an existing package based on multiple types should require a minimum amount of restructuring to be able to use it. The best part is probably that the macro doesn’t manipulate syntax that much, it really just wraps already defined types e.g. as a very simple example

struct A end
struct B end
@sumtype AT(A,B)

here, A and B are defined outside of the macro. I didn’t see a package taking this approach in this space. This is powerful because we can then just use multiple dispatch by accessing the variant of the sum type, no need for any other special macro.

Please see the ReadMe for a walkthrough on how it works and some micro and macro benchmarks of the approach.

I already made a previous announcement post of this package, but this is a new much better methodology. While the previous implementation was a bit too hackish, the new one is so simple, just 50 lines of code for the core of it (removing 1000 lines of hand crafted code gave me some mixed emotions though :melting_face:). Also the interface is very simple, it is for the most part straightforward Julia code, while also being very performant, it should compete with SumTypes.jl, Virtual.jl or Unityper.jl and at least in my tests it was actually faster than them.

20 Likes

Does the performance degrades if I have a large number of subtypes, say 100+?

Let’s find it out @liuyxpp. This is the performance of a toy model in respect to 1 type with Union and @sumtype, while the number of types changes the number of operations performed is always the same

As you can see, the optimization breaks somewhat between 20 and 30 for some reason

4 Likes

I understood the issue somewhat, the model operates on mutable types and by profiling I saw that the setproperty! of @sumtype becomes a lot slower from 26 types on. Don’t understand why at the moment. getproperty is instead very fast even at 100 types, so the performance drop shouldn’t affect immutable structs.

ok, cracked it, simply solved by inlining the function, now

I will push the changes soon, I think that maybe inlining big switches can have some compile time increase effects, so maybe an :aggressive version in the macro could do that, but not done by default in the future.

8 Likes

Wow! Appears quite general, can this in principle replace the Union handling in Julia itself?

1 Like

Yes, actually I think so! But maybe this stresses the compiler too much in very complicated programs with fundamental types, I’m not sure, just guessing why it hasn’t been done already actually (even if this is the case, as an option though seems legit?)

Notice that Julia does something like this in some cases I think, but unreliably, I have some very simple programs where Union is able to almost match @sumtype performance and others like the one I’m showing where it is not.

Not a compiler expert so to be able to understand how to intervene though unfortunately even if I would like :laughing:

5 Likes

Wow, I should immediately adopt DynmaicSumTypes.jl to one of my project. It is amazing!

2 Likes

The drawback compared to “real” tagged unions supported by the compiler would still be that the union field needs a boxed storage, right? I’m not sure how well Julia optimizes in the face of such union fields.

No, Julia does not require boxed storage if you have a finite number of concrete types in the sum type.

Here’s an example with SumTypes.jl, but I assume the same is true of DynamicSumTypes.jl as well:

julia> using SumTypes

julia> @sum_type Foo begin
           A(::Int)
           B(::Float64)
           C(::ComplexF64, ::Int8)
           D(::Char)
           E(::Int128)
           F(::Float32)
           G(::Char, ::Char, ::Char)
           H(::Tuple{UInt8, UInt8})
       end

julia> @btime let i = rand(1:8)
           if i == 1
               A(i)
           elseif i == 2
               B(i/2)
           elseif i == 3
               C(i/2 + im, Int8(i))
           elseif i == 4
               D(Char(i+1))
           elseif i == 5
               E(Int128(i^5))
           elseif i == 6
               F(Float32(i*2/3))
           elseif i ==7
               G(Char(i), Char(i+1), Char(i-1))
           elseif i == 8
               H((i%UInt8, i%UInt8))
           end
       end
  10.169 ns (0 allocations: 0 bytes)
B(1.0)::Foo
6 Likes

just to confirm, on my PC:

Benchmark
julia> using BenchmarkTools, DynamicSumTypes

julia> struct A
           x::Int
       end

julia> struct B
           x::Float64
       end

julia> struct C
           x::ComplexF64
           y::Int8
       end

julia> struct D
           x::Char
       end

julia> struct E
           x::Int128
       end

julia> struct F
           x::Float32
       end

julia> struct G
           x::Char
           y::Char
           z::Char
       end

julia> struct H
           x::Tuple{UInt8, UInt8}
       end

julia> @sumtype S(A,B,C,D,E,F,G,H)
S

julia> @btime let i = rand(1:8)
                  if i == 1
                      S(A(i))
                  elseif i == 2
                      S(B(i/2))
                  elseif i == 3
                      S(C(i/2 + im, Int8(i)))
                  elseif i == 4
                      S(D(Char(i+1)))
                  elseif i == 5
                      S(E(Int128(i^5)))
                  elseif i == 6
                      S(F(Float32(i*2/3)))
                  elseif i ==7
                      S(G(Char(i), Char(i+1), Char(i-1)))
                  elseif i == 8
                      S(H((i%UInt8, i%UInt8)))
                  end
              end

11.715 ns (0 allocations: 0 bytes)

maybe interestingly this less verbose version also works fine

Benchmark 2
@btime rand(
           S.((
            A(i), 
            B(i/2), 
            C(i/2 + im, Int8(i)), 
            D(Char(i+1)), 
            E(Int128(i^5)), 
            F(Float32(i*2/3)),
            G(Char(i), Char(i+1), Char(i-1)), 
            H((i%UInt8, i%UInt8))
           ))
       ) setup = (i = rand(1:8))

11.794 ns (0 allocations: 0 bytes)

2 Likes

Let me know your experience if you can, I think this will let us see if the approach is really robust, I would bet it is because it passed all of my tests, but external verification would be even better probably :slight_smile:

1 Like

Does this zero allocation behavior also apply to storage in vectors? So can those unions be stored contiguously in memory?

1 Like

Yep

2 Likes

It was cool to implement this package and it works quite well on Julia <=1.10, it really patches Union-splitting there.

However, I’m seeing an incredible performance improvement in Union-splitting in 1.11rc1 for some reason, it seems to almost always work automatically which is really cool, e.g. all the benchmarks related to it here Union splitting vs C++ are okay now, no dynamic dispatch anywhere.

By extensive testing I found some generators having dynamic dispatch, but not more than that. There are some insights from this package which can maybe lead to some more improvements, e.g. actually in some cases the storage required is much less than what Julia currently achieves with multiples types, but Union-splitting works quite well now! So if you can update to Julia 1.11, even if I think this approach has some more guarantees, I would just first try to work with a Union :slight_smile:

The battle is indeed much harsher there:

Code
# The following simple model has a variable number of agent types,
# but there is no removing or creating of additional agents.
# It creates a model that has the same number of agents and does
# overall the same number of operations, but these operations
# are split in a varying number of agents. It shows how much of a
# performance hit is to have many different agent types.

using Agents, DynamicSumTypes, Random, BenchmarkTools

tn = 100
agent_types_s = [Symbol(:Agent, i) for i in 1:tn]


for (i, t) in enumerate(agent_types_s)
    eval(:(@agent struct $t(GridAgent{2})
               money::Any
           end))
end

for i in 1:tn
    eval(:(@sumtype $(Symbol(:AgentAll, i))($(agent_types_s[1:i]...)) <: AbstractAgent))
end

const agent_types = Tuple(eval.(agent_types_s))
const agent_all_t = NamedTuple(Symbol(:AgentAll, i) => eval(Symbol(:AgentAll, i)) for i in 1:tn)

function initialize_model_1(;n_agents=600,dims=(5,5))
    space = GridSpace(dims)
    model = StandardABM(Agent1, space; agent_step!,
                        scheduler=Schedulers.Randomly(),
                        rng = Xoshiro(42), warn=false)
    id = 0
    for id in 1:n_agents
        add_agent!(Agent1, model, 10)
    end
    return model
end

function initialize_model_sum(;n_agents=600, n_types=1, dims=(5,5))
    agents_used = agent_types[1:n_types]
    agent_all = agent_all_t[Symbol(:AgentAll, n_types)]
    space = GridSpace(dims)
    model = StandardABM(agent_all, space; agent_step!,
                        scheduler=Schedulers.Randomly(), warn=false,
                        rng = Xoshiro(42))
    agents_per_type = div(n_agents, n_types)
    for A in agents_used
        add_agents!(A, model, agents_per_type, agent_all)
    end
    return model
end

function initialize_model_n(;n_agents=600, n_types=1, dims=(5,5))
    agents_used = agent_types[1:n_types]
    space = GridSpace(dims)
    model = StandardABM(Union{agents_used...}, space; agent_step!,
                        scheduler=Schedulers.Randomly(), warn=false,
                        rng = Xoshiro(42))
    agents_per_type = div(n_agents, n_types)
    for A in agents_used
        add_agents!(A, model, agents_per_type)
    end
    return model
end

function add_agents!(A, model, n)
    for _ in 1:n
        a = A(model, random_position(model), 10 + rand(abmrng(model), 1:10))
        add_agent!(a, model)
    end
    return nothing
end
function add_agents!(A, model, n, W)
    for _ in 1:n
        a = W(A(model, random_position(model), 10 + rand(abmrng(model), 1:10)))
        add_agent!(a, model)
    end
    return nothing
end

function agent_step!(agent, model)
    move!(agent, model)
    agents = agents_in_position(agent.pos, model)
    for a in agents
        exchange!(agent, a)
    end
    return nothing
end

function move!(agent, model)
    cell = random_nearby_position(agent, model)
    move_agent!(agent, cell, model)
    return nothing
end

function exchange!(agent, other_agent)
    v1 = agent.money
    v2 = other_agent.money
    agent.money = v2
    other_agent.money = v1
    return nothing
end

function run_simulation_1(n_steps)
    model = initialize_model_1()
    Agents.step!(model, n_steps)
end

function run_simulation_sum(n_steps; n_types)
    model = initialize_model_sum(; n_types=n_types)
    Agents.step!(model, n_steps)
end

function run_simulation_n(n_steps; n_types)
    model = initialize_model_n(; n_types=n_types)
    Agents.step!(model, n_steps)
end

# %% Run the simulation, do performance estimate, first with 1, then with many
n_steps = 50
n_types = [2,3,4,5,10,20,30,40,50,60,70,80,90,100]

time_1 = @belapsed run_simulation_1($n_steps)
times_n = Float64[]
times_multi_s = Float64[]
for n in n_types
    println(n)
    t = @belapsed run_simulation_n($n_steps; n_types=$n)
    push!(times_n, t/time_1)
    t_sum = @belapsed run_simulation_sum($n_steps; n_types=$n)
    print(t/time_1, " ", t_sum/time_1)
    push!(times_multi_s, t_sum/time_1)
end

println("relative time of model with 1 type: 1.0")
for (n, t1, t2) in zip(n_types, times_n, times_multi_s)
    println("relative time of model with $n types: $t1")
    println("relative time of model with $n @sumtype: $t2")
end

using CairoMakie
fig, ax = CairoMakie.scatterlines(n_types, times_n; label = "Union");
scatterlines!(ax, n_types, times_multi_s; label = "@sumtype")
ax.xlabel = "# types"
ax.ylabel = "time relative to 1 type"
ax.title = "Union types vs @sumtype"
axislegend(ax; position = :lt)
ax.yticks = 0:1:ceil(Int, maximum(times_n))
ax.xticks = n_types
fig
6 Likes

Just to tell you something maybe expected but anyway interesting, @sumtype is much faster to compile though even in 1.11 (this uses the code I posted in the last comment):

julia> # this uses a `Union`
       @time @eval run_simulation_n(50; n_types=100);
 56.261693 seconds (50.98 M allocations: 2.121 GiB, 0.80% gc time, 99.88% compilation time)

julia> @time @eval run_simulation_n(50; n_types=100);
  0.045467 seconds (31.43 k allocations: 1.767 MiB)

julia> # this uses `@sumtype`
       @time @eval run_simulation_sum(50; n_types=100);
  6.715498 seconds (24.48 M allocations: 2.250 GiB, 3.09% gc time, 99.17% compilation time)

julia> @time @eval run_simulation_sum(50; n_types=100);
  0.052121 seconds (30.84 k allocations: 1.671 MiB)

this is I think a direct consequence of the “wrapping” technique used in @sumtype

4 Likes

Hi Tortar,

Can you briefly explain, how is the sumtype implemented under the hood? We have an example of agent system in our class on czech technical univerisity and we cannot get away from a dynamic dispatch, because agents are of different type. We have tried to implement the solution with symtypes, but it becomes too messy. Your package might be nice solution, but i would like to understand, how it works under the hood. Which mechanism it implements.

Thanks a lot for the answer,

Tomas

2 Likes

Hi Thomas, do you have your model somewhere in the open? I would like to take a look. Anyway the mechanism is simple, you could macro-expand it to see its inner working:

julia> using Agents, DynamicSumTypes, Random, BenchmarkTools

julia> struct A end

julia> struct B end

julia> @macroexpand @sumtype AT(A,B)
quote
    #= /home/bob/.julia/dev/DynamicSumTypes/src/DynamicSumTypes.jl:41 =#
    struct AT <: Any
        #= /home/bob/.julia/dev/DynamicSumTypes/src/DynamicSumTypes.jl:42 =#
        variants::Union{A, B}
        #= /home/bob/.julia/dev/DynamicSumTypes/src/DynamicSumTypes.jl:43 =#
        AT(v) = begin
                #= /home/bob/.julia/dev/DynamicSumTypes/src/DynamicSumTypes.jl:43 =#
                if v isa A
                    return new(v)
                end
                elseif v isa B
                    return new(v)
                end
                error("THIS_SHOULD_BE_UNREACHABLE")
            end
    end
    #= /home/bob/.julia/dev/DynamicSumTypes/src/DynamicSumTypes.jl:45 =#
    function DynamicSumTypes.variant(sumt::AT)
        $(Expr(:meta, :inline))
        #= /home/bob/.julia/dev/DynamicSumTypes/src/DynamicSumTypes.jl:45 =#
        #= /home/bob/.julia/dev/DynamicSumTypes/src/DynamicSumTypes.jl:46 =#
        v = DynamicSumTypes.unwrap(sumt)
        #= /home/bob/.julia/dev/DynamicSumTypes/src/DynamicSumTypes.jl:47 =#
        if v isa A
            return v
        end
        elseif v isa B
            return v
        end
        error("THIS_SHOULD_BE_UNREACHABLE")
    end
    #= /home/bob/.julia/dev/DynamicSumTypes/src/DynamicSumTypes.jl:49 =#
    function Base.getproperty(sumt::AT, s::Symbol)
        $(Expr(:meta, :inline))
        #= /home/bob/.julia/dev/DynamicSumTypes/src/DynamicSumTypes.jl:49 =#
        #= /home/bob/.julia/dev/DynamicSumTypes/src/DynamicSumTypes.jl:50 =#
        v = DynamicSumTypes.unwrap(sumt)
        #= /home/bob/.julia/dev/DynamicSumTypes/src/DynamicSumTypes.jl:51 =#
        if v isa A
            return Base.getproperty(v, s)
        end
        elseif v isa B
            return Base.getproperty(v, s)
        end
        error("THIS_SHOULD_BE_UNREACHABLE")
    end
    #= /home/bob/.julia/dev/DynamicSumTypes/src/DynamicSumTypes.jl:53 =#
    function Base.setproperty!(sumt::AT, s::Symbol, value)
        $(Expr(:meta, :inline))
        #= /home/bob/.julia/dev/DynamicSumTypes/src/DynamicSumTypes.jl:53 =#
        #= /home/bob/.julia/dev/DynamicSumTypes/src/DynamicSumTypes.jl:54 =#
        v = DynamicSumTypes.unwrap(sumt)
        #= /home/bob/.julia/dev/DynamicSumTypes/src/DynamicSumTypes.jl:55 =#
        if v isa A
            return Base.setproperty!(v, s, value)
        end
        elseif v isa B
            return Base.setproperty!(v, s, value)
        end
        error("THIS_SHOULD_BE_UNREACHABLE")
    end
    #= /home/bob/.julia/dev/DynamicSumTypes/src/DynamicSumTypes.jl:57 =#
    function Base.propertynames(sumt::AT)
        #= /home/bob/.julia/dev/DynamicSumTypes/src/DynamicSumTypes.jl:57 =#
        #= /home/bob/.julia/dev/DynamicSumTypes/src/DynamicSumTypes.jl:58 =#
        v = DynamicSumTypes.unwrap(sumt)
        #= /home/bob/.julia/dev/DynamicSumTypes/src/DynamicSumTypes.jl:59 =#
        if v isa A
            return Base.propertynames(v)
        end
        elseif v isa B
            return Base.propertynames(v)
        end
        error("THIS_SHOULD_BE_UNREACHABLE")
    end
    #= /home/bob/.julia/dev/DynamicSumTypes/src/DynamicSumTypes.jl:61 =#
    function Base.hasproperty(sumt::AT, s::Symbol)
        #= /home/bob/.julia/dev/DynamicSumTypes/src/DynamicSumTypes.jl:61 =#
        #= /home/bob/.julia/dev/DynamicSumTypes/src/DynamicSumTypes.jl:62 =#
        v = DynamicSumTypes.unwrap(sumt)
        #= /home/bob/.julia/dev/DynamicSumTypes/src/DynamicSumTypes.jl:63 =#
        if v isa A
            return Base.hasproperty(v, s)
        end
        elseif v isa B
            return Base.hasproperty(v, s)
        end
        error("THIS_SHOULD_BE_UNREACHABLE")
    end
    #= /home/bob/.julia/dev/DynamicSumTypes/src/DynamicSumTypes.jl:65 =#
    function Base.copy(sumt::AT)
        #= /home/bob/.julia/dev/DynamicSumTypes/src/DynamicSumTypes.jl:65 =#
        #= /home/bob/.julia/dev/DynamicSumTypes/src/DynamicSumTypes.jl:66 =#
        v = DynamicSumTypes.unwrap(sumt)
        #= /home/bob/.julia/dev/DynamicSumTypes/src/DynamicSumTypes.jl:67 =#
        if v isa A
            return AT(Base.copy(v))
        end
        elseif v isa B
            return AT(Base.copy(v))
        end
        error("THIS_SHOULD_BE_UNREACHABLE")
    end
    #= /home/bob/.julia/dev/DynamicSumTypes/src/DynamicSumTypes.jl:69 =#
    DynamicSumTypes.variantof(sumt) = begin
            #= /home/bob/.julia/dev/DynamicSumTypes/src/DynamicSumTypes.jl:69 =#
            typeof(variant(sumt))
        end
    #= /home/bob/.julia/dev/DynamicSumTypes/src/DynamicSumTypes.jl:70 =#
    DynamicSumTypes.allvariants(sumt::Type{AT}) = begin
            #= /home/bob/.julia/dev/DynamicSumTypes/src/DynamicSumTypes.jl:70 =#
            tuple(A, B)
        end
    #= /home/bob/.julia/dev/DynamicSumTypes/src/DynamicSumTypes.jl:71 =#
    DynamicSumTypes.is_sumtype(sumt::Type{AT}) = begin
            #= /home/bob/.julia/dev/DynamicSumTypes/src/DynamicSumTypes.jl:71 =#
            true
        end
    #= /home/bob/.julia/dev/DynamicSumTypes/src/DynamicSumTypes.jl:72 =#
    AT
end

In practice I think what I did “forces” the compiler to apply Union-splitting by guarding with if-else branches any access on the types.

Thanks a lot, I understand now.

The code written and maintained by my student is here

and the description can be found in lab notes here

The problem is that our World contains a Dict{Int,A} where A is an agent and which is dynamically resolved when needed. Unfortunatelly, I have not written the code, but Niklas @nmheim would be able to answer any question.

Thanks for interests. Would be nice to show students better solution than what we have.
Tomas

I’ve taken a look at your code, I think that the problem is that your Dict as implemented is a so called open-world scenario, because it will have a abstract type Agent as value, this means that you will end-up with dynamic dispatch in any case because an Agent can be anything. What you can do is to make it closed-world as a first step with something like

function World(agents::Vector{<:Agent})
    max_id = maximum(a.id for a in agents)
    types = Set(typeof(a) for a in agents)
    agent_type = Union{types...}
    World(Dict{Int, agent_type}(a.id=>a for a in agents), max_id)
end

instead of

function World(agents::Vector{<:Agent})
    max_id = maximum(a.id for a in agents)
    World(Dict(a.id=>a for a in agents), max_id)
end

Here, I used a Union but the same can be done with @sumtype. Hope this is helpful :slight_smile:

2 Likes