# Speeding up MCMC for logistic regression with large datasets using Turing or some alternative

Hi, I am running some logistic regression MCMC sampling chains using Turing and wondering how best to optimise / use multiprocessing etc. to get maximum performance. My code currently looks like this:

``````function get_conditional_pairs(l1, l2; max_sum=1)
return ((a1, a2) for a1 in l1 for a2 in l2 if a1 + a2 <= max_sum)
end

struct Data{Ty_real, TX_real, Ty_synth, TX_synth}
y_real::Ty_real
X_real::TX_real
y_synth::Ty_synth
X_synth::TX_synth
end

struct KLDParams{Tw, Tσ}
w::Tw
σ::Tσ
end

@model logistic_regression(data, params) = begin
@unpack y_real, X_real, y_synth, X_synth = data
@unpack w, σ = params
coefs ~ MvNormal(zeros(size(X_real)[2]), Diagonal(repeat([σ], size(X_real)[2])))

@logpdf() += sum(logpdf.(BinomialLogit.(1, X_real * coefs), y_real))
@logpdf() += w * sum(logpdf.(BinomialLogit.(1, X_synth * coefs), y_synth))
end

real_αs = [0.1, 0.25, 0.5, 1.0]
synth_αs = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 1.0]

for (real_α, synth_α) in get_conditional_pairs(real_αs, synth_αs)
input_data = Data(
Matrix(real_train[1:floor(Int32, len_real * real_α), labels]),
Matrix(real_train[1:floor(Int32, len_real * real_α), Not(labels)]),
Matrix(synth_train[1:floor(Int32, len_synth * synth_α), labels]),
Matrix(synth_train[1:floor(Int32, len_synth * synth_α), Not(labels)])
)
weighted_synth_params = KLDParams(0.5, 10.0)
weighted_synth_result = sample(logistic_regression(input_data, no_synth_params), HMC(0.05, 10), 6000)
end
``````

The dataset I’m using is around 200k rows total so understandably at large alphas it will be slow, just looking to ensure I am doing all I can to optimise. I do have access to a pretty large cluster, and some GPUs if needed, what kind of speed-up can I achieve using Turing / some other approach that may be more performant? @trappmartin I made a new question for this as my previous one was resolved but wonder if you have any more good insight for this?

Running HMC / NUTS in Turing is often limited by the `AD` backend that you use.

You can find some pointers here:
https://turing.ml/dev/docs/using-turing/performancetips
https://turing.ml/dev/docs/using-turing/autodiff

@mohamed82008 and @Kai_Xu might be able to give further tips and details regarding inference on a GPU.

I am having some issues with getting my code to work with reverse_diff enabled, which I believe would be a good idea due to having around 35 params, I get the following error:

``````ERROR: MethodError: no method matching zero(::Type{Any})
Closest candidates are:
zero(::Type{Union{Missing, T}}) where T at missing.jl:105
zero(::Type{Missing}) at missing.jl:103
zero(::Type{LibGit2.GitHash}) at /Users/sabae/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.3/LibGit2/src/oid.jl:220
...
Stacktrace:
[1] zero(::Type{Any}) at ./missing.jl:105
[2] reduce_empty(::typeof(+), ::Type) at ./reduce.jl:223
[4] mapreduce_empty(::typeof(identity), ::Function, ::Type) at ./reduce.jl:247
[5] _mapreduce(::typeof(identity), ::typeof(Base.add_sum), ::IndexLinear, ::Array{Any,2}) at ./reduce.jl:301
[6] _mapreduce_dim at ./reducedim.jl:312 [inlined]
[7] #mapreduce#584 at ./reducedim.jl:307 [inlined]
[8] mapreduce at ./reducedim.jl:307 [inlined]
[9] _sum at ./reducedim.jl:657 [inlined]
[10] _sum at ./reducedim.jl:656 [inlined]
[11] #sum#587 at ./reducedim.jl:652 [inlined]
[12] sum(::Array{Any,2}) at ./reducedim.jl:652
[13] macro expansion at ./REPL[11]:8 [inlined]
[15] #_#3 at /Users/harrisonwilde/.julia/packages/DynamicPPL/xwKXl/src/model.jl:24 [inlined]
[16] Model at /Users/harrisonwilde/.julia/packages/DynamicPPL/xwKXl/src/model.jl:24 [inlined]
[17] runmodel! at /Users/harrisonwilde/.julia/packages/DynamicPPL/xwKXl/src/varinfo.jl:602 [inlined]
[18] runmodel! at /Users/harrisonwilde/.julia/packages/DynamicPPL/xwKXl/src/varinfo.jl:598 [inlined]
[20] #20 at /Users/harrisonwilde/.julia/packages/Tracker/cpxco/src/back.jl:148 [inlined]
[22] forward(::Function, ::Array{Float64,1}) at /Users/harrisonwilde/.julia/packages/Tracker/cpxco/src/back.jl:148
[25] ∂logπ∂θ at /Users/harrisonwilde/.julia/packages/Turing/azHIm/src/inference/hmc.jl:401 [inlined]
[29] #find_good_eps at /Users/harrisonwilde/.julia/packages/Turing/azHIm/src/inference/hmc.jl:0 [inlined]
[35] Sampler at /Users/harrisonwilde/.julia/packages/Turing/azHIm/src/inference/hmc.jl:302 [inlined]
[37] sample at /Users/harrisonwilde/.julia/packages/Turing/azHIm/src/inference/Inference.jl:148 [inlined]
[38] #sample#1 at /Users/harrisonwilde/.julia/packages/Turing/azHIm/src/inference/Inference.jl:136 [inlined]
[40] top-level scope at ./REPL[35]:14
``````

from this model and code:

``````@model weighted_logistic_regression(data, params) = begin

@unpack y_real, X_real, y_synth, X_synth = data
@unpack w, μs, σs = params

coefs ~ MvNormal(μs, σs)
@logpdf() += sum(logpdf.(BinomialLogit.(1, X_real * coefs), y_real))
@logpdf() += w * sum(logpdf.(BinomialLogit.(1, X_synth * coefs), y_synth))
end

...

len_real = size(real_train)[1]
len_synth = size(synth_train)[1]
len_test = size(real_test)[1]

num_chains = 3
real_αs = [0.1, 0.25, 0.5]
synth_αs = [0.0, 0.05, 0.1, 0.25, 0.5]
for (real_α, synth_α) in get_conditional_pairs(real_αs, synth_αs)
input_data = Data(
Int.(Matrix(real_train[1:floor(Int32, len_real * real_α), labels])),
Matrix(real_train[1:floor(Int32, len_real * real_α), Not(labels)]),
Int.(Matrix(synth_train[1:floor(Int32, len_synth * synth_α), labels])),
Matrix(synth_train[1:floor(Int32, len_synth * synth_α), Not(labels)])
)
params = WeightedKLDParams(
0.5,
ones(size(real_train)[2] - size(labels)[1]),
Diagonal(repeat([2.0], size(real_train)[2] - size(labels)[1]))
)
# weighted_chain = mapreduce(c -> sample(weighted_logistic_regression(input_data, weighted_params), DynamicNUTS(), 5000), chainscat, 1:num_chains)
weighted_chain = sample(weighted_logistic_regression(input_data, params), NUTS(500, 0.651), 5000)
write("weighted_chains_real" * string(real_α) * "_synth" * string(synth_α), weighted_chain)
params = WeightedKLDParams(
1.0,
ones(size(real_train)[2] - size(labels)[1]),
Diagonal(repeat([2.0], size(real_train)[2] - size(labels)[1]))
)
# weighted_chain = mapreduce(c -> sample(weighted_logistic_regression(input_data, kld_params), DynamicNUTS(), 5000), chainscat, 1:num_chains)
naive_chain = sample(weighted_logistic_regression(input_data, params), NUTS(500, 0.651), 5000)
write("naive_chains_real" * string(real_α) * "_synth" * string(synth_α), naive_chain)
end
``````

Any ideas on how to get reverse diff working / is there a clear bit of doc anywhere to describe how to use it / what methods are supported?