Hi everyone,
I’m currently implementing JAX’s function transformations in Julia. In my code I dispatch on the number of the inputs to ensure that the length of slices_iterator is known at compile time. However this appears not to be the case. When I @code_warntype my function, I get an ::Any return.
I understand this instability comes from the ... splat. Can anyone suggest how to work around this? I was trying to think of places to add function barriers but given that my function is so short I’m not sure how to approach it.
function vmap(f, in_axes::NTuple{N,Int}) where {N}
    function vmapped(args::Vararg{T,N}) where {T}
        slices_iterator = zip((eachslice(arg, dims=axis) for (arg, axis) in zip(args, in_axes))...)
        return map(x -> f(x...), slices_iterator)
    end
    return vmapped
end
begin
    f(a, b) = reduce(*, a .+ b)
    in_axes = (2, 1)
    g = vmap(f, in_axes)
    A = Matrix(reshape(1:9, 3, 3))
    B = Matrix(reshape(9 .+ (1:9), 3, 3))
end
@code_warntype g(A, B)
Output:
MethodInstance for (::var"#vmapped#317"{2, typeof(f), Tuple{Int64, Int64}})(::Matrix{Int64}, ::Matrix{Int64})
  from (::var"#vmapped#317"{N})(args::Vararg{T, N}) where {T, N} in Main at /Users/smit/.julia/dev/JAXTransformations/src/vmap.jl:26
Static Parameters
  T = Matrix{Int64}
  N = 2
Arguments
  #self#::var"#vmapped#317"{2, typeof(f), Tuple{Int64, Int64}}
  args::Tuple{Matrix{Int64}, Matrix{Int64}}
Locals
  #316::var"#316#319"{typeof(f)}
  #315::var"#315#318"
  slices_iterator::Base.Iterators.Zip
Body::Any
1 ─       (#315 = %new(Main.:(var"#315#318")))
│   %2  = #315::Core.Const(var"#315#318"())
│   %3  = Core.getfield(#self#, :in_axes)::Tuple{Int64, Int64}
│   %4  = Main.zip(args, %3)::Base.Iterators.Zip{Tuple{Tuple{Matrix{Int64}, Matrix{Int64}}, Tuple{Int64, Int64}}}
│   %5  = Base.Generator(%2, %4)::Base.Generator{Base.Iterators.Zip{Tuple{Tuple{Matrix{Int64}, Matrix{Int64}}, Tuple{Int64, Int64}}}, var"#315#318"}
│         (slices_iterator = Core._apply_iterate(Base.iterate, Main.zip, %5))
│   %7  = Main.:(var"#316#319")::Core.Const(var"#316#319")
│   %8  = Core.getfield(#self#, :f)::Core.Const(f)
│   %9  = Core.typeof(%8)::Core.Const(typeof(f))
│   %10 = Core.apply_type(%7, %9)::Core.Const(var"#316#319"{typeof(f)})
│   %11 = Core.getfield(#self#, :f)::Core.Const(f)
│         (#316 = %new(%10, %11))
│   %13 = #316::Core.Const(var"#316#319"{typeof(f)}(f))
│   %14 = Main.map(%13, slices_iterator)::Any
└──       return %14
