Case study: Speeding up a logistic regression with RHS prior (Turing vs Numpyro) - any tricks I'm missing?

Inspired by the recent long discussion on Slack around the performance of Turing, I’ve decided to try to beat my go-to implementation in Numpyro (and failed).

I’d appreciate any tips or tricks on how to further speed up - see section Findings below.
Also, there were several surprises that I cannot explain - see section Surprises.

Scenario: logistic regression with regularized horseshoe prior (link) with 20000 observations where only a few are relevant (3/23)

Process:

  • Write several different parameterizations in Turing and a comparison one in Numpyro (called with PythonCall from Julia)
  • Benchmark gradient evaluation (with the fantastic TuringBenchmarking.jl script from Torfjelde)
  • Run a comparable single chain (500 adaption steps, 65% accept prob., max tree depth of 8, 100 samples)

Findings:

  • Numpyro baseline is 35 seconds (ala logreg_rhs2) (ess corresponds to Turing)
  • Fastest Turing implementation sampled in 266 seconds (logreg_rhs3, alternative parametrization, ReverseDiff+tape compile but also oddly high eps step size)
  • Logprob call seems to dominate the evaluation time and the default is oddly slow - Out of the box version takes 1.25ms, custom implementation 0.3ms, Numpyro (via JAX) version takes 0.2ms

Surprises

  • Alternative parametrization (logreg_rhs3) creates more variables (70ish), so ForwardDiff gets slower but ReverseDiff gets faster (easier to sample?)
  • The fastest Turing implementation with ReverseDiff picked a weirdly high step size (eps=0.8) compared to ForwardDiff (same model); it happens on all subsequent runs
  • Adding the intercept as a separate variable (alpha in model logreg_rhs3b) leads to a 10x slowdown for ReverseDiff (1.9ms to 19ms)

Tested optimizations:

  • Turing performance tips: vectorized as much as possible, with and without LazyArray (someone mentioned some possible regression problems?)
  • Other Turing tips: switch to @addlogprob, turning on/off logging, Dense and Diag metric, initialization with Pathfinder.jl
  • Different AD backends (with the help of TuringBenchmarking.jl)
  • Alternative parameterizations of RHS (variants C1 and C2 in the paper)
  • Julia standard tricks: type stability of the @evaluate!! call and also of the cond_model.f() call, reduced some allocations (not sure what else would be worth it without loss of readability)
  • Alternative inference packages: tried MuseInference.jl and BouncyParticle in ZigZagBoomerang.jl but couldn’t get either to work

Some of the results and the MWE are attached below.

Model benchmarking: Fastest logreg_rhs3 with ReverseDiff 266 seconds
(sampling calls are done with ForwardDiff unless stated otherwise)

## Define models
# Vanilla logistic regression for comparisons
@model function logreg_vanilla(X)
    dim_X=size(X,2)
    betas ~ filldist(Normal(),dim_X)
    y ~ arraydist(LazyArray(@~ BernoulliLogit.(X*betas)))
    return X*betas
end;
cond_model=logreg_vanilla(X) | (;y);
turing_suite=make_turing_suite(cond_model;adbackends=DEFAULT_ADBACKENDS) |> run
# 2-element BenchmarkTools.BenchmarkGroup:
#   tags: []
#   "linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(167.703 ms)
#           "ForwardDiffAD{100, true}()" => Trial(5.079 ms)
#           "evaluation" => Trial(1.334 ms)
#           "ForwardDiffAD{40, true}()" => Trial(5.077 ms)
#           "ReverseDiffAD{true}()" => Trial(52.727 ms)
#   "not_linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(162.366 ms)
#           "ForwardDiffAD{100, true}()" => Trial(5.082 ms)
#           "evaluation" => Trial(1.354 ms)
#           "ForwardDiffAD{40, true}()" => Trial(5.113 ms)
#           "ReverseDiffAD{true}()" => Trial(63.573 ms)

@time chain_v=sample(cond_model, Turing.NUTS(500,0.65;max_depth=8),100;progress=true);
# ┌ Info: Found initial step size
# └   ϵ = 0.025
# Sampling 100%|███████████████████████████████████████████████████████████████████████| Time: 0:01:09
#  70.540812 seconds (7.91 M allocations: 32.934 GiB, 2.61% gc time, 4.52% compilation time: 96% of which was recompilation)


# Logistic regression + Regularized horseshoe Prior
# Based on: https://arxiv.org/pdf/1707.01694.pdf#page11 (Appendix C1)
@model function logreg_rhs(X)
    slab_df=4
    slab_scale=2
    eff_params=3
    dim_X=size(X,2)
    tau_0=eff_params/(dim_X-eff_params)/sqrt(size(X,1))

    lambdas ~ filldist(truncated(Cauchy();lower=0),dim_X)
    tau ~ truncated(Cauchy(0,tau_0);lower=0)
    z ~ filldist(Normal(),dim_X)

    c_aux ~ InverseGamma(0.5 * slab_df, 0.5 * slab_df)
    c_sq=slab_scale^2 * c_aux # squared already
    lambdas_tilde=sqrt.((c_sq .* lambdas.^2) ./ (c_sq .+ tau.^2 .* lambdas.^2))
    betas=lambdas_tilde .* tau .* z
    y ~ arraydist(LazyArray(@~ BernoulliLogit.(X*betas)))
    return X*betas
end;
cond_model=logreg_rhs(X) | (;y);
turing_suite=make_turing_suite(cond_model;adbackends=DEFAULT_ADBACKENDS) |> run
# 2-element BenchmarkTools.BenchmarkGroup:
#   tags: []
#   "linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(169.728 ms)
#           "ForwardDiffAD{100, true}()" => Trial(22.894 ms)
#           "evaluation" => Trial(1.230 ms)
#           "ForwardDiffAD{40, true}()" => Trial(12.122 ms)
#           "ReverseDiffAD{true}()" => Trial(62.633 ms)
#   "not_linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(169.774 ms)
#           "ForwardDiffAD{100, true}()" => Trial(23.266 ms)
#           "evaluation" => Trial(1.241 ms)
#           "ForwardDiffAD{40, true}()" => Trial(12.566 ms)
#           "ReverseDiffAD{true}()" => Trial(49.232 ms)
@time chain_rhs1=sample(cond_model, Turing.NUTS(500,0.65;max_depth=8),100;progress=true);
# ┌ Info: Found initial step size
# └   ϵ = 0.025
# Sampling 100%|███████████████████████████████████████████████████████████████████████| Time: 0:58:42
# 3532.049475 seconds (69.87 M allocations: 4.075 TiB, 6.48% gc time, 0.50% compilation time)


