Case study: Speeding up a logistic regression with RHS prior (Turing vs Numpyro) - any tricks I'm missing?

TL;DR Vanilla logistic regression improved by ~20x with no tricks needed (70s in the first post to 4s now); The fastest RHS logistic regression improved slightly as well (98s now), but much simpler to implement.


Quick update following:

devmotion’s new BernoulliLogit improved gradient evaluations and sampling ~3x (on ForwardDiff)

@model function logreg_vanilla2(X)
    dim_X = size(X, 2)
    betas ~ filldist(Normal(), dim_X)
    logits = X * betas
    y ~ arraydist(LazyArray(@~ BernoulliLogit.(logits)))
    return nothing
end;
cond_model = logreg_vanilla2(X) | (; y);
turing_suite = make_turing_suite(cond_model; adbackends=DEFAULT_ADBACKENDS) |> run
@time chain = sample(cond_model, Turing.NUTS(500, 0.65; max_depth=8), 100; progress=false);

# old version -- Dist 0.25.70
# 2-element BenchmarkTools.BenchmarkGroup:
#   tags: []
#   "linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(198.876 ms)
#           "ForwardDiffAD{100, true}()" => Trial(5.311 ms)
#           "evaluation" => Trial(1.565 ms)
#           "ForwardDiffAD{40, true}()" => Trial(5.236 ms)
#           "ReverseDiffAD{true}()" => Trial(52.074 ms)
#   "not_linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(191.500 ms)
#           "ForwardDiffAD{100, true}()" => Trial(5.470 ms)
#           "evaluation" => Trial(1.577 ms)
#           "ForwardDiffAD{40, true}()" => Trial(5.406 ms)
#           "ReverseDiffAD{true}()" => Trial(49.968 ms)

# new version -- Dist 0.25.80
# 2-element BenchmarkTools.BenchmarkGroup:
#   tags: []
#   "linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(178.014 ms)
#           "ForwardDiffAD{100, true}()" => Trial(1.745 ms)
#           "evaluation" => Trial(662.792 μs)
#           "ForwardDiffAD{40, true}()" => Trial(1.747 ms)
#           "ReverseDiffAD{true}()" => Trial(52.078 ms)
#   "not_linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(177.480 ms)
#           "ForwardDiffAD{100, true}()" => Trial(1.746 ms)
#           "evaluation" => Trial(663.083 μs)
#           "ForwardDiffAD{40, true}()" => Trial(1.745 ms)
#           "ReverseDiffAD{true}()" => Trial(45.275 ms)

# Similar changes to the inference time

# old Distributions.jl
# ┌ Info: Found initial step size
# └   ϵ = 0.025
#  89.040300 seconds (1.10 M allocations: 43.573 GiB, 3.02% gc time, 0.00% compilation time)

# new Distributions.jl
# ┌ Info: Found initial step size
# └   ϵ = 0.05
#  25.547634 seconds (864.86 k allocations: 329.484 MiB, 0.19% gc time, 0.01% compilation time)

Unfortunately, the custom JAX logpdf was still beating it 2.5-times with ReverseDiff+tape and for some reason that wasn’t the case with the new BernoulliLogit (almost 2 orders of magnitude slower…)

# use custom logpdf
@model function logreg_vanilla4(X, y)
    dim_X = size(X, 2)
    betas ~ filldist(Normal(), dim_X)
    logits = X * betas
    Turing.@addlogprob! sum(jax_logpdf.(logits, y))
    return nothing
end;
cond_model = logreg_vanilla4(X, y);
turing_suite = make_turing_suite(cond_model; adbackends=DEFAULT_ADBACKENDS) |> run

# 2-element BenchmarkTools.BenchmarkGroup:
#   tags: []
#   "linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(795.791 μs)
#           "ForwardDiffAD{100, true}()" => Trial(2.919 ms)
#           "evaluation" => Trial(435.500 μs)
#           "ForwardDiffAD{40, true}()" => Trial(2.917 ms)
#           "ReverseDiffAD{true}()" => Trial(650.583 μs)
#   "not_linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(801.125 μs)
#           "ForwardDiffAD{100, true}()" => Trial(2.971 ms)
#           "evaluation" => Trial(434.750 μs)
#           "ForwardDiffAD{40, true}()" => Trial(2.929 ms)
#           "ReverseDiffAD{true}()" => Trial(650.375 μs)

Enter Tor and his investigation into ReverseDiff and Zygote taking the slow paths (explanation also here). Gradient eval goes down 2.5x compared to ForwardDiff, which leads to inference in 4seconds

# Tor's performant model with BroadcastArray
@model function logreg_vanilla6(X)
    dim_X = size(X, 2)
    betas ~ filldist(Normal(), dim_X)
    logits = X * betas
    # replacing @addlogprob! with BroadcastArray from Torfjelde's PR
    y ~ arraydist(BroadcastArray(BernoulliLogit, logits))
    return nothing
