[Guide] Using sum types to define dynamic C APIs for Julia 1.12 `--trim`

I have recently spent a lot of time trying to get one of my libraries ready for the Julia 1.12 --trim feature such that I can ship it as a C library and wrap it into a python module. During this work I learned a lot and would like to share a pattern I have found that helps design @ccallable interfaces with dynamic types, for example a matrix type that might be either Matrix or Diagonal. This might be of interest to folks also interested in my notes on trimming Creating fully self-contained and pre-compiled library - #17 by RomeoV.

To illustrate, we’ll write a C-callable function matrixsum_cc that computes the sum of a square matrix. The matrix is passed via a raw pointer Ptr{Cdouble} and a size parameter, but crucially may have additional structure given through a Cint enum value. In our example, it can be a dense matrix or a diagonal matrix. We want to leverage this information by converting it into a Matrix or Diagonal, respectively, so that we can load it correctly and use Julia’s dispatch for computation.

In short, in the end we want to write a function

@ccallable function matrixsum_cc(sz::Integer, ptr::Ptr{Cdouble}, mattype::MatType.Enum)::Cdouble
    m = build_matrix(sz, ptr, mattype)
    return sum(m)
end

that is trimmable, i.e., we can pass to juliac --trim to get a static C library or executable.

The issue here is that this function is by default not “type grounded”, i.e., not all variable types within the function are determinable from the input types. In particular, m will be either of type Matrix or Diagonal, depending on the value of mattype (which is just a Cint). We therefore need a way to go from the s-expression style formulation (storing the type info as a variable) to a formulation that is fully “type grounded” and compatible with multiple dispatch. We can do this by leveraging “closed” sum types.

We first need to build a datatype to return for our build_matrix function. Since that can either be a dense or diagonal matrix, we define a “closed” sum type, i.e., a sum type that contains all possibilities. We will use Moshi.jl, although other libraries such as SumTypes.jl [EDIT: and LightSumTypes.jl, see below] should also work.

import Pkg
Pkg.activate(; temp=true)
Pkg.add(["Moshi", "CEnum", "JET"])

using LinearAlgebra
import Moshi.Data: @data
@data MyMatrix{T<:Number} begin
    DenseMat(Matrix{T})
    DiagMat(Diagonal{T, Vector{T}})  # <- Make sure this is a concrete type. `Diagonal{T}` wouldn't be enough.
end

For convenience we also define a CEnum that allows us for passing in the type of the matrix from C via a CInt but refer to it in Julia via an enum. We wrap it into a module just to keep it in a namespace.

module MatType
    using CEnum
    @cenum Enum::Cint begin
        dense = 1
        diag = 2
    end
end

Now we are ready to write our build_matrix function. The function takes in the MatType enum, builds the julia matrix, and returns the matrix wrapped into a sum type such that the function is type stable. Notice that all internal variables have a clearly inferable type, making the function additionally type grounded, a stronger property than type stable.

function build_matrix(sz::Integer, ptr::Ptr{Cdouble}, mattype::MatType.Enum)::MyMatrix.Type{Cdouble}
    if mattype == MatType.dense
        # Unsafe pointer logic to create the Matrix
        m_dense = unsafe_wrap(Matrix{Cdouble}, ptr, (sz, sz))
        return MyMatrix.DenseMat(m_dense) # Wrap it in the sum type variant
    elseif mattype == MatType.diag
        m_diag = Diagonal(unsafe_wrap(Vector{Cdouble}, ptr, sz))
        return MyMatrix.DiagMat(m_diag) # Wrap it
    else
        error("Unmatched MatType")
    end
end

Now comes the magic. For each variant of our sum type, we are able to call our kernel (here the sum function) and call it with a concrete type! Crucially, to maintain type stability, our kernel needs to return the same type for each variant (here T=Float64). Instead of sum, this could also be something like mat\b or a whole solver or simulation process.

import Base: sum
import Moshi.Match: @match
function sum(m::MyMatrix.Type{T})::T where {T}
    @match m begin
        MyMatrix.DenseMat(mat) => sum(mat)
        MyMatrix.DiagMat(mat) => sum(mat)
    end
end

Finally we’re able to write our C interface where we can pass in a raw pointer and the MatType enum value (an integer) and get our result in a type-stable way. Notice how now m has a concrete type: MyMatrix{Float64}! This makes the entire code type grounded and ready for trimming.

import Base: @ccallable
@ccallable function matrixsum_cc(sz::Cint, ptr::Ptr{Cdouble}, mattype::MatType.Enum)::Cdouble
    m = build_matrix(sz, ptr, mattype)
    return sum(m)
end

To validate the type stability, we can check with @code_warntype, or directly test with JET.jl.

julia> using JET
julia> densemat = rand(3,3);
julia> @test_opt matrixsum_cc(Cint(3), pointer(densemat), MatType.dense)
Test Passed
julia> diagmat = Diagonal(rand(3));
julia> @test_opt matrixsum_cc(Cint(3), pointer(diagmat.diag), MatType.diag)
Test Passed

That’s it! :juliabouncing: :llvm:

PS: A copy-pastable version of all the code is here.

30 Likes

Thank you for sharing @RomeoV ! I wonder if this code pattern could be abstracted out in a macro that rewrites normal Julia function definitions into a sum-type approach that is C-callable.

If we have to rewrite code with wrapper sum-types to achieve static compilation, we are basically reviving the two-language problem. The difference is that we are rewriting code using Julia syntax.

Maybe I didn’t get this right, but to me it seems that no code needs to be rewritten.
Please correct me if I’m wrong but I think all the implemented functionality can stay the same and it is just wrapped in an additional type stable/grounded layer.

I’m not arguing that a macro that does this automatically wouldn’t be nice, but calling this the revived two-language problem seems a bit far-fetched.

This is equivalent to writing wrappers in a different language, with restricted types. That is what I tried to say. Ideally, no extra work would be needed from the Julia developer other than calling a macro that creates wrappers automatically for a specific subset of selected types.

2 Likes

it seems that no code needs to be rewritten

Yes, that’s right. New compatibility code has to be written only at the “boundary” of the API. None of the internal Julia code needs to be restructured (besides that all Julia functions need to be type grounded too, for trimmability).

1 Like

Let’s say we want to make 100 different functions C-callable. Yes, then we’d need to make 100 custom interface functions. Further, the interface functions can look more complicated. Let’s say we are parsing two arrays, not just one. Then we’d have to do something like

using .MyMatrix: DenseMat, DiagMat
function sum(m::MyMatrix.Type{T}, m2::MyMatrix.Type{T})::T where {T}
    @match (m, m2) begin
        (DenseMat(mat), DenseMat(mat2)) => sum(mat)+sum(mat2)
        (DenseMat(mat), DiagMat(mat2)) => sum(mat)+sum(mat2)
        (DiagMat(mat), DenseMat(mat2)) => sum(mat)+sum(mat2)
        (DiagMat(mat), DiagMat(mat2)) => sum(mat)+sum(mat2)
    end
end

So, obviously there’s some combinatorial explosion here, with (n_{variants}) ^ {n_{variables}} choices. One could definitely macro generate this.

Could we also just generate all the required functions automatically for any Julia function?
Let’s consider what we need to turn our original sum(m::AbstractMatrix) into matrixsum_cc(sz, ptr, mattype).

  1. First, we need to enumerate exactly which concrete matrix types we need to compile for. This is the @data ... part, which we can’t really make more concise through another macro. In normal Julia code, this is determined automatically by checking which concrete types the sum function ends up getting called with, but for a C interface this has to be done manually somehow.
  2. Second, we have to define an enum which maps each variant to a Cint to make it ccallable (MatType above). This could in theory be automated with a macro, but it’s important that it stays readable.
  3. Third, we have to define how to load each matrix from the tuple (sz, ptr, mattype). This can just be written once and reused, without macros. In fact, it’s exactly our build_matrix function!
  4. Finally, we have to use the @match magic to call our kernel function with concrete types. If we have only univariate “wrapper” sumtypes (just one variable per sum-type variant, nothing nested etc) I think we could definitely generate this.

So perhaps a macro could look something like this:

@ccallablevariants MyMatrix{T<:Number} Cint begin
    DenseMat(Matrix{T})=1  # will have enum value `Cint==1`
    DiagMat(Diagonal{T, Vector{T}})=2  # will have enum value `Cint==2`
