DiffEqJump/Catalyst, Turing possible?


I am wondering if it is possible to do parameter inference using Turing for a DiffEqJump model?

I understand that this type of problem might be better handled using an optimizer, but I am curious if it can be done in Bayesian mode? The jumps would make it hard for AD, I suppose, and cause it to be inefficient in exploring the parameter space. But perhaps there is a better parameterization or choice of options that might make it work? Or better still, I just did something silly and easily fixed? :slight_smile:

Here is a toy model to make it clearer and for you to pick apart:

using DiffEqJump, Turing, Catalyst

# create dummy data 
p = [ 0.5 ]
u0 = [ 10 ]
tspan = ( 0.0, 5.0 )
dt = 0.2

rs = @reaction_network begin
  r, X--> 2X
end r 

dp = DiscreteProblem( rs, u0, tspan, p )
jp = JumpProblem(rs, dp, Direct() )

jsim = solve( jp, SSAStepper(), saveat=dt )

# sample and add noise
keep = sample(1:size(jsim)[2], 20, replace = false, ordered=true)
testdata = Array(jsim[keep] )' + (0.05 * randn(20))
datatimes = jsim.t[keep]

# Plots.scatter(datatimes, testdata)

# estimate params as a Poisson process
@model function simpleExp(y, prob, N=length(y) )
  # priors
  r ~ Normal( 1.0, 1.0 )
  y0 ~ truncated( Cauchy(0.0, 0.5), 0, Inf )
  sigma ~ truncated( Cauchy(0.0, 0.5), 0, Inf )
  jp2 = remake( prob, u0=[y0], tspan=tspan, p=[r] )
  jsol = solve( jp2, SSAStepper() ) 
  # likelihood
  for i in 1:N
    j = findall(t -> t==datatimes[i], jsol.t)
    if length(j) > 0
      y[i] ~ Normal( jsol.u[j[1]][1], sigma   )

simplemodel = simpleExp(testdata, jp)

s = sample(simplemodel, MH(), 5 )  # almost works... but slow
s = sample(simplemodel, SMC(), 5 )  # fails
s = sample(simplemodel, NUTS(0.5), 5) #fails

You can’t do NUTS right now because the algorithm is not traditional AD compatible, though we will have a paper soon detailing how new ADs can be developed to handle it (along with a new AD engine that can handle it :wink:, but that probably won’t be integrated into Turing any time soon).

So for now you’d have to use derivative-free methods like Metropolis-Hastings (as you have found), but just wait a hot second and we’ll have real solutions be coming out soon enough.

Hi, Checking in if the hot second has passed yet?


Though it doesn’t quite work with Catalyst yet. It needs a different stepper because PoissonRandom.jl hasnt setup a derivative rule, though I believe @gaurav-arya has a working version hasn’t been made live yet

thank you kind sir!