Comparison to numpyro

using Lux, ComponentArrays, Turing, Random, MCMCChains, FillArrays, LinearAlgebra
using ReverseDiff, Tracker

Turing.setadbackend(:tracker)

nn = Chain(Dense(1 => 32, gelu), Dense(32 => 1))
ps_nt, st = Lux.setup(Random.default_rng(), nn)
ps = ComponentArray(ps_nt)
ps_ax = getaxes(ps)

@model function bayes_nn(xs, ts, nparameters; alpha = 0.09)
    parameters ~ MvNormal(Zeros(nparameters), I / alpha)
    preds = dropdims(first(nn(xs, ComponentArray(parameters, ps_ax), st)); dims = 1)
    @. ts ~ Normal(preds)
end

# Some dummy data
sx = [rand(Float32, 1) for _ in 1:128]
sy = [rand(Float32) for _ in 1:128]

N = 2000
ch = sample(bayes_nn(reduce(hcat, sx), sy, length(ps)), NUTS(), N);

Using Tracker, I get ~7.5mins. (I just put some random data and dimensions for sx and sy) But the equivalent ReverseDiff version seems to take 2hrs

The equivalent Flux version (with tracker) gives me ~10hrs

import Flux

nn_flux = Flux.Chain(Flux.Dense(1 => 32, gelu), Flux.Dense(32 => 1))
params, re = Flux.destructure(nn_flux)

@model function bayes_nn_flux(xs, ts, nparameters; alpha = 0.09)
    parameters ~ MvNormal(Zeros(nparameters), I / alpha)
    preds = dropdims(re(parameters)(xs); dims = 1)
    @. ts ~ Normal(preds)
end

ch = sample(bayes_nn_flux(reduce(hcat, sx), sy, length(params)), NUTS(), N);
7 Likes