DiffEqJump/Catalyst, Turing possible?


I am wondering if it is possible to do parameter inference using Turing for a DiffEqJump model?

I understand that this type of problem might be better handled using an optimizer, but I am curious if it can be done in Bayesian mode? The jumps would make it hard for AD, I suppose, and cause it to be inefficient in exploring the parameter space. But perhaps there is a better parameterization or choice of options that might make it work? Or better still, I just did something silly and easily fixed? :slight_smile:

Here is a toy model to make it clearer and for you to pick apart:

using DiffEqJump, Turing, Catalyst

# create dummy data 
p = [ 0.5 ]
u0 = [ 10 ]
tspan = ( 0.0, 5.0 )
dt = 0.2

rs = @reaction_network begin
  r, X--> 2X
end r 

dp = DiscreteProblem( rs, u0, tspan, p )
jp = JumpProblem(rs, dp, Direct() )

jsim = solve( jp, SSAStepper(), saveat=dt )

# sample and add noise
keep = sample(1:size(jsim)[2], 20, replace = false, ordered=true)
testdata = Array(jsim[keep] )' + (0.05 * randn(20))
datatimes = jsim.t[keep]

# Plots.scatter(datatimes, testdata)

# estimate params as a Poisson process
@model function simpleExp(y, prob, N=length(y) )
  # priors
  r ~ Normal( 1.0, 1.0 )
  y0 ~ truncated( Cauchy(0.0, 0.5), 0, Inf )
  sigma ~ truncated( Cauchy(0.0, 0.5), 0, Inf )
  jp2 = remake( prob, u0=[y0], tspan=tspan, p=[r] )
  jsol = solve( jp2, SSAStepper() ) 
  # likelihood
  for i in 1:N
    j = findall(t -> t==datatimes[i], jsol.t)
    if length(j) > 0
      y[i] ~ Normal( jsol.u[j[1]][1], sigma   )

simplemodel = simpleExp(testdata, jp)

s = sample(simplemodel, MH(), 5 )  # almost works... but slow
s = sample(simplemodel, SMC(), 5 )  # fails
s = sample(simplemodel, NUTS(0.5), 5) #fails

You can’t do NUTS right now because the algorithm is not traditional AD compatible, though we will have a paper soon detailing how new ADs can be developed to handle it (along with a new AD engine that can handle it :wink:, but that probably won’t be integrated into Turing any time soon).

So for now you’d have to use derivative-free methods like Metropolis-Hastings (as you have found), but just wait a hot second and we’ll have real solutions be coming out soon enough.