Slow MCMC sampling of complex polynomial model [Turing.jl]

In my model functions I use several complicated functions (long equations for Q and cd functions with many terms). Basically [look in Q_generated function] they are some polynomials of v and w parameter - where v is scalar and w is vector of values (each w for some independet variable x -temperatrue). I want to perform efficient MCMC (NUTS) using Turing.jl. [For now “experimental” data is only simulated with added noise, but I’ll use some real data later]

What are the best practices for conducting such inference?
Currently, my code runs reasonably well for N_pep=7. However, as I incorporate more complexity into the model (with longer and more complex polynomials - e.g. N_pep=17), the sampling speed significantly drops. Is there a more effective way to write these polynomials so that Julia can process them more efficiently? (pre-compilation of functions, vectorization,…?)?

Here is my sample code (try for N_pep=7 and N_pep=17)…

using Turing, MCMCChains, Random,Plots,StatsPlots, Statistics
#include("myfunctions.jl")
#using MyFunctions


function w_generated(x, dG, dH, dcp)
    return exp(-1 / R * (dG / T0 + dH * (1 / x - 1 / T0) + dcp * (1 - T0 / x - log(x / T0))))
end

function spectro_temp(P0, Pt, temps)
    return P0 + Pt * (temps - (T0-273.15))
end

function Q_generated(v, w,Npep)
    if Npep==7
        q=3 *v^4 *w^2 + 9 *v^4 *w + 12 *v^4 + 2 *v^3 *w^3 + 6 *v^3 *w^2 + 12 *v^3 *w + 20 *v^3 + v^2 *w^5 + 2 *v^2 *w^4 + 3 *v^2 *w^3 + 4 *v^2 *w^2 + 5 *v^2 *w + 21 *v^2 + 7 *v + 1     
        return q
    elseif Npep==17
        q = 54 *v^6 *w^8 + 240 *v^6 *w^7 + 630 *v^6 *w^6 + 1260 *v^6 *w^5 + 2100 *v^6 *w^4 + 3024 *v^6 *w^3 + 3780 *v^6 *w^2 + 3960 *v^6 *w + 33 *v^5 *w^10 + 120 *v^5 *w^9 + 270 *v^5 *w^8 + 480 *v^5 *w^7 + 735 *v^5 *w^6 + 1008 *v^5 *w^5 + 1260 *v^5 *w^4 + 1440 *v^5 *w^3 + 1485 *v^5 *w^2 + 1320 *v^5 *w + 13 *v^4 *w^12 + 39 *v^4 *w^11 + 78 *v^4 *w^10 + 130 *v^4 *w^9 + 195 *v^4 *w^8 + 273 *v^4 *w^7 + 364 *v^4 *w^6 + 468 *v^4 *w^5 + 585 *v^4 *w^4 + 715 *v^4 *w^3 + 858 *v^4 *w^2 + 1014 *v^4 *w + 1092 *v^4 + 2 *v^3 *w^13 + 6 *v^3 *w^12 + 12 *v^3 *w^11 + 20 *v^3 *w^10 + 30 *v^3 *w^9 + 42 *v^3 *w^8 + 56 *v^3 *w^7 + 72 *v^3 *w^6 + 90 *v^3 *w^5 + 110 *v^3 *w^4 + 132 *v^3 *w^3 + 156 *v^3 *w^2 + 182 *v^3 *w + 210 *v^3 + v^2 *w^15 + 2 *v^2 *w^14 + 3 *v^2 *w^13 + 4 *v^2 *w^12 + 5 *v^2 *w^11 + 6 *v^2 *w^10 + 7 *v^2 *w^9 + 8 *v^2 *w^8 + 9 *v^2 *w^7 + 10 *v^2 *w^6 + 11 *v^2 *w^5 + 12 *v^2 *w^4 + 13 *v^2 *w^3 + 14 *v^2 *w^2 + 15 *v^2 *w + 136 *v^2 + 17 *v + 1
        return q
    end
end

