Type stability in function involving selectdim

So I’m trying to write a function to perform trapezoidal rule:

function trapzNd(A, t; dims=1)
    N = size(A, dims)
    
    dt_shape = ones(Int64, ndims(A))
    dt_shape[dims] = length(t) - 1
    dt = reshape(diff(t), dt_shape...)   
    ans = 0.5*sum((selectdim(A, dims, 1:N-1) + selectdim(A, dims, 2:N)) .* dt, dims=dims)

    return ans
end

Running code_warntype gives:

A_test = ones(256, 501)
t_test = range(0, 1, 501)

@code_warntype trapzNd(A_test, t_test; dims=2)

MethodInstance for (::var"#trapzNd##kw")(::NamedTuple{(:dims,), Tuple{Int64}}, ::typeof(trapzNd), ::Matrix{Float64}, ::StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64})
  from (::var"#trapzNd##kw")(::Any, ::typeof(trapzNd), A, t) in Main at In[99]:10
Arguments
  _::Core.Const(var"#trapzNd##kw"())
  @_2::NamedTuple{(:dims,), Tuple{Int64}}
  @_3::Core.Const(trapzNd)
  A::Matrix{Float64}
  t::StepRangeLen{Float64, Base.TwicePrecision{Float64}, Base.TwicePrecision{Float64}, Int64}
Locals
  dims::Int64
  @_7::Int64
Body::Any
1 ─ %1  = Base.haskey(@_2, :dims)::Core.Const(true)
│         Core.typeassert(%1, Core.Bool)
│         (@_7 = Base.getindex(@_2, :dims))
└──       goto #3
2 ─       Core.Const(:(@_7 = 1))
3 ┄ %6  = @_7::Int64
│         (dims = %6)
│   %8  = (:dims,)::Core.Const((:dims,))
│   %9  = Core.apply_type(Core.NamedTuple, %8)::Core.Const(NamedTuple{(:dims,)})
│   %10 = Base.structdiff(@_2, %9)::Core.Const(NamedTuple())
│   %11 = Base.pairs(%10)::Core.Const(Base.Pairs{Symbol, Union{}, Tuple{}, NamedTuple{(), Tuple{}}}())
│   %12 = Base.isempty(%11)::Core.Const(true)
│         Core.typeassert(%12, Core.Bool)
└──       goto #5
4 ─       Core.Const(:(Base.kwerr(@_2, @_3, A, t)))
5 ┄ %16 = Main.:(var"#trapzNd#10")(dims, @_3, A, t)::Any
└──       return %16

I tried to get rid of the parts involving t and dt inside the function, but with sum(selectdim(A, dims, 1:N-1) + selectdim(A, dims, 2:N) , dims=dims) it still gives Any. I guess there is some problem with using selectdim and this makes Julia fails to deduce the output because the array shape of the output depends on the input argument dims which is only known in runtime? If that’s the case, then how can I make the function type stable?

This is an unfortunate thing that I’ve run into as well. It seems like the most straightforward fix would be to allow a Val or similar type for d such that the value can propagate in type space. Couldn’t find an issue for it on the bug tracker so opened Add type stable `selectdim(x, ::Val, ...)` overload · Issue #46215 · JuliaLang/julia · GitHub.

selectdim(x, ::Val, ...) should be const-propable since 1.7, IIRC.
It does not support a Val call now but we can still

function trapzNd(A, t; dims::Val{Dims}=Val(1)) where {Dims}
    dims = Int(Dims)
    N = size(A, dims)
    
   # it is not a good idea to cache `dt_shape` in a `Vector`. As the length info would lost.
   # use `Tuple` and `Base.setindex` instead.
    dt_shape = ntuple(one, ndims(A))
    dt_shape = Base.setindex(dt_shape, length(t) - 1, dims)
    dt = reshape(diff(t), dt_shape...)
    ans = 0.5 * sum((selectdim(A, dims, 1:N-1) + selectdim(A, dims, 2:N)) .* dt, dims=dims)

    return ans
end
trapzNd(A_test, t_test; dims=Val(2)) # with stable result

In fact, if you fix the instability from dt_shape::Vector. You can define the function as

Base.@constprop :aggressive function trapzNd_cp(A, t; dims = 1)
    N = size(A, dims)

    dt_shape = ntuple(one, ndims(A))
    dt_shape = Base.setindex(dt_shape, length(t) - 1, dims)
    dt = reshape(diff(t), dt_shape...)
    ans = 0.5 * sum((selectdim(A, dims, 1:N-1) + selectdim(A, dims, 2:N)) .* dt, dims=dims)

    return ans
end
f(A, t) = trapzNd_cp(A, t; dims = 2)

f(A_test, t_test) will be stable even if trapzNd_cp(A_test, t_test; dims = 2) is not stable.

Maybe use Static.jl?