Speeding up MCMC for logistic regression with large datasets using Turing or some alternative

Hi, I am running some logistic regression MCMC sampling chains using Turing and wondering how best to optimise / use multiprocessing etc. to get maximum performance. My code currently looks like this:

function get_conditional_pairs(l1, l2; max_sum=1)
	return ((a1, a2) for a1 in l1 for a2 in l2 if a1 + a2 <= max_sum)
end

struct Data{Ty_real, TX_real, Ty_synth, TX_synth}
    y_real::Ty_real
    X_real::TX_real
    y_synth::Ty_synth
    X_synth::TX_synth
end

struct KLDParams{Tw, Tσ}
    w::Tw
    σ::Tσ
end

@model logistic_regression(data, params) = begin
	@unpack y_real, X_real, y_synth, X_synth = data
	@unpack w, σ = params
    coefs ~ MvNormal(zeros(size(X_real)[2]), Diagonal(repeat([σ], size(X_real)[2])))

    @logpdf() += sum(logpdf.(BinomialLogit.(1, X_real * coefs), y_real))
    @logpdf() += w * sum(logpdf.(BinomialLogit.(1, X_synth * coefs), y_synth))
end

real_αs = [0.1, 0.25, 0.5, 1.0]
synth_αs = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 1.0]

for (real_α, synth_α) in get_conditional_pairs(real_αs, synth_αs)
	input_data = Data(
		Matrix(real_train[1:floor(Int32, len_real * real_α), labels]), 
		Matrix(real_train[1:floor(Int32, len_real * real_α), Not(labels)]), 
		Matrix(synth_train[1:floor(Int32, len_synth * synth_α), labels]), 
		Matrix(synth_train[1:floor(Int32, len_synth * synth_α), Not(labels)])
	)
	weighted_synth_params = KLDParams(0.5, 10.0)
	weighted_synth_result = sample(logistic_regression(input_data, no_synth_params), HMC(0.05, 10), 6000)
end

The dataset I’m using is around 200k rows total so understandably at large alphas it will be slow, just looking to ensure I am doing all I can to optimise. I do have access to a pretty large cluster, and some GPUs if needed, what kind of speed-up can I achieve using Turing / some other approach that may be more performant? @trappmartin I made a new question for this as my previous one was resolved but wonder if you have any more good insight for this?

Running HMC / NUTS in Turing is often limited by the AD backend that you use.

You can find some pointers here:
https://turing.ml/dev/docs/using-turing/performancetips
https://turing.ml/dev/docs/using-turing/autodiff

@mohamed82008 and @Kai_Xu might be able to give further tips and details regarding inference on a GPU.

1 Like

I am having some issues with getting my code to work with reverse_diff enabled, which I believe would be a good idea due to having around 35 params, I get the following error:

ERROR: MethodError: no method matching zero(::Type{Any})
Closest candidates are:
  zero(::Type{Union{Missing, T}}) where T at missing.jl:105
  zero(::Type{Missing}) at missing.jl:103
  zero(::Type{LibGit2.GitHash}) at /Users/sabae/buildbot/worker/package_macos64/build/usr/share/julia/stdlib/v1.3/LibGit2/src/oid.jl:220
  ...