function cd_signal(v,w,H1,H2,C,Npep)
    if Npep==7
        cd=96 *C *v^4 + 160 *C *v^3 + 168 *C *v^2 + 56 *C *v + 8 *C + 8 *H1 *v^4 *w^2 + 2 *v^4 *w^2 *(3 *C + 5 *H1) + 36 *v^4 *w *(C + H1) + 4 *v^3 *w^3 *(C + 3 *H1) + 6 *v^3 *w^2 *(3 *C + 5 *H1) + 48 *v^3 *w *(C + H1) + 2 *v^2 *w^5 *(3 *H1 + H2) + 2 *v^2 *w^4 *(C + 6 *H1 + H2) + 6 *v^2 *w^3 *(C + 3 *H1) + 4 *v^2 *w^2 *(3 *C + 5 *H1) + 20 *v^2 *w *(C + H1)
        return cd
    elseif Npep==17
        cd=C * (19656 *v^4 + 3780 *v^3 + 2448 *v^2 + 306 *v + 18) + v^6 *w^8 *(252 *C + 216 *H1 + 180 *H2) + v^6 *w^7 *(1380 *C + 2340 *H1 + 600 *H2) + v^6 *w^6 *(4320 *C + 6030 *H1 + 990 *H2) + v^6 *w^5 *(10080 *C + 11340 *H1 + 1260 *H2) + v^6 *w^4 *(19220 *C + 17640 *H1 + 840 *H2) + v^6 *w^3 *(31752 *C + 22680 *H1) + v^6 *w^2 *(45360 *C + 22680 *H1) + 7920 *v^6 *w *(7 *C + 2 *H1) + v^5 *w^10 *(84 *C + 343.8 *H1 + 166.2 *H2) + v^5 *w^9 *(432 *C + 1224 *H1 + 504 *H2) + v^5 *w^8 *(1260 *C + 2722.5 *H1 + 877.5 *H2) + v^5 *w^7 *(2760 *C + 4680 *H1 + 1200 *H2) + v^5 *w^6 *(5040 *C + 7035 *H1 + 1155 *H2) + v^5 *w^5 *(8064 *C + 9072 *H1 + 1008 *H2) + v^5 *w^4 *(11592 *C + 10584 *H1 + 504 *H2) + v^5 *w^3 *(15120 *C + 10800 *H1) + v^5 *w^2 *(17820 *C + 8910 *H1) + 2640 *v^5 *w *(7 *C + 2 *H1) + v^4 *w^12 *(6 *C + 138.5 *H1 + 89.5 *H2) + v^4 *w^11 *(66 *C + 396 *H1 + 240 *H2) + v^4 *w^10 *(228 *C + 735.6 *H1 + 416.4 *H2) + v^4 *w^9 *(540 *C + 1200 *H1 + 600 *H2) + v^4 *w^8 *(1050 *C + 1721.25 *H1 + 828.75 *H2) + v^4 *w^7 *(1806 *C + 3268 *H1 + 840 *H2) + v^4 *w^6 *(2856 *C + 4474 *H1 + 812 *H2) + v^4 *w^5 *(4320 *C + 3456 *H1 + 720 *H2) + v^4 *w^4 *(6030 *C + 4050 *H1 + 450 *H2) + v^4 *w^3 *(8750 *C + 4620 *H1) + v^4 *w^2 *(10956 *C + 4488 *H1) + 2028 *v^4 *w *(7 *C + 2 *H1) + v^3 *w^13 *(4 *C + 12 *H1 + 20 *H2) + v^3 *w^12 *(18 *C + 36 *H1 + 54 *H2) + v^3 *w^11 *(48 *C + 72 *H1 + 96 *H2) + v^3 *w^10 *(100 *C + 120 *H1 + 140 *H2) + v^3 *w^9 *(180 *C + 180 *H1 + 180 *H2) + v^3 *w^8 *(294 *C + 252 *H1 + 210 *H2) + v^3 *w^7 *(448 *C + 336 *H1 + 224 *H2) + v^3 *w^6 *(648 *C + 432 *H1 + 216 *H2) + v^3 *w^5 *(900 *C + 540 *H1 + 180 *H2) + v^3 *w^4 *(1210 *C + 660 *H1 + 110 *H2) + v^3 *w^3 *(1584 *C + 792 *H1) + v^3 *w^2 *(2028 *C + 780 *H1) + 364 *v^3 *w *(7 *C + 2 *H1) + v^2 *w^15 *(6 *H1 + 12 *H2) + v^2 *w^14 *(2 *C + 12 *H1 + 22 *H2) + v^2 *w^13 *(6 *C + 18 *H1 + 30 *H2) + v^2 *w^12 *(12 *C + 24 *H1 + 36 *H2) + v^2 *w^11 *(20 *C + 30 *H1 + 40 *H2) + v^2 *w^10 *(30 *C + 36 *H1 + 42 *H2) + v^2 *w^9 *(42 *C + 42 *H1 + 42 *H2) + v^2 *w^8 *(56 *C + 48 *H1 + 40 *H2) + v^2 *w^7 *(72 *C + 54 *H1 + 36 *H2) + v^2 *w^6 *(90 *C + 60 *H1 + 30 *H2) + v^2 *w^5 *(110 *C + 66 *H1 + 22 *H2) + v^2 *w^4 *(132 *C + 72 *H1 + 12 *H2) + v^2 *w^3 *(156 *C + 78 *H1) + v^2 *w^2 *(182 *C + 70 *H1) + 30 *v^2 *w *(7 *C + 2 *H1)
        return cd
    end
