Bayesian parameter estimation of a ODE with Bernouilli multivariate likelihood

Hi,

I am trying to estimate the parameters of a ODE model with N equations. The model’s output are probabilities, ie, a vector of N probabilities each time step. Each coordinate of the vector defines the probability that patch i is occupied at time t. The observed data is 0/1 which means empty/occupied, this data is yearly. So I am using a Bernoulli distribution for the likelihood.
I have the output of the model which is “continuous” and I compare it with the 0/1 of the observed data. To do that I take t = 15 August year X from the output of the model and compare to the 0/1 of this year X. Also the model is non autonomous so it has a parameter that depends on time, which I interpolate and after in the model ODE it calls this interpolation to compute the value of the parameter. So I have three main issues:

  1. There is no multivariate Bernouilli.
  2. I do not have many observed data.
  3. The size of the system is big and it takes ages. Just to solve the ODE it takes minutes.

I do not know if there is something that I am doing wrong and this is why it never finish or its just slow due to the cost of the integration method.

This is how I code this using fake input data:

# Code that integrates the Hanski model to obtain the solutions 
# With input: (from code input_Hanski.R)
#     . Distance matrix
#     . Flow matrix
#     . Area vector
#     . R_M vector
# And estimate the parameters of the model from the presence absence data
# Load pkgs and remove data -----------------------------------------------------
import Pkg
using DifferentialEquations
using DataFrames
using CSV
using Plots
using Shapefile
import GeoDataFrames as GDF
using LinearAlgebra
using ODE
using Interpolations
using DiffEqParamEstim
using SparseArrays
using ParameterizedFunctions
#using RecursiveArrayTools
using Optimization
using Distributions

# Constants
m_c = 0.001 # probability of mosquito in a car
alp = 1/200 # average natural dispersal distance

# Create fake data
end_ind = 5
eta = rand(Uniform(0, 1000),end_ind, end_ind)
eta[diagind(eta)] .= 0 

# Set to zero random entraces
eta[1,4] = 0
eta[2,3] = 0
eta[5,1:end] .= 0

N = size(eta, 1) # Number of patches
d_1 = rand(Uniform(0, 100),end_ind, end_ind)
d_1[5,:] .= rand(Uniform(900, 1000),end_ind,1)
typeof(d_1[1,1])

# Choose number of patches and IC
pop_init = zeros(N)
pop_init[1] = 1

# Fake time series
num_years = 6
num_days = 365*num_years
days = 1:num_days
days_per_year = 365
dates = collect(1:num_days)

#  Fake seasonal
R_Ms = zeros(num_days,end_ind)
R_Ms[1:end,1] .= (30 .+ 10 * sin.(2 * π * days / days_per_year) + 2. * randn(num_days))./10
R_Ms[1:end,2] .= (30 .+ 10 * sin.(2 * π * days / days_per_year) + 2. * randn(num_days))./30
R_Ms[1:end,3] .= (30 .+ 10 * sin.(2 * π * days / days_per_year) + 2. * randn(num_days))./40
R_Ms[1:end,4] .= (30 .+ 10 * sin.(2 * π * days / days_per_year) + 2. * randn(num_days))./10
R_Ms[1:end,5] .= (30 .+ 10 * sin.(2 * π * days / days_per_year) + 2. * randn(num_days))./100

plot(days, R_Ms[1:end,1], title="Positive Seasonal Pattern with Random Variation", xlabel="Days", ylabel="Value")
# Create an array to store interpolated functions
interpolated_functions = []
 
# Perform interpolation for each location
for i in 1:end_ind
    # Extract temperature values for the current location
    R_Ms_val = R_Ms[:, i]
    dates_num = 1:size(dates, 1)

    # Perform linear interpolation
    itp = LinearInterpolation(dates_num, R_Ms_val,extrapolation_bc=Flat())
    # Store the interpolated function
    push!(interpolated_functions, itp)

end

plot(interpolated_functions[1:N])

# Integration non autonomous
function fun_na!(du, u, p, t)
    R_M_vec = [interpolated_functions[i](t) for i in 1:N]
  
    mat = exp.(-alp*d_1)
    mat[diagind(mat)] .= 0 # Set to zero diagonal of distance
    eta1 = 1.0./(1.0 .+ exp.(-p[4].*eta.+p[5]))
    Cd = p[1].*mat*u # Natural dispersal
    Ch = p[2].*(m_c*eta1)*u # Human mobility
    du.= R_M_vec.*(Cd .+ Ch) .* (1 .- u) - p[3] .* u
    nothing
end
  