Stacktrace:
 [1] zero(::Type{Any}) at ./missing.jl:105
 [2] reduce_empty(::typeof(+), ::Type) at ./reduce.jl:223
 [3] reduce_empty(::typeof(Base.add_sum), ::Type) at ./reduce.jl:230
 [4] mapreduce_empty(::typeof(identity), ::Function, ::Type) at ./reduce.jl:247
 [5] _mapreduce(::typeof(identity), ::typeof(Base.add_sum), ::IndexLinear, ::Array{Any,2}) at ./reduce.jl:301
 [6] _mapreduce_dim at ./reducedim.jl:312 [inlined]
 [7] #mapreduce#584 at ./reducedim.jl:307 [inlined]
 [8] mapreduce at ./reducedim.jl:307 [inlined]
 [9] _sum at ./reducedim.jl:657 [inlined]
 [10] _sum at ./reducedim.jl:656 [inlined]
 [11] #sum#587 at ./reducedim.jl:652 [inlined]
 [12] sum(::Array{Any,2}) at ./reducedim.jl:652
 [13] macro expansion at ./REPL[11]:8 [inlined]
 [14] (::var"##inner_function#429#11")(::DynamicPPL.VarInfo{NamedTuple{(:coefs,),Tuple{DynamicPPL.Metadata{Dict{DynamicPPL.VarName{:coefs},Int64},Array{MvNormal{Float64,PDMats.PDiagMat{Float64,Array{Float64,1}},Array{Float64,1}},1},Array{DynamicPPL.VarName{:coefs},1},TrackedArray{…,Array{Float64,1}},Array{Set{DynamicPPL.Selector},1}}}},Tracker.TrackedReal{Float64}}, ::DynamicPPL.Sampler{NUTS{Turing.Core.TrackerAD,(),AdvancedHMC.Adaptation.DiagEuclideanMetric},Turing.Inference.SamplerState{DynamicPPL.VarInfo{NamedTuple{(:coefs,),Tuple{DynamicPPL.Metadata{Dict{DynamicPPL.VarName{:coefs},Int64},Array{MvNormal{Float64,PDMats.PDiagMat{Float64,Array{Float64,1}},Array{Float64,1}},1},Array{DynamicPPL.VarName{:coefs},1},Array{Float64,1},Array{Set{DynamicPPL.Selector},1}}}},Float64}}}, ::DynamicPPL.DefaultContext, ::DynamicPPL.Model{var"##inner_function#429#11",NamedTuple{(:data, :params),Tuple{Data{Array{Int64,2},Array{Float64,2},Array{Int64,2},Array{Float64,2}},WeightedKLDParams{Float64,Array{Float64,1},Diagonal{Float64,Array{Float64,1}}}}},DynamicPPL.ModelGen{(:data, :params),var"###weighted_logistic_regression#437",NamedTuple{(),Tuple{}}},Val{()}}) at /Users/harrisonwilde/.julia/packages/DynamicPPL/xwKXl/src/compiler.jl:602
 [15] #_#3 at /Users/harrisonwilde/.julia/packages/DynamicPPL/xwKXl/src/model.jl:24 [inlined]
 [16] Model at /Users/harrisonwilde/.julia/packages/DynamicPPL/xwKXl/src/model.jl:24 [inlined]
 [17] runmodel! at /Users/harrisonwilde/.julia/packages/DynamicPPL/xwKXl/src/varinfo.jl:602 [inlined]
 [18] runmodel! at /Users/harrisonwilde/.julia/packages/DynamicPPL/xwKXl/src/varinfo.jl:598 [inlined]
 [19] (::Turing.Core.var"#f#8"{DynamicPPL.VarInfo{NamedTuple{(:coefs,),Tuple{DynamicPPL.Metadata{Dict{DynamicPPL.VarName{:coefs},Int64},Array{MvNormal{Float64,PDMats.PDiagMat{Float64,Array{Float64,1}},Array{Float64,1}},1},Array{DynamicPPL.VarName{:coefs},1},Array{Float64,1},Array{Set{DynamicPPL.Selector},1}}}},Float64},DynamicPPL.Model{var"##inner_function#429#11",NamedTuple{(:data, :params),Tuple{Data{Array{Int64,2},Array{Float64,2},Array{Int64,2},Array{Float64,2}},WeightedKLDParams{Float64,Array{Float64,1},Diagonal{Float64,Array{Float64,1}}}}},DynamicPPL.ModelGen{(:data, :params),var"###weighted_logistic_regression#437",NamedTuple{(),Tuple{}}},Val{()}},DynamicPPL.Sampler{NUTS{Turing.Core.TrackerAD,(),AdvancedHMC.Adaptation.DiagEuclideanMetric},Turing.Inference.SamplerState{DynamicPPL.VarInfo{NamedTuple{(:coefs,),Tuple{DynamicPPL.Metadata{Dict{DynamicPPL.VarName{:coefs},Int64},Array{MvNormal{Float64,PDMats.PDiagMat{Float64,Array{Float64,1}},Array{Float64,1}},1},Array{DynamicPPL.VarName{:coefs},1},Array{Float64,1},Array{Set{DynamicPPL.Selector},1}}}},Float64}}}})(::TrackedArray{…,Array{Float64,1}}) at /Users/harrisonwilde/.julia/packages/Turing/azHIm/src/core/ad.jl:136
 [20] #20 at /Users/harrisonwilde/.julia/packages/Tracker/cpxco/src/back.jl:148 [inlined]
 [21] forward(::Tracker.var"#20#22"{Turing.Core.var"#f#8"{DynamicPPL.VarInfo{NamedTuple{(:coefs,),Tuple{DynamicPPL.Metadata{Dict{DynamicPPL.VarName{:coefs},Int64},Array{MvNormal{Float64,PDMats.PDiagMat{Float64,Array{Float64,1}},Array{Float64,1}},1},Array{DynamicPPL.VarName{:coefs},1},Array{Float64,1},Array{Set{DynamicPPL.Selector},1}}}},Float64},DynamicPPL.Model{var"##inner_function#429#11",NamedTuple{(:data, :params),Tuple{Data{Array{Int64,2},Array{Float64,2},Array{Int64,2},Array{Float64,2}},WeightedKLDParams{Float64,Array{Float64,1},Diagonal{Float64,Array{Float64,1}}}}},DynamicPPL.ModelGen{(:data, :params),var"###weighted_logistic_regression#437",NamedTuple{(),Tuple{}}},Val{()}},DynamicPPL.Sampler{NUTS{Turing.Core.TrackerAD,(),AdvancedHMC.Adaptation.DiagEuclideanMetric},Turing.Inference.SamplerState{DynamicPPL.VarInfo{NamedTuple{(:coefs,),Tuple{DynamicPPL.Metadata{Dict{DynamicPPL.VarName{:coefs},Int64},Array{MvNormal{Float64,PDMats.PDiagMat{Float64,Array{Float64,1}},Array{Float64,1}},1},Array{DynamicPPL.VarName{:coefs},1},Array{Float64,1},Array{Set{DynamicPPL.Selector},1}}}},Float64}}}}}, ::Tracker.Params) at /Users/harrisonwilde/.julia/packages/Tracker/cpxco/src/back.jl:135
 [22] forward(::Function, ::Array{Float64,1}) at /Users/harrisonwilde/.julia/packages/Tracker/cpxco/src/back.jl:148
 [23] gradient_logp_reverse(::Array{Float64,1}, ::DynamicPPL.VarInfo{NamedTuple{(:coefs,),Tuple{DynamicPPL.Metadata{Dict{DynamicPPL.VarName{:coefs},Int64},Array{MvNormal{Float64,PDMats.PDiagMat{Float64,Array{Float64,1}},Array{Float64,1}},1},Array{DynamicPPL.VarName{:coefs},1},Array{Float64,1},Array{Set{DynamicPPL.Selector},1}}}},Float64}, ::DynamicPPL.Model{var"##inner_function#429#11",NamedTuple{(:data, :params),Tuple{Data{Array{Int64,2},Array{Float64,2},Array{Int64,2},Array{Float64,2}},WeightedKLDParams{Float64,Array{Float64,1},Diagonal{Float64,Array{Float64,1}}}}},DynamicPPL.ModelGen{(:data, :params),var"###weighted_logistic_regression#437",NamedTuple{(),Tuple{}}},Val{()}}, ::DynamicPPL.Sampler{NUTS{Turing.Core.TrackerAD,(),AdvancedHMC.Adaptation.DiagEuclideanMetric},Turing.Inference.SamplerState{DynamicPPL.VarInfo{NamedTuple{(:coefs,),Tuple{DynamicPPL.Metadata{Dict{DynamicPPL.VarName{:coefs},Int64},Array{MvNormal{Float64,PDMats.PDiagMat{Float64,Array{Float64,1}},Array{Float64,1}},1},Array{DynamicPPL.VarName{:coefs},1},Array{Float64,1},Array{Set{DynamicPPL.Selector},1}}}},Float64}}}) at /Users/harrisonwilde/.julia/packages/Turing/azHIm/src/core/ad.jl:140
 [24] gradient_logp at /Users/harrisonwilde/.julia/packages/Turing/azHIm/src/core/ad.jl:73 [inlined]
 [25] ∂logπ∂θ at /Users/harrisonwilde/.julia/packages/Turing/azHIm/src/inference/hmc.jl:401 [inlined]
 [26] ∂H∂θ at /Users/harrisonwilde/.julia/packages/AdvancedHMC/haUrH/src/hamiltonian.jl:28 [inlined]
 [27] phasepoint at /Users/harrisonwilde/.julia/packages/AdvancedHMC/haUrH/src/hamiltonian.jl:59 [inlined]
 [28] #find_good_eps#6(::Int64, ::typeof(AdvancedHMC.find_good_eps), ::Random._GLOBAL_RNG, ::AdvancedHMC.Hamiltonian{AdvancedHMC.Adaptation.DiagEuclideanMetric{Float64,Array{Float64,1}},Turing.Inference.var"#logπ#49"{DynamicPPL.VarInfo{NamedTuple{(:coefs,),Tuple{DynamicPPL.Metadata{Dict{DynamicPPL.VarName{:coefs},Int64},Array{MvNormal{Float64,PDMats.PDiagMat{Float64,Array{Float64,1}},Array{Float64,1}},1},Array{DynamicPPL.VarName{:coefs},1},Array{Float64,1},Array{Set{DynamicPPL.Selector},1}}}},Float64},DynamicPPL.Sampler{NUTS{Turing.Core.TrackerAD,(),AdvancedHMC.Adaptation.DiagEuclideanMetric},Turing.Inference.SamplerState{DynamicPPL.VarInfo{NamedTuple{(:coefs,),Tuple{DynamicPPL.Metadata{Dict{DynamicPPL.VarName{:coefs},Int64},Array{MvNormal{Float64,PDMats.PDiagMat{Float64,Array{Float64,1}},Array{Float64,1}},1},Array{DynamicPPL.VarName{:coefs},1},Array{Float64,1},Array{Set{DynamicPPL.Selector},1}}}},Float64}}},DynamicPPL.Model{var"##inner_function#429#11",NamedTuple{(:data, :params),Tuple{Data{Array{Int64,2},Array{Float64,2},Array{Int64,2},Array{Float64,2}},WeightedKLDParams{Float64,Array{Float64,1},Diagonal{Float64,Array{Float64,1}}}}},DynamicPPL.ModelGen{(:data, :params),var"###weighted_logistic_regression#437",NamedTuple{(),Tuple{}}},Val{()}}},Turing.Inference.var"#∂logπ∂θ#48"{DynamicPPL.VarInfo{NamedTuple{(:coefs,),Tuple{DynamicPPL.Metadata{Dict{DynamicPPL.VarName{:coefs},Int64},Array{MvNormal{Float64,PDMats.PDiagMat{Float64,Array{Float64,1}},Array{Float64,1}},1},Array{DynamicPPL.VarName{:coefs},1},Array{Float64,1},Array{Set{DynamicPPL.Selector},1}}}},Float64},DynamicPPL.Sampler{NUTS{Turing.Core.TrackerAD,(),AdvancedHMC.Adaptation.DiagEuclideanMetric},Turing.Inference.SamplerState{DynamicPPL.VarInfo{NamedTuple{(:coefs,),Tuple{DynamicPPL.Metadata{Dict{DynamicPPL.VarName{:coefs},Int64},Array{MvNormal{Float64,PDMats.PDiagMat{Float64,Array{Float64,1}},Array{Float64,1}},1},Array{DynamicPPL.VarName{:coefs},1},Array{Float64,1},Array{Set{DynamicPPL.Selector},1}}}},Float64}}},DynamicPPL.Model{var"##inner_function#429#11",NamedTuple{(:data, :params),Tuple{Data{Array{Int64,2},Array{Float64,2},Array{Int64,2},Array{Float64,2}},WeightedKLDParams{Float64,Array{Float64,1},Diagonal{Float64,Array{Float64,1}}}}},DynamicPPL.ModelGen{(:data, :params),var"###weighted_logistic_regression#437",NamedTuple{(),Tuple{}}},Val{()}}}}, ::Array{Float64,1}) at /Users/harrisonwilde/.julia/packages/AdvancedHMC/haUrH/src/trajectory.jl:608
 [29] #find_good_eps at /Users/harrisonwilde/.julia/packages/Turing/azHIm/src/inference/hmc.jl:0 [inlined]
 [30] #find_good_eps#7 at /Users/harrisonwilde/.julia/packages/AdvancedHMC/haUrH/src/trajectory.jl:668 [inlined]
 [31] find_good_eps at /Users/harrisonwilde/.julia/packages/AdvancedHMC/haUrH/src/trajectory.jl:668 [inlined]
 [32] #HMCState#52(::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::Type{Turing.Inference.HMCState}, ::DynamicPPL.Model{var"##inner_function#429#11",NamedTuple{(:data, :params),Tuple{Data{Array{Int64,2},Array{Float64,2},Array{Int64,2},Array{Float64,2}},WeightedKLDParams{Float64,Array{Float64,1},Diagonal{Float64,Array{Float64,1}}}}},DynamicPPL.ModelGen{(:data, :params),var"###weighted_logistic_regression#437",NamedTuple{(),Tuple{}}},Val{()}}, ::DynamicPPL.Sampler{NUTS{Turing.Core.TrackerAD,(),AdvancedHMC.Adaptation.DiagEuclideanMetric},Turing.Inference.SamplerState{DynamicPPL.VarInfo{NamedTuple{(:coefs,),Tuple{DynamicPPL.Metadata{Dict{DynamicPPL.VarName{:coefs},Int64},Array{MvNormal{Float64,PDMats.PDiagMat{Float64,Array{Float64,1}},Array{Float64,1}},1},Array{DynamicPPL.VarName{:coefs},1},Array{Float64,1},Array{Set{DynamicPPL.Selector},1}}}},Float64}}}, ::Random._GLOBAL_RNG) at /Users/harrisonwilde/.julia/packages/Turing/azHIm/src/inference/hmc.jl:552
 [33] Turing.Inference.HMCState(::DynamicPPL.Model{var"##inner_function#429#11",NamedTuple{(:data, :params),Tuple{Data{Array{Int64,2},Array{Float64,2},Array{Int64,2},Array{Float64,2}},WeightedKLDParams{Float64,Array{Float64,1},Diagonal{Float64,Array{Float64,1}}}}},DynamicPPL.ModelGen{(:data, :params),var"###weighted_logistic_regression#437",NamedTuple{(),Tuple{}}},Val{()}}, ::DynamicPPL.Sampler{NUTS{Turing.Core.TrackerAD,(),AdvancedHMC.Adaptation.DiagEuclideanMetric},Turing.Inference.SamplerState{DynamicPPL.VarInfo{NamedTuple{(:coefs,),Tuple{DynamicPPL.Metadata{Dict{DynamicPPL.VarName{:coefs},Int64},Array{MvNormal{Float64,PDMats.PDiagMat{Float64,Array{Float64,1}},Array{Float64,1}},1},Array{DynamicPPL.VarName{:coefs},1},Array{Float64,1},Array{Set{DynamicPPL.Selector},1}}}},Float64}}}, ::Random._GLOBAL_RNG) at /Users/harrisonwilde/.julia/packages/Turing/azHIm/src/inference/hmc.jl:533
 [34] DynamicPPL.Sampler(::NUTS{Turing.Core.TrackerAD,(),AdvancedHMC.Adaptation.DiagEuclideanMetric}, ::DynamicPPL.Model{var"##inner_function#429#11",NamedTuple{(:data, :params),Tuple{Data{Array{Int64,2},Array{Float64,2},Array{Int64,2},Array{Float64,2}},WeightedKLDParams{Float64,Array{Float64,1},Diagonal{Float64,Array{Float64,1}}}}},DynamicPPL.ModelGen{(:data, :params),var"###weighted_logistic_regression#437",NamedTuple{(),Tuple{}}},Val{()}}, ::DynamicPPL.Selector) at /Users/harrisonwilde/.julia/packages/Turing/azHIm/src/inference/hmc.jl:310
 [35] Sampler at /Users/harrisonwilde/.julia/packages/Turing/azHIm/src/inference/hmc.jl:302 [inlined]
 [36] #sample#2(::Type, ::Nothing, ::Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}, ::typeof(sample), ::Random._GLOBAL_RNG, ::DynamicPPL.Model{var"##inner_function#429#11",NamedTuple{(:data, :params),Tuple{Data{Array{Int64,2},Array{Float64,2},Array{Int64,2},Array{Float64,2}},WeightedKLDParams{Float64,Array{Float64,1},Diagonal{Float64,Array{Float64,1}}}}},DynamicPPL.ModelGen{(:data, :params),var"###weighted_logistic_regression#437",NamedTuple{(),Tuple{}}},Val{()}}, ::NUTS{Turing.Core.TrackerAD,(),AdvancedHMC.Adaptation.DiagEuclideanMetric}, ::Int64) at /Users/harrisonwilde/.julia/packages/Turing/azHIm/src/inference/Inference.jl:149
 [37] sample at /Users/harrisonwilde/.julia/packages/Turing/azHIm/src/inference/Inference.jl:148 [inlined]
 [38] #sample#1 at /Users/harrisonwilde/.julia/packages/Turing/azHIm/src/inference/Inference.jl:136 [inlined]
 [39] sample(::DynamicPPL.Model{var"##inner_function#429#11",NamedTuple{(:data, :params),Tuple{Data{Array{Int64,2},Array{Float64,2},Array{Int64,2},Array{Float64,2}},WeightedKLDParams{Float64,Array{Float64,1},Diagonal{Float64,Array{Float64,1}}}}},DynamicPPL.ModelGen{(:data, :params),var"###weighted_logistic_regression#437",NamedTuple{(),Tuple{}}},Val{()}}, ::NUTS{Turing.Core.TrackerAD,(),AdvancedHMC.Adaptation.DiagEuclideanMetric}, ::Int64) at /Users/harrisonwilde/.julia/packages/Turing/azHIm/src/inference/Inference.jl:136
 [40] top-level scope at ./REPL[35]:14

