Hi everyone,
I’m currently implementing JAX’s function transformations in Julia. In my code I dispatch on the number of the inputs to ensure that the length of slices_iterator
is known at compile time. However this appears not to be the case. When I @code_warntype
my function, I get an ::Any
return.
I understand this instability comes from the ...
splat. Can anyone suggest how to work around this? I was trying to think of places to add function barriers but given that my function is so short I’m not sure how to approach it.
function vmap(f, in_axes::NTuple{N,Int}) where {N}
function vmapped(args::Vararg{T,N}) where {T}
slices_iterator = zip((eachslice(arg, dims=axis) for (arg, axis) in zip(args, in_axes))...)
return map(x -> f(x...), slices_iterator)
end
return vmapped
end
begin
f(a, b) = reduce(*, a .+ b)
in_axes = (2, 1)
g = vmap(f, in_axes)
A = Matrix(reshape(1:9, 3, 3))
B = Matrix(reshape(9 .+ (1:9), 3, 3))
end
@code_warntype g(A, B)
Output:
MethodInstance for (::var"#vmapped#317"{2, typeof(f), Tuple{Int64, Int64}})(::Matrix{Int64}, ::Matrix{Int64})
from (::var"#vmapped#317"{N})(args::Vararg{T, N}) where {T, N} in Main at /Users/smit/.julia/dev/JAXTransformations/src/vmap.jl:26
Static Parameters
T = Matrix{Int64}
N = 2
Arguments
#self#::var"#vmapped#317"{2, typeof(f), Tuple{Int64, Int64}}
args::Tuple{Matrix{Int64}, Matrix{Int64}}
Locals
#316::var"#316#319"{typeof(f)}
#315::var"#315#318"
slices_iterator::Base.Iterators.Zip
Body::Any
1 ─ (#315 = %new(Main.:(var"#315#318")))
│ %2 = #315::Core.Const(var"#315#318"())
│ %3 = Core.getfield(#self#, :in_axes)::Tuple{Int64, Int64}
│ %4 = Main.zip(args, %3)::Base.Iterators.Zip{Tuple{Tuple{Matrix{Int64}, Matrix{Int64}}, Tuple{Int64, Int64}}}
│ %5 = Base.Generator(%2, %4)::Base.Generator{Base.Iterators.Zip{Tuple{Tuple{Matrix{Int64}, Matrix{Int64}}, Tuple{Int64, Int64}}}, var"#315#318"}
│ (slices_iterator = Core._apply_iterate(Base.iterate, Main.zip, %5))
│ %7 = Main.:(var"#316#319")::Core.Const(var"#316#319")
│ %8 = Core.getfield(#self#, :f)::Core.Const(f)
│ %9 = Core.typeof(%8)::Core.Const(typeof(f))
│ %10 = Core.apply_type(%7, %9)::Core.Const(var"#316#319"{typeof(f)})
│ %11 = Core.getfield(#self#, :f)::Core.Const(f)
│ (#316 = %new(%10, %11))
│ %13 = #316::Core.Const(var"#316#319"{typeof(f)}(f))
│ %14 = Main.map(%13, slices_iterator)::Any
└── return %14