Turing.jl run not reproducible!

Hi all,

I am running a Turing.jl model (an ODE model) and the results I am getting change every time I run the script. I have a seed set at the beginning of the script and there is more than one call to RNG in there but I thought having the seed up there and running the whole thing should make this reproducible! Any good reason why it isn’t?

Here is he script

cd(@__DIR__)
cd("..")
using Pkg; Pkg.activate(".")
using CSV, DataFramesMeta, Chain, Random
using OrdinaryDiffEq, DiffEqCallbacks, Turing, Distributions, CategoricalArrays
using GlobalSensitivity, QuadGK
using Plots, StatsPlots, MCMCChains
using Gadfly
import Cairo, Fontconfig

Random.seed!(1)

# paths
modDir = "model"
modName = "mavoPBPKGenODE"
tabDir = joinpath("deliv","table")
figDir = joinpath("deliv","figure")
modPath = mkpath(joinpath(modDir, string(modName, "_jl")))
figPath = mkpath(joinpath(figDir, string(modName, "_jl")))
tabPath = mkpath(joinpath(tabDir, string(modName, "_jl")))

# read data
dat_orig = CSV.read("data/Mavoglurant_A2121_nmpk.csv", DataFrame)
dat = dat_orig[dat_orig.ID .<= 812,:]
dat_obs = dat[dat.MDV .== 0,:]  # grab first 20 subjects ; remove missing obs

# load model
include(joinpath("..", modDir, string(modName, ".jl")))

# conditions
nSubject = length(unique(dat.ID))
doses = dat.AMT[dat.EVID .== 1,:]
rates = dat.RATE[dat.EVID .== 1,:]
durs = doses ./ rates

ncmt = 16
u0 = zeros(ncmt)
tspan = (0.0,maximum(dat.TIME))
times = []
for i in unique(dat_obs.ID); push!(times, dat_obs.TIME[dat_obs.ID .== i]); end

# fixed parameters
wts = dat.WT[dat.EVID .== 1,:]
VVBs = (5.62 .* wts ./ 100) ./ 1.040
BP = 0.61

# callback function for infusion stopping
function affect!(integrator)
    integrator.p[8] = integrator.p[8] == 0.0
end
cbs = []
for i in 1:length(unique(dat.ID)); push!(cbs, PresetTimeCallback([durs[i]], affect!)); end

# variable params
CLint = exp(7.1)
KbBR = exp(1.1);
KbMU = exp(0.3);
KbAD = exp(2);
KbBO = exp(0.03);
KbRB = exp(0.3);
p = [CLint, KbBR, KbMU, KbAD, KbBO, KbRB, wts[1], rates[1]]

# define problem
prob = ODEProblem(PBPKODE!, u0, tspan, p)

########

## sensitivity analysis ##
## create function that takes in paramers and returns endpoints for sensitivity
p_sens = p[1:6]
f_globsens = function(p_sens)
    p_all = [p_sens;[wts[1],rates[1]]]
    tmp_prob = remake(prob, p = p_all)
    tmp_sol = solve(tmp_prob, Tsit5(), save_idxs=[15], callback=cbs[1])
    auc, err = quadgk(tmp_sol, 0.0, 48.0)
    return(auc[1])
end

### conditions
n = 1000
lb = [1000.0, 1.0, 1.0, 1.0, 1.0, 1.0]
ub = [1500.0, 10.0, 10.0, 10.0, 10.0, 10.0]
sampler = GlobalSensitivity.SobolSample()
A, B = GlobalSensitivity.QuasiMonteCarlo.generate_design_matrices(n, lb, ub, sampler)

bounds = [[1000.0,1500.0],[1.0,10.0],[1.0,10.0],[1.0,10.0],[1.0,10.0],[1.0,10.0]]

#### Sobol
@time s = GlobalSensitivity.gsa(f_globsens, Sobol(), A, B)

plot_sens_total = Plots.bar(["CLint","KbBR","KbMU","KbAD","KbBO","KbRB"], s.ST, title="Total Order Indices", ylabel="Index", legend=false)
Plots.hline!([0.05], linestyle=:dash)
plot_sens_single = Plots.bar(["CLint","KbBR","KbMU","KbAD","KbBO","KbRB"], s.S1, title="First Order Indices", ylabel="Index", legend=false)
Plots.hline!([0.05], linestyle=:dash)

plot_sens = Plots.plot(plot_sens_single, plot_sens_total, xrotation = 45)
savefig(plot_sens, joinpath(figPath, "sensitivity.pdf"))

