Towards implementing CSE in the Julia compiler

Specifically implementing CSE on Julia IR could be a pretty good introduction to the compiler for someone. We already have all the analysis needed (it’s the same analysis for dead code elimination), so it’s mostly just a matter of a little plumbing.

13 Likes

I was curious how difficult it is for a novice to do such a project, as described by Oscar.

It seems the dead code elimination Julia pass is not too difficult to find. On one hand, it is a relatively high-level function of just 150 lines of code with relatively easy to follow logic. On the other hand, it seems to require knowing a lot of terms, conventions, and definitions, used only in the Julia internals, and not really known to most folks.

Could one of the core developers give us a minimal example of what it looks like to use this adce_pass! function “manually”? E.g., how can I see the effect of adce_pass on this method:

function f(x::Int)
    if false
        return 0
    end
    if isa(x, Float64)
        return 1
    end
    return 2
end

How do I transform that method into IRCode that I can then call adce_pass on?

Lastly, can I modify adce_pass in order to have it print debug statements and rely on Revise to take care of reloading it without having to restart (or recompile) Julia?

Maybe this should be split off into a separate thread…

10 Likes
function f(x::Int)
    if false
        return 0
    end
    if isa(x, Float64)
        return 1
    end
    return 2
end

julia> IR = only(Base.code_ircode(f, (Int,), optimize_until="SROA")).first
2 1 ─     goto #3 if not false                                                                                                                  │
  2 ─     nothing::Nothing                                                                                                                      │
5 3 ┄     goto #5 if not false                                                                                                                  │
  4 ─     nothing::Nothing                                                                                                                      │
8 5 ┄     return 2                                                                                                                              │
                                                                                                                                                 

julia> Core.Compiler.adce_pass!(IR)
3 Likes

Thank you! If I may follow up with some silly questions:

julia> function f(x::Int)
           y = x^x
           if false
               return 0
           end
           if isa(x, Float64)
               return 1
           end
           return 2
       end
f (generic function with 1 method)

julia> Base.code_ircode(f, (Int,)) |> only |> first |> Core.Compiler.adce_pass!
2 1 ─     invoke Base.power_by_squaring(_2::Int64, _2::Int64)::Int64     │╻ ^
3 └──     goto #3 if not false                                           │ 
  2 ─     nothing::Nothing                                               │ 
6 3 ┄     goto #5 if not false                                           │ 
  4 ─     nothing::Nothing                                               │ 
9 5 ┄     return 2
  1. I would have expected the power_by_squaring to be eliminated… Why is it still there? (even after a redundant adce_pass! given that code_ircode already calls it).

  2. Why are all these goto to constant locations not skipped?

It doesn’t get eliminated because power_by_squaring will throw for negative integers. If you changed it to x^7 it will be eliminated. The gotos are there for reasons of annoying everyone who works on our IR.

12 Likes

Is there a way to use a Dict{Expr, Int} in the compiler, I want a way to map between expressions and their SSAValues. IdDict doesn’t seem to work as two Expr’s having the same value doesn’t imply their objectid is the same.

Could someone define the abbreviation CSE?

4 Likes

Should be Common Subexpression Elimination.

5 Likes

I don’t think so. You might be able to use an IdDict of the args of the call.

Silly question, but if it really is Common Subexpression Elimination, don’t we need to be sure there are no side-effects? Is that info sufficiently available at the compiler level for CSE to be worthwhile?

Yeah. That’s what the effect system tells you. Specifically, you can only CSE expressions that are consistent, side effect free, terminate and nothrow. (the last two conditions go away if you don’t CSE across control flow).

5 Likes

btw this is the very rough version I have right now (it works in the REPL but not in the compiler because of the Dict)

import Core: SSAValue
import Core.Compiler: IRCode, IR_FLAG_EFFECT_FREE, IR_FLAG_NOTHROW, IR_FLAG_CONSISTENT, naive_idoms, is_meta_expr_head

# Available Expressions Data Flow Analysis
function available_expressions!(ir::IRCode)
    # @assert eachindex(ir.cfg.blocks) isa OneTo

    outflow = [Dict{Expr, Int}() for _ in eachindex(ir.cfg.blocks)]
    ssamap = IdDict{Int, Int}()

    # which is faster?
    # dom_tree = construct_domtree(ir.cfg.blocks)
    idoms = naive_idoms(ir.cfg.blocks)

    worklist = trues(length(ir.cfg.blocks))
    worklist_active_size = length(ir.cfg.blocks)

    block_idx = firstindex(ir.cfg.blocks)
    while worklist_active_size > 0
        if !worklist[block_idx]
            block_idx = findnext(worklist, block_idx)
        end
        worklist[block_idx] = false
        worklist_active_size -= 1

        block = ir.cfg.blocks[block_idx]

        # idom_idx = domtree.idoms_bb[block_idx]
        idom_idx = idoms[block_idx]
        outflow[block_idx] = idom_idx > 0 ? copy(outflow[idom_idx]) : Dict{Expr, Int}()

        changed = transfer_and_cse_pass!(ir, outflow[block_idx], ssamap, block.stmts)

        block_idx += 1
        if changed
            for succ_idx in block.succs
                if idoms[succ_idx] == block_idx
                    worklist[succ_idx] = true
                    worklist_active_size += !worklist[succ_idx]
                    block_idx = min(block_idx, succ_idx)
                end
            end
        end
    end

    return ir
end

# Tranfer function and CSE pass
function transfer_and_cse_pass!(ir::IRCode, inflow, ssamap, stmt_idxs)
    changed = false
    for i in stmt_idxs
        stmt = ir.stmts.stmt[i]

        stmt isa Expr || continue

       # need to check consistency and terminate?
        effect_free = (ir.stmts.flag[i] & (IR_FLAG_EFFECT_FREE | IR_FLAG_NOTHROW)) == IR_FLAG_EFFECT_FREE | IR_FLAG_NOTHROW
        effect_free || continue

        # copy propagation
        # taken from renumber_ir_elements!
        el = stmt
        if el.head === :(=) && el.args[2] isa Expr
            el = el.args[2]::Expr
        end
        if el.head !== :enter && !is_meta_expr_head(el.head)
            args = el.args
            for i in eachindex(args)
                arg = args[i]
                if isa(arg, SSAValue)
                    stmt.args[i] = SSAValue(get(ssamap, arg.id, arg.id))
                end
            end
        end

        if haskey(inflow, stmt)
            ssamap[i] = inflow[stmt]
            ir.stmts.stmt[i] = SSAValue(inflow[stmt])
        else
            inflow[stmt] = i
            changed = true
        end
    end

    return changed
end
5 Likes

Nice! I think you can probably make this work in the compiler by replacing the Expr with the svec of the function call and the args and then using an IDDict.

2 Likes