Differentiating through a Jump Problem

Hi all,

I’m thinking about Gillespie simulations, as one might use to simulate the individual molecules of a chemical reaction.

Clearly a Gillespie Simulation (JumpProblem) is non-differentiable, as it has non-deterministic output dependent on the realisation of random variables (when is the next molecular collision event? depends on the output of the waiting time random variable).

However if I ‘freeze’ the noise on the forward simulation of the model, it seems like the simulation should be differentiable wrt parameters. i.e. a function of the form

(parameters, realisations of all random variables corresponding to reaction times and reaction choices) -> trajectory

I naively tried to implement this on a simple SIR jump process (as given in the documentation here using ForwardDiff, as follows:

using DifferentialEquations
using DiffEqBiological
using Plots
using ForwardDiff

sir_model = @reaction_network SIR begin
c1, s + i --> 2i
c2, i --> r
end c1 c2

p = [0.1/1000,0.01]
prob = DiscreteProblem([999,1,0],(0.0,250.0),p)

jump_prob = JumpProblem(prob,Direct(),sir_model)

nomsol = solve(jump_prob,FunctionMap())
plot(nomsol)

function param_sol(p_, nomprob)
println(typeof(p_))
# probp = remake(nomsol.prob, p=p_;u0=convert.(eltype(p_),nomprob.u0))
probp = remake(nomsol.prob;u0=convert.(eltype(p_),nomprob.u0),p=p_)
return solve(probp)
end

function forward_pass(p_)
prob = DiscreteProblem([999,1,0],(0.0,250.0),p_)
prob = remake(prob; u0=convert.(eltype(p_),prob.u0),p=p_)
jump_prob = JumpProblem(prob,sir_model)
sol = solve(jump_prob,FunctionMap())
loss = norm(sol[end] .- [1,200,700])
return loss
end

loss_gradient = p_ -> Zygote.gradient(forward_pass,p_)
println(loss_gradient§)

I got an error:

LoadError: MethodError: no method matching JumpProblem(::DiscreteProblem{Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(forward_pass),Float64},Float64,2},1},Tuple{Float64,Float64},true,Array{ForwardDiff.Dual{ForwardDiff.Tag{typeof(forward_pass),Float64},Float64,2},1},DiscreteFunction{true,getfield(DiffEqBase, Symbol("##161#162")),Nothing,Nothing},Base.Iterators.Pairs{Union{},Union{},Tuple{},NamedTuple{(),Tuple{}}}}, ::SIR)

Is what I am trying to do doomed in principle to failure? Is there a way to modify this code so it could work?

Thanks a lot in advance!

1 Like

In principle it can work. Let me take a stab at it later this week. Feel free to remind me if I forget!

1 Like

Much appreciated! Massive fan of all the work you have been doing in DifferentialEquations.jl, it’s really helped me be much more productive as a researcher.

Do you typically have only end-time observations or is this for the sake of the example?

Hi, thanks for your interest :). I’m sorry I don’t know what you mean by an end-time observation.

In your example, the loss is defined in terms of the state of the process at terminal time T or am I misreading the code? Does this correspond to your application problem?

Oh I see,

No I’d ideally like something more general. I just thought this was an easy loss function for the example.

That said, I’d be interested if there were some easier way of doing this for the special case of end-time observations! Although my intuition says that is unlikely…

Ah, this is very close to what we are doing right now with https://github.com/mmider/BridgeSDEInference.jl , with a pull request pending cleaning up some things and implementing parameter inference for a Gillespie type model (a diffusion approximation to a jump process modelling auto-regulatory protein expression network https://mmider.github.io/BridgeSDEInference.jl/dev/examples/prokaryote/).

2 Likes

Looks very cool, I will keep an eye on it!

We will make an announcement here on Discourse when it is ready.

1 Like

I took a stab at it, but it turns out to be impossible :laughing:. The reason is because after doing the rate calculations, you have to arrive at an integer for which jump is going to occur, and that integer value is going to “break” the differentiation chain (duals come out zero), so AD cannot be directly applied here. Mixings don’t seem to make sense.

At a more fundamental level, I dug up a Petzold paper that made it a bit more explicit. They took a look at the chemical master equations, and sure enough, the resulting sensitivity equations are not something that can be estimated using an SSA algorithm.

