Hi there,
Lux state handling, while lovely explicit, seems sometimes a bit difficult.
Here the example from the Bayesian Lux Tutorial
# Specify the probabilistic model.
@model function bayes_nn(xs, ts)
global st
# Sample the parameters
nparameters = Lux.parameterlength(nn)
parameters ~ MvNormal(zeros(nparameters), Diagonal(abs2.(sig .* ones(nparameters))))
# Forward NN to make predictions
preds, st = nn(xs, vector_to_parameters(parameters, ps), st)
# Observe each prediction.
for i in 1:length(ts)
ts[i] ~ Bernoulli(preds[i])
end
end
How can I prevent the global variable in a way that the state update is efficiently optimized by the compiler?
EDIT: I opened a more generic Issue on Lux.jl for this question Documentation Request: Standardize the handling of the state `st` · Issue #515 · LuxDL/Lux.jl · GitHub