# same model as logreg_rhs but changed logprob. accumulation for observed data
@model function logreg_rhs2(X,y)
    slab_df=4
    slab_scale=2
    eff_params=3
    dim_X=size(X,2)
    tau_0=eff_params/(dim_X-eff_params)/sqrt(size(X,1))

    lambdas ~ filldist(truncated(Cauchy();lower=0),dim_X)
    tau ~ truncated(Cauchy(0,tau_0);lower=0)
    z ~ filldist(Normal(),dim_X)

    c_aux ~ InverseGamma(0.5 * slab_df, 0.5 * slab_df)
    c_sq=slab_scale^2 * c_aux # squared already
    lambdas_tilde=sqrt.((c_sq .* lambdas.^2) ./ (c_sq .+ tau.^2 .* lambdas.^2))
    betas=lambdas_tilde .* tau .* z
    # push logprob directly
    Turing.@addlogprob! sum(logpdf.(BernoulliLogit.(X*betas),y))
    return X*betas
end;
cond_model=logreg_rhs2(X,y);
turing_suite=make_turing_suite(cond_model;adbackends=DEFAULT_ADBACKENDS) |> run
# 2-element BenchmarkTools.BenchmarkGroup:
#   tags: []
#   "linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(4.037 ms)
#           "ForwardDiffAD{100, true}()" => Trial(24.744 ms)
#           "evaluation" => Trial(1.266 ms)
#           "ForwardDiffAD{40, true}()" => Trial(14.086 ms)
#           "ReverseDiffAD{true}()" => Trial(3.124 ms)
#   "not_linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(4.111 ms)
#           "ForwardDiffAD{100, true}()" => Trial(24.957 ms)
#           "evaluation" => Trial(1.267 ms)
#           "ForwardDiffAD{40, true}()" => Trial(14.871 ms)
#           "ReverseDiffAD{true}()" => Trial(3.120 ms)
@time chain_rhs2=sample(cond_model, Turing.NUTS(500,0.65;max_depth=8),100;progress=true);
# ┌ Info: Found initial step size
# └   ϵ = 0.0125
# Sampling 100%|███████████████████████████████████████████████████████████████████████| Time: 0:37:35
# 2261.932322 seconds (41.61 M allocations: 2.624 TiB, 6.30% gc time, 0.52% compilation time: 40% of which was recompilation)



# Logistic regression + Regularized horseshoe Prior
# Based on: https://arxiv.org/pdf/1707.01694.pdf#page11 (Appendix C2)
@model function logreg_rhs3(X,y)
    slab_df=4
    slab_scale=2
    eff_params=3
    dim_X=size(X,2)
    tau_0=eff_params/(dim_X-eff_params)/sqrt(size(X,1))

    # re-parametrize lambdas and lambdas squared directly to plug into lambas_tilde
    lambdas_scale ~ filldist(InverseGamma(0.5,0.5),dim_X)
    lambdas_z ~ filldist(truncated(Normal();lower=0),dim_X)
    lambdas_sq= lambdas_z.^2 .* lambdas_scale

    # re-parametrize tau and and square it directly to plug into lambas_tilde
    tau_scale ~ InverseGamma(0.5,0.5)
    tau_z ~ truncated(Normal();lower=0)
    tau_sq = tau_z^2 .*tau_scale .* tau_0^2
    z ~ filldist(Normal(),dim_X)

    c_aux ~ InverseGamma(0.5 * slab_df, 0.5 * slab_df)
    c_sq=slab_scale^2 * c_aux # squared already
    lambdas_tilde=sqrt.((c_sq .* lambdas_sq) ./ (c_sq .+ tau_sq .* lambdas_sq))
    betas=lambdas_tilde .* sqrt(tau_sq) .* z

    # push logprob directly
    Turing.@addlogprob! sum(logpdf.(BernoulliLogit.(X*betas),y))
    return X*betas
end;
cond_model=logreg_rhs3(X,y);
turing_suite=make_turing_suite(cond_model;adbackends=DEFAULT_ADBACKENDS) |> run
# 2-element BenchmarkTools.BenchmarkGroup:
#   tags: []
#   "linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(2.559 ms)
#           "ForwardDiffAD{100, true}()" => Trial(40.301 ms)
#           "evaluation" => Trial(1.269 ms)
#           "ForwardDiffAD{40, true}()" => Trial(39.164 ms)
#           "ReverseDiffAD{true}()" => Trial(1.901 ms)
#   "not_linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(2.658 ms)
#           "ForwardDiffAD{100, true}()" => Trial(40.992 ms)
#           "evaluation" => Trial(1.269 ms)
#           "ForwardDiffAD{40, true}()" => Trial(39.619 ms)
#           "ReverseDiffAD{true}()" => Trial(1.885 ms)
@time chain_rhs3=sample(cond_model, Turing.NUTS(500,0.65;max_depth=8),100;progress=true);
# ┌ Info: Found initial step size
# └   ϵ = 0.025
# Sampling 100%|███████████████████████████████████████████████████████████████████████| Time: 0:58:42
# 3532.049475 seconds (69.87 M allocations: 4.075 TiB, 6.48% gc time, 0.50% compilation time)

# Given that ReverseDiff was so fast, I sampled also with ReverseDiff
Turing.setadbackend(:reversediff)
Turing.setrdcache(true)
@time chain_rhs3alt=sample(cond_model, Turing.NUTS(500,0.65;max_depth=8),100;progress=true);
# ┌ Info: Found initial step size
# └   ϵ = 0.8
# Sampling 100%|███████████████████████████████████████████████████████████████████████| Time: 0:04:21
# 266.116006 seconds (185.96 M allocations: 8.047 GiB, 0.35% gc time, 3.71% compilation time)


