Selectdim behaviour for arrays of dimension greater than 2

I just noticed a type-instability with selectdim. Consider the function

f(A) = selectdim(A, 1, 1)

For a matrix, we have that the return type of f is a 1D SubArray (as expected):

julia> A = rand(2,2);

julia> @code_warntype f(A)
Variables
  #self#::Core.Compiler.Const(f, false)
  A::Array{Float64,2}

Body::SubArray{Float64,1,Array{Float64,2},Tuple{Int64,Base.Slice{Base.OneTo{Int64}}},true}
1 ─ %1 = Main.selectdim(A, 1, 1)::SubArray{Float64,1,Array{Float64,2},Tuple{Int64,Base.Slice{Base.OneTo{Int64}}},true}
└──      return %1

However, if we call f on an array of dimension 3, the return type of f is a union of a 1D SubArray and a 2D SubArray:

julia> A = rand(2,2,2);

julia> @code_warntype f(A)
Variables
  #self#::Core.Compiler.Const(f, false)
  A::Array{Float64,3}

Body::Union{SubArray{Float64,2,Array{Float64,3},Tuple{Int64,Base.Slice{Base.OneTo{Int64}},Base.Slice{Base.OneTo{Int64}}},true}, SubArray{Float64,1,Array{Float64,3},Tuple{Int64,Base.Slice{Base.OneTo{Int64}},Int64},true}}
1 ─ %1 = Main.selectdim(A, 1, 1)::Union{SubArray{Float64,2,Array{Float64,3},Tuple{Int64,Base.Slice{Base.OneTo{Int64}},Base.Slice{Base.OneTo{Int64}}},true}, SubArray{Float64,1,Array{Float64,3},Tuple{Int64,Base.Slice{Base.OneTo{Int64}},Int64},true}}
└──      return %1

I expected the return type of f to b a 2D SubArray (hence be type stable) which would match the behaviour of @views A[1,:,:] no?
Am I missing something here? when would f ever be a 1D SubArray?

return type of f to b a 2D SubArray (hence be type stable)

SubArray also (necessarily) keeps track of the indices you used as part of the type:

julia> typeof(selectdim(A, 1, 1))
SubArray{Float64, 2, Array{Float64, 3}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}}, true}

julia> typeof(selectdim(A, 2, 1))
SubArray{Float64, 2, Array{Float64, 3}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64, Base.Slice{Base.OneTo{Int64}}}, false}

julia> typeof(selectdim(A, 1, 1)) == typeof(selectdim(A, 2, 1))
false

Now, that said, it feels like we should be “saved” by constant propagation with the hard-coded dimension in f(A) = selectdim(A, 1, 1) — and indeed we almost are! To get our list of indices, we recursively walk through each dimension, ask if it’s equal to the selected one, and conditionally use an int or a slice. The compiler happens to propagate constants through the first two recursions — two dimensions! — but then gives up. That is, it statically figured out that the indices in our SubArray are going to definitely be a 1 and then a :… and then it gave up. That third dimension could either be a 1 or a :, depending upon the value of d. Yes, it feels funny, because the compiler knew the value of d in the first two dimensions, but now has forgotten.

Type stability isn’t about what it will actually return, it’s what it could prove it might return.

1 Like

Absolutely :slight_smile: I thought constant propagation would yield type stability.

What feels funny to me is that in the example I gave, the compiler hesitates between a 1D SubArray (SubArray{Float64,1,(...)}) or a 2D SubArray (SubArray{Float64,2,(...)}) when it seems that it should only be concerned with subtypes of SubArray{Float64,2,(...)}.

Then, the follow-up question is why did the compiler give up after two dimensions?
Is there a workaround?

The rubber hits the road here:

https://github.com/JuliaLang/julia/blob/c708ca20cadc4c9e54edd643ae92f2fbf8bdc131/base/tuple.jl#L57

This is a recursive function that takes a value v, an index i, and a varargs list of arguments (the splatted axes) and will swap out the element at index i for value v. We start off with i being a literal constant — but as we progress, we do arithmetic on it… and at some point the compiler gives up (perhaps intentionally hitting a heuristic that’s trading effort — compile time — for return). Typically, the best tool for working with tuples is ntuple — that has the most smarts baked into it these days I think. We can try hacking in a replacement here:

julia> @inline function Base._setindex(v, i::Int, args...)
           ntuple(dim->ifelse(i==dim, v, args[dim]), length(args))
       end

julia> f(A)
2×2 view(::Array{Float64, 3}, 1, :, :) with eltype Float64:
 0.901717  0.427261
 0.130018  0.669248

julia> @code_warntype f(A)
Variables
  #self#::Core.Const(f)
  A::Array{Float64, 3}

Body::SubArray{Float64, 2, Array{Float64, 3}, Tuple{Int64, Base.Slice{Base.OneTo{Int64}}, Base.Slice{Base.OneTo{Int64}}}, true}

This is probably something we should propose doing…

1 Like

Indeed, your fix looks very neat! :slight_smile: I’ll tag your reply as the solution for this post. Thanks for your help!

Just as a follow-up: I opened this PR on GitHub. It is a faithful transcription of @mbauman’s post just above (cf. the original issue opened on GitHub and linked to the PR).

Additionally, the PR would really benefit from reviews :slight_smile:.

EDIT: to add some motivation, here is a small showcase of the performance improvement:

# julia 1.6.1

using BenchmarkTools

f(A) = selectdim(A, 1, 1)

A = rand(10, 10, 10);

# @code_warntype f(A)
@btime f($A); # 148.503 ns (4 allocations: 128 bytes)

# implement the (one line) PR
Base._setindex(v, i::Int, args...) = ntuple(dim -> ifelse(i == dim, v, args[dim]), length(args))

# @code_warntype f(A)
@btime f($A); # 19.774 ns (0 allocations: 0 bytes)