Overloading and performances

Hi, I am trying to make a symbolic system and on my way, I stumbled upon something that I can’t figure out.
I created multiple node to construct the syntax but when making a tree with


x = SymbNode(:x)
tree = AddNode(PowNode(x,3), ProdNode(-2,x))

But when measuring time with @time, I got like 98.97% of compilation time and some big allocation, and Julia was not able to optimize the tree construction unless it’s the exact same tree I am trying to create.
But after doing this


Base.:+(n1::NodeorNumber, n2:: NodeorNumber) = AddNode(n1,n2)
Base.:+(n1::Number, n2:: NSyntaxNode) = AddNode(n2,n1)
Base.:-(n1::NodeorNumber, n2:: NodeorNumber) = SubNode(n1,n2)
Base.:*(n1::NodeorNumber, n2:: NodeorNumber) = ProdNode(n1,n2)
Base.:*(n1::NSyntaxNode, n2::Number) = ProdNode(n2,n1)
Base.:^(n1::NodeorNumber, n2:: NodeorNumber) = PowNode(n1,n2)
Base.:/(n1::NodeorNumber, n2:: NodeorNumber) = DivNode(n1,n2)

And then do


x = SymbNode(:x)
tree = 3x^2 - 2x

I suddenly got a performance boost, no more compilation time, the tree is created in less than 1microsecond and I don’t understand why, so I would be glad if someone could

Before you did your second experiment, did you restart Julia? Because when you ran the first code block, all the *Node methods were compiled, so then calling them again would be instant, no matter what syntax you use to call them.

Yrs, I restarted Julia between the two experiment, many times.
So I don’t really understand why the first was slow but the second fast

The second call with the same types will usually be fast.

Where are SymbNode, PowNode etc. defined?

Here are the different nodes definition


struct ConstNode{T <: Number} <: NSyntaxNode
    n::T
end
ConstNode(n::Number) = ConstNode{typeof(n)}(n)

struct SymbNode <: NSyntaxNode
    n::Symbol
end

struct AddNode <: NSyntaxNode
    n1::NSyntaxNode
    n2::NSyntaxNode

    function AddNode(n1::NSyntaxNode, n2::NSyntaxNode)
        if n1 isa ConstNode && n2 isa ConstNode
            return ConstNode(n1.n + n2.n)
        elseif iszero(n1)
            return n2
        elseif iszero(n2)
            return n1
        else
            return new(n1, n2)
        end
    end
end

AddNode(n1::NodeType, n2::NodeType) = AddNode(_make_node(n1), _make_node(n2))

struct SubNode <: NSyntaxNode
    n1::NSyntaxNode
    n2::NSyntaxNode

    function SubNode(n1::NSyntaxNode, n2::NSyntaxNode)
        if n1 isa ConstNode && n2 isa ConstNode
            return ConstNode(n1.n - n2.n)
        elseif iszero(n2)
            return n1
        else
            return new(n1, n2)
        end
    end
end


SubNode(n1::NodeType, n2::NodeType) = SubNode(_make_node(n1), _make_node(n2))

struct ProdNode <: NSyntaxNode
    n1::NSyntaxNode
    n2::NSyntaxNode

    function ProdNode(n1::NSyntaxNode, n2::NSyntaxNode)
        if n1 isa ConstNode && n2 isa ConstNode
            return ConstNode(n1.n * n2.n)
        elseif iszero(n1) || iszero(n2)
            return ConstNode(0)
        elseif isone(n1)
            return n2
        elseif isone(n2)
            return n1
        else
            return new(n1, n2)
        end
    end
end
ProdNode(n1::NodeType, n2::NodeType) = ProdNode(_make_node(n1), _make_node(n2))

struct DivNode <: NSyntaxNode
    n1::NSyntaxNode
    n2::NSyntaxNode

    function DivNode(n1::NSyntaxNode, n2::NSyntaxNode)
        if iszero(n2)
            throw(ZeroDivisionError("can create a node that divide by 0"))
        else
            if n1 isa ConstNode && n2 isa ConstNode
                return ConstNode(n1.n / n2.n)
            elseif iszero(n1)
                return ConstNode(0)
            elseif isone(n2)
                return n1
            else
                return new(n1, n2)
            end
        end
    end
end

DivNode(n1::NodeType, n2::NodeType) = DivNode(_make_node(n1), _make_node(n2))

struct PowNode <: NSyntaxNode
    n1::NSyntaxNode
    n2::NSyntaxNode

      function PowNode(n1::NSyntaxNode, n2::NSyntaxNode)
        if n1 isa ConstNode && n2 isa ConstNode
            return ConstNode(n1.n ^ n2.n)
        elseif iszero(n1)
            return ConstNode(0)
        elseif iszero(n2)
            return ConstNode(1)
        else
            return new(n1, n2)
        end
    end
