Type Instability with MMatrix

Hi! I am having some trouble with a weird type instability seems to appear in broadcast!() when I use it over a MMatrix from StaticArrays.jl but disappears if I use a normal Matrix. Here is my MWE:

using StaticArrays, OrdinaryDiffEq

# Random Spin
function rand_spin()
    ϕ = 2π*rand()
    θ = acos(2rand()-1)
    SVector(sin(ϕ)*sin(θ), cos(ϕ)*sin(θ), cos(θ))
end

# Matrix generation
function rand_spin_matrix(N::Integer, M::Integer; mtype = :normal)
    if mtype == :mutable
        matrix =  MMatrix{N,M,SVector{3,Float64}}()
    elseif mtype == :normal
        matrix =  Matrix{SVector{3,Float64}}(N,M)
    else
        error("mtype = {:mutable, :normal}")
    end
    for i in eachindex(matrix)
        matrix[i] = rand_spin()
    end
    return matrix
end

# Functions to broadcast
sum_mul_4(σ, h, a1,a2,a3,a4, k1,k2,k3,k4) =
    σ + h*(a1*k1 + a2*k2 + a3*k3 + a4*k4)

sum_mul_5(σ, h, a1,a2,a3,a4,a5, k1,k2,k3,k4,k5) =
    σ + h*(a1*k1 + a2*k2 + a3*k3 + a4*k4 + a5*k5)

a51 = 19372/6561
a52 = -25360/2187
a53 = 64448/6561
a54 = -212/729
a61 = 9017/3168
a62 = -355/33
a63 = 46732/5247
a64 = 49/176
a65 = -5103/18656

### Test ###

h = 0.1
σ = rand_spin_matrix(3, 3, mtype = :mutable)

k1 = copy(σ)
k2 = copy(σ)
k3 = copy(σ)
k4 = copy(σ)
k5 = copy(σ)

σ_k5 = copy(σ)
σ_k6 = copy(σ)

# Type stable
@code_warntype broadcast!(  sum_mul_4, σ_k5, σ, h,
                            a51, a52, a53, a54, k1, k2, k3, k4)
# Type unstable?
@code_warntype broadcast!(  sum_mul_5, σ_k6, σ, h,
                            a61, a62, a63, a64, a65, k1, k2, k3, k4, k5)

If I use instead a normal Matrix as container σ = rand_spin_matrix(3, 3, mtype = :normal) the type instability disappears.

Is this a bug or am I missing something? Any help is appreciated!

From this function signature, N and M are runtime variables, so the type of the MMatrix cannot be deduced at compile-time, causing the type-instability. The way to solve this right now is to turn those into value-types Val{N} and Val{M} and only use those arguments with literals so that way it’s compile-time information. This is a case where constant propagation would be helpful.

There’s a function barrier on the way to broadcast! so this shouldn’t matter at all?

This is probably just some type inference limit getting reached when the number of arguments gets too big.

1 Like

Oh yes, it’s also hitting this:

Yes, your example is the last few lines of a 6-stage Runge-Kutta method, and this is why we cannot internally use broadcast for a lot of things for a bit (Not fully broadcasting methods · Issue #106 · SciML/OrdinaryDiffEq.jl · GitHub). If this is internally broadcasting on this line, it’s a bug and we should do the (hopefully) temporary workaround of just turning that into a loop.

1 Like

BTW, if someone has a working build of Julia’s master it would be nice to see if @jameson’s PR fixed this:

https://github.com/JuliaLang/julia/pull/23912

Probably not fixed, although without benchmarktools.jl it is a bit harder to test. Here results using https://github.com/JuliaLang/julia/issues/22255#issuecomment-306814370:

julia> VERSION                                                                                                                                                           
v"0.7.0-DEV.2436"                                                                                                                                                        

julia> @time fun1!(a,b,c,d,e,f,g,h,i,j,k);                                                                                                                               
  0.000006 seconds (4 allocations: 160 bytes)                                                                                                                            
                                                                                                                                                                         
julia> @time fun2!(a,b,c,d,e,f,g,h,i,j,k,l);                                                                                                                             
  0.000015 seconds (8 allocations: 400 bytes)