Integrating network of particles+springs with varying number of particles

This is a mwe of code which integrates a network of ~100 - ~1000 interacting particles that can grow and die. In general, the connections are arbitrary and growth/death has logic but here I just connect every particle and destroy a random particle at each step. Timing this gives

17.380744 seconds (809.67 M allocations: 37.817 GiB, 10.43% gc time, 2.14% compilation time: 5% of which was recompilation)

and a profile shows things like this

I’m not entirely sure what can be done to reduce the number of allocations here or to speed it up - it seems to be complaining about the * operation itself.

Code

using OrdinaryDiffEq
using LinearAlgebra: norm
using Random: seed!

mutable struct SpringParams{T<:Real}
    k::Function
    L::T
    connections::Dict{Int64, Vector{Int64}}
end

function spring_force(
    xy1::SubArray{T, 1, Vector{T}, Tuple{UnitRange{R}}},
    xy2::SubArray{T, 1, Vector{T}, Tuple{UnitRange{R}}},
    parameters::SpringParams{T}) where {T <:Real, R<:Integer}
    d = norm(xy1 - xy2)
    return parameters.k(d)*(parameters.L/d - 1)*(xy1 - xy2)
end

function Network!(du, u, p::SpringParams, t)
    for i = 1:Integer(length(u)/2)
        # note that x, y = u[2*i-1], u[2*i]

        du[2*i-1] = 1.0
        du[2*i] = 0.5

        @views for j in p.connections[i]
            du[2*i-1:2*i] .= du[2*i-1:2*i] .+ spring_force(u[2*i-1:2*i], u[2*j-1:2*j], p)
        end
    end
end

##################
seed!(1234)

ics = Iterators.product(range(1.0, 3.0, length = 20), range(5.0, 9.0, length = 20)) |> collect |> vec |> x -> [float(x[i][j]) for i = 1:length(x) for j = 1:2]
tspan = (0.0, 1.0)
n_items = length(ics)/2 |> Integer
sp = SpringParams(k -> 1.0, 1.0, Dict(i => [j for j = 1:n_items if j != i] for i = 1:n_items))

cb = DiscreteCallback(
    (u, t, integrator) -> true, 
    integrator -> begin
        n_items =  Integer(length(integrator.u)/2)
        idx = rand(1:n_items)
        deleteat!(integrator, 2*idx-1) # x coordinate
        deleteat!(integrator, 2*idx-1) # y coordinate (now where x used to be)
        integrator.p.connections = Dict(i => [j for j = 1:n_items-1 if j != i] for i = 1:n_items-1) # have to update connections to take into account growth/death
        nothing 
    end
)

prob = ODEProblem(Network!, ics, tspan, sp)
@time sol = solve(prob, Tsit5(), callback = cb);

You’ll want to make these array builds lazy. xy1 - xy2. u[2*i-1:2*i] needs to be a view.

You probably want to avoid having the function k be an abstract type. You might also be able to squeeze out a little more performance if you make your struct not mutable.

struct SpringParams{T<:Real, Tk<:Function}
    k::Tk
    L::T
    connections::Dict{Int64, Vector{Int64}}
end

You can probably skip the complicated type annotation for spring force, I don’t think it’s helping:

function spring_force(xy1, xy2, parameters)
    d = norm(xy1 - xy2)
    return parameters.k(d)*(parameters.L/d - 1)*(xy1 - xy2)
end

After some tinkering, this is the fastest version I found. Now I get

1.007681 seconds (1.51 M allocations: 360.787 MiB, 15.35% gc time, 48.35% compilation time)

which is a significant improvement. I made everything a scalar (couldn’t figure out how to make xy1 - xy2 without allocating), annotated the type of k and kept SpringParams immutable. Thank you @ChrisRackauckas and @JonasWickman.

Code

using OrdinaryDiffEq
using LinearAlgebra: norm
using Random: seed!

struct SpringParams{T<:Real, Tk<:Function}
    k::Tk
    L::T
    connections::Dict{Int64, Vector{Int64}}
end

