Persistent tree data structure as an alternative to Rust-like borrow checker

Hi, @MilesCranmer, I’m opening this as a separate thread so the other thread can focus on your thoughts about a Rust-like borrow checker for Julia.

I made an example of a Node type with a splice method that replaces a sub-tree of the Node with a different sub-tree. If I’m not mistaken, it avoids copying the whole tree over. As @bertschi mentioned, you only need to rebuild the path to the node that is changed—the other sub-trees are unchanged.

Here’s the code:

struct Node
    x::Int
    l::Union{Node, Nothing}
    r::Union{Node, Nothing}
end

# `path` is a boolean vector. A `false` means go left and
# a `true` means go right.
function splice(node, new_node, path)
    _splice(node, new_node, path, 1)
end

function _splice(node, new_node, path, i)
    go_right = path[i]
    i += 1
    if i > length(path)
        if go_right
            Node(node.x, node.l, new_node)
        else
            Node(node.x, new_node, node.r)
        end
    else
        if go_right
            Node(node.x, node.l, _splice(node.r, new_node, path, i))
        else
            Node(node.x, _splice(node.l, new_node, path, i), node.r)
        end
    end
end

Here’s a small example:

julia> t = (
           Node(1,
               Node(2,
                   Node(4,
                       nothing,
                       nothing
                   ),
                   Node(5,
                       nothing,
                       nothing
                   )
               ),
               Node(3,
                   nothing,
                   nothing
               )
           )
       )
Node(1, Node(2, Node(4, nothing, nothing), Node(5, nothing, nothing)), Node(3, nothing, nothing))

julia> s = (
           Node(6,
               Node(7,
                   nothing,
                   nothing
               ),
               Node(8,
                   nothing,
                   nothing
               )
           )
       )
Node(6, Node(7, nothing, nothing), Node(8, nothing, nothing))

julia> r = splice(t, s, [false, true])
Node(1, Node(2, Node(4, nothing, nothing), Node(6, Node(7, nothing, nothing), Node(8, nothing, nothing))), Node(3, nothing, nothing))

And a small benchmark:

julia> using Chairmarks

julia> path = [false, true];

julia> @be splice(t, s, path)
Benchmark: 3968 samples with 946 evaluations
 min    17.354 ns (2 allocs: 64 bytes)
 median 20.304 ns (2 allocs: 64 bytes)
 mean   25.512 ns (2 allocs: 64 bytes, 0.20% gc time)
 max    3.832 μs (2 allocs: 64 bytes, 98.99% gc time)

Admittedly, it’s a very small benchmark. Maybe you can generate a bigger example.

I don’t know if this is fast enough for you, or if this is similar to what you’ve tried in the past.

1 Like

Cool, thanks! How is the speed if you have, say, 100 nodes?

Also, how would you use this to modify a single node, in the middle of a tree? Is the idea you would rip out the children nodes, and attach those to the new tree at construction time?

Yeah, I think it would be something like that. I’ll see if I can add an update function (to change the value of a single node) and add some beefier benchmarks later this evening. :slight_smile:

1 Like

Ok, here’s the code with a new update method to update the value at a single node.

struct Node
    x::Int
    l::Union{Node, Nothing}
    r::Union{Node, Nothing}
end

function splice(node, new_node, path)
    _splice(node, new_node, path, 1)
end

function _splice(node, new_node, path, i)
    if i > length(path)
        new_node
    else
        go_right = path[i]
        i += 1
        if go_right
            Node(node.x, node.l, _splice(node.r, new_node, path, i))
        else
            Node(node.x, _splice(node.l, new_node, path, i), node.r)
        end
    end
end

update(node, value, path) = _update(node, value, path, 1)

function _update(node, value, path, i)
    if i > length(path)
        Node(value, node.l, node.r)
    else
        go_right = path[i]
        i += 1
        if go_right
            Node(node.x, node.l, _update(node.r, value, path, i))
        else
            Node(node.x, _update(node.l, value, path, i), node.r)
        end
    end
end

A basic example of update:

