Hello all,
I have a question related the ODE Turing tutorial:
using Turing, DifferentialEquations, StatsPlots, LinearAlgebra, Random
Random.seed!(14);
function lotka_volterra(du, u, p, t)
# Model parameters.
α, β, γ, δ = p
# Current state.
x, y = u
# Evaluate differential equations.
du[1] = (α - β * y) * x # prey
du[2] = (δ * x - γ) * y # predator
return nothing
end
# Define initial-value problem.
u0 = [1.0, 1.0]
p = [1.5, 1.0, 3.0, 1.0]
tspan = (0.0, 10.0)
prob = ODEProblem(lotka_volterra, u0, tspan, p)
sol = solve(prob, Tsit5(); saveat=0.1)
odedata = Array(sol) + 0.8 * randn(size(Array(sol)))
@model function fitlv(data, prob)
# Prior distributions.
σ ~ InverseGamma(2, 3)
α ~ truncated(Normal(1.5, 0.5); lower=0.5, upper=2.5)
β ~ truncated(Normal(1.2, 0.5); lower=0, upper=2)
γ ~ truncated(Normal(3.0, 0.5); lower=1, upper=4)
δ ~ truncated(Normal(1.0, 0.5); lower=0, upper=2)
# Simulate Lotka-Volterra model.
p = [α, β, γ, δ]
predicted = solve(prob, Tsit5(); p=p, saveat=0.1)
# Observations.
for i in 1:length(predicted)
data[:, i] ~ MvNormal(predicted[i], σ^2 * I)
end
return nothing
end
model = fitlv(odedata, prob)
# Sample 3 independent chains with forward-mode automatic differentiation (the default).
@time chain = sample(model, NUTS(), MCMCSerial(), 500, 3; progress=true)
Now imagine that we do not have the same amount of observations for preys and predators in odedata
. We will have missing values and I have read it is quite problematic when calculating the likelihood.
Is it possible to just concatenate the predators and prey and use Normal
instead of MvNormal
as follows?
Thanks
@model function fitlv_test(data, prob)
# Prior distributions.
σ ~ InverseGamma(2, 3)
α ~ truncated(Normal(1.5, 0.5); lower=0.5, upper=2.5)
β ~ truncated(Normal(1.2, 0.5); lower=0, upper=2)
γ ~ truncated(Normal(3.0, 0.5); lower=1, upper=4)
δ ~ truncated(Normal(1.0, 0.5); lower=0, upper=2)
# Simulate Lotka-Volterra model.
p = [α, β, γ, δ]
predicted = solve(prob, Tsit5(); p=p, saveat=0.1)
pred = vcat(predicted[1,:],predicted[2,:]);
# Observations.
for i in 1:length(pred)
data[i] ~ Normal(pred[i], σ^2)
end
return nothing
end
data_test = vcat(odedata[1,:],odedata[2,:])
model_test = fitlv_test(data_test, prob)
@time chain_test = sample(model_test, NUTS(), MCMCSerial(), 500, 3; progress=true)