Why can't the compiler infer the size/type of my Zip?

The issue is that eachsliice is not type stable. For example, typeof(eachslice(rand(3, 3), dims=1)) != typeof(eachslice(rand(3, 3), dims=2)), even though the argument types are the same. This means that when the compiler knows the types of f and in_axes, it still has no possible way to figure out the types of the eachslices iterators without knowing the values of in_axes. This is not your fault, but we can hack a way around it by telling the compiler those values using the Val type.

To minimally change your code and remove the type instability:

# New function
stable_eachslice(x; dims::Val{N}) where N = eachslice(x; dims=N)

function vmap(f, in_axes) # remove type annotation
    function vmapped(args::Vararg{T,N}) where {T,N}
        slices_iterator = zip((stable_eachslice(arg, dims=axis) for (arg, axis) in zip(args, in_axes))...)
        return map(x -> f(x...), slices_iterator)
    end
    return vmapped
end

# Convert Ints to Vals to move the type instability into the outer function
vmap(f, in_axes::NTuple{N,Int}) where {N} = vmap(f, Val.(in_axes))

begin
    f(a, b) = reduce(*, a .+ b)
    in_axes = (2, 1)
    g = vmap(f, in_axes)

    A = Matrix(reshape(1:9, 3, 3))
    B = Matrix(reshape(9 .+ (1:9), 3, 3))
end

@code_warntype g(A, B)

To also make some style changes:

stable_eachslice(x; dims::Val{N}) where N = eachslice(x; dims=N)

function vmap(f, in_axes)
    function vmapped(args...)
        map(f, (stable_eachslice(arg, dims=axis) for (arg, axis) in zip(args, in_axes))...)
    end
end

begin
    f(a, b) = reduce(*, a .+ b)
    in_axes = (Val(2), Val(1))
    g = vmap(f, in_axes)

    A = Matrix(reshape(1:9, 3, 3))
    B = Matrix(reshape(9 .+ (1:9), 3, 3))
end

@code_warntype g(A, B)

Note that this works in 1.9 thanks to @simonbyrne’s https://github.com/JuliaLang/julia/pull/32310 but not 1.8. In 1.8 and earlier, eachslice was even worse.

2 Likes