Manual type inference of hand-written IRCode

i am trying to write something like Petite Zygote by manipulating Core.Compiler.IRCode to get sense, how it works. I am considering that as a part of survey, if I should show it to students, where the compiler infrastructure is heading. But not sure I will succeed. I also do not know, if it is better to ask here on discourse.
Anyway, here is my question.
Let’s say I have a function

function foo(x,y)
  z = x * y
  z + sin(x)

and its IRCode

2 1 ─ %1 = (_2 * _3)::Float64                                                                                     β”‚
3 β”‚   %2 = Main.sin(_2)::Float64                                                                                  β”‚
  β”‚   %3 = (%1 + %2)::Float64                                                                                     β”‚
  └──      return %3                                                                                              β”‚

I transform it to

2 1 ─ %1 = rrule(Main.:*, _2, _3)::Any                                                                            β”‚
  β”‚   %2 = getindex(%1, 1)::Any                                                                                   β”‚
  β”‚        getindex(%1, 2)::Any                                                                                   β”‚
3 β”‚   %4 = rrule(Main.sin, _2)::Any                                                                               β”‚
  β”‚   %5 = getindex(%4, 1)::Any                                                                                   β”‚
  β”‚        getindex(%4, 2)::Any                                                                                   β”‚
  β”‚   %7 = rrule(Main.:+, %2, %5)::Any                                                                            β”‚
  β”‚   %8 = getindex(%7, 1)::Any                                                                                   β”‚
  β”‚        getindex(%7, 2)::Any                                                                                   β”‚
  └──      return %8                                                                                              β”‚

Where I set all types to Any. Is there some function, which will perform type inference given I know types of arguments? If not, how I can perform type inference manually? The trouble is that I know the function to be called by Core.GlobalRef and not by their type used by Base.code_ircode which I wanted to use to give me the return type.
Thanks for answers in advance.

I posted this originally on slack. I am moving the discussion to here, such that it is archived. I will post below the important answers and clearly denote the authors, since obviously I cannot post it on their behalf.


@jameson has answered:
Probably not really, since the question is underspecified once you have an IRCode stripped of its MethodInstance. The MethodInstance is needed to provide context to inference

@aviatesk has answered:
It’s very hacky and not recommended, but you can do something like:

julia> function foo(x,y)
         z = x * y
         z + sin(x)

julia> ir, = only(Base.code_ircode(foo, (Float64,Float64,)));

julia> ir.stmts.type .= Any;

julia> ir
2 1 ─ %1 = Base.mul_float(_2, _3)::Any                                β”‚β•» *
3 β”‚   %2 = invoke Main.sin(_2::Float64)::Any                          β”‚ 
  β”‚   %3 = Base.add_float(%1, %2)::Any                                β”‚β•» +
  └──      return %3                                                  β”‚ 

julia> interp = Core.Compiler.NativeInterpreter();

julia> world = Core.Compiler.get_world_counter(interp);

julia> mi = only(methods(foo)).specializations;

julia> argtypes = Any[Core.Const(foo), Float64, Float64];

julia> src = first(@code_typed foo(1.0, 2.0));

julia> method_info = Core.Compiler.MethodInfo(src);

julia> irsv = Core.Compiler.IRInterpretationState(interp, method_info, ir, mi, argtypes, world, src.min_world, src.max_world);

julia> irsv.argtypes_refined .|= true; # force reinference

julia> Core.Compiler.ir_abstract_constant_propagation(interp, irsv);

2 1 ─ %1 = Base.mul_float(_2, _3)::Float64                            β”‚β•» *
3 β”‚   %2 = invoke Main.sin(_2::Float64)::Float64                      β”‚ 
  β”‚   %3 = Base.add_float(%1, %2)::Float64                            β”‚β•» +
  └──      return %3                                                  β”‚

My problem is clearly that I do not understand internals of enough Julia, which means I am mainly hacking my way through it.

