Curious if anyone else has tried comparing Turing with Numpyro? I am working with a dataset of about N = 5,000 and fitting a bayesian MLP on top of a fine-tuned language model which I am viewing as a fixed feature extractor.
Attempting to fit even a small MLP (64 => 32 => 1), or even smaller 1 => 32 => 1 like below, on 1,000 data points is resulting in ~10 hours of sampling time.
nn_initial = Chain(Dense(1, 32,gelu),Dense(32,1))
parameters_initial, reconstruct = Flux.destructure(nn_initial)
@model function bayes_nn(xs, ts, nparameters, reconstruct; alpha=0.09)
parameters ~ MvNormal(Zeros(nparameters), I /alpha )
nn = reconstruct(parameters)
preds = dropdims(nn(xs),dims=1)
@. ts ~ Normal(preds)
end;
using MCMCChains
N = 2000
ch = sample(bayes_nn(hcat(sx...), sy, length(parameters_initial), reconstruct), NUTS(), N);
An equivalent Numpyro model took about 20 minutes for NUTS sampling.
Are there any optimizations that I can work with here? Is the speed restriction coming from Zygote, or Turing, or me?
I really appreciate using Turing for smaller work because the syntax is so nice!