[Guide] Using sum types to define dynamic C APIs for Julia 1.12 `--trim`

I have recently spent a lot of time trying to get one of my libraries ready for the Julia 1.12 --trim feature such that I can ship it as a C library and wrap it into a python module. During this work I learned a lot and would like to share a pattern I have found that helps design @ccallable interfaces with dynamic types, for example a matrix type that might be either Matrix or Diagonal. This might be of interest to folks also interested in my notes on trimming Creating fully self-contained and pre-compiled library - #17 by RomeoV.

To illustrate, we’ll write a C-callable function matrixsum_cc that computes the sum of a square matrix. The matrix is passed via a raw pointer Ptr{Cdouble} and a size parameter, but crucially may have additional structure given through a Cint enum value. In our example, it can be a dense matrix or a diagonal matrix. We want to leverage this information by converting it into a Matrix or Diagonal, respectively, so that we can load it correctly and use Julia’s dispatch for computation.

In short, in the end we want to write a function

@ccallable function matrixsum_cc(sz::Integer, ptr::Ptr{Cdouble}, mattype::MatType.Enum)::Cdouble
    m = build_matrix(sz, ptr, mattype)
    return sum(m)
end

that is trimmable, i.e., we can pass to juliac --trim to get a static C library or executable.

The issue here is that this function is by default not “type grounded”, i.e., not all variable types within the function are determinable from the input types. In particular, m will be either of type Matrix or Diagonal, depending on the value of mattype (which is just a Cint). We therefore need a way to go from the s-expression style formulation (storing the type info as a variable) to a formulation that is fully “type grounded” and compatible with multiple dispatch. We can do this by leveraging “closed” sum types.

We first need to build a datatype to return for our build_matrix function. Since that can either be a dense or diagonal matrix, we define a “closed” sum type, i.e., a sum type that contains all possibilities. We will use Moshi.jl, although other libraries such as SumTypes.jl should also work.

import Pkg
Pkg.activate(; temp=true)
Pkg.add(["Moshi", "CEnum", "JET"])

using LinearAlgebra
import Moshi.Data: @data
@data MyMatrix{T<:Number} begin
    DenseMat(Matrix{T})
    DiagMat(Diagonal{T, Vector{T}})  # <- Make sure this is a concrete type. `Diagonal{T}` wouldn't be enough.
end

For convenience we also define a CEnum that allows us for passing in the type of the matrix from C via a CInt but refer to it in Julia via an enum. We wrap it into a module just to keep it in a namespace.

module MatType
    using CEnum
    @cenum Enum::Cint begin
        dense = 1
        diag = 2
    end
end

Now we are ready to write our build_matrix function. The function takes in the MatType enum, builds the julia matrix, and returns the matrix wrapped into a sum type such that the function is type stable. Notice that all internal variables have a clearly inferable type, making the function additionally type grounded, a stronger property than type stable.

function build_matrix(sz::Integer, ptr::Ptr{Cdouble}, mattype::MatType.Enum)::MyMatrix.Type{Cdouble}
    if mattype == MatType.dense
        # Unsafe pointer logic to create the Matrix
        m_dense = unsafe_wrap(Matrix{Cdouble}, ptr, (sz, sz))
        return MyMatrix.DenseMat(m_dense) # Wrap it in the sum type variant
    elseif mattype == MatType.diag
        m_diag = Diagonal(unsafe_wrap(Vector{Cdouble}, ptr, sz))
        return MyMatrix.DiagMat(m_diag) # Wrap it
    else
        error("Unmatched MatType")
    end
end

Now comes the magic. For each variant of our sum type, we are able to call our kernel (here the sum function) and call it with a concrete type! Crucially, to maintain type stability, our kernel needs to return the same type for each variant (here T=Float64). Instead of sum, this could also be something like mat\b or a whole solver or simulation process.

import Base: sum
import Moshi.Match: @match
function sum(m::MyMatrix.Type{T})::T where {T}
    @match m begin
        MyMatrix.DenseMat(mat) => sum(mat)
        MyMatrix.DiagMat(mat) => sum(mat)
    end
end