function Network!(du, u, p::SpringParams, t)
    for i = 1:Integer(length(u)/2)
        # note that x, y = u[2*i-1], u[2*i]

        du[2*i-1] = 1.0
        du[2*i] = 0.5

        for j in p.connections[i]
            d = sqrt((u[2*i-1] - u[2*j-1])^2 + (u[2*i] - u[2*j])^2)
            fac = (p.k(d))*(p.L/d - 1)

            du[2*i-1] += fac*(u[2*i-1] - u[2*j-1])
            du[2*i] += fac*(u[2*i] - u[2*j])
        end
    end
end


##################
seed!(1234)

ics = Iterators.product(range(1.0, 3.0, length = 20), range(5.0, 9.0, length = 20)) |> collect |> vec |> x -> [float(x[i][j]) for i = 1:length(x) for j = 1:2]
tspan = (0.0, 1.0)
n_items = length(ics)/2 |> Integer

sp = SpringParams(k -> 1.0, 1.0, Dict(i => [j for j = 1:n_items if j != i] for i = 1:n_items))

cb = DiscreteCallback(
    (u, t, integrator) -> true, 
    integrator -> begin
        n_items =  Integer(length(integrator.u)/2)
        idx = rand(1:n_items)
        deleteat!(integrator, 2*idx-1) # x coordinate
        deleteat!(integrator, 2*idx-1) # y coordinate (now where x used to be)
        delete!(integrator.p.connections, idx)
        for i = 1:n_items-1
            integrator.p.connections[i] = [j for j = 1:n_items-1 if j != i]
        end
        nothing 
    end
)

prob = ODEProblem(Network!, ics, tspan, sp)
@time global sol = solve(prob, Tsit5(), callback = cb);

You can use a generator and take a norm over the generator expression.

1 Like

A few comments (sorry for typos - I’m on mobile).

Half of this is compilation time and probably most allocations are caused by the compilation as well. Run the @time twice to get accurate results.

StaticArrays.jl would likely have helped here :slight_smile:

I am not convinced that a Dict is the best choice as datastructure here. It’s probably worth experimenting a bit with different strategies for handling the deleting of elements. Do you delete a significant fraction of springs? How sparse is your connection graph? I would imagine that if you delete only few springs you could just keep their entries in the state vector and just stop updating them (maybe mask them out at the very end) by changing the connections graph. If your connections are, I would recommend using SparseArrays.jl or just a list of Tuples or so.

If you say springs, then the connections will always be symmetric. So you can gain a factor of 2 by just iterating over all connections and updating both participating springs. However the Dict-based datastructure isn’t really suitable for this.

Thank you for your comments, I greatly appreciate it.

Half of this is compilation time and probably most allocations are caused by the compilation as well. Run the @time twice to get accurate results.

For sure, you would just need to be mindful that sp is modified by the first call but that’s not too big of a deal here.

StaticArrays.jl

I tried this, but I can never get it to be as fast as just working with scalar quantities - probably I am doing something wrong.

Regarding connections: in the real situation, in principle any number of particles could grow or die. However, I think one can get away with only forming connections between each particles’ (say) 10 nearest neighbors. Originally the connections were stored in a matrix but I found it cumbersome to constantly resize it. The factor of 2 in the connections calculation is certainly annoying to leave on the table, something like this works but only gives a slight speedup.

function Network2!(du, u, p::SpringParams, t)
    for i = 1:Integer(length(u)/2)
        du[2*i-1] = 1.0
        du[2*i] = 0.5
    end

    for i = 1:Integer(length(u)/2)
        # note that x, y = u[2*i-1], u[2*i]

        for j in filter(x -> x > i, p.connections[i])
            d = sqrt((u[2*i-1] - u[2*j-1])^2 + (u[2*i] - u[2*j])^2)
            fac = (p.k(d))*(p.L/d - 1)
            f1 = fac*(u[2*i-1] - u[2*j-1])
            f2 = fac*(u[2*i] - u[2*j])

            du[2*i-1] += f1
            du[2*i] += f2
            du[2*j-1] += -f1
            du[2*j] += -f2     
        end
    end
end

What if you use a just a list of the coupled indices, i.e. Vector{NTuple{2,Int}}, to track the connections? Then you iterate simply over all the connections and update both particles. No need to filter anything in the hot loop :slight_smile:

BTW: Is this logic for updating the connections correct?

I think you need to shift all connection indices that are larger than the one you deleted no?