I just wanted to show of a fun little macro I wrote. Doing a course on stochastic simulation, I’m writing small functions that simulate some process. For performance reasons I don’t want to emit all the intermediate values corresponding to different timesteps. But it’s nice to have that option for debugging or visualization purposes. One way to avoid essentially copy&pasting the function would be to take an optional Val
-argument and compile away the unnecessary stuff based on that:
function test(collect::Val{T} = Val(false)) where T
t = 0.0
B = 0.0
if T # create lists to store values
ts = [t]
Bs = [B]
end
while (t += h) < T
# do computation, then store values
if T
push!(ts, t)
push!(Bs, B)
end
end
return B, ts, Bs
end
But I got tired of the boilerplate and wanted to play with macros. The macro @collectible
takes a function definition with a block as its last statement. The generated function will return a lambda with this block as its body. You mark the desired variables (only scalar values) with @collect
. It will then trace modifications and store the intermediate values. Furthermore the return expression need to annotated with @return
for the additional values to be automatically added.
@collectible function geo_brown_sampler(h; S₀ = 10.0, μ = 0.03, σ² = 0.05, T = 10.0)
h_root = sqrt(h)
σ = sqrt(σ²)
coeff = μ - σ² / 2
begin
@collect t = 0.0 # current time
@collect B = 0.0 # last Brownian motion (with var σ²) value
@collect S = S₀ # S(previous time)
# approximate integral: (last+current) / h
res = S / 2
while (t += h) < T
# step time by h ~> add N(0, hσ²) to B
B += σ * h_root * randn()
S = S₀ * exp(coeff*t + B)
res += S
end
# step time to T
Δt = T - t + h
B += σ * sqrt(Δt) * randn() # ~ N(0, Δt*σ²)
res = h*(res - S/2) + Δt*S/2
S = S₀ * exp(coeff*T + B)
@return (res + Δt * S / 2) / T
end
end
Calling geo_brown_sampler(0.1)
yields a function you can just call in order to get an execution of the last block, i.e.
julia> samp = geo_brown_sampler(0.01)
#13 (generic function with 1 method)
julia> samp()
8.629606496567117
In order to capture the intermediate values, you call geo_brown_sampler(0.1, true)
. That will return a lambda that additionally captures and returns the values of t
, B
and S
.
julia> samp, fields = geo_brown_sampler(0.1, true)
(var"#12#14"{...}), [:t, :B, :S])
julia> typeof(samp())
Tuple{Float64, Vector{Float64}, Vector{Float64}, Vector{Float64}}
julia> res, ts, Bs, Ss, = samp()
(6.818649790188291, [0.0, 0.1, 0.2, 0.30000000000000004, 0.4, 0.5, 0.6, 0.7, 0.7999999999999999, 0.8999999999999999 … 9.199999999999983, 9.299999999999983, 9.399999999999983, 9.499999999999982, 9.599999999999982, 9.699999999999982, 9.799999999999981, 9.89999999999998, 9.99999999999998, 10.09999999999998], [0.0, 0.05367575273944805, 0.06641594050575657, -0.06714949643519959, -0.12156273686299154, -0.128799816021638, -0.1438791273339596, -0.015214851807525215, 0.0726950857462695, 0.09459776489835521 … -0.266423679427629, -0.24859874313909594, -0.30138334429243946, -0.2220425312669843, -0.13891534284124804, -0.02179723866907865, 0.016912993572113823, -0.039082064894935924, -0.10904399911600285, -0.1090439920631063], [10.0, 10.556701227677545, 10.697403339869576, 9.364590390065167, 8.87308339131126, 8.813505641475004, 8.685942946608579, 9.883534998985112, 10.79712806370992, 11.041742437649269 … 8.021788592750955, 8.170142494074037, 7.753945020092463, 8.398411339621086, 9.130950076563764, 10.270624747050425, 10.681337786183443, 10.104723907317648, 9.42665291805097, 9.426652984536178])
I’ll be the first one to admit that the macro itself is quite janky and extremely fragile.
using MacroTools: postwalk
using DataStructures
macro collectible(expr)
@assert expr.head == :function
@assert expr.args[2].args[end] isa Expr &&
expr.args[2].args[end].head == :block "last line should be block"
# add collect argument
push!(expr.args[1].args, Expr(:kw, :collect, false))
# remove annotaions from return body
body_orig = expr.args[2].args[end]
body_none = postwalk(body_orig) do x
if x isa Expr && x.args[1] == Symbol("@collect")
x.args[end]
elseif x isa Expr && x.args[1] == Symbol("@return")
if x.args[end] isa Expr && x.args[end].head == :tuple
Expr(:return, x.args[end]...)
else
Expr(:return, x.args[end])
end
else
x
end
end
# compute collecting body
vars = OrderedDict{Symbol,Symbol}() # keep insertion order
body_collect = postwalk(body_orig) do x
x isa Expr || return x
if x.head == :macrocall && x.args[1] == Symbol("@collect")
# process `@collect var = [...]`
@assert x.args[end].head == :(=) && length(x.args[end].args) == 2 && x.args[end].args[1] isa Symbol
var = x.args[end].args[1]
var_list = get!(vars, var) do
Symbol(var, "_collect")
end
quote
$(x.args[end])
$var_list = [$var]
$var
end
elseif x.head == :macrocall && x.args[1] == Symbol("@return")
# process `@return [...]`
if x.args[end] isa Expr && x.args[end].head == :tuple
Expr(:return, Expr(:tuple, x.args[end].args..., values(vars)...))
else
Expr(:return, Expr(:tuple, x.args[end], values(vars)...))
end
elseif x.head ∈ [:(=), :+=, :-=, :*=, :/=] &&
(var = x.args[1]) isa Symbol &&
(var_list = get(vars, var, nothing)) !== nothing
# store value of x.head
quote
$x
push!($var_list, $var)
$var
end
else
x
end
end
# construct new return expression
expr.args[2].args[end] = quote
if collect
() -> $body_collect, $(keys(vars))
else
() -> $body_none
end
end
return expr
end