# Manual type inference of hand-written IRCode

Hello,
i am trying to write something like Petite Zygote by manipulating Core.Compiler.IRCode to get sense, how it works. I am considering that as a part of survey, if I should show it to students, where the compiler infrastructure is heading. But not sure I will succeed. I also do not know, if it is better to ask here on discourse.
Anyway, here is my question.
Letβs say I have a function

``````function foo(x,y)
z = x * y
z + sin(x)
end
``````

and its IRCode

``````2 1 β %1 = (_2 * _3)::Float64                                                                                     β
3 β   %2 = Main.sin(_2)::Float64                                                                                  β
β   %3 = (%1 + %2)::Float64                                                                                     β
βββ      return %3                                                                                              β
``````

I transform it to

``````2 1 β %1 = rrule(Main.:*, _2, _3)::Any                                                                            β
β   %2 = getindex(%1, 1)::Any                                                                                   β
β        getindex(%1, 2)::Any                                                                                   β
3 β   %4 = rrule(Main.sin, _2)::Any                                                                               β
β   %5 = getindex(%4, 1)::Any                                                                                   β
β        getindex(%4, 2)::Any                                                                                   β
β   %7 = rrule(Main.:+, %2, %5)::Any                                                                            β
β   %8 = getindex(%7, 1)::Any                                                                                   β
β        getindex(%7, 2)::Any                                                                                   β
βββ      return %8                                                                                              β
``````

Where I set all types to Any. Is there some function, which will perform type inference given I know types of arguments? If not, how I can perform type inference manually? The trouble is that I know the function to be called by Core.GlobalRef and not by their type used by Base.code_ircode which I wanted to use to give me the return type.

I posted this originally on slack. I am moving the discussion to here, such that it is archived. I will post below the important answers and clearly denote the authors, since obviously I cannot post it on their behalf.

2 Likes

Probably not really, since the question is underspecified once you have an IRCode stripped of its MethodInstance. The MethodInstance is needed to provide context to inference

Itβs very hacky and not recommended, but you can do something like:

``````julia> function foo(x,y)
z = x * y
z + sin(x)
end;

julia> ir, = only(Base.code_ircode(foo, (Float64,Float64,)));

julia> ir.stmts.type .= Any;

julia> ir
2 1 β %1 = Base.mul_float(_2, _3)::Any                                ββ» *
3 β   %2 = invoke Main.sin(_2::Float64)::Any                          β
β   %3 = Base.add_float(%1, %2)::Any                                ββ» +
βββ      return %3                                                  β

julia> interp = Core.Compiler.NativeInterpreter();

julia> world = Core.Compiler.get_world_counter(interp);

julia> mi = only(methods(foo)).specializations;

julia> argtypes = Any[Core.Const(foo), Float64, Float64];

julia> src = first(@code_typed foo(1.0, 2.0));

julia> method_info = Core.Compiler.MethodInfo(src);

julia> irsv = Core.Compiler.IRInterpretationState(interp, method_info, ir, mi, argtypes, world, src.min_world, src.max_world);

julia> irsv.argtypes_refined .|= true; # force reinference

julia> Core.Compiler.ir_abstract_constant_propagation(interp, irsv);

julia> irsv.ir
2 1 β %1 = Base.mul_float(_2, _3)::Float64                            ββ» *
3 β   %2 = invoke Main.sin(_2::Float64)::Float64                      β
β   %3 = Base.add_float(%1, %2)::Float64                            ββ» +
βββ      return %3                                                  β
``````
1 Like

My problem is clearly that I do not understand internals of enough Julia, which means I am mainly hacking my way through it.

I was looking to reverse pass of Diffractor it coverts its `IRCode` to `CodeInfo` and passes it further. This puzzles me a bit, because `CodeInfo` contains function calls defined by `GlobalRef` as in `IRCode`. I therefore do not know, how the type inference can obtain handle to the method. To be concrete, if I take the `CodeInfo` of `foo` as

``````ci = @code_lowered foo(1.0,1.0)
julia> ci.code[1].args[2].args[1]
:(Main.:*)

