Hi all,
I can’t seem to get my stochastic UODE running. Here’s the setup and generation of synthetic data:
using Plots, DiffEqFlux, Flux, Optim, LinearAlgebra, DifferentialEquations
using Statistics
using DiffEqSensitivity
using DifferentialEquations.EnsembleAnalysis
function sys!(du, u, p, t)
r, e, μ, h, ph, z, i = p
du[1] = e * 0.5 * (5μ - u[1]) # nutrient input time series
du[2] = e * 0.05 * (10μ - u[2]) # grazer density time series
du[3] = 0.2 * exp(u[1]) - 0.05 * u[3] - r * u[3] / (h + u[3]) * u[4] # nutrient concentration
du[4] =
r * u[3] / (h + u[3]) * u[4] - 0.1 * u[4] -
0.02 * u[4]^z / (ph^z + u[4]^z) * exp(u[2] / 2.0) + i #Algae density
end
function noise!(du, u, p, t)
du[1] = p[end] # n
du[2] = p[end] # n
du[3] = 0.0
du[4] = 0.0
end
# I've dropped the time frame down quite a bit here just to try to get it to run.
#datasize = 30
#tspan = (0.0f0, 10.0f0)
datasize = 10
tspan = (0.0f0, 3.0f0)
tsteps = range(tspan[1], tspan[2], length = datasize)
u0 = Float32[1.0, 1.0, 1.0, 1.0]
p_ = Float32[1.1, 1.0, 0.0, 2.0, 1.0, 1.0, 1e-6, 1.0]
prob = SDEProblem(sys!, noise!, u0, tspan, p_)
ensembleprob = EnsembleProblem(prob)
solution = solve(
ensembleprob,
EnsembleThreads();
trajectories = 1000,
abstol = 1e-5,
reltol = 1e-5,
maxiters = 1e8,
saveat = tsteps,
)
(truemean, truevar) = Array.(timeseries_steps_meanvar(solution))
I’ve got 14 hardware threads, so this doesn’t take too long to run (@btime
on the solve step is 23 seconds).
The plan is to solve for the Michaelis-Menton kinetics portions of eqs 3 and 4, so I set up my UODE following a combination of documentation pages (I’m not 100% sure there’s an example that matches my case): neural_sde, PINNs, zygote example, uode examples.
# Generate a NN with four inputs [L, G, N, A] and two outputs [NN1, NN2]
ann = FastChain(FastDense(4, 32, tanh), FastDense(32, 32, tanh), FastDense(32, 2))
α = initial_params(ann)
function dudt_(du, u, p, t)
r, e, μ, h, ph, z, i = p_
# we replace each of our MM values with the result of the NN.
MM = ann(u, p)
du[1] = e * 0.5 * (5μ - u[1]) # nutrient input time series
du[2] = e * 0.05 * (10μ - u[2]) # grazer density time series
du[3] = 0.2 * exp(u[1]) - 0.05 * u[3] - MM[1] # nutrient concentration
du[4] = MM[2] - 0.1 * u[4] - 0.02 * u[4]^z / (ph^z + u[4]^z) * exp(u[2] / 2.0) + i #Algae density
end
function noise_(du, u, p, t)
du[1] = p_[end]
du[2] = p_[end]
du[3] = 0.0
du[4] = 0.0
end
prob_nn = SDEProblem(dudt_, noise_, u0, tspan, p = nothing)
function loss(θ)
tmp_prob = remake(prob_nn, p = θ)
ensembleprob = EnsembleProblem(tmp_prob)
tmp_sol = solve(
ensembleprob,
EnsembleThreads();
saveat = tsteps,
trajectories = 100, # Drop down a little for troubleshooting
# sensealg = ReverseDiffAdjoint(),
# sensealg = ForwardDiffSensitivity(),
)
(tmp_mean, tmp_var) = Array.(timeseries_steps_meanvar(tmp_sol))
sum(abs2, truemean - tmp_mean) + 0.1 * sum(abs2, truevar - tmp_var), tmp_mean
end
# Test
loss(α) # @btime 360 ms
const losses = []
callback(θ, l, pred) = begin
push!(losses, l)
println("Current loss after $(length(losses)) iterations: $(losses[end])")
false
end
res1 = DiffEqFlux.sciml_train(
loss,
α,
ADAM(0.1),
cb = callback,
maxiters = 200,
# allow_f_increases = true,
)
If I’ve got this set up correctly, I cannot figure out what sensealg
works here. If I choose the default (i.e. omit a sensealg
keyword which is happening in the loss
function above), I get an unhelpful error stating:
Training 100%|█████████████████████████████████████████████████████████████| Time: 0:00:54
ERROR: Mutating arrays is not supported
Stacktrace
Stacktrace:
[1] error(::String) at ./error.jl:33
[2] (::Zygote.var"#372#373")(::Nothing) at /home/tim/.julia/packages/Zygote/c0awc/src/lib/array.jl:65
[3] (::Zygote.var"#2265#back#374"{Zygote.var"#372#373"})(::Nothing) at /home/tim/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
[4] timeseries_steps_meanvar at /home/tim/.julia/packages/DiffEqBase/3iigH/src/ensemble/ensemble_analysis.jl:71 [inlined]
[5] (::typeof(∂(timeseries_steps_meanvar)))(::Tuple{Array{Float64,2},Array{Float64,2}}) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
[6] loss at ./REPL[22]:12 [inlined]
[7] (::typeof(∂(loss)))(::Tuple{Float64,Nothing}) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
[8] #150 at /home/tim/.julia/packages/Zygote/c0awc/src/lib/lib.jl:191 [inlined]
[9] (::Zygote.var"#1693#back#152"{Zygote.var"#150#151"{typeof(∂(loss)),Tuple{Tuple{Nothing},Tuple{}}}})(::Tuple{Float64,Nothing}) at /home/tim/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
[10] #74 at /home/tim/.julia/packages/DiffEqFlux/8UHw5/src/train.jl:120 [inlined]
[11] (::typeof(∂(λ)))(::Float64) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
[12] (::Zygote.var"#54#55"{Zygote.Params,Zygote.Context,typeof(∂(λ))})(::Float64) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface.jl:177
[13] gradient(::Function, ::Zygote.Params) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface.jl:54
[14] macro expansion at /home/tim/.julia/packages/DiffEqFlux/8UHw5/src/train.jl:119 [inlined]
[15] macro expansion at /home/tim/.julia/packages/ProgressLogging/BBN0b/src/ProgressLogging.jl:328 [inlined]
[16] (::DiffEqFlux.var"#73#78"{typeof(callback),Int64,Bool,Bool,typeof(loss),Array{Float32,1},Zygote.Params})() at /home/tim/.julia/packages/DiffEqFlux/8UHw5/src/train.jl:64
[17] with_logstate(::Function, ::Any) at ./logging.jl:408
[18] with_logger at ./logging.jl:514 [inlined]
[19] maybe_with_logger(::DiffEqFlux.var"#73#78"{typeof(callback),Int64,Bool,Bool,typeof(loss),Array{Float32,1},Zygote.Params}, ::LoggingExtras.TeeLogger{Tuple{LoggingExtras.EarlyFilteredLogger{TerminalLoggers.TerminalLogger,DiffEqFlux.var"#68#70"},LoggingExtras.EarlyFilteredLogger{Logging.ConsoleLogger,DiffEqFlux.var"#69#71"}}}) at /home/tim/.julia/packages/DiffEqFlux/8UHw5/src/train.jl:39
[20] sciml_train(::Function, ::Array{Float32,1}, ::ADAM, ::Base.Iterators.Cycle{Tuple{DiffEqFlux.NullData}}; cb::Function, maxiters::Int64, progress::Bool, save_best::Bool) at /home/tim/.julia/packages/DiffEqFlux/8UHw5/src/train.jl:63
[21] top-level scope at REPL[26]:2
According to the DiffEqSensitivity docs, I should be using QuadratureAdjoint
for my system size, which if I’m not mistaken is what the defaults actually picked, since I get the Mutating arrays not supported error
using it. At what point am I mutating arrays?
ReverseDiffAdjoint
gives me an equally opaque error somewhere deep:
Training 100%|█████████████████████████████████████████████████████████████| Time: 0:01:12
ERROR: TaskFailedException:
MethodError: no method matching randn(::RandomNumbers.Xorshifts.Xoroshiro128Plus, ::Type{ReverseDiff.TrackedReal{Float32,Float32,Nothing}})
Closest candidates are:
randn(::Random.AbstractRNG, ::Type{T}, ::Tuple{Vararg{Int64,N}} where N) where T at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/Random/src/normal.jl:201
randn(::Random.AbstractRNG, ::Type{T}, ::Integer, ::Integer...) where T at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/Random/src/normal.jl:204
randn(::Random.AbstractRNG) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/Random/src/normal.jl:38
...
Stacktrace
Stacktrace:
[1] randn!(::RandomNumbers.Xorshifts.Xoroshiro128Plus, ::Array{ReverseDiff.TrackedReal{Float32,Float32,Nothing},1}) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/Random/src/normal.jl:178
[2] wiener_randn! at /home/tim/.julia/packages/DiffEqNoiseProcess/KaHvM/src/wiener.jl:9 [inlined]
[3] INPLACE_WHITE_NOISE_DIST(::Array{ReverseDiff.TrackedReal{Float32,Float32,Nothing},1}, ::NoiseProcess{ReverseDiff.TrackedReal{Float32,Float32,Nothing},2,Float32,Array{ReverseDiff.TrackedReal{Float32,Float32,Nothing},1},Array{ReverseDiff.TrackedReal{Float32,Float32,Nothing},1},Array{Array{ReverseDiff.TrackedReal{Float32,Float32,Nothing},1},1},typeof(DiffEqNoiseProcess.INPLACE_WHITE_NOISE_DIST),typeof(DiffEqNoiseProcess.INPLACE_WHITE_NOISE_BRIDGE),true,ResettableStacks.ResettableStack{Tuple{Float32,Array{ReverseDiff.TrackedReal{Float32,Float32,Nothing},1},Array{ReverseDiff.TrackedReal{Float32,Float32,Nothing},1}},true},ResettableStacks.ResettableStack{Tuple{Float32,Array{ReverseDiff.TrackedReal{Float32,Float32,Nothing},1},Array{ReverseDiff.TrackedReal{Float32,Float32,Nothing},1}},true},RSWM{Float64},Nothing,RandomNumbers.Xorshifts.Xoroshiro128Plus}, ::Float32, ::Array{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}},1}, ::ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}, ::Float32, ::RandomNumbers.Xorshifts.Xoroshiro128Plus) at /home/tim/.julia/packages/DiffEqNoiseProcess/KaHvM/src/wiener.jl:42
[4] setup_next_step! at /home/tim/.julia/packages/DiffEqNoiseProcess/KaHvM/src/noise_interfaces/noise_process_interface.jl:145 [inlined]
[5] setup_next_step! at /home/tim/.julia/packages/StochasticDiffEq/Abmgl/src/integrators/integrator_utils.jl:2 [inlined]
[6] __init(::SDEProblem{Array{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}},1},Tuple{Float32,Float32},true,ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing}, ::SOSRI, ::Array{Any,1}, ::Array{Any,1}, ::Type{T} where T, ::Type{Val{true}}; saveat::StepRangeLen{Float32,Float64,Float64}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_noise::Bool, save_on::Bool, save_start::Bool, save_end::Bool, callback::Nothing, dense::Bool, calck::Bool, dt::Float32, adaptive::Bool, gamma::Rational{Int64}, abstol::Nothing, reltol::Nothing, qmax::Rational{Int64}, qmin::Rational{Int64}, qoldinit::Rational{Int64}, fullnormalize::Bool, failfactor::Int64, beta2::Rational{Int64}, beta1::Rational{Int64}, delta::Rational{Int64}, maxiters::Int64, dtmax::Float32, dtmin::Float32, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, force_dtmin::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), userdata::Nothing, initialize_integrator::Bool, seed::UInt64, alias_u0::Bool, alias_jumps::Bool, kwargs::Base.Iterators.Pairs{Symbol,Bool,Tuple{Symbol},NamedTuple{(:default_set,),Tuple{Bool}}}) at /home/tim/.julia/packages/StochasticDiffEq/Abmgl/src/solve.jl:562
[7] #__solve#97 at /home/tim/.julia/packages/StochasticDiffEq/Abmgl/src/solve.jl:6 [inlined]
[8] __solve(::SDEProblem{Array{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}},1},Tuple{Float32,Float32},true,ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing}, ::Nothing; default_set::Bool, kwargs::Base.Iterators.Pairs{Symbol,StepRangeLen{Float32,Float64,Float64},Tuple{Symbol},NamedTuple{(:saveat,),Tuple{StepRangeLen{Float32,Float64,Float64}}}}) at /home/tim/.julia/packages/DifferentialEquations/fpohE/src/default_solve.jl:7
[9] #solve_call#456 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/solve.jl:65 [inlined]
[10] #solve_up#458 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/solve.jl:92 [inlined]
[11] #solve#457 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/solve.jl:74 [inlined]
[12] (::DiffEqSensitivity.var"#reversediff_adjoint_forwardpass#168"{Base.Iterators.Pairs{Symbol,StepRangeLen{Float32,Float64,Float64},Tuple{Symbol},NamedTuple{(:saveat,),Tuple{StepRangeLen{Float32,Float64,Float64}}}},SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},Nothing,Tuple{}})(::ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}, ::ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}) at /home/tim/.julia/packages/DiffEqSensitivity/WiCRA/src/local_sensitivity/concrete_solve.jl:315
[13] ReverseDiff.GradientTape(::Function, ::Tuple{Array{Float32,1},Array{Float32,1}}, ::ReverseDiff.GradientConfig{Tuple{ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}},ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}}}) at /home/tim/.julia/packages/ReverseDiff/jFRo1/src/api/tape.jl:207
[14] GradientTape at /home/tim/.julia/packages/ReverseDiff/jFRo1/src/api/tape.jl:204 [inlined]
[15] _concrete_solve_adjoint(::SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing}, ::Nothing, ::ReverseDiffAdjoint, ::Array{Float32,1}, ::Array{Float32,1}; kwargs::Base.Iterators.Pairs{Symbol,StepRangeLen{Float32,Float64,Float64},Tuple{Symbol},NamedTuple{(:saveat,),Tuple{StepRangeLen{Float32,Float64,Float64}}}}) at /home/tim/.julia/packages/DiffEqSensitivity/WiCRA/src/local_sensitivity/concrete_solve.jl:326
[16] #_solve_adjoint#478 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/solve.jl:277 [inlined]
[17] #adjoint#469 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/solve.jl:241 [inlined]
[18] _pullback at /home/tim/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:53 [inlined]
[19] adjoint at /home/tim/.julia/packages/Zygote/c0awc/src/lib/lib.jl:188 [inlined]
[20] _pullback at /home/tim/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:47 [inlined]
[21] #solve#457 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/solve.jl:74 [inlined]
[22] _pullback(::Zygote.Context, ::DiffEqBase.var"##solve#457", ::ReverseDiffAdjoint, ::Nothing, ::Nothing, ::Base.Iterators.Pairs{Symbol,StepRangeLen{Float32,Float64,Float64},Tuple{Symbol},NamedTuple{(:saveat,),Tuple{StepRangeLen{Float32,Float64,Float64}}}}, ::typeof(solve), ::SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing}, ::Nothing) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
[23] adjoint at /home/tim/.julia/packages/Zygote/c0awc/src/lib/lib.jl:188 [inlined]
[24] _pullback at /home/tim/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:47 [inlined]
[25] _pullback(::Zygote.Context, ::DiffEqBase.var"#solve##kw", ::NamedTuple{(:saveat, :sensealg),Tuple{StepRangeLen{Float32,Float64,Float64},ReverseDiffAdjoint}}, ::typeof(solve), ::SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing}, ::Nothing) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
[26] #batch_func#360 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/ensemble/basic_ensemble_solve.jl:143 [inlined]
[27] _pullback(::Zygote.Context, ::DiffEqBase.var"##batch_func#360", ::Base.Iterators.Pairs{Symbol,Any,Tuple{Symbol,Symbol},NamedTuple{(:saveat, :sensealg),Tuple{StepRangeLen{Float32,Float64,Float64},ReverseDiffAdjoint}}}, ::typeof(DiffEqBase.batch_func), ::Int64, ::EnsembleProblem{SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing}, ::Nothing) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
[28] #365 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/ensemble/basic_ensemble_solve.jl:187 [inlined]
[29] #499 at /home/tim/.julia/packages/Zygote/c0awc/src/lib/array.jl:180 [inlined]
[30] iterate at ./generator.jl:47 [inlined]
[31] _collect(::UnitRange{Int64}, ::Base.Generator{UnitRange{Int64},Zygote.var"#499#503"{Zygote.Context,DiffEqBase.var"#365#366"{Base.Iterators.Pairs{Symbol,Any,Tuple{Symbol,Symbol},NamedTuple{(:saveat, :sensealg),Tuple{StepRangeLen{Float32,Float64,Float64},ReverseDiffAdjoint}}},EnsembleProblem{SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing},Nothing}}}, ::Base.EltypeUnknown, ::Base.HasShape{1}) at ./array.jl:699
[32] collect_similar(::UnitRange{Int64}, ::Base.Generator{UnitRange{Int64},Zygote.var"#499#503"{Zygote.Context,DiffEqBase.var"#365#366"{Base.Iterators.Pairs{Symbol,Any,Tuple{Symbol,Symbol},NamedTuple{(:saveat, :sensealg),Tuple{StepRangeLen{Float32,Float64,Float64},ReverseDiffAdjoint}}},EnsembleProblem{SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing},Nothing}}}) at ./array.jl:628
[33] map(::Function, ::UnitRange{Int64}) at ./abstractarray.jl:2162
[34] ∇map(::Zygote.Context, ::Function, ::UnitRange{Int64}) at /home/tim/.julia/packages/Zygote/c0awc/src/lib/array.jl:180
[35] adjoint at /home/tim/.julia/packages/Zygote/c0awc/src/lib/array.jl:196 [inlined]
[36] _pullback at /home/tim/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:47 [inlined]
[37] #solve_batch#364 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/ensemble/basic_ensemble_solve.jl:186 [inlined]
[38] _pullback(::Zygote.Context, ::DiffEqBase.var"##solve_batch#364", ::Base.Iterators.Pairs{Symbol,Any,Tuple{Symbol,Symbol},NamedTuple{(:saveat, :sensealg),Tuple{StepRangeLen{Float32,Float64,Float64},ReverseDiffAdjoint}}}, ::typeof(DiffEqBase.solve_batch), ::EnsembleProblem{SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing}, ::Nothing, ::EnsembleSerial, ::UnitRange{Int64}, ::Int64) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0 (repeats 2 times)
[39] #369 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/ensemble/basic_ensemble_solve.jl:213 [inlined]
[40] _pullback(::Zygote.Context, ::DiffEqBase.var"#369#371"{Base.Iterators.Pairs{Symbol,Any,Tuple{Symbol,Symbol},NamedTuple{(:saveat, :sensealg),Tuple{StepRangeLen{Float32,Float64,Float64},ReverseDiffAdjoint}}},EnsembleProblem{SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing},Nothing,UnitRange{Int64},Int64,Int64}, ::Int64) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
[41] #639 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/init.jl:233 [inlined]
[42] macro expansion at /home/tim/.julia/packages/DiffEqBase/3iigH/src/ensemble/basic_ensemble_solve.jl:221 [inlined]
[43] (::DiffEqBase.var"#510#threadsfor_fun#372"{DiffEqBase.var"#639#644"{Zygote.Context,DiffEqBase.var"#369#371"{Base.Iterators.Pairs{Symbol,Any,Tuple{Symbol,Symbol},NamedTuple{(:saveat, :sensealg),Tuple{StepRangeLen{Float32,Float64,Float64},ReverseDiffAdjoint}}},EnsembleProblem{SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing},Nothing,UnitRange{Int64},Int64,Int64}},Tuple{UnitRange{Int64}},Array{Tuple{Any,typeof(∂(λ))},1},UnitRange{Int64}})(::Bool) at ./threadingconstructs.jl:81
[44] (::DiffEqBase.var"#510#threadsfor_fun#372"{DiffEqBase.var"#639#644"{Zygote.Context,DiffEqBase.var"#369#371"{Base.Iterators.Pairs{Symbol,Any,Tuple{Symbol,Symbol},NamedTuple{(:saveat, :sensealg),Tuple{StepRangeLen{Float32,Float64,Float64},ReverseDiffAdjoint}}},EnsembleProblem{SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing},Nothing,UnitRange{Int64},Int64,Int64}},Tuple{UnitRange{Int64}},Array{Tuple{Any,typeof(∂(λ))},1},UnitRange{Int64}})() at ./threadingconstructs.jl:48
Stacktrace:
[1] wait at ./task.jl:267 [inlined]
[2] threading_run(::Function) at ./threadingconstructs.jl:34
[3] macro expansion at ./threadingconstructs.jl:93 [inlined]
[4] tmap(::Function, ::UnitRange{Int64}) at /home/tim/.julia/packages/DiffEqBase/3iigH/src/ensemble/basic_ensemble_solve.jl:220
[5] ∇tmap(::Zygote.Context, ::Function, ::UnitRange{Int64}) at /home/tim/.julia/packages/DiffEqBase/3iigH/src/init.jl:233
[6] adjoint at /home/tim/.julia/packages/DiffEqBase/3iigH/src/zygote.jl:55 [inlined]
[7] _pullback(::Zygote.Context, ::typeof(DiffEqBase.tmap), ::Function, ::UnitRange{Int64}) at /home/tim/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:47
[8] #solve_batch#367 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/ensemble/basic_ensemble_solve.jl:207 [inlined]
[9] _pullback(::Zygote.Context, ::DiffEqBase.var"##solve_batch#367", ::Base.Iterators.Pairs{Symbol,Any,Tuple{Symbol,Symbol},NamedTuple{(:saveat, :sensealg),Tuple{StepRangeLen{Float32,Float64,Float64},ReverseDiffAdjoint}}}, ::typeof(DiffEqBase.solve_batch), ::EnsembleProblem{SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing}, ::Nothing, ::EnsembleThreads, ::UnitRange{Int64}, ::Int64) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
[10] macro expansion at ./timing.jl:233 [inlined]
[11] #__solve#359 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/ensemble/basic_ensemble_solve.jl:108 [inlined]
[12] _pullback(::Zygote.Context, ::DiffEqBase.var"##__solve#359", ::Int64, ::Int64, ::Int64, ::Base.Iterators.Pairs{Symbol,Any,Tuple{Symbol,Symbol},NamedTuple{(:saveat, :sensealg),Tuple{StepRangeLen{Float32,Float64,Float64},ReverseDiffAdjoint}}}, ::typeof(DiffEqBase.__solve), ::EnsembleProblem{SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing}, ::Nothing, ::EnsembleThreads) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0 (repeats 2 times)
[13] #__solve#358 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/ensemble/basic_ensemble_solve.jl:87 [inlined]
[14] _pullback(::Zygote.Context, ::DiffEqBase.var"##__solve#358", ::Base.Iterators.Pairs{Symbol,Any,Tuple{Symbol,Symbol,Symbol},NamedTuple{(:saveat, :trajectories, :sensealg),Tuple{StepRangeLen{Float32,Float64,Float64},Int64,ReverseDiffAdjoint}}}, ::typeof(DiffEqBase.__solve), ::EnsembleProblem{SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing}, ::EnsembleThreads) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
[15] adjoint at /home/tim/.julia/packages/Zygote/c0awc/src/lib/lib.jl:188 [inlined]
[16] _pullback at /home/tim/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:47 [inlined]
[17] #solve#459 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/solve.jl:100 [inlined]
[18] _pullback(::Zygote.Context, ::DiffEqBase.var"##solve#459", ::Base.Iterators.Pairs{Symbol,Any,Tuple{Symbol,Symbol,Symbol},NamedTuple{(:saveat, :trajectories, :sensealg),Tuple{StepRangeLen{Float32,Float64,Float64},Int64,ReverseDiffAdjoint}}}, ::typeof(solve), ::EnsembleProblem{SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing}, ::EnsembleThreads) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
[19] adjoint at /home/tim/.julia/packages/Zygote/c0awc/src/lib/lib.jl:188 [inlined]
[20] _pullback at /home/tim/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:47 [inlined]
[21] _pullback(::Zygote.Context, ::DiffEqBase.var"#solve##kw", ::NamedTuple{(:saveat, :trajectories, :sensealg),Tuple{StepRangeLen{Float32,Float64,Float64},Int64,ReverseDiffAdjoint}}, ::typeof(solve), ::EnsembleProblem{SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing}, ::EnsembleThreads) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
Had to cut this due to the text limit, I can paste the rest in a subsequent post if needed.
ForwardDiffSensitivity
does not fail, but also does not solve anything. I’ve left the run going overnight and don’t even manage to get to the first callback. It just sits and spools forever.
So if all of these are failing, have I set up this problem incorrectly? Or is there a better strategy to solve this particular system, or a better sensealg
setup here?