Performance optimizations with runtime dispatch

I’m working on types that efficiently wrap polish notation and allow for fast evaluation of mathematical expressions. If I run @profview on the following script, I see that most of the time inside fpass! is spent deciding which specific method to call, depending on the node.

I’m looking to understand the dispatch component here, I feel that I have read and tried to apply the performance tips section of the julia manual.

Here is reproducible code:

## Definitions

mutable struct QNode{T,A}
    op::T

    function QNode(op::T, arity::Int) where {T<:Function}
        if arity == 1
            return new{T,Val{1}}(op)
        elseif arity == 2
            return new{T,Val{2}}(op)
        end
    end
    QNode(op::T) where {T<:Real} = new{Float32,Val{0}}(Float32(op))
    QNode(op::Symbol) = new{Symbol,Val{0}}(op)
end

struct TreePtr
    l::Union{Int64,Nothing}
    r::Union{Int64,Nothing}

    TreePtr() = new(nothing, nothing)
    TreePtr(l) = new(l, nothing)
    TreePtr(l, r) = new(l, r)
end

struct QExpr
    program::Vector{QNode}
    tree_ixs::Vector{TreePtr}

    QExpr(program::Vector{QNode}) = new(
        program,
        parse_tree_ixs(program),
    )
    QExpr(program::Vector{QNode}, tree_ixs::Vector{TreePtr}) = new(
        program,
        tree_ixs,
    )
end
Base.length(qexpr::QExpr) = length(qexpr.program)
Base.size(qexpr::QExpr) = size(qexpr.program)
Base.getindex(qexpr::QExpr, ix) = qexpr.program[ix]

function parse_tree_ixs(program::Vector{QNode})
    result = []
    recent_unused = []
    for ix in reverse(eachindex(program))
        n = program[ix]
        parse_node(n, recent_unused, result)
        pushfirst!(recent_unused, ix)
    end
    result
end

parse_node(node::QNode{T,Val{0}}, recent_unused, result) where {T} = pushfirst!(result, TreePtr())
parse_node(node::QNode{T,Val{1}}, recent_unused, result) where {T} = pushfirst!(result, TreePtr(popfirst!(recent_unused)))
parse_node(node::QNode{T,Val{2}}, recent_unused, result) where {T} = pushfirst!(result, TreePtr(popfirst!(recent_unused), popfirst!(recent_unused)))

function fpass!(expr::QExpr, forwards::Vector{Float32}, row::NamedTuple)
    for node_ix in reverse(eachindex(expr.program))
        fpass_node!(expr[node_ix], node_ix, expr.tree_ixs[node_ix], forwards, row)
    end
end

function fpass_node!(node::QNode{Symbol,Val{0}}, node_ix::Int, tree_ptr::TreePtr, forwards::Vector{Float32}, row::NamedTuple)
    @inbounds begin
        forwards[node_ix] = row[node.op]
    end
end

function fpass_node!(node::QNode{Float32,Val{0}}, node_ix::Int, tree_ptr::TreePtr, forwards::Vector{Float32}, row::NamedTuple)
    @inbounds begin
        forwards[node_ix] = node.op
    end
end

function fpass_node!(node::QNode{<:Function,Val{1}}, node_ix::Int, tree_ptr::TreePtr, forwards::Vector{Float32}, row::NamedTuple)
    @inbounds begin
        forwards[node_ix] = node.op(forwards[tree_ptr.l])
    end
end

function fpass_node!(node::QNode{<:Function,Val{2}}, node_ix::Int, tree_ptr::TreePtr, forwards::Vector{Float32}, row::NamedTuple)
    @inbounds begin
        forwards[node_ix] = node.op(forwards[tree_ptr.l], forwards[tree_ptr.r])
    end
end

## Profiling
qexpr = QExpr([QNode(+, 2), QNode(exp, 1), QNode(:x), QNode(0.1)])
row = (x=3,)

forwards = Vector{Float32}(undef, length(qexpr))

fpass!(qexpr, forwards, row)

println(forwards)
println(exp(3) + 0.1)

@profview for _ in 1:1_000_000
    fpass!(qexpr, forwards, row)
end

I would suggest eschewing runtime dispatch. Instead use a switch case to call the appropriate function (and make sure there’s no type instability when calling a function for a specific case) or use something like Virtual.jl

I would recommend taking a look at how DynamicExpressions.jl does things with the Node type:

If your fpass_node! functions were more complicated you might benefit from the techniques used here, Parsing Protobuf at 2+GB/s: How I Learned To Love Tail Calls in C (reverberate.org) and Building the fastest Lua interpreter… automatically! | (sillycross.github.io). Mainly the stuff about tail calls, though I don’t know if you could do that nicely in Julia.

maybe MixedStructTypes.jl could help you? It should help to avoid dynamic dispatch