Such an equation should be solved simultaneously with the CME. As with the CME, the infinite dimensionality of the coupled sensitivity-CME differential equation makes its analytical solution difficult to construct. Moreover, the SSA cannot be directly applied to solve the sensitivity equation without loss of rigorous physical basis (Gillespie, 1992a). These reasons motivate application of a black-box approach, such as finite difference, to estimate the sensitivity coefficients below

The bigger issue than what they mention there is that the sensitivity (the derivative of a distribution) is not necessarily a distribution itself, so you’d have to find some way of normalizing it on the fly in order to create some sampling based approach for it. It looks like such a normalizing method is derived here:

https://royalsocietypublishing.org/doi/full/10.1098/rsif.2014.0979

but AD won’t give you back that same algorithm automatically. Instead, their method can be implemented, replacing finite differences with dual numbers, and it would work, but it’s fundamentally different than just taking the derivative along each trajectory like simple applications of AD here would do.

So, :man_shrugging: this needs an adjoint. Note that certain subsets of this problem are AD-differentiable, i.e. if you just have one rate and want to differentiate the solution with respect to a jump parameter in affect function, and a single jump with a rate can AD though (if time is a Dual number as well, so you’d need to make a small modification). But the moment you aggregate rates you need to do something fundamentally different.

This is pointing to the idea though that the Next Reaction Method is likely AD-differentiable, so it would be nice to do that (in a similar way to SDEs). However, I wouldn’t expect anything doing fancy aggregations to be AD-differentiable because of how it has to break the link between the rate computation and the choice of reaction.

@isaacsas this would be interesting to look more into.

2 Likes

It would be interesting to see this on the common interface. I see there’s some element of this going on already: https://github.com/mmider/BridgeSDEInference.jl/pull/47/files#diff-7d60385b868acac865650bab11dd1610 and it would be cool to have a good source to just point DiffEqBiological users to for parameter inference of their Gillespie model.

Perhaps one should ask what you is the differential you want to compute. Do you want the differential of the stochastic flow, the Malliavin derivative… ?

1 Like

One more reason why to move to a continuum approximation :slight_smile:

You can’t always though. You need some very specific problems for that to hold.

Yes! I imagine we sit together (in a while, we are still busy here) and repeat the exercise of setting up an interface like we did before.

Hi,

My higher level goal was to ‘train’ the mean trajectory to minimise some loss function. As such, when I say I want a ‘gradient’, I mean I want ‘some direction in parameter space that improves the loss function in expectation’. The Malliavin derivative is appropriate when I want to explicitly consider all the random variables realised during a Gillespie simulation, as random variables. Basically I was wondering whether I could ‘cheat’, by a priori realising all the random variables (thus making the forward pass deterministic), and computing a standard gradient with respect to the now deterministic mapping from parameters to trajectory.

Of course, this would not give me the true Malliavin derivative. I doubt my gradient has any nice mathematical properties either. But very approximate gradients are often sufficient for fast learning, as the success of online gradient descent in machine learning demonstrates.(In that case one takes a stochastic gradient, where the stochastic component often hugely outcompetes the deterministic term to train a model). I guess I wanted to test empirically whether the ‘gradient’ I wanted to compute would be effective, in the sense that ‘gradient’ descent would effectively decrease a loss function on the parameters.

TLDR: I want a direction in parameter space that improves a loss function, and is easy to compute.

Thanks so much for this really detailed analysis!

Yes so the step of the algorithm that chooses which reaction occurs, given the next reaction time, is obviously not a step that you can feed a dual through. That makes sense. Therefore we need something like ‘independent simulation threads’, that work out the next reaction time for each reaction individually. The problem is that the reaction rates of one species depend upon concentrations of other species, which constitute ‘different threads’.

It seems to me like a fixed timestep approximation (i.e. tau leaping) would solve this issue, and be differentiable through.

Thanks also for the references, they are useful.

I tried running the code posted in the first post, but the forward function did not run outside differentiation. If you have a version that runs, I would like to try a hacky tool I put together to differentiate this kind of functions.

Yeah Tau leaping would do it. Try our Tau leaping method and it should just work