Hi all,
Iβm trying to convert some code from recursion to iteration keeping an internal stack. My benchmark is showing no improvement. MWE
using BenchmarkTools
mutable struct Node{T}
value::T
left::Union{Nothing, Node{T}}
right::Union{Nothing, Node{T}}
height::Int
end
const Tree{T} = Union{Nothing, Node{T}}
nil() = nothing
node(t::T, left, right, height) where T = Node{T}(t, left, right, height)
height(tree::Nothing) = 0
height(tree::Node{T}) where T = tree.height
height!(tree::Nothing) = 0
function height!(tree::Node{T}) where T
left = height(tree.left)
right = height(tree.right)
max(left, right) + 1
end
function rotateright(tree::Node{T}) where T
tree.left.right, tree.left, tree =
tree, tree.left.right, tree.left
tree.right.height = height!(tree.right)
tree
end
function rotateleft(tree::Node{T}) where T
tree.right.left, tree.right, tree =
tree, tree.right.left, tree.right
tree.left.height = height!(tree.left)
tree
end
balance(tree::Nothing) = nothing
function balance(tree::Node{T}) where T
height_left = height(tree.left)
height_right = height(tree.right)
if height_left < height_right - 1
tree = rotateleft(tree)
height_left = tree.left.height
height_right = tree.right.height
elseif height_left > height_right + 1
tree = rotateright(tree)
height_left = tree.left.height
height_right = tree.right.height
end
# _height = height!(tree)
_height = max(height_left, height_right) + 1
if tree.height != _height
tree.height = _height
end
tree
end
insert1(tree::Nothing, t::T) where T = node(t, nothing, nothing, 1)
function insert1(tree::Node{T}, t::T) where T
if t < tree.value
_height = height(tree.left)
tree.left = insert1(tree.left, t)
if height(tree.left) != _height
tree = balance(tree)
end
elseif t > tree.value
_height = height(tree.right)
tree.right = insert1(tree.right, t)
if height(tree.right) != _height
tree = balance(tree)
end
end
tree
end
insert2(tree::Nothing, t::T) where T = node(t, nothing, nothing, 1)
function insert2(tree::Node{T}, t::T) where T
stack = Vector{Tuple{Node{T}, Int, Symbol}}(undef, 64)
sp = 1
parent = tree
while parent != nothing
if t < parent.value
stack[sp] = (parent, height(parent.left), :left)
sp += 1
parent = parent.left
elseif t > parent.value
stack[sp] = (parent, height(parent.right), :right)
sp += 1
parent = parent.right
else
break
end
end
if parent == nothing
tree = node(t, nothing, nothing, 1)
else
tree = parent
end
while sp != 1
sp -= 1
parent, _height, loc = stack[sp]
if loc == :left
parent.left = tree
if height(parent.left) != _height
parent = balance(parent)
end
elseif loc == :right
parent.right = tree
if height(parent.right) != _height
parent = balance(parent)
end
end
tree = parent
end
tree
end
insert3(tree::Nothing, t::T) where T = node(t, nothing, nothing, 1)
function insert3(tree::Node{T}, t::T) where T
stack = Vector{Tuple{Ref{Node{T}}, Int, Symbol}}(undef, 64)
sp = 1
parent = tree
while parent != nothing
if t < parent.value
stack[sp] = (Ref{Node{T}}(parent), height(parent.left), :left)
sp += 1
parent = parent.left
elseif t > parent.value
stack[sp] = (Ref{Node{T}}(parent), height(parent.right), :right)
sp += 1
parent = parent.right
else
break
end
end
if parent == nothing
tree = node(t, nothing, nothing, 1)
else
tree = parent
end
while sp != 1
sp -= 1
parent_ref, _height, loc = stack[sp]
parent = parent_ref[]
if loc == :left
parent.left = tree
if height(parent.left) != _height
parent = balance(parent)
end
elseif loc == :right
parent.right = tree
if height(parent.right) != _height
parent = balance(parent)
end
end
tree = parent
end
tree
end
Base.haskey(tree::Nothing, t::T) where T = false
function Base.haskey(tree::Node{T}, t::T) where T
while tree != nothing
if t < tree.value
tree = tree.left
elseif t > tree.value
tree = tree.right
else
return true
end
end
return false
end
function test1(n)
tree = nil()
for j in 1:n
tree = insert1(tree, j)
end
@assert !haskey(tree, 0)
for j in 1:n
@assert haskey(tree, j)
end
@assert !haskey(tree, n + 1)
tree
end
function test2(n)
tree = nil()
for j in 1:n
tree = insert2(tree, j)
end
@assert !haskey(tree, 0)
for j in 1:n
@assert haskey(tree, j)
end
@assert !haskey(tree, n + 1)
tree
end
function test3(n)
tree = nil()
for j in 1:n
tree = insert3(tree, j)
end
@assert !haskey(tree, 0)
for j in 1:n
@assert haskey(tree, j)
end
@assert !haskey(tree, n + 1)
tree
end
GC.gc()
@btime test1(10000)
GC.gc()
@btime test2(10000)
GC.gc()
@btime test3(10000)
println()
with result
2.299 ms (10000 allocations: 468.75 KiB)
19.735 ms (143616 allocations: 19.79 MiB)
21.359 ms (276723 allocations: 12.21 MiB)
Is there a way to get rid of the additional allocations?