Optimising iterative solution of small system

Hello again, I’m solving iteratively this system of 6 equations as part of a larger simulation

Toy data
using StaticArrays
using LinearAlgebra

mutable struct V{T}
    u::T
    A::T
    A0::T
    γ::T
    β::T
end

struct N{T}
    U::MVector{6, T}
    W::MVector{3, T}
    F::MVector{6, T}
    J::MArray{Tuple{6, 6}, T, 2, 36}
end

function N()
    U = @MArray zeros(6)
    W = @MArray zeros(3)
    F = @MArray zeros(6)
    J = @MArray zeros(6,6) 
    J .+= I(6)
    N(U, W, F, J)
end

function getVs(n::Int64)
    u1 = zeros(n)
    u2 = zeros(n)
    u3 = zeros(n)

    r1 = 0.018105
    r2 = 0.015799
    r3 = 0.008561

    A1 = zeros(n) .+ pi * r1^2 .+ rand(n)
    A2 = zeros(n) .+ pi * r2^2 .+ rand(n)
    A3 = zeros(n) .+ pi * r3^2 .+ rand(n)

    A01 = zeros(n) .+ pi * r1^2
    A02 = zeros(n) .+ pi * r2^2
    A03 = zeros(n) .+ pi * r3^2

    σ = 0.5
    Eh0 = 445.0
    ρ = 1060.0

    β1 = sqrt.(pi ./ A1) .* Eh0 / (1 - σ^2)
    γ1 = β1 ./ (3ρ * r1 * sqrt(pi))

    β2 = sqrt.(pi ./ A2) .* Eh0 / (1 - σ^2)
    γ2 = β2 ./ (3ρ .* r2sqrt(pi))

    β3 = sqrt.(pi ./ A3) .* Eh0 / (1 - σ^2)
    γ3 = β3 ./ (3ρ .* r3 * sqrt(pi))

    V(u1, A1, A01, γ1, β1), V(u2, A2, A02, γ2, β2), V(u3, A3, A03, γ3, β3)
end
Utility functions
function w!(n::N, k)
    n.W[1] = n.U[1] + 4 * k[1] * n.U[4]
    n.W[2] = n.U[2] - 4 * k[2] * n.U[5]
    n.W[3] = n.U[3] - 4 * k[3] * n.U[6]
end

function jac!(n::N, v1::V, v2::V, v3::V, k)
    n.J[1, 4] = 4k[1]
    n.J[2, 5] = -4k[2]
    n.J[3, 6] = -4k[3]

    n.J[4, 1] = (n.U[4] * n.U[4] * n.U[4] * n.U[4])
    n.J[4, 2] = -(n.U[5] * n.U[5] * n.U[5] * n.U[5])
    n.J[4, 3] = -(n.U[6] * n.U[6] * n.U[6] * n.U[6])
    n.J[4, 4] = 4n.U[1] * (n.U[4] * n.U[4] * n.U[4])
    n.J[4, 5] = -4n.U[2] * (n.U[5] * n.U[5] * n.U[5])
    n.J[4, 6] = -4n.U[3] * (n.U[6] * n.U[6] * n.U[6])

    n.J[5, 4] = 2v1.β[end] * n.U[4] / sqrt(v1.A0[end])
    n.J[5, 5] = -2v2.β[1] * n.U[5] / sqrt(v2.A0[1])

    n.J[6, 4] = 2v1.β[end] * n.U[4] / sqrt(v1.A0[end])
    n.J[6, 6] = -2v3.β[1] * n.U[6] / sqrt(v3.A0[1])
end

