Incorporating forcing functions in the ODE model

Hello! For 6 iterations, t goes from 1.0 to 36.0, but the 7th iteration never moves beyond t=1.0. Here are the 6th and 7th iterations:

┌ Warning: Only a single thread available: MCMC chains are not sampled in parallel
└ @ AbstractMCMC C:\Users\Bharadwaj\.julia\packages\AbstractMCMC\BPJCW\src\sample.jl:291
...
t = 1.0
t = 1.0
t = 1.0000000485997007
t = 1.000000782455182
t = 1.0000015892102145
t = 1.0000043739730673
t = 1.0000047628948012
t = 1.0000048599700748
t = 1.0000048599700748
t = 1.0000126845218953
t = 1.0000207520722193
t = 1.0000485997007478
t = 1.000052488918088
t = 1.0000534596708226
t = 1.0000534596708226
t = 1.0001317051890266
t = 1.0002123806922678
t = 1.0004908569775528
t = 1.0005297491509542
t = 1.0005394566783006
t = 1.0005394566783006
t = 1.00132191186034
t = 1.0021286668927534
t = 1.0049134297456022
t = 1.005302351479616
t = 1.0053994267530801
t = 1.0053994267530801
t = 1.0088365371217218
t = 1.0123803900483956
t = 1.024613087198903
t = 1.0263215133877168
t = 1.02674793835955
t = 1.02674793835955
t = 1.033320914060806
t = 1.0400980194422254
t = 1.0634912807889318
t = 1.0667583984028772
t = 1.0675738743921963
t = 1.0675738743921963
t = 1.0776184191354206
t = 1.0879749062619746
t = 1.123723503391586
t = 1.1287161749829902
t = 1.1299623510581849
t = 1.1299623510581849
t = 1.1436535776357943
t = 1.1577699975853792
t = 1.2064971580137658
t = 1.2133024239300352
t = 1.2150010254532748
t = 1.2150010254532748
t = 1.2331974052210632
t = 1.2519588899505842
t = 1.316719918565136
t = 1.32576448461123
t = 1.328022017799787
t = 1.328022017799787
t = 1.350755110312844
t = 1.3741941994877969
t = 1.4551014169410987
t = 1.4664009698901028
t = 1.4692213501790223
t = 1.4692213501790223
t = 1.49667689091032
t = 1.524985088186068
t = 1.6226995281800654
t = 1.6363463884151455
t = 1.63975265906907
t = 1.63975265906907
t = 1.6716701340330657
t = 1.70457883492762
t = 1.8181733265696665
t = 1.8340380048199765
t = 1.8379978451808439
t = 1.8379978451808439
t = 1.8740352558351026
t = 1.9111918407332822
t = 2.039449209086637
t = 2.057361713933873
t = 2.0618326939650586
t = 2.0618326939650586
t = 2.1015224504417334
t = 2.142444808051224
t = 2.283700897871935
t = 2.3034288123465574
t = 2.3083529205282547
t = 2.3083529205282547
t = 2.3517590358382705
t = 2.3965131671517033
t = 2.5509958011432503
t = 2.5725709431075208
t = 2.5779561212115834
t = 2.5779561212115834
t = 2.625138946291293
t = 2.6737870764977023
t = 2.8417110440174147
t = 2.865163411088264
t = 2.871017146551396
t = 2.871017146551396
t = 2.92294657072343
t = 2.976488709931862
t = 3.161305853103137
t = 3.1871175317256633
t = 3.1935601538311085
t = 3.1935601538311085
t = 3.2516983112063818
t = 3.311642001419396
t = 3.518556064624561
t = 3.547453813016783
t = 3.554666721379389
t = 3.554666721379389
t = 3.6211663831616434
t = 3.689731251831545
t = 3.926403961155965
t = 3.9594578208087094
t = 3.9677080989089184
t = 3.9677080989089184
t = 4.045692369165227
t = 4.1260985111686255
t = 4.4036450133851766
t = 4.442407332697627
t = 4.452082448326983
t = 4.452082448326983
t = 4.545380569567259
t = 4.64157614798891
t = 4.973624740974486
t = 5.019998856611925
t = 5.031573884601986
t = 5.031573884601986
t = 5.141745727043812
t = 5.255339055275757
t = 5.6474413268233725
t = 5.702202577033643
t = 5.715871042625749
t = 5.715871042625749
t = 5.8396263209218775
t = 5.967224930842109
t = 6.407670734964354
t = 6.46918367338321
t = 6.484537367446421
t = 6.484537367446421
t = 6.62597433276796
t = 6.771803750428678
t = 7.275178788498504
t = 7.345480463366507
t = 7.363027835282069
t = 7.363027835282069
t = 7.525962187567403
t = 7.693956612905078
t = 8.273840984703194
t = 8.354828001309155
t = 8.375042445749985
t = 8.375042445749985
t = 8.5213229195255
t = 8.672146265157398
t = 9.192759379898828
t = 9.265468535412204
t = 9.283616817026477
t = 9.283616817026477
t = 9.48122633281496
t = 9.684972790087558
t = 10.388266284167067
t = 10.486488696520182
t = 10.511005113849354
t = 10.511005113849354
t = 10.686642912708416
t = 10.86773530159416
t = 11.49283131865157
t = 11.580132622111035
t = 11.601923119185148
t = 11.601923119185148
t = 11.759362415500851
t = 11.921691130956793
t = 12.482018564428206
t = 12.560274246709508
t = 12.57980694723299
t = 12.57980694723299
t = 12.730463788665682
t = 12.885799413993798
t = 13.421988048409526
t = 13.496872490810867
t = 13.51556372631803
t = 13.51556372631803
t = 13.665083057882418
t = 13.819245846948805
t = 14.351386076677962
t = 14.425705116330887
t = 14.444255226717955
t = 14.444255226717955
t = 14.595494460359511
t = 14.751430564611054
t = 15.289691936515473
t = 15.364865858740002
t = 15.383629348715198
t = 15.383629348715198
t = 15.537278051749817
t = 15.695698453636444
t = 16.242535142076427
t = 16.31890669840948
t = 16.33796911911656
t = 16.33796911911656
t = 16.492581843262546
t = 16.65199620480437
t = 17.20226385036741
t = 17.27911457633326
t = 17.298296598284168
t = 17.298296598284168
t = 17.457386920931953
t = 17.62141793682967
t = 18.187621383271782
t = 18.266697713225582
t = 18.286435248270404
t = 18.286435248270404
t = 18.44667543783798
t = 18.61189203080827
t = 19.182187860759946
t = 19.26183573557228
t = 19.28171592881434
t = 19.28171592881434
t = 19.442657478241383
t = 19.608597212433242
t = 20.18138918648105
t = 20.26138567435005
t = 20.281352881777348
t = 20.281352881777348
t = 20.44286622255102
t = 20.609395505584747
t = 21.184222488586702
t = 21.264503187087993
t = 21.284541333787743
t = 21.284541333787743
t = 21.44669954243314
t = 21.613893720291124
t = 22.191015792053932
t = 22.271617024097125
t = 22.291735176305732
t = 22.291735176305732
t = 22.45466095724302
t = 22.62264654504172
t = 23.202500411358912
t = 23.283483167550195
t = 23.303696548587045
t = 23.303696548587045
t = 23.467018321288624
t = 23.63541219873249
t = 24.21667540219836
t = 24.297854987303495
t = 24.31811749704406
t = 24.31811749704406
t = 24.48209352778081
t = 24.65116198183236
t = 25.23475369370911
t = 25.316258479765523
t = 25.336602160005224
t = 25.336602160005224
t = 25.501475189015594
t = 25.671468498430016
t = 26.258252632734006
t = 26.340203274517357
t = 26.36065824081498
t = 26.36065824081498
t = 26.52559901141111
t = 26.69566216618724
t = 27.282687393215692
t = 27.364671706160795
t = 27.38513507681577
t = 27.38513507681577
t = 27.54961405341023
t = 27.719201072756068
t = 28.304582772064297
t = 28.386337548892815
t = 28.40674362709191
t = 28.40674362709191
t = 28.570918895364944
t = 28.740192774453913
t = 29.324493573959806
t = 29.40609739164113
t = 29.426465790278463
t = 29.426465790278463
t = 29.59022926530114
t = 29.75907856252949
t = 30.341913787299635
t = 30.42331292189027
t = 30.4436302314131
t = 30.4436302314131
t = 30.60770023858896
t = 30.776865587602583
t = 31.36079176220984
t = 31.442343259541857
t = 31.462698598965034
t = 31.462698598965034
t = 31.627312475896165
t = 31.797038585030123
t = 32.38290039547446
t = 32.464722224926646
t = 32.48514503953106
t = 32.48514503953106
t = 32.649904721660576
t = 32.81978116410467
t = 33.40616189615567
t = 33.488056198526294
t = 33.5084971024473
t = 33.5084971024473
t = 33.67332356226874
t = 33.843268856246134
t = 34.429887250517474
t = 34.51181474494348
t = 34.53226393363638
t = 34.53226393363638
t = 34.69671489251523
t = 34.86627302403007
t = 35.45155500811443
t = 35.53329585865198
t = 35.55369846083421
t = 35.55369846083421
t = 35.625553008639905
t = 35.69963906414142
t = 35.95536984608342
t = 35.99108536816168
t = 36.0
t = 36.0
t = 1.0
t = 1.0
t = NaN

The NaN pops up randomly! :neutral_face:

Boom diggity. Now you’re onto it. Post the full stack trace there now. I want to see in what part of the solvers/fitters it’s hitting a NaN.

1 Like

The full stack trace error is extra ordinarily long, so posting the link for the .txt file containing the full error.

Error_file

In the error:

  1. Line 62 is
    dy[1] = -β*w(t)*m(t)*I*S/N + λ*R; # S
  2. Line 107 is
    predicted = solve(problem, Tsit5(), saveat=1.0)
  3. Line 123 is
    chain = sample(model, NUTS(0.65), MCMCThreads(), 10000, number_of_chains);

Hope this helps! :crossed_fingers:

What’s the current code? Give me something I can copy-paste to run.

1 Like

@ChrisRackauckas Thanks for the response! Please find the attached link to the files containing all the observation data as well as the original code.
All_files

Hope this helps. Thanks in advance. :blush:

Why don’t you simply insert the code here? Every extra activity (including clicking here and dowloading there) reduces the chances that people will actually help, I dare to guess.

I actually did download the code and those 123 lines of code could be easily inserted here, no problem. Just note that it is a good habit to enclose the code in triple backticks here at discourse.

1 Like

Thank you @zdenek_hurak for the suggestion! True that I can copy-paste the code, but the issue is that the code requires some .xlsx data files to run. One who wishes to run the code anyway needs to download them, that’s why attached a link. :sweat_smile:

Nevertheless this is the code:

#=
Section 1: Import required packages
=#

using Turing, Distributions, DifferentialEquations, Interpolations
using MCMCChains, Plots, StatsPlots
using CSV, XLSX, DataFrames
using Random
Random.seed!(18431)

#=
Section 2: Read the data file containing observation data and get the NPI data into arrays
=#

my_data = DataFrame(XLSX.readtable("UK_weekly_15_July.xlsx","Sheet1"; infer_eltypes = true)...);

total_weeks = 36;   # Total number of time points
N = 67081000;       # Population

y_time = 1:1:total_weeks;               # Timepoints (weeks)

y_S = Float64.(my_data.Susceptible);    # Susceptible
y_S = y_S[1:total_weeks];

y_D = Float64.(my_data.Deceased);       # Deceased
y_D = y_D[1:total_weeks];

y_HC = Float64.(my_data.Hosp_critical); # Critical hospitalizations
y_HC = y_HC[1:total_weeks];

y_T = Float64.(my_data.Hosp_total);     # Total hospitalizations
y_T = y_T[1:total_weeks];
y_HNC = y_T - y_HC;                     # Non-critical hospitalizations

observation_data = [y_S y_D y_HC y_HNC];

wet_data = DataFrame(XLSX.readtable("Wetdata.xlsx","Wetdata"; infer_eltypes = true)...);
# IPTCC is a forcing function
IPTCC = wet_data.Normalized_IPTCC;
IPTCC = IPTCC[1:total_weeks];

