Help with IR Manipulation

Since this is a relatively complex stack, I will begin with an overarching explanation of my problem, then contextualize it further.

Keeping it Vague

Suppose I have a misty closure (MistyClosures.jl), where I want to specifically track the most recent capture of a very specific type T. The closure is defined in such a way that there is only one valid capture of type T that represents the relevant object in the stack.

The underlying IR contains calls like the following which I want to track along execution.

%982 =   builtin Base.getfield(_1, 104, %981)::Base.RefValue{T}
%983 =   builtin Base.getfield(%982, :x)::T

Here, if this is the most recent call on the stack, I want to write 104 to a new element in the captures (represented by Core.Compiler.Argument(1)).

Identifying these calls is easy and can be done with MacroTools.jl:

MacroTools.@capture(
    instruction,
    $(GlobalRef(Base, :getfield))($(Core.Argument(1)), ref_, ssa_)
)

The issue now is creating a NewInstruction to add a new capture that saves ref from above (ref should be an integer representing an index of captures). I know I will have to remake the closure with the new IR, but I don’t know if the captures will be consistent here…

The Gory Details

For context, I’m using Libtask.jl to construct a copyable/resumable function for use in sequential importance sampling. Under the hood, this manipulates the IR to add return nodes when we tell the process to produce.

Using the type T from earlier, our (much simplified) stack looks like this

function f(x::T, val, dist)
    return accumulate!(x.acc[3], val, dist)
end

function accumulate!(acc, val, dist)
    logp = logpdf(dist, val)
    produce(logp)
    acc.logprob += logp
end

I want to extract the most recently referenced value of x before the produce call in the stack. This value exists within the captures, but since it is not always operated in place, there are multiple captures of type T.

And depending on where we are in the stack of the taped task, the most relevant capture of type T varies.

For those kind souls brave enough to help me out, you can expand the following to get a painful summary of what I really want.

"Minimal" Working Example

Please note, that I’m using the breaking branch of DynamicPPL.jl for experimentation here; if you don’t want that you can remove UnlinkAll from the InitContext.

The type I want to track is an AbstractVarInfo, which is called within tilde_observe/tilde_assume call.

using Distributions
using DynamicPPL
using Libtask
using MacroTools
using MistyClosures
using Random
using Random123

## LINEAR REGRESSION MODEL #################################################################

@model function linear_regression(x, y)
    β ~ Normal(0, 1)
    σ ~ truncated(Cauchy(0, 3); lower=0)
    for t in eachindex(x)
        y[t] ~ Normal(β * x[t], σ)
    end
end

# condition the model
rng = MersenneTwister(1234)
x, y = rand(rng, 10), rand(rng, 10)
reg_model = linear_regression(x, y)

## ACCUMULATOR #############################################################################

struct ProduceLogLikelihoodAccumulator{T<:Real} <: DynamicPPL.LogProbAccumulator{T}
    logp::T
end

DynamicPPL.accumulator_name(::Type{<:ProduceLogLikelihoodAccumulator}) = :LogLikelihood
DynamicPPL.logp(acc::ProduceLogLikelihoodAccumulator) = acc.logp

function DynamicPPL.acclogp(acc::ProduceLogLikelihoodAccumulator, val)
    Libtask.produce(val)
    newacc = ProduceLogLikelihoodAccumulator(DynamicPPL.logp(acc) + val)
    return newacc
end

function DynamicPPL.accumulate_assume!!(
    acc::ProduceLogLikelihoodAccumulator, val, tval, logjac, vn, dist, template
)
    return acc
end
function DynamicPPL.accumulate_observe!!(
    acc::ProduceLogLikelihoodAccumulator, dist, val, vn
)
    return DynamicPPL.acclogp(acc, Distributions.loglikelihood(dist, val))
end

Libtask.@might_produce(DynamicPPL.accumulate_observe!!)
Libtask.@might_produce(DynamicPPL.tilde_observe!!)
Libtask.@might_produce(DynamicPPL.accloglikelihood!!)

function Libtask.might_produce(
    ::Type{
        <:Tuple{
            typeof(Base.:+),
            ProduceLogLikelihoodAccumulator,
            DynamicPPL.LogLikelihoodAccumulator,
        },
    },
)
    return true
end

Libtask.@might_produce(DynamicPPL.tilde_assume!!)
Libtask.@might_produce(DynamicPPL.evaluate!!)
Libtask.@might_produce(DynamicPPL.init!!)
Libtask.might_produce(::Type{<:Tuple{<:DynamicPPL.Model,Vararg}}) = true

## TAPED TASK ##############################################################################

function tape(rng::AbstractRNG, model::Model)
    vi = DynamicPPL.setacc!!(VarInfo(model), ProduceLogLikelihoodAccumulator())
    inner_rng = Random.seed!(Random123.Philox2x(), rand(rng, Random.Sampler(rng, UInt64)))
    inner_model = DynamicPPL.setleafcontext(
        model, InitContext(inner_rng, InitFromPrior(), UnlinkAll())
    )
    args, kwargs = DynamicPPL.make_evaluate_args_and_kwargs(
        inner_model, DynamicPPL.empty!!(vi)
    )
    return TapedTask(nothing, inner_model.f, args...; kwargs...)
end

task = tape(rng, reg_model);

I will happily provide more context if needed, and I apologize for the not-so-minimal working example.