Curious if anyone else has tried comparing Turing with Numpyro? I am working with a dataset of about N = 5,000 and fitting a bayesian MLP on top of a fine-tuned language model which I am viewing as a fixed feature extractor.
Attempting to fit even a small MLP (64 => 32 => 1), or even smaller 1 => 32 => 1 like below, on 1,000 data points is resulting in ~10 hours of sampling time.
nn_initial = Chain(Dense(1, 32,gelu),Dense(32,1)) parameters_initial, reconstruct = Flux.destructure(nn_initial) @model function bayes_nn(xs, ts, nparameters, reconstruct; alpha=0.09) parameters ~ MvNormal(Zeros(nparameters), I /alpha ) nn = reconstruct(parameters) preds = dropdims(nn(xs),dims=1) @. ts ~ Normal(preds) end; using MCMCChains N = 2000 ch = sample(bayes_nn(hcat(sx...), sy, length(parameters_initial), reconstruct), NUTS(), N);
An equivalent Numpyro model took about 20 minutes for NUTS sampling.
Are there any optimizations that I can work with here? Is the speed restriction coming from Zygote, or Turing, or me?
I really appreciate using Turing for smaller work because the syntax is so nice!