Trampolines

Interesting. It seems one is able to get allocation free TCO for mutual recursive functions in Julia this way. MWE:

using BenchmarkTools

struct Done{T}
   res::T
end
struct Call{F <: Function, T <: Tuple}
   f::F
   args::T
end

_call(call::Call{F, T}) where {F <: Function, T <: Tuple} = call.f(call.args...)
_return(call::Call{F, T}) where {F <: Function, T <: Tuple} = nothing
_return(done::Done{T}) where T = done.res

function Base.iterate(::Call{F, T}, bounce) where {F <: Function, T <: Tuple}
   if !isa(bounce, Done)
      bounce = _call(bounce)
      bounce, bounce
   else
      return nothing
   end
end
function Base.iterate(call::Call{F, T}) where {F <: Function, T <: Tuple}
   bounce = _call(call)
   bounce, bounce
end

function trampoline(it)
   res = nothing
   for bounce in it
      res = _return(bounce)
   end
   res
end

function rewriterhs(names, ::Val{:call}, args)
   if args[1] in names
      Expr(:call, :Call, Symbol(args[1], :step),
           Expr(:tuple, args[2:end]...))
   else
      Expr(:call, (rewrite(names, arg) for arg in args)...)
   end
end
function rewritelhs(names, ::Val{:call}, args)
   if args[1] in names
      Expr(:call, Symbol(args[1], :step), (rewritelhs(names, arg) for arg in args[2:end])...)
   else
      Expr(:call, args...)
   end
end

function rewriterhs(names, ::Val{:if}, args)
   Expr(:if, args[1], (rewriterhs(names, arg) for arg in args[2:end])...)
end

function rewriterhs(names, ::Val{:block}, args)
   Expr(:block, (rewriterhs(names, arg) for arg in args)...)
end

function rewritelhs(names, ::Val{:kw}, args)
   args[1]
end

function rewriterhs(names, lineNumberNode::LineNumberNode)
   lineNumberNode
end
function rewriterhs(names, number::Number)
   Expr(:call, :Done, number)
end
function rewriterhs(names, symbol::Symbol)
   Expr(:call, :Done, symbol)
end
function rewritelhs(names, symbol::Symbol)
   symbol
end
function rewriterhs(names, expr::Expr)
   rewriterhs(names, Val(expr.head), expr.args)
end
function rewritelhs(names, expr::Expr)
   rewritelhs(names, Val(expr.head), expr.args)
end

funcname(expr) = expr.args[1].args[1]
funcargs(expr) = expr.args[1].args[2:end]

function tcostep(names, expr)
    Expr(expr.head,
         rewritelhs(names, expr.args[1]),
         rewriterhs(names, expr.args[2]))
end
function tcotrampoline(names, expr)
   Expr(expr.head, expr.args[1],
      Expr(:block,
      Expr(:call, :trampoline,
         Expr(:call, :Call, Symbol(funcname(expr), :step),
            Expr(:tuple, (rewritelhs(names, arg) for arg in funcargs(expr))...)))))
end

macro tco(names, expr)
   Expr(:block,
        esc(tcostep(eval(names), expr)),
        esc(tcotrampoline(eval(names), expr)))
end

macro dump(expr)
   println(expr)
end

@tco [:fac] (fac(n, r=1) = n == 0 ? r : fac(n - 1, n * r))
@tco [:pow] (pow(x, n, r=1) = n == 0 ? r : pow(x, n - 1, x * r))
@tco [:fib] (fib(n, r1=1, r2=1) = n == 1 ? r1 : (n == 2 ? r2 : fib(n - 1, r2, r1 + r2)))
@tco [:even, :odd] (even(n) = n == 0 ? true : odd(n - 1))
@tco [:even, :odd] (odd(n) = n == 0 ? false : even(n - 1))

n = 63
print("fac ")
@btime fac($n)

x = 2.
n = 64
print("pow ")
@btime pow($x, $n)

n = 10
print("fib ")
@btime fib($n)

n = 2^17
print("even ")
@btime even($n)
print("odd ")
@btime odd($n)

with results

fac   47.976 ns (0 allocations: 0 bytes)
pow   41.802 ns (0 allocations: 0 bytes)
fib   10.210 ns (0 allocations: 0 bytes)
even   83.800 μs (0 allocations: 0 bytes)
odd   83.800 μs (0 allocations: 0 bytes)

The macros are probably brittle, because I’m new to programming them.

2 Likes