Trampolines

Hi guys,

I was playing around with trampolines and the following program

using BenchmarkTools

fac(n) = n == 0 ? 1 : n * fac(n - 1)
pow(x, n) = n == 0 ? 1 : x * pow(x, n - 1)

n = 63
@btime fac($n)

x = 2.
n = 64
@btime pow($x, $n)

function trampoline(f::F, args::T) where {F <: Function, T <: Union{Tuple, R} where R}
   argsn = f(args...)
   while isa(argsn, Tuple)
      argsn = f(argsn...)
   end
   argsn
end

fact(n, r=1) = n == 0 ? r : (n - 1, n * r)
powt(x, n, r=1) = n == 0 ? r : (x, n - 1, x * r)

n = 63
@assert trampoline(fact, (n, )) == fac(n)
@btime trampoline(fact, ($n,))

x = 2.
n = 64
@assert trampoline(powt, (x, n)) == pow(x, n)
@btime trampoline(powt, ($x, $n))

showed

  368.932 ns (0 allocations: 0 bytes)
  525.131 ns (0 allocations: 0 bytes)
  26.707 ns (0 allocations: 0 bytes)
  33.434 ns (0 allocations: 0 bytes)

It looks to me like I’m doing something wrong here?

1 Like

Can you specify what you think the issue is?

I’m surprised by the difference in performance of the trampolined and recursive versions…

Loops are much easier to optimize than recursion. I don’t think you should be surprised.

By the way, the type signature for args in trampoline is exactly equivalent to ::Any, and thus to no type signature.

This is basically a variation of the method https://github.com/TakekazuKATO/TailRec.jl uses. My experiments with trampolines and mutual recursion first didn’t work out, but (ab)using the iterator interface I came up with

using BenchmarkTools

even(n) = n == 0 ? true : odd(n - 1)
odd(n) = n == 0 ? false : even(n - 1)

n = 2^17
@btime even($n)
@btime odd($n)

struct Done{T}
   res::T
end
struct Call{F <: Function, T <: Tuple}
   f::F
   args::T
end

_call(call::Call{F, T}) where {F <: Function, T <: Tuple} = call.f(call.args...)
_return(done::Done{T}) where T = done.res

function Base.iterate(::Call{F, T}, bounce) where {F <: Function, T <: Tuple}
   if !isa(bounce, Done)
      bounce = _call(bounce)
      bounce, bounce
   else
      return nothing
   end
end
function Base.iterate(call::Call{F, T}) where {F <: Function, T <: Tuple}
   bounce = _call(call)
   bounce, bounce
end

function Base.last(it)
   y = nothing
   for x in it
      y = x
   end
   y
end

event(n) = n == 0 ? Done(true) : Call(oddt, (n - 1,))
oddt(n) = n == 0 ? Done(false) : Call(event, (n - 1,))

n = 2^17
@assert last(Call(event, (n,))).res == even(n)
n = 2^18
@btime last(Call(event, ($n,)))
n = 2^17
@assert last(Call(oddt, (n,))).res == odd(n)
n = 2^18
@btime last(Call(oddt, ($n,)))

resulting in

  675.700 μs (0 allocations: 0 bytes)
  675.700 μs (0 allocations: 0 bytes)
  1.314 ms (1 allocation: 16 bytes)
  1.258 ms (1 allocation: 16 bytes)

Notice n = 2^18 generates stack overflow for the non-tco’ed version on my system.

Edit: this technique seems slightly better than the one used in @bounce of https://github.com/MikeInnes/Lazy.jl

Edit: pirating Base.last was probably not a good idea

Interesting. It seems one is able to get allocation free TCO for mutual recursive functions in Julia this way. MWE:

using BenchmarkTools

struct Done{T}
   res::T
end
struct Call{F <: Function, T <: Tuple}
   f::F
   args::T
end

_call(call::Call{F, T}) where {F <: Function, T <: Tuple} = call.f(call.args...)
_return(call::Call{F, T}) where {F <: Function, T <: Tuple} = nothing
_return(done::Done{T}) where T = done.res

