Best, most idiomatic way to implement Bayesian Differential Equations in Turing?

Hi! As part of GSOC, I’m working on porting the posteriordb benchmarking suite to Turing. This includes a number of Bayesian differential equations.

I’m really interested in putting Julia’s “best foot forwards”, so I’m making an effort to write my models in the most efficient and idiomatic ways possible. The DifferentialEquations models in particular seem like an area where it would be easy to miss potential optimizations.

I’m asking this here instead of on Slack, as I think this discussion could be broadly useful.

So, with that in mind, I’d appreciate feedback on the following attempt to implement a Bayesian Lotka-Volterra model in Diffeqs/Turing:

# Define Diffeq in DifferentialEquations form
function lotka_volterra!(dz, z, p, t)
    alpha, beta, gamma, delta = p
    u, v = z

    # Evaluate differential equations.
    dz[1] = (alpha - beta * v) * u # prey
    dz[2] = (- gamma + delta * u) * v # predator

    return nothing
end


# Define Turing Model
@model function lotka_volterra(N, ts, y_init, y)
    theta ~ arraydist([
        Truncated(Normal(1, 0.5), 0, Inf),
        Truncated(Normal(0.05, 0.05), 0, Inf),
        Truncated(Normal(1, 0.5), 0, Inf),
        Truncated(Normal(0.05, 0.05), 0, Inf)
    ])

    sigma ~ filldist(LogNormal(-1, 1), 2)
    z_init ~ filldist(LogNormal(log(10), 1), 2)

    # Create trajectory for this parameter set
    prob = ODEProblem(lotka_volterra!, z_init, (0, ts[end]), theta)
    z = solve(prob, DP5(), saveat = ts)

    # Include trajectories in chain to match Stan
    z1 := z[1, :]
    z2 := z[2, :]

    # If the solver failed, reject (taken roughly from DiffEqBayes.jl)
    if length(z[1,:]) < N || any(z .< 0)
        Turing.DynamicPPL.acclogp!!(__varinfo__, -Inf)
        return
    end

    # Initial Condition Likelihood (y_init is observed)
    for i in 1:2
        y_init[i] ~ LogNormal(log(z_init[i]), sigma[i])
        y[:, i] ~ MvLogNormal(log.(z[i, :]), sigma[i]^2 .* I)
    end

    # Generated Quantities:
    # Todo, it's a bit finicky because the rand() calls break
    # autodiff, so not yet set on how to do this inside the model.
end

I am constrained by the fact that this implementation should match the Stan implementation, which motivates some of the decisions here (e.g. making theta an array dist so the resulting chains look the same, naming the parameters whatever they are named in the Stan model, using DP5 to match their solver, and so on.)


Some ideas:

  1. Should I avoid creating a new problem inside the model? Is this actually costly? If so, what’s the best way to set this up outside of the model?
  2. Putting aside the odd names and lack of special characters, which I’ve chosen to match Stan, are there any changes I should make to the code to make it align better with some style conventions which I’m accidentally violating?
  3. Is there some obvious optimization I’m missing?

A complete reproducible gist corresponding to the Julia code above is here.

1 Like

Perhaps the most obvious speedup is to use StaticArrays with an out-of-place form for the LV model. For small ODEs, this is known to speed things up considerably.

1 Like

Yes there’s lots of allocations on a small model, so using static arrays would make sense.

Allocating like that doesn’t make sense. Just loop instead to avoid the alloc.

2 Likes

Should I avoid creating a new problem inside the model? Is this actually costly? If so, what’s the best way to set this up outside of the model?

Yeah, don’t recreate the problem, eithe use remake explicitly or just pass the new parameters into solve, e.g. Bayesian Estimation of Differential Equations – Turing.jl

Putting aside the odd names and lack of special characters, which I’ve chosen to match Stan, are there any changes I should make to the code to make it align better with some style conventions which I’m accidentally violating?

  1. Use the sol.retcode to check whether the solver succeeded or not.
  2. If you’re indeed only going to have two y vectors, perf will be better if you don’t use the for loop but wriite it by hand.
  3. Using arraydist with different distributions is not ideal. In the above example it should probably work out without any issues due to types all being the same, but in general it’s not recommend for perf reasons. If you’re indeed only working with a small number of variables like this with different distributions, it’s much better, both semantically and perf-wise, to write them out as independent ~ statements.
  4. Explicit use of Truncated is discouraged, similarly use lower and upper kwargs instead of passing both values. That is, use truncated(dist, lower=0) in stead of Truncated(dist, 0, Inf).

Allocating like that doesn’t make sense. Just loop instead to avoid the alloc.

Note the : in := of these expressions; it’s a new syntax for including the quantities on the RHS in the resulting chains under representation of LHS. @JasonPekos is doing this on purpose to keep track of these:)

3 Likes