Turing MethodError with reversediff but not with forwarddiff

Hi everyone

I have a small ODE model, which runs fine (albeit slow) with Turing forwarddiff AD. However, I cannot get it to run with reversediff (and neither with Zygote). It throws a MethodError (see below). Is this some type problem, maybe related to the parameters being an array? What would I need to change? TIA

Error:

ERROR: MethodError: convert(::Type{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.TimeGradientWrapper{ODEFunction{true, typeof(SEIRS2!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, LArray{ReverseDiff.TrackedReal{Float64, Float64, Nothing}, 1, Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, (:β, :η, :ω, :φ, :σ, :μ, :δ2, :γ1, :g2)}}, Float64}, ReverseDiff.TrackedReal{Float64, Float64, Nothing}, 1}}, ::ReverseDiff.TrackedReal{Float64, Float64, Nothing}) is ambiguous. Candidates:
  convert(::Type{R}, t::T) where {R<:Real, T<:ReverseDiff.TrackedReal} in ReverseDiff at C:\Users\micky\.julia\packages\ReverseDiff\E4Tzn\src\tracked.jl:260
  convert(::Type{ForwardDiff.Dual{T, V, N}}, x::Number) where {T, V, N} in ForwardDiff at C:\Users\micky\.julia\packages\ForwardDiff\5gUap\src\dual.jl:380
  convert(::Type{T}, x::Number) where T<:Number in Base at number.jl:7
  convert(::Type{ForwardDiff.Dual{T, V, N}}, x) where {T, V, N} in ForwardDiff at C:\Users\micky\.julia\packages\ForwardDiff\5gUap\src\dual.jl:379
Possible fix, define
  convert(::Type{ForwardDiff.Dual{T, V, N}}, ::T) where {T, V, N, T<:ReverseDiff.TrackedReal}
Stacktrace:

setindex!(A::Vector{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.TimeGradientWrapper{ODEFunction{true, typeof(SEIRS2!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, LArray{ReverseDiff.TrackedReal{Float64, Float64, Nothing}, 1, Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, (:β, :η, :ω, :φ, :σ, :μ, :δ2, :γ1, :g2)}}, Float64}, ReverseDiff.TrackedReal{Float64, Float64, Nothing}, 1}}, x::ReverseDiff.TrackedReal{Float64, Float64, Nothing}, i1::Int64) at .\array.jl

SEIRS2!(du::Vector{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.TimeGradientWrapper{ODEFunction{true, typeof(SEIRS2!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, LArray{ReverseDiff.TrackedReal{Float64, Float64, Nothing}, 1, Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, (:β, :η, :ω, :φ, :σ, :μ, :δ2, :γ1, :g2)}}, Float64}, ReverseDiff.TrackedReal{Float64, Float64, Nothing}, 1}}, u::Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, p::LArray{ReverseDiff.TrackedReal{Float64, Float64, Nothing}, 1, Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, (:β, :η, :ω, :φ, :σ, :μ, :δ2, :γ1, :g2)}, t::ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.TimeGradientWrapper{ODEFunction{true, typeof(SEIRS2!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, LArray{ReverseDiff.TrackedReal{Float64, Float64, Nothing}, 1, Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, (:β, :η, :ω, :φ, :σ, :μ, :δ2, :γ1, :g2)}}, Float64}, Float64, 1}) at h:\My Documents\SEIRS2_NUTS_toy.jl