end;
cond_model = logreg_vanilla6(X) | (; y);
turing_suite = make_turing_suite(cond_model; adbackends=DEFAULT_ADBACKENDS) |> run

# 2-element BenchmarkTools.BenchmarkGroup:
# tags: []
# "linked" => 5-element BenchmarkTools.BenchmarkGroup:
#         tags: []
#         "ReverseDiffAD{false}()" => Trial(814.792 μs)
#         "ForwardDiffAD{100, true}()" => Trial(2.568 ms)
#         "evaluation" => Trial(418.792 μs)
#         "ForwardDiffAD{40, true}()" => Trial(2.574 ms)
#         "ReverseDiffAD{true}()" => Trial(619.542 μs)
# "not_linked" => 5-element BenchmarkTools.BenchmarkGroup:
#         tags: []
#         "ReverseDiffAD{false}()" => Trial(788.375 μs)
#         "ForwardDiffAD{100, true}()" => Trial(2.565 ms)
#         "evaluation" => Trial(418.541 μs)
#         "ForwardDiffAD{40, true}()" => Trial(2.562 ms)
#         "ReverseDiffAD{true}()" => Trial(623.292 μs)

Turing.setadbackend(:reversediff)
Turing.setrdcache(true)
@time chain_v = with_logger(NullLogger()) do
    chain_v = sample(cond_model, Turing.NUTS(500, 0.65; max_depth=8), 100; progress=false)
end;
# 4.276798 seconds (411.30 k allocations: 59.751 MiB, 0.38% gc time, 0.12% compilation time)

So what does that mean for the RHS model? Gradient evaluation is tiny bit faster than in logreg_rhs4, which leads to slightly faster inference (98s vs 104s before).

The biggest highlight? Beginner user could get this speed without any shenanigans (custom logpdf, @addlogprob, etc) – all you need is the BroadcastArray and the PR to be merged.

# Alternative parametrization based on: https://arxiv.org/pdf/1707.01694.pdf#page11 (Appendix C2)
# and BroadcastArray from Torfjelde's PR
@model function logreg_rhs7(X)
    slab_df = 4
    slab_scale = 2
    eff_params = 3
    dim_X = size(X, 2)
    tau_0 = eff_params / (dim_X - eff_params) / sqrt(size(X, 1))

    # re-parametrize lambdas and lambdas squared directly to plug into lambas_tilde
    lambdas_scale ~ filldist(InverseGamma(0.5, 0.5), dim_X)
    lambdas_z ~ filldist(truncated(Normal(); lower=0), dim_X)
    lambdas_sq = lambdas_z .^ 2 .* lambdas_scale

    # re-parametrize tau and and square it directly to plug into lambas_tilde
    tau_scale ~ InverseGamma(0.5, 0.5)
    tau_z ~ truncated(Normal(); lower=0)
    tau_sq = tau_z^2 .* tau_scale .* tau_0^2
    z ~ filldist(Normal(), dim_X)

    c_aux ~ InverseGamma(0.5 * slab_df, 0.5 * slab_df)
    c_sq = slab_scale^2 * c_aux # squared already
    lambdas_tilde = sqrt.((c_sq .* lambdas_sq) ./ (c_sq .+ tau_sq .* lambdas_sq))
    betas = lambdas_tilde .* sqrt(tau_sq) .* z
    logits = X * betas
    y ~ arraydist(BroadcastArray(BernoulliLogit, logits))
    return nothing
end;
cond_model = logreg_rhs7(X) | (; y);
turing_suite = make_turing_suite(cond_model; adbackends=DEFAULT_ADBACKENDS) |> run
# 2-element BenchmarkTools.BenchmarkGroup:
#   tags: []
#   "linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(991.750 μs)
#           "ForwardDiffAD{100, true}()" => Trial(19.483 ms)
#           "evaluation" => Trial(329.375 μs)
#           "ForwardDiffAD{40, true}()" => Trial(18.484 ms)
#           "ReverseDiffAD{true}()" => Trial(609.958 μs)
#   "not_linked" => 5-element BenchmarkTools.BenchmarkGroup:
#           tags: []
#           "ReverseDiffAD{false}()" => Trial(1.115 ms)
#           "ForwardDiffAD{100, true}()" => Trial(19.978 ms)
#           "evaluation" => Trial(328.125 μs)
#           "ForwardDiffAD{40, true}()" => Trial(18.959 ms)
#           "ReverseDiffAD{true}()" => Trial(595.208 μs)

@time chain = with_logger(NullLogger()) do
    chain = sample(cond_model, Turing.NUTS(500, 0.65; max_depth=8), 100; progress=false)
end;
98.191588 seconds (162.01 M allocations: 6.176 GiB, 0.66% gc time, 0.01% compilation time)