Manual type inference of hand-written IRCode

Hello,
i am trying to write something like Petite Zygote by manipulating Core.Compiler.IRCode to get sense, how it works. I am considering that as a part of survey, if I should show it to students, where the compiler infrastructure is heading. But not sure I will succeed. I also do not know, if it is better to ask here on discourse.
Anyway, here is my question.
Let’s say I have a function

function foo(x,y)
  z = x * y
  z + sin(x)
end

and its IRCode

2 1 ─ %1 = (_2 * _3)::Float64                                                                                     β”‚
3 β”‚   %2 = Main.sin(_2)::Float64                                                                                  β”‚
  β”‚   %3 = (%1 + %2)::Float64                                                                                     β”‚
  └──      return %3                                                                                              β”‚

I transform it to

2 1 ─ %1 = rrule(Main.:*, _2, _3)::Any                                                                            β”‚
  β”‚   %2 = getindex(%1, 1)::Any                                                                                   β”‚
  β”‚        getindex(%1, 2)::Any                                                                                   β”‚
3 β”‚   %4 = rrule(Main.sin, _2)::Any                                                                               β”‚
  β”‚   %5 = getindex(%4, 1)::Any                                                                                   β”‚
  β”‚        getindex(%4, 2)::Any                                                                                   β”‚
  β”‚   %7 = rrule(Main.:+, %2, %5)::Any                                                                            β”‚
  β”‚   %8 = getindex(%7, 1)::Any                                                                                   β”‚
  β”‚        getindex(%7, 2)::Any                                                                                   β”‚
  └──      return %8                                                                                              β”‚

Where I set all types to Any. Is there some function, which will perform type inference given I know types of arguments? If not, how I can perform type inference manually? The trouble is that I know the function to be called by Core.GlobalRef and not by their type used by Base.code_ircode which I wanted to use to give me the return type.
Thanks for answers in advance.

I posted this originally on slack. I am moving the discussion to here, such that it is archived. I will post below the important answers and clearly denote the authors, since obviously I cannot post it on their behalf.

@jameson has answered:
Probably not really, since the question is underspecified once you have an IRCode stripped of its MethodInstance. The MethodInstance is needed to provide context to inference

@aviatesk has answered:
It’s very hacky and not recommended, but you can do something like:

julia> function foo(x,y)
         z = x * y
         z + sin(x)
       end;

julia> ir, = only(Base.code_ircode(foo, (Float64,Float64,)));

julia> ir.stmts.type .= Any;

julia> ir
2 1 ─ %1 = Base.mul_float(_2, _3)::Any                                β”‚β•» *
3 β”‚   %2 = invoke Main.sin(_2::Float64)::Any                          β”‚ 
  β”‚   %3 = Base.add_float(%1, %2)::Any                                β”‚β•» +
  └──      return %3                                                  β”‚ 
                                                                        

julia> interp = Core.Compiler.NativeInterpreter();

julia> world = Core.Compiler.get_world_counter(interp);

julia> mi = only(methods(foo)).specializations;

julia> argtypes = Any[Core.Const(foo), Float64, Float64];

julia> src = first(@code_typed foo(1.0, 2.0));

julia> method_info = Core.Compiler.MethodInfo(src);

julia> irsv = Core.Compiler.IRInterpretationState(interp, method_info, ir, mi, argtypes, world, src.min_world, src.max_world);

julia> irsv.argtypes_refined .|= true; # force reinference

julia> Core.Compiler.ir_abstract_constant_propagation(interp, irsv);

julia> irsv.ir
2 1 ─ %1 = Base.mul_float(_2, _3)::Float64                            β”‚β•» *
3 β”‚   %2 = invoke Main.sin(_2::Float64)::Float64                      β”‚ 
  β”‚   %3 = Base.add_float(%1, %2)::Float64                            β”‚β•» +
  └──      return %3                                                  β”‚

My problem is clearly that I do not understand internals of enough Julia, which means I am mainly hacking my way through it.

I was looking to reverse pass of Diffractor it coverts its IRCode to CodeInfo and passes it further. This puzzles me a bit, because CodeInfo contains function calls defined by GlobalRef as in IRCode. I therefore do not know, how the type inference can obtain handle to the method. To be concrete, if I take the CodeInfo of foo as

