Cassette: compiler pass to translate branch to function

I would like to use Cassette.jl to implement a compiler pass. The task is to (conditioned on types) translate code that contains a branch into a map over a function that contains the branch. An example of the transformation I would like to do is from

code = Meta.@lower if x > 0
    return x^2
else
    return -x^2
end

to

code2 = Meta.@lower map(x.particles) do x
    if x > 0
        return x^2
    else
        return -x^2
    end
end

where the translation should only be done if x is of a special type defined below

struct Particles{T} <: Real
    particles::Vector{T}
end

Without this transformation, code like the following fails predictably

Base.:(^)(p::Particles,r) = Particles(p.particles.^r)
Base.:(>)(p::Particles, r) = Particles(map(>, p.particles, r))
p = Particles(randn(10))
function negsquare(x)
    if x > 0
        return x^2
    else
        return -x^2
    end
end
julia> negsquare(p)
ERROR: TypeError: non-boolean (Particles) used in boolean context

If the code was translated to

function negsquare(x)
    Particles(map(x.particles) do x
        if x > 0
            return x^2
        else
            return -x^2
        end
    end)
end
julia> negsquare(p)
Particles([0.404953, 0.210984, -1.00176, 0.253796, -0.00620389, 0.831144, -0.0240916, -1.90169, 0.875192, 1.4788])

I would get the desired result.

Before we start the discussion on why I just don’t map negsquare over p.particles, the branch statement might occur in an arbitrary third-party function to which the input is an arbitrary third-party structure, where my type Particles occurs as a field somewhere deep down in the structure. This is a trivial example for demonstration.

I have so far made some progress with Cassette; I can sort out code that does not operate on Particles and I can identify branches in the code. However, I cannot figure out how to transform the relevant code into the map statement above. My attempt, with three inserted “QUESTION” and “TODO” in comments:

contains_branch(ir::Core.CodeInfo) = any(contains_branch, ir.code)
contains_branch(ex::Expr) = ex.head == :gotoifnot # Unfortunately, there seems to be more ways of branching in ir
contains_branch(any) = false
branch_target(ex)    = ex.args[2] # Return the goto statement target index

"My custom compiler pass"
function mapif(::Type{<:Ctx}, reflection::Cassette.Reflection)
    ir = reflection.code_info
    any(x-> x <: Particles, reflection.signature.parameters) || (return ir) # No particles included in this call
    contains_branch(ir) || (return ir) # If there is no branch we leave the code alone

    stmtcount = function (stmt, i)
        contains_branch(stmt) || (return nothing)
        return 1 # QUESTION: One function call replaces the branch? 
    end

    newstmts = function (stmt, i)
        @show branch_body = ir.code[i+1:branch_target(stmt)-1] # the branch body starts one after the index of the gotoifnot and ends one before the branch target
        # TODO: put the branch body into a map function, have to somehow get rid of all stmt that were put into the function
        [stmt] # Must have length
    end

    Cassette.insert_statements!(ir.code, ir.codelocs, stmtcount, newstmts) # QUESTION: Is it good to send in the entire ir.code so that all SSAValues are updated?
    ir
end

mapifpass = Cassette.@pass mapif
ctx = Ctx(pass=mapifpass)
Cassette.overdub(ctx, negsquare, p)

Anyone with Cassette experience that can point me in the right direction? I find it hard to interpret the difference between code code and code2 above, but this difference should indicate what I need to do with my compiler pass.

A gist with all code above is available here