Recursion to iteration

Hi all,

I’m trying to convert some code from recursion to iteration keeping an internal stack. My benchmark is showing no improvement. MWE

using BenchmarkTools

mutable struct Node{T}
    value::T
    left::Union{Nothing, Node{T}}
    right::Union{Nothing, Node{T}}
    height::Int
end
const Tree{T} = Union{Nothing, Node{T}}

nil() = nothing
node(t::T, left, right, height) where T = Node{T}(t, left, right, height)

height(tree::Nothing) = 0
height(tree::Node{T}) where T = tree.height

height!(tree::Nothing) = 0
function height!(tree::Node{T}) where T
    left = height(tree.left)
    right = height(tree.right)
    max(left, right) + 1
end

function rotateright(tree::Node{T}) where T
    tree.left.right, tree.left, tree =
        tree, tree.left.right, tree.left
    tree.right.height = height!(tree.right)
    tree
end
function rotateleft(tree::Node{T})  where T
    tree.right.left, tree.right, tree =
        tree, tree.right.left, tree.right
    tree.left.height = height!(tree.left)
    tree
end

balance(tree::Nothing) = nothing
function balance(tree::Node{T}) where T
    height_left = height(tree.left)
    height_right = height(tree.right)
    if height_left < height_right - 1
        tree = rotateleft(tree)
        height_left = tree.left.height
        height_right = tree.right.height
    elseif height_left > height_right + 1
        tree = rotateright(tree)
        height_left = tree.left.height
        height_right = tree.right.height
    end
    # _height = height!(tree)
    _height = max(height_left, height_right) + 1
    if tree.height != _height
        tree.height = _height
    end
    tree
end

insert1(tree::Nothing, t::T) where T = node(t, nothing, nothing, 1)
function insert1(tree::Node{T}, t::T) where T
    if t < tree.value
        _height = height(tree.left)
        tree.left = insert1(tree.left, t)
        if height(tree.left) != _height
            tree = balance(tree)
        end
    elseif t > tree.value
        _height = height(tree.right)
        tree.right = insert1(tree.right, t)
        if height(tree.right) != _height
            tree = balance(tree)
        end
    end
    tree
end

insert2(tree::Nothing, t::T) where T = node(t, nothing, nothing, 1)
function insert2(tree::Node{T}, t::T) where T
    stack = Vector{Tuple{Node{T}, Int, Symbol}}(undef, 64)
    sp = 1
    parent = tree
    while parent != nothing
        if t < parent.value
            stack[sp] = (parent, height(parent.left), :left)
            sp += 1
            parent = parent.left
        elseif t > parent.value
            stack[sp] = (parent, height(parent.right), :right)
            sp += 1
            parent = parent.right
        else
            break
        end
    end
    if parent == nothing
        tree = node(t, nothing, nothing, 1)
    else
        tree = parent
    end
    while sp != 1
        sp -= 1
        parent, _height, loc = stack[sp]
        if loc == :left
            parent.left = tree
            if height(parent.left) != _height
                parent = balance(parent)
            end
        elseif loc == :right
            parent.right = tree
            if height(parent.right) != _height
                parent = balance(parent)
            end
        end
        tree = parent
    end
    tree
end

insert3(tree::Nothing, t::T) where T = node(t, nothing, nothing, 1)
function insert3(tree::Node{T}, t::T) where T
    stack = Vector{Tuple{Ref{Node{T}}, Int, Symbol}}(undef, 64)
    sp = 1
    parent = tree
    while parent != nothing
        if t < parent.value
            stack[sp] = (Ref{Node{T}}(parent), height(parent.left), :left)
            sp += 1
            parent = parent.left
        elseif t > parent.value
            stack[sp] = (Ref{Node{T}}(parent), height(parent.right), :right)
            sp += 1
            parent = parent.right
        else
            break
        end
    end
    if parent == nothing
        tree = node(t, nothing, nothing, 1)
    else
        tree = parent
    end
    while sp != 1
        sp -= 1
        parent_ref, _height, loc = stack[sp]
        parent = parent_ref[]
        if loc == :left
            parent.left = tree
            if height(parent.left) != _height
                parent = balance(parent)
            end
        elseif loc == :right
            parent.right = tree
            if height(parent.right) != _height
                parent = balance(parent)
            end
        end
        tree = parent
    end
    tree
end

Base.haskey(tree::Nothing, t::T) where T = false
function Base.haskey(tree::Node{T}, t::T) where T
    while tree != nothing
        if t < tree.value
            tree = tree.left
        elseif t > tree.value
            tree = tree.right
        else
            return true
        end
    end
    return false
end

