using Lux, ComponentArrays, Turing, Random, MCMCChains, FillArrays, LinearAlgebra
using ReverseDiff, Tracker
Turing.setadbackend(:tracker)
nn = Chain(Dense(1 => 32, gelu), Dense(32 => 1))
ps_nt, st = Lux.setup(Random.default_rng(), nn)
ps = ComponentArray(ps_nt)
ps_ax = getaxes(ps)
@model function bayes_nn(xs, ts, nparameters; alpha = 0.09)
parameters ~ MvNormal(Zeros(nparameters), I / alpha)
preds = dropdims(first(nn(xs, ComponentArray(parameters, ps_ax), st)); dims = 1)
@. ts ~ Normal(preds)
end
# Some dummy data
sx = [rand(Float32, 1) for _ in 1:128]
sy = [rand(Float32) for _ in 1:128]
N = 2000
ch = sample(bayes_nn(reduce(hcat, sx), sy, length(ps)), NUTS(), N);
Using Tracker, I get ~7.5mins. (I just put some random data and dimensions for sx and sy) But the equivalent ReverseDiff version seems to take 2hrs
The equivalent Flux version (with tracker) gives me ~10hrs
import Flux
nn_flux = Flux.Chain(Flux.Dense(1 => 32, gelu), Flux.Dense(32 => 1))
params, re = Flux.destructure(nn_flux)
@model function bayes_nn_flux(xs, ts, nparameters; alpha = 0.09)
parameters ~ MvNormal(Zeros(nparameters), I / alpha)
preds = dropdims(re(parameters)(xs); dims = 1)
@. ts ~ Normal(preds)
end
ch = sample(bayes_nn_flux(reduce(hcat, sx), sy, length(params)), NUTS(), N);