#######

## Bayesian inference ##
@model function fitPBPK(data, prob, nSubject, rates, times, wts, cbs, VVBs, BP) # data should be a Vector
    # priors
    σ ~ truncated(Cauchy(0, 0.5), 0.0, Inf) # residual error
    ĈLint ~ LogNormal(7.1,0.25)
    KbBR ~ LogNormal(1.1,0.25)
    KbMU ~ LogNormal(0.3,0.25)
    KbAD ~ LogNormal(2.0,0.25)
    KbBO ~ LogNormal(0.03, 0.25)
    KbRB ~ LogNormal(0.3, 0.25)
    ω ~ truncated(Cauchy(0, 0.5), 0.0, 1.0)
    ηᵢ ~ filldist(Normal(0.0, 1.0), nSubject)

    # individual params
    CLintᵢ = ĈLint .* exp.(ω .* ηᵢ)  # non-centered parameterization

    # simulate
    function prob_func(prob,i,repeat)
        ps = [CLintᵢ[i], KbBR, KbMU, KbAD, KbBO, KbRB, wts[i], rates[i]]
        remake(prob, p=ps, saveat=times[i], callback=cbs[i])
    end

    tmp_ensemble_prob = EnsembleProblem(prob, prob_func=prob_func)
    tmp_ensemble_sol = solve(tmp_ensemble_prob, Tsit5(), trajectories=nSubject) 

    predicted = []
    for i in 1:nSubject
        times_tmp = times[i]
        idx = findall(x -> x in times_tmp, tmp_ensemble_sol[i].t)
        tmp_sol = Array(tmp_ensemble_sol[i])[15,idx] ./ (VVBs[i]*BP/1000.0)
        append!(predicted, tmp_sol)
    end

    # likelihood
    for i = 1:length(predicted)
        data[i] ~ LogNormal(log.(max(predicted[i], 1e-12)), σ)
    end
end

mod = fitPBPK(dat_obs.DV, prob, nSubject, rates, times, wts, cbs, VVBs, BP)

# run 
## serial 
#@time mcmcchains = mapreduce(c -> sample(mod, NUTS(250,adapt_delta), 250), chainscat, 1:4)  # serial
#@time mcmcchains_prior = mapreduce(c -> sample(mod, Prior(), 250), chainscat, 1:4)  # serial

## multithreading
nsampl = 250
nchains = 4
adapt_delta = .8
@time mcmcchains = sample(mod, NUTS(nsampl,adapt_delta), MCMCThreads(), nsampl, nchains)
@time mcmcchains_prior = sample(mod, Prior(), MCMCThreads(), nsampl, nchains)  # parallel

## save mcmcchains
write(joinpath(modPath, string(modName, "chains.jls")), mcmcchains)
write(joinpath(modPath, string(modName, "chains_prior.jls")), mcmcchains_prior)

##load saved chains
#mcmcchains = read(joinpath(modPath, string(modName, "chains.jls")), Chains)


#---# diagnostics #---#
# tables
summ, quant = describe(mcmcchains)
#summ, quant = describe(mcmcchains; q = [0.05, 0.25, 0.5, 0.75, 0.95])

## summary
df_summ = DataFrame(summ)
CSV.write(joinpath(tabPath, "Summary.csv"), df_summ)

## quantiles
df_quant = DataFrame(quant)
CSV.write(joinpath(tabPath, "Quantiles.csv"), df_quant)

# plots
## trace plots
#plot_chains = StatsPlots.plot(mcmcchains[:,1:8,:])  # mcmcchains[samples, params, chains]
plot_chains1 = StatsPlots.plot(mcmcchains[:,1:4,:])
plot_chains2 = StatsPlots.plot(mcmcchains[:,5:8,:])
plot_chains = Plots.plot(plot_chains1, plot_chains2, layout = (1,2))
savefig(plot_chains, joinpath(figPath, "MCMCTrace.pdf"))

## density plots
p_post = zeros(nsampl, 8, 4)
p_prior = deepcopy(p_post)

for i in 1:4; p_post[:,:,i] = Array(mcmcchains[:,1:8,1]); end
for i in 1:4; p_prior[:,:,i] = Array(mcmcchains_prior[:,1:8,1]); end

p_post_mean = mean(p_post, dims=3)[:,:,1]
p_prior_mean = mean(p_prior, dims=3)[:,:,1]

pars = summ[1:8,1]
dens_plots = []
for i in 1:8; p = density(p_post_mean[:,i], title=pars[i], label="Posterior"); density!(p_prior_mean[:,i], label="Prior"); push!(dens_plots, p); end

