Bikeshedding mapslices

I’m trying to improve the performance of mapslices, which is needed for in https://github.com/JuliaLang/julia/issues/3893#issuecomment-304924903. The only changes I’ve made are to rename things so I understand (a little) what’s going on and some small rearrangements. I haven’t changed performance as far as I can tell. @tim.holy mentioned he might have some leads (and something about a blog post?) which is good cause I’m lost.

function my_mapslices(f, input::AbstractArray, sliced_dimensions::AbstractVector)
    if isempty(sliced_dimensions)
        return map(f,input)
    end

    axes = [indices(input)...]

    indexed_dimensions = setdiff([1:ndims(input);], sliced_dimensions)

    input_index = Any[first(index) for index in indices(input)]

    for dimension in sliced_dimensions
        input_index[dimension] = Base.Slice(indices(input, dimension))
    end

    # Apply the function to the first slice in order to determine the next steps
    input_slice = input[input_index...]
    first_output = f(input_slice)
    # In some cases, we can re-use the first slice for a dramatic performance
    # increase. The slice itself must be mutable and the result cannot contain
    # any mutable containers. The following errs on the side of being overly
    # strict (#18570 & #21123).
    safe_for_reuse = 
        isa(input_slice, StridedArray) && 
        (isa(first_output, Number) || 
            (isa(first_output, AbstractArray) && eltype(first_output) <: Number))

    # determine result size and allocate
    result_axes = copy(axes)
    # TODO: maybe support removing dimensions
    if !isa(first_output, AbstractArray) || ndims(first_output) == 0
        first_output = [first_output]
    end
    number_of_trivial_output_axes = 
        max(0, length(sliced_dimensions) - ndims(first_output)) 

    if eltype(result_axes) == Int
        result_axes[sliced_dimensions] = 
            [size(first_output)..., 
                ntuple(dimension->1, number_of_trivial_output_axes)...]
    else
        result_axes[sliced_dimensions] = 
            [indices(first_output)...,
                ntuple(dimension -> Base.OneTo(1), number_of_trivial_output_axes)...]
    end
    result = similar(first_output, tuple(result_axes...,))

    result_index = Any[map(first, indices(result))...]
    for dimension in sliced_dimensions
        result_index[dimension] = indices(result, dimension)
    end

    result[result_index...] = first_output

    over_indexed_dimensions = 1:length(indexed_dimensions)

    update_indexes!(input_index, result_index, index_i) =
        for indexed_dimension in over_indexed_dimensions
            dimension = indexed_dimensions[indexed_dimension]
            input_index[dimension] = 
                result_index[dimension] = 
                index_i[indexed_dimension]
        end

    # skip the first element, we already handled it
    indexes = Iterators.drop(CartesianRange(tuple(axes[indexed_dimensions]...) ), 1)

    if safe_for_reuse
        # when f returns an input, result[result_index...] = f(input_slice) line copies elements,
        # so we can reuse input_slice
        for index in indexes
            update_indexes!(input_index, result_index, index.I)
            Base._unsafe_getindex!(input_slice, input, input_index...)
            result[result_index...] = f(input_slice)
        end
    else
        # we can't guarantee safety (#18524), so allocate new storage for each slice
        for index in indexes
            update_indexes!(input_index, result_index, index.I)
            result[result_index...] = f(input[input_index...])
        end
    end

    result
end

The easiest approach will be to split it into two functions: the “outer” function does all the type-unstable stuff, then calls a type-stable “inner” function marked with @noinline (this is the “function barrier technique”). Basically, the stuff above the definition of your indexes variable is all the type-unstable stuff; the main loop is presumably the only part that’s urgently in need of inferrability for performance.

1 Like

So something like this?

@noinline update_indexes!(over_indexed_dimensions, indexed_dimensions, input_index, result_index, index_i) =
    for indexed_dimension in over_indexed_dimensions
        dimension = indexed_dimensions[indexed_dimension]
        input_index[dimension] = 
            result_index[dimension] = 
            index_i[indexed_dimension]
    end


@noinline function inner_mapslices!(axes, indexed_dimensions, safe_for_reuse, input_index, result_index, input_slice, input, result, f)
    # skip the first element, we already handled it
    over_indexed_dimensions = 1:length(indexed_dimensions)
    indexes = Iterators.drop(CartesianRange(tuple(axes[indexed_dimensions]...) ), 1)
    
    if safe_for_reuse
        # when f returns an input, result[result_index...] = f(input_slice) line copies elements,
        # so we can reuse input_slice
        for index in indexes
            update_indexes!(over_indexed_dimensions, indexed_dimensions, input_index, result_index, index.I)
            Base._unsafe_getindex!(input_slice, input, input_index...)
            result[result_index...] = f(input_slice)end
    else
        # we can't guarantee safety (#18524), so allocate new storage for each slice
        for index in indexes
            update_indexes!(over_indexed_dimensions, indexed_dimensions, input_index, result_index, index.I)
            result[result_index...] = f(input[input_index...])
        end
    end
    result
