Type-stability of recursive functions

I have challenged myself to write a (ideally type-stable) function that computes the determinant of (small) matricies using the O(n!) recursive expansion formula without resorting to @generated. I ended up with the following.

# Compute determinant of a matrix with equal-length row/column vectors `vs`
det(vs::Vararg{Any,N}) where N = _det(ntuple(identity, N), vs)

@inline _det((ind,)::NTuple{1,Int}, (v,)::NTuple{1,Any}) = getindex(v, ind)

@inline _det(inds::NTuple{N,Int}, (v, vs...)::NTuple{N,Any}) where N =
    +(map((sgn, (i, is...)) -> flipsign(v[i], sgn) * _det(is, vs),
          NTuple{N,Int}(altsigns(Int)), swapeach_1st(inds))...)

which seems to give the correct numerical result.

The helper functions used in _det are below.

using Static
const CanonicalInt = Union{StaticInt, Int}

# Like `ntuple`, but uses `StaticInt`s
sntuple(f, n::StaticInt) = _sntuple(f, n, n)
_sntuple(_, ::StaticInt, ::StaticInt{0}) = ()
_sntuple(f, n::StaticInt, i::StaticInt) =
    (f(n - i + static(1)), _sntuple(f, n, i - static(1))...)

altsigns(init) = Iterators.cycle((init, -init))
altsigns(T::Type) = altsigns(oneunit(T))
# altsigns(Int) -> (1, -1, 1, -1, ...)

# Return the input tuple with the element at the index `i` swapped with the
# first element.
# julia> swap_1st(static(2), ('a', 'b', 'c'))
# ('b', 'a', 'c')
function swap_1st(i::StaticInt, tup::Tuple)
    N = Arr.static_length(tup)
    1 ≀ i ≀ N || return tup
    (tup[i], sntuple(j -> tup[j], i - static(1))...,
             sntuple(j -> tup[j+i], N - i)...)
end

# Return a tuple of tuples where the ith tuple swaps the 1st and i-th elements
# of the input tuple.
# julia> swapeach_1st(('a', 'b', 'c'))
# (('a', 'b', 'c'), ('b', 'a', 'c'), ('c', 'a', 'b'))
swapeach_1st(tup::NTuple{N,Any}) where {N} =
    sntuple(i -> swap_1st(i, tup), static(N))

The function sntuple is necessary to make swapeach_1st type-stable and uses the Static package.

My goal with this is for det to @inline roughly as follows, generating an explicit formula for an n \times n determinant at compile time without need for @generated.

det(v1, v2, v3)

_det((1,2,3), (v1, v2, v3))

(+v1[1]) * _det((2,3), (v2, v3)) +
    (-v1[2]) * _det((1,3), (v2, v3)) +
    (+v1[3]) * _det((1,2), (v2, v3))

v1[1] * ((+v2[2]) * _det((3,), (v3,)) + (-v2[3]) * _det((2,), (v3,))) -
    v1[2] * ((+v2[1]) * _det((3,), (v3,)) + (-v2[3]) * _det((1,), (v3,))) +
    v1[3] * ((+v2[1]) * _det((2,), (v3,)) + (-v2[2]) * _det((1,), (v3,)))

v1[1] * (v2[2] * v3[3] - v2[3] * v3[2]) -
    v1[2] * (v2[1] * v3[3] - v2[3] * v3[1]) +
    v1[3] * (v2[1] * v3[2] - v2[2] * v3[1])

Interestingly, det is type-stable for 2 \times 2 and 3 \times 3 matrices, but not 4 \times 4 matrices.

det(eachcol(rand(3,3))...)  # type stable -> Float64
det(eachcol(rand(4,4))...)  # not type stable -> Any

Can anyone offer more info on what is happening? Is it possible to restore type stability to the 4x4 case? (It wouldn’t be very practical to use this function on much larger matrices anyway.)

Here is the output of @code_warntype on _det for the 3x3 case

julia> @code_warntype _det((1,2,3), NTuple{3}(eachcol(mat3x3)))

MethodInstance for _det(::Tuple{Int64, Int64, Int64}, ::Tuple{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}})
  from _det(inds::Tuple{Vararg{Int64, N}}, ::Tuple{Vararg{Any, N}}) where N in Main at In[11]:7
Static Parameters
  N = 3
Arguments
  #self#::Core.Const(_det)
  inds::Tuple{Int64, Int64, Int64}
  @_3::Tuple{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}
Locals
  #25::var"#25#26"{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, Tuple{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}}
  @_5::Int64
  v::SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}
  vs::Tuple{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}
