Okay, I’ve managed to solve the issue. At root, the problem is that Cassette doesn’t do its IR rewriting inside of eval
ed expressions. To solve this, “all” we have to do is overdub eval
with something that acts like jl_toplevel_eval_flex
does but does the Cassette rewriting before executing any thunks. The problem is that Cassette’s overdub_pass!
was never designed to work with random eval
bodies and insists on these “arguments” things to pass the environment in. Pah!
Thus, in the great tradition of programmers everywhere, we copy and paste overdub_pass!
and make some changes. Note that I haven’t implemented the tagging functionality because I’m lazy and don’t need it for my application.
# the following comes from Cassette.jl and implements the actual overdub pass
# this has been modified to rewrite non-method-bodies
# this assumes a priori that the context is in the first slot and that there are no arguments
# no support for tagging, we don't need it
function overdub_pass!(reflection::Cassette.Reflection,
context,
is_invoke::Bool = false)
code_info = reflection.code_info
context_type = typeof(context)
istaggingenabled = Cassette.hastagging(context_type)
#=== execute user-provided pass (is a no-op by default) ===#
code_info = Core._apply_pure(Cassette.passtype(context_type), (context_type, reflection))
isa(code_info, Expr) && return code_info
#=== munge the code into a valid form for `overdub_generator` ===#
# NOTE: The slotflags set by this pass are set according to what makes sense based on the
# compiler's actual `@code_lowered` output in practice, since this real-world output does
# not seem to match Julia's developer documentation.
# construct new slotnames/slotflags for added slots
# there aren't any arguments (this is an eval'ed block) so we don't have to worry about it
code_info.slotnames = Any[Cassette.OVERDUB_CONTEXT_NAME, code_info.slotnames...]
code_info.slotflags = UInt8[0x00, code_info.slotflags...]
n_prepended_slots = 1
overdub_ctx_slot = Core.SlotNumber(1)
# For the sake of convenience, the rest of this pass will translate `code_info`'s fields
# into these overdubbed equivalents instead of updating `code_info` in-place. Then, at
# the end of the pass, we'll reset `code_info` fields accordingly.
overdubbed_code = Any[]
overdubbed_codelocs = Int32[]
#=== finish initialization of `overdubbed_code`/`overdubbed_codelocs` ===#
# substitute static parameters, offset slot numbers by number of added slots, and
# offset statement indices by the number of additional statements
Base.Meta.partially_inline!(code_info.code, Any[], Tuple{}, Any[], n_prepended_slots, length(overdubbed_code), :propagate)
original_code_start_index = length(overdubbed_code) + 1
append!(overdubbed_code, code_info.code)
append!(overdubbed_codelocs, code_info.codelocs)
# inject the context from eval
stmtcount = (x, i) -> if i == 1 return 2 else nothing end
newstmts = (x, i) -> begin
return [Expr(:(=), overdub_ctx_slot, context), x]
end
Cassette.insert_statements!(overdubbed_code, overdubbed_codelocs, stmtcount, newstmts)
#=== replace `Expr(:call, ...)` with `Expr(:call, :overdub, ...)` calls ===#
arehooksenabled = Cassette.hashooks(context_type)
stmtcount = (x, i) -> begin
i >= original_code_start_index || return nothing
isassign = Base.Meta.isexpr(x, :(=))
stmt = isassign ? x.args[2] : x
if Base.Meta.isexpr(stmt, :call) && !(Base.Meta.isexpr(stmt.args[1], :nooverdub))
isapplycall = Cassette.is_ir_element(stmt.args[1], GlobalRef(Core, :_apply), overdubbed_code)
if isapplycall && arehooksenabled
return 7
elseif isapplycall
return 2 + isassign
elseif arehooksenabled
return 4
else
return 1 + isassign
end
end
return nothing
end
newstmts = (x, i) -> begin
callstmt = Base.Meta.isexpr(x, :(=)) ? x.args[2] : x
isapplycall = Cassette.is_ir_element(callstmt.args[1], GlobalRef(Core, :_apply), overdubbed_code)
if isapplycall && arehooksenabled
callf = callstmt.args[2]
callargs = callstmt.args[3:end]
stmts = Any[
Expr(:call, GlobalRef(Core, :tuple), overdub_ctx_slot),
Expr(:call, GlobalRef(Core, :tuple), callf),
Expr(:call, GlobalRef(Core, :_apply), GlobalRef(Cassette, :prehook), Core.SSAValue(i), Core.SSAValue(i + 1), callargs...),
Expr(:call, GlobalRef(Core, :_apply), GlobalRef(Cassette, :overdub), Core.SSAValue(i), Core.SSAValue(i + 1), callargs...),
Expr(:call, GlobalRef(Core, :tuple), Core.SSAValue(i + 3)),
Expr(:call, GlobalRef(Core, :_apply), GlobalRef(Cassette, :posthook), Core.SSAValue(i), Core.SSAValue(i + 4), Core.SSAValue(i + 1), callargs...),
Base.Meta.isexpr(x, :(=)) ? Expr(:(=), x.args[1], Core.SSAValue(i + 3)) : Core.SSAValue(i + 3)
]
elseif isapplycall
callf = callstmt.args[2]
callargs = callstmt.args[3:end]
stmts = Any[
Expr(:call, GlobalRef(Core, :tuple), overdub_ctx_slot, callf),
Expr(:call, GlobalRef(Core, :_apply), GlobalRef(Cassette, :overdub), Core.SSAValue(i), callargs...),
]
Base.Meta.isexpr(x, :(=)) && push!(stmts, Expr(:(=), x.args[1], Core.SSAValue(i + 1)))
elseif arehooksenabled
stmts = Any[
Expr(:call, GlobalRef(Cassette, :prehook), overdub_ctx_slot, callstmt.args...),
Expr(:call, GlobalRef(Cassette, :overdub), overdub_ctx_slot, callstmt.args...),
Expr(:call, GlobalRef(Cassette, :posthook), overdub_ctx_slot, Core.SSAValue(i + 1), callstmt.args...),
Base.Meta.isexpr(x, :(=)) ? Expr(:(=), x.args[1], Core.SSAValue(i + 1)) : Core.SSAValue(i + 1)
]
else
stmts = Any[
Expr(:call, GlobalRef(Cassette, :overdub), overdub_ctx_slot, callstmt.args...),
]
Base.Meta.isexpr(x, :(=)) && push!(stmts, Expr(:(=), x.args[1], Core.SSAValue(i)))
end
return stmts
end
Cassette.insert_statements!(overdubbed_code, overdubbed_codelocs, stmtcount, newstmts)
#=== unwrap all `Expr(:nooverdub)`s ===#
Cassette.replace_match!(x -> x.args[1], x -> Base.Meta.isexpr(x, :nooverdub), overdubbed_code)
#=== replace all `Expr(:contextslot)`s ===#
Cassette.replace_match!(x -> overdub_ctx_slot, x -> Base.Meta.isexpr(x, :contextslot), overdubbed_code)
#=== set `code_info`/`reflection` fields accordingly ===#
code_info.code = overdubbed_code
code_info.codelocs = overdubbed_codelocs
code_info.ssavaluetypes = length(overdubbed_code)
return code_info
end
The main signature difference is that instead of taking the context type (since the usual overdub_pass!
is called from a generated method), we take an actual context. Moreover, there are no effective function arguments (so we only need to inject one slot into the thunk) and can embed the context directly into aforementioned slot.
From here, we can implement our eval
-alike, which is if you squint very hard effectively a lazy reimplementation of jl_toplevel_eval_flex
that doesn’t work properly with line numbers.
function trace_eval(mod, expr, ctx)
if !(expr isa Expr)
return Core.eval(mod, expr)
end
if Meta.isexpr(expr, :(.))
return Core.eval(mod, expr)
end
if ccall(:jl_needs_lowering, Int32, (Any,), expr) == 1
expr = ccall(:jl_expand_with_loc, Any, (Any, Any, String, Int32),
expr, mod, "none", 0)
end
if Meta.isexpr(expr, :using) || Meta.isexpr(expr, :import) ||
Meta.isexpr(expr, :export) || Meta.isexpr(expr, :global) || Meta.isexpr(expr, :const) ||
Meta.isexpr(expr, :error) || Meta.isexpr(expr, :incomplete)
return Core.eval(mod, expr)
elseif Meta.isexpr(expr, :module)
# create a new module and recurse
newmod = Core.eval(mod, :(module $(expr.args[2]) end))
if !(expr.args[3] isa Array)
trace_eval(newmod, expr.args[3])
else
for iexpr in expr.args[3]
trace_eval(newmod, iexpr)
end
end
return newmod
elseif Meta.isexpr(expr, :toplevel)
res = nothing
for entry in expr.args
res = trace_eval(mod, entry)
end
return res
end
# must be a thunk with inner code
# execution flow:
# 1: rewrite with cassette
# 2: compile & execute
refl = Cassette.Reflection(Tuple{typeof(dummy_method)}, dummy_method_reference, [], expr.args[1])
overdub_pass!(refl, ctx, false)
return Core.eval(mod, expr) # the IR has been rewritten internally
end
The dummy_method
is just a placeholder to fill in for the field in Reflection, it’s not actually used by anything.
This done, we can implement Cassette-through-eval like follows:
Cassette.overdub(ctx::MyContext, ::typeof(Core.eval), mod, expr) = trace_eval(mod, expr, ctx)
Done!