Manual type inference of hand-written IRCode

The trick I have mentioned above worked quite nicely. I have written a following function to get a type of ChainRules.rrule


argtype(ir::CC.IRCode, a::Core.Argument) = ir.argtypes[a.n]
argtype(ir::CC.IRCode, a::Core.SSAValue) = ir.stmts.type[a.id]
argtype(ir::CC.IRCode, f::GlobalRef) = typeof(eval(f))
argtype(ir::CC.IRCode, a) = error("argtype of $(typeof(a)) not supported")

"""
  type_of_pullback(ir, inst)

  infer type of the pullback
"""
function type_of_pullback(ir, inst, optimize_until = "compact 1")
  inst.head != :call && error("inferrin return type of calls is supported")
  params = tuple([argtype(ir, a) for a in inst.args]...)
  (ir, rt) = only(Base.code_ircode(ChainRules.rrule, params; optimize_until))
  if !(rt <:Tuple{A,B} where {A,B})
    error("The return type of pullback `ChainRules.rrule($(params))` should be tuple")
  end
  rt
end

This seems to work quite nicely so far. Adding this to the code above, I am able to construct a fully typed IRCode

julia> new_ir = CC.IRCode(is, cfg, ir.linetable, ir.argtypes, ir.meta, ir.sptypes)
2 1 ─ %1 = ChainRules.rrule(Main.:*, _2, _3)::Tuple{Float64, ChainRules.var"#times_pullback2#1331"{Float64, Float64}}
  │   %2 = getindex(%1, 1)::Float64                                                                               │
  │        getindex(%1, 2)::ChainRules.var"#times_pullback2#1331"{Float64, Float64}                               │
3 │   %4 = ChainRules.rrule(Main.sin, _2)::Tuple{Float64, ChainRules.var"#sin_pullback#1291"{Float64}}            │
  │   %5 = getindex(%4, 1)::Float64                                                                               │
  │        getindex(%4, 2)::ChainRules.var"#sin_pullback#1291"{Float64}                                           │
  │   %7 = ChainRules.rrule(Main.:+, %2, %5)::Tuple{Float64, ChainRules.var"#+_pullback#1319"{Bool, Bool, ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}, ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}}}
  │   %8 = getindex(%7, 1)::Float64                                                                               │
  │        getindex(%7, 2)::ChainRules.var"#+_pullback#1319"{Bool, Bool, ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}, ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}}
  └──      return %8                                                                                              │
6 Likes