Are there idioms in Julia for fast Algebraic Data Types (ADT)?

Coming from functional programming, one of the features I have been missing the most in Julia are Algebraic Data Types. Packages such as MLStyle are doing a pretty good job at adding them into the language but I am more interested in how to emulate them using idiomatic Julia code in this post.

Some Julia users have expressed skepticism about adding ADTs in the language on the ground that in many usecases, pattern matching can be emulated using multiple dispatch. For example, we can define a type for simple expressions as follows:

abstract type Expression end

struct Const <: Expression
    value :: Int
end

struct Add <: Expression
    lhs :: Expression
    rhs :: Expression
end

evaluate(e::Const) = e.value
evaluate(e::Add) = evaluate(e.lhs) + evaluate(e.rhs)

evaluate(Add(Const(2), Const(3))) # evaluates to `5`

One difference between this code and traditional ADTs is that Expr here is not closed: anyone can add a new subtype at any moment. This prevents static exhaustive checks but this is not what worries me here. What I am worried about is performances. Indeed:

  • The fields of Add have abstract types, which is normally a big performance red flag
  • Even in the absence of a recursive definition, users want to manipulate objects such as vector of expressions (Vector{Expr}) and the same problem happens then.

A tentative fix would be to do something like this

abstract type Expression end

const AnyExpression = Union{Const, Add}

struct Const <: Expression
    value :: Int
end

struct Add <: Expression
    lhs :: AnyExpression
    rhs :: AnyExpression
end

Unfortunately, this code does not compile because AnyExpression and Add are mutually recursive. Also, assuming it compiled, I am not sure how much I can trust Julia’s implementation to always do the smart thing with those union types, especially when they get bigger (imagine an expression definition with 10 cases).

So, here are my questions:

  • Are you aware of any idiom to define fast ADTs in Julia?
  • How much of a performance penalty would you expect when going with the naive solution I sketched in the first listing?
  • In the case of non-recursive ADTs, would the trick of defining a union type to gather all cases always lead to efficient code (when working with Vector{AnyExpr} for example)?
  • Do packages such as MLStyle address the performance issues I am pointing out? (This may be a good question for @thautwarm)

Edit: I found a way to encode closed recursive ADTs in Julia and benchmarked this solution against a naive solution. Unfortunately, it performs worse (4.7ÎĽs vs 3.3ÎĽs for the naive solution), probably due to having to do more allocations.

Edit 2: I also benchmarked Julia against OCaml on manipulating ADTs and Julia is only 2x slower. This makes me feel better about encoding ADTs in Julia.

9 Likes

I don’t know much about ADTs, but if I understand correctly, your specific problem here can be easily solved using parametric types.

abstract type Expression end
struct Const{T} <: Expression
    value :: T
end
struct Add{L, R} <: Expression
    lhs::L
    rhs::R
end

evaluate(e::Const) = e.value
evaluate(e::Add) = evaluate(e.lhs) + evaluate(e.rhs)
julia> @code_warntype evaluate(Add(Const(2), Const(3)))
Variables
  #self#::Core.Compiler.Const(evaluate, false)
  e::Add{Const{Int64},Const{Int64}}

Body::Int64
1 ─ %1 = Base.getproperty(e, :lhs)::Const{Int64}
│   %2 = Main.evaluate(%1)::Int64
│   %3 = Base.getproperty(e, :rhs)::Const{Int64}
│   %4 = Main.evaluate(%3)::Int64
│   %5 = (%2 + %4)::Int64
└──      return %5

2 Likes

@Mason Unfortunately, this solution does not solve the problem of working with vectors of expressions. Indeed, you would be forced to manipulate a Vector{Expression} and the same problems would arise again.

If you want to efficiently handle containers of heterogeneous types, you need to use Tuples instead.

julia> typeof((Add(Const(3), Const(4.0)), Const(3)))
Tuple{Add{Const{Int64},Const{Float64}},Const{Int64}}

