[ANN] SumTypes.jl v0.4

I recently started tinkering with SumTypes.jl again, and it led to a flurry of refactors and overhauls and I think the package is now getting into some pretty interesting territory.

I see this package as a proof of concept for how an implementation of RFC: native algebra data type/tagged union/sum type/enum support · JuliaLang/julia · Discussion #48883 · GitHub should work.

SumTypes.jl now pretty much fully implements the features of Rust’s Enums, though matching Rust’s full pattern matching capabilities is the main thing lacking. I have a @cases macro in SumTypes.jl for efficient destructuring and matching of sum types, but it’s not a full pattern matching system. MLStyle.jl can be used for pattern matching on sum types with some work, but I’m considering writing a new pattern matching library leveraging SumTypes.jl sometime in the vague future.


For people that have looked at SumTypes.jl before, five things may have changed since the early versions where it was last discussed on Discourse:

1 Singleton variants of a sum type don’t need parenthesis

2 Sumtypes can recursively store themselves.

Together with 1., this means that you can write a simple, type stable linked list as

Click to expand
@sum_type List{T} begin 
    Nil
    Cons{T}(::T, ::List{T}) 
end
Cons(x::A, y::List{Uninit}) where {A} = Cons(x, List{A}(y))

List(first, rest...) = Cons(first, List(rest...))
List() = Nil
julia> List(1,2,3,4)
Cons(1, Cons(2, Cons(3, Cons(4, Nil::List{Int64})::List{Int64})::List{Int64})::List{Int64})::List{Int64}

3. A smarter destructuring system.

Back to our linked list, this is how you’d find the length of that list with sum types:

Click to expand
function Base.length(l::List) 
    @cases l begin
        Nil => 0
        Cons(_, l) => 1 + length(l)
    end
end

This definition is basically something like:

# Pseudo-code
function Base.length(l::List)
    let data = l
        throw_error_if_not_exhaustive(typeof(l), (:Nil, :Cons)) # this would error at compile time if we didn't cover every variant of the sum type.
        if get_tag(l) == tag_of(typeof(L), :Nil)
              0
        elseif get_tag(l) == tag_of(typeof(L), :Cons)
              _, l = super_special_reinterpret(Cons, l)::Cons
              1 + length(l)
         else
             error("something went wrong")
         end
    end
end

4. We now allow you to hide the variants of a sumtype so that they don’t clutter your name space.

That looks like this:

Click to expand
@sum_type Fruit :hidden begin
    apple
    banana
    orange
end
@sum_type Colour :hidden begin
    orange # won't conflict with the variant from Fruit!
    blue
    green
end;
julia> Fruit'.orange
orange::Fruit

julia> Colour'.orange
orange::Colour

julia> let (; orange) = Fruit'
           orange
       end
orange::Fruit

5. The memory footprint and layout of sumtypes is now optimized to compactify all the memory of the variants together.

We do this in a way that works safely with non-isbits types, and it even works with parametrically typed storage. A sumtype’s memory footprint is now going to be the size of the biggest variant, plus the size of a discriminator byte (or bytes if you have more than 255 variants), plus maybe some bits used for alignment purposes. Example:

Click to expand
@sum_type Either{A, B} begin
    Left{A}(::A)
    Right{B}(::B)
end

