A student I’m working with is working on a Gaussian Process model in Turing, but it was very slow. Out of frustration, they reimplemented in stan and it’s much faster. For context, it fits (1000 samples from 1 chain) in about 200s in stan but takes about 4000s in Turing (20x difference).
We have tried many things, including messing with the sampler (in Turing), playing with the autodiff backend, changing the prior, etc but haven’t figured anything out. In all these cases, the speed difference seems fairly similar. At this point, we would appreciate any tips on how to help the Julia model catch up to stan or else to understand why Turing is struggling here (to better choose tools in the future).
Thanks!
Our model looks like:
where
with a Gaussian process over \beta(s).
Although this is kind of a silly model, it’s an important building block for future work.
Our Turing model looks like
@model function gp(x, y, D; cov_fn::Function = cov_exp_quad)
# parse inputs
N = size(y, 2)
# priors
μ0 ~ Normal(5, 3) #
β_m ~ Normal(0, 1)
β_ρ ~ InverseGamma(5, 5) # kernel length parameter -- see Stan manual
β_α ~ Truncated(Normal(0, 1), 0, Inf) # kernel variance -- see Stan manual
# compute the kernel
β_K = cov_fn(D, β_α, β_ρ)
# often one adds noise to the diagonals of the covariance fn to account for noise
# in this case this does not make sense -- the true value of β(s) is smooth
# reject any sample where the covariance kernel is not positive definite
if !isposdef(β_K)
Turing.@addlogprob! -Inf
return nothing
end
# spatial model for the coefficient
μ_β ~ MvNormal(β_m * ones(N), β_K)
# the mean
μ = μ0 .+ hcat(x) .* transpose(hcat(μ_β))
logs ~ Normal(0, 1) # log standard deviation
s = exp(logs) # standard deviation
# compute the distributions
dists = Normal.(μ, s)
y .~ dists
end
where
function cov_exp_quad(D, α, ρ)
sq_dist = D .^ 2
return α^2 * exp.(-sq_dist ./ (2 * ρ^2))
end
Our stan model looks like
data {
int<lower=1> N_loc;
int<lower=1> N_obs;
vector[2] X[N_loc]; //locations, longitude & latitude
vector[N_loc] y[N_obs]; //observations of rainfall, indexed by [time, location]
vector[N_obs] x; // covariates
}
parameters{
real<lower=0> mu0; // intercept for mean
real beta_m; // coefficient for mean
real beta_rho; // kernel length parameter
real<lower=0> beta_alpha; // kernel variance parameter
real logs; // log of standard deviation
vector[N_loc] mu_beta;// coefficient
}
model{
matrix[N_loc, N_loc] beta_xi;
matrix[N_loc, N_loc] beta_K = cov_exp_quad(X, beta_alpha, beta_rho); // covariance function
vector[N_loc] mu[N_obs]; // mean value for each observation
beta_xi = cholesky_decompose(beta_K);
mu0 ~ normal(5, 3);
beta_m ~ std_normal();
beta_rho ~ inv_gamma(5, 5);
beta_alpha ~ std_normal();
logs ~ std_normal();
mu_beta ~ multi_normal_cholesky(rep_vector(beta_m, n_station), beta_xi);
for (i in 1:N_obs){
for (j in 1:N_loc){
mu[i, j] = mu0 + x[i] * mu_beta[j];
}
}
for (i in 1:N_obs){
for (j in 1:N_loc){
y[i, j] ~ normal(mu[i, j], exp(logs));
}
}
}