# Last iteration - what if intercept is a separate random variable (alpha)
@model function logreg_rhs3b(X,y)
    slab_df=4
    slab_scale=2
    eff_params=3
    dim_X=size(X,2)
    tau_0=eff_params/(dim_X-eff_params)/sqrt(size(X,1))

    # re-parametrize lambdas and lambdas squared directly to plug into lambas_tilde
    lambdas_scale ~ filldist(InverseGamma(0.5,0.5),dim_X)
    lambdas_z ~ filldist(truncated(Normal();lower=0),dim_X)
    lambdas_sq= lambdas_z.^2 .* lambdas_scale

    # re-parametrize tau and and square it directly to plug into lambas_tilde
    tau_scale ~ InverseGamma(0.5,0.5)
    tau_z ~ truncated(Normal();lower=0)
    tau_sq = tau_z^2 .*tau_scale .* tau_0^2
    z ~ filldist(Normal(),dim_X)

    c_aux ~ InverseGamma(0.5 * slab_df, 0.5 * slab_df)
    c_sq=slab_scale^2 * c_aux # squared already
    lambdas_tilde=sqrt.((c_sq .* lambdas_sq) ./ (c_sq .+ tau_sq .* lambdas_sq))
    betas=lambdas_tilde .* sqrt(tau_sq) .* z

    # define intercept as a separate random variable
    alpha ~ Normal(0,1)
    # push logprob directly
    Turing.@addlogprob! sum(logpdf.(BernoulliLogit.(alpha.+X*betas),y))
    return X*betas
end;
cond_model=logreg_rhs3b(X[:,2:end],y);
turing_suite=make_turing_suite(cond_model;adbackends=DEFAULT_ADBACKENDS) |> run
# 2-element BenchmarkTools.BenchmarkGroup:
#   tags: []
#   "linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(20.004 ms)
#           "ForwardDiffAD{100, true}()" => Trial(37.831 ms)
#           "evaluation" => Trial(1.272 ms)
#           "ForwardDiffAD{40, true}()" => Trial(35.002 ms)
#           "ReverseDiffAD{true}()" => Trial(19.039 ms)
#   "not_linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(19.219 ms)
#           "ForwardDiffAD{100, true}()" => Trial(37.483 ms)
#           "evaluation" => Trial(1.273 ms)
#           "ForwardDiffAD{40, true}()" => Trial(35.279 ms)
#           "ReverseDiffAD{true}()" => Trial(19.077 ms)

Numpyro model: 35 seconds

# Numpyro model
using CondaPkg
CondaPkg.add("numpyro")

using PythonCall
jnp=pyimport("jax.numpy")

# Define model and do a test run
(mcmc,posterior_samples)=@pyexec (X=jnp.array(X),y=jnp.array(y)) => """
global numpyro,dist,jnp,random,MCMC,NUTS,Predictive;
global logreg_rhs2py;

import numpyro
import numpyro.distributions as dist
import jax.numpy as jnp
import jax.random as random
from numpyro.infer import MCMC, NUTS, Predictive
import time

def logreg_rhs2py(X,y=None):
    dim_X = X.shape[1]
    slab_df=4
    slab_scale=2
    eff_params=3
    tau_0=eff_params/(dim_X-eff_params)/jnp.sqrt(X.shape[0])

    lambdas = numpyro.sample("lambdas", dist.HalfCauchy(jnp.ones(dim_X)))
    tau = numpyro.sample("tau", dist.HalfCauchy(jnp.ones(1)))

    c_aux = numpyro.sample("c_aux", dist.InverseGamma(0.5*slab_df,0.5*slab_df))
    c_sq=slab_scale**2 * c_aux
    lambdas_tilde=jnp.sqrt((c_sq * lambdas**2) / (c_sq + tau**2 * lambdas**2))
    z = numpyro.sample("z", dist.Normal(0.0, jnp.ones(dim_X)))
    betas = numpyro.deterministic("betas", tau * lambdas_tilde * z)

    logits = numpyro.deterministic("logits",jnp.dot(X, betas))
    y_obs=numpyro.sample("y", dist.Bernoulli(logits=logits), obs=y)

    return y_obs

# test correct shapes of implementation
with numpyro.handlers.seed(rng_seed=0):
    temp=logreg_rhs2py(X,y)
    print("output shape: ",temp.shape)
    print("output proba: ",temp.mean())

# Fit NUTS
rng_key = random.PRNGKey(1)
start = time.time()
kernel = NUTS(logreg_rhs2py,target_accept_prob=0.65,max_tree_depth=8)
mcmc = MCMC(kernel, num_warmup=500, num_samples=100,num_chains=1,jit_model_args=True)
mcmc.run(rng_key, X,y,extra_fields=("accept_prob",'num_steps','diverging'))

print("Mean accept prob:", jnp.mean(mcmc.get_extra_fields()["accept_prob"]))
print(f'Number of divergences: {mcmc.get_extra_fields()["diverging"].sum():.0f};')
print(f'Number of transitions: {(mcmc.get_extra_fields()["num_steps"] + 1 == 2 ** 10).sum():.0f} that exceeded the maximum treedepth.')
mcmc.print_summary(exclude_deterministic=True)
# SAMPLES
posterior_samples = mcmc.get_samples()
print("MCMC elapsed time:", time.time() - start)
""" => (mcmc,posterior_samples)


# Number of divergences: 3
# MCMC elapsed time: 34.69472002983093
# After inspecting the chains (mcmc.print_summary()), it is comparable to Turing.jl results (there were also divergences -- chains are too short)

Logprob benchmarking: Jax at 200μs, custom at 320μs, Turing default 1.25ms!

# Benchmark logprob call itself in Julia
@btime sum(logpdf.(BernoulliLogit.($X*$betas),$y));
# 1.251 ms (4 allocations: 312.59 KiB)

# custom implementation in Julia
# from https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits
jax_logpdf(x,y)=-(max(0,x) +log1pexp(-abs(x)) -x*y)

@btime sum(jax_logpdf.($X*$betas,$y));
#  324.208 μs (4 allocations: 312.59 KiB)

# logprob evaluation with Numpyro/Jax
@time @pyexec (X=jnp.array(X),betas=jnp.array(betas),y=jnp.array(y)) => """
import time
start = time.time()
ll=dist.Bernoulli(logits=jnp.dot(X,betas)).log_prob(y).sum()
print("Elapsed time (ms):", (time.time() - start)*1000)
print(ll)
""";
# Elapsed time (ms): 0.1988410949707031
# loglik value matches Julia

Setup script and data generation

using DataFramesMeta
using Random
using Turing
using LogExpFunctions: logit, logistic
using LazyArrays
using LogDensityProblems
import ReverseDiff
import ForwardDiff
using Turing.Essential: ForwardDiffAD, TrackerAD, ReverseDiffAD, ZygoteAD, CHUNKSIZE

# Code below for benchmarking gradients copied from https://github.com/torfjelde/TuringBenchmarking.jl
const DEFAULT_ADBACKENDS = [
    ForwardDiffAD{40}(),    # chunksize=40
    ForwardDiffAD{100}(),   # chunksize=100
    ReverseDiffAD{false}(), # rdcache=false
    ReverseDiffAD{true}()   # rdcache=false
]

# Code below for benchmarking gradients copied from https://github.com/torfjelde/TuringBenchmarking.jl
"""
    make_turing_suite(model; kwargs...)

Create default benchmark suite for `model`.

# Keyword arguments
- `adbackends`: a collection of adbackends to use. Defaults to `$(DEFAULT_ADBACKENDS)`.
- `run_once=true`: if `true`, the body of each benchmark will be run once to avoid
  compilation to be included in the timings (this may occur if compilation runs
  longer than the allowed time limit).
- `save_grads=false`: if `true` and `run_once` is `true`, the gradients from the initial
  execution will be saved and returned as the second return-value. This is useful if you
  want to check correctness of the gradients for different backends.

# Notes
- A separate "parameter" instance (`DynamicPPL.VarInfo`) will be created for _each test_.
  Hence if you have a particularly large model, you might want to only pass one `adbackend`
  at the time.
"""
function make_turing_suite(
    model::DynamicPPL.Model;
    adbackends = DEFAULT_ADBACKENDS, run_once = true, save_grads = false
)
    suite = BenchmarkGroup()
    suite["not_linked"] = BenchmarkGroup()
    suite["linked"] = BenchmarkGroup()

    grads = Dict(:not_linked => Dict(), :linked => Dict())

    vi_orig = DynamicPPL.VarInfo(model)
    spl = DynamicPPL.SampleFromPrior()

    for adbackend in adbackends
        vi = DynamicPPL.VarInfo(vi_orig, spl, vi_orig[spl])
        f = LogDensityProblems.ADgradient(
            adbackend,
            Turing.LogDensityFunction(vi, model, spl, DynamicPPL.DefaultContext())
        )
        θ = vi[spl]

        if run_once
            ℓ, ∇ℓ = LogDensityProblems.logdensity_and_gradient(f, θ)

            if save_grads
                grads[:not_linked][adbackend] = (ℓ, ∇ℓ)
            end
        end
        suite["not_linked"]["$(adbackend)"] = @benchmarkable $(LogDensityProblems.logdensity_and_gradient)($f, $θ)

        # Need a separate `VarInfo` for the linked version since otherwise we risk the
        # `vi` from above being mutated.
        vi_linked = deepcopy(vi)
        DynamicPPL.link!(vi_linked, spl)
        f_linked = LogDensityProblems.ADgradient(
            adbackend,
            Turing.LogDensityFunction(vi_linked, model, spl, DynamicPPL.DefaultContext())
        )
        θ_linked = vi_linked[spl]
        if run_once
            ℓ, ∇ℓ = LogDensityProblems.logdensity_and_gradient(f_linked, θ_linked)

            if save_grads
                grads[:linked][adbackend] = (ℓ, ∇ℓ)
            end
        end
        suite["linked"]["$(adbackend)"] = @benchmarkable $(LogDensityProblems.logdensity_and_gradient)($f_linked, $θ_linked)
    end

    # Also benchmark just standard model evaluation because why not.
    suite["not_linked"]["evaluation"] = @benchmarkable $(DynamicPPL.evaluate!!)($model, $vi_orig, $(DynamicPPL.DefaultContext()))
    DynamicPPL.link!(vi_orig, spl)
    suite["linked"]["evaluation"] = @benchmarkable $(DynamicPPL.evaluate!!)($model, $vi_orig, $(DynamicPPL.DefaultContext()))

    return save_grads ? (suite, grads) : suite
end


# Generate mock data
X=hcat(ones(20000),randn(20000,22))
betas=vcat([-0.8],zeros(3),[1],zeros(5),[1],zeros(12))
@assert length(betas)==size(X,2)
y=(X*betas) .|> x->rand(BernoulliLogit(x))
@info "Important coef positions: $(findall(betas.!=0)), Average rate: $(mean(y))"

7 Likes

Wow, that’s a very thorough investigation! Unfortunately, I don’t have time to go through it in detail right now. I noticed one thing in the simple model (logreg_vanilla), however, that should improve performance immediately: You compute X*betas twice. I assume (untested) that it would be more efficient to just use something like

@model function logreg_vanilla(X)
    dim_X=size(X,2)
    betas ~ filldist(Normal(),dim_X)
    z = X*betas
    y ~ arraydist(LazyArray(@~ BernoulliLogit.(z)))
    return z
end

(I’m not sure how much one gains from using LazyArrays compared with a simple map.) I also imagine that an improved BernoulliLogit implementation, possibly with custom derivatives, could speed up the model. You could test this by e.g. computing the logpdf manually inside the model (using @addlogprob! and the alternative logpdf implementation without any BernoulliLogit).

1 Like

Thank you! The vanilla model is there only for a sense check to understand the cost of RHS prior structure.

To your point, I assumed that the return call is irrelevant for the sampling step (tilde_observe), since the model is rewritten anyway. I thought that it would play a role only if use generated_quantities()
It is something I have tried to change in the past and I haven’t been able to observe much difference.

The fastest implementation actually uses the @addlogprob! (logreg_rhs3)

The vanilla model is there only for a sense check to understand the cost of RHS prior structure.

That was my impression as well but also the other models recompute X*betas it seems.

To your point, I assumed that the return call is irrelevant for the sampling step (tilde_observe), since the model is rewritten anyway. I thought that it would play a role only if use generated_quantities()
It is something I have tried to change in the past and I haven’t been able to observe much difference.

Turing always runs the full model, it just throws away/doesn’t store return values during sampling.

The fastest implementation actually uses the @addlogprob! (logreg_rhs3)

Yes, I saw that. However, it still uses BernoulliLogit. My suggestion was to use the apparently more performant logpdf implementation without BernoulliLogit.

1 Like

