Case study: Speeding up a logistic regression with RHS prior (Turing vs Numpyro) - any tricks I'm missing?

Thank you for the extensive reply! I’m excited to play around with it (and compare the posterior results with NUTS tbh).

Are you aware of any working example with a model defined in Turing.jl (ie,make_grads() function leveraging the DynamicPPL interface)? I went through the repo but couldn’t find any.

And, perhaps even more ambitious questions, any examples leveraging the high-level interface of ZZB?

Let me see! Which ones are (most) relevant for you, @svilupp ?

  • Bouncy particle for general targets
  • ZigZag for targets with structure (sparse coefficient vectors like here or a lot of conditional independence between variables.)

I am still working on a nice factorised Bouncy particle sampler which combines both, for now I have this dichotomy

I think the most common are partially pooled generalized linear models (logistic, normal, poison, with hierarchical priors (usually 1-2 levels max)).

In Numpyro, we would sometimes add some structural information about correlation (LKJ), but that’s much harder in Turing, but never conditional independence because we’re often working with behavioural data (different measurements of some underlying “job-to-be-done”).

Most of our models are in the range of tens of thousands to hundreds of thousands of observations (mostly with n >> p), hence this thread.

Sparsity is a nice-to-have for “decision support” modelling, I hope to add Aki’s “projpred” procedure over time to distil the signal even further. In my head, I’d willingly take 3x slowdown to get sparsity in my model (we always have too many predictors vs what we believe is strictly necessary)


This is probably irrelevant but I thought some background info might help:

  • we’re “in the industry”, so in general, time-to-develop the model is 10x more important than time-to-run (hence the question about API and ask for Turing, because there is less cognitive load from using the same PPL)
  • I count iterations etc into “time-to-develop”, so I don’t mind if it runs for 10 hours, but I like being able to get “some” feedback from the model within 5 minutes of running to be able to iterate the structure / priors /etc.
  • This whole thread is driven by the need to reduce “time-to-develop” (iteration time), because I want to be able to see some results within that “5minute” budget
1 Like

Tilde dev here. I don’t use Turing much. How is this harder in Turing?

I’m probably not the most qualified person to answer this, but I think it’s well described here, because the Slack conversation has disappeared.
My take is that it’s possible thanks to amazing @sethaxen and his snippet (see the link above), but it’s non-trivial (because Bijectors require the same dimensions, so we need to use TransformedVariables.jl and go to a lower-level)

Thanks for the details. I’ve tried to use Bijectors before, but the flattening thing has gotten in the way. @torfjelde has a fix here, but it looks like it hasn’t been touched in a while.

FWIW, Tilde works in terms of TransformVariables, but we’ll soon also have a new “transport API” thanks to @oschulz 's work in MeasureBase.jl. Then we’ll be able to transform any prior into a multivariate uniform or normal, updating the likelihood accordingly. The geometry gets simplified, no more funnels.

3 Likes

Can you “donate” a secondTuring model of interest?

RE your question about things being hard in Turing, I don’t know if other people have had better luck, but I’ve never been able to successfully run a Turing model that used the LKJ or (I)Wishart distributions via NUTS, and even HMC often has difficulties.

For example:

using Turing, Distributions

@model function LKJ_demo(P)
    y ~ LKJ(P, 1)
end

P = 2
model = LKJ_demo(P)
m1 = sample(model, HMC(0.1, 20), MCMCThreads(), 1000, 4)
m2 = sample(model, NUTS(1000, 0.65), MCMCThreads(), 1000, 4)

When dealing with a simple 2x2 correlation matrix, both HMC and NUTS work fine.

However, if you up the dimension by 1 (to P = 3), NUTS always fails for me (with a PosDef error), even if you increase the acceptance rate (one of the suggestions in the thread linked above talking about the LKJCholesky distribution, and also suggested elsewhere in various github issues) to something absurd, like 0.99, or fiddle with the initial step size, etc… And if you increase P to some larger number, like 5, NUTS will frequently fail with a different error (especially when using a large acceptance rate): DomainError with -1.0: log will only return a complex result if called with a complex argument. Try log(Complex(x)).