function test1(n)
    tree = nil()
    for j in 1:n
        tree = insert1(tree, j)
    end
    @assert !haskey(tree, 0)
    for j in 1:n
        @assert haskey(tree, j)
    end
    @assert !haskey(tree, n + 1)
    tree
end

function test2(n)
    tree = nil()
    for j in 1:n
        tree = insert2(tree, j)
    end
    @assert !haskey(tree, 0)
    for j in 1:n
        @assert haskey(tree, j)
    end
    @assert !haskey(tree, n + 1)
    tree
end

function test3(n)
    tree = nil()
    for j in 1:n
        tree = insert3(tree, j)
    end
    @assert !haskey(tree, 0)
    for j in 1:n
        @assert haskey(tree, j)
    end
    @assert !haskey(tree, n + 1)
    tree
end

GC.gc()
@btime test1(10000)
GC.gc()
@btime test2(10000)
GC.gc()
@btime test3(10000)
println()

with result

  2.299 ms (10000 allocations: 468.75 KiB)
  19.735 ms (143616 allocations: 19.79 MiB)
  21.359 ms (276723 allocations: 12.21 MiB)

Is there a way to get rid of the additional allocations?

On a second thought, maybe eliminating the internal stack and keeping the parent as part of the node structure is the way to go.

one slightly unrelated observation: your balancing code is much less efficient than it could be. currently your height calculation is O(2^height) while it could be O(1)

Thanks for your feedback. It would be very helpful to know, where I’m going wrong.

the basic is that you should rely on the already calculated heights that haven’t changed rather than recalculating them.

I’d start with finding out where the allocations are coming from. You can either benchmark smaller pieces of code instead of everything at once or use the Profile stdlib (Profiling Β· The Julia Language).

Also note that @btime as well as @benchmark already run your code in a loop and report statistics over all runs.

Another part would be finding out whether your code is type stable - a cursory look using @code_warntype should help here.

1 Like

Thank you,

problem is creating and copying the tuples to the stack. Regarding type stability: at least I don’t see any indication of dynamic dispatch in Juno profiler.

Fair critique. But height calculation is not done recursively and shouldn’t change the complexity of the algorithm therefore.

That doesn’t fit my model of how allocation in julia works or how it’s counted, do you mind explaining how you got to that conclusion?

Juno profiler indicates allocations there or below. Should I post profile logs?

Edit: @code_warntype shows problems at least for insert3

Variables
  #self#::Core.Const(insert3)
  tree@_2::Node{Int64}
  t::Int64
  parent::Any
  sp::Int64
  stack::Vector{Tuple{Ref{Node{Int64}}, Int64, Symbol}}
  @_7::Int64
  loc::Symbol
  _height::Int64
  parent_ref::Ref{Node{Int64}}
  tree@_11::Any

