Expand a Tuple at given positions, type stable dims

question

#1

I would like to make a PR extending #29749 by generalizing eachslice to generic dimensions.

I have two questions:

  1. What’s the recommended API for specifying dimensions in a type stable way? Eg should I use a Tuple of Val indexes, as in dims = (Val(1), Val(3))?

  2. Is there a good solution to inserting Colon() into a CartesianIndex at positions known at compile time (supposing there is an answer to the above questions).

For (2), I came up with a recursive solution (using tupes of numbers, and assuming insertion positions are ordered, also no error checking):


@inline _insert_colons(before, after) = (before..., after...) # done
@inline function _insert_colons(before, after, ::Val{1}, rest...)
    _insert_colons((before..., Colon()), after, rest...)
end
@inline _decr_index(::Val{T}) where T = Val{T-1}()
@inline function _insert_colons(before, after, rest...)
    _insert_colons((before..., first(after)), Base.tail(after), map(_decr_index, rest)...)
end

insert_colons(indexes, positions) = _insert_colons((), indexes, positions...)

@code_warntype insert_colons((1, 2, 3), (Val{1}(), Val{2}()))

but it compiles into a PartialStruct.


#2

Nice! I’m more than happy to provide a bit of feedback to help make this happen.

For point 1, the key is to lean on constant propagation instead of explicit Val usage. The key test cases, then, are going to be introspecting functions that call eachslice with a constant literal tuple. For example:

julia> f(;dims=()) = dims[1]+dims[2]+dims[3]
f (generic function with 1 method)

julia> g() = f(dims=(1,2,3))
g (generic function with 1 method)

julia> @code_llvm g()

;  @ REPL[5]:1 within `g'
define i64 @julia_g_12214() {
top:
  ret i64 6
}

For this to work, you need to ensure that f inlines up to the point where you use dims. Now combining this with a more complicated algorithm (point 2) is gonna be tricky. Encoding these literal constants as quickly as possible into a typed tuple is probably going to be your best bet. Unfortunately the naive things with ntuple doesn’t quite do the trick:

julia> f(A;dims) = ntuple(i->i in dims ? (:) : 1, ndims(A))
f (generic function with 1 method)

julia> g(A) = f(A, dims=(2,3))
g (generic function with 1 method)

julia> @code_warntype g(rand(3,4,5))
Body::Tuple{Union{Colon, Int64},Colon,Union{Colon, Int64}}

It’s rather amazing how Julia is able to concretely identify the second element in that tuple, but it’s not quite able to follow in the whole way along all the checks in the tuple. Even more amazing is that this is all happening with the totally generic in(x, itr) definition. Ok, so perhaps that’s the key, let’s try adding some specialized in methods for small tuples:

julia> Base.in(x, ::Tuple{}) = false
       Base.in(x, t::Tuple{Any}) = x == t[1]
       Base.in(x, t::NTuple{2,Any}) = x == t[1] || x == t[2]
       Base.in(x, t::NTuple{3,Any}) = x == t[1] || x == t[2] || x == t[3]
       Base.in(x, t::NTuple{4,Any}) = x == t[1] || x in Base.tail(t)

julia> @code_warntype g(rand(3,4,5))
Body::Tuple{Int64,Colon,Colon}
# ...
11 ─      return (1, Colon(), Colon())

Now you can encode your constant literals into a Tuple of either Ints or Colons… and then use that for your type-stable algorithm. To use this “template” tuple to replace an arbitrary tuple of indices, walk through them in sync:

replace_colons(template::Tuple{}, idxs::Tuple{}) = ()
replace_colons(template::Tuple{Int, Vararg{Any}}, idxs::Tuple{Any, Vararg{Any}}) = (idxs[1], replace_colons(tail(template), tail(idxs))...)
replace_colons(template::Tuple{Colon, Vararg{Any}}, idxs::Tuple{Any, Vararg{Any}}) = (:, replace_colons(tail(template), tail(idxs))...)

#3

Thanks for your help. Shoudn’t there be a Val-based or similar fallback interface though that does not rely on constant propagation? Then the caller could use the dims = (1,2,3) interface, have that propagate the constants, the rest of the code working in type space.


#4

Why though? It’ll only be fast if folks use a literal constant like Val(2). If they’re doing something dynamic with Val(d), then it’ll be slow.

If we just get things working with constant literals in the first place, then it’ll be fast if folks just write the obvious constant literals. If they’re doing something dynamic, then it’ll be just as slow to construct the (type-unstable) template tuple as it would be to construct the (type-unstable) Vals.


#5

There’s one overarching rule in creating fast and type-inferable recursive tuple functions: each step in the recursion must simplify the types of the arguments. A great way to simplify arguments for tuples is by following the idiom of f(t) = (op(t[1]), f(tail(t))...). Doing something like _decr_index isn’t simplifying the arguments in a way that they type system can reason about — Val(1) is exactly the same complexity as Val(0). That’s why a template tuple here is nice — you can recursively step through it while simplifying the arguments and it occurs in a type-stable manner.


#6

Thanks, that makes sense. I will make a PR over the weekend and link it here.


#7

I must be missing something, because the way I think about this is making a subset of axes, walking that with a CartesianIndices, and putting back the Colon()s for view. If I do replacement like you suggested, I will get the same views multiple times.

Eg in size(A) == (2, 3) along dims = (1, ), I would want view(A, :, 1), view(A, :, 2), view(A, :, 3). But replacing the colons in the elements of CartesianIndices(axes(A)) will still give me 6 views. What am I missing from your suggestion?


#8

Yeah, it’s a little tricky, especially since you’re just building up an iterator itself. You’ll just want to build the CartesianIndices of the dimensions not in dims and then instead of replacing certain dimensions with a : you just insert the colon at the appropriate place.


#9

This would be great to have. Maybe this is obvious, but a pattern I found useful for such things is this: collect(Iterators.product(axes(B,1), Ref(:), axes(B,3))). No CartesianIndices, but I’m not sure how you would feed those to view() anyway.