Differentiating through a Jump Problem

Lots of interesting stuff has been done with regards to sensitivity analysis, I think also involving coupling through multilevel methods in recent years. David Anderson has a lot on this too.

1 Like

Sorry, I posted sloppy code. I’ve cleaned it up now

EDIT: I’ve made a less sloppy version of the above code that’s a bit more minimal and actually runs!

using DifferentialEquations
using DiffEqBiological
using ForwardDiff
using LinearAlgebra

sir_model = @reaction_network SIR begin
c1, s + i → 2i
c2, i → r
end c1 c2

p = [0.1/1000,0.01]
prob = DiscreteProblem([999,1,0],(0.0,250.0),p)

function forward_pass(p_)
prob = DiscreteProblem([999,1,0],(0.0,250.0),p_)
prob = remake(prob; u0=convert.(eltype(p_),prob.u0),p=p_)
jump_prob = JumpProblem(prob, Direct(), sir_model)
sol = solve(jump_prob,FunctionMap())
loss = norm(sol[end] .- [1,200,700])
return loss
end

loss_gradient = p_ → ForwardDiff.gradient(forward_pass,p_)
println(loss_gradient(p))

It didn’t! Though that’s probably my coding inexperience. Code and error posted below.

However, maybe this whole idea is ill-conceived. Even with tau leaping, we have an everywhere non-differentiable (in the mathematical sense) function from parameters to to trajectory (which is a step function, being a Discrete problem). I’m not seeing how ForwardDiff can propagate a dual through these discrete jumps.

Code:

using DifferentialEquations
using DiffEqBiological
using ForwardDiff
using LinearAlgebra
using Plots
pyplot()

“”" rates of each reaction “”"
function rate(out,u,p,t)
out[1] = p[1]*u[1]*u[2] #infection
out[2] = p[2]*u[2] # recovery
end

u0 = [999.0,1.0,0.0] #susceptible infected recovered
p0 = [0.1/1000,0.01]

“”" change matrix upon reaction “”"
function c(dc,u,p,t,mark)
“”" first reaction loses susceptible, gains infected “”"
dc[1,1] = -1
dc[2,1] = 1
“”" second reaction loses infected, gains recovered “”"
dc[2,2] = -1
dc[3,2] = 1
end

“”" naive problem solution “”"
dc = zeros(3,2)
rj = RegularJump(rate,c,dc;constant_c=true)

prob = DiscreteProblem(u0,(0.0,250.0),p0)
jump_prob = JumpProblem(prob,Direct(),rj)
sol = solve(jump_prob,SimpleTauLeaping();dt=1.0)

function forward_pass(p_)
prob = DiscreteProblem(u0,(0.0,250.0),p_)
prob = remake(prob; u0=convert.(eltype(p_),prob.u0),p=p_)
jump_prob = JumpProblem(prob,Direct(),rj)
sol = solve(jump_prob,SimpleTauLeaping();dt=1.0)
return loss(sol)
end

function loss(sol)
return norm(sol[end] .- [1,200,700])
end

loss_gradient = p_ → ForwardDiff.gradient(forward_pass,p_)

Error:

ERROR: MethodError: no method matching Float64(::ForwardDiff.Dual{ForwardDiff.Tag{typeof(forward_pass),Float64},Float64,2})
Closest candidates are:
Float64(::Real, ::RoundingMode) where T<:AbstractFloat at rounding.jl:194
Float64(::T<:Number) where T<:Number at boot.jl:741
Float64(::Int8) at float.jl:60

It should be as differentiable as the SDE, in that there is a strong derivative. I’ll look at this a bit later.

Sure, thanks again!

I think there’s a difference: The output trajectory for a reactant is a step function that can only take integer values. You can’t have infinitesimal changes in the output trajectory over a fixed timestep. For an SDE, the output can be changed infinitesimally, as it takes real values.

I’m having a stab at this and I can’t quite get the DifEq part to accept my types. I manage to get the prob to be

julia> prob
DiscreteProblem with uType Array{Particles{Float64,500},1} and tType Float64. In-place: true
timespan: (0.0, 250.0)
u0: Particles{Float64,500}[999.0, 1.0, 0.0]

but the function rate is still called with the signature

(Array{Float64,1}, Array{Particles{Float64,500},1}, Array{Particles{Float64,500},1}, Float64)

i.e., the array to store the derivative in is still Array{Float64,1}.

Edit: PR https://github.com/JuliaDiffEq/DiffEqJump.jl/pull/88

I tried making the problem inplace=false but I couldn’t figure out how to do that.

Interesting. I’ll take a look at that. Got a lot going on this week, but this is always fun procrastination work.

PR https://github.com/JuliaDiffEq/DiffEqJump.jl/pull/88