Comparison to numpyro

Curious if anyone else has tried comparing Turing with Numpyro? I am working with a dataset of about N = 5,000 and fitting a bayesian MLP on top of a fine-tuned language model which I am viewing as a fixed feature extractor.

Attempting to fit even a small MLP (64 => 32 => 1), or even smaller 1 => 32 => 1 like below, on 1,000 data points is resulting in ~10 hours of sampling time.

nn_initial = Chain(Dense(1, 32,gelu),Dense(32,1))
parameters_initial, reconstruct = Flux.destructure(nn_initial)

@model function bayes_nn(xs, ts, nparameters, reconstruct; alpha=0.09)
    parameters ~ MvNormal(Zeros(nparameters), I /alpha )
    nn = reconstruct(parameters)
    preds = dropdims(nn(xs),dims=1)

    @. ts ~ Normal(preds)
end;
using MCMCChains
N = 2000
ch = sample(bayes_nn(hcat(sx...), sy, length(parameters_initial), reconstruct), NUTS(), N);

An equivalent Numpyro model took about 20 minutes for NUTS sampling.

Are there any optimizations that I can work with here? Is the speed restriction coming from Zygote, or Turing, or me?

I really appreciate using Turing for smaller work because the syntax is so nice!

3 Likes

Try using Lux.jl instead, follow example:

Zygote, which is more-or-less broken; see here. You can try ReverseDiff instead, which should be much faster.

ReverseDiff made the run time worse, projected time of 47 hours, even on the toy example of 1 => 32 => 1. This is with Memoization and rdcache.

Curious if Turing and Numpyro use different NUTS implementations?

1 Like

Did you set compiled mode?

I suppose not since I didn’t know that was a thing that could be done. how do I do that?

edit: I found some documentation on ReverseDiff.jl’s github. Unclear to me if it is sufficient to compile the tape for the Flux model, or if there is additional work to be done around the Turing model itself?

using Lux, ComponentArrays, Turing, Random, MCMCChains, FillArrays, LinearAlgebra
using ReverseDiff, Tracker

Turing.setadbackend(:tracker)

nn = Chain(Dense(1 => 32, gelu), Dense(32 => 1))
ps_nt, st = Lux.setup(Random.default_rng(), nn)
ps = ComponentArray(ps_nt)
ps_ax = getaxes(ps)

@model function bayes_nn(xs, ts, nparameters; alpha = 0.09)
    parameters ~ MvNormal(Zeros(nparameters), I / alpha)
    preds = dropdims(first(nn(xs, ComponentArray(parameters, ps_ax), st)); dims = 1)
    @. ts ~ Normal(preds)
end

# Some dummy data
sx = [rand(Float32, 1) for _ in 1:128]
sy = [rand(Float32) for _ in 1:128]

N = 2000
ch = sample(bayes_nn(reduce(hcat, sx), sy, length(ps)), NUTS(), N);

Using Tracker, I get ~7.5mins. (I just put some random data and dimensions for sx and sy) But the equivalent ReverseDiff version seems to take 2hrs

The equivalent Flux version (with tracker) gives me ~10hrs

import Flux

nn_flux = Flux.Chain(Flux.Dense(1 => 32, gelu), Flux.Dense(32 => 1))
params, re = Flux.destructure(nn_flux)

@model function bayes_nn_flux(xs, ts, nparameters; alpha = 0.09)
    parameters ~ MvNormal(Zeros(nparameters), I / alpha)
    preds = dropdims(re(parameters)(xs); dims = 1)
    @. ts ~ Normal(preds)
end

ch = sample(bayes_nn_flux(reduce(hcat, sx), sy, length(params)), NUTS(), N);
7 Likes

Thanks for this! Helpful. I thought it must be somewhat related to Flux because Turing doesn’t otherwise have huge speed problems.

5 Likes

I’m curious, do you have any thoughts about why this my be occurring? It seems to be an interaction between the AD package and the NN framework. No point in going out of your way, I’m just been trying to get my head around the practical circumstances of these things. Just wanted to ask in case you (or anyone else) has explanation or intuition.

This particular issue is caused by how Restructure “restructures” the neural network. If you look into the constructed network, it contains Array of TrackedReals (which is a valid data structure used by tracker), but this means all function calls internally will use generic dispatches and not blas in the forward pass, and in the backward pass it will again not use the rules IIRC.

As to why reversediff with Lux gives 2hrs, I think that is due to how the componentarrays wraps reversediff. I have been meaning to do sort that out but haven’t had to bandwidth to look at it recently. A workaround to using ReverseDiff would be what I do in this tutorial Bayesian Neural Network | LuxDL Docs

3 Likes