Using Cassette through `eval`

I’m working on developing a dynamic analysis system that needs to dynamically track function definitions and invocations, including those made in eval statements. I’ve, so far, been using Cassette to track function invocations to good effect, but have ran into something of a roadblock with eval.

The problem that I have is that Cassette won’t recursively overdub into evaled blocks, and it’s not easy to design a program transformation for evaled ASTs to convert them into a function that can be overdubbed via invokelatest. I think that this latter is technically feasible, but it requires dealing with a large number of AST constructs so I would rather avoid it if possible.

Is there a good way to overdub evaled bodies, or an alternative to Cassette that is able to intercept invocations in a similar way?

A quick example that shows the issue is

using Cassette
Cassette.@context Ctx;
Cassette.prehook(::Ctx, f, args...) = println(f, args)
Cassette.overdub(Ctx(), eval, :(1/2))

Tracing of function calls stops after it hits the eval.

Okay, I’ve managed to solve the issue. At root, the problem is that Cassette doesn’t do its IR rewriting inside of evaled expressions. To solve this, “all” we have to do is overdub eval with something that acts like jl_toplevel_eval_flex does but does the Cassette rewriting before executing any thunks. The problem is that Cassette’s overdub_pass! was never designed to work with random eval bodies and insists on these “arguments” things to pass the environment in. Pah!

Thus, in the great tradition of programmers everywhere, we copy and paste overdub_pass! and make some changes. Note that I haven’t implemented the tagging functionality because I’m lazy and don’t need it for my application.

# the following comes from Cassette.jl and implements the actual overdub pass
# this has been modified to rewrite non-method-bodies 
# this assumes a priori that the context is in the first slot and that there are no arguments 
# no support for tagging, we don't need it
function overdub_pass!(reflection::Cassette.Reflection,
                       context,
                       is_invoke::Bool = false)
    code_info = reflection.code_info
    context_type = typeof(context)

    istaggingenabled = Cassette.hastagging(context_type)

    #=== execute user-provided pass (is a no-op by default) ===#
    code_info = Core._apply_pure(Cassette.passtype(context_type), (context_type, reflection))
    isa(code_info, Expr) && return code_info

    #=== munge the code into a valid form for `overdub_generator` ===#

    # NOTE: The slotflags set by this pass are set according to what makes sense based on the
    # compiler's actual `@code_lowered` output in practice, since this real-world output does
    # not seem to match Julia's developer documentation.

    # construct new slotnames/slotflags for added slots
    # there aren't any arguments (this is an eval'ed block) so we don't have to worry about it
    code_info.slotnames = Any[Cassette.OVERDUB_CONTEXT_NAME, code_info.slotnames...]
    code_info.slotflags = UInt8[0x00, code_info.slotflags...]
    n_prepended_slots = 1
    overdub_ctx_slot = Core.SlotNumber(1)

    # For the sake of convenience, the rest of this pass will translate `code_info`'s fields
    # into these overdubbed equivalents instead of updating `code_info` in-place. Then, at
    # the end of the pass, we'll reset `code_info` fields accordingly.
    overdubbed_code = Any[]
    overdubbed_codelocs = Int32[]

    #=== finish initialization of `overdubbed_code`/`overdubbed_codelocs` ===#

    # substitute static parameters, offset slot numbers by number of added slots, and
    # offset statement indices by the number of additional statements
    Base.Meta.partially_inline!(code_info.code, Any[], Tuple{}, Any[], n_prepended_slots, length(overdubbed_code), :propagate)

    original_code_start_index = length(overdubbed_code) + 1

    append!(overdubbed_code, code_info.code)
    append!(overdubbed_codelocs, code_info.codelocs)

    # inject the context from eval
    stmtcount = (x, i) -> if i == 1 return 2 else nothing end
    newstmts = (x, i) -> begin
        return [Expr(:(=), overdub_ctx_slot, context), x]
    end
    Cassette.insert_statements!(overdubbed_code, overdubbed_codelocs, stmtcount, newstmts)

    #=== replace `Expr(:call, ...)` with `Expr(:call, :overdub, ...)` calls ===#
    arehooksenabled = Cassette.hashooks(context_type)
    stmtcount = (x, i) -> begin
        i >= original_code_start_index || return nothing
        isassign = Base.Meta.isexpr(x, :(=))
        stmt = isassign ? x.args[2] : x
        if Base.Meta.isexpr(stmt, :call) && !(Base.Meta.isexpr(stmt.args[1], :nooverdub))
            isapplycall = Cassette.is_ir_element(stmt.args[1], GlobalRef(Core, :_apply), overdubbed_code)
            if isapplycall && arehooksenabled
                return 7
            elseif isapplycall
                return 2 + isassign
            elseif arehooksenabled
                return 4
            else
                return 1 + isassign
            end
        end
        return nothing
    end
    newstmts = (x, i) -> begin
        callstmt = Base.Meta.isexpr(x, :(=)) ? x.args[2] : x
        isapplycall = Cassette.is_ir_element(callstmt.args[1], GlobalRef(Core, :_apply), overdubbed_code)
        if isapplycall && arehooksenabled
            callf = callstmt.args[2]
            callargs = callstmt.args[3:end]
            stmts = Any[
                Expr(:call, GlobalRef(Core, :tuple), overdub_ctx_slot),
                Expr(:call, GlobalRef(Core, :tuple), callf),
                Expr(:call, GlobalRef(Core, :_apply), GlobalRef(Cassette, :prehook), Core.SSAValue(i), Core.SSAValue(i + 1), callargs...),
                Expr(:call, GlobalRef(Core, :_apply), GlobalRef(Cassette, :overdub), Core.SSAValue(i), Core.SSAValue(i + 1), callargs...),
                Expr(:call, GlobalRef(Core, :tuple), Core.SSAValue(i + 3)),
                Expr(:call, GlobalRef(Core, :_apply), GlobalRef(Cassette, :posthook), Core.SSAValue(i), Core.SSAValue(i + 4), Core.SSAValue(i + 1), callargs...),
                Base.Meta.isexpr(x, :(=)) ? Expr(:(=), x.args[1], Core.SSAValue(i + 3)) : Core.SSAValue(i + 3)
            ]
        elseif isapplycall
            callf = callstmt.args[2]
            callargs = callstmt.args[3:end]
            stmts = Any[
                Expr(:call, GlobalRef(Core, :tuple), overdub_ctx_slot, callf),
                Expr(:call, GlobalRef(Core, :_apply), GlobalRef(Cassette, :overdub), Core.SSAValue(i), callargs...),
            ]
            Base.Meta.isexpr(x, :(=)) && push!(stmts, Expr(:(=), x.args[1], Core.SSAValue(i + 1)))
        elseif arehooksenabled
            stmts = Any[
                Expr(:call, GlobalRef(Cassette, :prehook), overdub_ctx_slot, callstmt.args...),
                Expr(:call, GlobalRef(Cassette, :overdub), overdub_ctx_slot, callstmt.args...),
                Expr(:call, GlobalRef(Cassette, :posthook), overdub_ctx_slot, Core.SSAValue(i + 1), callstmt.args...),
                Base.Meta.isexpr(x, :(=)) ? Expr(:(=), x.args[1], Core.SSAValue(i + 1)) : Core.SSAValue(i + 1)
            ]
        else
            stmts = Any[
                Expr(:call, GlobalRef(Cassette, :overdub), overdub_ctx_slot, callstmt.args...),
            ]
            Base.Meta.isexpr(x, :(=)) && push!(stmts, Expr(:(=), x.args[1], Core.SSAValue(i)))
        end
        return stmts
    end
    Cassette.insert_statements!(overdubbed_code, overdubbed_codelocs, stmtcount, newstmts)


    #=== unwrap all `Expr(:nooverdub)`s ===#

    Cassette.replace_match!(x -> x.args[1], x -> Base.Meta.isexpr(x, :nooverdub), overdubbed_code)

    #=== replace all `Expr(:contextslot)`s ===#

    Cassette.replace_match!(x -> overdub_ctx_slot, x -> Base.Meta.isexpr(x, :contextslot), overdubbed_code)

    #=== set `code_info`/`reflection` fields accordingly ===#

    code_info.code = overdubbed_code
    code_info.codelocs = overdubbed_codelocs
    code_info.ssavaluetypes = length(overdubbed_code)

    return code_info
