Inference bug when using nested generators?

The following code does not infer:

f(n) = (i^2 for i in 1:n)
g(n) = sum(x^2 for x in f(n))
@code_warntype g(3)

Output:

MethodInstance for g(::Int64)
  from g(n) in Main at /home/jonathan/AlphaZero.jl/redesign/private/inference.jl:2
Arguments
  #self#::Core.Const(g)
  n::Int64
Locals
  #61::var"#61#62"
Body::Any
1 ─      (#61 = %new(Main.:(var"#61#62")))
│   %2 = #61::Core.Const(var"#61#62"())
│   %3 = Main.f(n)::Core.PartialStruct(Base.Generator{UnitRange{Int64}, var"#59#60"}, Any[Core.Const(var"#59#60"()), Core.PartialStruct(UnitRange{Int64}, Any[Core.Const(1), Int64])])
│   %4 = Base.Generator(%2, %3)::Core.PartialStruct(Base.Generator{Base.Generator{UnitRange{Int64}, var"#59#60"}, var"#61#62"}, Any[Core.Const(var"#61#62"()), Core.PartialStruct(Base.Generator{UnitRange{Int64}, var"#59#60"}, Any[Core.Const(var"#59#60"()), Core.PartialStruct(UnitRange{Int64}, Any[Core.Const(1), Int64])])])
│   %5 = Main.sum(%4)::Any
└──      return %5

Does anyone have an explanation why type inference would give up in such a scenario? I am using Julia 1.7.3.

2 Likes

Yeah it’s weird. The return type is inferred with any of these changes to g:

  1. change the comprehension to create a temporary Vector [x^2 for x in f(n)]
  2. don’t square each element (x for x in f(n))
  3. replace sum with first

With (3), you can make an equivalent workaround where the return type is inferred:

julia> function g2(n)
         y = (x^2 for x in f(n))
         summ = zero(first(y))
         for el in y
           summ += el
         end
         summ
       end

julia> all([g(i) == g2(i) for i in 1:100])
true

julia> using Test

julia> @inferred g2(27)
3142062

julia> @inferred g(27)
ERROR: return type Int64 does not match inferred return type Any
1 Like

I think the reason is this:

julia> f(n) = (i^2 for i in 1:n)
f (generic function with 1 method)

julia> gen = f(3)
Base.Generator{UnitRange{Int64}, var"#5#6"}(var"#5#6"(), 1:3)

julia> eltype(gen)
Any

This hits the fallback definition for eltype, which just always returns Any. I don’t know why this is using the fallback, but I imagine it has to do with having to figure out the return type of every possible call to the function given to the generator for every possible element produced by the iterator the generator is iterating over.

Perhaps this could be improved by feeding the eltype of the iterator to inference for the function of the generator? It’s definitely a tricky thing to get right in general though, since that syntax is basically creating a closure, which may capture more state from the surrounding code as well. Additionally, explicitly calling into inference for this is often discouraged, I think - it forces dynamism.

1 Like

Scratch that - just adding

sum(g::Generator) = sum(g.f, g.iter)

Makes inference happy:

julia> Base.sum(g::Base.Generator) = sum(g.f, g.iter)

julia> f(n) = (i^2 for i in 1:n)
f (generic function with 1 method)

julia> g(n) = sum(x^2 for x in f(n))
g (generic function with 1 method)

julia> @code_warntype g(3)
MethodInstance for g(::Int64)
  from g(n)
     @ Main REPL[3]:1
Arguments
  #self#::Core.Const(g)
  n::Int64
Locals
  #5::var"#5#6"
Body::Int64
1 ─      (#5 = %new(Main.:(var"#5#6")))
│   %2 = #5::Core.Const(var"#5#6"())
│   %3 = Main.f(n)::Core.PartialStruct(Base.Generator{UnitRange{Int64}, var"#3#4"}, Any[Core.Const(var"#3#4"()), Core.PartialStruct(UnitRange{Int64}, Any[Core.Const(1), Int64])])
│   %4 = Base.Generator(%2, %3)::Core.PartialStruct(Base.Generator{Base.Generator{UnitRange{Int64}, var"#3#4"}, var"#5#6"}, Any[Core.Const(var"#5#6"()), Core.PartialStruct(Base.Generator{UnitRange{Int64}, var"#3#4"}, Any[Core.Const(var"#3#4"()), Core.PartialStruct(UnitRange{Int64}, Any[Core.Const(1), Int64])])])
│   %5 = Main.sum(%4)::Int64
└──      return %5

I don’t know if this was just forgotten or is deliberately done, so maybe it’s worth an issue/short PR for discussion?

1 Like

sum(x^2 for x in f(5)) is inferable while sum(x for x in f(5)) isn’t. Both use the most generic sum(a; kw...) = sum(identity, a; kw...) in reduce.jl, and I followed the trail through mapreduce, mapfoldl, mapfoldl_impl, and a recursive _xfadjoint. The type instability seems to be rooted in MappingRF’s 2nd type parameter not being inferred in _xfadjoint for the x^2 case.

I think it runs Base._xfadjoint(Base.BottomRF(Base.add_sum), Base.Generator(identity, (x^2 for x in f(5)) ) ). The x^2 case returns a MappingRF instance that contains another MappingRF instance, while the x case returns an unnested MappingRF instance. Test.@inferred is able to tell that the x^2 case is not inferable, but @code_warntype prints blue text for the call’s return type, a weird discrepancy I’ve noticed before in much stranger circumstances.

1 Like

Thanks everyone for having a look!
I opened an issue: Inference suboptimality when using nested generators · Issue #45748 · JuliaLang/julia · GitHub.

If you supply the init keyword to sum, inference works. To me this suggests the problem is that sum is not able to infer which type it should use for initialization.

What call, exactly? @code_warntype sum(x^2 for x in f(5); init=0) infers ::Any for me.

1 Like
julia> f(n) = (i^2 for i in 1:n)
f (generic function with 1 method)

julia> g(n) = sum(x^2 for x in f(n), init=0)
g (generic function with 1 method)

julia> @code_warntype g(3)
MethodInstance for g(::Int64)
  from g(n) in Main at REPL[3]:1
Arguments
  #self#::Core.Const(g)
  n::Int64
Locals
  #5::var"#5#6"
Body::Int64

The difference was using a comma vs semicolon before the keyword argument. It says the same line in the source, but the printed method looks really different. Honestly not sure what’s going on.

julia> @which sum(x^2 for x in f(3), init=0) # type-stable
sum(a; kw...) in Base at reduce.jl:532

julia> @which sum(x^2 for x in f(3); init=0) # type-unstable
(::Base.var"#sum##kw")(::Any, ::typeof(sum), a) in Base at reduce.jl:532

0 would replace the init=_InitialValue() in mapfoldl. and it is passed to foldl_impl. It doesn’t seem to seem to help @code_warntype infer mapfoldl though.

I can reproduce the comma vs semicolon difference on 1 week old master. Moreover, the result of @code_warntype is different if g has been called, for example:

julia> f(n) = (i^2 for i in 1:n)
f (generic function with 1 method)

julia> g(n) = sum(x^2 for x in f(n); init=0)
g (generic function with 1 method)

julia> @code_warntype g(3) # does not infer
MethodInstance for g(::Int64)
  from g(n) in Main at REPL[2]:1
Arguments
  #self#::Core.Const(g)
  n::Int64
Locals
  #5::var"#5#6"
Body::Any
1 ─       (#5 = %new(Main.:(var"#5#6")))
│   %2  = #5::Core.Const(var"#5#6"())
│   %3  = Main.f(n)::Core.PartialStruct(Base.Generator{UnitRange{Int64}, var"#3#4"}, Any[Core.Const(var"#3#4"()), Core.PartialStruct(UnitRange{Int64}, Any[Core.Const(1), Int64])])
│   %4  = Base.Generator(%2, %3)::Core.PartialStruct(Base.Generator{Base.Generator{UnitRange{Int64}, var"#3#4"}, var"#5#6"}, Any[Core.Const(var"#5#6"()), Core.PartialStruct(Base.Generator{UnitRange{Int64}, var"#3#4"}, Any[Core.Const(var"#3#4"()), Core.PartialStruct(UnitRange{Int64}, Any[Core.Const(1), Int64])])])
│   %5  = (:init,)::Core.Const((:init,))
│   %6  = Core.apply_type(Core.NamedTuple, %5)::Core.Const(NamedTuple{(:init,)})
│   %7  = Core.tuple(0)::Core.Const((0,))
│   %8  = (%6)(%7)::Core.Const((init = 0,))
│   %9  = Core.kwfunc(Main.sum)::Core.Const(Base.var"#sum##kw"())
│   %10 = (%9)(%8, Main.sum, %4)::Any
└──       return %10

julia> g(4)
354

julia> @code_warntype g(3) # now it infers?
MethodInstance for g(::Int64)
  from g(n) in Main at REPL[2]:1
Arguments
  #self#::Core.Const(g)
  n::Int64
Locals
  #5::var"#5#6"
Body::Int64
1 ─       (#5 = %new(Main.:(var"#5#6")))
│   %2  = #5::Core.Const(var"#5#6"())
│   %3  = Main.f(n)::Core.PartialStruct(Base.Generator{UnitRange{Int64}, var"#3#4"}, Any[Core.Const(var"#3#4"()), Core.PartialStruct(UnitRange{Int64}, Any[Core.Const(1), Int64])])
│   %4  = Base.Generator(%2, %3)::Core.PartialStruct(Base.Generator{Base.Generator{UnitRange{Int64}, var"#3#4"}, var"#5#6"}, Any[Core.Const(var"#5#6"()), Core.PartialStruct(Base.Generator{UnitRange{Int64}, var"#3#4"}, Any[Core.Const(var"#3#4"()), Core.PartialStruct(UnitRange{Int64}, Any[Core.Const(1), Int64])])])
│   %5  = (:init,)::Core.Const((:init,))
│   %6  = Core.apply_type(Core.NamedTuple, %5)::Core.Const(NamedTuple{(:init,)})
│   %7  = Core.tuple(0)::Core.Const((0,))
│   %8  = (%6)(%7)::Core.Const((init = 0,))
│   %9  = Core.kwfunc(Main.sum)::Core.Const(Base.var"#sum##kw"())
│   %10 = (%9)(%8, Main.sum, %4)::Int64
└──       return %10

which is definitely mysterious to me.

2 Likes

The issue with semicolon vs comma is simply Julia’s handling of keyword arguments. More information here. This is not really relevant to the inference issue of sum, as far as I can tell.

Oh dear, this is awful. Typing sum(i for i in 1:3, init=0) parses the init=0 to be a loop itself. I.e. it is equivalent to

itr = (i for i in 1:3, init in 0)
sum(itr)

A different issue is then Julia’s kwsorter. So, to recap:

  • Without the init keyword, there is runtime dispatch to find the element type that should be used in the sum function
  • With the init keyword, the dispatch to kwsorter causes runtime dispatch. This can be solved by wrapping g in another function
  • If the init keyword is not properly added with a semicolon, it is parsed as a single nested loop without any keywords
3 Likes

Certainly. But the issue that @code_warntype infers differently before and after calling g is independent of the kwarg, consider:

julia> f(n) = (i^2 for i in 1:n);

julia> g(n) = sum(x^2 for x in f(n));

julia> (@code_typed g(3))[2]
Any

julia> g(6)
2275

julia> (@code_typed g(3))[2]
Int64

So, ultimately, I’m not sure whether these codes do infer or not.

For information, testing with @inferred errors even after executing g(6).

2 Likes

Is this not the expected behavior?

Hmm, now it seems like this JET.jl bug may be caused by an underlying Julia issue: Result of @report_opt depends on whether there were previous @report_opt calls · Issue #352 · aviatesk/JET.jl · GitHub

This does not look like reasonable behavior to me. When init is not specified in sum(xs), a default like zero(eltype(xs)) should be used whose type is inferrable.

Another issue is the puzzling @code_warntype behavior reported by @Liozou. It is unclear to me whether or not these two issues are related.

That is currently the case:

julia> sum(i for i in 6:5) # note that 6:5 is empty
0

so the question is whether the eltype of the generator in g is inferred in the call to sum… And the answer is no, since g(0) errors, calling for an explicit init kwarg. I agree that this should infer though – but that’s precisely the point of this thread.

Note that the following errors too, so the inference issue might not come from f at all:

julia> sum(i^2 for i in 6:5)
ERROR: MethodError: reducing over an empty collection is not allowed; consider supplying `init` to the reducer
Stacktrace:

julia> eltype(f(1))
Any

regardless, you are right. It should be able to infer, even without the init keyword. And I’m still stumped as to why it suddenly infers after you call it once.

This eltype is just due to eltype(::Generator) falling back to the fallback definition. Adding a specialized version does not fix this, as I already mentioned above:

julia> Base.eltype(g::Base.Generator) = Union{Core.Compiler.return_types(g.f, (eltype(g.iter),))...}

julia> f(n) = (i^2 for i in 1:n)
f (generic function with 1 method)

julia> g(n) = sum(x^2 for x in f(n))
g (generic function with 1 method)

julia> @code_warntype g(3)
MethodInstance for g(::Int64)
  from g(n)
     @ Main REPL[3]:1
Arguments
  #self#::Core.Const(g)
  n::Int64
Locals
  #5::var"#5#6"
Body::Any
1 ─      (#5 = %new(Main.:(var"#5#6")))
│   %2 = #5::Core.Const(var"#5#6"())
│   %3 = Main.f(n)::Core.PartialStruct(Base.Generator{UnitRange{Int64}, var"#3#4"}, Any[Core.Const(var"#3#4"()), Core.PartialStruct(UnitRange{Int64}, Any[Core.Const(1), Int64])])
│   %4 = Base.Generator(%2, %3)::Core.PartialStruct(Base.Generator{Base.Generator{UnitRange{Int64}, var"#3#4"}, var"#5#6"}, Any[Core.Const(var"#5#6"()), Core.PartialStruct(Base.Generator{UnitRange{Int64}, var"#3#4"}, Any[Core.Const(var"#3#4"()), Core.PartialStruct(UnitRange{Int64}, Any[Core.Const(1), Int64])])])
│   %5 = Main.sum(%4)::Any
└──      return %5

However, calling eltype(f(3)) then subsequently does make this infer, without ever calling g.

Either way, specializing sum for Generator instead of going through sum(identity, gen) already fixes this. I’m not sure how much more general you want to go with this investigation though.

I’m inclined to think that this has to do with some invalidation/specialization that happens when g or f is actually called.

Ack, syntax overlap strikes again! The proper parentheses around the Generator comprehension does the intended (parentheses should only be omitted if its the only argument in a call):

julia> @which sum((i for i in 1:3), init=0)
(::Base.var"#sum##kw")(::Any, ::typeof(sum), a) in Base at reduce.jl:532

julia> @which sum(i for i in 1:3; init=0)
(::Base.var"#sum##kw")(::Any, ::typeof(sum), a) in Base at reduce.jl:532

Anyway, back to the actual post.

Fixing a generic implementation could benefit a lot more types. Not sure what types, though, summing over a Generator is the only case I can think of, and if the fix is only in _xfadjoint, it’s already a Generator-peeling recursion. Though the outermost Generator layer for _xfadjoint was created during the process, so I suppose even a non-Generator input would get wrapped in one and end up there.

I can reproduce this on v1.7.3, it’s similar to the thread I linked when I said “much stranger circumstances”, which also involves summing over Generator.

2 Likes