dens_plots[1] = Plots.plot(dens_plots[1], xlims=(0.0,0.6))

plot_dens = Plots.plot(dens_plots[1],
                        dens_plots[2],
                        dens_plots[3],
                        dens_plots[4],
                        dens_plots[5],
                        dens_plots[6],
                        dens_plots[7],
                        dens_plots[8], 
                        layout=grid(4,2),
                        size = (650,650))
savefig(plot_dens, joinpath(figPath, "DensPlots.pdf"))


#=
## rhat
df_tmp = @orderby(DataFrame(summ)[1:8,:], :rhat)
plot_rhat = Plots.bar(string.(df_tmp.parameters), df_tmp.rhat, orientation=:h, legend = false, xlim=[0.99,1.05], xticks=[1.0,1.05], xlabel = "R̂")
Plots.vline!([1.0,1.05], linestyle=[:solid,:dash])
savefig(plot_rhat, joinpath(figPath, "rhat.pdf"))

## neff
df_tmp = @orderby(@transform!(df_tmp, :neff = :ess ./ 1000.0), :neff)
plot_neff = Plots.bar(string.(df_tmp.parameters), df_tmp.neff, orientation=:h, legend = false, xlim=[0.0,2.5], xticks=[0.0:0.25:2.5;], xlabel = "Neff/N")
Plots.vline!([0.1,0.5,1.0], linestyle=:dash)
savefig(plot_neff, joinpath(figPath, "neff.pdf"))
=#

#---# predictive checks #---#

#--# conditional on chains #--#

dat_missing = Vector{Missing}(missing, length(dat_obs.DV)) # vector of `missing`
mod_pred = fitPBPK(dat_missing, prob, nSubject, rates, times, wts, cbs, VVBs, BP)
#pred = predict(mod_pred, mcmcchains)  # posterior ; conditioned on each sample in chains
pred = predict(mod_pred, mcmcchains, include_all=false)  # include_all = false means sampling new !!
#pred_prior = predict(mod_pred, mcmcchains_prior)

### predictive checks summaries
#summarystats(pred)
summ_pred, quant_pred = describe(pred)
#summarystats(pred_prior)
#summ_pred_prior, quant_pred_prior = describe(pred_prior)

#### save
#CSV.write(joinpath(tabPath, "Summary_PPC.csv"), summ_pred)
#CSV.write(joinpath(tabPath, "Quantiles_PPC.csv"), quant_pred)

# data assembly
bins = [0, 1, 2, 3, 4, 6, 8, 10, 20, 30, 40, 50]
labels = string.(1:length(bins) - 1)

## observed
df_vpc_obs = @chain begin
    dat_obs
    @select(:ID, :TIME, :DV, :DOSE)
    @transform(:DNDV = :DV ./ :DOSE,
               :bins = cut(:TIME, bins, labels = labels))
    
    groupby(:bins)
    @transform(:lo = quantile(:DNDV, 0.05),
               :med = quantile(:DNDV, 0.5),
               :hi = quantile(:DNDV, 0.95))
end
df_vpc_obs2 = @orderby(unique(@select(df_vpc_obs, :TIME, :bins, :lo, :med, :hi)), :TIME)

## predicted
df_pred = DataFrame(pred)
df_vpc_pred = @chain begin
    df_pred
    DataFramesMeta.stack(3:ncol(df_pred))
    @orderby(:iteration, :chain)
    hcat(select(repeat(dat_obs, 1000), [:ID,:TIME,:DOSE]))
    @transform(:DNDV = :value ./ :DOSE,
    :bins = cut(:TIME, bins, labels = labels))

    groupby([:iteration, :chain, :bins])
    @transform(:lo = quantile(:DNDV, 0.05),
               :med = quantile(:DNDV, 0.5),
               :hi = quantile(:DNDV, 0.95))
    
    groupby(:bins)
    @transform(:loLo = quantile(:lo, 0.025),
               :medLo = quantile(:lo, 0.5),
               :hiLo = quantile(:lo, 0.975),
               :loMed = quantile(:med, 0.025),
               :medMed = quantile(:med, 0.5),
               :hiMed = quantile(:med, 0.975),
               :loHi = quantile(:hi, 0.025),
               :medHi = quantile(:hi, 0.5),
               :hiHi = quantile(:hi, 0.975))
end

df_vpc_pred2 = @orderby(unique(df_vpc_pred[!,[6;13:21]]), :TIME)

