Efficiently copying a tree with shared children

What is the most efficient way to copy a tree with shared children? While this is technically a graph rather than a tree, here there is a clear root node and directionality, so it’s more intuitive to think of it as a tree with shared children nodes.

I am attempting to go from expression trees to expression graphs in SymbolicRegression.jl (pull request here).

Here is a simplified version of my data structure for a tree (full version here):

mutable struct Node
    degree::Int  # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
    val::Float32  # If is a leaf, this stores the value
    op::Int  # enum over operators
    l::Node  # Left child node. Only defined for degree=1 or degree=2.
    r::Node  # Right child node. Only defined for degree=2. 

    Node(val::Float32) = new(0, val)
    Node(op::Int, l::Node) = new(1, 0f0, op, l)
    Node(op::Int, l::Node, r::Node) = new(2, 0f0, op, l, r)
end

One can create a tree that has multiple nodes with the same child node. For example, the expression
\cos(x - 3.2 y) + \cos(2 (x - 3.2 y)) could be stored as the following tree:

Here, both the cos on the left branch and the * on the right branch would link to this shared - node. Now, since this is implemented as a tree, how should this be copied? The normal way of copying this tree would be:

function copy_node(tree::Node)::Node
    if tree.degree == 0
        Node(copy(tree.val))
    elseif tree.degree == 1
        Node(copy(tree.op), copy_node(tree.l))
    else
        Node(
            copy(tree.op),
            copy_node(tree.l),
            copy_node(tree.r),
        )
    end
end

The problem with this is that it breaks the topology: shared child nodes are duplicated! For example:

base_tree = Node(1, Node(1f0)) # e.g., if 1 is cos, cos(1.0)
tree = Node(1, base_tree, base_tree) # e.g., if 1 is +, cos(1.0) + cos(1.0)
objectid(tree.l) == objectid(tree.r) # true!

ctree = copy_node(tree)
objectid(ctree.l) == objectid(ctree.r) # false!

So, I use the following trick with IdDict:

function copy_node(tree::Node, id_map::IdDict{Node,Node})::Node
    get!(id_map, tree) do
        if tree.degree == 0
            Node(copy(tree.val))
        elseif tree.degree == 1
            Node(copy(tree.op), copy_node(tree.l, id_map))
        else
            Node(
                copy(tree.op),
                copy_node(tree.l, id_map),
                copy_node(tree.r, id_map),
            )
        end
    end
end

what this basically does is store a dictionary mapping objectid(node) => copied_node. Thus, if a particular node (referenced by its ID) already has been copied, it simply returns that copy instead. This works!

ctree = copy_node(tree, IdDict{Node,Node}())
objectid(ctree.l) == objectid(ctree.r) # true

However, this is about significantly slower than a normal copy. I don’t understand the reason for this, because it’s only storing the reference information in an IdDict. If anything I almost think it should be faster, since it results in fewer allocations. Is there a faster way I can copy a tree, preserving its topology?

Here’s a benchmark:

using BenchmarkTools

base_tree = Node(8, Node(1f0), Node(2, Node(3f0)))
tree = Node(2, base_tree, Node(2, Node(1, base_tree, base_tree), Node(2f0)))

@btime [copy_node($tree) for i=1:1_000];
# 101.292 μs (16001 allocations: 757.94 KiB)
@btime [copy_node($tree, IdDict{Node,Node}()) for i=1:1_000];
# 300.042 μs (10001 allocations: 711.06 KiB)

Even with fewer total allocations, for some reason this topology-preserving copy is 3x slower. Some part of this may just be constructing IdDict{Node,Node}() - it seems like that part accounts for 30us. But, is there a more efficient way to do this?

Thanks!
Miles

(Also I know that I don’t need to copy the tree.val and tree.op - that is there in case I decide to generalize the types in the future, and they actually need to be copied. copy(val) is just compiled to val so it doesn’t actually affect performance.)

It could be mostly from the hashing in IdDict:

@btime [IdDict([$tree => $tree.l])[$tree] for i=1:1_000];
# 112.250 μs (3001 allocations: 414.19 KiB)

Why is this so slow? Is there a faster way of doing this? (I note that it’s the same for Dict{UInt,Node}([objectid(tree) => tree.l]))

note that this type of structure is often called a DAG (directed acyclic graph)

2 Likes

Note here that there is only one root node, whereas a DAG could have multiple root nodes (e.g., Multitree - Wikipedia).

(I don’t know if there is a specific name for this type of tree… but if you know please mention it!)

I played around with some other types of dictionaries and couldn’t find any speedup. I think this just might be at performance limits for such an operation.

The other option is to forgo the Node type (with linked children) and instead store this as an actual Graph type, with a set of nodes and set of edges (with indices to each node). Then copying would simply be copying a couple of arrays. However, such a change would require a massive refactor of my library so I don’t think this is feasible.

Did you already try the standard depth first traversal method to clone a directed acyclic graph?

