Case study: Speeding up a logistic regression with RHS prior (Turing vs Numpyro) - any tricks I'm missing?

Hi @svilupp , I got a sticky Zig-Zag running on the example with Tilde:

import Pkg
cd(@__DIR__)
# Pkg.activate(@__DIR__)

using Tilde, Pathfinder,  PDMats, StructArrays
using ForwardDiff
using ForwardDiff: Dual
using LinearAlgebra, Random, Statistics, StatsBase, SparseArrays
using ZigZagBoomerang
using ZigZagBoomerang: StickyBarriers, StructuredTarget, StickyUpperBounds, StickyFlow, EndTime
using MCMCChains
using ArraysOfArrays

# Configuration 
Random.seed!(1)
κ = 0.01 # stickyness
T = 5000.0 # sampling time
c = 0.01
progress = true # show progress bar
PLOT = true # plot posterior trace
nsamples = 200
Δt = T/nsamples

# Generate mock data
println("Data...")
X = hcat(ones(20000),randn(20000,22))
Xt = Matrix(X')
d = size(X, 2)
n = size(X, 1)
betas = vcat([-0.8], zeros(3), [1.0], zeros(5), [0.9], zeros(12))
@assert length(betas) == d
y = (X*betas) .|> x->rand(Bernoulli(logitp = x))
@info "Important coef positions: $(findall(betas.!=0)), Average rate: $(mean(y))"

# Simple logistic model
model_lr = @model (Xt, y) begin
    d, n = size(Xt)
    θ ~ Normal() ^ d
    for j in 1:n
        logitp = view(Xt, :, j)' * θ
        y[j] ~ Bernoulli(logitp = logitp)
    end
end

# Gradients
function make_grads(model_lr, At, y, d)    
    post = model_lr(At, y) | (;y)
    as_post = as(post)
    obj(θ) = -Tilde.unsafe_logdensityof(post, transform(as_post, θ))
    ℓ(θ) = -obj(θ)

    gconfig = ForwardDiff.GradientConfig(obj, rand(25), ForwardDiff.Chunk{25}())
    function ∇neglogp!(y, t, x, args...)
        ForwardDiff.gradient!(y, obj, x, gconfig)
        return
    end

    ith = zeros(d)
    function ∂neglogp(x,i)
        # should use StructArrays, seems tilde broke that
        ForwardDiff.partials(obj([Dual{}(x[j], 1.0*(i==j)) for j in eachindex(x)]))[]
    end
    
    post, ℓ, ∇neglogp!, ∂neglogp
end
post, ℓ, ∇neglogp!, ∂neglogp =  make_grads(model_lr, Xt, y, d)    

# Pathfinding
println("Pathfinder...")
init_scale = 1
if !@isdefined pf_result
    @time pf_result = pathfinder(ℓ; dim=d, init_scale)
end
M = PDMats.PDiagMat(diag(pf_result.fit_distribution.Σ))
Γ = sparse(inv(M))
#Γ = sparse(inv(pf_result.fit_distribution.Σ))
x0 = μ = pf_result.fit_distribution.μ
v0 = PDMats.unwhiten(M, randn(length(x0)))



# Sticky sampler
println("Sticky sampler...")
barriers = [StickyBarriers((0.0, 0.0), (:sticky, :sticky), (κ, κ)) for i in 1:d]
d = length(x0)
t0 = fill(0.0, d)
u0 = (t0, x0, v0) 
target = StructuredTarget([i => 1:d for i in 1:d], ∂neglogp)
flow = StickyFlow(ZigZag(Γ, μ))
strong_upperbounds = false
adapt = true
multiplier = 1.7 # increase bounds
G = target.G
G1 = [i => rowvals(Γ)[nzrange(Γ, i)] for i in axes(Γ, 1)]
upper_bounds = StickyUpperBounds(G, G1, Γ, fill(c, d); adapt=adapt, strong = strong_upperbounds, multiplier= multiplier)
end_time = EndTime(T)
∇ϕ(x, i) = ZigZagBoomerang.idot(Γ, i, x) # sparse computation
elapsed_time = @elapsed begin
trace, _, _, acc = @time stickyzz(u0, target, flow, upper_bounds, barriers, end_time; progress=progress)
end
@info "Upper bounds: $(upper_bounds.c)"
println("acc ", acc.acc/acc.num)


# Plot continuous trace
if PLOT
   
    ts, xs = ZigZagBoomerang.sep(collect(trace))
    println("Plot...")
    colors = [:green, :red, :blue, :violet]
    using GLMakie
    fig1 = fig = Figure()
    r = 1:length(ts)
    ax = Axis(fig[1,1], title = "trace")
    is = [1, 2, 5, 11]
    for i in 1:length(is)
        lines!(ax, ts[r], getindex.(xs[r], is[i]), color=colors[i])
        lines!(ax, ts[r], fill(betas[is[i]], length(r)), linestyle=:dash, color = (colors[i], 0.5))
    end
    display(fig)
end

# Your samples
samples = flatview(VectorOfSimilarVectors(ZigZagBoomerang.sep(ZigZagBoomerang.discretise(trace, Δt))[2]))

chain = MCMCChains.Chains(samples')
chain = setinfo(chain, (;start_time=0.0, stop_time = elapsed_time));
chain

Doesn’t seem like a bad idea :slight_smile: Note that I replace the Horseshoe prior with a Spike and Slab with spike at 0 with weight 1/κ, which the sampler handles (the likelihood just shows the slab/Normal part).

(Link to gist with project)

3 Likes