function Base.iterate(::Call{F, T}, bounce) where {F <: Function, T <: Tuple}
   if !isa(bounce, Done)
      bounce = _call(bounce)
      bounce, bounce
   else
      return nothing
   end
end
function Base.iterate(call::Call{F, T}) where {F <: Function, T <: Tuple}
   bounce = _call(call)
   bounce, bounce
end

function trampoline(it)
   res = nothing
   for bounce in it
      res = _return(bounce)
   end
   res
end

function rewriterhs(names, ::Val{:call}, args)
   if args[1] in names
      Expr(:call, :Call, Symbol(args[1], :step),
           Expr(:tuple, args[2:end]...))
   else
      Expr(:call, (rewrite(names, arg) for arg in args)...)
   end
end
function rewritelhs(names, ::Val{:call}, args)
   if args[1] in names
      Expr(:call, Symbol(args[1], :step), (rewritelhs(names, arg) for arg in args[2:end])...)
   else
      Expr(:call, args...)
   end
end

function rewriterhs(names, ::Val{:if}, args)
   Expr(:if, args[1], (rewriterhs(names, arg) for arg in args[2:end])...)
end

function rewriterhs(names, ::Val{:block}, args)
   Expr(:block, (rewriterhs(names, arg) for arg in args)...)
end

function rewritelhs(names, ::Val{:kw}, args)
   args[1]
end

function rewriterhs(names, lineNumberNode::LineNumberNode)
   lineNumberNode
end
function rewriterhs(names, number::Number)
   Expr(:call, :Done, number)
end
function rewriterhs(names, symbol::Symbol)
   Expr(:call, :Done, symbol)
end
function rewritelhs(names, symbol::Symbol)
   symbol
end
function rewriterhs(names, expr::Expr)
   rewriterhs(names, Val(expr.head), expr.args)
end
function rewritelhs(names, expr::Expr)
   rewritelhs(names, Val(expr.head), expr.args)
end

funcname(expr) = expr.args[1].args[1]
funcargs(expr) = expr.args[1].args[2:end]

function tcostep(names, expr)
    Expr(expr.head,
         rewritelhs(names, expr.args[1]),
         rewriterhs(names, expr.args[2]))
end
function tcotrampoline(names, expr)
   Expr(expr.head, expr.args[1],
      Expr(:block,
      Expr(:call, :trampoline,
         Expr(:call, :Call, Symbol(funcname(expr), :step),
            Expr(:tuple, (rewritelhs(names, arg) for arg in funcargs(expr))...)))))
end

macro tco(names, expr)
   Expr(:block,
        esc(tcostep(eval(names), expr)),
        esc(tcotrampoline(eval(names), expr)))
end

macro dump(expr)
   println(expr)
end

@tco [:fac] (fac(n, r=1) = n == 0 ? r : fac(n - 1, n * r))
@tco [:pow] (pow(x, n, r=1) = n == 0 ? r : pow(x, n - 1, x * r))
@tco [:fib] (fib(n, r1=1, r2=1) = n == 1 ? r1 : (n == 2 ? r2 : fib(n - 1, r2, r1 + r2)))
@tco [:even, :odd] (even(n) = n == 0 ? true : odd(n - 1))
@tco [:even, :odd] (odd(n) = n == 0 ? false : even(n - 1))

n = 63
print("fac ")
@btime fac($n)

x = 2.
n = 64
print("pow ")
@btime pow($x, $n)

n = 10
print("fib ")
@btime fib($n)

n = 2^17
print("even ")
@btime even($n)
print("odd ")
@btime odd($n)

with results

fac   47.976 ns (0 allocations: 0 bytes)
pow   41.802 ns (0 allocations: 0 bytes)
fib   10.210 ns (0 allocations: 0 bytes)
even   83.800 μs (0 allocations: 0 bytes)
odd   83.800 μs (0 allocations: 0 bytes)

The macros are probably brittle, because I’m new to programming them.

2 Likes

Next question is what to do with this. I see a couple of options:

Opinions?