Creating a Static Array with a loop

I am solving an ODE with my own version of a Dormand Prince solver, and previously I have been using an in-place version, which has worked well for me, in that the allocations come from allocating for the array containing the result. Now, I’ve been trying to rewrite this in-place solver using Static Arrays, and for the most part, this implementation seems to be working ok.

using BenchmarkTools
using StaticArrays
using LinearAlgebra

function dpsolveinplace(x0,f!,dt,n::Int;final=0,printout::Bool=true)
    #=
    Dormand-Prince solver for a differential equation
    of the form x' = f(x)
    initial condition x0
    time step dt
    number of time steps n
    final = true means only return the final value
    printout tells if you want to print out the progress
    =#
    N = length(x0)
    b = dpmat()
    kmat = zeros(typeof(x0[1]),7,N)
    ubuffer = zeros(typeof(x0[1]),N)
    fbuffer = zeros(typeof(x0[1]),N)
    res = zeros(typeof(x0[1]),n+1,N)
    res[1,:] .= x0
    a = 0
    for count in 1:n
        f!(fbuffer,@view res[count,:])
        kmat[1,:] .= fbuffer
        for j in 2:7
            for k in 1:N
                ubuffer[k] = 0.0
                for i in 1:j-1
                    ubuffer[k] += b[i,j-1]*kmat[i,k]
                end
                ubuffer[k] = dt*ubuffer[k] + res[count,k]
            end
            f!(fbuffer,ubuffer)
            for k in 1:N
                kmat[j,k] = fbuffer[k]
            end
        end
        for k in 1:N
            res[count+1,k] = res[count,k] + dt*sum(b[i,6]*kmat[i,k] for i in 1:6)
        end
        if mod(count,convert(Int,round(n/20))) == 0 && printout
            println("Progress = ",round(100*count/n,digits=2),"%")
        end
    end
    if typeof(final) == Bool
        return @view res[end,:]
    else
        return res
    end
end
function dpsolve_static(x0,f,dt,n::Int;final=0,printout::Bool=true)
    #=
    Dormand-Prince solver for a differential equation
    of the form x' = f(x)
    Uses static arrays
    Also uses an in-place function to mutate the static array
    initial condition x0
    time step dt
    number of time steps n
    Uses SVectors
    final = true means only return the final value
    =#
    N = length(x0)
    b = dpmat()
    ubuffer = @MVector zeros(eltype(x0),N)
    buffer = copy(ubuffer)
    sumbuffer = copy(ubuffer)
    res = Array{SVector{N,eltype(x0)}}(undef,n+1)
    kmat = Array{SVector{N,eltype(x0)}}(undef,7)
    for i in 1:7
        kmat[i] = zeros(eltype(x0),N)
    end
    res[1] = copy(x0)
    a = 0
    for count in 1:n
        kmat[1] = f(res[count])
        for j in 2:7
            for k in 1:N
                ubuffer[k] = 0
                for i in 1:j-1
                    ubuffer[k] += b[i,j-1]*kmat[i][k]
                end
            end
            buffer = res[count] + dt*ubuffer
            kmat[j] = f(buffer)
        end
        sumbuffer .= 0.0
        for i in 1:6
            for k in 1:N
                sumbuffer[k] += b[i,6]*kmat[i][k]
            end
        end
        for k in 1:N
            sumbuffer[k] = sumbuffer[k]*dt + res[count][k]
        end
        res[count+1] = SVector(sumbuffer)
        if mod(count,convert(Int,round(n/20))) == 0 && printout
            println("Progress = ",round(100*count/n,digits=2),"%")
        end
    end
    a > 0 && println(a)
    if typeof(final) == Bool
        return @view res[end,:]
    else
        return res
    end
end
function dpmat() #Dormand Prince matrix
    b = @MMatrix zeros(6,6)
    b[1,1] = 1/5
    b[1,2] = 3/40
    b[2,2] = 9/40
    b[1,3] = 44/45
    b[2,3] = -56/15
    b[3,3] = 32/9
    b[1,4] = 19372/6561
    b[2,4] = -25360/2187
    b[3,4] = 64448/6561
    b[4,4] = -212/729
    b[1,5] = 9017/3168
    b[2,5] = -355/33
    b[3,5] = 46732/5247
    b[4,5] = 49/176
    b[5,5] = -5103/18656
    b[1,6] = 35/384
    b[2,6] = 0
    b[3,6] = 500/1113
    b[4,6] = 125/192
    b[5,6] = -2187/6784
    b[6,6] = 11/84
    return b
end

where I have chosen to define the Dormand Prince coefficient matrix dpmat() as a MMatrix because the syntax to define it as a SMatrix immediately would be hideous.

