Speed up multilevel model in Turing.jl

Hi everybody,

I’ve coded up a multilevel model with random intercepts and slopes in Turing (see below). However, my model is very slow compared to similar models in Stan. Does anybody have any insights on how to speed things up further?

Thanks

using Turing
using Random
using MCMCChains
using LinearAlgebra: I
using StatsPlots

@model function linreg(X, y, idx; P = size(X,2), G = length(unique(idx)))

    # prior for overall noise
    σ ~ Exponential(1)
    
    # priors over means of group-level intercept and slopes 
    α ~ Normal(0, 2.5)
    β ~ filldist(Normal(0, 2.5), P)

    # priors over means of subject level intercept and slopes 
    αₚ ~ filldist(Normal(), G)
    βₚ ~ filldist(Normal(), P, G)

    # priors over stds of subject level intercept and slopes 
    σₐ ~ Exponential(1)            
    σᵦ ~ filldist(Exponential(1), P) 

    # construct regression coefficients
    α_ = α .+ αₚ .* σₐ
    β_ = β .+ βₚ .* σᵦ    

    # likelihood
    μ = α_[idx] + sum(X .* β_[:,idx]', dims=2)
    y ~ MvNormal(vec(μ), σ^2 * I)
end;

# generate some synthetic data to test the model
Random.seed!(123)
N = 100
G = 10
α = 5
β = [-2, 3, -4, 5]
σ = .1
σₐ = 2
σᵦ = [1 2 .1 .5]
αₚ = rand(Normal(), G).*σₐ .+ α
βₚ = rand(MvNormal(zeros(length(β)), I), G).*σᵦ' .+ β
X = randn(N*G, length(β))
idx = collect(repeat(1:G, inner=N))
μ = αₚ[idx] + sum(X .* βₚ[:,idx]', dims=2)
y = rand(MvNormal(vec(μ), σ^2 * I))

model = linreg(X, y, idx)
chn = sample(model, NUTS(), 100)
plot(group(chn, :β))
plot(group(chn, :σᵦ))

Try switching to autodiff to ReverseDiff via Turing.setadbackend(:reversediff). That should speed up the code significantly. It will still run slower than Stan until Enzyme or some other more performant autodiff becomes available.

1 Like

Thanks, that does help quite a bit! But you’re right, still a couple of margins slower than Stan unfortunately. Will keep an eye out for better auto diff functionality. Do you know if there’s any timeline for this?

Also do Turing.setrdcache(true). Should result in further speedup.

1 Like

And no there’s unfortunately no timeline regarding the AD :confused:

Ah nice! With this I feel like I’m close to what I’m seeing in Stan!! (but haven’t done any proper comparison)

edit: ran some benchmarks, it is definitely still slower unfortunately :confused:

Yeah. That has been my experience too. It does not scale as well as Stan does. My hope is that the situation improves when Enzyme is more mature.

Are you running into Neal’s Funnel ?

1 Like

I was going to check the maximum likelihood estimates for the parameters using [MixedModels.jl](GitHub - JuliaStats/MixedModels.jl: A Julia package for fitting (statistical) mixed-effects models] but I encountered an error trying to attach the package. It appears that one of the dependencies of MixedModels triggers precompilation of DistributionsADLazyArraysExt which lacks a definition of UnivariateDistribution

[ Info: Precompiling ArrayInterfaceBandedMatricesExt [26e938bc-0cd5-5679-9003-44616cbf91d3]
[ Info: Precompiling DistributionsADLazyArraysExt [1640c15e-5bc7-5daf-855c-3cb1bdbdee06]
ERROR: LoadError: UndefVarError: `UnivariateDistribution` not defined
Stacktrace:
 [1] top-level scope
   @ ~/.julia/packages/DistributionsAD/e9aui/ext/DistributionsADLazyArraysExt.jl:15
 [2] include
   @ ./Base.jl:457 [inlined]
...

As a work-around I wrote the data to an Arrow file, splitting the matrix X into 4 columns X1, X2, X3 and X4

julia> using Arrow, MixedModels, TypedTables

julia> tbl = Arrow.Table("./simulate.arrow")
Arrow.Table with 1000 rows, 6 columns, and schema:
 :X1   Float64
 :X2   Float64
 :X3   Float64
 :X4   Float64
 :y    Float64
 :idx  Int64

julia> Table(tbl)
Table with 6 columns and 1000 rows:
      X1          X2         X3          X4          y         idx
    ┌─────────────────────────────────────────────────────────────
 1  │ 1.17808     -1.39646   0.474254    -0.641192   -2.39514  1
 2  │ -1.28548    -0.834103  0.63777     0.218158    4.30926   1
 3  │ 1.03768     0.399244   0.468157    -0.0545903  4.70301   1
 4  │ 0.978176    1.25376    0.506624    0.346546    8.45962   1
 5  │ -0.785813   -0.305753  1.06711     1.37005     8.79346   1
 6  │ -0.13633    0.522859   1.4255      0.775234    6.17984   1
 7  │ -0.436433   -0.326708  -0.60639    -0.107138   7.97244   1
 8  │ -1.89627    1.4558     0.279539    -1.85704    2.13968   1
 9  │ -0.174792   0.0676495  1.58414     1.33187     7.07517   1
 10 │ -0.0592719  -0.444041  0.39296     0.168131    4.89601   1
 11 │ -1.57775    1.63487    -0.205142   -0.175202   11.7824   1
 12 │ 0.396038    0.406323   -0.0623051  0.214006    8.48127   1
 13 │ 0.255504    0.397921   -1.23652    -0.365      10.3261   1
 14 │ -0.296876   0.814631   -0.0727651  -1.2915     3.18705   1
 15 │ 0.285478    0.346752   1.51136     -0.866273   -2.41642  1
 16 │ -1.41377    1.25227    -0.121022   1.21157     16.7508   1
 17 │ 1.11677     1.34093    -1.75145    -0.778489   11.9013   1
 18 │ 0.722617    -1.08571   -0.985866   -0.637033   4.2221    1
 19 │ -1.00617    1.21131    1.73369     2.5076      15.2832   1
 20 │ -1.1118     -0.92523   1.11891     -1.26446    -4.61929  1
 21 │ -1.23767    0.443356   1.25002     -0.0987525  3.54935   1
 22 │ 0.860856    1.00413    -1.28248    1.76232     21.1302   1
 23 │ 1.06417     -1.44945   0.46746     0.0530596   0.771947  1
 24 │ 0.303779    0.87557    -1.97205    -0.510784   13.5135   1
 25 │ -0.956855   -0.879204  -0.510858   -0.38644    5.45459   1
 26 │ 0.877251    -1.90806   1.89145     0.56419     -3.1738   1
 27 │ 1.17449     -0.732107  0.796771    1.1526      6.16885   1
 ⋮  │     ⋮           ⋮          ⋮           ⋮          ⋮       ⋮

julia> mmfit1 = fit(
           MixedModel,
           @formula(y ~ 1 + X1 + X2 + X3 + X4 + zerocorr(1 + X1 + X2 + X3 + X4 | idx)),
           tbl;
           contrasts = Dict(:idx => Grouping()),
       )
Linear mixed model fit by maximum likelihood
 y ~ 1 + X1 + X2 + X3 + X4 + zerocorr(1 + X1 + X2 + X3 + X4 | idx)
   logLik   -2 logLik     AIC       AICc        BIC    
  2857.1330 -5714.2660 -5692.2660 -5691.9988 -5638.2807

Variance components:
            Column    Variance  Std.Dev.   Corr.
idx      (Intercept)  2.3710714 1.5398284
         X1           0.4638382 0.6810566   .  
         X2           3.6333740 1.9061411   .     .  
         X3           0.0131245 0.1145624   .     .     .  
         X4           0.2135884 0.4621562   .     .     .     .  
Residual              0.0001014 0.0100706
 Number of obs: 1000; levels of grouping factors: 10

  Fixed-effects parameters:
────────────────────────────────────────────────────
                Coef.  Std. Error        z  Pr(>|z|)
────────────────────────────────────────────────────
(Intercept)   4.304     0.486937      8.84    <1e-18
X1           -2.15691   0.215369    -10.01    <1e-22
X2            3.60678   0.602775      5.98    <1e-08
X3           -3.96721   0.0362293  -109.50    <1e-99
X4            4.84147   0.146147     33.13    <1e-99
────────────────────────────────────────────────────

This only takes a few milliseconds.

Is there any value in having the MLE’s to start off the MCMC iterations?

@ palday, are you asking because I’m reparameterizing the model the same way one would do to get around Neal’s Funnel? I don’t think I’m running into the funnel here (or am I?) but I just wanted to make the code robust to also sample from posteriors with funny geometry.

@ dmbates, sounds like an interesting idea. I didn’t want to get MixedModels.jl to work, so I used the means of the posteriors of a previous run as the initial values for the sampler. Unfortunately this doesn’t really seem to speed up things either.

Here’s the full code I’m using in case it’s useful:

using Turing, Random, MCMCChains, LinearAlgebra, StatsPlots, ReverseDiff, BenchmarkTools

@model function linreg(X, y, idx; P = size(X,2), G = length(unique(idx)))

    # prior for overall noise
    σ ~ Exponential(1)
    
    # priors over means of group-level intercept and slopes 
    α ~ Normal(0, 2.5)
    β ~ filldist(Normal(0, 2.5), P)

    # priors over means of subject level intercept and slopes 
    αₚ ~ filldist(Normal(), G)
    βₚ ~ filldist(Normal(), P, G)

    # priors over stds of subject level intercept and slopes 
    σₐ ~ Exponential(1)            
    σᵦ ~ filldist(Exponential(1), P) 

    # construct regression coefficients
    α_ = α .+ αₚ .* σₐ
    β_ = β .+ βₚ .* σᵦ    

    # likelihood
    μ = α_[idx] + sum(X .* β_[:,idx]', dims=2)
    y ~ MvNormal(vec(μ), σ^2 * I)
end;

# generate some synthetic data to test the model
Random.seed!(123)
N = 100
G = 10
α = 5
β = [-2, 3, -4, 5]
σ = .1
σₐ = 2
σᵦ = [1 2 .1 .5]
αₚ = rand(Normal(), G).*σₐ .+ α
βₚ = rand(MvNormal(zeros(length(β)), I), G).*σᵦ' .+ β
X = randn(N*G, length(β))
idx = collect(repeat(1:G, inner=N))
μ = αₚ[idx] + sum(X .* βₚ[:,idx]', dims=2)
y = rand(MvNormal(vec(μ), σ^2 * I))

# set auto diff to reversse
Turing.setadbackend(:reversediff)
Turing.setrdcache(true)

m = linreg(X, y, idx)

@btime chn = sample(m, NUTS(), 100)

@btime chn2 = sample(m, NUTS(), 100, init_params = mean(chn.value.data, dims=1)[1:61])

This should be fixed now:) It was a bug that was up only for a few hours I believe. Sorry about that!

How much slower roughly speaking? Note that it’s not surprising that it’s slower for something as simple as linear regression.

Just out of interest: why would that be expected? And what is keeping Turing.jl from bridging the gap to Stan specifically in the linear regression case? Not challenging just curious.

Stan has it’s own AD written in C++ specifically for their own codebase. For “simple” models, e.g. linear regression, it’s difficult to be better than an AD engine specifically written to be fast on this problem using general-purpose AD in Julia.

Got it. Makes sense. Couldn’t this, in principle, be mitigated by writing more high-level rules in chain rules? At least for the ADs that are hooked up to it?

In theory, yes. But in practice that is difficult. In Turing.jl there are several functions between x ~ Normal() and the resulting logpdf computation. Sure, if we defined all those rules we’d be golden, but it wouldn’t be maintaible.

With that being said, there’s tons of perf-optimization we can do beyond what we’re currently doing! For example for his linreg model, we might not be far off being able to use Enzyme + some custom implementation of our tracing structure that is fully mutable. And that should be pretty darn close to Stan. But

3 Likes