Tree <: Collection? Looking for API feedback in DynamicExpressions

Intro

I am wondering if I could others’ opinions on a potential interface for traversing expressions in DynamicExpressions.jl.

To give you a sense of what the package does, the original announcement thread for this package is here: [ANN] DynamicExpressions.jl: Fast evaluation of runtime-generated expressions without compilation.

DynamicExpressions.jl is used in PySR and SymbolicRegression.jl as the backend - it lets those packages evaluate arbitrary expressions without additional compilation.


Tree structure

Expressions are currently implemented as binary trees with the following rigid structure. It is built for type stability, fast traversals, and low memory usage.

mutable struct Node{T}
    degree::Int  # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
    constant::Bool  # false if variable
    val::Union{T,Nothing}  # If is a constant, this stores the actual value
    # ------------------- (possibly undefined below)
    feature::Int  # If is a variable (e.g., x in cos(x)), this stores the feature index.
    op::Int  # If operator, this is the index of the operator in operators.binary_operators, or operators.unary_operators
    l::Node{T}  # Left child node. Only defined for degree=1 or degree=2.
    r::Node{T}  # Right child node. Only defined for degree=2. 
end

(in the future I will likely change this to Node{T,D} for some max degree D, but so far there hasn’t been a strong need)

Problem

Following a discussion with @jling I realized it would be nice if there was an easier way to traverse expressions and compute various properties. For example, you might want to know the number of occurrences of a particular operator (e.g., for solving this puzzle), or you might want to extract all the constants in the expression for use in an optimizer (i.e., used in SymbolicRegression.jl for BFGS-optimization of a given expression).

For example, right now if you want to compute the number of nodes in an expression (i.e., constants+operators+variables), you could write the function (which performs a depth-first search)

function count_nodes(tree::Node)::Int
    if tree.degree == 0
        return 1
    elseif tree.degree == 1
        return 1 + count_nodes(tree.l)
    else
        return 1 + count_nodes(tree.l) + count_nodes(tree.r)
    end
end

likewise if you wanted to compute the depth of a tree, you could write:

function count_depth(tree::Node)::Int
    if tree.degree == 0
        return 1
    elseif tree.degree == 1
        return 1 + count_depth(tree.l)
    else
        return 1 + max(count_depth(tree.l), count_depth(tree.r))
    end
end

These specific functions are already implemented in DynamicExpressions.jl, but I think there are many other properties users would want to compute which are not available in the library.

I would like it to be easier for users to write their own functions operating on these trees.

Trees as a Collection

Here is the idea for an interface: treat the tree like an ordered Julia collection. By default, the “order” of nodes in the collection would be that of a depth-first search, but I think in the future you could specify other orderings.

With this interface, I implemented four primitive functions:

  • mapreduce (aggregation is commutative),
  • tree_mapreduce (aggregation is non-commutative; needs to know parent from child),
  • filter_and_map, (filters nodes, then maps, then collects) and
  • any (to lazily traverse a tree).

From these, most other core functions on collections are implemented: filter, collect, map, count, sum, all, iterate, in, length, and others.

For example, count_nodes becomes just:

count_nodes(tree::Node) = count(_ -> 1, tree)  # == length(tree)

which is just as fast as the version I gave above. count_depth becomes:

count_depth(tree::Node) = tree_mapreduce(_ -> 1, (p, c...) -> p + max(c...), tree)

which also experiences no change in performance.

Likewise you could get a list of nodes with collect(tree), or do things like extract all nodes in the tree which store constants with:

filter(t -> t.degree==0 && t.constant, tree)

or perhaps count the occurrence of a particular binary operator with:

count(t -> t.degree==2 && t.op == 1, tree)

or see if there are any negative constants in the expression with:

any(t -> t.degree==0 && t.constant && t.val<0, tree)

or take the product of all constants with:

mapreduce(*, tree) do t
    t.degree == 0 && t.constant && return t.val
    return one(T)
end

iter is implemented, so you can also do things like:

for node in tree
    if node.degree==0 && !node.constant && node.feature == 1
        println("Found feature 1 in tree!")
    end
end

which will iterate through the nodes depth-first order.

Since every node in a tree is a tree itself, you could have traversals within traversals, like:

any(tree) do subtree
    if subtree.degree == 2 && subtree.op == 1
        return any(t -> t.degree==0&&!t.constant&&t.feature==3, subtree)
    end
    return false
end

which checks if there are any instances where feature 3 appears within the first binary operator.

I also have eachindex, setindex!, and getindex implemented, which refer to their order in the tree-traversal search. However I might delete those as they seem unintuitive (and are not fast anyways, as every call to getindex would result in a fresh traversal).

Implementation

You can find the code implementing this here: DynamicExpressions.jl/tree_map.jl at tree-map · SymbolicML/DynamicExpressions.jl · GitHub (coding tips appreciated too, I am always eager to learn more about optimizing Julia code)


Your Feedback

What do you think of this interface?

Would it be weird as a user if you could treat a tree structure like an ordered collection?

Would you prefer function names which are more explicit, such as tree_mapreduce, tree_any, etc.?

Would you rather be required to declare the type of ordering of nodes explicitly, like:

for node in DepthFirst(tree)  # New type DepthFirst{Node{T}}
    # ...
end

or would it be fine as a default if this order is assumed? (Regardless, I may add DepthFirst / BreadthFirst types, with DepthFirst assumed if not given).


Thanks for any tips.

2 Likes

One other thought would be to only require the DepthFirst(tree) operation when performing operations specific to ordered collections, like indexing. But operations which are order-invariant like mapreduce could take tree as an argument.