New syntax for slicing

I’d like to suggest new syntax for eachslice, by overloading getindex (similar to how : behaves). Using / (which looks like a slice) would indicate a dimension to iterate over, i.e.

array[/, :] == eachslice(array; dims=1)

In addition to convenience and clarity, this should help users write more type-stable code. Some packages (like DimensionalData and AxisArrays) allow arbitrary axes for indexing into an array. Typically, these axes are included in the type. However, this means eachslice can be type-unstable, since the types of the slices depend on the value of dims. Constant propagation fixes this in most, but not quite all, cases (as we found when testing DimensionalData). Users slicing with / instead would get 100% guaranteed type stability, since the dimensions to be sliced across can be determined only from the types of the arguments.

Would the core devs be interested in an implementation of this?

4 Likes

This looks interesting. Could you put together a few examples showing this vs einstein notation for different common formulas?

2 Likes

You might like StarSlice.jl, which should probably be updated to the new eachslice.

(The idea of *,: came I believe from early versions of JuliennedArrays.)

7 Likes

As an example, squaring a matrix could be written:

map(sum, x[:, /], x[/, :]) == @einsum y[i, k] := x[i, j] * x[j, k]

I prefer the einsum notation in this case, but it isn’t as general. x[:, /] notation should have its advantage when working with operations that would be unnatural to write in Einstein notation. As an example, calculating the within-group variance for a panel dataset could be:

var.(x[time=/])

Rather than:

var.(eachslice(x; dims=:time))

Thats pretty cool syntax.

@ParadaCarleton the solution for DimensionalData.jl type stability is to use:

eachslice(array; dims=X())

As with most other base methods that accept dims keyword.

3 Likes

This doesn’t necessarily need new syntax - / is a function/an object already after all, so it’d “just” need an overload of getindex. Might be a bit brittle though, if someone changes what / means in their package and doesn’t rely on the one exported from Base.

I don’t think this works, there’s no method of sum taking two arrays, and the final container is 1-dimensional. Perhaps you meant either

slice(x) = broadcast(dot, eachslice(x; dims=2, drop=false), eachslice(x; dims=1, drop=false))
einsum(x) = @einsum y[i, k] := x[i, j] * x[j, k]
base(x) = x * x

or

slice2(x) = map(dot, eachslice(x; dims=2), eachslice(x; dims=1))
einsum2(x) = @einsum y[i] := x[i, j] * x[j, i]
base2(x) = vec(sum(x .* x'; dims=1))  # == diag(x * x)

See the package implementation above.

3 Likes

Potential eachslice(X; dims=Val(:lat)) is simple and typestable.

Simple enough, but quite a bit longer than x[lat=/].

1 Like

Yep, sorry for the error. So a possible notation for matrix multiplication using this new method would be:

dot.(x[:, /], x[/, :])

The goal of this isn’t to replace Einsum, though; I think Einstein notation is cleaner for tensor operations where it applies naturally. The main use case for / is wherever eachslice or mapslices is being used now, like summary statistics over slices.

Maybe worth noting that var has a method which takes dims. And this is much faster than working on individual rows, I think because it can pick a cache-friendly order in which to access the elements:

julia> let x = rand(100, 100)
         a = @btime var.(eachslice($x; dims=1))  # i.e. eachrow
         b = @btime var($x; dims=2)
         a ≈ vec(b)
       end
  16.250 μs (1 allocation: 896 bytes)
  3.833 μs (16 allocations: 2.44 KiB)
true

If you change dims so that this is eachcol, then the two are comparably fast. But if the goal is to handle arrays by dimension name, where you want not to care which dimension is first, then it seems you ought to worry about both (or all) cases.

1 Like

Using var with dims is a good solution for performance at the moment, but as mentioned elsewhere, having to implement dims keywords for all of these reducing functions by hand is both a bit limited and ugly design. It should be possible to implement the optimizations we currently do case-by-case for sum, var, etc. in a more general way, but it’s not immediately clear to me how.

3 Likes

I think the big optimization here is these are all reducing functions, so you can write the same operation as either map(f, eachslice(x; dims)) or as reduce(f, x; dims=Not(dims))*. Sometimes, memory-locality means that rewriting this way can be a huge improvement.

A while back, I thought it would be really nice if traits were added to the language, and could be added to functions as well as structs. If that were the case, marking reducing functions like sum and var with a Reduces trait would let us take advantage of this reorder-for-memory-locality optimization by dispatching on the function trait.

*this doesn’t work with var itself, but variance can be rewritten as a reducing function, which is what I assume the implementation with the dims keyword is doing internally.

It leads to centralize_sumabs2!, which looks a lot like a re-implementation of reduce. I guess it’s not literally reduce as it has to index the array of pre-computed means at the same time as the input & output arrays. Are you suggesting a tidier way to implement that?

The case without dims is literally mapreduce(centralizedabs2fun(m), +, A) (just above). And this is what’s called by any map over slices, of course. It pre-computes the mean, whereas OnlineStats.jl does it in one pass, but seems slower.

I’m not sure if we’re talking about a notation for slicing anymore, though, so much as how best to (a) implement various reduction-like things efficiently, and (b) how best to call them.

For (a) it seems clear that the implementation cannot always just work on slices. For (b), one objection to an API which pretends you are working with slices, while actually doing something different, is that this seems a little fragile. It’s a magic fast path a little like the reduce(vcat, xs) one – which has surprising ways to step off the fast path accidentally. With the “ugly design” of having a dims method if and only if there exists a faster-than-slices implementation, at least you know what you’re getting.

I assume that’s the implicit rule in Base/std.lib. But might be wrong; are there examples of things with dims methods which do just always call mapslices or something? (Is seems that median does call mapslices, but uses mutation to do a bit better than mapslices(median, x; dims).)

Quite a few do share the same implementation as sum. Maybe it would help to list others you include in “etc”?

1 Like

sum, prod, all, any, , , maximum, minimum, median, mode, and every other kind of summary statistic.

You could definitely implement it to calculate both means and variances in one pass, and then reduce using mean(var)+var(mean), but that’s getting off-topic for sure.

That’s exactly what I’m proposing we should fix with the proposal I made. The current dims keyword argument is exactly that kind of magic fastpath–using a dims keyword argument makes this dramatically faster than the fully-supported map(sum, eachslice(x; dims)) syntax and stands in contrast to the always-fast reduce(+, x; dims)) (which is just another syntax for sum). (To be clear, I don’t think dims is always ugly; I just think it’s ugly having to reimplement it for half-a-dozen different functions case-by-case.) The reduce syntax is better and should be encouraged, but there’s nothing wrong with trying to optimize map so new users don’t find their code slow if they try using eachslice instead.