Error when getting grads of an ensemble simulation with the newest packages

Hello everyone,

I updated to the newest packages of Zygote and Flux and I get error apparently related to the paralelization of an ensemble simulation

using DifferentialEquations, Flux,  DiffEqFlux
using DiffEqSensitivity
using Random

function dt!(du, u, p, t)
  x, y = u
  α, β, δ, γ = p
  du[1] = dx = α*x - β*x*y
  du[2] = dy = -δ*y + γ*x*y
end

n_par=3
Random.seed!(2)
u0=rand(2,n_par)
u0[:,1] = [1.0,1.0]
tspan = (0.0, 10.0)
p = [2.2, 1.0, 2.0, 0.4]
prob_ode = ODEProblem(dt!, u0[:,1], tspan)

function test_loss(p1,prob)

	function prob_func(prob, i, repeat)
		@show i
		remake(prob,u0=u0[:,i])
	end

	#define ensemble problem
	ensembleprob = EnsembleProblem(prob,prob_func = prob_func)

	u = Array(solve(ensembleprob, EM(),trajectories=n_par,
	ensemblealg=EnsembleThreads(), p=p,
	sensealg = ForwardDiffSensitivity(),
	saveat = 0.1, dt=0.001))[:,end,:]
	loss=sum(u)
	return loss
end

#testing backprop
ps = Flux.params(p)
@time gs = gradient(ps) do
	test_loss(p,prob_ode)
end

#ERROR
ERROR: Compiling Tuple{typeof(Base.Threads.threading_run),SciMLBase.var"#400#threadsfor_fun#446"{SciMLBase.var"#443#445"{Base.Iterators.Pairs{Symbol,Any,NTuple{5,Symbol},NamedTuple{(:ensemblealg, :p, :sensealg, :saveat, :dt),Tuple{EnsembleThreads,Array{Float64,1},ForwardDiffSensitivity{0,nothing},Float64,Float64}}},EnsembleProblem{ODEProblem{Array{Float64,1},Tuple{Float64,Float64},true,SciMLBase.NullParameters,ODEFunction{true,typeof(dt!),LinearAlgebra.UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,typeof(SciMLBase.DEFAULT_OBSERVED),Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},SciMLBase.StandardODEProblem},var"#prob_func#7",typeof(SciMLBase.DEFAULT_OUTPUT_FUNC),typeof(SciMLBase.DEFAULT_REDUCTION),Nothing},EM{true},UnitRange{Int64},Int64,Int64},Tuple{UnitRange{Int64}},Array{Array{T,1} where T,1},UnitRange{Int64}}}: try/catch is not supported.

The versions of the packages are:
[aae7a2af] DiffEqFlux v1.34.0
[41bf760c] DiffEqSensitivity v6.42.0
[0c46a032] DifferentialEquations v6.16.0
[587475ba] Flux v0.11.6
[e88e6eb3] Zygote v0.6.2

For oler these versions the code works:
[aae7a2af] DiffEqFlux v1.23.0
[41bf760c] DiffEqSensitivity v6.33.0
[0c46a032] DifferentialEquations v6.15.0
[587475ba] Flux v0.11.1
[e88e6eb3] Zygote v0.5.9

Any ideas what’s wrong?

It was a regression from the SciMLBase change. Fixed in Fix AD of ensemble problems with threading by ChrisRackauckas · Pull Request #647 · SciML/DiffEqBase.jl · GitHub

Great, thanks!