end

function my_mapslices(f, input::AbstractArray, sliced_dimensions::AbstractVector)
    if isempty(sliced_dimensions)
        return map(f,input)
    end

    axes = [indices(input)...]

    indexed_dimensions = setdiff([1:ndims(input);], sliced_dimensions)

    input_index = Any[first(index) for index in indices(input)]

    for dimension in sliced_dimensions
        input_index[dimension] = Base.Slice(indices(input, dimension))
    end

    # Apply the function to the first slice in order to determine the next steps
    input_slice = input[input_index...]
    first_output = f(input_slice)
    # In some cases, we can re-use the first slice for a dramatic performance
    # increase. The slice itself must be mutable and the result cannot contain
    # any mutable containers. The following errs on the side of being overly
    # strict (#18570 & #21123).
    safe_for_reuse = 
        isa(input_slice, StridedArray) && 
        (isa(first_output, Number) || 
            (isa(first_output, AbstractArray) && eltype(first_output) <: Number))

    # determine result size and allocate
    result_axes = copy(axes)
    # TODO: maybe support removing dimensions
    if !isa(first_output, AbstractArray) || ndims(first_output) == 0
        first_output = [first_output]
    end
    number_of_trivial_output_axes = 
        max(0, length(sliced_dimensions) - ndims(first_output)) 

    if eltype(result_axes) == Int
        result_axes[sliced_dimensions] = 
            [size(first_output)..., 
                ntuple(dimension->1, number_of_trivial_output_axes)...]
    else
        result_axes[sliced_dimensions] = 
            [indices(first_output)...,
                ntuple(dimension -> Base.OneTo(1), number_of_trivial_output_axes)...]
    end
    result = similar(first_output, tuple(result_axes...,))

    result_index = Any[map(first, indices(result))...]
    for dimension in sliced_dimensions
        result_index[dimension] = indices(result, dimension)
    end

    result[result_index...] = first_output

    inner_mapslices!(axes, indexed_dimensions, safe_for_reuse, input_index, result_index, input_slice, input, result, f)
end

f = sum
input = rand(100, 100, 100, 100)
sliced_dimensions = [2, 4]

using BenchmarkTools
result1 = @benchmark mapslices(f, input, sliced_dimensions)
result2 = @benchmark my_mapslices(f, input, sliced_dimensions)

Not seeing any time improvements…

Wait, nevermind. This seems to have cut down time by about 10%:

@noinline update_indexes!(over_indexed_dimensions, indexed_dimensions, 
    input_index, result_index, index_i) =

    for indexed_dimension in over_indexed_dimensions
        dimension = indexed_dimensions[indexed_dimension]
        input_index[dimension] = 
            result_index[dimension] = 
            index_i[indexed_dimension]
    end


@noinline function inner_mapslices!(axes, indexes, over_indexed_dimensions, 
    indexed_dimensions, safe_for_reuse, input_index, result_index, input_slice, 
    input, result, f)

    # skip the first element, we already handled it
    
    if safe_for_reuse
        # when f returns an input, result[result_index...] = f(input_slice) line copies elements,
        # so we can reuse input_slice
        for index in indexes
            update_indexes!(over_indexed_dimensions, indexed_dimensions, 
                input_index, result_index, index.I)
            Base._unsafe_getindex!(input_slice, input, input_index...)
            result[result_index...] = f(input_slice)end
    else
        # we can't guarantee safety (#18524), so allocate new storage for each slice
        for index in indexes
            update_indexes!(over_indexed_dimensions, indexed_dimensions,
                input_index, result_index, index.I)
            result[result_index...] = f(input[input_index...])
        end
    end
    result
end