end 
# make this also generate `MyMatrix.Enum`, with `MyMatrix.Enum.DenseMat == CInt(1)` etc.

# make `construct_from_ctypes` a standarized function for `@generate_ccallable` below
function construct_from_ctypes(::Type{MyMatrix.Type})
    function build_matrix(sz::Integer, ptr::Ptr{Cdouble}, mattype::MatType.Enum)::MyMatrix.Type{Cdouble}
        # as above ...
    end
    return (; sz=Integer, ptr=Ptr{Cdouble}, mattype=MyMatrix.Enum), build_matrix
end
    
# now this can generate all the other code
@generate_ccallable sum(MyMatrix{Float64})

or for multiple parameters and a mix of sum types and regular types

foo(mat::AbstractMatrix, mat2::AbstractMatrix, otherparam::Integer) = ...
@generate_ccallable foo(::MyMatrix{Float64}, ::MyMatrix{Float64}, otherparam::Cint)

# the call above would generate something like
function foo(sz_1::Integer, ptr_1::Ptr{Cdouble}, mattype_1::MyMatrix.Enum,
             sz_2::Integer, ptr_2::Ptr{Cdouble}, mattype_2::MyMatrix.Enum,
             otherparam::Cint)
    # ...
end
1 Like

Not sure if correct, but it seems LightSumTypes.jl could be a bit better, you wouldn’t have to deal with the combinatorial explosion. You would just write sum(variant(m))+sum(variant(m2)) instead of defining all the matches.

2 Likes

Thanks, using LightSumTypes.jl indeed looks like an even cleaner alternative!

using LightSumTypes
@sumtype MyMatrix2{T}(
    Matrix{T},
    Diagonal{T, Vector{T}}
)

# even simpler now
sum(m::MyMatrix2{T}) where {T} = sum(variant(m))

function build_matrix2(sz::Integer, ptr::Ptr{Cdouble}, mattype::MatType.Enum)
    if mattype == MatType.dense
        # Unsafe pointer logic to create the Matrix
        m_dense = unsafe_wrap(Matrix{Cdouble}, ptr, (sz, sz))
        return MyMatrix2(m_dense) # Wrap it in the sum type variant
    elseif mattype == MatType.diag
        m_diag = Diagonal(unsafe_wrap(Vector{Cdouble}, ptr, sz))
        return MyMatrix2{Float64}(m_diag) # Wrap it
    else
        error("Unmatched MatType")
    end
end

@ccallable function matrixsum2_cc(sz::Cint, ptr::Ptr{Cdouble}, mattype::MatType.Enum)::Cdouble
    m = build_matrix2(sz, ptr, mattype)
    return sum(m)
end

@test_opt matrixsum2_cc(Cint(3), pointer(diagmat.diag), MatType.diag)
# Test Passed

EDIT: Wow and LightSumTypes.jl is like 100 loc or so total! I think it’s definitely a great choice here.

3 Likes

Nice! Yes, indeed I didn’t particularly try to write the cleanest source code possible, we could defnitely try to arrive at 50 LOC. In a sense, the approach is very simple and it could be ideally in my opinion adopted in Base if sensible. The question is if the concept of a “ClosedUnion” is generally useful to have in Julia itself. Your post makes me think that maybe it does.

3 Likes

Yes, fully agree.

But that’s just not what is usually referred to as the “two-language problem”.
Sure, it might also involve the notion of wrapping code (e.g. C in python), but the core idea is that prototyping happens in a fast to work in but slow to execute language (python) followed by (at least partly) rewriting in a low-level, compiled, fast to execute language (C) for production.

If a LightSumType is better than a Union, why isn’t it the default? What are the disadvantages of LightSumTypes compared to a Union ?

I had mixed results comparing performance between the two approaches, usually LightSumTypes is better though, actually with Julia 1.10 it is much better. Apart from that, a Union doesn’t require the boilerplate LightSumTypes needs, and also, a Union is a covariant abstract type while LightSumTypes creates invariant types. I think these characteristics makes a Union usually easier to work with.

2 Likes

Thank you for the write up!

How important is type-groundedness here? While your functions are type-grounded, somewhere deeper down something must not be grounded because at some point there must be a variable which is either a Matrix or a Diagonal no? Or are Moshi and LightSumTypes doing some magic to make it so?