I wrote some 2D test functions to check my implementations.

function testfun_2d(x)
    x1 = x[1] + x[2]
    x2 = x[1] - x[2]
    return @SVector[x1,x2]
end
function testfun_2d!(u,x)
    u[1] = x[1] + x[2]
    u[2] = x[1] - x[2]
end
@btime dpsolveinplace([1.0,1.0],testfun_2d!,1e-3,1000,final=true,printout=false)
  138.900 μs (6 allocations: 16.58 KiB)
2-element view(::Matrix{Float64}, 1001, :) with eltype Float64:
 4.914781300625746
 2.1781835566085666
@btime dpsolve_static(@SVector[1.0,1.0],testfun_2d,1e-3,1000,final=true,printout=false)
  39.600 μs (12 allocations: 17.05 KiB)
1-element view(::Matrix{SVector{2, Float64}}, 1001, :) with eltype SVector{2, Float64}:
 [4.914781300625746, 2.1781835566085666]

I am a little unsure why the static array method allocates slightly more memory, but it’s less than 1 kB so that’s not my query.

I have a question with creating a Static Array where I need to run a loop to determine each entry. The array is then passed to an ODE solver. Here is an in-place example of what I am using.

function static_test!(u,x;N=6)
    u[1] = 2
    u[2] = 3
    for i in 1:N
        u[i+2] = i
    end
    u[9] = 1
    u[10] = 2
end

Suppose, in principle, that I don’t know (at compilation) the value of N. However, my function will eventually calculate N before solving the ODE. In this case, how should I generate my Static Array? The following, for instance, is fairly slow.

function static_test(x)
    return SVector{10,Float64}([2;3;collect(1:6);1;2])
end

where I have chosen N = 6. The two solves return (leaving out the return value)

