No method matching terminate!(::DiffEqSensitivity.FakeIntegrator

Hey,

i am trying to optimize the parameters of a Differential equation using the DiffEqFlux.
Since i am only interested into the new steady state, i am using a callback to terminate the integration, once this new state is reached. This looks sth like:

condition(u,t,integrator) = sum( abs.(u[2*n+1:end]) )  < 1e-7*sum(abs.(barX))
affect!(integrator) = terminate!(integrator)
cb = DiscreteCallback(condition,affect!)
sol = solve(prob,BS3(),p=p,callback=cb)

However, if i hit this callback when calculating the gradient, i get the following error message:

ERROR: MethodError: no method matching terminate!(::DiffEqSensitivity.FakeIntegrator{Vector{Float64}, Vector{Float64}, Float64})

So it seems, the terminate! is not implemented yet in the DiffEqSensitivity stack.
Is there an easy way to circumvent this problem?

(sorry for not providing a complete working example. hope its still clear, what the problem is.)

Yes, a pure AD sensealg should be fine. sol = solve(prob,BS3(),p=p,callback=cb,sensealg=ReverseDiffAdjoint()). But we should get an issue open on this anyways. However…

In this case you don’t need to differentiate through the ODE system at all! If you use SteadyStateProblem, it has a special adjoint that treats the case nicely.

And actually I was just reviewing. [WIP] State-dependent Continuous Callbacks by frankschae · Pull Request #445 · SciML/DiffEqSensitivity.jl · GitHub< should solve this.

Thanks so much, however, it i do that, i get a different error:

#%%
using DifferentialEquations, DiffEqFlux, Plots, Flux

#%%
function f!(du, u, p, t)
  du[1] = - p[1]*u[3]
  du[2] = - p[2]*u[3]
  du[3] = - du[1]*u[2] - du[2]*u[1]
end

#%%
u0 = [1.,1.,1.]

tspan = (0.0, 10.0)

#%%
condition(u,t,integrator) = abs.(u)[3] < 1e-4
affect!(integrator) = terminate!(integrator)
cb = DiscreteCallback(condition,affect!)

#%%
p = [1.,.5]
prob = ODEProblem(f!, u0, tspan,p)

sol = solve(prob,BS3(),p=p,callback = cb)
fit_sol = sol.u[end]

#%%
# Setup the ODE problem, then solve
plot(sol)

#%%
function loss(pl)
    sol = solve(prob,BS3(),p=pl) # ,callback=cb,sensealg=ReverseDiffAdjoint())
    state = sol.u[end]
    return sum(abs,sol[end][1:2] - fit_sol[1:2])
end

#%%
loss(p)

#########################
# THIS WORKS
#########################
#%%
pv = rand(2)
par = Flux.params(pv)
grads = gradient(() -> loss(pv), par )
grads[pv]

#%%
function loss(pl)
  sol = solve(prob,BS3(),p=pl,callback=cb,sensealg=ReverseDiffAdjoint())
  state = sol.u[end]
  return sum(abs,sol[end][1:2] - fit_sol[1:2])
end

#%%
loss(p)

#####################
# Here i get an error
#####################
#%%
par = Flux.params(pv)
grads = gradient(() -> loss(pv), par )
grads[pv]

The error is

ERROR: LoadError: TrackedArrays do not support setindex!
Stacktrace:

and it points back to the sol=... in the loss function