Hello everyone,
I updated to the newest packages of Zygote and Flux and I get error apparently related to the paralelization of an ensemble simulation
using DifferentialEquations, Flux, DiffEqFlux
using DiffEqSensitivity
using Random
function dt!(du, u, p, t)
x, y = u
α, β, δ, γ = p
du[1] = dx = α*x - β*x*y
du[2] = dy = -δ*y + γ*x*y
end
n_par=3
Random.seed!(2)
u0=rand(2,n_par)
u0[:,1] = [1.0,1.0]
tspan = (0.0, 10.0)
p = [2.2, 1.0, 2.0, 0.4]
prob_ode = ODEProblem(dt!, u0[:,1], tspan)
function test_loss(p1,prob)
function prob_func(prob, i, repeat)
@show i
remake(prob,u0=u0[:,i])
end
#define ensemble problem
ensembleprob = EnsembleProblem(prob,prob_func = prob_func)
u = Array(solve(ensembleprob, EM(),trajectories=n_par,
ensemblealg=EnsembleThreads(), p=p,
sensealg = ForwardDiffSensitivity(),
saveat = 0.1, dt=0.001))[:,end,:]
loss=sum(u)
return loss
end
#testing backprop
ps = Flux.params(p)
@time gs = gradient(ps) do
test_loss(p,prob_ode)
end
#ERROR
ERROR: Compiling Tuple{typeof(Base.Threads.threading_run),SciMLBase.var"#400#threadsfor_fun#446"{SciMLBase.var"#443#445"{Base.Iterators.Pairs{Symbol,Any,NTuple{5,Symbol},NamedTuple{(:ensemblealg, :p, :sensealg, :saveat, :dt),Tuple{EnsembleThreads,Array{Float64,1},ForwardDiffSensitivity{0,nothing},Float64,Float64}}},EnsembleProblem{ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,SciMLBase.NullParameters,ODEFunction{true,typeof(dt!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},SciMLBase.StandardODEProblem},var"#prob_func#7",typeof(SciMLBase.DEFAULT_OUTPUT_FUNC),typeof(SciMLBase.DEFAULT_REDUCTION),Nothing},EM{true},UnitRange{Int64},Int64,Int64},Tuple{UnitRange{Int64}},Array{Array{T,1} where T,1},UnitRange{Int64}}}: try/catch is not supported.
The versions of the packages are:
[aae7a2af] DiffEqFlux v1.34.0
[41bf760c] DiffEqSensitivity v6.42.0
[0c46a032] DifferentialEquations v6.16.0
[587475ba] Flux v0.11.6
[e88e6eb3] Zygote v0.6.2
For oler these versions the code works:
[aae7a2af] DiffEqFlux v1.23.0
[41bf760c] DiffEqSensitivity v6.33.0
[0c46a032] DifferentialEquations v6.15.0
[587475ba] Flux v0.11.1
[e88e6eb3] Zygote v0.5.9
Any ideas what’s wrong?