function my_mapslices(f, input::AbstractArray, sliced_dimensions::AbstractVector)
    if isempty(sliced_dimensions)
        return map(f,input)
    end

    axes = [indices(input)...]

    indexed_dimensions = setdiff([1:ndims(input);], sliced_dimensions)

    input_index = Any[first(index) for index in indices(input)]

    for dimension in sliced_dimensions
        input_index[dimension] = Base.Slice(indices(input, dimension))
    end

    # Apply the function to the first slice in order to determine the next steps
    input_slice = input[input_index...]
    first_output = f(input_slice)
    # In some cases, we can re-use the first slice for a dramatic performance
    # increase. The slice itself must be mutable and the result cannot contain
    # any mutable containers. The following errs on the side of being overly
    # strict (#18570 & #21123).
    safe_for_reuse = 
        isa(input_slice, StridedArray) && 
        (isa(first_output, Number) || 
            (isa(first_output, AbstractArray) && eltype(first_output) <: Number))

    # determine result size and allocate
    result_axes = copy(axes)
    # TODO: maybe support removing dimensions
    if !isa(first_output, AbstractArray) || ndims(first_output) == 0
        first_output = [first_output]
    end
    number_of_trivial_output_axes = 
        max(0, length(sliced_dimensions) - ndims(first_output)) 

    if eltype(result_axes) == Int
        result_axes[sliced_dimensions] = 
            [size(first_output)..., 
                ntuple(dimension->1, number_of_trivial_output_axes)...]
    else
        result_axes[sliced_dimensions] = 
            [indices(first_output)...,
                ntuple(dimension -> Base.OneTo(1), number_of_trivial_output_axes)...]
    end
    result = similar(first_output, tuple(result_axes...,))

    result_index = Any[map(first, indices(result))...]
    for dimension in sliced_dimensions
        result_index[dimension] = indices(result, dimension)
    end

    result[result_index...] = first_output

    indexes = Iterators.drop(CartesianRange(tuple(axes[indexed_dimensions]...) ), 1)
    over_indexed_dimensions = 1:length(indexed_dimensions)

    inner_mapslices!(axes, indexes, over_indexed_dimensions, 
        indexed_dimensions, safe_for_reuse, input_index, 
        result_index, input_slice, input, result, f)
end

This also works but no performance improvement:

@noinline function replace_tuples(
    rank, indexed_dimensions, 
    input_index, result_index, iterated_index)

    function replace_tuples_recur(new_input_index = (), new_result_index = (), i = 1, j = 1)
        if i <= rank
            if i in indexed_dimensions                
                new_input_index = replace_tuples_recur(
                    (new_input_index..., iterated_index[j] ),
                    (new_result_index..., iterated_index[j] ),
                    i + 1, j + 1)
            else
                new_input_index = replace_tuples_recur(
                    (new_input_index..., input_index[i] ),
                    (new_result_index...,  result_index[i] ),
                    i + 1, j)
                end
            end
        else
            new_input_index, new_result_index
        end
    end
    replace_tuples_recur()
end

@noinline function inner_mapslices!(f, input, result,
    rank, indexed_dimensions,
    indexes, input_index, result_index,
    safe_for_reuse, input_slice
)
    if safe_for_reuse
        # when f returns an input, result[result_index...] = f(input_slice) line copies elements,
        # so we can reuse input_slice
        for index in indexes
            new_input_index, new_result_index = 
                replace_tuples(
                    rank, indexed_dimensions, 
                    input_index, result_index, index.I)
            Base._unsafe_getindex!(input_slice, input, new_input_index...)
            result[new_result_index...] = f(input_slice)
        end
    else
        # we can't guarantee safety (#18524), so allocate new storage for each slice
        for index in indexes
            new_input_index, new_result_index = 
                replace_tuples(
                    rank, indexed_dimensions, 
                    input_index, result_index, index.I)
            result[new_result_index...] = f(input[new_input_index...])
        end
    end
    result
end