I was looking to reverse pass of Diffractor it coverts its IRCode to CodeInfo and passes it further. This puzzles me a bit, because CodeInfo contains function calls defined by GlobalRef as in IRCode. I therefore do not know, how the type inference can obtain handle to the method. To be concrete, if I take the CodeInfo of foo as

ci = @code_lowered foo(1.0,1.0)
julia> ci.code[1].args[2].args[1]

julia> ci.code[1].args[2].args[1] |> typeof

I just do not have a good mental picture.

I think the answer for that is CodeInfo is the input data-structure for Inference and IRCode is the input format for optimization.

Because it performs a static eval on the GlobalRef and if it is defined uses the value found?

Thanks Valentin for the answer. It makes a lot of sense. Can I find somewhere the code performing static eval on the GlobalRef? I would try to hijack it.

Also, if you think that mu toying is useless (I do it for educational purpose), better to tell me. My goal is to lower a bar a bit for people who want to mess around the IRCode and internals in general.

1 Like

No I think it’s very worthwhile, we don’t have the material to point people to and soften their journey and lowering the barrier of contributing is something that would be great!

1 Like

Thanks Valentin for the answer. It makes a lot of sense. Can I find somewhere the code performing static eval on the GlobalRef? I would try to hijack it.

I see, so it does something like this to get the function handle?

julia> ci.code[1].args[2].args[1] |> eval
* (generic function with 334 methods)

This might make it fly.

The trick I have mentioned above worked quite nicely. I have written a following function to get a type of ChainRules.rrule

argtype(ir::CC.IRCode, a::Core.Argument) = ir.argtypes[a.n]
argtype(ir::CC.IRCode, a::Core.SSAValue) = ir.stmts.type[]
argtype(ir::CC.IRCode, f::GlobalRef) = typeof(eval(f))
argtype(ir::CC.IRCode, a) = error("argtype of $(typeof(a)) not supported")

  type_of_pullback(ir, inst)

  infer type of the pullback
function type_of_pullback(ir, inst, optimize_until = "compact 1")
  inst.head != :call && error("inferrin return type of calls is supported")
  params = tuple([argtype(ir, a) for a in inst.args]...)
  (ir, rt) = only(Base.code_ircode(ChainRules.rrule, params; optimize_until))
  if !(rt <:Tuple{A,B} where {A,B})
    error("The return type of pullback `ChainRules.rrule($(params))` should be tuple")

This seems to work quite nicely so far. Adding this to the code above, I am able to construct a fully typed IRCode

julia> new_ir = CC.IRCode(is, cfg, ir.linetable, ir.argtypes, ir.meta, ir.sptypes)
2 1 ─ %1 = ChainRules.rrule(Main.:*, _2, _3)::Tuple{Float64, ChainRules.var"#times_pullback2#1331"{Float64, Float64}}
  β”‚   %2 = getindex(%1, 1)::Float64                                                                               β”‚
  β”‚        getindex(%1, 2)::ChainRules.var"#times_pullback2#1331"{Float64, Float64}                               β”‚
3 β”‚   %4 = ChainRules.rrule(Main.sin, _2)::Tuple{Float64, ChainRules.var"#sin_pullback#1291"{Float64}}            β”‚
  β”‚   %5 = getindex(%4, 1)::Float64                                                                               β”‚
  β”‚        getindex(%4, 2)::ChainRules.var"#sin_pullback#1291"{Float64}                                           β”‚
  β”‚   %7 = ChainRules.rrule(Main.:+, %2, %5)::Tuple{Float64, ChainRules.var"#+_pullback#1319"{Bool, Bool, ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}, ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}}}
  β”‚   %8 = getindex(%7, 1)::Float64                                                                               β”‚
  β”‚        getindex(%7, 2)::ChainRules.var"#+_pullback#1319"{Bool, Bool, ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}, ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}}
  └──      return %8                                                                                              β”‚

If it matters, I find these questions and Discourse discussions quite educational and good future references for me. I would also like to learn more about these low level details of the Julia compiler and how to hack it but I haven’t been able to wrap my head around most of it. Please keep it up.