I would like to use Cassette.jl to implement a compiler pass. The task is to (conditioned on types) translate code that contains a branch into a map over a function that contains the branch. An example of the transformation I would like to do is from
code = Meta.@lower if x > 0
return x^2
else
return -x^2
end
to
code2 = Meta.@lower map(x.particles) do x
if x > 0
return x^2
else
return -x^2
end
end
where the translation should only be done if x
is of a special type defined below
struct Particles{T} <: Real
particles::Vector{T}
end
Without this transformation, code like the following fails predictably
Base.:(^)(p::Particles,r) = Particles(p.particles.^r)
Base.:(>)(p::Particles, r) = Particles(map(>, p.particles, r))
p = Particles(randn(10))
function negsquare(x)
if x > 0
return x^2
else
return -x^2
end
end
julia> negsquare(p)
ERROR: TypeError: non-boolean (Particles) used in boolean context
If the code was translated to
function negsquare(x)
Particles(map(x.particles) do x
if x > 0
return x^2
else
return -x^2
end
end)
end
julia> negsquare(p)
Particles([0.404953, 0.210984, -1.00176, 0.253796, -0.00620389, 0.831144, -0.0240916, -1.90169, 0.875192, 1.4788])
I would get the desired result.
Before we start the discussion on why I just don’t map
negsquare
over p.particles
, the branch statement might occur in an arbitrary third-party function to which the input is an arbitrary third-party structure, where my type Particles
occurs as a field somewhere deep down in the structure. This is a trivial example for demonstration.
I have so far made some progress with Cassette; I can sort out code that does not operate on Particles
and I can identify branches in the code. However, I cannot figure out how to transform the relevant code into the map statement above. My attempt, with three inserted “QUESTION” and “TODO” in comments:
contains_branch(ir::Core.CodeInfo) = any(contains_branch, ir.code)
contains_branch(ex::Expr) = ex.head == :gotoifnot # Unfortunately, there seems to be more ways of branching in ir
contains_branch(any) = false
branch_target(ex) = ex.args[2] # Return the goto statement target index
"My custom compiler pass"
function mapif(::Type{<:Ctx}, reflection::Cassette.Reflection)
ir = reflection.code_info
any(x-> x <: Particles, reflection.signature.parameters) || (return ir) # No particles included in this call
contains_branch(ir) || (return ir) # If there is no branch we leave the code alone
stmtcount = function (stmt, i)
contains_branch(stmt) || (return nothing)
return 1 # QUESTION: One function call replaces the branch?
end
newstmts = function (stmt, i)
@show branch_body = ir.code[i+1:branch_target(stmt)-1] # the branch body starts one after the index of the gotoifnot and ends one before the branch target
# TODO: put the branch body into a map function, have to somehow get rid of all stmt that were put into the function
[stmt] # Must have length
end
Cassette.insert_statements!(ir.code, ir.codelocs, stmtcount, newstmts) # QUESTION: Is it good to send in the entire ir.code so that all SSAValues are updated?
ir
end
mapifpass = Cassette.@pass mapif
ctx = Ctx(pass=mapifpass)
Cassette.overdub(ctx, negsquare, p)
Anyone with Cassette experience that can point me in the right direction? I find it hard to interpret the difference between code code
and code2
above, but this difference should indicate what I need to do with my compiler pass.
A gist with all code above is available here