I am trying to get gradients of the solution of an SDE where a callback function is used in the
solve command. It seems that the forward adjoint method is not compatible with the usage of callbacks as I get
"ERROR: type Pairs has no field callback". Or is there a way how to make it work?
Here is a simple example derived from the Lotka-Volterra tutorials
using DifferentialEquations, Flux, DiffEqFlux using DiffEqSensitivity function dt!(du, u, p, t) x, y = u α, β, δ, γ = p du = dx = α*x - β*x*y du = dy = -δ*y + γ*x*y end function dW!(du, u, p, t) du = 0.1u du = 0.1u end u0 = [1.0,1.0] tspan = (0.0, 10.0) p = [2.2, 1.0, 2.0, 0.4] prob_sde = SDEProblem(dt!, dW!, u0, tspan,p) condition(u,t,integrator) = integrator.t >9.0 #some condition function affect!(integrator) println("Callback") #some callback end cb = DiscreteCallback(condition,affect!,save_positions=(false,false)) function predict_sde(p) return Array(solve(prob_sde, EM(), saveat = 0.1,sensealg = ForwardDiffSensitivity(), dt=0.001, callback=cb)) end loss_sde(p)= sum(abs2, x-1 for x in predict_sde(p)) ps = Flux.params(p) @time gs = gradient(ps) do loss_sde(p) end #ERROR: type Pairs has no field callback