Body::Any
1 ──       (tree@_11 = tree@_2)
β”‚    %2  = Core.apply_type(Main.Node, $(Expr(:static_parameter, 1)))::Core.Const(Node{Int64})
β”‚    %3  = Core.apply_type(Main.Ref, %2)::Core.Const(Ref{Node{Int64}})
β”‚    %4  = Core.apply_type(Main.Tuple, %3, Main.Int, Main.Symbol)::Core.Const(Tuple{Ref{Node{Int64}}, Int64, Symbol})
β”‚    %5  = Core.apply_type(Main.Vector, %4)::Core.Const(Vector{Tuple{Ref{Node{Int64}}, Int64, Symbol}})
β”‚          (stack = (%5)(Main.undef, 64))
β”‚          (sp = 1)
└───       (parent = tree@_11::Node{Int64})
2 ┄─ %9  = (parent::Tree{Int64} != Main.nothing)::Bool
└───       goto #9 if not %9
3 ── %11 = Base.getproperty(parent::Tree{Int64}, :value)::Int64
β”‚    %12 = (t < %11)::Bool
└───       goto #5 if not %12
4 ── %14 = Core.apply_type(Main.Node, $(Expr(:static_parameter, 1)))::Core.Const(Node{Int64})
β”‚    %15 = Core.apply_type(Main.Ref, %14)::Core.Const(Ref{Node{Int64}})
β”‚    %16 = (%15)(parent::Tree{Int64})::Base.RefValue{Node{Int64}}
β”‚    %17 = Base.getproperty(parent::Tree{Int64}, :left)::Tree{Int64}
β”‚    %18 = Main.height(%17)::Int64
β”‚    %19 = Core.tuple(%16, %18, :left)::Core.PartialStruct(Tuple{Base.RefValue{Node{Int64}}, Int64, Symbol}, Any[Base.RefValue{Node{Int64}}, Int64, Core.Const(:left)])    
β”‚          Base.setindex!(stack, %19, sp)
β”‚          (sp = sp + 1)
β”‚          (parent = Base.getproperty(parent::Tree{Int64}, :left))
└───       goto #8
5 ── %24 = Base.getproperty(parent::Tree{Int64}, :value)::Int64
β”‚    %25 = (t > %24)::Bool
└───       goto #7 if not %25
6 ── %27 = Core.apply_type(Main.Node, $(Expr(:static_parameter, 1)))::Core.Const(Node{Int64})
β”‚    %28 = Core.apply_type(Main.Ref, %27)::Core.Const(Ref{Node{Int64}})
β”‚    %29 = (%28)(parent::Tree{Int64})::Base.RefValue{Node{Int64}}
β”‚    %30 = Base.getproperty(parent::Tree{Int64}, :right)::Tree{Int64}
β”‚    %31 = Main.height(%30)::Int64
β”‚    %32 = Core.tuple(%29, %31, :right)::Core.PartialStruct(Tuple{Base.RefValue{Node{Int64}}, Int64, Symbol}, Any[Base.RefValue{Node{Int64}}, Int64, Core.Const(:right)])  
β”‚          Base.setindex!(stack, %32, sp)
β”‚          (sp = sp + 1)
β”‚          (parent = Base.getproperty(parent::Tree{Int64}, :right))
└───       goto #8
7 ──       goto #9
8 ┄─       goto #2
9 ┄─ %39 = (parent::Tree{Int64} == Main.nothing)::Bool
└───       goto #11 if not %39
10 ─       (tree@_11 = Main.node(t, Main.nothing, Main.nothing, 1))
└───       goto #12
11 ─       (tree@_11 = parent::Tree{Int64})
12 β”„ %44 = (sp != 1)::Bool
└───       goto #21 if not %44
13 ─       (sp = sp - 1)
β”‚    %47 = Base.getindex(stack, sp)::Tuple{Ref{Node{Int64}}, Int64, Symbol}
β”‚    %48 = Base.indexed_iterate(%47, 1)::Core.PartialStruct(Tuple{Ref{Node{Int64}}, Int64}, Any[Ref{Node{Int64}}, Core.Const(2)])
β”‚          (parent_ref = Core.getfield(%48, 1))
β”‚          (@_7 = Core.getfield(%48, 2))
β”‚    %51 = Base.indexed_iterate(%47, 2, @_7::Core.Const(2))::Core.PartialStruct(Tuple{Int64, Int64}, Any[Int64, Core.Const(3)])
β”‚          (_height = Core.getfield(%51, 1))
β”‚          (@_7 = Core.getfield(%51, 2))
β”‚    %54 = Base.indexed_iterate(%47, 3, @_7::Core.Const(3))::Core.PartialStruct(Tuple{Symbol, Int64}, Any[Symbol, Core.Const(4)])
β”‚          (loc = Core.getfield(%54, 1))
β”‚          (parent = Base.getindex(parent_ref))
β”‚    %57 = (loc == :left)::Bool
└───       goto #17 if not %57
14 ─       Base.setproperty!(parent, :left, tree@_11)
β”‚    %60 = Base.getproperty(parent, :left)::Any
β”‚    %61 = Main.height(%60)::Int64
β”‚    %62 = (%61 != _height)::Bool
└───       goto #16 if not %62
15 ─       (parent = Main.balance(parent))
16 β”„       goto #20
17 ─ %66 = (loc == :right)::Bool
└───       goto #20 if not %66
18 ─       Base.setproperty!(parent, :right, tree@_11)
β”‚    %69 = Base.getproperty(parent, :right)::Any
β”‚    %70 = Main.height(%69)::Int64
β”‚    %71 = (%70 != _height)::Bool
└───       goto #20 if not %71
19 ─       (parent = Main.balance(parent))
20 β”„       (tree@_11 = parent)
└───       goto #12
21 ─       return tree@_11

It seems to me the problem starts with dereferencing the reference:

parent = parent_ref[]

Here is what I see for insert2

and here is the profile result for insert3

and I understand neither of them.

Here is what I found out about insert2: type inference does not seem to figure out that parent is a Node and not a Union(Nothing, Node). So after adding

            @assert isa(parent, Node)
            stack[sp] = (parent, height(parent.right), :right)

insert2 works as intended. In insert3 type inference also does not conclude that parent_ref[] is a Node if parent_ref is a Ref{Node}. Here helps changing parent = parent_ref[] to parent = convert(Node{T}, parent_ref[]). Are these solutions idiomatic?

Final code:

using BenchmarkTools
using StaticArrays

mutable struct Node{T}
    value::T
    left::Union{Nothing, Node{T}}
    right::Union{Nothing, Node{T}}
    height::Int