end


PowNode(n1::NodeType, n2::NodeType) = PowNode(_make_node(n1), _make_node(n2))



I’m not knowledgeable on these symbolics, but let me point out that I spot a few abstract container types in there; perhaps performance may improve if you make these parametric types instead?

Yeah, more or less. But parametrizing types would make tons of different dispatch (to create a simple AddNode, we would have a method for each combinaison of parameter) which may slow down process.
So abstract type was the suitable architecture for this.
Now I have the same problem, why overloading operators and using them to create a syntax tree made it so much faster???

It probably didn’t. The first call compiles all these methods, the second re-uses them. Even if these methods are being called via other newly defined Base operators.

I think. We still can’t run what you are running to check we’re on the same page. Perhaps you can edit the code block above to ensure that it runs on a fresh session (defining NodeType, NSyntaxNode etc) all the way through to your timing.

You can actually run the code if you’re interested.
You just need to copy everything I sent and add these


abstract type NSyntaxNode 

const NodeType = Union{Number, NSyntaxNode, Symbol, Expr}
const NodeorNumber = Union{Number, NSyntaxNode}

_make_node(n::NSyntaxNode) = n
_make_node(n::Number) = ConstNode(n)

It should work.

not directly answering your question but you might be intereted in GitHub - SymbolicML/DynamicExpressions.jl: Ridiculously fast symbolic expressions if you are interested in a production grade library

Sorry, the actual full code


abstract type AbstractCodeSpace end
abstract type NSyntaxNode end

const NodeType = Union{Number, Symbol, Expr, NSyntaxNode}
const NodeorNumber = Union{Number, NSyntaxNode}

struct NSyntaxTree{T <: NSyntaxNode}
    root::T
end
NSyntaxTree(n::NSyntaxNode) = NSyntaxTree{typeof(n)}(n)

mutable struct SymbolicSpace{T <: Any} <: AbstractCodeSpace 
    code::NSyntaxTree
    const var::Dict{Symbol, T}

    ## Constructor 

    SymbolicSpace{T}() where T<: Any = new{T}(NSyntaxTree(ConstNode(0)), Dict{Symbol, T}())
    SymbolicSpace{T}(ex::Expr) where T<: Any = new{T}(totree(ex), Dict{Symbol, T}())

end

## Better use immutable struct for this. type parameters reduce the need for the compiler to infer fields type 


struct ConstNode{T <: Number} <: NSyntaxNode
    n::T
end
ConstNode(n::Number) = ConstNode{typeof(n)}(n)

struct SymbNode <: NSyntaxNode
    n::Symbol
end

struct AddNode <: NSyntaxNode
    n1::NSyntaxNode
    n2::NSyntaxNode

    function AddNode(n1::NSyntaxNode, n2::NSyntaxNode)
        if n1 isa ConstNode && n2 isa ConstNode
            return ConstNode(n1.n + n2.n)
        elseif iszero(n1)
            return n2
        elseif iszero(n2)
            return n1
        else
            return new(n1, n2)
        end
    end
end

AddNode(n1::NodeType, n2::NodeType) = AddNode(_make_node(n1), _make_node(n2))

struct SubNode <: NSyntaxNode
    n1::NSyntaxNode
    n2::NSyntaxNode

    function SubNode(n1::NSyntaxNode, n2::NSyntaxNode)
        if n1 isa ConstNode && n2 isa ConstNode
            return ConstNode(n1.n - n2.n)
        elseif iszero(n2)
            return n1
        else
            return new(n1, n2)
        end
    end
end


SubNode(n1::NodeType, n2::NodeType) = SubNode(_make_node(n1), _make_node(n2))

struct ProdNode <: NSyntaxNode
    n1::NSyntaxNode
    n2::NSyntaxNode

    function ProdNode(n1::NSyntaxNode, n2::NSyntaxNode)
        if n1 isa ConstNode && n2 isa ConstNode
            return ConstNode(n1.n * n2.n)
        elseif iszero(n1) || iszero(n2)
            return ConstNode(0)
        elseif isone(n1)
            return n2
        elseif isone(n2)
            return n1
        else
            return new(n1, n2)
        end
    end
end
ProdNode(n1::NodeType, n2::NodeType) = ProdNode(_make_node(n1), _make_node(n2))

