Inference bug when using nested generators?

@descend shows that our compiler fails to inference Base._xfadjoint in this case.

_xfadjoint(op, itr::Generator) =
    if itr.f === identity
        _xfadjoint(op, itr.iter)
    else
        _xfadjoint(MappingRF(itr.f, op), itr.iter)
    end

It unwraps the nested Generator recursively and forms the new MappingRF.

julia> a = (-i for i in 1:3); b = (-i for i in a); c = (-i for i in b);

julia> Base.return_types(Base._xfadjoint, Base.typesof(+, c))
1-element Vector{Any}:
 Tuple{Base.MappingRF{var"#7#8"}, UnitRange{Int64}}

julia> Base.return_types(Base._xfadjoint, Base.typesof(+, b))
1-element Vector{Any}:
 Tuple{Base.MappingRF{var"#7#8", Base.MappingRF{var"#9#10", typeof(+)}}, UnitRange{Int64}}

Failure starts with c (three nested layers), which matches our inference constraint on recursion.

2 Likes

I think this modified version would make our compiler much happier

function _xfadjoint(op, itr)
    itrā€², wrap = _xfadjoint_unwrap(itr)
    wrap(op), itrā€²
end

_xfadjoint_unwrap(itr) = itr, identity
function _xfadjoint_unwrap(itr::Generator)
    itrā€², wrap = _xfadjoint_unwrap(itr.iter)
    itr.f === identity && return itrā€², wrap
    return itrā€², wrap āˆ˜ Fix1(MappingRF, itr.f)
end
function _xfadjoint_unwrap(itr::Filter)
    itrā€², wrap = _xfadjoint_unwrap(itr.itr)
    return itrā€², wrap āˆ˜ Fix1(FilteringRF, itr.flt)
end
function _xfadjoint_unwrap(itr::Flatten)
    itrā€², wrap = _xfadjoint_unwrap(itr.it)
    return itrā€², wrap āˆ˜ FlatteningRF
end

But we need to fix ComposedFunction type inference regression from v1.6 LTS Ā· Issue #45715 Ā· JuliaLang/julia Ā· GitHub and make Base.Fix1(f, Int) stable first.

2 Likes