Finally we’re able to write our C interface where we can pass in a raw pointer and the MatType enum value (an integer) and get our result in a type-stable way. Notice how now m has a concrete type: MyMatrix{Float64}! This makes the entire code type grounded and ready for trimming.

import Base: @ccallable
@ccallable function matrixsum_cc(sz::Cint, ptr::Ptr{Cdouble}, mattype::MatType.Enum)::Cdouble
    m = build_matrix(sz, ptr, mattype)
    return sum(m)
end

To validate the type stability, we can check with @code_warntype, or directly test with JET.jl.

julia> using JET
julia> densemat = rand(3,3);
julia> @test_opt matrixsum_cc(Cint(3), pointer(densemat), MatType.dense)
Test Passed
julia> diagmat = Diagonal(rand(3));
julia> @test_opt matrixsum_cc(Cint(3), pointer(diagmat.diag), MatType.diag)
Test Passed

That’s it! :juliabouncing: :llvm:

PS: A copy-pastable version of all the code is here.

26 Likes

Thank you for sharing @RomeoV ! I wonder if this code pattern could be abstracted out in a macro that rewrites normal Julia function definitions into a sum-type approach that is C-callable.

If we have to rewrite code with wrapper sum-types to achieve static compilation, we are basically reviving the two-language problem. The difference is that we are rewriting code using Julia syntax.

Maybe I didn’t get this right, but to me it seems that no code needs to be rewritten.
Please correct me if I’m wrong but I think all the implemented functionality can stay the same and it is just wrapped in an additional type stable/grounded layer.

I’m not arguing that a macro that does this automatically wouldn’t be nice, but calling this the revived two-language problem seems a bit far-fetched.

This is equivalent to writing wrappers in a different language, with restricted types. That is what I tried to say. Ideally, no extra work would be needed from the Julia developer other than calling a macro that creates wrappers automatically for a specific subset of selected types.

2 Likes

it seems that no code needs to be rewritten

Yes, that’s right. New compatibility code has to be written only at the “boundary” of the API. None of the internal Julia code needs to be restructured (besides that all Julia functions need to be type grounded too, for trimmability).

1 Like

Let’s say we want to make 100 different functions C-callable. Yes, then we’d need to make 100 custom interface functions. Further, the interface functions can look more complicated. Let’s say we are parsing two arrays, not just one. Then we’d have to do something like

using .MyMatrix: DenseMat, DiagMat
function sum(m::MyMatrix.Type{T}, m2::MyMatrix.Type{T})::T where {T}
    @match (m, m2) begin
        (DenseMat(mat), DenseMat(mat2)) => sum(mat)+sum(mat2)
        (DenseMat(mat), DiagMat(mat2)) => sum(mat)+sum(mat2)
        (DiagMat(mat), DenseMat(mat2)) => sum(mat)+sum(mat2)
        (DiagMat(mat), DiagMat(mat2)) => sum(mat)+sum(mat2)
    end
end

So, obviously there’s some combinatorial explosion here, with (n_{variants}) ^ {n_{variables}} choices. One could definitely macro generate this.

Could we also just generate all the required functions automatically for any Julia function?
Let’s consider what we need to turn our original sum(m::AbstractMatrix) into matrixsum_cc(sz, ptr, mattype).

  1. First, we need to enumerate exactly which concrete matrix types we need to compile for. This is the @data ... part, which we can’t really make more concise through another macro. In normal Julia code, this is determined automatically by checking which concrete types the sum function ends up getting called with, but for a C interface this has to be done manually somehow.
  2. Second, we have to define an enum which maps each variant to a Cint to make it ccallable (MatType above). This could in theory be automated with a macro, but it’s important that it stays readable.
  3. Third, we have to define how to load each matrix from the tuple (sz, ptr, mattype). This can just be written once and reused, without macros. In fact, it’s exactly our build_matrix function!
  4. Finally, we have to use the @match magic to call our kernel function with concrete types. If we have only univariate “wrapper” sumtypes (just one variable per sum-type variant, nothing nested etc) I think we could definitely generate this.

So perhaps a macro could look something like this:

@ccallablevariants MyMatrix{T<:Number} Cint begin
    DenseMat(Matrix{T})=1  # will have enum value `Cint==1`
    DiagMat(Diagonal{T, Vector{T}})=2  # will have enum value `Cint==2`
