Very large performance difference between Turing and Stan [Gaussian Process]

A student I’m working with is working on a Gaussian Process model in Turing, but it was very slow. Out of frustration, they reimplemented in stan and it’s much faster. For context, it fits (1000 samples from 1 chain) in about 200s in stan but takes about 4000s in Turing (20x difference).

We have tried many things, including messing with the sampler (in Turing), playing with the autodiff backend, changing the prior, etc but haven’t figured anything out. In all these cases, the speed difference seems fairly similar. At this point, we would appreciate any tips on how to help the Julia model catch up to stan or else to understand why Turing is struggling here (to better choose tools in the future).

Thanks!

Our model looks like:

y(s, t) \sim \mathcal{N}(\mu(s, t), \sigma)

where

\mu(s, t) = \mu_0 + \beta_\mu(s) * x(t)

with a Gaussian process over \beta(s).
Although this is kind of a silly model, it’s an important building block for future work.

Our Turing model looks like

@model function gp(x, y, D; cov_fn::Function = cov_exp_quad)
    
    # parse inputs
    N = size(y, 2)

    # priors
    μ0 ~ Normal(5, 3) # 
    β_m ~ Normal(0, 1)
    β_ρ ~ InverseGamma(5, 5) # kernel length parameter -- see Stan manual
    β_α ~ Truncated(Normal(0, 1), 0, Inf) # kernel variance -- see Stan manual

    # compute the kernel
    β_K = cov_fn(D, β_α, β_ρ)
    # often one adds noise to the diagonals of the covariance fn to account for noise
    # in this case this does not make sense -- the true value of β(s) is smooth

    # reject any sample where the covariance kernel is not positive definite
    if !isposdef(β_K)
        Turing.@addlogprob! -Inf
        return nothing
    end

    # spatial model for the coefficient
    μ_β ~ MvNormal(β_m * ones(N), β_K)

    # the mean
    μ = μ0 .+ hcat(x) .* transpose(hcat(μ_β))
    logs ~ Normal(0, 1) # log standard deviation
    s = exp(logs) # standard deviation

    # compute the distributions
    dists = Normal.(μ, s)
    y .~ dists

end

where

function cov_exp_quad(D, α, ρ)
    sq_dist = D .^ 2
    return α^2 * exp.(-sq_dist ./ (2 * ρ^2))
end

Our stan model looks like

data {
  int<lower=1> N_loc;
  int<lower=1> N_obs;
  vector[2] X[N_loc]; //locations, longitude & latitude
  vector[N_loc] y[N_obs]; //observations of rainfall, indexed by [time, location]
  vector[N_obs] x; // covariates
}

parameters{
  real<lower=0> mu0; // intercept for mean
  real beta_m; // coefficient for mean
  real beta_rho; // kernel length parameter
  real<lower=0> beta_alpha; // kernel variance parameter
  real logs; // log of standard deviation
  vector[N_loc] mu_beta;// coefficient
}

model{
  matrix[N_loc, N_loc] beta_xi;
  matrix[N_loc, N_loc] beta_K = cov_exp_quad(X, beta_alpha, beta_rho); // covariance function
  vector[N_loc] mu[N_obs]; // mean value for each observation

  beta_xi = cholesky_decompose(beta_K);
  
  mu0 ~ normal(5, 3);
  beta_m ~ std_normal();
  beta_rho ~ inv_gamma(5, 5);
  beta_alpha ~ std_normal();
  logs ~ std_normal();
  
  mu_beta ~ multi_normal_cholesky(rep_vector(beta_m, n_station), beta_xi);
  
  for (i in 1:N_obs){
    for (j in 1:N_loc){
      mu[i, j] = mu0 + x[i] * mu_beta[j];
    }
  }

  for (i in 1:N_obs){
    for (j in 1:N_loc){
      y[i, j] ~ normal(mu[i, j], exp(logs));
    }
  }
}
2 Likes

have you checked CPU usage is they’re both using single thread or same amount of CPUs if not single?

1 Like

Thanks for the suggestion. We’re using 1 CPU for both so I don’t think that’s it.

I remember having a huge slowdown with truncated normals…

See if anything discussed there helps?

Yes, that’s one improvement that can be made here. Truncated should not be used directly. Instead, one should use truncated. Secondly, the truncated(Normal(0, 1), 0, Inf) is deprecated syntax, since it can introduce numerical instability. Instead, use truncated(Normal(0, 1); lower=0).

A few other improvements are:

  • move all computations that can be made one-time out of the main body and into the arguments. e.g. you can declare X=hcat(x) as a keyword argument and use that in the model. D is elementwise squared in your model each evaluation even though it never changes, which is wasted effort.
  • Don’t use isposdef(β_K) if you can avoid it. Internally, that cholesky decomposes your matrix, which slows down your gradients. Instead, if you can check the arguments, that would be much faster.
  • Use s ~ LogNormal(0, 1).
  • Use y ~ MvNormal(μ, Diagonal(s.^2))
7 Likes

Thanks, these are very helpful tips. I’ll review them with her and see what the speed up looks like

In theory our matrix should always be positive definite. However, in practice we sometimes got some numerical issues, so we introduced this hack. Any other suggestions for handling this? Maybe putting stronger priors on the kernel length scale to keep it from getting too small would help…

If you know certain low values are unreasonable, incorporating that in the prior is a good idea. Sometimes it makes sense to add something like 1e-10 * I to the matrix just to smooth over some numerical issues. It’s a hack but one that works quite well for many models.

2 Likes

Thanks @sethaxen, better constraints on the parameters has helped with this issue so we’ve been able to get rid of the isposdef line without issues.

Things are still quite slow, as @dlakelan notes. It doesn’t dramatically get better if we turn that truncated Normal into something else. We’ve looked thorugh the performance tips documentation, but we’ll scan the discussion referenced and see if we can find anything useful.

Thanks

1 Like

One great thing about using Julia + Turing is that you can just run the regular old Julia profiler on it. Try that out and post a flame graph. I like the StatProfilerHTML.jl library for this.

Can you post the latest model with the changes that you have applied? Also, which AD backend are you using? Depending on the dimensionality, the AD backend choice tend to be very critical for performance.