The results are in!

  • Vanilla log. reg. sampling time went from 87s (logreg_vanilla below) to 11s (logreg_vanilla5)(chains are comparable; there were some odd artifacts but I blame the single short chain; general results match);
  • Gradient evaluation benchmarking is an excellent predictor: 5ms → 0.7 microsec
  • Tricks involved: no duplication of logits, custom logpdf, @addlogprob, reversediff + tape compilation, no logging
  • Translates into RHS model speed up to 100s of sampling (vs. 266s in Turing previously, vs 35s in Numpyro) - ie, we’re within 3x of Numpyro

Any other tips & tricks?

On duplication of logits multiplication
Fascinating! You were right that even that little duplication can be costly (in a simple model). Previously, I benchmark by calling the model.f and I wasn’t able to detect the difference, but on logreg_vanilla I can see almost a 20% speed up in the gradient computation just by creating logits only once.

Step size changes after AD change
I have noticed that HMC step size for the same model tends to jump by order of magnitude if I switch between AD backends (both FWD diff → Reverse diff and vice versa). Eg, eps=0.05 to eps=0.8

Could you think of a reason why that could be? Is there some leakage perhaps or some computational artefact?

Vanilla log. reg. benchmarking

# same setup as before
# # Vanilla logistic regression for comparisons
@model function logreg_vanilla(X)
    dim_X=size(X,2)
    betas ~ filldist(Normal(),dim_X)
    y ~ arraydist(LazyArray(@~ BernoulliLogit.(X*betas)))
    return X*betas
end;
cond_model=logreg_vanilla(X) | (;y);
turing_suite=make_turing_suite(cond_model;adbackends=DEFAULT_ADBACKENDS) |> run
# 2-element BenchmarkTools.BenchmarkGroup:
#   tags: []
#   "linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(184.207 ms)
#           "ForwardDiffAD{100, true}()" => Trial(5.176 ms)
#           "evaluation" => Trial(1.528 ms)
#           "ForwardDiffAD{40, true}()" => Trial(5.176 ms)
#           "ReverseDiffAD{true}()" => Trial(51.965 ms)
#   "not_linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(185.341 ms)
#           "ForwardDiffAD{100, true}()" => Trial(5.173 ms)
#           "evaluation" => Trial(1.530 ms)
#           "ForwardDiffAD{40, true}()" => Trial(5.170 ms)
#           "ReverseDiffAD{true}()" => Trial(46.953 ms)
@time chain_v=sample(cond_model, Turing.NUTS(500,0.65;max_depth=8),100;progress=true);
 # Info: Found initial step size
# └   ϵ = 0.05
# Sampling 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:01:21
 # 87.865645 seconds (49.62 M allocations: 42.934 GiB, 2.18% gc time, 16.92% compilation time: 0% of which was recompilation)

# create logits only once
@model function logreg_vanilla2(X)
    dim_X=size(X,2)
    betas ~ filldist(Normal(),dim_X)
    logits=X*betas
    y ~ arraydist(LazyArray(@~ BernoulliLogit.(logits)))
    return logits
end;
cond_model=logreg_vanilla2(X) | (;y);
turing_suite=make_turing_suite(cond_model;adbackends=DEFAULT_ADBACKENDS) |> run
# 2-element BenchmarkTools.BenchmarkGroup:
#   tags: []
#   "linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(31.634 ms)
#           "ForwardDiffAD{100, true}()" => Trial(3.908 ms)
#           "evaluation" => Trial(1.050 ms)
#           "ForwardDiffAD{40, true}()" => Trial(3.904 ms)
#           "ReverseDiffAD{true}()" => Trial(10.359 ms)
#   "not_linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(31.310 ms)
#           "ForwardDiffAD{100, true}()" => Trial(3.913 ms)
#           "evaluation" => Trial(1.047 ms)
#           "ForwardDiffAD{40, true}()" => Trial(3.914 ms)
#           "ReverseDiffAD{true}()" => Trial(10.573 ms)
@time chain_v2=sample(cond_model, Turing.NUTS(500,0.65;max_depth=8),100;progress=true);
# ┌ Info: Found initial step size
# └   ϵ = 0.05
# Sampling 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:00:54
#  56.102031 seconds (12.17 M allocations: 37.305 GiB, 1.64% gc time, 9.82% compilation time)

# wrap logpdf in @addlogprob
@model function logreg_vanilla3(X,y)
    dim_X=size(X,2)
    betas ~ filldist(Normal(),dim_X)
    logits=X*betas
    Turing.@addlogprob! sum(logpdf.(BernoulliLogit.(logits),y))
    return logits
end;
cond_model=logreg_vanilla3(X,y);
turing_suite=make_turing_suite(cond_model;adbackends=DEFAULT_ADBACKENDS) |> run
# 2-element BenchmarkTools.BenchmarkGroup:
#   tags: []
#   "linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(1.700 ms)
#           "ForwardDiffAD{100, true}()" => Trial(3.875 ms)
#           "evaluation" => Trial(1.121 ms)
#           "ForwardDiffAD{40, true}()" => Trial(3.880 ms)
#           "ReverseDiffAD{true}()" => Trial(1.550 ms)
#   "not_linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(1.692 ms)
#           "ForwardDiffAD{100, true}()" => Trial(3.875 ms)
#           "evaluation" => Trial(1.120 ms)
#           "ForwardDiffAD{40, true}()" => Trial(3.874 ms)
#           "ReverseDiffAD{true}()" => Trial(1.551 ms)
@time chain_v3=sample(cond_model, Turing.NUTS(500,0.65;max_depth=8),100;progress=true);
#┌ Info: Found initial step size
# └   ϵ = 0.05
# Sampling 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:00:52
#  54.932298 seconds (15.37 M allocations: 67.761 GiB, 2.91% gc time, 12.23% compilation time) call custom logpdf

# use custom logpdf
@model function logreg_vanilla4(X,y)
    dim_X=size(X,2)
    betas ~ filldist(Normal(),dim_X)
    logits=X*betas
    Turing.@addlogprob! sum(jax_logpdf.(logits,y))
    return logits
