Hello,
Following is a moderately large Turing model, where the ODE system is solved for 51 time steps with a discrete callback at t=20
.
#======================================
PARAMETER ESTIMATION FROM SYNTHETIC DATA
======================================#
#=
Section 1: Importing packages, files, and defining variables
=#
using DifferentialEquations, Interpolations, XLSX, DataFrames, StatsPlots
using Distributions, DistributionsAD, MCMCChains, Memoization, Turing
using ReverseDiff
# Turing.setrdcache(true)
using Turing; Turing.setadbackend(:reversediff)
include("model_file.jl") # Loads the ODE model function our_model_v1!
tot_weeks = 51 # Number of time points
N = 1000000 # Parameter - population size
V = 2 # Parameter - number of species
tspan = (1, tot_weeks)
tspan = Float64.(tspan) # Time period for simulation of ODEs
# The next two variables, IPTCC and mobil, are linearly interpolated forcing
# functions that are used in the model. They are read from two separate
# files, and the synthetic data was generated using the same variables.
# IPTCC is a forcing function
wet_data = DataFrame(XLSX.readtable("Wetdata.xlsx","Wetdata"; infer_eltypes = true)...)
IPTCC = wet_data.Normalized_IPTCC
IPTCC = IPTCC[1:tot_weeks]
IPTCC = Float64.(IPTCC)
IPTCC = interpolate(IPTCC, BSpline(Linear()))
# mobil is another forcing function
mobil_data = DataFrame(XLSX.readtable("Mobdata.xlsx","Mobdata"; infer_eltypes = true)...)
mobil = mobil_data.Mean
mobil = mobil[1:tot_weeks]
mobil = Float64.(mobil)
mobil = interpolate(mobil, BSpline(Linear()))
# Reading the observation data
synth_data = DataFrame(XLSX.readtable("synthData_version1.xlsx","Sheet1"; infer_eltypes = true)...)
#=
Section 2: Define the observation model
=#
@model function truth_data_fitting!(data, ODEtspan, num_species, tot_pop, interp_IPTCC, interp_mobil)
param_prior_df = DataFrame(); # A dataframe container for parameters
# Priors of parameters - ODE model with two species
priorβ₁ ~ Uniform(0, 7)
priorβ₂ ~ Uniform(0, 7)
priorα ~ Uniform(1.4, 3.5)
priorθₙ ~ Uniform(0.0008, 0.6993)
priorθᵢ ~ Uniform(0.000001, 0.14)
priorγₙ ~ Uniform(0.35, 0.7)
priorγᵢ ~ Uniform(0.35, 0.7)
priorγ ~ Uniform(0.8, 3.4965)
priorλ ~ Uniform(7/1095, 7/730)
priorδᵢ ~ Uniform(0.35, 0.7)
# Priors of parameters - Observation model
σ₁ ~ Gamma(1.0, 5.0) # For observation 1
σ₂ ~ Gamma(1.0, 5.0) # For observation 2
σ₃ ~ Gamma(1.0, 5.0) # For observation 3
σ₄ ~ Gamma(1.0, 5.0) # For observation 4
σ₅ ~ Gamma(1.0, 5.0) # For observation 5
# Pushing the ODE model priors to the data frame
param_prior_df.β = [priorβ₁; priorβ₂]
param_prior_df.α = [priorα; priorα]
param_prior_df.θₙ = [priorθₙ; priorθₙ]
param_prior_df.θᵢ = [priorθᵢ; priorθᵢ]
param_prior_df.γₙ = [priorγₙ; priorγₙ]
param_prior_df.γᵢ = [priorγᵢ; priorγᵢ]
param_prior_df.γ = [priorγ; priorγ]
param_prior_df.λ = [priorλ; priorλ]
param_prior_df.δᵢ = [priorδᵢ; priorδᵢ]
# Final parameter container
const_params = [num_species, tot_pop, interp_IPTCC, interp_mobil]
final_parameters = [const_params; param_prior_df]
# Defining the initial conditions and the problem
I0₁ ~ Uniform(0.1, 100)
I0₂ ~ Uniform(0.1, 100)
u0 = zeros(eltype(I0₁), 5*num_species+5)
u0[1] = tot_pop
u0[4] = I0₁
u0[12] = I0₁
inference_problem = ODEProblem(our_model_v1!, u0, ODEtspan, final_parameters)
# Defining the callback function to introduce change in species 2 at time point 20
v2_seed_time = 20.0
condition(u,t,integrator) = t==v2_seed_time
function affect!(integrator)
integrator.u[5] += I0₂ # For the model equation
integrator.u[13] += I0₂ # For the cumulative counter
end
cb = DiscreteCallback(condition,affect!;save_positions=(false,false))
# Solve the model
inference_solution = solve(inference_problem, Tsit5(), saveat = 1.0, callback = cb, tstops=[v2_seed_time])
# Inference using multivariate normal distributions
data[:,1] ~ MvNormal(inference_solution[12,:], σ₁*(sqrt.(inference_solution[12,:])))
data[:,2] ~ MvNormal(inference_solution[13,:], σ₂*(sqrt.(inference_solution[13,:])))
data[:,3] ~ MvNormal(inference_solution[14,:], σ₃*(sqrt.(inference_solution[14,:])))
data[:,4] ~ MvNormal(inference_solution[15,:], σ₄*(sqrt.(inference_solution[15,:])))
data[:,5] ~ MvNormal(inference_solution[11,:], σ₅*(sqrt.(inference_solution[11,:])))
end
#=
Section 3: Run the inference procedure and save the chains
=#
combo1_model = truth_data_fitting!(synth_data, tspan, V, N, IPTCC, mobil)
chains = sample(combo1_model, NUTS(0.65), 100; progress=true)
When simulating with Turing.setrdcache(true)
, it throws a Domain error
for trying to estimate sqrt
for a negative number. But my code doesn’t contain any branches or loops that run for different time periods.
But if I run without caching, the solver is perpetually stuck with the following error and never finds a time step.
┌ Warning: The current proposal will be rejected due to numerical error(s).
│ isfinite.((θ, r, ℓπ, ℓκ)) = (true, true, false, true)
└ @ AdvancedHMC ~/.julia/packages/AdvancedHMC/HQHnm/src/hamiltonian.jl:47
Can someone please, please help me out? Big thanks in advance.