end 
# make this also generate `MyMatrix.Enum`, with `MyMatrix.Enum.DenseMat == CInt(1)` etc.

# make `construct_from_ctypes` a standarized function for `@generate_ccallable` below
function construct_from_ctypes(::Type{MyMatrix.Type})
    function build_matrix(sz::Integer, ptr::Ptr{Cdouble}, mattype::MatType.Enum)::MyMatrix.Type{Cdouble}
        # as above ...
    end
    return (; sz=Integer, ptr=Ptr{Cdouble}, mattype=MyMatrix.Enum), build_matrix
end
    
# now this can generate all the other code
@generate_ccallable sum(MyMatrix{Float64})

or for multiple parameters and a mix of sum types and regular types

foo(mat::AbstractMatrix, mat2::AbstractMatrix, otherparam::Integer) = ...
@generate_ccallable foo(::MyMatrix{Float64}, ::MyMatrix{Float64}, otherparam::Cint)

# the call above would generate something like
function foo(sz_1::Integer, ptr_1::Ptr{Cdouble}, mattype_1::MyMatrix.Enum,
             sz_2::Integer, ptr_2::Ptr{Cdouble}, mattype_2::MyMatrix.Enum,
             otherparam::Cint)
    # ...
end
1 Like

Not sure if correct, but it seems LightSumTypes.jl could be a bit better, you wouldn’t have to deal with the combinatorial explosion. You would just write sum(variant(m))+sum(variant(m2)) instead of defining all the matches.

1 Like

Thanks, using LightSumTypes.jl indeed looks like an even cleaner alternative!

using LightSumTypes
@sumtype MyMatrix2{T}(
    Matrix{T},
    Diagonal{T, Vector{T}}
)

# even simpler now
sum(m::MyMatrix2{T}) where {T} = sum(variant(m))

function build_matrix2(sz::Integer, ptr::Ptr{Cdouble}, mattype::MatType.Enum)
    if mattype == MatType.dense
        # Unsafe pointer logic to create the Matrix
        m_dense = unsafe_wrap(Matrix{Cdouble}, ptr, (sz, sz))
        return MyMatrix2(m_dense) # Wrap it in the sum type variant
    elseif mattype == MatType.diag
        m_diag = Diagonal(unsafe_wrap(Vector{Cdouble}, ptr, sz))
        return MyMatrix2{Float64}(m_diag) # Wrap it
    else
        error("Unmatched MatType")
    end
end

@ccallable function matrixsum2_cc(sz::Cint, ptr::Ptr{Cdouble}, mattype::MatType.Enum)::Cdouble
    m = build_matrix2(sz, ptr, mattype)
    return sum(m)
end

@test_opt matrixsum2_cc(Cint(3), pointer(diagmat.diag), MatType.diag)
# Test Passed

EDIT: Wow and LightSumTypes.jl is like 100 loc or so total! I think it’s definitely a great choice here.

2 Likes

Nice! Yes, indeed I didn’t particularly try to write the cleanest source code possible, we could defnitely try to arrive at 50 LOC. In a sense, the approach is very simple and it could be ideally in my opinion adopted in Base if sensible. The question is if the concept of a “ClosedUnion” is generally useful to have in Julia itself. Your post makes me think that maybe it does.

2 Likes

Yes, fully agree.

But that’s just not what is usually referred to as the “two-language problem”.
Sure, it might also involve the notion of wrapping code (e.g. C in python), but the core idea is that prototyping happens in a fast to work in but slow to execute language (python) followed by (at least partly) rewriting in a low-level, compiled, fast to execute language (C) for production.

If a LightSumType is better than a Union, why isn’t it the default? What are the disadvantages of LightSumTypes compared to a Union ?

I had mixed results comparing performance between the two approaches, usually LightSumTypes is better though, actually with Julia 1.10 it is much better. Apart from that, a Union doesn’t require the boilerplate LightSumTypes needs, and also, a Union is a covariant abstract type while LightSumTypes creates invariant types. I think these characteristics makes a Union usually easier to work with.

1 Like