end;
cond_model=logreg_vanilla4(X,y);
turing_suite=make_turing_suite(cond_model;adbackends=DEFAULT_ADBACKENDS) |> run
# 2-element BenchmarkTools.BenchmarkGroup:
#   tags: []
#   "linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(861.041 μs)
#           "ForwardDiffAD{100, true}()" => Trial(3.034 ms)
#           "evaluation" => Trial(445.000 μs)
#           "ForwardDiffAD{40, true}()" => Trial(3.036 ms)
#           "ReverseDiffAD{true}()" => Trial(716.625 μs)
#   "not_linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(862.834 μs)
#           "ForwardDiffAD{100, true}()" => Trial(3.042 ms)
#           "evaluation" => Trial(445.459 μs)
#           "ForwardDiffAD{40, true}()" => Trial(3.051 ms)
#           "ReverseDiffAD{true}()" => Trial(720.000 μs)
@time chain_v4=sample(cond_model, Turing.NUTS(500,0.65;max_depth=8),100;progress=true);
# ┌ Info: Found initial step size
# └   ϵ = 0.05
# Sampling 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| Time: 0:00:37
#  39.722312 seconds (15.48 M allocations: 64.849 GiB, 3.30% gc time, 17.18% compilation time)


# use custom logpdf + reversediff with tape compilation, suppress logging
Logging.disable_logging(Logging.Warn);
Turing.setadbackend(:reversediff)
Turing.setrdcache(true)
@model function logreg_vanilla5(X,y)
    dim_X=size(X,2)
    betas ~ filldist(Normal(),dim_X)
    logits=X*betas
    Turing.@addlogprob! sum(jax_logpdf.(logits,y))
    return logits
end;
cond_model=logreg_vanilla5(X,y);
turing_suite=make_turing_suite(cond_model;adbackends=DEFAULT_ADBACKENDS) |> run
# 2-element BenchmarkTools.BenchmarkGroup:
#   tags: []
#   "linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(873.792 μs)
#           "ForwardDiffAD{100, true}()" => Trial(3.043 ms)
#           "evaluation" => Trial(462.708 μs)
#           "ForwardDiffAD{40, true}()" => Trial(3.049 ms)
#           "ReverseDiffAD{true}()" => Trial(734.083 μs)
#   "not_linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(875.041 μs)
#           "ForwardDiffAD{100, true}()" => Trial(3.049 ms)
#           "evaluation" => Trial(462.750 μs)
#           "ForwardDiffAD{40, true}()" => Trial(3.055 ms)
#           "ReverseDiffAD{true}()" => Trial(732.875 μs
@time chain_v5=sample(cond_model, Turing.NUTS(500,0.65;max_depth=8),100;progress=false);
# 11.020877 seconds (14.19 M allocations: 1.032 GiB, 0.74% gc time, 45.83% compilation time)

Log.reg. with RHS prior benchmarking

# Logistic regression + Regularized horseshoe Prior
# Based on: https://arxiv.org/pdf/1707.01694.pdf#page11 (Appendix C1)
@model function logreg_rhs(X)
    slab_df=4
    slab_scale=2
    eff_params=3
    dim_X=size(X,2)
    tau_0=eff_params/(dim_X-eff_params)/sqrt(size(X,1))

    lambdas ~ filldist(truncated(Cauchy();lower=0),dim_X)
    tau ~ truncated(Cauchy(0,tau_0);lower=0)
    z ~ filldist(Normal(),dim_X)

    c_aux ~ InverseGamma(0.5 * slab_df, 0.5 * slab_df)
    c_sq=slab_scale^2 * c_aux # squared already
    lambdas_tilde=sqrt.((c_sq .* lambdas.^2) ./ (c_sq .+ tau.^2 .* lambdas.^2))
    betas=lambdas_tilde .* tau .* z
    y ~ arraydist(LazyArray(@~ BernoulliLogit.(X*betas)))
    return X*betas
end;
cond_model=logreg_rhs(X) | (;y);
turing_suite=make_turing_suite(cond_model;adbackends=DEFAULT_ADBACKENDS) |> run
# 2-element BenchmarkTools.BenchmarkGroup:
#   tags: []
#   "linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(171.180 ms)
#           "ForwardDiffAD{100, true}()" => Trial(21.854 ms)
#           "evaluation" => Trial(1.235 ms)
#           "ForwardDiffAD{40, true}()" => Trial(10.736 ms)
#           "ReverseDiffAD{true}()" => Trial(48.728 ms)
#   "not_linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(173.224 ms)
#           "ForwardDiffAD{100, true}()" => Trial(22.077 ms)
#           "evaluation" => Trial(1.231 ms)
#           "ForwardDiffAD{40, true}()" => Trial(11.172 ms)
#           "ReverseDiffAD{true}()" => Trial(58.029 ms)


# stop duplication of logits, custom logpdf, run in reversediff + tape compilation
# Alternative parametrization based on: https://arxiv.org/pdf/1707.01694.pdf#page11 (Appendix C2)
@model function logreg_rhs4(X,y)
    # these could be provided via args but it makes the function calls clunky, perhaps a parameters NamedTuple
    slab_df=4
    slab_scale=2
    eff_params=3
    dim_X=size(X,2)
    tau_0=eff_params/(dim_X-eff_params)/sqrt(size(X,1))

    # re-parametrize lambdas and lambdas squared directly to plug into lambas_tilde
    lambdas_scale ~ filldist(InverseGamma(0.5,0.5),dim_X)
    lambdas_z ~ filldist(truncated(Normal();lower=0),dim_X)
    lambdas_sq= lambdas_z.^2 .* lambdas_scale

    # re-parametrize tau and and square it directly to plug into lambas_tilde
    tau_scale ~ InverseGamma(0.5,0.5)
    tau_z ~ truncated(Normal();lower=0)
    tau_sq = tau_z^2 .*tau_scale .* tau_0^2
    z ~ filldist(Normal(),dim_X)

    c_aux ~ InverseGamma(0.5 * slab_df, 0.5 * slab_df)
    c_sq=slab_scale^2 * c_aux # squared already
    lambdas_tilde=sqrt.((c_sq .* lambdas_sq) ./ (c_sq .+ tau_sq .* lambdas_sq))
    betas=lambdas_tilde .* sqrt(tau_sq) .* z
    logits=X*betas
    Turing.@addlogprob! sum(jax_logpdf.(logits,y))
    return logits