julia> t = (
           Node(1,
               Node(2,
                   Node(4,
                       nothing,
                       nothing
                   ),
                   Node(5,
                       nothing,
                       nothing
                   )
               ),
               Node(3,
                   nothing,
                   nothing
               )
           )
       )
Node(1, Node(2, Node(4, nothing, nothing), Node(5, nothing, nothing)), Node(3, nothing, nothing))

julia> s = (
           Node(6,
               Node(7,
                   nothing,
                   nothing
               ),
               Node(8,
                   nothing,
                   nothing
               )
           )
       )
Node(6, Node(7, nothing, nothing), Node(8, nothing, nothing))

julia> update(t, 100, [false])
Node(1, Node(100, Node(4, nothing, nothing), Node(5, nothing, nothing)), Node(3, nothing, nothing))

Benchmarking code:

make_tree(n) = _make_tree(0, n)

function _make_tree(i, n)
    if i == n
        nothing
    else
        i += 1
        Node(rand(Int16), _make_tree(i, n), _make_tree(i, n))
    end
end


s = make_tree(10)
t = make_tree(10)
julia> path = [false, true, false, true, false, true, false, true, false, true];

julia> @be splice(s, t, path)
Benchmark: 3150 samples with 361 evaluations
 min    51.130 ns (10 allocs: 320 bytes)
 median 59.440 ns (10 allocs: 320 bytes)
 mean   81.948 ns (10 allocs: 320 bytes, 0.40% gc time)
 max    9.236 μs (10 allocs: 320 bytes, 98.68% gc time)

julia> path = [false, true, false, true, false, true, false, true, false];

julia> @be update(t, 99999, path)
Benchmark: 5883 samples with 213 evaluations
 min    49.488 ns (10 allocs: 320 bytes)
 median 55.944 ns (10 allocs: 320 bytes)
 mean   73.091 ns (10 allocs: 320 bytes, 0.22% gc time)
 max    13.421 μs (10 allocs: 320 bytes, 98.83% gc time)

It’s still remarkably fast for the larger trees! (s and t each have 1023 nodes.)

2 Likes

Thanks for working on this!

Here’s the version with mutables, unless I’m doing something wrong (strong possibility as I’m sleepy)

julia> mutable struct Node
           const x::Int
           l::Node
           r::Node
       
           Node(x::Integer) = new(x)
           Node(x::Integer, l::Node, r::Node) = new(x, l, r)
       end

julia> function make_tree(n, i=0)
           if i == n
               return Node(0)
           else
               i += 1
               return Node(rand(Int16), make_tree(n, i), make_tree(n, i))
           end
       end
make_tree (generic function with 2 methods)

julia> s = make_tree(10);

julia> t = make_tree(10);

julia> splice(s, t) = (s.l.r.l.r.l.r.l.r.l = t; s)
splice (generic function with 1 method)

julia> @btime splice($s, $t);
  7.292 ns (0 allocations: 0 bytes)
1 Like

I’d like to point you @MilesCranmer to the implementation of ScopedValues in Base.

ScopedValues are based on just such a persistent tree data structure. For performance the tree has rather large fan-out plus some well-googleable tricks (keyword: HAMT / hash array mapped trie).

Unfortunately the Base version doesn’t handle hash-collisions, so it’s not usable for general-purpose programming (it can get away with that because the keys are Symbol, and two Symbols compare equal if and only if their hashes coincide, because they are globally interned string and the hash implementation is a bijective map on the heap-address).

However, this usecase demonstrates very well how and when to use persistent datastructures, and the advantages of garbage collection.

The basic problem is: When creating a child task, the child must inherit the current ScopedValue-dict. Copying would be O(size(scopedValueDict)), which sucks. So instead, the scopedValueDict is a persistent datastructure and the child receives a snapshot; when entering a ScopedValue construction, the current task’s scopedValueDict is updated, and when catching an exception or leaving a ScopedValue construction, the scopedValueDict is restored. So we pay a small price on task-creation and all try / finally blocks, regardless of whether ScopedValues are used.

