NUTS algorithm in Turing.jl

Dear all,

Can someone explain to me that the n_adapts in case NUTS algorithm NUTS(n_adapts, δ) used in Turing sample function is number of iteration or number of accepted samples before main sampling? Because when I give n_adapts 100 and print out iterations completed in @model, I can see that main sampling still not started even when 100 iterations are completed. Im confused because I just checked quickly on Gemini and it says n_adapts is iteration not accepted samples before beginning of main sampling.

Thank you for your help.

I’m not sure if I totally understand the question here, but I’ll try my best to answer what I can.

n_adapts refers to the number of iterations (calls to AbstractMCMC.step) where we adaptively tune the Euclidean metric. Therefore, if you set n_adapts to 100, then the first 100 states along the Markov chain will be used to estimate the covariance.

By default these states are thrown away, but from what I’ve seen they are always accepted. If you want a better look at the states used in adaptation, feel free to add the discard_adapt kwarg to your sampler.

rng = MersenneTwister(1234)
chain = sample(rng, model, NUTS(100, 0.65), 250; discard_adapt=false)
describe(chain[101:end])

The first 100 elements of the chain are used in adaptation, and are usually discarded without the additional flag. If you were to exclude this flag, it would perform 350 iterations and return the last 250 elements.

Hope that helps

That also what I found out but then if I have following situation chain = sample(model, model, NUTS(100, 0.65), 250; discard_adapt=false)print iterations as follow inside @model where iteration_counter = Ref(0) outside of @model:

   @model xyz(....)

         # other stuff

          if iteration_counter[] % 100 == 0
               @info "$(iteration_counter[])"
          end
    end

model = xyz(....)

chain = sample(model, model, NUTS(100, 0.65), 250; discard_adapt=false)

when I get this @info it means 100 iterations are done right ? Because even after that I see sampling doesn’t start. why so? Am I missing something?

or may be better question, is there a way to track warmup ? if I give 1000 warmup then i can extract update 500 steps already done 500 remaining?

when I get this @info it means 100 iterations are done right ? Because even after that I see sampling doesn’t start. why so? Am I missing something?

Assuming you add 1 to iteration_counter[] upon each model execution, what you’re tracking is the number of model evaluations/calls to model.f(...). This counter not only accumulates computations of the log-likelihood, but also sampling from the model as well as any executions done by AD.

is there a way to track warmup ?

Yes. Using my above code, the first 100 elements of the chain is the “warm-up” period. The subsequent draws are with the fully adapted Euclidean metric.

adaptation_stats = chain[1:100]
posterior_sample = chain[101:end]

okay understood. Thank you for your explanation.

Yes. Using my above code, the first 100 elements of the chain is the “warm-up” period. The subsequent draws are with the fully adapted Euclidean metric.

but this is information I get after finishing my run. Lets say I have a complicated model and its take very long time and I would like to track NUTS accepted samples for number of current iterations. How can I do this?

You can track progress by setting the keyword argument progress=true. By default this should be true, but you can check the AbstractMCMC.jl docs for more details. If you use an IDE like vscode, the progress bar is displayed on the lower left hand side as opposed to the terminal.

You could also do some fancy things with implementing a custom logger. The gory details can be seen in the unit tests.

AbstractMCMC has callback functions which are run once at the end of every iteration (note that this is different from each model evaluation, because every MCMC iteration may evaluate the model many times, especially with a sampler like NUTS).

See Callbacks · AbstractMCMC

Now, be warned because everything that follows on from this is Turing internals.

Inside the callback function you have access to a number of quantities. I’d recommend you start by playing with this and printing out all of them to see what’s inside them.

function my_callback(rng, model, sampler, transition, state, iteration; kwargs...)
    @show transition state iteration
end

sample(model, NUTS(), 1000; callback=my_callback)

The question is where in these things can you extract the info about whether the transition was accepted or not. This is the part where (unfortunately) you have to look at the things that the Turing sampler returns:

From this code you can see that there’s this thing called t.stat, which contains a boolean is_accept. Furthermore, t.stat is bundled as part of the transition object, which is a DynamicPPL.ParamsWithStats.

If you look up the definition of that struct

this will suggest that the right callback to use is something like this:

function my_callback(rng, model, sampler, transition, state, iteration; kwargs...)
    @show iteration transition.stats.is_accept
end

sample(model, NUTS(), 100; callback=my_callback)

(Of course, you can do whatever you want rather than just @show ing it. For example, you could make the callback track the percentage of accepted transitions over time. That part is up to you :slight_smile:)

This is unfortunately pretty finicky. It would be nice to have a better interface in the future, but this is what it is for now.

BTW, the callbacks docs page (linked above) also has an example of using TensorBoard to log these stats as sampling progresses. I’ve never tried to use it myself, but if you fancy that, you could try it out too. That should log a ton of different stats (not just the acceptance).

1 Like