I recently started tinkering with SumTypes.jl again, and it led to a flurry of refactors and overhauls and I think the package is now getting into some pretty interesting territory.
I see this package as a proof of concept for how an implementation of RFC: native algebra data type/tagged union/sum type/enum support · JuliaLang/julia · Discussion #48883 · GitHub should work.
SumTypes.jl now pretty much fully implements the features of Rust’s Enums, though matching Rust’s full pattern matching capabilities is the main thing lacking. I have a @cases
macro in SumTypes.jl for efficient destructuring and matching of sum types, but it’s not a full pattern matching system. MLStyle.jl can be used for pattern matching on sum types with some work, but I’m considering writing a new pattern matching library leveraging SumTypes.jl sometime in the vague future.
For people that have looked at SumTypes.jl before, five things may have changed since the early versions where it was last discussed on Discourse:
1 Singleton variants of a sum type don’t need parenthesis
2 Sumtypes can recursively store themselves.
Together with 1.
, this means that you can write a simple, type stable linked list as
Click to expand
@sum_type List{T} begin
Nil
Cons{T}(::T, ::List{T})
end
Cons(x::A, y::List{Uninit}) where {A} = Cons(x, List{A}(y))
List(first, rest...) = Cons(first, List(rest...))
List() = Nil
julia> List(1,2,3,4)
Cons(1, Cons(2, Cons(3, Cons(4, Nil::List{Int64})::List{Int64})::List{Int64})::List{Int64})::List{Int64}
3. A smarter destructuring system.
Back to our linked list, this is how you’d find the length of that list with sum types:
Click to expand
function Base.length(l::List)
@cases l begin
Nil => 0
Cons(_, l) => 1 + length(l)
end
end
This definition is basically something like:
# Pseudo-code
function Base.length(l::List)
let data = l
throw_error_if_not_exhaustive(typeof(l), (:Nil, :Cons)) # this would error at compile time if we didn't cover every variant of the sum type.
if get_tag(l) == tag_of(typeof(L), :Nil)
0
elseif get_tag(l) == tag_of(typeof(L), :Cons)
_, l = super_special_reinterpret(Cons, l)::Cons
1 + length(l)
else
error("something went wrong")
end
end
end
4. We now allow you to hide the variants of a sumtype so that they don’t clutter your name space.
That looks like this:
Click to expand
@sum_type Fruit :hidden begin
apple
banana
orange
end
@sum_type Colour :hidden begin
orange # won't conflict with the variant from Fruit!
blue
green
end;
julia> Fruit'.orange
orange::Fruit
julia> Colour'.orange
orange::Colour
julia> let (; orange) = Fruit'
orange
end
orange::Fruit
5. The memory footprint and layout of sumtypes is now optimized to compactify all the memory of the variants together.
We do this in a way that works safely with non-isbits types, and it even works with parametrically typed storage. A sumtype’s memory footprint is now going to be the size of the biggest variant, plus the size of a discriminator byte (or bytes if you have more than 255 variants), plus maybe some bits used for alignment purposes. Example:
Click to expand
@sum_type Either{A, B} begin
Left{A}(::A)
Right{B}(::B)
end
julia> sizeof(Either{Bool, Nothing}'.Left(true))
2
julia> sizeof(Either{Int, Int}'.Left(1))
16
julia> sizeof(Either{Int128, Int}'.Left(1))
24
julia> sizeof(Either{Int128, Tuple{Int, Int}}'.Left(1))
24
Why care about any of this?
Well, if you like performance, here’s a little benchmark that shows how this approach can be dramatically faster than manual union splitting over an abstract type:
Click to expand
module AbstractTypeTest
using BenchmarkTools
abstract type AT end
Base.@kwdef struct A <: AT
common_field::Int = 0
a::Bool = true
b::Int = 10
end
Base.@kwdef struct B <: AT
common_field::Int = 0
a::Int = 1
b::Float64 = 1.0
d::Complex = 1 + 1.0im # not isbits
end
Base.@kwdef struct C <: AT
common_field::Int = 0
b::Float64 = 2.0
d::Bool = false
e::Float64 = 3.0
k::Complex{Real} = 1 + 2im # not isbits
end
Base.@kwdef struct D <: AT
common_field::Int = 0
b::Any = :hi # not isbits
end
foo!(xs) = for i in eachindex(xs)
@inbounds x = xs[i]
@inbounds xs[i] = x isa A ? B() :
x isa B ? C() :
x isa C ? D() :
x isa D ? A() : error()
end
xs = rand((A(), B(), C(), D()), 10000);
display(@benchmark foo!($xs);)
end
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 267.399 μs … 3.118 ms ┊ GC (min … max): 0.00% … 90.36%
Time (median): 278.904 μs ┊ GC (median): 0.00%
Time (mean ± σ): 316.971 μs ± 306.290 μs ┊ GC (mean ± σ): 11.68% ± 10.74%
█ ▁
█▆▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇▇ █
267 μs Histogram: log(frequency) by time 2.77 ms <
Memory estimate: 654.75 KiB, allocs estimate: 21952.
module SumTypeTest
using SumTypes, BenchmarkTools
@sum_type AT begin
A(common_field::Int, a::Bool, b::Int)
B(common_field::Int, a::Int, b::Float64, d::Complex)
C(common_field::Int, b::Float64, d::Bool, e::Float64, k::Complex{Real})
D(common_field::Int, b::Any)
end
A(;common_field=1, a=true, b=10) = A(common_field, a, b)
B(;common_field=1, a=1, b=1.0, d=1 + 1.0im) = B(common_field, a, b, d)
C(;common_field=1, b=2.0, d=false, e=3.0, k=Complex{Real}(1 + 2im)) = C(common_field, b, d, e, k)
D(;common_field=1, b=:hi) = D(common_field, b)
foo!(xs) = for i in eachindex(xs)
xs[i] = @cases xs[i] begin
A => B()
B => C()
C => D()
D => A()
end
end
xs = rand((A(), B(), C(), D()), 10000);
display(@benchmark foo!($xs);)
end
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
Range (min … max): 53.120 μs … 64.690 μs ┊ GC (min … max): 0.00% … 0.00%
Time (median): 54.070 μs ┊ GC (median): 0.00%
Time (mean ± σ): 54.093 μs ± 425.595 ns ┊ GC (mean ± σ): 0.00% ± 0.00%
▁ ▂▂▅▇▆█▅▆▃▃
▁▁▁▁▁▂▂▃▄▅▇▅▇▆█▇██████████▇▇▅▅▅▃▃▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
53.1 μs Histogram: frequency by time 55.8 μs <
Memory estimate: 0 bytes, allocs estimate: 0.
Situations where you need to build compact, type stable representations of heterogeneous data like the above benchmark are pretty common outside of the numerical code world. For instance, it’s something that caused a big bottleneck in SymbolicUtils.jl, which is why @YingboMa et. al. made Unityper.jl. SumTypes.jl is similar to Unityper.jl but more flexible in the sorts of data it can enclose. E.g. SumTypes.jl supports parametric types, does not require default values for fields, and it can handle storing non-primitive isbits
types like Tuple
.
I believe @c42f also was looking into using a data structure like this to store parsed code in JuliaSyntax.jl but I’m not sure what he ended up doing.
How can you help?
The main thing that I think would help with this is trying it out and reporting things that break, or don’t feel ergonomic.
I also really need some better documentation, so even just requests for clarification on how things work would be appreciated.