(::ODEFunction{true, typeof(SEIRS2!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing})(::Vector{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.TimeGradientWrapper{ODEFunction{true, typeof(SEIRS2!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, LArray{ReverseDiff.TrackedReal{Float64, Float64, Nothing}, 1, Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, (:β, :η, :ω, :φ, :σ, :μ, :δ2, :γ1, :g2)}}, Float64}, ReverseDiff.TrackedReal{Float64, Float64, Nothing}, 1}}, ::Vararg{Any, N} where N) at C:\Users\micky.julia\packages\SciMLBase\cU5k7\src\scimlfunctions.jl

(::SciMLBase.TimeGradientWrapper{ODEFunction{true, typeof(SEIRS2!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, LArray{ReverseDiff.TrackedReal{Float64, Float64, Nothing}, 1, Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, (:β, :η, :ω, :φ, :σ, :μ, :δ2, :γ1, :g2)}})(du2::Vector{ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.TimeGradientWrapper{ODEFunction{true, typeof(SEIRS2!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, LArray{ReverseDiff.TrackedReal{Float64, Float64, Nothing}, 1, Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, (:β, :η, :ω, :φ, :σ, :μ, :δ2, :γ1, :g2)}}, Float64}, ReverseDiff.TrackedReal{Float64, Float64, Nothing}, 1}}, t::ForwardDiff.Dual{ForwardDiff.Tag{SciMLBase.TimeGradientWrapper{ODEFunction{true, typeof(SEIRS2!), LinearAlgebra.UniformScaling{Bool}, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, Nothing, typeof(SciMLBase.DEFAULT_OBSERVED), Nothing}, Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, LArray{ReverseDiff.TrackedReal{Float64, Float64, Nothing}, 1, Vector{ReverseDiff.TrackedReal{Float64, Float64, Nothing}}, (:β, :η, :ω, :φ, :σ, :μ, :δ2, :γ1, :g2)}}, Float64}, Float64, 1}) at C:\Users\micky.julia\packages\SciMLBase\cU5k7\src\function_wrappers.jl

[TRUNCATED]

Code:

# Household stuff
cd(@__DIR__)
using Pkg 
Pkg.activate(".") 

using DifferentialEquations
using StatsPlots
using StatsBase
using Turing
using Distributions
using Random
using LabelledArrays
using Serialization
using LazyArrays
using ReverseDiff
using DiffEqSensitivity
using Zygote
using Memoization

Random.seed!(422)

# ODE model
function SEIRS2!(du,u,p,t)

    # states
    (S1, E1, I1, R1, S2, E2, I2, R2) = u[1:8]
    N1 = S1 + E1 + I1 + R1
    N2 = S2 + E2 + I2 + R2
    N = N1 + N2
  
    # params
    β = p.β
    η = p.η
    φ = p.φ
    ω = 1.0/p.ω
    μ = p.μ
    σ = p.σ
    γ1 = p.γ1
    γ2 = γ1 / p.g2
    δ2 = p.δ2
      
    # FOI
    βeff = β * (1.0+η*cos(2.0*π*(t-φ)/365.0))
    λ1 = βeff*(I1/N1 + I2/N2)
    λ2 = λ1 * δ2

    # change in states
    du[1] = μ*N - λ1*S1 - μ*S1
    du[2] = λ1*S1 - σ*E1 - μ*E1
    du[3] = σ*E1 - γ1*I1 - μ*I1
    du[4] = γ1*I1 - ω*R1 - μ*R1

    du[5] = ω*(R1 + R2) - λ2*S2 - μ*S2
    du[6] = λ2*S2 - σ*E2 - μ*E2
    du[7] = σ*E2 - γ2*I2 - μ*I2
    du[8] = γ2*I2 - ω*R2 - μ*R2

    du[9] = (σ*(E1 + E2)) # cumulative incidence

end

# observation model
function NegativeBinomial2(ψ, incidence)
    p = 1.0/(1.0 + ψ*incidence)
    r = 1.0/ψ
    return NegativeBinomial(r, p)
end

# Solver settings
tmin = 0.0
tmax = 20.0*365.0
tspan = (tmin, tmax)
solvsettings = (abstol = 1.0e-3, 
                reltol = 1.0e-3, 
                saveat = 7.0,
                solver = AutoTsit5(Rosenbrock23()))

# Initiate ODE problem
theta_fix =  [1.0/4.98, 1.0/(80*365), 0.89, 1/6.16, 0.87]
theta_est =  [0.15, 0.001, 0.28, 0.07, 365.0, 180.0]
parnames = (:ψ, :ρ, :β, :η, :ω, :φ, :σ, :μ, :δ2, :γ1, :g2)
p = @LArray [theta_est; theta_fix] parnames
u0 = [200_000.0,1000.0,1000.0,300_000.0, 500_000.0, 1000.0,1000.0, 296_000, 2000.0]

# Initiate ODE problem
problem = ODEProblem(SEIRS2!,u0,tspan,p)

sol = solve(problem, 
            solvsettings.solver, 
            abstol=solvsettings.abstol, 
            reltol=solvsettings.reltol, 
            isoutofdomain=(u,p,t)->any(x->x<0.0,u),
            save_idxs=9, 
            saveat=solvsettings.saveat)

# Fake some data from model
foo = (sol[2:end] - sol[1:(end-1)]) .* p.ρ
data = rand.(NegativeBinomial2.(p.ψ, foo))
plot(foo, legend = false); scatter!(data,legend = false)

# Fit model to fake data

# Set up as Turing model
#Turing.setadbackend(:forwarddiff)
Turing.setadbackend(:reversediff)
Turing.setrdcache(true)

@model function turingmodel(data, theta_fix, u0, problem, parnames, solvsettings) 
    
    # Priors
    ψ ~ Beta(1,5)
    ρ ~ Uniform(0.0,1.0) 
    β ~ Uniform(0.0,1.0) 
    η ~ Uniform(0.0,1.0) 
    ω ~ Uniform(1.0, 3.0*365.0) 
    φ ~ Uniform(0.0,364.0)

    theta_est = [β,η,ω,φ]
    p_new = @LArray vcat(theta_est, theta_fix) parnames[3:end]

    # Update problem and solve ODEs
    problem_new = remake(problem, p=p_new, u0=eltype(p_new).(u0)) 

    sol_new = concrete_solve(problem_new, 
                    solvsettings.solver, 
                    abstol=solvsettings.abstol, 
                    reltol=solvsettings.reltol, 
                    isoutofdomain=(u,p,t)->any(x->x<0.0,u), 
                    save_idxs=9,
                    saveat=solvsettings.saveat)

    incidence = sol_new[2:end] - sol_new[1:(end-1)]
    incidence = max.(0.0, incidence) # avoid numerical instability issue
    increp = incidence .* ρ
    if size(increp,1)==size(data,1) 
        data ~ arraydist(LazyArray(@~ @. NegativeBinomial2(ψ, increp)))
    else 
        Turing.@addlogprob! -Inf
        return
    end 

end

model = turingmodel(data, theta_fix, u0, problem, parnames, solvsettings)

@time trace = mapreduce(c -> sample(model, NUTS(5000, 0.65), 1000, progress=true), chainscat, 1:3)   

useful discussion with lots of code improvement thanks to @torfjelde
https://github.com/TuringLang/Turing.jl/issues/1673

Yes, https://github.com/TuringLang/Turing.jl/issues/1673#issuecomment-891686825 expanded is a great post.

1 Like