Hi everyone
I have a small ODE model, which runs fine (albeit slow) with Turing forwarddiff AD. However, I cannot get it to run with reversediff (and neither with Zygote). It throws a MethodError (see below). Is this some type problem, maybe related to the parameters being an array? What would I need to change? TIA
Error:
ERROR: MethodError: convert(::Type{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.TimeGradientWrapper{ODEFunction{true, typeof(SEIRS2!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, LArray{ReverseDiff.TrackedReal{Float64, Float64, Nothing}, 1, Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, (:β, :η, :ω, :φ, :σ, :μ, :δ2, :γ1, :g2)}}, Float64}, ReverseDiff.TrackedReal{Float64, Float64, Nothing}, 1}}, ::ReverseDiff.TrackedReal{Float64, Float64, Nothing}) is ambiguous. Candidates:
convert(::Type{R}, t::T) where {R<:Real, T<:ReverseDiff.TrackedReal} in ReverseDiff at C:\Users\micky\.julia\packages\ReverseDiff\E4Tzn\src\tracked.jl:260
convert(::Type{ForwardDiff.Dual{T, V, N}}, x::Number) where {T, V, N} in ForwardDiff at C:\Users\micky\.julia\packages\ForwardDiff\5gUap\src\dual.jl:380
convert(::Type{T}, x::Number) where T<:Number in Base at number.jl:7
convert(::Type{ForwardDiff.Dual{T, V, N}}, x) where {T, V, N} in ForwardDiff at C:\Users\micky\.julia\packages\ForwardDiff\5gUap\src\dual.jl:379
Possible fix, define
convert(::Type{ForwardDiff.Dual{T, V, N}}, ::T) where {T, V, N, T<:ReverseDiff.TrackedReal}
Stacktrace:
setindex!(A::Vector{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.TimeGradientWrapper{ODEFunction{true, typeof(SEIRS2!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, LArray{ReverseDiff.TrackedReal{Float64, Float64, Nothing}, 1, Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, (:β, :η, :ω, :φ, :σ, :μ, :δ2, :γ1, :g2)}}, Float64}, ReverseDiff.TrackedReal{Float64, Float64, Nothing}, 1}}, x::ReverseDiff.TrackedReal{Float64, Float64, Nothing}, i1::Int64) at .\array.jl
SEIRS2!(du::Vector{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.TimeGradientWrapper{ODEFunction{true, typeof(SEIRS2!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, LArray{ReverseDiff.TrackedReal{Float64, Float64, Nothing}, 1, Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, (:β, :η, :ω, :φ, :σ, :μ, :δ2, :γ1, :g2)}}, Float64}, ReverseDiff.TrackedReal{Float64, Float64, Nothing}, 1}}, u::Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, p::LArray{ReverseDiff.TrackedReal{Float64, Float64, Nothing}, 1, Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, (:β, :η, :ω, :φ, :σ, :μ, :δ2, :γ1, :g2)}, t::ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.TimeGradientWrapper{ODEFunction{true, typeof(SEIRS2!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, LArray{ReverseDiff.TrackedReal{Float64, Float64, Nothing}, 1, Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, (:β, :η, :ω, :φ, :σ, :μ, :δ2, :γ1, :g2)}}, Float64}, Float64, 1}) at h:\My Documents\SEIRS2_NUTS_toy.jl
(::ODEFunction{true, typeof(SEIRS2!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing})(::Vector{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.TimeGradientWrapper{ODEFunction{true, typeof(SEIRS2!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, LArray{ReverseDiff.TrackedReal{Float64, Float64, Nothing}, 1, Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, (:β, :η, :ω, :φ, :σ, :μ, :δ2, :γ1, :g2)}}, Float64}, ReverseDiff.TrackedReal{Float64, Float64, Nothing}, 1}}, ::Vararg{Any, N} where N) at C:\Users\micky.julia\packages\SciMLBase\cU5k7\src\scimlfunctions.jl
(::SciMLBase.TimeGradientWrapper{ODEFunction{true, typeof(SEIRS2!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, LArray{ReverseDiff.TrackedReal{Float64, Float64, Nothing}, 1, Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, (:β, :η, :ω, :φ, :σ, :μ, :δ2, :γ1, :g2)}})(du2::Vector{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.TimeGradientWrapper{ODEFunction{true, typeof(SEIRS2!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, LArray{ReverseDiff.TrackedReal{Float64, Float64, Nothing}, 1, Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, (:β, :η, :ω, :φ, :σ, :μ, :δ2, :γ1, :g2)}}, Float64}, ReverseDiff.TrackedReal{Float64, Float64, Nothing}, 1}}, t::ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.TimeGradientWrapper{ODEFunction{true, typeof(SEIRS2!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, LArray{ReverseDiff.TrackedReal{Float64, Float64, Nothing}, 1, Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, (:β, :η, :ω, :φ, :σ, :μ, :δ2, :γ1, :g2)}}, Float64}, Float64, 1}) at C:\Users\micky.julia\packages\SciMLBase\cU5k7\src\function_wrappers.jl
[TRUNCATED]
Code:
# Household stuff
cd(@__DIR__)
using Pkg
Pkg.activate(".")
using DifferentialEquations
using StatsPlots
using StatsBase
using Turing
using Distributions
using Random
using LabelledArrays
using Serialization
using LazyArrays
using ReverseDiff
using DiffEqSensitivity
using Zygote
using Memoization
Random.seed!(422)
# ODE model
function SEIRS2!(du,u,p,t)
# states
(S1, E1, I1, R1, S2, E2, I2, R2) = u[1:8]
N1 = S1 + E1 + I1 + R1
N2 = S2 + E2 + I2 + R2
N = N1 + N2
# params
β = p.β
η = p.η
φ = p.φ
ω = 1.0/p.ω
μ = p.μ
σ = p.σ
γ1 = p.γ1
γ2 = γ1 / p.g2
δ2 = p.δ2
# FOI
βeff = β * (1.0+η*cos(2.0*π*(t-φ)/365.0))
λ1 = βeff*(I1/N1 + I2/N2)
λ2 = λ1 * δ2
# change in states
du[1] = μ*N - λ1*S1 - μ*S1
du[2] = λ1*S1 - σ*E1 - μ*E1
du[3] = σ*E1 - γ1*I1 - μ*I1
du[4] = γ1*I1 - ω*R1 - μ*R1
du[5] = ω*(R1 + R2) - λ2*S2 - μ*S2
du[6] = λ2*S2 - σ*E2 - μ*E2
du[7] = σ*E2 - γ2*I2 - μ*I2
du[8] = γ2*I2 - ω*R2 - μ*R2
du[9] = (σ*(E1 + E2)) # cumulative incidence
end
# observation model
function NegativeBinomial2(ψ, incidence)
p = 1.0/(1.0 + ψ*incidence)
r = 1.0/ψ
return NegativeBinomial(r, p)
end
# Solver settings
tmin = 0.0
tmax = 20.0*365.0
tspan = (tmin, tmax)
solvsettings = (abstol = 1.0e-3,
reltol = 1.0e-3,
saveat = 7.0,
solver = AutoTsit5(Rosenbrock23()))
# Initiate ODE problem
theta_fix = [1.0/4.98, 1.0/(80*365), 0.89, 1/6.16, 0.87]
theta_est = [0.15, 0.001, 0.28, 0.07, 365.0, 180.0]
parnames = (:ψ, :ρ, :β, :η, :ω, :φ, :σ, :μ, :δ2, :γ1, :g2)
p = @LArray [theta_est; theta_fix] parnames
u0 = [200_000.0,1000.0,1000.0,300_000.0, 500_000.0, 1000.0,1000.0, 296_000, 2000.0]
# Initiate ODE problem
problem = ODEProblem(SEIRS2!,u0,tspan,p)
sol = solve(problem,
solvsettings.solver,
abstol=solvsettings.abstol,
reltol=solvsettings.reltol,
isoutofdomain=(u,p,t)->any(x->x<0.0,u),
save_idxs=9,
saveat=solvsettings.saveat)
# Fake some data from model
foo = (sol[2:end] - sol[1:(end-1)]) .* p.ρ
data = rand.(NegativeBinomial2.(p.ψ, foo))
plot(foo, legend = false); scatter!(data,legend = false)
# Fit model to fake data
# Set up as Turing model
#Turing.setadbackend(:forwarddiff)
Turing.setadbackend(:reversediff)
Turing.setrdcache(true)
@model function turingmodel(data, theta_fix, u0, problem, parnames, solvsettings)
# Priors
ψ ~ Beta(1,5)
ρ ~ Uniform(0.0,1.0)
β ~ Uniform(0.0,1.0)
η ~ Uniform(0.0,1.0)
ω ~ Uniform(1.0, 3.0*365.0)
φ ~ Uniform(0.0,364.0)
theta_est = [β,η,ω,φ]
p_new = @LArray vcat(theta_est, theta_fix) parnames[3:end]
# Update problem and solve ODEs
problem_new = remake(problem, p=p_new, u0=eltype(p_new).(u0))
sol_new = concrete_solve(problem_new,
solvsettings.solver,
abstol=solvsettings.abstol,
reltol=solvsettings.reltol,
isoutofdomain=(u,p,t)->any(x->x<0.0,u),
save_idxs=9,
saveat=solvsettings.saveat)
incidence = sol_new[2:end] - sol_new[1:(end-1)]
incidence = max.(0.0, incidence) # avoid numerical instability issue
increp = incidence .* ρ
if size(increp,1)==size(data,1)
data ~ arraydist(LazyArray(@~ @. NegativeBinomial2(ψ, increp)))
else
Turing.@addlogprob! -Inf
return
end
end
model = turingmodel(data, theta_fix, u0, problem, parnames, solvsettings)
@time trace = mapreduce(c -> sample(model, NUTS(5000, 0.65), 1000, progress=true), chainscat, 1:3)