I am trying to get gradients of the solution of an SDE where a callback function is used in the solve
command. It seems that the forward adjoint method is not compatible with the usage of callbacks as I get "ERROR: type Pairs has no field callback"
. Or is there a way how to make it work?
Here is a simple example derived from the Lotka-Volterra tutorials
using DifferentialEquations, Flux, DiffEqFlux
using DiffEqSensitivity
function dt!(du, u, p, t)
x, y = u
α, β, δ, γ = p
du[1] = dx = α*x - β*x*y
du[2] = dy = -δ*y + γ*x*y
end
function dW!(du, u, p, t)
du[1] = 0.1u[1]
du[2] = 0.1u[2]
end
u0 = [1.0,1.0]
tspan = (0.0, 10.0)
p = [2.2, 1.0, 2.0, 0.4]
prob_sde = SDEProblem(dt!, dW!, u0, tspan,p)
condition(u,t,integrator) = integrator.t >9.0 #some condition
function affect!(integrator)
println("Callback") #some callback
end
cb = DiscreteCallback(condition,affect!,save_positions=(false,false))
function predict_sde(p)
return Array(solve(prob_sde, EM(), saveat = 0.1,sensealg = ForwardDiffSensitivity(), dt=0.001, callback=cb))
end
loss_sde(p)= sum(abs2, x-1 for x in predict_sde(p))
ps = Flux.params(p)
@time gs = gradient(ps) do
loss_sde(p)
end #ERROR: type Pairs has no field callback