struct DivNode <: NSyntaxNode
    n1::NSyntaxNode
    n2::NSyntaxNode

    function DivNode(n1::NSyntaxNode, n2::NSyntaxNode)
        if iszero(n2)
            throw(ZeroDivisionError("can create a node that divide by 0"))
        else
            if n1 isa ConstNode && n2 isa ConstNode
                return ConstNode(n1.n / n2.n)
            elseif iszero(n1)
                return ConstNode(0)
            elseif isone(n2)
                return n1
            else
                return new(n1, n2)
            end
        end
    end
end

DivNode(n1::NodeType, n2::NodeType) = DivNode(_make_node(n1), _make_node(n2))

struct PowNode <: NSyntaxNode
    n1::NSyntaxNode
    n2::NSyntaxNode

      function PowNode(n1::NSyntaxNode, n2::NSyntaxNode)
        if n1 isa ConstNode && n2 isa ConstNode
            return ConstNode(n1.n ^ n2.n)
        elseif iszero(n1)
            return ConstNode(0)
        elseif iszero(n2)
            return ConstNode(1)
        else
            return new(n1, n2)
        end
    end
end


PowNode(n1::NodeType, n2::NodeType) = PowNode(_make_node(n1), _make_node(n2))

## Todo : Add checks for non n positive values
struct LnNode <: NSyntaxNode
    n::NSyntaxNode

    LnNode(n::NSyntaxNode) = new(n)
end

LnNode(n::NodeType) = LnNode(_make_node(n))

struct LogNode <: NSyntaxNode
    n::NSyntaxNode

    LogNode(n::NSyntaxNode) = new(n)
end

LogNode(n::NodeType) = LogNode(_make_node(n))
### Spaces function 

setvar(space::SymbolicSpace{T}, s::Symbol, val) where T = (space.var[s] = convert(T, val))

####### Operations

## since I will likely modify the field's names, it's safer for this case to rely on their order
Base.getindex(n::NSyntaxNode, I::Integer) = begin
    fields = fieldnames(typeof(n))
    getfield(n, fields[I])
end

## TODO : Add more operator when the code base will be ready
getop(::AddNode) = :+
getop(::SubNode) = :-
getop(::ProdNode) = :*
getop(::DivNode) = :/
getop(::PowNode) = :^
getop(::LnNode) = :log
getop(::LogNode) = :log10


Base.:+(n1::NodeorNumber, n2:: NodeorNumber) = AddNode(n1,n2)
Base.:+(n1::Number, n2:: NSyntaxNode) = AddNode(n2,n1)
Base.:-(n1::NodeorNumber, n2:: NodeorNumber) = SubNode(n1,n2)
Base.:*(n1::NodeorNumber, n2:: NodeorNumber) = ProdNode(n1,n2)
Base.:*(n1::NSyntaxNode, n2::Number) = ProdNode(n2,n1)
Base.:^(n1::NodeorNumber, n2:: NodeorNumber) = PowNode(n1,n2)
Base.:/(n1::NodeorNumber, n2:: NodeorNumber) = DivNode(n1,n2)

toexpr(tree::NSyntaxTree) = toexpr(tree.root)
toexpr(n::ConstNode) = n.n
toexpr(n::SymbNode) = n.n
toexpr(n::NSyntaxNode) = Expr(:call, getop(n), toexpr(n.n1), toexpr(n.n2))
toexpr(n::LnNode) = Expr(:call, getop(n), toexpr(n[1]))
toexpr(n::LogNode) = Expr(:call, getop(n), toexpr(n[1]))

totree(ex::Expr) = NSyntaxTree(_make_node(ex))

_make_node(n::NSyntaxNode) = n
_make_node(n::Number) = ConstNode(n)
_make_node(s::Symbol) = SymbNode(s)
_make_node(::Val{:+}, n1::NodeType, n2::NodeType) = AddNode(n1, n2)
_make_node(::Val{:-}, n1::NodeType, n2::NodeType) = SubNode(n1, n2)
_make_node(::Val{:*}, n1::NodeType, n2::NodeType) = ProdNode(n1, n2)
_make_node(::Val{:/}, n1::NodeType, n2::NodeType) = DivNode(n1, n2)
_make_node(::Val{:^}, n1::NodeType, n2::NodeType) = PowNode(n1, n2)
_make_node(::Val{:ln}, n::NodeType) = LnNode(n)
_make_node(::Val{:log}, n::NodeType) = LogNode(n)
_make_node(ex::Expr) = begin
    ch = ex.args
    if length(ch) == 2
        _make_node(Val(ch[1]), ch[2])
    elseif length(ch) == 3
        _make_node(Val(ch[1]), ch[2], ch[3])
    end
end