[Guide] Using sum types to define dynamic C APIs for Julia 1.12 `--trim`

I tried to dig deeper and indeed LightSumTypes.jl fails in the general case, sorry for having been wrong, e.g. I think this is a total failure in terms of inference:

julia> using LightSumTypes

julia> @sumtype X(Bool, Int, Vector{Bool}, Vector{Int})

julia> xs = [X([1,2]), X(true)]

julia> Base.sum(x::X) = sum(variant(x));

julia> @code_warntype sum.(xs)
MethodInstance for (::var"##dotfunction#230#3")(::Vector{X})
  from (::var"##dotfunction#230#3")(x1) @ Main none:0
Arguments
  #self#::Core.Const(var"##dotfunction#230#3"())
  x1::Vector{X}
Body::AbstractVector
1 ─ %1 = Base.broadcasted(Main.sum, x1)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(sum), Tuple{Vector{X}}}
│   %2 = Base.materialize(%1)::AbstractVector
└──      return %2

LightSumTypes.jl works efficiently when working with fields of structs, but I actually didn’t try to work something out for this use case, and I’m afraid that currently it doesn’t work as expected.

Though, I think we can easily patch this with adding to the source code something like:

function apply(f::Function, sumt)
   v = $LightSumTypes.unwrap(sumt)
   $(branchs(variants, :(return f(v))))
end

with this @code_warntype returns

julia> Base.sum(x::X) = apply(sum, x);

julia> @code_warntype sum.(xs)
MethodInstance for (::var"##dotfunction#230#1")(::typeof(sum), ::Vector{X})
  from (::var"##dotfunction#230#1")(x1, x2) @ Main none:0
Arguments
  #self#::Core.Const(var"##dotfunction#230#1"())
  x1::Core.Const(sum)
  x2::Vector{X}
Body::Vector{Int64}
1 ─ %1 = Base.broadcasted(Main.apply, x1, x2)::Base.Broadcast.Broadcasted{Base.Broadcast.DefaultArrayStyle{1}, Nothing, typeof(apply), Tuple{Base.RefValue{typeof(sum)}, Vector{X}}}
│   %2 = Base.materialize(%1)::Vector{Int64}
└──      return %2

I don’t have much time at the moment, but if someone wants to help by implementing a more structured version of this idea, I would be glad to include something like this :slight_smile:

Though it will be harder for functions which require multiple inputs, some of which are not sumtypes. We are basically bound to try to pattern match automatically. Maybe possible with generated functions?

2 Likes

It also occurred to me yesterday that the code

sum(m::MyMatrix2{T}) where {T} = sum(variant(m))

is not type grounded. If we rewrite this function as

function sum(m::MyMatrix2{T}) where {T}
  mvar = variant(m)
  sum(mvar)
end

notably we can not infer the type of mvar just from the type MyMatrix2{T}, which by definition means this is not type grounded. Still, JET still seems fine with this, so I wonder if trimming ends up working or not. I’ll make some MWEs to test trimming.

I can confirm that both versions with Moshi and LightSumTypes are trimmable without any verifier errors or problems at runtime. Check out my example here which includes usage from C run on CI (check the make step here).

Now, how this relates to my post just above I’m not sure. I suppose for these calls the compiler is able to come up with type grounded through lowering or compiler tricks somehow…

1 Like

My guess is that sum(variant(m)) is optimized with Union-splitting, which branches to statically dispatched calls for very small instabilities (one input with 3 types by default, at least 2 types due to nothing semantics). I alluded to it earlier, but try JET on and trim this to see if it Union-splits as well:

@ccallable function matrixsum_cc(sz::Cint, ptr::Ptr{Cdouble}, matflag::Cint)::Cdouble
    m = build_matrix(sz, ptr, matflag)
    return sum(m)
end

function build_matrix(sz::Integer, ptr::Ptr{Cdouble}, matflag::Cint)
    if matflag == 1
        # Unsafe pointer logic to create the Matrix
        m_dense = unsafe_wrap(Matrix{Cdouble}, ptr, (sz, sz))
        return m_dense
    elseif matflag == 2
        m_diag = Diagonal(unsafe_wrap(Vector{Cdouble}, ptr, sz))
        return m_diag
    else
        error("Matrix (1) or Diagonal (2) only.")
    end
end
3 Likes

That’s exactly it in my experiments. With a small number of types the compiler is able to use Union-splitting. Though as I showed above it fails with more types. So we need something like the apply trick I mentioned, which would be type-grounded as far as I can tell. Opened this issue Function to perform branching automatically to solve inference failures · Issue #120 · JuliaDynamics/LightSumTypes.jl · GitHub tracking this feature in LightSumTypes.jl based on this discussion

Attempt for the apply function, seems to work perfectly! Add apply function for type stable calls of functions by Tortar · Pull Request #121 · JuliaDynamics/LightSumTypes.jl · GitHub

I’m actually thinking that LightSumTypes.jl is too opinionated in some of its implementation. I will work as soon as I find the time to its successor adding the most minimalistic interface I can think of in this repo (for now empty): GitHub - Tortar/WrappedUnions.jl: Wrap a Union and Enjoy Type-Stability :slight_smile:

2 Likes

Made it! WrappedUnions.jl. If you have any suggestion, please open an issue (maybe I’m clattering too much this thread), I would be glad to discuss! By the way, 70 LOC :smiley:

5 Likes