Lux + Turing: How to not to use a global variable for state `st`

Hi there,

Lux state handling, while lovely explicit, seems sometimes a bit difficult.
Here the example from the Bayesian Lux Tutorial

# Specify the probabilistic model.
@model function bayes_nn(xs, ts)
    global st

    # Sample the parameters
    nparameters = Lux.parameterlength(nn)
    parameters ~ MvNormal(zeros(nparameters), Diagonal(abs2.(sig .* ones(nparameters))))

    # Forward NN to make predictions
    preds, st = nn(xs, vector_to_parameters(parameters, ps), st)

    # Observe each prediction.
    for i in 1:length(ts)
        ts[i] ~ Bernoulli(preds[i])
    end
end

How can I prevent the global variable in a way that the state update is efficiently optimized by the compiler?

EDIT: I opened a more generic Issue on Lux.jl for this question Documentation Request: Standardize the handling of the state `st` · Issue #515 · LuxDL/Lux.jl · GitHub

Use Utilities | LuxDL Docs. It can handle both cases where the state type changes and the state type is fixed.

That said if your state changes type across iterations it is probably some problem in the model implementation. Most state type changes should be performed using update_state / testmode outside of main model calls.

1 Like