Inspired by the recent long discussion on Slack around the performance of Turing, I’ve decided to try to beat my go-to implementation in Numpyro (and failed).

I’d appreciate any tips or tricks on how to further speed up - see section Findings below.

Also, there were several surprises that I cannot explain - see section Surprises.

**Scenario**: logistic regression with regularized horseshoe prior (link) with 20000 observations where only a few are relevant (3/23)

**Process**:

- Write several different parameterizations in Turing and a comparison one in Numpyro (called with PythonCall from Julia)
- Benchmark gradient evaluation (with the fantastic TuringBenchmarking.jl script from Torfjelde)
- Run a comparable single chain (500 adaption steps, 65% accept prob., max tree depth of 8, 100 samples)

**Findings**:

- Numpyro baseline is
**35 second**s (ala logreg_rhs2) (ess corresponds to Turing) - Fastest Turing implementation sampled in
**266 seconds**(logreg_rhs3, alternative parametrization, ReverseDiff+tape compile but also oddly high`eps`

step size) - Logprob call seems to dominate the evaluation time and the default is oddly slow - Out of the box version takes 1.25ms, custom implementation 0.3ms, Numpyro (via JAX) version takes 0.2ms

**Surprises**

- Alternative parametrization (logreg_rhs3) creates more variables (70ish), so ForwardDiff gets slower but ReverseDiff gets faster (easier to sample?)
- The fastest Turing implementation with ReverseDiff picked a weirdly high step size (eps=0.8) compared to ForwardDiff (same model); it happens on all subsequent runs
- Adding the intercept as a separate variable (
`alpha`

in model logreg_rhs3b) leads to a 10x slowdown for ReverseDiff (1.9ms to 19ms)

**Tested optimizations:**

- Turing performance tips: vectorized as much as possible, with and without LazyArray (someone mentioned some possible regression problems?)
- Other Turing tips: switch to @addlogprob, turning on/off logging, Dense and Diag metric, initialization with Pathfinder.jl
- Different AD backends (with the help of TuringBenchmarking.jl)
- Alternative parameterizations of RHS (variants C1 and C2 in the paper)
- Julia standard tricks: type stability of the
`@evaluate!!`

call and also of the`cond_model.f()`

call, reduced some allocations (not sure what else would be worth it without loss of readability) - Alternative inference packages: tried MuseInference.jl and BouncyParticle in ZigZagBoomerang.jl but couldn’t get either to work

Some of the results and the MWE are attached below.

**Model benchmarking: Fastest logreg_rhs3 with ReverseDiff 266 seconds**

(sampling calls are done with ForwardDiff unless stated otherwise)