### plot
dat_obs2 = @transform(dat_obs, :DNDV = :DV ./ :DOSE)

set_default_plot_size(17cm, 12cm)

plot_ppc = Gadfly.plot(x=dat_obs2.TIME, y=dat_obs2.DNDV, Geom.point, Scale.y_log10, Theme(background_color="white", default_color="black"), alpha=[0.2], Guide.xlabel("Time (h)"), Guide.ylabel("Mavoglurant dose-normalized concentration (ng/mL/mg)", orientation=:vertical),
    layer(x=df_vpc_obs.TIME, y=df_vpc_obs.med, Geom.line, Theme(default_color="black")),
    layer(x=df_vpc_obs.TIME, y=df_vpc_obs.lo, Geom.line, Theme(default_color="black")),
    layer(x=df_vpc_obs.TIME, y=df_vpc_obs.hi, Geom.line, Theme(default_color="black")),
    layer(x=df_vpc_pred2.TIME, ymin=df_vpc_pred2.loMed, ymax=df_vpc_pred2.hiMed, Geom.ribbon, Theme(default_color="deepskyblue"), alpha=[0.8]),
    layer(x=df_vpc_pred2.TIME, ymin=df_vpc_pred2.loLo, ymax=df_vpc_pred2.hiLo, Geom.ribbon, Theme(default_color="deepskyblue"), alpha=[0.5]),
    layer(x=df_vpc_pred2.TIME, ymin=df_vpc_pred2.loHi, ymax=df_vpc_pred2.hiHi, Geom.ribbon, Theme(default_color="deepskyblue"), alpha=[0.5]))

plot_tmp = PDF(joinpath(figPath, "PPCCond.pdf"), 17cm, 12cm)
draw(plot_tmp, plot_ppc)


#--# new population #--#

df_params = DataFrame(mcmcchains)[:,3:10]

# save CSV
#CSV.write(joinpath(modPath, "df_params.csv"), df_params)

## new etas
ηs = reshape(rand(Normal(0.0, 1.0), nSubject*nrow(df_params)), nrow(df_params), nSubject)

array_pred = Array{Float64}(undef, nrow(df_params), nrow(dat_obs))

for j in 1:nrow(df_params)
    KbBR = df_params[j,:KbBR]
    KbMU = df_params[j,:KbMU]
    KbAD = df_params[j,:KbAD]
    KbBO = df_params[j,:KbBO]
    KbRB = df_params[j,:KbRB]

    CLintᵢ = df_params[j,:ĈLint] .* exp.(df_params[j,:ω] .* ηs[j,:])

    # simulate
    function prob_func(prob,i,repeat)
        ps = [CLintᵢ[i], KbBR, KbMU, KbAD, KbBO, KbRB, wts[i], rates[i]]
        remake(prob, p=ps, saveat=times[i], callback=cbs[i])
    end
    
    tmp_ensemble_prob = EnsembleProblem(prob, prob_func=prob_func)
    tmp_ensemble_sol = solve(tmp_ensemble_prob, Tsit5(), trajectories=nSubject)
    
    predicted = []
    for i in 1:nSubject
        times_tmp = times[i]
        idx = findall(x -> x in times_tmp, tmp_ensemble_sol[i].t)
        tmp_sol = Array(tmp_ensemble_sol[i])[15,idx] ./ (VVBs[i]*BP/1000.0)
        append!(predicted, tmp_sol)
    end

    array_pred[j, :] = rand.(LogNormal.(log.(predicted), df_params[j,:σ]))
end

df_pred_new = DataFrame(array_pred, :auto)
@transform!(df_pred_new, :iteration = 1:size(array_pred)[1])

# save version for R's vpc
df_pred_new2 = @chain begin
    df_pred_new
    DataFramesMeta.stack(1:268)
    @orderby(:iteration)
    hcat(select(repeat(dat_obs, 1000), [:ID,:TIME,:DOSE]))
    @transform(:DNDV = :value ./ :DOSE)
end

# save CSV
#CSV.write(joinpath(modPath, "df_pred.csv"), df_pred_new2)

