Turing performance vs numpyro on horseshoe regression

Hi all,

I’ve benchmarked a simple implementation of regression with a horseshoe prior over the coefficients in both numpyro and Turing. I found numpyro to over an order of magnitude faster, which I found quite surprising. Here’s the Turing implementation - would be grateful for feedback on anything obviously incorrect here:

using Distributions
using Turing
# using TuringBenchmarking
using Statistics
using LinearAlgebra
using Enzyme
using ReverseDiff


function get_data()
    N = 100
    P = 1000
    D = 5

    X = rand(Normal(0, 1), N, P)
    W = zeros(P)
    W[1:D] .= [0.2, -0.3, 0.35, -0.4, 0.45]

    Y = X * W + rand(Normal(0, sqrt(1 - .615)), N)

    return X, Y, W
end

@model function model(X; N::Int64, P::Int64)

    λ ~ filldist(truncated(Cauchy(0.0, 1.0); lower=0), P)
    τ ~ truncated(Cauchy(0.0, 1.0); lower=0)

    W ~ MvNormal(zeros(P), I * λ.^2 * τ^2)

    μ = X * W

    σ ~ truncated(Cauchy(0.0, 1.0); lower=0)

    Y ~ MvNormal(μ, I * σ^2)
end


function inference()
    @time X, Y, W = get_data()

    @time m = model(X; N=size(X, 1), P=size(X, 2)) | (; Y=Y)

    # turing_suite = make_turing_suite(m)
    # run(turing_suite)
    # @time chain = sample(m, NUTS(0.85; adtype=Turing.Essential.AutoEnzyme()), MCMCThreads(), 500, 2)
    @time chain = sample(m, NUTS(0.85; adtype=AutoReverseDiff(true)), MCMCThreads(), 500, 2)

    return chain
end

# inference()

I’ve installing the Turing branch with Enzyme AD functionality, though not clear to me whether that is considered ready for use yet. The numpyro implementation assumes the exact same data generating process, etc, and I think the max depth in NUTS is the same.

Thanks!

2 Likes

This is not an isolated example, unfortunately. I have also found that Numpyro is at least an order of magnitude faster than Turing using any AD backend and the maximum amount of vectorization and multivariate distribution optimizations that I can think of on any non-trivial modeling problem.

I much prefer Numpyro’s “plate” syntax for inducing conditional independencies in the model. It is clear and easy to understand both from a programming and mathematical perspective.

My general impression is that the performance issues in Turing are multifaceted.

1 Like

Thanks. I was wondering whether it was something specific to this model, perhaps the truncated calls or similar resulting in the poor performance. But perhaps numpyro is just that fast. Not clear to me how to develop intuition for which models turing will be about the same as numpyro, and for which models it will be an order of magnitude slower (like the above example).

1 Like

Hi, just a side note but when you say you’re comparing two implementations it might be a good idea to provide both implementations.

1 Like

Fair enough, here’s the numpyro implementation:

import os
import time
import numpy as np
import jax.numpy as jnp
import jax.random as random
import numpyro
import numpyro.distributions as dist 
from numpyro.infer import MCMC, NUTS, Predictive
from numpyro.diagnostics import print_summary, summary


def get_data():
    N = 100
    P = 1000
    D = 5
    X = np.random.randn(N, P)
    W = np.zeros(P)
    W[:D] = [0.2, -0.3, 0.35, -0.4, 0.45]

    # true Rsq of 0.615
    Y = jnp.dot(X, W) + jnp.sqrt(1 - .615) * random.normal(random.PRNGKey(0), (N,))

    return X, Y, W

def model(X, Y):
    N, P = X.shape

    lambdas = numpyro.sample("lambdas", dist.HalfCauchy(jnp.ones(P)))
    tau = numpyro.sample("tau", dist.HalfCauchy(jnp.ones(1)))

    betas  = numpyro.sample("betas", dist.Normal(jnp.zeros(P), lambdas * tau))

    mu = jnp.dot(X, betas)
    sigma = numpyro.sample("sigma", dist.HalfCauchy(jnp.ones(1)))

    numpyro.sample("Y", dist.Normal(mu, sigma), obs=Y)

if __name__ == "__main__":

    numpyro.set_host_device_count(2)

    X, Y, W = get_data()

    start = time.time()
    kernel = NUTS(model)
    mcmc = MCMC(kernel, num_warmup=1000, num_samples=500, num_chains=2)
    mcmc.run(random.PRNGKey(0), X, Y)
    mcmc.print_summary()
    summary_dict = summary(mcmc.get_samples(), group_by_chain=False)

    end = time.time()
    seconds = end - start
    print("Time: ", seconds)
    # Takes 122 seconds, 167 divergences
    error = np.sum(np.square(summary_dict["betas"]["mean"][:5] - W[:5]))
    print("Error: ", error)
    # 0.55
    mean_ess = np.mean(summary_dict["betas"]["n_eff"][:5])
    print("Ess per second: ", mean_ess / seconds)
    # 0.063

When comparing different impls like this, I highly recommend just comparing the gradient computations, as this is what’s going to dominate the runtimes + everything else involves stochasticity so it can be difficult to get consistent estimates.

Doing so I’m indeed seeing that the gradient computation with Numpyro is much better: ~33µs vs. ~3.2ms (~100X).

The expectation is that Turing.jl will benefit greatly from Enzyme and so this gap should be much better within the coming year, but it’s unfortunately not yet ready :confused:

EDIT: I did a quick @profile of the computation and noticed that there’s a broadcast that is taking up most of the computation.

If I make

    W ~ MvNormal(Zeros{T}(P), I * λ.^2 * τ^2)

