I would like to write a macro which subsets a collection of variables (with shared dimensions). My use case can be simplified to a dictionary of vectors where each vector has the same length, for example:
ds = Dict(:a => 1:10,:b => 2:11)
I would like to subset the variable ds based on the values associated to :a
and :b
(and not based on the indices). The macro syntax is the following
ds2 = @select(ds,a < 5)
which corresponds to
Dict( (k,v[ds[:a] .< 5]) for (k,v) in ds )
One should also be able to substitute variable prefixed by $
as it is the case with the @btime
macro (from BenchmarkTools):
limit = 5
ds2 = @select(ds,a < $limit)
The code below is an implementation of this @select
macro. I would like to know if it is possible to avoid the evil eval
function and if it is possible to simplify the code (as there is a lot of quoting/escaping).
Thank you for sharing your ideas!
Below is a short implementation of the @select
macro. In my case, the values of the dictionary are arrays (not just vectors) with named dimensions but I think the simplified case (using vectors) should be sufficient to show my question here.
# ds will always by a dictionary whoes values are
# arrays of the same size
ds = Dict(
:a => 1:10,
:b => 2:11)
@assert all(sz -> sz == size(first(values(ds))),size.(values(ds)))
# helper function to recursively scan the expression
function scan_exp!(exp::Symbol,varnames,found)
if exp in varnames
push!(found,exp)
end
return found
end
function scan_exp!(exp::Expr,varnames,found)
for arg in exp.args
scan_exp!(arg,varnames,found)
end
return found
end
# neither Expr nor Symbol
scan_exp!(exp,varnames,found) = nothing
scan_exp(exp::Expr,varnames) = scan_exp!(exp::Expr,varnames,Symbol[])
function scan_coordinate_name(exp,coordinate_names)
params = scan_exp(exp,coordinate_names)
@assert length(params) == 1
param = params[1]
return param
end
macro select(ds,condition)
exp2 = Meta.quot(condition)
quote
coord_names = keys($ds)
exp = $(esc(exp2))
param = scan_coordinate_name(exp,coord_names)
fun = eval(Expr(:->,param,exp))
# avoid world age problem
ind = Base.invokelatest(findall,fun,ds[param])
Dict( (k,v[ind]) for (k,v) in ds )
end
end
function test_fun(ds)
ds2 = @select(ds,a < 5)
@show ds2
end
function test_fun2(ds)
limit = 5
ds2 = @select(ds,a < $limit)
@show ds2
end
# both function should return
# Dict(:a => [1, 2, 3, 4], :b => [2, 3, 4, 5])
#
# which corresponds to
#
# Dict( (k,v[ds[:a] .< 5]) for (k,v) in ds )
test_fun(ds)
test_fun2(ds)