Transforming arbitrary Julia code to lazy evaluation

I’ve been experimenting with methods for transforming arbitrary Julia code for lazy evaluation (repo at https://github.com/tbenst/Thunks.jl). Here’s an example of what I’d like to be able to do in a fraction of a second:

add1(x) = x + 1
abc = @thunk begin
  b = identity(10)
  z = sleep(10)
  a = identity(collect(1:b))[8:end]
  c,d,e,f = [x-1 for x in 1:4]
  abc = sum(add1.(a) .* [c,d,f])
end
@time @assert reify(abc) == 43

However, transforming arbitrary code including dot broadcasting of functions or infix operators, variadic return values, and indexing is a bit tricky.

The core implementation is very terse:

mutable struct Thunk
    f::Any # usually a Function, but could be any callable
    args::Tuple # args will be passed to f
    kwargs::Dict # kwargs will be passed to f
    evaluated::Bool # false until computed, then true
    result::Any # cache result once computed
    Thunk(f, args, kwargs) = new(f, args, kwargs, false, nothing)
end

function thunk(f)
    (args...; kwargs...) -> Thunk(f, args, kwargs)
end

"""
    reify(thunk::Thunk)
    reify(value::Any)

Reify a thunk into a value.

In other words, compute the value of the expression.

We walk through the thunk's arguments and keywords, recursively evaluating each one,
and then evaluating the thunk's function with the evaluated arguments.
"""
function reify(thunk::Thunk)
    if thunk.evaluated
        return thunk.result
    else
        args = [reify(x) for x in thunk.args]
        kwargs = Dict(k => reify(v) for (k,v) in thunk.kwargs)
        thunk.result = thunk.f(args...; kwargs...)
        thunk.evaluated = true
        return thunk.result
    end
end

function reify(value)
    value
end

This does allow for any arbitrary Julia expression be written, albeit in ugly & manual fashion:

b = thunk(identity)(10)
z = thunk(sleep)(10)
a = thunk(b->identity(collect(1:b))[8:end])(b)
res = thunk(()->[x-1 for x in 1:4])()
c,d,e,f = [thunk(getindex)(res,1), thunk(getindex)(res,2),
    thunk(getindex)(res,3), thunk(getindex)(res,4)]
abc = thunk((a,c,d,f)->sum(add1.(a) .* [c,d,f]))(a,c,d,f)
@time @assert reify(abc) == 43

As a programmer writing this transform, I have to parse each line, and look for the existence of a symbol that is a Thunk, and then I use each of these symbols as a variable when wrapping in a function.

In theory, this should be very fast and easy to do thanks to the homoiconic nature of Julia.


"Walk an AST and find all unique symbols."
function find_symbols_in_ast(ex)
    # accumulate sub-expressions for recursion
    expressions = [ex]
    # accumulate found symbols
    symbols = []

    # since no Tail Call Optimization in Julia,
    # we write recursion in a while loop
    while length(expressions) > 0
        ex = pop!(expressions)
        first, rest = _find_symbols_in_ast(ex)
        if typeof(first) == Symbol
            # got value, no more recursion
            push!(symbols, first)
        end
        if ~isnothing(rest)
            # recur
            expressions = vcat(expressions, rest)
        end
    end
    return unique(symbols)
end

"Helper function following `cons` and `nil` pattern from Lisp."
function _find_symbols_in_ast(ex)
    if typeof(ex) == Expr
        head, args = ex.head, ex.args
        return nothing, args
    elseif typeof(ex) == Symbol
        return ex, nothing
    else
        return nothing, nothing
    end
end

isthunk(x) = typeof(x) == Thunk

"Safely evaluate symbols that may not have an assignment."
function safe_eval_isthunk(ex)
    try
        return isthunk(eval(ex))
    catch
        return false
    end
end

"Return array of symbols that are assigned to a Thunk."
function find_thunks_in_ast(ex)
    symbols = find_symbols_in_ast(ex)
    filter(safe_eval_isthunk, symbols)
end

Unfortunately, find_thunks_in_ast does not work if Thunks are assigned to variables in the local scope because “Julia’s eval always evaluates code [in] the current module’s scope, not your local scope.

Any ideas how to sidestep this problem? I think this is an example where allowing eval to access variables in a local scope has massive ergonomic benefit and virtually no performance cost, and I’m not sure there are any easy alternatives…

Many thanks for your help!

Can you describe a bit more what you want this to do? You say

by which I think you mean that z = sleep(10) should be eliminated, since z isn’t used. Then this isn’t just delayed evaluation, it’s some kind of dead code elimination too. Does it assume all functions visible to the macro are pure, i.e. that I haven’t defined, before the macro, something like this:

CNT = Ref(1)
add1(x) = x + (CNT[]+=1)
sleep(x) = CNT[] += x

I’m also not sure I follow how deeply is the recursive un-thunking is supposed to work. If I define say f(x) = rand()>0.5 ? 0 : @thunk factorial(x), and call this somewhere within @thunk, is the goal that the final result will still be just e.g. a number?

I mean that z = sleep(10) should not be evaluated until it is needed. In the example I wrote, the value of abc is needed due to the call reify(abc), and to evaluate abc we need a,c,d,f, which in turn depends on b. Since z has not yet been needed, it remains unevaluated. This is classic lazy evaluation or call-by-need, albeit with the minor twist that memoization is occuring locally in a struct, rather than globally memoizing args to a function.

Does it assume all functions visible to the macro are pure

One could do side effects, but they will only execute up to 1 time, so with your functions, CNT would be 1 until we reify(abc), and then CNT == 4 subsequently, if we call reify(abc) one or more times.

julia> using Thunks

julia> begin
         CNT = Ref(1)
         add1(x) = x + (CNT[]+=1)
         sleep(x) = CNT[] += x
         b = thunk(identity)(10)
         z = thunk(sleep)(10)
         a = thunk(b->identity(collect(1:b))[8:end])(b)
         res = thunk(()->[x-1 for x in 1:4])()
         c,d,e,f = [thunk(getindex)(res,1), thunk(getindex)(res,2),
           thunk(getindex)(res,3), thunk(getindex)(res,4)]
         abc = thunk((a,c,d,f)->sum(add1.(a) .* [c,d,f]))(a,c,d,f)
       end;

julia> CNT
Base.RefValue{Int64}(1)

julia> reify(abc)
54

julia> CNT
Base.RefValue{Int64}(4)

julia> reify(abc)
54

julia> CNT
Base.RefValue{Int64}(4)

To your final question on, g(x) = rand()>0.5 ? 0 : @thunk factorial(x), the first reify would be random, but all subsequent would be cached:

julia> a = g(5)
0

julia> b = g(5)
Thunk(factorial, (5,), Dict{Union{}, Union{}}(), false, nothing)

julia> aa = @thunk a + 1
Thunk(+, (0, 1), Dict{Union{}, Union{}}(), false, nothing)

julia> reify(aa)
1

julia> bb = @thunk b + 1
Thunk(+, (Thunk(factorial, (5,), Dict{Union{}, Union{}}(), false, nothing), 1), Dict{Union{}, Union{}}(), false, nothing)

julia> reify(bb)
121

julia> reify(aa)
1

julia> reify(bb)
121

Although you would get random behavior if you made a new thunk on each call:

julia> reify(g(5))
120

julia> reify(g(5))
0

julia> reify(g(5))
120

julia> reify(g(5))
0

Edit: in case it isn’t clear, all of the Thunk implementation is operational and fine, it’s the metaprogramming of how to create the @thunk macro such that it can rewrite my first code block that is less certain, as the simplest approach of wrapping in a function is hampered by the inability of eval to reference local scope i.e. I can’t figure out what symbols in an expression are thunks.

So your definition of needed is then “is called by another function”, rather than “is necessary to get the same results as without @thunk”. Which is OK if you know to expect this, of course, but not safe to freely use on arbitrary code.

(Also, the macro does not return one thing, but expands out to a block such that a,b,c,d,e,f,z are all defined in its scope afterwards, I think — so the abc = on the first line is redundant.)

I don’t think this is knowable at macro expansion time. For instance if the macro is called within a function definition, then the local variables it refers to aren’t defined at all.

Is it obvious you must know these types? It might be possible to write everything in terms of dispatch, insert everywhere in the tree some function like maybe_unthunk(x)=x and maybe_unthunk(x::Thunk)=...?

Otherwise you may want with a generated function. But this sounds messy.

With the package, I get this error:

ERROR: LoadError: AssertionError: unexpected head (ref != :call) for: (identity(collect(1:b)))[8:end]

To turn more things into function calls, you might want things like Meta.lower(Main, :(A[i,end])) (with __module__ presumably). Infix operators should be ordinary calls though.

Thanks for the thoughts and helpful tips!

So your definition of needed is then “is called by another function”, rather than "is necessary to get the same results as without @thunk ". Which is OK if you know to expect this, of course, but not safe to freely use on arbitrary code.

Not sure I agree. My definition of “needed” is intended to be similar to Haskell, albeit with reify / unthunk triggering the need. It also should be safe to use on arbitrary code and I’d like it to support arbitrary syntax. I have caching by default because I want it for my use cases, but it’s not hard to remove this if someone wanted to do a lazy monte carlo or other sampling algorithm.

This is precisely what I’m trying to overcome. I pasted a lot of info about my particular usecase because I don’t see an alternative for doing lazy evaluation, in the Programming Language Theory sense, of arbitrary Julia expressions without the ability to determine if local variables are Thunks. The two options I see right now are:

  • enforce that all Thunks have to be global variables, but otherwise support all Julia expressions
  • allow Thunks to be local variables, but only support a subset of valid Julia expressions

But hopefully there’s a better solution!

Is it obvious you must know these types? It might be possible to write everything in terms of dispatch, insert everywhere in the tree some function like maybe_unthunk(x)=x and maybe_unthunk(x::Thunk)=... ?

Yes, I that’s the basis for how I recursively reify args and kwargs :slight_smile:

Otherwise you may want with a generated function. But this sounds messy.

Interesting, I haven’t looked into generated functions yet, thanks for the tip.

With the package, I get this error:

Yes, the very first code block in the first post on this thread is aspirational and highlights syntax that’s not supported. That’s why I manually rewrote that to:

One could add a getindex for a thunk so that thunk(identity(collect)(1:b)[8:end] is ok, But it’s hard to do this in general. For instance, imagine that instead the line was a = @thunk map(x->2*x, identity(collect(1:b))[8:end]). If we could scan symbols, and know that b is a thunk, then the rewrite is trivial.

julia> Meta.lower(Main, :(A[i,end])) 
:($(Expr(:thunk, CodeInfo(
    @ none within `top-level scope'
1 ─ %1 = Base.lastindex(A, 2)
│   %2 = Base.getindex(A, i, %1)
└──      return %2
))))

Woah cool, had no idea Julia AST has a notion of a :thunk. I can’t find it documented anywhere