If we do the same thing for the Wishart/InverseWishart distributions, we see the same general sort of pattern: things work fine in very low dimensions (e.g., P = 2 and 3), but when you get to P = 4, both HMC and NUTS fail (HMC using step sizes down to 0.001, and NUTS using acceptance rates up to 0.95).

Now, I’ve seen people in the past post models where they have (I)Wishart priors of higher dimension (like 10+), so presumably that has worked at some point, but whenever I have tried their models, they fail for me, just like my toy examples.

Anyway, I’m sure this will get ironed out eventually, but right now the experience of modeling covariances in Turing can sometimes feel a bit… brittle, unfortunately.

1 Like

See https://github.com/JuliaStats/Distributions.jl/pull/1066#issuecomment-584905987

The m = d(d-1)/2 free elements of a correlation matrix live in a subset of [-1, 1]ᵐ that has finite and strictly positive Lebesgue measure.

Can this really work with HMC/NUTS? Already the symmetry makes a problem, or?

Sure, it’s very easy in Stan or Soss. I think the best way to is to use an LKJCholesky.

Is this maybe a case of Turing treating dimensionality and degrees of freedom as the same? That causes lots of problems in general.

That might be hard! GLMs are mostly what I’ve used Turing for (it’s >80% of our bayesian pursuits).

I’ve tried a few other things in Turing as part of some work research, but went back to Numpyro:

  • Ordinal regression but that’s more about having the likelihood implemented (used it for survey results, eg, here - if anything is wrong in this tutorial, it might be my fault)
  • HMMs (2-layer HMM) where likelihoods are marginalized (because it’s a medium-sized dataset); Numpyro does the marginalization very well through enumerations in additional plate dimensions, so I quickly went back (it also runs on GPU thanks to Jax)
  • (Future desire) Exploring GPs for optimization and active learning (I’m not convinced PPL is needed; I can just use GP-specific libraries)

Also, some repeated modelling “challenges”

  • slightly larger models (>200k observations, thousands of RVs); I tried variational inference and the speed up wasn’t as significant as expected
  • structure: correlation among covariates, sparsity (but that’s okay now)
  • benchmarking/testing my models (again, much better now with Tor’s benchmark suite)

In lieu of offering solutions, I’d like to try to narrow down where the problem is coming from.

TLDR: I suspect the culprit is the bijector for correlation matrices.

I’ll focus on LKJ(2, 1) because this case is really simple (only 1 free parameter) and lets us do some stuff analytically, and already when I sample this with NUTS as in the above example I find some issues:

using Turing, StatsPlots

sampler = NUTS(1000, 0.65)
nchains, ndraws = 4, 1_000
@model function LKJ_basic()
    X ~ LKJ(2, 1)
end
chns_basic = sample(LKJ_basic(), sampler, MCMCThreads(), ndraws, nchains);

Now we check the post-warmup numerical error:

julia> chns_basic[:numerical_error] |> sum
38.0

38 divergent transitions is really too many. We could increase the acceptance ratio to reduce the number (not all the way to 0), but as we’ll see below, this shouldn’t be necessary. This indicates that either the geometry of the induced unconstrained distribution is bad or there is numerical instability somewhere else in the model.

For some analytical results, if X \sim \mathrm{LKJ}(2, 1), then X_{11} = X_{22} = 1 and X_{21} = X_{12} = b \sim \mathrm{Uniform}(-1, 1). We can easily check this with exact draws:

julia> Xs = rand(LKJ(2, 1), 1_000);

julia> bs = [X[2, 1] for X in Xs];

julia> plot(Uniform(-1, 1); func=cdf, label="Uniform(-1, 1)")

julia> ecdfplot!(bs; label="b");

tmp2

For this 2x2 case, the lower Cholesky factor of X is L = \begin{pmatrix}1 & 0 \\ b & \sqrt{1 - b^2}\end{pmatrix}. b is transformed from an unconstrained parameter y with b = \tanh(y), whose derivative is \frac{\mathrm{d}b}{\mathrm{d}y} = 1 - \tanh(y)^2 = 1 - b^2.
So the induced distribution on y is p(y) = \frac{1}{2}(1 - \tanh(y)^2), which is not a hard distribution to sample:

julia> bs ≈ [cholesky(X).L[2, 1] for X in Xs]
true

julia> @model function LKJ_latent()
           y ~ Flat()
           Turing.@addlogprob! log1p(-tanh(y)^2)
       end;

julia> chns_latent = sample(LKJ_latent(), sampler, MCMCThreads(), ndraws, nchains);

julia> chns_latent[:numerical_error] |> sum
0.0

julia> ecdfplot(ys; label="y exact")

julia> @df chns_latent ecdfplot!(:y; label="y NUTS")

tmp3

Since the geometry of the induced unconstrained distribution (what HMC actually samples) is fine, the source of numerical stability must be in the bijector itself (computing X from y), its logdetjac, the logpdf of LKJ, or gradients of any of these.

We’ll test if the source is the bijector or its gradient by manually computing the bijection:

@model function LKJ_latent_bijector_test()
    y ~ Flat()
    b = tanh(y)
    X = [1 b; b 1]
    logJ = log1p(-b^2)
    Turing.@addlogprob! logpdf(LKJ(2, 1), X) + logJ
end
chns_bijector = sample(LKJ_latent_bijector_test(), sampler, MCMCThreads(), ndraws, nchains)

Now let’s see the results:

julia> chns_bijector[:numerical_error] |> sum
0.0

julia> ecdfplot(ys; label="y exact")

julia> @df chns_bijector ecdfplot!(:y; label="y bijector NUTS")

tmp4

No sampling problems, and we are indeed targeting the correct distribution! And this model included the logpdf of LKJ and its gradient, so that can’t be our problem.

By process of elimination, this would imply that the numerical instability for this simple case is in the Bijectors.jl bijector itself (or its gradient). Perhaps that’s also where the problem is for higher-dimensional correlation matrices.

11 Likes

Excellent bit of detective work! :smiley:

As a quick follow-up, I just remembered CorrBijector makes posterior improper · Issue #228 · TuringLang/Bijectors.jl · GitHub, and it seems likely to me that its the cause of these problems. For LKJ(2, 1), there’s just 1 degree of freedom, but Bijectors will instead sample 4. 3 of those don’t appear in any log density term, so they’re completely unbounded on the reals. Let’s sample this with an explicit y:

@model function LKJ_latent_bijector_test2()
    y ~ filldist(Flat(), 2, 2)
    U = Bijectors._inv_link_chol_lkj(y)  # bijectors uses upper triangular factors
    b = U[1, 2]
    X = U'U
    logJ = log1p(-b^2)
    Turing.@addlogprob! logpdf(LKJ(2, 1), X) + logJ
end
julia> chns_bijector2 = sample(LKJ_latent_bijector_test2(), sampler, MCMCThreads(), ndraws, nchains)
Chains MCMC chain (1000×16×4 Array{Float64, 3}):

Iterations        = 1001:1:2000
Number of chains  = 4
Samples per chain = 1000
Wall duration     = 19.43 seconds
Compute duration  = 67.24 seconds
parameters        = y[1,1], y[2,1], y[1,2], y[2,2]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters                         mean                          std                   naive_se                        mcse         ess      rhat   ess_per_sec 
      Symbol                      Float64                      Float64                    Float64                     Float64     Float64   Float64       Float64 

      y[1,1]   -11147904746169815040.0000   103813424022984933376.0000   1641434358067363072.0000   13068514454427486208.0000      8.5923    3.5233        0.1278
      y[2,1]    -5848611813323702272.0000    56976721607190446080.0000    900881069440256640.0000    6841438468043346944.0000     10.2138    1.9909        0.1519
      y[1,2]                       0.0226                       0.8944                     0.0141                      0.0136   3895.7260    1.0001       57.9342
      y[2,2]     4836710533717538816.0000    28310642859682942976.0000    447630567300904064.0000    3509525907373429760.0000     10.7624    1.6892        0.1601