function f!(n::N, v1::V, v2::V, v3::V, k)
    n.F[1] = n.U[1] + 4k[1] * n.U[4] - n.W[1]
    n.F[2] = n.U[2] - 4k[2] * n.U[5] - n.W[2]
    n.F[3] = n.U[3] - 4k[3] * n.U[6] - n.W[3]
    n.F[4] =
        n.U[1] * (n.U[4] * n.U[4] * n.U[4] * n.U[4]) - n.U[2] * (n.U[5] * n.U[5] * n.U[5] * n.U[5]) -
        n.U[3] * (n.U[6] * n.U[6] * n.U[6] * n.U[6])
    n.F[5] =
        v1.β[end] * (n.U[4] * n.U[4] / sqrt(v1.A0[end]) - 1) -
        (v2.β[1] * (n.U[5] * n.U[5] / sqrt(v2.A0[1]) - 1))
    n.F[6] =
        v1.β[end] * (n.U[4] * n.U[4] / sqrt(v1.A0[end]) - 1) -
        (v3.β[1] * (n.U[6] * n.U[6] / sqrt(v3.A0[1]) - 1))
end

function u!(n::N, v1::V, v2::V, v3::V)
    n.U[1] = v1.u[end]
    n.U[2] = v2.u[1]
    n.U[3] = v3.u[1]
    n.U[4] = sqrt(sqrt(v1.A[end]))
    n.U[5] = sqrt(sqrt(v2.A[1]))
    n.U[6] = sqrt(sqrt(v3.A[1]))
end
Implementation
function alloc!(n::N, v1::V, v2::V, v3::V, k)
    u!(n, v1, v2, v3)
    w!(n, k)
    jac!(n, v1, v2, v3, k)
    f!(n, v1, v2, v3, k)
end

function NR!(n::N, v1::V, v2::V, v3::V, k)
    while norm(n.F)>1e-5
        n.U .+= n.J \ (-n.F)
        w!(n, k)
        f!(n, v1, v2, v3, k)
    end
end

function update!(v1::V, v2::V, v3::V, n::N)
    v1.u[end] = n.U[1]
    v2.u[1] = n.U[2]
    v3.u[1] = n.U[3]

    v1.A[end] = n.U[4] * n.U[4] * n.U[4] * n.U[4]
    v2.A[1] = n.U[5] * n.U[5] * n.U[5] * n.U[5]
    v3.A[1] = n.U[6] * n.U[6] * n.U[6] * n.U[6]
end

function solve!(n::N, v1::V, v2::V, v3::V)
    k = (sqrt(1.5*v1.γ[end]), sqrt(1.5*v2.γ[1]), sqrt(1.5*v3.γ[1]))
    alloc!(n, v1, v2, v3, k)

    NR!(n, v1, v2, v3, k)

    update!(v1, v2, v3, n)
end

Benchmarking the solver I get

julia> v1,v2,v3=getVs(100);n=N();

julia> @btime solve!($n, $v1, $v2, $v3)
  75.288 ns (0 allocations: 0 bytes)

and

julia> @btime alloc!($n, $v1, $v2, $v3, $k)
  47.169 ns (0 allocations: 0 bytes)
 
julia> @btime NR!($n, $v1, $v2, $v3, $k)
  14.059 ns (0 allocations: 0 bytes)
 
julia> @btime update!($v1, $v2, $v3, $n.U)
  7.202 ns (0 allocations: 0 bytes)

therefore the slow part is the building of relevant vectors in alloc!.

However this is not what I see when profiling the simulation as most of the time is spent in NR! when calling \ in

n.U .+= n.J \ (-n.F)

Any advice on how can I go about optimising \?

Profiling only runs the code once so it’s possible that the samples were noisy. Have you tried profiling a repeating loop instead?

Hard to say without a more complete flame graph but the trouble could be that \ allocates, you might wanna try ldiv! instead

https://docs.julialang.org/en/v1/stdlib/LinearAlgebra/#LinearAlgebra.ldiv!

1 Like

It looks like you aren’t updating n.J in this loop? If so, you can just compute its LU factorization once and re-use it.

Also, it seems like it would be a lot easier, and possibly faster, to use SVector and SArray rather than MVector and MArray in this code. Then you can ditch all the “in-place” operations, write in a more natural style, and it still won’t allocate.

3 Likes

EDIT: The above poster beat me to this.