function my_mapslices(f, input::AbstractArray, sliced_dimensions::AbstractVector)
    if isempty(sliced_dimensions)
        return map(f,input)
    end

    axes = [indices(input)...]
    rank = ndims(input)
    indexed_dimensions = setdiff([1:rank;], sliced_dimensions)

    input_index = Any[first(index) for index in indices(input)]
    for dimension in sliced_dimensions
        input_index[dimension] = Base.Slice(indices(input, dimension))
    end

    # Apply the function to the first slice in order to determine the next steps
    input_slice = input[input_index...]
    first_output = f(input_slice)
    # In some cases, we can re-use the first slice for a dramatic performance
    # increase. The slice itself must be mutable and the result cannot contain
    # any mutable containers. The following errs on the side of being overly
    # strict (#18570 & #21123).
    safe_for_reuse = 
        isa(input_slice, StridedArray) && 
        (isa(first_output, Number) || 
            (isa(first_output, AbstractArray) && eltype(first_output) <: Number))

    # determine result size and allocate
    result_axes = copy(axes)
    # TODO: maybe support removing dimensions
    if !isa(first_output, AbstractArray) || ndims(first_output) == 0
        first_output = [first_output]
    end
    number_of_trivial_output_axes = 
        max(0, length(sliced_dimensions) - ndims(first_output)) 

    if eltype(result_axes) == Int
        result_axes[sliced_dimensions] = 
            [size(first_output)..., 
                ntuple(dimension->1, number_of_trivial_output_axes)...]
    else
        result_axes[sliced_dimensions] = 
            [indices(first_output)...,
                ntuple(dimension -> Base.OneTo(1), number_of_trivial_output_axes)...]
    end
    result = similar(first_output, tuple(result_axes...,))

    result_index = Any[map(first, indices(result))...]
    for dimension in sliced_dimensions
        result_index[dimension] = indices(result, dimension)
    end

    result[result_index...] = first_output

    # skip the first element, we already handled it
    indexes = Iterators.drop(CartesianRange(tuple(axes[indexed_dimensions]...) ), 1)

    inner_mapslices!(f, input, result,
        rank, indexed_dimensions,
        indexes, (input_index...), (result_index...),
        safe_for_reuse, input_slice, 
    )
end

f = sum
input = rand(100, 100, 100, 100)
sliced_dimensions = [2, 4]

using BenchmarkTools 
Test.@test mapslices(f, input, sliced_dimensions) == my_mapslices(f, input, sliced_dimensions)
result1 = @benchmark mapslices(f, input, sliced_dimensions)
result2 = @benchmark my_mapslices(f, input, sliced_dimensions)

Note how I wrapped replace_tuples_recur inside replace_tuples. This seemed to prevent the ridiculous amount of codegen that made it impossible to compile replace_tuples_recur

Presumably the improvements depends on how you’re testing it; if I use

A = rand(1000, 1000)

then I see about a 2x improvement over mapslices(mean, A, 1); moreover, it’s within a factor of 2 of sum(A, 1) which isn’t too bad.

However, for a smaller matrix (e.g., 5x5) you have ~50x overhead compared to sum(A, 1). To do better, you’re going to have to switch most of the array-based index logic to tuples, because only with tuples can the compiler infer the dimensionality of your indexing operations. On the “real” mapslices function, sliced_dimensions should be a Dims tuple (NOT an AbstractVector). You will likely want to mimic some of the logic in broadcast.jl, for example in these lines. The main trick here is “lispy tuple iteration,” where you peel off the first element and process it, then recursively call the same function on the remaining elements until you’ve exhausted them all (Tuple{} is the type of an empty tuple). With generous use of @inline, the compiler converts the recursive sequence into something you might write by hand, e.g., map(f, x) where x is a 3-tuple would get converted into (f(x[1]), f(x[2]), f(x[3])) which is pretty hard to beat in terms of efficiency.

3 Likes

Great! Ok, here’s another version with a replace function modeled after those broadcast lines:

@inline replace_tuple(old_index::Tuple{}, replace_index, should_replace) = ()
@inline replace_tuple(old_index, replace_index, should_replace) = 
    if should_replace[1]
        (replace_index[1], replace_tuple(Base.tail(old_index), Base.tail(replace_index), Base.tail(should_replace) )...)
    else
        (old_index[1], replace_tuple(Base.tail(old_index), replace_index, Base.tail(should_replace) )...)
    end

@noinline function inner_mapslices!(f, input, result,
    dimension_is_indexed, indexes, input_index, result_index,
    safe_for_reuse, input_slice
)
    if safe_for_reuse
        # when f returns an input, result[result_index...] = f(input_slice) line copies elements,
        # so we can reuse input_slice
        for index in indexes
            Base._unsafe_getindex!(input_slice, input, replace_tuple(input_index, index.I, dimension_is_indexed)...)
            result[replace_tuple(result_index, index.I, dimension_is_indexed)...] = f(input_slice)
        end
    else
        # we can't guarantee safety (#18524), so allocate new storage for each slice
        for index in indexes
            result[replace_tuple(result_index, index.I, dimension_is_indexed)...] = 
                f(input[replace_tuple(input_index, index.I, dimension_is_indexed)...])
        end
    end
    result
end

