Performance optimizations with runtime dispatch

I’m working on types that efficiently wrap polish notation and allow for fast evaluation of mathematical expressions. If I run @profview on the following script, I see that most of the time inside fpass! is spent deciding which specific method to call, depending on the node.

I’m looking to understand the dispatch component here, I feel that I have read and tried to apply the performance tips section of the julia manual.

Here is reproducible code:

## Definitions

mutable struct QNode{T,A}
    op::T

    function QNode(op::T, arity::Int) where {T<:Function}
        if arity == 1
            return new{T,Val{1}}(op)
        elseif arity == 2
            return new{T,Val{2}}(op)
        end
    end
    QNode(op::T) where {T<:Real} = new{Float32,Val{0}}(Float32(op))
    QNode(op::Symbol) = new{Symbol,Val{0}}(op)
end

struct TreePtr
    l::Union{Int64,Nothing}
    r::Union{Int64,Nothing}

    TreePtr() = new(nothing, nothing)
    TreePtr(l) = new(l, nothing)
    TreePtr(l, r) = new(l, r)
end

struct QExpr
    program::Vector{QNode}
    tree_ixs::Vector{TreePtr}

    QExpr(program::Vector{QNode}) = new(
        program,
        parse_tree_ixs(program),
    )
    QExpr(program::Vector{QNode}, tree_ixs::Vector{TreePtr}) = new(
        program,
        tree_ixs,
    )
end
Base.length(qexpr::QExpr) = length(qexpr.program)
Base.size(qexpr::QExpr) = size(qexpr.program)
Base.getindex(qexpr::QExpr, ix) = qexpr.program[ix]

function parse_tree_ixs(program::Vector{QNode})
    result = []
    recent_unused = []
    for ix in reverse(eachindex(program))
        n = program[ix]
        parse_node(n, recent_unused, result)
        pushfirst!(recent_unused, ix)
    end
    result
end

parse_node(node::QNode{T,Val{0}}, recent_unused, result) where {T} = pushfirst!(result, TreePtr())
parse_node(node::QNode{T,Val{1}}, recent_unused, result) where {T} = pushfirst!(result, TreePtr(popfirst!(recent_unused)))
parse_node(node::QNode{T,Val{2}}, recent_unused, result) where {T} = pushfirst!(result, TreePtr(popfirst!(recent_unused), popfirst!(recent_unused)))

function fpass!(expr::QExpr, forwards::Vector{Float32}, row::NamedTuple)
    for node_ix in reverse(eachindex(expr.program))
        fpass_node!(expr[node_ix], node_ix, expr.tree_ixs[node_ix], forwards, row)
    end
end

function fpass_node!(node::QNode{Symbol,Val{0}}, node_ix::Int, tree_ptr::TreePtr, forwards::Vector{Float32}, row::NamedTuple)
    @inbounds begin
        forwards[node_ix] = row[node.op]
    end
end

function fpass_node!(node::QNode{Float32,Val{0}}, node_ix::Int, tree_ptr::TreePtr, forwards::Vector{Float32}, row::NamedTuple)
    @inbounds begin
        forwards[node_ix] = node.op
    end
end

function fpass_node!(node::QNode{<:Function,Val{1}}, node_ix::Int, tree_ptr::TreePtr, forwards::Vector{Float32}, row::NamedTuple)
    @inbounds begin
        forwards[node_ix] = node.op(forwards[tree_ptr.l])
    end
end

function fpass_node!(node::QNode{<:Function,Val{2}}, node_ix::Int, tree_ptr::TreePtr, forwards::Vector{Float32}, row::NamedTuple)
    @inbounds begin
        forwards[node_ix] = node.op(forwards[tree_ptr.l], forwards[tree_ptr.r])
    end
end

## Profiling
qexpr = QExpr([QNode(+, 2), QNode(exp, 1), QNode(:x), QNode(0.1)])
row = (x=3,)

forwards = Vector{Float32}(undef, length(qexpr))

fpass!(qexpr, forwards, row)

println(forwards)
println(exp(3) + 0.1)

@profview for _ in 1:1_000_000
    fpass!(qexpr, forwards, row)
end

I would suggest eschewing runtime dispatch. Instead use a switch case to call the appropriate function (and make sure there’s no type instability when calling a function for a specific case) or use something like Virtual.jl

I would recommend taking a look at how DynamicExpressions.jl does things with the Node type:

If your fpass_node! functions were more complicated you might benefit from the techniques used here, Parsing Protobuf at 2+GB/s: How I Learned To Love Tail Calls in C (reverberate.org) and Building the fastest Lua interpreter… automatically! | (sillycross.github.io). Mainly the stuff about tail calls, though I don’t know if you could do that nicely in Julia.

maybe MixedStructTypes.jl could help you? It should help to avoid dynamic dispatch

Thanks for your input Miles! Tree structures are nice for mutation rules, but due to recursive definitions not as quick to evaluate as Polish notation expressions.

Not sure about memory impact, but if your Node type had a pointer to move up and down in a polish program representation, then an evaluation like this would use less CPU cycles:

mutable struct Node{T}
    #... other stuff not included
    prog_up::Union{Node{T}, Nothing}
    prog_down::Union{Node{T}, Nothing}
end

function eval_expr(node::Node, df::DataFrame)
	while node !== nothing
		eval_node(node, df)
		node = node.prog_up
	end
end

The union of nothing might not be the approach you’d choose though.

Fields of non-concrete type destroy performance.