end;
cond_model=logreg_rhs4(X,y);
turing_suite=make_turing_suite(cond_model;adbackends=DEFAULT_ADBACKENDS) |> run
# 2-element BenchmarkTools.BenchmarkGroup:
#   tags: []
#   "linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(921.792 μs)
#           "ForwardDiffAD{100, true}()" => Trial(20.925 ms)
#           "evaluation" => Trial(303.125 μs)
#           "ForwardDiffAD{40, true}()" => Trial(18.665 ms)
#           "ReverseDiffAD{true}()" => Trial(632.583 μs)
#   "not_linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(1.044 ms)
#           "ForwardDiffAD{100, true}()" => Trial(21.179 ms)
#           "evaluation" => Trial(303.209 μs)
#           "ForwardDiffAD{40, true}()" => Trial(18.930 ms)
#           "ReverseDiffAD{true}()" => Trial(620.083 μs)

@time chain_rhs4=sample(cond_model, Turing.NUTS(500,0.65;max_depth=8),100;progress=false);
# 103.736903 seconds (201.11 M allocations: 9.029 GiB, 0.85% gc time, 14.20% compilation time)

3 Likes

Great work but I’m still curious, as you are, why we’re a factor 3 slower than numpyro. But i might be missing something obvious.

Yes, same here. But to be honest, I don’t look for parity.

If I could have:

  • 3x speed out of the box (eg, gradient benchmark being code part of Turing.jl to tell me which AD to use, logprob performance like jax)
  • mostly Julia syntax with very little extra quirks (LazyArrays, @addlogprob etc)
  • some basic diagnostics for HMC out of the box (divergences, etc ala Stan)

I would pick Turing every time :slight_smile: !
I’m keen to learn more about these performance tips (and generalize them a bit), so I don’t have to play with it as much next time.

1 Like

In Tilde.jl, @mschauer and I had been doing something like

model_lr = @model (Xt, y) begin
    d, n = size(Xt)
    θ ~ Normal() ^ d
    for j in 1:n
        logitp = view(Xt, :, j)' * θ
        y[j] ~ Bernoulli(logitp = logitp)
    end
end

So the big differences are

  • You pass Xt instead of X, so you can more efficiently walk down a single observation
  • Instead of a matrix-vector product requiring allocation on each evaluation, only compute the required dot product

Maybe something similar could work in Turing?

1 Like

That’s a great idea but I think I’m missing some trick.

Code is type stable, but when I add the loop allocations actually go up and speed goes down.
I’m not sure where they are coming from as:

  • I have flipped the axes in data to loop column-major
  • I use views to avoid allocations
  • inner function (logprob) is type-stable and fast (no allocs)
  • I took out the return statement of logits (because that would cause the allocations of X*betas)

Allocations go from ~900 (with matrix-vector) to 260k (loop), so there are 10+ allocations in each loop. I could unroll the dot product as well but it doesn’t sound like the right direction.

@model function logreg_rhs5(X,y,::Type{T} = Float64) where {T}
    # these could be provided via args but it makes the function calls clunky, perhaps a parameters NamedTuple
    slab_df=4
    slab_scale=2
    eff_params=3
    dim_X=size(X,1) ## flipped axes!
    tau_0=eff_params/(dim_X-eff_params)/sqrt(size(X,2))

    # re-parametrize lambdas and lambdas squared directly to plug into lambas_tilde
    lambdas_scale ~ filldist(InverseGamma(0.5,0.5),dim_X)
    lambdas_z ~ filldist(truncated(Normal();lower=0),dim_X)
    lambdas_sq= lambdas_z.^2 .* lambdas_scale

    # re-parametrize tau and and square it directly to plug into lambas_tilde
    tau_scale ~ InverseGamma(0.5,0.5)
    tau_z ~ truncated(Normal();lower=0)
    tau_sq = tau_z^2 .*tau_scale .* tau_0^2
    z ~ filldist(Normal(),dim_X)

    c_aux ~ InverseGamma(0.5 * slab_df, 0.5 * slab_df)
    c_sq=slab_scale^2 * c_aux # squared already
    lambdas_tilde=sqrt.((c_sq .* lambdas_sq) ./ (c_sq .+ tau_sq .* lambdas_sq))
    betas=reshape(lambdas_tilde .* sqrt(tau_sq) .* z,1,:)
    llik=zero(T)
    @simd for i in 1:size(X,2)
        @inbounds llik+= jax_logpdf(only(betas*@view(X[:,i])),y[i])
    end
    # @info "Sense check" llik sum(jax_logpdf.(betas*X|>vec,y))
    Turing.@addlogprob! llik
    return nothing