Quantiles
  parameters                          2.5%                        25.0%                      50.0%                       75.0%                        97.5% 
      Symbol                       Float64                      Float64                    Float64                     Float64                      Float64 

      y[1,1]   -248511673432142577664.0000   -12349585123783217152.0000   5015527916168183808.0000   26492620967308754944.0000   188906933481250881536.0000
      y[2,1]   -129198750965275377664.0000   -43501092142952873984.0000   -360217999869619840.0000   27753645829781962752.0000   114326851817075408896.0000
      y[1,2]                       -1.8147                      -0.5439                     0.0171                      0.5837                       1.8161
      y[2,2]    -38932511857583849472.0000    -7583069140425722880.0000    962254295777056512.0000   12060787746128771072.0000   101517804871764230144.0000

julia> chns_bijector2[:numerical_error] |> sum
29.0

So we can see how the unconstrained parameters explode, and we see a similar amount of post-warmup numerical error as with the original model.

EDIT: Let’s test the theory by using the same model but just putting a normal prior on the unused degrees of freedom:

@model function LKJ_latent_bijector_test3()
    y ~ filldist(Flat(), 2, 2)
    U = Bijectors._inv_link_chol_lkj(y)
    b = U[1, 2]
    X = U'U
    logJ = log1p(-b^2)
    Turing.@addlogprob! logpdf(LKJ(2, 1), X) + logJ
    Turing.@addlogprob! -(y[1, 1]^2 + y[2, 1]^2 + y[2, 2]^2)/2
end
julia> chns_bijector3 = sample(LKJ_latent_bijector_test3(), sampler, MCMCThreads(), ndraws, nchains)
Chains MCMC chain (1000×16×4 Array{Float64, 3}):

Iterations        = 1001:1:2000
Number of chains  = 4
Samples per chain = 1000
Wall duration     = 0.98 seconds
Compute duration  = 3.6 seconds
parameters        = y[1,1], y[2,1], y[1,2], y[2,2]
internals         = lp, n_steps, is_accept, acceptance_rate, log_density, hamiltonian_energy, hamiltonian_energy_error, max_hamiltonian_energy_error, tree_depth, numerical_error, step_size, nom_step_size

Summary Statistics
  parameters      mean       std   naive_se      mcse         ess      rhat   ess_per_sec 
      Symbol   Float64   Float64    Float64   Float64     Float64   Float64       Float64 

      y[1,1]    0.0011    0.9983     0.0158    0.0120   7042.3000    0.9995     1954.5656
      y[2,1]    0.0037    0.9945     0.0157    0.0118   5927.0604    0.9996     1645.0348
      y[1,2]   -0.0028    0.9163     0.0145    0.0118   4972.3063    1.0001     1380.0461
      y[2,2]   -0.0013    1.0046     0.0159    0.0121   6517.8198    0.9997     1808.9980

Quantiles
  parameters      2.5%     25.0%     50.0%     75.0%     97.5% 
      Symbol   Float64   Float64   Float64   Float64   Float64 

      y[1,1]   -1.9008   -0.6818    0.0110    0.6896    1.9938
      y[2,1]   -1.9811   -0.6736    0.0007    0.6845    1.9273
      y[1,2]   -1.8420   -0.5662    0.0129    0.5413    1.8464
      y[2,2]   -1.9892   -0.6589    0.0024    0.6556    2.0052


julia> chns_bijector3[:numerical_error] |> sum
0.0

So I think this confirms the theory.

4 Likes

Great investigation!

I think we have a few different threads running here:
a) speed tricks for logistic regression with NUTS sampler
b) alternative samplers (Pathfinder, ZigZagBoomerang’s samplers)
c) challenges with LKJ / correlation among variables


This answer is following up on thread a) - speed tricks for logistic regression.

Following the discussion in here, I’ve applied Tullio.jl to calculate the matrix-vector multiply (logits) and logpdf calculation.

Findings:

  • It doesn’t change the “best” model (with ReverseDiff)
  • It does, however, speed up ForwardDiff-based gradient calculation (18ms → 10ms)!
  • I have tried to calculate only the logpdfs+reduction in Tullio and that was as slow as baseline (ie, the benefit seems to come from the matrix-vector multiplication as Chad suggested on improving at the top)

Note: Tullio should work with ForwardDiff, but I haven’t explicitly checked the grads to confirm.

