I’m trying to improve the performance of mapslices, which is needed for in https://github.com/JuliaLang/julia/issues/3893#issuecomment-304924903. The only changes I’ve made are to rename things so I understand (a little) what’s going on and some small rearrangements. I haven’t changed performance as far as I can tell. @tim.holy mentioned he might have some leads (and something about a blog post?) which is good cause I’m lost.
function my_mapslices(f, input::AbstractArray, sliced_dimensions::AbstractVector)
if isempty(sliced_dimensions)
return map(f,input)
end
axes = [indices(input)...]
indexed_dimensions = setdiff([1:ndims(input);], sliced_dimensions)
input_index = Any[first(index) for index in indices(input)]
for dimension in sliced_dimensions
input_index[dimension] = Base.Slice(indices(input, dimension))
end
# Apply the function to the first slice in order to determine the next steps
input_slice = input[input_index...]
first_output = f(input_slice)
# In some cases, we can re-use the first slice for a dramatic performance
# increase. The slice itself must be mutable and the result cannot contain
# any mutable containers. The following errs on the side of being overly
# strict (#18570 & #21123).
safe_for_reuse =
isa(input_slice, StridedArray) &&
(isa(first_output, Number) ||
(isa(first_output, AbstractArray) && eltype(first_output) <: Number))
# determine result size and allocate
result_axes = copy(axes)
# TODO: maybe support removing dimensions
if !isa(first_output, AbstractArray) || ndims(first_output) == 0
first_output = [first_output]
end
number_of_trivial_output_axes =
max(0, length(sliced_dimensions) - ndims(first_output))
if eltype(result_axes) == Int
result_axes[sliced_dimensions] =
[size(first_output)...,
ntuple(dimension->1, number_of_trivial_output_axes)...]
else
result_axes[sliced_dimensions] =
[indices(first_output)...,
ntuple(dimension -> Base.OneTo(1), number_of_trivial_output_axes)...]
end
result = similar(first_output, tuple(result_axes...,))
result_index = Any[map(first, indices(result))...]
for dimension in sliced_dimensions
result_index[dimension] = indices(result, dimension)
end
result[result_index...] = first_output
over_indexed_dimensions = 1:length(indexed_dimensions)
update_indexes!(input_index, result_index, index_i) =
for indexed_dimension in over_indexed_dimensions
dimension = indexed_dimensions[indexed_dimension]
input_index[dimension] =
result_index[dimension] =
index_i[indexed_dimension]
end
# skip the first element, we already handled it
indexes = Iterators.drop(CartesianRange(tuple(axes[indexed_dimensions]...) ), 1)
if safe_for_reuse
# when f returns an input, result[result_index...] = f(input_slice) line copies elements,
# so we can reuse input_slice
for index in indexes
update_indexes!(input_index, result_index, index.I)
Base._unsafe_getindex!(input_slice, input, input_index...)
result[result_index...] = f(input_slice)
end
else
# we can't guarantee safety (#18524), so allocate new storage for each slice
for index in indexes
update_indexes!(input_index, result_index, index.I)
result[result_index...] = f(input[input_index...])
end
end
result
end