end

# Simulated dataset
Random.seed!(12)
N = 101 # number of data points
R = 0.001987 # constant R
T0 = 273.15 # constant T0

x = collect(273.15:1:373.15) # temperatures
xc = x .-T0

N_pep=7 #CHOOSE SYSTEM - MODEL (N_pep=7,17,22,27,32)

# Simulated parameter values
v_true = 0.048
dG_true = -0.24
dH_true = -1.0
dcp_true = -0.01

h1_true=-15000
c_true=2000
h2_true=-40000

H1t_true=spectro_temp.(h1_true,20,xc)
H2t_true=spectro_temp.(h2_true,100,xc)
Ct_true=spectro_temp.(c_true,-50,xc)

σ_true = 250 # add noise

# Simulate the model
w_true = w_generated.(x, dG_true, dH_true, dcp_true)
Q_true = Q_generated.(v_true, w_true,N_pep)

#SIMULATED data
y = (cd_signal.(v_true,w_true,H1t_true,H2t_true,Ct_true,N_pep) ./ (N_pep .* Q_true)) .+ randn(N) .* σ_true

#Plot the simulated data
p1=plot(xc, y, seriestype = :scatter, title = "Simulated Data", xlabel = "x", ylabel = "y", legend=false)
savefig("CD_simulated_data.png") # save the plot to a file

# MODEL
@model function model_cp(x, xc, y,N_pep, N)
    σ ~ truncated(Normal(0,300), 0, Inf)

    #TD parameters
    v ~ Normal(v_true,0.1)
    dG ~ Normal(dG_true,0.2)
    dH ~ Normal(dH_true,0.5)
    dcp ~ Normal(dcp_true,0.1)
    #Spectroscopic parameters
    h1 ~ Normal(h1_true,1000)
    h2 ~ Normal(h2_true,1000)
    c ~ Normal(c_true,300)

    #Deterministic parmams
    w = w_generated.(x, dG, dH, dcp)
    Q = Q_generated.(v, w, N_pep)

    H1t=spectro_temp.(h1,20,xc)
    H2t=spectro_temp.(h2,100,xc)
    Ct=spectro_temp.(c,-50,xc)

    for n in 1:N
        μ = (cd_signal(v,w[n],H1t[n],H2t[n],Ct[n],N_pep) / (N_pep * Q[n]))

        y[n] ~ Normal(μ, σ)
    end   
end

# MCMC
model = model_cp(x, xc, y, N_pep, N)
chains = sample(model, NUTS(), 1000,burn_in=200)

# Plot the results
p = plot(chains) # MCMC diagnostics plots
savefig(p, "CD_mcmc_diagnostics.png") # Save the diagnostics plot

# Plot the data with the best fit line
v_hat = mean(chains[:v])  # estimated v
dG_hat = mean(chains[:dG])  # estimated dG
dH_hat = mean(chains[:dH])  # estimated dH
dcp_hat = mean(chains[:dcp])  # estimated dcp
σ_hat = mean(chains[:σ])  # estimated σ

println(v_hat,dG_hat,dH_hat,dcp_hat)

w_hat = w_generated.(x, dG_hat, dH_hat, dcp_hat)
Q_hat = Q_generated.(v_hat, w_hat,N_pep)

h1_hat= mean(chains[:h1])
h2_hat = mean(chains[:h2])
c_hat = mean(chains[:c])

H1t_hat=spectro_temp.(h1_hat,20,xc)
H2t_hat=spectro_temp.(h2_hat,100,xc)
Ct_hat=spectro_temp.(c_hat,-50,xc)
          
y_hat = cd_signal.(v_hat,w_hat,H1t_hat,H2t_hat,Ct_hat,N_pep) ./ (N_pep .* Q_hat)

p = scatter(xc, y, label="Data")  # plot the original data
plot!(xc, y_hat, label="Best Fit Line")  # add the best fit line
savefig(p, "CD_data_and_best_fit.png")  # Save the data and best fit plot

I never used it myself but you can give RxInfer.jl a try. The following video is very didactic:

https://youtu.be/_vVHWzK9NEI

1 Like

There are some performance tips on the Turing website: Performance Tips

But these need to be updated quite significantly; will hopefully get around to that somewhat soonish.

Generally speaking, I recommend using BenchmarkTools.jl to benchmark the different components of your model.Then you can also try out GitHub - TuringLang/TuringBenchmarking.jl to benchmark your model for different AD backends.