Compatibility of alternative representation of Expressions with SymbolicUtils.jl

Hello,

with my student we are trying to reproduce this Facebook’s paper (GitHub - facebookresearch/neural-rewriter: Learning to Perform Local Rewriting for Combinatorial Optimization) in Julia. At the moment, we use Metatheory.jl for matching and rule system, but since we are operating over julia native Expressions, the code is not type stable and hence slow. Therefore we have started to experiment with an alternative representation of Expr, which are not as expressive (restricted to two childs at most), but for this project it is sufficient. The idea is that we represent each Expr as a structure

struct OnlyNode
    head::Symbol
    iscall::Bool
    v::Float32
    left::NodeID
    right::NodeID
end

where NodeID is just an integer identifying a concrete expressions. With that, each expression is represented by a unique NodeID and if it has as children subexpressions, it they are NodeIDs as well. I expect I have implemented something which is known to the community.

My question now is if I can make this compatible with rule system of SymbolicUtils.jl. I have extended the TermInterface.jl which on the example from doc would behave like this

I put the expression to the rewriting system

nc = NodeCache()
julia> ex = get!(nc, :(f(a, b)))
NodeID(0x000003ca)

and the term interface will behave like this

julia> head(ex)
:call

julia> children(ex)
3-element Vector{Any}:
 :f
 NodeID(0x000003c8)
 NodeID(0x000003c9)

julia> operation(ex)
:f

julia> arguments(ex)
2-element Vector{NodeID}:
 NodeID(0x000003c8)
 NodeID(0x000003c9)

julia> isexpr(ex)
true

julia> iscall(ex)
true

But it does not work with the rule rewriting system. For example this

julia> r = @rule sin(~x + ~y) => sin(~x)*cos(~y) + cos(~x)*sin(~y);

Now if I try the example from SymbolicUtils

julia> ex = get!(nc, :(sin(x + y)))
NodeID(0x000003ce)

julia> r = @rule sin(~x + ~y) => sin(~x)*cos(~y) + cos(~x)*sin(~y);

julia> r(ex)

it return nothing. Can someone suggests what I do incorrectly?

Thanks a lot in advance,
Tomas

Tracking down the problem, I think it is on the end a bit simple, but I do not know, how to resolve it.

My representation is working with symbolics.
Consider this example

@syms x y;
nc = NodeCache();
sx = sin(x + y);
ex = get!(nc, :(sin(x + y)));

now head(sx) returns a function sin, whereas my representation returns Symbol :sin. If I construct a rule

r = @rule sin(~x + ~y) => sin(~x)*cos(~y) + cos(~x)*sin(~y);

it matches function sin and not Symbol :sin. This makes of course perfect sense. But I do not know, if there is an easy fix. I can probably redefine isequal as for example

Base.isequal(x::Symbol,  ::typeof(sin)) = x == :sin

Is this the right solution