df_vpc_pred_new = @chain begin
    df_pred_new
    DataFramesMeta.stack(1:268)
    @orderby(:iteration)
    hcat(select(repeat(dat_obs, 1000), [:ID,:TIME,:DOSE]))
    @transform(:DNDV = :value ./ :DOSE,
    :bins = cut(:TIME, bins, labels = labels))

    groupby([:iteration, :bins])
    @transform(:lo = quantile(:DNDV, 0.05),
               :med = quantile(:DNDV, 0.5),
               :hi = quantile(:DNDV, 0.95))
    
    groupby(:bins)
    @transform(:loLo = quantile(:lo, 0.025),
               :medLo = quantile(:lo, 0.5),
               :hiLo = quantile(:lo, 0.975),
               :loMed = quantile(:med, 0.025),
               :medMed = quantile(:med, 0.5),
               :hiMed = quantile(:med, 0.975),
               :loHi = quantile(:hi, 0.025),
               :medHi = quantile(:hi, 0.5),
               :hiHi = quantile(:hi, 0.975))
end

df_vpc_pred_new2 = @orderby(unique(df_vpc_pred_new[!,[5;12:20]]), :TIME)

### plot
#dat_obs2 = @transform(dat_obs, :DNDV = :DV ./ :DOSE)

set_default_plot_size(17cm, 12cm)

plot_ppc_new = Gadfly.plot(x=dat_obs2.TIME, y=dat_obs2.DNDV, Geom.point, Scale.y_log10, Theme(background_color="white", default_color="black"), alpha=[0.2], Guide.xlabel("Time (h)"), Guide.ylabel("Mavoglurant dose-normalized concentration (ng/mL/mg)", orientation=:vertical),
    layer(x=df_vpc_obs.TIME, y=df_vpc_obs.med, Geom.line, Theme(default_color="black")),
    layer(x=df_vpc_obs.TIME, y=df_vpc_obs.lo, Geom.line, Theme(default_color="black")),
    layer(x=df_vpc_obs.TIME, y=df_vpc_obs.hi, Geom.line, Theme(default_color="black")),
    layer(x=df_vpc_pred_new2.TIME, ymin=df_vpc_pred_new2.loMed, ymax=df_vpc_pred_new2.hiMed, Geom.ribbon, Theme(default_color="deepskyblue"), alpha=[0.8]),
    layer(x=df_vpc_pred_new2.TIME, ymin=df_vpc_pred_new2.loLo, ymax=df_vpc_pred_new2.hiLo, Geom.ribbon, Theme(default_color="deepskyblue"), alpha=[0.5]),
    layer(x=df_vpc_pred_new2.TIME, ymin=df_vpc_pred_new2.loHi, ymax=df_vpc_pred_new2.hiHi, Geom.ribbon, Theme(default_color="deepskyblue"), alpha=[0.5]))

plot_tmp = PDF(joinpath(figPath, "PPCPred.pdf"), 17cm, 12cm)
draw(plot_tmp, plot_ppc_new)

#--# individual plots #--#

df_cObs = @chain begin
    df_vpc_obs
    @select(:ID,:TIME,:DV,:DOSE)
end

df_cCond = @chain begin
    df_vpc_pred
    groupby([:ID,:TIME])
    @transform(:loCond = quantile(:value, 0.05),
               :medCond = quantile(:value, 0.5),
               :hiCond = quantile(:value, 0.95))
    @select(:ID, :TIME, :DOSE, :loCond, :medCond, :hiCond)
    unique()
end

df_cPred = @chain begin
    df_vpc_pred_new
    groupby([:ID,:TIME])
    @transform(:loPred = quantile(:value, 0.05),
               :medPred = quantile(:value, 0.5),
               :hiPred = quantile(:value, 0.95))
    @select(:ID, :TIME, :DOSE, :loPred, :medPred, :hiPred)
    unique()
end

# join all data
df_cAll = hcat(df_cObs, 
               @select(df_cCond, :loCond, :medCond, :hiCond),
               @select(df_cPred, :loPred, :medPred, :hiPred))

# save CSV for plotting in R
CSV.write(joinpath(modPath, "df_ind.csv"), df_cAll)

#=
# plot
plot_ind = Gadfly.plot(df_call, x=:TIME, y=:DV, xgroup=:ID, Geom.subplot_grid(Geom.point),
                       Theme(background_color="white", default_color="black"), 
                       Scale.y_log10) 
plot_ind = Gadfly.plot(df_call, x=:TIME, Theme(background_color="white", default_color="black"), Scale.y_log10, 
    layer(y=:DV, Geom.point))


