Hi!
in Base we have dropdims
which effectively calls reshape
.
Today, I was confused that we don’t have insertdims
being the inverse to the former.
There is already one implementation in MLUtils.jl for a single dim. However, it should be possible to generalize that to tuple of dims.
Is there a reason we don’t have that in Base?
Best,
Felix
1 Like
I tried to implement insertdims
similarly to dropdims
.
Is that something we should try to add to Base?
insertdims(A; dims) = _insertdims(A, dims)
function _insertdims(A::AbstractArray{T, N}, dims::Tuple{Vararg{Int64, M}}) where {T, N, M}
for i in eachindex(dims)
for j = 1:i-1
dims[j] == dims[i] && throw(ArgumentError("inserted dims must be unique"))
end
end
# sorted list of dims
new_dims = _sortedmerge(ntuple(identity, Val(N)), dims)
for i in 2:length(new_dims)
new_dims[i-1] == new_dims[i] || new_dims[i-1] + 1 == new_dims[i] ||
throw(ArgumentError("inserted dims and existing dims must be contiguos"))
end
# n is the amount of the dims already inserted
ax_n = Base._foldoneto(((ds, n), d) -> d in dims ? ((ds..., 1), n+1) : ((ds..., axes(A,d - n)), n),
((), 0), Val(ndims(A) + length(dims)))
# we need only the new shape and not n
reshape(A, ax_n[1])::AbstractArray{T, N + M}
end
_insertdims(A::AbstractArray, dim::Integer) = _insertdims(A, (Int(dim),))
_sortedmerge(::Tuple{}, ::Tuple{}) = ()
_sortedmerge(::Tuple{}, s::Tuple) = s
_sortedmerge(t::Tuple, s::Tuple{}) = t
_sortedmerge(t::Tuple, s::Tuple) = (first(s) < first(t) ? (first(s), _sortedmerge(t, Base.tail(s))...)
: (first(t), _sortedmerge(Base.tail(t), s)...))