end
const Tree{T} = Union{Nothing, Node{T}}

nil() = nothing
node(t::T, left, right, height) where T = Node{T}(t, left, right, height)

height(tree::Nothing) = 0
height(tree::Node{T}) where T = tree.height

height!(tree::Nothing) = 0
function height!(tree::Node{T}) where T
    left = height(tree.left)
    right = height(tree.right)
    max(left, right) + 1
end

function rotateright(tree::Node{T}) where T
    tree.left.right, tree.left, tree =
        tree, tree.left.right, tree.left
    tree.right.height = height!(tree.right)
    tree
end
function rotateleft(tree::Node{T})  where T
    tree.right.left, tree.right, tree =
        tree, tree.right.left, tree.right
    tree.left.height = height!(tree.left)
    tree
end

balance(tree::Nothing) = nothing
function balance(tree::Node{T}) where T
    height_left = height(tree.left)
    height_right = height(tree.right)
    if height_left < height_right - 1
        tree = rotateleft(tree)
        height_left = tree.left.height
        height_right = tree.right.height
    elseif height_left > height_right + 1
        tree = rotateright(tree)
        height_left = tree.left.height
        height_right = tree.right.height
    end
    # _height = height!(tree)
    _height = max(height_left, height_right) + 1
    if tree.height != _height
        tree.height = _height
    end
    tree
end

insert1(tree::Nothing, t::T) where T = node(t, nothing, nothing, 1)
function insert1(tree::Node{T}, t::T) where T
    if t < tree.value
        _height = height(tree.left)
        tree.left = insert1(tree.left, t)
        if height(tree.left) != _height
            tree = balance(tree)
        end
    elseif t > tree.value
        _height = height(tree.right)
        tree.right = insert1(tree.right, t)
        if height(tree.right) != _height
            tree = balance(tree)
        end
    end
    tree
end

insert2(tree::Nothing, t::T, stack::Vector{Tuple{Node{T}, Int, Symbol}}) where T = node(t, nothing, nothing, 1)
function insert2(tree::Node{T}, t::T, stack::Vector{Tuple{Node{T}, Int, Symbol}}) where T
    sp = 1
    parent = tree
    while parent != nothing
        if t < parent.value
            @assert isa(parent, Node)
            @inbounds stack[sp] = (parent, height(parent.left), :left)
            sp += 1
            parent = parent.left
        elseif t > parent.value
            @assert isa(parent, Node)
            @inbounds stack[sp] = (parent, height(parent.right), :right)
            sp += 1
            parent = parent.right
        else
            break
        end
    end
    if parent == nothing
        tree = node(t, nothing, nothing, 1)
    else
        tree = parent
    end
    while sp != 1
        sp -= 1
        @inbounds parent, _height, loc = stack[sp]
        if loc == :left
            parent.left = tree
            if height(tree) != _height
                tree = balance(parent)
            else
                @inbounds tree = stack[1][1]
                break
            end
        elseif loc == :right
            parent.right = tree
            if height(tree) != _height
                tree = balance(parent)
            else
                @inbounds tree = stack[1][1]
                break
            end
        end
    end
    tree
end

Base.haskey(tree::Nothing, t::T) where T = false
function Base.haskey(tree::Node{T}, t::T) where T
    while tree != nothing
        if t < tree.value
            tree = tree.left
        elseif t > tree.value
            tree = tree.right
        else
            return true
        end
    end
    return false
end

function test1(n)
    tree = nil()
    for j in 1:n
        tree = insert1(tree, j)
    end
    @assert !haskey(tree, 0)
    for j in 1:n
        @assert haskey(tree, j)
    end
    @assert !haskey(tree, n + 1)
    tree
end

function test2(n)
    tree = nil()
    stack = Vector{Tuple{Node{Int}, Int, Symbol}}(undef, 64)
    for j in 1:n
        tree = insert2(tree, j, stack)
    end
    @assert !haskey(tree, 0)
    for j in 1:n
        @assert haskey(tree, j)
    end
    @assert !haskey(tree, n + 1)
    tree
end

for n in [10000, 100000, 1000000]
    GC.gc()
    @btime test1($n)
    GC.gc()
    @btime test2($n)
end

with output

  2.177 ms (10000 allocations: 468.75 KiB)
  4.018 ms (10001 allocations: 470.34 KiB)
  28.577 ms (100000 allocations: 4.58 MiB)
  42.629 ms (100001 allocations: 4.58 MiB)
  380.003 ms (1000000 allocations: 45.78 MiB)
  466.184 ms (1000001 allocations: 45.78 MiB)

My conclusion: probably only makes sense for very large n.