plot_ind = Gadfly.plot(x=dat_obs2.TIME, y=dat_obs2.DNDV, Geom.point, Scale.y_log10, Theme(background_color="white", default_color="black"), alpha=[0.2], Guide.xlabel("Time (h)"), Guide.ylabel("Mavoglurant dose-normalized concentration (ng/mL/mg)", orientation=:vertical),
layer(x=df_vpc_obs.TIME, y=df_vpc_obs.med, Geom.line, Theme(default_color="black")),
layer(x=df_vpc_obs.TIME, y=df_vpc_obs.lo, Geom.line, Theme(default_color="black")),
layer(x=df_vpc_obs.TIME, y=df_vpc_obs.hi, Geom.line, Theme(default_color="black")),
layer(x=df_vpc_pred_new2.TIME, ymin=df_vpc_pred_new2.loMed, ymax=df_vpc_pred_new2.hiMed, Geom.ribbon, Theme(default_color="deepskyblue"), alpha=[0.8]),
layer(x=df_vpc_pred_new2.TIME, ymin=df_vpc_pred_new2.loLo, ymax=df_vpc_pred_new2.hiLo, Geom.ribbon, Theme(default_color="deepskyblue"), alpha=[0.5]),
layer(x=df_vpc_pred_new2.TIME, ymin=df_vpc_pred_new2.loHi, ymax=df_vpc_pred_new2.hiHi, Geom.ribbon, Theme(default_color="deepskyblue"), alpha=[0.5]))
=#

#######

## simulation ##

# get wts for 500 individuals
wts_sim = @chain begin 
    dat_orig 
    @subset(:EVID .== 1) 
    unique(:ID)
    @select(:WT)
    Array
end
wts_sim = sample(wts_sim, 500, replace=true)
vvbs_sim = (5.62 .* wts_sim ./ 100) ./ 1.040

# get 500 replicates of inferred parameters
df_params_sim1 = DataFrame(mcmcchains)[:,3:10]
df_params_sim2 = df_params_sim1[sample(1:nrow(df_params_sim1), 500, replace=false),:]

## new etas
ηs_sim = reshape(rand(Normal(0.0, 1.0), 500*500), 500, 500)

# create array to hold results
times_sim = [0.0:0.1:48.0;]
array_pred_sim = Array{Float64}(undef, 500, length(times_sim)*500)

# simulate a 50 mg single
## infusion callback
dose = 50.0
cb = PresetTimeCallback([10.0/60.0], affect!)

for j in 1:nrow(df_params_sim2)
    KbBR = df_params_sim2[j,:KbBR]
    KbMU = df_params_sim2[j,:KbMU]
    KbAD = df_params_sim2[j,:KbAD]
    KbBO = df_params_sim2[j,:KbBO]
    KbRB = df_params_sim2[j,:KbRB]

    CLintᵢ = df_params_sim2[j,:ĈLint] .* exp.(df_params_sim2[j,:ω] .* ηs_sim[j,:])

    # simulate
    function prob_func_sim(prob,i,repeat)
        ps = [CLintᵢ[i], KbBR, KbMU, KbAD, KbBO, KbRB, wts_sim[i], 300.0]
        remake(prob, p=ps, tspan=(0.0,48.0))
    end
    
    sim_ensemble_prob = EnsembleProblem(prob, prob_func=prob_func_sim)
    sim_ensemble_sol = solve(sim_ensemble_prob, Tsit5(), save_idxs=[15], callback=cb, saveat=times_sim, trajectories=nrow(df_params_sim2))

    predicted = []
    for i in 1:500
        idx = findall(x -> x in times_sim, sim_ensemble_sol[i].t)
        tmp_sol = Array(sim_ensemble_sol[i])[1,idx] ./ (vvbs_sim[i]*BP/1000.0)
        append!(predicted, tmp_sol)
    end

    array_pred_sim[j, :] = rand.(LogNormal.(log.(predicted), df_params_sim2[j,:σ]))
end

# get stats

df_pred_sim = DataFrame(array_pred_sim, :auto)
@transform!(df_pred_sim, :iteration = 1:size(array_pred_sim)[1])

df_pred_sim2 = @chain begin
    df_pred_sim
    DataFramesMeta.stack(1:240500)
    @orderby(:iteration)
    @transform(:ID = repeat(repeat(1:500,inner=length(times_sim)), 500),
               :TIME = repeat(times_sim, 500*500),
               :DOSE = dose)
    @transform(:DNDV = :value ./ :DOSE)

    groupby([:iteration, :TIME])
    @transform(:lo = quantile(:value, 0.05),
               :med = quantile(:value, 0.5),
               :hi = quantile(:value, 0.95))
    
    groupby(:TIME)
    @transform(:loLo = quantile(:lo, 0.025),
               :medLo = quantile(:lo, 0.5),
               :hiLo = quantile(:lo, 0.975),
               :loMed = quantile(:med, 0.025),
               :medMed = quantile(:med, 0.5),
               :hiMed = quantile(:med, 0.975),
               :loHi = quantile(:hi, 0.025),
               :medHi = quantile(:hi, 0.5),
               :hiHi = quantile(:hi, 0.975))