julia> ci.code[1].args[2].args[1] |> typeof
GlobalRef
``````

I just do not have a good mental picture.

1 Like

I think the answer for that is CodeInfo is the input data-structure for Inference and IRCode is the input format for optimization.

Because it performs a static eval on the GlobalRef and if it is defined uses the value found?

1 Like

Thanks Valentin for the answer. It makes a lot of sense. Can I find somewhere the code performing static eval on the GlobalRef? I would try to hijack it.

Also, if you think that mu toying is useless (I do it for educational purpose), better to tell me. My goal is to lower a bar a bit for people who want to mess around the IRCode and internals in general.

1 Like

No I think itβs very worthwhile, we donβt have the material to point people to and soften their journey and lowering the barrier of contributing is something that would be great!

3 Likes

Thanks Valentin for the answer. It makes a lot of sense. Can I find somewhere the code performing static eval on the GlobalRef? I would try to hijack it.

I see, so it does something like this to get the function handle?

``````julia> ci.code[1].args[2].args[1] |> eval
* (generic function with 334 methods)
``````

This might make it fly.

The trick I have mentioned above worked quite nicely. I have written a following function to get a type of `ChainRules.rrule`

``````
argtype(ir::CC.IRCode, a::Core.Argument) = ir.argtypes[a.n]
argtype(ir::CC.IRCode, a::Core.SSAValue) = ir.stmts.type[a.id]
argtype(ir::CC.IRCode, f::GlobalRef) = typeof(eval(f))
argtype(ir::CC.IRCode, a) = error("argtype of \$(typeof(a)) not supported")

"""
type_of_pullback(ir, inst)

infer type of the pullback
"""
function type_of_pullback(ir, inst, optimize_until = "compact 1")
inst.head != :call && error("inferrin return type of calls is supported")
params = tuple([argtype(ir, a) for a in inst.args]...)
(ir, rt) = only(Base.code_ircode(ChainRules.rrule, params; optimize_until))
if !(rt <:Tuple{A,B} where {A,B})
error("The return type of pullback `ChainRules.rrule(\$(params))` should be tuple")
end
rt
end
``````

This seems to work quite nicely so far. Adding this to the code above, I am able to construct a fully typed `IRCode`

``````julia> new_ir = CC.IRCode(is, cfg, ir.linetable, ir.argtypes, ir.meta, ir.sptypes)
2 1 β %1 = ChainRules.rrule(Main.:*, _2, _3)::Tuple{Float64, ChainRules.var"#times_pullback2#1331"{Float64, Float64}}
β   %2 = getindex(%1, 1)::Float64                                                                               β
β        getindex(%1, 2)::ChainRules.var"#times_pullback2#1331"{Float64, Float64}                               β
3 β   %4 = ChainRules.rrule(Main.sin, _2)::Tuple{Float64, ChainRules.var"#sin_pullback#1291"{Float64}}            β
β   %5 = getindex(%4, 1)::Float64                                                                               β
β        getindex(%4, 2)::ChainRules.var"#sin_pullback#1291"{Float64}                                           β
β   %7 = ChainRules.rrule(Main.:+, %2, %5)::Tuple{Float64, ChainRules.var"#+_pullback#1319"{Bool, Bool, ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}, ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}}}
β   %8 = getindex(%7, 1)::Float64                                                                               β
β        getindex(%7, 2)::ChainRules.var"#+_pullback#1319"{Bool, Bool, ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}, ChainRulesCore.ProjectTo{Float64, @NamedTuple{}}}
βββ      return %8                                                                                              β
``````
5 Likes

If it matters, I find these questions and Discourse discussions quite educational and good future references for me. I would also like to learn more about these low level details of the Julia compiler and how to hack it but I havenβt been able to wrap my head around most of it. Please keep it up.

5 Likes

I believe replacing `eval` in `argtype(ir::CC.IRCode, f::GlobalRef) = typeof(eval(f))` as

``````argtype(ir::CC.IRCode, f::GlobalRef) = typeof(getproperty(f.mod, f.name))
``````

will work as well and be might considered better practice.

2 Likes

Thanks, I have updated my code.

1 Like