Hello all Turing experts. I am having a vexing issue with using NUTS sampling to sample from the posterior of a semi-complicated model involving ODE’s. This is a performance issue that I can’t track down. Unfortunately I can’t describe the whole modeling problem or boil it down to a minimum working example, so here are a few details.
I have a biological modeling problem involving a few dozen differential equations for which I am interested in only a few parameters (see below for turing model outline).
All of the complexity of the model resides in the custom logp_fun that computes the sum log likelihood of data given parameters.
I have been careful in the construction of this function. I have ensured that types can be inferred and that the remake process for continually re-simulating ODE’s follows SciML guidance for optimization contexts. Due to the structure of my data, I am multi-threading within this function (so no MCMCThreads() later). I have however ensured there are no data-races or any threading issues (simple use of Threads.@threads on a loop).
I have tested this core function alone on behavior and performance. Computing the function alone is ~1-2ms and computing the gradient is also ~1-2ms (from BenchmarkTools).
Optimizing this model using Optimization.jl behaves and performs as expected. Using Adam I get ~3ms / step. Using BBO I get ~1ms / step (no gradient here). And the optimization results themselves are what I expect.
To help Turing out, I first use Optimization.jl to get a MLE estimate. I use that as an initial estimate to get a MAP estimate using maximum_a_posteriori. I then use that MAP estimate as the initial_params for sampling. So I know that the sampler is starting in the support of the posterior (I also have print checks in logp_fun to tell me if the model has wandered and it does not).
However using NUTS within turing takes 8 sec / sample (single chain, 150 samples in 1293 sec).
To reiterate, the logp_fun and its gradient take 1-3ms to compute. Adam takes about 3ms / step. But NUTS is taking 8sec per sample (1000x longer). I realize that NUTS requires multiple evaluations per sample, but 1000x? Any ideas?
I realize there is not much to work with here since the logp_fun is too complicated to include here. But the function and its gradients clearly function well both on their own and within a gradient based optimizer (Adam). So there must be something about the NUTS algorithm or the Turing implementation of it that is leading to extremely slow sampling.
@model function construct_model(logp_fun)
# Sampling logs to work on real line. Parameter back-conversion in the logp_fun.
log_p1 ~ Normal(-4.5,0.5)
log_p2 ~ Normal(0.0,1.0)
log_p3 ~ Normal(0.0,1.0)
log_p4 ~ Normal(-4.5,0.5)
log_σ ~ Normal(-0.7,1.0)
σ = exp(log_σ)
params_vec = [ log_p1, log_p2, log_p3, log_p4, σ ];
# logp_fun is hand crafted and cannot be reduced to a minimum working example.
Turing.@addlogprob! logp_fun(params_vec)
end
model = construct_model(logp_fun); #logp_fun defined elsewhere
# Optimization.jl used to get initial MLE estimate. Global optim (BBO) required so Turing not used for this.
map_est = maximum_a_posteriori(model,
initial_params = MLE_optim.u,
);
chain_NUTS = sample(model, NUTS(0.7), 100; initial_params=map_est.values.array)
We had some troubles with the combination of NUTS and ODE solvers that use adaptive time stepping. The adaptive time-stepping can result in “noisy” gradients which then forces NUTS to make tiny steps.
I do not know is there is a formal investigation of this unfortunate interactions. A bit related is this paper, although they focus on optimization and do not use automatic differentiation:
Kavetski, D., Clark, M.P., 2010. Ancient numerical daemons of conceptual hydrological modeling: 2. Impact of time stepping schemes on model analysis and prediction. Water Resources Research 46. https://doi.org/10.1029/2009WR008896
The noise is related to the tolerance of the ODE solver. So generally for inverse problems the tolerance should be set to something like abstol=1e-8,reltol=1e-8 or lower. Did you do that in your case?
(Generally if this is your issue though, you would see the same issue with BFGS/LBFGS since those also do not work well with stochasticity in the gradient, but from OP it seems that might not apply here)
The noise definitively depends on the tolerances. We did some experiments and I remember that the noise of the gradient was also depending on the solver algorithm (I would need to look up the details, it a while ago)
At least for our toy model, deactivating the adaptive time stepping has sometime lead to bias, but we got always a smooth gradient even for low tolerances.
Thanks for the comments and suggestions. I’m hearing two potential issues.
Initializing at the MAP estimate may lead to issues either I) due to starting at highest density point or II) poor adaptation phase since gradients should be small at the MAP. I don’t think this is the case. I attempted to move the initial_params a bit from the MAP and things did not improve. The case @rto mentioned is for high dimensional problems. This is 5-dim.
@ChrisRackauckas suggested lower ODE tolerances. This appears to work, but with a strange quirk. I tried three cases (all using AutoTsit5(Rosenbrock23())).
i) abstol = reltol = 1e-8. The logpfun (and gradient) take ~10ms. NUTS takes ~180ms / sample.
ii) abstol = reltol = 1e-4. The logpfun and its gradient (ForwardDiff) take ~1.5ms. NUTS takes ~180ms / sample.
iii) No tolerences included. In this case, the target function and its gradient take ~1.5ms. NUTS takes ~8sec / sample. Further, NUTS misbehaves in the sense that its progress seems to change erratically (based on the progress printing)
So here are the takehomes (for this specific problem)
2a) Specifying tolerances indeed improves NUTS sampling. However, tight tolerances aren’t needed for sampling to proceed. The act of specifying them seems to have some effect.
2b) Tight tolerances indeed slow the logdensity and its gradient evaluation (factor of 7). And the use of Optimization.jl for a MLE estimate slows accordingly (not shown). However, NUTS sampling takes about the same time.
2c) Posteriors obtained with tighter and looser tolerances are near identical, though the ESS is a bit better for the tighter tolerance. The posterior (has to run overnight!!!) without specifying tolerances is terrible.
Not sure the source of this. And the testing was not exhaustive. But it appears specifying tolerances in the ODE solve may be important when using NUTS, though the results aren’t that sensitive to the tolerances.
Would be great if anyone has insights into this. But this is good enough for now.
For a complicated posterior NUTS can need quite a lot of evaluations per sample. Turing stores some internal information from the sampler in the chain. You might want to check in particular, chain_NUTS[info] where info ∈ [:n_steps, :acceptance_rate, :tree_depth, :step_size] to get some additional insights and diagnose problematic sampling.
Thanks for the suggestion @bertschi. I wasn’t sure how to access this. I looked at these metrics for the case where the ODE is solved 1) with and 2) without the included tolerance.
The first image is with abs_tol = rel_tol = 1e-4. The second is with no tolerance provided. As expected by NUTS grinding to a halt, each sample requires an enormous number of steps and a deep tree when no tolerance is included. The n_steps, tree_depth and step_size are comparable for a tolerance of 1e-8.
The question is why? A tolerance of 1e-4 is not exactly stringent.
It’s more stringent than the defualt reltol=1e-3. I say for inverse problems a safe default is 1e-8 since it just needs to be “low enough”: the defaults are low enough for forward problems and a basic plot but not for inverse problems.
I just tried things with explicit default tolerances of 1e-3 and got the same behavior as with no specified tolerance. Thanks.
For my problem, 1e-4 is sufficient, but NUTS time / sample is the same for 1e-4 and 1e-8 (despite direct simulation time being 7x longer in the 1e-8 case) with slightly better ESS with tighter tolerance. Will of course be problem specific.