Great investigation!
I think we have a few different threads running here:
a) speed tricks for logistic regression with NUTS sampler
b) alternative samplers (Pathfinder, ZigZagBoomerang’s samplers)
c) challenges with LKJ / correlation among variables
This answer is following up on thread a) - speed tricks for logistic regression.
Following the discussion in here, I’ve applied Tullio.jl to calculate the matrix-vector multiply (logits) and logpdf calculation.
Findings:
- It doesn’t change the “best” model (with ReverseDiff)
- It does, however, speed up ForwardDiff-based gradient calculation (18ms → 10ms)!
- I have tried to calculate only the logpdfs+reduction in Tullio and that was as slow as baseline (ie, the benefit seems to come from the matrix-vector multiplication as Chad suggested on improving at the top)
Note: Tullio should work with ForwardDiff, but I haven’t explicitly checked the grads to confirm.
# leverage tullio for matrix-vector multiply and calculating logpdf
using Tullio
@model function logreg_rhs6b(X,y)
# these could be provided via args but it makes the function calls clunky, perhaps a parameters NamedTuple
slab_df=4
slab_scale=2
eff_params=3
dim_X=size(X,2)
tau_0=eff_params/(dim_X-eff_params)/sqrt(size(X,1))
# re-parametrize lambdas and lambdas squared directly to plug into lambas_tilde
lambdas_scale ~ filldist(InverseGamma(0.5,0.5),dim_X)
lambdas_z ~ filldist(truncated(Normal();lower=0),dim_X)
lambdas_sq= lambdas_z.^2 .* lambdas_scale
# re-parametrize tau and and square it directly to plug into lambas_tilde
tau_scale ~ InverseGamma(0.5,0.5)
tau_z ~ truncated(Normal();lower=0)
tau_sq = tau_z^2 .*tau_scale .* tau_0^2
z ~ filldist(Normal(),dim_X)
c_aux ~ InverseGamma(0.5 * slab_df, 0.5 * slab_df)
c_sq=slab_scale^2 * c_aux # squared already
lambdas_tilde=sqrt.((c_sq .* lambdas_sq) ./ (c_sq .+ tau_sq .* lambdas_sq))
betas=lambdas_tilde .* sqrt(tau_sq) .* z
@tullio llik[i] := jax_logpdf(_,y[i]) <| (X[i,j]*betas[j]) grad=Dual
Turing.@addlogprob! sum(llik)
# return logits
end;
cond_model=logreg_rhs6b(X,y);
turing_suite=make_turing_suite(cond_model;adbackends=[ForwardDiffAD{40}()]) |> run
## uses only ForwardDiff as ReverseDiff crashes hard (due to Tullio)
## TULLIO used for logits and logpdf
# 2-element BenchmarkTools.BenchmarkGroup:
# tags: []
# "linked" => 2-element BenchmarkTools.BenchmarkGroup:
# tags: []
# "evaluation" => Trial(316.750 μs)
# "ForwardDiffAD{40, true}()" => Trial(10.315 ms)
# "not_linked" => 2-element BenchmarkTools.BenchmarkGroup:
# tags: []
# "evaluation" => Trial(316.667 μs)
# "ForwardDiffAD{40, true}()" => Trial(10.624 ms)