```
## Define models
# Vanilla logistic regression for comparisons
@model function logreg_vanilla(X)
dim_X=size(X,2)
betas ~ filldist(Normal(),dim_X)
y ~ arraydist(LazyArray(@~ BernoulliLogit.(X*betas)))
return X*betas
end;
cond_model=logreg_vanilla(X) | (;y);
turing_suite=make_turing_suite(cond_model;adbackends=DEFAULT_ADBACKENDS) |> run
# 2-element BenchmarkTools.BenchmarkGroup:
# tags: []
# "linked" => 5-element BenchmarkTools.BenchmarkGroup:
# tags: []
# "ReverseDiffAD{false}()" => Trial(167.703 ms)
# "ForwardDiffAD{100, true}()" => Trial(5.079 ms)
# "evaluation" => Trial(1.334 ms)
# "ForwardDiffAD{40, true}()" => Trial(5.077 ms)
# "ReverseDiffAD{true}()" => Trial(52.727 ms)
# "not_linked" => 5-element BenchmarkTools.BenchmarkGroup:
# tags: []
# "ReverseDiffAD{false}()" => Trial(162.366 ms)
# "ForwardDiffAD{100, true}()" => Trial(5.082 ms)
# "evaluation" => Trial(1.354 ms)
# "ForwardDiffAD{40, true}()" => Trial(5.113 ms)
# "ReverseDiffAD{true}()" => Trial(63.573 ms)
@time chain_v=sample(cond_model, Turing.NUTS(500,0.65;max_depth=8),100;progress=true);
# ┌ Info: Found initial step size
# └ ϵ = 0.025
# Sampling 100%|███████████████████████████████████████████████████████████████████████| Time: 0:01:09
# 70.540812 seconds (7.91 M allocations: 32.934 GiB, 2.61% gc time, 4.52% compilation time: 96% of which was recompilation)
# Logistic regression + Regularized horseshoe Prior
# Based on: https://arxiv.org/pdf/1707.01694.pdf#page11 (Appendix C1)
@model function logreg_rhs(X)
slab_df=4
slab_scale=2
eff_params=3
dim_X=size(X,2)
tau_0=eff_params/(dim_X-eff_params)/sqrt(size(X,1))
lambdas ~ filldist(truncated(Cauchy();lower=0),dim_X)
tau ~ truncated(Cauchy(0,tau_0);lower=0)
z ~ filldist(Normal(),dim_X)
c_aux ~ InverseGamma(0.5 * slab_df, 0.5 * slab_df)
c_sq=slab_scale^2 * c_aux # squared already
lambdas_tilde=sqrt.((c_sq .* lambdas.^2) ./ (c_sq .+ tau.^2 .* lambdas.^2))
betas=lambdas_tilde .* tau .* z
y ~ arraydist(LazyArray(@~ BernoulliLogit.(X*betas)))
return X*betas
end;
cond_model=logreg_rhs(X) | (;y);
turing_suite=make_turing_suite(cond_model;adbackends=DEFAULT_ADBACKENDS) |> run
# 2-element BenchmarkTools.BenchmarkGroup:
# tags: []
# "linked" => 5-element BenchmarkTools.BenchmarkGroup:
# tags: []
# "ReverseDiffAD{false}()" => Trial(169.728 ms)
# "ForwardDiffAD{100, true}()" => Trial(22.894 ms)
# "evaluation" => Trial(1.230 ms)
# "ForwardDiffAD{40, true}()" => Trial(12.122 ms)
# "ReverseDiffAD{true}()" => Trial(62.633 ms)
# "not_linked" => 5-element BenchmarkTools.BenchmarkGroup:
# tags: []
# "ReverseDiffAD{false}()" => Trial(169.774 ms)
# "ForwardDiffAD{100, true}()" => Trial(23.266 ms)
# "evaluation" => Trial(1.241 ms)
# "ForwardDiffAD{40, true}()" => Trial(12.566 ms)
# "ReverseDiffAD{true}()" => Trial(49.232 ms)
@time chain_rhs1=sample(cond_model, Turing.NUTS(500,0.65;max_depth=8),100;progress=true);
# ┌ Info: Found initial step size
# └ ϵ = 0.025
# Sampling 100%|███████████████████████████████████████████████████████████████████████| Time: 0:58:42
# 3532.049475 seconds (69.87 M allocations: 4.075 TiB, 6.48% gc time, 0.50% compilation time)
# same model as logreg_rhs but changed logprob. accumulation for observed data
@model function logreg_rhs2(X,y)
slab_df=4
slab_scale=2
eff_params=3
dim_X=size(X,2)
tau_0=eff_params/(dim_X-eff_params)/sqrt(size(X,1))
lambdas ~ filldist(truncated(Cauchy();lower=0),dim_X)
tau ~ truncated(Cauchy(0,tau_0);lower=0)
z ~ filldist(Normal(),dim_X)
c_aux ~ InverseGamma(0.5 * slab_df, 0.5 * slab_df)
c_sq=slab_scale^2 * c_aux # squared already
lambdas_tilde=sqrt.((c_sq .* lambdas.^2) ./ (c_sq .+ tau.^2 .* lambdas.^2))
betas=lambdas_tilde .* tau .* z
# push logprob directly
Turing.@addlogprob! sum(logpdf.(BernoulliLogit.(X*betas),y))
return X*betas
end;
cond_model=logreg_rhs2(X,y);
turing_suite=make_turing_suite(cond_model;adbackends=DEFAULT_ADBACKENDS) |> run
# 2-element BenchmarkTools.BenchmarkGroup:
# tags: []
# "linked" => 5-element BenchmarkTools.BenchmarkGroup:
# tags: []
# "ReverseDiffAD{false}()" => Trial(4.037 ms)
# "ForwardDiffAD{100, true}()" => Trial(24.744 ms)
# "evaluation" => Trial(1.266 ms)
# "ForwardDiffAD{40, true}()" => Trial(14.086 ms)
# "ReverseDiffAD{true}()" => Trial(3.124 ms)
# "not_linked" => 5-element BenchmarkTools.BenchmarkGroup:
# tags: []
# "ReverseDiffAD{false}()" => Trial(4.111 ms)
# "ForwardDiffAD{100, true}()" => Trial(24.957 ms)
# "evaluation" => Trial(1.267 ms)
# "ForwardDiffAD{40, true}()" => Trial(14.871 ms)
# "ReverseDiffAD{true}()" => Trial(3.120 ms)
@time chain_rhs2=sample(cond_model, Turing.NUTS(500,0.65;max_depth=8),100;progress=true);
# ┌ Info: Found initial step size
# └ ϵ = 0.0125
# Sampling 100%|███████████████████████████████████████████████████████████████████████| Time: 0:37:35
# 2261.932322 seconds (41.61 M allocations: 2.624 TiB, 6.30% gc time, 0.52% compilation time: 40% of which was recompilation)
# Logistic regression + Regularized horseshoe Prior
# Based on: https://arxiv.org/pdf/1707.01694.pdf#page11 (Appendix C2)
@model function logreg_rhs3(X,y)
slab_df=4
slab_scale=2
eff_params=3
dim_X=size(X,2)
tau_0=eff_params/(dim_X-eff_params)/sqrt(size(X,1))
# re-parametrize lambdas and lambdas squared directly to plug into lambas_tilde
lambdas_scale ~ filldist(InverseGamma(0.5,0.5),dim_X)
lambdas_z ~ filldist(truncated(Normal();lower=0),dim_X)
lambdas_sq= lambdas_z.^2 .* lambdas_scale
# re-parametrize tau and and square it directly to plug into lambas_tilde
tau_scale ~ InverseGamma(0.5,0.5)
tau_z ~ truncated(Normal();lower=0)
tau_sq = tau_z^2 .*tau_scale .* tau_0^2
z ~ filldist(Normal(),dim_X)
c_aux ~ InverseGamma(0.5 * slab_df, 0.5 * slab_df)
c_sq=slab_scale^2 * c_aux # squared already
lambdas_tilde=sqrt.((c_sq .* lambdas_sq) ./ (c_sq .+ tau_sq .* lambdas_sq))
betas=lambdas_tilde .* sqrt(tau_sq) .* z
# push logprob directly
Turing.@addlogprob! sum(logpdf.(BernoulliLogit.(X*betas),y))
return X*betas
end;
cond_model=logreg_rhs3(X,y);
turing_suite=make_turing_suite(cond_model;adbackends=DEFAULT_ADBACKENDS) |> run
# 2-element BenchmarkTools.BenchmarkGroup:
# tags: []
# "linked" => 5-element BenchmarkTools.BenchmarkGroup:
# tags: []
# "ReverseDiffAD{false}()" => Trial(2.559 ms)
# "ForwardDiffAD{100, true}()" => Trial(40.301 ms)
# "evaluation" => Trial(1.269 ms)
# "ForwardDiffAD{40, true}()" => Trial(39.164 ms)
# "ReverseDiffAD{true}()" => Trial(1.901 ms)
# "not_linked" => 5-element BenchmarkTools.BenchmarkGroup:
# tags: []
# "ReverseDiffAD{false}()" => Trial(2.658 ms)
# "ForwardDiffAD{100, true}()" => Trial(40.992 ms)
# "evaluation" => Trial(1.269 ms)
# "ForwardDiffAD{40, true}()" => Trial(39.619 ms)
# "ReverseDiffAD{true}()" => Trial(1.885 ms)
@time chain_rhs3=sample(cond_model, Turing.NUTS(500,0.65;max_depth=8),100;progress=true);
# ┌ Info: Found initial step size
# └ ϵ = 0.025
# Sampling 100%|███████████████████████████████████████████████████████████████████████| Time: 0:58:42
# 3532.049475 seconds (69.87 M allocations: 4.075 TiB, 6.48% gc time, 0.50% compilation time)
# Given that ReverseDiff was so fast, I sampled also with ReverseDiff
Turing.setadbackend(:reversediff)
Turing.setrdcache(true)
@time chain_rhs3alt=sample(cond_model, Turing.NUTS(500,0.65;max_depth=8),100;progress=true);
# ┌ Info: Found initial step size
# └ ϵ = 0.8
# Sampling 100%|███████████████████████████████████████████████████████████████████████| Time: 0:04:21
# 266.116006 seconds (185.96 M allocations: 8.047 GiB, 0.35% gc time, 3.71% compilation time)
# Last iteration - what if intercept is a separate random variable (alpha)
@model function logreg_rhs3b(X,y)
slab_df=4
slab_scale=2
eff_params=3
dim_X=size(X,2)
tau_0=eff_params/(dim_X-eff_params)/sqrt(size(X,1))
# re-parametrize lambdas and lambdas squared directly to plug into lambas_tilde
lambdas_scale ~ filldist(InverseGamma(0.5,0.5),dim_X)
lambdas_z ~ filldist(truncated(Normal();lower=0),dim_X)
lambdas_sq= lambdas_z.^2 .* lambdas_scale
# re-parametrize tau and and square it directly to plug into lambas_tilde
tau_scale ~ InverseGamma(0.5,0.5)
tau_z ~ truncated(Normal();lower=0)
tau_sq = tau_z^2 .*tau_scale .* tau_0^2
z ~ filldist(Normal(),dim_X)
c_aux ~ InverseGamma(0.5 * slab_df, 0.5 * slab_df)
c_sq=slab_scale^2 * c_aux # squared already
lambdas_tilde=sqrt.((c_sq .* lambdas_sq) ./ (c_sq .+ tau_sq .* lambdas_sq))
betas=lambdas_tilde .* sqrt(tau_sq) .* z
# define intercept as a separate random variable
alpha ~ Normal(0,1)
# push logprob directly
Turing.@addlogprob! sum(logpdf.(BernoulliLogit.(alpha.+X*betas),y))
return X*betas
end;
cond_model=logreg_rhs3b(X[:,2:end],y);
turing_suite=make_turing_suite(cond_model;adbackends=DEFAULT_ADBACKENDS) |> run
# 2-element BenchmarkTools.BenchmarkGroup:
# tags: []
# "linked" => 5-element BenchmarkTools.BenchmarkGroup:
# tags: []
# "ReverseDiffAD{false}()" => Trial(20.004 ms)
# "ForwardDiffAD{100, true}()" => Trial(37.831 ms)
# "evaluation" => Trial(1.272 ms)
# "ForwardDiffAD{40, true}()" => Trial(35.002 ms)
# "ReverseDiffAD{true}()" => Trial(19.039 ms)
# "not_linked" => 5-element BenchmarkTools.BenchmarkGroup:
# tags: []
# "ReverseDiffAD{false}()" => Trial(19.219 ms)
# "ForwardDiffAD{100, true}()" => Trial(37.483 ms)
# "evaluation" => Trial(1.273 ms)
# "ForwardDiffAD{40, true}()" => Trial(35.279 ms)
# "ReverseDiffAD{true}()" => Trial(19.077 ms)
```