First thing to do is create a representation of your expressions storing all or almost all data as values, not in the type domain. Executing these should proceed without any run time dispatch.

After this works fine, if the performance isn’t satisfying, should you consider run time dispatch/compilation. If it turns out the additional performance that’s possible by introducing compiled versions of your expressions is necessary, then and only then design another expression representation, but this time move some of the data to the type domain. Again, executing these should proceed without any run time dispatch. Run time dispatch should only happen while converting the first representation type to the second.

Thanks for the tips! I did indeed end up eschewing runtime dispatch completely in favour of using switch cases like Zentrik suggested.

Yes, which is also why I mention the union of nothing not being the first choice of an implementer. Excuse the bad code example, the point I was making is that for a given mathematical expression, evaluating it in Polish notation will be faster than evaluating it as a binary tree. In general, this is because you can evaluate it in a single pass and avoid a larger call stack from recursion.

1 Like

I thought that small unions had been greatly optimized, and that unions with Nothing had little penalty.

This may not be your performance problem, but untyped [] are a red flag, performance-wise. Can’t you type the container?

1 Like

There should be no run time dispatch, but I think the extra branching may still slow things down considerably.

In the second expression representation I discuss above, the one that has some data moved into the type domain, presumably in the case of mathematical expressions you’d move the expression tree structure itself into the type domain, in which case the distinction between Polish notation and more general binary trees disappears after compilation.

That’s an interesting thought. I’m not sure I’d know how to get the equation represented as a type.

A while back I also imagined that if you compile the expression once, just like a user-written 2*x + 1 for instance, then that should compile away anything about any data structures and just be assembly math, so insanely fast. To that end, I looked at packages like GitHub - SciML/RuntimeGeneratedFunctions.jl: Functions generated at runtime without world-age issues or overhead, but for lots of expressions generated dynamically, the compile time overhead nullified any possible benefits. But if I missed something, I’d be more than happy to learn!

1 Like

I created a small package here:

The README happens to conclude with an example expression whose type is a singleton type, meaning it is completely in contained in the type domain. NB: that example depends on the fact that the relevant operations are also of singleton type.

Probably gonna register this soon, after checking off some of the to-dos.

1 Like

Package announcement:

Very cool package. I’m playing around with it, and the ideas are interesting.

I now think I did not mention some desiderata in the original design, which led me to include the forwards in my original example. Given you can evaluate an expression quickly, the actual problem to solve using that ability is to optimize paramaters in the expression using stochastic gradient descent. Now, Polish notation gives me helpful indices useful for referring to the output of every node in the expression, which will be needed when I do the reverse pass and backpropagate the gradients.

In CallableExpressions.jl, my intuition given that the expression itself is now neatly packed into the types, the ability to locate and read/edit the outputs/gradient updates seems to be lost. Interested to hear your take on this.

Though I will try shoving one of your expression types through Zygote and will see what happens.

EDIT: On my machine, the forward evaluation of the example expression in your README takes around 90ns, and finding the gradient of that using Zygote wrt x and y takes 25µs. An implementation like the one above measures 200ns on the forward evaluation, but only 600ns to calculate the gradient (also using Zygote, but separately on each node).

1 Like

I studied your code some more and wrote a way to get the polish notation into the type system as well. The difference is that I can also maintain the output of every node. The overall weakness of these approaches seems to be a quickly growing compile time, 16 nodes already take 70ms on my machine just for the compilation, which mostly offsets the crazy speeds you can get later.

The code:

struct ProgEnd end
struct Term{next,value,ix} end
struct Variable{next,name,ix} end
struct Unary{next,op,ix,childix} end
struct Binary{next,op,ix,lchildix,rchildix} end


# binary tree:     exp(x) + y
# polish notation: + exp x y
pend = ProgEnd()
p = Binary{pend,+,1,2,4}()
e = Unary{p,exp,2,3}()
v2 = Variable{e,:y,4}()
v1 = Variable{v2,:x,3}()

function evalexpr!(::Vector{Float32}, ::ProgEnd, subs)
end

function evalexpr!(fwds::Vector{Float32}, ::Binary{next,op,ix,lix,rix}, subs) where {next,op,ix,lix,rix}
    @inbounds fwds[ix] = op(fwds[lix], fwds[rix])
    evalexpr!(fwds, next, subs)
end

function evalexpr!(fwds::Vector{Float32}, ::Unary{next,op,ix,lix}, subs) where {next,op,ix,lix}
    @inbounds fwds[ix] = op(fwds[lix])
    evalexpr!(fwds, next, subs)
end

function evalexpr!(fwds::Vector{Float32}, ::Term{next,value,ix}, subs) where {next,value,ix,}
    @inbounds fwds[ix] = value
    evalexpr!(fwds, next, subs)
end

function evalexpr!(fwds::Vector{Float32}, ::Variable{next,name,ix}, subs::NT) where {next,name,ix,NT<:NamedTuple}
    @inbounds fwds[ix] = subs[name]
    evalexpr!(fwds, next, subs)
end

fwds = zeros(Float32, 4)
data = (; x=2, y=3)
evalexpr!(fwds, v1, data)
# on my machine, 3ms to compile, 10ns to evaluate

It’s easy to get AD in there as well, but that adds another 100ms to compilation times! For now, I think this is as good as it gets. I’d love to know more julia internals, the compiled output for math expressions is so simple that I feel I should be able to jump in at some levels lower than julia code, and in that manner avoid some of the compilation overhead.