into

    W ~ MvNormal(Zeros{T}(P), I)

we shave off 10X, i.e. it now only takes ~330µs :grimacing:

So we might be hitting a particularly bad pullback somewhere in MvNormal wrt. covarinace matrix.

5 Likes

Very interesting - thanks for taking a look at the gradient computation cost. Unfortunately the variance specification is fairly important for the prior so can’t be removed in the interest of speed.

For now, I’ll eagerly await the integration of Enzyme into the Turing ecosystem; I cheer on the developers from the sidelines.

Oh for sure not suggesting the variance specification to be removed! I just noticed in the profiling that it seemed this took up most of the runtime, so I was curious how much of a difference it made.

But thanks for the model! Definitively one we should keep an eye out for when benchmarking.

2 Likes

Actually, you don’t need to drop this to get the desired performance, you can just replace

    W ~ MvNormal(zeros(P), I * λ.^2 * τ^2)

with

    W_inner ~ MvNormal(zeros(P), I)
    W = W_inner .* λ * τ
    Turing.@addlogprob! -sum(log.(λ * τ))

This gets us down to ~285µs per gradient call on my end, i.e. “only” ~8x slower than numpyro with Jax.

This will work nicely with HMC and other samplers, but will have the side-effect that sampling from the prior using, e.g. rand(model), won’t give the “right” result.

1 Like

This will work nicely with HMC and other samplers, but will have the side-effect that sampling from the prior using, e.g. rand(model), won’t give the “right” result.

As a final note, because Turing.jl is all Julia code, you can “easily” work around this in the model as follows:

"""
    is_prior_sampler(context)

Returns `true` if the context is a prior sampler, `false` otherwise.
"""
is_prior_sampler(context) = is_prior_sampler(DynamicPPL.NodeTrait(context), context)
is_prior_sampler(::DynamicPPL.IsLeaf, context) = false
is_prior_sampler(::DynamicPPL.IsParent, context) = is_prior_sampler(DynamicPPL.childcontext(context))
is_prior_sampler(context::DynamicPPL.SamplingContext) = context.sampler isa DynamicPPL.SampleFromPrior

@model function model(X; N::Int64, P::Int64)
    λ ~ filldist(truncated(Cauchy(0.0, 1.0); lower=0), P)
    τ ~ truncated(Cauchy(0.0, 1.0); lower=0)

    # HACK: Replace the following
    #   W ~ MvNormal(zeros(P), Diagonal(λ.^2 * τ^2))
    # with
    if is_prior_sampler(__context__)
        # Ensures that sampling from the prior is correct.
        W_inner ~ MvNormal(zeros(P), Diagonal(λ.^2 * τ^2))
        W = W_inner
    else
        # Faster path used when not sampling from the prior.
        W_inner ~ MvNormal(zeros(P), I)
        W = W_inner .* λ * τ
        Turing.@addlogprob! -sum(log.(λ * τ))
    end

    μ = X * W

    σ ~ truncated(Cauchy(0.0, 1.0); lower=0)

    Y ~ MvNormal(μ, I * σ^2)
end

Here you get correctness when sampling from the prior + performance improvement when using gradient samplers.

But this is accessing some internals, e.g. __context__ (this is always available in a @model and defines the “evaluation-mode”), so this is not something I recommend doing at the moment. The first thing is to get faster AD, and then we’ll write a proper performance guide + a bit on the internals that might be useful.

And just for the record: I also think it’s very annoying that such minor difference (i.e. avoiding passing the covariance matrix into MvNormal) makes such a huge difference :confused: I’ve myself been on the receiving end of this many times (brought one MCMC run down from a week to a day by discovering a type-instability far down in the callstack in the gradient computation).
But one “good” benefit here is that we can at least profile these computations without any issues (which is how I identified the above), and with a bit of experience and guidance we can get pretty decent perf in most cases.

EDIT: Turns out you can get Enzyme to work on this model. Include

using Enzyme
Enzyme.API.typeWarning!(false)
Enzyme.API.runtimeActivity!(true)

and then use AutoEnzyme.

This brings my runtime down to ~220µs on my end. But this is all with Float64; if we could use Float32 like the jax code is doing (nothing technical that stops, but there are still too many hardcoded Float64 usages in Turing.jl atm for it to “just work”), we’d probably bring it down to 3-4x slower than numpyro, which isn’t too bad.

5 Likes

I will note that that branch (and likely the first cut of Enzyme+Turing), will be a bit rough performance wise at first. There will likely be nontrivial perf improvements by working together by having Turing define EnzymeRules to mark functions as not needing a derivative or otherwise pass in additional information.

2 Likes

This is great - I’ve learned a lot from this interaction. Seems like with some tweaks to circumvent slow gradient computations one can get within an order of magnitude of numpyro with this model.

I did have one question on this:

W_inner ~ MvNormal(zeros(P), I)
W = W_inner .* λ * τ
Turing.@addlogprob! -sum(log.(λ * τ))

I think the idea here is to fix the log density increment of the MvNormal(0, I) prior by adding in the log determinant of the covariance. I think this is slightly off; the log determinant of the covariance would be:

-sum(log.(λ.^2 * τ^2))

Also isn’t the kernel slightly off?

I.e. if the covariance is Σ = Diagonal(I * λ.^2 * τ^2), and we replace that with just I, the difference is:

W^t Σ^-1 W = W^t(Σ^-1 - I)W + W^tW 

So seems like we’d need to add W^t(Σ^-1 - I)W in with Turing.@addlogprob! as well.

Sorry, ignore my post: yes, that does work - it adds in the log determinant of the jacobian of the inverse transform. Thanks again!