function my_mapslices(f, input::AbstractArray, sliced_dimensions::AbstractVector)
    if isempty(sliced_dimensions)
        return map(f,input)
    end

    axes = [indices(input)...]
    rank = ndims(input)

    input_index = Any[first(index) for index in indices(input)]
    for dimension in sliced_dimensions
        input_index[dimension] = Base.Slice(indices(input, dimension))
    end

    # Apply the function to the first slice in order to determine the next steps
    input_slice = input[input_index...]
    first_output = f(input_slice)
    # In some cases, we can re-use the first slice for a dramatic performance
    # increase. The slice itself must be mutable and the result cannot contain
    # any mutable containers. The following errs on the side of being overly
    # strict (#18570 & #21123).
    safe_for_reuse = 
        isa(input_slice, StridedArray) && 
        (isa(first_output, Number) || 
            (isa(first_output, AbstractArray) && eltype(first_output) <: Number))

    # determine result size and allocate
    result_axes = copy(axes)
    # TODO: maybe support removing dimensions
    if !isa(first_output, AbstractArray) || ndims(first_output) == 0
        first_output = [first_output]
    end
    number_of_trivial_output_axes = 
        max(0, length(sliced_dimensions) - ndims(first_output)) 

    if eltype(result_axes) == Int
        result_axes[sliced_dimensions] = 
            [size(first_output)..., 
                ntuple(dimension->1, number_of_trivial_output_axes)...]
    else
        result_axes[sliced_dimensions] = 
            [indices(first_output)...,
                ntuple(dimension -> Base.OneTo(1), number_of_trivial_output_axes)...]
    end
    result = similar(first_output, tuple(result_axes...,))

    result_index = Any[map(first, indices(result))...]
    for dimension in sliced_dimensions
        result_index[dimension] = indices(result, dimension)
    end

    result[result_index...] = first_output

    # skip the first element, we already handled it
    indexes = Iterators.drop(CartesianRange(tuple(axes[setdiff([1:rank;], sliced_dimensions)]...) ), 1)

    dimension_is_indexed = map(1:rank) do dimension
        !(dimension in sliced_dimensions)
    end

    inner_mapslices!(f, input, result,
        (dimension_is_indexed...), indexes, (input_index...), (result_index...),
        safe_for_reuse, input_slice, 
    )
end

f = sum
input = rand(5, 5)
sliced_dimensions = 1

using BenchmarkTools 
Test.@test mapslices(f, input, sliced_dimensions) == my_mapslices(f, input, sliced_dimensions)
result1 = @benchmark mapslices(f, input, sliced_dimensions)
result2 = @benchmark my_mapslices(f, input, [sliced_dimensions])
result3 = @benchmark sum(input, 1)

Doesn’t seem to have done the trick?

And another attempt:

@inline replace_tuple(old_index::Tuple{}, replace_index, should_replace) = ()
@inline replace_tuple(old_index, replace_index, should_replace) = 
    if should_replace[1]
        (replace_index[1], replace_tuple(Base.tail(old_index), Base.tail(replace_index), Base.tail(should_replace) )...)
    else
        (old_index[1], replace_tuple(Base.tail(old_index), replace_index, Base.tail(should_replace) )...)
    end

@inline function inner_mapslices!(f, input, result,
    dimension_is_indexed, indexes, input_index, result_index,
    safe_for_reuse, input_slice
)
    if safe_for_reuse
        # when f returns an input, result[result_index...] = f(input_slice) line copies elements,
        # so we can reuse input_slice
        for index in indexes
            Base._unsafe_getindex!(input_slice, input, replace_tuple(input_index, index.I, dimension_is_indexed)...)
            result[replace_tuple(result_index, index.I, dimension_is_indexed)...] = f(input_slice)
        end
    else
        # we can't guarantee safety (#18524), so allocate new storage for each slice
        for index in indexes
            result[replace_tuple(result_index, index.I, dimension_is_indexed)...] = 
                f(input[replace_tuple(input_index, index.I, dimension_is_indexed)...])
        end
    end
    result
end

@noinline inner_mapslices_noinline!(f, input, result,
    dimension_is_indexed, indexes, input_index, result_index,
    safe_for_reuse, input_slice
) = inner_mapslices!(f, input, result,
    dimension_is_indexed, indexes, input_index, result_index,
    safe_for_reuse, input_slice
)

