Macro capturing intermediate values

I just wanted to show of a fun little macro I wrote. Doing a course on stochastic simulation, I’m writing small functions that simulate some process. For performance reasons I don’t want to emit all the intermediate values corresponding to different timesteps. But it’s nice to have that option for debugging or visualization purposes. One way to avoid essentially copy&pasting the function would be to take an optional Val-argument and compile away the unnecessary stuff based on that:

function test(collect::Val{T} = Val(false)) where T
    t = 0.0
    B = 0.0
    if T # create lists to store values
        ts = [t]
        Bs = [B]
    end
    while (t += h) < T
        # do computation, then store values
        if T
            push!(ts, t)
            push!(Bs, B)
        end
    end
    return B, ts, Bs
end

But I got tired of the boilerplate and wanted to play with macros. The macro @collectible takes a function definition with a block as its last statement. The generated function will return a lambda with this block as its body. You mark the desired variables (only scalar values) with @collect. It will then trace modifications and store the intermediate values. Furthermore the return expression need to annotated with @return for the additional values to be automatically added.

@collectible function geo_brown_sampler(h; S₀ = 10.0, μ = 0.03, σ² = 0.05, T = 10.0)
    h_root = sqrt(h)
    σ = sqrt(σ²)
    coeff = μ - σ² / 2

    begin
        @collect t = 0.0 # current time
        @collect B = 0.0 # last Brownian motion (with var σ²) value
        @collect S = S₀ # S(previous time)

        # approximate integral: (last+current) / h
        res = S / 2
        while (t += h) < T
            # step time by h ~> add N(0, hσ²) to B
            B += σ * h_root * randn()

            S = S₀ * exp(coeff*t + B)
            res += S
        end

        # step time to T
        Δt = T - t + h
        B += σ * sqrt(Δt) * randn() # ~ N(0, Δt*σ²)

        res = h*(res - S/2) + Δt*S/2
        S = S₀ * exp(coeff*T + B)
        @return (res + Δt * S / 2) / T
    end
end

Calling geo_brown_sampler(0.1) yields a function you can just call in order to get an execution of the last block, i.e.

julia> samp = geo_brown_sampler(0.01)
#13 (generic function with 1 method)
julia> samp()
8.629606496567117

In order to capture the intermediate values, you call geo_brown_sampler(0.1, true). That will return a lambda that additionally captures and returns the values of t, B and S.

julia> samp, fields = geo_brown_sampler(0.1, true)
(var"#12#14"{...}), [:t, :B, :S])
julia> typeof(samp())
Tuple{Float64, Vector{Float64}, Vector{Float64}, Vector{Float64}}
julia> res, ts, Bs, Ss, = samp()
(6.818649790188291, [0.0, 0.1, 0.2, 0.30000000000000004, 0.4, 0.5, 0.6, 0.7, 0.7999999999999999, 0.8999999999999999  …  9.199999999999983, 9.299999999999983, 9.399999999999983, 9.499999999999982, 9.599999999999982, 9.699999999999982, 9.799999999999981, 9.89999999999998, 9.99999999999998, 10.09999999999998], [0.0, 0.05367575273944805, 0.06641594050575657, -0.06714949643519959, -0.12156273686299154, -0.128799816021638, -0.1438791273339596, -0.015214851807525215, 0.0726950857462695, 0.09459776489835521  …  -0.266423679427629, -0.24859874313909594, -0.30138334429243946, -0.2220425312669843, -0.13891534284124804, -0.02179723866907865, 0.016912993572113823, -0.039082064894935924, -0.10904399911600285, -0.1090439920631063], [10.0, 10.556701227677545, 10.697403339869576, 9.364590390065167, 8.87308339131126, 8.813505641475004, 8.685942946608579, 9.883534998985112, 10.79712806370992, 11.041742437649269  …  8.021788592750955, 8.170142494074037, 7.753945020092463, 8.398411339621086, 9.130950076563764, 10.270624747050425, 10.681337786183443, 10.104723907317648, 9.42665291805097, 9.426652984536178])

I’ll be the first one to admit that the macro itself is quite janky and extremely fragile.

using MacroTools: postwalk
using DataStructures

macro collectible(expr)
    @assert expr.head == :function
    @assert expr.args[2].args[end] isa Expr && 
        expr.args[2].args[end].head == :block "last line should be block"

    # add collect argument
    push!(expr.args[1].args, Expr(:kw, :collect, false))

    # remove annotaions from return body
    body_orig = expr.args[2].args[end]
    body_none = postwalk(body_orig) do x
        if x isa Expr && x.args[1] == Symbol("@collect")
            x.args[end]
        elseif x isa Expr && x.args[1] == Symbol("@return")
            if x.args[end] isa Expr && x.args[end].head == :tuple
                Expr(:return, x.args[end]...)
            else
                Expr(:return, x.args[end])
            end
        else
            x
        end
    end

    # compute collecting body
    vars = OrderedDict{Symbol,Symbol}() # keep insertion order
    body_collect = postwalk(body_orig) do x
        x isa Expr || return x
        if x.head == :macrocall && x.args[1] == Symbol("@collect")
            # process `@collect var = [...]`
            @assert x.args[end].head == :(=) && length(x.args[end].args) == 2 && x.args[end].args[1] isa Symbol
            var = x.args[end].args[1]
            var_list = get!(vars, var) do
                Symbol(var, "_collect")
            end

            quote
                $(x.args[end])
                $var_list = [$var]
                $var
            end
        elseif x.head == :macrocall && x.args[1] == Symbol("@return")
            # process `@return [...]`
            if x.args[end] isa Expr && x.args[end].head == :tuple
                Expr(:return, Expr(:tuple, x.args[end].args..., values(vars)...))
            else
                Expr(:return, Expr(:tuple, x.args[end], values(vars)...))
            end
        elseif x.head ∈ [:(=), :+=, :-=, :*=, :/=] && 
                (var = x.args[1]) isa Symbol && 
                (var_list = get(vars, var, nothing)) !== nothing
            # store value of x.head
            quote
                $x
                push!($var_list, $var)
                $var
            end
        else
            x
        end
    end

    # construct new return expression
    expr.args[2].args[end] = quote
        if collect
            () -> $body_collect, $(keys(vars))
        else
            () -> $body_none
        end
    end

    return expr
end
2 Likes