# Set parameters
t0=0.0
tf=2000.0
tspan = (t0, tf)
t_vect=1:tf
u = pop_init
p = [0.001,0.001,0.001,0.5,500]

# Create the ode model
prob = DifferentialEquations.ODEProblem(fun_na!, u, tspan, p)
rtol=1e-14
alg= DP8() #For low tolerances, Rodas4()

# Test Integrate ODE
@time sol = DifferentialEquations.solve(prob,alg,reltol = rtol)
plot(sol)

# Create the vector of times of the observation. Assuming we observe in August 15 (or 16 depending onn leap years)
t_obs = zeros(num_years)
t_obs[1] = 227 
for i = 2:length(t_obs)
  t_obs[i] = t_obs[i-1] + 365
end

# Create fake obs if p(t)>0.5 then 1
matrix_obs = zeros(end_ind,num_years)

# Loop through unique id_mitma values
for i in 1:end_ind
    # Get years for the current id_mitma
    for j in 1:num_years
        prob = sol(t_obs[j])[i]
        # Convert years to indices in the matrix
        if (prob > 0.5)
           matrix_obs[i,j:end] .= 1 
        end
    end
end

# Load packages
using Turing
using StatsPlots

# Estimation Bayesian --------------------------------
@model function fitlv(data, prob) 
    # Prior distributions.
    c_d_dist ~ Uniform(0.1e-4, 0.1)
    c_h_dist ~ Uniform(0.1e-4, 0.1)
    e_dist ~ Uniform(0.1e-7, 0.1)
    a_dist ~ Uniform(0.1, 0.8)
    b_dist ~ Uniform(450,550)

    # Simulate Hanski model. 
    p = [c_d_dist, c_h_dist, e_dist,a_dist,b_dist] 
    #print("p:\n")
    #print(p)
    prob = DifferentialEquations.ODEProblem(fun_na!, u, tspan, p)
    predicted = DifferentialEquations.solve(prob,alg,reltol = rtol) # Observations. 
    
    #if length(predicted) > 0
    for i in 1:length(t_obs)
      predicted_value = predicted(t_obs[i])
      for j in 1:length(predicted_value)
          print("\n")
          print("Bernoulli\n")
          print(Bernoulli(clamp(predicted_value[j],0,1)))
          print("\n")
          print("\n")
          data[j, i] ~ Bernoulli(clamp(predicted_value[j],0,1))  # Ensure it's a probability it goes to -7.e-10 which is zero 
      end
    end
    #end
    
    #for i in 1:length(t_obs) 
    #    data[:, i] ~ MvNormal(predicted(t_obs[i])) 
    #end 

    return nothing 
end 

# Save the Bayesian model to fit with the observed data and the ode model
model = fitlv(matrix_obs, prob) 

# Sample 3 independent chains with forward-mode automatic differentiation (the default). 
iterations = 1000
print("Before MCMC")
chain = sample(model, NUTS(), MCMCSerial(),1000, 3, 
init_params = [0.01, 0.01, 0.001,0.45,490] ; progress=false, verbose = false)

I just looked diagonally at your code but two things seemed unusual to me:

  1. Do you always need to rebuild those arrays like R_M_vec inside the differential equation step fun_na!? You are doing that for every time the integrator takes a step. If it’s possible to do it once outside before you integrate the ODE it will be significantly faster.
  2. The print statements inside the quadratic loop in the model might be non-negligibly slowing it.

Try profiling it using ProfileView.jl:
https://www.youtube.com/watch?v=pvduxLowpPY

EDIT: These lines from fun_na! seem like an example of matrix operations that you’re repeating every timestep without needing to

 mat = exp.(-alp*d_1)
 mat[diagind(mat)] .= 0
 eta1 = 1.0./(1.0 .+ exp.(-p[4].*eta.+p[5]))

Hi, thanks for the answer.

The algorithm to solve the differential equation is adaptive so I need to each time call R_M_vec. And the print I wrote it after to see what was going on but it was just a check I do know that slow things a lot.

Regards,
Marta

Have you tried moving those three lines with operations that do not depend on u and t outside the integrator? I suspect those might speed things up.

Do you mean this line:

R_M_vec = [interpolated_functionsi for i in 1:N]

Because it cannot be moved from there as far as I know. Because the integrator will need the value of R_M_vec at each time step, and the time steps are decided inside the integration.

No, not that one, I understood your remark about using the adaptive scheme and it depends on t (although could you try a fixed step scheme just to test the speed).
I meant these three:

They do not depend on u and t and for the last line, the parameters will only change every time you resample them for the next parameter inference step.

Ok, the first two I can put them outside it make it faster but not the last since it has the p[4] and p[5] which are the parameters that I want to estimate