end

The main signature difference is that instead of taking the context type (since the usual overdub_pass! is called from a generated method), we take an actual context. Moreover, there are no effective function arguments (so we only need to inject one slot into the thunk) and can embed the context directly into aforementioned slot.

From here, we can implement our eval-alike, which is if you squint very hard effectively a lazy reimplementation of jl_toplevel_eval_flex that doesn’t work properly with line numbers.

function trace_eval(mod, expr, ctx)
	if !(expr isa Expr)
		return Core.eval(mod, expr)
	end

	if Meta.isexpr(expr, :(.))
		return Core.eval(mod, expr)
	end

	if ccall(:jl_needs_lowering, Int32, (Any,), expr) == 1
		expr = ccall(:jl_expand_with_loc, Any, (Any, Any, String, Int32), 
			         expr, mod, "none", 0)
	end

	if Meta.isexpr(expr, :using) || Meta.isexpr(expr, :import) ||
		Meta.isexpr(expr, :export) || Meta.isexpr(expr, :global) || Meta.isexpr(expr, :const) ||
		Meta.isexpr(expr, :error) || Meta.isexpr(expr, :incomplete)
		return Core.eval(mod, expr)
	elseif Meta.isexpr(expr, :module)
		# create a new module and recurse
		newmod = Core.eval(mod, :(module $(expr.args[2]) end))
		if !(expr.args[3] isa Array)
			trace_eval(newmod, expr.args[3])
		else 
			for iexpr in expr.args[3] 
				trace_eval(newmod, iexpr)
			end
		end
		return newmod
	elseif Meta.isexpr(expr, :toplevel)
		res = nothing
		for entry in expr.args
			res = trace_eval(mod, entry)
		end
		return res 
	end 
	# must be a thunk with inner code
	# execution flow:
	# 1: rewrite with cassette
	# 2: compile & execute
	refl = Cassette.Reflection(Tuple{typeof(dummy_method)}, dummy_method_reference, [], expr.args[1])
	overdub_pass!(refl, ctx, false)
	return Core.eval(mod, expr) # the IR has been rewritten internally
end

The dummy_method is just a placeholder to fill in for the field in Reflection, it’s not actually used by anything.

This done, we can implement Cassette-through-eval like follows:

Cassette.overdub(ctx::MyContext, ::typeof(Core.eval), mod, expr) = trace_eval(mod, expr, ctx)

Done!