**Numpyro model: 35 seconds**

```
# Numpyro model
using CondaPkg
CondaPkg.add("numpyro")
using PythonCall
jnp=pyimport("jax.numpy")
# Define model and do a test run
(mcmc,posterior_samples)=@pyexec (X=jnp.array(X),y=jnp.array(y)) => """
global numpyro,dist,jnp,random,MCMC,NUTS,Predictive;
global logreg_rhs2py;
import numpyro
import numpyro.distributions as dist
import jax.numpy as jnp
import jax.random as random
from numpyro.infer import MCMC, NUTS, Predictive
import time
def logreg_rhs2py(X,y=None):
dim_X = X.shape[1]
slab_df=4
slab_scale=2
eff_params=3
tau_0=eff_params/(dim_X-eff_params)/jnp.sqrt(X.shape[0])
lambdas = numpyro.sample("lambdas", dist.HalfCauchy(jnp.ones(dim_X)))
tau = numpyro.sample("tau", dist.HalfCauchy(jnp.ones(1)))
c_aux = numpyro.sample("c_aux", dist.InverseGamma(0.5*slab_df,0.5*slab_df))
c_sq=slab_scale**2 * c_aux
lambdas_tilde=jnp.sqrt((c_sq * lambdas**2) / (c_sq + tau**2 * lambdas**2))
z = numpyro.sample("z", dist.Normal(0.0, jnp.ones(dim_X)))
betas = numpyro.deterministic("betas", tau * lambdas_tilde * z)
logits = numpyro.deterministic("logits",jnp.dot(X, betas))
y_obs=numpyro.sample("y", dist.Bernoulli(logits=logits), obs=y)
return y_obs
# test correct shapes of implementation
with numpyro.handlers.seed(rng_seed=0):
temp=logreg_rhs2py(X,y)
print("output shape: ",temp.shape)
print("output proba: ",temp.mean())
# Fit NUTS
rng_key = random.PRNGKey(1)
start = time.time()
kernel = NUTS(logreg_rhs2py,target_accept_prob=0.65,max_tree_depth=8)
mcmc = MCMC(kernel, num_warmup=500, num_samples=100,num_chains=1,jit_model_args=True)
mcmc.run(rng_key, X,y,extra_fields=("accept_prob",'num_steps','diverging'))
print("Mean accept prob:", jnp.mean(mcmc.get_extra_fields()["accept_prob"]))
print(f'Number of divergences: {mcmc.get_extra_fields()["diverging"].sum():.0f};')
print(f'Number of transitions: {(mcmc.get_extra_fields()["num_steps"] + 1 == 2 ** 10).sum():.0f} that exceeded the maximum treedepth.')
mcmc.print_summary(exclude_deterministic=True)
# SAMPLES
posterior_samples = mcmc.get_samples()
print("MCMC elapsed time:", time.time() - start)
""" => (mcmc,posterior_samples)
# Number of divergences: 3
# MCMC elapsed time: 34.69472002983093
# After inspecting the chains (mcmc.print_summary()), it is comparable to Turing.jl results (there were also divergences -- chains are too short)
```