Depth-first traversal is what both versions of copy_node are doing. One of them preserves shared children nodes (slow), and one duplicates shared children (fast). I am trying to figure out why the one which preserves shared children is 3x as slow, since tree copying is done quite frequently in the library.

I see. object_id is not a very lightweight call. It calls into jl_object_id_ which implemented here:

Rather than using IdDict perhaps you would be better off implementing hash and isequal for Node.

2 Likes

The tricky part is I don’t want nodes with the same content to be treated as the same node; otherwise I couldn’t have two identical (but separate) nodes in a tree. This is why I need something like objectid.

Another option is to add an ID field to each node that specifies a random UInt. But I will need to ensure whenever a node changes, that the ID field updates - which will be a big refactor.

1 Like

Maybe one solution is to construct a tree of ID’s, using a simple counter for the ID, then use that ID instead of objectid?


Edit: Turns out this is even slower than using an IdDict

It shaves about 20% off the time if I add a key attribute to each node (randomly generated), and then reference that in a dict instead of objectid. But still about 2.5x as slow as normal copy. Is the hashing (or array allocation) alone really that expensive?

const KeyType = UInt

mutable struct Node
    key::KeyType
    degree::Int  # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
    val::Float32  # If is a leaf, this stores the value
    op::Int  # enum over operators
    l::Node  # Left child node. Only defined for degree=1 or degree=2.
    r::Node  # Right child node. Only defined for degree=2. 

    Node(val::Float32) = new(rand(KeyType), 0, val)
    Node(op::Int, l::Node) = new(rand(KeyType), 1, 0f0, op, l)
    Node(op::Int, l::Node, r::Node) = new(rand(KeyType), 2, 0f0, op, l, r)
end

function copy_node(tree::Node)::Node
    if tree.degree == 0
        Node(copy(tree.val))
    elseif tree.degree == 1
        Node(copy(tree.op), copy_node(tree.l))
    else
        Node(
            copy(tree.op),
            copy_node(tree.l),
            copy_node(tree.r),
        )
    end
end

function copy_node(tree::Node, id_map::Dict{KeyType,Node})::Node
    get!(id_map, tree.key) do
        if tree.degree == 0
            Node(copy(tree.val))
        elseif tree.degree == 1
            Node(copy(tree.op), copy_node(tree.l, id_map))
        else
            Node(
                copy(tree.op),
                copy_node(tree.l, id_map),
                copy_node(tree.r, id_map),
            )
        end
    end
end

using BenchmarkTools

base_tree = Node(8, Node(1f0), Node(2, Node(3f0)))
tree = Node(2, base_tree, Node(2, Node(1, base_tree, base_tree), Node(2f0)))

@btime [copy_node($tree) for i=1:1_000];
# 126.417 μs (16001 allocations: 1007.94 KiB)
# (old: 101.292 μs (16001 allocations: 757.94 KiB))
@btime [copy_node($tree, Dict{KeyType,Node}()) for i=1:1_000];
# 232.208 μs (12001 allocations: 1023.56 KiB)
# (old: 300.042 μs (10001 allocations: 711.06 KiB))
1 Like

Just implement hash(node::Node) and isequal(a::Node, b::Node)

This is the internal implementation hash which is overkill for what you are trying to do:

hash(x::Any) = hash(x, zero(UInt))
hash(w::WeakRef, h::UInt) = hash(w.value, h)

## hashing general objects ##

hash(@nospecialize(x), h::UInt) = hash_uint(3h - objectid(x))

hash(x::Symbol) = objectid(x)

## core data hashing functions ##

function hash_64_64(n::UInt64)
    a::UInt64 = n
    a = ~a + a << 21
    a =  a ⊻ a >> 24
    a =  a + a << 3 + a << 8
    a =  a ⊻ a >> 14
    a =  a + a << 2 + a << 4
    a =  a ⊻ a >> 28
    a =  a + a << 31
    return a
end

function hash_64_32(n::UInt64)
    a::UInt64 = n
    a = ~a + a << 18
    a =  a ⊻ a >> 31
    a =  a * 21
    a =  a ⊻ a >> 11
    a =  a + a << 6
    a =  a ⊻ a >> 22
    return a % UInt32
end

function hash_32_32(n::UInt32)
    a::UInt32 = n
    a = a + 0x7ed55d16 + a << 12
    a = a ⊻ 0xc761c23c ⊻ a >> 19
    a = a + 0x165667b1 + a << 5
    a = a + 0xd3a2646c ⊻ a << 9
    a = a + 0xfd7046c5 + a << 3
    a = a ⊻ 0xb55a4f09 ⊻ a >> 16
    return a
end

if UInt === UInt64
    hash_uint64(x::UInt64) = hash_64_64(x)
    hash_uint(x::UInt)     = hash_64_64(x)
else
    hash_uint64(x::UInt64) = hash_64_32(x)
    hash_uint(x::UInt)     = hash_32_32(x)
end

## efficient value-based hashing of integers ##