mobil_data = DataFrame(XLSX.readtable("Mobdata.xlsx","Mobdata"; infer_eltypes = true)...);
# mobil is another forcing function
mobil = mobil_data.Mean;
mobil = mobil[1:total_weeks];

wet_forcing = interpolate(IPTCC, BSpline(Linear()));
mobil_forcing = interpolate(mobil, BSpline(Linear()));

forcing_params = (wet_forcing, mobil_forcing);

#=
Section 3: Define the model and the respective parameters
=#

function epidemic_wildtype(dy, y, p, t)
    S, E, I, Hᵪ, Hₙ, R, D = y;
    β, λ, α, γ, θᵪ, θₙ, γᵪ, γₙ, δᵪ, w, m = p;
    N = 67081000;

    @show t
    dy[1] = -β*w(t)*m(t)*I*S/N + λ*R;  # S
    dy[2] = β*w(t)*m(t)*I*S/N - α*E;   # E
    dy[3] = α*E - (γ + θᵪ + θₙ)*I;     # I
    dy[4] = θₙ*I - γₙ*Hᵪ;               # HNC
    dy[5] = θᵪ*I - (γᵪ + δᵪ)*Hₙ;       # HC
    dy[6] = γ*I + γₙ*Hₙ + γᵪ*Hᵪ - λ*R;  # R
    dy[7] = δᵪ*Hᵪ;                     # D
end

#=
Section 4: Define the priors and the Bayesian model
=#

Turing.setadbackend(:forwarddiff)
@model function fitting_epidemic_wildtype(observ_data, w_forcing, m_forcing)
    # Priors of model parameters
    β ~ truncated(Normal(0.65, 0.1), 0, 2)
    λ ~ truncated(Normal(0.5, 0.1), 0, 5)
    α ~ truncated(Normal(0.25, 0.1), 0.1, 0.5)
    γ ~ truncated(Normal(0.05, 0.1), 0, 5)
    γₙ ~ Uniform(0.05, 0.1)
    γᵪ ~ Uniform(0.05, 0.1)
    θₙ ~ Uniform(0.09, 0.75)
    θᵪ ~ Uniform(0.09, 0.75)
    δᵪ ~ Uniform(0.1, 0.8)

    p = (β, λ, α, γ, θᵪ, θₙ, γᵪ, γₙ, δᵪ, w_forcing, m_forcing);

    # Priors of standard deviations
    σ₁ ~ InverseGamma(1, 1) # Susceptible
    σ₂ ~ InverseGamma(1, 1) # Deceased
    σ₃ ~ InverseGamma(2, 3) # Critically hospitalized
    σ₄ ~ InverseGamma(1, 1) # Non-critically hospitalized

    # Initial conditions
    N = 67081000;
    S0 = N;
    I0 = 100;
    y0 = [S0, 0, I0, 0, 0, 0, 0];
    # @show typeof(y0)
    # @show eltype(p)
    y0 = typeof(β).(y0);

    # Solve the model and compare with observed data
    problem = ODEProblem(epidemic_wildtype, y0, (1.0,36.0), p)
    predicted = solve(problem, Tsit5(), saveat=1.0)

    for i = 1:length(predicted)
        observ_data[i,1] ~ Normal(predicted[1,i], σ₁)
        observ_data[i,2] ~ Normal(predicted[7,i], σ₂)
        observ_data[i,3] ~ Normal(predicted[5,i], σ₃)
        observ_data[i,4] ~ Normal(predicted[4,i], σ₄)
    end
end

#=
Section 5: Run the model-inference system and save the chains
=#

model = fitting_epidemic_wildtype(observation_data, wet_forcing, mobil_forcing);
number_of_chains = 1;
chain = sample(model, NUTS(0.65), MCMCThreads(), 10000, number_of_chains);

This is fixed by Better NaN detection by ChrisRackauckas · Pull Request #1516 · SciML/OrdinaryDiffEq.jl · GitHub. Basically, Turing was sampling in bad places and it needed a better way to tell the ODE solver to give up on the bad parameters where NaN derivatives were found.

2 Likes

Thank you!