I have recently spent a lot of time trying to get one of my libraries ready for the Julia 1.12 --trim
feature such that I can ship it as a C library and wrap it into a python module. During this work I learned a lot and would like to share a pattern I have found that helps design @ccallable
interfaces with dynamic types, for example a matrix type that might be either Matrix
or Diagonal
. This might be of interest to folks also interested in my notes on trimming Creating fully self-contained and pre-compiled library - #17 by RomeoV.
To illustrate, we’ll write a C-callable function matrixsum_cc
that computes the sum of a square matrix. The matrix is passed via a raw pointer Ptr{Cdouble}
and a size parameter, but crucially may have additional structure given through a Cint
enum value. In our example, it can be a dense matrix or a diagonal matrix. We want to leverage this information by converting it into a Matrix
or Diagonal
, respectively, so that we can load it correctly and use Julia’s dispatch for computation.
In short, in the end we want to write a function
@ccallable function matrixsum_cc(sz::Integer, ptr::Ptr{Cdouble}, mattype::MatType.Enum)::Cdouble
m = build_matrix(sz, ptr, mattype)
return sum(m)
end
that is trimmable, i.e., we can pass to juliac --trim
to get a static C library or executable.
The issue here is that this function is by default not “type grounded”, i.e., not all variable types within the function are determinable from the input types. In particular, m
will be either of type Matrix
or Diagonal
, depending on the value of mattype
(which is just a Cint
). We therefore need a way to go from the s-expression style formulation (storing the type info as a variable) to a formulation that is fully “type grounded” and compatible with multiple dispatch. We can do this by leveraging “closed” sum types.
We first need to build a datatype to return for our build_matrix
function. Since that can either be a dense or diagonal matrix, we define a “closed” sum type, i.e., a sum type that contains all possibilities. We will use Moshi.jl, although other libraries such as SumTypes.jl should also work.
import Pkg
Pkg.activate(; temp=true)
Pkg.add(["Moshi", "CEnum", "JET"])
using LinearAlgebra
import Moshi.Data: @data
@data MyMatrix{T<:Number} begin
DenseMat(Matrix{T})
DiagMat(Diagonal{T, Vector{T}}) # <- Make sure this is a concrete type. `Diagonal{T}` wouldn't be enough.
end
For convenience we also define a CEnum
that allows us for passing in the type of the matrix from C via a CInt
but refer to it in Julia via an enum. We wrap it into a module just to keep it in a namespace.
module MatType
using CEnum
@cenum Enum::Cint begin
dense = 1
diag = 2
end
end
Now we are ready to write our build_matrix
function. The function takes in the MatType
enum, builds the julia matrix, and returns the matrix wrapped into a sum type such that the function is type stable. Notice that all internal variables have a clearly inferable type, making the function additionally type grounded, a stronger property than type stable.
function build_matrix(sz::Integer, ptr::Ptr{Cdouble}, mattype::MatType.Enum)::MyMatrix.Type{Cdouble}
if mattype == MatType.dense
# Unsafe pointer logic to create the Matrix
m_dense = unsafe_wrap(Matrix{Cdouble}, ptr, (sz, sz))
return MyMatrix.DenseMat(m_dense) # Wrap it in the sum type variant
elseif mattype == MatType.diag
m_diag = Diagonal(unsafe_wrap(Vector{Cdouble}, ptr, sz))
return MyMatrix.DiagMat(m_diag) # Wrap it
else
error("Unmatched MatType")
end
end
Now comes the magic. For each variant of our sum type, we are able to call our kernel (here the sum
function) and call it with a concrete type! Crucially, to maintain type stability, our kernel needs to return the same type for each variant (here T=Float64
). Instead of sum
, this could also be something like mat\b
or a whole solver or simulation process.
import Base: sum
import Moshi.Match: @match
function sum(m::MyMatrix.Type{T})::T where {T}
@match m begin
MyMatrix.DenseMat(mat) => sum(mat)
MyMatrix.DiagMat(mat) => sum(mat)
end
end
Finally we’re able to write our C interface where we can pass in a raw pointer and the MatType
enum value (an integer) and get our result in a type-stable way. Notice how now m
has a concrete type: MyMatrix{Float64}
! This makes the entire code type grounded and ready for trimming.
import Base: @ccallable
@ccallable function matrixsum_cc(sz::Cint, ptr::Ptr{Cdouble}, mattype::MatType.Enum)::Cdouble
m = build_matrix(sz, ptr, mattype)
return sum(m)
end
To validate the type stability, we can check with @code_warntype
, or directly test with JET.jl.
julia> using JET
julia> densemat = rand(3,3);
julia> @test_opt matrixsum_cc(Cint(3), pointer(densemat), MatType.dense)
Test Passed
julia> diagmat = Diagonal(rand(3));
julia> @test_opt matrixsum_cc(Cint(3), pointer(diagmat.diag), MatType.diag)
Test Passed
That’s it!
PS: A copy-pastable version of all the code is here.