function my_mapslices(f, input::AbstractArray, sliced_dimensions::AbstractVector)
    if isempty(sliced_dimensions)
        return map(f,input)
    end

    axes = [indices(input)...]
    rank = ndims(input)

    input_index = Any[first(index) for index in indices(input)]
    for dimension in sliced_dimensions
        input_index[dimension] = Base.Slice(indices(input, dimension))
    end

    # Apply the function to the first slice in order to determine the next steps
    input_slice = input[input_index...]
    first_output = f(input_slice)
    # In some cases, we can re-use the first slice for a dramatic performance
    # increase. The slice itself must be mutable and the result cannot contain
    # any mutable containers. The following errs on the side of being overly
    # strict (#18570 & #21123).
    safe_for_reuse = 
        isa(input_slice, StridedArray) && 
        (isa(first_output, Number) || 
            (isa(first_output, AbstractArray) && eltype(first_output) <: Number))

    # determine result size and allocate
    result_axes = copy(axes)
    # TODO: maybe support removing dimensions
    if !isa(first_output, AbstractArray) || ndims(first_output) == 0
        first_output = [first_output]
    end
    number_of_trivial_output_axes = 
        max(0, length(sliced_dimensions) - ndims(first_output)) 

    if eltype(result_axes) == Int
        result_axes[sliced_dimensions] = 
            [size(first_output)..., 
                ntuple(dimension->1, number_of_trivial_output_axes)...]
    else
        result_axes[sliced_dimensions] = 
            [indices(first_output)...,
                ntuple(dimension -> Base.OneTo(1), number_of_trivial_output_axes)...]
    end
    result = similar(first_output, tuple(result_axes...,))

    result_index = Any[map(first, indices(result))...]
    for dimension in sliced_dimensions
        result_index[dimension] = indices(result, dimension)
    end

    result[result_index...] = first_output

    # skip the first element, we already handled it
    indexes = Iterators.drop(CartesianRange(tuple(axes[setdiff([1:rank;], sliced_dimensions)]...) ), 1)

    dimension_is_indexed = map(1:rank) do dimension
        !(dimension in sliced_dimensions)
    end

    inner_mapslices_noinline!(f, input, result,
        (dimension_is_indexed...), indexes, (input_index...), (result_index...),
        safe_for_reuse, input_slice, 
    )
end

Any other suggestions?

Here’s another attempt with value types:

@inline dispatch_on_value(::Val{true}, first_old_index, tail_old_index, replace_index, should_replace) = 
    (replace_index[1], replace_tuple(tail_old_index, Base.tail(replace_index), should_replace )...)
@inline dispatch_on_value(::Val{false}, first_old_index, tail_old_index, replace_index, should_replace) = 
    (first_old_index, replace_tuple(tail_old_index, replace_index, should_replace )...)

@inline replace_tuple(old_index, replace_index::Tuple{}, should_replace) = old_index
@inline replace_tuple(old_index, replace_index, should_replace) = 
    dispatch_on_value(should_replace[1], old_index[1], Base.tail(old_index), replace_index, Base.tail(should_replace))

@inline function inner_mapslices!(f, input, result,
    dimension_is_indexed, indexes, input_index, result_index,
    safe_for_reuse, input_slice
)
    if safe_for_reuse
        # when f returns an input, result[result_index...] = f(input_slice) line copies elements,
        # so we can reuse input_slice
        for index in indexes
            Base._unsafe_getindex!(input_slice, input, replace_tuple(input_index, index.I, dimension_is_indexed)...)
            result[replace_tuple(result_index, index.I, dimension_is_indexed)...] = f(input_slice)
        end
    else
        # we can't guarantee safety (#18524), so allocate new storage for each slice
        for index in indexes
            result[replace_tuple(result_index, index.I, dimension_is_indexed)...] = 
                f(input[replace_tuple(input_index, index.I, dimension_is_indexed)...])
        end
    end
    result
end

@code_warntype inner_mapslices!(f, input, result,
    (dimension_is_indexed...), indexes, (input_index...), (result_index...),
    safe_for_reuse, input_slice
)

@noinline inner_mapslices_noinline!(f, input, result,
    dimension_is_indexed, indexes, input_index, result_index,
    safe_for_reuse, input_slice
) = inner_mapslices!(f, input, result,
    dimension_is_indexed, indexes, input_index, result_index,
    safe_for_reuse, input_slice
)

