Type-instability of mapreduce vs. map + reduce

I’m a bit surprised that in this simple example, f is type-stable but g is not. Is this a JET.jl issue (@aviatesk) or a Julia issue?

using JET

function f(m, n)
    blocks = map(1:n) do i
        ones(m)
    end
    return reduce(hcat, blocks)
end

function g(m, n)
    return mapreduce(hcat, 1:n) do i
        ones(m)
    end
end 
julia> @assert f(10, 20) == g(10, 20)

julia> @test_opt f(10, 20)
Test Passed

julia> @test_opt g(10, 20)
JET-test failed at /home/guillaume/Work/GitHub/Julia/DifferentiationInterface.jl/DifferentiationInterfaceTest/test/playground.jl:18
  Expression: #= /home/guillaume/Work/GitHub/Julia/DifferentiationInterface.jl/DifferentiationInterfaceTest/test/playground.jl:18 =# JET.@test_opt g(10, 20)
  ═════ 1 possible error found ═════
  β”Œ g(m::Int64, n::Int64) @ Main /home/guillaume/Work/GitHub/Julia/DifferentiationInterface.jl/DifferentiationInterfaceTest/test/playground.jl:11
  β”‚β”Œ mapreduce(f::var"#23#24"{Int64}, op::typeof(hcat), A::UnitRange{Int64}) @ Base ./reducedim.jl:357
  β”‚β”‚β”Œ mapreduce(f::var"#23#24"{Int64}, op::typeof(hcat), A::UnitRange{Int64}; dims::Colon, init::Base._InitialValue) @ Base ./reducedim.jl:357
  β”‚β”‚β”‚β”Œ _mapreduce_dim(f::var"#23#24"{Int64}, op::typeof(hcat), ::Base._InitialValue, A::UnitRange{Int64}, ::Colon) @ Base ./reducedim.jl:365
  β”‚β”‚β”‚β”‚β”Œ _mapreduce(f::var"#23#24"{Int64}, op::typeof(hcat), ::IndexLinear, A::UnitRange{Int64}) @ Base ./reduce.jl:432
  β”‚β”‚β”‚β”‚β”‚β”Œ mapreduce_empty_iter(f::var"#23#24"{Int64}, op::typeof(hcat), itr::UnitRange{Int64}, ItrEltype::Base.HasEltype) @ Base ./reduce.jl:380
  β”‚β”‚β”‚β”‚β”‚β”‚β”Œ reduce_empty_iter(op::Base.MappingRF{var"#23#24"{Int64}, typeof(hcat)}, itr::UnitRange{Int64}, ::Base.HasEltype) @ Base ./reduce.jl:384
  β”‚β”‚β”‚β”‚β”‚β”‚β”‚β”Œ reduce_empty(op::Base.MappingRF{var"#23#24"{Int64}, typeof(hcat)}, ::Type{Int64}) @ Base ./reduce.jl:361
  β”‚β”‚β”‚β”‚β”‚β”‚β”‚β”‚ runtime dispatch detected: Base.mapreduce_empty(%1::var"#23#24"{Int64}, hcat, ::Int64)
  │││││││└────────────────────
  
ERROR: There was an error during testing

@Oscar_Smith and @Mason, sorry for the ping, @hill tells me you may have intuitions about this.

1 Like

Yeah, this is because we have a special codepath for reduce(hcat, arr) but not for mapreduce(f, hcat, arr).

The special overloads for reduce(*cat) are really brittle and easily broken, which is why @mcabbott took on creating stack which is what I’d recommend using instead.

h(m, n) = stack(1:n) do i
    ones(m)
end
julia> h(10, 20) == g(10, 20)
true

julia> @test_opt h(10, 20)
Test Passed
3 Likes

Thanks for the quick answer! I’m aware of stack but it doesn’t fit my purpose for two reasons:

First, the actual use case involves concatenating smaller matrices into a bigger matrix, without adding an extra dimension. I could stack into an Array{_,3} and then drop the last dimension, but that seems wasteful?

function g(m, n)
    return mapreduce(hcat, 1:n) do i
        ones(m, 2)  # this is a block and not just a column
    end
end 

Second, stack doesn’t have an optimized implementation for static arrays, and it actually returns a Matrix in those cases. I would be willing to try and fix it, but I’d need some help, maybe from @mcabbott.

Is there another way out of this?

I’d probably just write a for loop in that case, or use Tullio.jl, I think it knows about staticarrays (but I’m not sure).

1 Like

Unfortunately this is for DifferentiationInterface.jl out-of-place Jacobian matrices, so

  • I can’t pre-allocate and then mutate in a for loop, for instance because it would waste the benefits of static arrays
  • I can’t pull in something like Tullio.jl because the dependencies need to be absolutely minimal
2 Likes

Yes you can. Just allocate a MArray, mutate it, then convert to SArray.

julia> using StaticArrays

julia> function f(::Val{m}, ::Val{n}) where {m, n}
           M = MArray{Tuple{m, n}, Float64}(undef)
           for j ∈ axes(m, 2)
               for i ∈ axes(m, 1)
                   M[i, j] = 1
               end
           end
           SArray(M)
       end
f (generic function with 1 method)

julia> @btime f(Val(10), Val(20))
  42.968 ns (0 allocations: 0 bytes)
10Γ—20 SMatrix{10, 20, Float64, 200} with indices SOneTo(10)Γ—SOneTo(20):
 1.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0
 0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0  0.0