# leverage tullio for matrix-vector multiply and calculating logpdf
using Tullio
@model function logreg_rhs6b(X,y)
    # these could be provided via args but it makes the function calls clunky, perhaps a parameters NamedTuple
    slab_df=4
    slab_scale=2
    eff_params=3
    dim_X=size(X,2)
    tau_0=eff_params/(dim_X-eff_params)/sqrt(size(X,1))

    # re-parametrize lambdas and lambdas squared directly to plug into lambas_tilde
    lambdas_scale ~ filldist(InverseGamma(0.5,0.5),dim_X)
    lambdas_z ~ filldist(truncated(Normal();lower=0),dim_X)
    lambdas_sq= lambdas_z.^2 .* lambdas_scale

    # re-parametrize tau and and square it directly to plug into lambas_tilde
    tau_scale ~ InverseGamma(0.5,0.5)
    tau_z ~ truncated(Normal();lower=0)
    tau_sq = tau_z^2 .*tau_scale .* tau_0^2
    z ~ filldist(Normal(),dim_X)

    c_aux ~ InverseGamma(0.5 * slab_df, 0.5 * slab_df)
    c_sq=slab_scale^2 * c_aux # squared already
    lambdas_tilde=sqrt.((c_sq .* lambdas_sq) ./ (c_sq .+ tau_sq .* lambdas_sq))
    betas=lambdas_tilde .* sqrt(tau_sq) .* z

    @tullio llik[i] := jax_logpdf(_,y[i]) <| (X[i,j]*betas[j]) grad=Dual
    Turing.@addlogprob! sum(llik)
    # return logits
end;
cond_model=logreg_rhs6b(X,y);
turing_suite=make_turing_suite(cond_model;adbackends=[ForwardDiffAD{40}()]) |> run

## uses only ForwardDiff as ReverseDiff crashes hard (due to Tullio)
## TULLIO used for logits and logpdf
# 2-element BenchmarkTools.BenchmarkGroup:
#   tags: []
#   "linked" => 2-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "evaluation" => Trial(316.750 μs)
#           "ForwardDiffAD{40, true}()" => Trial(10.315 ms)
#   "not_linked" => 2-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "evaluation" => Trial(316.667 μs)
#           "ForwardDiffAD{40, true}()" => Trial(10.624 ms)

How much guarantee do we have that nesting macros inside @model macro will work without subtle bugs etc?

TL;DR Vanilla logistic regression improved by ~20x with no tricks needed (70s in the first post to 4s now); The fastest RHS logistic regression improved slightly as well (98s now), but much simpler to implement.


Quick update following:

devmotion’s new BernoulliLogit improved gradient evaluations and sampling ~3x (on ForwardDiff)

@model function logreg_vanilla2(X)
    dim_X = size(X, 2)
    betas ~ filldist(Normal(), dim_X)
    logits = X * betas
    y ~ arraydist(LazyArray(@~ BernoulliLogit.(logits)))
    return nothing
end;
cond_model = logreg_vanilla2(X) | (; y);
turing_suite = make_turing_suite(cond_model; adbackends=DEFAULT_ADBACKENDS) |> run
@time chain = sample(cond_model, Turing.NUTS(500, 0.65; max_depth=8), 100; progress=false);

# old version -- Dist 0.25.70
# 2-element BenchmarkTools.BenchmarkGroup:
#   tags: []
#   "linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(198.876 ms)
#           "ForwardDiffAD{100, true}()" => Trial(5.311 ms)
#           "evaluation" => Trial(1.565 ms)
#           "ForwardDiffAD{40, true}()" => Trial(5.236 ms)
#           "ReverseDiffAD{true}()" => Trial(52.074 ms)
#   "not_linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(191.500 ms)
#           "ForwardDiffAD{100, true}()" => Trial(5.470 ms)
#           "evaluation" => Trial(1.577 ms)
#           "ForwardDiffAD{40, true}()" => Trial(5.406 ms)
#           "ReverseDiffAD{true}()" => Trial(49.968 ms)

