Case study: Speeding up a logistic regression with RHS prior (Turing vs Numpyro) - any tricks I'm missing?

Great investigation!

I think we have a few different threads running here:
a) speed tricks for logistic regression with NUTS sampler
b) alternative samplers (Pathfinder, ZigZagBoomerang’s samplers)
c) challenges with LKJ / correlation among variables


This answer is following up on thread a) - speed tricks for logistic regression.

Following the discussion in here, I’ve applied Tullio.jl to calculate the matrix-vector multiply (logits) and logpdf calculation.

Findings:

  • It doesn’t change the “best” model (with ReverseDiff)
  • It does, however, speed up ForwardDiff-based gradient calculation (18ms → 10ms)!
  • I have tried to calculate only the logpdfs+reduction in Tullio and that was as slow as baseline (ie, the benefit seems to come from the matrix-vector multiplication as Chad suggested on improving at the top)

Note: Tullio should work with ForwardDiff, but I haven’t explicitly checked the grads to confirm.

# leverage tullio for matrix-vector multiply and calculating logpdf
using Tullio
@model function logreg_rhs6b(X,y)
    # these could be provided via args but it makes the function calls clunky, perhaps a parameters NamedTuple
    slab_df=4
    slab_scale=2
    eff_params=3
    dim_X=size(X,2)
    tau_0=eff_params/(dim_X-eff_params)/sqrt(size(X,1))

    # re-parametrize lambdas and lambdas squared directly to plug into lambas_tilde
    lambdas_scale ~ filldist(InverseGamma(0.5,0.5),dim_X)
    lambdas_z ~ filldist(truncated(Normal();lower=0),dim_X)
    lambdas_sq= lambdas_z.^2 .* lambdas_scale

    # re-parametrize tau and and square it directly to plug into lambas_tilde
    tau_scale ~ InverseGamma(0.5,0.5)
    tau_z ~ truncated(Normal();lower=0)
    tau_sq = tau_z^2 .*tau_scale .* tau_0^2
    z ~ filldist(Normal(),dim_X)

    c_aux ~ InverseGamma(0.5 * slab_df, 0.5 * slab_df)
    c_sq=slab_scale^2 * c_aux # squared already
    lambdas_tilde=sqrt.((c_sq .* lambdas_sq) ./ (c_sq .+ tau_sq .* lambdas_sq))
    betas=lambdas_tilde .* sqrt(tau_sq) .* z

    @tullio llik[i] := jax_logpdf(_,y[i]) <| (X[i,j]*betas[j]) grad=Dual
    Turing.@addlogprob! sum(llik)
    # return logits
end;
cond_model=logreg_rhs6b(X,y);
turing_suite=make_turing_suite(cond_model;adbackends=[ForwardDiffAD{40}()]) |> run

## uses only ForwardDiff as ReverseDiff crashes hard (due to Tullio)
## TULLIO used for logits and logpdf
# 2-element BenchmarkTools.BenchmarkGroup:
#   tags: []
#   "linked" => 2-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "evaluation" => Trial(316.750 μs)
#           "ForwardDiffAD{40, true}()" => Trial(10.315 ms)
#   "not_linked" => 2-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "evaluation" => Trial(316.667 μs)
#           "ForwardDiffAD{40, true}()" => Trial(10.624 ms)