DifferentialEquations.jl has facilities for parallel processing of the solver w/different parameters, see docs. In further testing of prior code, I added the option to run parallel threads using EmsembleThreads(), as documented on the above docs page. The following code crashes as the size of the problem increases, for reasons that I cannot find.
To run the code:
import FrenchFlagThreads as fft
Then for the following, the first four work, the next two crash with the first lines of crash reporting shown. Note also that in REPL w/Revise, sometimes one of the first four commands fails the first time, but then running it again immediately works, suggesting some sort of compilation issue as an additional problem ??
fft.main(S=3, B=256, threads=false, ensemble=false)
fft.main(S=3, B=256, threads=false, ensemble=true)
fft.main(S=3, B=256, threads=true, ensemble=true)
fft.main(S=7, B=256, threads=false, ensemble=false)
Those first four work (apart from need to rerun sometimes because of REPL or compile issue)
fft.main(S=7, B=256, threads=false, ensemble=true)
fails w/output:
0.25135885013308756
0.2627287236602456
0.2510416786361418
0.25157018945790177
ERROR: BoundsError: attempt to access 13-element Vector{Vector{Float64}} at index [14]
Stacktrace:
[1] getindex
@ ./array.jl:805 [inlined]
[2] (::DiffEqSensitivity.var"#253#259"{Vector{Matrix{Float64}}, Vector{Vector{Float64}}, Vector{Float64}})(i::Int64)
@ DiffEqSensitivity /opt/julia/packages/DiffEqSensitivity/uakCr/src/concrete_solve.jl:527
[3] _mapreduce(f::DiffEqSensitivity.var"#253#259"{Vector{Matrix{Float64}}, Vector{Vector{Float64}}, Vector{Float64}}, op::typeof(Base.add_sum), #unused#::IndexLinear, A::Base.OneTo{Int64})
@ Base ./reduce.jl:411
fft.main(S=7, B=256, threads=true, ensemble=true)
fails with output:
0.25135885013308756
0.2627287236602456
0.2510416786361418
0.25157018945790177
ERROR: TaskFailedException
Stacktrace:
[1] wait
@ ./task.jl:322 [inlined]
[2] threading_run(func::Function)
@ Base.Threads ./threadingconstructs.jl:34
[3] macro expansion
@ ./threadingconstructs.jl:93 [inlined]
[4] tmap(::Function, ::Vector{typeof(∂(λ))}, ::Vararg{Any, N} where N)
@ SciMLBase /opt/julia/packages/SciMLBase/x3z0g/src/ensemble/basic_ensemble_solve.jl:220
Here is the code:
module FrenchFlagThreads
using Random: seed!
using RandomNumbers.Xorshifts
using DifferentialEquations, DiffEqFlux
using DiffEqSensitivity, OrdinaryDiffEq, ForwardDiff
using Statistics
r = Xoroshiro128Plus(0x1234567890abcdef)
const final_t = 3.0 # final time
const tspan = (0.0,final_t) # range of times
atol = 1e-6 # abs tol for diffeq
rtol = 1e-4 # rel tol for diffeq
####################################################################
phi(x) = 1.0 ./ (exp.(-x) .+ 1.0)
function ode!(ds, s, w, t, input)
ds .= phi(w*s) .- s .+ input
end
function newbatch(B, S; stochastic=false)
initial = 0.1 .* ones(B,S)
if stochastic
inputvec = 2.0 .* rand(r,B)
else
inputvec = [range(0,2,length=B);] # need semicolon to expand range
end
# make inputs [B,S] with first column as input to y0 at t=0, other cols as zeros
input = hcat(inputvec, zeros(B,S-1))
output = [(inputvec[i] >= 0.5 && inputvec[i] <= 1.5) ? 1.0 : 0.0 for i in 1:B]
# map fails because of bug in autodiff
# output = map(x -> convert(Float64,x >= 0.5 && x <= 1.5), inputvec)
return input, initial, output
end
function loss(p, B, S, stoch_batch; threads=false, use_ensemble=false)
input, initial, output = newbatch(B, S, stochastic=stoch_batch)
if use_ensemble
prob = ODEProblem((ds, s, p, t) -> ode!(ds, s, w(p,S), t, input[1,:]),
initial[1,:], tspan, p)
function prob_func(prob, i, repeat)
remake(prob, f = (ds, s, p, t) -> ode!(ds, s, w(p,S), t, input[i,:]),
u0 = initial[i,:])
end
ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
ensemble = threads ? EnsembleThreads() : EnsembleSerial()
soln = solve(ensemble_prob, Tsit5(), ensemble, saveon = false, trajectories = B,
reltol=rtol,abstol=atol)
# soln is vector of length B, each soln has final time at [end]
# store each batch element as a column because Julia is column major and so faster
result = soln[1][end]
for i in 2:B
result = hcat(result, soln[i][end])
end
else
for i in 1:B
prob = ODEProblem((ds, s, p, t) -> ode!(ds, s, w(p,S), t, input[i,:]),
initial[i,:], tspan, p)
soln = solve(prob, Tsit5(), saveon = false, reltol=rtol,abstol=atol)
# store each batch element as a column because Julia is column major and so faster
result = (i == 1) ? soln[end] : hcat(result, soln[end])
end
end
# assume the last column (last species) of each replicate is the phenotype for that
# replicate, and then compare the phenotype to the expected output
loss_val = mean(abs2.(a(p) .* result[end,:] .- output))
# return loss, additional ret vals used as input for callback
return loss_val
end
function callback(p, l)
display(l)
return false
end
w(p, S) = reshape(p[1:end-1],S,S)
a(p) = p[end]
# parameters p:
# a = p[end] = 1.0
# w = reshape(p[1:end-1],S,S)
function main(;B::Int=4, S::Int=3, stoch_batch::Bool=false,
rnd_entropy::Bool=false, threads::Bool=false, ensemble::Bool=false)
rnd_entropy ? seed!(r) : seed!(r, 0x1234567437ab88ef)
p = ones(S*S+1)
p[1:end-1] = 0.1 .* randn(r, S*S)
# try without ADAM arg, should then be ADAM -> BFGS by default
# or do sequence with ADAM then BFGS
opt_result = DiffEqFlux.sciml_train(p -> loss(p, B, S, stoch_batch;
threads=threads, use_ensemble=ensemble),
p, ADAM(0.1); cb = callback, maxiters = 50)
finalp = opt_result.u
println()
print(opt_result)
println()
print("Final param = ", typeof(finalp))
println()
println("Loss calculated from final param = ",
loss(finalp, B, S, stoch_batch; threads=threads)[1])
return w(opt_result.u,S)
end