Body::Float64
1 ─       nothing
β”‚   %2  = Base.indexed_iterate(@_3, 1)::Core.PartialStruct(Tuple{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, Int64}, Any[SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, Core.Const(2)])
β”‚         (v = Core.getfield(%2, 1))
β”‚         (@_5 = Core.getfield(%2, 2))
β”‚         (vs = Base.rest(@_3, @_5::Core.Const(2)))
β”‚   %6  = Main.:(var"#25#26")::Core.Const(var"#25#26")
β”‚   %7  = Core.typeof(v)::Core.Const(SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true})
β”‚   %8  = Core.typeof(vs)::Core.Const(Tuple{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}})
β”‚   %9  = Core.apply_type(%6, %7, %8)::Core.Const(var"#25#26"{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, Tuple{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}})
β”‚   %10 = v::SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}
β”‚         (#25 = %new(%9, %10, vs))
β”‚   %12 = #25::var"#25#26"{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, Tuple{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}}
β”‚   %13 = Core.apply_type(Main.NTuple, $(Expr(:static_parameter, 1)), Main.Int)::Core.Const(Tuple{Int64, Int64, Int64})
β”‚   %14 = Main.altsigns(Main.Int)::Core.Const(Base.Iterators.Cycle{Tuple{Int64, Int64}}((1, -1)))
β”‚   %15 = (%13)(%14)::Core.Const((1, -1, 1))
β”‚   %16 = Main.swapeach_1st(inds)::Tuple{Tuple{Int64, Int64, Int64}, Tuple{Int64, Int64, Int64}, Tuple{Int64, Int64, Int64}}
β”‚   %17 = Main.map(%12, %15, %16)::Tuple{Float64, Float64, Float64}
β”‚   %18 = Core._apply_iterate(Base.iterate, Main.:+, %17)::Float64
└──       return %18

and for the 4x4 case

julia> @code_warntype _det((1,2,3,4), NTuple{4}(eachcol(mat4x4)))

MethodInstance for _det(::NTuple{4, Int64}, ::NTuple{4, SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}})
  from _det(inds::Tuple{Vararg{Int64, N}}, ::Tuple{Vararg{Any, N}}) where N in Main at In[10]:7
Static Parameters
  N = 4
Arguments
  #self#::Core.Const(_det)
  inds::NTuple{4, Int64}
  @_3::NTuple{4, SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}
Locals
  #23::var"#23#24"{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, Tuple{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}}
  @_5::Int64
  v::SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}
  vs::Tuple{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}
Body::Any
1 ─       nothing
β”‚   %2  = Base.indexed_iterate(@_3, 1)::Core.PartialStruct(Tuple{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, Int64}, Any[SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, Core.Const(2)])
β”‚         (v = Core.getfield(%2, 1))
β”‚         (@_5 = Core.getfield(%2, 2))
β”‚         (vs = Base.rest(@_3, @_5::Core.Const(2)))
β”‚   %6  = Main.:(var"#23#24")::Core.Const(var"#23#24")
β”‚   %7  = Core.typeof(v)::Core.Const(SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true})
β”‚   %8  = Core.typeof(vs)::Core.Const(Tuple{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}})
β”‚   %9  = Core.apply_type(%6, %7, %8)::Core.Const(var"#23#24"{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, Tuple{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}})
β”‚   %10 = v::SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}
β”‚         (#23 = %new(%9, %10, vs))
β”‚   %12 = #23::var"#23#24"{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, Tuple{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}}
β”‚   %13 = Core.apply_type(Main.NTuple, $(Expr(:static_parameter, 1)), Main.Int)::Core.Const(NTuple{4, Int64})
β”‚   %14 = Main.altsigns(Main.Int)::Core.Const(Base.Iterators.Cycle{Tuple{Int64, Int64}}((1, -1)))
β”‚   %15 = (%13)(%14)::Core.Const((1, -1, 1, -1))
β”‚   %16 = Main.swapeach_1st(inds)::NTuple{4, NTuple{4, Int64}}
β”‚   %17 = Main.map(%12, %15, %16)::NTuple{4, Any}
β”‚   %18 = Core._apply_iterate(Base.iterate, Main.:+, %17)::Any
└──       return %18

Looks fine for me on 1.9.0

using Static, ArrayInterface, BenchmarkTools

# Compute determinant of a matrix with equal-length row/column vectors `vs`
det(vs::Vararg{Any,N}) where N = _det(ntuple(identity, N), vs)

@inline _det((ind,)::NTuple{1,Int}, (v,)::NTuple{1,Any}) = getindex(v, ind)

@inline _det(inds::NTuple{N,Int}, (v, vs...)::NTuple{N,Any}) where N =
    +(map((sgn, (i, is...)) -> flipsign(v[i], sgn) * _det(is, vs),
          NTuple{N,Int}(altsigns(Int)), swapeach_1st(inds))...)

using Base.Iterators: map as imap
const CanonicalInt = Union{StaticInt, Int}

# Like `ntuple`, but uses `StaticInt`s
sntuple(f, n::StaticInt) = _sntuple(f, n, n)
_sntuple(_, ::StaticInt, ::StaticInt{0}) = ()
_sntuple(f, n::StaticInt, i::StaticInt) =
    (f(n - i + static(1)), _sntuple(f, n, i - static(1))...)

altsigns(init) = Iterators.cycle((init, -init))
altsigns(T::Type) = altsigns(oneunit(T))
# altsigns(Int) -> (1, -1, 1, -1, ...)

# Return the input tuple with the element at the index `i` swapped with the
# first element.
# julia> swap_1st(static(2), ('a', 'b', 'c'))
# ('b', 'a', 'c')
function swap_1st(i::StaticInt, tup::Tuple)
    N = ArrayInterface.static_length(tup)
    1 ≀ i ≀ N || return tup
    (tup[i], sntuple(j -> tup[j], i - static(1))...,
            sntuple(j -> tup[j+i], N - i)...)
end

# Return a tuple of tuples where the ith tuple swaps the 1st and i-th elements
# of the input tuple.
# julia> swapeach_1st(('a', 'b', 'c'))
# (('a', 'b', 'c'), ('b', 'a', 'c'), ('c', 'a', 'b'))
swapeach_1st(tup::NTuple{N,Any}) where {N} =
    sntuple(i -> swap_1st(i, tup), static(N))          

@btime det(NTuple{3}(eachcol(rand(3,3)))...)  # type stable -> Float64
@btime det(NTuple{4}(eachcol(rand(4,4)))...)

@code_warntype det(NTuple{3}(eachcol(rand(3,3)))...)  # type stable -> Float64
@code_warntype det(NTuple{4}(eachcol(rand(4,4)))...)

yields

  118.053 ns (1 allocation: 128 bytes)
  174.590 ns (1 allocation: 192 bytes)
MethodInstance for det(::SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, ::SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, ::SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true})
  from det(vs::Vararg{Any, N}) where N in Main at ...