Since you are a part-time rustacean, I’d be interested in how you would do something similar? You’d probably use refcounted smart pointers?

I’d also like to note that this is not about avoiding race condition bugs, in the sense of rust “we’re gonna protect you from programmer error”, it’s a genuinely appropriate datastructure even if programmers were perfect.

3 Likes

I don’t completely follow how this works, do you have an example of how it could be used? I’m not opposed to any of these ideas; it just needs to be fast enough for the expression search in SymbolicRegression.jl. But if it’s a little bit slower than direct mutation, but much safer, then I think it’s a win. I haven’t found one fast enough yet though.

I haven’t thought through how I’d replicate this in Rust, actually! I would have thought the only rust borrow checker-approved approach would be an array of nodes and an array of edges, rather than a mutable binary tree.

This isn’t particularly surprising to me. However, the question becomes, how much of a bottleneck is this particular operation for you?

If before/after each splice there’s some pre/post-processing that takes e.g. a couple hundred nanoseconds to a microsecond, then I’d look at this extra overhead as a pretty cheap price to pay to avoid the classic pitfalls of mutation heavy multithreaded code.

DynamicExpressions.jl and SymbolicRegression.jl have been through so many rounds of profiling and micro-optimizations that I’m nearly starting to hear the profiler whisper to me in my sleep. It’s just so hard to improve performance in it now. I’ve tried immutable structs for the expressions multiple times but it’s been substantially slower every time.

Some of the more expensive parts of the code (maybe ~30-50% for a test workload) is from running Optim.jl on the symbolic expressions. For these, I create a Ref to each of the leafs holding a variable in the tree, so I don’t need to walk through each time. Then it becomes extremely efficient to take the candidate parameter vector from Optim.jl, and directly map it into the leafs with a single map((node, newval) -> (node[].val = newval), leafs, x). I think it would be really hard to match the performance of this with immutables as you’d have to reconstruct the tree from the roots to the leafs every forward pass. Since there can be 100 evaluations for different values for the constants with each optimization loop, you can pick up a lot of allocations if needing to allocate new nodes each time. (and yes, I’ve tried with .val::RefValue as well - improved but still no match due to the other setproperties! calls needed throughout the evolutionary algorithm parts).

Nowdays the evaluation itself for a similarly-sized data structure can often be sub-us and result in zero allocations (due to a preallocated buffer). And so the tree writes and copies themselves are becoming a sizable chunk of the flamegraph, surprisingly. So I have even started to re-use a preallocated vector of nodes, and copy the tree onto those nodes (another thing a borrow checker would be super useful for is ensuring safe use of such preallocated buffers).

1 Like

@CameronBieganek Thanks for taking this up, this was exactly the example I had in mind. Also played a bit with Accessors.jl – which provide a neat syntax for working with path information – but unfortunately it seems quite a bit slower than your direct approach:

julia> using Accessors

julia> path = @o _.l.r.l.r.l.r.l.r.l
(@o _.l.r.l.r.l.r.l.r.l)

julia> @btime set($s, $path, $t);
  3.113 μs (23 allocations: 512 bytes)

# Same with a shorter path
julia> @btime @set ($s).l.r.l.r = $t;
  892.978 ns (8 allocations: 192 bytes)

# For comparison: Your splice on my machine
julia> path = [false, true, false, true, false, true, false, true, false, true];

julia> @btime splice($s, $t, $path);
  243.096 ns (10 allocations: 320 bytes)

@MilesCranmer Had a quick look on DynamicExpressions and for your use case, rebuilding a tree with many constants is probably not an option with immutable data structures, i.e., mutating refs will be much faster. On the other hand, there is always the option of using another level of indirection, i.e., treating the constants as another type of variables. From what I got, variables are represented as an index in the tree and eval uses that to look them up in the input, whereas constants are stored as literal values in the tree. Instead, you could just store them just like variables and have eval look them up in another vector, i.e., like a function with inputs and parameters. Whether to re-use that vector is then an independent concern.

1 Like