Type-stability of recursive functions

Looks fine for me on 1.9.0

using Static, ArrayInterface, BenchmarkTools

# Compute determinant of a matrix with equal-length row/column vectors `vs`
det(vs::Vararg{Any,N}) where N = _det(ntuple(identity, N), vs)

@inline _det((ind,)::NTuple{1,Int}, (v,)::NTuple{1,Any}) = getindex(v, ind)

@inline _det(inds::NTuple{N,Int}, (v, vs...)::NTuple{N,Any}) where N =
    +(map((sgn, (i, is...)) -> flipsign(v[i], sgn) * _det(is, vs),
          NTuple{N,Int}(altsigns(Int)), swapeach_1st(inds))...)

using Base.Iterators: map as imap
const CanonicalInt = Union{StaticInt, Int}

# Like `ntuple`, but uses `StaticInt`s
sntuple(f, n::StaticInt) = _sntuple(f, n, n)
_sntuple(_, ::StaticInt, ::StaticInt{0}) = ()
_sntuple(f, n::StaticInt, i::StaticInt) =
    (f(n - i + static(1)), _sntuple(f, n, i - static(1))...)

altsigns(init) = Iterators.cycle((init, -init))
altsigns(T::Type) = altsigns(oneunit(T))
# altsigns(Int) -> (1, -1, 1, -1, ...)

# Return the input tuple with the element at the index `i` swapped with the
# first element.
# julia> swap_1st(static(2), ('a', 'b', 'c'))
# ('b', 'a', 'c')
function swap_1st(i::StaticInt, tup::Tuple)
    N = ArrayInterface.static_length(tup)
    1 ≤ i ≤ N || return tup
    (tup[i], sntuple(j -> tup[j], i - static(1))...,
            sntuple(j -> tup[j+i], N - i)...)
end

# Return a tuple of tuples where the ith tuple swaps the 1st and i-th elements
# of the input tuple.
# julia> swapeach_1st(('a', 'b', 'c'))
# (('a', 'b', 'c'), ('b', 'a', 'c'), ('c', 'a', 'b'))
swapeach_1st(tup::NTuple{N,Any}) where {N} =
    sntuple(i -> swap_1st(i, tup), static(N))          

@btime det(NTuple{3}(eachcol(rand(3,3)))...)  # type stable -> Float64
@btime det(NTuple{4}(eachcol(rand(4,4)))...)

@code_warntype det(NTuple{3}(eachcol(rand(3,3)))...)  # type stable -> Float64
@code_warntype det(NTuple{4}(eachcol(rand(4,4)))...)

yields

  118.053 ns (1 allocation: 128 bytes)
  174.590 ns (1 allocation: 192 bytes)
MethodInstance for det(::SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, ::SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, ::SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true})
  from det(vs::Vararg{Any, N}) where N in Main at ...
Static Parameters
  N = 3
Arguments
  #self#::Core.Const(det)
  vs::Tuple{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}
Body::Float64
1 ─ %1 = Main.ntuple(Main.identity, $(Expr(:static_parameter, 1)))::Core.Const((1, 2, 3))
│   %2 = Main._det(%1, vs)::Float64
└──      return %2

MethodInstance for det(::SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, ::SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, ::SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, ::SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true})
  from det(vs::Vararg{Any, N}) where N in Main at ...
Static Parameters
  N = 4
Arguments
  #self#::Core.Const(det)
  vs::NTuple{4, SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}
Body::Float64
1 ─ %1 = Main.ntuple(Main.identity, $(Expr(:static_parameter, 1)))::Core.Const((1, 2, 3, 4))
│   %2 = Main._det(%1, vs)::Float64
└──      return %2

Edit: posted complete code and corrected results…

2 Likes