end;
cond_model=logreg_rhs5(copy(X'),y);
turing_suite=make_turing_suite(cond_model;adbackends=DEFAULT_ADBACKENDS) |> run
# 2-element BenchmarkTools.BenchmarkGroup:
#   tags: []
#   "linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(54.755 ms)
#           "ForwardDiffAD{100, true}()" => Trial(40.511 ms)
#           "evaluation" => Trial(3.510 ms)
#           "ForwardDiffAD{40, true}()" => Trial(48.977 ms)
#           "ReverseDiffAD{true}()" => Trial(21.696 ms)
#   "not_linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(54.233 ms)
#           "ForwardDiffAD{100, true}()" => Trial(40.231 ms)
#           "evaluation" => Trial(3.507 ms)
#           "ForwardDiffAD{40, true}()" => Trial(47.588 ms)
#           "ReverseDiffAD{true}()" => Trial(21.879 ms)

I’m not following closely the thread, so my comment maybe pointless: but be careful here: a matrix-vector product call extremely optimized BLAS routines, which even if allocate something are probably faster than manually written alternatives. There are in-place lower level functions to call from BLAS, if you need inplace alternatives. Another alternative is to use MKL instead.

1 Like

That’s my sense as well.

Also, the winning AD is ReverseDiff, which I think will prefer vectorized code, so there would be no ultimate benefit.

ReverseDiff.jl doesn’t need nor benefit from vectorized code styles if tape compilation is used. In fact, using fully non-allocating calls (mixing in PreallocationTools.jl), only using BLAS/LAPACK calls otherwise, will generally be fastest.

1 Like

Thank you for the clarification! I must have mixed it up with Tracker/Zygote.

Are you aware of any example with Turing.jl + PreallocationTools.jl? I’ve checked it out before, but it wasn’t clear to me how to plug it in to the usual Turing interface.

No, but it’s no different from the normal usage of PreallocationTools, so you’d just define your caches with it.

Great point. I didn’t compare so many alternatives, this was just a quick implementation starting with the idea that maybe we should avoid allocation. From there, the transpose seemed much quicker, which makes sense because it’s more cache-friendly.

I thought this would call BLAS vector-vector multiply at each step, but if that’s not the case it should clearly be changed. I didn’t try anything in-place, because sometimes AD can have a harder time with that.

Also, this was with ForwardDiff, since the ZigZagBoomerang sampler we used usually doesn’t need the full gradient but just individual partials at a given step.

Hi @svilupp , I got a sticky Zig-Zag running on the example with Tilde:

import Pkg
cd(@__DIR__)
# Pkg.activate(@__DIR__)

using Tilde, Pathfinder,  PDMats, StructArrays
using ForwardDiff
using ForwardDiff: Dual
using LinearAlgebra, Random, Statistics, StatsBase, SparseArrays
using ZigZagBoomerang
using ZigZagBoomerang: StickyBarriers, StructuredTarget, StickyUpperBounds, StickyFlow, EndTime
using MCMCChains
using ArraysOfArrays

# Configuration 
Random.seed!(1)
κ = 0.01 # stickyness
T = 5000.0 # sampling time
c = 0.01
progress = true # show progress bar
PLOT = true # plot posterior trace
nsamples = 200
Δt = T/nsamples

# Generate mock data
println("Data...")
X = hcat(ones(20000),randn(20000,22))
Xt = Matrix(X')
d = size(X, 2)
n = size(X, 1)
betas = vcat([-0.8], zeros(3), [1.0], zeros(5), [0.9], zeros(12))
@assert length(betas) == d
y = (X*betas) .|> x->rand(Bernoulli(logitp = x))
@info "Important coef positions: $(findall(betas.!=0)), Average rate: $(mean(y))"

# Simple logistic model
model_lr = @model (Xt, y) begin
    d, n = size(Xt)
    θ ~ Normal() ^ d
    for j in 1:n
        logitp = view(Xt, :, j)' * θ
        y[j] ~ Bernoulli(logitp = logitp)
    end
end

# Gradients
function make_grads(model_lr, At, y, d)    
    post = model_lr(At, y) | (;y)
    as_post = as(post)
    obj(θ) = -Tilde.unsafe_logdensityof(post, transform(as_post, θ))
    ℓ(θ) = -obj(θ)

    gconfig = ForwardDiff.GradientConfig(obj, rand(25), ForwardDiff.Chunk{25}())
    function ∇neglogp!(y, t, x, args...)
        ForwardDiff.gradient!(y, obj, x, gconfig)
        return
    end

    ith = zeros(d)
    function ∂neglogp(x,i)
        # should use StructArrays, seems tilde broke that
        ForwardDiff.partials(obj([Dual{}(x[j], 1.0*(i==j)) for j in eachindex(x)]))[]
    end
    
    post, ℓ, ∇neglogp!, ∂neglogp
end
post, ℓ, ∇neglogp!, ∂neglogp =  make_grads(model_lr, Xt, y, d)    

# Pathfinding
println("Pathfinder...")
init_scale = 1
if !@isdefined pf_result
    @time pf_result = pathfinder(ℓ; dim=d, init_scale)
end
M = PDMats.PDiagMat(diag(pf_result.fit_distribution.Σ))
Γ = sparse(inv(M))
#Γ = sparse(inv(pf_result.fit_distribution.Σ))
x0 = μ = pf_result.fit_distribution.μ
v0 = PDMats.unwhiten(M, randn(length(x0)))



# Sticky sampler
println("Sticky sampler...")
barriers = [StickyBarriers((0.0, 0.0), (:sticky, :sticky), (κ, κ)) for i in 1:d]
d = length(x0)
t0 = fill(0.0, d)
u0 = (t0, x0, v0) 
target = StructuredTarget([i => 1:d for i in 1:d], ∂neglogp)
flow = StickyFlow(ZigZag(Γ, μ))
strong_upperbounds = false
adapt = true
multiplier = 1.7 # increase bounds
G = target.G
G1 = [i => rowvals(Γ)[nzrange(Γ, i)] for i in axes(Γ, 1)]
upper_bounds = StickyUpperBounds(G, G1, Γ, fill(c, d); adapt=adapt, strong = strong_upperbounds, multiplier= multiplier)
end_time = EndTime(T)
∇ϕ(x, i) = ZigZagBoomerang.idot(Γ, i, x) # sparse computation
elapsed_time = @elapsed begin
trace, _, _, acc = @time stickyzz(u0, target, flow, upper_bounds, barriers, end_time; progress=progress)
end
@info "Upper bounds: $(upper_bounds.c)"
println("acc ", acc.acc/acc.num)


# Plot continuous trace
if PLOT
   
    ts, xs = ZigZagBoomerang.sep(collect(trace))
    println("Plot...")
    colors = [:green, :red, :blue, :violet]
    using GLMakie
    fig1 = fig = Figure()
    r = 1:length(ts)
    ax = Axis(fig[1,1], title = "trace")
    is = [1, 2, 5, 11]
    for i in 1:length(is)
        lines!(ax, ts[r], getindex.(xs[r], is[i]), color=colors[i])
        lines!(ax, ts[r], fill(betas[is[i]], length(r)), linestyle=:dash, color = (colors[i], 0.5))
    end
    display(fig)
end

# Your samples
samples = flatview(VectorOfSimilarVectors(ZigZagBoomerang.sep(ZigZagBoomerang.discretise(trace, Δt))[2]))

chain = MCMCChains.Chains(samples')
chain = setinfo(chain, (;start_time=0.0, stop_time = elapsed_time));
chain

Doesn’t seem like a bad idea :slight_smile: Note that I replace the Horseshoe prior with a Spike and Slab with spike at 0 with weight 1/κ, which the sampler handles (the likelihood just shows the slab/Normal part).

(Link to gist with project)

3 Likes

That’s awesome! Thank you - I’ll try it over the weekend.

I’ve tried the Bouncy Particle based on the with example with the Turing interface but failed (link).

A few questions:

  • What timings did you get to produce the 50-100ish effective samples? (I appreciate it’s different with PDMPs)
  • What was the challenge with using RHS prior?

15 seconds or so with room for improvement with more effort.

Haha, no challenge, I just spend hours of my life figuring out how not to use a continuous approximation to a spike and slab, see ZigZagBoomerang.jl - parallel inference and variable selection | Moritz Schauer | JuiaCon2021 - YouTube (or the entire video). Therefore have an implementation of the ZigZag which just assumes a discrete spike at 0 and a continuous slab around it.

3 Likes