# Correct sensealg for an ensemble problem of stochastic UODEs with diagonal noise

Hi all,

I can’t seem to get my stochastic UODE running. Here’s the setup and generation of synthetic data:

``````using Plots, DiffEqFlux, Flux, Optim, LinearAlgebra, DifferentialEquations
using Statistics
using DiffEqSensitivity
using DifferentialEquations.EnsembleAnalysis

function sys!(du, u, p, t)
r, e, μ, h, ph, z, i = p
du[1] = e * 0.5 * (5μ - u[1]) # nutrient input time series
du[2] = e * 0.05 * (10μ - u[2]) # grazer density time series
du[3] = 0.2 * exp(u[1]) - 0.05 * u[3] - r * u[3] / (h + u[3]) * u[4] # nutrient concentration
du[4] =
r * u[3] / (h + u[3]) * u[4] - 0.1 * u[4] -
0.02 * u[4]^z / (ph^z + u[4]^z) * exp(u[2] / 2.0) + i #Algae density
end

function noise!(du, u, p, t)
du[1] = p[end] # n
du[2] = p[end] # n
du[3] = 0.0
du[4] = 0.0
end

# I've dropped the time frame down quite a bit here just to try to get it to run.
#datasize = 30
#tspan = (0.0f0, 10.0f0)
datasize = 10
tspan = (0.0f0, 3.0f0)
tsteps = range(tspan[1], tspan[2], length = datasize)
u0 = Float32[1.0, 1.0, 1.0, 1.0]

p_ = Float32[1.1, 1.0, 0.0, 2.0, 1.0, 1.0, 1e-6, 1.0]

prob = SDEProblem(sys!, noise!, u0, tspan, p_)
ensembleprob = EnsembleProblem(prob)

solution = solve(
ensembleprob,
trajectories = 1000,
abstol = 1e-5,
reltol = 1e-5,
maxiters = 1e8,
saveat = tsteps,
)

(truemean, truevar) = Array.(timeseries_steps_meanvar(solution))
``````

I’ve got 14 hardware threads, so this doesn’t take too long to run (`@btime` on the solve step is 23 seconds).

The plan is to solve for the Michaelis-Menton kinetics portions of eqs 3 and 4, so I set up my UODE following a combination of documentation pages (I’m not 100% sure there’s an example that matches my case): neural_sde, PINNs, zygote example, uode examples.

``````# Generate a NN with four inputs [L, G, N, A] and two outputs [NN1, NN2]
ann = FastChain(FastDense(4, 32, tanh), FastDense(32, 32, tanh), FastDense(32, 2))
α = initial_params(ann)

function dudt_(du, u, p, t)
r, e, μ, h, ph, z, i = p_

# we replace each of our MM values with the result of the NN.
MM = ann(u, p)

du[1] = e * 0.5 * (5μ - u[1]) # nutrient input time series
du[2] = e * 0.05 * (10μ - u[2]) # grazer density time series
du[3] = 0.2 * exp(u[1]) - 0.05 * u[3] - MM[1] # nutrient concentration
du[4] = MM[2] - 0.1 * u[4] - 0.02 * u[4]^z / (ph^z + u[4]^z) * exp(u[2] / 2.0) + i #Algae density
end
function noise_(du, u, p, t)
du[1] = p_[end]
du[2] = p_[end]
du[3] = 0.0
du[4] = 0.0
end

prob_nn = SDEProblem(dudt_, noise_, u0, tspan, p = nothing)

function loss(θ)
tmp_prob = remake(prob_nn, p = θ)
ensembleprob = EnsembleProblem(tmp_prob)
tmp_sol = solve(
ensembleprob,
saveat = tsteps,
trajectories = 100, # Drop down a little for troubleshooting
#    sensealg = ForwardDiffSensitivity(),
)
(tmp_mean, tmp_var) = Array.(timeseries_steps_meanvar(tmp_sol))
sum(abs2, truemean - tmp_mean) + 0.1 * sum(abs2, truevar - tmp_var), tmp_mean
end

# Test
loss(α) # @btime 360 ms

const losses = []

callback(θ, l, pred) = begin
push!(losses, l)
println("Current loss after \$(length(losses)) iterations: \$(losses[end])")
false
end

res1 = DiffEqFlux.sciml_train(
loss,
α,
cb = callback,
maxiters = 200,
# allow_f_increases = true,
)
``````

If I’ve got this set up correctly, I cannot figure out what `sensealg` works here. If I choose the default (i.e. omit a `sensealg` keyword which is happening in the `loss` function above), I get an unhelpful error stating:

``````Training 100%|█████████████████████████████████████████████████████████████| Time: 0:00:54
ERROR: Mutating arrays is not supported
``````
Stacktrace
``````Stacktrace:
[1] error(::String) at ./error.jl:33
[2] (::Zygote.var"#372#373")(::Nothing) at /home/tim/.julia/packages/Zygote/c0awc/src/lib/array.jl:65
[4] timeseries_steps_meanvar at /home/tim/.julia/packages/DiffEqBase/3iigH/src/ensemble/ensemble_analysis.jl:71 [inlined]
[5] (::typeof(∂(timeseries_steps_meanvar)))(::Tuple{Array{Float64,2},Array{Float64,2}}) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
[6] loss at ./REPL[22]:12 [inlined]
[7] (::typeof(∂(loss)))(::Tuple{Float64,Nothing}) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
[8] #150 at /home/tim/.julia/packages/Zygote/c0awc/src/lib/lib.jl:191 [inlined]
[10] #74 at /home/tim/.julia/packages/DiffEqFlux/8UHw5/src/train.jl:120 [inlined]
[11] (::typeof(∂(λ)))(::Float64) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
[12] (::Zygote.var"#54#55"{Zygote.Params,Zygote.Context,typeof(∂(λ))})(::Float64) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface.jl:177
[14] macro expansion at /home/tim/.julia/packages/DiffEqFlux/8UHw5/src/train.jl:119 [inlined]
[15] macro expansion at /home/tim/.julia/packages/ProgressLogging/BBN0b/src/ProgressLogging.jl:328 [inlined]
[16] (::DiffEqFlux.var"#73#78"{typeof(callback),Int64,Bool,Bool,typeof(loss),Array{Float32,1},Zygote.Params})() at /home/tim/.julia/packages/DiffEqFlux/8UHw5/src/train.jl:64
[17] with_logstate(::Function, ::Any) at ./logging.jl:408
[18] with_logger at ./logging.jl:514 [inlined]
[19] maybe_with_logger(::DiffEqFlux.var"#73#78"{typeof(callback),Int64,Bool,Bool,typeof(loss),Array{Float32,1},Zygote.Params}, ::LoggingExtras.TeeLogger{Tuple{LoggingExtras.EarlyFilteredLogger{TerminalLoggers.TerminalLogger,DiffEqFlux.var"#68#70"},LoggingExtras.EarlyFilteredLogger{Logging.ConsoleLogger,DiffEqFlux.var"#69#71"}}}) at /home/tim/.julia/packages/DiffEqFlux/8UHw5/src/train.jl:39
[20] sciml_train(::Function, ::Array{Float32,1}, ::ADAM, ::Base.Iterators.Cycle{Tuple{DiffEqFlux.NullData}}; cb::Function, maxiters::Int64, progress::Bool, save_best::Bool) at /home/tim/.julia/packages/DiffEqFlux/8UHw5/src/train.jl:63
[21] top-level scope at REPL[26]:2
``````

According to the DiffEqSensitivity docs, I should be using `QuadratureAdjoint` for my system size, which if I’m not mistaken is what the defaults actually picked, since I get the `Mutating arrays not supported error` using it. At what point am I mutating arrays?

`ReverseDiffAdjoint` gives me an equally opaque error somewhere deep:

``````Training 100%|█████████████████████████████████████████████████████████████| Time: 0:01:12
MethodError: no method matching randn(::RandomNumbers.Xorshifts.Xoroshiro128Plus, ::Type{ReverseDiff.TrackedReal{Float32,Float32,Nothing}})
Closest candidates are:
randn(::Random.AbstractRNG, ::Type{T}, ::Tuple{Vararg{Int64,N}} where N) where T at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/Random/src/normal.jl:201
randn(::Random.AbstractRNG, ::Type{T}, ::Integer, ::Integer...) where T at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/Random/src/normal.jl:204
randn(::Random.AbstractRNG) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/Random/src/normal.jl:38
...
``````
Stacktrace
``````Stacktrace:
[1] randn!(::RandomNumbers.Xorshifts.Xoroshiro128Plus, ::Array{ReverseDiff.TrackedReal{Float32,Float32,Nothing},1}) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/Random/src/normal.jl:178
[2] wiener_randn! at /home/tim/.julia/packages/DiffEqNoiseProcess/KaHvM/src/wiener.jl:9 [inlined]
[3] INPLACE_WHITE_NOISE_DIST(::Array{ReverseDiff.TrackedReal{Float32,Float32,Nothing},1}, ::NoiseProcess{ReverseDiff.TrackedReal{Float32,Float32,Nothing},2,Float32,Array{ReverseDiff.TrackedReal{Float32,Float32,Nothing},1},Array{ReverseDiff.TrackedReal{Float32,Float32,Nothing},1},Array{Array{ReverseDiff.TrackedReal{Float32,Float32,Nothing},1},1},typeof(DiffEqNoiseProcess.INPLACE_WHITE_NOISE_DIST),typeof(DiffEqNoiseProcess.INPLACE_WHITE_NOISE_BRIDGE),true,ResettableStacks.ResettableStack{Tuple{Float32,Array{ReverseDiff.TrackedReal{Float32,Float32,Nothing},1},Array{ReverseDiff.TrackedReal{Float32,Float32,Nothing},1}},true},ResettableStacks.ResettableStack{Tuple{Float32,Array{ReverseDiff.TrackedReal{Float32,Float32,Nothing},1},Array{ReverseDiff.TrackedReal{Float32,Float32,Nothing},1}},true},RSWM{Float64},Nothing,RandomNumbers.Xorshifts.Xoroshiro128Plus}, ::Float32, ::Array{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}},1}, ::ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}, ::Float32, ::RandomNumbers.Xorshifts.Xoroshiro128Plus) at /home/tim/.julia/packages/DiffEqNoiseProcess/KaHvM/src/wiener.jl:42
[4] setup_next_step! at /home/tim/.julia/packages/DiffEqNoiseProcess/KaHvM/src/noise_interfaces/noise_process_interface.jl:145 [inlined]
[5] setup_next_step! at /home/tim/.julia/packages/StochasticDiffEq/Abmgl/src/integrators/integrator_utils.jl:2 [inlined]
[6] __init(::SDEProblem{Array{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}},1},Tuple{Float32,Float32},true,ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing}, ::SOSRI, ::Array{Any,1}, ::Array{Any,1}, ::Type{T} where T, ::Type{Val{true}}; saveat::StepRangeLen{Float32,Float64,Float64}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_noise::Bool, save_on::Bool, save_start::Bool, save_end::Bool, callback::Nothing, dense::Bool, calck::Bool, dt::Float32, adaptive::Bool, gamma::Rational{Int64}, abstol::Nothing, reltol::Nothing, qmax::Rational{Int64}, qmin::Rational{Int64}, qoldinit::Rational{Int64}, fullnormalize::Bool, failfactor::Int64, beta2::Rational{Int64}, beta1::Rational{Int64}, delta::Rational{Int64}, maxiters::Int64, dtmax::Float32, dtmin::Float32, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, force_dtmin::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), userdata::Nothing, initialize_integrator::Bool, seed::UInt64, alias_u0::Bool, alias_jumps::Bool, kwargs::Base.Iterators.Pairs{Symbol,Bool,Tuple{Symbol},NamedTuple{(:default_set,),Tuple{Bool}}}) at /home/tim/.julia/packages/StochasticDiffEq/Abmgl/src/solve.jl:562
[7] #__solve#97 at /home/tim/.julia/packages/StochasticDiffEq/Abmgl/src/solve.jl:6 [inlined]
[8] __solve(::SDEProblem{Array{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}},1},Tuple{Float32,Float32},true,ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing}, ::Nothing; default_set::Bool, kwargs::Base.Iterators.Pairs{Symbol,StepRangeLen{Float32,Float64,Float64},Tuple{Symbol},NamedTuple{(:saveat,),Tuple{StepRangeLen{Float32,Float64,Float64}}}}) at /home/tim/.julia/packages/DifferentialEquations/fpohE/src/default_solve.jl:7
[9] #solve_call#456 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/solve.jl:65 [inlined]
[10] #solve_up#458 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/solve.jl:92 [inlined]
[11] #solve#457 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/solve.jl:74 [inlined]
[21] #solve#457 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/solve.jl:74 [inlined]
[22] _pullback(::Zygote.Context, ::DiffEqBase.var"##solve#457", ::ReverseDiffAdjoint, ::Nothing, ::Nothing, ::Base.Iterators.Pairs{Symbol,StepRangeLen{Float32,Float64,Float64},Tuple{Symbol},NamedTuple{(:saveat,),Tuple{StepRangeLen{Float32,Float64,Float64}}}}, ::typeof(solve), ::SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing}, ::Nothing) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
[25] _pullback(::Zygote.Context, ::DiffEqBase.var"#solve##kw", ::NamedTuple{(:saveat, :sensealg),Tuple{StepRangeLen{Float32,Float64,Float64},ReverseDiffAdjoint}}, ::typeof(solve), ::SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing}, ::Nothing) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
[26] #batch_func#360 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/ensemble/basic_ensemble_solve.jl:143 [inlined]
[27] _pullback(::Zygote.Context, ::DiffEqBase.var"##batch_func#360", ::Base.Iterators.Pairs{Symbol,Any,Tuple{Symbol,Symbol},NamedTuple{(:saveat, :sensealg),Tuple{StepRangeLen{Float32,Float64,Float64},ReverseDiffAdjoint}}}, ::typeof(DiffEqBase.batch_func), ::Int64, ::EnsembleProblem{SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing}, ::Nothing) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
[28] #365 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/ensemble/basic_ensemble_solve.jl:187 [inlined]
[29] #499 at /home/tim/.julia/packages/Zygote/c0awc/src/lib/array.jl:180 [inlined]
[30] iterate at ./generator.jl:47 [inlined]
[31] _collect(::UnitRange{Int64}, ::Base.Generator{UnitRange{Int64},Zygote.var"#499#503"{Zygote.Context,DiffEqBase.var"#365#366"{Base.Iterators.Pairs{Symbol,Any,Tuple{Symbol,Symbol},NamedTuple{(:saveat, :sensealg),Tuple{StepRangeLen{Float32,Float64,Float64},ReverseDiffAdjoint}}},EnsembleProblem{SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing},Nothing}}}, ::Base.EltypeUnknown, ::Base.HasShape{1}) at ./array.jl:699
[32] collect_similar(::UnitRange{Int64}, ::Base.Generator{UnitRange{Int64},Zygote.var"#499#503"{Zygote.Context,DiffEqBase.var"#365#366"{Base.Iterators.Pairs{Symbol,Any,Tuple{Symbol,Symbol},NamedTuple{(:saveat, :sensealg),Tuple{StepRangeLen{Float32,Float64,Float64},ReverseDiffAdjoint}}},EnsembleProblem{SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing},Nothing}}}) at ./array.jl:628
[33] map(::Function, ::UnitRange{Int64}) at ./abstractarray.jl:2162
[34] ∇map(::Zygote.Context, ::Function, ::UnitRange{Int64}) at /home/tim/.julia/packages/Zygote/c0awc/src/lib/array.jl:180
[37] #solve_batch#364 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/ensemble/basic_ensemble_solve.jl:186 [inlined]
[38] _pullback(::Zygote.Context, ::DiffEqBase.var"##solve_batch#364", ::Base.Iterators.Pairs{Symbol,Any,Tuple{Symbol,Symbol},NamedTuple{(:saveat, :sensealg),Tuple{StepRangeLen{Float32,Float64,Float64},ReverseDiffAdjoint}}}, ::typeof(DiffEqBase.solve_batch), ::EnsembleProblem{SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing}, ::Nothing, ::EnsembleSerial, ::UnitRange{Int64}, ::Int64) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0 (repeats 2 times)
[39] #369 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/ensemble/basic_ensemble_solve.jl:213 [inlined]
[40] _pullback(::Zygote.Context, ::DiffEqBase.var"#369#371"{Base.Iterators.Pairs{Symbol,Any,Tuple{Symbol,Symbol},NamedTuple{(:saveat, :sensealg),Tuple{StepRangeLen{Float32,Float64,Float64},ReverseDiffAdjoint}}},EnsembleProblem{SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing},Nothing,UnitRange{Int64},Int64,Int64}, ::Int64) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
[41] #639 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/init.jl:233 [inlined]
[42] macro expansion at /home/tim/.julia/packages/DiffEqBase/3iigH/src/ensemble/basic_ensemble_solve.jl:221 [inlined]
Stacktrace:
[3] macro expansion at ./threadingconstructs.jl:93 [inlined]
[4] tmap(::Function, ::UnitRange{Int64}) at /home/tim/.julia/packages/DiffEqBase/3iigH/src/ensemble/basic_ensemble_solve.jl:220
[5] ∇tmap(::Zygote.Context, ::Function, ::UnitRange{Int64}) at /home/tim/.julia/packages/DiffEqBase/3iigH/src/init.jl:233
[7] _pullback(::Zygote.Context, ::typeof(DiffEqBase.tmap), ::Function, ::UnitRange{Int64}) at /home/tim/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:47
[8] #solve_batch#367 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/ensemble/basic_ensemble_solve.jl:207 [inlined]
[9] _pullback(::Zygote.Context, ::DiffEqBase.var"##solve_batch#367", ::Base.Iterators.Pairs{Symbol,Any,Tuple{Symbol,Symbol},NamedTuple{(:saveat, :sensealg),Tuple{StepRangeLen{Float32,Float64,Float64},ReverseDiffAdjoint}}}, ::typeof(DiffEqBase.solve_batch), ::EnsembleProblem{SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing}, ::Nothing, ::EnsembleThreads, ::UnitRange{Int64}, ::Int64) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
[10] macro expansion at ./timing.jl:233 [inlined]
[11] #__solve#359 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/ensemble/basic_ensemble_solve.jl:108 [inlined]
[12] _pullback(::Zygote.Context, ::DiffEqBase.var"##__solve#359", ::Int64, ::Int64, ::Int64, ::Base.Iterators.Pairs{Symbol,Any,Tuple{Symbol,Symbol},NamedTuple{(:saveat, :sensealg),Tuple{StepRangeLen{Float32,Float64,Float64},ReverseDiffAdjoint}}}, ::typeof(DiffEqBase.__solve), ::EnsembleProblem{SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing}, ::Nothing, ::EnsembleThreads) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0 (repeats 2 times)
[13] #__solve#358 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/ensemble/basic_ensemble_solve.jl:87 [inlined]
[17] #solve#459 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/solve.jl:100 [inlined]
``````

Had to cut this due to the text limit, I can paste the rest in a subsequent post if needed.

`ForwardDiffSensitivity` does not fail, but also does not solve anything. I’ve left the run going overnight and don’t even manage to get to the first callback. It just sits and spools forever.

So if all of these are failing, have I set up this problem incorrectly? Or is there a better strategy to solve this particular system, or a better `sensealg` setup here?

Not for SDEs. For diagonal noise SDEs likely the right thing is `TrackerAdjoint`. `QuadratureAdjoint` doesn’t exist for SDEs given the weird rules of SDE integration, so we should throw a better error for that case. You can try `BacksolveAdjoint` or `InterpolatingAdjoint`, but you’ll need to switch to a commutative noise SDE solver because the reverse pass is no longer diagonal. We haven’t done all of the timings yet but @frankschae might have more comments.

`ReverseDiffAdjoint` is really what we should be using, so @mohamed82008 could we get help fixing that?

`timeseries_steps_meanvar` mutates as I found out yesterday: https://github.com/SciML/DiffEqFlux.jl/issues/446 . I’ll probably patch this over the weekend.

1 Like

Hi! You could try replacing the `timeseries_steps_meanvar` function by

``````tmp_sol = Array(solve(
ensembleprob,
EM(),
dt = 0.001f0,
...)
(tmp_mean, tmp_var) = mean(tmp_sol,dims=3)[:,:], var(tmp_sol,dims=3)[:,:]
``````

With that change, Zygote can compute the gradients:

``````Zygote.gradient(p->loss(p)[1], α)
``````

So this should solve the mutation issues until `timeseries_steps_meanvar` is patched.

Please note that I have chosen the 0.5 order `EM()` solver for Ito SDEs. This has essentially two reasons:
(1) As Chris said, the reverse process will have non-diagonal noise, therefore not all solvers are applicable.
(2) We have some problems with the reconstruction of the noise process with higher order solvers for Ito SDEs. (see https://github.com/SciML/DiffEqNoiseProcess.jl/pull/62). So, if you want to use
`BacksolveAdjoint` or `InterpolatingAdjoint` for SDEs in the Ito sense, one currently should choose `EM()` to be on the safe side…

Note also that the default sensitivity algorithm will be `InterpolatingAdjoint`. For SDEs written in the Ito sense, we have just implemented the tests for `BacksolveAdjoint` so far (https://github.com/SciML/DiffEqSensitivity.jl/pull/317)… More docs and tests should come soon

I am not sure what’s going wrong with `ReverseDiffAdjoint` and `ForwardDiffSensitivity` – Except that `ForwardDiffSensitivity` scales badly with the numbers of parameters… Does it work for fewer network parameters?

Regarding performance, it’s probably best if you do your own benchmarks for the different adjoint choices (I’d be interested to see them as well). Some small suggestions:

– setting return values to `nothing`, i.e.,

``````function noise_(du, u, p, t)
du[1] = p_[end]
du[2] = p_[end]
du[3] = 0.0
du[4] = 0.0
return nothing
end
``````

– make sure not to mix up different float types like in

``````du[2] = e * 0.05 * (10μ - u[2]) -> 0.05 is Float64
``````

– localize all parameters in the drift and diffusion function

``````r, e, μ, h, ph, z, i  = p[1:7]
p_nn = @view p[8:end]
``````

(while of course just optimizing the network parameters.)

Thanks both! I’ll take a look at these suggestions and report back.

It’s not immediately obvious to me what’s wrong. @Libbum could you please open an issue to track this?

do we have a good way to make it skip tracking of randoms?

Yes. We can define a custom gradient of `rand` or its equivalent to be nothing.

However, there has been a discussion before about whether this makes sense in general or not. https://github.com/JuliaDiff/ChainRules.jl/issues/262 and https://github.com/TuringLang/DistributionsAD.jl/issues/123

I would’ve thought I handled it via https://github.com/SciML/DiffEqNoiseProcess.jl/blob/master/src/init.jl#L12-L16 , but it seems that’s not being picked up?

You would need one for `TrackedReal` as well. The user is using `inplace = true` so it’s using `Array{<:TrackedReal}` which is probably a bigger problem performance-wise but it should work.

1 Like

@Libbum try this PR branch:

I released a patch with this.

Thanks for the assistance everyone!

I’m using the patched version of `DiffEqNoiseProcess.jl` @ChrisRackauckas, although since we seem to be tracking more than one problem in this thread I can’t say for sure if that’s assisted or not.

Here’s what’s working:

`````` function loss(θ)
tmp_prob = remake(prob_nn, p = θ)
ensembleprob = EnsembleProblem(tmp_prob)
tmp_sol = Array(solve(
ensembleprob,
EM();
dt = tsteps.step,
trajectories = 100,
))
tmp_mean = mean(tmp_sol,dims=3)[:,:]
tmp_var = var(tmp_sol,dims=3)[:,:]
sum(abs2, truemean - tmp_mean) + 0.1 * sum(abs2, truevar - tmp_var), tmp_mean
end
``````

Which is essentially what @frankschae has suggested above (remove the mutating functions, use `EM()` and `InterpolatingAdjoint()`).

If I follow correctly, the `TrackedReals` addition should assist when using `ReverseDiffAdjoint` correct?
The error I’m getting there is now:

`````` ERROR: TaskFailedException:
MethodError: no method matching randn(::RandomNumbers.Xorshifts.Xoroshiro128Plus, ::Type{ReverseDiff.TrackedReal{Float32,Float32,Nothing}})
Closest candidates are:
randn(::Random.AbstractRNG, ::Type{T}, ::Tuple{Vararg{Int64,N}} where N) where T at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/Random/src/normal.jl:201
randn(::Random.AbstractRNG, ::Type{T}, ::Integer, ::Integer...) where T at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/Random/src/normal.jl:204
randn(::Random.AbstractRNG) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/Random/src/normal.jl:38
...
``````

@mohamed82008, concerning your latest comment: would it be better to not use inplace in this instance? I don’t quite follow if Chris’ changes has resolved that worry or not. If you think this problem requires an issue to track it - which package are we looking at here? DiffEqSensitivity? Or should I just put it on DifferentialEquations?

This looks like something that wouldn’t happen in my fix? Are you on DiffEqNoiseProcess v5.4.2? If so, please give a code that I can just copy paste to see the error, otherwise I’m not going to go through the effort to find a reproducer.

Yeah, I believe so. But I’ll add it explicitly and try again. If not I’ll put up a full MWE.

Fails with 5.4.2—yes.

``````(uode) pkg> st
Status `/mnt/turtle/scratch/uode/Project.toml`
[ff3c4d4f] AutoOptimize v0.1.0 `https://github.com/SciML/AutoOptimize.jl.git#master`
[052768ef] CUDA v1.3.3
[a93c6f00] DataFrames v0.22.1
[7806a523] DecisionTree v0.10.10
[2b5f629d] DiffEqBase v6.48.2
[aae7a2af] DiffEqFlux v1.24.0
[071ae1c0] DiffEqGPU v1.8.0
[77a26b50] DiffEqNoiseProcess v5.4.2
[41bf760c] DiffEqSensitivity v6.34.0
[587475ba] Flux v0.11.1
[033835bb] JLD2 v0.3.0
[2fda8390] LsqFit v0.11.0
[961ee093] ModelingToolkit v3.21.0
[429524aa] Optim v1.2.0
[1dea7af3] OrdinaryDiffEq v5.45.1
[789caeaf] StochasticDiffEq v6.26.0
[5d786b92] TerminalLoggers v0.1.2
[37e2e46d] LinearAlgebra
[56ddb016] Logging
[2f01184e] SparseArrays
[10745b16] Statistics
``````

Current script:

``````using DiffEqFlux, Flux, Optim, LinearAlgebra
using DiffEqNoiseProcess
using StochasticDiffEq
using Statistics
using DiffEqSensitivity
using DiffEqBase.EnsembleAnalysis

function sys!(du, u, p, t)
r, e, μ, h, ph, z, i = p
du[1] = e * 0.5 * (5μ - u[1]) # nutrient input time series
du[2] = e * 0.05 * (10μ - u[2]) # grazer density time series
du[3] = 0.2 * exp(u[1]) - 0.05 * u[3] - r * u[3] / (h + u[3]) * u[4] # nutrient concentration
du[4] =
r * u[3] / (h + u[3]) * u[4] - 0.1 * u[4] -
0.02 * u[4]^z / (ph^z + u[4]^z) * exp(u[2] / 2.0) + i #Algae density
end

function noise!(du, u, p, t)
du[1] = p[end] # n
du[2] = p[end] # n
du[3] = 0.0
du[4] = 0.0
end

datasize = 10
tspan = (0.0f0, 3.0f0)
tsteps = range(tspan[1], tspan[2], length = datasize)
u0 = Float32[1.0, 1.0, 1.0, 1.0]

p_ = Float32[1.1, 1.0, 0.0, 2.0, 1.0, 1.0, 1e-6, 1.0]

prob = SDEProblem(sys!, noise!, u0, tspan, p_)
ensembleprob = EnsembleProblem(prob)

solution = solve(
ensembleprob,
SOSRI(),
trajectories = 1000,
abstol = 1e-5,
reltol = 1e-5,
maxiters = 1e8,
saveat = tsteps,
)

(truemean, truevar) = Array.(timeseries_steps_meanvar(solution))

ann = FastChain(FastDense(4, 32, tanh), FastDense(32, 32, tanh), FastDense(32, 2))
α = initial_params(ann)

function dudt_(du, u, p, t)
r, e, μ, h, ph, z, i = p_

MM = ann(u, p)

du[1] = e * 0.5 * (5μ - u[1]) # nutrient input time series
du[2] = e * 0.05 * (10μ - u[2]) # grazer density time series
du[3] = 0.2 * exp(u[1]) - 0.05 * u[3] - MM[1] # nutrient concentration
du[4] = MM[2] - 0.1 * u[4] - 0.02 * u[4]^z / (ph^z + u[4]^z) * exp(u[2] / 2.0) + i #Algae density
return nothing
end
function noise_(du, u, p, t)
du[1] = p_[end]
du[2] = p_[end]
du[3] = 0.0
du[4] = 0.0
return nothing
end

prob_nn = SDEProblem(dudt_, noise_, u0, tspan, p = nothing)

function loss(θ)
tmp_prob = remake(prob_nn, p = θ)
ensembleprob = EnsembleProblem(tmp_prob)
tmp_sol = Array(solve(
ensembleprob,
EM();
dt = tsteps.step,
trajectories = 100,
))
tmp_mean = mean(tmp_sol,dims=3)[:,:]
tmp_var = var(tmp_sol,dims=3)[:,:]
sum(abs2, truemean - tmp_mean) + 0.1 * sum(abs2, truevar - tmp_var), tmp_mean
end

const losses = []
callback(θ, l, pred) = begin
push!(losses, l)
if length(losses)%50 == 0
println("Current loss after \$(length(losses)) iterations: \$(losses[end])")
end
false
end

res1 = DiffEqFlux.sciml_train(
loss,
α,
This is the overload you need: https://github.com/SciML/DiffEqNoiseProcess.jl/pull/75 . But now I’m going to punt it over to @mohamed82008 since it’s erroring because the output of the function is an Array{TrackedReal} so it can’t `increment_deriv!`, and `reduce(vcat,output)` isn’t giving a TrackedArray.