Turing.jl - NUTS/sample parameters


I’m trying to learn Turing.jl and the concepts behind it and I’m not seeing any documentation that explains what the parameters are that are passed to NUTS() when sampling. The documentation utilizes the NUTS sampler in several of the tutorials but I can’t seem to find anything that explains what the function arguments are. For example, when using the NUTS sampler with a linear regression model:

chain = sample(model, NUTS(200, 0.65), 1500)

What are 200, 0.65 and 1500?

Here is the docstring https://github.com/TuringLang/Turing.jl/blob/master/src/inference/hmc.jl#L214. There is an empty line between the docstring and type definition so it is not showing in the help. I will fix that. 1500 is the number of samples.

1 Like

200 is the number of adaptation samples. Adaptation samples are used to set the sampler up to generate good samples, and are discarded by default (so your chain will end up with 1500 - 200 samples).

0.65 is the target acceptance rate for samples, but for most use cases it’s fine to leave it at 0.65.