I have tried to compress my code into some example. I am happy to share the full code, but it is over 200 lines.
However, this is basically what it does; I declare a determine_behaviour()
function, which takes a parameter set as input and classifies my system of interest according to how many fixed point it has, and how long it takes and SDE of the system to pass a threshold (for that parameter set):
using Catalyst # Tool for creating biochemical reaction network models for DifferentialEquations.jl
using Polynomials
using StochasticDiffEq
# Classifies a parameter set depending on how the corresponding model behaves.
function determine_behaviour(p;m=4,l=2000)
rts = find_zeros(p)
return (rt_type=get_root_type(rts), passage_type=meassure_passage(rts,p,m,l))
end
# The model fixed points can be found by solving a polynomial equation.
function find_zeros(p)
S,D,τ,v0,n = p
coefficients = zeros(Int64(n)+2)
coefficients[Int64(n)+2] = -S^n-D^n
coefficients[Int64(n)+1] = v0*(S^n+D^n)+S^n
coefficients[2] = -1
coefficients[1] = v0
return sort(real.(filter(x->(imag(x)==0)&&real(x)>0.,roots(Polynomial(coefficients)))))
end
# Classifies the type of root.
function get_root_type(rts)
(length(rts) === 1) && return :single
(length(rts) === 3) && return :triple
return :weird
end
# The model is a biochemical reaction network (declared through Catalyst).
const model = @reaction_network begin
v0+(S*σ)^n/((S*σ)^n+(D*A)^n+1), ∅ → σ
1., σ → ∅
(τ*σ,τ), ∅ ↔ A
end S D τ v0 n
# Measures how often the system passages a threshold, when starting in the fixed point (given random fluctuations).
function meassure_passage(rts,p,m,l)
(length(rts) > 3) && return :Problem
passage_times = get_passage_times(p,rts[1:1],3*rts[1],m,l)
return classify_passage_times(passage_times,l,m)
end
# Gets the times for the system, from a give initial condition, pass a threshold.
function get_passage_times(p,u0,thres,m,l)
prob = SDEProblem(model,u0,(0.,l),p,noise_scaling=(@variables η)[1])
ensemble_prob = EnsembleProblem(prob,prob_func=(p,i,r)->p)
sols = solve(ensemble_prob,ImplicitEM(),EnsembleSerial();trajectories=m,callback=terminate_sim(u0,thres))
return last.(getfield.(sols[:],:t))
end
# Assigns a class to the solution vector.
function classify_passage_times(passage_times,l,m)
nbr_not_terminated = count(passage_times .== l)
(nbr_not_terminated === m) && return :no_passage
(nbr_not_terminated === 0) && return :all_passage
return :some_passage
end
# Terminates a simulation once a threshold has been reached.
function terminate_sim(u0,thres);
condition = (u0[1]<thres) ? (u,t,integrator)->(u[1]>thres) : (u,t,integrator)->(u[1]<thres)
affect!(integrator) = terminate!(integrator)
return DiscreteCallback(condition,affect!,save_positions = (true,true))
end
this is the code where I try running it with, and without, multithreading:
# Meassures the performance
@time determine_behaviour([1.,1.,1.,0.1,3,0.05]) # 4.282405 seconds (13.99 M allocations: 633.620 MiB, 3.79% gc time)
@time Threads.@threads for i = 1.:1.:8.
determine_behaviour([i,1.,1.,0.1,3,0.05])
end # 41.056272 seconds (112.26 M allocations: 4.972 GiB, 3.38% gc time)
@time for i = 1.:1.:8.
determine_behaviour([1.,1.,1.,0.1,3,0.05])
end # 34.791720 seconds (111.91 M allocations: 4.950 GiB, 3.30% gc time)
garbage collection time is only about 3%.
This is closer to reality, but takes slightly longer to run. I would actually loop over more parameter sets, and write to disk:
# Technically more like this, but this has slightly longer run time.
using Serialization
@time Threads.@threads for i = 1:1.:8.
behaviours = Vector{NamedTuple{(:rt_type, :passage_type),Tuple{Symbol,Symbol}}}(undef,8)
for j = 1:1:8
b = determine_behaviour([i,j,1.,0.1,3,0.05])
behaviours[j] = b
end
serialize("data_i_$i",behaviours)
end # 383.943875 seconds (920.99 M allocations: 42.692 GiB, 3.03% gc time)
@time for i = 1:1.:8.
behaviours = Vector{}(undef,8)
for j = 1:1:8
b = determine_behaviour([i,j,1.,0.1,3,0.05])
behaviours[j] = b
end
serialize("data_i_$i",behaviours) # 320.130658 seconds (900.79 M allocations: 41.852 GiB, 3.02% gc time)
end
Finally, a more minimal example, only depending on StochasticDiffEq
using StochasticDiffEq
function f(du,u,p,t)
σ,A = u
S,D,τ,v0,n,η = p
du[1] = v0+(S*σ)^n/((S*σ)^n+(D*A)^n+1) - σ
du[2] = τ*(σ-A)
end
function g(du,u,p,t)
σ,A = u
S,D,τ,v0,n,η = p
du[1,1] = η*sqrt(v0+(S*σ)^n/((S*σ)^n+(D*A)^n+1))
du[1,2] = -η*sqrt(σ)
du[1,3] = 0
du[1,4] = 0
du[2,1] = 0
du[2,2] = 0
du[2,3] = η*sqrt(τ*σ)
du[2,4] = -η*sqrt(τ*A)
end
function run_sims(p)
prob = SDEProblem(f,g,[1.,1.],(0.,50000.),p,noise_rate_prototype=zeros(2,4))
ensemble_prob = EnsembleProblem(prob,prob_func=(p,i,r)->p)
sols = solve(ensemble_prob,ImplicitEM(),EnsembleSerial();trajectories=8)
end
@time run_sims([1.,1.,1.,0.1,3,0.05]) # 2.346126 seconds (4.73 M allocations: 253.060 MiB, 11.12% gc time)
@time Threads.@threads for i = 1.:1.:8.
run_sims([i,1.,1.,0.1,3,0.05])
end # 22.173271 seconds (58.21 M allocations: 2.946 GiB, 1.01% gc time)
@time for i = 1.:1.:8.
run_sims([i,1.,1.,0.1,3,0.05])
end # 33.169600 seconds (58.22 M allocations: 2.947 GiB, 5.07% gc time)
this also fails to utilise multi threading.
(everything is run on a 8 core linux machine, using 8 threads)