ci = @code_lowered foo(1.0,1.0)
julia> ci.code[1].args[2].args[1]
:(Main.:*)

julia> ci.code[1].args[2].args[1] |> typeof
GlobalRef

I just do not have a good mental picture.

I think the answer for that is CodeInfo is the input data-structure for Inference and IRCode is the input format for optimization.

Because it performs a static eval on the GlobalRef and if it is defined uses the value found?

Thanks Valentin for the answer. It makes a lot of sense. Can I find somewhere the code performing static eval on the GlobalRef? I would try to hijack it.

Also, if you think that mu toying is useless (I do it for educational purpose), better to tell me. My goal is to lower a bar a bit for people who want to mess around the IRCode and internals in general.

No I think it’s very worthwhile, we don’t have the material to point people to and soften their journey and lowering the barrier of contributing is something that would be great!

Thanks Valentin for the answer. It makes a lot of sense. Can I find somewhere the code performing static eval on the GlobalRef? I would try to hijack it.

I see, so it does something like this to get the function handle?

julia> ci.code[1].args[2].args[1] |> eval
* (generic function with 334 methods)

This might make it fly.

The trick I have mentioned above worked quite nicely. I have written a following function to get a type of ChainRules.rrule


argtype(ir::CC.IRCode, a::Core.Argument) = ir.argtypes[a.n]
argtype(ir::CC.IRCode, a::Core.SSAValue) = ir.stmts.type[a.id]
argtype(ir::CC.IRCode, f::GlobalRef) = typeof(eval(f))
argtype(ir::CC.IRCode, a) = error("argtype of $(typeof(a)) not supported")

"""
  type_of_pullback(ir, inst)

  infer type of the pullback
"""
function type_of_pullback(ir, inst, optimize_until = "compact 1")
  inst.head != :call && error("inferrin return type of calls is supported")
  params = tuple([argtype(ir, a) for a in inst.args]...)
  (ir, rt) = only(Base.code_ircode(ChainRules.rrule, params; optimize_until))
  if !(rt <:Tuple{A,B} where {A,B})
    error("The return type of pullback `ChainRules.rrule($(params))` should be tuple")
  end
  rt
end

This seems to work quite nicely so far. Adding this to the code above, I am able to construct a fully typed IRCode

julia> new_ir = CC.IRCode(is, cfg, ir.linetable, ir.argtypes, ir.meta, ir.sptypes)
2 1 ─ %1 = ChainRules.rrule(Main.:*, _2, _3)::Tuple{Float64, ChainRules.var"#times_pullback2#1331"{Float64, Float64}}
  β”‚   %2 = getindex(%1, 1)::Float64                                                                               β”‚
  β”‚        getindex(%1, 2)::ChainRules.var"#times_pullback2#1331"{Float64, Float64}                               β”‚
3 β”‚   %4 = ChainRules.rrule(Main.sin, _2)::Tuple{Float64, ChainRules.var"#sin_pullback#1291"{Float64}}            β”‚
  β”‚   %5 = getindex(%4, 1)::Float64                                                                               β”‚
  β”‚        getindex(%4, 2)::ChainRules.var"#sin_pullback#1291"{Float64}                                           β”‚
  β”‚   %7 = ChainRules.rrule(Main.:+, %2, %5)::Tuple{Float64, ChainRules.var"#+_pullback#1319"{Bool, Bool, ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}, ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}}}
  β”‚   %8 = getindex(%7, 1)::Float64                                                                               β”‚
  β”‚        getindex(%7, 2)::ChainRules.var"#+_pullback#1319"{Bool, Bool, ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}, ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}}
  └──      return %8                                                                                              β”‚

If it matters, I find these questions and Discourse discussions quite educational and good future references for me. I would also like to learn more about these low level details of the Julia compiler and how to hack it but I haven’t been able to wrap my head around most of it. Please keep it up.

I believe replacing eval in argtype(ir::CC.IRCode, f::GlobalRef) = typeof(eval(f)) as

argtype(ir::CC.IRCode, f::GlobalRef) = typeof(getproperty(f.mod, f.name))

will work as well and be might considered better practice.

Thanks, I have updated my code.