julia> (Add(Const(3), Const(4.0)), Const(3)) isa NTuple{2, Expression}
true

Vector just doesn’t seem like an appropriate datastructure for this kind of thing anyways, since it’s heap allocated and mutable with no static size.

1 Like

This is an interesting attempt, but this solution lacks flexibility. You cannot expect the type of every container of expressions you will have to manipulate to be statically inferrable in general.

And I don’t understand why Vector would not be an appropriate container here. In many languages, a list of expressions would just be stored as an array of (tag, pointer) pairs. Sure, it involves indirections and allocations but you can’t really do any better unless you know the types of every expression you want to store statically.

I guess it would help if you explained what you were actually trying to do.

Oh, if you’re happy with that, just use Vector, that’s exactly what it will do. I only suggested Tuple because I thought you seemed concerned about the overhead and allocations associated with an array of an abstract type.

For reference, Julia’s own expression type Expr stores a Symbol head and a Vector{<:Any} of arguments.

julia> dump(:(1 + 2))
Expr
  head: Symbol call
  args: Array{Any}((3,))
    1: Symbol +
    2: Int64 1
    3: Int64 2

This is fine, it just means that the output type of indexing into args isn’t statically inferrable.

4 Likes

The main thing that makes this painful in julia is that performance tends to be quite dependent on type inference because our dynamic dispatches are so costly (due to everything being generic functions => huge method table). However, clever use of things like function barriers and @nospecialize can help a lot.

1 Like

Concretely, I want the following function to be as fast as possible:

function evaluate_sum(es::Vector{Expression})
  s = 0
  for e in es
    s += evaluate(e)
  end
  return s
end

In a functional language where Expression is a closed ADT, the “dispatch” that happens at every call to evaluate costs almost nothing (just making a switch on an integer tag). I am worried that this may not be true in Julia. I guess I should run concrete benchmarks to see how problematic it is in practice.

2 Likes

I am worried about the cost of dispatch, exactly. Any idea on how to make it as small as possible in my evaluate_sum example?

How fast are you looking for this to be? It’s already quite fast:

abstract type Expression end
struct Const{T} <: Expression
    value :: T
end
struct Add{L, R} <: Expression
    lhs::L
    rhs::R
end


evaluate(e::Const) = e.value
evaluate(e::Add) = evaluate(e.lhs) + evaluate(e.rhs)

function evaluate_sum(exprs::Vector{Expression})
    s = 0
    for e in exprs
        s += evaluate(e)
    end
    s
end
julia> es = [Const(1)
             Add(Const(2), Const(3))
             Add(Add(Const(-1), Const(3)), Const(4))
             Const(40)]
4-element Array{Expression,1}:
 Const{Int64}(1)
 Add{Const{Int64},Const{Int64}}(Const{Int64}(2), Const{Int64}(3))
 Add{Add{Const{Int64},Const{Int64}},Const{Int64}}(Add{Const{Int64},Const{Int64}}(Const{Int64}(-1), Const{Int64}(3)), Const{Int64}(4))
 Const{Int64}(40)

julia> @btime evaluate_sum(es)
  127.472 ns (0 allocations: 0 bytes)
52

I’m not really sure what to compare it to though.

1 Like

Turns out it’s much better to just go with your original idea and have abstract storage:

abstract type Expression end
struct Const <: Expression
    value :: Int
end
struct Add <: Expression
    lhs :: Expression
    rhs :: Expression
end


evaluate(e::Const) = e.value
evaluate(e::Add)   = evaluate(e.lhs) + evaluate(e.rhs)

function evaluate_sum(exprs::Vector{Expression})
    s = 0
    for e in exprs
        x = let e = e
            evaluate(e)
        end
        s += x
    end
    s
end

es = [Const(1)
      Add(Const(2), Const(3))
      Add(Add(Const(-1), Const(3)), Const(4))
      Const(40)]
julia> @btime evaluate_sum(es);
15.159 ns (0 allocations: 0 bytes)

