I’m working on types that efficiently wrap polish notation and allow for fast evaluation of mathematical expressions. If I run @profview on the following script, I see that most of the time inside fpass! is spent deciding which specific method to call, depending on the node.
I’m looking to understand the dispatch component here, I feel that I have read and tried to apply the performance tips section of the julia manual.
Here is reproducible code:
## Definitions
mutable struct QNode{T,A}
op::T
function QNode(op::T, arity::Int) where {T<:Function}
if arity == 1
return new{T,Val{1}}(op)
elseif arity == 2
return new{T,Val{2}}(op)
end
end
QNode(op::T) where {T<:Real} = new{Float32,Val{0}}(Float32(op))
QNode(op::Symbol) = new{Symbol,Val{0}}(op)
end
struct TreePtr
l::Union{Int64,Nothing}
r::Union{Int64,Nothing}
TreePtr() = new(nothing, nothing)
TreePtr(l) = new(l, nothing)
TreePtr(l, r) = new(l, r)
end
struct QExpr
program::Vector{QNode}
tree_ixs::Vector{TreePtr}
QExpr(program::Vector{QNode}) = new(
program,
parse_tree_ixs(program),
)
QExpr(program::Vector{QNode}, tree_ixs::Vector{TreePtr}) = new(
program,
tree_ixs,
)
end
Base.length(qexpr::QExpr) = length(qexpr.program)
Base.size(qexpr::QExpr) = size(qexpr.program)
Base.getindex(qexpr::QExpr, ix) = qexpr.program[ix]
function parse_tree_ixs(program::Vector{QNode})
result = []
recent_unused = []
for ix in reverse(eachindex(program))
n = program[ix]
parse_node(n, recent_unused, result)
pushfirst!(recent_unused, ix)
end
result
end
parse_node(node::QNode{T,Val{0}}, recent_unused, result) where {T} = pushfirst!(result, TreePtr())
parse_node(node::QNode{T,Val{1}}, recent_unused, result) where {T} = pushfirst!(result, TreePtr(popfirst!(recent_unused)))
parse_node(node::QNode{T,Val{2}}, recent_unused, result) where {T} = pushfirst!(result, TreePtr(popfirst!(recent_unused), popfirst!(recent_unused)))
function fpass!(expr::QExpr, forwards::Vector{Float32}, row::NamedTuple)
for node_ix in reverse(eachindex(expr.program))
fpass_node!(expr[node_ix], node_ix, expr.tree_ixs[node_ix], forwards, row)
end
end
function fpass_node!(node::QNode{Symbol,Val{0}}, node_ix::Int, tree_ptr::TreePtr, forwards::Vector{Float32}, row::NamedTuple)
@inbounds begin
forwards[node_ix] = row[node.op]
end
end
function fpass_node!(node::QNode{Float32,Val{0}}, node_ix::Int, tree_ptr::TreePtr, forwards::Vector{Float32}, row::NamedTuple)
@inbounds begin
forwards[node_ix] = node.op
end
end
function fpass_node!(node::QNode{<:Function,Val{1}}, node_ix::Int, tree_ptr::TreePtr, forwards::Vector{Float32}, row::NamedTuple)
@inbounds begin
forwards[node_ix] = node.op(forwards[tree_ptr.l])
end
end
function fpass_node!(node::QNode{<:Function,Val{2}}, node_ix::Int, tree_ptr::TreePtr, forwards::Vector{Float32}, row::NamedTuple)
@inbounds begin
forwards[node_ix] = node.op(forwards[tree_ptr.l], forwards[tree_ptr.r])
end
end
## Profiling
qexpr = QExpr([QNode(+, 2), QNode(exp, 1), QNode(:x), QNode(0.1)])
row = (x=3,)
forwards = Vector{Float32}(undef, length(qexpr))
fpass!(qexpr, forwards, row)
println(forwards)
println(exp(3) + 0.1)
@profview for _ in 1:1_000_000
fpass!(qexpr, forwards, row)
end