Intro
I am wondering if I could others’ opinions on a potential interface for traversing expressions in DynamicExpressions.jl.
To give you a sense of what the package does, the original announcement thread for this package is here: [ANN] DynamicExpressions.jl: Fast evaluation of runtime-generated expressions without compilation.
DynamicExpressions.jl is used in PySR and SymbolicRegression.jl as the backend - it lets those packages evaluate arbitrary expressions without additional compilation.
Tree structure
Expressions are currently implemented as binary trees with the following rigid structure. It is built for type stability, fast traversals, and low memory usage.
mutable struct Node{T}
degree::Int # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
constant::Bool # false if variable
val::Union{T,Nothing} # If is a constant, this stores the actual value
# ------------------- (possibly undefined below)
feature::Int # If is a variable (e.g., x in cos(x)), this stores the feature index.
op::Int # If operator, this is the index of the operator in operators.binary_operators, or operators.unary_operators
l::Node{T} # Left child node. Only defined for degree=1 or degree=2.
r::Node{T} # Right child node. Only defined for degree=2.
end
(in the future I will likely change this to Node{T,D}
for some max degree D
, but so far there hasn’t been a strong need)
Problem
Following a discussion with @jling I realized it would be nice if there was an easier way to traverse expressions and compute various properties. For example, you might want to know the number of occurrences of a particular operator (e.g., for solving this puzzle), or you might want to extract all the constants in the expression for use in an optimizer (i.e., used in SymbolicRegression.jl for BFGS-optimization of a given expression).
For example, right now if you want to compute the number of nodes in an expression (i.e., constants+operators+variables), you could write the function (which performs a depth-first search)
function count_nodes(tree::Node)::Int
if tree.degree == 0
return 1
elseif tree.degree == 1
return 1 + count_nodes(tree.l)
else
return 1 + count_nodes(tree.l) + count_nodes(tree.r)
end
end
likewise if you wanted to compute the depth of a tree, you could write:
function count_depth(tree::Node)::Int
if tree.degree == 0
return 1
elseif tree.degree == 1
return 1 + count_depth(tree.l)
else
return 1 + max(count_depth(tree.l), count_depth(tree.r))
end
end
These specific functions are already implemented in DynamicExpressions.jl
, but I think there are many other properties users would want to compute which are not available in the library.
I would like it to be easier for users to write their own functions operating on these trees.
Trees as a Collection
Here is the idea for an interface: treat the tree like an ordered Julia collection. By default, the “order” of nodes in the collection would be that of a depth-first search, but I think in the future you could specify other orderings.
With this interface, I implemented four primitive functions:
-
mapreduce
(aggregation is commutative), -
tree_mapreduce
(aggregation is non-commutative; needs to know parent from child), -
filter_and_map
, (filters nodes, then maps, then collects) and -
any
(to lazily traverse a tree).
From these, most other core functions on collections are implemented: filter
, collect
, map
, count
, sum
, all
, iterate
, in
, length
, and others.
For example, count_nodes
becomes just:
count_nodes(tree::Node) = count(_ -> 1, tree) # == length(tree)
which is just as fast as the version I gave above. count_depth
becomes:
count_depth(tree::Node) = tree_mapreduce(_ -> 1, (p, c...) -> p + max(c...), tree)
which also experiences no change in performance.
Likewise you could get a list of nodes with collect(tree)
, or do things like extract all nodes in the tree which store constants with:
filter(t -> t.degree==0 && t.constant, tree)
or perhaps count the occurrence of a particular binary operator with:
count(t -> t.degree==2 && t.op == 1, tree)
or see if there are any negative constants in the expression with:
any(t -> t.degree==0 && t.constant && t.val<0, tree)
or take the product of all constants with:
mapreduce(*, tree) do t
t.degree == 0 && t.constant && return t.val
return one(T)
end
iter
is implemented, so you can also do things like:
for node in tree
if node.degree==0 && !node.constant && node.feature == 1
println("Found feature 1 in tree!")
end
end
which will iterate through the nodes depth-first order.
Since every node in a tree is a tree itself, you could have traversals within traversals, like:
any(tree) do subtree
if subtree.degree == 2 && subtree.op == 1
return any(t -> t.degree==0&&!t.constant&&t.feature==3, subtree)
end
return false
end
which checks if there are any instances where feature 3 appears within the first binary operator.
I also have eachindex
, setindex!
, and getindex
implemented, which refer to their order in the tree-traversal search. However I might delete those as they seem unintuitive (and are not fast anyways, as every call to getindex
would result in a fresh traversal).
Implementation
You can find the code implementing this here: DynamicExpressions.jl/tree_map.jl at tree-map · SymbolicML/DynamicExpressions.jl · GitHub (coding tips appreciated too, I am always eager to learn more about optimizing Julia code)
Your Feedback
What do you think of this interface?
Would it be weird as a user if you could treat a tree structure like an ordered collection?
Would you prefer function names which are more explicit, such as tree_mapreduce
, tree_any
, etc.?
Would you rather be required to declare the type of ordering of nodes explicitly, like:
for node in DepthFirst(tree) # New type DepthFirst{Node{T}}
# ...
end
or would it be fine as a default if this order is assumed? (Regardless, I may add DepthFirst / BreadthFirst
types, with DepthFirst
assumed if not given).
Thanks for any tips.