julia> sizeof(Either{Bool, Nothing}'.Left(true))
2

julia> sizeof(Either{Int, Int}'.Left(1))
16

julia> sizeof(Either{Int128, Int}'.Left(1))
24

julia> sizeof(Either{Int128, Tuple{Int, Int}}'.Left(1))
24

Why care about any of this?

Well, if you like performance, here’s a little benchmark that shows how this approach can be dramatically faster than manual union splitting over an abstract type:

Click to expand
module AbstractTypeTest

using BenchmarkTools

abstract type AT end
Base.@kwdef struct A <: AT
    common_field::Int = 0
    a::Bool = true
    b::Int = 10
end
Base.@kwdef struct B <: AT
    common_field::Int = 0
    a::Int = 1
    b::Float64 = 1.0
    d::Complex = 1 + 1.0im # not isbits
end
Base.@kwdef struct C <: AT
    common_field::Int = 0
    b::Float64 = 2.0
    d::Bool = false
    e::Float64 = 3.0
    k::Complex{Real} = 1 + 2im # not isbits
end
Base.@kwdef struct D <: AT
    common_field::Int = 0
    b::Any = :hi # not isbits
end

foo!(xs) = for i in eachindex(xs)
    @inbounds x = xs[i]
    @inbounds xs[i] = x isa A ? B() :
                      x isa B ? C() :
                      x isa C ? D() :
                      x isa D ? A() : error()
end


xs = rand((A(), B(), C(), D()), 10000);
display(@benchmark foo!($xs);)

end
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  267.399 μs …   3.118 ms  ┊ GC (min … max):  0.00% … 90.36%
 Time  (median):     278.904 μs               ┊ GC (median):     0.00%
 Time  (mean ± σ):   316.971 μs ± 306.290 μs  ┊ GC (mean ± σ):  11.68% ± 10.74%

  █                                                             ▁
  █▆▄▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▇▇ █
  267 μs        Histogram: log(frequency) by time       2.77 ms <

 Memory estimate: 654.75 KiB, allocs estimate: 21952.
module SumTypeTest

using SumTypes,  BenchmarkTools
@sum_type AT begin
    A(common_field::Int, a::Bool, b::Int)
    B(common_field::Int, a::Int, b::Float64, d::Complex)
    C(common_field::Int, b::Float64, d::Bool, e::Float64, k::Complex{Real})
    D(common_field::Int, b::Any)
end

A(;common_field=1, a=true, b=10) = A(common_field, a, b) 
B(;common_field=1, a=1, b=1.0, d=1 + 1.0im) = B(common_field, a, b, d)
C(;common_field=1, b=2.0, d=false, e=3.0, k=Complex{Real}(1 + 2im)) = C(common_field, b, d, e, k)
D(;common_field=1, b=:hi) = D(common_field, b)

foo!(xs) = for i in eachindex(xs)
    xs[i] = @cases xs[i] begin
        A => B()
        B => C()
        C => D()
        D => A()
    end
end

xs = rand((A(), B(), C(), D()), 10000);
display(@benchmark foo!($xs);)
end 
BenchmarkTools.Trial: 10000 samples with 1 evaluation.
 Range (min … max):  53.120 μs …  64.690 μs  ┊ GC (min … max): 0.00% … 0.00%
 Time  (median):     54.070 μs               ┊ GC (median):    0.00%
 Time  (mean ± σ):   54.093 μs ± 425.595 ns  ┊ GC (mean ± σ):  0.00% ± 0.00%

                ▁ ▂▂▅▇▆█▅▆▃▃                                    
  ▁▁▁▁▁▂▂▃▄▅▇▅▇▆█▇██████████▇▇▅▅▅▃▃▂▂▁▂▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁▁ ▃
  53.1 μs         Histogram: frequency by time         55.8 μs <
 Memory estimate: 0 bytes, allocs estimate: 0.

Situations where you need to build compact, type stable representations of heterogeneous data like the above benchmark are pretty common outside of the numerical code world. For instance, it’s something that caused a big bottleneck in SymbolicUtils.jl, which is why @YingboMa et. al. made Unityper.jl. SumTypes.jl is similar to Unityper.jl but more flexible in the sorts of data it can enclose. E.g. SumTypes.jl supports parametric types, does not require default values for fields, and it can handle storing non-primitive isbits types like Tuple.

I believe @c42f also was looking into using a data structure like this to store parsed code in JuliaSyntax.jl but I’m not sure what he ended up doing.

How can you help?

The main thing that I think would help with this is trying it out and reporting things that break, or don’t feel ergonomic.

I also really need some better documentation, so even just requests for clarification on how things work would be appreciated.

39 Likes

Thanks for your work on this - I’ve been reading through the docs and it’s looking very nice.

2 Likes

Cool, I want to try this out with the path structs in Makie, I think that one would be a classic use case as the elements have different sizes but I still want to store them efficiently in a vector (which I think they don’t right now).

3 Likes

I am curious about a usecase along the following lines, focused on still writing code as if we are using typical multiple dispatch from Julia:

Someone already has a simple ray-tracing library written in Julia which implements tens of different solid object primitives (each a separate struct with supertype Abstract3DObject). Then a render function looks something like

function render(scene::Vector{Abstract3DObject})
    for object in scene
        find_ray(object)
    ...
end

find_ray itself has a ton of methods for the various object types.

My understanding is that this is a rather bad way to do things in Julia because

  • one usually has a ton of different types stored in scene so the vector contains boxed objects even if all the objects are otherwise pretty simple
  • find_ray inside of render will be using slower dynamical dispatch because there are too many types for automatic union splitting

Can SumTypes be used here without rewriting the library to make Abstract3DObject a sumtype? Can the sumtype be created dynamically by some wrapper for the render function, so that everything except render can be written as if SumType is not involved?

I guess I am asking for a function that dynamically creates a sumtype when one of its arguments is a Vector of many structs of the same abstract type.

3 Likes

@peremato I wonder if this helps Allocation and slow down when # of types involved increase (slower than C++ virtual methods)

2 Likes

Hmm, that would be tricky. Creating a new sum type is currently a pretty heavy process that requires defining some structs, and defining a lot of methods on those structs so doing it dynamically would be quite costly.

However, I suppose one way around that would be to only define new sumtypes as they appear and assume that the number of subtypes of Abstract3DObject doesn’t change very often. So you could do something like this:

function render(scene::Vector{Abstract3DObject})
    if already_made_sumtype_for_current_subtypes_of(Abstract3DObject)
        ST = get_sum_type(Abstract3DObject)
    else
         new_name = gensym()
         ST = @eval begin 
             $SumTypes.@sum_type $new_name begin
                 $(programatically_generate_varaints_of(subtypes(Abstract3DOject)))
             end
             newest_sumtype(::Type{Abstract3DObject}) = $new_name
             $new_name
         end
    end
    Base.invoke_latest() do
        compact_scene = ST.(scene)
        for object in compact_scene
            find_ray(object)
        end
        ...
    end
end

Not exactly pretty but it might work, I’m not sure. It could be interesting to make a macro that does this automatically to transform abstract types into sum types and try to get the best of both worlds. But again, this is somewhat speculative. This might not be worth it at all.

1 Like

This is all super cool. Thanks for the update post. I’m playing with it a bit now to experiment with writing variants of math functions that don’t throw, and I’m experiencing a surprising (to me) overhead that I’m wondering if you can help me understand:

using SumTypes, BenchmarkTools

@sum_type Result{T} begin
  Failure
  Success{T}(::T)
end

# If I annotate the return type here as ::Result{T}, this function throws an
# ambiguous method error.
function log_nothrow(x::T) where{T<:AbstractFloat}
  if x < zero(x) 
    # Because I can't annotate the return type, using the other README trick.
    f::Result{T} = Failure
    return f
  end
  Success(log(x))
end

# It seems like constructing the Success{T} has a 10ns 
# overhead compared to log(1.1).
#@benchmark log_nothrow($(1.1))

# No additional overhead in the case where the sum type doesn't have a payload.:w
#@benchmark log_nothrow($(-1.1))

Obviously just a toy example, but I am interested in some kind of result type like this for math functions that need to be as fast as possible, so I’m wondering this kind of overhead is expected. A generic struct instantiation overhead seems to be about 1ns on my computer, so I wonder if maybe I’m doing something wrong here.

2 Likes

This is very interesting! How does this part work:

I see we transformed Vector{Abstract3DObject} into Vector{ST}. But then, how come find_ray still works? find_ray has pre-defined methods for the subtypes of Abstract3DObject, not for the variants of ST (there is one-to-one correspondence between subtypes and variants, but I do not see where that comes in).

I guess in the @eval block, you’d have to define a new method of find_ray(::$new_name) that does something like.

find_ray(obj::$new_name) = @cases obj begin
    $(programatically_insert_all_cases(find_ray))
end
1 Like

Oh, good find! Looks like there are two problems being hit in your example

  1. the return type assertion wasn’t working because I wrote two ambigious convert methods
  2. the convert method you used instead instead was blocking constant propagation

These were quite easy and fast to fix, so I did that here: More `convert` fixes, parameterize on flagtype by MasonProtter · Pull Request #24 · MasonProtter/SumTypes.jl · GitHub

However, there is a further problem which gets revealed when benchmarking the actual runtime of this and not the constant propagation:

julia> let x = Ref(1.1)
           @btime log_nothrow($x[])
           @btime log($x[])
       end
  15.280 ns (0 allocations: 0 bytes)
  4.740 ns (0 allocations: 0 bytes)
0.09531017980432493

What’s going on here is an alignment problem that I was hoping I had solved but it looks like it’s back. I’m working on a solution now though and I think it’s almost ready.

5 Likes

Okay, More `convert` fixes, parameterize on flagtype by MasonProtter · Pull Request #24 · MasonProtter/SumTypes.jl · GitHub is now merged and I have fixed the alignment issue (at the cost of adding in another type parameter). @cgeoga, if you update to SumTypes.jl v0.4.2 (should be on the general registry within the next 15 minutes) you should see this:

using SumTypes

@sum_type Result{T} begin
    Failure
    Success{T}(::T)
end

function log_nothrow(x::T)::Result{T} where{T<:AbstractFloat}
    if x < zero(x) 
        return Failure
    end
    Success(log(x))
end


julia> let x = Ref(1.1)
           @btime log_nothrow($x[])
           @btime log($x[])
       end
  5.230 ns (0 allocations: 0 bytes)
  4.529 ns (0 allocations: 0 bytes)
0.09531017980432493

Before, I was adding padding-bits to my struct to try and ensure good alignment of small sum types, but it turned out that because I can’t communicate to LLVM that I don’t care about the padding bits and they’re just for alignment purposes, there was a performance loss to doing this. So now I adjust the size of the discriminator byte in order to get good alignment. I learned that trick from this blog post @jar1 shared with me about the memory layout of Rust’s Enums: Rust enum-match code generation

6 Likes

Wow, that is all super cool! Thank you for the explanation, link, and super fast release to fix it.

…which makes me remorse to be more difficult, but I’m still actually experiencing the issue. When I copy-paste that code into a REPL, I still get

julia> let x = Ref(1.1)
                  @btime log_nothrow($x[])
                  @btime log($x[])
              end
  14.628 ns (0 allocations: 0 bytes)
  4.046 ns (0 allocations: 0 bytes)

with

(jl_7uKFVl) pkg> st
Status `/tmp/jl_7uKFVl/Project.toml`
  [8e1ec7a9] SumTypes v0.4.2

julia> versioninfo()
Julia Version 1.8.5
Commit 17cfb8e65ea (2023-01-08 06:45 UTC)
Platform Info:
  OS: Linux (x86_64-linux-gnu)
  CPU: 12 × 11th Gen Intel(R) Core(TM) i5-11600K @ 3.90GHz
  WORD_SIZE: 64
  LIBM: libopenlibm
  LLVM: libLLVM-13.0.1 (ORCJIT, rocketlake)
  Threads: 1 on 12 virtual cores

Is this a 1.8 vs 1.9 issue? What can I do to help figure out why my machine/Julia aren’t getting the right optimization?

EDIT: It is a v1.8.5 vs v1.9.0-rc2 issue. On 1.9:

julia> let x = Ref(1.1)
                  @btime log_nothrow($x[])
                  @btime log($x[])
              end
  4.794 ns (0 allocations: 0 bytes)
  4.032 ns (0 allocations: 0 bytes)

So that’s interesting.

1 Like

Hm, that’s very interesting! Yes, I did all my benchmarking on v1.9 not 1.8! I can reproduce your result on v1.8. It seems that what’s going on here is that there was a change in some interprodeedural optimization heuristic from 1.8 to 1.9, because when I look at the code LLVM for log_nothrow, it’s identical between the two versions:
v1.8:

julia> @code_llvm log_nothrow(1.1)
;  @ REPL[15]:1 within `log_nothrow`
define void @julia_log_nothrow_2569({ [8 x i8], i64 }* noalias nocapture sret({ [8 x i8], i64 }) %0, double %1) #0 {
top:
;  @ REPL[15]:2 within `log_nothrow`
; ┌ @ float.jl:412 within `<`
   %2 = fcmp uge double %1, 0.000000e+00
; └
  br i1 %2, label %L4, label %L3

common.ret:                                       ; preds = %L4, %L3
;  @ REPL[15] within `log_nothrow`
  ret void

L3:                                               ; preds = %top
;  @ REPL[15]:3 within `log_nothrow`
  %3 = getelementptr inbounds { [8 x i8], i64 }, { [8 x i8], i64 }* %0, i64 0, i32 0, i64 0
  call void @llvm.memset.p0i8.i64(i8* noundef nonnull align 8 dereferenceable(16) %3, i8 0, i64 16, i1 false)
  br label %common.ret

L4:                                               ; preds = %top
;  @ REPL[15]:5 within `log_nothrow`
; ┌ @ special/log.jl:267 within `log`
   %4 = call double @j__log_2571(double %1, {}* inttoptr (i64 139985714399816 to {}*)) #0
; └
  %5 = bitcast { [8 x i8], i64 }* %0 to double*
  store double %4, double* %5, align 8
  %.sroa.2.0..sroa_idx6 = getelementptr inbounds { [8 x i8], i64 }, { [8 x i8], i64 }* %0, i64 0, i32 1
  store i64 1, i64* %.sroa.2.0..sroa_idx6, align 8
  br label %common.ret
}

v1.9:

julia> @code_llvm log_nothrow(1.1)
;  @ REPL[7]:1 within `log_nothrow`
define void @julia_log_nothrow_1013({ [8 x i8], i64 }* noalias nocapture noundef nonnull sret({ [8 x i8], i64 }) align 8 dereferenceable(16) %0, double %1) #0 {
top:
;  @ REPL[7]:2 within `log_nothrow`
; ┌ @ float.jl:535 within `<`
   %2 = fcmp uge double %1, 0.000000e+00
; └
  br i1 %2, label %L4, label %L3

common.ret:                                       ; preds = %L4, %L3
;  @ REPL[7] within `log_nothrow`
  ret void

L3:                                               ; preds = %top
;  @ REPL[7]:3 within `log_nothrow`
  %3 = getelementptr inbounds { [8 x i8], i64 }, { [8 x i8], i64 }* %0, i64 0, i32 0, i64 0
  call void @llvm.memset.p0i8.i64(i8* noundef nonnull align 8 dereferenceable(16) %3, i8 0, i64 16, i1 false)
  br label %common.ret

L4:                                               ; preds = %top
;  @ REPL[7]:5 within `log_nothrow`
; ┌ @ special/log.jl:267 within `log`
   %4 = call double @j__log_1015(double %1, {}* inttoptr (i64 139900220263176 to {}*)) #0
; └
  %5 = bitcast { [8 x i8], i64 }* %0 to double*
  store double %4, double* %5, align 8
  %.sroa.2.0..sroa_idx6 = getelementptr inbounds { [8 x i8], i64 }, { [8 x i8], i64 }* %0, i64 0, i32 1
  store i64 1, i64* %.sroa.2.0..sroa_idx6, align 8
  br label %common.ret
}

Since the benchmarks for log itself hasn’t changed, I think it must be to do with IPO, but I’m not really sure. :frowning:

1 Like

Very interesting! I suppose I’ll just have to move to 1.9 now then…

1.9 is an amazing upgrade! I strongly recommend it.

In general though, this package is doing a lot of somewhat fancy code-gen stuff at compile time in order to properly support various features that I think a fully general rust-like Enum should have. So I guess it shouldn’t be too surprising that sometimes older versions of the compiler might make bad choices with it.

For the purposes of a log_nothrow though, it might be that the much more simple ErrorTypes.jl works better? It’s essentially a set of hand-made sum types for error handling.

Really the main thing SumTypes.jl is trying to do is provide genericness, bervity, and flexibility when defining a sum type, so that might not be so great an advantage in a situation like log_nothrow (much as it pains me to say!).

2 Likes

Could you send me a link to this ray tracing library?

It was a made up example because raytracing is a common situation in which you need to map a function with many different methods over an array of many different types. The collider physics simulator mentioned above is kinda the same, it is just that the rays are elementary particles, not optical beams.

I do think someone on discourse was advertising their ray tracing project a year or two ago, I would suggest searching through the archive. And there are a few fun “ray tracing in $SMALL_NUMBER days” online which can be an inspiration if you would like to try yourself.

1 Like

This is pretty cool! I think the data layout with (bits, ptrs, #tag#) here is about as good as it can be with the current Julia runtime. (If we had sum types in the language it’d be better of course; the GC could look at the tag to figure out where the pointers are.)

This reminds me of the issue about teaching the GC to overlap bits types with pointers in another situation. This one:

I feel like there’s some similarities here, and maybe both cases could be served with the same low level extension to the Julia object model.

Regarding JuliaSyntax.jl, we’re explicitly tagging the expression nodes (SyntaxNode / GreenNode) there with a Kind and in principle a sum type could be used, in which case the Kind would be the sum type’s tag. However, for JuliaSyntax it’s more useful to have the Kind tag be first class and to reuse it across the system in various data structures: both as the head() of different syntax node types, and in the streaming parser’s input and output arrays (the design allows for parsing without constructing any linked tree data structures.)

1 Like

This looks great. I am thinking of ways to represent operations in a quantum circuit. I’m wondering if SumTypes might work well. The problem is similar to the example given by @Krastanov .Pulling heterogeneous objects out of a container and using dispatch to choose a branch results in runtime dispatch, which is really slow.

Because a package that works with quantum circuits is rather large, you can afford a bit more effort in rolling your own solution. It looks like the intent of SumTypes is to also support smaller one-off uses. My main use case is large collections of operators. I think a struct of vectors approach is more efficient. And SumTypes is better suited to a vector of structs approach. One thing they both try to offer over arrays of standard Julia types is more efficient storage and performant access.

The contents of a circuit can vary, but many of them are made of largely of structureless objects that could be represented by, say, an Int32. And many with an Int32 and a float parameter. But you also need to support operators with more structure.

My current best thought is MEnums.jl which I wrote with this use case in mind. This was copied from Base Enums. I didn’t know about the other enums packages at the time. But this one is a bit different, although some features overlap with EnumsX. With MEnums the following features are essential for my application.

  • Add more instances after declaring the type.
  • Specify a new module (name space) for the instances
  • Manage “blocks” (that is ranges of ordinal numbers) of instances. Add new instances to a particular block. Check if an instance is in particular block, etc. Organizing and supporting queries relating instances and blocks is how you emulate dispatch.

For example

@menum (Element, blocklength=10^6, numblocks=50, compactshow=true)

I said that the operators are structureless. But you also have to specify where and how they are in the circuit. Some of that info might best be stored in a type instance with the operator type. That specification is already heterogeneous, which is why I am not using a vector of structs that includes location information. I also have to carry an array or two for the additional structure that has entries nothing for the structureless operators.

The problems of representing circuits and raytracers are examples of a generic problem, one whose particulars might make one solution better than another. Despite my description of the problem above, I’ll probably experiment with implementing circuits with SumTypes. It’s not yet clear which solutions are workable.

MEnums is not registered in the general registry , because the name was rejected and I haven’t gotten around to wrangling through it. It could use a better name anyway. I’m open to suggestions, but that should go in another thread, I don’t want to hijack this one.

MEnums is in LapeyreRegistry which I manage with LocalRegistry It would be great to see more support and culture around alternatives to the general registry.

2 Likes

Funnily enough, my actual use case is the exact same as @jlapeyre’s. I was just looking for a non-quantum example :smiley:

2 Likes