end

df_pred_summ = @orderby(unique(df_pred_sim2[!,[5;11:19]]), :TIME)

# save CSV
#CSV.write(joinpath(modPath, "df_pred.csv"), df_pred_new2)

### plot
#dat_obs2 = @transform(dat_obs, :DNDV = :DV ./ :DOSE)

set_default_plot_size(17cm, 12cm)

plot_pred_summ = Gadfly.plot(x=df_pred_summ.TIME, ymin=df_pred_summ.loMed, ymax=df_pred_summ.hiMed, Geom.ribbon, Scale.y_log10, Theme(default_color="deepskyblue", background_color="white"), alpha=[0.8], Guide.xlabel("Time (h)"), Guide.ylabel("Mavoglurant concentration (ng/mL)", orientation=:vertical),
    layer(x=df_pred_summ.TIME, ymin=df_pred_summ.loLo, ymax=df_pred_summ.hiLo, Geom.ribbon, Theme(default_color="deepskyblue"), alpha=[0.5]),
    layer(x=df_pred_summ.TIME, ymin=df_pred_summ.loHi, ymax=df_pred_summ.hiHi, Geom.ribbon, Theme(default_color="deepskyblue"), alpha=[0.5]),
    layer(x=df_pred_summ.TIME, y=df_pred_summ.medMed, Geom.line, Theme(default_color="black")),
    layer(x=df_pred_summ.TIME, y=df_pred_summ.medLo, Geom.line, Theme(default_color="black")),
    layer(x=df_pred_summ.TIME, y=df_pred_summ.medHi, Geom.line, Theme(default_color="black")))

plot_tmp = PDF(joinpath(figPath, "SimPred.pdf"), 17cm, 12cm)
draw(plot_tmp, plot_pred_summ)

I would suggest using StableRNG where reproducibility is important, supplying directly to the sample function.

e.g.

rng = StableRNG(123)
@time mcmcchains = sample(rng, mod, NUTS(nsampl,adapt_delta), MCMCThreads(), nsampl, nchains)

If you do this, do you have the same issue?

Your comment and code seem to indicate that you use several threads in your sampling, which might not work with your seeding. Perhaps multithreading - Setting seeds in multi-threading loop in Julia - Stack Overflow is relevant.

I can try that. Do you think that would work with the multithreading I am using as @Johan-Gronqvist mentioned?! Would I still need multiple seeds for each thread? How would I apply that in the sample function?

This should work for multithreadding, I believe the implementation can be found at the link below. rng is copied for each thread, and a seed is generated and set for each thread based on the original rng.

Got it. WiIl try that. Thanks.

I would suggest using StableRNG where reproducibility is important, supplying directly to the sample function.

Just a brief comment on this: StableRNG only improves reproducibility
wrt. hardware, but IIUC the above is all on the same device, in which
case StableRNG should not make any difference.

But the passing of rng to sample is the crucial part to check for reproducibility.

1 Like

I don’t think this is quite right.

Streams from Julia’s default RNG should be reproducible within the same Julia version even across different hardware.

StableRNG streams are reproducible across different Julia versions.

Seems that passing rng to sample only works if running in serial and fails to reproduce results f running multi threads. Also, looks like the solutions I have seen here depend on some loop while using multiple threads like using foreach or Threads.@threads, but won’t this require the user to manually patch the output results and compute the useful Bayesian stats that Turing.jl computes automatically like ESS and Rhat?

Reproducibility of sampling a single chain and sampling multiple chains (serially, with multithreading and with multiple processes) are tested in AbstractMCMC and also Turing itself. It should work for any RNG if the model does not use any shared RNG (eg a call to rand with the default RNG) that might mess with reproducibility. If there is a problem somewhere (eg if maybe a sampler does not respect the RNG provided by sample), a simple MWE would be good.

2 Likes

Could you please clarify what MWE refers to?

It means “minimal working example”. This post (pinned on the discourse start page) explains it: Please read: make it easier to help you

Thanks for the clarification. So perhaps I did something wrong here!! Here is a MWE. Takes about two minutes to run but the issue still persists whatever number of samples used.