from this model and code:

@model weighted_logistic_regression(data, params) = begin

	@unpack y_real, X_real, y_synth, X_synth = data
	@unpack w, μs, σs = params

    coefs ~ MvNormal(μs, σs)
    @logpdf() += sum(logpdf.(BinomialLogit.(1, X_real * coefs), y_real))
    @logpdf() += w * sum(logpdf.(BinomialLogit.(1, X_synth * coefs), y_synth))
end

...

len_real = size(real_train)[1]
len_synth = size(synth_train)[1]
len_test = size(real_test)[1]

Turing.setadbackend(:reverse_diff)
num_chains = 3
real_αs = [0.1, 0.25, 0.5]
synth_αs = [0.0, 0.05, 0.1, 0.25, 0.5]
for (real_α, synth_α) in get_conditional_pairs(real_αs, synth_αs)
	input_data = Data(
		Int.(Matrix(real_train[1:floor(Int32, len_real * real_α), labels])), 
		Matrix(real_train[1:floor(Int32, len_real * real_α), Not(labels)]), 
		Int.(Matrix(synth_train[1:floor(Int32, len_synth * synth_α), labels])), 
		Matrix(synth_train[1:floor(Int32, len_synth * synth_α), Not(labels)])
	)
	params = WeightedKLDParams(
		0.5,
		ones(size(real_train)[2] - size(labels)[1]),
		Diagonal(repeat([2.0], size(real_train)[2] - size(labels)[1]))
	)
	# weighted_chain = mapreduce(c -> sample(weighted_logistic_regression(input_data, weighted_params), DynamicNUTS(), 5000), chainscat, 1:num_chains)
	weighted_chain = sample(weighted_logistic_regression(input_data, params), NUTS(500, 0.651), 5000)
	write("weighted_chains_real" * string(real_α) * "_synth" * string(synth_α), weighted_chain)
	params = WeightedKLDParams(
		1.0,
		ones(size(real_train)[2] - size(labels)[1]),
		Diagonal(repeat([2.0], size(real_train)[2] - size(labels)[1]))
	)
	# weighted_chain = mapreduce(c -> sample(weighted_logistic_regression(input_data, kld_params), DynamicNUTS(), 5000), chainscat, 1:num_chains)
	naive_chain = sample(weighted_logistic_regression(input_data, params), NUTS(500, 0.651), 5000)
	write("naive_chains_real" * string(real_α) * "_synth" * string(synth_α), naive_chain)
end

Any ideas on how to get reverse diff working / is there a clear bit of doc anywhere to describe how to use it / what methods are supported?