Why can't the compiler infer the size/type of my Zip?

Hi everyone,

I’m currently implementing JAX’s function transformations in Julia. In my code I dispatch on the number of the inputs to ensure that the length of slices_iterator is known at compile time. However this appears not to be the case. When I @code_warntype my function, I get an ::Any return.

I understand this instability comes from the ... splat. Can anyone suggest how to work around this? I was trying to think of places to add function barriers but given that my function is so short I’m not sure how to approach it.

function vmap(f, in_axes::NTuple{N,Int}) where {N}
    function vmapped(args::Vararg{T,N}) where {T}
        slices_iterator = zip((eachslice(arg, dims=axis) for (arg, axis) in zip(args, in_axes))...)
        return map(x -> f(x...), slices_iterator)
    end
    return vmapped
end

begin
    f(a, b) = reduce(*, a .+ b)
    in_axes = (2, 1)
    g = vmap(f, in_axes)

    A = Matrix(reshape(1:9, 3, 3))
    B = Matrix(reshape(9 .+ (1:9), 3, 3))
end

@code_warntype g(A, B)

Output:

MethodInstance for (::var"#vmapped#317"{2, typeof(f), Tuple{Int64, Int64}})(::Matrix{Int64}, ::Matrix{Int64})
  from (::var"#vmapped#317"{N})(args::Vararg{T, N}) where {T, N} in Main at /Users/smit/.julia/dev/JAXTransformations/src/vmap.jl:26
Static Parameters
  T = Matrix{Int64}
  N = 2
Arguments
  #self#::var"#vmapped#317"{2, typeof(f), Tuple{Int64, Int64}}
  args::Tuple{Matrix{Int64}, Matrix{Int64}}
Locals
  #316::var"#316#319"{typeof(f)}
  #315::var"#315#318"
  slices_iterator::Base.Iterators.Zip
Body::Any
1 ─       (#315 = %new(Main.:(var"#315#318")))
│   %2  = #315::Core.Const(var"#315#318"())
│   %3  = Core.getfield(#self#, :in_axes)::Tuple{Int64, Int64}
│   %4  = Main.zip(args, %3)::Base.Iterators.Zip{Tuple{Tuple{Matrix{Int64}, Matrix{Int64}}, Tuple{Int64, Int64}}}
│   %5  = Base.Generator(%2, %4)::Base.Generator{Base.Iterators.Zip{Tuple{Tuple{Matrix{Int64}, Matrix{Int64}}, Tuple{Int64, Int64}}}, var"#315#318"}
│         (slices_iterator = Core._apply_iterate(Base.iterate, Main.zip, %5))
│   %7  = Main.:(var"#316#319")::Core.Const(var"#316#319")
│   %8  = Core.getfield(#self#, :f)::Core.Const(f)
│   %9  = Core.typeof(%8)::Core.Const(typeof(f))
│   %10 = Core.apply_type(%7, %9)::Core.Const(var"#316#319"{typeof(f)})
│   %11 = Core.getfield(#self#, :f)::Core.Const(f)
│         (#316 = %new(%10, %11))
│   %13 = #316::Core.Const(var"#316#319"{typeof(f)}(f))
│   %14 = Main.map(%13, slices_iterator)::Any
└──       return %14

please don’t… What actual problem you think you need a vmap and only vmap can solve ?

btw, fundamentally, this won’t work because you’re not forcing specialization on f, but also because # of arguments is not a property of typeof(f), we don’t have any parametric type information regarding methods.

The syntax of using vmap is rly nice :slight_smile:

I don’t rly get what you are saying by this, could you explain more?

Let say we define a function foo as follows:

julia> foo(x) = x == sin ? 1 : 2
foo (generic function with 1 method)

julia> foo(sin)
1

julia> foo(5)
2

julia> ms = methods(foo).ms
[1] foo(x) in Main at REPL[10]:1

julia> ms[1].specializations
svec(MethodInstance for foo(::Function), MethodInstance for foo(::Int64), nothing, nothing, nothing, nothing, nothing, nothing)

We see that that foo is only specialized for a Function and for a Int64. What it is not specialized for is typeof(sin).

We can specialize for a specific function by adding a type parameter.

julia> foo(::F) where F <: Function = 3
foo (generic function with 2 methods)

julia> foo(sin)
3

julia> ms = methods(foo).ms
[1] foo(::F) where F<:Function in Main at REPL[25]:1
[2] foo(x) in Main at REPL[10]:1

julia> ms[1].specializations      svec(MethodInstance for foo(::typeof(sin)), nothing, nothing, nothing, nothing, nothing, nothing, nothing)

Specialization means that Julia creates a compiled version of the function, a method instance, specifically for a certain type of argument.

In Julia each named function is its own type, but generally we do not specialize other functions based on that specific type.

1 Like

The issue is that eachsliice is not type stable. For example, typeof(eachslice(rand(3, 3), dims=1)) != typeof(eachslice(rand(3, 3), dims=2)), even though the argument types are the same. This means that when the compiler knows the types of f and in_axes, it still has no possible way to figure out the types of the eachslices iterators without knowing the values of in_axes. This is not your fault, but we can hack a way around it by telling the compiler those values using the Val type.

To minimally change your code and remove the type instability:

# New function
stable_eachslice(x; dims::Val{N}) where N = eachslice(x; dims=N)

function vmap(f, in_axes) # remove type annotation
    function vmapped(args::Vararg{T,N}) where {T,N}
        slices_iterator = zip((stable_eachslice(arg, dims=axis) for (arg, axis) in zip(args, in_axes))...)
        return map(x -> f(x...), slices_iterator)
    end
    return vmapped
end

# Convert Ints to Vals to move the type instability into the outer function
vmap(f, in_axes::NTuple{N,Int}) where {N} = vmap(f, Val.(in_axes))

begin
    f(a, b) = reduce(*, a .+ b)
    in_axes = (2, 1)
    g = vmap(f, in_axes)

    A = Matrix(reshape(1:9, 3, 3))
    B = Matrix(reshape(9 .+ (1:9), 3, 3))
end

@code_warntype g(A, B)

To also make some style changes:

stable_eachslice(x; dims::Val{N}) where N = eachslice(x; dims=N)

function vmap(f, in_axes)
    function vmapped(args...)
        map(f, (stable_eachslice(arg, dims=axis) for (arg, axis) in zip(args, in_axes))...)
    end
end

begin
    f(a, b) = reduce(*, a .+ b)
    in_axes = (Val(2), Val(1))
    g = vmap(f, in_axes)

    A = Matrix(reshape(1:9, 3, 3))
    B = Matrix(reshape(9 .+ (1:9), 3, 3))
end

@code_warntype g(A, B)

Note that this works in 1.9 thanks to @simonbyrne’s https://github.com/JuliaLang/julia/pull/32310 but not 1.8. In 1.8 and earlier, eachslice was even worse.

2 Likes

Thank you both for the explanations! I feel like I understand specialisation much better now :slight_smile: