Correct sensealg for an ensemble problem of stochastic UODEs with diagonal noise

Hi all,

I can’t seem to get my stochastic UODE running. Here’s the setup and generation of synthetic data:

using Plots, DiffEqFlux, Flux, Optim, LinearAlgebra, DifferentialEquations
using Statistics
using DiffEqSensitivity
using DifferentialEquations.EnsembleAnalysis

function sys!(du, u, p, t)
    r, e, μ, h, ph, z, i = p
    du[1] = e * 0.5 * (5μ - u[1]) # nutrient input time series
    du[2] = e * 0.05 * (10μ - u[2]) # grazer density time series
    du[3] = 0.2 * exp(u[1]) - 0.05 * u[3] - r * u[3] / (h + u[3]) * u[4] # nutrient concentration
    du[4] =
        r * u[3] / (h + u[3]) * u[4] - 0.1 * u[4] -
        0.02 * u[4]^z / (ph^z + u[4]^z) * exp(u[2] / 2.0) + i #Algae density
end

function noise!(du, u, p, t)
    du[1] = p[end] # n
    du[2] = p[end] # n
    du[3] = 0.0
    du[4] = 0.0
end


# I've dropped the time frame down quite a bit here just to try to get it to run.
#datasize = 30
#tspan = (0.0f0, 10.0f0)
datasize = 10
tspan = (0.0f0, 3.0f0)
tsteps = range(tspan[1], tspan[2], length = datasize)
u0 = Float32[1.0, 1.0, 1.0, 1.0]

p_ = Float32[1.1, 1.0, 0.0, 2.0, 1.0, 1.0, 1e-6, 1.0]

prob = SDEProblem(sys!, noise!, u0, tspan, p_)
ensembleprob = EnsembleProblem(prob)

solution = solve(
    ensembleprob,
    EnsembleThreads();
    trajectories = 1000,
    abstol = 1e-5,
    reltol = 1e-5,
    maxiters = 1e8,
    saveat = tsteps,
)

(truemean, truevar) = Array.(timeseries_steps_meanvar(solution))

I’ve got 14 hardware threads, so this doesn’t take too long to run (@btime on the solve step is 23 seconds).

The plan is to solve for the Michaelis-Menton kinetics portions of eqs 3 and 4, so I set up my UODE following a combination of documentation pages (I’m not 100% sure there’s an example that matches my case): neural_sde, PINNs, zygote example, uode examples.

# Generate a NN with four inputs [L, G, N, A] and two outputs [NN1, NN2]
ann = FastChain(FastDense(4, 32, tanh), FastDense(32, 32, tanh), FastDense(32, 2))
α = initial_params(ann)

function dudt_(du, u, p, t)
    r, e, μ, h, ph, z, i = p_

    # we replace each of our MM values with the result of the NN.
    MM = ann(u, p)

    du[1] = e * 0.5 * (5μ - u[1]) # nutrient input time series
    du[2] = e * 0.05 * (10μ - u[2]) # grazer density time series
    du[3] = 0.2 * exp(u[1]) - 0.05 * u[3] - MM[1] # nutrient concentration
    du[4] = MM[2] - 0.1 * u[4] - 0.02 * u[4]^z / (ph^z + u[4]^z) * exp(u[2] / 2.0) + i #Algae density
end
function noise_(du, u, p, t)
    du[1] = p_[end]
    du[2] = p_[end]
    du[3] = 0.0
    du[4] = 0.0
end

prob_nn = SDEProblem(dudt_, noise_, u0, tspan, p = nothing)

function loss(θ)
    tmp_prob = remake(prob_nn, p = θ)
    ensembleprob = EnsembleProblem(tmp_prob)
    tmp_sol = solve(
        ensembleprob,
        EnsembleThreads();
        saveat = tsteps,
        trajectories = 100, # Drop down a little for troubleshooting
       # sensealg = ReverseDiffAdjoint(),
    #    sensealg = ForwardDiffSensitivity(),
    )
    (tmp_mean, tmp_var) = Array.(timeseries_steps_meanvar(tmp_sol))
    sum(abs2, truemean - tmp_mean) + 0.1 * sum(abs2, truevar - tmp_var), tmp_mean
end

# Test
loss(α) # @btime 360 ms

const losses = []

callback(θ, l, pred) = begin
    push!(losses, l)
    println("Current loss after $(length(losses)) iterations: $(losses[end])")
    false
end

res1 = DiffEqFlux.sciml_train(
    loss,
    α,
    ADAM(0.1),
    cb = callback,
    maxiters = 200,
   # allow_f_increases = true,
)

If I’ve got this set up correctly, I cannot figure out what sensealg works here. If I choose the default (i.e. omit a sensealg keyword which is happening in the loss function above), I get an unhelpful error stating:

Training 100%|█████████████████████████████████████████████████████████████| Time: 0:00:54
ERROR: Mutating arrays is not supported
Stacktrace
Stacktrace:
 [1] error(::String) at ./error.jl:33
 [2] (::Zygote.var"#372#373")(::Nothing) at /home/tim/.julia/packages/Zygote/c0awc/src/lib/array.jl:65
 [3] (::Zygote.var"#2265#back#374"{Zygote.var"#372#373"})(::Nothing) at /home/tim/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [4] timeseries_steps_meanvar at /home/tim/.julia/packages/DiffEqBase/3iigH/src/ensemble/ensemble_analysis.jl:71 [inlined]
 [5] (::typeof(∂(timeseries_steps_meanvar)))(::Tuple{Array{Float64,2},Array{Float64,2}}) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [6] loss at ./REPL[22]:12 [inlined]
 [7] (::typeof(∂(loss)))(::Tuple{Float64,Nothing}) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [8] #150 at /home/tim/.julia/packages/Zygote/c0awc/src/lib/lib.jl:191 [inlined]
 [9] (::Zygote.var"#1693#back#152"{Zygote.var"#150#151"{typeof(∂(loss)),Tuple{Tuple{Nothing},Tuple{}}}})(::Tuple{Float64,Nothing}) at /home/tim/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:49
 [10] #74 at /home/tim/.julia/packages/DiffEqFlux/8UHw5/src/train.jl:120 [inlined]
 [11] (::typeof(∂(λ)))(::Float64) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [12] (::Zygote.var"#54#55"{Zygote.Params,Zygote.Context,typeof(∂(λ))})(::Float64) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface.jl:177
 [13] gradient(::Function, ::Zygote.Params) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface.jl:54
 [14] macro expansion at /home/tim/.julia/packages/DiffEqFlux/8UHw5/src/train.jl:119 [inlined]
 [15] macro expansion at /home/tim/.julia/packages/ProgressLogging/BBN0b/src/ProgressLogging.jl:328 [inlined]
 [16] (::DiffEqFlux.var"#73#78"{typeof(callback),Int64,Bool,Bool,typeof(loss),Array{Float32,1},Zygote.Params})() at /home/tim/.julia/packages/DiffEqFlux/8UHw5/src/train.jl:64
 [17] with_logstate(::Function, ::Any) at ./logging.jl:408
 [18] with_logger at ./logging.jl:514 [inlined]
 [19] maybe_with_logger(::DiffEqFlux.var"#73#78"{typeof(callback),Int64,Bool,Bool,typeof(loss),Array{Float32,1},Zygote.Params}, ::LoggingExtras.TeeLogger{Tuple{LoggingExtras.EarlyFilteredLogger{TerminalLoggers.TerminalLogger,DiffEqFlux.var"#68#70"},LoggingExtras.EarlyFilteredLogger{Logging.ConsoleLogger,DiffEqFlux.var"#69#71"}}}) at /home/tim/.julia/packages/DiffEqFlux/8UHw5/src/train.jl:39
 [20] sciml_train(::Function, ::Array{Float32,1}, ::ADAM, ::Base.Iterators.Cycle{Tuple{DiffEqFlux.NullData}}; cb::Function, maxiters::Int64, progress::Bool, save_best::Bool) at /home/tim/.julia/packages/DiffEqFlux/8UHw5/src/train.jl:63
 [21] top-level scope at REPL[26]:2

According to the DiffEqSensitivity docs, I should be using QuadratureAdjoint for my system size, which if I’m not mistaken is what the defaults actually picked, since I get the Mutating arrays not supported error using it. At what point am I mutating arrays?

ReverseDiffAdjoint gives me an equally opaque error somewhere deep:

Training 100%|█████████████████████████████████████████████████████████████| Time: 0:01:12
ERROR: TaskFailedException:
MethodError: no method matching randn(::RandomNumbers.Xorshifts.Xoroshiro128Plus, ::Type{ReverseDiff.TrackedReal{Float32,Float32,Nothing}})                                                           
Closest candidates are:
  randn(::Random.AbstractRNG, ::Type{T}, ::Tuple{Vararg{Int64,N}} where N) where T at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/Random/src/normal.jl:201
  randn(::Random.AbstractRNG, ::Type{T}, ::Integer, ::Integer...) where T at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/Random/src/normal.jl:204
  randn(::Random.AbstractRNG) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/Random/src/normal.jl:38
  ...
Stacktrace
Stacktrace:
 [1] randn!(::RandomNumbers.Xorshifts.Xoroshiro128Plus, ::Array{ReverseDiff.TrackedReal{Float32,Float32,Nothing},1}) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/Random/src/normal.jl:178
 [2] wiener_randn! at /home/tim/.julia/packages/DiffEqNoiseProcess/KaHvM/src/wiener.jl:9 [inlined]
 [3] INPLACE_WHITE_NOISE_DIST(::Array{ReverseDiff.TrackedReal{Float32,Float32,Nothing},1}, ::NoiseProcess{ReverseDiff.TrackedReal{Float32,Float32,Nothing},2,Float32,Array{ReverseDiff.TrackedReal{Float32,Float32,Nothing},1},Array{ReverseDiff.TrackedReal{Float32,Float32,Nothing},1},Array{Array{ReverseDiff.TrackedReal{Float32,Float32,Nothing},1},1},typeof(DiffEqNoiseProcess.INPLACE_WHITE_NOISE_DIST),typeof(DiffEqNoiseProcess.INPLACE_WHITE_NOISE_BRIDGE),true,ResettableStacks.ResettableStack{Tuple{Float32,Array{ReverseDiff.TrackedReal{Float32,Float32,Nothing},1},Array{ReverseDiff.TrackedReal{Float32,Float32,Nothing},1}},true},ResettableStacks.ResettableStack{Tuple{Float32,Array{ReverseDiff.TrackedReal{Float32,Float32,Nothing},1},Array{ReverseDiff.TrackedReal{Float32,Float32,Nothing},1}},true},RSWM{Float64},Nothing,RandomNumbers.Xorshifts.Xoroshiro128Plus}, ::Float32, ::Array{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}},1}, ::ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}, ::Float32, ::RandomNumbers.Xorshifts.Xoroshiro128Plus) at /home/tim/.julia/packages/DiffEqNoiseProcess/KaHvM/src/wiener.jl:42
 [4] setup_next_step! at /home/tim/.julia/packages/DiffEqNoiseProcess/KaHvM/src/noise_interfaces/noise_process_interface.jl:145 [inlined]                                                             
 [5] setup_next_step! at /home/tim/.julia/packages/StochasticDiffEq/Abmgl/src/integrators/integrator_utils.jl:2 [inlined]                                                                             
 [6] __init(::SDEProblem{Array{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}},1},Tuple{Float32,Float32},true,ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing}, ::SOSRI, ::Array{Any,1}, ::Array{Any,1}, ::Type{T} where T, ::Type{Val{true}}; saveat::StepRangeLen{Float32,Float64,Float64}, tstops::Tuple{}, d_discontinuities::Tuple{}, save_idxs::Nothing, save_everystep::Bool, save_noise::Bool, save_on::Bool, save_start::Bool, save_end::Bool, callback::Nothing, dense::Bool, calck::Bool, dt::Float32, adaptive::Bool, gamma::Rational{Int64}, abstol::Nothing, reltol::Nothing, qmax::Rational{Int64}, qmin::Rational{Int64}, qoldinit::Rational{Int64}, fullnormalize::Bool, failfactor::Int64, beta2::Rational{Int64}, beta1::Rational{Int64}, delta::Rational{Int64}, maxiters::Int64, dtmax::Float32, dtmin::Float32, internalnorm::typeof(DiffEqBase.ODE_DEFAULT_NORM), isoutofdomain::typeof(DiffEqBase.ODE_DEFAULT_ISOUTOFDOMAIN), unstable_check::typeof(DiffEqBase.ODE_DEFAULT_UNSTABLE_CHECK), verbose::Bool, force_dtmin::Bool, timeseries_errors::Bool, dense_errors::Bool, advance_to_tstop::Bool, stop_at_next_tstop::Bool, initialize_save::Bool, progress::Bool, progress_steps::Int64, progress_name::String, progress_message::typeof(DiffEqBase.ODE_DEFAULT_PROG_MESSAGE), userdata::Nothing, initialize_integrator::Bool, seed::UInt64, alias_u0::Bool, alias_jumps::Bool, kwargs::Base.Iterators.Pairs{Symbol,Bool,Tuple{Symbol},NamedTuple{(:default_set,),Tuple{Bool}}}) at /home/tim/.julia/packages/StochasticDiffEq/Abmgl/src/solve.jl:562
 [7] #__solve#97 at /home/tim/.julia/packages/StochasticDiffEq/Abmgl/src/solve.jl:6 [inlined]
 [8] __solve(::SDEProblem{Array{ReverseDiff.TrackedReal{Float32,Float32,ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}},1},Tuple{Float32,Float32},true,ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing}, ::Nothing; default_set::Bool, kwargs::Base.Iterators.Pairs{Symbol,StepRangeLen{Float32,Float64,Float64},Tuple{Symbol},NamedTuple{(:saveat,),Tuple{StepRangeLen{Float32,Float64,Float64}}}}) at /home/tim/.julia/packages/DifferentialEquations/fpohE/src/default_solve.jl:7
 [9] #solve_call#456 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/solve.jl:65 [inlined]
 [10] #solve_up#458 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/solve.jl:92 [inlined]
 [11] #solve#457 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/solve.jl:74 [inlined]
 [12] (::DiffEqSensitivity.var"#reversediff_adjoint_forwardpass#168"{Base.Iterators.Pairs{Symbol,StepRangeLen{Float32,Float64,Float64},Tuple{Symbol},NamedTuple{(:saveat,),Tuple{StepRangeLen{Float32,Float64,Float64}}}},SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},Nothing,Tuple{}})(::ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}, ::ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}) at /home/tim/.julia/packages/DiffEqSensitivity/WiCRA/src/local_sensitivity/concrete_solve.jl:315
 [13] ReverseDiff.GradientTape(::Function, ::Tuple{Array{Float32,1},Array{Float32,1}}, ::ReverseDiff.GradientConfig{Tuple{ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}},ReverseDiff.TrackedArray{Float32,Float32,1,Array{Float32,1},Array{Float32,1}}}}) at /home/tim/.julia/packages/ReverseDiff/jFRo1/src/api/tape.jl:207
 [14] GradientTape at /home/tim/.julia/packages/ReverseDiff/jFRo1/src/api/tape.jl:204 [inlined]
 [15] _concrete_solve_adjoint(::SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing}, ::Nothing, ::ReverseDiffAdjoint, ::Array{Float32,1}, ::Array{Float32,1}; kwargs::Base.Iterators.Pairs{Symbol,StepRangeLen{Float32,Float64,Float64},Tuple{Symbol},NamedTuple{(:saveat,),Tuple{StepRangeLen{Float32,Float64,Float64}}}}) at /home/tim/.julia/packages/DiffEqSensitivity/WiCRA/src/local_sensitivity/concrete_solve.jl:326
 [16] #_solve_adjoint#478 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/solve.jl:277 [inlined]
 [17] #adjoint#469 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/solve.jl:241 [inlined]
 [18] _pullback at /home/tim/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:53 [inlined]
 [19] adjoint at /home/tim/.julia/packages/Zygote/c0awc/src/lib/lib.jl:188 [inlined]
 [20] _pullback at /home/tim/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:47 [inlined]
 [21] #solve#457 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/solve.jl:74 [inlined]
 [22] _pullback(::Zygote.Context, ::DiffEqBase.var"##solve#457", ::ReverseDiffAdjoint, ::Nothing, ::Nothing, ::Base.Iterators.Pairs{Symbol,StepRangeLen{Float32,Float64,Float64},Tuple{Symbol},NamedTuple{(:saveat,),Tuple{StepRangeLen{Float32,Float64,Float64}}}}, ::typeof(solve), ::SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing}, ::Nothing) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [23] adjoint at /home/tim/.julia/packages/Zygote/c0awc/src/lib/lib.jl:188 [inlined]
 [24] _pullback at /home/tim/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:47 [inlined]
 [25] _pullback(::Zygote.Context, ::DiffEqBase.var"#solve##kw", ::NamedTuple{(:saveat, :sensealg),Tuple{StepRangeLen{Float32,Float64,Float64},ReverseDiffAdjoint}}, ::typeof(solve), ::SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing}, ::Nothing) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [26] #batch_func#360 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/ensemble/basic_ensemble_solve.jl:143 [inlined]                                                                                
 [27] _pullback(::Zygote.Context, ::DiffEqBase.var"##batch_func#360", ::Base.Iterators.Pairs{Symbol,Any,Tuple{Symbol,Symbol},NamedTuple{(:saveat, :sensealg),Tuple{StepRangeLen{Float32,Float64,Float64},ReverseDiffAdjoint}}}, ::typeof(DiffEqBase.batch_func), ::Int64, ::EnsembleProblem{SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing}, ::Nothing) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [28] #365 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/ensemble/basic_ensemble_solve.jl:187 [inlined]                                                                                           
 [29] #499 at /home/tim/.julia/packages/Zygote/c0awc/src/lib/array.jl:180 [inlined]
 [30] iterate at ./generator.jl:47 [inlined]
 [31] _collect(::UnitRange{Int64}, ::Base.Generator{UnitRange{Int64},Zygote.var"#499#503"{Zygote.Context,DiffEqBase.var"#365#366"{Base.Iterators.Pairs{Symbol,Any,Tuple{Symbol,Symbol},NamedTuple{(:saveat, :sensealg),Tuple{StepRangeLen{Float32,Float64,Float64},ReverseDiffAdjoint}}},EnsembleProblem{SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing},Nothing}}}, ::Base.EltypeUnknown, ::Base.HasShape{1}) at ./array.jl:699
 [32] collect_similar(::UnitRange{Int64}, ::Base.Generator{UnitRange{Int64},Zygote.var"#499#503"{Zygote.Context,DiffEqBase.var"#365#366"{Base.Iterators.Pairs{Symbol,Any,Tuple{Symbol,Symbol},NamedTuple{(:saveat, :sensealg),Tuple{StepRangeLen{Float32,Float64,Float64},ReverseDiffAdjoint}}},EnsembleProblem{SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing},Nothing}}}) at ./array.jl:628
 [33] map(::Function, ::UnitRange{Int64}) at ./abstractarray.jl:2162
 [34] ∇map(::Zygote.Context, ::Function, ::UnitRange{Int64}) at /home/tim/.julia/packages/Zygote/c0awc/src/lib/array.jl:180
 [35] adjoint at /home/tim/.julia/packages/Zygote/c0awc/src/lib/array.jl:196 [inlined]
 [36] _pullback at /home/tim/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:47 [inlined]
 [37] #solve_batch#364 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/ensemble/basic_ensemble_solve.jl:186 [inlined]                                                                               
 [38] _pullback(::Zygote.Context, ::DiffEqBase.var"##solve_batch#364", ::Base.Iterators.Pairs{Symbol,Any,Tuple{Symbol,Symbol},NamedTuple{(:saveat, :sensealg),Tuple{StepRangeLen{Float32,Float64,Float64},ReverseDiffAdjoint}}}, ::typeof(DiffEqBase.solve_batch), ::EnsembleProblem{SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing}, ::Nothing, ::EnsembleSerial, ::UnitRange{Int64}, ::Int64) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0 (repeats 2 times)                                                                                                  
 [39] #369 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/ensemble/basic_ensemble_solve.jl:213 [inlined]                                                                                           
 [40] _pullback(::Zygote.Context, ::DiffEqBase.var"#369#371"{Base.Iterators.Pairs{Symbol,Any,Tuple{Symbol,Symbol},NamedTuple{(:saveat, :sensealg),Tuple{StepRangeLen{Float32,Float64,Float64},ReverseDiffAdjoint}}},EnsembleProblem{SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing},Nothing,UnitRange{Int64},Int64,Int64}, ::Int64) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [41] #639 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/init.jl:233 [inlined]
 [42] macro expansion at /home/tim/.julia/packages/DiffEqBase/3iigH/src/ensemble/basic_ensemble_solve.jl:221 [inlined]                                                                                
 [43] (::DiffEqBase.var"#510#threadsfor_fun#372"{DiffEqBase.var"#639#644"{Zygote.Context,DiffEqBase.var"#369#371"{Base.Iterators.Pairs{Symbol,Any,Tuple{Symbol,Symbol},NamedTuple{(:saveat, :sensealg),Tuple{StepRangeLen{Float32,Float64,Float64},ReverseDiffAdjoint}}},EnsembleProblem{SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing},Nothing,UnitRange{Int64},Int64,Int64}},Tuple{UnitRange{Int64}},Array{Tuple{Any,typeof(∂(λ))},1},UnitRange{Int64}})(::Bool) at ./threadingconstructs.jl:81
 [44] (::DiffEqBase.var"#510#threadsfor_fun#372"{DiffEqBase.var"#639#644"{Zygote.Context,DiffEqBase.var"#369#371"{Base.Iterators.Pairs{Symbol,Any,Tuple{Symbol,Symbol},NamedTuple{(:saveat, :sensealg),Tuple{StepRangeLen{Float32,Float64,Float64},ReverseDiffAdjoint}}},EnsembleProblem{SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing},Nothing,UnitRange{Int64},Int64,Int64}},Tuple{UnitRange{Int64}},Array{Tuple{Any,typeof(∂(λ))},1},UnitRange{Int64}})() at ./threadingconstructs.jl:48
Stacktrace:
 [1] wait at ./task.jl:267 [inlined]
 [2] threading_run(::Function) at ./threadingconstructs.jl:34
 [3] macro expansion at ./threadingconstructs.jl:93 [inlined]
 [4] tmap(::Function, ::UnitRange{Int64}) at /home/tim/.julia/packages/DiffEqBase/3iigH/src/ensemble/basic_ensemble_solve.jl:220
 [5] ∇tmap(::Zygote.Context, ::Function, ::UnitRange{Int64}) at /home/tim/.julia/packages/DiffEqBase/3iigH/src/init.jl:233
 [6] adjoint at /home/tim/.julia/packages/DiffEqBase/3iigH/src/zygote.jl:55 [inlined]
 [7] _pullback(::Zygote.Context, ::typeof(DiffEqBase.tmap), ::Function, ::UnitRange{Int64}) at /home/tim/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:47
 [8] #solve_batch#367 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/ensemble/basic_ensemble_solve.jl:207 [inlined]
 [9] _pullback(::Zygote.Context, ::DiffEqBase.var"##solve_batch#367", ::Base.Iterators.Pairs{Symbol,Any,Tuple{Symbol,Symbol},NamedTuple{(:saveat, :sensealg),Tuple{StepRangeLen{Float32,Float64,Float64},ReverseDiffAdjoint}}}, ::typeof(DiffEqBase.solve_batch), ::EnsembleProblem{SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing}, ::Nothing, ::EnsembleThreads, ::UnitRange{Int64}, ::Int64) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [10] macro expansion at ./timing.jl:233 [inlined]
 [11] #__solve#359 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/ensemble/basic_ensemble_solve.jl:108 [inlined]
 [12] _pullback(::Zygote.Context, ::DiffEqBase.var"##__solve#359", ::Int64, ::Int64, ::Int64, ::Base.Iterators.Pairs{Symbol,Any,Tuple{Symbol,Symbol},NamedTuple{(:saveat, :sensealg),Tuple{StepRangeLen{Float32,Float64,Float64},ReverseDiffAdjoint}}}, ::typeof(DiffEqBase.__solve), ::EnsembleProblem{SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing}, ::Nothing, ::EnsembleThreads) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0 (repeats 2 times)
 [13] #__solve#358 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/ensemble/basic_ensemble_solve.jl:87 [inlined]
 [14] _pullback(::Zygote.Context, ::DiffEqBase.var"##__solve#358", ::Base.Iterators.Pairs{Symbol,Any,Tuple{Symbol,Symbol,Symbol},NamedTuple{(:saveat, :trajectories, :sensealg),Tuple{StepRangeLen{Float32,Float64,Float64},Int64,ReverseDiffAdjoint}}}, ::typeof(DiffEqBase.__solve), ::EnsembleProblem{SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing}, ::EnsembleThreads) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [15] adjoint at /home/tim/.julia/packages/Zygote/c0awc/src/lib/lib.jl:188 [inlined]
 [16] _pullback at /home/tim/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:47 [inlined]
 [17] #solve#459 at /home/tim/.julia/packages/DiffEqBase/3iigH/src/solve.jl:100 [inlined]
 [18] _pullback(::Zygote.Context, ::DiffEqBase.var"##solve#459", ::Base.Iterators.Pairs{Symbol,Any,Tuple{Symbol,Symbol,Symbol},NamedTuple{(:saveat, :trajectories, :sensealg),Tuple{StepRangeLen{Float32,Float64,Float64},Int64,ReverseDiffAdjoint}}}, ::typeof(solve), ::EnsembleProblem{SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing}, ::EnsembleThreads) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0
 [19] adjoint at /home/tim/.julia/packages/Zygote/c0awc/src/lib/lib.jl:188 [inlined]
 [20] _pullback at /home/tim/.julia/packages/ZygoteRules/6nssF/src/adjoint.jl:47 [inlined]
 [21] _pullback(::Zygote.Context, ::DiffEqBase.var"#solve##kw", ::NamedTuple{(:saveat, :trajectories, :sensealg),Tuple{StepRangeLen{Float32,Float64,Float64},Int64,ReverseDiffAdjoint}}, ::typeof(solve), ::EnsembleProblem{SDEProblem{Array{Float32,1},Tuple{Float32,Float32},true,Array{Float32,1},Nothing,SDEFunction{true,typeof(dudt_),typeof(noise_),UniformScaling{Bool},Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing,Nothing},typeof(noise_),Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}},Nothing},typeof(DiffEqBase.DEFAULT_PROB_FUNC),typeof(DiffEqBase.DEFAULT_OUTPUT_FUNC),typeof(DiffEqBase.DEFAULT_REDUCTION),Nothing}, ::EnsembleThreads) at /home/tim/.julia/packages/Zygote/c0awc/src/compiler/interface2.jl:0

Had to cut this due to the text limit, I can paste the rest in a subsequent post if needed.

ForwardDiffSensitivity does not fail, but also does not solve anything. I’ve left the run going overnight and don’t even manage to get to the first callback. It just sits and spools forever.

So if all of these are failing, have I set up this problem incorrectly? Or is there a better strategy to solve this particular system, or a better sensealg setup here?

Not for SDEs. For diagonal noise SDEs likely the right thing is TrackerAdjoint. QuadratureAdjoint doesn’t exist for SDEs given the weird rules of SDE integration, so we should throw a better error for that case. You can try BacksolveAdjoint or InterpolatingAdjoint, but you’ll need to switch to a commutative noise SDE solver because the reverse pass is no longer diagonal. We haven’t done all of the timings yet but @frankschae might have more comments.

ReverseDiffAdjoint is really what we should be using, so @mohamed82008 could we get help fixing that?

timeseries_steps_meanvar mutates as I found out yesterday: EnsembleProblem with timepoint_meanvar errors due to internal mutation · Issue #446 · SciML/DiffEqFlux.jl · GitHub . I’ll probably patch this over the weekend.

1 Like

Hi! You could try replacing the timeseries_steps_meanvar function by

tmp_sol = Array(solve(
        ensembleprob,
        EM(),
        dt = 0.001f0,
        EnsembleThreads();
      ...)
(tmp_mean, tmp_var) = mean(tmp_sol,dims=3)[:,:], var(tmp_sol,dims=3)[:,:]

With that change, Zygote can compute the gradients:

Zygote.gradient(p->loss(p)[1], α)

So this should solve the mutation issues until timeseries_steps_meanvar is patched.

Please note that I have chosen the 0.5 order EM() solver for Ito SDEs. This has essentially two reasons:
(1) As Chris said, the reverse process will have non-diagonal noise, therefore not all solvers are applicable.
(2) We have some problems with the reconstruction of the noise process with higher order solvers for Ito SDEs. (see https://github.com/SciML/DiffEqNoiseProcess.jl/pull/62). So, if you want to use
BacksolveAdjoint or InterpolatingAdjoint for SDEs in the Ito sense, one currently should choose EM() to be on the safe side…

Note also that the default sensitivity algorithm will be InterpolatingAdjoint. For SDEs written in the Ito sense, we have just implemented the tests for BacksolveAdjoint so far (https://github.com/SciML/DiffEqSensitivity.jl/pull/317)… More docs and tests should come soon :slight_smile:

I am not sure what’s going wrong with ReverseDiffAdjoint and ForwardDiffSensitivity – Except that ForwardDiffSensitivity scales badly with the numbers of parameters… Does it work for fewer network parameters?

Regarding performance, it’s probably best if you do your own benchmarks for the different adjoint choices (I’d be interested to see them as well). Some small suggestions:

– setting return values to nothing, i.e.,

function noise_(du, u, p, t)
    du[1] = p_[end]
    du[2] = p_[end]
    du[3] = 0.0
    du[4] = 0.0
    return nothing
end

– make sure not to mix up different float types like in

du[2] = e * 0.05 * (10μ - u[2]) -> 0.05 is Float64

– localize all parameters in the drift and diffusion function

r, e, μ, h, ph, z, i  = p[1:7]
p_nn = @view p[8:end]

(while of course just optimizing the network parameters.)

Thanks both! I’ll take a look at these suggestions and report back.

It’s not immediately obvious to me what’s wrong. @Libbum could you please open an issue to track this?

do we have a good way to make it skip tracking of randoms?

Yes. We can define a custom gradient of rand or its equivalent to be nothing.

However, there has been a discussion before about whether this makes sense in general or not. https://github.com/JuliaDiff/ChainRules.jl/issues/262 and https://github.com/TuringLang/DistributionsAD.jl/issues/123

I would’ve thought I handled it via https://github.com/SciML/DiffEqNoiseProcess.jl/blob/master/src/init.jl#L12-L16 , but it seems that’s not being picked up?

You would need one for TrackedReal as well. The user is using inplace = true so it’s using Array{<:TrackedReal} which is probably a bigger problem performance-wise but it should work.

1 Like

@Libbum try this PR branch:

https://github.com/SciML/DiffEqNoiseProcess.jl/pull/73

I released a patch with this.

Thanks for the assistance everyone!

I’m using the patched version of DiffEqNoiseProcess.jl @ChrisRackauckas, although since we seem to be tracking more than one problem in this thread I can’t say for sure if that’s assisted or not.

Here’s what’s working:

 function loss(θ)
    tmp_prob = remake(prob_nn, p = θ)
    ensembleprob = EnsembleProblem(tmp_prob)
    tmp_sol = Array(solve(
        ensembleprob,
        EM(); 
        dt = tsteps.step,
        trajectories = 100,
        sensealg = InterpolatingAdjoint(),
       ))                                  
    tmp_mean = mean(tmp_sol,dims=3)[:,:]   
    tmp_var = var(tmp_sol,dims=3)[:,:]     
    sum(abs2, truemean - tmp_mean) + 0.1 * sum(abs2, truevar - tmp_var), tmp_mean
end

Which is essentially what @frankschae has suggested above (remove the mutating functions, use EM() and InterpolatingAdjoint()).

If I follow correctly, the TrackedReals addition should assist when using ReverseDiffAdjoint correct?
The error I’m getting there is now:

 ERROR: TaskFailedException:
MethodError: no method matching randn(::RandomNumbers.Xorshifts.Xoroshiro128Plus, ::Type{ReverseDiff.TrackedReal{Float32,Float32,Nothing}})                                                           
Closest candidates are:
  randn(::Random.AbstractRNG, ::Type{T}, ::Tuple{Vararg{Int64,N}} where N) where T at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/Random/src/normal.jl:201
  randn(::Random.AbstractRNG, ::Type{T}, ::Integer, ::Integer...) where T at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/Random/src/normal.jl:204
  randn(::Random.AbstractRNG) at /buildworker/worker/package_linux64/build/usr/share/julia/stdlib/v1.5/Random/src/normal.jl:38
  ...

@mohamed82008, concerning your latest comment: would it be better to not use inplace in this instance? I don’t quite follow if Chris’ changes has resolved that worry or not. If you think this problem requires an issue to track it - which package are we looking at here? DiffEqSensitivity? Or should I just put it on DifferentialEquations?

This looks like something that wouldn’t happen in my fix? Are you on DiffEqNoiseProcess v5.4.2? If so, please give a code that I can just copy paste to see the error, otherwise I’m not going to go through the effort to find a reproducer.

Yeah, I believe so. But I’ll add it explicitly and try again. If not I’ll put up a full MWE.

Fails with 5.4.2—yes.

(uode) pkg> st
Status `/mnt/turtle/scratch/uode/Project.toml`
  [ff3c4d4f] AutoOptimize v0.1.0 `https://github.com/SciML/AutoOptimize.jl.git#master`
  [052768ef] CUDA v1.3.3
  [2445eb08] DataDrivenDiffEq v0.4.1
  [a93c6f00] DataFrames v0.22.1
  [7806a523] DecisionTree v0.10.10
  [2b5f629d] DiffEqBase v6.48.2
  [aae7a2af] DiffEqFlux v1.24.0
  [071ae1c0] DiffEqGPU v1.8.0
  [77a26b50] DiffEqNoiseProcess v5.4.2
  [41bf760c] DiffEqSensitivity v6.34.0
  [587475ba] Flux v0.11.1
  [033835bb] JLD2 v0.3.0
  [2fda8390] LsqFit v0.11.0
  [961ee093] ModelingToolkit v3.21.0
  [429524aa] Optim v1.2.0
  [1dea7af3] OrdinaryDiffEq v5.45.1
  [789caeaf] StochasticDiffEq v6.26.0
  [5d786b92] TerminalLoggers v0.1.2
  [37e2e46d] LinearAlgebra
  [56ddb016] Logging
  [2f01184e] SparseArrays
  [10745b16] Statistics

Current script:

using DiffEqFlux, Flux, Optim, LinearAlgebra
using DiffEqNoiseProcess
using StochasticDiffEq
using Statistics
using DiffEqSensitivity
using DiffEqBase.EnsembleAnalysis

function sys!(du, u, p, t)
    r, e, μ, h, ph, z, i = p
    du[1] = e * 0.5 * (5μ - u[1]) # nutrient input time series
    du[2] = e * 0.05 * (10μ - u[2]) # grazer density time series
    du[3] = 0.2 * exp(u[1]) - 0.05 * u[3] - r * u[3] / (h + u[3]) * u[4] # nutrient concentration
    du[4] =
        r * u[3] / (h + u[3]) * u[4] - 0.1 * u[4] -
        0.02 * u[4]^z / (ph^z + u[4]^z) * exp(u[2] / 2.0) + i #Algae density
end

function noise!(du, u, p, t)
    du[1] = p[end] # n
    du[2] = p[end] # n
    du[3] = 0.0
    du[4] = 0.0
end

datasize = 10
tspan = (0.0f0, 3.0f0)
tsteps = range(tspan[1], tspan[2], length = datasize)
u0 = Float32[1.0, 1.0, 1.0, 1.0]

p_ = Float32[1.1, 1.0, 0.0, 2.0, 1.0, 1.0, 1e-6, 1.0]

prob = SDEProblem(sys!, noise!, u0, tspan, p_)
ensembleprob = EnsembleProblem(prob)

solution = solve(
    ensembleprob,
    SOSRI(),
    EnsembleThreads();  
    trajectories = 1000,
    abstol = 1e-5,
    reltol = 1e-5, 
    maxiters = 1e8, 
    saveat = tsteps,
)

(truemean, truevar) = Array.(timeseries_steps_meanvar(solution))

ann = FastChain(FastDense(4, 32, tanh), FastDense(32, 32, tanh), FastDense(32, 2))
α = initial_params(ann)

function dudt_(du, u, p, t)
    r, e, μ, h, ph, z, i = p_

    MM = ann(u, p)

    du[1] = e * 0.5 * (5μ - u[1]) # nutrient input time series
    du[2] = e * 0.05 * (10μ - u[2]) # grazer density time series
    du[3] = 0.2 * exp(u[1]) - 0.05 * u[3] - MM[1] # nutrient concentration
    du[4] = MM[2] - 0.1 * u[4] - 0.02 * u[4]^z / (ph^z + u[4]^z) * exp(u[2] / 2.0) + i #Algae density
    return nothing
end
function noise_(du, u, p, t)
    du[1] = p_[end]
    du[2] = p_[end]
    du[3] = 0.0
    du[4] = 0.0
    return nothing
end

prob_nn = SDEProblem(dudt_, noise_, u0, tspan, p = nothing)

function loss(θ)
    tmp_prob = remake(prob_nn, p = θ)
    ensembleprob = EnsembleProblem(tmp_prob)
    tmp_sol = Array(solve(
        ensembleprob,
        EM();
        dt = tsteps.step,
        trajectories = 100,
        sensealg = ReverseDiffAdjoint(),
       ))
    tmp_mean = mean(tmp_sol,dims=3)[:,:]
    tmp_var = var(tmp_sol,dims=3)[:,:]
    sum(abs2, truemean - tmp_mean) + 0.1 * sum(abs2, truevar - tmp_var), tmp_mean
end

const losses = []
callback(θ, l, pred) = begin
    push!(losses, l)
    if length(losses)%50 == 0
        println("Current loss after $(length(losses)) iterations: $(losses[end])")
    end
    false
end

res1 = DiffEqFlux.sciml_train(
    loss,
    α,
    ADAM(0.1),
    cb = callback,
    maxiters = 200,
)

This is the overload you need: https://github.com/SciML/DiffEqNoiseProcess.jl/pull/75 . But now I’m going to punt it over to @mohamed82008 since it’s erroring because the output of the function is an Array{TrackedReal} so it can’t increment_deriv!, and reduce(vcat,output) isn’t giving a TrackedArray.

1 Like

Thanks Chris!. @mohamed82008, I’m tracking the issue Chris just mentioned, so happy to provide details there if needed.