hash(x::Int64,  h::UInt) = hash_uint64(bitcast(UInt64, x)) - 3h
hash(x::UInt64, h::UInt) = hash_uint64(x) - 3h
hash(x::Union{Bool,Int8,UInt8,Int16,UInt16,Int32,UInt32}, h::UInt) = hash(Int64(x), h)

function hash_integer(n::Integer, h::UInt)
    h ⊻= hash_uint((n % UInt) ⊻ h)
    n = abs(n)
    n >>>= sizeof(UInt) << 3
    while n != 0
        h ⊻= hash_uint((n % UInt) ⊻ h)
        n >>>= sizeof(UInt) << 3
    end
    return h
end

## symbol & expression hashing ##

if UInt === UInt64
    hash(x::Expr, h::UInt) = hash(x.args, hash(x.head, h + 0x83c7900696d26dc6))
    hash(x::QuoteNode, h::UInt) = hash(x.value, h + 0x2c97bf8b3de87020)
else
    hash(x::Expr, h::UInt) = hash(x.args, hash(x.head, h + 0x96d26dc6))
    hash(x::QuoteNode, h::UInt) = hash(x.value, h + 0x469d72af)
end

I tried setting hash(x::KeyType) = x.val but it didn’t really help… I guess it’s not the hash that’s slow, maybe it’s the dict allocation.

For example:

using Random

struct KeyType
    val::UInt
end
Base.convert(::Type{UInt}, key::KeyType) = key.val
Base.convert(::Type{KeyType}, key::UInt) = KeyType(key)
Base.rand(::Type{KeyType}) = KeyType(rand(UInt))
Base.hash(key::KeyType) = key.val
Base.hash(key::KeyType, h::UInt) = key.val + h
const DictType = Dict

mutable struct Node
    key::KeyType
    degree::Int  # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
    val::Float32  # If is a leaf, this stores the value
    op::Int  # enum over operators
    l::Node  # Left child node. Only defined for degree=1 or degree=2.
    r::Node  # Right child node. Only defined for degree=2. 

    Node(val::Float32) = new(rand(KeyType), 0, val)
    Node(op::Int, l::Node) = new(rand(KeyType), 1, 0f0, op, l)
    Node(op::Int, l::Node, r::Node) = new(rand(KeyType), 2, 0f0, op, l, r)
end

function copy_node(tree::Node)::Node
    if tree.degree == 0
        Node(copy(tree.val))
    elseif tree.degree == 1
        Node(copy(tree.op), copy_node(tree.l))
    else
        Node(
            copy(tree.op),
            copy_node(tree.l),
            copy_node(tree.r),
        )
    end
end

function copy_node(tree::Node, id_map::DictType{KeyType,Node})::Node
    get!(id_map, tree.key) do
        if tree.degree == 0
            Node(copy(tree.val))
        elseif tree.degree == 1
            Node(copy(tree.op), copy_node(tree.l, id_map))
        else
            Node(
                copy(tree.op),
                copy_node(tree.l, id_map),
                copy_node(tree.r, id_map),
            )
        end
    end
end

using BenchmarkTools

base_tree = Node(8, Node(1f0), Node(2, Node(3f0)))
tree = Node(2, base_tree, Node(2, Node(1, base_tree, base_tree), Node(2f0)))

@btime [copy_node($tree) for i=1:1_000];
#  126.333 μs (16001 allocations: 1007.94 KiB)
@btime [copy_node($tree, DictType{KeyType,Node}()) for i=1:1_000];
#  242.542 μs (12001 allocations: 1023.56 KiB)
1 Like

How about isequal?

If it’s the allocation, perhaps consider an immutable struct?

How would the immutable struct be constructed though? Presumably you’d need a list of objectid’s at some point.

By isequal do you mean to test node equality? That would mean that you couldn’t have separate copies of the same node content in a tree - you need to use === here.

1 Like
const NodeCounter = Ref{UInt}(0)
function get_node_key()
   NodeCounter[] += 1
   return NodeCounter[]
end

struct Node
    key::UInt
    degree::Int  # 0 for constant/variable, 1 for cos/sin, 2 for +/* etc.
    val::Float32  # If is a leaf, this stores the value
    op::Int  # enum over operators
    l::Base.RefValue{Node}  # Left child node. Only defined for degree=1 or degree=2.
    r::Base.RefValue{Node}  # Right child node. Only defined for degree=2.
    Node(val::Float32) = new(get_node_key(), 0, val)
    Node(op::Int, l::Node) = new(get_node_key(), 1, 0f0, op, Ref(l))
    Node(op::Int, l::Node, r::Node) = new(get_node_key(), 2, 0f0, op, Ref(l), Ref(r))
end
Base.isequal(a::Node, b::Node) = a.key == b.key

Ah, you mean making the expressions themselves immutable, I understand now. I’m not sure this would work because you need to frequently mutate single nodes inside large expressions (and optimize constants) - would be very expensive to construct the tree from scratch each time.

By the way - what does the Base.RefValue{Node} do here? I thought simply putting Node for the type of l and r would make those attributes act as references… Is RefValue somehow faster?