I don’t think this has a lot to do with the \ you identified as slow. But I find MArray to be not that performant compared to SArray. You might consider seeing if you can rewrite the whole thing using SVector/SMatrix. I also find most operations to be a bit more natural to write when you aren’t constantly trying to mutate things.

More broadly, small mutable types are almost never better than (though sometimes tie) their immutable counterparts. So I’ll recommend you make all your types here immutable.

So rather than updating N or a V, you would simply make a whole new one and throw the old one out. While this may sound wasteful, it’s actually often faster than mutating elements of a mutable type (because the compiler can still choose to mutate the elements when it implements the code). This is because mutables (usually) exist on the heap so must constantly write their contents there. Immutable types do not have this requirement, so are free to operate in registers/stack/heap, as the compiler deems best.

2 Likes

Thanks all for your replies, in order

I’m profiling the whole simulation for a short period, enough to call solve! tens of thousands of times.

I ditched N and all the ! functions and converted all the MArray to SArrays but I ended up writing this thing

getU(v1::V, v2::V, v3::V) = SVector{6,Float64}(v1.u[end], v2.u[1], v3.u[1], sqrt(sqrt(v1.A[end])), sqrt(sqrt(v2.A[1])), sqrt(sqrt(v3.A[1])))

getU(U::SVector{6, Float64}, dU::SVector{6, Float64}) = SVector{6,Float64}(U[1]+dU[1], U[2]+dU[2], U[3]+dU[3], U[4]+dU[4], U[5]+dU[5], U[6]+dU[6])

where the latter is used in the inner solver loop

function NR(U, W, J, F, k, v1::V, v2::V, v3::V)
    while norm(F)>1e-5
        dU = J \ (-F)
        U = getU(U, dU)
        W = getW(U, k)
        F = getF(v1, v2, v3, U, k, W)
    end
    U
end

function solve!(v1::V, v2::V, v3::V)
    k = (sqrt(1.5*v1.γ[end]), sqrt(1.5*v2.γ[1]), sqrt(1.5*v3.γ[1]))
    U = getU(v1, v2, v3)
    W = getW(U, k)
    J = getJ(v1, v2, v3, U, k)
    F = getF(v1, v2, v3, U, k, W)

    U = NR(U, W, J, F, k, v1, v2, v3)

    update!(v1, v2, v3, U)
end

so I basically swapped n.U .+= n.J \ (-n.F) with

dU = J \ (-F)
U = getU(U, dU) # i.e., U = U + dU

Is that the way of doing this?

1 Like

It will be just as fast to do:

getU(U, dU) = U + dU

(The argument-type declarations aren’t accomplishing anything here, either.) Of course, once it simplifies this much it’s not clear why you need a function at all, as opposed to just replacing U = getU(U, dU) with

U += dU

There are probably other similar simplifications you can make to your code. In general, with SVector you can write code in a pretty natural vectorized way and it will be fast. (Note that there is no need for “in-place” operations like .+= here.)

Since you aren’t modifying J in this loop, why not do:

function NR(U, W, J, F, k, v1::V, v2::V, v3::V)
    J_LU = lu(J)
    while norm(F)>1e-5
        dU = J_LU \ (-F)

so that you can re-use the Gaussian elimination on J? Or even just

    Jinv = inv(J)
    while norm(F)>1e-5
        dU = Jinv \ (-F)

(which could be faster if you do a lot of iterations).

(This looks like a multivariate Newton iteration, with J being the Jacobian matrix — in this case, I’m surprised that J doesn’t change on every iteration?)

(This looks like a multivariate Newton iteration, with J being the Jacobian matrix — in this case, I’m surprised that
J
doesn’t change on every iteration?)

@stevengj good catch! Indeed, J should be updated at every iteration

while norm(F)>1e-5
    U += J \ (-F)
    W = getW(U, k)
    F = getF(v1, v2, v3, U, k, W)
    J = getJ(v1, v2, v3, U, k)
end

unfortunately I cannot precompute inv(J).

And thanks for pointing out I can use += with a SArray, I never tried as I tought you’d need a mutable structure for that.

