Hi all,
I’m trying to estimate parameters in a logistic growth model with varying parameters using @DifferentialEquations.jl and @Turing.jl. Starting from the simplest case, I’d like to estimate K
for each equation in the system, which ideally would be inferred from the data. I’ve tried using this example of a multi-level model https://github.com/StatisticalRethinkingJulia/TuringModels.jl/blob/master/scripts/13/m13.2.jl with DiffEq solver and am running into the following error:
ERROR: TypeError: in typeassert, expected Float64, got a value of type ForwardDiff.Dual{Nothing,Float64,6}
Here’s a general example of what I’m trying to do:
using DifferentialEquations, Turing, MCMCChains
function f(u, p, t)
r, k = p[1], p[2:end]
return r*u .* (1 .- u ./ k)
end
n = 4
u0 = rand(n)
tspan = (0.0,20.0)
r = .75; k = 25; γ = 2
g = rand(Normal(k, γ), n)
p = [r; g]
dt = .5
prob = ODEProblem(f, u0, tspan, p)
sol = solve(prob, RK4(), saveat = dt, reltol = 1e-6)
data = Array(sol) + .1randn(size(Array(sol)))
@model function fn_lg(data, prob, ::Type{T} = Float64) where {T <: Real}
γ ~ Exponential(5)
k ~ Normal(30, 10)
theta = Vector{Float64}(undef, length(n))
for j=1:length(n)
theta[j] ~ truncated(Normal(k, γ), 0,100)
end
σ ~ InverseGamma(2,1)
r ~ truncated(Normal(1, 1), 0, 2)
p = [r]
for i=1:length(n)
push!(p, convert(Float64, false ? theta[i] : theta[i]) )
end
predicted = solve(prob, Rosenbrock23(autodiff=false), reltol=1e-3; p=p, saveat = dt)
for i=1:length(predicted)
data[:,i] ~ MvNormal(predicted[i], σ)
end
end
model = fn_lg(data, prob)
chain = @time sample(model, Turing.NUTS(40, .65), 50, progress=true)
I could have missed something in the documentation and/or DiffEqBayes.turing_inference()
https://github.com/SciML/DiffEqBayes.jl/blob/master/src/turing_inference.jl (which handles arrayed of parameters), but cannot seem to pin down what exactly when writing it out explicitly via Turing.@model
. The error seems to be related to the type of parameters passed on to the ode solver (?), but I could really use some help with its interpretation and locating the solution.
Will greatly appreciate any thoughts or advice!