(note you’ll need to restart julia to run this due to type redefinitions)

2 Likes

I ran the following benchmark to compare the cost of doing dispatch on closed unions with the cost of doing dispatch on abstract types:

using BenchmarkTools

abstract type SignedInteger end

struct Pos <: SignedInteger
    abs :: UInt64
end

struct Neg <: SignedInteger
    abs :: UInt64
end

const AnySignedInteger = Union{Pos, Neg}

const posvec = [Pos(i) for i in 1:100]
const negvec = [Neg(i) for i in 1:100]

value(x::Pos) = Int64(x.abs)
value(x::Neg) = -Int64(x.abs)

function test_open()
    return sum(value(x) for x in SignedInteger[posvec; negvec])
end

function test_closed()
    return sum(value(x) for x in AnySignedInteger[posvec; negvec])
end

println("Testing open version")
@btime test_open()
println("Testing closed version")
@btime test_closed()

The result:

Testing open version
  1.792 ÎĽs (202 allocations: 4.92 KiB)
Testing closed version
  697.020 ns (2 allocations: 2.02 KiB)

Conclusion: it is about 2.3x faster to do dispatch on a closed union type. I actually expected more of a difference, which makes me think that my naive solution would actually not be prohibitively slow compared to something smarter.

Interesting. This means that abstract storage is not as slow as I would have thought.
Do you have any idea why this is faster than the version where you add type parameters to Add and Const?

No, I’m actually a bit perplexed as this code:

abstract type Expression end
struct Const <: Expression
    value :: Int
end
struct Add{L, R} <: Expression
    lhs :: L
    rhs :: R
end

const AddAbstract = Add{Expression, Expression}


evaluate(e::Const) = e.value
evaluate(e::Add)   = evaluate(e.lhs) + evaluate(e.rhs)

function evaluate_sum(exprs::Vector{Expression})
    s = 0
    for e in exprs
        s += evaluate(e)
    end
    s
end

es = [Const(1)
      AddAbstract(Const(2), Const(3))
      AddAbstract(AddAbstract(Const(-1), Const(3)), Const(4))
      Const(40)]
julia> @btime evaluate_sum($es);
49.006 ns (0 allocations: 0 bytes)

produces identical @code_warntype as the other version, but is slower. This suggests that perhaps the problem is on the LLVM side.

Update: I found a way to encode closed recursive ADTs in Julia and benchmarked this solution against a naive solution. Unfortunately, it performs worse (4.7ÎĽs vs 3.3ÎĽs for the naive solution), probably due to having to do more allocations.

If anyone finds a better way, please tell me!

Naive solution

abstract type Expression end

struct Const <: Expression
    value :: Int
end

struct Var <: Expression
    varname ::String
end

struct Add <: Expression
    lhs :: Expression
    rhs :: Expression
end

evaluate(e::Const, env) = e.value
evaluate(e::Var, env) = env[e.varname]
evaluate(e::Add, env) = evaluate(e.lhs, env) + evaluate(e.rhs, env)

function sum_of_ints(n)
    if n == 1
        return Const(1)
    else
        return Add(Const(n), sum_of_ints(n - 1))
    end
end

using BenchmarkTools
@btime evaluate(sum_of_ints(100), Dict{String, Int}())
3.228 ÎĽs (202 allocations: 5.23 KiB)

Solution that encodes recursive closed ADTs

struct Const
    value :: Int
end

struct Var
    varname ::String
end

struct Add{E}
    lhs :: E
    rhs :: E
end

struct Expression
    ctor :: Union{Const, Var, Add{Expression}}
end

mkConst(value) = Expression(Const(value))
mkAdd(lhs, rhs) = Expression(Add{Expression}(lhs, rhs))
mkVar(var) = Expression(Var(var))

evaluate(e::Expression, env) = evaluate(e.ctor, env)
evaluate(e::Const, env) = e.value
evaluate(e::Var, env) = env[e.varname]
evaluate(e::Add, env) = evaluate(e.lhs, env) + evaluate(e.rhs, env)

