Interesting. It seems one is able to get allocation free TCO for mutual recursive functions in Julia this way. MWE:
using BenchmarkTools
struct Done{T}
res::T
end
struct Call{F <: Function, T <: Tuple}
f::F
args::T
end
_call(call::Call{F, T}) where {F <: Function, T <: Tuple} = call.f(call.args...)
_return(call::Call{F, T}) where {F <: Function, T <: Tuple} = nothing
_return(done::Done{T}) where T = done.res
function Base.iterate(::Call{F, T}, bounce) where {F <: Function, T <: Tuple}
if !isa(bounce, Done)
bounce = _call(bounce)
bounce, bounce
else
return nothing
end
end
function Base.iterate(call::Call{F, T}) where {F <: Function, T <: Tuple}
bounce = _call(call)
bounce, bounce
end
function trampoline(it)
res = nothing
for bounce in it
res = _return(bounce)
end
res
end
function rewriterhs(names, ::Val{:call}, args)
if args[1] in names
Expr(:call, :Call, Symbol(args[1], :step),
Expr(:tuple, args[2:end]...))
else
Expr(:call, (rewrite(names, arg) for arg in args)...)
end
end
function rewritelhs(names, ::Val{:call}, args)
if args[1] in names
Expr(:call, Symbol(args[1], :step), (rewritelhs(names, arg) for arg in args[2:end])...)
else
Expr(:call, args...)
end
end
function rewriterhs(names, ::Val{:if}, args)
Expr(:if, args[1], (rewriterhs(names, arg) for arg in args[2:end])...)
end
function rewriterhs(names, ::Val{:block}, args)
Expr(:block, (rewriterhs(names, arg) for arg in args)...)
end
function rewritelhs(names, ::Val{:kw}, args)
args[1]
end
function rewriterhs(names, lineNumberNode::LineNumberNode)
lineNumberNode
end
function rewriterhs(names, number::Number)
Expr(:call, :Done, number)
end
function rewriterhs(names, symbol::Symbol)
Expr(:call, :Done, symbol)
end
function rewritelhs(names, symbol::Symbol)
symbol
end
function rewriterhs(names, expr::Expr)
rewriterhs(names, Val(expr.head), expr.args)
end
function rewritelhs(names, expr::Expr)
rewritelhs(names, Val(expr.head), expr.args)
end
funcname(expr) = expr.args[1].args[1]
funcargs(expr) = expr.args[1].args[2:end]
function tcostep(names, expr)
Expr(expr.head,
rewritelhs(names, expr.args[1]),
rewriterhs(names, expr.args[2]))
end
function tcotrampoline(names, expr)
Expr(expr.head, expr.args[1],
Expr(:block,
Expr(:call, :trampoline,
Expr(:call, :Call, Symbol(funcname(expr), :step),
Expr(:tuple, (rewritelhs(names, arg) for arg in funcargs(expr))...)))))
end
macro tco(names, expr)
Expr(:block,
esc(tcostep(eval(names), expr)),
esc(tcotrampoline(eval(names), expr)))
end
macro dump(expr)
println(expr)
end
@tco [:fac] (fac(n, r=1) = n == 0 ? r : fac(n - 1, n * r))
@tco [:pow] (pow(x, n, r=1) = n == 0 ? r : pow(x, n - 1, x * r))
@tco [:fib] (fib(n, r1=1, r2=1) = n == 1 ? r1 : (n == 2 ? r2 : fib(n - 1, r2, r1 + r2)))
@tco [:even, :odd] (even(n) = n == 0 ? true : odd(n - 1))
@tco [:even, :odd] (odd(n) = n == 0 ? false : even(n - 1))
n = 63
print("fac ")
@btime fac($n)
x = 2.
n = 64
print("pow ")
@btime pow($x, $n)
n = 10
print("fib ")
@btime fib($n)
n = 2^17
print("even ")
@btime even($n)
print("odd ")
@btime odd($n)
with results
fac 47.976 ns (0 allocations: 0 bytes)
pow 41.802 ns (0 allocations: 0 bytes)
fib 10.210 ns (0 allocations: 0 bytes)
even 83.800 μs (0 allocations: 0 bytes)
odd 83.800 μs (0 allocations: 0 bytes)
The macros are probably brittle, because I’m new to programming them.