This is a mwe of code which integrates a network of ~100 - ~1000
interacting particles that can grow and die. In general, the connections are arbitrary and growth/death has logic but here I just connect every particle and destroy a random particle at each step. Timing this gives
17.380744 seconds (809.67 M allocations: 37.817 GiB, 10.43% gc time, 2.14% compilation time: 5% of which was recompilation)
and a profile shows things like this
I’m not entirely sure what can be done to reduce the number of allocations here or to speed it up - it seems to be complaining about the *
operation itself.
Code
using OrdinaryDiffEq
using LinearAlgebra: norm
using Random: seed!
mutable struct SpringParams{T<:Real}
k::Function
L::T
connections::Dict{Int64, Vector{Int64}}
end
function spring_force(
xy1::SubArray{T, 1, Vector{T}, Tuple{UnitRange{R}}},
xy2::SubArray{T, 1, Vector{T}, Tuple{UnitRange{R}}},
parameters::SpringParams{T}) where {T <:Real, R<:Integer}
d = norm(xy1 - xy2)
return parameters.k(d)*(parameters.L/d - 1)*(xy1 - xy2)
end
function Network!(du, u, p::SpringParams, t)
for i = 1:Integer(length(u)/2)
# note that x, y = u[2*i-1], u[2*i]
du[2*i-1] = 1.0
du[2*i] = 0.5
@views for j in p.connections[i]
du[2*i-1:2*i] .= du[2*i-1:2*i] .+ spring_force(u[2*i-1:2*i], u[2*j-1:2*j], p)
end
end
end
##################
seed!(1234)
ics = Iterators.product(range(1.0, 3.0, length = 20), range(5.0, 9.0, length = 20)) |> collect |> vec |> x -> [float(x[i][j]) for i = 1:length(x) for j = 1:2]
tspan = (0.0, 1.0)
n_items = length(ics)/2 |> Integer
sp = SpringParams(k -> 1.0, 1.0, Dict(i => [j for j = 1:n_items if j != i] for i = 1:n_items))
cb = DiscreteCallback(
(u, t, integrator) -> true,
integrator -> begin
n_items = Integer(length(integrator.u)/2)
idx = rand(1:n_items)
deleteat!(integrator, 2*idx-1) # x coordinate
deleteat!(integrator, 2*idx-1) # y coordinate (now where x used to be)
integrator.p.connections = Dict(i => [j for j = 1:n_items-1 if j != i] for i = 1:n_items-1) # have to update connections to take into account growth/death
nothing
end
)
prob = ODEProblem(Network!, ics, tspan, sp)
@time sol = solve(prob, Tsit5(), callback = cb);