function my_mapslices(f, input::AbstractArray, sliced_dimensions::AbstractVector)
    if isempty(sliced_dimensions)
        return map(f,input)
    end

    axes = [indices(input)...]
    rank = ndims(input)

    input_index = Any[first(index) for index in indices(input)]
    for dimension in sliced_dimensions
        input_index[dimension] = Base.Slice(indices(input, dimension))
    end

    # Apply the function to the first slice in order to determine the next steps
    input_slice = input[input_index...]
    first_output = f(input_slice)
    # In some cases, we can re-use the first slice for a dramatic performance
    # increase. The slice itself must be mutable and the result cannot contain
    # any mutable containers. The following errs on the side of being overly
    # strict (#18570 & #21123).
    safe_for_reuse = 
        isa(input_slice, StridedArray) && 
        (isa(first_output, Number) || 
            (isa(first_output, AbstractArray) && eltype(first_output) <: Number))

    # determine result size and allocate
    result_axes = copy(axes)
    # TODO: maybe support removing dimensions
    if !isa(first_output, AbstractArray) || ndims(first_output) == 0
        first_output = [first_output]
    end
    number_of_trivial_output_axes = 
        max(0, length(sliced_dimensions) - ndims(first_output)) 

    if eltype(result_axes) == Int
        result_axes[sliced_dimensions] = 
            [size(first_output)..., 
                ntuple(dimension->1, number_of_trivial_output_axes)...]
    else
        result_axes[sliced_dimensions] = 
            [indices(first_output)...,
                ntuple(dimension -> Base.OneTo(1), number_of_trivial_output_axes)...]
    end
    result = similar(first_output, tuple(result_axes...,))

    result_index = Any[map(first, indices(result))...]
    for dimension in sliced_dimensions
        result_index[dimension] = indices(result, dimension)
    end

    result[result_index...] = first_output

    # skip the first element, we already handled it
    indexes = Iterators.drop(CartesianRange(tuple(axes[setdiff([1:rank;], sliced_dimensions)]...) ), 1)

    dimension_is_indexed = map(1:rank) do dimension
        if dimension in sliced_dimensions
            Val{false}()
        else
            Val{true}()
        end
    end

    inner_mapslices_noinline!(f, input, result,
        (dimension_is_indexed...), indexes, (input_index...), (result_index...),
        safe_for_reuse, input_slice
    )
end

f = sum
input = rand(5, 5)
sliced_dimensions = [1]

using BenchmarkTools 
Test.@test mapslices(f, input, sliced_dimensions) == my_mapslices(f, input, sliced_dimensions)
result1 = @benchmark mapslices(f, input, sliced_dimensions)
result2 = @benchmark my_mapslices(f, input, sliced_dimensions)
result3 = @benchmark sum(input, sliced_dimensions)

@tim.holy

I worked really hard on this and this is about the fastest I can get.

import Base.tail

not_tuple(v::Val{false}) = Val{true}()
not_tuple(v::Val{true}) = Val{false}()

ifelse_tuple(switch::Val{false}, old, new) = old
ifelse_tuple(switch::Val{true}, old, new) = new

ifelse_vectorized(switch, old::Tuple{}, new::Tuple{}) = ()
ifelse_vectorized(switch, old, new) = (
    ifelse_tuple(first(switch), first(old), first(new)),
    ifelse_vectorized(tail(switch), tail(old), tail(new))...
)

getindex_tuple(into::Tuple{}, index) = ()
getindex_tuple(into, index) = begin
    next = getindex_tuple(tail(into), tail(index))
    ifelse_tuple(first(index), next, (first(into), next...))
end


setindex_tuple(old::Tuple{}, new::Tuple{}, switch::Tuple{}) = ()
setindex_tuple(old::Tuple{}, new, switch::Tuple{}) = ()
setindex_tuple(old, new::Tuple{}, switch) = old
setindex_tuple(old, new, switch) =  begin
    first_switch = first(switch)
    ifelse_tuple(first_switch, first(old), first(new)),
    setindex_tuple(
        tail(old),
        ifelse_tuple(first_switch, new, tail(new)),
        tail(switch))...
end

setindex_recycle(old::Tuple{}, new, switch::Tuple{}) = ()
setindex_recycle(old, new, switch) =  begin
    first_switch = first(switch)
    ifelse_tuple(first_switch, first(old), new),
    setindex_recycle(
        tail(old),
        new,
        tail(switch))...
end

setindex_tuple(old::Tuple{}, new::Tuple{}, switch::Tuple{}, default) = ()
setindex_tuple(old::Tuple{}, new, switch::Tuple{}, default) = ()
setindex_tuple(old, new::Tuple{}, switch, default) = setindex_recycle(old, default, switch)
setindex_tuple(old, new, switch, default) =  begin
    first_switch = first(switch)
    ifelse_tuple(first_switch, first(old), first(new)),
    setindex_tuple(
        tail(old),
        ifelse_tuple(first_switch, new, tail(new)),
        tail(switch),
        default)...
end

slice_indices(array_indices, colon_indices) =
    ifelse_vectorized(colon_indices, first.(array_indices), Base.Slice.(array_indices))

maybe_wrap(a::AbstractArray{T, 0} where T) = [a]
maybe_wrap(a::AbstractArray) = a
maybe_wrap(any) = [any]