function sum_of_ints(n)
    if n == 1
        return mkConst(1)
    else
        return mkAdd(mkConst(n), sum_of_ints(n - 1))
    end
end

using BenchmarkTools
@btime evaluate(sum_of_ints(100), Dict{String, Int}())
4.681 ÎĽs (396 allocations: 8.25 KiB)

Maybe I’m missing something, but isn’t it nice that the simple solution is faster? At any rate, I don’t think the performance difference between your two solutions is particularly large.

@CameronBieganek What this experiment indicates, in my opinion, is that the compiler misses an opportunity to optimize the second version. In a perfect world, the second version should be faster as the compiler would leverage the fact that an expression can be nothing else other than a Const, a Var or an Add, and make dynamic dispatch very fast based on this.

So a more interesting comparison would be to compare the time it takes to evaluate the naive version in Julia with an equivalent program written in a language with native ADTs such as OCaml, Haskell or Rust. I am going to try this now.

So I did the experiment in OCaml, which is about twice as fast as the Julia version (1.61ÎĽs vs 3.23ÎĽs). This is actually not too bad for Julia and this makes me feel better about using ADTs in Julia.

Benchmark Code

type expr =
  | Const of int
  | Var of string
  | Add of expr * expr

let rec evaluate expr env =
  match expr with
  | Const v -> v
  | Var x -> List.Assoc.find_exn env ~equal:String.equal x
  | Add (lhs, rhs) -> evaluate lhs env + evaluate rhs env

let rec sum_of_ints = function
  | 1 -> Const 1
  | n -> Add (Const n, sum_of_ints (n - 1))

let profile n =
  let acc = ref 0 in
  let t = Caml.Sys.time () in
  for i = 1 to n do
    acc := !acc + evaluate (sum_of_ints 100) []
  done;
  let dt = (Caml.Sys.time () -. t) /. (Float.of_int n) in
  Stdio.printf "Average time: %.3f ÎĽs" (dt *. 1e6);
  acc

let _ = profile 1000000
Average time: 1.653 ÎĽs
3 Likes

This is similar to open types, and static exhaustive checking wouldn’t get affected if your analyzer can walk through the whole program.

I’m sorry that MLStyle didn’t address this performance issue.

Actually I did consider the questions you raised here, and due to the restrictions of Julia I don’t really find out an approach.

1 Like

Also, there is a technique to alter ADTs, called tagless final.

ADT approach is called initial approach in some context, and tagless final is called the final approach in this scope.

For your code, we can use tagless final, to achieve stably typed Julia program:

struct SYM{F1, F2}
    constant :: F1
    add :: F2
end

function constant(v)
    function (sym::SYM)
        sym.constant(v)
    end
end

function add(term1, term2)
    function (sym::SYM)
        sym.add(term1(sym), term2(sym))
    end
end

# self algebra
self = SYM(constant, add)

evaluate =
    let constant(v::Int) = v,
        add(l::Int, r::Int) = l + r
        SYM(constant, add)
    end


println(add(constant(2), constant(3))(evaluate))
@code_warntype add(constant(2), constant(3))(evaluate)

There’re no red points, try above codes in your Julia shell

5
Variables
  #self#::var"#17#18"{var"#15#16"{Int64},var"#15#16"{Int64}}
  sym::Core.Compiler.Const(SYM{var"#constant#19",var"#add#20"}(var"#constant#19"(), var"#add#20"()), false)

Body::Int64
1 ─ %1 = Base.getproperty(sym, :add)::Core.Compiler.Const(var"#add#20"(), false)
│   %2 = Core.getfield(#self#, :term1)::var"#15#16"{Int64}
│   %3 = (%2)(sym)::Int64
│   %4 = Core.getfield(#self#, :term2)::var"#15#16"{Int64}
│   %5 = (%4)(sym)::Int64
│   %6 = (%1)(%3, %5)::Int64
└──      return %6
6 Likes