@btime dpsolve_static(@SVector[1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0],static_test,1e-3,1000,final=0,printout=false)`
 24.788 ms (335012 allocations: 10.56 MiB)
@btime dpsolveinplace(ones(10),static_test!,1e-3,1000,final=0,printout=false)
  446.700 μs (7 allocations: 79.75 KiB)

Why is there such a large difference in the memory allocations now? Is there a better way to define static_test, other than hard-coding @SVector[2,3,1,2,3,4,5,6,1,2], which isn’t generally an option?

You are allocating a vector to fill a static vector. Using

function static_test(x)
    return SVector{10,Float64}(2,3,1:6...,1,2)
end

works better for me.

Thanks for that! I didn’t realise I could put the 1:6 in the constructor.

What if I needed SVector{10,Float64}(2,3,[i^2 for i in 1:6]) instead? The actual vector function I am trying to code up is

    function rhs!(u0,x)
        n = length(u0)
        u0[1] = -x[1]*(1 + Omega*im)/kappa - 2/kappa*conj(x[3])
        u0[2] = real(x[1]*x[3]) - epsilon*x[2] + 1
        for i in 3:n-2
            u0[i] = 1/2*(x[1]*x[i+1] - conj(x[1])*x[i-1]) - epsilon*x[i]
        end
        u0[n-1] = -1/2*conj(x[1])*x[n-2] - epsilon*x[n-1]
        u0[n] = x[1]
    end

where Omega, kappa are real numbers and everything is inside another function, and I don’t know of a way to not allocate a vector here.

Same procedure, splat a tuple comprehension

function static_test(x)
    return SVector{10,Float64}(2,3,(i^2 for i in 1:6)...,1,2)
end

Regarding your rhs!: here I don’t see a constant to use in the construction of the static vector, so I’m lost.

Edit: you probably mean something like

const Omega = 1
const kappa = 1
const epsilon = 1
function rhs(x;N=10)
    return SVector{N,ComplexF64}(
        -x[1]*(1 + Omega*im)/kappa - 2/kappa*conj(x[3]),
        real(x[1]*x[3]) - epsilon*x[2] + 1,
        (1/2*(x[1]*x[i+1] - conj(x[1])*x[i-1]) - epsilon*x[i] for i in 3:N-2)..., 
        -1/2*conj(x[1])*x[N-2] - epsilon*x[N-1],
        x[1])
end

@btime dpsolve_static(SVector{10,ComplexF64}(1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0),rhs,1e-3,1000,final=0,printout=false)

But maybe you better should manage the static size of your problem with something like Val(N)

1 Like

Ah I see. Thanks for the tip about tuple comprehension. I did not know about those.

In my rhs!, I defined a vector outside the function to mutate, but it ended up being the same as just an in-place modification with extra work with defining SVectors. Apologies for not making that clearer.

I was originally going manage the static size of my problem with a simple check on N, but I realised that it wouldn’t be type-stable. I didn’t know about using Val either. Thanks for the tips!

While the tuple comprehensions mentioned by @goerch are probably the fastest constructors, If you need to conveniently create an SMatrix at the beginning of some calculation you could also start with an MMatrix and in the end simply promote this to an SMatrix. That allows you also to mutate this matrix while constructing (i.e. you can also construct more complicated matrices) but this still should ensure the compiler that the matrix will stay immutable once it is returned by your function:

function dpmat() #Dormand Prince matrix
    b = @MMatrix zeros(6,6)
    b[1,1] = 1/5
    [...] #do whatever with the elements of your matrix
    b[6,6] = 11/84
    return SMatrix(b) # result is now an SMatrix and cannot be mutated anymore
end 
julia> @allocated dpmat() 
0

I cannot think of any downsides to this, other than possible the speed of construction, which is probably slightly slower than directly creating an SMatrix.

1 Like

Sometimes it is slower because LLVM will write to the stack pointer (where the MMatrix lives) before copying it to the SMatrix’s destination, but there is no reason in principle for this to be slower/for LLVM not to write directly to the SMatrix.
In simple cases, where everything can fit into the registers, I would expect LLVM to fully eliminate all stores to the stack pointer (to the MMatrix) so that this code is optimal.
Maybe it’d succeed in some slightly more complicated cases, too.

EDIT:
Checking a few simple examples, and the generated code unfortunately looks quite bad.

function example(N)
    A = @MMatrix zeros(6,6)
    @inbounds for i = 1:min(N,6)
        A[i,i] = 1/N + i
    end
    SMatrix(A)
end
2 Likes

Im curious, can you emphasize on that? As I understand it, the loop cannot really optimize away the references to the mutable matrix, at least not for this loop:

Example
function example(N)
    A = @MMatrix zeros(6,6)
    @inbounds for i = 1:min(N,6)
        A[i,i] = 1. /N + i
    end
    SMatrix(A)
end

function example_3expl(N) #assuming N=3 for explicit construction (not very fair)
    a11 = 1. /N +1.
    a22 = 1. /N +2.
    a33 = 1. /N +3.

    a44 = 0.
    a55 = 0.
    a66 = 0.
    return SMatrix{6,6,Float64,36}(
        a11,0.,0.,0.,0.,0.,
        0.,a22,0.,0.,0.,0.,
        0.,0.,a33,0.,0.,0.,
        0.,0.,0.,a44,0.,0.,
        0.,0.,0.,0.,a55,0.,
        0.,0.,0.,0.,0.,a66
    )
end

function example_lessExpl(N) #not assuming N=3 for construction (fair)
    a11 = (1. /N +1.) *(1<=N)
    a22 = (1. /N +2.) *(2<=N)
    a33 = (1. /N +3.) *(3<=N)
    a44 = (1. /N +4.) *(4<=N)
    a55 = (1. /N +5.) *(5<=N)
    a66 = (1. /N +6.) *(6<=N)

    return SMatrix{6,6,Float64,36}(
        a11,0.,0.,0.,0.,0.,
        0.,a22,0.,0.,0.,0.,
        0.,0.,a33,0.,0.,0.,
        0.,0.,0.,a44,0.,0.,
        0.,0.,0.,0.,a55,0.,
        0.,0.,0.,0.,0.,a66
    )
end


function example2(N) #not using loop to better compare to example_lessExpl
    A = @MMatrix zeros(6,6)
    A[1,1] = (1. /N +1.) *(1<=N)
    A[2,2] = (1. /N +2.) *(2<=N)
    A[3,3] = (1. /N +3.) *(3<=N)
    A[4,4] = (1. /N +4.) *(4<=N)
    A[5,5] = (1. /N +5.) *(5<=N)
    A[6,6] = (1. /N +6.) *(6<=N)
    SMatrix(A)
end

using BenchmarkTools

@btime example_3expl($3) # 7.600 ns (0 allocations: 0 bytes)
@btime example_lessExpl($3) # 19.559 ns (0 allocations: 0 bytes)
@btime example($3) # 26.305 ns (0 allocations: 0 bytes)
@btime example2($3) # 19.157 ns (0 allocations: 0 bytes)

println(example(3) === example_3expl(3)) # true
println(example(3) === example_lessExpl(3)) # true
println(example(3) === example2(3)) # true

Of course directly using an SMatrix is more restrictive and will probably always force you to use less abstractions (and thus likely be faster).

Wait, what? There’s no such thing as a tuple comprehension. Those are generators, and do not create tuples.

1 Like

You are right, of course, but…

Not sure, I’d thought these were: (1:2...,).

That’s just splatting? A comprehension normally involves some expression with iteration over a set. Not sure of the exact definition, but I wouldn’t call [x...] an array comprehension either.