x += anything has identical meaning as x = x + anything. Both mean that x + anything is first computed, and then the variable x is reassigned to “point to” this new quantity. So it really doesn’t matter whether x is mutable or not. In your case, there is no heap allocation incurred from this because the new right-hand side is an SArray, and is therefore stack allocated, not heap allocated.

2 Likes

Amazing, thank you all, the \ is flame plot bar is now way smaller than before!

Regarding converting the entire code to static arrays, in reality I have a quite big struct for V, something like

mutable struct Vessel
    label::SubString{String}
    tosave::Bool

    #Topological notation
    sn::Int64
    tn::Int64

    #Numerical constants
    M::Int64
    dx::Float64
    invDx::Float64
    halfDx::Float64

    #Physical constants
    beta::Vector{Float64}
    Cv::Vector{Float64}
    viscoelastic::Bool
    gamma::Vector{Float64}
    gamma_ghost::Vector{Float64}
    A0::Vector{Float64}
    tapered::Bool
    dA0dx::Vector{Float64}
    dTaudx::Vector{Float64}
    Pext::Float64
    gamma_profile::Int64

    #Iterative solution
    A::Vector{Float64}
    Q::Vector{Float64}
    u::Vector{Float64}
    P::Vector{Float64}

    #Riemann invariants
    W1M0::Float64
    W2M0::Float64

    #Ghost cells
    U00A::Float64
    U00Q::Float64
    UM1A::Float64
    UM1Q::Float64

    #Saving locations
    node2::Int64
    node3::Int64
    node4::Int64

    #Peripheral boundary condition
    usewk3::Bool
    Rt::Float64
    R1::Float64
    R2::Float64
    total_peripheral_resistance::Float64
    inlet_impedance_matching::Bool
    Cc::Float64
    Pc::Float64

    #MUSCLArrays
    fluxA::Vector{Float64}
    fluxQ::Vector{Float64}

    vA::Vector{Float64}
    vQ::Vector{Float64}

    dUA::Vector{Float64}
    dUQ::Vector{Float64}

    slopesA::Vector{Float64}
    slopesQ::Vector{Float64}

    Al::Vector{Float64}
    Ar::Vector{Float64}

    Ql::Vector{Float64}
    Qr::Vector{Float64}

    Fl::Vector{Float64}
    Fr::Vector{Float64}

    # waveforms
    waveforms::Dict{String, Array{Float64, 2}}
end

and I have several (can be hundreds) of these in a single

struct Network
    graph::SimpleDiGraph{Int64}
    edges::Vector{Graphs.SimpleGraphs.SimpleEdge{Int64}}
    vessels::Dict{Tuple{Int,Int},Vessel}
    blood::Blood
    heart::Heart
    Ccfl::Float64
end

at runtime I use the Network.graph to access the different Vessels and update them in place, therefore the mutable struct Vessel definition.

The various Vector{Float64} have all the same size in a single Vessel, but different Vessels can have different lengths. The lengths are know at runtime, not before.

Shall I refactor the code so that I have a struct Vessel with only SArray and remove all the inplace operations?

75 nanoseconds is…pretty fast already. Not saying it can’t be made smaller. But the scope for improvement is pretty small. Just a big picture comment.

Did you try using ModelingToolkit + static arrays? It would be a mixture of:

The exmple here is a bit hard to work with… so I didn’t take the time to fully demonstrate it. But I know that here, tearing will eliminate 4 of these equations and you’ll end up with just the last two equations, and solving a 2 equation Newton will be faster and more stable than the 6 equation version. Writing down the analytical solution for one of them in terms of the quadratic (which MTK cannot do right now) would also give you a scalar rootfinding problem which you can then do really fast.

So I think your only next options are not just numeric but symbolic-numeric, i.e. analytically simplify the problem to be solved before solving it.

1 Like

Thanks @ChrisRackauckas I’ll give it a try and possibly open a new thread for that.

Meanwhile using immutable static arrays did the trick