**Logprob benchmarking: Jax at 200μs, custom at 320μs, Turing default 1.25ms!**

```
# Benchmark logprob call itself in Julia
@btime sum(logpdf.(BernoulliLogit.($X*$betas),$y));
# 1.251 ms (4 allocations: 312.59 KiB)
# custom implementation in Julia
# from https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits
jax_logpdf(x,y)=-(max(0,x) +log1pexp(-abs(x)) -x*y)
@btime sum(jax_logpdf.($X*$betas,$y));
# 324.208 μs (4 allocations: 312.59 KiB)
# logprob evaluation with Numpyro/Jax
@time @pyexec (X=jnp.array(X),betas=jnp.array(betas),y=jnp.array(y)) => """
import time
start = time.time()
ll=dist.Bernoulli(logits=jnp.dot(X,betas)).log_prob(y).sum()
print("Elapsed time (ms):", (time.time() - start)*1000)
print(ll)
""";
# Elapsed time (ms): 0.1988410949707031
# loglik value matches Julia
```

**Setup script and data generation**

```
using DataFramesMeta
using Random
using Turing
using LogExpFunctions: logit, logistic
using LazyArrays
using LogDensityProblems
import ReverseDiff
import ForwardDiff
using Turing.Essential: ForwardDiffAD, TrackerAD, ReverseDiffAD, ZygoteAD, CHUNKSIZE
# Code below for benchmarking gradients copied from https://github.com/torfjelde/TuringBenchmarking.jl
const DEFAULT_ADBACKENDS = [
ForwardDiffAD{40}(), # chunksize=40
ForwardDiffAD{100}(), # chunksize=100
ReverseDiffAD{false}(), # rdcache=false
ReverseDiffAD{true}() # rdcache=false
]
# Code below for benchmarking gradients copied from https://github.com/torfjelde/TuringBenchmarking.jl
"""
make_turing_suite(model; kwargs...)
Create default benchmark suite for `model`.
# Keyword arguments
- `adbackends`: a collection of adbackends to use. Defaults to `$(DEFAULT_ADBACKENDS)`.
- `run_once=true`: if `true`, the body of each benchmark will be run once to avoid
compilation to be included in the timings (this may occur if compilation runs
longer than the allowed time limit).
- `save_grads=false`: if `true` and `run_once` is `true`, the gradients from the initial
execution will be saved and returned as the second return-value. This is useful if you
want to check correctness of the gradients for different backends.
# Notes
- A separate "parameter" instance (`DynamicPPL.VarInfo`) will be created for _each test_.
Hence if you have a particularly large model, you might want to only pass one `adbackend`
at the time.
"""
function make_turing_suite(
model::DynamicPPL.Model;
adbackends = DEFAULT_ADBACKENDS, run_once = true, save_grads = false
)
suite = BenchmarkGroup()
suite["not_linked"] = BenchmarkGroup()
suite["linked"] = BenchmarkGroup()
grads = Dict(:not_linked => Dict(), :linked => Dict())
vi_orig = DynamicPPL.VarInfo(model)
spl = DynamicPPL.SampleFromPrior()
for adbackend in adbackends
vi = DynamicPPL.VarInfo(vi_orig, spl, vi_orig[spl])
f = LogDensityProblems.ADgradient(
adbackend,
Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext())
)
θ = vi[spl]
if run_once
ℓ, ∇ℓ = LogDensityProblems.logdensity_and_gradient(f, θ)
if save_grads
grads[:not_linked][adbackend] = (ℓ, ∇ℓ)
end
end
suite["not_linked"]["$(adbackend)"] = @benchmarkable $(LogDensityProblems.logdensity_and_gradient)($f, $θ)
# Need a separate `VarInfo` for the linked version since otherwise we risk the
# `vi` from above being mutated.
vi_linked = deepcopy(vi)
DynamicPPL.link!(vi_linked, spl)
f_linked = LogDensityProblems.ADgradient(
adbackend,
Turing.LogDensityFunction(vi_linked, model, spl, DynamicPPL.DefaultContext())
)
θ_linked = vi_linked[spl]
if run_once
ℓ, ∇ℓ = LogDensityProblems.logdensity_and_gradient(f_linked, θ_linked)
if save_grads
grads[:linked][adbackend] = (ℓ, ∇ℓ)
end
end
suite["linked"]["$(adbackend)"] = @benchmarkable $(LogDensityProblems.logdensity_and_gradient)($f_linked, $θ_linked)
end
# Also benchmark just standard model evaluation because why not.
suite["not_linked"]["evaluation"] = @benchmarkable $(DynamicPPL.evaluate!!)($model, $vi_orig, $(DynamicPPL.DefaultContext()))
DynamicPPL.link!(vi_orig, spl)
suite["linked"]["evaluation"] = @benchmarkable $(DynamicPPL.evaluate!!)($model, $vi_orig, $(DynamicPPL.DefaultContext()))
return save_grads ? (suite, grads) : suite
end
# Generate mock data
X=hcat(ones(20000),randn(20000,22))
betas=vcat([-0.8],zeros(3),[1],zeros(5),[1],zeros(12))
@assert length(betas)==size(X,2)
y=(X*betas) .|> x->rand(BernoulliLogit(x))
@info "Important coef positions: $(findall(betas.!=0)), Average rate: $(mean(y))"
```