using DifferentialEquations
using Turing
using RDatasets
using DataFramesMeta, Chain
using Random

# load Theoph datasets
Theoph = dataset("datasets", "Theoph")

data1 = @chain begin
    Theoph
    @transform(:amt = 0.0, :evid = 0)
    rename(:Subject => :ID, :Conc => :dv)
end

data_dose = @chain begin
    data1
    @subset(:Time .== 0.0)
    @transform(:evid = 1, :amt = :Dose .* :Wt, :dv = 0.0)
end

function pk1cpt!(du, u, p, t)
    depot, cent = u
    ka, CL, V = p

    du[1] = ddepot = -ka * depot
    du[2] = dcent = ka * depot / V - (CL / V) * cent
end

# set conditions
u0 = [319.365, 0.0]
p = [2.0, 4.0, 35.0]
tspan = (0.0, 25.0)

# define ODE problem and solve
prob = ODEProblem(pk1cpt!, u0, tspan, p)

times = [data1.Time[data1.ID.==string(i)] for i = 1:12]
doses = data_dose.amt
nSubject = 12
bws = data_dose.Wt

@model function fitPKPop(data, prob, nSubject, doses, times, bws)
    # priors
    ## residual error
    σ ~ truncated(Cauchy(0.0, 0.5), 0.0, 2.0)

    ## population params
    k̂a ~ LogNormal(log(2.0), 0.2)
    ĈL ~ LogNormal(log(4.0), 0.2)
    V̂ ~ LogNormal(log(35.0), 0.2)

    # IIV
    ωₖₐ ~ truncated(Cauchy(0.0, 0.5), 0.0, 2.0)

    CLᵢ = ĈL .* (bws ./ 70.0) .^ 0.75
    Vᵢ = V̂ .* (bws ./ 70.0)

    # non-centered parameterization
    ηᵢ ~ filldist(Normal(0.0, 1.0), nSubject)
    kaᵢ = k̂a .* exp.(ωₖₐ .* ηᵢ)

    function prob_func(prob, i, repeat)
        u0_tmp = [doses[i], 0.0]
        ps = [kaᵢ[i], CLᵢ[i], Vᵢ[i]]
        remake(prob, u0 = u0_tmp, p = ps, saveat = times[i])
    end

    tmp_ensemble_prob = EnsembleProblem(prob, prob_func = prob_func)
    tmp_ensemble_sol = solve(tmp_ensemble_prob, Tsit5(), trajectories = nSubject)

    predicted = reduce(vcat, [Array(tmp_ensemble_sol[i])[2, :] for i = 1:nSubject])

    # likelihood
    for i = 1:length(predicted)
        data[i] ~ Normal(max(predicted[i], 1e-12), σ)
    end
end

model_pop = fitPKPop(data1.dv, prob, nSubject, doses, times, bws)

rng = MersenneTwister(1)
@time chain_pop = sample(rng, model_pop, NUTS(250,.8), MCMCThreads(), 250, 4)

I wouldn’t call that example minimal, pre-compiling the dependencies took more than 5 minutes, running the example another 5 minutes, and the REPL is flooded with “Interrupted. Larger maxiters is needed” warnings from the ODE solver. Ideally, a MWE should be condensed to the absolute bare minimum needed to reproduce the issue.

Terminating a second run with Ctrl+C resulted in a Julia segfault and I don’t have time today anymore to rerun the example again.

Maybe a MWE only requires a simple EnsembleProblem call. EnsembleProblem uses multithreading by default, and possibly nested multithreading causes the issues here. You could check if the issue is fixed by (a) using a different algorithm for solving the EnsembleProblem (e.g., EnsembleSerial: Parallel Ensemble Simulations · SciML) or (b) using a different parallel sampling algorithm (e.g. MCMCSerial).

Sorry the example took longer than anticipated. I will try to condense it a bit. I will try the alternatives you suggested. I know that using MCMCSerial reproduces the results. That is not running in parallel though, is it?

I confirm that the issue must have been with the nested threading. Setting EnsembleSerial() for EnsembleProblem solves the issue.

I’m not sure doubling threads like that makes sense. If you’re already multithreading the chains, then multithrading the solves will just overdo the amount of threads. I’d be curious to see how that effects the RNG, but the real coding solution for this kind of code is to not double threading primitives to get more robustness and performance.

Yes I really hadn’t considered the double threading here. I just missed the default threading in solving the EnsembleProblem. Doing some quick tests to see which threading would get me a better performance, threading chains or solves. If I only use one threading technique, results are indeed reproducible.