# new version -- Dist 0.25.80
# 2-element BenchmarkTools.BenchmarkGroup:
#   tags: []
#   "linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(178.014 ms)
#           "ForwardDiffAD{100, true}()" => Trial(1.745 ms)
#           "evaluation" => Trial(662.792 μs)
#           "ForwardDiffAD{40, true}()" => Trial(1.747 ms)
#           "ReverseDiffAD{true}()" => Trial(52.078 ms)
#   "not_linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(177.480 ms)
#           "ForwardDiffAD{100, true}()" => Trial(1.746 ms)
#           "evaluation" => Trial(663.083 μs)
#           "ForwardDiffAD{40, true}()" => Trial(1.745 ms)
#           "ReverseDiffAD{true}()" => Trial(45.275 ms)

# Similar changes to the inference time

# old Distributions.jl
# ┌ Info: Found initial step size
# └   ϵ = 0.025
#  89.040300 seconds (1.10 M allocations: 43.573 GiB, 3.02% gc time, 0.00% compilation time)

# new Distributions.jl
# ┌ Info: Found initial step size
# └   ϵ = 0.05
#  25.547634 seconds (864.86 k allocations: 329.484 MiB, 0.19% gc time, 0.01% compilation time)

Unfortunately, the custom JAX logpdf was still beating it 2.5-times with ReverseDiff+tape and for some reason that wasn’t the case with the new BernoulliLogit (almost 2 orders of magnitude slower…)

# use custom logpdf
@model function logreg_vanilla4(X, y)
    dim_X = size(X, 2)
    betas ~ filldist(Normal(), dim_X)
    logits = X * betas
    Turing.@addlogprob! sum(jax_logpdf.(logits, y))
    return nothing
end;
cond_model = logreg_vanilla4(X, y);
turing_suite = make_turing_suite(cond_model; adbackends=DEFAULT_ADBACKENDS) |> run

# 2-element BenchmarkTools.BenchmarkGroup:
#   tags: []
#   "linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(795.791 μs)
#           "ForwardDiffAD{100, true}()" => Trial(2.919 ms)
#           "evaluation" => Trial(435.500 μs)
#           "ForwardDiffAD{40, true}()" => Trial(2.917 ms)
#           "ReverseDiffAD{true}()" => Trial(650.583 μs)
#   "not_linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(801.125 μs)
#           "ForwardDiffAD{100, true}()" => Trial(2.971 ms)
#           "evaluation" => Trial(434.750 μs)
#           "ForwardDiffAD{40, true}()" => Trial(2.929 ms)
#           "ReverseDiffAD{true}()" => Trial(650.375 μs)

Enter Tor and his investigation into ReverseDiff and Zygote taking the slow paths (explanation also here). Gradient eval goes down 2.5x compared to ForwardDiff, which leads to inference in 4seconds

# Tor's performant model with BroadcastArray
@model function logreg_vanilla6(X)
    dim_X = size(X, 2)
    betas ~ filldist(Normal(), dim_X)
    logits = X * betas
    # replacing @addlogprob! with BroadcastArray from Torfjelde's PR
    y ~ arraydist(BroadcastArray(BernoulliLogit, logits))
    return nothing
end;
cond_model = logreg_vanilla6(X) | (; y);
turing_suite = make_turing_suite(cond_model; adbackends=DEFAULT_ADBACKENDS) |> run

# 2-element BenchmarkTools.BenchmarkGroup:
# tags: []
# "linked" => 5-element BenchmarkTools.BenchmarkGroup:
#         tags: []
#         "ReverseDiffAD{false}()" => Trial(814.792 μs)
#         "ForwardDiffAD{100, true}()" => Trial(2.568 ms)
#         "evaluation" => Trial(418.792 μs)
#         "ForwardDiffAD{40, true}()" => Trial(2.574 ms)
#         "ReverseDiffAD{true}()" => Trial(619.542 μs)
# "not_linked" => 5-element BenchmarkTools.BenchmarkGroup:
#         tags: []
#         "ReverseDiffAD{false}()" => Trial(788.375 μs)
#         "ForwardDiffAD{100, true}()" => Trial(2.565 ms)
#         "evaluation" => Trial(418.541 μs)
#         "ForwardDiffAD{40, true}()" => Trial(2.562 ms)
#         "ReverseDiffAD{true}()" => Trial(623.292 μs)