The original intention leading me to write this post was to use juliac --trim to build a statically compiled C library or executable. Now, if you have played around with trimming you know that it’s not so easy to get the trimming verifier to work for any Julia code. However, it seems to me now that these three are roughly equivalent:

  • the trimming verifier succeeds for some entrypoint(...)
  • JET.@test_opt entrypoint(...) passes without warnings
  • the entire callgraph reachable from entrypoint(...) is type grounded

To recall the definition of type groundedness once more, let me just directly quote the paper [Sec 3]:

Effectively, there are two competing, type-related properties of function bodies. To address this confusion, we define and refer to them using two distinct terms:
• type stability is when a function’s return type depends only on its argument types, and
• type groundedness is when every variable’s type depends only on the argument types.
Although type stability is strictly weaker than type groundedness, for the purposes of this paper, we are interested in both properties. The latter, type groundedness, is useful for performance of the function itself, as it implies that unboxing, devirtualization, and inlining can occur. The former, type stability, allows the function to be used efficiently by other functions: namely, type-grounded functions may call a function that is only type stable but not grounded.

Now let me address your point.

While your functions are type-grounded, somewhere deeper down something must not be grounded because at some point there must be a variable which is either a Matrix or a Diagonal no?

Julia’s dispatch mechanism makes it so that we are type grounded all the way down. Let’s consider the snippet from above again:

function sum(m::MyMatrix.Type{T})::T where {T}
    @match m begin
        MyMatrix.DenseMat(mat) => sum(mat)  # mat is a concrete type: `Matrix{Float64}`!
        MyMatrix.DiagMat(mat) => sum(mat)  # mat is a concrete type: `Diagonal{Float64, Vector{Float64}}`!
    end
end

Notice how on the inside, for each branch sum is called with a concrete type, i.e., no sum-type business. This makes the entire inner sum(mat) call type grounded “all the way down”, for each branch. Then, since they return the same type (this is important), that means that our outer sum also returns the same type regardless of the branch, making it type grounded, too!

I get that. But my point was that MyMatrix presumably has a field of type Union{Matrix{...}, Diagonal{...}} and so the bit of code (which lives in Moshi or LightSumTypes) which branches on the type briefly has to handle a variable whose type is unknown.

So my question is how does Moshi or LightSumTypes make accessing the underlying value grounded?

Do they basically do something like the following?

if mymatrix.value isa Matrix
   ...
elseif mymatrix.value isa Diagonal
   ...

And if so, is this grounded enough for --trim even though mymatrix.value cannot have an inferred concrete type at the outset? Is type-checking exempt from the notion of type-groundedness?

1 Like

Yes, they do basically that. I would indeed like to see if this is enough for --trim for big sum types. Don’t know the answers to your other questions :sweat_smile:

Do these sum-type packages rely on low-level tricks beyond “pure” Julia, e.g. ccall, to achieve type stability and efficient storage layout?

no, as far as I can tell, this is not the case for Moshi.jl, SumTypes.jl or LightSumTypes.jl. The only one that does something more low-level is Unityper.jl but it is actually worse than the others performance-wise.

1 Like

Does Union-splitting the two possible concrete types for m not remove dynamic dispatches? Or is Union-split code not trimmable for another reason?

In general cases though I agree with this approach to static code.

It’s only the two-language problem if we’re writing essentially the same code in another language. Putting aside the already addressed caveat that there aren’t 2 languages, I would argue that working with sum types, even if wholly internally like this, is semantically different enough from dynamically dispatching over runtime types. We wouldn’t call edits for multithreading or GPU arrays to be a two-language problem for the same reason. As nice as easy @static_please <expr> and @multithread_please <expr> macros would be, programming isn’t that declarative.

I think the juliac talks did speculate on more aggressive whole-program optimizations that branch over more types in an inferred Union. But it wasn’t clear what the limitations there are; I believe I’ve read that Julia is too slow to build if the compiler generally tries to compile 10 possibilities instead of falling back to dynamic dispatch. Sum types on the other hand can be isolated to specified contexts instead, and they at their best would require about the same code as anything handled by a macro or compiler at their best (sum(variant(m)) is a lot shorter than a trivial match statement repeating an expression, though I’m not sure how variant calls are type-stable).

1 Like