Static Parameters
  N = 3
Arguments
  #self#::Core.Const(det)
  vs::Tuple{SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}
Body::Float64
1 ─ %1 = Main.ntuple(Main.identity, $(Expr(:static_parameter, 1)))::Core.Const((1, 2, 3))
β”‚   %2 = Main._det(%1, vs)::Float64
└──      return %2

MethodInstance for det(::SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, ::SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, ::SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}, ::SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true})
  from det(vs::Vararg{Any, N}) where N in Main at ...
Static Parameters
  N = 4
Arguments
  #self#::Core.Const(det)
  vs::NTuple{4, SubArray{Float64, 1, Matrix{Float64}, Tuple{Base.Slice{Base.OneTo{Int64}}, Int64}, true}}
Body::Float64
1 ─ %1 = Main.ntuple(Main.identity, $(Expr(:static_parameter, 1)))::Core.Const((1, 2, 3, 4))
β”‚   %2 = Main._det(%1, vs)::Float64
└──      return %2

Edit: posted complete code and corrected results…

1 Like

Thanks @goerch!

I think you should be getting a Float64 return value, though. det expects each vector as a separate argument. (There’s a typo in my original post that I corrected but may have made things confusing.) Does

det(NTuple{4}(eachcol(rand(4,4)))...)

infer to a Float64?

I don’t have 1.9.0 installed, but that would be great news if there have been enhancements that allow this to work.

Edit: Thanks again @goerch :grinning:! Very cool that this works in 1.9.0. Thankfully, for my potential application of this pattern, I only need the 2x2 and 3x3 cases (for now), so I can afford to wait for it.

I am wondering to what extent @generated functions, with their many gotchas, will be really needed in the future. It seems to me that you can do quite general β€œcode generation” using just @inline, ntuple, map, and packages like Static.

BTW I tried things initially on julia 1.7.2. There, other functions with a recursive call to themselves made inside map seem to have similar issues β€” they seem to go type-unstable if 3 or more recursive calls occur.

1 Like

In a real application, I would just use SMatrix from StaticArrays.jl, which already has these cases built-in.

2 Likes

Alright, I am still a little confused still after trying a simpler example than my original determinant problem. I tried using julia 1.7.2 as well as the nightly build and got the same results for the functions below.

The function pop returns the ith element of a tuple xs at position i and a new tuple with that element removed.

@inline pop(xs::Tuple, i::Integer) = _pop((), xs, i)
@inline _pop(out::Tuple, ::Tuple{}, i::Integer) = error("index out of bounds")
@inline _pop(out::Tuple, (x, xs...)::Tuple, i::Integer) =
    i == 1 ? (x, (out..., xs...)) : _pop((out..., x), xs, i-1)

It is type-stable only for input tuples of length three or fewer. Trying to apply it to longer tuples results in allocations.

An alternative approach somehow avoids the type instability (I tested on input tuples up to length 10), but I am not sure why.

@inline deleteat(xs::Tuple, i) =
@inline deleteat(::Tuple{}, i) = error("index out of bounds")
@inline deleteat((x, xs...)::Tuple, i=1) =
    i == 1 ? xs : (x, deleteat(xs, i-1)...)

pop2(xs::Tuple, i::Integer) =
    1 ≀ i ≀ length(xs) ? (xs[i], deleteat(xs, i)) : error("index out of bounds")

I guess I am just wondering if there are rules to be aware of here…

1 Like

I asked a similar question earlier, and Shuhei suggested following the approach used here:

which solved the problem in my case.
Except he suggested using @nospecialize(args...).

1 Like