Turing.setadbackend(:reversediff)
Turing.setrdcache(true)
@time chain_v = with_logger(NullLogger()) do
    chain_v = sample(cond_model, Turing.NUTS(500, 0.65; max_depth=8), 100; progress=false)
end;
# 4.276798 seconds (411.30 k allocations: 59.751 MiB, 0.38% gc time, 0.12% compilation time)

So what does that mean for the RHS model? Gradient evaluation is tiny bit faster than in logreg_rhs4, which leads to slightly faster inference (98s vs 104s before).

The biggest highlight? Beginner user could get this speed without any shenanigans (custom logpdf, @addlogprob, etc) – all you need is the BroadcastArray and the PR to be merged.

# Alternative parametrization based on: https://arxiv.org/pdf/1707.01694.pdf#page11 (Appendix C2)
# and BroadcastArray from Torfjelde's PR
@model function logreg_rhs7(X)
    slab_df = 4
    slab_scale = 2
    eff_params = 3
    dim_X = size(X, 2)
    tau_0 = eff_params / (dim_X - eff_params) / sqrt(size(X, 1))

    # re-parametrize lambdas and lambdas squared directly to plug into lambas_tilde
    lambdas_scale ~ filldist(InverseGamma(0.5, 0.5), dim_X)
    lambdas_z ~ filldist(truncated(Normal(); lower=0), dim_X)
    lambdas_sq = lambdas_z .^ 2 .* lambdas_scale

    # re-parametrize tau and and square it directly to plug into lambas_tilde
    tau_scale ~ InverseGamma(0.5, 0.5)
    tau_z ~ truncated(Normal(); lower=0)
    tau_sq = tau_z^2 .* tau_scale .* tau_0^2
    z ~ filldist(Normal(), dim_X)

    c_aux ~ InverseGamma(0.5 * slab_df, 0.5 * slab_df)
    c_sq = slab_scale^2 * c_aux # squared already
    lambdas_tilde = sqrt.((c_sq .* lambdas_sq) ./ (c_sq .+ tau_sq .* lambdas_sq))
    betas = lambdas_tilde .* sqrt(tau_sq) .* z
    logits = X * betas
    y ~ arraydist(BroadcastArray(BernoulliLogit, logits))
    return nothing
end;
cond_model = logreg_rhs7(X) | (; y);
turing_suite = make_turing_suite(cond_model; adbackends=DEFAULT_ADBACKENDS) |> run
# 2-element BenchmarkTools.BenchmarkGroup:
#   tags: []
#   "linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(991.750 μs)
#           "ForwardDiffAD{100, true}()" => Trial(19.483 ms)
#           "evaluation" => Trial(329.375 μs)
#           "ForwardDiffAD{40, true}()" => Trial(18.484 ms)
#           "ReverseDiffAD{true}()" => Trial(609.958 μs)
#   "not_linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(1.115 ms)
#           "ForwardDiffAD{100, true}()" => Trial(19.978 ms)
#           "evaluation" => Trial(328.125 μs)
#           "ForwardDiffAD{40, true}()" => Trial(18.959 ms)
#           "ReverseDiffAD{true}()" => Trial(595.208 μs)

@time chain = with_logger(NullLogger()) do
    chain = sample(cond_model, Turing.NUTS(500, 0.65; max_depth=8), 100; progress=false)
end;
98.191588 seconds (162.01 M allocations: 6.176 GiB, 0.66% gc time, 0.01% compilation time)

Can you provide that to examine? Interested to see the difference in customization.

It’s in the first post in the section “Logprob benchmarking”

# custom implementation in Julia
# from https://www.tensorflow.org/api_docs/python/tf/nn/sigmoid_cross_entropy_with_logits
jax_logpdf(x,y)=-(max(0,x) +log1pexp(-abs(x)) -x*y)

I learned about it from JAX docs when doing the comparison, but you can see that it comes from TF.

I believe that’s what Devmotion implemented in Distributions.jl.