The trick here is that you can create MArray objects without any allocations if the compiler knows that the array never escapes the function body. That’s why we convert it to an SArray at the end.

In v1.12, this will get better becuase the compiler can reason about non-inlined function calls, but even in quite early versions of julia, simple kernels like this where everything is inlined works really well.

3 Likes

Good to know that this works! The challenge now will be to code it in a way that is not specialized on StaticArrays but instead works the same on most AbstractArrays.
And I still think optimizing stack on SArrays would be worth doing, if anyone has pointers.

1 Like

g is type unstable because the reducer β€” hcat β€” is not called for one-value collections:

julia> g(3, 1)
3-element Vector{Float64}:
 1.0
 1.0
 1.0

julia> g(3, 2)
3Γ—2 Matrix{Float64}:
 1.0  1.0
 1.0  1.0
 1.0  1.0

To ensure type stability, use ones(m, 1).

4 Likes

That applies to the first example I gave, but apparently not to my real use case with blocks, since hcat-ing even just one always returns a matrix?

julia> using JET

julia> function g(m, n)
           return mapreduce(hcat, 1:n) do i
               ones(m, 2)
           end
       end
g (generic function with 1 method)

julia> g(3, 1)
3Γ—2 Matrix{Float64}:
 1.0  1.0
 1.0  1.0
 1.0  1.0

julia> @test_opt g(3, 1)
JET-test failed at /home/guillaume/Work/GitHub/Julia/DifferentiationInterface.jl/DifferentiationInterface/test/playground.jl:11
  Expression: #= /home/guillaume/Work/GitHub/Julia/DifferentiationInterface.jl/DifferentiationInterface/test/playground.jl:11 =# JET.@test_opt g(3, 1)
  ═════ 1 possible error found ═════
  β”Œ g(m::Int64, n::Int64) @ Main /home/guillaume/Work/GitHub/Julia/DifferentiationInterface.jl/DifferentiationInterface/test/playground.jl:4
  β”‚β”Œ mapreduce(f::var"#35#36"{Int64}, op::typeof(hcat), A::Base.OneTo{Int64}) @ Base ./reducedim.jl:357
  β”‚β”‚β”Œ mapreduce(f::var"#35#36"{Int64}, op::typeof(hcat), A::Base.OneTo{Int64}; dims::Colon, init::Base._InitialValue) @ Base ./reducedim.jl:357
  β”‚β”‚β”‚β”Œ _mapreduce_dim(f::var"#35#36"{Int64}, op::typeof(hcat), ::Base._InitialValue, A::Base.OneTo{Int64}, ::Colon) @ Base ./reducedim.jl:365
  β”‚β”‚β”‚β”‚β”Œ _mapreduce(f::var"#35#36"{Int64}, op::typeof(hcat), ::IndexLinear, A::Base.OneTo{Int64}) @ Base ./reduce.jl:432
  β”‚β”‚β”‚β”‚β”‚β”Œ mapreduce_empty_iter(f::var"#35#36"{Int64}, op::typeof(hcat), itr::Base.OneTo{Int64}, ItrEltype::Base.HasEltype) @ Base ./reduce.jl:380
  β”‚β”‚β”‚β”‚β”‚β”‚β”Œ reduce_empty_iter(op::Base.MappingRF{var"#35#36"{Int64}, typeof(hcat)}, itr::Base.OneTo{Int64}, ::Base.HasEltype) @ Base ./reduce.jl:384
  β”‚β”‚β”‚β”‚β”‚β”‚β”‚β”Œ reduce_empty(op::Base.MappingRF{var"#35#36"{Int64}, typeof(hcat)}, ::Type{Int64}) @ Base ./reduce.jl:361
  β”‚β”‚β”‚β”‚β”‚β”‚β”‚β”‚ runtime dispatch detected: Base.mapreduce_empty(%1::var"#35#36"{Int64}, hcat, ::Int64)
  │││││││└────────────────────
  
ERROR: There was an error during testing

That’s in the error path for empty collections. You can avoid that path if you can use an init: init=Array{Float64}(undef, m, 0). And using an init will also ensure that hcat is called for 1-arg collections, too.

4 Likes

I think that’s the trick I was missing! Thank you.

Now of course I need to figure out a function which generates a zero-width equivalent for any AbstractMatrix, to use in the init…

Are you chasing down dynamic dispatches for static compilation? If not, you really don’t need to worry about this.

But if you are chasing every single dynamic dispatch down, another alternative could be to locally patch in Julia#51948 β€” prior to that we were dynamically fixing up a method error.

Yeah, this whole quest is for DifferentiationInterface.jl to perform optimally on StaticArrays.jl. In that case, I would ideally avoid any unnecessary runtime dispatch and even allocations.

I’m not sure I understand, what do you want me to patch there?

The magic line is @eval Base mapreduce_empty(f, op, T) = _empty_reduce_error(). But probably not something you want to do in this context.

If you use a named function for the mapper, then you could overload this without piracy or shame.

2 Likes

I think I’m gonna go with another solution and compute the first matrix block outside of the mapreduce. Less elegant but less scary. Thanks a lot for your advice!

1 Like

Do you think there is a way to tell the compiler statically that the length of the iterator won’t be zero?

I’d take a step back β€” are you only chasing this because JET is telling you to chase it? I’m not sure it’s actually a real dynamic dispatch; it’s an intentional method missing error. And it’s not there in v1.11.

You might be ok to ignore it, I think.

2 Likes

You’re probably right. JET has been bossing me around for too long.
And I was testing on 1.10, so good catch. Let’s close this thread.