function fast_mapslices(f, A, colon_indices)
    all_indices = indices(A)
    input_index = slice_indices(all_indices, colon_indices)
    return_slice = maybe_wrap(f(@view A[input_index...]))
    final_return = similar(return_slice, setindex_tuple(all_indices, indices(return_slice), colon_indices, Base.OneTo(1))...)
    return_index = slice_indices(indices(final_return), colon_indices)
    final_return[return_index...] .= return_slice
    index_indices = not_tuple.(colon_indices)
    for I in Iterators.Drop(CartesianRange(getindex_tuple(all_indices, index_indices)), 1)
        index = I.I
        final_return[setindex_tuple(return_index, index, index_indices)...] .=
            f(@view A[setindex_tuple(input_index, index, index_indices)...])
    end
    final_return
end

my_mapslices(f, A, dims) =
    if isempty(dims)
        map(f, A)
    else
        fast_mapslices(f, A, map(dim -> (dim in dims) ? Val{true}() : Val{false}(), (1:ndims(A)...)))
    end

macro values(args...)
    Expr(:tuple, map(args) do arg
        :($Val($arg))
    end...)
end

using BenchmarkTools

const f = sum
const A = rand(5, 5)
const dims = (1,)

@code_warntype fast_mapslices(f, A, @values true false)

@btime mapslices(f, A, dims)
@btime my_mapslices(f, A, dims)
@btime sum(A, dims)
@btime fast_mapslices(f, A, @values true false)
3 Likes

Yes, that’s the right way to implement mapslices. Nice work! The only thing I see to complain about is the fact that with true/false, it’s hard to remember which is which. I find myself leaning towards mapslices(f, A, (:, *, :)) for a 3-dimensional A being equivalent to mapslices(f, A, [1, 3]). The reasoning is that “dims is where the colons go” so you might as well use : directly. The * indicates that this slot is to be replaced by a value, meant to be reminiscent of a wildcard search. I’m less thrilled about * than :, but I haven’t yet thought of anything better.

Of course you can “translate” (essentially for free) into Val{true}()/Val{false}() in your internal implementation, this is just about the user-facing API.

This looks great! I was working on something that turned out to be exactly the same. Besides the interface issue that Tim brought up, the only other point I see for debate is your decision to use views for slicing the input array. Most of the time I think it’s a win, but there are combinations of f and A for which that would slow things down a lot. I’m thinking about a case where f must access each element of a slice more than once. In that case performance would suffer if the slice was not contiguous in memory. On the bright side I think this simplifies the code…you were able to eliminate the ugliness of deciding whether to reuse the first slice of A.

Ok, I implemented your suggestion. I’ve put up a package JuliennedArrays which uses most of the same code except that its iterator based. I think I lost some of the performance gains I had made… If I understood the code I think there would be further help to be found in mapreducedims

I was checking the indexing and hadn’t noticed the use of view. An important feature of the current implementation is that it’s guaranteed to be non-mutating no matter what function the user provides:

julia> a = reshape(collect(1:30), 5, 6)
5×6 Array{Int64,2}:
 1   6  11  16  21  26
 2   7  12  17  22  27
 3   8  13  18  23  28
 4   9  14  19  24  29
 5  10  15  20  25  30

julia> f(b) = (fill!(b, 0); 5)
f (generic function with 1 method)

julia> mapslices(f, a, 1)
1×6 Array{Int64,2}:
 5  5  5  5  5  5

julia> a
5×6 Array{Int64,2}:
 1   6  11  16  21  26
 2   7  12  17  22  27
 3   8  13  18  23  28
 4   9  14  19  24  29
 5  10  15  20  25  30

This is in direct analogy to map, which can’t mutate the underlying array. (We may want a mapslices! similar to map!.)

1 Like

Ok, I’ve updated the JuliennedArrays package to go back to the _unsafe_getindex strategy. I’ve gotten rid of all the allocations I think too. Views didn’t seem to be consistently improving performance anyway. I’m seeing 5-20X speedups.

1 Like

Still not beating mapreducedims for functions like sum though. Presumably there is a fast way to do reductions without having to make a temporary array (or a view) at all.

You might be running into cache issues, which will always plague mapslices by comparison to hand-tuned implementations (i.e., there isn’t an easy solution to this problem). @Cody-G has written some benchmarks that account for this issue, I think would be good to check them.

Just wanted to mention that mapslices is causing type instability in my code.

z = rand(3, 44)
@code_warntype mapslices(norm, z, 1)

and you see the red Any’s.

I imagine this will be a performance issue in my functions that use mapslices. Is there a way around this? Thanks.